import ast
import concurrent.futures
import os
import typing
import warnings
from copy import deepcopy
import dask
import packaging.version
import xarray as xr
try:
if packaging.version.Version(xr.__version__) < packaging.version.Version('2024.10'):
from datatree import DataTree
else:
from xarray import DataTree
_DATATREE_AVAILABLE = True
except ImportError:
_DATATREE_AVAILABLE = False
import itables
import pandas as pd
import polars as pl
import pydantic
from fastprogress.fastprogress import progress_bar
from intake.catalog import Catalog
from .cat import ESMCatalogModel
from .derived import DerivedVariableRegistry, default_registry
from .source import ESMDataSource
from .utils import MinimalExploder
[docs]class esm_datastore(Catalog):
"""
An intake plugin for parsing an ESM (Earth System Model) Catalog
and loading assets (netCDF files and/or Zarr stores) into xarray datasets.
The in-memory representation for the catalog is a Pandas DataFrame.
Parameters
----------
obj : str, dict, ESMCatalogModel
The ESM Catalog to use, or a path to a JSON file containing the catalog.
If string, this must be a path or URL to an ESM catalog JSON file.
If dict, this must be a dict representation of an ESM catalog.
This dict must have two keys: 'esmcat' and 'df'. The 'esmcat' key must be a
dict representation of the ESM catalog and the 'df' key must
be a Pandas DataFrame containing content that would otherwise be in a CSV file.
sep : str, optional
Delimiter to use when constructing a key for a query, by default '.'
registry : DerivedVariableRegistry, optional
Registry of derived variables to use, by default None. If not provided, uses the default registry.
read_kwargs : dict, optional
Additional keyword arguments passed through to the :py:func:`~polars.scan_csv` function, if the
datastore is saved in csv format, or :py:func:`~polars.scan_parquet` if the datastore is saved in
parquet format.
read_csv_kwargs : dict, optional
Deprecated alias for `read_kwargs`.
columns_with_iterables : list of str, optional
A list of columns in the csv file containing iterables. Values in columns specified here will be
converted with `ast.literal_eval` when :py:func:`~pandas.read_csv` is called (i.e., this is a
shortcut to passing converters to `read_kwargs`).
storage_options : dict, optional
Parameters passed to the backend file-system such as Google Cloud Storage,
Amazon Web Service S3.
intake_kwargs: dict, optional
Additional keyword arguments are passed through to the :py:class:`~intake.catalog.Catalog` base class.
Examples
--------
At import time, this plugin is available in intake's registry as `esm_datastore` and
can be accessed with `intake.open_esm_datastore()`:
>>> import intake
>>> url = 'https://storage.googleapis.com/cmip6/pangeo-cmip6.json'
>>> cat = intake.open_esm_datastore(url)
>>> cat.df.head()
activity_id institution_id source_id experiment_id ... variable_id grid_label zstore dcpp_init_year
0 AerChemMIP BCC BCC-ESM1 ssp370 ... pr gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
1 AerChemMIP BCC BCC-ESM1 ssp370 ... prsn gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
2 AerChemMIP BCC BCC-ESM1 ssp370 ... tas gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
3 AerChemMIP BCC BCC-ESM1 ssp370 ... tasmax gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
4 AerChemMIP BCC BCC-ESM1 ssp370 ... tasmin gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
"""
name = 'esm_datastore'
container = 'xarray'
def __init__(
self,
obj: pydantic.FilePath | pydantic.AnyUrl | dict[str, typing.Any] | ESMCatalogModel,
*,
progressbar: bool = True,
sep: str = '.',
registry: DerivedVariableRegistry | None = None,
read_kwargs: dict[str, typing.Any] | None = None,
read_csv_kwargs: dict[str, typing.Any] | None = None,
columns_with_iterables: list[str] | None = None,
storage_options: dict[str, typing.Any] | None = None,
threaded: bool | None = None,
**intake_kwargs: dict[str, typing.Any],
):
"""Intake Catalog representing an ESM Collection."""
super().__init__(**intake_kwargs)
self.storage_options = storage_options or {}
if read_csv_kwargs is not None:
warnings.warn(
'read_csv_kwargs is deprecated and will be removed in a future version. '
'Please use read_kwargs instead.',
DeprecationWarning,
stacklevel=2,
)
if read_kwargs is not None:
raise ValueError(
'Cannot provide both `read_csv_kwargs` and `read_kwargs`. '
'Please use `read_kwargs`.'
)
read_kwargs = read_csv_kwargs
read_kwargs = read_kwargs or {}
if columns_with_iterables:
converter = ast.literal_eval
read_kwargs.setdefault('converters', {})
for col in columns_with_iterables:
if read_kwargs['converters'].setdefault(col, converter) != converter:
raise ValueError(
f"Cannot provide converter for '{col}' via `read_kwargs` when '{col}' is also specified in `columns_with_iterables`"
)
self.read_kwargs = read_kwargs
self.progressbar = progressbar
self.sep = sep
if threaded is None:
self.threaded = ast.literal_eval(os.getenv('ITK_ESM_THREADING', 'True'))
else:
self.threaded = threaded
if isinstance(obj, ESMCatalogModel):
self.esmcat = obj
elif isinstance(obj, dict):
self.esmcat = ESMCatalogModel.from_dict(obj)
else:
self.esmcat = ESMCatalogModel.load(
obj, storage_options=self.storage_options, read_kwargs=read_kwargs
)
self.derivedcat = registry or default_registry
self._entries = {}
self._requested_variables = []
self._columns_with_iterables = columns_with_iterables or []
self.datasets = {}
self._validate_derivedcat()
def _validate_derivedcat(self) -> None:
if self.esmcat.aggregation_control is None and len(self.derivedcat):
raise ValueError(
'Variable derivation requires `aggregation_control` to be specified in the catalog.'
)
for key, entry in self.derivedcat.items():
if self.esmcat.aggregation_control.variable_column_name not in entry.query.keys():
raise ValueError(
f'Variable derivation requires `{self.esmcat.aggregation_control.variable_column_name}` to be specified in query: {entry.query} for derived variable {key}.'
)
for col in entry.query:
if col not in self.esmcat.df.columns:
raise ValueError(
f'Derived variable {key} depends on unknown column {col} in query: {entry.query}. Valid ESM catalog columns: {self.esmcat.df.columns.tolist()}.'
)
[docs] def keys(self) -> list[str]:
"""
Get keys for the catalog entries
Returns
-------
list
keys for the catalog entries
"""
return list(self.esmcat._construct_group_keys(sep=self.sep).keys())
[docs] def keys_info(self) -> pd.DataFrame:
"""
Get keys for the catalog entries and their metadata
Returns
-------
pandas.DataFrame
keys for the catalog entries and their metadata
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('./tests/sample-catalogs/cesm1-lens-netcdf.json')
>>> cat.keys_info()
component experiment stream
key
ocn.20C.pop.h ocn 20C pop.h
ocn.CTRL.pop.h ocn CTRL pop.h
ocn.RCP85.pop.h ocn RCP85 pop.h
"""
results = self.esmcat._construct_group_keys(sep=self.sep)
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
groupby_attrs = self.esmcat.aggregation_control.groupby_attrs
else:
groupby_attrs = list(self.df.columns)
data = {key: dict(zip(groupby_attrs, results[key])) for key in results}
data = pd.DataFrame.from_dict(data, orient='index')
data.index.name = 'key'
return data
@property
def key_template(self) -> str:
"""
Return string template used to create catalog entry keys
Returns
-------
str
string template used to create catalog entry keys
"""
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
return self.sep.join(self.esmcat.aggregation_control.groupby_attrs)
else:
return self.sep.join(self.esmcat.df.columns)
@property
def df(self) -> pd.DataFrame:
"""
Return pandas :py:class:`~pandas.DataFrame`.
"""
return self.esmcat.df
@property
def interactive(self) -> None:
"""
Use itables to display the catalog in an interactive table. Use polars
for performance ideally. Fall back to pandas if not.
We have to explode columns with iterables, otherwise javascript stringifcation
can cause ellipsis to be rendered directly into the interactive table,
losing actual data and inserting junk.
"""
try:
pl_df = self.esmcat._frames.polars # type:ignore[union-attr]
except AttributeError:
pl_df = pl.from_pandas(self.df)
exploded_df = MinimalExploder(pl_df)()
return itables.show(
exploded_df,
search={'regex': True, 'caseInsensitive': True},
layout={'top1': 'searchPanes'},
searchPanes={
'layout': 'columns-3',
'cascadePanes': True,
'columns': [i for i, _ in enumerate(pl_df.columns)],
},
maxBytes=0,
)
def __len__(self) -> int:
return len(self.keys())
def _get_entries(self) -> dict[str, ESMDataSource]:
# Due to just-in-time entry creation, we may not have all entries loaded
# We need to make sure to create entries missing from self._entries
missing = set(self.keys()) - set(self._entries.keys())
for key in missing:
_ = self[key]
return self._entries
[docs] @pydantic.validate_call
def __getitem__(self, key: str) -> ESMDataSource:
"""
This method takes a key argument and return a data source
corresponding to assets (files) that will be aggregated into a
single xarray dataset.
Parameters
----------
key : str
key to use for catalog entry lookup
Returns
-------
intake_esm.source.ESMDataSource
A data source by name (key)
Raises
------
KeyError
if key is not found.
Examples
--------
>>> cat = intake.open_esm_datastore('mycatalog.json')
>>> data_source = cat['AerChemMIP.BCC.BCC-ESM1.piClim-control.AERmon.gn']
"""
# The canonical unique key is the key of a compatible group of assets
try:
return self._entries[key]
except KeyError as e:
if key in self.keys():
keys_dict = self.esmcat._construct_group_keys(sep=self.sep)
grouped = self.esmcat.grouped
internal_key = keys_dict[key]
if isinstance(grouped, pd.DataFrame):
records = [grouped.loc[internal_key].to_dict()]
else:
records = grouped.get_group(internal_key).to_dict(orient='records')
if self.esmcat.aggregation_control:
variable_column_name = self.esmcat.aggregation_control.variable_column_name
aggregations = self.esmcat.aggregation_control.aggregations
else:
variable_column_name = None
aggregations = []
# Create a new entry
entry = ESMDataSource(
key=key,
records=records,
variable_column_name=variable_column_name,
path_column_name=self.esmcat.assets.column_name,
data_format=self.esmcat.assets.format,
format_column_name=self.esmcat.assets.format_column_name,
aggregations=aggregations,
intake_kwargs={'metadata': {}},
threaded=self.threaded,
)
self._entries[key] = entry
return self._entries[key]
raise KeyError(
f'key={key} not found in catalog. You can access the list of valid keys via the .keys() method.'
) from e
def __contains__(self, key) -> bool:
# Python falls back to iterating over the entire catalog
# if this method is not defined. To avoid this, we implement it differently
try:
self[key]
except KeyError:
return False
else:
return True
def __repr__(self) -> str:
"""Make string representation of object."""
return f'<{self.esmcat.id or ""} catalog with {len(self)} dataset(s) from {len(self.df)} asset(s)>'
def _repr_html_(self) -> str:
"""
Return an html representation for the catalog object.
Mainly for IPython notebook
"""
uniques = pd.DataFrame(self.nunique(), columns=['unique'])
text = uniques._repr_html_()
return f'<p><strong>{self.esmcat.id or ""} catalog with {len(self)} dataset(s) from {len(self.df)} asset(s)</strong>:</p> {text}'
def _ipython_display_(self):
"""
Display the entry as a rich object in an IPython session
"""
from IPython.display import HTML, display
contents = self._repr_html_()
display(HTML(contents))
def __dir__(self) -> list[str]:
rv = [
'df',
'to_dataset_dict',
'to_datatree',
'to_dask',
'keys',
'keys_info',
'serialize',
'datasets',
'search',
'unique',
'nunique',
'key_template',
]
return sorted(list(self.__dict__.keys()) + rv)
def _ipython_key_completions_(self):
return self.__dir__()
[docs] @pydantic.validate_call
def search(
self,
require_all_on: str | list[str] | None = None,
**query: typing.Any,
):
"""Search for entries in the catalog.
Parameters
----------
require_all_on : list, str, optional
A dataframe column or a list of dataframe columns across
which all entries must satisfy the query criteria.
If None, return entries that fulfill any of the criteria specified
in the query, by default None.
**query:
keyword arguments corresponding to user's query to execute against the dataframe.
Returns
-------
cat : :py:class:`~intake_esm.core.esm_datastore`
A new Catalog with a subset of the entries in this Catalog.
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('pangeo-cmip6.json')
>>> cat.df.head(3)
activity_id institution_id source_id ... grid_label zstore dcpp_init_year
0 AerChemMIP BCC BCC-ESM1 ... gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
1 AerChemMIP BCC BCC-ESM1 ... gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
2 AerChemMIP BCC BCC-ESM1 ... gn gs://cmip6/AerChemMIP/BCC/BCC-ESM1/ssp370/r1i1... NaN
>>> sub_cat = cat.search(
... source_id=['BCC-CSM2-MR', 'CNRM-CM6-1', 'CNRM-ESM2-1'],
... experiment_id=['historical', 'ssp585'],
... variable_id='pr',
... table_id='Amon',
... grid_label='gn',
... )
>>> sub_cat.df.head(3)
activity_id institution_id source_id ... grid_label zstore dcpp_init_year
260 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r1i... NaN
346 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r2i... NaN
401 CMIP BCC BCC-CSM2-MR ... gn gs://cmip6/CMIP/BCC/BCC-CSM2-MR/historical/r3i... NaN
The search method also accepts compiled regular expression objects
from :py:func:`~re.compile` as patterns.
>>> import re
>>> # Let's search for variables containing "Frac" in their name
>>> pat = re.compile(r'Frac') # Define a regular expression
>>> cat.search(variable_id=pat)
>>> cat.df.head().variable_id
0 residualFrac
1 landCoverFrac
2 landCoverFrac
3 residualFrac
4 landCoverFrac
"""
# step 1: Search in the base/main catalog
esmcat_results = self.esmcat.search(require_all_on=require_all_on, query=query)
# step 2: Search for entries required to derive variables in the derived catalogs
# This requires a bit of a hack i.e. the user has to specify the variable in the query
derivedcat_results = []
if self.esmcat.aggregation_control:
variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None)
else:
variables = None
dependents = []
derived_cat_subset = {}
if variables:
if isinstance(variables, str):
variables = [variables]
for key, value in self.derivedcat.items():
if key in variables:
res = self.esmcat.search(
require_all_on=require_all_on, query={**value.query, **query}
)
if not res.empty:
derivedcat_results.append(res)
dependents.extend(
value.dependent_variables(
self.esmcat.aggregation_control.variable_column_name
)
)
derived_cat_subset[key] = value
if derivedcat_results:
# Merge results from the main and the derived catalogs
esmcat_results = pd.concat([esmcat_results, *derivedcat_results])
esmcat_results = esmcat_results[~esmcat_results.astype(str).duplicated()]
cat = self.__class__({'esmcat': self.esmcat.model_dump(), 'df': esmcat_results})
cat.esmcat.catalog_file = None # Don't save the catalog file
if self.esmcat.has_multiple_variable_assets:
requested_variables = list(set(variables or []).union(dependents))
else:
requested_variables = []
cat._requested_variables = requested_variables
# step 3: Subset the derived catalog,
# but only if variables were looked up, otherwise transfer the whole catalog.
if variables is not None:
cat.derivedcat = DerivedVariableRegistry()
cat.derivedcat._registry.update(derived_cat_subset)
else:
cat.derivedcat = self.derivedcat
return cat
[docs] @pydantic.validate_call
def serialize(
self,
name: pydantic.StrictStr,
directory: pydantic.DirectoryPath | pydantic.StrictStr | None = None,
catalog_type: str = 'dict',
to_csv_kwargs: dict[typing.Any, typing.Any] | None = None,
json_dump_kwargs: dict[typing.Any, typing.Any] | None = None,
storage_options: dict[str, typing.Any] | None = None,
) -> None:
"""Serialize catalog to corresponding json and csv files.
Parameters
----------
name : str
name to use when creating ESM catalog json file and csv catalog.
directory : str, PathLike, default None
The path to the local directory. If None, use the current directory
catalog_type: str, default 'dict'
Whether to save the catalog table as a dictionary in the JSON file or as a separate CSV file.
to_csv_kwargs : dict, optional
Additional keyword arguments passed through to the :py:meth:`~pandas.DataFrame.to_csv` method.
json_dump_kwargs : dict, optional
Additional keyword arguments passed through to the :py:func:`~json.dump` function.
storage_options: dict
fsspec parameters passed to the backend file-system such as Google Cloud Storage,
Amazon Web Service S3.
Notes
-----
Large catalogs can result in large JSON files. To keep the JSON file size manageable, call with
`catalog_type='file'` to save catalog as a separate CSV file.
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('pangeo-cmip6.json')
>>> cat_subset = cat.search(
... source_id='BCC-ESM1',
... grid_label='gn',
... table_id='Amon',
... experiment_id='historical',
... )
>>> cat_subset.serialize(name='cmip6_bcc_esm1', catalog_type='file')
"""
self.esmcat.save(
name,
directory=directory,
catalog_type=catalog_type,
to_csv_kwargs=to_csv_kwargs,
json_dump_kwargs=json_dump_kwargs,
storage_options=storage_options,
)
[docs] def nunique(self) -> pd.Series:
"""Count distinct observations across dataframe columns
in the catalog.
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('pangeo-cmip6.json')
>>> cat.nunique()
activity_id 10
institution_id 23
source_id 48
experiment_id 29
member_id 86
table_id 19
variable_id 187
grid_label 7
zstore 27437
dcpp_init_year 59
dtype: int64
"""
nunique = self.esmcat.nunique()
if self.esmcat.aggregation_control:
nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len(
self.derivedcat.keys()
)
return nunique
[docs] def unique(self) -> pd.Series:
"""Return unique values for given columns in the
catalog.
"""
unique = self.esmcat.unique()
if self.esmcat.aggregation_control:
unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list(
self.derivedcat.keys()
)
return unique
[docs] @pydantic.validate_call
def to_dataset_dict(
self,
xarray_open_kwargs: dict[str, typing.Any] | None = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] | None = None,
preprocess: typing.Callable | None = None,
storage_options: dict[pydantic.StrictStr, typing.Any] | None = None,
progressbar: pydantic.StrictBool | None = None,
aggregate: pydantic.StrictBool | None = None,
skip_on_error: pydantic.StrictBool = False,
threaded: bool | None = None,
**kwargs,
) -> dict[str, xr.Dataset]:
"""
Load catalog entries into a dictionary of xarray datasets.
Column values, dataset keys and requested variables are added as global
attributes on the returned datasets. The names of these attributes can be
customized with :py:class:`intake_esm.utils.set_options`.
Parameters
----------
xarray_open_kwargs : dict
Keyword arguments to pass to :py:func:`~xarray.open_dataset` function
xarray_combine_by_coords_kwargs: : dict
Keyword arguments to pass to :py:func:`~xarray.combine_by_coords` function.
preprocess : callable, optional
If provided, call this function on each dataset prior to aggregation.
storage_options : dict, optional
fsspec Parameters passed to the backend file-system such as Google Cloud Storage,
Amazon Web Service S3.
progressbar : bool
If True, will print a progress bar to standard error (stderr)
when loading assets into :py:class:`~xarray.Dataset`.
aggregate : bool, optional
If False, no aggregation will be done.
skip_on_error : bool, optional
If True, skip datasets that cannot be loaded and/or variables we are unable to derive.
threaded : bool, optional
If True, use :py:func:`dask.compute` to load datasets in parallel. If False, load datasets sequentially.
If None, the environment variable `ITK_ESM_THREADING` will be used to determine the threading behavior,
defaulting to True if the variable is not set. If a value is provided, it will override the environment
variable determined default.
Returns
-------
dsets : dict
A dictionary of xarray :py:class:`~xarray.Dataset`.
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('glade-cmip6.json')
>>> sub_cat = cat.search(
... source_id=['BCC-CSM2-MR', 'CNRM-CM6-1', 'CNRM-ESM2-1'],
... experiment_id=['historical', 'ssp585'],
... variable_id='pr',
... table_id='Amon',
... grid_label='gn',
... )
>>> dsets = sub_cat.to_dataset_dict()
>>> dsets.keys()
dict_keys(['CMIP.BCC.BCC-CSM2-MR.historical.Amon.gn', 'ScenarioMIP.BCC.BCC-CSM2-MR.ssp585.Amon.gn'])
>>> dsets['CMIP.BCC.BCC-CSM2-MR.historical.Amon.gn']
<xarray.Dataset>
Dimensions: (bnds: 2, lat: 160, lon: 320, member_id: 3, time: 1980)
Coordinates:
* lon (lon) float64 0.0 1.125 2.25 3.375 ... 355.5 356.6 357.8 358.9
* lat (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14
* time (time) object 1850-01-16 12:00:00 ... 2014-12-16 12:00:00
* member_id (member_id) <U8 'r1i1p1f1' 'r2i1p1f1' 'r3i1p1f1'
Dimensions without coordinates: bnds
Data variables:
lat_bnds (lat, bnds) float64 dask.array<chunksize=(160, 2), meta=np.ndarray>
lon_bnds (lon, bnds) float64 dask.array<chunksize=(320, 2), meta=np.ndarray>
time_bnds (time, bnds) object dask.array<chunksize=(1980, 2), meta=np.ndarray>
pr (member_id, time, lat, lon) float32 dask.array<chunksize=(1, 600, 160, 320), meta=np.ndarray>
"""
# Return fast
if not self.keys():
warnings.warn(
'There are no datasets to load! Returning an empty dictionary.',
UserWarning,
stacklevel=2,
)
return {}
if (
self.esmcat.aggregation_control
and (
self.esmcat.aggregation_control.variable_column_name
in self.esmcat.aggregation_control.groupby_attrs
)
and len(self.derivedcat) > 0
):
raise NotImplementedError(
f'The `{self.esmcat.aggregation_control.variable_column_name}` column name is used as a groupby attribute: {self.esmcat.aggregation_control.groupby_attrs}. '
'This is not yet supported when computing derived variables.'
)
threaded = _get_threaded(threaded)
xarray_open_kwargs = xarray_open_kwargs or {}
xarray_combine_by_coords_kwargs = xarray_combine_by_coords_kwargs or {}
cdf_kwargs, zarr_kwargs = kwargs.get('cdf_kwargs'), kwargs.get('zarr_kwargs')
if cdf_kwargs or zarr_kwargs:
warnings.warn(
'cdf_kwargs and zarr_kwargs are deprecated and will be removed in a future version. '
'Please use xarray_open_kwargs instead.',
DeprecationWarning,
stacklevel=2,
)
if cdf_kwargs:
xarray_open_kwargs.update(cdf_kwargs)
if zarr_kwargs:
xarray_open_kwargs.update(zarr_kwargs)
source_kwargs = dict(
xarray_open_kwargs=xarray_open_kwargs,
xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs,
preprocess=preprocess,
requested_variables=self._requested_variables,
storage_options=storage_options,
threaded=threaded,
)
if aggregate is not None and not aggregate and self.esmcat.aggregation_control:
self = deepcopy(self)
self.esmcat.aggregation_control.groupby_attrs = []
if progressbar is not None:
self.progressbar = progressbar
if self.progressbar:
print(
f"""\n--> The keys in the returned dictionary of datasets are constructed as follows:\n\t'{self.key_template}'"""
)
sources = {key: source(**source_kwargs) for key, source in self.items()}
datasets = {}
with concurrent.futures.ThreadPoolExecutor(max_workers=dask.system.CPU_COUNT) as executor:
future_tasks = [
executor.submit(_load_source, key, source) for key, source in sources.items()
]
if self.progressbar:
gen = progress_bar(
concurrent.futures.as_completed(future_tasks), total=len(sources)
)
else:
gen = concurrent.futures.as_completed(future_tasks)
for task in gen:
try:
key, ds = task.result()
datasets[key] = ds
except Exception as exc:
if not skip_on_error:
raise exc
self.datasets = self._create_derived_variables(datasets, skip_on_error)
return self.datasets
[docs] @pydantic.validate_call
def to_datatree(
self,
xarray_open_kwargs: dict[str, typing.Any] | None = None,
xarray_combine_by_coords_kwargs: dict[str, typing.Any] | None = None,
preprocess: typing.Callable | None = None,
storage_options: dict[pydantic.StrictStr, typing.Any] | None = None,
progressbar: pydantic.StrictBool | None = None,
aggregate: pydantic.StrictBool | None = None,
skip_on_error: pydantic.StrictBool = False,
levels: list[str] = None,
**kwargs,
):
"""
Load catalog entries into a tree of xarray datasets.
Parameters
----------
xarray_open_kwargs : dict
Keyword arguments to pass to :py:func:`~xarray.open_dataset` function
xarray_combine_by_coords_kwargs: : dict
Keyword arguments to pass to :py:func:`~xarray.combine_by_coords` function.
preprocess : callable, optional
If provided, call this function on each dataset prior to aggregation.
storage_options : dict, optional
Parameters passed to the backend file-system such as Google Cloud Storage,
Amazon Web Service S3.
progressbar : bool
If True, will print a progress bar to standard error (stderr)
when loading assets into :py:class:`~xarray.Dataset`.
aggregate : bool, optional
If False, no aggregation will be done.
skip_on_error : bool, optional
If True, skip datasets that cannot be loaded and/or variables we are unable to derive.
levels : list[str], optional
List of fields to use as the datatree nodes. WARNING: This will overwrite the fields
used to create the unique aggregation keys.
Returns
-------
dsets : :py:class:`~datatree.DataTree`
A tree of xarray :py:class:`~xarray.Dataset`.
Examples
--------
>>> import intake
>>> cat = intake.open_esm_datastore('glade-cmip6.json')
>>> sub_cat = cat.search(
... source_id=['BCC-CSM2-MR', 'CNRM-CM6-1', 'CNRM-ESM2-1'],
... experiment_id=['historical', 'ssp585'],
... variable_id='pr',
... table_id='Amon',
... grid_label='gn',
... )
>>> dsets = sub_cat.to_datatree()
>>> dsets['CMIP/BCC.BCC-CSM2-MR/historical/Amon/gn'].ds
<xarray.Dataset>
Dimensions: (bnds: 2, lat: 160, lon: 320, member_id: 3, time: 1980)
Coordinates:
* lon (lon) float64 0.0 1.125 2.25 3.375 ... 355.5 356.6 357.8 358.9
* lat (lat) float64 -89.14 -88.03 -86.91 -85.79 ... 86.91 88.03 89.14
* time (time) object 1850-01-16 12:00:00 ... 2014-12-16 12:00:00
* member_id (member_id) <U8 'r1i1p1f1' 'r2i1p1f1' 'r3i1p1f1'
Dimensions without coordinates: bnds
Data variables:
lat_bnds (lat, bnds) float64 dask.array<chunksize=(160, 2), meta=np.ndarray>
lon_bnds (lon, bnds) float64 dask.array<chunksize=(320, 2), meta=np.ndarray>
time_bnds (time, bnds) object dask.array<chunksize=(1980, 2), meta=np.ndarray>
pr (member_id, time, lat, lon) float32 dask.array<chunksize=(1, 600, 160, 320), meta=np.ndarray>
"""
if not _DATATREE_AVAILABLE:
raise ImportError(
'.to_datatree() requires the xarray-datatree package to be installed. '
'To proceed please install xarray-datatree using: '
' `python -m pip install xarray-datatree` or `conda install -c conda-forge xarray-datatree`.'
)
# Change the groupby controls if neccessary, used to assemble the tree
if levels is not None:
self = deepcopy(self)
self.esmcat.aggregation_control.groupby_attrs = levels
# Set the separator to a / for datatree temporarily
self.sep, old_sep = '/', self.sep
# Use to dataset dict to access dictionary of datasets
self.datasets = self.to_dataset_dict(
xarray_open_kwargs=xarray_open_kwargs,
xarray_combine_by_coords_kwargs=xarray_combine_by_coords_kwargs,
preprocess=preprocess,
storage_options=storage_options,
progressbar=progressbar,
aggregate=aggregate,
skip_on_error=skip_on_error,
**kwargs,
)
# Set the separator to the original value
self.sep = old_sep
# Convert the dictionary of datasets to a datatree
self.datasets = DataTree.from_dict(self.datasets)
return self.datasets
[docs] def to_dask(self, **kwargs) -> xr.Dataset:
"""
Convert result to an xarray dataset.
This is only possible if the search returned exactly one result.
Parameters
----------
kwargs: dict
Parameters forwarded to :py:func:`~intake_esm.esm_datastore.to_dataset_dict`.
Returns
-------
:py:class:`~xarray.Dataset`
"""
if len(self) != 1: # quick check to fail more quickly if there are many results
raise ValueError(
f'Expected exactly one dataset. Received {len(self)} datasets. Please refine your search or use `.to_dataset_dict()`.'
)
res = self.to_dataset_dict(**{**kwargs, 'progressbar': False})
if len(res) != 1: # extra check in case kwargs did modify something
raise ValueError(
f'Expected exactly one dataset. Received {len(self)} datasets. Please refine your search or use `.to_dataset_dict()`.'
)
_, ds = res.popitem()
return ds
def _create_derived_variables(self, datasets, skip_on_error):
if len(self.derivedcat) > 0:
datasets = self.derivedcat.update_datasets(
datasets=datasets,
variable_key_name=self.esmcat.aggregation_control.variable_column_name,
skip_on_error=skip_on_error,
)
return datasets
def _load_source(key, source):
return key, source.to_dask()
def _get_threaded(threaded: bool | None) -> bool:
"""
Read the threading option from the environment variable & passed value
"""
if threaded is None:
try:
threaded = ast.literal_eval(os.getenv('ITK_ESM_THREADING', 'True'))
except ValueError as e:
raise ValueError(
'The environment variable ITK_ESM_THREADING must be a boolean, if set.'
) from e
return threaded