Source code for intake_esm.utils

"""Helper functions for fetching and loading catalog"""

import importlib
import sys
from collections import defaultdict

import polars as pl
import zarr

__all__ = [
    'OPTIONS',
    'set_options',
    '_set_async_flag',
]


def show_versions(file=sys.stdout):  # pragma: no cover
    """print the versions of intake-esm and its dependencies.
       Adapted from xarray/util/print_versions.py

    Parameters
    ----------
    file : file-like, optional
        print to the given file-like object. Defaults to sys.stdout.
    """

    deps = [
        ('xarray', lambda mod: mod.__version__),
        ('pandas', lambda mod: mod.__version__),
        ('intake', lambda mod: mod.__version__),
        ('intake_esm', lambda mod: mod.__version__),
        ('fsspec', lambda mod: mod.__version__),
        ('s3fs', lambda mod: mod.__version__),
        ('gcsfs', lambda mod: mod.__version__),
        ('fastprogress', lambda mod: mod.__version__),
        ('dask', lambda mod: mod.__version__),
        ('zarr', lambda mod: mod.__version__),
        ('cftime', lambda mod: mod.__version__),
        ('netCDF4', lambda mod: mod.__version__),
        ('requests', lambda mod: mod.__version__),
    ]

    deps_blob = []
    for modname, ver_f in deps:
        try:
            if modname in sys.modules:
                mod = sys.modules[modname]
            else:
                mod = importlib.import_module(modname)
        except Exception:
            deps_blob.append((modname, None))
        else:
            try:
                ver = ver_f(mod)
                deps_blob.append((modname, ver))
            except Exception:
                deps_blob.append((modname, 'installed'))

    print('\nINSTALLED VERSIONS', file=file)
    print('------------------', file=file)

    print('', file=file)
    for k, stat in sorted(deps_blob):
        print(f'{k}: {stat}', file=file)


def _zarr_async() -> bool:
    """
    Zarr went all async in version 3.0.0. This sets the async flag based on
    the zarr version in storage options
    """

    return int(zarr.__version__.split('.')[0]) > 2


def _set_async_flag(data_format: str, xarray_open_kwargs: dict) -> dict:
    """
    If we have the data format set to either zarr2 or zarr3, the async flag in
    `xarray_open_kwargs['storage_options']['remote_opetions']` is constrained to
    be either False or True, respectively.

    Parameters
    ----------
    data_format : str

    xarray_open_kwargs : dict
        The xarray open kwargs dictionary that may contain storage options.
    Returns
    -------
    dict
        The updated xarray open kwargs with the async flag set appropriately.
    """
    if data_format not in {'zarr2', 'zarr3'}:
        return xarray_open_kwargs

    storage_opts_template = {
        'backend_kwargs': {'storage_options': {'remote_options': {'asynchronous': _zarr_async()}}}
    }
    if (
        xarray_open_kwargs.get('backend_kwargs', {})
        .get('storage_options', {})
        .get('remote_options', None)
        is not None
    ):
        xarray_open_kwargs['backend_kwargs']['storage_options']['remote_options'][
            'asynchronous'
        ] = _zarr_async()
    elif xarray_open_kwargs.get('backend_kwargs', {}).get('storage_options', None) is not None:
        xarray_open_kwargs['backend_kwargs']['storage_options'] = storage_opts_template[
            'backend_kwargs'
        ]['storage_options']
    elif xarray_open_kwargs.get('backend_kwargs', None) is not None:
        xarray_open_kwargs['backend_kwargs'] = storage_opts_template['backend_kwargs']
    else:
        xarray_open_kwargs = storage_opts_template

    return xarray_open_kwargs


OPTIONS = {
    'attrs_prefix': 'intake_esm_attrs',
    'dataset_key': 'intake_esm_dataset_key',
    'vars_key': 'intake_esm_vars',
}


[docs]class set_options: """Set options for intake_esm in a controlled context. Currently-supported options: - ``attrs_prefix``: The prefix to use in the names of attributes constructed from the catalog's columns when returning xarray Datasets. Default: ``intake_esm_attrs``. - ``dataset_key``: Name of the global attribute where to store the dataset's key. Default: ``intake_esm_dataset_key``. - ``vars_key``: Name of the global attribute where to store the list of requested variables when opening a dataset. Default: ``intake_esm_vars``. Examples -------- You can use ``set_options`` either as a context manager: >>> import intake >>> import intake_esm >>> cat = intake.open_esm_datastore('catalog.json') >>> with intake_esm.set_options(attrs_prefix='cat'): ... out = cat.to_dataset_dict() Or to set global options: >>> intake_esm.set_options(attrs_prefix='cat', vars_key='cat_vars') """ def __init__(self, **kwargs): self.old = {} for k, v in kwargs.items(): if k not in OPTIONS: raise ValueError( f'argument name {k} is not in the set of valid options {set(OPTIONS)}' ) if not isinstance(v, str): raise ValueError(f'option {k} given an invalid value: {v}') self.old[k] = OPTIONS[k] self._update(kwargs) def __enter__(self): """Context management.""" return def _update(self, kwargs): """Update values.""" for k, v in kwargs.items(): OPTIONS[k] = v def __exit__(self, type, value, traceback): """Context management.""" self._update(self.old)
class MinimalExploder: """ A comprehensive class for analyzing and performing minimal explosions of DataFrames with multiple list columns. """ def __init__(self, df: pl.DataFrame): self.df = df self._list_cols: list[str] | None = None self._length_patterns: dict[str, tuple[int, ...]] | None = None self._explodable_groups: list[list[str]] | None = None @property def list_columns(self) -> list[str]: """Get all list-type columns in the DataFrame.""" if self._list_cols is None: self._list_cols = [col for col in self.df.columns if self.df[col].dtype == pl.List] return self._list_cols @property def length_patterns(self) -> dict[str, tuple[int, ...]]: """Get length patterns for all list columns. This is stored as a dictionary containing tuples of all list lengths, ie 'a' : (1,3,2), 'b' : (2,2,2), """ if self._length_patterns is None: self._length_patterns = self._analyze_patterns() return self._length_patterns @property def explodable_groups(self) -> list[list[str]]: """Get groups of columns that can be exploded together.""" if self._explodable_groups is None: self._explodable_groups = self._compute_groups() return self._explodable_groups def _analyze_patterns(self) -> dict[str, tuple[int, ...]]: """Analyze length patterns of all list columns. Returns a value rather than setting self._length_patterns to shut up mypy.""" _length_patterns = {} for col in self.list_columns: lengths = self.df.select(pl.col(col).list.len()).to_series().to_list() _length_patterns[col] = tuple(lengths) return _length_patterns def _compute_groups(self): """Compute explodable groups based on length patterns. Returns a value rather than setting self._explodable_groups to shut up mypy.""" pattern_groups = defaultdict(list) for col, pattern in self.length_patterns.items(): pattern_groups[pattern].append(col) return list(pattern_groups.values()) @property def summary(self) -> dict: """Get a summary of the explosion analysis.""" return { 'total_columns': len(self.df.columns), 'list_columns': len(self.list_columns), 'unique_patterns': len(set(self.length_patterns.values())), 'explodable_groups': len(self.explodable_groups), 'explosion_operations_needed': len(self.explodable_groups), 'groups': self.explodable_groups, } def __call__(self) -> pl.DataFrame: """Perform the minimal explosion.""" if not self.list_columns: return self.df result_df = self.df for group in self.explodable_groups: result_df = result_df.explode(*group) return result_df