Source code for deepretro.utils.domain_features

"""Domain feature extraction utilities for reaction-step featurization."""

import numpy as np
from collections import Counter
from rdkit import Chem
from rdkit.Chem import Descriptors

NUM_DOMAIN_FEATURES = 15


[docs] def extract_domain_features_single( product_smiles: str, reactants_smiles: str ) -> np.ndarray: """ Extract hand-crafted domain features for one product-reactant pair. Computes atom-count deltas (C, N, O, Cl, Br), bond/ring/aromaticity deltas, molecular-weight deltas, and absolute counts. Parameters ---------- product_smiles : str SMILES of the target product. reactants_smiles : str SMILES of the proposed reactants (dot-separated when multiple). Returns ------- features : np.ndarray, shape (NUM_DOMAIN_FEATURES,) 1-D feature vector. Returns a NaN vector on any parsing failure, so invalid rows are distinguishable from real data downstream. Examples -------- >>> from deepretro.utils import extract_domain_features_single >>> feats = extract_domain_features_single("CCO", "CC.O") >>> feats.shape (15,) """ try: product_mol = Chem.MolFromSmiles(product_smiles) if product_mol is None: return np.full(NUM_DOMAIN_FEATURES, np.nan) # Reactants may be multiple molecules joined by '.' (standard SMILES mixture notation) reactant_mols = [Chem.MolFromSmiles(r) for r in reactants_smiles.split(".")] # Product-side descriptors p_atoms = Counter( a.GetSymbol() for a in product_mol.GetAtoms() ) # element frequency map p_bonds = product_mol.GetNumBonds() # total bond count p_rings = len(Chem.GetSSSR(product_mol)) # smallest set of smallest rings p_arom = sum( 1 for a in product_mol.GetAtoms() if a.GetIsAromatic() ) # aromatic atom count p_mw = Descriptors.MolWt(product_mol) # molecular weight (Da) # Reactant-side descriptors (aggregated across all valid reactant molecules) r_atoms, r_bonds, r_rings, r_arom, r_mw, n_valid = Counter(), 0, 0, 0, 0.0, 0 for mol in reactant_mols: if mol: r_atoms += Counter(a.GetSymbol() for a in mol.GetAtoms()) r_bonds += mol.GetNumBonds() r_rings += len(Chem.GetSSSR(mol)) r_arom += sum(1 for a in mol.GetAtoms() if a.GetIsAromatic()) r_mw += Descriptors.MolWt(mol) n_valid += 1 # count of successfully parsed reactant molecules # 15 features: deltas (reactant - product) capture what atoms/bonds are "consumed" return np.array( [ float(r_atoms.get("C", 0) - p_atoms.get("C", 0)), # carbon atom delta float(r_atoms.get("N", 0) - p_atoms.get("N", 0)), # nitrogen atom delta float(r_atoms.get("O", 0) - p_atoms.get("O", 0)), # oxygen atom delta float( r_atoms.get("Cl", 0) - p_atoms.get("Cl", 0) ), # chlorine atom delta float( r_atoms.get("Br", 0) - p_atoms.get("Br", 0) ), # bromine atom delta float(r_bonds - p_bonds), # bond count delta float(r_rings - p_rings), # ring count delta float(r_arom - p_arom), # aromatic atom count delta float(r_mw - p_mw), # molecular weight delta (Da) float(n_valid), # number of valid reactant molecules float( sum(r_atoms.values()) - sum(p_atoms.values()) ), # total heavy-atom delta float(p_mw), # absolute product molecular weight float(r_mw), # absolute reactant molecular weight (sum) float(p_rings), # absolute product ring count float(r_rings), # absolute reactant ring count (sum) ] ) except Exception: return np.full(NUM_DOMAIN_FEATURES, np.nan)