Source code for deepretro.data.loader

"""Dataset loading pipeline using DeepChem data structures.

Provides :class:`ReactionDataLoader`, a subclass of DeepChem's
:class:`~deepchem.data.DataLoader` that converts a reaction CSV
(product, reactants, label) into a ``DiskDataset`` with automatic
sharding, plus a convenience ``stratified_split`` function for
stratified train/valid/test splitting.
"""

from __future__ import annotations

import warnings
from typing import Any, Iterator, List, Optional

import numpy as np
import pandas as pd
from deepchem.data import Dataset, DiskDataset
from deepchem.data.data_loader import DataLoader
from deepchem.splits import SingletaskStratifiedSplitter

from deepretro.featurizers import ReactionStepFeaturizer


[docs] class ReactionDataLoader(DataLoader): """Load a two-column reaction CSV into a DeepChem ``DiskDataset``. Inherits from :class:`~deepchem.data.DataLoader` and overrides :meth:`_get_shards` and :meth:`_featurize_shard` to handle paired SMILES columns (product + reactants). The parent's ``create_dataset`` method is overridden to add automatic NaN-row dropping (invalid SMILES) with a summary warning. Parameters ---------- featurizer : ReactionStepFeaturizer or None, optional Pre-configured featurizer. A default one (radius=2, size=2048, domain features on) is created when ``None``. product_col : str, optional Column name for product SMILES. Default ``"product"``. reactants_col : str, optional Column name for reactant SMILES. Default ``"reactants"``. label_col : str, optional Column name for binary labels. Default ``"label"``. id_field : str or None, optional Column name to use as sample identifiers. When ``None`` (default), sequential integer IDs are generated per shard. log_every_n : int, optional Log a progress message every *n* shards. Default ``1000``. Examples -------- >>> from deepretro.data import ReactionDataLoader >>> loader = ReactionDataLoader() >>> ds = loader.create_dataset("data/dataset.csv") # doctest: +SKIP >>> len(ds) # doctest: +SKIP 808 """
[docs] def __init__( self, featurizer: ReactionStepFeaturizer | None = None, product_col: str = "product", reactants_col: str = "reactants", label_col: str = "label", id_field: str | None = None, log_every_n: int = 1000, ) -> None: feat = featurizer or ReactionStepFeaturizer() super().__init__( tasks=[label_col], featurizer=feat, id_field=id_field, log_every_n=log_every_n, ) self.product_col = product_col self.reactants_col = reactants_col self.label_col = label_col
# DataLoader interface overrides def _get_shards( self, inputs: List[str], shard_size: Optional[int], ) -> Iterator[pd.DataFrame]: """Yield DataFrame chunks from one or more CSV files. Parameters ---------- inputs : list of str Paths to CSV files. shard_size : int or None Rows per shard. ``None`` reads the whole file at once. """ for input_file in inputs: if shard_size is not None: for chunk in pd.read_csv(input_file, chunksize=shard_size): yield chunk else: yield pd.read_csv(input_file) def _featurize_shard( self, shard: pd.DataFrame, ) -> tuple[np.ndarray, np.ndarray]: """Featurize a shard of reaction pairs. Parameters ---------- shard : pd.DataFrame A chunk of the input CSV. Returns ------- X : np.ndarray Feature matrix (only valid rows). valid_inds : np.ndarray Integer indices of rows that were successfully featurized (all-NaN rows are excluded). """ products = shard[self.product_col].tolist() reactants = shard[self.reactants_col].tolist() pairs = list(zip(products, reactants)) X = self.featurizer.featurize(pairs) # Identify rows where featurization failed (all-NaN) nan_mask = np.isnan(X).all(axis=1) valid_inds = np.where(~nan_mask)[0] X = X[valid_inds] return X, valid_inds # create_dataset : customised to handle missing id_field & NaN warn
[docs] def create_dataset( self, inputs: str | List[str], data_dir: Optional[str] = None, shard_size: Optional[int] = 1000, ) -> DiskDataset: """Read, featurize, and write a reaction CSV to a ``DiskDataset``. Follows the same contract as :meth:`~deepchem.data.DataLoader.create_dataset` but adds automatic dropping of rows where featurization fails (invalid SMILES), with a summary warning at the end. Parameters ---------- inputs : str or list of str Path(s) to CSV file(s). data_dir : str or None, optional Directory to write shards into. When ``None`` a temporary directory is used (cleaned up when the ``DiskDataset`` object is garbage-collected). shard_size : int or None, optional Number of rows per shard. Default ``1000``. Returns ------- dataset : DiskDataset Disk-backed DeepChem dataset ready for splitting / training. Examples -------- >>> loader = ReactionDataLoader() >>> ds = loader.create_dataset("data/dataset.csv") # doctest: +SKIP """ if not isinstance(inputs, list): inputs = [inputs] total_dropped = 0 # Generator that streams featurized shards (X, y, w, ids) into ``DiskDataset.create_dataset`` while tracking how many rows are dropped due to all-NaN features from failed featurization. def shard_generator(): nonlocal total_dropped for shard in self._get_shards(inputs, shard_size): X, valid_inds = self._featurize_shard(shard) dropped = len(shard) - len(valid_inds) if dropped > 0: total_dropped += dropped # Extract labels from task columns y = shard[self.label_col].values.reshape(-1, 1) y = y[valid_inds] w = np.ones_like(y, dtype=np.float32) # Use id_field if provided, else sequential IDs if self.id_field is not None and self.id_field in shard.columns: ids = shard[self.id_field].values[valid_inds] else: ids = np.arange(len(valid_inds)) yield X, y, w, ids dataset = DiskDataset.create_dataset(shard_generator(), data_dir, self.tasks) if total_dropped > 0: warnings.warn( f"Dropped {total_dropped} rows with NaN features (invalid SMILES)." ) return dataset
[docs] def stratified_split( dataset: Dataset, frac_train: float = 0.7, frac_valid: float = 0.15, frac_test: float = 0.15, seed: int = 42, ) -> tuple[Dataset, Dataset, Dataset]: """Stratified split into train / valid / test sets. Uses DeepChem's ``SingletaskStratifiedSplitter`` to preserve class balance across all three splits. Works with any DeepChem ``Dataset`` subclass (``DiskDataset``, ``NumpyDataset``, etc.). Parameters ---------- dataset : Dataset Full dataset to split (``DiskDataset`` or ``NumpyDataset``). frac_train : float, optional Training fraction. Default 0.7. frac_valid : float, optional Validation fraction. Default 0.15. frac_test : float, optional Test fraction. Default 0.15. seed : int, optional Random seed. Default 42. Returns ------- train_ds : Dataset valid_ds : Dataset test_ds : Dataset Examples -------- >>> import numpy as np >>> from deepchem.data import NumpyDataset >>> from deepretro.data import stratified_split >>> ds = NumpyDataset(X=np.random.rand(100, 10), y=np.array([0]*50 + [1]*50).reshape(-1,1)) >>> train, valid, test = stratified_split(ds) >>> len(train) + len(valid) + len(test) == 100 True """ splitter = SingletaskStratifiedSplitter() train_inds, valid_inds, test_inds = splitter.split( dataset, frac_train=frac_train, frac_valid=frac_valid, frac_test=frac_test, seed=seed, ) return ( dataset.select(train_inds.astype(int)), dataset.select(valid_inds.astype(int)), dataset.select(test_inds.astype(int)), )