Source code for DARTassembler.src.assembler.isomer

"""
"""
import sys
import warnings
from typing import Dict, Any, Optional, Union, Iterable
from ase.visualize import view
from ase import Atoms, Atom
import numpy as np
import itertools
import ase
import networkx as nx
from copy import deepcopy
from typing import List
import logging
from scipy.spatial.distance import cdist, pdist
from scipy.spatial.transform import Rotation as R
from scipy.optimize import linear_sum_assignment, differential_evolution
import pandas as pd
from scipy.optimize import brute
from DARTassembler.src.assembler.utils import are_atoms_equal, get_list_with_all_possible_swappings, \
    remove_haptic_dummy_atom, get_complex_name, join_duplicate_groups_by_union, get_target_vector
from DARTassembler.src.metalig.archetype import try_all_geometrical_isomer_possibilities, all_archetypes, align_vectors, align_donor_atoms
from DARTassembler.src.constants.chem import Element
from DARTassembler.src.metalig.mol import BaseMolecule, Ligand
from DARTassembler.src.metalig.utils_molecule import get_atomic_props_from_ase_atoms
from DARTassembler.src.metalig.utils_graph import  graph_to_dict_with_node_labels, get_graph_hash

try:    # optional imports for visualization. Don't print if not available (since not essential for the main functionality).
    import plotly.graph_objects as go
    import plotly.express as px
    import plotly.io as pio
    import dash
    pio.renderers.default = 'browser'
except ImportError:
    pass


[docs] class AssembledIsomer(BaseMolecule): """ Represent an assembled transition-metal complex isomer. This class wraps the assembled ASE Atoms object together with connectivity graph, ligand/metal indices and metadata for downstream filtering and output. :param atomic_props: ASE Atoms object or dict of atomic properties describing the complex. :type atomic_props: Union[ase.Atoms, dict[str, Any]] :param graph: NetworkX graph representing atomic connectivity and node labels. :type graph: nx.Graph :param metal_idc: Indices of metal center atoms in the ASE Atoms object. :type metal_idc: list[int] :param donor_idc: Flattened indices of donor atoms in the merged graph. :type donor_idc: list[list[int]] :param ligand_idc: Indices of atoms belonging to each ligand in the merged graph. :type ligand_idc: list[list[int]] :param ligand_info: Dictionary containing ligand metadata required to reconstruct Ligand objects. :type ligand_info: dict[str, Any] | None :param global_props: Global properties for the molecule. :type global_props: dict[str, Any] | None :param validity_check: If True, run transition-metal-complex specific validity checks. :type validity_check: bool :param target_vectors: Target donor vectors used to assemble the isomer (one list per ligand). :type target_vectors: list[list[list[float]]] | None :param ligand_origins: Origin coordinates (shape (3,)) used to translate each ligand to its metal center. :type ligand_origins: list[list[float]] | None :param warning: Warning string (e.g., 'clashing' or 'duplicate(...)') to annotate the isomer. :type warning: str :param isomer_name: Unique name assigned to the isomer for identification. :type isomer_name: str | None """ def __init__(self, atomic_props: Union[ase.Atoms, Dict[str, Any]], graph: nx.Graph, metal_idc: List[int], donor_idc: List[List[int]], ligand_idc: List[List[int]], ligand_info: Dict[str, Any] = None, global_props: Dict[str, Any] = None, validity_check: bool = False, target_vectors=None, ligand_origins=None, warning: str = '', isomer_name: str = None ): if global_props is None: global_props = {} if ligand_info is None: ligand_info = {} super().__init__( atomic_props=atomic_props, global_props=global_props, graph=graph, validity_check=False # will be performed later if required ) self.metal_idc = metal_idc self.donor_idc = donor_idc self.ligand_idc = ligand_idc self.ligand_info = ligand_info self.warning = warning self.target_vectors = target_vectors self.ligand_origins = ligand_origins self.isomer_name = isomer_name self.metals = [self.atoms[idx].symbol for idx in self.metal_idc] # A little expensive, but small compared to the mono-axial optimization. self.ligands = self._get_ligands() # Ligand() objects for convenient access to ligands. Won't be saved to disk in the DART workflow. if validity_check: self._tmc_validity_checks()
[docs] def update_positions(self, newatoms: ase.Atoms) -> None: """ Update the atomic positions of the isomer and its ligands. Ensures atom symbols match. Should be used instead of manually updating atoms/atomic_props. :param newatoms: ASE Atoms object with updated positions. :return: None :rtype: None """ assert all(self.atoms.symbols == newatoms.symbols), "Cannot update positions: atom symbols do not match." self.atoms = newatoms self.atomic_props = get_atomic_props_from_ase_atoms(atoms=newatoms) for ligand, idc in zip(self.ligands, self.ligand_idc): ligand.atoms = deepcopy(newatoms[idc]) ligand.atomic_props = get_atomic_props_from_ase_atoms(ligand.atoms)
[docs] def round_positions(self, decimals: int = 6) -> None: """ Round the atomic positions of the isomer to a specified number of decimal places. Useful to avoid tiny numerical differences in tests. :param decimals: Number of decimal places to round to. :return: None :rtype: None """ rounded_atoms = deepcopy(self.atoms) rounded_atoms.set_positions(np.round(rounded_atoms.get_positions(), decimals=decimals)) self.update_positions(newatoms=rounded_atoms)
def _get_ligands(self, validity_check: bool = True) -> List[Ligand]: """ Construct Ligand objects from stored ligand metadata and ASE Atoms slices. The method builds Ligand() wrappers for each ligand using the ligand indices and ligand_info provided during assembly. These Ligand objects are convenient for downstream geometry operations and will not be persisted by default. :param validity_check: Whether to run validity checks when instantiating Ligand objects. :type validity_check: bool :return: List of Ligand objects corresponding to each ligand in the isomer. :rtype: list[ Ligand ] """ ligands = [] for idx, (isomer_donor_idc, isomer_ligand_indices) in enumerate(zip(self.donor_idc, self.ligand_idc)): ligand = Ligand( atomic_props=self.atoms[isomer_ligand_indices], donor_idc=self.ligand_info['donor_idcs'][idx], global_props={'archetype': self.ligand_info['archetypes'][idx]}, graph=self.graph.subgraph(isomer_ligand_indices), unique_name=self.ligand_info['unique_names'][idx], charge=self.ligand_info['charges'], hapdent_idc=self.ligand_info['hapdent_idcs'][idx], geometric_isomers_hapdent_idc=self.ligand_info['geometric_isomers_hapdent_idcs'][idx], validity_check=validity_check, ) ligands.append(ligand) return ligands
[docs] def get_metal_symbols(self) -> List[str]: """ Return the chemical symbols of the metal center atoms. :return: Chemical symbols of metals in the same order as metal_idc. :rtype: list[str] """ return self.atoms[self.metal_idc].get_chemical_symbols()
def _tmc_validity_checks(self) -> None: """ Run transition-metal-complex specific validity checks. Checks basic molecular validity via the BaseMolecule helper and warns if any declared metal center is not classified as a metal element by the Element utility. :raises AssertionError: If base molecule validity checks fail. :return: None :rtype: None """ self._check_if_molecule_valid() # Checks basic molecular properties like atomic_props and graph # Doublecheck if all the metals are really metals. Don't raise an error in case it's intentional. all_metals = all(Element(metal).is_metal for metal in self.metals) if not all_metals: warnings.warn(f"Any of the metal centers {self.metals} in AssembledIsomer() is not a metal. Providing a chemical element as `metal center` that is not a metal is not a problem for DART, this warning is just to make you aware.") return
[docs] def to_dict(self) -> Dict[str, Any]: """ Convert the AssembledIsomer to a serializable dictionary. The returned dictionary includes atomic properties, connectivity graph (handled by the base class), and assembly-specific metadata such as metal/donor/ligand indices. :return: Dictionary representation suitable for JSON or other serialization. :rtype: dict[str, Any] """ d = super().to_dict() # Base class takes care of atomic_props, global_props, and graph d.update({ "metal_idc": self.metal_idc, "donor_idc": self.donor_idc, "ligand_idc": self.ligand_idc, "ligand_info": self.ligand_info, }) return d
# todo: update these methods using the new output format of the AssembledComplex.to_dict() method. # @classmethod # def from_json(cls, filepath) -> 'AssembledIsomer': # """ # Loads an AssembledIsomer object from a .json file. # :param filepath: Path to the .json file # :return: AssembledIsomer object # """ # data = load_json(filepath) # return cls.from_dict(data) # # @classmethod # def from_dict(cls, data: Dict[str, Any]) -> 'AssembledIsomer': # """ # Creates an AssembledIsomer object from a dictionary in the correct format. # :param data: Dictionary containing the AssembledIsomer data # :return: AssembledIsomer object # """ # data['graph'] = graph_from_graph_dict(data['graph']) # data['charge'] = 0 # todo # return cls(**data)
[docs] class AssembledComplex(object): """ Assemble transition-metal complex isomers from Ligand objects, target vectors and metal centers. The class stores the input ligands, their target donor vectors and metal center definitions and provides methods to generate assembled isomers, perform mono-axial optimization and filter duplicates. """ def __init__( self, ligands: List[Ligand], target_vectors: list[list[list[float]]], metal_centers: Union[List[List[Union[str, List[float]]]], str], ligand_origins: List[List[float]] = None, ): """ The assembler will align and translate each ligand such that its donor atoms point along the provided target_vectors and place metal center atoms at the specified coordinates. If metal_centers is a string (element symbol), a single metal center at the origin is assumed for each ligand (mono-metallic complex). :param ligands: Sequence of Ligand objects to assemble. :type ligands: list[ Ligand ] :param target_vectors: For each ligand, a list of donor target vectors (each vector length 3). :type target_vectors: list[list[list[float]]] # shape: (n_ligands, n_donors_per_ligand, 3) :param metal_centers: Either a symbol string for a single metal at origin, or a list where each entry corresponds to the metal center(s) the ligand binds to. For the latter, each metal is a tuple [element_symbol, [x,y,z]]. :type metal_centers: Union[str, list[list[ Union[str, list[float]] ]]] :param ligand_origins: Optional list of origin coordinates (3 floats) for each ligand; if None, defaults to the corresponding metal center positions (or the mean for bridging ligands). :type ligand_origins: list[list[float]] | None :param ligand_origins: Optional list of origin coordinates (3 floats) for each ligand; if None, defaults to the corresponding metal center positions (or the mean for bridging ligands). :type ligand_origins: list[list[float]] | None Example usage for assembling a mono-metallic square-planar Pd complex with 2 cis bidentate ligands in the xy-plane: .. code-block:: python factory = AssembledComplex( ligands=... # (list of two 2-cis Ligand() objects from the MetaLig) metal_centers='Pd', target_vectors=[ [[1, 0, 0], [0, 1, 0]], # the donor atoms of the first bidentate ligand are oriented in (+x,+y) direction [[-1, 0, 0], [0, -1, 0]], # the donor atoms of the second bidentate ligand are oriented in (-x,-y) direction ], ) Example usage for assembling a bi-metallic complex with three monodentate ligands, each one at either end and the third one bridging: .. code-block:: python ligands = ... # (list of three 1-mono Ligand() objects from the MetaLig) ru = ['Ru', [1, 0, 0]] fe = ['Fe', [-1, 0, 0]] metal_centers = [ [ru], # metal center for the first ligand [fe] # metal center for the second ligand [ru, fe], # metal centers for the third, bridging ligand. By default, the ligand will be placed at the center between both metals. ] target_vectors = [ [[1, 0, 0]], # bound to Ru, pointing away from center [[-1, 0, 0]], # bound to Fe, pointing away from center [[0, 0, 1]], # bridging between Ru and Fe, pointing up ] factory = AssembledComplex( ligands=ligands, target_vectors=target_vectors, metal_centers=metal_centers ) isomers = factory.generate_isomers() """ metal_centers, ligand_origins, ligands, target_vectors = self._check_input_and_handle_defaults(metal_centers, ligand_origins, ligands, target_vectors) self.ligands = ligands self.target_vectors = target_vectors self.ligand_origins = ligand_origins self.metal_centers = metal_centers def _check_input_and_handle_defaults(self, metal_centers, ligand_origins, ligands, target_vectors): """ Validate inputs and set sensible defaults for metal_centers and ligand_origins. This method: - Expands a single-element metal_centers string into ASE Atoms at the origin for each ligand. - Converts provided metal center specifications into ASE Atom objects. - Sets default ligand_origins to the metal center(s) position (or mean position for bridging ligands). - Validates dimensions and numeric types of target_vectors and matches them against ligand archetypes, issuing warnings if the provided vectors do not match the ligand archetype. :param metal_centers: Metal center specification (see __init__). :type metal_centers: Union[str, list[list[ Union[str, list[float]] ]]] :param ligand_origins: Optional ligand origins to use for translation. :type ligand_origins: list[list[float]] | None # shape: (n_ligands, 3) :param ligands: List of Ligand objects. :type ligands: list[ Ligand ] :param target_vectors: Target donor vectors per ligand. :type target_vectors: list[list[list[float]]] # shape: (n_ligands, n_donors, 3) :raises ValueError: If list lengths mismatch or target vector shapes are invalid. :raises TypeError: If target_vectors has an invalid type structure. :return: Normalized (metal_centers, ligand_origins, ligands, target_vectors) :rtype: tuple[ list[list[ase.Atom]], list[list[float]], list[ Ligand ], list[list[list[float]]] ] """ if isinstance(metal_centers, str): # If the metal center is provided as a chemical element, it's a mono-metallic complex at the origin metal_centers = [[ase.Atom(symbol=metal_centers, position=[0, 0, 0])] for _ in ligands] else: # If the metal center is provided as a list of elements and positions, convert to ASE Atoms objects metal_centers = [[ase.Atom(symbol=metal[0], position=metal[1]) for metal in metal_list] for metal_list in metal_centers] if ligand_origins is None: # The `ligand_origins` defaults for ligands that are connected to only one metal to its metal center coordinates, and for ligands that are connected to multiple metal centers to the average of the metal center coordinates (i.e. in the center of them). ligand_origins = [] for lig, centers in zip(ligands, metal_centers): # If the ligand is connected to multiple metal centers, use the average position of the metal centers avg_position = np.mean([metal.position for metal in centers], axis=0) ligand_origins.append(avg_position.tolist()) # Check input format of lists input_lengths = (len(ligands), len(target_vectors), len(ligand_origins), len(metal_centers)) all_same_length = all(length == input_lengths[0] for length in input_lengths) if not all_same_length: raise ValueError(f'Input lists must all have the same length. Got lengths: {input_lengths}') try: target_vectors = [[get_target_vector(v) for v in target_vector_list] for target_vector_list in target_vectors] except: pass # Check target vectors format try: for target_vector_list in target_vectors: try: array = np.array(target_vector_list) except ValueError as e: raise ValueError(f"Target vector is not a list of lists of floats: {target_vector_list}. Error: {e}") if array.ndim != 2: raise ValueError(f"Target vector must have 2 dimensions (list of list), got {array.shape}: {target_vector_list}") elif array.shape[1] != 3: raise ValueError(f"Target vector must have 3 elements (list of list of floats), got {array.shape[1]}: {target_vector_list}") elif not np.issubdtype(array.dtype, np.floating) and not np.issubdtype(array.dtype, np.integer): raise ValueError(f"Target vector must be a list of lists of floats, got {array.dtype}: {target_vector_list}") except TypeError as e: raise TypeError(f"Target vectors must be a list of lists of lists of floats. Error: {e}") # Check if the target vectors are compatible with the ligand archetypes. for target_vector_list, ligand in zip(target_vectors, ligands): ligand_arch_vector = np.asarray(all_archetypes[ligand.n_eff_denticities][ligand.archetype][0]) _, rssd = align_vectors(target_vectors=target_vector_list, donor_vectors=ligand_arch_vector) if not np.isclose(rssd, 0.0): # Check if any combination of the target vectors matches the archetype of the ligand. target_vector_list_permutations = list(itertools.permutations(target_vector_list)) any_other_order_matches = False for target_vector_list_test in target_vector_list_permutations: _, rssd = align_vectors(target_vectors=target_vector_list_test, donor_vectors=ligand_arch_vector) if np.isclose(rssd, 0.0): any_other_order_matches = True break if any_other_order_matches: print_target_vector_list_test = np.array(target_vector_list_test).tolist() warn_string = f'Most likely, you provided the target vectors in the wrong order (see documentation), because if we change the order to e.g. `{print_target_vector_list_test}`, they fit perfectly.' else: warn_string = f'Most likely, you simply provided erroneous ligands/target vectors.' print_target_vector_list = np.array(target_vector_list).tolist() logging.warning( f"DART ASSEMBLER WARNING: The user-provided target vectors `{print_target_vector_list}` do not perfectly match the ideal target vectors of the ligand archetype `{ligand.archetype}`. A typical choice of `{ligand.archetype}` ligands would be the target vectors `{ligand_arch_vector.tolist()}` (or any 3D rotated equivalent). The DART assembler will continue and try to align the ligand's donor vectors with the target vectors you provided, but the assembled complexes may not have the geometry you intended. If this is intended, you can ignore this warning. {warn_string}") return metal_centers, ligand_origins, ligands, target_vectors # This is the method used in the DART workflow to generate isomers.
[docs] def generate_isomers(self, permutable_ligands: Optional[List[int]] = None, monoaxial_optimization: Optional[bool] = True, force_all_isomers: bool = False, clashing_tolerance: float = -0.3, clashing_metal: bool = False, duplicate_tolerance: float = 0.5, complex_name_length: int = 8, complex_name_suffix: str = '', avoid_names: Optional[Iterable[str]] = None, ): """ Generate assembled isomers from the provided ligands and metal centers. The method constructs all geometric combinations by assigning rotated ligand instances to metal centers, applies optional mono-axial optimization, filters clashing isomers, and groups duplicates via fingerprint or alignment-based comparison. :param permutable_ligands: Indices of ligands allowed to be permuted when generating isomers. :type permutable_ligands: list[int] | None :param monoaxial_optimization: Whether to perform mono-axial rotation optimization for mono-coordinating ligands. :type monoaxial_optimization: bool :param force_all_isomers: If True, generate all geometrical isomers for each ligand archetype. :type force_all_isomers: bool :param clashing_tolerance: Distance buffer below covalent radii sum allowed before flagging a clash (Å). :type clashing_tolerance: float | None :param clashing_metal: If True, also check ligand-metal and metal-metal clashes. :type clashing_metal: bool :param duplicate_tolerance: Cutoff used by duplicate grouping (fingerprint threshold). :type duplicate_tolerance: float :param complex_name_length: Length of the generated complex name seed. :type complex_name_length: int :param complex_name_suffix: Optional suffix appended to generated complex names. :type complex_name_suffix: str :param avoid_names: Iterable of names to avoid when generating unique complex names. :type avoid_names: Iterable[str] | None :return: None. Results are stored on the AssembledComplex instance as `isomers`, `successful_isomers`, `unsuccessful_isomers` and `success` boolean. :rtype: None """ if avoid_names is None: avoid_names = set() self.clashing_tolerance = clashing_tolerance self.clashing_metal = clashing_metal self.duplicate_tolerance = duplicate_tolerance self.permutable_ligands = permutable_ligands self.monoaxial_optimization = monoaxial_optimization self.force_all_isomers = force_all_isomers self.complex_name_length = complex_name_length self.complex_name_suffix = complex_name_suffix self.avoid_names = avoid_names unique_metal_centers = self._get_all_unique_metal_centers() self.metal_idc = [idx for idx in range(len(unique_metal_centers))] self.graph, self.ligand_indices, self.donor_indices = self._get_merged_graph_from_ligands_and_metal_centers() self.graph_hash = get_graph_hash(self.graph) self.complex_name = self._get_complex_name(avoid_names=avoid_names) # Generate all possible geometric isomers to be generated via exchanging ligands (or, as here implemented, exchanging the target vectors of the ligands). target_vector_combs = get_list_with_all_possible_swappings(objects=self.target_vectors, permutable_ligands=self.permutable_ligands) ligand_origin_combs = get_list_with_all_possible_swappings(objects=self.ligand_origins, permutable_ligands=self.permutable_ligands) isomers = [] same_length_target_vectors = [] same_length_ligand_origins = [] isomer_idx = 1 for target_vectors, ligand_origins in zip(target_vector_combs, ligand_origin_combs): rotated_ligands = self._get_rotated_ligands(target_vectors=target_vectors, ligand_origins=ligand_origins) # Generate all combinations. Each combination is a tuple with one isomer per ligand. combinations = list(itertools.product(*rotated_ligands)) ase_isomers = [] for combo in combinations: combined = Atoms() # Start with an empty Atoms object. for atom in unique_metal_centers: combined += atom for ligand in combo: # Iterate over the ligands in the combination. combined += ligand # combining Atoms objects. ase_isomers.append(combined) # Store all the new isomers for ase_isomer in ase_isomers: isomer = AssembledIsomer( atomic_props=ase_isomer, graph=self.graph, metal_idc=self.metal_idc, ligand_idc=self.ligand_indices, donor_idc=self.donor_indices, global_props={}, ligand_info=self._get_ligandinfo(), target_vectors=target_vectors, ligand_origins=ligand_origins, warning='', # Initially no warning, will be updated later if needed validity_check=True, isomer_name=self.complex_name + str(isomer_idx) # Assign a name based on the complex name and index ) # Round coordinates to avoid numerical tiny differences in pytests when running with different OS/python/package versions. Makes many things much more stable. isomer.round_positions() isomers.append(isomer) same_length_ligand_origins.append(ligand_origins) same_length_target_vectors.append(target_vectors) isomer_idx += 1 pre_isomers_duplicate_groups = _DuplicateIsomerFilter(isomers=isomers, fingerprint_grouping_cutoff=self.duplicate_tolerance, metal_centers=self.metal_centers).get_duplicate_groups() pre_isomers_duplicate_group_names = [set([isomer.isomer_name for isomer in isomer_group]) for isomer_group in pre_isomers_duplicate_groups] # Do a mono-axial optimization of the isomers and afterward check for clashing ligands. for idx, isomer, target_vectors, ligand_origins in zip(range(len(isomers)), isomers, same_length_target_vectors, same_length_ligand_origins): isomer = _AxialOptModifier(isomers=[isomer], opt=self.monoaxial_optimization).modify(target_vectors_list=[target_vectors], ligand_origins_list=[ligand_origins])[0] # Round coordinates to avoid numerical tiny differences in pytests when running with different OS/python/package versions. Makes many things much more stable. isomer.round_positions() isomers[idx] = isomer # important: copy changed object over to list if isomer.warning == '' and pd.notna(self.clashing_tolerance): clashfilter = _IsomerClashFilter(buffer=self.clashing_tolerance, check_metal_clashes=self.clashing_metal) clashing = clashfilter.has_clashing_atoms(atoms=isomer.atoms, ligand_idc=isomer.ligand_idc, metal_idc=isomer.metal_idc) if clashing: isomer.warning = 'clashing' # this now updates the optimized object # Check for duplicates again after the mono-axial optimization. joined_isomers_duplicate_group_names = self._get_duplicate_isomers_group_names(isomers=isomers, pre_isomers_duplicate_group_names=pre_isomers_duplicate_group_names, duplicate_tolerance=self.duplicate_tolerance, metal_centers=self.metal_centers) self.successful_isomers, self.unsuccessful_isomers = self._divide_into_successful_and_unsuccessful_isomers(isomers, joined_isomers_duplicate_group_names) self.success = len(self.successful_isomers) > 0 self.isomers = isomers # for isomer in self.isomers: # isomer.atoms.set_positions(np.round(isomer.atoms.get_positions(), decimals=0)) return
@staticmethod def _get_duplicate_isomers_group_names(isomers: list[Any], pre_isomers_duplicate_group_names: list[set[Union[str, None]]], duplicate_tolerance: float, metal_centers: list[list[Atom]]) -> list[list[str]]: """ Join pre- and post-optimization duplicate groups and preserve isomer ordering. Uses union of pre- and post-monoaxial-optimization duplicate sets and sorts each joined group according to the original `isomers` order so the first element can be retained. :param isomers: List of AssembledIsomer objects. :type isomers: list[Any] :param pre_isomers_duplicate_group_names: List of sets containing isomer names flagged as duplicates before optimization. :type pre_isomers_duplicate_group_names: list[ set[str] ] :param duplicate_tolerance: Cutoff used by duplicate grouping (fingerprint threshold). :type duplicate_tolerance: float :param metal_centers: Metal center atom definitions used for duplicate comparison. :type metal_centers: list[list[ase.Atom]] :return: Ordered list of duplicate groups where each group is a list of isomer names. :rtype: list[list[str]] """ post_isomers_duplicate_groups = _DuplicateIsomerFilter(isomers=isomers, fingerprint_grouping_cutoff=duplicate_tolerance, metal_centers=metal_centers).get_duplicate_groups() post_isomers_duplicate_group_names = [set([isomer.isomer_name for isomer in isomer_group]) for isomer_group in post_isomers_duplicate_groups] # Join the pre- and post-isomers duplicate groups. If an isomer is a duplicate in either the pre- or post-isomers duplicate groups, it is considered a duplicate. joined_isomers_duplicate_group_names = join_duplicate_groups_by_union(pre_isomers_duplicate_group_names, post_isomers_duplicate_group_names, mode='post') assert sorted(name for group in joined_isomers_duplicate_group_names for name in group) == sorted(name for isomer in isomers for name in [isomer.isomer_name]), "Joined isomer groups do not contain all isomers." # Sort the joint isomer names by the order of `isomers` and convert to lists, so that the output order of isomers is preserved. That is particularly important so that the duplicate filter always keeps the same, "first" isomer in the group. isomer_names_order = [isomer.isomer_name for isomer in isomers] joined_isomers_duplicate_group_names = [sorted(list(group), key=lambda x: isomer_names_order.index(x)) for group in joined_isomers_duplicate_group_names] return joined_isomers_duplicate_group_names def _divide_into_successful_and_unsuccessful_isomers(self, isomers: list[Any], joined_isomers_duplicate_group_names: list[list[str]]) -> tuple[list[Any], list[Any]]: """ Split isomers into successful and unsuccessful sets based on warnings and duplicate groups. Successful isomers are the first non-clashing member of each duplicate group. Others are marked unsuccessful and labelled with a suitable warning ('clashing' or 'duplicate(...)'). :param isomers: List of AssembledIsomer objects (each must have 'isomer_name' and 'warning' attributes). :type isomers: list[Any] :param joined_isomers_duplicate_group_names: Joined duplicate groups with isomer names (post optimization). :type joined_isomers_duplicate_group_names: list[list[str]] :return: Tuple (successful_isomers, unsuccessful_isomers). :rtype: tuple[list[Any], list[Any]] :raises AssertionError: If duplicate group contains duplicate names or is empty. """ successful_isomers = [] unsuccessful_isomers = [] isomer_dict = {isomer.isomer_name: isomer for isomer in isomers} for isomer_group in joined_isomers_duplicate_group_names: assert len(isomer_group) == len(set(isomer_group)) and len( isomer_group) > 0, f"Duplicate isomer group contains duplicates or is empty: {isomer_group}" first_isomer_in_group_added = False for isomer_name in isomer_group: if isomer_dict[isomer_name].warning == 'clashing': # If the isomer is clashing, add it to the unsuccessful isomers. unsuccessful_isomers.append(isomer_dict[isomer_name]) elif isomer_dict[isomer_name].warning == '': if not first_isomer_in_group_added: # If the isomer is not clashing and is the first in the group, add it to the successful isomers. successful_isomers.append(isomer_dict[isomer_name]) first_isomer_in_group_added = True else: # If the isomer is not clashing and is not the first in the group, add it to the unsuccessful isomers and mark it as a duplicate. duplicate_indices = ','.join(name.removeprefix(self.complex_name) for name in isomer_group) isomer_dict[isomer_name].warning = f'duplicate({duplicate_indices})' unsuccessful_isomers.append(isomer_dict[isomer_name]) else: raise ValueError( f"Isomer {isomer_name} has an unexpected warning: {isomer_dict[isomer_name].warning}. Expected 'clashing' or ''.") return successful_isomers, unsuccessful_isomers
[docs] def to_dict(self): """ Convert the AssembledComplex factory state to a dictionary. The resulting dictionary contains the generated isomers (by name), the merged connectivity graph, graph hash, index mappings and the input parameters used to assemble. :return: Dictionary containing assembly metadata and serialized isomer entries. :rtype: dict[str, Any] """ isomer_data = {} for isomer in self.isomers: # Assert atomic_props and ase.Atoms positions are equivalent. all_atomic_props_coordinates = np.array([isomer.atomic_props[x] for x in ['x', 'y', 'z']]).T assert np.allclose(all_atomic_props_coordinates, isomer.atoms.get_positions()), f"Mismatch between atomic_props and ase.Atoms positions in isomer {isomer.isomer_name}." isomer_data[isomer.isomer_name] = { 'atomic_props': isomer.atomic_props, 'warning': isomer.warning, 'target_vectors': isomer.target_vectors, 'ligand_origins': isomer.ligand_origins, } return { "complex_name": self.complex_name, "isomers": isomer_data, "graph": graph_to_dict_with_node_labels(self.graph), "graph_hash": self.graph_hash, "metal_idc": self.metal_idc, "donor_idc": self.donor_indices, "ligand_idc": self.ligand_indices, "ligand_info": self._get_ligandinfo(), "input": { "clashing_tolerance": self.clashing_tolerance, "clashing_metal": self.clashing_metal, "duplicate_tolerance": self.duplicate_tolerance, "permutable_ligands": self.permutable_ligands, "monoaxial_optimization": self.monoaxial_optimization, "force_all_isomers": self.force_all_isomers, "complex_name_length": self.complex_name_length, "complex_name_suffix": self.complex_name_suffix, } }
def _get_complex_name(self, avoid_names: Optional[Iterable[str]]) -> str: """ Generate a (pseudo-)random complex name based on the graph hash. :param avoid_names: Iterable of names that should be avoided when generating a new name. :type avoid_names: Iterable[str] | None :return: Generated complex name string. :rtype: str """ return get_complex_name(seed=self.graph_hash, length=self.complex_name_length, suffix=self.complex_name_suffix, avoid_names=avoid_names) def _get_ligandinfo(self) -> Dict[str, Any]: """ Return ligand metadata required to reconstruct Ligand objects in AssembledIsomer. :return: Dictionary with keys such as 'unique_names', 'archetypes', 'donor_idcs', etc. :rtype: dict[str, Any] """ return { # Important info for making Ligands() objects in the AssembledIsomer(). 'unique_names': [lig.unique_name for lig in self.ligands], 'archetypes': [lig.archetype for lig in self.ligands], 'donor_idcs': [lig.donor_idc for lig in self.ligands], 'charges': [lig.charge for lig in self.ligands], 'stoichiometries': [lig.stoichiometry for lig in self.ligands], 'hapdent_idcs': [lig.hapdent_idc for lig in self.ligands], 'geometric_isomers_hapdent_idcs': [lig.geometric_isomers_hapdent_idc for lig in self.ligands], # Convenience information for the output csv. 'donors': ['-'.join(sorted(lig.donor_elements)) for lig in self.ligands] } def _get_rotated_ligands(self, target_vectors, ligand_origins) -> list[list[Atoms]]: """ Produce rotated and translated ASE Atoms for each ligand according to target vectors. Each ligand yields a list of ASE Atoms objects representing the possible geometric isomers compatible with the supplied target_vectors. Returned structure is a list of lists: outer list over ligands, inner list over isomer instances (ASE Atoms). :param target_vectors: List of target vectors for each ligand. :type target_vectors: list[list[list[float]]] # shape: (n_ligands, n_donors, 3) :param ligand_origins: List of origin coordinates used to translate each ligand instance. :type ligand_origins: list[list[float]] # shape: (n_ligands, 3) :return: Rotated and translated ligand isomer instances. :rtype: list[list[ ase.Atoms ]] """ rotated_ligands = [] for ligand, target_vector_list, origin in zip(self.ligands, target_vectors, ligand_origins): # Extract the archetype and donor atoms of the effective ligand, potentially with 'Cu' dummy atoms for haptic ligands. atoms, donor_atoms = ligand.get_isomers_effective_ligand_atoms_with_effective_donor_indices(dummy='Cu') # Cast the target vectors to numpy arrays target_vector_list = [np.array(v) for v in target_vector_list] # Align the donor atoms of the ligand to the target vectors. Either make all possible geometrical isomers or just the ones specified in the MetaLig. For most cases these two should be identical, but making all combinations has two consequences: (a) the order of input target vectors does not matter (the one with the lowest error is always assembled) and (b) some archetypes have more isomers, e.g. the `trigonal` archetype has in theory three isomers in which simply the ligand is rotated, but for the MetaLig we had decided to filter out these isomers so that only one is kept. if self.force_all_isomers: ligand_isomers, _, _ = try_all_geometrical_isomer_possibilities(atoms=atoms, donor_idc=donor_atoms[0], target_vectors=target_vector_list) else: ligand_isomers = [align_donor_atoms(atoms, donor_idc=idc, target_vectors=target_vector_list, return_rssd=False) for idc in donor_atoms] # Remove the dummy atom from the haptic ligands if ligand.n_haptic_atoms > 0: ligand_isomers = [remove_haptic_dummy_atom(atoms=atoms, dummy_atom='Cu') for atoms in ligand_isomers] # Translate the ligand to its correct location in the complex for ligand_isomer in ligand_isomers: # Note: This method assumes that the ligand has been pre-translated to 0,0,0, which is the case for all ligands in the MetaLig database. ligand_isomer.set_positions(ligand_isomer.get_positions() + np.array(origin)) # Append the rotated ligands to the list rotated_ligands.append(ligand_isomers) return rotated_ligands def _get_all_unique_metal_centers(self) -> List[ase.Atom]: """ Return a list of unique ASE Atom metal centers used across all ligand definitions. The method deduplicates identical metal atoms (same element and position). :return: List of unique ASE Atom objects for all metal centers. :rtype: list[ ase.Atom ] """ unique_metal_centers = [self.metal_centers[0][0]] # initialize the list with the first metal center for metal_list in self.metal_centers: for metal in metal_list: metal_in_list = any([are_atoms_equal(metal, m) for m in unique_metal_centers]) if not metal_in_list: unique_metal_centers.append(metal) return unique_metal_centers def _get_merged_graph_from_ligands_and_metal_centers(self) -> tuple[nx.Graph, list, list]: """ Merge per-ligand connectivity graphs and connect donor atoms to metal nodes. Returns a merged NetworkX Graph in which metal centers occupy the first node indices followed by ligand atom nodes. Also returns ligand atom index lists (per ligand) and flattened donor index list. :return: Tuple (graph, ligand_indices, donor_idc) :rtype: tuple[ nx.Graph, list[list[int]], list[int] ] :raises AssertionError: If merged graph is not connected or node labels mismatch expected atom list. """ ligand_graphs = [deepcopy(lig.graph) for lig in self.ligands] unique_metal_centers = self._get_all_unique_metal_centers() # Create the new graph by merging everything graph = nx.Graph() for i, unique_metal_center in enumerate(unique_metal_centers): graph.add_nodes_from([(i, {"node_label": unique_metal_center.symbol})]) # Relabel the nodes of the old graphs so that they are unique for the next step i = len(unique_metal_centers) # start after the metals ligand_indices = [] for ligand_graph in ligand_graphs: node_mapping = {node: i + k for k, node in enumerate(sorted(ligand_graph.nodes))} nx.relabel_nodes(ligand_graph, mapping=node_mapping, copy=False) ligand_indices.append(list(node_mapping.values())) i += len(ligand_graph.nodes) # Copy the ligand graphs for ligand_graph in ligand_graphs: graph.add_nodes_from(ligand_graph.nodes(data=True)) # add ligand nodes graph.add_edges_from(ligand_graph.edges()) # add ligand edges # Connect the metal centers to the ligands ligand_donor_indices = [[] for _ in self.ligands] for i, (ligand, ligand_metal_centers, ligand_graph) in enumerate(zip(self.ligands, self.metal_centers, ligand_graphs)): for metal_center in ligand_metal_centers: unique_metal_center_idx = \ [i for i, atom in enumerate(unique_metal_centers) if are_atoms_equal(atom, metal_center)][0] for atomic_donor_idx in ligand.donor_idc: assert ligand.atomic_props['atoms'][ atomic_donor_idx] in ligand.donor_elements, f"Atom {ligand.atomic_props['atoms'][atomic_donor_idx]} is not a donor atom of ligand." graph_donor_idx = sorted(ligand_graph.nodes)[atomic_donor_idx] graph.add_edge(unique_metal_center_idx, graph_donor_idx) if graph_donor_idx not in ligand_donor_indices[i]: ligand_donor_indices[i].append(graph_donor_idx) # Check if everything is valid assert nx.is_connected(graph), "The graph is not fully connected!" assert all([set(ligand_donor_indices[i]).issubset(set(ligand_indices[i])) for i in range(len(ligand_indices))]), "The ligand donor indices are not subset of the ligand indices!" assert sorted(graph.nodes) == list( range(len(graph.nodes))), f"The graphs indices are not in order: {list(graph.nodes)}" all_atomic_elements = [unique_metal_center.symbol for unique_metal_center in unique_metal_centers] for ligand in self.ligands: all_atomic_elements += ligand.atomic_props['atoms'] all_graph_elements = [graph.nodes[node]['node_label'] for node in sorted(graph.nodes)] assert all_graph_elements == all_atomic_elements, f"The graph elements do not match the atomic elements: {all_graph_elements} vs {all_atomic_elements}!" atomic_donor_elements = sorted([el for lig in self.ligands for el in lig.donor_elements]) graph_donor_elements = sorted( [graph.nodes[node]['node_label'] for idc in ligand_donor_indices for node in sorted(graph.nodes) if node in idc]) assert atomic_donor_elements == graph_donor_elements, f"The atomic donor elements do not match the graph donor elements: {atomic_donor_elements} vs {graph_donor_elements}!" # For debugging: Plot the graph only for the metals and the coordination atoms # plot_graph = deepcopy(graph) # keep_idc = list(range(len(unique_metal_centers))) + [idx for idc in ligand_donor_indices for idx in idc] # for node in list(plot_graph.nodes): # if node not in keep_idc: # plot_graph.remove_node(node) # view_graph(plot_graph) # Flatten the ligand donor indices donor_idc = [idx for idc in ligand_donor_indices for idx in idc] return graph, ligand_indices, donor_idc
class _AxialOptModifier: """ Optimize mono-coordinating ligand rotations around their coordination axis. Uses a global differential evolution optimizer to find per-ligand rotation angles minimizing a distance-based penalty (short interatomic distances penalized). """ def __init__(self, isomers: List['AssembledIsomer'], opt: bool = True, distance_cutoff: Optional[float] = 4.0, use_cutoff: bool = False): """ Initialize the mono-axial optimization modifier. :param isomers: List of AssembledIsomer objects to optimize. :type isomers: list[ AssembledIsomer ] :param opt: Whether to perform optimization or return inputs unchanged. :type opt: bool :param distance_cutoff: Optional Å cutoff to restrict penalty to interatomic distances below this value. :type distance_cutoff: float | None :param use_cutoff: Whether to apply the provided distance_cutoff. :type use_cutoff: bool """ self.input_isomers = isomers self.opt_command = opt # Optional distance cutoff for the objective function (None = use all pairs) self.distance_cutoff: Optional[float] = distance_cutoff if use_cutoff else None self.output_isomers = [] logging.debug(f"AxialOpt initialized with {len(self.input_isomers)} AssembledIsomer objects.") def modify(self, target_vectors_list, ligand_origins_list, maxiter=1000, popsize=15) -> List['AssembledIsomer']: """ Optimize rotation angles for each provided isomer independently. For each isomer, a separate differential evolution run is executed to find the set of ligand rotation angles that minimize the interatomic distance penalty. Only ligands with archetypes '1-mono' or '2-trans' are rotated. :param target_vectors_list: List of target_vectors corresponding to each isomer in input_isomers. :type target_vectors_list: list[list[list[list[float]]]] # shape: (n_isomers, n_ligands, n_donors, 3) :param ligand_origins_list: Corresponding ligand origins for each isomer. :type ligand_origins_list: list[list[list[float]]] # shape: (n_isomers, n_ligands, 3) :param maxiter: Maximum iterations passed to differential_evolution. :type maxiter: int :param popsize: Population size passed to differential_evolution. :type popsize: int :return: List of AssembledIsomer objects with updated atom positions and atomic_props. :rtype: list[ AssembledIsomer ] :raises ValueError: If lengths of input_isomers, target_vectors_list, and ligand_origins_list differ. """ if not self.opt_command: return self.input_isomers # Clear output isomers each run self.output_isomers = [] # Sanity check lengths if len(self.input_isomers) != len(target_vectors_list) or len(self.input_isomers) != len(ligand_origins_list): raise ValueError("Each isomer must have its own set of target_vectors and ligand_origins.") # Run the optimizer. # Optimize each isomer independently for isomer, target_vectors, ligand_origins in zip(self.input_isomers, target_vectors_list, ligand_origins_list): atoms = isomer.atoms.copy() # Each ligand rotation angle gets its own bound bounds = [[0, 360] for _ in target_vectors] archetypes = [ligand.archetype for ligand in isomer.ligands] # Run the optimizer result = differential_evolution( self.objective_function, bounds=bounds, args=(target_vectors, ligand_origins, atoms.copy(), isomer.ligand_idc, archetypes), seed=42, maxiter=maxiter, popsize=popsize, polish=True ) best_ligand_angles = list(result.x) # Correctly apply rotations to this isomer's ligands for angle, axis, origin, idc, ligand in zip(best_ligand_angles, target_vectors, ligand_origins, isomer.ligand_idc, isomer.ligands): if ligand.archetype not in ['1-mono', '2-trans']: continue self.rotate(atoms=atoms, vector=axis, origin=origin, idc=idc, angle=angle) # Copy isomer before modification to avoid unintended side effects new_isomer = deepcopy(isomer) new_isomer.update_positions(atoms) self.output_isomers.append(new_isomer) logging.debug(f"Optimized {len(self.output_isomers)} complexes correctly.") return self.output_isomers def objective_function(self, x: np.ndarray, vectors_in: List[np.array], origins_in: List[np.array], TMC_in: ase.Atoms, ligand_idc: list[list[int]], archetypes: List[str]) -> float: """ Compute the penalty for a given set of rotation angles. The penalty is computed as the sum over 1.0 / d^2 for all atomic pairs (or pairs within the distance_cutoff), where d is the interatomic distance after applying ligand rotations defined by x. Rotations are only applied to ligands whose archetype is in ['1-mono', '2-trans']. :param x: Array of rotation angles in degrees for each ligand. :type x: np.ndarray # shape: (n_rotations,) :param vectors_in: List of rotation axes (unit or non-unit vectors) per ligand. :type vectors_in: list[ np.ndarray ] # each of shape (3,) or (2,3) for 2-trans :param origins_in: Origins (3 floats) for each ligand rotation. :type origins_in: list[list[float]] # shape: (n_rotations, 3) :param TMC_in: ASE Atoms object representing the assembled complex to act upon. :type TMC_in: ase.Atoms :param ligand_idc: List of lists containing atom indices for each ligand in the complex. :type ligand_idc: list[list[int]] :param archetypes: List of ligand archetypes in the same order as ligand_idc. :type archetypes: list[str] :return: Scalar penalty value to minimize. :rtype: float """ # Generate a copy of the input complex TMC_worker = TMC_in.copy() for angle, axis, origin, idc, archetype in zip(list(x), vectors_in, origins_in, ligand_idc, archetypes): if archetype not in ['1-mono', '2-trans']: continue self.rotate(atoms=TMC_worker, vector=axis, origin=origin, idc=idc, angle=angle) # Vectorised distance computation using condensed matrix positions = TMC_worker.positions d = pdist(positions) # condensed distance matrix as flat array (each pair once) # Apply optional distance cutoff if self.distance_cutoff is not None: d = d[d <= self.distance_cutoff] # Guard against zero distances d = d[d > 0.0] if d.size == 0: return 0.0 penalty = np.sum(1.0 / d ** 2) return penalty @staticmethod def rotate(atoms: Atoms, vector: np.array, origin: np.array, idc: List[int], angle: int): """ Rotate selected atoms around a given axis by the specified angle (degrees). Accepts axis definitions of shape (3,), (1,3) or (2,3). For (2,3) the first row is used (2-trans). Rotation is performed in-place on the provided ASE Atoms object. :param atoms: ASE Atoms object whose subset of atoms will be rotated in-place. :type atoms: ase.Atoms :param vector: Rotation axis vector (or matrix of axis vectors). :type vector: np.ndarray # shape: (3,) or (1,3) or (2,3) :param origin: Rotation center coordinates. :type origin: list[float] # length 3 :param idc: Indices of atoms to rotate. :type idc: list[int] :param angle: Rotation angle in degrees. :type angle: int | float :return: The modified ASE Atoms object (same object passed in). :rtype: ase.Atoms :raises ValueError: If the axis has zero length or unexpected shape. """ # Normalize vector shape and cast the list to a numpy array vector = np.asarray(vector, dtype=float) # If vector is 2D, reduce to 1D if vector.ndim == 2: if vector.shape == (2, 3): # If 2-trans archetype we should have two near opposite donor atom vectors # Check if donor atom vectors are opposite # if not np.allclose(vector[0], -vector[1], atol=1e-3): # logging.warning("rotate(): 2-trans target vectors not perfectly opposite; using first vector.") vector = vector[0] # warn if not opposite but still use the first vector elif vector.shape == (1, 3): vector = vector[0] else: raise ValueError(f"rotate(): Unexpected axis shape {vector.shape}. Expected (3,), (1,3), or (2,3).") elif vector.ndim != 1 or vector.shape[0] != 3: raise ValueError(f"rotate(): Axis must be a single 3D vector, got shape {vector.shape}") # --- Normalize direction (unit vector) --- norm = np.linalg.norm(vector) if norm == 0: raise ValueError("rotate(): Rotation axis has zero length.") vector = vector / norm origin = np.asarray(origin, dtype=float) # --- Perform rotation --- rotation = R.from_rotvec(np.radians(angle) * vector) idc_arr = np.asarray(idc, dtype=int) rel = atoms.positions[idc_arr] - origin # translate atoms.positions[idc_arr] = rotation.apply(rel) + origin # rotate & translate back return atoms def visualize_structures(self): """ Visualize input and optimized output complexes interleaved using ASE's viewer. The ASE viewer will show alternating frames: Input_0, Optimized_0, Input_1, Optimized_1, ... If no optimization has been performed, the method prints a message and returns. :return: None :rtype: None """ if not self.output_isomers: print("No output complexes found. Run opt_mono_rotation() first.") return structures_to_view = [] for i, (in_isomer, out_isomer) in enumerate(zip(self.input_isomers, self.output_isomers)): input_copy = in_isomer.atoms.copy() input_copy.info["label"] = f"Input {i}" output_copy = out_isomer.atoms.copy() output_copy.info["label"] = f"Optimized {i}" structures_to_view.extend([input_copy, output_copy]) # interleave input/output print(f"Launching viewer for {len(structures_to_view)} structures...") view(structures_to_view) class _DuplicateIsomerFilter: """ Reduce the number of assembled isomers by detecting duplicates via fingerprint or alignment. The class supports two main deduplication strategies: - 'distances' (fingerprint-based, fast) - 'alignment' (rotational alignment and heuristic comparison) """ def __init__(self, isomers: List['AssembledIsomer'], method: str = "distances", grid_size=9, isomer_comparison_mode: str = "max_diff", isomer_comparison_grouping_mode: str = "cutoff", # 'cluster' or 'cutoff' fingerprint_grouping_cutoff: float = 0.5, metal_centers: List[List[ase.Atom]] = None, energy_heuristic_mode: str = "max", ): """ Initialize the duplicate isomer filter parameters. :param isomers: List of AssembledIsomer objects to analyze. :type isomers: list[ AssembledIsomer ] :param method: Deduplication strategy: 'alignment' or 'distances'. :type method: str :param grid_size: Grid density used by brute-force alignment search (points per angle). :type grid_size: int :param isomer_comparison_mode: Mode for fingerprint comparison ('max_diff', 'sum_diff', 'mean_diff', 'rmsd'). :type isomer_comparison_mode: str :param isomer_comparison_grouping_mode: Grouping approach: 'cluster' or 'cutoff'. :type isomer_comparison_grouping_mode: str :param fingerprint_grouping_cutoff: Numeric cutoff used in 'cutoff' grouping mode. Use NaN to disable. :type fingerprint_grouping_cutoff: float :param metal_centers: Metal center definitions used to determine rotation degrees of freedom. :type metal_centers: list[list[ ase.Atom ]] | None :param energy_heuristic_mode: Mode for energy heuristic in alignment ('max' or 'sum'). :type energy_heuristic_mode: str """ if metal_centers is None: metal_centers = [] self.isomers = isomers self.method = method.lower() self.grid_size = grid_size # The number of grid points when scanning from 0 to 360 for exact alignment self.isomer_comparison_mode = isomer_comparison_mode self.isomer_comparison_grouping_mode = isomer_comparison_grouping_mode self.fingerprint_grouping_cutoff = fingerprint_grouping_cutoff self.check_duplicate = pd.notna(fingerprint_grouping_cutoff) self.metal_centres = metal_centers self.unique_metal_centers = list({(atom.symbol, tuple(atom.position)): atom for sublist in self.metal_centres for atom in sublist}.values()) self.diff_matrix = None # Placeholder for the fingerprint difference matrix self.energy_heuristic_mode = energy_heuristic_mode self.isomer_group = [] self.similarity_cutoff_used = None def get_duplicate_groups(self) -> List[List['AssembledIsomer']]: """ Group input isomers into duplicate clusters. If duplicate checking is disabled or only one isomer exists, returns trivial groups. Otherwise, delegates to the configured reduction method. :return: List of groups; each group is a list of AssembledIsomer objects considered duplicates. :rtype: list[list[ AssembledIsomer ]] """ if len(self.isomers) <= 1: return [self.isomers] if not self.check_duplicate: # If duplicate checking is disabled, return each isomer as its own group. return [[isomer] for isomer in self.isomers] if self.method == "alignment": self.isomer_group = self._reduce_by_alignment() elif self.method == "distances": self.isomer_group = self._reduce_by_fingerprint() else: raise ValueError(f"Fatal Error: Unsupported reduction method '{self.method}. Supported methods are 'alignment' and 'distances'.") return self.isomer_group def _reduce_by_alignment(self) -> List[List['AssembledIsomer']]: """ Reduce isomers by searching for an alignment that minimizes a distance heuristic. Rotational alignment between pairs is performed using a brute-force grid over allowed rotation axes and angles. The resulting pairwise scores are clustered to form duplicate groups. :return: Duplicate groups as lists of AssembledIsomer objects. :rtype: list[list[ AssembledIsomer ]] """ n = len(self.isomers) diff_matrix = np.zeros((n, n)) counter = 0 total = n * (n - 1) // 2 for i in range(n): isomer1 = self.isomers[i] for j in range(i + 1, n): isomer2 = self.isomers[j] logging.debug(f"Comparing isomers [{i}] and [{j}] ... [{counter}/{total}]") score = self.align_isomers(isomer1.atoms, isomer2.atoms) diff_matrix[i, j] = score diff_matrix[j, i] = score counter += 1 self.diff_matrix = diff_matrix group_labels_matrix = self._analyze_similarity( self.diff_matrix, quantile=0.2, method=self.isomer_comparison_grouping_mode, cutoff=self.fingerprint_grouping_cutoff ) self.isomer_group = self._group_isomers(group_labels_matrix, self.isomers) return self.isomer_group def energy_heuristic(self, stat_atoms: ase.Atoms, rot_atoms: ase.Atoms): """ Compute a heuristic 'energy' that quantifies similarity between two isomers. For each element type the method solves an optimal assignment between like atoms (Hungarian algorithm) and either returns the sum of assigned distances ('sum' mode) or the maximum assigned distance among element-types ('max' mode). :param stat_atoms: Stationary ASE Atoms object (reference). :type stat_atoms: ase.Atoms :param rot_atoms: Rotated ASE Atoms object (to be compared). :type rot_atoms: ase.Atoms :return: Energy heuristic: scalar representing sum or maximum of assigned distances. :rtype: float :raises ValueError: If elements mismatch or counts per element differ between isomers. """ # Initialize the total energy and overall max distance. total_energy = 0.0 overall_max_distance = 0.0 # Get chemical symbols from both isomers stat_symbols = stat_atoms.get_chemical_symbols() rot_symbols = rot_atoms.get_chemical_symbols() # Precompute indices for each element. stat_index_dict = {element: [i for i, sym in enumerate(stat_symbols) if sym == element] for element in set(stat_symbols)} rot_index_dict = {element: [i for i, sym in enumerate(rot_symbols) if sym == element] for element in set(rot_symbols)} # Loop over each unique element in the stationary isomer. for element, stat_indices in stat_index_dict.items(): if element not in rot_index_dict: raise ValueError(f"Element {element} missing in rotated isomer") rot_indices = rot_index_dict[element] if len(stat_indices) != len(rot_indices): raise ValueError(f"Mismatch in number of '{element}' atoms between isomers: " f"{len(stat_indices)} vs. {len(rot_indices)}") # Get positions for the current element. stat_positions = stat_atoms.positions[stat_indices] rot_positions = rot_atoms.positions[rot_indices] # Compute the distance matrix using C-optimized function. distance_matrix = cdist(stat_positions, rot_positions) # Solve the assignment problem to pair atoms optimally row_ind, col_ind = linear_sum_assignment(distance_matrix) if self.energy_heuristic_mode == "sum": total_energy += distance_matrix[row_ind, col_ind].sum() elif self.energy_heuristic_mode == "max": element_max = distance_matrix[row_ind, col_ind].max() overall_max_distance = max(overall_max_distance, element_max) else: raise ValueError("Invalid mode. Use 'sum' or 'max'.") return total_energy if self.energy_heuristic_mode == "sum" else overall_max_distance def align_isomers(self, stationary_atoms: ase.Atoms, rotated_atoms: ase.Atoms) -> float: """ Align two isomers by searching for an alignment that minimizes a distance heuristic. Depending on the number of unique metal centres, the function reduces the rotational degrees of freedom (3D rotation for one metal, 1D rotation around metal-metal axis for two metals, and direct heuristic for three or more metals). :param stationary_atoms: ASE Atoms kept fixed during alignment. :type stationary_atoms: ase.Atoms :param rotated_atoms: ASE Atoms that will be rotated to align with stationary_atoms. :type rotated_atoms: ase.Atoms :return: Alignment score computed by objective_function and energy_heuristic. :rtype: float :raises AssertionError: If metal_centers was not set prior to calling this method. """ assert hasattr(self, "metal_centers"), ValueError("Fatal Error: metal_centers must be defined before calling align_isomers. ") logging.debug(f"Aligning isomers based on {len(self.unique_metal_centers)} metal centre(s).") # Here the number of metal centres and the fact that metal centres of different isomers # must always be aligned is taken into account to determine if the isomers are similar or not if len(self.unique_metal_centers) == 1: # There are 3 axes which an isomer can be rotated around to align it with another isomer bounds = [[0, 360] for _ in range(3)] # The 3 cardinal axes, properly centered on the metal centre axes = np.eye(3) # Standard Cartesian axes (x, y, z) logging.debug("Performing 3D brute-force alignment over x, y, z axes.") elif len(self.unique_metal_centers) == 2: # The isomer can only be rotated around the metal-metal axis to determine if the isomers are similar or not bounds = [[0, 360] for _ in range(1)] axis_vector = self.unique_metal_centers[1].position - self.unique_metal_centers[0].position axis_vector /= np.linalg.norm(axis_vector) # Normalize axes = [axis_vector] logging.debug(f"Performing 1D brute-force alignment around axis: {axis_vector.tolist()}") elif len(self.unique_metal_centers) >= 3: # 3 metal centres means each isomer is fixed in space and their archetypes can be directly compared # We simply return the energy heuristic logging.debug("Three or more metal centres detected — skipping alignment and using direct heuristic.") return self.energy_heuristic(stat_atoms=stationary_atoms, rot_atoms=rotated_atoms) else: raise ValueError(f"Fatal Error: Unsupported number of metal centres ({len(self.unique_metal_centers)}). ") # Perform brute-force global optimization using the configured grid density result_angles = brute( func=self.objective_function, ranges=bounds, args=(stationary_atoms, rotated_atoms, axes), Ns=self.grid_size ) min_score = self.objective_function(np.array(result_angles), stationary_atoms, rotated_atoms, axes) logging.debug(f"Best alignment score after brute-force search: {min_score:.4f}") return min_score def objective_function(self, x: np.ndarray, atoms1: ase.Atoms, atoms2: ase.Atoms, axes: np.array): """ Objective used by the brute-force alignment routine. Applies combined rotations described by x around axes and returns the energy heuristic comparing atoms1 (reference) and the rotated version of atoms2. :param x: Rotation angles in degrees. :type x: np.ndarray :param atoms1: Stationary reference ASE Atoms. :type atoms1: ase.Atoms :param atoms2: ASE Atoms to be rotated. :type atoms2: ase.Atoms :param axes: List/array of rotation axes. :type axes: np.ndarray | list[np.ndarray] :return: Scalar value from energy_heuristic to minimize. :rtype: float """ # Copy the input atoms to avoid modifying the original stationary_isomer = atoms1.copy() rotated_isomer = atoms2.copy() # Precompute the combined rotation matrix. R_total = self.combined_rotation_matrix(x, axes) # Apply the rotation to the rotated isomer self.apply_combined_rotation(atoms=rotated_isomer, R_total=R_total, center=np.array(self.unique_metal_centers[0].position)) # calculate the energy heuristic that will be minimized val = self.energy_heuristic(stat_atoms=stationary_isomer, rot_atoms=rotated_isomer) return val @staticmethod def combined_rotation_matrix(angles, axes): """ Compute a combined 3x3 rotation matrix from sequences of angles and axes. :param angles: Iterable of angles in degrees. :type angles: Iterable[ float ] :param axes: Corresponding iterable of 3D axis vectors. :type axes: Iterable[ list[float] ] | np.ndarray :return: Combined rotation matrix (3x3). :rtype: np.ndarray """ # Start with identity matrix. R_total = np.eye(3) for angle, axis in zip(angles, axes): # Create a rotation for this angle and axis. r = R.from_rotvec(np.deg2rad(angle) * np.array(axis)) # Combine rotations. Note that the order matters! R_total = r.as_matrix() @ R_total return R_total @staticmethod def apply_combined_rotation(atoms, R_total, center): """ Apply a combined rotation matrix to an ASE Atoms object about a center. Positions are rotated in-place. :param atoms: ASE Atoms to rotate. :type atoms: ase.Atoms :param R_total: 3x3 rotation matrix. :type R_total: np.ndarray :param center: Center of rotation (3 floats). :type center: list[float] | np.ndarray :return: None :rtype: None """ # Shift positions relative to center, apply rotation, then shift back. shifted = atoms.positions - center atoms.positions = center + (shifted @ R_total.T) def _reduce_by_fingerprint(self): """ Reduce isomers using a sorted interatomic-distance fingerprint and clustering. Computes a symmetric difference matrix and clusters pairs labelled 'Close' to form groups. :return: List of duplicate groups (lists of AssembledIsomer objects). :rtype: list[list[ AssembledIsomer ]] """ self.diff_matrix = self._compute_fingerprint_matrix(self.isomers) # Cluster labels: "Close" or "Far" for each pair group_labels_matrix = self._analyze_similarity(self.diff_matrix, quantile=0.2, method=self.isomer_comparison_grouping_mode, cutoff=self.fingerprint_grouping_cutoff) self.isomer_group = self._group_isomers(group_labels_matrix, self.isomers) return self.isomer_group @staticmethod def _group_isomers(group_labels_matrix: np.ndarray, isomers: List['AssembledIsomer']) -> List[List['AssembledIsomer']]: """ Group isomers into connected components based on a pairwise label matrix. The input matrix should contain labels 'Close'/'Far' for each pair; connected 'Close' entries are grouped into duplicate clusters. The method ensures that groups preserve original ordering. :param group_labels_matrix: Square 2D numpy array with values 'Close' or 'Far'. :type group_labels_matrix: np.ndarray :param isomers: List of AssembledIsomer objects to group. :type isomers: list[ AssembledIsomer ] :return: Grouped isomers as a list of lists preserving original order within groups. :rtype: list[list[ AssembledIsomer ]] :raises ValueError: If group_labels_matrix is not square or does not match number of isomers. """ n = len(isomers) if group_labels_matrix.shape != (n, n): raise ValueError( "group_labels_matrix must be a square matrix with the same " "dimension as the number of isomers." ) # Build an undirected adjacency list adjacency = [[] for _ in range(n)] for i in range(n): for j in range(i + 1, n): if ( group_labels_matrix[i, j] == "Close" or group_labels_matrix[j, i] == "Close" ): adjacency[i].append(j) adjacency[j].append(i) # Depth-first search for connected components visited = [False] * n groups: List[List["AssembledIsomer"]] = [] for start in range(n): if visited[start]: continue stack = [start] component_indices: List[int] = [] while stack: node = stack.pop() if visited[node]: continue visited[node] = True component_indices.append(node) stack.extend(adjacency[node]) # Maintain original order inside the component component_indices.sort() groups.append([isomers[idx] for idx in component_indices]) assert len([_ for sublist in groups for _ in sublist]) == n, "Grouped indices do not cover all isomers or have duplicates." return groups def _analyze_similarity(self, matrix: np.ndarray, quantile: float = 0.2, method: str = "cluster", cutoff: Optional[float] = None) -> np.ndarray: """ Convert a pairwise difference matrix to a matrix of labels ('Close'/'Far'). Two modes are supported: - 'cluster': MeanShift clustering on upper-triangle values to identify the 'Close' cluster. - 'cutoff': Hard cutoff thresholding of matrix values. :param matrix: Square pairwise difference matrix. :type matrix: np.ndarray :param quantile: Quantile used to estimate bandwidth in clustering mode (unused default kept). :type quantile: float :param method: 'cluster' or 'cutoff' defining grouping strategy. :type method: str :param cutoff: Numeric cutoff required when method='cutoff'. :type cutoff: float | None :return: Square matrix with entries 'Close' or 'Far'. :rtype: np.ndarray :raises ValueError: If cutoff is None when method='cutoff'. """ if method == "cluster": from sklearn.cluster import MeanShift # Flatten the matrix values for clustering. # Upper triangle indices are used to avoid redundancy. triu_indices = np.triu_indices_from(matrix, k=1) values = matrix[triu_indices].reshape(-1, 1) # bandwidth = float(estimate_bandwidth(values, quantile=quantile, n_samples=len(values))) ms = MeanShift(bandwidth=0.5) ms.fit(values) cluster_labels = ms.labels_ cluster_centers = ms.cluster_centers_ # Identify the cluster whose center is closest to zero. close_cluster_idx = np.argmin(np.abs(cluster_centers)) # Extract the true cutoff as the largest value still labeled "Close" close_values = values[np.where(cluster_labels == close_cluster_idx)] self.similarity_cutoff_used = float(np.max(close_values)) # ensure it's a float, not a 0-d array # Reconstruct group label matrix full_labels = np.full(matrix.shape, "Far", dtype=object) full_labels[triu_indices] = np.where(cluster_labels == close_cluster_idx, "Close", "Far") full_labels[(triu_indices[1], triu_indices[0])] = full_labels[triu_indices] # mirror np.fill_diagonal(full_labels, "Far") # enforce diagonal is "Far" return full_labels elif method == "cutoff": if cutoff is None: raise ValueError("Fatal Error: Cutoff value must be provided when using 'cutoff' method.") group_labels = np.where(matrix <= cutoff, "Close", "Far") self.similarity_cutoff_used = cutoff return group_labels else: raise ValueError( f"Fatal Error: Unsupported similarity analysis method '{method}'. " f"Supported options: 'cluster', 'cutoff'.") def _compute_fingerprint_matrix(self, isomers: list) -> np.ndarray: """ Compute symmetric fingerprint difference matrix for a list of isomers. :param isomers: List of AssembledIsomer objects. :type isomers: list[ AssembledIsomer ] :return: Symmetric numpy array (n,n) of pairwise fingerprint differences. :rtype: np.ndarray """ # Generate fingerprints for each isomer fingerprints = [self._compute_sorted_distance_fingerprint(isomer)[0] for isomer in isomers] # Ensure all fingerprints have the same length assert all(len(fp) == len(fingerprints[0]) for fp in fingerprints) # Initialize a square matrix to hold the differences n = len(fingerprints) diff_matrix = np.zeros((n, n)) for i in range(n): for j in range(i + 1, n): # Only compute upper triangle diff = self._fingerprint_comparison(fingerprints[i], fingerprints[j], mode=self.isomer_comparison_mode) diff_matrix[i, j] = diff diff_matrix[j, i] = diff # Reflect to lower triangle return diff_matrix @staticmethod def _compute_sorted_distance_fingerprint(isomer) -> tuple[np.ndarray, list[tuple[str, str]]]: """ Compute a sorted inter-atomic distance fingerprint for an isomer. The fingerprint vector is the ordered list of pairwise distances between atoms, sorted first by element-pair (lexicographically) and then by ascending distance. Also returns the corresponding ordered element-pair labels. :param isomer: AssembledIsomer object with an `.atoms` ASE Atoms attribute. :type isomer: Any :return: Tuple (sorted_distances, sorted_element_pairs). :rtype: tuple[ np.ndarray, list[ tuple[str, str] ] ] """ positions = isomer.atoms.get_positions() # (N, 3) elements = np.asarray(isomer.atoms.get_chemical_symbols()) # Get pairwise distances and elements in a fully vectorized manner dists = pdist(positions) # length N*(N-1)//2 i_idx, j_idx = np.triu_indices(len(elements), k=1) # matching i<j indices e_i, e_j = elements[i_idx], elements[j_idx] # Sort elements to ensure that (C-H) and (H-C) are treated the same swap = e_i > e_j first_elem = np.where(swap, e_j, e_i) second_elem = np.where(swap, e_i, e_j) # Global sort: primary key = elements, secondary key = distance order = np.lexsort((dists, second_elem, first_elem)) # last key is primary sorted_dists = dists[order] # sort distances according to the order # also get sorted element pairs sorted_pairs = list(zip(first_elem[order], second_elem[order])) return sorted_dists, sorted_pairs @staticmethod def _fingerprint_comparison(fp1: np.ndarray, fp2: np.ndarray, mode: str = "max_diff"): """ Compare two fingerprint vectors with various metrics. Supported modes: - 'max_diff': maximum absolute element-wise difference - 'sum_diff': sum of absolute differences - 'mean_diff': mean absolute difference - 'rmsd': root mean square difference :param fp1: Fingerprint array for isomer 1. :type fp1: np.ndarray :param fp2: Fingerprint array for isomer 2. :type fp2: np.ndarray :param mode: Comparison mode string. :type mode: str :return: Scalar score representing fingerprint difference according to mode. :rtype: float :raises ValueError: If an unsupported mode is supplied. """ logging.debug(f"Comparing fingerprints with mode '{mode}'.") if mode == "max_diff": return np.max(np.abs(fp1 - fp2)) elif mode == "sum_diff": return np.sum(np.abs(fp1 - fp2)) elif mode == "mean_diff": return np.mean(np.abs(fp1 - fp2)) elif mode == "rmsd": return np.sqrt(np.mean((fp1 - fp2) ** 2)) else: raise ValueError( f"Fatal Error: Unsupported fingerprint comparison mode '{mode}'. " f"Supported modes are 'max_diff', 'sum_diff', 'mean_diff', and 'rmsd'.") def plot_fingerprint_difference_matrix(self, write_svg: bool = True, plot_plotly: bool = True, colorscale_min: float = 0.0, colorscale_mid: float = 0.5, color_scale_max: float = 1.0, min_color: Optional[dict] = None, mid_color: Optional[dict] = None, max_color: Optional[dict] = None, cell_label_mode: str = "value", # "value", "group", or "none" ) -> None: """ Visualize the fingerprint difference matrix as a heatmap (Matplotlib and optional Plotly). Produces an SVG file and optionally launches an interactive Plotly Dash app. The function will compute similarity grouping if not already available. :param write_svg: Whether to save a Matplotlib SVG of the heatmap. :type write_svg: bool :param plot_plotly: Whether to launch an interactive Plotly/Dash visualization. :type plot_plotly: bool :param colorscale_min: Lower bound for custom colorscale interpolation. :type colorscale_min: float :param colorscale_mid: Midpoint for colorscale interpolation. :type colorscale_mid: float :param color_scale_max: Upper bound for colorscale interpolation. :type color_scale_max: float :param min_color: Dictionary with 'r','g','b' ints for minimum color. :type min_color: dict | None :param mid_color: Dictionary with 'r','g','b' ints for mid color. :type mid_color: dict | None :param max_color: Dictionary with 'r','g','b' ints for maximum color. :type max_color: dict | None :param cell_label_mode: 'value' to show numeric differences, 'group' to show Close/Far, or 'none'. :type cell_label_mode: str :return: None :rtype: None """ from matplotlib.colors import LinearSegmentedColormap import matplotlib.pyplot as plt assert cell_label_mode in {"value", "group", "none"}, ValueError( f"Fatal Error: Unsupported cell label mode '{cell_label_mode}'. " "Supported modes are 'value', 'group', and 'none'.") # Fallback defaults for colors min_color = {"r": 238, "g": 100, "b": 97} if min_color is None else min_color mid_color = {"r": 255, "g": 255, "b": 255} if mid_color is None else mid_color max_color = {"r": 12, "g": 171, "b": 185} if max_color is None else max_color if self.diff_matrix is None: self.get_duplicate_groups() for isomer_group in self.isomer_group: for i in range(1, len(isomer_group)): isomer_group[i].warning = 'duplicate' # Mark all but the first isomer in the group as duplicate. labels_list = [f"{i}" for i in range(len(self.isomers))] df = pd.DataFrame(self.diff_matrix, index=labels_list, columns=labels_list) group_labels_matrix = self._analyze_similarity( matrix=self.diff_matrix, method=self.isomer_comparison_grouping_mode, cutoff=self.fingerprint_grouping_cutoff ) # --- Matplotlib --- if write_svg: cmap = LinearSegmentedColormap.from_list("custom_rwb", [ (colorscale_min, tuple(np.array(list(min_color.values())) / 255)), (colorscale_mid, tuple(np.array(list(mid_color.values())) / 255)), (color_scale_max, tuple(np.array(list(max_color.values())) / 255)) ]) fig, ax = plt.subplots(figsize=(8, 8)) cax = ax.imshow(df.values, cmap=cmap, vmin=0, vmax=1.0) def get_contrasting_text_color(value, vmin=0.0, vmax=0.5): norm_val = (value - vmin) / (vmax - vmin) r, g, b = cmap(norm_val)[:3] luminance = 0.299 * r + 0.587 * g + 0.114 * b return 'black' if luminance > 0.5 else 'white' for i in range(len(df)): for j in range(len(df)): label = "" if cell_label_mode == "value": label = f"{df.values[i, j]:.2f}" elif cell_label_mode == "group": label = group_labels_matrix[i, j] if label: color = get_contrasting_text_color(df.values[i, j]) ax.text(j, i, label, ha='center', va='center', color=color, fontsize=8) ax.set_xticks(np.arange(len(df))) ax.set_yticks(np.arange(len(df))) ax.set_xticklabels(labels_list, rotation=0, ha='center', fontsize=12) ax.set_yticklabels(labels_list, rotation=0, ha='right', fontsize=12) ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) fig.colorbar(cax, ax=ax, fraction=0.046, pad=0.04, label="Difference") ax.set_title("Comparison Matrix (Matplotlib)") plt.tight_layout() plt.savefig("Isomer_comparison_matrix.svg", format="svg") plt.close() # --- Plotly --- if plot_plotly: color_scale = [ [colorscale_min, f"rgb({min_color['r']},{min_color['g']},{min_color['b']})"], [colorscale_mid, f"rgb({mid_color['r']},{mid_color['g']},{mid_color['b']})"], [color_scale_max, f"rgb({max_color['r']},{max_color['g']},{max_color['b']})"] ] self.launch_interactive_heatmap(df, group_labels_matrix, color_scale, cell_label_mode) def launch_interactive_heatmap(self, df: pd.DataFrame, group_labels_matrix: np.ndarray, color_scale: list, cell_label_mode: str = "value"): """ Launch an interactive Dash application showing the difference matrix and histogram. Clicking on a heatmap cell opens the ASE viewer with the corresponding isomer alignment. :param df: Pandas DataFrame representing the difference matrix (square). :type df: pd.DataFrame :param group_labels_matrix: Square matrix of 'Close'/'Far' labels. :type group_labels_matrix: np.ndarray :param color_scale: Plotly-compatible color scale definition. :type color_scale: list :param cell_label_mode: Mode for cell labels ('value', 'group', 'none'). :type cell_label_mode: str :return: None :rtype: None """ if "plotly" not in sys.modules: print("Plotly is not installed. Please install it to use the interactive heatmap feature.") return None app = dash.Dash(__name__) fig = px.imshow( df, color_continuous_scale=color_scale, zmin=0, zmax=1.0, title="Comparison Matrix (Interactive)", labels={"x": "AssembledIsomer Index", "y": "AssembledIsomer Index", "color": "Difference"} ) # Get the upper triangle indices (excluding the diagonal) triu_indices = np.triu_indices_from(df.values, k=1) # Extract the values from the upper triangle (excluding diagonal) upper_triangle_values = df.values[triu_indices] # Create the histogram fig_hist = px.histogram( x=upper_triangle_values, nbins=400, title="Histogram of Comparison Values (Upper Triangle Only)", labels={'x': 'Difference Value', 'y': 'Count'} ) if self.similarity_cutoff_used is not None: fig_hist.add_vline( x=self.similarity_cutoff_used, line_width=3, line_dash="dash", line_color="red", annotation_text=f"Cutoff = {self.similarity_cutoff_used:.2f}", annotation_position="top left" ) annotations = [] for i in range(len(df)): for j in range(len(df)): value = df.values[i, j] label = "" if cell_label_mode == "value": label = f"{value:.2f}" elif cell_label_mode == "group": label = group_labels_matrix[i, j] if label: # Compute contrast-aware color r, g, b = [int(c) for c in color_scale[min(2, int(3 * value))][1][4:-1].split(',')] luminance = 0.299 * r + 0.587 * g + 0.114 * b font_color = "black" if luminance > 186 else "white" annotations.append(dict( x=j, y=i, text=label, showarrow=False, font=dict(color=font_color, size=10), xanchor="center", yanchor="middle" )) fig.update_layout( annotations=annotations, dragmode="zoom", hovermode="closest", height=800, width=800 ) app.layout = dash.html.Div([ dash.dcc.Graph(id="heatmap", figure=fig), dash.dcc.Graph(id="histogram", figure=fig_hist), ]) @app.callback( dash.Output("heatmap", "figure"), dash.Input("heatmap", "clickData") ) def display_click(clickData): if clickData: point = clickData["points"][0] i_idx = int(point["y"]) j_idx = int(point["x"]) print(f"Clicked cell: ({i_idx}, {j_idx}) — opening viewer.") # Direct call to viewer self.view_isomer_alignment(i_idx, j_idx, grid_size=self.grid_size) return fig app.run( debug=True, use_reloader=False, **{ "threaded": False, "processes": 1, "use_debugger": True, "dev_tools_silence_routes_logging": False, "dev_tools_prune_errors": False } ) def view_isomer_alignment(self, index1: int, index2: int, grid_size=None) -> None: """ Visualize two isomers and their optimal alignment using ASE viewer. Displays three frames: reference isomer (index1), unaligned isomer (index2), and overlaid result after aligning isomer2 to isomer1. :param index1: Index of the reference isomer in self.isomers. :type index1: int :param index2: Index of the isomer to align and overlay. :type index2: int :param grid_size: Optional grid density for brute alignment; defaults to object's grid_size. :type grid_size: int | None :return: None :rtype: None :raises AssertionError: If indices are out of range. """ assert 0 <= index1 < len(self.isomers), f"Index1 out of range: {index1}" assert 0 <= index2 < len(self.isomers), f"Index2 out of range: {index2}" print("1") isomer1 = self.isomers[index1].atoms.copy() print("2") isomer2 = self.isomers[index2].atoms.copy() print("3") isomer2_aligned = self.isomers[index2].atoms.copy() print("4") # Determine alignment axes if len(self.unique_metal_centers) == 1: bounds = [[0, 360] for _ in range(3)] axes = np.eye(3) elif len(self.unique_metal_centers) == 2: bounds = [[0, 360]] axis_vector = self.unique_metal_centers[1].position - self.unique_metal_centers[0].position axis_vector /= np.linalg.norm(axis_vector) axes = [axis_vector] elif len(self.unique_metal_centers) >= 3: bounds = None axes = None print("Three or more metal centres — skipping rotation; showing structures unaligned.") else: raise ValueError("Fatal Error: Invalid number of metal centres for alignment.") # Perform rotation if applicable if len(self.unique_metal_centers) < 3: result_angles = brute( func=self.objective_function, ranges=bounds, args=(isomer1, isomer2_aligned, axes), Ns=grid_size if grid_size else self.grid_size ) R_total = self.combined_rotation_matrix(result_angles, axes) self.apply_combined_rotation(atoms=isomer2_aligned, R_total=R_total, center=np.array(self.unique_metal_centers[0].position)) # Create overlaid image: isomer1 + rotated isomer2 overlaid = isomer1.copy() + isomer2_aligned.copy() # Assign tags so colors are distinguishable for atom in isomer1: atom.tag = 1 for atom in isomer2: atom.tag = 2 for atom in overlaid: atom.tag = 3 print("Launching ASE viewer with aligned isomers...") view([isomer1, isomer2, overlaid], viewer="ase") def debug_fingerprints(self, idx1, idx2): """ Display and return a DataFrame comparing two isomer fingerprints for debugging. The DataFrame contains distances, element pairs and absolute differences to help interpret which entries contribute most to the fingerprint score. :param idx1: Index of the first isomer to compare. :type idx1: int :param idx2: Index of the second isomer to compare. :type idx2: int :return: Pandas DataFrame summarizing distances, pairs and differences. :rtype: pd.DataFrame :raises ValueError: If fingerprints differ in length. """ # Get the fingerprints of the two isomers fp1, pairs1 = self._compute_sorted_distance_fingerprint(self.isomers[idx1]) fp2, pairs2 = self._compute_sorted_distance_fingerprint(self.isomers[idx2]) # Ensure both fingerprints are of the same length if len(fp1) != len(fp2): raise ValueError(f"Fingerprints of isomers {idx1} and {idx2} have different lengths: {len(fp1)} vs {len(fp2)}") # Create a DataFrame for visualization df = pd.DataFrame({ "Distance": np.arange(len(fp1)), f"Isomer {idx1}": fp1, f"Isomer {idx2}": fp2, "Pair": [f"{pair[0]}-{pair[1]}" for pair in pairs1], "Difference": np.abs(fp1 - fp2), }) # display the results print(df) return df class _IsomerClashFilter: """ Filter assembled isomers for interatomic clashes based on covalent radii. Clashes are determined by comparing pairwise distances to the sum of covalent radii plus a buffer. Pairs within the same ligand are ignored. Metal-involving pairs can optionally be excluded from checks. """ def __init__( self, buffer: float = -0.3, check_metal_clashes: bool = False ): """ Initialize clash filter configuration. :param buffer: Distance buffer (Å) added to sum of covalent radii. Negative allows closer approach. :type buffer: float :param check_metal_clashes: If False, pairs involving any metal atom are ignored. :type check_metal_clashes: bool """ self.buffer = buffer self.check_metal_clashes = check_metal_clashes def has_clashing_atoms( self, atoms: ase.Atoms, ligand_idc: list[list[int]], metal_idc: list[int], ) -> bool: """ Determine whether the provided ASE Atoms contains any clashing atom pairs. Ignores intra-ligand pairs and optionally ignores any pair involving metal atoms. Clashing is True if any remaining interatomic distance is smaller than the corresponding sum of covalent radii plus the configured buffer. :param atoms: ASE Atoms object representing the assembled isomer. :type atoms: ase.Atoms :param ligand_idc: List of lists of atom indices corresponding to each ligand. :type ligand_idc: list[list[int]] :param metal_idc: List of indices of metal center atoms within `atoms`. :type metal_idc: list[int] :return: True if any non-ignored pair is closer than allowed, otherwise False. :rtype: bool """ n = len(atoms) if n <= 1: # nothing to clash return False pos = atoms.positions # (N, 3) symbols = atoms.get_chemical_symbols() # list[str] # Get vectorized covalent radii radii_map = {s: Element(s).covalent_radius_angstrom for s in set(symbols)} radii = np.fromiter((radii_map[s] for s in symbols), dtype=float) # Get arrays of pairwise distances and minimum allowed distances dists = pdist(pos) i_idx, j_idx = np.triu_indices(n, k=1) # matches pdist order min_allowed = radii[i_idx] + radii[j_idx] + self.buffer # Build masks to exclude intra-ligand distances and optionally metal distances mask = np.ones_like(dists, dtype=bool) # (a) skip intra-ligand pairs ligand_id_mask = np.full(n, -1, dtype=np.int32) for ligand_id, ligand_indices in enumerate(ligand_idc): ligand_id_mask[ligand_indices] = ligand_id same_ligand = (ligand_id_mask[i_idx] != -1) & (ligand_id_mask[i_idx] == ligand_id_mask[j_idx]) mask &= ~same_ligand # (b) optionally skip any pair that involves a metal atom, i.e. ligand-metal or metal-metal clashes metal_idc = np.asarray(metal_idc, dtype=np.int32) if not self.check_metal_clashes and metal_idc.size: metal_pair = np.isin(i_idx, metal_idc) | np.isin(j_idx, metal_idc) mask &= ~metal_pair if not np.any(mask): # nothing left to check return False # Check if any distances are below the allowed minimum which means atoms are clashing has_clashes = np.any(dists[mask] < min_allowed[mask]) return has_clashes