Skip to content

API Reference

agave_chem

agave_chem initialization.

IdenticalFragmentMapper

Bases: ReactionMapper

Reaction mapper that identifies and atom-maps fragments appearing identically on both sides of a reaction.

Source code in agave_chem/mappers/identical_fragments/identical_fragment_mapper.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
class IdenticalFragmentMapper(ReactionMapper):
    """
    Reaction mapper that identifies and atom-maps fragments appearing
    identically on both sides of a reaction.
    """

    def __init__(self, mapper_name: str, mapper_weight: float = 1):
        super().__init__("identical_fragment", mapper_name, mapper_weight)

    def _atom_map_identical_fragments(
        self, reaction_smiles: str
    ) -> Tuple[List[str], str]:
        """
        Atom map identical fragments in a reaction SMILES string.

        Args:
            reaction_smiles (str): A reaction SMILES string.

        Returns:
            Tuple[List[str], str]:
                - First element: A list of mapped identical fragment SMILES.
                - Second element: The remaining reaction SMILES with identical
                  fragments removed from both sides.
        """
        reactants, products = self._split_reaction_components(reaction_smiles)

        reactants_smiles_list = reactants.split(".")
        products_smiles_list = products.split(".")

        reactants_smiles_list_mapping_dict = {
            canonicalize_smiles(reactant): reactant
            for reactant in reactants_smiles_list
        }

        canonicalized_reactants_smiles_list = [
            canonicalize_smiles(smiles) for smiles in reactants_smiles_list
        ]
        canonicalized_products_smiles_list = [
            canonicalize_smiles(smiles) for smiles in products_smiles_list
        ]

        atom_mapped_identical_reactants_products = []
        atom_map_num = 500
        for canonicalized_reactant in canonicalized_reactants_smiles_list:
            reactant = reactants_smiles_list_mapping_dict[canonicalized_reactant]
            if canonicalized_reactant in canonicalized_products_smiles_list:
                reactants_smiles_list.remove(reactant)
                products_smiles_list.remove(reactant)
                reactant_mol = Chem.MolFromSmiles(canonicalized_reactant)
                for atom in reactant_mol.GetAtoms():
                    atom.SetAtomMapNum(atom_map_num)
                    atom_map_num += 1
                mapped_reactant = Chem.MolToSmiles(reactant_mol)
                atom_mapped_identical_reactants_products.append(mapped_reactant)
        return (
            atom_mapped_identical_reactants_products,
            ".".join(reactants_smiles_list) + ">>" + ".".join(products_smiles_list),
        )

    def _add_identical_fragments_to_mapping(
        self,
        mapped_reaction_smiles: str,
        atom_mapped_identical_reactants_products: List[str],
    ) -> str:
        """
        Add identical fragments back to a mapped reaction SMILES string.

        Args:
            mapped_reaction_smiles (str): A mapped reaction SMILES string.
            atom_mapped_identical_reactants_products (List[str]): A list of
                atom-mapped identical fragment SMILES to append to both sides.

        Returns:
            str: A mapped reaction SMILES string with identical fragments added.
        """
        reactants, products = self._split_reaction_components(mapped_reaction_smiles)

        reactants_smiles_list = reactants.split(".")
        products_smiles_list = products.split(".")

        for identical_fragment in atom_mapped_identical_reactants_products:
            reactants_smiles_list.append(identical_fragment)
            products_smiles_list.append(identical_fragment)

        mapped_reactants = ".".join(reactants_smiles_list)
        mapped_products = ".".join(products_smiles_list)

        return mapped_reactants + ">>" + mapped_products

    def create_identical_fragments_mapping_list(
        self,
        reaction_smiles_list: List[str],
    ) -> Tuple[List[str], List[List[str]]]:
        """
        Strip identical fragments from a list of reactions for downstream mapping.

        Args:
            reaction_smiles_list (List[str]): A list of reaction SMILES strings.

        Returns:
            Tuple[List[str], List[List[str]]]:
                - First element: Reaction SMILES with identical fragments removed.
                - Second element: Per-reaction lists of atom-mapped identical
                  fragment SMILES to be re-added after downstream mapping.
        """
        new_rxns = []
        identical_fragments_mapping_list = []
        for reaction_smiles in reaction_smiles_list:
            atom_mapped_identical_fragments, new_rxn = (
                self._atom_map_identical_fragments(reaction_smiles)
            )
            identical_fragments_mapping_list.append(atom_mapped_identical_fragments)
            new_rxns.append(new_rxn)
        return new_rxns, identical_fragments_mapping_list

    def resolve_identical_fragments_mapping_list(
        self,
        mapped_reaction_smiles_list: List[str],
        identical_fragments_mapping_list: List[List[str]],
    ) -> List[str]:
        """
        Re-add identical fragments to a list of already-mapped reaction SMILES.

        Args:
            mapped_reaction_smiles_list (List[str]): Mapped reaction SMILES strings
                (from a downstream mapper).
            identical_fragments_mapping_list (List[List[str]]): Per-reaction lists of
                atom-mapped identical fragment SMILES (produced by
                ``create_identical_fragments_mapping_list``).

        Returns:
            List[str]: Final reaction SMILES strings with identical fragments restored.
        """
        final_reactions = []
        for mapped_reaction_smiles, identical_fragments_mapping in zip(
            mapped_reaction_smiles_list, identical_fragments_mapping_list
        ):
            final_reactions.append(
                self._add_identical_fragments_to_mapping(
                    mapped_reaction_smiles, identical_fragments_mapping
                )
            )
        return final_reactions

    def map_reaction(self, reaction_smiles: str) -> ReactionMapperResult:
        """
        Map a single reaction by atom-mapping its identical fragments.

        Args:
            reaction_smiles (str): A reaction SMILES string.

        Returns:
            ReactionMapperResult: Mapping result. If the input is invalid, an empty
                default result is returned.
        """
        if not self._reaction_smiles_valid(reaction_smiles):
            return self._return_default_mapping_dict(reaction_smiles)

        atom_mapped_fragments, remaining_rxn = self._atom_map_identical_fragments(
            reaction_smiles
        )

        if atom_mapped_fragments:
            mapped_reaction_smiles = self._add_identical_fragments_to_mapping(
                remaining_rxn, atom_mapped_fragments
            )
        else:
            mapped_reaction_smiles = reaction_smiles

        if not self._verify_validity_of_mapping(
            mapped_reaction_smiles, expect_full_mapping=False
        ):
            logger.warning("Invalid mapping")
            return self._return_default_mapping_dict(reaction_smiles)

        return ReactionMapperResult(
            original_smiles=reaction_smiles,
            selected_mapping=mapped_reaction_smiles,
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )

    def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
        """
        Map a list of reaction SMILES strings using the identical-fragment mapper.

        Args:
            reaction_list (List[str]): List of reaction SMILES strings to map.

        Returns:
            List[ReactionMapperResult]: The mapping results in the same order as the
                input reactions.
        """
        mapped_reactions: List[ReactionMapperResult] = []
        for reaction in reaction_list:
            mapped_reactions.append(self.map_reaction(reaction))
        return mapped_reactions

create_identical_fragments_mapping_list(reaction_smiles_list)

Strip identical fragments from a list of reactions for downstream mapping.

Parameters:

Name Type Description Default
reaction_smiles_list List[str]

A list of reaction SMILES strings.

required

Returns:

Type Description
Tuple[List[str], List[List[str]]]

Tuple[List[str], List[List[str]]]: - First element: Reaction SMILES with identical fragments removed. - Second element: Per-reaction lists of atom-mapped identical fragment SMILES to be re-added after downstream mapping.

Source code in agave_chem/mappers/identical_fragments/identical_fragment_mapper.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def create_identical_fragments_mapping_list(
    self,
    reaction_smiles_list: List[str],
) -> Tuple[List[str], List[List[str]]]:
    """
    Strip identical fragments from a list of reactions for downstream mapping.

    Args:
        reaction_smiles_list (List[str]): A list of reaction SMILES strings.

    Returns:
        Tuple[List[str], List[List[str]]]:
            - First element: Reaction SMILES with identical fragments removed.
            - Second element: Per-reaction lists of atom-mapped identical
              fragment SMILES to be re-added after downstream mapping.
    """
    new_rxns = []
    identical_fragments_mapping_list = []
    for reaction_smiles in reaction_smiles_list:
        atom_mapped_identical_fragments, new_rxn = (
            self._atom_map_identical_fragments(reaction_smiles)
        )
        identical_fragments_mapping_list.append(atom_mapped_identical_fragments)
        new_rxns.append(new_rxn)
    return new_rxns, identical_fragments_mapping_list

map_reaction(reaction_smiles)

Map a single reaction by atom-mapping its identical fragments.

Parameters:

Name Type Description Default
reaction_smiles str

A reaction SMILES string.

required

Returns:

Name Type Description
ReactionMapperResult ReactionMapperResult

Mapping result. If the input is invalid, an empty default result is returned.

Source code in agave_chem/mappers/identical_fragments/identical_fragment_mapper.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def map_reaction(self, reaction_smiles: str) -> ReactionMapperResult:
    """
    Map a single reaction by atom-mapping its identical fragments.

    Args:
        reaction_smiles (str): A reaction SMILES string.

    Returns:
        ReactionMapperResult: Mapping result. If the input is invalid, an empty
            default result is returned.
    """
    if not self._reaction_smiles_valid(reaction_smiles):
        return self._return_default_mapping_dict(reaction_smiles)

    atom_mapped_fragments, remaining_rxn = self._atom_map_identical_fragments(
        reaction_smiles
    )

    if atom_mapped_fragments:
        mapped_reaction_smiles = self._add_identical_fragments_to_mapping(
            remaining_rxn, atom_mapped_fragments
        )
    else:
        mapped_reaction_smiles = reaction_smiles

    if not self._verify_validity_of_mapping(
        mapped_reaction_smiles, expect_full_mapping=False
    ):
        logger.warning("Invalid mapping")
        return self._return_default_mapping_dict(reaction_smiles)

    return ReactionMapperResult(
        original_smiles=reaction_smiles,
        selected_mapping=mapped_reaction_smiles,
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )

map_reactions(reaction_list)

Map a list of reaction SMILES strings using the identical-fragment mapper.

Parameters:

Name Type Description Default
reaction_list List[str]

List of reaction SMILES strings to map.

required

Returns:

Type Description
List[ReactionMapperResult]

List[ReactionMapperResult]: The mapping results in the same order as the input reactions.

Source code in agave_chem/mappers/identical_fragments/identical_fragment_mapper.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
    """
    Map a list of reaction SMILES strings using the identical-fragment mapper.

    Args:
        reaction_list (List[str]): List of reaction SMILES strings to map.

    Returns:
        List[ReactionMapperResult]: The mapping results in the same order as the
            input reactions.
    """
    mapped_reactions: List[ReactionMapperResult] = []
    for reaction in reaction_list:
        mapped_reactions.append(self.map_reaction(reaction))
    return mapped_reactions

resolve_identical_fragments_mapping_list(mapped_reaction_smiles_list, identical_fragments_mapping_list)

Re-add identical fragments to a list of already-mapped reaction SMILES.

Parameters:

Name Type Description Default
mapped_reaction_smiles_list List[str]

Mapped reaction SMILES strings (from a downstream mapper).

required
identical_fragments_mapping_list List[List[str]]

Per-reaction lists of atom-mapped identical fragment SMILES (produced by create_identical_fragments_mapping_list).

required

Returns:

Type Description
List[str]

List[str]: Final reaction SMILES strings with identical fragments restored.

Source code in agave_chem/mappers/identical_fragments/identical_fragment_mapper.py
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def resolve_identical_fragments_mapping_list(
    self,
    mapped_reaction_smiles_list: List[str],
    identical_fragments_mapping_list: List[List[str]],
) -> List[str]:
    """
    Re-add identical fragments to a list of already-mapped reaction SMILES.

    Args:
        mapped_reaction_smiles_list (List[str]): Mapped reaction SMILES strings
            (from a downstream mapper).
        identical_fragments_mapping_list (List[List[str]]): Per-reaction lists of
            atom-mapped identical fragment SMILES (produced by
            ``create_identical_fragments_mapping_list``).

    Returns:
        List[str]: Final reaction SMILES strings with identical fragments restored.
    """
    final_reactions = []
    for mapped_reaction_smiles, identical_fragments_mapping in zip(
        mapped_reaction_smiles_list, identical_fragments_mapping_list
    ):
        final_reactions.append(
            self._add_identical_fragments_to_mapping(
                mapped_reaction_smiles, identical_fragments_mapping
            )
        )
    return final_reactions

MCSReactionMapper

Bases: ReactionMapper

MCS reaction classification and atom-mapping.

Source code in agave_chem/mappers/mcs/mcs_mapper.py
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
class MCSReactionMapper(ReactionMapper):
    """
    MCS reaction classification and atom-mapping.
    """

    def __init__(self, mapper_name: str, mapper_weight: float = 3):
        super().__init__("mcs", mapper_name, mapper_weight)
        self._uncharger = rdMolStandardize.Uncharger()
        self._tautomer_enumerator = rdMolStandardize.TautomerEnumerator()

    def _normalize_mol(self, mol: Chem.Mol) -> Chem.Mol:
        """Neutralize charges and canonicalize tautomer for matching purposes."""
        mol = Chem.RWMol(mol)  # work on a copy
        mol = self._uncharger.uncharge(mol)
        mol = self._tautomer_enumerator.Canonicalize(mol)
        return mol

    @staticmethod
    def _compute_skeleton(mol: Chem.Mol, atom_idx: int, radius: int) -> Skeleton:
        """
        Compute the canonical bond-environment skeleton for an atom at a given
        radius.

        The skeleton records *which* parent-molecule atoms and bonds participate
        in the environment and in what canonical order, but does **not** encode
        any chemical properties.  Properties are looked up later from a
        ``MolPropertyCache`` when a full ``Fingerprint`` is needed.

        Args:
            mol (Chem.Mol): RDKit molecule.
            atom_idx (int): Index of the root atom in *mol*.
            radius (int): Bond-radius for environment extraction.

        Returns:
            Skeleton: Canonically ordered list of
                ``(begin_atom_idx, bond_idx, end_atom_idx, begin_dist, end_dist)``
                referencing parent-molecule indices and BFS distances from the
                root atom.  Empty list when no bonds are found within the given
                radius.
        """
        bond_ids = Chem.rdmolops.FindAtomEnvironmentOfRadiusN(
            mol,
            radius=radius,
            rootedAtAtom=atom_idx,
        )
        if not bond_ids:
            return []

        amap: Dict[int, int] = {}
        submol = Chem.PathToSubmol(mol, bond_ids, atomMap=amap)

        root_sub = amap[atom_idx]
        for a in submol.GetAtoms():
            a.SetAtomMapNum(0)
        submol.GetAtomWithIdx(root_sub).SetAtomMapNum(1)

        # BFS distances from root in the submol
        dist: Dict[int, int] = {root_sub: 0}
        queue = deque([root_sub])
        while queue:
            cur = queue.popleft()
            for nb in submol.GetAtomWithIdx(cur).GetNeighbors():
                nidx = nb.GetIdx()
                if nidx not in dist:
                    dist[nidx] = dist[cur] + 1
                    queue.append(nidx)

        ranks = list(
            Chem.CanonicalRankAtoms(submol, breakTies=True, includeAtomMaps=True)
        )

        inv_amap = {sub_i: parent_i for parent_i, sub_i in amap.items()}

        bond_items: List[Tuple[Tuple[int, ...], int, int]] = []
        for sb in submol.GetBonds():
            sa1, sa2 = sb.GetBeginAtomIdx(), sb.GetEndAtomIdx()
            r1, r2 = ranks[sa1], ranks[sa2]
            lo, hi = (r1, r2) if r1 <= r2 else (r2, r1)
            key = (
                lo,
                hi,
                int(sb.GetBondTypeAsDouble()),
                int(sb.GetIsAromatic()),
                int(sb.IsInRing()),
            )
            bond_items.append((key, sa1, sa2))

        bond_items.sort(key=lambda x: x[0])

        skeleton: Skeleton = []
        for _, sa1, sa2 in bond_items:
            if ranks[sa1] <= ranks[sa2]:
                s_begin, s_end = sa1, sa2
            else:
                s_begin, s_end = sa2, sa1

            p_begin = inv_amap[s_begin]
            p_end = inv_amap[s_end]
            pbond = mol.GetBondBetweenAtoms(p_begin, p_end)
            skeleton.append(
                (p_begin, pbond.GetIdx(), p_end, dist[s_begin], dist[s_end])
            )

        return skeleton

    @staticmethod
    def _compute_fingerprint(
        cache: MolPropertyCache, skeleton: Skeleton
    ) -> Fingerprint:
        """
        Build a hashable fingerprint from a skeleton and current properties.

        The fingerprint includes atom properties and the current atom-map number
        but **excludes** atom indices, so two environments that are chemically
        identical (modulo index) produce the same fingerprint.

        Args:
            cache (MolPropertyCache): Pre-computed properties for the molecule
                that owns the skeleton.
            skeleton (Skeleton): Canonical list of
                ``(begin_atom_idx, bond_idx, end_atom_idx)``.

        Returns:
            Fingerprint: Hashable nested tuple suitable for equality / hash-based
                matching.
        """
        parts: List[Tuple[int, AtomProps, BondProps, int, AtomProps]] = []
        for begin_idx, bond_idx, end_idx, begin_dist, end_dist in skeleton:
            parts.append(
                (
                    begin_dist,
                    cache.atom_props[begin_idx] + (cache.atom_map_nums[begin_idx],),
                    cache.bond_props[bond_idx],
                    end_dist,
                    cache.atom_props[end_idx] + (cache.atom_map_nums[end_idx],),
                )
            )
        return tuple(sorted(parts))

    def _build_skeletons_at_radius(
        self,
        mols: List[Chem.Mol],
        caches: List[MolPropertyCache],
        radius: int,
        skeletons: Dict[EnvKey, Skeleton],
        fingerprints: Dict[EnvKey, Fingerprint],
        atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
        skip_atoms: Optional[Set[Tuple[int, int]]] = None,
    ) -> None:
        """
        Compute skeletons and fingerprints for every atom in *mols* at a single
        radius and append them to the provided dicts.

        Args:
            mols (List[Chem.Mol]): List of molecules.
            caches (List[MolPropertyCache]): Property caches parallel to *mols*.
            radius (int): Bond-radius to compute.
            skeletons (Dict[EnvKey, Skeleton]): Skeleton dict to extend in place.
            fingerprints (Dict[EnvKey, Fingerprint]): Fingerprint dict to extend
                in place.
            atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Reverse index
                mapping ``(mol_idx, atom_idx)`` to the set of ``EnvKey`` values
                whose skeletons reference that atom.  Extended in place.
            skip_atoms (Optional[Set[Tuple[int, int]]]): Atoms to skip, as
                ``(mol_idx, atom_idx)`` pairs.  Populated by the caller with
                atoms whose environment at a previous radius had no cross-side
                fingerprint match.

        Returns:
            None: All three dicts are mutated in place.
        """
        for mol_idx, mol in enumerate(mols):
            if mol.GetNumAtoms() < radius:
                continue
            cache = caches[mol_idx]
            for atom in mol.GetAtoms():
                aidx = atom.GetIdx()
                if skip_atoms is not None and (mol_idx, aidx) in skip_atoms:
                    continue
                key: EnvKey = (mol_idx, aidx, radius)
                skel = self._compute_skeleton(mol, aidx, radius)
                skeletons[key] = skel
                fingerprints[key] = self._compute_fingerprint(cache, skel)
                for begin_idx, _, end_idx, _, _ in skel:
                    atom_to_entries[(mol_idx, begin_idx)].add(key)
                    atom_to_entries[(mol_idx, end_idx)].add(key)

    @staticmethod
    def _find_matches(
        r_fps: Dict[EnvKey, Fingerprint],
        p_fps: Dict[EnvKey, Fingerprint],
        radius: int,
        min_radius_to_anchor_new_mapping: int = 3,
        require_anchor: bool = False,
    ) -> List[Tuple[EnvKey, EnvKey]]:
        """
        Find reactant–product environment pairs with identical fingerprints at a
        given radius.

        Uses hash-based grouping for O(R + P) complexity instead of the
        previous O(R × P) pairwise comparison.

        Args:
            r_fps (Dict[EnvKey, Fingerprint]): Reactant fingerprints.
            p_fps (Dict[EnvKey, Fingerprint]): Product fingerprints.
            radius (int): Radius to match at.
            min_radius_to_anchor_new_mapping (int): Below this radius, only
                match environments that already contain at least one mapped
                atom (i.e. at least one atom-map number ≠ 0).
            require_anchor (bool): If True, only match environments that
                contain at least one already-mapped atom, regardless of
                radius.  Used during the extend phase to force anchored-only
                matching at every radius level.

        Returns:
            List[Tuple[EnvKey, EnvKey]]: Pairs of
                ``(reactant_key, product_key)`` whose fingerprints are equal.
        """
        need_mapped_check = (
            require_anchor or radius < min_radius_to_anchor_new_mapping - 1
        )

        def _has_mapped_atom(fp: Fingerprint) -> bool:
            for entry in fp:
                # entry: (begin_dist, begin_atom_props, bond_props, end_dist, end_atom_props)
                atom_begin = entry[1]
                atom_end = entry[4]
                if atom_begin[-1] != 0 or atom_end[-1] != 0:
                    return True
            return False

        # Build reactant index: fingerprint → [keys]
        fp_to_r_keys: Dict[Fingerprint, List[EnvKey]] = defaultdict(list)
        for key, fp in r_fps.items():
            if key[2] != radius:
                continue
            if not fp:
                continue
            if need_mapped_check and not _has_mapped_atom(fp):
                continue
            fp_to_r_keys[fp].append(key)

        # Match product fingerprints against the reactant index
        matches: List[Tuple[EnvKey, EnvKey]] = []
        for key, fp in p_fps.items():
            if key[2] != radius:
                continue
            if not fp:
                continue
            if need_mapped_check and not _has_mapped_atom(fp):
                continue
            if fp in fp_to_r_keys:
                for rkey in fp_to_r_keys[fp]:
                    matches.append((rkey, key))

        return matches

    def _build_skeletons_and_find_radius(
        self,
        reactant_mols: List[Chem.Mol],
        product_mols: List[Chem.Mol],
        r_caches: List[MolPropertyCache],
        p_caches: List[MolPropertyCache],
        min_radius: int,
        max_radius: int,
    ) -> Tuple[
        Dict[EnvKey, Skeleton],
        Dict[EnvKey, Skeleton],
        Dict[EnvKey, Fingerprint],
        Dict[EnvKey, Fingerprint],
        Dict[Tuple[int, int], Set[EnvKey]],
        Dict[Tuple[int, int], Set[EnvKey]],
        int,
    ]:
        """
        Incrementally build skeletons/fingerprints radius-by-radius and
        determine the optimal (final) radius for atom-mapping.

        Stops early as soon as a unique match is found or matches disappear.

        Args:
            reactant_mols (List[Chem.Mol]): Reactant molecules.
            product_mols (List[Chem.Mol]): Product molecules.
            r_caches (List[MolPropertyCache]): Reactant property caches.
            p_caches (List[MolPropertyCache]): Product property caches.
            min_radius (int): Minimum bond-radius (inclusive).
            max_radius (int): Maximum bond-radius (exclusive).

        Returns:
            Tuple containing:
                - Reactant skeletons dict.
                - Product skeletons dict.
                - Reactant fingerprints dict.
                - Product fingerprints dict.
                - Reactant atom-to-entries reverse index.
                - Product atom-to-entries reverse index.
                - The final radius selected for the mapping phase.
        """
        r_skeletons: Dict[EnvKey, Skeleton] = {}
        p_skeletons: Dict[EnvKey, Skeleton] = {}
        r_fps: Dict[EnvKey, Fingerprint] = {}
        p_fps: Dict[EnvKey, Fingerprint] = {}
        r_a2e: Dict[Tuple[int, int], Set[EnvKey]] = defaultdict(set)
        p_a2e: Dict[Tuple[int, int], Set[EnvKey]] = defaultdict(set)

        # Atoms whose environment at a previous radius had no cross-side
        # fingerprint match and can be safely skipped at all subsequent radii.
        r_skip: Set[Tuple[int, int]] = set()
        p_skip: Set[Tuple[int, int]] = set()

        final_radius = min_radius
        for radius in range(min_radius, max_radius):
            final_radius = radius

            # Products first, then reactants.
            self._build_skeletons_at_radius(
                product_mols,
                p_caches,
                radius,
                p_skeletons,
                p_fps,
                p_a2e,
                skip_atoms=p_skip,
            )
            self._build_skeletons_at_radius(
                reactant_mols,
                r_caches,
                radius,
                r_skeletons,
                r_fps,
                r_a2e,
                skip_atoms=r_skip,
            )

            matches = self._find_matches(
                r_fps,
                p_fps,
                radius,
                min_radius_to_anchor_new_mapping=0,
            )

            if len(matches) == 1:
                break
            if len(matches) == 0:
                final_radius = radius - 1
                break

            # Prune atoms whose fingerprint at this radius had no match on
            # the other side — they cannot match at any higher radius either.
            p_fp_set: Set[Fingerprint] = set()
            r_fp_set: Set[Fingerprint] = set()
            for key, fp in p_fps.items():
                if key[2] == radius and fp:
                    p_fp_set.add(fp)
            for key, fp in r_fps.items():
                if key[2] == radius and fp:
                    r_fp_set.add(fp)

            for key, fp in r_fps.items():
                if key[2] == radius and (not fp or fp not in p_fp_set):
                    r_skip.add((key[0], key[1]))
            for key, fp in p_fps.items():
                if key[2] == radius and (not fp or fp not in r_fp_set):
                    p_skip.add((key[0], key[1]))

        return r_skeletons, p_skeletons, r_fps, p_fps, r_a2e, p_a2e, final_radius

    def _recompute_affected_fingerprints(
        self,
        mol_idx: int,
        atom_idx: int,
        skeletons: Dict[EnvKey, Skeleton],
        fps: Dict[EnvKey, Fingerprint],
        caches: List[MolPropertyCache],
        atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
    ) -> None:
        """
        Recompute fingerprints for every skeleton entry that references a
        given atom.

        Called after updating an atom's map number so that fingerprints
        reflect the new mapping state.

        Args:
            mol_idx (int): Index of the molecule containing the updated atom.
            atom_idx (int): Index of the updated atom within the molecule.
            skeletons (Dict[EnvKey, Skeleton]): All precomputed skeletons.
            fps (Dict[EnvKey, Fingerprint]): Fingerprint dict to update in
                place.
            caches (List[MolPropertyCache]): Property caches indexed by
                *mol_idx*.
            atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Reverse
                index.

        Returns:
            None: *fps* is mutated in place.
        """
        affected = atom_to_entries.get((mol_idx, atom_idx), set())
        for key in affected:
            if key in skeletons:
                fps[key] = self._compute_fingerprint(
                    caches[key[0]],
                    skeletons[key],
                )

    def _assign_single_mapping(
        self,
        r_key: EnvKey,
        p_key: EnvKey,
        reactant_mols: List[Chem.Mol],
        product_mols: List[Chem.Mol],
        r_caches: List[MolPropertyCache],
        p_caches: List[MolPropertyCache],
        r_skeletons: Dict[EnvKey, Skeleton],
        p_skeletons: Dict[EnvKey, Skeleton],
        r_fps: Dict[EnvKey, Fingerprint],
        p_fps: Dict[EnvKey, Fingerprint],
        r_atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
        p_atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
        product_atom_source: Dict[Tuple[int, int], int],
        atom_map_num: int,
    ) -> int:
        """
        Assign an atom-mapping number to one matched reactant–product atom
        pair.

        Updates molecule objects, property caches, and fingerprints in place.
        Deletes skeleton / fingerprint entries whose root is the newly mapped
        atom (at all radii) so they cannot match again.

        Also enforces adjacency consistency: a mapping is rejected if the
        product atom has an already-mapped neighbor whose mapping originated
        from a different reactant molecule.

        Args:
            r_key (EnvKey): Reactant environment key ``(mol_idx, atom_idx, radius)``.
            p_key (EnvKey): Product environment key ``(mol_idx, atom_idx, radius)``.
            reactant_mols (List[Chem.Mol]): Reactant molecule objects.
            product_mols (List[Chem.Mol]): Product molecule objects.
            r_caches (List[MolPropertyCache]): Reactant property caches.
            p_caches (List[MolPropertyCache]): Product property caches.
            r_skeletons (Dict[EnvKey, Skeleton]): Reactant skeletons (mutated).
            p_skeletons (Dict[EnvKey, Skeleton]): Product skeletons (mutated).
            r_fps (Dict[EnvKey, Fingerprint]): Reactant fingerprints (mutated).
            p_fps (Dict[EnvKey, Fingerprint]): Product fingerprints (mutated).
            r_atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Reactant
                reverse index.
            p_atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Product
                reverse index.
            product_atom_source (Dict[Tuple[int, int], int]): Mapping from
                ``(product_mol_idx, product_atom_idx)`` to the reactant
                molecule index that the atom was mapped from.  Updated in
                place on successful assignment.
            atom_map_num (int): The atom-map number to assign.

        Returns:
            int: The next ``atom_map_num`` to use — incremented by 1 when the
                assignment succeeds, unchanged otherwise.
        """
        r_mol_idx, r_atom_idx, _ = r_key
        p_mol_idx, p_atom_idx, _ = p_key

        # Skip if either root atom is already mapped. Mapping must be assigned
        # atomically to both sides to avoid creating a one-sided mapping state
        # that can prevent downstream anchored matching.
        if reactant_mols[r_mol_idx].GetAtomWithIdx(r_atom_idx).GetAtomMapNum() != 0:
            return atom_map_num
        if product_mols[p_mol_idx].GetAtomWithIdx(p_atom_idx).GetAtomMapNum() != 0:
            return atom_map_num

        # Adjacency consistency: reject if any already-mapped neighbor of the
        # product atom was mapped from a different reactant molecule.
        p_atom = product_mols[p_mol_idx].GetAtomWithIdx(p_atom_idx)
        for nb in p_atom.GetNeighbors():
            nb_key = (p_mol_idx, nb.GetIdx())
            if (
                nb_key in product_atom_source
                and product_atom_source[nb_key] != r_mol_idx
            ):
                return atom_map_num

        # Assign on both molecules
        reactant_mols[r_mol_idx].GetAtomWithIdx(r_atom_idx).SetAtomMapNum(atom_map_num)
        product_mols[p_mol_idx].GetAtomWithIdx(p_atom_idx).SetAtomMapNum(atom_map_num)

        # Update caches
        r_caches[r_mol_idx].atom_map_nums[r_atom_idx] = atom_map_num
        p_caches[p_mol_idx].atom_map_nums[p_atom_idx] = atom_map_num

        # Delete entries whose root IS the newly mapped atom (all radii)
        for k in [k for k in r_skeletons if k[0] == r_mol_idx and k[1] == r_atom_idx]:
            del r_skeletons[k]
            r_fps.pop(k, None)

        for k in [k for k in p_skeletons if k[0] == p_mol_idx and k[1] == p_atom_idx]:
            del p_skeletons[k]
            p_fps.pop(k, None)

        # Recompute fingerprints that reference the mapped atoms
        self._recompute_affected_fingerprints(
            r_mol_idx,
            r_atom_idx,
            r_skeletons,
            r_fps,
            r_caches,
            r_atom_to_entries,
        )
        self._recompute_affected_fingerprints(
            p_mol_idx,
            p_atom_idx,
            p_skeletons,
            p_fps,
            p_caches,
            p_atom_to_entries,
        )

        product_atom_source[(p_mol_idx, p_atom_idx)] = r_mol_idx

        return atom_map_num + 1

    def _assign_atom_map_nums(
        self,
        reactant_mols: List[Chem.Mol],
        product_mols: List[Chem.Mol],
        r_caches: List[MolPropertyCache],
        p_caches: List[MolPropertyCache],
        r_skeletons: Dict[EnvKey, Skeleton],
        p_skeletons: Dict[EnvKey, Skeleton],
        r_fps: Dict[EnvKey, Fingerprint],
        p_fps: Dict[EnvKey, Fingerprint],
        r_atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
        p_atom_to_entries: Dict[Tuple[int, int], Set[EnvKey]],
        final_radius: int,
        min_radius_to_anchor_new_mapping: int,
    ) -> Tuple[List[Chem.Mol], List[Chem.Mol]]:
        """
        Assign atom-map numbers using an anchor-extend strategy.

        The algorithm alternates between two phases:

        1. **Extend phase** — scan all radii (high to low), only matching
           environments that contain at least one already-mapped atom.  Repeat
           the full sweep until no more anchored matches are found.
        2. **New-anchor phase** — find one unanchored match at the highest
           feasible radius (>= *min_radius_to_anchor_new_mapping*) and assign
           it.  Return to the extend phase.

        This ensures that each anchor site is fully propagated before a new
        one is created, preventing gaps caused by split-source mapping.

        Args:
            reactant_mols (List[Chem.Mol]): Reactant molecule objects.
            product_mols (List[Chem.Mol]): Product molecule objects.
            r_caches (List[MolPropertyCache]): Reactant property caches.
            p_caches (List[MolPropertyCache]): Product property caches.
            r_skeletons (Dict[EnvKey, Skeleton]): Reactant skeletons.
            p_skeletons (Dict[EnvKey, Skeleton]): Product skeletons.
            r_fps (Dict[EnvKey, Fingerprint]): Reactant fingerprints.
            p_fps (Dict[EnvKey, Fingerprint]): Product fingerprints.
            r_atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Reactant
                reverse index.
            p_atom_to_entries (Dict[Tuple[int, int], Set[EnvKey]]): Product
                reverse index.
            final_radius (int): Largest radius to start assigning from.
            min_radius_to_anchor_new_mapping (int): Passed through to
                ``_find_matches``.

        Returns:
            Tuple[List[Chem.Mol], List[Chem.Mol]]: The (mutated) reactant and
                product molecule lists with atom-map numbers assigned.
        """
        atom_map_num = 1
        product_atom_source: Dict[Tuple[int, int], int] = {}

        def _try_assign_first_viable(
            matches: List[Tuple[EnvKey, EnvKey]],
        ) -> bool:
            """
            Attempt to assign the first viable match from *matches*.

            Args:
                matches (List[Tuple[EnvKey, EnvKey]]): Candidate
                    reactant-product environment pairs.

            Returns:
                bool: True if a mapping was successfully assigned.
            """
            nonlocal atom_map_num
            for r_key, p_key in matches:
                prev = atom_map_num
                atom_map_num = self._assign_single_mapping(
                    r_key,
                    p_key,
                    reactant_mols,
                    product_mols,
                    r_caches,
                    p_caches,
                    r_skeletons,
                    p_skeletons,
                    r_fps,
                    p_fps,
                    r_atom_to_entries,
                    p_atom_to_entries,
                    product_atom_source,
                    atom_map_num,
                )
                if atom_map_num > prev:
                    return True
            return False

        while True:
            # --- Extend phase: propagate from existing anchors ---
            extended = True
            while extended:
                extended = False
                for radius in range(final_radius, 0, -1):
                    any_at_radius = True
                    while any_at_radius:
                        any_at_radius = False
                        matches = self._find_matches(
                            r_fps,
                            p_fps,
                            radius,
                            min_radius_to_anchor_new_mapping=min_radius_to_anchor_new_mapping,
                            require_anchor=True,
                        )
                        if matches and _try_assign_first_viable(matches):
                            any_at_radius = True
                            extended = True

            # --- New-anchor phase: create one new unanchored mapping ---
            new_anchor = False
            for radius in range(final_radius, 0, -1):
                matches = self._find_matches(
                    r_fps,
                    p_fps,
                    radius,
                    min_radius_to_anchor_new_mapping=min_radius_to_anchor_new_mapping,
                    require_anchor=False,
                )
                if matches and _try_assign_first_viable(matches):
                    new_anchor = True
                    break

            if not new_anchor:
                break

        return reactant_mols, product_mols

    def map_reaction(
        self,
        reaction_smiles: str,
        min_radius: int = 1,
        min_radius_to_anchor_new_mapping: int = 3,
        max_radius: Optional[int] = None,
    ) -> ReactionMapperResult:
        """
        Atom-map a single reaction SMILES using MCS-based environment matching.

        Args:
            reaction_smiles (str): Reaction SMILES of the form
                ``"reactants>>products"``.
            min_radius (int): Smallest bond-radius to consider.
            min_radius_to_anchor_new_mapping (int): Below this radius,
                environments are only matched when they already contain at
                least one mapped atom.
            max_radius (Optional[int]): Largest bond-radius to search.
                Defaults to the size of the largest molecule.

        Returns:
            ReactionMapperResult: Mapping result containing the original and
                mapped SMILES.  Falls back to the default (empty) result on
                invalid input or failed mapping.
        """
        if not self._reaction_smiles_valid(reaction_smiles):
            return self._return_default_mapping_dict(reaction_smiles)

        canonicalized_reaction_smiles = canonicalize_reaction_smiles(
            reaction_smiles,
            canonicalize_tautomer=False,
        )
        reactants_str, products_str = self._split_reaction_components(
            canonicalized_reaction_smiles,
        )

        reactant_mols = [Chem.MolFromSmiles(r) for r in reactants_str.split(".")]
        if None in reactant_mols:
            logger.warning(f"Failed to parse reactant SMILES: {reactants_str}")
            return self._return_default_mapping_dict(reaction_smiles)
        product_mols = [Chem.MolFromSmiles(p) for p in products_str.split(".")]
        if None in product_mols:
            logger.warning(f"Failed to parse product SMILES: {products_str}")
            return self._return_default_mapping_dict(reaction_smiles)

        if not max_radius:
            max_radius = max(mol.GetNumAtoms() for mol in reactant_mols + product_mols)

        # Pre-compute property caches
        r_caches = [MolPropertyCache(m) for m in reactant_mols]
        p_caches = [MolPropertyCache(m) for m in product_mols]

        # Build skeletons / fingerprints incrementally and find optimal radius
        (
            r_skeletons,
            p_skeletons,
            r_fps,
            p_fps,
            r_a2e,
            p_a2e,
            final_radius,
        ) = self._build_skeletons_and_find_radius(
            reactant_mols,
            product_mols,
            r_caches,
            p_caches,
            min_radius,
            max_radius,
        )

        # Assign atom-map numbers from final_radius down to 1
        reactant_mols, product_mols = self._assign_atom_map_nums(
            reactant_mols,
            product_mols,
            r_caches,
            p_caches,
            r_skeletons,
            p_skeletons,
            r_fps,
            p_fps,
            r_a2e,
            p_a2e,
            final_radius,
            min_radius_to_anchor_new_mapping,
        )

        mapped_reactant_smiles = ".".join(
            Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
            for mol in reactant_mols
        )
        mapped_product_smiles = ".".join(
            Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
            for mol in product_mols
        )
        mapped_reaction_smiles = mapped_reactant_smiles + ">>" + mapped_product_smiles

        if not self._verify_validity_of_mapping(
            mapped_reaction_smiles,
            expect_full_mapping=False,
        ):
            logger.warning("Invalid mapping")
            return self._return_default_mapping_dict(reaction_smiles)

        return ReactionMapperResult(
            original_smiles=reaction_smiles,
            selected_mapping=mapped_reaction_smiles,
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )

    def map_reactions(
        self,
        reaction_list: List[str],
        min_radius: int = 1,
        min_radius_to_anchor_new_mapping: int = 3,
    ) -> List[ReactionMapperResult]:
        """
        Map a list of reaction SMILES strings using the MCS mapper.

        Args:
            reaction_list (List[str]): List of reaction SMILES strings to map.
            min_radius (int): Smallest bond-radius to consider.
            min_radius_to_anchor_new_mapping (int): Below this radius,
                environments are only matched when they already contain at
                least one mapped atom.

        Returns:
            List[ReactionMapperResult]: The mapping results in the same order
                as the input reactions.
        """
        mapped_reactions = []
        for reaction in reaction_list:
            mapped_reactions.append(
                self.map_reaction(
                    reaction,
                    min_radius=min_radius,
                    min_radius_to_anchor_new_mapping=min_radius_to_anchor_new_mapping,
                )
            )
        return mapped_reactions

map_reaction(reaction_smiles, min_radius=1, min_radius_to_anchor_new_mapping=3, max_radius=None)

Atom-map a single reaction SMILES using MCS-based environment matching.

Parameters:

Name Type Description Default
reaction_smiles str

Reaction SMILES of the form "reactants>>products".

required
min_radius int

Smallest bond-radius to consider.

1
min_radius_to_anchor_new_mapping int

Below this radius, environments are only matched when they already contain at least one mapped atom.

3
max_radius Optional[int]

Largest bond-radius to search. Defaults to the size of the largest molecule.

None

Returns:

Name Type Description
ReactionMapperResult ReactionMapperResult

Mapping result containing the original and mapped SMILES. Falls back to the default (empty) result on invalid input or failed mapping.

Source code in agave_chem/mappers/mcs/mcs_mapper.py
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def map_reaction(
    self,
    reaction_smiles: str,
    min_radius: int = 1,
    min_radius_to_anchor_new_mapping: int = 3,
    max_radius: Optional[int] = None,
) -> ReactionMapperResult:
    """
    Atom-map a single reaction SMILES using MCS-based environment matching.

    Args:
        reaction_smiles (str): Reaction SMILES of the form
            ``"reactants>>products"``.
        min_radius (int): Smallest bond-radius to consider.
        min_radius_to_anchor_new_mapping (int): Below this radius,
            environments are only matched when they already contain at
            least one mapped atom.
        max_radius (Optional[int]): Largest bond-radius to search.
            Defaults to the size of the largest molecule.

    Returns:
        ReactionMapperResult: Mapping result containing the original and
            mapped SMILES.  Falls back to the default (empty) result on
            invalid input or failed mapping.
    """
    if not self._reaction_smiles_valid(reaction_smiles):
        return self._return_default_mapping_dict(reaction_smiles)

    canonicalized_reaction_smiles = canonicalize_reaction_smiles(
        reaction_smiles,
        canonicalize_tautomer=False,
    )
    reactants_str, products_str = self._split_reaction_components(
        canonicalized_reaction_smiles,
    )

    reactant_mols = [Chem.MolFromSmiles(r) for r in reactants_str.split(".")]
    if None in reactant_mols:
        logger.warning(f"Failed to parse reactant SMILES: {reactants_str}")
        return self._return_default_mapping_dict(reaction_smiles)
    product_mols = [Chem.MolFromSmiles(p) for p in products_str.split(".")]
    if None in product_mols:
        logger.warning(f"Failed to parse product SMILES: {products_str}")
        return self._return_default_mapping_dict(reaction_smiles)

    if not max_radius:
        max_radius = max(mol.GetNumAtoms() for mol in reactant_mols + product_mols)

    # Pre-compute property caches
    r_caches = [MolPropertyCache(m) for m in reactant_mols]
    p_caches = [MolPropertyCache(m) for m in product_mols]

    # Build skeletons / fingerprints incrementally and find optimal radius
    (
        r_skeletons,
        p_skeletons,
        r_fps,
        p_fps,
        r_a2e,
        p_a2e,
        final_radius,
    ) = self._build_skeletons_and_find_radius(
        reactant_mols,
        product_mols,
        r_caches,
        p_caches,
        min_radius,
        max_radius,
    )

    # Assign atom-map numbers from final_radius down to 1
    reactant_mols, product_mols = self._assign_atom_map_nums(
        reactant_mols,
        product_mols,
        r_caches,
        p_caches,
        r_skeletons,
        p_skeletons,
        r_fps,
        p_fps,
        r_a2e,
        p_a2e,
        final_radius,
        min_radius_to_anchor_new_mapping,
    )

    mapped_reactant_smiles = ".".join(
        Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
        for mol in reactant_mols
    )
    mapped_product_smiles = ".".join(
        Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
        for mol in product_mols
    )
    mapped_reaction_smiles = mapped_reactant_smiles + ">>" + mapped_product_smiles

    if not self._verify_validity_of_mapping(
        mapped_reaction_smiles,
        expect_full_mapping=False,
    ):
        logger.warning("Invalid mapping")
        return self._return_default_mapping_dict(reaction_smiles)

    return ReactionMapperResult(
        original_smiles=reaction_smiles,
        selected_mapping=mapped_reaction_smiles,
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )

map_reactions(reaction_list, min_radius=1, min_radius_to_anchor_new_mapping=3)

Map a list of reaction SMILES strings using the MCS mapper.

Parameters:

Name Type Description Default
reaction_list List[str]

List of reaction SMILES strings to map.

required
min_radius int

Smallest bond-radius to consider.

1
min_radius_to_anchor_new_mapping int

Below this radius, environments are only matched when they already contain at least one mapped atom.

3

Returns:

Type Description
List[ReactionMapperResult]

List[ReactionMapperResult]: The mapping results in the same order as the input reactions.

Source code in agave_chem/mappers/mcs/mcs_mapper.py
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def map_reactions(
    self,
    reaction_list: List[str],
    min_radius: int = 1,
    min_radius_to_anchor_new_mapping: int = 3,
) -> List[ReactionMapperResult]:
    """
    Map a list of reaction SMILES strings using the MCS mapper.

    Args:
        reaction_list (List[str]): List of reaction SMILES strings to map.
        min_radius (int): Smallest bond-radius to consider.
        min_radius_to_anchor_new_mapping (int): Below this radius,
            environments are only matched when they already contain at
            least one mapped atom.

    Returns:
        List[ReactionMapperResult]: The mapping results in the same order
            as the input reactions.
    """
    mapped_reactions = []
    for reaction in reaction_list:
        mapped_reactions.append(
            self.map_reaction(
                reaction,
                min_radius=min_radius,
                min_radius_to_anchor_new_mapping=min_radius_to_anchor_new_mapping,
            )
        )
    return mapped_reactions

MappingScorer

Comprehensive scorer for atom-to-atom mappings.

This class computes all the metrics used to evaluate and compare different mapping solutions, following the approach in RDTool.

Metrics include: - Bond energy cost (formation/breaking) - Number of bond changes - Number of fragments affected - Stereochemistry changes - Ring opening/closing events - Overall similarity score

Source code in agave_chem/scoring/scoring.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
class MappingScorer:
    """
    Comprehensive scorer for atom-to-atom mappings.

    This class computes all the metrics used to evaluate and compare
    different mapping solutions, following the approach in RDTool.

    Metrics include:
    - Bond energy cost (formation/breaking)
    - Number of bond changes
    - Number of fragments affected
    - Stereochemistry changes
    - Ring opening/closing events
    - Overall similarity score
    """

    def __init__(
        self,
        energy_penalty_weight: float = 1.0,
        bond_change_weight: float = 10.0,
        fragment_weight: float = 20.0,
        stereo_weight: float = 15.0,
        ring_weight: float = 25.0,
    ):
        """
        Initialize the scorer with custom weights.

        Args:
            energy_penalty_weight: Weight for bond energy cost
            bond_change_weight: Weight for number of bond changes
            fragment_weight: Weight for fragment changes
            stereo_weight: Weight for stereo changes
            ring_weight: Weight for ring changes
        """
        self.weights = {
            "bond_energy_cost": energy_penalty_weight,
            "num_bond_changes": bond_change_weight,
            "num_fragments": fragment_weight,
            "stereo_changes": stereo_weight,
            "ring_changes": ring_weight,
        }

    def _parse_mapped_reaction_smiles(
        self, atom_mapped_rxn_smiles: str
    ) -> Tuple[List[Chem.Mol], List[Chem.Mol], FrozenSet[AtomMapping]]:
        """
        Parse an atom-mapped reaction SMILES into reactants, products, and atom mappings.

        Args:
            atom_mapped_rxn_smiles: Atom-mapped reaction SMILES string
                (e.g., "[CH3:1][OH:2]>>[CH3:1][O:2][H:3]")

        Returns:
            Tuple of (reactant_mols, product_mols, atom_mapping_set)

        Raises:
            ValueError: If the SMILES is invalid, contains duplicate map numbers,
                or a mapped reactant atom has no corresponding product atom.
        """
        parts = atom_mapped_rxn_smiles.strip().split(">>")
        if len(parts) != 2:
            raise ValueError(f"Invalid reaction SMILES: {atom_mapped_rxn_smiles}")

        reactant_smiles_list = [s for s in parts[0].split(".") if s]
        product_smiles_list = [s for s in parts[1].split(".") if s]

        reactants = [Chem.MolFromSmiles(s) for s in reactant_smiles_list]
        products = [Chem.MolFromSmiles(s) for s in product_smiles_list]

        for i, mol in enumerate(reactants):
            if mol is None:
                raise ValueError(
                    f"Could not parse reactant SMILES: {reactant_smiles_list[i]}"
                )
        for i, mol in enumerate(products):
            if mol is None:
                raise ValueError(
                    f"Could not parse product SMILES: {product_smiles_list[i]}"
                )

        # Build map number -> (mol_idx, atom_idx) for reactants
        reactant_map_dict: Dict[int, Tuple[int, int]] = {}
        for mol_idx, mol in enumerate(reactants):
            for atom in mol.GetAtoms():
                map_num = atom.GetAtomMapNum()
                if map_num > 0:
                    if map_num in reactant_map_dict:
                        raise ValueError(
                            f"Duplicate atom map number {map_num} in reactants"
                        )
                    reactant_map_dict[map_num] = (mol_idx, atom.GetIdx())

        # Build map number -> (mol_idx, atom_idx) for products
        product_map_dict: Dict[int, Tuple[int, int]] = {}
        for mol_idx, mol in enumerate(products):
            for atom in mol.GetAtoms():
                map_num = atom.GetAtomMapNum()
                if map_num > 0:
                    if map_num in product_map_dict:
                        raise ValueError(
                            f"Duplicate atom map number {map_num} in products"
                        )
                    product_map_dict[map_num] = (mol_idx, atom.GetIdx())

        # Create mappings by matching map numbers
        mapping_set: Set[AtomMapping] = set()
        for map_num, (r_mol_idx, r_atom_idx) in reactant_map_dict.items():
            if map_num not in product_map_dict:
                continue
                # raise ValueError(
                #     f"Mapped reactant atom {map_num} not found in products"
                # )
            p_mol_idx, p_atom_idx = product_map_dict[map_num]
            mapping_set.add(
                AtomMapping(
                    reactant_mol_idx=r_mol_idx,
                    reactant_atom_idx=r_atom_idx,
                    product_mol_idx=p_mol_idx,
                    product_atom_idx=p_atom_idx,
                )
            )

        return reactants, products, frozenset(mapping_set)

    def score_mapping(
        self,
        atom_mapped_rxn_smiles: str,
    ) -> MappingScore:
        """
        Compute comprehensive score for an atom-mapped reaction SMILES.

        Args:
            atom_mapped_rxn_smiles: Atom-mapped reaction SMILES string
                (e.g., "[CH3:1][OH:2]>>[CH3:1][O:2][H:3]")

        Returns:
            MappingScore object with all metrics
        """
        reactants, products, mapping = self._parse_mapped_reaction_smiles(
            atom_mapped_rxn_smiles
        )
        bond_changes = self.compute_bond_changes(reactants, products, mapping)

        # Count bond changes by type
        num_formed = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.FORMED
        )
        num_broken = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.BROKEN
        )
        num_order_changes = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.ORDER_CHANGE
        )

        # Calculate total energy cost
        energy_cost = sum(bc.energy_cost for bc in bond_changes)

        # Calculate fragment changes
        num_fragments = self._count_fragment_changes(reactants, products, mapping)

        # Calculate stereo changes
        stereo_changes = self._count_stereo_changes(reactants, products, mapping)

        # Calculate ring changes
        ring_changes = self._count_ring_changes(
            reactants, products, mapping, bond_changes
        )

        # Calculate similarity score
        similarity = self._calculate_similarity(reactants, products, mapping)

        return MappingScore(
            bond_energy_cost=energy_cost,
            num_bond_changes=num_formed + num_broken + num_order_changes,
            num_bonds_formed=num_formed,
            num_bonds_broken=num_broken,
            num_fragments=num_fragments,
            stereo_changes=stereo_changes,
            similarity_score=similarity,
            ring_changes=ring_changes,
        )

    def compute_bond_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> List[BondChange]:
        """
        Compute all bond changes in the reaction.

        Args:
            reactants: List of reactant molecules
            products: List of product molecules
            mapping: Set of atom mappings

        Returns:
            List of BondChange objects
        """
        changes = []

        # Create mapping lookup: atom_map_num -> (mol_type, mol_idx, atom_idx)
        # First, assign temporary map numbers
        map_num_counter = 1
        reactant_to_map: Dict[Tuple[int, int], int] = {}
        product_to_map: Dict[Tuple[int, int], int] = {}

        for am in mapping:
            reactant_to_map[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
                map_num_counter
            )
            product_to_map[(am.product_mol_idx, am.product_atom_idx)] = map_num_counter
            map_num_counter += 1

        # Build bond sets for reactants (using map numbers)
        reactant_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
        for mol_idx, mol in enumerate(reactants):
            for bond in mol.GetBonds():
                atom1_key = (mol_idx, bond.GetBeginAtomIdx())
                atom2_key = (mol_idx, bond.GetEndAtomIdx())

                if atom1_key in reactant_to_map and atom2_key in reactant_to_map:
                    map1 = reactant_to_map[atom1_key]
                    map2 = reactant_to_map[atom2_key]
                    bond_order = bond.GetBondTypeAsDouble()

                    atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                    atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                    reactant_bonds[frozenset([map1, map2])] = (
                        bond_order,
                        atom1_sym,
                        atom2_sym,
                    )

        # Build bond sets for products (using map numbers)
        product_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
        for mol_idx, mol in enumerate(products):
            for bond in mol.GetBonds():
                atom1_key = (mol_idx, bond.GetBeginAtomIdx())
                atom2_key = (mol_idx, bond.GetEndAtomIdx())

                if atom1_key in product_to_map and atom2_key in product_to_map:
                    map1 = product_to_map[atom1_key]
                    map2 = product_to_map[atom2_key]
                    bond_order = bond.GetBondTypeAsDouble()

                    atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                    atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                    product_bonds[frozenset([map1, map2])] = (
                        bond_order,
                        atom1_sym,
                        atom2_sym,
                    )

        # Find broken bonds (in reactants but not products)
        for bond_key, (order, sym1, sym2) in reactant_bonds.items():
            if bond_key not in product_bonds:
                map1, map2 = sorted(bond_key)
                energy = get_bond_energy(sym1, sym2, order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.BROKEN,
                        old_order=order,
                        new_order=None,
                        energy_cost=energy,
                    )
                )

        # Find formed bonds (in products but not reactants)
        for bond_key, (order, sym1, sym2) in product_bonds.items():
            if bond_key not in reactant_bonds:
                map1, map2 = sorted(bond_key)
                energy = get_bond_energy(sym1, sym2, order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.FORMED,
                        old_order=None,
                        new_order=order,
                        energy_cost=energy,
                    )
                )

        # Find order changes (in both but different order)
        for bond_key in reactant_bonds.keys() & product_bonds.keys():
            r_order, r_sym1, r_sym2 = reactant_bonds[bond_key]
            p_order, _, _ = product_bonds[bond_key]

            if abs(r_order - p_order) > 0.1:
                map1, map2 = sorted(bond_key)
                # Energy change for order modification
                old_energy = get_bond_energy(r_sym1, r_sym2, r_order)
                new_energy = get_bond_energy(r_sym1, r_sym2, p_order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.ORDER_CHANGE,
                        old_order=r_order,
                        new_order=p_order,
                        energy_cost=abs(new_energy - old_energy),
                    )
                )

        return changes

    def _count_fragment_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> int:
        """Count the number of molecular fragments that change."""
        # Simple heuristic: count difference in number of molecules
        return abs(len(reactants) - len(products))

    def _count_stereo_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> int:
        """Count stereochemistry changes in the reaction."""
        changes = 0

        # Build lookup for mapping
        mapping_lookup: Dict[Tuple[int, int], Tuple[int, int]] = {}
        for am in mapping:
            mapping_lookup[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
                am.product_mol_idx,
                am.product_atom_idx,
            )

        # Check each mapped atom for stereo changes
        for am in mapping:
            r_mol = reactants[am.reactant_mol_idx]
            p_mol = products[am.product_mol_idx]

            r_atom = r_mol.GetAtomWithIdx(am.reactant_atom_idx)
            p_atom = p_mol.GetAtomWithIdx(am.product_atom_idx)

            # Check chiral tag
            r_chiral = r_atom.GetChiralTag()
            p_chiral = p_atom.GetChiralTag()

            if r_chiral != p_chiral:
                changes += 1

        return changes

    def _count_ring_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
        bond_changes: List[BondChange],
    ) -> int:
        """Count ring opening and closing events."""
        changes = 0

        # Create map number to molecule info lookup
        map_to_reactant: Dict[int, Tuple[int, int]] = {}
        map_to_product: Dict[int, Tuple[int, int]] = {}

        map_num = 1
        for am in mapping:
            map_to_reactant[map_num] = (am.reactant_mol_idx, am.reactant_atom_idx)
            map_to_product[map_num] = (am.product_mol_idx, am.product_atom_idx)
            map_num += 1

        # Get ring info for all molecules
        reactant_rings: List[Set[int]] = []
        for mol in reactants:
            reactant_rings.extend(get_ring_info(mol))

        product_rings: List[Set[int]] = []
        for mol in products:
            product_rings.extend(get_ring_info(mol))

        # Check each bond change for ring involvement
        for bc in bond_changes:
            if bc.atom1_map in map_to_reactant and bc.atom2_map in map_to_reactant:
                r_info1 = map_to_reactant.get(bc.atom1_map)
                r_info2 = map_to_reactant.get(bc.atom2_map)

                if r_info1 and r_info2:
                    r_mol_idx1, r_atom_idx1 = r_info1
                    r_mol_idx2, r_atom_idx2 = r_info2

                    # Check if both atoms were in same ring
                    if r_mol_idx1 == r_mol_idx2:
                        mol = reactants[r_mol_idx1]
                        atom1 = mol.GetAtomWithIdx(r_atom_idx1)
                        atom2 = mol.GetAtomWithIdx(r_atom_idx2)

                        if atom1.IsInRing() and atom2.IsInRing():
                            if bc.change_type == BondChangeType.BROKEN:
                                changes += 1  # Ring opening
                            elif bc.change_type == BondChangeType.FORMED:
                                changes += 1  # Ring closing

        return changes

    def _calculate_similarity(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> float:
        """
        Calculate overall similarity based on the mapping.

        Returns fraction of atoms that are successfully mapped.
        """
        total_reactant_atoms = sum(mol.GetNumAtoms() for mol in reactants)
        total_product_atoms = sum(mol.GetNumAtoms() for mol in products)

        if total_reactant_atoms == 0 or total_product_atoms == 0:
            return 0.0

        mapped_atoms = len(mapping)

        # Similarity as fraction of atoms mapped
        reactant_coverage = mapped_atoms / total_reactant_atoms
        product_coverage = mapped_atoms / total_product_atoms

        return (reactant_coverage + product_coverage) / 2

__init__(energy_penalty_weight=1.0, bond_change_weight=10.0, fragment_weight=20.0, stereo_weight=15.0, ring_weight=25.0)

Initialize the scorer with custom weights.

Parameters:

Name Type Description Default
energy_penalty_weight float

Weight for bond energy cost

1.0
bond_change_weight float

Weight for number of bond changes

10.0
fragment_weight float

Weight for fragment changes

20.0
stereo_weight float

Weight for stereo changes

15.0
ring_weight float

Weight for ring changes

25.0
Source code in agave_chem/scoring/scoring.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def __init__(
    self,
    energy_penalty_weight: float = 1.0,
    bond_change_weight: float = 10.0,
    fragment_weight: float = 20.0,
    stereo_weight: float = 15.0,
    ring_weight: float = 25.0,
):
    """
    Initialize the scorer with custom weights.

    Args:
        energy_penalty_weight: Weight for bond energy cost
        bond_change_weight: Weight for number of bond changes
        fragment_weight: Weight for fragment changes
        stereo_weight: Weight for stereo changes
        ring_weight: Weight for ring changes
    """
    self.weights = {
        "bond_energy_cost": energy_penalty_weight,
        "num_bond_changes": bond_change_weight,
        "num_fragments": fragment_weight,
        "stereo_changes": stereo_weight,
        "ring_changes": ring_weight,
    }

compute_bond_changes(reactants, products, mapping)

Compute all bond changes in the reaction.

Parameters:

Name Type Description Default
reactants List[Mol]

List of reactant molecules

required
products List[Mol]

List of product molecules

required
mapping FrozenSet[AtomMapping]

Set of atom mappings

required

Returns:

Type Description
List[BondChange]

List of BondChange objects

Source code in agave_chem/scoring/scoring.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def compute_bond_changes(
    self,
    reactants: List[Chem.Mol],
    products: List[Chem.Mol],
    mapping: FrozenSet[AtomMapping],
) -> List[BondChange]:
    """
    Compute all bond changes in the reaction.

    Args:
        reactants: List of reactant molecules
        products: List of product molecules
        mapping: Set of atom mappings

    Returns:
        List of BondChange objects
    """
    changes = []

    # Create mapping lookup: atom_map_num -> (mol_type, mol_idx, atom_idx)
    # First, assign temporary map numbers
    map_num_counter = 1
    reactant_to_map: Dict[Tuple[int, int], int] = {}
    product_to_map: Dict[Tuple[int, int], int] = {}

    for am in mapping:
        reactant_to_map[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
            map_num_counter
        )
        product_to_map[(am.product_mol_idx, am.product_atom_idx)] = map_num_counter
        map_num_counter += 1

    # Build bond sets for reactants (using map numbers)
    reactant_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
    for mol_idx, mol in enumerate(reactants):
        for bond in mol.GetBonds():
            atom1_key = (mol_idx, bond.GetBeginAtomIdx())
            atom2_key = (mol_idx, bond.GetEndAtomIdx())

            if atom1_key in reactant_to_map and atom2_key in reactant_to_map:
                map1 = reactant_to_map[atom1_key]
                map2 = reactant_to_map[atom2_key]
                bond_order = bond.GetBondTypeAsDouble()

                atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                reactant_bonds[frozenset([map1, map2])] = (
                    bond_order,
                    atom1_sym,
                    atom2_sym,
                )

    # Build bond sets for products (using map numbers)
    product_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
    for mol_idx, mol in enumerate(products):
        for bond in mol.GetBonds():
            atom1_key = (mol_idx, bond.GetBeginAtomIdx())
            atom2_key = (mol_idx, bond.GetEndAtomIdx())

            if atom1_key in product_to_map and atom2_key in product_to_map:
                map1 = product_to_map[atom1_key]
                map2 = product_to_map[atom2_key]
                bond_order = bond.GetBondTypeAsDouble()

                atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                product_bonds[frozenset([map1, map2])] = (
                    bond_order,
                    atom1_sym,
                    atom2_sym,
                )

    # Find broken bonds (in reactants but not products)
    for bond_key, (order, sym1, sym2) in reactant_bonds.items():
        if bond_key not in product_bonds:
            map1, map2 = sorted(bond_key)
            energy = get_bond_energy(sym1, sym2, order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.BROKEN,
                    old_order=order,
                    new_order=None,
                    energy_cost=energy,
                )
            )

    # Find formed bonds (in products but not reactants)
    for bond_key, (order, sym1, sym2) in product_bonds.items():
        if bond_key not in reactant_bonds:
            map1, map2 = sorted(bond_key)
            energy = get_bond_energy(sym1, sym2, order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.FORMED,
                    old_order=None,
                    new_order=order,
                    energy_cost=energy,
                )
            )

    # Find order changes (in both but different order)
    for bond_key in reactant_bonds.keys() & product_bonds.keys():
        r_order, r_sym1, r_sym2 = reactant_bonds[bond_key]
        p_order, _, _ = product_bonds[bond_key]

        if abs(r_order - p_order) > 0.1:
            map1, map2 = sorted(bond_key)
            # Energy change for order modification
            old_energy = get_bond_energy(r_sym1, r_sym2, r_order)
            new_energy = get_bond_energy(r_sym1, r_sym2, p_order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.ORDER_CHANGE,
                    old_order=r_order,
                    new_order=p_order,
                    energy_cost=abs(new_energy - old_energy),
                )
            )

    return changes

score_mapping(atom_mapped_rxn_smiles)

Compute comprehensive score for an atom-mapped reaction SMILES.

Parameters:

Name Type Description Default
atom_mapped_rxn_smiles str

Atom-mapped reaction SMILES string (e.g., "[CH3:1][OH:2]>>[CH3:1][O:2][H:3]")

required

Returns:

Type Description
MappingScore

MappingScore object with all metrics

Source code in agave_chem/scoring/scoring.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def score_mapping(
    self,
    atom_mapped_rxn_smiles: str,
) -> MappingScore:
    """
    Compute comprehensive score for an atom-mapped reaction SMILES.

    Args:
        atom_mapped_rxn_smiles: Atom-mapped reaction SMILES string
            (e.g., "[CH3:1][OH:2]>>[CH3:1][O:2][H:3]")

    Returns:
        MappingScore object with all metrics
    """
    reactants, products, mapping = self._parse_mapped_reaction_smiles(
        atom_mapped_rxn_smiles
    )
    bond_changes = self.compute_bond_changes(reactants, products, mapping)

    # Count bond changes by type
    num_formed = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.FORMED
    )
    num_broken = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.BROKEN
    )
    num_order_changes = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.ORDER_CHANGE
    )

    # Calculate total energy cost
    energy_cost = sum(bc.energy_cost for bc in bond_changes)

    # Calculate fragment changes
    num_fragments = self._count_fragment_changes(reactants, products, mapping)

    # Calculate stereo changes
    stereo_changes = self._count_stereo_changes(reactants, products, mapping)

    # Calculate ring changes
    ring_changes = self._count_ring_changes(
        reactants, products, mapping, bond_changes
    )

    # Calculate similarity score
    similarity = self._calculate_similarity(reactants, products, mapping)

    return MappingScore(
        bond_energy_cost=energy_cost,
        num_bond_changes=num_formed + num_broken + num_order_changes,
        num_bonds_formed=num_formed,
        num_bonds_broken=num_broken,
        num_fragments=num_fragments,
        stereo_changes=stereo_changes,
        similarity_score=similarity,
        ring_changes=ring_changes,
    )

NeuralReactionMapper

Bases: ReactionMapper

Neural network-based reaction atom-mapping

Source code in agave_chem/mappers/neural/neural_mapper.py
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
class NeuralReactionMapper(ReactionMapper):
    """
    Neural network-based reaction atom-mapping
    """

    def __init__(
        self,
        mapper_name: str,
        mapper_weight: float = 3,
        checkpoint_path: Optional[str] = None,
        use_supervised: bool = True,
        supervised_config: SupervisedConfig | None = None,
        sequence_max_length: int = 512,
    ):
        """
        Initialize the NeuralReactionMapper instance.

        Args:
            mapper_name (str): The name of the mapper.
            mapper_weight (float): The weight of the mapper.
            checkpoint_path (Optional[str]): The path to the checkpoint file.
        """

        super().__init__("neural", mapper_name, mapper_weight)

        if not checkpoint_path:
            checkpoint_path = str(
                files("agave_chem.datafiles.models").joinpath("supervised_albert_model")
            )

        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._sequence_max_length = sequence_max_length
        self._use_supervised = use_supervised
        self._supervised_config = supervised_config or SupervisedConfig()

        self._model = load_neural_albert_model(
            checkpoint_dir=checkpoint_path,
            device=self._device,
            use_supervised=use_supervised,
            max_length=sequence_max_length,
            supervised_config=self._supervised_config,
        )

        self._tokenizer = CustomTokenizer(smiles_token_to_id_dict)

    def _encode_atom(self, atom: Chem.Atom) -> List[int]:
        """
        Encode an RDKit Atom object into a list of integers.

        The encoding is as follows:
        - z: The atomic number of the atom.
        - chg: The formal charge of the atom.
        - arom: 1 if the atom is aromatic, 0 otherwise.
        - ring: 1 if the atom is in a ring, 0 otherwise.
        - h: The total number of hydrogen atoms bonded to the atom.
        - d: The degree of the atom.

        Args:
            atom (Chem.Atom): The RDKit Atom object to encode.

        Returns:
            List[int]: A list of integers encoding the atom.
        """
        z = atom.GetAtomicNum()
        chg = atom.GetFormalCharge()
        arom = 1 if atom.GetIsAromatic() else 0
        ring = 1 if atom.IsInRing() else 0
        h = atom.GetTotalNumHs()
        d = atom.GetDegree()
        return [z, chg, arom, ring, h, d]

    def get_attention_matrix_for_head(
        self,
        text: str,
        layer: int,
        head: int,
        max_length: int = 512,
        trim_padding: bool = True,
    ) -> Tuple[np.ndarray, List[str]]:
        """
        Returns the attention matrix for a given layer/head for a single input string.

        Args:
            text: input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)
            layer: 0-based layer index
            head: 0-based head index
            max_length: tokenization length (should match training, e.g. 256)
            trim_padding: if True, slices matrix down to non-pad tokens only

        Returns:
            attn: Tensor of shape (seq_len, seq_len) (trimmed if requested)
            tokens: list[str] tokens aligned to attn axes (trimmed if requested)
        """
        self._model.eval()

        enc = self._tokenizer(
            text,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(self._device)
        attention_mask = enc["attention_mask"].to(self._device)

        token_type_ids = enc.get("token_type_ids", torch.zeros_like(enc["input_ids"]))
        token_type_ids = token_type_ids.to(self._device)

        with torch.no_grad():
            if isinstance(self._model, AlbertWithAttentionAlignment):
                attn_probs = self._model.predict_attention_probs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                )  # (B,S,S)
                attn = attn_probs[0].detach().cpu()  # (S,S)
            else:
                outputs = self._model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_attentions=True,
                    return_dict=True,
                )
                attentions = outputs.attentions  # tuple[num_layers] of (B,H,S,S)

                if layer < 0 or layer >= len(attentions):
                    raise ValueError(
                        f"layer must be in [0, {len(attentions) - 1}], got {layer}"
                    )

                num_heads = attentions[layer].shape[1]
                if head < 0 or head >= num_heads:
                    raise ValueError(
                        f"head must be in [0, {num_heads - 1}], got {head}"
                    )

                attn = attentions[layer][0, head].detach().cpu()  # (S,S)

        # Tokens for inspection/plotting
        token_ids = enc["input_ids"][0].tolist()
        tokens = self._tokenizer.convert_ids_to_tokens(token_ids)

        if trim_padding:
            real_len = int(enc["attention_mask"][0].sum().item())
            attn = attn[:real_len, :real_len]
            tokens = tokens[:real_len]

        # IMPORTANT: keep downstream behavior identical by returning log-attn
        return torch.log(attn).numpy(), tokens

    def get_reactants_products_dict(
        self,
        tokens: List[str],
    ) -> StringInfoDict:
        """
        Extracts reactants and products from a list of tokens in a reaction SMILES string.

        Args:
            tokens: A list of tokens in a reaction SMILES string.

        Returns:
            A tuple containing:
                reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
                products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
                atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices.
                non_atom_tokens: A list of token indices that correspond to non-atom tokens.
                reactants_start_index: The index of the first reactant token.
                reactants_end_index: The index of the last reactant token.
                products_start_index: The index of the first product token.
                products_end_index: The index of the last product token.
        """
        reactants_dict: Dict[int, str] = {}
        products_dict: Dict[int, str] = {}
        atom_tokens_dict: Dict[int, List[int]] = {}
        non_atom_tokens: List[int] = []

        found_reaction_symbol = False
        for i, token in enumerate(tokens):
            if token == ">>":
                found_reaction_symbol = True
                non_atom_tokens.append(i)
                continue
            if token_atom_identity_dict.get(token, 0) == 0:
                non_atom_tokens.append(i)
            else:
                if token_atom_identity_dict.get(token, 0) not in atom_tokens_dict:
                    atom_tokens_dict[token_atom_identity_dict.get(token, 0)] = [i]
                else:
                    atom_tokens_dict[token_atom_identity_dict.get(token, 0)].append(i)
            if found_reaction_symbol:
                products_dict[i] = token
            else:
                reactants_dict[i] = token

        string_info_dict: StringInfoDict = {
            "reactants_dict": reactants_dict,
            "products_dict": products_dict,
            "reactants_start_index": 0,
            "reactants_end_index": max(reactants_dict.keys()),
            "products_start_index": min(products_dict.keys()),
            "products_end_index": max(products_dict.keys()),
            "atom_tokens_dict": atom_tokens_dict,
            "non_atom_tokens": non_atom_tokens,
        }

        return string_info_dict

    def mask_attn_matrix(
        self,
        attn: np.ndarray,
        string_info_dict: StringInfoDict,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Masks the attention matrix to set the attention probability for certain tokens to 0.

        Args:
            attn: The attention matrix to be masked.
            reactants_start_index: The index of the first reactant token.
            reactants_end_index: The index of the last reactant token.
            products_start_index: The index of the first product token.
            products_end_index: The index of the last product token.
            non_atom_tokens: A list of indices of non-atom tokens.
            atom_tokens_dict: A dictionary mapping atom numbers to a list of token indices.

        Returns:
            The masked attention matrix.
        """
        attn[
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
        ] = -1e6  # Set attention logits for reactant tokens to other reactant tokens to very small value
        attn[
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
        ] = -1e6  # Set attention logits for product tokens to other product tokens to very small value
        for i in string_info_dict[
            "non_atom_tokens"
        ][
            :-1
        ]:  # Set attention logits for reactant or product tokens to non-atom tokens to very small value
            attn[i] = -1e6
            attn[:, i] = -1e6

        for token_indices in string_info_dict[
            "atom_tokens_dict"
        ].values():  # Set attention logits for reactant and product tokens of different atom numbers to very small value
            idx = np.asarray(token_indices, dtype=np.int64)
            last = attn.shape[0] - 1
            idx = idx[idx != last]  # protect last row/column from mask

            diff_atom_mask = np.ones(attn.shape[1], dtype=bool)
            diff_atom_mask[idx] = False
            diff_atom_mask[last] = False  # protect last row/column from mask

            attn[np.ix_(idx, diff_atom_mask)] = -1e6
            attn[np.ix_(diff_atom_mask, idx)] = -1e6

        row_max = np.max(attn, axis=1, keepdims=True)  # max per row
        exp_logits = np.exp(attn - row_max)
        probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

        probs[
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
        ] = 0  # Set attention probability for reactant tokens to other reactant tokens to 0
        probs[
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
        ] = 0  # Set attention probability for product tokens to other product tokens to 0
        for i in string_info_dict[
            "non_atom_tokens"
        ][
            :-1
        ]:  # Set attention probability for reactant or product tokens to non-atom tokens to 0
            probs[i] = 0
            probs[:, i] = 0

        for token_indices in string_info_dict[
            "atom_tokens_dict"
        ].values():  # Set attention probability for reactant and product tokens of different atom numbers to 0
            idx = np.asarray(token_indices, dtype=np.int64)

            diff_atom_mask = np.ones(probs.shape[1], dtype=bool)
            diff_atom_mask[idx] = False
            probs[np.ix_(idx, diff_atom_mask)] = 0
            probs[np.ix_(diff_atom_mask, idx)] = 0

        return probs, exp_logits

    def get_aligned_attn_scores(
        self,
        out: np.ndarray,
        reactants_start_index: int,
        reactants_end_index: int,
        products_start_index: int,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Extract and align cross-attention scores between reactant and product tokens.

        Slices the full attention matrix `out` to obtain two cross-attention
        sub-matrices: one representing how each product token attends to reactant
        tokens, and one representing how each reactant token attends to product
        tokens. The latter is transposed so that both returned arrays share the
        same index orientation (rows = product tokens, columns = reactant tokens).

        Args:
            out (np.ndarray): Square attention probability matrix of shape
                ``(sequence_length, sequence_length)``, where entry ``[i, j]``
                is the attention weight from token ``i`` to token ``j``.
            reactants_start_index (int): Index of the first reactant atom token
                in the sequence.
            reactants_end_index (int): Index of the last reactant atom token
                in the sequence (inclusive).
            products_start_index (int): Index of the first product token in the
                sequence; all tokens from this index onward are product tokens.

        Returns:
            Tuple[np.ndarray, np.ndarray]:
                - **products_to_reactants_attn** (np.ndarray): Sub-matrix of shape
                  ``(n_product_tokens, n_reactant_tokens)`` giving the attention
                  weights from each product token to each reactant token.
                - **reactants_to_products_attn** (np.ndarray): Transposed sub-matrix
                  of shape ``(n_product_tokens, n_reactant_tokens)`` giving the
                  attention weights from each reactant token to each product token,
                  transposed so that rows correspond to product tokens and columns
                  to reactant tokens, matching the orientation of
                  ``products_to_reactants_attn``.
        """
        products_to_reactants_attn = out[
            products_start_index:,
            reactants_start_index : reactants_end_index + 1,
        ]  # products to reactants attention
        reactants_to_products_attn = out[
            reactants_start_index : reactants_end_index + 1, products_start_index:
        ].T  # reactants to products attention, transposed so indices align
        return products_to_reactants_attn, reactants_to_products_attn

    def remove_non_atom_rows_and_columns(
        self, attn: np.ndarray, string_info_dict: StringInfoDict
    ) -> np.ndarray:
        """
        Remove non-atom tokens from attention matrix.

        Args:
            attn (np.ndarray): The attention matrix.
            string_info_dict (StringInfoDict): A dictionary containing information about the tokens in the reaction SMILES string.

        Returns:
            np.ndarray: The attention matrix with non-atom tokens removed.
        """
        reactants_non_atom_tokens = [
            ele
            for ele in string_info_dict["non_atom_tokens"]
            if ele <= string_info_dict["reactants_end_index"]
        ]  # Get non-atom tokens in reactants
        products_non_atom_tokens = [
            ele - string_info_dict["products_start_index"]
            for ele in string_info_dict["non_atom_tokens"]
            if ele >= string_info_dict["products_start_index"]
        ]  # Get non-atom tokens in products with offset of products start index

        idx = np.asarray(reactants_non_atom_tokens, dtype=int)
        attn = np.delete(attn, idx, axis=1)

        idx = np.asarray(products_non_atom_tokens, dtype=int)
        attn = np.delete(attn, idx, axis=0)

        return attn

    def get_duplicate_indices(
        self, list_of_lists: List[List[int]]
    ) -> Dict[int, List[int]]:
        """
        Find indices of duplicate values across a list of sublists using globally
        offset indices.

        For each element, returns a mapping to all other elements in the same
        sublist that share the same value. Elements without duplicates are omitted.

        Args:
            list_of_lists (List[List[int]]): A list of sublists, where each sublist
                contains integer values (e.g., canonical atom ranks per molecule).

        Returns:
            Dict[int, List[int]]: A dictionary mapping each globally-offset index
                to a list of other globally-offset indices within the same sublist
                that share the same value. Only indices with at least one duplicate
                are included.
        """
        result = {}
        offset = 0

        for sublist in list_of_lists:
            # Group flattened indices by value within this sublist
            value_to_indices = defaultdict(list)
            for i, val in enumerate(sublist):
                value_to_indices[val].append(offset + i)

            # For each item, map to OTHER items with same value in the same sublist
            for i, val in enumerate(sublist):
                flat_idx = offset + i
                others = [idx for idx in value_to_indices[val] if idx != flat_idx]
                if others:  # only include entries that actually have duplicates
                    result[flat_idx] = others

            offset += len(sublist)

        return result

    def _build_atom_dict(
        self, mols: List[Chem.Mol]
    ) -> Tuple[Dict[int, Chem.Atom], Dict[int, List[Tuple[int, List[int]]]]]:
        """
        Build a global atom dictionary and neighbor dictionary from a list of molecules.

        Iterates over each molecule in order, assigning globally unique atom indices
        that are contiguous across all molecules. Neighbors are stored as
        (global_atom_index, atom_feature_vector) pairs.

        Args:
            mols (List[Chem.Mol]): A list of RDKit molecule objects to process.

        Returns:
            Tuple[Dict[int, Chem.Atom], Dict[int, List[Tuple[int, List[int]]]]]:
                - First dict: Maps global atom index to its RDKit Atom object.
                - Second dict: Maps global atom index to a list of
                  (global_neighbor_index, encoded_neighbor) tuples for each
                  neighboring atom.
        """
        atom_dict: Dict[int, Chem.Atom] = {}
        atom_dict_neighbors: Dict[int, List[Tuple[int, List[int]]]] = {}
        global_atom_num = 0
        for mol in mols:
            mol_atom_dict: Dict[int, Chem.Atom] = {}
            mol_idx_to_atom_num: Dict[int, int] = {}
            for atom in mol.GetAtoms():
                mol_atom_dict[global_atom_num] = atom
                mol_idx_to_atom_num[atom.GetIdx()] = global_atom_num
                global_atom_num += 1
            for atom_num, atom in mol_atom_dict.items():
                atom_dict_neighbors[atom_num] = [
                    (
                        mol_idx_to_atom_num[neighbor.GetIdx()],
                        self._encode_atom(neighbor),
                    )
                    for neighbor in atom.GetNeighbors()
                ]
            atom_dict.update(mol_atom_dict)
        return atom_dict, atom_dict_neighbors

    def _get_symmetric_atom_indices(self, mols: List[Chem.Mol]) -> Dict[int, List[int]]:
        """
        Identify sets of topologically equivalent atoms across a list of molecules.

        All molecules are combined into a single disconnected graph via
        Chem.CombineMols before ranking, so that canonical ranks are assigned
        globally. This means two atoms are considered symmetric if they are
        topologically equivalent either within the same molecule (intra-molecular
        symmetry, e.g., ortho carbons in benzene) or across identical fragment
        molecules (inter-molecular symmetry, e.g., corresponding atoms in two
        identical benzaldehyde reactants).

        Args:
            mols (List[Chem.Mol]): A list of RDKit molecule objects.

        Returns:
            Dict[int, List[int]]: A mapping from each globally-offset atom index
                to a list of other globally-offset atom indices that are
                topologically equivalent. Only atoms with at least one symmetric
                partner are included.
        """
        ranks = []
        seen_smiles_and_symmetry_classes: Dict[str, List[int]] = {}
        for i, mol in enumerate(mols):
            if Chem.MolToSmiles(mol) in seen_smiles_and_symmetry_classes:
                mol_symmetry_classes = seen_smiles_and_symmetry_classes[
                    Chem.MolToSmiles(mol)
                ]
            else:
                mol_symmetry_classes = get_symmetry_class_from_mol(mol)
                mol_symmetry_classes = [
                    ele + (i + 1) * 1000 for ele in mol_symmetry_classes
                ]
                seen_smiles_and_symmetry_classes[Chem.MolToSmiles(mol)] = (
                    mol_symmetry_classes
                )

            ranks.extend(mol_symmetry_classes)
        return self.get_duplicate_indices([ranks])

    def _apply_symmetric_attention(
        self,
        attn: np.ndarray,
        symmetric_indices: Dict[int, List[int]],
        axis: int,
    ) -> np.ndarray:
        """
        Sum attention scores for topologically equivalent (symmetric) atoms.

        For each group of symmetric atoms, replaces each member's attention slice
        (row or column) with the summed values across the group. This prevents
        symmetric atoms from receiving artificially low attention scores caused by
        probability mass being split equally among equivalent positions.

        Sums are computed from the input array before any modifications are applied,
        ensuring groups do not double-count one another.

        Args:
            attn (np.ndarray): Attention matrix of shape
                (n_product_atoms, n_reactant_atoms).
            symmetric_indices (Dict[int, List[int]]): Output of
                _get_symmetric_atom_indices — maps each atom index to its
                symmetric partners.
            axis (int): Axis along which to aggregate. Use 1 for reactant atoms
                (columns) and 0 for product atoms (rows).

        Returns:
            np.ndarray: A copy of attn with symmetric atom slices replaced by
                their group sum. The original array is not modified.
        """
        identical_groups: List[Tuple[int, ...]] = list(
            {tuple(sorted([k] + v)) for k, v in symmetric_indices.items()}
        )

        result = attn.copy()
        new_val_mapping: Dict[int, np.ndarray] = {}
        for group in identical_groups:
            idx = list(group)
            if axis == 1:
                summed = np.sum(attn[:, idx], axis=1)
                for i in idx:
                    new_val_mapping[i] = summed
            else:
                summed = np.sum(attn[idx, :], axis=0)
                for i in idx:
                    new_val_mapping[i] = summed

        for i, val in new_val_mapping.items():
            if axis == 1:
                result[:, i] = val
            else:
                result[i, :] = val

        return result

    def assign_atom_maps(
        self,
        rxn_smiles: str,
        aligned_attn_scores: Tuple[np.ndarray, np.ndarray],
        one_to_one_correspondence: bool = True,
        adjacent_atom_multiplier: float = 30,
        identical_adjacent_atom_multiplier: float = 10,
        used_atom_divisor: float = 10,
        reactants_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
        products_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
    ) -> Tuple[str, float, Dict[str, int]]:
        """
        Assign atom-to-atom map numbers to a reaction SMILES using a pre-computed
        attention matrix.

        Handles symmetric atoms in both reactants and products by summing their
        attention contributions, preventing artificially low confidence scores
        caused by equivalent atoms splitting probability mass.

        Args:
            rxn_smiles (str): Unmapped reaction SMILES string of the form
                "reactants>>products".
            aligned_attn_scores (Tuple[np.ndarray, np.ndarray]): Tuple of attention matrices
                of shape (n_product_atoms, n_reactant_atoms).
            one_to_one_correspondence (bool): If True, enforces a one-to-one
                assignment using greedy selection of the global attention maximum.
                If False, assigns each product atom independently to its
                highest-attention reactant atom.
            adjacent_atom_multiplier (float): Multiplier applied to attention
                scores of atoms neighboring an already-mapped pair.
            identical_adjacent_atom_multiplier (float): Additional multiplier
                applied when a neighboring pair shares the same atom encoding.
            used_atom_divisor (float): Divisor applied to attention scores
                of reactant atoms that are already mapped if one_to_one_correspondence
                is False
            reactants_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Maps
                global reactant atom indices to existing atom map numbers, used
                to anchor partially pre-mapped reactions.
            products_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Maps
                global product atom indices to existing atom map numbers, used
                to anchor partially pre-mapped reactions.

        Returns:
            Tuple[str, float, Dict[str, int]]:
                - Mapped reaction SMILES string with atom map numbers assigned.
                - Confidence score computed as the product of per-atom assignment
                  probabilities.
                - Dictionary mapping oversubscribed reactant SMILES (atom maps
                  removed) to the maximum number of times any atom in that fragment
                  was assigned to multiple product atoms. Empty when
                  one_to_one_correspondence is True or when no oversubscription occurs.
        """
        if not reactants_atom_idx_to_orig_mapping:
            reactants_atom_idx_to_orig_mapping = {}
        if not products_atom_idx_to_orig_mapping:
            products_atom_idx_to_orig_mapping = {}

        reactants_str, products_str = self._split_reaction_components(rxn_smiles)
        reactants_mols = [
            Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
        ]
        products_mols = [
            Chem.MolFromSmiles(product) for product in products_str.split(".")
        ]

        reactants_atom_dict, reactants_atom_dict_neighbors = self._build_atom_dict(
            reactants_mols
        )
        products_atom_dict, products_atom_dict_neighbors = self._build_atom_dict(
            products_mols
        )

        products_orig_mapping_to_idx = {
            value: key
            for key, value in products_atom_idx_to_orig_mapping.items()
            if value != 0
        }
        reactants_orig_mapping_to_idx = {
            value: key
            for key, value in reactants_atom_idx_to_orig_mapping.items()
            if value != 0
        }

        (reactants_to_products_attn, products_to_reactants_attn) = aligned_attn_scores

        orig_reactants_to_products_attn = reactants_to_products_attn.copy()
        orig_products_to_reactants_attn = products_to_reactants_attn.copy()

        reactants_symmetric_indices = self._get_symmetric_atom_indices(reactants_mols)
        products_symmetric_indices = self._get_symmetric_atom_indices(products_mols)

        orig_products_to_reactants_attn = self._apply_symmetric_attention(
            orig_products_to_reactants_attn, reactants_symmetric_indices, axis=1
        )
        orig_reactants_to_products_attn = self._apply_symmetric_attention(
            orig_reactants_to_products_attn, products_symmetric_indices, axis=0
        )

        ## If not a one-to-one correspondence, multiple product atoms could map to the same
        ## reactant atom. That reactant atom signal would be split, producing inaccurate
        ## assignment probabilities. So we use only product to reactant attention
        if one_to_one_correspondence:
            orig_attn = (
                orig_reactants_to_products_attn.copy()
                + orig_products_to_reactants_attn.copy()
            ) / 2
            attn = (
                reactants_to_products_attn.copy() + products_to_reactants_attn.copy()
            ) / 2
        else:
            orig_attn = orig_products_to_reactants_attn.copy()
            attn = products_to_reactants_attn.copy()

        assignment_probs = []
        for map_num in range(attn.shape[0]):
            if products_orig_mapping_to_idx.get(map_num + 1, 0):
                row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
                col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]

                if reactants_atom_dict[col_highest_attn].GetAtomMapNum():
                    if not reactants_atom_dict[col_highest_attn].HasProp(
                        "oversubscribed_count"
                    ):
                        reactants_atom_dict[col_highest_attn].SetIntProp(
                            "oversubscribed_count", 1
                        )
                    else:
                        oversubscribed_count = reactants_atom_dict[
                            col_highest_attn
                        ].GetIntProp("oversubscribed_count")
                        reactants_atom_dict[col_highest_attn].SetIntProp(
                            "oversubscribed_count", oversubscribed_count + 1
                        )

                products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                attn[row_highest_attn] = 0
                attn[:, col_highest_attn] = 0
                assignment_probs.append(1.0)
            else:
                highest_attn_score = attn.max()
                highest_attn_score_indices = np.where(attn == highest_attn_score)
                row_highest_attn = highest_attn_score_indices[0][0]
                col_highest_attn = highest_attn_score_indices[1][0]

                if reactants_atom_dict[col_highest_attn].GetAtomMapNum():
                    if not reactants_atom_dict[col_highest_attn].HasProp(
                        "oversubscribed_count"
                    ):
                        reactants_atom_dict[col_highest_attn].SetIntProp(
                            "oversubscribed_count", 1
                        )
                    else:
                        oversubscribed_count = reactants_atom_dict[
                            col_highest_attn
                        ].GetIntProp("oversubscribed_count")
                        reactants_atom_dict[col_highest_attn].SetIntProp(
                            "oversubscribed_count", oversubscribed_count + 1
                        )

                products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)

                if one_to_one_correspondence:
                    attn[row_highest_attn] = 0
                    attn[:, col_highest_attn] = 0
                else:
                    attn[row_highest_attn] = 0
                    attn[:, col_highest_attn] /= used_atom_divisor

                assignment_probs.append(orig_attn[row_highest_attn, col_highest_attn])

            for (
                product_atom_idx,
                product_atom_env,
            ) in products_atom_dict_neighbors[row_highest_attn]:
                for (
                    reactant_atom_idx,
                    reactant_atom_env,
                ) in reactants_atom_dict_neighbors[col_highest_attn]:
                    if product_atom_env == reactant_atom_env:
                        attn[product_atom_idx, reactant_atom_idx] *= (
                            adjacent_atom_multiplier
                            * identical_adjacent_atom_multiplier
                        )
                    else:
                        attn[product_atom_idx, reactant_atom_idx] *= (
                            adjacent_atom_multiplier
                        )

        mapped_reactants_str = ".".join(
            [Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols]
        )
        mapped_products_str = ".".join(
            [Chem.MolToSmiles(product, canonical=False) for product in products_mols]
        )
        mapped_rxn_smiles = mapped_reactants_str + ">>" + mapped_products_str

        confidence = float(np.prod(assignment_probs))

        if one_to_one_correspondence:
            return mapped_rxn_smiles, confidence, {}

        oversubscribed_dict = {}
        for reactant in reactants_mols:
            max_oversubscribed_count = 0
            for reactant_atom in reactant.GetAtoms():
                if not reactant_atom.HasProp("oversubscribed_count"):
                    continue
                oversubscribed_count = reactant_atom.GetIntProp("oversubscribed_count")
                if oversubscribed_count > max_oversubscribed_count:
                    max_oversubscribed_count = oversubscribed_count
            if max_oversubscribed_count == 0:
                continue
            [atom.SetAtomMapNum(0) for atom in reactant.GetAtoms()]
            oversubscribed_dict[Chem.MolToSmiles(reactant)] = max_oversubscribed_count

        return mapped_rxn_smiles, confidence, oversubscribed_dict

    def get_data_from_partially_mapped_smiles(self, rxn_smiles):
        reactants_str, products_str = self._split_reaction_components(rxn_smiles)
        reactants_mols = [
            Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
        ]
        products_mols = [
            Chem.MolFromSmiles(product) for product in products_str.split(".")
        ]

        reactants_atom_idx_to_orig_mapping = {}
        reactants_atom_dict = {}
        reactants_atom_dict_neighbors = {}
        reactant_atom_num = 0
        for mol in reactants_mols:
            for atom in mol.GetAtoms():
                reactants_atom_dict[reactant_atom_num] = atom
                reactants_atom_idx_to_orig_mapping[reactant_atom_num] = (
                    atom.GetAtomMapNum()
                )
                reactants_atom_dict_neighbors[reactant_atom_num] = [
                    neighbor.GetIdx() for neighbor in atom.GetNeighbors()
                ]
                atom.SetAtomMapNum(0)
                reactant_atom_num += 1

        products_atom_idx_to_orig_mapping = {}
        products_atom_dict = {}
        products_atom_dict_neighbors = {}
        product_atom_num = 0
        for mol in products_mols:
            for atom in mol.GetAtoms():
                products_atom_dict[product_atom_num] = atom
                products_atom_idx_to_orig_mapping[product_atom_num] = (
                    atom.GetAtomMapNum()
                )
                products_atom_dict_neighbors[product_atom_num] = [
                    neighbor.GetIdx() for neighbor in atom.GetNeighbors()
                ]
                atom.SetAtomMapNum(0)
                product_atom_num += 1

        unmapped_reactants_strings = [
            Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols
        ]

        unmapped_products_strings = [
            Chem.MolToSmiles(product, canonical=False) for product in products_mols
        ]

        unmapped_rxn = (
            ".".join(unmapped_reactants_strings)
            + ">>"
            + ".".join(unmapped_products_strings)
        )

        return (
            unmapped_rxn,
            reactants_atom_idx_to_orig_mapping,
            products_atom_idx_to_orig_mapping,
        )

    def _get_attention_matrices_batch(
        self,
        texts: List[str],
        layer: int = 11,
        head: int = 7,
        max_length: int = 512,
    ) -> List[Tuple[np.ndarray, List[str]]]:
        """
        Run batched neural network inference and return log-attention matrices for a
        list of reaction SMILES strings.

        Tokenizes all inputs together in a single padded batch, executes one forward
        pass, then trims each result to its non-padding length before applying the
        logarithm.

        Args:
            texts (List[str]): Reaction SMILES strings to encode. Must be non-empty.
            layer (int): 0-based layer index. Only used when the underlying model is
                the base AlbertForMaskedLM; ignored for AlbertWithAttentionAlignment.
            head (int): 0-based head index. Only used when the underlying model is the
                base AlbertForMaskedLM; ignored for AlbertWithAttentionAlignment.
            max_length (int): Maximum tokenization length. Must match the value used
                during training.

        Returns:
            List[Tuple[np.ndarray, List[str]]]: One entry per input string, each
                containing:
                    - Log-attention matrix of shape (real_seq_len, real_seq_len) as a
                      numpy array, with padding tokens stripped.
                    - List of token strings aligned to the attention matrix axes.
        """
        self._model.eval()

        enc = self._tokenizer(
            texts,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(self._device)
        attention_mask = enc["attention_mask"].to(self._device)
        token_type_ids = enc.get("token_type_ids", torch.zeros_like(enc["input_ids"]))
        token_type_ids = token_type_ids.to(self._device)

        with torch.no_grad():
            if isinstance(self._model, AlbertWithAttentionAlignment):
                attn_probs = self._model.predict_attention_probs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                )  # (B, S, S)
                attn_batch = attn_probs.detach().cpu()  # (B, S, S)
            else:
                outputs = self._model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_attentions=True,
                    return_dict=True,
                )
                attentions = outputs.attentions  # tuple[num_layers] of (B, H, S, S)
                if layer < 0 or layer >= len(attentions):
                    raise ValueError(
                        f"layer must be in [0, {len(attentions) - 1}], got {layer}"
                    )
                num_heads = attentions[layer].shape[1]
                if head < 0 or head >= num_heads:
                    raise ValueError(
                        f"head must be in [0, {num_heads - 1}], got {head}"
                    )
                attn_batch = attentions[layer][:, head].detach().cpu()  # (B, S, S)

        results: List[Tuple[np.ndarray, List[str]]] = []
        for i in range(len(texts)):
            real_len = int(enc["attention_mask"][i].sum().item())
            attn_i = attn_batch[i, :real_len, :real_len]
            tokens_i = self._tokenizer.convert_ids_to_tokens(
                enc["input_ids"][i].tolist()
            )[:real_len]
            results.append((torch.log(attn_i).numpy(), tokens_i))

        return results

    def _map_from_attention(
        self,
        rxn_smiles: str,
        attn: np.ndarray,
        tokens: List[str],
        sequence_max_length: int = 512,
        adjacent_atom_multiplier: float = 10,
        identical_adjacent_atom_multiplier: float = 10,
        one_to_one_correspondence: bool = True,
        # canonicalize_reaction_smiles: bool = True,
        reactants_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
        products_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
    ) -> Tuple[ReactionMapperResult, Optional[str]]:
        """
        Assign atom mappings from a pre-computed log-attention matrix and token list.

        Performs all post-inference processing: token validation, attention masking,
        cross-attention score alignment, non-atom row/column removal, and atom map
        assignment. When one_to_one_correspondence is False and oversubscribed reactant
        atoms are detected, the expanded reaction SMILES (with extra reactant copies)
        is returned as the second element for a downstream retry pass.

        Args:
            rxn_smiles (str): An unmapped reaction SMILES string.
            attn (np.ndarray): Log-attention matrix of shape (seq_len, seq_len) as
                returned by _get_attention_matrices_batch or get_attention_matrix_for_head.
            tokens (List[str]): Token strings aligned to the attention matrix axes.
            sequence_max_length (int): Maximum allowed sequence length; sequences at or
                above this length are treated as failures.
            adjacent_atom_multiplier (float): Multiplier applied to attention scores of
                atoms neighboring an already-mapped pair.
            identical_adjacent_atom_multiplier (float): Additional multiplier applied
                when a neighboring pair shares the same atom encoding.
            one_to_one_correspondence (bool): If True, enforces greedy one-to-one
                assignment; if False, each product atom independently picks its best
                reactant atom.
            reactants_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Existing
                reactant atom map numbers to anchor partial mappings.
            products_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Existing
                product atom map numbers to anchor partial mappings.

        Returns:
            Tuple[ReactionMapperResult, Optional[str]]:
                - Mapping result. On failure (unknown tokens, sequence too long, or
                  invalid mapping), returns a result with an empty selected_mapping.
                - Expanded reaction SMILES with extra copies of oversubscribed reactant
                  fragments appended, or None if no oversubscription was detected. Only
                  non-None when one_to_one_correspondence is False and at least one
                  reactant atom was assigned to more than one product atom.
        """
        default_mapping_dict = ReactionMapperResult(
            original_smiles="",
            selected_mapping="",
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )

        if "[UNK]" in tokens:
            logger.warning("Unknown token in sequence")
            return default_mapping_dict, None

        if ">>" not in tokens:
            logger.warning("Sequence too long")
            return default_mapping_dict, None

        if len(tokens) >= sequence_max_length:
            logger.warning("Sequence too long")
            return default_mapping_dict, None

        string_info_dict = self.get_reactants_products_dict(tokens)
        attn_probs, _ = self.mask_attn_matrix(attn, string_info_dict)

        products_to_reactants_attn, reactants_to_products_attn = (
            self.get_aligned_attn_scores(
                attn_probs,
                string_info_dict["reactants_start_index"],
                string_info_dict["reactants_end_index"],
                string_info_dict["products_start_index"],
            )
        )

        reactants_to_products_attn = self.remove_non_atom_rows_and_columns(
            reactants_to_products_attn, string_info_dict
        )
        products_to_reactants_attn = self.remove_non_atom_rows_and_columns(
            products_to_reactants_attn, string_info_dict
        )

        mapped_rxn_smiles, confidence, oversubscribed_dict = self.assign_atom_maps(
            rxn_smiles,
            (reactants_to_products_attn, products_to_reactants_attn),
            one_to_one_correspondence=one_to_one_correspondence,
            adjacent_atom_multiplier=adjacent_atom_multiplier,
            identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
            reactants_atom_idx_to_orig_mapping=reactants_atom_idx_to_orig_mapping,
            products_atom_idx_to_orig_mapping=products_atom_idx_to_orig_mapping,
        )

        expanded_rxn_smiles: Optional[str] = None
        if oversubscribed_dict:
            orig_reactants, orig_products = rxn_smiles.split(">>")
            new_reactants_list: List[str] = []
            for reactant, num_oversubscribed in oversubscribed_dict.items():
                new_reactants_list.extend([reactant] * num_oversubscribed)
            expanded_rxn_smiles = (
                orig_reactants
                + "."
                + ".".join(new_reactants_list)
                + ">>"
                + orig_products
            )

        if not self._verify_validity_of_mapping(mapped_rxn_smiles):
            return default_mapping_dict, expanded_rxn_smiles

        return ReactionMapperResult(
            original_smiles=rxn_smiles,
            selected_mapping=mapped_rxn_smiles,
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=confidence,
            additional_info=[{}],
        ), expanded_rxn_smiles

    def _strip_unmapped_reactant_fragments(
        self,
        mapped_rxn_smiles: str,
        orig_rxn_smiles: str,
    ) -> str:
        """
        Remove unused extra reactant fragments from an oversubscription-expanded mapped reaction.

        Uses fragment counts rather than positional indices, so the result is
        independent of fragment ordering or SMILES canonicalization. For each
        fragment type (identified by canonical SMILES with atom maps stripped),
        the original count from orig_rxn_smiles is tracked in a counter. When
        processing the mapped reactants, each fragment first tries to consume an
        original slot; if one exists it is kept unconditionally (preserving
        legitimate spectators). Once all original slots for a given type are
        consumed, remaining copies are treated as extra and are kept only if at
        least one of their atoms carries a non-zero atom map number.

        Args:
            mapped_rxn_smiles (str): Mapped reaction SMILES from the second-pass
                retry, containing the original reactants plus any extra appended
                copies.
            orig_rxn_smiles (str): The pre-expansion reaction SMILES, used to
                determine the original fragment counts.

        Returns:
            str: The reaction SMILES with unused (fully unmapped) extra reactant
                fragments removed. Returns mapped_rxn_smiles unchanged if either
                reactants side cannot be parsed.
        """
        orig_reactants_str, _ = self._split_reaction_components(orig_rxn_smiles)
        mapped_reactants_str, products_str = self._split_reaction_components(
            mapped_rxn_smiles
        )
        orig_mol = Chem.MolFromSmiles(orig_reactants_str)
        mapped_mol = Chem.MolFromSmiles(mapped_reactants_str)
        if orig_mol is None or mapped_mol is None:
            return mapped_rxn_smiles

        def _canonical_key(frag: Chem.Mol) -> str:
            rw = Chem.RWMol(frag)
            for atom in rw.GetAtoms():
                atom.SetAtomMapNum(0)
            return Chem.MolToSmiles(rw)

        orig_counts: Dict[str, int] = defaultdict(int)
        for frag in Chem.GetMolFrags(orig_mol, asMols=True):
            orig_counts[_canonical_key(frag)] += 1

        kept_frags: List[str] = []
        for frag in Chem.GetMolFrags(mapped_mol, asMols=True):
            key = _canonical_key(frag)
            if orig_counts[key] > 0:
                orig_counts[key] -= 1
                kept_frags.append(Chem.MolToSmiles(frag, canonical=False))
            elif any(atom.GetAtomMapNum() != 0 for atom in frag.GetAtoms()):
                kept_frags.append(Chem.MolToSmiles(frag, canonical=False))

        return ".".join(kept_frags) + ">>" + products_str

    def map_reaction(
        self,
        rxn_smiles: str,
        layer: int = 11,
        head: int = 7,
        sequence_max_length: int = 512,
        adjacent_atom_multiplier: float = 10,
        identical_adjacent_atom_multiplier: float = 10,
        one_to_one_correspondence: bool = True,
        start_from_partial_map: bool = False,
    ) -> ReactionMapperResult:
        """
        Map a single reaction SMILES string using the neural mapper.

        Convenience wrapper around map_reactions for single-reaction use.

        Args:
            rxn_smiles (str): A reaction SMILES string.
            layer (int): 0-based layer index to use for attention.
            head (int): 0-based head index to use for attention.
            sequence_max_length (int): Maximum allowed sequence length.
            adjacent_atom_multiplier (float): Multiplier for adjacent atom attention scores.
            identical_adjacent_atom_multiplier (float): Additional multiplier when
                neighboring atom encodings match.
            one_to_one_correspondence (bool): If True, enforces greedy one-to-one assignment.
            start_from_partial_map (bool): If True, extracts and preserves existing atom
                map numbers from the input SMILES before remapping.

        Returns:
            ReactionMapperResult: Mapping result. On failure returns a result with an
                empty selected_mapping.
        """
        return self.map_reactions(
            [rxn_smiles],
            layer=layer,
            head=head,
            sequence_max_length=sequence_max_length,
            adjacent_atom_multiplier=adjacent_atom_multiplier,
            identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
            one_to_one_correspondence=one_to_one_correspondence,
            start_from_partial_map=start_from_partial_map,
        )[0]

    def map_reactions(
        self,
        reaction_list: List[str],
        layer: int = 11,
        head: int = 7,
        sequence_max_length: int = 512,
        adjacent_atom_multiplier: float = 10,
        identical_adjacent_atom_multiplier: float = 10,
        one_to_one_correspondence: bool = True,
        start_from_partial_map: bool = False,
        batch_size: int = 32,
    ) -> List[ReactionMapperResult]:
        """
        Map a list of reaction SMILES strings using batched neural network inference.

        Tokenizes and runs the model in batches of batch_size for efficiency, then
        assigns atom mappings for each reaction individually from the resulting
        attention matrices.

        Args:
            reaction_list (List[str]): A list of unmapped reaction SMILES strings.
            layer (int): 0-based layer index. Only used for the base AlbertForMaskedLM;
                ignored for AlbertWithAttentionAlignment.
            head (int): 0-based head index. Only used for the base AlbertForMaskedLM;
                ignored for AlbertWithAttentionAlignment.
            sequence_max_length (int): Maximum tokenization length.
            adjacent_atom_multiplier (float): Multiplier for adjacent atom attention scores.
            identical_adjacent_atom_multiplier (float): Additional multiplier when
                neighboring atom encodings match.
            one_to_one_correspondence (bool): If True, enforces greedy one-to-one assignment.
            start_from_partial_map (bool): If True, extracts and preserves existing atom
                map numbers before remapping.
            batch_size (int): Number of reactions to process in a single forward pass.

        Returns:
            List[ReactionMapperResult]: A list of mapping results, one per input
                reaction. Failed mappings return a result with an empty
                selected_mapping. When one_to_one_correspondence is False and
                oversubscribed reactant atoms are detected, a second batched
                inference pass is run on expanded reactions (extra reactant copies
                appended) with one_to_one_correspondence=True; successful retry
                results replace the first-pass results.
        """
        results: List[ReactionMapperResult] = [
            ReactionMapperResult(
                original_smiles="",
                selected_mapping="",
                possible_mappings={},
                mapping_type=self._mapper_type,
                mapping_score=None,
                additional_info=[{}],
            )
            for _ in reaction_list
        ]

        # Preprocess: validate and optionally strip existing partial maps
        prepared: List[
            Optional[Tuple[str, Optional[Dict[int, int]], Optional[Dict[int, int]]]]
        ] = []
        for rxn_smiles in reaction_list:
            rxn_smiles = canonicalize_reaction_smiles(rxn_smiles)
            if not self._reaction_smiles_valid(rxn_smiles):
                prepared.append(None)
                continue
            reactants_atom_idx_to_orig_mapping = None
            products_atom_idx_to_orig_mapping = None
            if start_from_partial_map:
                (
                    rxn_smiles,
                    reactants_atom_idx_to_orig_mapping,
                    products_atom_idx_to_orig_mapping,
                ) = self.get_data_from_partially_mapped_smiles(rxn_smiles)
            prepared.append(
                (
                    rxn_smiles,
                    reactants_atom_idx_to_orig_mapping,
                    products_atom_idx_to_orig_mapping,
                )
            )

        valid_pairs = [(i, p) for i, p in enumerate(prepared) if p is not None]
        valid_smiles = [p[0] for _, p in valid_pairs]

        if not valid_smiles:
            return results

        # Batched neural network inference
        attn_tokens_list: List[Tuple[np.ndarray, List[str]]] = []
        for batch_start in range(0, len(valid_smiles), batch_size):
            batch = valid_smiles[batch_start : batch_start + batch_size]
            attn_tokens_list.extend(
                self._get_attention_matrices_batch(
                    texts=batch,
                    layer=layer,
                    head=head,
                    max_length=sequence_max_length,
                )
            )

        # Assign atom maps per reaction from pre-computed attention matrices
        oversubscribed_cases: List[Tuple[int, str, str]] = []
        for local_idx, (
            orig_idx,
            (rxn_smiles, reactants_map, products_map),
        ) in enumerate(valid_pairs):
            attn, tokens = attn_tokens_list[local_idx]
            result, expanded_rxn_smiles = self._map_from_attention(
                rxn_smiles=rxn_smiles,
                attn=attn,
                tokens=tokens,
                sequence_max_length=sequence_max_length,
                adjacent_atom_multiplier=adjacent_atom_multiplier,
                identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
                one_to_one_correspondence=one_to_one_correspondence,
                reactants_atom_idx_to_orig_mapping=reactants_map,
                products_atom_idx_to_orig_mapping=products_map,
            )
            results[orig_idx] = result
            if expanded_rxn_smiles is not None:
                expanded_rxn_smiles = canonicalize_reaction_smiles(expanded_rxn_smiles)
                oversubscribed_cases.append((orig_idx, rxn_smiles, expanded_rxn_smiles))

        if not oversubscribed_cases:
            return results

        # Second pass: batch-map expanded reactions with one_to_one_correspondence=True
        expanded_smiles = [expanded for _, _, expanded in oversubscribed_cases]
        expanded_attn_tokens: List[Tuple[np.ndarray, List[str]]] = []
        for batch_start in range(0, len(expanded_smiles), batch_size):
            batch = expanded_smiles[batch_start : batch_start + batch_size]
            expanded_attn_tokens.extend(
                self._get_attention_matrices_batch(
                    texts=batch,
                    layer=layer,
                    head=head,
                    max_length=sequence_max_length,
                )
            )

        for local_idx, (orig_idx, orig_rxn_smiles, expanded_rxn) in enumerate(
            oversubscribed_cases
        ):
            attn, tokens = expanded_attn_tokens[local_idx]
            retry_result, _ = self._map_from_attention(
                rxn_smiles=expanded_rxn,
                attn=attn,
                tokens=tokens,
                sequence_max_length=sequence_max_length,
                adjacent_atom_multiplier=adjacent_atom_multiplier,
                identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
                one_to_one_correspondence=True,
            )
            if retry_result["selected_mapping"]:
                retry_result["original_smiles"] = orig_rxn_smiles
                retry_result["selected_mapping"] = (
                    self._strip_unmapped_reactant_fragments(
                        retry_result["selected_mapping"],
                        orig_rxn_smiles,
                    )
                )
                results[orig_idx] = retry_result

        return results

__init__(mapper_name, mapper_weight=3, checkpoint_path=None, use_supervised=True, supervised_config=None, sequence_max_length=512)

Initialize the NeuralReactionMapper instance.

Parameters:

Name Type Description Default
mapper_name str

The name of the mapper.

required
mapper_weight float

The weight of the mapper.

3
checkpoint_path Optional[str]

The path to the checkpoint file.

None
Source code in agave_chem/mappers/neural/neural_mapper.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(
    self,
    mapper_name: str,
    mapper_weight: float = 3,
    checkpoint_path: Optional[str] = None,
    use_supervised: bool = True,
    supervised_config: SupervisedConfig | None = None,
    sequence_max_length: int = 512,
):
    """
    Initialize the NeuralReactionMapper instance.

    Args:
        mapper_name (str): The name of the mapper.
        mapper_weight (float): The weight of the mapper.
        checkpoint_path (Optional[str]): The path to the checkpoint file.
    """

    super().__init__("neural", mapper_name, mapper_weight)

    if not checkpoint_path:
        checkpoint_path = str(
            files("agave_chem.datafiles.models").joinpath("supervised_albert_model")
        )

    self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    self._sequence_max_length = sequence_max_length
    self._use_supervised = use_supervised
    self._supervised_config = supervised_config or SupervisedConfig()

    self._model = load_neural_albert_model(
        checkpoint_dir=checkpoint_path,
        device=self._device,
        use_supervised=use_supervised,
        max_length=sequence_max_length,
        supervised_config=self._supervised_config,
    )

    self._tokenizer = CustomTokenizer(smiles_token_to_id_dict)

assign_atom_maps(rxn_smiles, aligned_attn_scores, one_to_one_correspondence=True, adjacent_atom_multiplier=30, identical_adjacent_atom_multiplier=10, used_atom_divisor=10, reactants_atom_idx_to_orig_mapping=None, products_atom_idx_to_orig_mapping=None)

Assign atom-to-atom map numbers to a reaction SMILES using a pre-computed attention matrix.

Handles symmetric atoms in both reactants and products by summing their attention contributions, preventing artificially low confidence scores caused by equivalent atoms splitting probability mass.

Parameters:

Name Type Description Default
rxn_smiles str

Unmapped reaction SMILES string of the form "reactants>>products".

required
aligned_attn_scores Tuple[ndarray, ndarray]

Tuple of attention matrices of shape (n_product_atoms, n_reactant_atoms).

required
one_to_one_correspondence bool

If True, enforces a one-to-one assignment using greedy selection of the global attention maximum. If False, assigns each product atom independently to its highest-attention reactant atom.

True
adjacent_atom_multiplier float

Multiplier applied to attention scores of atoms neighboring an already-mapped pair.

30
identical_adjacent_atom_multiplier float

Additional multiplier applied when a neighboring pair shares the same atom encoding.

10
used_atom_divisor float

Divisor applied to attention scores of reactant atoms that are already mapped if one_to_one_correspondence is False

10
reactants_atom_idx_to_orig_mapping Optional[Dict[int, int]]

Maps global reactant atom indices to existing atom map numbers, used to anchor partially pre-mapped reactions.

None
products_atom_idx_to_orig_mapping Optional[Dict[int, int]]

Maps global product atom indices to existing atom map numbers, used to anchor partially pre-mapped reactions.

None

Returns:

Type Description
Tuple[str, float, Dict[str, int]]

Tuple[str, float, Dict[str, int]]: - Mapped reaction SMILES string with atom map numbers assigned. - Confidence score computed as the product of per-atom assignment probabilities. - Dictionary mapping oversubscribed reactant SMILES (atom maps removed) to the maximum number of times any atom in that fragment was assigned to multiple product atoms. Empty when one_to_one_correspondence is True or when no oversubscription occurs.

Source code in agave_chem/mappers/neural/neural_mapper.py
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def assign_atom_maps(
    self,
    rxn_smiles: str,
    aligned_attn_scores: Tuple[np.ndarray, np.ndarray],
    one_to_one_correspondence: bool = True,
    adjacent_atom_multiplier: float = 30,
    identical_adjacent_atom_multiplier: float = 10,
    used_atom_divisor: float = 10,
    reactants_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
    products_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
) -> Tuple[str, float, Dict[str, int]]:
    """
    Assign atom-to-atom map numbers to a reaction SMILES using a pre-computed
    attention matrix.

    Handles symmetric atoms in both reactants and products by summing their
    attention contributions, preventing artificially low confidence scores
    caused by equivalent atoms splitting probability mass.

    Args:
        rxn_smiles (str): Unmapped reaction SMILES string of the form
            "reactants>>products".
        aligned_attn_scores (Tuple[np.ndarray, np.ndarray]): Tuple of attention matrices
            of shape (n_product_atoms, n_reactant_atoms).
        one_to_one_correspondence (bool): If True, enforces a one-to-one
            assignment using greedy selection of the global attention maximum.
            If False, assigns each product atom independently to its
            highest-attention reactant atom.
        adjacent_atom_multiplier (float): Multiplier applied to attention
            scores of atoms neighboring an already-mapped pair.
        identical_adjacent_atom_multiplier (float): Additional multiplier
            applied when a neighboring pair shares the same atom encoding.
        used_atom_divisor (float): Divisor applied to attention scores
            of reactant atoms that are already mapped if one_to_one_correspondence
            is False
        reactants_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Maps
            global reactant atom indices to existing atom map numbers, used
            to anchor partially pre-mapped reactions.
        products_atom_idx_to_orig_mapping (Optional[Dict[int, int]]): Maps
            global product atom indices to existing atom map numbers, used
            to anchor partially pre-mapped reactions.

    Returns:
        Tuple[str, float, Dict[str, int]]:
            - Mapped reaction SMILES string with atom map numbers assigned.
            - Confidence score computed as the product of per-atom assignment
              probabilities.
            - Dictionary mapping oversubscribed reactant SMILES (atom maps
              removed) to the maximum number of times any atom in that fragment
              was assigned to multiple product atoms. Empty when
              one_to_one_correspondence is True or when no oversubscription occurs.
    """
    if not reactants_atom_idx_to_orig_mapping:
        reactants_atom_idx_to_orig_mapping = {}
    if not products_atom_idx_to_orig_mapping:
        products_atom_idx_to_orig_mapping = {}

    reactants_str, products_str = self._split_reaction_components(rxn_smiles)
    reactants_mols = [
        Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
    ]
    products_mols = [
        Chem.MolFromSmiles(product) for product in products_str.split(".")
    ]

    reactants_atom_dict, reactants_atom_dict_neighbors = self._build_atom_dict(
        reactants_mols
    )
    products_atom_dict, products_atom_dict_neighbors = self._build_atom_dict(
        products_mols
    )

    products_orig_mapping_to_idx = {
        value: key
        for key, value in products_atom_idx_to_orig_mapping.items()
        if value != 0
    }
    reactants_orig_mapping_to_idx = {
        value: key
        for key, value in reactants_atom_idx_to_orig_mapping.items()
        if value != 0
    }

    (reactants_to_products_attn, products_to_reactants_attn) = aligned_attn_scores

    orig_reactants_to_products_attn = reactants_to_products_attn.copy()
    orig_products_to_reactants_attn = products_to_reactants_attn.copy()

    reactants_symmetric_indices = self._get_symmetric_atom_indices(reactants_mols)
    products_symmetric_indices = self._get_symmetric_atom_indices(products_mols)

    orig_products_to_reactants_attn = self._apply_symmetric_attention(
        orig_products_to_reactants_attn, reactants_symmetric_indices, axis=1
    )
    orig_reactants_to_products_attn = self._apply_symmetric_attention(
        orig_reactants_to_products_attn, products_symmetric_indices, axis=0
    )

    ## If not a one-to-one correspondence, multiple product atoms could map to the same
    ## reactant atom. That reactant atom signal would be split, producing inaccurate
    ## assignment probabilities. So we use only product to reactant attention
    if one_to_one_correspondence:
        orig_attn = (
            orig_reactants_to_products_attn.copy()
            + orig_products_to_reactants_attn.copy()
        ) / 2
        attn = (
            reactants_to_products_attn.copy() + products_to_reactants_attn.copy()
        ) / 2
    else:
        orig_attn = orig_products_to_reactants_attn.copy()
        attn = products_to_reactants_attn.copy()

    assignment_probs = []
    for map_num in range(attn.shape[0]):
        if products_orig_mapping_to_idx.get(map_num + 1, 0):
            row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
            col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]

            if reactants_atom_dict[col_highest_attn].GetAtomMapNum():
                if not reactants_atom_dict[col_highest_attn].HasProp(
                    "oversubscribed_count"
                ):
                    reactants_atom_dict[col_highest_attn].SetIntProp(
                        "oversubscribed_count", 1
                    )
                else:
                    oversubscribed_count = reactants_atom_dict[
                        col_highest_attn
                    ].GetIntProp("oversubscribed_count")
                    reactants_atom_dict[col_highest_attn].SetIntProp(
                        "oversubscribed_count", oversubscribed_count + 1
                    )

            products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
            reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
            attn[row_highest_attn] = 0
            attn[:, col_highest_attn] = 0
            assignment_probs.append(1.0)
        else:
            highest_attn_score = attn.max()
            highest_attn_score_indices = np.where(attn == highest_attn_score)
            row_highest_attn = highest_attn_score_indices[0][0]
            col_highest_attn = highest_attn_score_indices[1][0]

            if reactants_atom_dict[col_highest_attn].GetAtomMapNum():
                if not reactants_atom_dict[col_highest_attn].HasProp(
                    "oversubscribed_count"
                ):
                    reactants_atom_dict[col_highest_attn].SetIntProp(
                        "oversubscribed_count", 1
                    )
                else:
                    oversubscribed_count = reactants_atom_dict[
                        col_highest_attn
                    ].GetIntProp("oversubscribed_count")
                    reactants_atom_dict[col_highest_attn].SetIntProp(
                        "oversubscribed_count", oversubscribed_count + 1
                    )

            products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
            reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)

            if one_to_one_correspondence:
                attn[row_highest_attn] = 0
                attn[:, col_highest_attn] = 0
            else:
                attn[row_highest_attn] = 0
                attn[:, col_highest_attn] /= used_atom_divisor

            assignment_probs.append(orig_attn[row_highest_attn, col_highest_attn])

        for (
            product_atom_idx,
            product_atom_env,
        ) in products_atom_dict_neighbors[row_highest_attn]:
            for (
                reactant_atom_idx,
                reactant_atom_env,
            ) in reactants_atom_dict_neighbors[col_highest_attn]:
                if product_atom_env == reactant_atom_env:
                    attn[product_atom_idx, reactant_atom_idx] *= (
                        adjacent_atom_multiplier
                        * identical_adjacent_atom_multiplier
                    )
                else:
                    attn[product_atom_idx, reactant_atom_idx] *= (
                        adjacent_atom_multiplier
                    )

    mapped_reactants_str = ".".join(
        [Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols]
    )
    mapped_products_str = ".".join(
        [Chem.MolToSmiles(product, canonical=False) for product in products_mols]
    )
    mapped_rxn_smiles = mapped_reactants_str + ">>" + mapped_products_str

    confidence = float(np.prod(assignment_probs))

    if one_to_one_correspondence:
        return mapped_rxn_smiles, confidence, {}

    oversubscribed_dict = {}
    for reactant in reactants_mols:
        max_oversubscribed_count = 0
        for reactant_atom in reactant.GetAtoms():
            if not reactant_atom.HasProp("oversubscribed_count"):
                continue
            oversubscribed_count = reactant_atom.GetIntProp("oversubscribed_count")
            if oversubscribed_count > max_oversubscribed_count:
                max_oversubscribed_count = oversubscribed_count
        if max_oversubscribed_count == 0:
            continue
        [atom.SetAtomMapNum(0) for atom in reactant.GetAtoms()]
        oversubscribed_dict[Chem.MolToSmiles(reactant)] = max_oversubscribed_count

    return mapped_rxn_smiles, confidence, oversubscribed_dict

get_aligned_attn_scores(out, reactants_start_index, reactants_end_index, products_start_index)

Extract and align cross-attention scores between reactant and product tokens.

Slices the full attention matrix out to obtain two cross-attention sub-matrices: one representing how each product token attends to reactant tokens, and one representing how each reactant token attends to product tokens. The latter is transposed so that both returned arrays share the same index orientation (rows = product tokens, columns = reactant tokens).

Parameters:

Name Type Description Default
out ndarray

Square attention probability matrix of shape (sequence_length, sequence_length), where entry [i, j] is the attention weight from token i to token j.

required
reactants_start_index int

Index of the first reactant atom token in the sequence.

required
reactants_end_index int

Index of the last reactant atom token in the sequence (inclusive).

required
products_start_index int

Index of the first product token in the sequence; all tokens from this index onward are product tokens.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

Tuple[np.ndarray, np.ndarray]: - products_to_reactants_attn (np.ndarray): Sub-matrix of shape (n_product_tokens, n_reactant_tokens) giving the attention weights from each product token to each reactant token. - reactants_to_products_attn (np.ndarray): Transposed sub-matrix of shape (n_product_tokens, n_reactant_tokens) giving the attention weights from each reactant token to each product token, transposed so that rows correspond to product tokens and columns to reactant tokens, matching the orientation of products_to_reactants_attn.

Source code in agave_chem/mappers/neural/neural_mapper.py
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def get_aligned_attn_scores(
    self,
    out: np.ndarray,
    reactants_start_index: int,
    reactants_end_index: int,
    products_start_index: int,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Extract and align cross-attention scores between reactant and product tokens.

    Slices the full attention matrix `out` to obtain two cross-attention
    sub-matrices: one representing how each product token attends to reactant
    tokens, and one representing how each reactant token attends to product
    tokens. The latter is transposed so that both returned arrays share the
    same index orientation (rows = product tokens, columns = reactant tokens).

    Args:
        out (np.ndarray): Square attention probability matrix of shape
            ``(sequence_length, sequence_length)``, where entry ``[i, j]``
            is the attention weight from token ``i`` to token ``j``.
        reactants_start_index (int): Index of the first reactant atom token
            in the sequence.
        reactants_end_index (int): Index of the last reactant atom token
            in the sequence (inclusive).
        products_start_index (int): Index of the first product token in the
            sequence; all tokens from this index onward are product tokens.

    Returns:
        Tuple[np.ndarray, np.ndarray]:
            - **products_to_reactants_attn** (np.ndarray): Sub-matrix of shape
              ``(n_product_tokens, n_reactant_tokens)`` giving the attention
              weights from each product token to each reactant token.
            - **reactants_to_products_attn** (np.ndarray): Transposed sub-matrix
              of shape ``(n_product_tokens, n_reactant_tokens)`` giving the
              attention weights from each reactant token to each product token,
              transposed so that rows correspond to product tokens and columns
              to reactant tokens, matching the orientation of
              ``products_to_reactants_attn``.
    """
    products_to_reactants_attn = out[
        products_start_index:,
        reactants_start_index : reactants_end_index + 1,
    ]  # products to reactants attention
    reactants_to_products_attn = out[
        reactants_start_index : reactants_end_index + 1, products_start_index:
    ].T  # reactants to products attention, transposed so indices align
    return products_to_reactants_attn, reactants_to_products_attn

get_attention_matrix_for_head(text, layer, head, max_length=512, trim_padding=True)

Returns the attention matrix for a given layer/head for a single input string.

Parameters:

Name Type Description Default
text str

input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)

required
layer int

0-based layer index

required
head int

0-based head index

required
max_length int

tokenization length (should match training, e.g. 256)

512
trim_padding bool

if True, slices matrix down to non-pad tokens only

True

Returns:

Name Type Description
attn ndarray

Tensor of shape (seq_len, seq_len) (trimmed if requested)

tokens List[str]

list[str] tokens aligned to attn axes (trimmed if requested)

Source code in agave_chem/mappers/neural/neural_mapper.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def get_attention_matrix_for_head(
    self,
    text: str,
    layer: int,
    head: int,
    max_length: int = 512,
    trim_padding: bool = True,
) -> Tuple[np.ndarray, List[str]]:
    """
    Returns the attention matrix for a given layer/head for a single input string.

    Args:
        text: input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)
        layer: 0-based layer index
        head: 0-based head index
        max_length: tokenization length (should match training, e.g. 256)
        trim_padding: if True, slices matrix down to non-pad tokens only

    Returns:
        attn: Tensor of shape (seq_len, seq_len) (trimmed if requested)
        tokens: list[str] tokens aligned to attn axes (trimmed if requested)
    """
    self._model.eval()

    enc = self._tokenizer(
        text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"].to(self._device)
    attention_mask = enc["attention_mask"].to(self._device)

    token_type_ids = enc.get("token_type_ids", torch.zeros_like(enc["input_ids"]))
    token_type_ids = token_type_ids.to(self._device)

    with torch.no_grad():
        if isinstance(self._model, AlbertWithAttentionAlignment):
            attn_probs = self._model.predict_attention_probs(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )  # (B,S,S)
            attn = attn_probs[0].detach().cpu()  # (S,S)
        else:
            outputs = self._model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True,
                return_dict=True,
            )
            attentions = outputs.attentions  # tuple[num_layers] of (B,H,S,S)

            if layer < 0 or layer >= len(attentions):
                raise ValueError(
                    f"layer must be in [0, {len(attentions) - 1}], got {layer}"
                )

            num_heads = attentions[layer].shape[1]
            if head < 0 or head >= num_heads:
                raise ValueError(
                    f"head must be in [0, {num_heads - 1}], got {head}"
                )

            attn = attentions[layer][0, head].detach().cpu()  # (S,S)

    # Tokens for inspection/plotting
    token_ids = enc["input_ids"][0].tolist()
    tokens = self._tokenizer.convert_ids_to_tokens(token_ids)

    if trim_padding:
        real_len = int(enc["attention_mask"][0].sum().item())
        attn = attn[:real_len, :real_len]
        tokens = tokens[:real_len]

    # IMPORTANT: keep downstream behavior identical by returning log-attn
    return torch.log(attn).numpy(), tokens

get_duplicate_indices(list_of_lists)

Find indices of duplicate values across a list of sublists using globally offset indices.

For each element, returns a mapping to all other elements in the same sublist that share the same value. Elements without duplicates are omitted.

Parameters:

Name Type Description Default
list_of_lists List[List[int]]

A list of sublists, where each sublist contains integer values (e.g., canonical atom ranks per molecule).

required

Returns:

Type Description
Dict[int, List[int]]

Dict[int, List[int]]: A dictionary mapping each globally-offset index to a list of other globally-offset indices within the same sublist that share the same value. Only indices with at least one duplicate are included.

Source code in agave_chem/mappers/neural/neural_mapper.py
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
def get_duplicate_indices(
    self, list_of_lists: List[List[int]]
) -> Dict[int, List[int]]:
    """
    Find indices of duplicate values across a list of sublists using globally
    offset indices.

    For each element, returns a mapping to all other elements in the same
    sublist that share the same value. Elements without duplicates are omitted.

    Args:
        list_of_lists (List[List[int]]): A list of sublists, where each sublist
            contains integer values (e.g., canonical atom ranks per molecule).

    Returns:
        Dict[int, List[int]]: A dictionary mapping each globally-offset index
            to a list of other globally-offset indices within the same sublist
            that share the same value. Only indices with at least one duplicate
            are included.
    """
    result = {}
    offset = 0

    for sublist in list_of_lists:
        # Group flattened indices by value within this sublist
        value_to_indices = defaultdict(list)
        for i, val in enumerate(sublist):
            value_to_indices[val].append(offset + i)

        # For each item, map to OTHER items with same value in the same sublist
        for i, val in enumerate(sublist):
            flat_idx = offset + i
            others = [idx for idx in value_to_indices[val] if idx != flat_idx]
            if others:  # only include entries that actually have duplicates
                result[flat_idx] = others

        offset += len(sublist)

    return result

get_reactants_products_dict(tokens)

Extracts reactants and products from a list of tokens in a reaction SMILES string.

Parameters:

Name Type Description Default
tokens List[str]

A list of tokens in a reaction SMILES string.

required

Returns:

Type Description
StringInfoDict

A tuple containing: reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings. products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings. atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices. non_atom_tokens: A list of token indices that correspond to non-atom tokens. reactants_start_index: The index of the first reactant token. reactants_end_index: The index of the last reactant token. products_start_index: The index of the first product token. products_end_index: The index of the last product token.

Source code in agave_chem/mappers/neural/neural_mapper.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def get_reactants_products_dict(
    self,
    tokens: List[str],
) -> StringInfoDict:
    """
    Extracts reactants and products from a list of tokens in a reaction SMILES string.

    Args:
        tokens: A list of tokens in a reaction SMILES string.

    Returns:
        A tuple containing:
            reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
            products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
            atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices.
            non_atom_tokens: A list of token indices that correspond to non-atom tokens.
            reactants_start_index: The index of the first reactant token.
            reactants_end_index: The index of the last reactant token.
            products_start_index: The index of the first product token.
            products_end_index: The index of the last product token.
    """
    reactants_dict: Dict[int, str] = {}
    products_dict: Dict[int, str] = {}
    atom_tokens_dict: Dict[int, List[int]] = {}
    non_atom_tokens: List[int] = []

    found_reaction_symbol = False
    for i, token in enumerate(tokens):
        if token == ">>":
            found_reaction_symbol = True
            non_atom_tokens.append(i)
            continue
        if token_atom_identity_dict.get(token, 0) == 0:
            non_atom_tokens.append(i)
        else:
            if token_atom_identity_dict.get(token, 0) not in atom_tokens_dict:
                atom_tokens_dict[token_atom_identity_dict.get(token, 0)] = [i]
            else:
                atom_tokens_dict[token_atom_identity_dict.get(token, 0)].append(i)
        if found_reaction_symbol:
            products_dict[i] = token
        else:
            reactants_dict[i] = token

    string_info_dict: StringInfoDict = {
        "reactants_dict": reactants_dict,
        "products_dict": products_dict,
        "reactants_start_index": 0,
        "reactants_end_index": max(reactants_dict.keys()),
        "products_start_index": min(products_dict.keys()),
        "products_end_index": max(products_dict.keys()),
        "atom_tokens_dict": atom_tokens_dict,
        "non_atom_tokens": non_atom_tokens,
    }

    return string_info_dict

map_reaction(rxn_smiles, layer=11, head=7, sequence_max_length=512, adjacent_atom_multiplier=10, identical_adjacent_atom_multiplier=10, one_to_one_correspondence=True, start_from_partial_map=False)

Map a single reaction SMILES string using the neural mapper.

Convenience wrapper around map_reactions for single-reaction use.

Parameters:

Name Type Description Default
rxn_smiles str

A reaction SMILES string.

required
layer int

0-based layer index to use for attention.

11
head int

0-based head index to use for attention.

7
sequence_max_length int

Maximum allowed sequence length.

512
adjacent_atom_multiplier float

Multiplier for adjacent atom attention scores.

10
identical_adjacent_atom_multiplier float

Additional multiplier when neighboring atom encodings match.

10
one_to_one_correspondence bool

If True, enforces greedy one-to-one assignment.

True
start_from_partial_map bool

If True, extracts and preserves existing atom map numbers from the input SMILES before remapping.

False

Returns:

Name Type Description
ReactionMapperResult ReactionMapperResult

Mapping result. On failure returns a result with an empty selected_mapping.

Source code in agave_chem/mappers/neural/neural_mapper.py
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
def map_reaction(
    self,
    rxn_smiles: str,
    layer: int = 11,
    head: int = 7,
    sequence_max_length: int = 512,
    adjacent_atom_multiplier: float = 10,
    identical_adjacent_atom_multiplier: float = 10,
    one_to_one_correspondence: bool = True,
    start_from_partial_map: bool = False,
) -> ReactionMapperResult:
    """
    Map a single reaction SMILES string using the neural mapper.

    Convenience wrapper around map_reactions for single-reaction use.

    Args:
        rxn_smiles (str): A reaction SMILES string.
        layer (int): 0-based layer index to use for attention.
        head (int): 0-based head index to use for attention.
        sequence_max_length (int): Maximum allowed sequence length.
        adjacent_atom_multiplier (float): Multiplier for adjacent atom attention scores.
        identical_adjacent_atom_multiplier (float): Additional multiplier when
            neighboring atom encodings match.
        one_to_one_correspondence (bool): If True, enforces greedy one-to-one assignment.
        start_from_partial_map (bool): If True, extracts and preserves existing atom
            map numbers from the input SMILES before remapping.

    Returns:
        ReactionMapperResult: Mapping result. On failure returns a result with an
            empty selected_mapping.
    """
    return self.map_reactions(
        [rxn_smiles],
        layer=layer,
        head=head,
        sequence_max_length=sequence_max_length,
        adjacent_atom_multiplier=adjacent_atom_multiplier,
        identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
        one_to_one_correspondence=one_to_one_correspondence,
        start_from_partial_map=start_from_partial_map,
    )[0]

map_reactions(reaction_list, layer=11, head=7, sequence_max_length=512, adjacent_atom_multiplier=10, identical_adjacent_atom_multiplier=10, one_to_one_correspondence=True, start_from_partial_map=False, batch_size=32)

Map a list of reaction SMILES strings using batched neural network inference.

Tokenizes and runs the model in batches of batch_size for efficiency, then assigns atom mappings for each reaction individually from the resulting attention matrices.

Parameters:

Name Type Description Default
reaction_list List[str]

A list of unmapped reaction SMILES strings.

required
layer int

0-based layer index. Only used for the base AlbertForMaskedLM; ignored for AlbertWithAttentionAlignment.

11
head int

0-based head index. Only used for the base AlbertForMaskedLM; ignored for AlbertWithAttentionAlignment.

7
sequence_max_length int

Maximum tokenization length.

512
adjacent_atom_multiplier float

Multiplier for adjacent atom attention scores.

10
identical_adjacent_atom_multiplier float

Additional multiplier when neighboring atom encodings match.

10
one_to_one_correspondence bool

If True, enforces greedy one-to-one assignment.

True
start_from_partial_map bool

If True, extracts and preserves existing atom map numbers before remapping.

False
batch_size int

Number of reactions to process in a single forward pass.

32

Returns:

Type Description
List[ReactionMapperResult]

List[ReactionMapperResult]: A list of mapping results, one per input reaction. Failed mappings return a result with an empty selected_mapping. When one_to_one_correspondence is False and oversubscribed reactant atoms are detected, a second batched inference pass is run on expanded reactions (extra reactant copies appended) with one_to_one_correspondence=True; successful retry results replace the first-pass results.

Source code in agave_chem/mappers/neural/neural_mapper.py
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
def map_reactions(
    self,
    reaction_list: List[str],
    layer: int = 11,
    head: int = 7,
    sequence_max_length: int = 512,
    adjacent_atom_multiplier: float = 10,
    identical_adjacent_atom_multiplier: float = 10,
    one_to_one_correspondence: bool = True,
    start_from_partial_map: bool = False,
    batch_size: int = 32,
) -> List[ReactionMapperResult]:
    """
    Map a list of reaction SMILES strings using batched neural network inference.

    Tokenizes and runs the model in batches of batch_size for efficiency, then
    assigns atom mappings for each reaction individually from the resulting
    attention matrices.

    Args:
        reaction_list (List[str]): A list of unmapped reaction SMILES strings.
        layer (int): 0-based layer index. Only used for the base AlbertForMaskedLM;
            ignored for AlbertWithAttentionAlignment.
        head (int): 0-based head index. Only used for the base AlbertForMaskedLM;
            ignored for AlbertWithAttentionAlignment.
        sequence_max_length (int): Maximum tokenization length.
        adjacent_atom_multiplier (float): Multiplier for adjacent atom attention scores.
        identical_adjacent_atom_multiplier (float): Additional multiplier when
            neighboring atom encodings match.
        one_to_one_correspondence (bool): If True, enforces greedy one-to-one assignment.
        start_from_partial_map (bool): If True, extracts and preserves existing atom
            map numbers before remapping.
        batch_size (int): Number of reactions to process in a single forward pass.

    Returns:
        List[ReactionMapperResult]: A list of mapping results, one per input
            reaction. Failed mappings return a result with an empty
            selected_mapping. When one_to_one_correspondence is False and
            oversubscribed reactant atoms are detected, a second batched
            inference pass is run on expanded reactions (extra reactant copies
            appended) with one_to_one_correspondence=True; successful retry
            results replace the first-pass results.
    """
    results: List[ReactionMapperResult] = [
        ReactionMapperResult(
            original_smiles="",
            selected_mapping="",
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )
        for _ in reaction_list
    ]

    # Preprocess: validate and optionally strip existing partial maps
    prepared: List[
        Optional[Tuple[str, Optional[Dict[int, int]], Optional[Dict[int, int]]]]
    ] = []
    for rxn_smiles in reaction_list:
        rxn_smiles = canonicalize_reaction_smiles(rxn_smiles)
        if not self._reaction_smiles_valid(rxn_smiles):
            prepared.append(None)
            continue
        reactants_atom_idx_to_orig_mapping = None
        products_atom_idx_to_orig_mapping = None
        if start_from_partial_map:
            (
                rxn_smiles,
                reactants_atom_idx_to_orig_mapping,
                products_atom_idx_to_orig_mapping,
            ) = self.get_data_from_partially_mapped_smiles(rxn_smiles)
        prepared.append(
            (
                rxn_smiles,
                reactants_atom_idx_to_orig_mapping,
                products_atom_idx_to_orig_mapping,
            )
        )

    valid_pairs = [(i, p) for i, p in enumerate(prepared) if p is not None]
    valid_smiles = [p[0] for _, p in valid_pairs]

    if not valid_smiles:
        return results

    # Batched neural network inference
    attn_tokens_list: List[Tuple[np.ndarray, List[str]]] = []
    for batch_start in range(0, len(valid_smiles), batch_size):
        batch = valid_smiles[batch_start : batch_start + batch_size]
        attn_tokens_list.extend(
            self._get_attention_matrices_batch(
                texts=batch,
                layer=layer,
                head=head,
                max_length=sequence_max_length,
            )
        )

    # Assign atom maps per reaction from pre-computed attention matrices
    oversubscribed_cases: List[Tuple[int, str, str]] = []
    for local_idx, (
        orig_idx,
        (rxn_smiles, reactants_map, products_map),
    ) in enumerate(valid_pairs):
        attn, tokens = attn_tokens_list[local_idx]
        result, expanded_rxn_smiles = self._map_from_attention(
            rxn_smiles=rxn_smiles,
            attn=attn,
            tokens=tokens,
            sequence_max_length=sequence_max_length,
            adjacent_atom_multiplier=adjacent_atom_multiplier,
            identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
            one_to_one_correspondence=one_to_one_correspondence,
            reactants_atom_idx_to_orig_mapping=reactants_map,
            products_atom_idx_to_orig_mapping=products_map,
        )
        results[orig_idx] = result
        if expanded_rxn_smiles is not None:
            expanded_rxn_smiles = canonicalize_reaction_smiles(expanded_rxn_smiles)
            oversubscribed_cases.append((orig_idx, rxn_smiles, expanded_rxn_smiles))

    if not oversubscribed_cases:
        return results

    # Second pass: batch-map expanded reactions with one_to_one_correspondence=True
    expanded_smiles = [expanded for _, _, expanded in oversubscribed_cases]
    expanded_attn_tokens: List[Tuple[np.ndarray, List[str]]] = []
    for batch_start in range(0, len(expanded_smiles), batch_size):
        batch = expanded_smiles[batch_start : batch_start + batch_size]
        expanded_attn_tokens.extend(
            self._get_attention_matrices_batch(
                texts=batch,
                layer=layer,
                head=head,
                max_length=sequence_max_length,
            )
        )

    for local_idx, (orig_idx, orig_rxn_smiles, expanded_rxn) in enumerate(
        oversubscribed_cases
    ):
        attn, tokens = expanded_attn_tokens[local_idx]
        retry_result, _ = self._map_from_attention(
            rxn_smiles=expanded_rxn,
            attn=attn,
            tokens=tokens,
            sequence_max_length=sequence_max_length,
            adjacent_atom_multiplier=adjacent_atom_multiplier,
            identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
            one_to_one_correspondence=True,
        )
        if retry_result["selected_mapping"]:
            retry_result["original_smiles"] = orig_rxn_smiles
            retry_result["selected_mapping"] = (
                self._strip_unmapped_reactant_fragments(
                    retry_result["selected_mapping"],
                    orig_rxn_smiles,
                )
            )
            results[orig_idx] = retry_result

    return results

mask_attn_matrix(attn, string_info_dict)

Masks the attention matrix to set the attention probability for certain tokens to 0.

Parameters:

Name Type Description Default
attn ndarray

The attention matrix to be masked.

required
reactants_start_index

The index of the first reactant token.

required
reactants_end_index

The index of the last reactant token.

required
products_start_index

The index of the first product token.

required
products_end_index

The index of the last product token.

required
non_atom_tokens

A list of indices of non-atom tokens.

required
atom_tokens_dict

A dictionary mapping atom numbers to a list of token indices.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

The masked attention matrix.

Source code in agave_chem/mappers/neural/neural_mapper.py
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
def mask_attn_matrix(
    self,
    attn: np.ndarray,
    string_info_dict: StringInfoDict,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Masks the attention matrix to set the attention probability for certain tokens to 0.

    Args:
        attn: The attention matrix to be masked.
        reactants_start_index: The index of the first reactant token.
        reactants_end_index: The index of the last reactant token.
        products_start_index: The index of the first product token.
        products_end_index: The index of the last product token.
        non_atom_tokens: A list of indices of non-atom tokens.
        atom_tokens_dict: A dictionary mapping atom numbers to a list of token indices.

    Returns:
        The masked attention matrix.
    """
    attn[
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
    ] = -1e6  # Set attention logits for reactant tokens to other reactant tokens to very small value
    attn[
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
    ] = -1e6  # Set attention logits for product tokens to other product tokens to very small value
    for i in string_info_dict[
        "non_atom_tokens"
    ][
        :-1
    ]:  # Set attention logits for reactant or product tokens to non-atom tokens to very small value
        attn[i] = -1e6
        attn[:, i] = -1e6

    for token_indices in string_info_dict[
        "atom_tokens_dict"
    ].values():  # Set attention logits for reactant and product tokens of different atom numbers to very small value
        idx = np.asarray(token_indices, dtype=np.int64)
        last = attn.shape[0] - 1
        idx = idx[idx != last]  # protect last row/column from mask

        diff_atom_mask = np.ones(attn.shape[1], dtype=bool)
        diff_atom_mask[idx] = False
        diff_atom_mask[last] = False  # protect last row/column from mask

        attn[np.ix_(idx, diff_atom_mask)] = -1e6
        attn[np.ix_(diff_atom_mask, idx)] = -1e6

    row_max = np.max(attn, axis=1, keepdims=True)  # max per row
    exp_logits = np.exp(attn - row_max)
    probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

    probs[
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
    ] = 0  # Set attention probability for reactant tokens to other reactant tokens to 0
    probs[
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
    ] = 0  # Set attention probability for product tokens to other product tokens to 0
    for i in string_info_dict[
        "non_atom_tokens"
    ][
        :-1
    ]:  # Set attention probability for reactant or product tokens to non-atom tokens to 0
        probs[i] = 0
        probs[:, i] = 0

    for token_indices in string_info_dict[
        "atom_tokens_dict"
    ].values():  # Set attention probability for reactant and product tokens of different atom numbers to 0
        idx = np.asarray(token_indices, dtype=np.int64)

        diff_atom_mask = np.ones(probs.shape[1], dtype=bool)
        diff_atom_mask[idx] = False
        probs[np.ix_(idx, diff_atom_mask)] = 0
        probs[np.ix_(diff_atom_mask, idx)] = 0

    return probs, exp_logits

remove_non_atom_rows_and_columns(attn, string_info_dict)

Remove non-atom tokens from attention matrix.

Parameters:

Name Type Description Default
attn ndarray

The attention matrix.

required
string_info_dict StringInfoDict

A dictionary containing information about the tokens in the reaction SMILES string.

required

Returns:

Type Description
ndarray

np.ndarray: The attention matrix with non-atom tokens removed.

Source code in agave_chem/mappers/neural/neural_mapper.py
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def remove_non_atom_rows_and_columns(
    self, attn: np.ndarray, string_info_dict: StringInfoDict
) -> np.ndarray:
    """
    Remove non-atom tokens from attention matrix.

    Args:
        attn (np.ndarray): The attention matrix.
        string_info_dict (StringInfoDict): A dictionary containing information about the tokens in the reaction SMILES string.

    Returns:
        np.ndarray: The attention matrix with non-atom tokens removed.
    """
    reactants_non_atom_tokens = [
        ele
        for ele in string_info_dict["non_atom_tokens"]
        if ele <= string_info_dict["reactants_end_index"]
    ]  # Get non-atom tokens in reactants
    products_non_atom_tokens = [
        ele - string_info_dict["products_start_index"]
        for ele in string_info_dict["non_atom_tokens"]
        if ele >= string_info_dict["products_start_index"]
    ]  # Get non-atom tokens in products with offset of products start index

    idx = np.asarray(reactants_non_atom_tokens, dtype=int)
    attn = np.delete(attn, idx, axis=1)

    idx = np.asarray(products_non_atom_tokens, dtype=int)
    attn = np.delete(attn, idx, axis=0)

    return attn

TemplateReactionMapper

Bases: ReactionMapper

Expert template reaction classification and atom-mapping

Source code in agave_chem/mappers/template/template_mapper.py
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
class TemplateReactionMapper(ReactionMapper):
    """
    Expert template reaction classification and atom-mapping
    """

    def __init__(
        self,
        mapper_name: str,
        mapper_weight: float = 3,
        custom_smirks_patterns: List[SmirksPattern] | None = None,
        use_default_smirks_patterns: bool = True,
        max_transforms: int = 1000,
        max_tautomers: int = 1000,
        use_mcs_mapping: bool = True,
    ):
        """
        Initialize the TemplateMapper instance.

        Args:
            custom_smirks_patterns (List[Dict]): A list of dictionaries containing
                custom SMIRKS patterns. Each dictionary should have a 'name' key,
                a 'smirks' key, and a 'superclass_id' key.
            use_default_smirks_patterns (bool): Whether to use the default SMIRKS
                patterns.
        """

        super().__init__("template", mapper_name, mapper_weight)

        if custom_smirks_patterns is not None:
            if not isinstance(custom_smirks_patterns, list):
                raise TypeError(
                    "Invalid input: custom_smirks_patterns must be a list of dictionaries."
                )
            for pattern in custom_smirks_patterns:
                if set(pattern.keys()) != set(["name", "smirks", "superclass_id"]):
                    raise TypeError(
                        "Invalid input: each dictionary in custom_smirks_patterns must have 'name', 'smirks', and 'superclass_id' keys."
                    )
                for key, value in pattern.items():
                    if key == "superclass_id":
                        if value is not None and not isinstance(value, int):
                            raise TypeError(
                                "Invalid input: 'superclass_id' value must be an integer or None."
                            )
                    else:
                        if not isinstance(value, str):
                            raise TypeError(
                                "Invalid input: 'name' and 'smirks' values must be strings."
                            )

        self._custom_smirks_patterns = custom_smirks_patterns
        self._use_default_smirks_patterns = use_default_smirks_patterns

        smirks_patterns_file = files("agave_chem.datafiles.smirks_patterns").joinpath(
            "smirks_patterns_with_children.json"
        )
        with smirks_patterns_file.open("r") as f:
            self._uninitialized_smirks_patterns = json.load(f)
        self._initialized_smirks_patterns: Optional[List[InitializedSmirksPattern]] = (
            None
        )

        self._tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
        self._tautomer_enumerator.SetMaxTransforms(max_transforms)
        self._tautomer_enumerator.SetMaxTautomers(max_tautomers)

        self._mcs_mapper = None
        if use_mcs_mapping:
            self._mcs_mapper = MCSReactionMapper(
                mapper_name="mcs_for_template", mapper_weight=1
            )

    def _initialize_smirks_patterns(self) -> None:
        """Initialize SMIRKS patterns."""
        if self._initialized_smirks_patterns is not None:
            return

        if self._use_default_smirks_patterns and self._custom_smirks_patterns is None:
            smirks_patterns = self._uninitialized_smirks_patterns
        elif self._custom_smirks_patterns and not self._use_default_smirks_patterns:
            smirks_patterns = self._custom_smirks_patterns
        elif self._custom_smirks_patterns and self._use_default_smirks_patterns:
            smirks_patterns = (
                self._custom_smirks_patterns + self._uninitialized_smirks_patterns
            )
        else:
            raise TypeError(
                "Attempting to initialize AgaveChem with no SMIRKS patterns"
            )

        initialized_smirks_patterns: List[InitializedSmirksPattern] = []
        for pattern in smirks_patterns:
            pattern_priority = pattern.get(
                "priority", {"priority_class": None, "priority": None}
            )

            pattern_priority_tuple = (
                pattern_priority.get("priority_class", 0),
                pattern_priority.get("priority", 0),
            )

            if None in pattern_priority_tuple:
                pattern_priority_tuple = (0, 0)

            for child_pattern in pattern.get("child_smirks", []):
                reactants_smarts, products_smarts, rdc_rxn = (
                    self._initialize_template_data_from_child_patterns(child_pattern)
                )
                if (
                    reactants_smarts is None
                    or products_smarts is None
                    or rdc_rxn is None
                ):
                    continue

                initialized_smirks_patterns.append(
                    InitializedSmirksPattern(
                        name=str(pattern.get("name", "")),
                        superclass_id=str(pattern.get("superclass_id", "")),
                        class_id=str(pattern.get("class_id", "")),
                        subclass_id=str(pattern.get("subclass_id", "")),
                        class_str=str(pattern.get("class_str", "")),
                        products_smarts=products_smarts,
                        reactants_smarts=reactants_smarts,
                        products_fps=[
                            PatternFingerprint(frag) for frag in products_smarts
                        ],
                        reactants_fps=[
                            PatternFingerprint(frag) for frag in reactants_smarts
                        ],
                        rdc_rxn=rdc_rxn,
                        parent_smirks=str(pattern.get("smirks", "")),
                        child_smirks=str(child_pattern),
                        template_name=str(pattern.get("name", "")),
                        priority=pattern_priority_tuple,
                    )
                )

        self._initialized_smirks_patterns = initialized_smirks_patterns

        return

    def _initialize_template_data_from_child_patterns(
        self,
        child_smirks: str,
    ) -> Tuple[
        Optional[List[Chem.Mol]],
        Optional[List[Chem.Mol]],
        Optional[rdc.rdchiralReaction],
    ]:
        """
        Initialize template data from child SMIRKS pattern.

        Args:
            child_smirks (str): Child SMIRKS pattern.

        Returns:
            Tuple[Optional[List[Chem.Mol]], Optional[List[Chem.Mol]], Optional[rdc.rdchiralReaction]]:
            Tuple of reactants SMARTS, products SMARTS, and rdchiral reaction.
        """
        products_smarts = [
            Chem.MolFromSmarts(smarts)
            for smarts in child_smirks.split(">>")[0].split(".")
        ]

        if None in products_smarts:
            return None, None, None

        reactants_smarts = [
            Chem.MolFromSmarts(smarts)
            for smarts in child_smirks.split(">>")[1].split(".")
        ]

        if None in reactants_smarts:
            return None, None, None

        try:
            rdc_rxn = rdc.rdchiralReaction(child_smirks)
        except Exception as e:
            logger.warning(f"Error converting smirks to rdchiral reaction: {e}")
            return None, None, None

        return reactants_smarts, products_smarts, rdc_rxn

    def _prepare_reaction_data(
        self,
        reactants_str: str,
        products_str: str,
        unmapped_product_atom_islands: Optional[Dict[int, Set[int]]] = None,
    ) -> ReactionData:
        """
        Prepare reaction mapping inputs from reactant and product SMILES strings.

        Args:
            reactants_str (str): Reactants SMILES string.
            products_str (str): Products SMILES string.
            unmapped_product_atom_islands (Optional[Dict[int, Set[int]]]): Product atom-island SMILES
                strings that are intentionally left unmapped. Defaults to None.

        Returns:
            ReactionData: Mapping input data containing RDKit molecule objects for
            reactants/products, an RDChiral reactants object derived from
            `products_str`, a tautomer SMILES dictionary, a fragment count dictionary,
            and the normalized `unmapped_product_atom_islands` list.
        """

        if unmapped_product_atom_islands is None:
            unmapped_product_atom_islands = {}

        product_mols = [
            Chem.MolFromSmiles(product_str) for product_str in products_str.split(".")
        ]
        reactant_mols = [
            Chem.MolFromSmiles(reactant_str)
            for reactant_str in reactants_str.split(".")
        ]

        return ReactionData(
            products_mols=product_mols,
            reactants_mols=reactant_mols,
            rdc_products=rdc.rdchiralReactants(products_str),
            tautomers_reactants=self._enumerate_tautomer_smiles(reactants_str),
            fragment_count_reactants=self._get_fragment_count_dict(reactants_str),
            unmapped_product_atom_islands=unmapped_product_atom_islands,
            product_mol_fps=[PatternFingerprint(mol) for mol in product_mols],
            reactant_mol_fps=[PatternFingerprint(mol) for mol in reactant_mols],
        )

    def _enumerate_tautomer_smiles(self, smiles: str) -> Dict[str, List[str]]:
        """
        Enumerate tautomer SMILES strings for a given SMILES string.

        Args:
            smiles (str): A SMILES string representing one or more molecular fragments.

        Returns:
            Dict[str, List[str]]: A dictionary where keys are the fragments of the input
                SMILES string and values are lists of the enumerated tautomer SMILES strings
                for each fragment.

        """
        enumerated_smiles_dict: Dict[str, List[str]] = {}
        for fragment_str in smiles.split("."):
            mol = Chem.MolFromSmiles(fragment_str)

            if mol is None:
                enumerated_smiles_dict[fragment_str] = []
                continue

            enumerated_fragment_mols = list(self._tautomer_enumerator.Enumerate(mol))
            enumerated_fragment_smiles = [
                Chem.MolToSmiles(frag_mol) for frag_mol in enumerated_fragment_mols
            ]
            enumerated_fragment_smiles.append(fragment_str)
            canonicalized_enumerated_fragment_smiles = [
                canonicalize_smiles(frag_smiles, canonicalize_tautomer=False)
                for frag_smiles in enumerated_fragment_smiles
                if frag_smiles
            ]
            enumerated_smiles_dict[fragment_str] = list(
                set(canonicalized_enumerated_fragment_smiles)
            )
        return enumerated_smiles_dict

    def _get_fragment_count_dict(self, smiles: str) -> Dict[str, int]:
        """
        Build a fragment-count mapping from a dot-delimited SMILES string.

        Args:
            smiles (str): A SMILES string where disconnected fragments are delimited by `.`.

        Returns:
            Dict[str, int]: A mapping of each fragment string to the number of times it
            appears in `smiles`.
        """
        fragment_count_dict = {}
        for fragment_str in smiles.split("."):
            canonical_fragment_str = Chem.MolToSmiles(Chem.MolFromSmiles(fragment_str))
            if canonical_fragment_str is None:
                continue
            if canonical_fragment_str not in fragment_count_dict:
                fragment_count_dict[canonical_fragment_str] = 1
            else:
                fragment_count_dict[canonical_fragment_str] += 1

        return fragment_count_dict

    def _apply_templates_and_collect_outcomes(
        self,
        reaction_smiles_data: ReactionData,
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 2,
    ) -> Dict[str, List[InitializedSmirksPattern]]:
        """
        Apply template SMIRKS patterns to a reaction and collect mapped outcomes.

        Args:
            reaction_smiles_data (ReactionData): Reaction data containing reactants,
                products, and precomputed helper structures used for template
                application.
            apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
            num_smirks_to_apply (int): The number of SMIRKS patterns to apply to the same reaction.

        Returns:
            Dict[str, List[InitializedSmirksPattern]]: A mapping from mapped outcome SMILES to a list of
            applied SMIRKS patterns.
        """
        mapped_outcomes_smirks_dict: Dict[str, List[InitializedSmirksPattern]] = {}

        atom_mapped_product = self._generate_mapped_product_smiles(reaction_smiles_data)
        outcomes_and_applied_smirks = self._apply_templates(
            reaction_smiles_data,
            apply_multiple_smirks=apply_multiple_smirks,
            num_smirks_to_apply=num_smirks_to_apply,
        )

        successfully_processed_reactants: Dict[str, str] = {}
        unsuccessfully_processed_reactants = []
        for outcome_and_applied_smirk in outcomes_and_applied_smirks:
            outcome_mapped_smiles = outcome_and_applied_smirk.get(
                "outcome_mapped_smiles"
            )
            if outcome_mapped_smiles is None:
                continue
            outcome_applied_smirk = outcome_and_applied_smirk.get("applied_smirk")
            if outcome_applied_smirk is None:
                continue

            if outcome_mapped_smiles in unsuccessfully_processed_reactants:
                continue

            if outcome_mapped_smiles in successfully_processed_reactants:
                good_reaction_smiles = successfully_processed_reactants[
                    outcome_mapped_smiles
                ]

                mapped_outcomes_smirks_dict[good_reaction_smiles].extend(
                    [outcome_applied_smirk]
                )
                continue

            outcome_reaction_smiles_dict = self._build_reaction_smiles_from_outcome(
                outcome_and_applied_smirk,
                reaction_smiles_data,
                atom_mapped_product,
            )

            if not outcome_reaction_smiles_dict:
                unsuccessfully_processed_reactants.append(outcome_mapped_smiles)

            for mapped_smiles, smirks_list in outcome_reaction_smiles_dict.items():
                if mapped_smiles not in mapped_outcomes_smirks_dict:
                    mapped_outcomes_smirks_dict[mapped_smiles] = smirks_list
                    successfully_processed_reactants[outcome_mapped_smiles] = (
                        mapped_smiles
                    )
                else:
                    mapped_outcomes_smirks_dict[mapped_smiles].extend(smirks_list)
                    successfully_processed_reactants[outcome_mapped_smiles] = (
                        mapped_smiles
                    )

        return mapped_outcomes_smirks_dict

    def _generate_mapped_product_smiles(
        self, reaction_smiles_data: ReactionData
    ) -> str:
        """
        Generate a mapped product SMILES string from reaction reactants.

        Args:
            reaction_smiles_data (ReactionData): Reaction data containing
                `rdc_products`, whose `reactants` molecule is annotated with
                atom-map numbers derived from the reactant index-to-map mapping.

        Returns:
            str: A SMILES string for the reactants molecule with atom-map numbers
            applied.
        """

        rdc_products = reaction_smiles_data["rdc_products"]

        # confusing rdchiral nomenclature - difference between retro and forward perspective
        rdc_products_mol = rdc_products.reactants

        for atom in rdc_products_mol.GetAtoms():
            atom.SetAtomMapNum(rdc_products.idx_to_mapnum(atom.GetIdx()))

        mapped_product = Chem.MolToSmiles(rdc_products_mol)

        return mapped_product

    def _fragment_fits_some_island(
        self,
        product_mols: List[Chem.Mol],
        unmapped_product_atom_islands: Dict[int, Set[int]],
        products_smarts_fragment: Chem.Mol,
    ) -> bool:
        """
        Check whether a SMARTS fragment has any overlap with an unmapped island.
        We can't do a full subset check because the SMARTS for the template
        may be overspecified, and include atoms that aren't actually changing in the
        reaction, and thus *are* mapped with the MCS mapper.

        Args:
            product_mols (List[Chem.Mol]): List of product mols.
            unmapped_product_atom_islands (Dict[int, Set[int]]): Dictionary mapping island IDs to sets of atom indices.
            products_smarts_fragment (Chem.Mol): SMARTS query fragment used to find
                substructure matches in the product molecules.

        Returns:
            bool: True if any substructure match has any overlap with an unmapped island;
            otherwise False.
        """
        for mol in product_mols:
            matches = mol.GetSubstructMatches(products_smarts_fragment)
            matches_set = [set(match) for match in matches]
            for match_set in matches_set:
                for island in unmapped_product_atom_islands.values():
                    if match_set & island:
                        return True
        return False

    def _matching_island_ids(
        self,
        unmapped_product_atom_islands_for_rdchiral: Dict[int, Set[int]],
        outcome_atom_map_indices: List[int],
    ) -> List[int]:
        """
        Find island IDs that contain any atom map indices of a given outcome.
        Can't do a full subset check for similar reasons as in
        _fragment_fits_some_island.

        Args:
            unmapped_product_atom_islands_for_rdchiral (Dict[int, Set[int]]): A dictionary
                mapping island IDs to sets of atom map indices for unmapped atoms in
                the product molecules (1-based indexing for rdchiral).
            outcome_atom_map_indices (List[int]): A list of atom map indices for
                a reaction outcome.

        Returns:
            List[int]: A list of island IDs where any atom map indices of the
                outcome are contained.
        """
        return [
            island_id
            for island_id, island in unmapped_product_atom_islands_for_rdchiral.items()
            if set(outcome_atom_map_indices) & island
        ]

    def _passes_fingerprint_screen(
        self,
        products_fps: List[DataStructs.ExplicitBitVect],
        reactants_fps: List[DataStructs.ExplicitBitVect],
        product_mol_fps: List[DataStructs.ExplicitBitVect],
        reactant_mol_fps: List[DataStructs.ExplicitBitVect],
    ) -> bool:
        """
        Check whether a template passes the fingerprint pre-screen against reaction molecule fingerprints.

        Fast bit-level check that eliminates templates whose required structural bits are
        absent from the reaction molecule fingerprints before running the more expensive
        substructure search.

        Args:
            products_fps (List[DataStructs.ExplicitBitVect]): Pattern fingerprints for the template product fragments.
            reactants_fps (List[DataStructs.ExplicitBitVect]): Pattern fingerprints for the template reactant fragments.
            product_mol_fps (List[DataStructs.ExplicitBitVect]): Pattern fingerprints for the reaction product molecules.
            reactant_mol_fps (List[DataStructs.ExplicitBitVect]): Pattern fingerprints for the reaction reactant molecules.

        Returns:
            bool: True if every template fragment fingerprint is subsumed by at least one
            reaction molecule fingerprint for both products and reactants; False otherwise.
        """
        if not all(
            any(
                DataStructs.AllProbeBitsMatch(q_fp, mol_fp)
                for mol_fp in product_mol_fps
            )
            for q_fp in products_fps
        ):
            return False

        if not all(
            any(
                DataStructs.AllProbeBitsMatch(q_fp, mol_fp)
                for mol_fp in reactant_mol_fps
            )
            for q_fp in reactants_fps
        ):
            return False

        return True

    def _passes_substructure_check(
        self,
        products_smarts: List[Chem.Mol],
        reactants_smarts: List[Chem.Mol],
        product_mols: List[Chem.Mol],
        reactant_mols: List[Chem.Mol],
        unmapped_product_atom_islands: Dict[int, Set[int]],
        has_islands: bool,
    ) -> bool:
        """
        Check whether a template's SMARTS fragments match the reaction molecules via substructure search.

        When unmapped product atom islands are present, the product check uses
        _fragment_fits_some_island to ensure each template fragment overlaps at least one
        unmapped island, rather than performing a plain substructure match.

        Args:
            products_smarts (List[Chem.Mol]): Parsed SMARTS fragments for the template products.
            reactants_smarts (List[Chem.Mol]): Parsed SMARTS fragments for the template reactants.
            product_mols (List[Chem.Mol]): RDKit molecule objects for the reaction products.
            reactant_mols (List[Chem.Mol]): RDKit molecule objects for the reaction reactants.
            unmapped_product_atom_islands (Dict[int, Set[int]]): Mapping from island ID to sets
                of unmapped atom indices in the product molecules (0-based RDKit indexing).
            has_islands (bool): Whether any unmapped product atom islands exist.

        Returns:
            bool: True if all template fragments match the respective reaction molecules;
            False otherwise.
        """
        # When islands exist, _fragment_fits_some_island subsumes the plain
        # substruct-match check (empty matches → no island overlap → False),
        # avoiding a second round of substructure searches.
        if has_islands:
            product_check_passes = all(
                self._fragment_fits_some_island(
                    product_mols, unmapped_product_atom_islands, frag
                )
                for frag in products_smarts
            )
        else:
            product_check_passes = all(
                any(mol.HasSubstructMatch(frag) for mol in product_mols)
                for frag in products_smarts
            )

        if not product_check_passes:
            return False

        return all(
            any(mol.HasSubstructMatch(frag) for mol in reactant_mols)
            for frag in reactants_smarts
        )

    def _collect_outcomes_for_template(
        self,
        template: InitializedSmirksPattern,
        rdc_products: rdc.rdchiralReactants,
        has_islands: bool,
        unmapped_product_atom_islands_for_rdchiral: Dict[int, Set[int]],
        num_smirks_applied: int,
    ) -> List[AppliedSmirkData]:
        """
        Run rdchiral on a single template and collect all valid mapped outcomes.

        Args:
            template (InitializedSmirksPattern): The initialized SMIRKS template to apply.
            rdc_products (rdc.rdchiralReactants): The rdchiral reactants object derived from
                the reaction product SMILES.
            has_islands (bool): Whether any unmapped product atom islands exist.
            unmapped_product_atom_islands_for_rdchiral (Dict[int, Set[int]]): Mapping from
                island ID to sets of atom map indices for unmapped product atoms
                (1-based rdchiral indexing).
            num_smirks_applied (int): Number of SMIRKS patterns already applied in the
                current mapping chain; used to set num_smirks_applied on each outcome.

        Returns:
            List[AppliedSmirkData]: A list of outcome data for each valid template application.
            Returns an empty list if rdchiral raises an exception or no valid outcomes are found.
        """
        rdc_rxn = template["rdc_rxn"]
        outcomes: List[AppliedSmirkData] = []
        try:
            _, outcomes_dict = rdc.rdchiralRun(
                rdc_rxn, rdc_products, return_mapped=True, combine_enantiomers=False
            )

            for k, v in outcomes_dict.items():
                matching_ids = []
                if has_islands:
                    matching_ids = self._matching_island_ids(
                        unmapped_product_atom_islands_for_rdchiral, v[1]
                    )

                if not matching_ids:
                    continue

                for island_id in matching_ids:
                    outcomes.append(
                        AppliedSmirkData(
                            outcome_unmapped_smiles=k,
                            outcome_mapped_smiles=v[0],
                            outcome_atom_map_indices=v[1],
                            applied_smirk=template,
                            outcome_to_island_id=island_id,
                            num_smirks_applied=num_smirks_applied + 1,
                        )
                    )

                    ## TODO: Check if not Chem.MolFromSmiles(k) - identify bad templates

        except Exception as e:
            logger.warning(f"Error applying templates: {e}")

        return outcomes

    def _find_multi_smirks_outcomes(
        self,
        single_island_outcomes: List[AppliedSmirkData],
        all_island_ids: Set[int],
        rdc_products: rdc.rdchiralReactants,
        unmapped_product_atom_islands_for_rdchiral: Dict[int, Set[int]],
        max_combinations: int = 500,
    ) -> List[AppliedSmirkData]:
        """
        Generate combined multi-SMIRKS outcomes by merging child SMIRKS patterns from
        single-island outcomes across all unmapped product atom islands.

        For each combination of one outcome per island, the corresponding child SMIRKS
        are combined into a single multi-reaction SMIRKS (with atom map numbers offset
        by multiples of 100 to prevent collisions) and applied to the original product
        via rdchiral. Outcomes that share the same child_smirks for a given island are
        deduplicated before combining to reduce redundant rdchiral calls. Each rdchiral
        outcome is validated to confirm its changed atom indices intersect every island,
        discarding results where the combined template only fired on a subset of islands.

        Args:
            single_island_outcomes (List[AppliedSmirkData]): Outcomes from single-template
                applications, each associated with exactly one island via outcome_to_island_id.
            all_island_ids (Set[int]): The complete set of unmapped product atom island IDs
                that must all be covered.
            rdc_products (rdc.rdchiralReactants): The rdchiral reactants object for the
                original product, reused for each combined SMIRKS application.
            unmapped_product_atom_islands_for_rdchiral (Dict[int, Set[int]]): Mapping from
                island ID to sets of 1-based atom indices used to validate that each outcome
                covers every island.
            max_combinations (int): Maximum number of template combinations to attempt
                before stopping, to prevent combinatorial explosion.

        Returns:
            List[AppliedSmirkData]: Combined outcomes where num_smirks_applied equals
            the number of islands and outcome_to_island_id is None. Returns an empty list
            if any island has no applicable templates or all combinations fail.
        """
        if not all_island_ids:
            return []

        by_island: Dict[int, List[AppliedSmirkData]] = defaultdict(list)
        seen_per_island_smirks: Dict[int, Set[str]] = defaultdict(set)
        seen_per_island_reactants: Dict[int, Set[str]] = defaultdict(set)

        for outcome in single_island_outcomes:
            iid = outcome["outcome_to_island_id"]
            if iid not in all_island_ids:
                continue
            child_smirks = outcome["applied_smirk"]["child_smirks"]
            unmapped_reactants = outcome["outcome_unmapped_smiles"]
            if (
                child_smirks not in seen_per_island_smirks[iid]
                and unmapped_reactants not in seen_per_island_reactants[iid]
            ):
                seen_per_island_smirks[iid].add(child_smirks)
                seen_per_island_reactants[iid].add(unmapped_reactants)
                by_island[iid].append(outcome)

        if not all(by_island[iid] for iid in all_island_ids):
            return []

        multi_outcomes: List[AppliedSmirkData] = []
        island_order = sorted(all_island_ids)

        num_applied_templates = 0
        for n, combo in enumerate(
            iterproduct(*[by_island[iid] for iid in island_order])
        ):
            if n >= max_combinations:
                logger.warning(
                    "max_combinations cap reached in _find_multi_smirks_outcomes"
                )
                break

            combined_smirks = _combine_child_smirks(
                [o["applied_smirk"]["child_smirks"] for o in combo]
            )

            try:
                combined_rxn = rdc.rdchiralReaction(combined_smirks)
                _, outcomes_dict = rdc.rdchiralRun(
                    combined_rxn,
                    rdc_products,
                    return_mapped=True,
                    combine_enantiomers=False,
                )
            except Exception as e:
                logger.warning(f"Error applying combined SMIRKS: {e}")
                continue

            num_applied_templates += len(outcomes_dict)
            if num_applied_templates > max_combinations:
                logger.warning(
                    "max_combinations cap reached in _find_multi_smirks_outcomes"
                )
                break

            composite = _make_composite_smirks_pattern(
                [o["applied_smirk"] for o in combo],
                combined_smirks,
                combined_rxn,
            )

            for k, v in outcomes_dict.items():
                outcome_indices = set(v[1])
                if not all(
                    outcome_indices & unmapped_product_atom_islands_for_rdchiral[iid]
                    for iid in all_island_ids
                ):
                    continue
                multi_outcomes.append(
                    AppliedSmirkData(
                        outcome_unmapped_smiles=k,
                        outcome_mapped_smiles=v[0],
                        outcome_atom_map_indices=list(v[1]),
                        applied_smirk=composite,
                        outcome_to_island_id=None,
                        num_smirks_applied=len(combo),
                    )
                )

        return multi_outcomes

    def _apply_templates(
        self,
        reaction_smiles_data: ReactionData,
        num_smirks_applied: int = 0,
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 2,
    ) -> List[AppliedSmirkData]:
        """
        Apply all initialized SMIRKS templates to a reaction and collect mapped outcomes.

        Each template is screened with a fast fingerprint pre-check, then a substructure
        check, and finally run through rdchiral via _collect_outcomes_for_template. Outcomes
        are collected only when they overlap at least one unmapped product atom island (if
        any islands exist).

        Args:
            reaction_smiles_data (ReactionData): Reaction data containing molecule objects,
                rdchiral reactants, fingerprints, and unmapped product atom islands.
            num_smirks_applied (int): Number of SMIRKS patterns already applied in the
                current mapping chain.
            apply_multiple_smirks (bool): Whether to recursively apply multiple SMIRKS
                patterns (not yet implemented).
            num_smirks_to_apply (int): Maximum number of SMIRKS patterns to apply
                (not yet used).

        Returns:
            List[AppliedSmirkData]: A list of outcome data for all valid template applications.

        Raises:
            ValueError: If SMIRKS patterns were not initialized before calling this method.
        """
        product_mols = reaction_smiles_data["products_mols"]
        reactant_mols = reaction_smiles_data["reactants_mols"]
        rdc_products = reaction_smiles_data["rdc_products"]
        unmapped_product_atom_islands = reaction_smiles_data[
            "unmapped_product_atom_islands"
        ]
        product_mol_fps = reaction_smiles_data["product_mol_fps"]
        reactant_mol_fps = reaction_smiles_data["reactant_mol_fps"]

        # rdchiral uses 1-based indexing, but rdkit uses 0-based indexing
        # so we need another dictionary specifically for rdchiral
        unmapped_product_atom_islands_for_rdchiral = {
            key: {ele + 1 for ele in value}
            for key, value in unmapped_product_atom_islands.items()
        }

        has_islands = bool(unmapped_product_atom_islands)

        if self._initialized_smirks_patterns is None:
            raise ValueError("SMIRKS patterns were not initialized correctly.")

        outcomes_and_applied_smirks: List[AppliedSmirkData] = []

        if num_smirks_to_apply < len(unmapped_product_atom_islands_for_rdchiral):
            return outcomes_and_applied_smirks

        for template in self._initialized_smirks_patterns:
            ## TODO: Check for tautomer matches?

            if not self._passes_fingerprint_screen(
                template["products_fps"],
                template["reactants_fps"],
                product_mol_fps,
                reactant_mol_fps,
            ):
                continue

            if not self._passes_substructure_check(
                template["products_smarts"],
                template["reactants_smarts"],
                product_mols,
                reactant_mols,
                unmapped_product_atom_islands,
                has_islands,
            ):
                continue

            outcomes_and_applied_smirks.extend(
                self._collect_outcomes_for_template(
                    template,
                    rdc_products,
                    has_islands,
                    unmapped_product_atom_islands_for_rdchiral,
                    num_smirks_applied,
                )
            )

        if not apply_multiple_smirks:
            return outcomes_and_applied_smirks

        if len(unmapped_product_atom_islands_for_rdchiral) == 1:
            return outcomes_and_applied_smirks

        multi_outcomes = self._find_multi_smirks_outcomes(
            single_island_outcomes=outcomes_and_applied_smirks,
            all_island_ids=set(unmapped_product_atom_islands.keys()),
            rdc_products=rdc_products,
            unmapped_product_atom_islands_for_rdchiral=unmapped_product_atom_islands_for_rdchiral,
        )

        return multi_outcomes

    def _remove_spectator_mappings(self, smiles: str) -> str:
        """
        Remove spectator atom-mapping numbers from a SMILES string.

        This clears atom map numbers greater than or equal to 900 (used for
        spectator fragments) by setting them to 0 for each valid SMILES
        fragment.

        Args:
            smiles (str): A SMILES string, potentially containing multiple
                fragments separated by `.`.

        Returns:
            str: A SMILES string with spectator atom mappings removed.
        """
        smiles_fragments = smiles.split(".")
        mol_fragments = []
        for smiles_fragment in smiles_fragments:
            mol_fragment = Chem.MolFromSmiles(smiles_fragment)
            if mol_fragment is None:
                continue
            for atom in mol_fragment.GetAtoms():
                if atom.GetAtomMapNum() >= 900:
                    atom.SetAtomMapNum(0)
            mol_fragments.append(mol_fragment)
        return ".".join(
            [Chem.MolToSmiles(mol_fragment) for mol_fragment in mol_fragments]
        )

    def _build_reaction_smiles_from_outcome(
        self,
        outcome_and_applied_smirk: AppliedSmirkData,
        reaction_smiles_data: ReactionData,
        atom_mapped_product: str,
    ) -> Dict[str, List[InitializedSmirksPattern]]:
        """
        Process a single SMIRKS application outcome to construct a finalized reaction SMILES.

        Removes spectator mappings from the mapped outcome, identifies and handles missing
        fragments by matching against original reactant tautomers, determines spectator
        molecules that should be added back, and assembles the final reaction SMILES string.

        Args:
            outcome_and_applied_smirk (AppliedSmirkData): Data containing the mapped outcome
                SMILES and the SMIRKS pattern that was applied.
            reaction_smiles_data (ReactionData): Reaction data containing tautomer
                dictionaries and fragment count information for the original reactants.
            atom_mapped_product (str): Atom-mapped product SMILES string.

        Returns:
            Dict[str, List[InitializedSmirksPattern]]:
                A dictionary mapping the finalized reaction SMILES (reactants + spectators >> product)
                to a list containing the applied SMIRKS pattern. Returns an empty dict if
                processing fails (e.g., missing fragments cannot be resolved).
        """
        mapped_outcome = outcome_and_applied_smirk.get("outcome_mapped_smiles")
        applied_smirk = outcome_and_applied_smirk.get("applied_smirk")

        if mapped_outcome is None or applied_smirk is None:
            return {}

        tautomers_reactants = reaction_smiles_data["tautomers_reactants"]
        fragment_count_dict = reaction_smiles_data["fragment_count_reactants"]

        mapped_outcome = self._remove_spectator_mappings(mapped_outcome)

        missing_fragments, found_fragments = self._find_missing_fragments(
            mapped_outcome, tautomers_reactants
        )

        if len(missing_fragments) != 0:
            fragment_mapped_dict = self._validate_and_map_missing_fragments(
                missing_fragments,
                found_fragments,
                tautomers_reactants,
            )

            if len(fragment_mapped_dict) != len(missing_fragments):
                return {}

            for k, v in fragment_mapped_dict.items():
                if k not in mapped_outcome:
                    return {}
                mapped_outcome = mapped_outcome.replace(k, v)

        unmapped_canonical_smiles_for_mapped_smiles = [
            canonicalize_smiles(fragment_smiles, canonicalize_tautomer=False)
            for fragment_smiles in mapped_outcome.split(".")
        ]

        spectators = []
        for fragment, fragment_count in fragment_count_dict.items():
            num_occurrences_mapped = unmapped_canonical_smiles_for_mapped_smiles.count(
                fragment
            )
            missing_fragment_count = fragment_count - num_occurrences_mapped
            if missing_fragment_count > 0:
                spectators.extend([fragment] * missing_fragment_count)

        reactants = mapped_outcome.split(".")
        reactants_and_spectators = reactants + spectators

        finalized_reaction_smiles = (
            ".".join(reactants_and_spectators) + ">>" + atom_mapped_product
        )

        return {finalized_reaction_smiles: [applied_smirk]}

    def _find_missing_fragments(
        self, mapped_outcome: str, unmapped_reactants: Dict[str, List[str]]
    ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
        """
        Identify fragments from a mapped outcome that are missing from unmapped reactants.

        Compares each fragment in the mapped outcome (after canonicalization) against
        the flattened list of unmapped reactant fragments to determine which fragments
        are present in the original reactants and which are newly formed or missing.

        Args:
            mapped_outcome (str): SMILES string of the mapped reaction outcome with
                atom mapping numbers, with fragments separated by dots.
            unmapped_reactants (Dict[str, List[str]]): Dictionary mapping original
                reactant SMILES to lists of enumerated tautomer SMILES for unmapped reactants.

        Returns:
            Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
                - First list: Tuples of (canonical_unmapped_fragment, original_mapped_fragment)
                    for fragments not found in the unmapped reactants (missing fragments).
                - Second list: Tuples of (canonical_unmapped_fragment, original_mapped_fragment)
                    for fragments found in the unmapped reactants (found fragments).
        """
        missing_fragments = []
        found_fragments = []
        reactant_fragments = list(unmapped_reactants.values())
        flattened_reactant_fragments = [
            item for sublist in reactant_fragments for item in sublist
        ]

        for mapped_fragment in mapped_outcome.split("."):
            unmapped_fragment = canonicalize_smiles(
                mapped_fragment, canonicalize_tautomer=False
            )
            if unmapped_fragment not in flattened_reactant_fragments:
                missing_fragments.append((unmapped_fragment, mapped_fragment))
            else:
                found_fragments.append((unmapped_fragment, mapped_fragment))

        return missing_fragments, found_fragments

    def _validate_and_map_missing_fragments(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> Dict[str, str]:
        """
        Identify and map missing reaction fragments to unmapped reactants.

        Validates that missing fragments are substructures of unmapped reactants, then
        attempts to map each missing fragment to its corresponding location in the
        unmapped reactant molecules by transferring atom mapping information.

        Args:
            missing_fragments (List[Tuple[str, str]]): List of tuples containing
                (unmapped_fragment, mapped_fragment) for fragments not found in reactants.
            found_fragments (List[Tuple[str, str]]): List of tuples containing
                (unmapped_fragment, mapped_fragment) for fragments already mapped to reactants.
            unmapped_reactants (Dict[str, List[str]]): Dictionary mapping original reactant SMILES
                to lists of enumerated tautomer SMILES for unmapped reactants.

        Returns:
            Dict[str, str]: Mapping from mapped fragment SMILES to their corresponding
                identified unmapped fragment SMILES with atom mapping transferred.
                Returns an empty dict if fragments cannot be mapped or if multiple
                possible mappings exist for any fragment.

        Note:
            If multiple possible fragments are identified for any reaction SMARTS
            substructure, a warning is logged and an empty dict is returned to
            indicate ambiguous mapping.
        """
        all_fragments_substructs = self._are_fragments_substructures(
            missing_fragments, found_fragments, unmapped_reactants
        )
        if not all_fragments_substructs:
            return {}

        fragment_mapped_dict = self._identify_and_map_fragments(
            missing_fragments,
            found_fragments,
            unmapped_reactants,
        )

        filtered_fragment_mapped_dict = {}
        for k, v in fragment_mapped_dict.items():
            if len(v) > 1:
                logger.warning(
                    "Multiple possible fragments identified for reaction SMARTS substructure"
                )
                return {}
            filtered_fragment_mapped_dict[k] = v[0]

        return filtered_fragment_mapped_dict

    def _are_fragments_substructures(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> bool:
        """
        Check if missing fragments with wildcards are substructures of unmapped reactants.

        For each missing fragment that contains a wildcard ("*"), this method checks
        whether the fragment pattern matches as a substructure within any of the
        unmapped reactant fragments. This is used to validate whether missing fragments
        can be accounted for as parts of larger reactant molecules.

        Args:
            missing_fragments (List[Tuple[str, str]]): List of tuples containing
                (fragment_smarts, mapped_fragment) for fragments not yet mapped to reactants.
            found_fragments (List[Tuple[str, str]]): List of tuples containing
                (fragment_smarts, mapped_fragment) for fragments already mapped to reactants.
            unmapped_reactants (Dict[str, List[str]]): Dictionary mapping original reactant
                SMILES to lists of tautomer SMILES that have not been mapped to fragments.

        Returns:
            bool: True if all missing fragments with wildcards are found as substructures
                within unmapped reactants; False otherwise. Returns True if no wildcards
                are present in missing fragments.

        Note:
            Fragments without wildcards ("*") are skipped from substructure matching.
            If any SMARTS parsing fails, the method returns False immediately.
        """
        unmapped_found_fragments = [
            unmapped_fragment[0] for unmapped_fragment in found_fragments
        ]
        for fragment_str, _ in missing_fragments:
            if "*" not in fragment_str:
                continue

            query_mol = Chem.MolFromSmarts(fragment_str)
            if not query_mol:
                return False

            found_match = False
            for reactant_group in unmapped_reactants.values():
                for reactant_fragment_str in reactant_group:
                    if reactant_fragment_str in unmapped_found_fragments:
                        continue
                    reactant_mol = Chem.MolFromSmarts(reactant_fragment_str)
                    if not reactant_mol:
                        return False
                    reactant_mol.UpdatePropertyCache()
                    if reactant_mol.HasSubstructMatch(query_mol):
                        found_match = True
                        break

                if found_match:
                    break

            if not found_match:
                return False

        return True

    def _identify_and_map_fragments(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> Dict[str, List[str]]:
        """ """
        unmapped_found_fragments = [ele[0] for ele in found_fragments]
        fragment_mapped_dict = {}
        for _, mapped_reactant_fragment in missing_fragments:
            fragment_found = False
            for orig_tautomer, tautomer_list in unmapped_reactants.items():
                for tautomer in tautomer_list:
                    if tautomer in unmapped_found_fragments:
                        continue

                    out = self._transfer_mapping(mapped_reactant_fragment, tautomer)

                    if not out:
                        continue

                    # TODO: Is this even needed if we just take all possible fragments in _validate_and_map_missing_fragments?
                    if len(tautomer_list) > 1:
                        mapped_enumerated_tautomers = list(
                            self._tautomer_enumerator.Enumerate(Chem.MolFromSmiles(out))
                        )
                        for mapped_enumerated_tautomer in mapped_enumerated_tautomers:
                            unmapped_tautomer_copy = Chem.Mol(
                                mapped_enumerated_tautomer
                            )
                            [
                                atom.SetAtomMapNum(0)
                                for atom in unmapped_tautomer_copy.GetAtoms()
                            ]
                            if Chem.MolToSmiles(
                                unmapped_tautomer_copy
                            ) == Chem.MolToSmiles(Chem.MolFromSmiles(orig_tautomer)):
                                out = Chem.MolToSmiles(mapped_enumerated_tautomer)
                                break

                    fragment_found = True

                    if mapped_reactant_fragment not in fragment_mapped_dict:
                        fragment_mapped_dict[mapped_reactant_fragment] = [out]
                    else:
                        existing_mapped_fragments = fragment_mapped_dict[
                            mapped_reactant_fragment
                        ]
                        existing_mapped_fragments.append(out)
                        fragment_mapped_dict[mapped_reactant_fragment] = sorted(
                            list(set(existing_mapped_fragments))
                        )

            if not fragment_found:
                return {}

        return fragment_mapped_dict

    def _transfer_mapping(
        self, mapped_substructure_smarts: str, full_molecule_smiles: str
    ) -> str | None:
        """
        Transfer atom mapping numbers from a query mol to a full molecule.

        The query mol is matched against the molecule, and if a unique symmetric match is found,
        the atom map numbers from the query mol are transferred to the corresponding atoms
        in the molecule. Map numbers in the range [1, 899] are transferred; map numbers
        >= 900 (spectator atoms) and map number 0 (unmapped) are ignored.

        Args:
            mapped_substructure_smarts (str): A SMARTS string representing the substructure
                query mol with atom map numbers to transfer.
            full_molecule_smiles (str): A SMILES string representing the full molecule
                to receive the atom map numbers.

        Returns:
            str | None: The mapped SMILES string with transferred atom map numbers, or None
                if the query mol cannot be parsed, the molecule cannot be parsed, no substructure
                match is found, or multiple non-symmetric matches are found.

        Raises:
            None: This method does not raise exceptions; it returns None on any error.
        """
        query_mol = Chem.MolFromSmarts(mapped_substructure_smarts)
        if not query_mol:
            return None

        mol = Chem.MolFromSmiles(full_molecule_smiles)
        if not mol:
            return None

        match_indices = mol.GetSubstructMatches(query_mol)

        if not match_indices:
            return None

        symmetry_class = {
            k: v
            for k, v in enumerate(
                list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
            )
        }

        symmetric = True
        for match_1 in match_indices:
            for match_2 in match_indices:
                for ele1, ele2 in zip(match_1, match_2):
                    if symmetry_class[ele1] != symmetry_class[ele2]:
                        symmetric = False
                        break

        if len(match_indices) != 1 and not symmetric:
            return None

        match_indices = match_indices[0]

        for atom in mol.GetAtoms():
            if atom.GetAtomMapNum() != 0:
                atom.SetAtomMapNum(0)

        for query_mol_atom in query_mol.GetAtoms():
            map_num = query_mol_atom.GetAtomMapNum()

            if map_num > 0 and map_num < 900:
                query_mol_atom_idx = query_mol_atom.GetIdx()
                mol_idx = match_indices[query_mol_atom_idx]
                mol_atom = mol.GetAtomWithIdx(mol_idx)
                mol_atom.SetAtomMapNum(map_num)

        mapped_smiles_output = Chem.MolToSmiles(mol)
        return mapped_smiles_output

    def _get_unmapped_product_atom_islands(self, smiles: str) -> Dict[int, Set[int]]:
        """
        Find connected components ("islands") of unmapped atoms in a product SMILES.

        Args:
            smiles (str): Product SMILES string to analyze.

        Returns:
            Dict[int, Set[int]]: Mapping from island index (0..N-1) to a set of RDKit
            atom indices belonging to that connected component, considering only atoms
            with atom map number equal to 0.

        Raises:
            ValueError: If the SMILES cannot be parsed into an RDKit molecule.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Could not parse SMILES: {smiles}")

        unmapped = {
            atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 0
        }

        visited = set()
        islands: Dict[int, Set[int]] = {}

        for idx in unmapped:
            if idx in visited:
                continue

            island: Set[int] = set()
            queue = deque([idx])
            visited.add(idx)

            while queue:
                current = queue.popleft()
                island.add(current)

                for neighbor in mol.GetAtomWithIdx(current).GetNeighbors():
                    neighbor_idx = neighbor.GetIdx()
                    if neighbor_idx in unmapped and neighbor_idx not in visited:
                        visited.add(neighbor_idx)
                        queue.append(neighbor_idx)

            islands[len(islands)] = island

        return islands

    def _select_preferred_mapping(
        self, possible_outcomes: Dict[str, List[InitializedSmirksPattern]]
    ) -> str:
        """
        Select the preferred mapping from a list of mappings.

        Args:
            possible_outcomes (Dict[str, List[InitializedSmirksPattern]]): Dictionary of possible mappings.

        Returns:
            str: The selected preferred mapping.
        """
        selected_mapping = ""
        max_num_mapped_product_atoms = 0
        highest_priority_class = (0, 0)
        for canonicalized_mapping, possible_mappings in possible_outcomes.items():
            for possible_mapping in possible_mappings:
                mapping_priority = possible_mapping.get("priority", (0, 0))

                if highest_priority_class[0] != 0 or mapping_priority[0] != 0:
                    if highest_priority_class[0] > mapping_priority[0]:
                        continue
                    if (
                        highest_priority_class[0] == mapping_priority[0]
                        and highest_priority_class[1] > mapping_priority[1]
                    ):
                        continue
                    if highest_priority_class[0] < mapping_priority[0]:
                        highest_priority_class = mapping_priority
                        selected_mapping = canonicalized_mapping
                        continue
                    if (
                        highest_priority_class[0] == mapping_priority[0]
                        and mapping_priority[1] > highest_priority_class[1]
                    ):
                        highest_priority_class = mapping_priority
                        selected_mapping = canonicalized_mapping
                        continue

                mapping_num_fragments = len(possible_mapping["reactants_smarts"])
                mapping_num_atoms = mapping_num_fragments
                for mol in possible_mapping["reactants_smarts"]:
                    mapping_num_atoms += len(mol.GetAtoms())
                for mol in possible_mapping["products_smarts"]:
                    mapping_num_atoms += len(mol.GetAtoms())
                if max_num_mapped_product_atoms < mapping_num_atoms:
                    selected_mapping = canonicalized_mapping
                    max_num_mapped_product_atoms = mapping_num_atoms

        return selected_mapping

    def _filter_and_deduplicate_outcomes(
        self,
        mapped_outcomes_smirks_dict: Dict[str, List[InitializedSmirksPattern]],
        canonicalized_input_smiles: str,
        original_smiles: str,
    ) -> Optional[ReactionMapperResult]:
        """
        Filter, validate, and deduplicate mapped reaction outcomes.

        Args:
            mapped_outcomes_smirks_dict (Dict[str, InitializedSmirksPattern]): Map of
                candidate mapped reaction SMIRKS/SMILES strings to their originating
                initialized pattern metadata.
            canonicalized_input_smiles (str): Canonicalized representation of the
                input reaction (used to discard outcomes that do not match).
            original_smiles (str): Original (non-canonicalized) input reaction string
                to return in the result.

        Returns:
            Optional[ReactionMapperResult]: A single selected mapping and associated
                candidate metadata if exactly one valid, matching outcome remains;
                otherwise, returns None.
        """
        deduplicated_mapped_outcomes: Dict[str, List[InitializedSmirksPattern]] = {}
        for (
            candidate_reaction_smiles,
            applied_patterns,
        ) in mapped_outcomes_smirks_dict.items():
            if not candidate_reaction_smiles:
                continue
            canonicalized_candidate_reaction_smiles = canonicalize_atom_mapping(
                canonicalize_reaction_smiles(
                    candidate_reaction_smiles,
                    canonicalize_tautomer=False,
                    remove_mapping=False,
                )
            )
            if not canonicalized_candidate_reaction_smiles:
                continue
            if (
                canonicalize_reaction_smiles(
                    canonicalized_candidate_reaction_smiles, canonicalize_tautomer=False
                )
                != canonicalized_input_smiles
            ):
                continue
            if not self._verify_validity_of_mapping(
                canonicalized_candidate_reaction_smiles
            ):
                continue

            if (
                canonicalized_candidate_reaction_smiles
                not in deduplicated_mapped_outcomes
            ):
                deduplicated_mapped_outcomes[
                    canonicalized_candidate_reaction_smiles
                ] = applied_patterns
            else:
                deduplicated_mapped_outcomes[
                    canonicalized_candidate_reaction_smiles
                ].extend(applied_patterns)

        if len(deduplicated_mapped_outcomes) == 0:
            return None
        if len(deduplicated_mapped_outcomes) > 1:
            logger.warning("Multiple possible mappings")
            selected_mapping = self._select_preferred_mapping(
                deduplicated_mapped_outcomes
            )
        else:
            selected_mapping = list(deduplicated_mapped_outcomes.keys())[0]

        return ReactionMapperResult(
            original_smiles=original_smiles,
            selected_mapping=selected_mapping,
            possible_mappings=deduplicated_mapped_outcomes,
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[],
        )

    def map_reaction_with_mcs_optimization(
        self,
        reaction_smiles: str,
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 2,
    ) -> Tuple[ReactionMapperResult, ReactionMapperResult]:
        """
        Map a reaction SMILES string using template-based atom mapping with optimization
        that uses MCS to identify probable reaction center.

        Args:
            reaction_smiles (str): Reaction SMILES to map.
            apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
            num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

        Returns:
            Tuple[ReactionMapperResult, ReactionMapperResult]: A tuple containing the template-based
                mapping result and the MCS mapping result.
        """
        self._initialize_smirks_patterns()

        if not self._reaction_smiles_valid(reaction_smiles):
            return (
                self._return_default_mapping_dict(reaction_smiles),
                self._return_default_mapping_dict(reaction_smiles),
            )

        canonicalized_reaction_smiles = canonicalize_reaction_smiles(
            reaction_smiles, canonicalize_tautomer=False
        )

        unmapped_product_atom_islands = {}
        if self._mcs_mapper is not None:
            mcs_result = self._mcs_mapper.map_reaction(canonicalized_reaction_smiles)

            if mcs_result["selected_mapping"] != "":
                unmapped_product_atom_islands = self._get_unmapped_product_atom_islands(
                    mcs_result["selected_mapping"].split(">>")[1]
                )

        reactants_str, products_str = self._split_reaction_components(
            canonicalized_reaction_smiles
        )

        reaction_data = self._prepare_reaction_data(
            reactants_str, products_str, unmapped_product_atom_islands
        )

        mapped_outcomes_smirks_dict = self._apply_templates_and_collect_outcomes(
            reaction_data,
            apply_multiple_smirks=apply_multiple_smirks,
            num_smirks_to_apply=num_smirks_to_apply,
        )

        result = self._filter_and_deduplicate_outcomes(
            mapped_outcomes_smirks_dict, canonicalized_reaction_smiles, reaction_smiles
        )
        if not result:
            return self._return_default_mapping_dict(reaction_smiles), mcs_result

        return result, mcs_result

    def map_reaction(
        self,
        reaction_smiles: str,
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 2,
    ) -> ReactionMapperResult:
        """
        Map a reaction SMILES string using template-based atom mapping.

        This is a convenience method that calls map_reaction_with_mcs_optimization and returns only the main mapping result.

        Args:
            reaction_smiles (str): Reaction SMILES to map.
            apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
            num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

        Returns:
            ReactionMapperResult: A mapping result containing the selected mapping and
            related metadata. If the input is invalid or no unique valid mapping can be
            produced, an "empty" default result is returned.
        """
        result, _ = self.map_reaction_with_mcs_optimization(
            reaction_smiles, apply_multiple_smirks, num_smirks_to_apply
        )
        return result

    def map_reactions(
        self,
        reaction_list: List[str],
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 2,
    ) -> List[ReactionMapperResult]:
        """
        Map a list of reaction SMILES strings using this mapper.

        Args:
            reaction_list (List[str]): Reaction SMILES strings to map.
            apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
            num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

        Returns:
            List[ReactionMapperResult]: Mapping results in the same order as the input.
        """

        mapped_reactions = []
        for reaction in reaction_list:
            mapped_reactions.append(
                self.map_reaction(reaction, apply_multiple_smirks, num_smirks_to_apply)
            )
        return mapped_reactions

__init__(mapper_name, mapper_weight=3, custom_smirks_patterns=None, use_default_smirks_patterns=True, max_transforms=1000, max_tautomers=1000, use_mcs_mapping=True)

Initialize the TemplateMapper instance.

Parameters:

Name Type Description Default
custom_smirks_patterns List[Dict]

A list of dictionaries containing custom SMIRKS patterns. Each dictionary should have a 'name' key, a 'smirks' key, and a 'superclass_id' key.

None
use_default_smirks_patterns bool

Whether to use the default SMIRKS patterns.

True
Source code in agave_chem/mappers/template/template_mapper.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def __init__(
    self,
    mapper_name: str,
    mapper_weight: float = 3,
    custom_smirks_patterns: List[SmirksPattern] | None = None,
    use_default_smirks_patterns: bool = True,
    max_transforms: int = 1000,
    max_tautomers: int = 1000,
    use_mcs_mapping: bool = True,
):
    """
    Initialize the TemplateMapper instance.

    Args:
        custom_smirks_patterns (List[Dict]): A list of dictionaries containing
            custom SMIRKS patterns. Each dictionary should have a 'name' key,
            a 'smirks' key, and a 'superclass_id' key.
        use_default_smirks_patterns (bool): Whether to use the default SMIRKS
            patterns.
    """

    super().__init__("template", mapper_name, mapper_weight)

    if custom_smirks_patterns is not None:
        if not isinstance(custom_smirks_patterns, list):
            raise TypeError(
                "Invalid input: custom_smirks_patterns must be a list of dictionaries."
            )
        for pattern in custom_smirks_patterns:
            if set(pattern.keys()) != set(["name", "smirks", "superclass_id"]):
                raise TypeError(
                    "Invalid input: each dictionary in custom_smirks_patterns must have 'name', 'smirks', and 'superclass_id' keys."
                )
            for key, value in pattern.items():
                if key == "superclass_id":
                    if value is not None and not isinstance(value, int):
                        raise TypeError(
                            "Invalid input: 'superclass_id' value must be an integer or None."
                        )
                else:
                    if not isinstance(value, str):
                        raise TypeError(
                            "Invalid input: 'name' and 'smirks' values must be strings."
                        )

    self._custom_smirks_patterns = custom_smirks_patterns
    self._use_default_smirks_patterns = use_default_smirks_patterns

    smirks_patterns_file = files("agave_chem.datafiles.smirks_patterns").joinpath(
        "smirks_patterns_with_children.json"
    )
    with smirks_patterns_file.open("r") as f:
        self._uninitialized_smirks_patterns = json.load(f)
    self._initialized_smirks_patterns: Optional[List[InitializedSmirksPattern]] = (
        None
    )

    self._tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
    self._tautomer_enumerator.SetMaxTransforms(max_transforms)
    self._tautomer_enumerator.SetMaxTautomers(max_tautomers)

    self._mcs_mapper = None
    if use_mcs_mapping:
        self._mcs_mapper = MCSReactionMapper(
            mapper_name="mcs_for_template", mapper_weight=1
        )

map_reaction(reaction_smiles, apply_multiple_smirks=True, num_smirks_to_apply=2)

Map a reaction SMILES string using template-based atom mapping.

This is a convenience method that calls map_reaction_with_mcs_optimization and returns only the main mapping result.

Parameters:

Name Type Description Default
reaction_smiles str

Reaction SMILES to map.

required
apply_multiple_smirks bool

Whether to apply multiple SMIRKS patterns to the same reaction.

True
num_smirks_to_apply int

Number of SMIRKS patterns to apply to the same reaction.

2

Returns:

Name Type Description
ReactionMapperResult ReactionMapperResult

A mapping result containing the selected mapping and

ReactionMapperResult

related metadata. If the input is invalid or no unique valid mapping can be

ReactionMapperResult

produced, an "empty" default result is returned.

Source code in agave_chem/mappers/template/template_mapper.py
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
def map_reaction(
    self,
    reaction_smiles: str,
    apply_multiple_smirks: bool = True,
    num_smirks_to_apply: int = 2,
) -> ReactionMapperResult:
    """
    Map a reaction SMILES string using template-based atom mapping.

    This is a convenience method that calls map_reaction_with_mcs_optimization and returns only the main mapping result.

    Args:
        reaction_smiles (str): Reaction SMILES to map.
        apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
        num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

    Returns:
        ReactionMapperResult: A mapping result containing the selected mapping and
        related metadata. If the input is invalid or no unique valid mapping can be
        produced, an "empty" default result is returned.
    """
    result, _ = self.map_reaction_with_mcs_optimization(
        reaction_smiles, apply_multiple_smirks, num_smirks_to_apply
    )
    return result

map_reaction_with_mcs_optimization(reaction_smiles, apply_multiple_smirks=True, num_smirks_to_apply=2)

Map a reaction SMILES string using template-based atom mapping with optimization that uses MCS to identify probable reaction center.

Parameters:

Name Type Description Default
reaction_smiles str

Reaction SMILES to map.

required
apply_multiple_smirks bool

Whether to apply multiple SMIRKS patterns to the same reaction.

True
num_smirks_to_apply int

Number of SMIRKS patterns to apply to the same reaction.

2

Returns:

Type Description
Tuple[ReactionMapperResult, ReactionMapperResult]

Tuple[ReactionMapperResult, ReactionMapperResult]: A tuple containing the template-based mapping result and the MCS mapping result.

Source code in agave_chem/mappers/template/template_mapper.py
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
def map_reaction_with_mcs_optimization(
    self,
    reaction_smiles: str,
    apply_multiple_smirks: bool = True,
    num_smirks_to_apply: int = 2,
) -> Tuple[ReactionMapperResult, ReactionMapperResult]:
    """
    Map a reaction SMILES string using template-based atom mapping with optimization
    that uses MCS to identify probable reaction center.

    Args:
        reaction_smiles (str): Reaction SMILES to map.
        apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
        num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

    Returns:
        Tuple[ReactionMapperResult, ReactionMapperResult]: A tuple containing the template-based
            mapping result and the MCS mapping result.
    """
    self._initialize_smirks_patterns()

    if not self._reaction_smiles_valid(reaction_smiles):
        return (
            self._return_default_mapping_dict(reaction_smiles),
            self._return_default_mapping_dict(reaction_smiles),
        )

    canonicalized_reaction_smiles = canonicalize_reaction_smiles(
        reaction_smiles, canonicalize_tautomer=False
    )

    unmapped_product_atom_islands = {}
    if self._mcs_mapper is not None:
        mcs_result = self._mcs_mapper.map_reaction(canonicalized_reaction_smiles)

        if mcs_result["selected_mapping"] != "":
            unmapped_product_atom_islands = self._get_unmapped_product_atom_islands(
                mcs_result["selected_mapping"].split(">>")[1]
            )

    reactants_str, products_str = self._split_reaction_components(
        canonicalized_reaction_smiles
    )

    reaction_data = self._prepare_reaction_data(
        reactants_str, products_str, unmapped_product_atom_islands
    )

    mapped_outcomes_smirks_dict = self._apply_templates_and_collect_outcomes(
        reaction_data,
        apply_multiple_smirks=apply_multiple_smirks,
        num_smirks_to_apply=num_smirks_to_apply,
    )

    result = self._filter_and_deduplicate_outcomes(
        mapped_outcomes_smirks_dict, canonicalized_reaction_smiles, reaction_smiles
    )
    if not result:
        return self._return_default_mapping_dict(reaction_smiles), mcs_result

    return result, mcs_result

map_reactions(reaction_list, apply_multiple_smirks=True, num_smirks_to_apply=2)

Map a list of reaction SMILES strings using this mapper.

Parameters:

Name Type Description Default
reaction_list List[str]

Reaction SMILES strings to map.

required
apply_multiple_smirks bool

Whether to apply multiple SMIRKS patterns to the same reaction.

True
num_smirks_to_apply int

Number of SMIRKS patterns to apply to the same reaction.

2

Returns:

Type Description
List[ReactionMapperResult]

List[ReactionMapperResult]: Mapping results in the same order as the input.

Source code in agave_chem/mappers/template/template_mapper.py
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
def map_reactions(
    self,
    reaction_list: List[str],
    apply_multiple_smirks: bool = True,
    num_smirks_to_apply: int = 2,
) -> List[ReactionMapperResult]:
    """
    Map a list of reaction SMILES strings using this mapper.

    Args:
        reaction_list (List[str]): Reaction SMILES strings to map.
        apply_multiple_smirks (bool): Whether to apply multiple SMIRKS patterns to the same reaction.
        num_smirks_to_apply (int): Number of SMIRKS patterns to apply to the same reaction.

    Returns:
        List[ReactionMapperResult]: Mapping results in the same order as the input.
    """

    mapped_reactions = []
    for reaction in reaction_list:
        mapped_reactions.append(
            self.map_reaction(reaction, apply_multiple_smirks, num_smirks_to_apply)
        )
    return mapped_reactions

map_reactions(reaction_list, mappers_list=[], mapping_selection_mode='weighted', batch_size=500)

Source code in agave_chem/main.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def map_reactions(
    reaction_list: List[str],
    mappers_list: List[ReactionMapper] = [],
    mapping_selection_mode: str = "weighted",
    batch_size: int = 500,
) -> List[AgaveChemMapperResult]:
    """ """
    if not mappers_list:
        mappers_list = [
            MCSReactionMapper("mcs_default"),
            TemplateReactionMapper("expert_default"),
        ]

    if isinstance(reaction_list, str):
        reaction_list = [reaction_list]

    if not isinstance(reaction_list, list):
        raise ValueError(
            "Invalid input: reaction_list must be a string or a non-empty list of strings."
        )
    if isinstance(reaction_list, list):
        if len(reaction_list) == 0:
            raise ValueError(
                "Invalid input: reaction_list must be a string or a non-empty list of strings."
            )
        for reaction in reaction_list:
            if not isinstance(reaction, str):
                raise ValueError(
                    "Invalid input: reaction_list must be a string or a non-empty list of strings."
                )
    if len(reaction_list) != len(set(reaction_list)):
        logger.warning("Removing duplicate reactions from reaction_list.")
        reaction_list = list(set(reaction_list))

    if not isinstance(mappers_list, list) or len(mappers_list) == 0:
        raise ValueError(
            "Invalid input: mappers_list must be a non-empty list of ReactionMapper instances."
        )

    seen_mappers = []
    for mapper in mappers_list:
        if not isinstance(mapper, ReactionMapper):
            raise ValueError(
                f"Invalid mapper: {mapper} is not an instance of ReactionMapper."
            )
        if mapper.mapper_name in seen_mappers:
            raise ValueError(f"Duplicate mapper name: {mapper.mapper_name}.")
        seen_mappers.append(mapper.mapper_name)

    if not isinstance(mapping_selection_mode, str) and not callable(
        mapping_selection_mode
    ):
        raise ValueError(
            "Invalid input: mapping_selection_mode must be a string or function."
        )

    if not isinstance(batch_size, int):
        raise TypeError("Invalid input: batch_size must be an integer.")
    if batch_size <= 0 or batch_size > 1000:
        raise ValueError("Invalid input: batch_size must be an integer between 1-1000.")

    mapper_results = map_reactions_using_mappers(
        reaction_list, mappers_list, batch_size
    )

    if not mapper_results:
        raise ValueError("Invalid input: batch_size must be an integer between 1-1000.")

    return mapper_results