diff --git a/xmitgcm/test/test_utils.py b/xmitgcm/test/test_utils.py index d5082344..693c70b0 100644 --- a/xmitgcm/test/test_utils.py +++ b/xmitgcm/test/test_utils.py @@ -973,6 +973,49 @@ def test_get_grid_from_input(all_grid_datadirs, usedask): dtype=np.dtype('d'), endian='>', use_dask=False, extra_metadata=None) + + + +@pytest.mark.parametrize("usedask", [True, False]) +def test_get_xg_yg_from_input(all_grid_datadirs, usedask): + from xmitgcm.utils import get_xg_yg_from_input, get_extra_metadata + from xmitgcm.utils import read_raw_data + dirname, expected = all_grid_datadirs + md = get_extra_metadata(domain=expected['domain'], nx=expected['nx']) + tx=30 + ty=30 + bl=[1,2,3] + ds = get_xg_yg_from_input(dirname + '/' + expected['gridfile'], + geometry=expected['geometry'], + dtype=np.dtype('d'), endian='>', + use_dask=usedask, + extra_metadata=md, + tilex=tx,tiley=ty, + blankList=bl) + # test types + assert type(ds) == xarray.Dataset + assert type(ds['XG']) == xarray.core.dataarray.DataArray + + if usedask: + ds.load() + + # check all variables are in + expected_variables = ['XG', 'YG'] + + for var in expected_variables: + assert type(ds[var]) == xarray.core.dataarray.DataArray + assert ds[var].values.shape[1] == tx+1 + assert ds[var].values.shape[2] ==ty+1 + + + # passing llc without metadata should fail + if expected['geometry'] == 'llc': + with pytest.raises(ValueError): + ds = get_xg_yg_from_input(dirname + '/' + expected['gridfile'], + geometry=expected['geometry'], + dtype=np.dtype('d'), endian='>', + use_dask=False, + extra_metadata=None) @pytest.mark.parametrize("dtype", [np.dtype('d'), np.dtype('f')]) diff --git a/xmitgcm/utils.py b/xmitgcm/utils.py index 52ac5746..f12d2109 100644 --- a/xmitgcm/utils.py +++ b/xmitgcm/utils.py @@ -756,6 +756,46 @@ def _llc_data_shape(llc_id, nz=None): return data_shape +def _file_metadata(endian='>',dtype=np.dtype('d'),extra_metadata=None): + file_metadata = {} + # grid variables are stored in this order + file_metadata['fldList'] = ['XC', 'YC', 'DXF', 'DYF', 'RAC', + 'XG', 'YG', 'DXV', 'DYU', 'RAZ', + 'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG'] + + file_metadata['vars'] = file_metadata['fldList'] + dims_vars_list = [] + for var in file_metadata['fldList']: + dims_vars_list.append(('ny', 'nx')) + file_metadata['dims_vars'] = dims_vars_list + + # no vertical levels or time records + file_metadata['nz'] = 1 + file_metadata['nt'] = 1 + +# for curvilinear non-facet grids (TO DO) +# if nx is not None: +# file_metadata['nx'] = nx +# if ny is not None: +# file_metadata['ny'] = ny + if extra_metadata is not None: + file_metadata.update(extra_metadata) + + # numeric representation + file_metadata['endian'] = endian + file_metadata['dtype'] = dtype + return file_metadata + +def _nxgrid_nygrid(file_metadata,kfacet): + if file_metadata['facet_orders'][kfacet] == 'C': + nxgrid = file_metadata['nx'] + 1 + nygrid = file_metadata['ny_facets'][kfacet] + 1 + elif file_metadata['facet_orders'][kfacet] == 'F': + nxgrid = file_metadata['ny_facets'][kfacet] + 1 + nygrid = file_metadata['nx'] + 1 + return nxgrid,nygrid + + def read_all_variables(variable_list, file_metadata, use_mmap=False, use_dask=False, chunks="3D"): """ @@ -1278,33 +1318,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc', all grid variables """ - file_metadata = {} - # grid variables are stored in this order - file_metadata['fldList'] = ['XC', 'YC', 'DXF', 'DYF', 'RAC', - 'XG', 'YG', 'DXV', 'DYU', 'RAZ', - 'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG'] - - file_metadata['vars'] = file_metadata['fldList'] - dims_vars_list = [] - for var in file_metadata['fldList']: - dims_vars_list.append(('ny', 'nx')) - file_metadata['dims_vars'] = dims_vars_list - - # no vertical levels or time records - file_metadata['nz'] = 1 - file_metadata['nt'] = 1 - -# for curvilinear non-facet grids (TO DO) -# if nx is not None: -# file_metadata['nx'] = nx -# if ny is not None: -# file_metadata['ny'] = ny - if extra_metadata is not None: - file_metadata.update(extra_metadata) - - # numeric representation - file_metadata['endian'] = endian - file_metadata['dtype'] = dtype + file_metadata=_file_metadata(endian,dtype,extra_metadata) if geometry == 'llc': nfacets = 5 @@ -1328,12 +1342,8 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc', fname = gridfile.replace('', str(kfacet+1).zfill(3)) grid_metadata['filename'] = fname - if file_metadata['facet_orders'][kfacet] == 'C': - nxgrid = file_metadata['nx'] + 1 - nygrid = file_metadata['ny_facets'][kfacet] + 1 - elif file_metadata['facet_orders'][kfacet] == 'F': - nxgrid = file_metadata['ny_facets'][kfacet] + 1 - nygrid = file_metadata['nx'] + 1 + nxgrid,nygrid=_nxgrid_nygrid(file_metadata,kfacet) + grid_metadata.update({'nx': nxgrid, 'ny': nygrid, 'has_faces': False}) @@ -1437,6 +1447,154 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc', return grid +def get_xg_yg_from_input(gridfile, nx=None, ny=None, geometry='llc', + dtype=np.dtype('d'), endian='>', use_dask=False, + extra_metadata=None, tilex=30, tiley=30, blankList=None): + """ + Read grid variables from grid input files, and tiles them according to + tilesizes input by the user, skipping blank tiles. + This function only reads xg and yg, and outputs all the values stored in the + input grid file (including the rightmost and uppermost xg and yg values). + It is useful for findingwhere a lat/lon point is on the llc grid. + + PARAMETERS + ---------- + gridfile : str + gridfile must contain as wildcard (e.g. tile.mitgrid) + nx : int + size of the face in the x direction + ny : int + size of the face in the y direction + geometry : str + domain geometry can be llc, cs or carthesian not supported yet + dtype : np.dtype + numeric precision (single/double) of input data + endian : string + endianness of input data + use_dask : bool + use dask or not + extra_metadata : dict + dictionary of extra metadata, needed for llc configurations + tilex : int + size of tile in the x direction + tiley : int + size of tile in the y direction + blankList : arraylike + List of blank tiles (indexing starts at 1 so that you can copy + directly from data.exch2) + RETURNS + ------- + grid : xarray.Dataset + all grid variables + """ + + file_metadata=_file_metadata(endian,dtype,extra_metadata) + + if geometry == 'llc': + nfacets = 5 + try: + nfaces = len(file_metadata['face_facets']) + except: + raise ValueError('metadata must contain face_facets') + if geometry == 'cs': # pragma: no cover + raise NotImplementedError("'cs' geometry is not supported yet") + + # create placeholders for data + gridfields = {} + for field in ['XG', 'YG']:# + gridfields.update({field: None}) + + if geometry == 'llc': + tileno=0 + dummy=0 + for kfacet in range(nfacets): + # we need to adapt the metadata to the grid file + grid_metadata = file_metadata.copy() + + fname = gridfile.replace('', str(kfacet+1).zfill(3)) + grid_metadata['filename'] = fname + + nxgrid,nygrid=_nxgrid_nygrid(file_metadata,kfacet) + + grid_metadata.update({'nx': nxgrid, 'ny': nygrid, + 'has_faces': False}) + + raw = read_all_variables(grid_metadata['vars'], grid_metadata, + use_dask=use_dask) + + rawfields = {} + for kfield in np.arange(len(file_metadata['fldList'])): + + rawfields.update( + {file_metadata['fldList'][kfield]: raw[kfield]}) + + tiles_on_facet=(nxgrid-1)*(nygrid-1)//tilex//tiley + tile_in_x=(nxgrid-1)//tilex + tile_in_y=(nygrid-1)//tiley + + for field in ['XG', 'YG']: + if field =='XG': + save_tile=tileno + else: + tileno=save_tile + if kfacet == 0: + dummy=0 + # symetrize + tmp = rawfields[field][:, :, :, :].squeeze() + # transpose + if grid_metadata['facet_orders'][kfacet] == 'F': + tmp = tmp.transpose() + + for tileon in range(0,tiles_on_facet): + tileno=tileno+1 + if tileno not in blankList: + offsety=(tileon//tile_in_x) + offsetx=(tileon-offsety*tile_in_x) + #transpose facet if needed + tmpt=tmp + if file_metadata['facet_orders'][kfacet] == 'F': + tmpt=tmp.transpose() + # extract the data + dataface = tmpt[offsety*tiley:(offsety+1)*tiley+1,offsetx*tilex:(offsetx+1)*tilex+1] + # assign values + dataface = dsa.stack([dataface], axis=0) + if dummy == 0: + gridfields[field] = dataface + dummy=1 + else: + gridfields[field] = dsa.concatenate( + [gridfields[field], dataface], axis=0) + + elif geometry == 'cs': # pragma: no cover + raise NotImplementedError("'cs' geometry is not supported yet") + pass + + # create the dataset + if geometry in ['llc', 'cs']: + ntile=gridfields['XG'].shape[0] + grid = xr.Dataset({'XG': (['tile', 'j_g', 'i_g'], gridfields['XG']), + 'YG': (['tile', 'j_g', 'i_g'], gridfields['YG']), + }, + coords={'i_g': (['i_g'], + np.arange(tilex+1)), + 'j_g': (['j_g'], + np.arange(tiley+1)), + 'tile': (['tile'], np.arange(ntile)) + } + ) + else: # pragma: no cover + grid = xr.Dataset({'XG': (['j_g', 'i_g'], gridfields['XG']), + 'YG': (['j_g', 'i_g'], gridfields['YG']), + }, + coords={'i_g': (['i_g'], + np.arange(tilex+1)), + 'j_g': (['j_g'], + np.arange(tiley+1)) + } + ) + + return grid + ########## WRITING BINARIES #############################