Source code for intake_esm.cat

from __future__ import annotations

import builtins
import datetime
import enum
import functools
import json
import os
import typing

import fsspec
import packaging.version
import pandas as pd
import polars as pl
import pydantic
import tlz
from pydantic import ConfigDict
from typing_extensions import Self

from ._search import search, search_apply_require_all_on

__framereaders__ = [pl, pd]
__filetypes__ = ['csv', 'csv.bz2', 'csv.gz', 'parquet']


def _allnan_or_nonan(df, column: str) -> bool:
    """Check if all values in a column are NaN or not NaN

    Returns
    -------
    bool
        Whether the dataframe column has all NaNs or no NaN valles

    Raises
    ------
    ValueError
        When the column has a mix of NaNs non NaN values
    """
    if df[column].isnull().all():
        return False
    if df[column].isnull().any():
        raise ValueError(
            f'The data in the {column} column should either be all NaN or there should be no NaNs'
        )
    return True


class AggregationType(str, enum.Enum):
    join_new = 'join_new'
    join_existing = 'join_existing'
    union = 'union'

    model_config = ConfigDict(validate_assignment=True)


class DataFormat(str, enum.Enum):
    netcdf = 'netcdf'
    zarr = 'zarr'
    zarr2 = 'zarr2'
    zarr3 = 'zarr3'
    reference = 'reference'
    opendap = 'opendap'

    model_config = ConfigDict(validate_assignment=True)


class Attribute(pydantic.BaseModel):
    column_name: pydantic.StrictStr
    vocabulary: pydantic.StrictStr = ''

    model_config = ConfigDict(validate_assignment=True)


class Assets(pydantic.BaseModel):
    column_name: pydantic.StrictStr
    format: DataFormat | None = None
    format_column_name: pydantic.StrictStr | None = None

    model_config = ConfigDict(validate_assignment=True)

    @pydantic.model_validator(mode='after')
    def _validate_data_format(self) -> Self:
        data_format, format_column_name = self.format, self.format_column_name
        if data_format is not None and format_column_name is not None:
            raise ValueError('Cannot set both format and format_column_name')
        elif data_format is None and format_column_name is None:
            raise ValueError('Must set one of format or format_column_name')
        return self


class Aggregation(pydantic.BaseModel):
    type: AggregationType
    attribute_name: pydantic.StrictStr
    options: dict = {}

    model_config = ConfigDict(validate_assignment=True)


class AggregationControl(pydantic.BaseModel):
    variable_column_name: pydantic.StrictStr
    groupby_attrs: list[pydantic.StrictStr]
    aggregations: list[Aggregation] = []

    model_config = ConfigDict(validate_default=True, validate_assignment=True)


[docs]class ESMCatalogModel(pydantic.BaseModel): """ Pydantic model for the ESM data catalog defined in https://git.io/JBWoW """ esmcat_version: pydantic.StrictStr attributes: list[Attribute] assets: Assets aggregation_control: AggregationControl | None = None id: str = '' catalog_dict: list[dict] | None = None catalog_file: pydantic.StrictStr | None = None description: pydantic.StrictStr | None = None title: pydantic.StrictStr | None = None last_updated: datetime.datetime | datetime.date | None = None _df: pd.DataFrame | None = pydantic.PrivateAttr() _frames: FramesModel | None = pydantic.PrivateAttr() _iterable_dtype_map: dict[str, str] = pydantic.PrivateAttr(default_factory=dict) model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) @pydantic.model_validator(mode='after') def validate_catalog(self) -> Self: catalog_dict, catalog_file = self.catalog_dict, self.catalog_file if catalog_dict is not None and catalog_file is not None: raise ValueError('catalog_dict and catalog_file cannot be set at the same time') return self def __setattr__(self, name, value): """If we manually set _df, we need to propagate the change to _frames""" if name == '_df': self._frames = FramesModel(df=value) return super().__setattr__(name, value) @classmethod def from_dict(cls, data: dict) -> ESMCatalogModel: esmcat = data['esmcat'] df = data['df'] if 'last_updated' not in esmcat: esmcat['last_updated'] = None cat = cls.model_validate(esmcat) cat._df = df cat._frames = FramesModel(df=df) return cat
[docs] def save( self, name: str, *, directory: str | None = None, catalog_type: str = 'dict', to_csv_kwargs: dict | None = None, json_dump_kwargs: dict | None = None, storage_options: dict[str, typing.Any] | None = None, ) -> None: """ Save the catalog to a file. Parameters ----------- name: str The name of the file to save the catalog to. directory: str The directory or cloud storage bucket to save the catalog to. If None, use the current directory. catalog_type: str The type of catalog to save. Whether to save the catalog table as a dictionary in the JSON file or as a separate CSV file. Valid options are 'dict' and 'file'. to_csv_kwargs : dict, optional Additional keyword arguments passed through to the :py:meth:`~pandas.DataFrame.to_csv` method. json_dump_kwargs : dict, optional Additional keyword arguments passed through to the :py:func:`~json.dump` function. storage_options: dict fsspec parameters passed to the backend file-system such as Google Cloud Storage, Amazon Web Service S3. Notes ----- Large catalogs can result in large JSON files. To keep the JSON file size manageable, call with `catalog_type='file'` to save catalog as a separate CSV file. """ if catalog_type not in {'file', 'dict'}: raise ValueError( f'catalog_type must be either "dict" or "file". Received catalog_type={catalog_type}' ) # Check if the directory is None, and if it is, set it to the current directory if directory is None: directory = os.getcwd() # Configure the fsspec mapper and associated filenames storage_options = storage_options if storage_options is not None else {} mapper = fsspec.get_mapper(f'{directory}', **storage_options) fs = mapper.fs csv_file_name = fs.unstrip_protocol(f'{mapper.root}/{name}.csv') json_file_name = fs.unstrip_protocol(f'{mapper.root}/{name}.json') data = self.model_dump().copy() for key in {'catalog_dict', 'catalog_file'}: data.pop(key, None) data['id'] = name data['last_updated'] = datetime.datetime.now().utcnow().strftime('%Y-%m-%dT%H:%M:%SZ') _tmp_df = self.df.copy(deep=True) for colname, dtype in self._iterable_dtype_map.items(): _tmp_df[colname] = _tmp_df[colname].apply(getattr(builtins, dtype)) if catalog_type == 'file': csv_kwargs: dict[str, typing.Any] = {'index': False} csv_kwargs |= to_csv_kwargs or {} compression = csv_kwargs.get('compression', '') extensions = {'gzip': '.gz', 'bz2': '.bz2', 'zip': '.zip', 'xz': '.xz'} csv_file_name = f'{csv_file_name}{extensions.get(compression, "")}' data['catalog_file'] = str(csv_file_name) with fs.open(csv_file_name, 'wb') as csv_outfile: _tmp_df.to_csv(csv_outfile, **csv_kwargs) else: data['catalog_dict'] = _tmp_df.to_dict(orient='records') with fs.open(json_file_name, 'w') as outfile: json_kwargs = {'indent': 2} json_kwargs |= json_dump_kwargs or {} json.dump(data, outfile, **json_kwargs) # type: ignore[arg-type] print(f'Successfully wrote ESM catalog json file to: {json_file_name}')
[docs] @classmethod def load( cls, json_file: str | pydantic.FilePath | pydantic.AnyUrl, storage_options: dict[str, typing.Any] | None = None, read_kwargs: dict[str, typing.Any] | None = None, ) -> ESMCatalogModel: """ Loads the catalog from a file Parameters ----------- json_file: str or pathlib.Path The path to the json file containing the catalog storage_options: dict fsspec parameters passed to the backend file-system such as Google Cloud Storage, Amazon Web Service S3. read_kwargs : dict, optional Additional keyword arguments passed through to the :py:func:`~pandas.read_csv` function, if the datastore is saved in csv format, or :py:func:`~pandas.read_parquet` if the datastore is saved in parquet format. """ storage_options = storage_options if storage_options is not None else {} read_kwargs = read_kwargs or {} json_file = str(json_file) # We accept Path, but fsspec doesn't. _mapper = fsspec.get_mapper(json_file, **storage_options) with fsspec.open(json_file, **storage_options) as fobj: data = json.loads(fobj.read()) if 'last_updated' not in data: data['last_updated'] = None cat = cls.model_validate(data) if cat.catalog_file: cat._frames = cat._df_from_file(cat, _mapper, storage_options, read_kwargs) else: cat._frames = FramesModel( lf=pl.LazyFrame(cat.catalog_dict), pl_df=pl.DataFrame(cat.catalog_dict), df=pl.DataFrame(cat.catalog_dict).to_pandas(), ) return cat
def _df_from_file( self, cat: ESMCatalogModel, _mapper: fsspec.FSMap, storage_options: dict[str, typing.Any], read_kwargs: dict[str, typing.Any], ) -> FramesModel: """ Read the catalog file from disk, falling back to pandas for bz2 files which polars can't read. Returns a FramesModel, which contains at least one of: - a polars LazyFrame - a polars DataFrame - a pandas DataFrame , as well as handling dataframe related methods, eg. columns_with_iterables. Parameters ---------- cat: ESMCatalogModel The catalog model _mapper: fsspec mapper A fsspec mapper object storage_options: dict fsspec parameters passed to the backend file-system such as Google Cloud Storage, Amazon Web Service S3. read_kwargs: dict Additional keyword arguments passed through to the :py:func:`~pandas.read_csv` function. Returns ------- FramesModel: A pydantic model containing at least one of a pandas/polars dataframe and a polars lazyframe """ if _mapper.fs.exists(cat.catalog_file): csv_path = cat.catalog_file else: csv_path = f'{os.path.dirname(_mapper.root)}/{cat.catalog_file}' cat.catalog_file = csv_path reader = CatalogFileDataReader(cat.catalog_file, storage_options, **read_kwargs) self._iterable_dtype_map = reader.dtype_map return reader.frames @property def lf(self) -> pl.LazyFrame: """Return a `pl.LazyFrame` containing the catalog, creating it if necessary""" return self._frames.lazy # type: ignore[union-attr] @property def pl_df(self) -> pl.DataFrame: """Return a `pl.DataFrame` containing the catalog, creating it if necessary""" return self._frames.polars # type: ignore[union-attr] @property def df(self) -> pd.DataFrame: """Return the `pd.DataFrame` containing the catalog, creating it if necessary""" return self._frames.pandas # type: ignore[union-attr] @property def columns_with_iterables(self) -> set[str]: """Return a set of columns that have iterables.""" return self._frames.columns_with_iterables # type: ignore[union-attr] @property def has_multiple_variable_assets(self) -> bool: """Return True if the catalog has multiple variable assets.""" if self.aggregation_control: return self.aggregation_control.variable_column_name in self.columns_with_iterables return False @property def grouped(self) -> pd.core.groupby.DataFrameGroupBy | pd.DataFrame: if self.aggregation_control: if self.aggregation_control.groupby_attrs: self.aggregation_control.groupby_attrs = list( filter( functools.partial(_allnan_or_nonan, self.df), self.aggregation_control.groupby_attrs, ) ) if self.aggregation_control.groupby_attrs and set( self.aggregation_control.groupby_attrs ) != set(self.df.columns): return self.df.groupby(self.aggregation_control.groupby_attrs) cols = list( filter( functools.partial(_allnan_or_nonan, self.df), self.df.columns, ) ) return self.df.groupby(cols) def _construct_group_keys(self, sep: str = '.') -> dict[str, str | tuple[str]]: internal_keys = self.grouped.groups.keys() public_keys = map( lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key), internal_keys, ) return dict(zip(public_keys, internal_keys)) def _unique(self) -> dict: def _find_unique(series): values = series.dropna() if series.name in self.columns_with_iterables: values = tlz.concat(values) return list(tlz.unique(values)) data = self.df[self.df.columns] if data.empty: return {col: [] for col in self.df.columns} else: return data.apply(_find_unique, result_type='reduce').to_dict()
[docs] def unique(self) -> pd.Series: """Return a series of unique values for each column in the catalog.""" return pd.Series(self._unique())
[docs] def nunique(self) -> pd.Series: """Return a series of the number of unique values for each column in the catalog.""" return self._frames.nunique() # type: ignore[union-attr]
[docs] def search( self, *, query: QueryModel | dict[str, typing.Any], require_all_on: str | list[str] | None = None, ) -> pd.DataFrame: """ Search for entries in the catalog. Parameters ---------- query: dict, optional A dictionary of query parameters to execute against the dataframe. require_all_on : list, str, optional A dataframe column or a list of dataframe columns across which all entries must satisfy the query criteria. If None, return entries that fulfill any of the criteria specified in the query, by default None. Returns ------- catalog: ESMCatalogModel A new catalog with the entries satisfying the query criteria. """ _query = ( query if isinstance(query, QueryModel) else QueryModel( query=query, require_all_on=require_all_on, columns=self.df.columns.tolist() ) ) results = search( df=self.df, query=_query.query, columns_with_iterables=self.columns_with_iterables ) if _query.require_all_on is not None and not results.empty: results = search_apply_require_all_on( df=results, query=_query.query, require_all_on=_query.require_all_on, columns_with_iterables=self.columns_with_iterables, ) return results
[docs]class QueryModel(pydantic.BaseModel): """A Pydantic model to represent a query to be executed against a catalog.""" query: dict[pydantic.StrictStr, typing.Any | list[typing.Any]] columns: list[str] require_all_on: str | list[typing.Any] | None = None # TODO: Seem to be unable to modify fields in model_validator with # validate_assignment=True since it leads to recursion model_config = ConfigDict(validate_assignment=False) @pydantic.model_validator(mode='after') def validate_query(self) -> Self: query = self.query columns = self.columns require_all_on = self.require_all_on if query: for key in query: if key not in columns: raise ValueError(f'Column {key} not in columns {columns}') if isinstance(require_all_on, str): self.require_all_on = [require_all_on] if require_all_on is not None: for key in self.require_all_on: if key not in columns: raise ValueError(f'Column {key} not in columns {columns}') _query = query.copy() for key, value in _query.items(): if isinstance(value, str | int | float | bool) or value is None or value is pd.NA: _query[key] = [value] self.query = _query return self
[docs]class FramesModel(pydantic.BaseModel): """A Pydantic model to represent our collection of dataframes - pandas, polars, and lazyframe.""" df: pd.DataFrame | None = None pl_df: pl.DataFrame | None = None lf: pl.LazyFrame | None = None model_config = ConfigDict(validate_assignment=True, arbitrary_types_allowed=True)
[docs] @pydantic.model_validator(mode='after') def ensure_some(self) -> Self: """ Make sure that at least one of the dataframes is not `None` when the model is instantiated. """ if self.df is None and self.pl_df is None and self.lf is None: raise AssertionError('At least one of df, pl_df, or lf must be set') return self
@property def pandas(self) -> pd.DataFrame: """Return the pandas DataFrame, instantiating it if necessary.""" if self.df is not None: return self.df if self.pl_df is not None: self.df = self.pl_df.to_pandas(use_pyarrow_extension_array=True) self.df[list(self.columns_with_iterables)] = self.df[ list(self.columns_with_iterables) ].map(tuple) return self.df self.pl_df = self.lf.collect() # type: ignore[union-attr] self.df = self.pl_df.to_pandas(use_pyarrow_extension_array=True) for colname in self.columns_with_iterables: self.df[colname] = self.df[colname].apply(tuple) return self.df @property def polars(self) -> pl.DataFrame: """Return the polars DataFrame, instantiating it if necessary.""" if self.pl_df is not None: return self.pl_df if self.lf is not None: self.pl_df = self.lf.collect() return self.pl_df self.pl_df = pl.from_pandas(self.df) self.lf = self.pl_df.lazy() return self.pl_df @property def lazy(self) -> pl.LazyFrame: """Return the polars LazyFrame, instantiating it if necessary.""" if self.lf is not None: return self.lf # Otherwise, it must be none - so lets create the lazyframe now. We use the # self.polars property, so we can cascade to creating it from the pandas dataframe # if necessary. self.lf = self.polars.lazy() return self.lf @property def columns_with_iterables(self) -> set[str]: """Return a set of columns that have iterables, preferentially using `self.lazy` > `self.polars` > `self.pandas` to minimise overhead.""" if (trunc_df := self.lazy.head(1).collect()).is_empty(): return set() if self.df is not None and self.df.empty: return set() colnames, dtypes = trunc_df.columns, trunc_df.dtypes return {colname for colname, dtype in zip(colnames, dtypes) if dtype == pl.List}
[docs] def nunique(self) -> pd.Series: """Return a series of the number of unique values for each column in the catalog.""" return pd.Series( { colname: self.polars.get_column(colname).explode().n_unique() if self.polars.schema[colname] == pl.List else self.polars.get_column(colname).n_unique() for colname in self.polars.columns } )
class CatalogFileDataReader: """Abstracts away some of the complexity related to reading dataframes""" def __init__( self, catalog_file: pydantic.StrictStr | None, storage_options: dict[str, typing.Any], **read_kwargs, ): self.catalog_file = catalog_file self.storage_options = storage_options self.read_kwargs = read_kwargs if self.catalog_file is None: raise AssertionError('catalog_file must be set to a valid file path or URL') # I think we want to replace this with a dict lookup. if self.catalog_file.endswith('.csv.gz') or self.catalog_file.endswith('.csv'): self.driver = 'polars' self.filetype = 'csv' elif self.catalog_file.endswith('.parquet'): self.driver = 'polars' self.filetype = 'parquet' elif self.catalog_file.endswith('.csv.bz2'): self.driver = 'pandas' self.filetype = 'csv' else: raise ValueError( f'Unsupported file type for catalog_file {self.catalog_file}. ' f'Expected one of {__filetypes__}' ) self._dtype_map: dict[str, str] = {} self.frames = self._read() def _read_csv_pd(self) -> FramesModel: """Read a catalog file stored as a csv using pandas""" df = pd.read_csv( self.catalog_file, storage_options=self.storage_options, **self.read_kwargs, ) self._dtype_map = { colname: df['colname'].dtype for colname in self.read_kwargs.get('converters', {}).keys() } return FramesModel(df=df) def _read_csv_pl(self) -> FramesModel: """Read a catalog file stored as a csv using polars""" converters = self.read_kwargs.pop('converters', {}) # Hack # For polars <1.33, we need to use fsspec here. For >=1.34, we can pass the raw # url. See https://github.com/pola-rs/polars/pull/24450 & https://github.com/intake/intake-esm/issues/744 if packaging.version.Version(pl.__version__) < packaging.version.Version('1.34'): with fsspec.open(self.catalog_file, **self.storage_options) as fobj: lf = pl.scan_csv( fobj, # type: ignore[arg-type] storage_options=self.storage_options, infer_schema=False, **self.read_kwargs, ) else: lf = pl.scan_csv( self.catalog_file, storage_options=self.storage_options, infer_schema=False, **self.read_kwargs, ) if dtype_map := ( lf.head(1) .select([colname for colname in converters.keys()]) .with_columns( [ pl.col(colname) .str.head(1) .str.replace_many( ['[', '(', '{'], ['list', 'tuple', 'set'], ) for colname in converters.keys() ] ) .collect() .to_dicts() ): # Returns an empty list if no rows - hence walrus self._dtype_map = dtype_map[0] lf = lf.with_columns( [ pl.col(colname) .str.replace('^.', '[') # Replace first/last chars with [ or ]. .str.replace('.$', ']') # set/tuple => list .str.replace(',]$', ']') # Remove trailing commas .str.replace_all("'", '"') .str.json_decode() # This is to do with the way polars reads json - single versus double quotes for colname in converters.keys() ] ) return FramesModel(lf=lf) def _read_parquet_pl(self) -> FramesModel: """Read a catalog file stored as a parquet using polars""" lf = pl.scan_parquet( self.catalog_file, # type: ignore[arg-type] storage_options=self.storage_options, **self.read_kwargs, ) return FramesModel(lf=lf) def _read(self): if self.driver == 'polars': if self.filetype == 'csv': return self._read_csv_pl() elif self.filetype == 'parquet': return self._read_parquet_pl() else: raise ValueError(f'Unsupported file type {self.filetype} for polars reader') if self.driver == 'pandas': if self.filetype == 'csv': return self._read_csv_pd() else: raise ValueError(f'Unsupported file type {self.filetype} for pandas reader') @property def dtype_map(self) -> dict[str, str]: """Return a map of column names to their dtypes for columns with iterables.""" return self._dtype_map