Source code for deepretro.algorithms.hallucination_checker

"""Heuristic hallucination checker for retrosynthetic reaction steps.

When a retrosynthetic step (product → reactant) is proposed, the
predicted reactant may contain structural mistakes: atoms appearing or vanishing, rings changing size,
substituents jumping to a different position on an aromatic ring, and so on.

This module catches those mistakes automatically by comparing the
reactant and product. Two main entry points are provided:

* `hallucination_compare_molecules` — runs every check and returns a
  detailed breakdown of what (if anything) looks wrong.
* `calculate_hallucination_score` — distils the breakdown into a single
  0–100 score (100 = looks fine, 0 = almost certainly hallucinated) with
  a severity label (low / medium / high / critical).
"""

from collections import Counter
from typing import Any

from rdkit import Chem
from rdkit.Chem import rdmolops


# Position mapping for consistent naming
pos_map: dict[str, str] = {
    "1": "position 1",
    "2": "position 2",
    "3": "position 3",
    "4": "position 4",
    "5": "position 5",
    "6": "position 6",
    "ortho": "ortho",
    "meta": "meta",
    "para": "para",
}


[docs] def hallucination_compare_molecules( reactant_smiles: str, product_smiles: str, ) -> dict[str, Any]: """Compare a reactant and product molecule to detect potential hallucinations. Given two SMILES strings, this function parses both molecules and checks atom-count consistency, ring-size changes, substituent position swaps, aromaticity shifts, and unnecessary bond formations. Parameters ---------- reactant_smiles : str SMILES string of the reactant molecule. product_smiles : str SMILES string of the product molecule. Returns ------- results : dict[str, Any] * ``valid_reactant`` (bool) — reactant SMILES parsed OK. * ``valid_product`` (bool) — product SMILES parsed OK. * ``atom_count_consistent`` (bool) — all elements match. * ``ring_size_changes`` (list[str]) — rings added/removed. * ``substituent_position_changes`` (list[dict]) — position swaps. * ``detected_issues`` (list[str]) — all issues found (empty if clean). Examples -------- >>> from deepretro.algorithms import hallucination_compare_molecules >>> res = hallucination_compare_molecules("c1ccccc1", "c1ccccc1OC") >>> res["valid_reactant"] and res["valid_product"] True """ results = { "valid_reactant": False, "valid_product": False, "atom_count_consistent": False, "ring_size_changes": [], "substituent_position_changes": [], "detected_issues": [], } # Check if SMILES strings are valid reactant_mol = Chem.MolFromSmiles(reactant_smiles) product_mol = Chem.MolFromSmiles(product_smiles) if reactant_mol is None: results["detected_issues"].append("Invalid reactant SMILES string") return results else: results["valid_reactant"] = True if product_mol is None: results["detected_issues"].append("Invalid product SMILES string") return results else: results["valid_product"] = True # Get basic molecule properties reactant_atoms = Counter([atom.GetSymbol() for atom in reactant_mol.GetAtoms()]) product_atoms = Counter([atom.GetSymbol() for atom in product_mol.GetAtoms()]) # Check atom count consistency for atom_symbol in set(list(reactant_atoms.keys()) + list(product_atoms.keys())): if reactant_atoms.get(atom_symbol, 0) != product_atoms.get(atom_symbol, 0): results["detected_issues"].append( f"Atom count mismatch for {atom_symbol}: " f"Reactant has {reactant_atoms.get(atom_symbol, 0)}, " f"Product has {product_atoms.get(atom_symbol, 0)}" ) if not any("Atom count mismatch" in issue for issue in results["detected_issues"]): results["atom_count_consistent"] = True # Check for ring size changes reactant_rings = Chem.GetSSSR(reactant_mol) product_rings = Chem.GetSSSR(product_mol) reactant_ring_sizes = [len(ring) for ring in reactant_rings] product_ring_sizes = [len(ring) for ring in product_rings] # Sort ring sizes for easier comparison reactant_ring_sizes.sort() product_ring_sizes.sort() if reactant_ring_sizes != product_ring_sizes: results["detected_issues"].append( f"Ring size change detected: Reactant rings {reactant_ring_sizes}, " f"Product rings {product_ring_sizes}" ) # Report specific ring changes for r_size in reactant_ring_sizes: if reactant_ring_sizes.count(r_size) > product_ring_sizes.count(r_size): results["ring_size_changes"].append(f"{r_size}-membered ring removed") for p_size in product_ring_sizes: if product_ring_sizes.count(p_size) > reactant_ring_sizes.count(p_size): results["ring_size_changes"].append(f"{p_size}-membered ring added") # Check for aromatic ring changes reactant_aromatic_atoms = set( [atom.GetIdx() for atom in reactant_mol.GetAtoms() if atom.GetIsAromatic()] ) product_aromatic_atoms = set( [atom.GetIdx() for atom in product_mol.GetAtoms() if atom.GetIsAromatic()] ) # Check if the number of aromatic atoms changed significantly if abs(len(reactant_aromatic_atoms) - len(product_aromatic_atoms)) > 2: results["detected_issues"].append( f"Significant change in aromaticity: Reactant has {len(reactant_aromatic_atoms)} " f"aromatic atoms, Product has {len(product_aromatic_atoms)}" ) # Advanced check for substituent position changes on rings check_ring_substituent_positions(reactant_mol, product_mol, results) # Check for unnecessary bond formations reactant_bonds = Counter([bond.GetBondType() for bond in reactant_mol.GetBonds()]) product_bonds = Counter([bond.GetBondType() for bond in product_mol.GetBonds()]) if sum(reactant_bonds.values()) < sum(product_bonds.values()): results["detected_issues"].append( f"Possible unnecessary bonds formed: Reactant has {sum(reactant_bonds.values())} bonds, " f"Product has {sum(product_bonds.values())} bonds" ) return results
[docs] def check_ring_substituent_positions( reactant_mol: Chem.Mol, product_mol: Chem.Mol, results: dict[str, Any], ) -> None: """Detect changes in the position of substituents on aromatic rings. For each aromatic ring that appears in both the reactant and the product, this function figures out what groups are attached and where (ortho / meta / para). If the same group shows up at a different position in the product, that is flagged, it almost always means the LLM hallucinated the position. Findings are written directly into *results*. Parameters ---------- reactant_mol : rdkit.Chem.Mol RDKit molecule object of the reactant. product_mol : rdkit.Chem.Mol RDKit molecule object of the product. results : dict Results dictionary to update with findings. Examples -------- >>> from rdkit import Chem >>> r_mol = Chem.MolFromSmiles("c1ccc(O)cc1") # phenol >>> p_mol = Chem.MolFromSmiles("c1ccc(O)cc1") # same phenol >>> res = {"detected_issues": [], "substituent_position_changes": []} >>> check_ring_substituent_positions(r_mol, p_mol, res) >>> res["substituent_position_changes"] [] """ # Get all ring systems in both molecules reactant_ring_info = identify_ring_systems(reactant_mol) product_ring_info = identify_ring_systems(product_mol) # If ring counts mismatch, this is already caught in the main function if len(reactant_ring_info) != len(product_ring_info): return # For each aromatic ring, identify and compare substituent patterns for r_idx, reactant_ring in enumerate(reactant_ring_info): if not reactant_ring["is_aromatic"]: continue # Find a matching aromatic ring in the product matching_rings = [ p for p in product_ring_info if p["is_aromatic"] and p["size"] == reactant_ring["size"] and not p["matched"] ] if not matching_rings: continue product_ring = matching_rings[0] product_ring["matched"] = True # Mark this ring as matched # Identify substituents and their positions for both rings reactant_substituents = identify_substituents(reactant_mol, reactant_ring) product_substituents = identify_substituents(product_mol, product_ring) # Create signature of each substituent reactant_sig = {} product_sig = {} for subst in reactant_substituents: sig = get_substituent_signature(reactant_mol, subst) if sig not in reactant_sig: reactant_sig[sig] = [] reactant_sig[sig].append(subst) for subst in product_substituents: sig = get_substituent_signature(product_mol, subst) if sig not in product_sig: product_sig[sig] = [] product_sig[sig].append(subst) # Check for position changes of similar substituents for sig in set(reactant_sig.keys()).intersection(set(product_sig.keys())): r_positions = [pos_map[s["position"]] for s in reactant_sig[sig]] p_positions = [pos_map[s["position"]] for s in product_sig[sig]] # Sort positions for easier comparison r_positions.sort() p_positions.sort() if r_positions != p_positions: # We found a substituent that has changed position subst_name = get_friendly_substituent_name(sig) results["detected_issues"].append( f"Substituent position change detected: {subst_name} moved from " f"{', '.join(r_positions)} to {', '.join(p_positions)} position(s)" ) results["substituent_position_changes"].append( { "substituent": subst_name, "from_positions": r_positions, "to_positions": p_positions, } )
[docs] def identify_ring_systems(mol: Chem.Mol) -> list[dict[str, Any]]: """Identify all ring systems in a molecule and their properties. Walks the SSSR (Smallest Set of Smallest Rings) that RDKit computes and, for each ring, notes how many atoms it has, which atom indices belong to it, and whether every atom in the ring is aromatic. The ``matched`` flag starts as ``False`` and is used later when pairing up rings between reactant and product. Parameters ---------- mol : rdkit.Chem.Mol RDKit molecule object. Returns ------- rings : list of dict Each dict has keys ``id``, ``atoms``, ``size``, ``is_aromatic``, and ``matched``. Examples -------- >>> from rdkit import Chem >>> mol = Chem.MolFromSmiles("c1ccccc1") >>> rings = identify_ring_systems(mol) >>> len(rings) 1 >>> rings[0]["size"] 6 >>> rings[0]["is_aromatic"] True """ rings = [] ring_info = Chem.GetSSSR(mol) for idx, ring in enumerate(ring_info): ring_atoms = list(ring) is_aromatic = all( mol.GetAtomWithIdx(atom_idx).GetIsAromatic() for atom_idx in ring_atoms ) rings.append( { "id": idx, "atoms": ring_atoms, "size": len(ring_atoms), "is_aromatic": is_aromatic, "matched": False, # Used later for matching rings between reactant and product } ) return rings
[docs] def identify_substituents( mol: Chem.Mol, ring_info: dict[str, Any], ) -> list[dict[str, Any]]: """Identify all substituents attached to a ring and their positions. Walks the atoms of the ring and, for every neighbour that is *not* part of the ring, traces out the full substituent group and labels its attachment point as ortho / meta / para (for 6-membered rings) or a numbered position (for other ring sizes). Parameters ---------- mol : rdkit.Chem.Mol RDKit molecule object. ring_info : dict Ring descriptor as returned by `identify_ring_systems`. Returns ------- substituents : list of dict Each dict has keys ``attachment_point``, ``first_atom``, ``atoms``, and ``position``. Examples -------- >>> from rdkit import Chem >>> mol = Chem.MolFromSmiles("c1ccc(O)cc1") # phenol >>> rings = identify_ring_systems(mol) >>> subs = identify_substituents(mol, rings[0]) >>> len(subs) >= 1 True >>> subs[0]["position"] in ("1", "ortho", "meta", "para") True """ substituents = [] ring_atoms = set(ring_info["atoms"]) # Get connections from ring atoms to non-ring atoms for ring_atom_idx in ring_atoms: ring_atom = mol.GetAtomWithIdx(ring_atom_idx) for neighbor in ring_atom.GetNeighbors(): neighbor_idx = neighbor.GetIdx() # Skip atoms that are part of the ring if neighbor_idx in ring_atoms: continue # Determine the position (ortho, meta, para) relative to other substituents position = determine_ring_position( mol, ring_atom_idx, ring_atoms, ring_info["size"] ) # Find the entire substituent group connected to this point subst_atoms = get_connected_atoms(mol, neighbor_idx, ring_atoms) substituents.append( { "attachment_point": ring_atom_idx, "first_atom": neighbor_idx, "atoms": subst_atoms, "position": position, } ) return substituents
[docs] def determine_ring_position( mol: Chem.Mol, atom_idx: int, ring_atoms: set[int], ring_size: int, ) -> str: """ Determine the position of a substituent on a ring. For 6-membered rings uses ortho/meta/para nomenclature. For other ring sizes returns numbered positions. Parameters ---------- mol : rdkit.Chem.Mol RDKit molecule object. atom_idx : int Index of the ring atom the substituent is bonded to. ring_atoms : set[int] All atom indices that belong to the ring. ring_size : int Size of the ring. Returns ------- position : str ``"ortho"``, ``"meta"``, ``"para"``, or a numbered position. Examples -------- >>> from rdkit import Chem >>> mol = Chem.MolFromSmiles("c1ccc(O)cc1") # phenol >>> ring_atoms = set(range(6)) >>> pos = determine_ring_position(mol, 3, ring_atoms, 6) >>> pos in ("1", "ortho", "meta", "para") True """ # For 6-membered rings, use ortho/meta/para nomenclature if ring_size == 6: # Find other substituents on the ring other_subst = [] for ring_atom in ring_atoms: if ring_atom == atom_idx: continue atom = mol.GetAtomWithIdx(ring_atom) for neighbor in atom.GetNeighbors(): if neighbor.GetIdx() not in ring_atoms: other_subst.append(ring_atom) break # If no other substituents, just return position number if not other_subst: return "1" # Calculate distance to other substituents distances = {} for other in other_subst: # Use shortest path through the ring path = rdmolops.GetShortestPath(mol, atom_idx, other) if path: path_len = ( len(path) - 1 ) # Subtract 1 because path includes both endpoints # Convert distance to position name if path_len == 1: pos = "ortho" elif path_len == 2: pos = "meta" elif path_len == 3: pos = "para" else: pos = str(path_len) distances[other] = pos # Return the closest position if multiple are found if distances: positions = list(distances.values()) # Prioritize ortho, then meta, then para for consistent naming if "ortho" in positions: return "ortho" elif "meta" in positions: return "meta" elif "para" in positions: return "para" else: return positions[0] # For other ring sizes, use numbered positions (1, 2, 3, etc.) return "1" # Default for now
[docs] def get_connected_atoms( mol: Chem.Mol, start_idx: int, exclude_atoms: set[int], ) -> list[int]: """ Get all atoms connected to a starting atom, excluding a set of atoms. Starting from *start_idx* (typically the first atom outside a ring), this does a breadth-first walk along bonds and collects every atom it reaches. It will *not* cross into any atom listed in *exclude_atoms*, this is how we stop at the ring boundary and only get the substituent itself. Parameters ---------- mol : rdkit.Chem.Mol RDKit molecule object. start_idx : int Atom index to start the walk from. exclude_atoms : set[int] Atom indices to treat as barriers (usually the ring atoms). Returns ------- atoms : list of int List of atom indices that form the connected component. Examples -------- >>> from rdkit import Chem >>> mol = Chem.MolFromSmiles("c1ccc(OC)cc1") # methoxybenzene >>> ring_atoms = set(range(6)) >>> # atom 6 is the O attached to the ring; BFS from there excluding ring >>> connected = get_connected_atoms(mol, 6, ring_atoms) >>> len(connected) >= 1 True """ visited = set([start_idx]) queue = [start_idx] while queue: current = queue.pop(0) atom = mol.GetAtomWithIdx(current) for neighbor in atom.GetNeighbors(): neighbor_idx = neighbor.GetIdx() if neighbor_idx not in visited and neighbor_idx not in exclude_atoms: visited.add(neighbor_idx) queue.append(neighbor_idx) return list(visited)
[docs] def get_substituent_signature( mol: Chem.Mol, substituent: dict[str, Any], ) -> str: """ Generate a signature for a substituent to identify similar groups. Counts element types in the substituent atoms and returns a sorted dot-separated string (e.g. ``"C2.O1"``). Parameters ---------- mol : rdkit.Chem.Mol RDKit molecule object. substituent : dict Substituent descriptor (must contain an ``atoms`` key with a list of atom indices). Returns ------- signature : str Signature string for the substituent. Examples -------- >>> from rdkit import Chem >>> mol = Chem.MolFromSmiles("c1ccc(O)cc1") # phenol >>> subst = {"atoms": [6]} # the oxygen atom >>> get_substituent_signature(mol, subst) 'O1' """ # Create a fragment of just the substituent atoms = substituent["atoms"] if not atoms: return "" # Get the SMILES of the fragment # This is a simplified approach, a more robust one would create a proper fragment atom_symbols = [] for atom_idx in atoms: atom = mol.GetAtomWithIdx(atom_idx) atom_symbols.append(atom.GetSymbol()) # Count elements as a basic signature elem_counts = Counter(atom_symbols) signature = ".".join( f"{elem}{count}" for elem, count in sorted(elem_counts.items()) ) # For more complex substituents, we could use a more sophisticated approach # like a Morgan fingerprint or a proper SMILES fragment return signature
[docs] def get_friendly_substituent_name(signature: str) -> str: """ Convert a substituent signature to a friendly name when possible. Parameters ---------- signature : str Element-count signature (e.g. ``"C1"``, ``"N1.O2"``). Returns ------- name : str Friendly name (e.g. ``"Methyl"``), or ``"Group (<signature>)"`` if no match is found. Examples -------- >>> get_friendly_substituent_name("C1") 'Methyl' >>> get_friendly_substituent_name("Br1") 'Bromo' >>> get_friendly_substituent_name("X99") 'Group (X99)' """ # Map of common substituent signatures to friendly names common_substituents = { "C1": "Methyl", "C2": "Ethyl", "C3": "Propyl", "N1": "Amino", "O1": "Hydroxy", "O2": "Carboxyl", "O2.C1": "Carboxyl acid", "Cl1": "Chloro", "Br1": "Bromo", "F1": "Fluoro", "I1": "Iodo", "N1.C1": "Methylamino", "C1.O1": "Hydroxy methyl", "C1.N1": "Aminomethyl", "N1.O1": "Nitro", "N1.O2": "Nitro", "S1": "Thiol", } return common_substituents.get(signature, f"Group ({signature})")
[docs] def calculate_hallucination_score( reactant_smiles: str, product_smiles: str, ) -> dict[str, Any]: """ Calculate a hallucination score for a chemical transformation. This is the high-level entry point. It runs `hallucination_compare_molecules` under the hood and then converts each kind of issue into a point deduction from a perfect score of 100. Bigger problems cost more points (e.g. a substituent jumping position costs 60, while one extra bond costs only 5). The final score is clamped to 0–100 and labelled with a severity: * **≥ 80** → ``"low"`` — looks plausible * **40–79** → ``"medium"`` — worth a second look * **20–39** → ``"high"`` — likely hallucinated * **< 20** → ``"critical"`` — almost certainly wrong If either SMILES string cannot be parsed, the score is 0 / critical. Parameters ---------- reactant_smiles : str SMILES string of the reactant molecule. product_smiles : str SMILES string of the product molecule. Returns ------- result : dict Dictionary with keys ``score`` (int, 0–100), ``severity`` (``"low"`` / ``"medium"`` / ``"high"`` / ``"critical"``), ``penalties`` (list of str), and ``message`` (str). Examples -------- >>> from deepretro.algorithms import calculate_hallucination_score >>> result = calculate_hallucination_score("c1ccccc1", "c1ccccc1OC") >>> result["severity"] 'low' >>> result["score"] >= 80 True """ # Get the detailed comparison results first comparison_results = hallucination_compare_molecules( reactant_smiles, product_smiles ) # Initialize the score at 100 (no hallucinations) base_score = 100 penalty_factors = [] penalty_descriptions = [] # Check if molecules are valid - severe penalty if not if ( not comparison_results["valid_reactant"] or not comparison_results["valid_product"] ): return { "score": 0, "severity": "critical", "message": "Invalid SMILES string detected - cannot assess transformation", } # Apply penalties based on detected issues # 1. Atom count consistency - Critical issue if not comparison_results["atom_count_consistent"]: atom_mismatch_penalties = [] for issue in comparison_results["detected_issues"]: if "Atom count mismatch" in issue: # Extract the difference in atom counts parts = issue.split(":")[1].strip() reactant_count = int(parts.split(",")[0].split()[-1]) product_count = int(parts.split(",")[1].split()[-1]) difference = abs(reactant_count - product_count) # Penalty: 5 points per atom mismatch penalty = min(5 * difference, 100) atom_mismatch_penalties.append(penalty) penalty_descriptions.append( f"Atom count inconsistency: -{penalty} points" ) # Take the maximum penalty from atom mismatches if atom_mismatch_penalties: penalty_factors.append(max(atom_mismatch_penalties)) # 2. Ring size changes - Potential issue, but could be valid in some reactions if comparison_results["ring_size_changes"]: # Check how many ring changes occurred num_ring_changes = len(comparison_results["ring_size_changes"]) # Penalty: 25 points per ring change ring_penalty = min(25 * num_ring_changes, 50) penalty_factors.append(ring_penalty) penalty_descriptions.append(f"Ring structure changes: -{ring_penalty} points") # 3. Substituent position changes - Usually suspicious if comparison_results["substituent_position_changes"]: # Check how many substituent position changes num_position_changes = len(comparison_results["substituent_position_changes"]) # Penalty: 60 points per substituent position change position_penalty = min(60 * num_position_changes, 100) penalty_factors.append(position_penalty) penalty_descriptions.append( f"Substituent position changes: -{position_penalty} points" ) # 4. Aromaticity changes - Significant structural change for issue in comparison_results["detected_issues"]: if "Significant change in aromaticity" in issue: aromaticity_penalty = 40 penalty_factors.append(aromaticity_penalty) penalty_descriptions.append( f"Significant aromaticity changes: -{aromaticity_penalty} points" ) # 5. Unnecessary bond formations - Could indicate hallucination for issue in comparison_results["detected_issues"]: if "Possible unnecessary bonds formed" in issue: # Extract number of additional bonds parts = issue.split(":")[1].strip() reactant_bonds = int(parts.split(",")[0].split()[-2]) product_bonds = int(parts.split(",")[1].split()[-2]) additional_bonds = product_bonds - reactant_bonds # Penalty: 5 points per additional bond bond_penalty = min(5 * additional_bonds, 30) penalty_factors.append(bond_penalty) penalty_descriptions.append( f"Unnecessary bond formations: -{bond_penalty} points" ) # Calculate the final score by applying all penalties final_score = base_score for penalty in penalty_factors: final_score -= penalty # Ensure score doesn't go below 0 final_score = max(0, final_score) # Determine severity level based on score if final_score >= 80: severity = "low" elif final_score >= 40: severity = "medium" elif final_score >= 20: severity = "high" else: severity = "critical" return { "score": final_score, "severity": severity, "penalties": penalty_descriptions, "message": interpret_score(final_score), }
[docs] def interpret_score(score: int) -> str: """Turn a numeric hallucination score into a sentence a non-expert can read. This is called automatically by `calculate_hallucination_score` to fill the ``message`` field, but you can also use it standalone if you already have a score. Parameters ---------- score : int Hallucination score (0 = worst, 100 = best). Returns ------- message : str One-sentence plain-English interpretation. Examples -------- >>> from deepretro.algorithms import interpret_score >>> interpret_score(95) 'Highly reliable transformation with minimal or no structural inconsistencies' """ if score >= 90: return "Highly reliable transformation with minimal or no structural inconsistencies" elif score >= 80: return "Generally reliable transformation with minor structural inconsistencies" elif score >= 70: return "Mostly reliable transformation with some structural inconsistencies" elif score >= 50: return "Questionable transformation with significant structural inconsistencies" elif score >= 30: return "Likely hallucination with major structural inconsistencies" elif score >= 10: return "Severe hallucination with critical structural inconsistencies" else: return "Complete hallucination or invalid transformation"