Source code for deepretro.models.hallucination_classifier

"""XGBoost hallucination classifier built on DeepChem's GBDTModel.

Provides a single class that handles training, evaluation, threshold
optimisation, single-reaction prediction, and persistence using
DeepChem APIs end-to-end.  ``GBDTModel`` wraps an ``XGBClassifier``
and adds automatic early-stopping with an 80/20 internal split.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np
from deepchem.data import Dataset, NumpyDataset
from deepchem.metrics import Metric, accuracy_score, f1_score, roc_auc_score
from deepchem.models import GBDTModel
from rdkit import Chem
from xgboost import XGBClassifier

from deepretro.featurizers import ReactionStepFeaturizer
from deepretro.utils.metrics import find_optimal_threshold


# Helper functions (kept outside the class per DeepChem convention)


[docs] def probability_scores(dataset: Dataset, model) -> dict[str, float]: """Compute ROC-AUC and optimal threshold from probabilities. Parameters ---------- dataset : Dataset Labelled dataset with ``y`` ground-truth. model : XGBClassifier Fitted sklearn-compatible model with ``predict_proba``. Returns ------- dict Keys: ``roc_auc``, ``optimal_threshold``, ``optimal_f1``. """ y_true = dataset.y.flatten() probabilities = model.predict_proba(dataset.X)[:, 1] auc = roc_auc_score(y_true, probabilities) opt_thr, opt_f1 = find_optimal_threshold(y_true, probabilities) return { "roc_auc": auc, "optimal_threshold": opt_thr, "optimal_f1": opt_f1, }
[docs] def predict_single_reaction( clf: "HallucinationClassifier", product_smiles: str, reactants_smiles: str, ) -> dict[str, Any]: """Predict whether a single reaction step is hallucinated. This is a module-level helper so it can be used independently of the class method. Parameters ---------- clf : HallucinationClassifier A fitted classifier instance. product_smiles : str SMILES of the target product. reactants_smiles : str SMILES of the proposed reactants (dot-separated). Returns ------- result : dict Keys: ``is_hallucination`` (bool), ``probability`` (float). On invalid SMILES an ``error`` key is added instead. Examples -------- >>> from deepretro.models import HallucinationClassifier >>> clf = HallucinationClassifier() >>> clf.load("saved_model/") # doctest: +SKIP >>> predict_single_reaction(clf, "CCO", "CC.O") # doctest: +SKIP {'is_hallucination': False, 'probability': 0.12} >>> predict_single_reaction(clf, "GARBAGE", "CC.O") # doctest: +SKIP {'error': 'Invalid SMILES', 'is_hallucination': None, 'probability': None} """ if ( Chem.MolFromSmiles(product_smiles) is None or Chem.MolFromSmiles(reactants_smiles) is None ): return { "error": "Invalid SMILES", "is_hallucination": None, "probability": None, } if clf.featurizer is None: clf.featurizer = ReactionStepFeaturizer() X = clf.featurizer.featurize([(product_smiles, reactants_smiles)]) ds = NumpyDataset(X=X) probability = clf.predict_probability(ds)[0] return { "is_hallucination": bool(probability >= clf.threshold), "probability": float(probability), }
[docs] class HallucinationClassifier(GBDTModel): """ Binary classifier for detecting hallucinated retrosynthesis reactions. Inherits from DeepChem's ``GBDTModel`` which wraps an ``XGBClassifier`` and adds automatic early-stopping via an internal 80/20 train/validation split. Training data ------------- Prepare a CSV with at least these columns: * ``product`` — SMILES of the target product. * ``reactants`` — SMILES of proposed reactants (dot-separated for multiple reactants, e.g. ``"CC.O"``). * ``label`` — ``1`` if the reaction is hallucinated, ``0`` if real. Then load and train: .. code-block:: python from deepretro.data import ReactionDataLoader, stratified_split from deepretro.models import HallucinationClassifier loader = ReactionDataLoader() ds = loader.create_dataset("path/to/your_dataset.csv") train, valid, test = stratified_split(ds) clf = HallucinationClassifier(model_dir="my_models/") clf.fit(train) scores = clf.evaluate(test) print(scores) The trained model is persisted via DeepChem's standard joblib serialisation. To reload later:: clf = HallucinationClassifier(model_dir="my_models/") clf.load("my_models/") Parameters ---------- model_dir : str, optional Directory for DeepChem model checkpoints. If ``None``, a temporary directory is used (see ``deepchem.models.Model``). early_stopping_rounds : int, optional Rounds for early stopping during ``fit()``. Default ``50``. **xgb_kwargs Forwarded to the underlying ``XGBClassifier`` via ``GBDTModel``. Defaults are tuned for the hallucination detection task. Examples -------- >>> from deepretro.models import HallucinationClassifier >>> clf = HallucinationClassifier() >>> clf.threshold 0.5 """ # Default XGBoost hyper-parameters _DEFAULT_XGB = dict( max_depth=6, learning_rate=0.05, n_estimators=300, subsample=0.8, colsample_bytree=0.8, min_child_weight=3, gamma=0.1, random_state=42, )
[docs] def __init__( self, model_dir: str | None = None, early_stopping_rounds: int = 50, **xgb_kwargs: Any, ) -> None: params = {**self._DEFAULT_XGB, **xgb_kwargs} xgb = XGBClassifier(**params) super().__init__( model=xgb, model_dir=model_dir, early_stopping_rounds=early_stopping_rounds, eval_metric="logloss", ) self.threshold: float = 0.5 self.featurizer: ReactionStepFeaturizer | None = None
# Training
[docs] def fit(self, train_dataset: Dataset) -> None: """ Train the model on a DeepChem ``Dataset``. ``GBDTModel`` automatically performs an internal 80/20 train/validation split for early stopping. The model is auto-saved to ``model_dir`` after training. Parameters ---------- train_dataset : Dataset Training data produced by ``deepretro.data.loader``. Examples -------- >>> clf.fit(train_ds) # doctest: +SKIP """ super().fit(train_dataset) super().save()
# Evaluation
[docs] def evaluate(self, test_dataset: Dataset, metrics=None) -> dict[str, float]: """ Evaluate using DeepChem ``Metric`` objects. Returns label-based metrics, plus probability-based ROC-AUC and the optimal threshold. Updates ``self.threshold`` to the optimal value and auto-saves the model state. Parameters ---------- test_dataset : Dataset Held-out test data. metrics : list of dc.metrics.Metric, optional Label-based metrics to compute. If ``None``, defaults to: ``[Metric(accuracy_score, name="accuracy"),`` ``Metric(f1_score, name="f1")]``. Any ``sklearn.metrics`` function that accepts ``(y_true, y_pred)`` can be wrapped with ``dc.metrics.Metric``, e.g. ``Metric(precision_score, name="precision")``. Returns ------- scores : dict Contains each requested metric name (or ``accuracy``/``f1`` when defaults are used), plus ``roc_auc``, ``optimal_threshold``, and ``optimal_f1``. Examples -------- >>> scores = clf.evaluate(test_ds) # doctest: +SKIP >>> scores["roc_auc"] # doctest: +SKIP 0.92 """ if metrics is None: metrics = [ Metric(accuracy_score, name="accuracy"), Metric(f1_score, name="f1"), ] label_scores = super().evaluate(test_dataset, metrics) prob_scores = probability_scores(test_dataset, self.model) self.threshold = prob_scores["optimal_threshold"] # Auto-save with updated threshold self.save(self.model_dir) scores = dict(label_scores) scores.update(prob_scores) return scores
# Prediction
[docs] def predict_probability(self, dataset: Dataset) -> np.ndarray: """ Return hallucination probabilities for each sample. Parameters ---------- dataset : Dataset Data to score. Returns ------- probabilities : np.ndarray, shape (n_samples,) Probability of the positive class (hallucination). """ return self.model.predict_proba(dataset.X)[:, 1]
[docs] def predict_with_threshold(self, dataset: Dataset) -> tuple[np.ndarray, np.ndarray]: """ Predict binary labels using the current threshold. Unlike the inherited ``predict()`` (which returns raw model output), this applies ``self.threshold`` to produce binary labels. Parameters ---------- dataset : Dataset Data to classify. Returns ------- labels : np.ndarray, shape (n_samples,) Binary predictions (0 or 1). probabilities : np.ndarray, shape (n_samples,) Hallucination probabilities. """ probabilities = self.predict_probability(dataset) return (probabilities >= self.threshold).astype(int), probabilities
[docs] def predict_single( self, product_smiles: str, reactants_smiles: str ) -> dict[str, Any]: """Thin wrapper around :func:`predict_single_reaction`.""" return predict_single_reaction(self, product_smiles, reactants_smiles)
# Persistence
[docs] def save(self, save_dir: str) -> None: """ Save model and threshold via DeepChem's joblib persistence. Parameters ---------- save_dir : str Directory to write artifacts into. Examples -------- >>> clf.save("saved_model/") # doctest: +SKIP """ save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) self.model.optimal_threshold_ = self.threshold self.model_dir = str(save_path) super().save()
[docs] def load(self, save_dir: str) -> None: """ Reload a previously saved model. Parameters ---------- save_dir : str Directory containing saved artifacts. Examples -------- >>> clf = HallucinationClassifier() >>> clf.load("saved_model/") # doctest: +SKIP """ self.model_dir = str(Path(save_dir)) self.reload() self.threshold = getattr(self.model, "optimal_threshold_", 0.5) self.featurizer = ReactionStepFeaturizer()