Source code for intake_esm.derived

import importlib
import inspect
import typing

import pydantic
import tlz
import xarray as xr


class DerivedVariableError(Exception):
    pass


[docs]class DerivedVariable(pydantic.BaseModel): func: typing.Callable variable: pydantic.StrictStr query: dict[pydantic.StrictStr, typing.Any | list[typing.Any]] prefer_derived: bool @pydantic.field_validator('query') def validate_query(cls, values): _query = values.copy() for key, value in _query.items(): if isinstance(value, str | int | float | bool): _query[key] = [value] return _query
[docs] def dependent_variables(self, variable_key_name: str) -> list[pydantic.StrictStr]: """Return a list of dependent variables for a given variable""" return self.query[variable_key_name]
def __call__(self, *args, variable_key_name: str | None = None, **kwargs) -> xr.Dataset: """Call the function and return the result""" try: return self.func(*args, **kwargs) except Exception as exc: dependent_variables = ( self.dependent_variables(variable_key_name) if variable_key_name else [] ) raise DerivedVariableError( f'Unable to derived variable: {self.variable} with dependent: {dependent_variables} using args:{args} and kwargs:{kwargs}' ) from exc
[docs]@pydantic.dataclasses.dataclass class DerivedVariableRegistry: """Registry of derived variables""" def __post_init__(self): self._registry = {}
[docs] @classmethod def load(cls, name: str, package: str | None = None) -> 'DerivedVariableRegistry': """Load a DerivedVariableRegistry from a Python module/file Parameters ---------- name : str The name of the module to load the DerivedVariableRegistry from. package : str, optional The package to load the module from. This argument is required when performing a relative import. It specifies the package to use as the anchor point from which to resolve the relative import to an absolute import. Returns ------- DerivedVariableRegistry A DerivedVariableRegistry loaded from the Python module. Notes ----- If you have a folder: /home/foo/pythonfiles, and you want to load a registry defined in registry.py, located in that directory, ensure to add your folder to the $PYTHONPATH before calling this function. >>> import sys >>> sys.path.insert(0, '/home/foo/pythonfiles') >>> from intake_esm.derived import DerivedVariableRegistry >>> registsry = DerivedVariableRegistry.load('registry') """ modname = importlib.import_module(name, package=package) if candidates := inspect.getmembers( modname, lambda x: isinstance(x, DerivedVariableRegistry) ): return candidates[0][1] else: raise ValueError(f'No DerivedVariableRegistry found in {name} module')
[docs] @tlz.curry def register( self, func: typing.Callable, *, variable: str, query: dict[pydantic.StrictStr, typing.Any | list[typing.Any]], prefer_derived: bool = False, ) -> typing.Callable: """Register a derived variable Parameters ---------- func : typing.Callable The function to apply to the dependent variables. variable : str The name of the variable to derive. query : typing.Dict[str, typing.Union[typing.Any, typing.List[typing.Any]]] The query to use to retrieve dependent variables required to derive `variable`. prefer_derived: bool, optional (default=False) Specify whether to compute this variable on datasets that already contain a variable of the same name. Default (False) is to leave the existing variable. Returns ------- typing.Callable The function that was registered. """ self._registry[variable] = DerivedVariable( func=func, variable=variable, query=query, prefer_derived=prefer_derived ) return func
def __contains__(self, item: str) -> bool: return item in self._registry def __getitem__(self, item: str) -> DerivedVariable: return self._registry[item] def __iter__(self) -> typing.Iterator[str]: return iter(self._registry.keys()) def __repr__(self) -> str: return f'DerivedVariableRegistry({self._registry})' def __len__(self) -> int: return len(self._registry) def items(self) -> list[tuple[str, DerivedVariable]]: return list(self._registry.items()) def keys(self) -> list[str]: return list(self._registry.keys()) def values(self) -> list[DerivedVariable]: return list(self._registry.values())
[docs] def search(self, variable: str | list[str]) -> 'DerivedVariableRegistry': """Search for a derived variable by name or list of names Parameters ---------- variable : typing.Union[str, typing.List[str]] The name of the variable to search for. Returns ------- DerivedVariableRegistry A DerivedVariableRegistry with the found variables. """ if isinstance(variable, str): variable = [variable] results = tlz.dicttoolz.keyfilter(lambda x: x in variable, self._registry) reg = DerivedVariableRegistry() reg._registry = results return reg
[docs] def update_datasets( self, *, datasets: dict[str, xr.Dataset], variable_key_name: str, skip_on_error: bool = False, ) -> dict[str, xr.Dataset]: """Given a dictionary of datasets, return a dictionary of datasets with the derived variables Parameters ---------- datasets : typing.Dict[str, xr.Dataset] A dictionary of datasets to apply the derived variables to. variable_key_name : str The name of the variable key used in the derived variable query skip_on_error : bool, optional If True, skip variables that fail variable derivation. Returns ------- typing.Dict[str, xr.Dataset] A dictionary of datasets with the derived variables applied. """ for dset_key, dataset in datasets.items(): for _, derived_variable in self.items(): if set(dataset.variables).issuperset( derived_variable.dependent_variables(variable_key_name) ) and ( (derived_variable.variable not in dataset.variables) or derived_variable.prefer_derived ): try: # Assumes all dependent variables are in the same dataset # TODO: Make this more robust to support datasets with variables from different datasets datasets[dset_key] = derived_variable( dataset, variable_key_name=variable_key_name ) except Exception as exc: if not skip_on_error: raise exc return datasets
default_registry = DerivedVariableRegistry()