Source code for deepretro.utils.az

"""AiZynthFinder integration for template-based retrosynthesis.

Runs AiZynthFinder on target molecules, with optional image export.
Uses ZINC stock and USPTO expansion/filter policies by default.
Requires ``AZ_MODEL_CONFIG_PATH`` or ``AZ_MODELS_PATH`` environment variables.
Caching is opt-in through an explicit ``CacheManager`` argument.
"""

from __future__ import annotations

import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Sequence, cast

import structlog
from rdkit import Chem
from rdkit.Chem import rdqueries

from deepretro.utils.cache import CacheManager, make_cache_key
from deepretro.utils.variables import BASIC_MOLECULES

if TYPE_CHECKING:
    from PIL.Image import Image

try:
    from aizynthfinder.aizynthfinder import AiZynthFinder
except ImportError:
    AiZynthFinder = None  # type: ignore[assignment, misc]

logger = structlog.get_logger()

PROJECT_ROOT = Path(__file__).resolve().parents[2]

AZ_MODEL_CONFIG_PATH = f"{PROJECT_ROOT}/{os.getenv('AZ_MODEL_CONFIG_PATH')}"
AZ_MODELS_PATH = f"{PROJECT_ROOT}/{os.getenv('AZ_MODELS_PATH')}"


def _basic_molecule_route(smiles: str) -> list[Dict[str, Any]]:
    """Return the solved route payload for a feedstock/basic molecule."""
    return [
        {
            "type": "mol",
            "hide": False,
            "smiles": smiles,
            "is_chemical": True,
            "in_stock": True,
        }
    ]


def _resolve_config(az_model: str | None = None) -> str:
    """Resolve AiZynthFinder config path, falling back to ``AZ_MODEL_CONFIG_PATH``."""
    if az_model is not None:
        config_path = f"{AZ_MODELS_PATH}/{az_model}/config.yml"
        try:
            with open(config_path, "r") as _:
                return config_path
        except FileNotFoundError:
            logger.warning("AZ config not found, trying fallback", path=config_path)

    try:
        with open(AZ_MODEL_CONFIG_PATH, "r") as _:
            return AZ_MODEL_CONFIG_PATH
    except FileNotFoundError:
        raise FileNotFoundError(
            f"AZ_MODEL_CONFIG_PATH not found at {AZ_MODEL_CONFIG_PATH}"
        )


def _run_az_core(
    smiles: str, az_model: str | None = None
) -> tuple[bool, Sequence[Dict[str, Any]], Any]:
    """Shared retrosynthesis logic used by both public entry points.

    Private because callers should use ``run_az`` or ``run_az_with_img``
    which add caching and shape the return tuple for their respective
    use-cases.

    Parameters
    ----------
    smiles : str
        SMILES string of the target molecule.
    az_model : str | None, optional
        Model variant for config resolution. ``None`` uses the global
        fallback ``AZ_MODEL_CONFIG_PATH``.

    Returns
    -------
    tuple[bool, Sequence[Dict[str, Any]], Any]
        ``(status, result_dict, finder)`` where *finder* is the
        ``AiZynthFinder`` instance (``None`` when the molecule was
        short-circuited as a basic/feedstock molecule).
    """
    if smiles in BASIC_MOLECULES or is_basic_molecule(smiles):
        return True, _basic_molecule_route(smiles), None

    config_filename = _resolve_config(az_model)

    if AiZynthFinder is None:
        raise ImportError(
            "AiZynthFinder support requires optional dependencies. "
            "Install the package with `deepretro[az]`."
        )
    finder = AiZynthFinder(configfile=config_filename)
    finder.stock.select("zinc")
    finder.expansion_policy.select("uspto")
    finder.filter_policy.select("uspto")
    finder.target_smiles = smiles
    finder.tree_search()
    finder.build_routes()
    stats = finder.extract_statistics()
    status = bool(stats["is_solved"])
    result_dict = finder.routes.dict_with_extra(
        include_metadata=True, include_scores=True
    )
    return status, result_dict, finder


[docs] def run_az( smiles: str, az_model: str = "USPTO", cache: CacheManager | None = None, ) -> tuple[bool, Sequence[Dict[str, Any]]]: """Run the retrosynthesis using AiZynthFinder. Example ------- >>> from deepretro.utils.az import run_az >>> status, result_dict = run_az("C1CCCCC1", "USPTO") # doctest: +SKIP >>> isinstance(status, bool) and isinstance(result_dict, list) # doctest: +SKIP True Parameters ---------- smiles : str SMILES string of the target molecule. az_model : str, optional AiZynthFinder model variant (e.g. ``"USPTO"``, ``"Pistachio_50"``), by default ``"USPTO"``. cache : CacheManager | None, optional Explicit cache instance used to memoize results for this call. When ``None``, no cache is read or written. Returns ------- tuple[bool, Sequence[Dict[str, Any]]] ``(solved, routes)`` — whether a route was found and the route data. Notes ----- Install the package with ``deepretro[az]``. Caching is disabled unless an explicit ``cache=CacheManager(...)`` is supplied. """ cache_key = make_cache_key("run_az", smiles, az_model=az_model, version=1) cache_miss = object() if cache is not None: cached_result = cache.get(cache_key, default=cache_miss) if cached_result is not cache_miss: return cast(tuple[bool, Sequence[Dict[str, Any]]], cached_result) status, result_dict, _ = _run_az_core(smiles, az_model) result = (status, result_dict) if cache is not None: cache.set(cache_key, result, tag=smiles) return result
[docs] def run_az_with_img( smiles: str, cache: CacheManager | None = None, ) -> tuple[bool, Sequence[Dict[str, Any]], Sequence[Image | None] | None]: """Run the retrosynthesis using AiZynthFinder, including route images. Example ------- >>> from deepretro.utils.az import run_az_with_img >>> status, result_dict, images = run_az_with_img("C1CCCCC1") # doctest: +SKIP >>> isinstance(status, bool) # doctest: +SKIP True Parameters ---------- smiles : str SMILES string of the target molecule. cache : CacheManager | None, optional Explicit cache instance used to memoize results for this call. When ``None``, no cache is read or written. Returns ------- tuple[bool, Sequence[Dict[str, Any]], Sequence[Image] | None] ``(solved, routes, images)`` — solved status, route data, and optional route images (PNG bytes). Uses ``AZ_MODEL_CONFIG_PATH``. Notes ----- Install the package with ``deepretro[az]``. Caching is disabled unless an explicit ``cache=CacheManager(...)`` is supplied. """ cache_key = make_cache_key("run_az_with_img", smiles, version=1) cache_miss = object() if cache is not None: cached_result = cache.get(cache_key, default=cache_miss) if cached_result is not cache_miss: return cast( tuple[bool, Sequence[Dict[str, Any]], Any], cached_result, ) status, result_dict, finder = _run_az_core(smiles) images = finder.routes.images if finder is not None else None result = (status, result_dict, images) if cache is not None: cache.set(cache_key, result, tag=smiles) return result
[docs] def is_basic_molecule(smiles: str) -> bool: """Check if the molecule is a basic molecule (if number of C atoms is less than 5 or total atoms < 5). Parameters ---------- smiles : str SMILES string of the target molecule Returns ------- bool True if the molecule is a basic molecule, False otherwise Examples -------- >>> from deepretro.utils.az import is_basic_molecule >>> is_basic_molecule("C") True >>> is_basic_molecule("CC") True >>> is_basic_molecule("C1CCCCC1") False >>> is_basic_molecule("CCO") True >>> is_basic_molecule("invalid_smiles!!") False """ try: mol = Chem.MolFromSmiles(smiles) except Exception: return False if mol is None: return False q = rdqueries.AtomNumEqualsQueryAtom(6) num_c_atoms = len(mol.GetAtomsMatchingQuery(q)) # if total number of atoms is less than 5, return True if mol.GetNumAtoms() < 5: return True elif num_c_atoms < 5: return True # if total number of C atoms is less than 5, return True return False