Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions xmitgcm/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,17 +957,21 @@ def test_get_extra_metadata(domain, nx):
em = get_extra_metadata(domain='notinlist', nx=nx)


@pytest.mark.parametrize("outer", [True, False])
@pytest.mark.parametrize("usedask", [True, False])
def test_get_grid_from_input(all_grid_datadirs, usedask):
def test_get_grid_from_input(all_grid_datadirs, usedask, outer):
from xmitgcm.utils import get_grid_from_input, get_extra_metadata
from xmitgcm.utils import read_raw_data
dirname, expected = all_grid_datadirs
md = get_extra_metadata(domain=expected['domain'], nx=expected['nx'])

ds = get_grid_from_input(dirname + '/' + expected['gridfile'],
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
use_dask=usedask,
extra_metadata=md)
extra_metadata=md,
outer=outer)

# test types
assert type(ds) == xarray.Dataset
assert type(ds['XC']) == xarray.core.dataarray.DataArray
Expand All @@ -980,9 +984,23 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

outerx_vars = ['DXC', 'RAW', 'DYG']
outery_vars = ['DYC', 'RAS', 'DXG']
outerxy_vars = ['XG', 'YG', 'RAZ']

for var in expected_variables:
assert type(ds[var]) == xarray.core.dataarray.DataArray
assert ds[var].values.shape == expected['shape']
expected_shape_outer = list(expected['shape'])
if var in outerx_vars or var in outerxy_vars:
expected_shape_outer[-1] = expected_shape_outer[-1] + 1
if var in outery_vars or var in outerxy_vars:
expected_shape_outer[-2] = expected_shape_outer[-2] + 1

if outer:
assert type(ds[var]) == xarray.core.dataarray.DataArray
Comment thread
AaronDavidSchneider marked this conversation as resolved.
Outdated
assert ds[var].values.shape == tuple(expected_shape_outer)
else:
assert type(ds[var]) == xarray.core.dataarray.DataArray
assert ds[var].values.shape == expected['shape']

# check we don't leave points behind
if expected['geometry'] == 'llc':
Comment thread
AaronDavidSchneider marked this conversation as resolved.
Expand Down Expand Up @@ -1033,13 +1051,13 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
order='F', partial_read=True,
offset=nx*ny5*sizeofd)

xc = np.concatenate([xc1[:-1, :-1].flatten(), xc2[:-1, :-1].flatten(),
xc3[:-1, :-1].flatten(), xc4[:-1, :-1].flatten(),
xc5[:-1, :-1].flatten()])
xc = np.concatenate([xc1.flatten(), xc2.flatten(),
xc3.flatten(), xc4.flatten(),
xc5.flatten()])

yc = np.concatenate([yc1[:-1, :-1].flatten(), yc2[:-1, :-1].flatten(),
yc3[:-1, :-1].flatten(), yc4[:-1, :-1].flatten(),
yc5[:-1, :-1].flatten()])
yc = np.concatenate([yc1.flatten(), yc2.flatten(),
yc3.flatten(), yc4.flatten(),
yc5.flatten()])

xc_from_ds = ds['XC'].values.flatten()
yc_from_ds = ds['YC'].values.flatten()
Expand All @@ -1056,7 +1074,8 @@ def test_get_grid_from_input(all_grid_datadirs, usedask):
geometry=expected['geometry'],
dtype=np.dtype('d'), endian='>',
use_dask=False,
extra_metadata=None)
extra_metadata=None,
outer=outer)


@pytest.mark.parametrize("dtype", [np.dtype('d'), np.dtype('f')])
Expand Down
55 changes: 40 additions & 15 deletions xmitgcm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def _pad_array(data, file_metadata, face=0):


def get_extra_metadata(domain='llc', nx=90):
"""
"""
Return the extra_metadata dictionay for selected domains

PARAMETERS
Expand Down Expand Up @@ -1308,9 +1308,9 @@ def get_extra_metadata(domain='llc', nx=90):


def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
dtype=np.dtype('d'), endian='>', use_dask=False,
dtype=np.dtype('d'), endian='>', use_dask=False, outer=False,
extra_metadata=None):
"""
"""
Read grid variables from grid input files, this is especially useful
for llc and cube sphere configurations used with land tiles
elimination. Reading the input grid files (e.g. tile00[1-5].mitgrid)
Expand All @@ -1332,11 +1332,13 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
endianness of input data
use_dask : bool
use dask or not
outer : bool
include outer boundary or not
extra_metadata : dict
dictionary of extra metadata, needed for llc configurations

RETURNS
-------
-------
grid : xarray.Dataset
all grid variables
"""
Expand All @@ -1347,6 +1349,10 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
'XG', 'YG', 'DXV', 'DYU', 'RAZ',
'DXC', 'DYC', 'RAW', 'RAS', 'DXG', 'DYG']

outerx_vars = ['DXC', 'RAW', 'DYG'] if outer else []
outery_vars = ['DYC', 'RAS', 'DXG'] if outer else []
outerxy_vars = ['XG', 'YG', 'RAZ'] if outer else []

file_metadata['vars'] = file_metadata['fldList']
dims_vars_list = []
for var in file_metadata['fldList']:
Expand Down Expand Up @@ -1399,6 +1405,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
nxgrid = file_metadata['ny_facets'][kfacet] + 1
nygrid = file_metadata['nx'] + 1


grid_metadata.update({'nx': nxgrid, 'ny': nygrid,
'has_faces': False})

Expand All @@ -1412,8 +1419,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
{file_metadata['fldList'][kfield]: raw[kfield]})

for field in file_metadata['fldList']:
# symetrize
tmp = rawfields[field][:, :, :-1, :-1].squeeze()

# get the full array
tmp = rawfields[field].squeeze()
# transpose
if grid_metadata['facet_orders'][kfacet] == 'F':
tmp = tmp.transpose()
Expand All @@ -1423,15 +1431,30 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
if grid_metadata['face_facets'][face] == kfacet:
# get offset of face from facet
offset = file_metadata['face_offsets'][face]
nx = file_metadata['nx']

nx = file_metadata['nx'] + 1
nxm1 = file_metadata['nx']
pad_metadata = file_metadata.copy()
pad_metadata['nx'] = file_metadata['nx'] + 1
# pad data, if needed (would trigger eager data eval)
# needs a new array not to pad multiple times
padded = _pad_array(tmp, file_metadata, face=face)
padded = _pad_array(tmp, pad_metadata, face=face)
# extract the data
dataface = padded[offset*nx:(offset+1)*nx, :]
dataface = padded[offset*nxm1:offset*nxm1 + nx, :]
# transpose, if needed
if file_metadata['transpose_face'][face]:
dataface = dataface.transpose()

# remove irrelevant data
if field in outerx_vars:
dataface = dataface[..., :-1, :].squeeze()
elif field in outery_vars:
dataface = dataface[..., :-1].squeeze()
elif field in outerxy_vars:
dataface = dataface.squeeze()
else:
dataface = dataface[..., :-1, :-1].squeeze()

# assign values
dataface = dsa.stack([dataface], axis=0)
if face == 0:
Expand All @@ -1441,6 +1464,7 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
[gridfields[field], dataface], axis=0)

# create the dataset
nxouter = file_metadata['nx'] + 1 if outer else file_metadata['nx']
if geometry == 'llc':
grid = xr.Dataset({'XC': (['face', 'j', 'i'], gridfields['XC']),
'YC': (['face', 'j', 'i'], gridfields['YC']),
Expand All @@ -1462,9 +1486,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['nx'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
Comment thread
AaronDavidSchneider marked this conversation as resolved.
'face': (['face'], np.arange(nfaces))
}
)
Expand All @@ -1489,13 +1513,14 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['nx'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'face': (['face'], np.arange(nfaces))
}
)
else: # pragma: no cover
nyouter = file_metadata['ny'] + 1 if outer else file_metadata['ny']
grid = xr.Dataset({'XC': (['j', 'i'], gridfields['XC']),
'YC': (['j', 'i'], gridfields['YC']),
'DXF': (['j', 'i'], gridfields['DXF']),
Expand All @@ -1516,9 +1541,9 @@ def get_grid_from_input(gridfile, nx=None, ny=None, geometry='llc',
coords={'i': (['i'], np.arange(file_metadata['nx'])),
'j': (['j'], np.arange(file_metadata['ny'])),
'i_g': (['i_g'],
np.arange(file_metadata['nx'])),
np.arange(nxouter)),
'j_g': (['j_g'],
np.arange(file_metadata['ny']))
np.arange(nyouter))
}
)

Expand Down