Skip to content

API Reference

agave_chem

agave_chem initialization.

MCSReactionMapper

Bases: ReactionMapper

MCS reaction classification and atom-mapping

Source code in agave_chem/mappers/mcs/mcs_mapper.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
class MCSReactionMapper(ReactionMapper):
    """
    MCS reaction classification and atom-mapping
    """

    def __init__(self, mapper_name: str, mapper_weight: float = 3):
        super().__init__("mcs", mapper_name, mapper_weight)

    def _get_atoms_in_radius(
        self, mol: Chem.Mol, atom: Chem.Atom, radius: int
    ) -> List[list]:
        """
        Encode the bond environment within a radius of a root atom.

        Args:
            mol (Chem.Mol): RDKit molecule containing the environment.
            atom (Chem.Atom): Root atom in `mol` to build the environment around.
            radius (int): Bond-radius to use when extracting the local environment.

        Returns:
            List[List[List[str | int]]]: A deterministic, sorted encoding of the
            radius-limited bond environment. Each element is a 3-item list of
            `[encoded_atom_begin, encoded_bond, encoded_atom_end]`. Returns an
            empty list when no bonds are found within the given radius.
        """
        bond_ids = Chem.rdmolops.FindAtomEnvironmentOfRadiusN(
            mol,
            radius=radius,
            rootedAtAtom=atom.GetIdx(),
        )

        if not bond_ids:
            return []

        amap: Dict[int, int] = {}
        submol = Chem.PathToSubmol(mol, bond_ids, atomMap=amap)

        root_parent = atom.GetIdx()
        root_sub = amap[root_parent]

        for a in submol.GetAtoms():
            a.SetAtomMapNum(0)
        submol.GetAtomWithIdx(root_sub).SetAtomMapNum(1)

        ranks = list(
            Chem.CanonicalRankAtoms(submol, breakTies=True, includeAtomMaps=True)
        )

        inv_amap = {sub_i: parent_i for parent_i, sub_i in amap.items()}

        bond_items = []
        for sb in submol.GetBonds():
            sa1, sa2 = sb.GetBeginAtomIdx(), sb.GetEndAtomIdx()
            r1, r2 = ranks[sa1], ranks[sa2]
            lo, hi = (r1, r2) if r1 <= r2 else (r2, r1)

            key = (
                lo,
                hi,
                int(sb.GetBondTypeAsDouble()),
                int(sb.GetIsAromatic()),
                int(sb.IsInRing()),
            )
            bond_items.append((key, sa1, sa2))

        bond_items.sort(key=lambda x: x[0])

        encoded_environment = []
        for _, sa1, sa2 in bond_items:
            if ranks[sa1] <= ranks[sa2]:
                s_begin, s_end = sa1, sa2
            else:
                s_begin, s_end = sa2, sa1

            p_begin = inv_amap[s_begin]
            p_end = inv_amap[s_end]

            pbond = mol.GetBondBetweenAtoms(p_begin, p_end)
            pbidx = pbond.GetIdx()

            encoded_environment.append(
                [
                    self._encode_atom(mol, p_begin),
                    self._encode_bond(mol, pbidx),
                    self._encode_atom(mol, p_end),
                ]
            )

        return encoded_environment

    def _encode_atom(self, mol: Chem.Mol, idx: int) -> List[str | int]:
        """
        Encode an RDKit Atom object into a list of integers.

        The encoding is as follows:
        - encoding_type: Atom encoding.
        - z: The atomic number of the atom.
        - chg: The formal charge of the atom.
        - arom: 1 if the atom is aromatic, 0 otherwise.
        - ring: 1 if the atom is in a ring, 0 otherwise.
        - h: The total number of hydrogen atoms bonded to the atom.
        - d: The degree of the atom.
        - chiral_type: The chiral type of the atom.
        - atom_map_num: The atom map number of the atom.
        - idx: The index of the atom in the molecule.

        Args:
            mol (Chem.Mol): The RDKit Molecule object containing the atom.
            idx (int): The index of the atom in the molecule.

        Returns:
            List[int]: A list of integers encoding the atom.
        """
        a = mol.GetAtomWithIdx(idx)
        z = a.GetAtomicNum()
        chg = a.GetFormalCharge()
        arom = 1 if a.GetIsAromatic() else 0
        ring = 1 if a.IsInRing() else 0
        h = a.GetTotalNumHs()
        d = a.GetDegree()
        atom_map_num = a.GetAtomMapNum()
        chiral_type = int(a.GetChiralTag())
        return [
            "atom_encoding",
            z,
            chg,
            arom,
            ring,
            h,
            d,
            chiral_type,
            atom_map_num,
            idx,
        ]

    def _encode_bond(self, mol: Chem.Mol, idx: int) -> List[str | int]:
        """
        Encode an RDKit Bond object into a list of integers.

        The encoding is as follows:
        - encoding_type: Bond encoding.
        - bond_order: The bond order of the bond multiplied by 10.
        - stereo: The stereochemistry of the bond.

        Args:
            mol (Chem.Mol): The RDKit Molecule object containing the bond.
            idx (int): The index of the bond in the molecule.

        Returns:
            List[int]: A list of integers encoding the bond.
        """
        b = mol.GetBondWithIdx(idx)
        stereo = b.GetStereo()
        return ["bond_encoding", int(b.GetBondTypeAsDouble() * 10), int(stereo)]

    def _assign_mapping(self, mol: Chem.Mol, atom_idx: int, atom_map_num: int) -> None:
        """Assign an atom-mapping number to a specific atom in a molecule.

        Args:
            mol (Chem.Mol): Molecule whose atom will be updated.
            atom_idx (int): Index of the atom in `mol` to update.
            atom_map_num (int): Atom-mapping number to assign to the atom.

        Returns:
            None: This method updates `mol` in place.
        """
        atom = mol.GetAtomWithIdx(atom_idx)
        atom.SetAtomMapNum(atom_map_num)
        return None

    def _remove_atom_idx_for_comparison(
        self, atom_map_list: List[List[List[int]]]
    ) -> List[List[List[int]]]:
        """Remove the trailing atom index from atom-encoding entries for equality checks.

        Args:
            atom_map_list (List[List[List[int]]]): Nested list structure describing a
                mapping in terms of atom/bond encodings, where sub-lists whose first
                element equals ``"atom_encoding"`` have a trailing atom index that
                should be ignored for comparison.

        Returns:
            List[List[List[int]]]: A new nested list with the last element removed
                from any ``["atom_encoding", ...]`` sub-list, leaving all other
                sub-lists unchanged.
        """
        new_list = []
        for init_list in atom_map_list:
            atom_bond_list = []
            for sub_list in init_list:
                if sub_list[0] == "atom_encoding":
                    atom_bond_list.append(sub_list[:-1])
                else:
                    atom_bond_list.append(sub_list)

            new_list.append(atom_bond_list)
        return new_list

    def _get_num_matching_atom_envs(
        self,
        reactant_encoding_dict: Dict[str, List[List[List[int]]]],
        product_encoding_dict: Dict[str, List[List[List[int]]]],
        radius: int,
        min_radius_to_anchor_new_mapping: int = 2,
    ) -> List[List[str]]:
        """ """
        matches = []
        reactant_encoding_dict_no_idx = {}
        for k, v in reactant_encoding_dict.items():
            if str(radius) != k.split("_")[-1]:
                continue
            if not v:
                continue
            if radius < min_radius_to_anchor_new_mapping:
                mapping_nums = [
                    subsub[-2]
                    for sub in v
                    for subsub in sub
                    if subsub[0] == "atom_encoding"
                ]
                if list(set(mapping_nums)) == [0]:
                    continue
            reactant_encoding_dict_no_idx[k] = self._remove_atom_idx_for_comparison(v)

        product_encoding_dict_no_idx = {}
        for k, v in product_encoding_dict.items():
            if str(radius) != k.split("_")[-1]:
                continue
            if not v:
                continue
            if radius < min_radius_to_anchor_new_mapping:
                mapping_nums = [
                    subsub[-2]
                    for sub in v
                    for subsub in sub
                    if subsub[0] == "atom_encoding"
                ]
                if list(set(mapping_nums)) == [0]:
                    continue
            product_encoding_dict_no_idx[k] = self._remove_atom_idx_for_comparison(v)

        for k1, v1 in reactant_encoding_dict_no_idx.items():
            for k2, v2 in product_encoding_dict_no_idx.items():
                if v1 == v2:
                    matches.append([k1, k2])

        return matches

    def _assign_atom_mapping(
        self,
        matches: List[List[str]],
        reactants_mols: List[Chem.Mol],
        products_mols: List[Chem.Mol],
        reactant_encoding_dict: Dict[str, List[List[List[int]]]],
        product_encoding_dict: Dict[str, List[List[List[int]]]],
        atom_map_num: int,
    ) -> Tuple[int, Dict[str, List[List[List[int]]]], Dict[str, List[List[List[int]]]]]:
        """ """
        reactant_mol_idx = int(matches[0][0].split("_")[0])
        reactant_atom_idx = int(matches[0][0].split("_")[1])
        if (
            reactants_mols[reactant_mol_idx]
            .GetAtomWithIdx(reactant_atom_idx)
            .GetAtomMapNum()
            != 0
        ):
            return atom_map_num, reactant_encoding_dict, product_encoding_dict
        self._assign_mapping(
            reactants_mols[reactant_mol_idx], reactant_atom_idx, atom_map_num
        )

        product_mol_idx = int(matches[0][1].split("_")[0])
        product_atom_idx = int(matches[0][1].split("_")[1])
        if (
            products_mols[product_mol_idx]
            .GetAtomWithIdx(product_atom_idx)
            .GetAtomMapNum()
            != 0
        ):
            return atom_map_num, reactant_encoding_dict, product_encoding_dict
        self._assign_mapping(
            products_mols[product_mol_idx], product_atom_idx, atom_map_num
        )

        keys_to_delete = []
        for k1, v1 in reactant_encoding_dict.items():
            if int(k1.split("_")[0]) != reactant_mol_idx:
                continue
            if int(k1.split("_")[1]) == reactant_atom_idx:
                keys_to_delete.append(k1)
            for sub_v1 in v1:
                for sub_sub_v1 in sub_v1:
                    if sub_sub_v1[0] == "atom_encoding":
                        if sub_sub_v1[-1] == reactant_atom_idx:
                            sub_sub_v1[-2] = atom_map_num
        for key in keys_to_delete:
            del reactant_encoding_dict[key]

        keys_to_delete = []
        for k1, v1 in product_encoding_dict.items():
            if int(k1.split("_")[0]) != product_mol_idx:
                continue
            if int(k1.split("_")[1]) == product_atom_idx:
                keys_to_delete.append(k1)
            for sub_v1 in v1:
                for sub_sub_v1 in sub_v1:
                    if sub_sub_v1[0] == "atom_encoding":
                        if sub_sub_v1[-1] == product_atom_idx:
                            sub_sub_v1[-2] = atom_map_num
        for key in keys_to_delete:
            del product_encoding_dict[key]

        return atom_map_num + 1, reactant_encoding_dict, product_encoding_dict

    def map_reaction(
        self,
        reaction_smiles: str,
        min_radius: int = 1,
        min_radius_to_anchor_new_mapping: int = 3,
    ) -> ReactionMapperResult:
        """ """
        default_mapping_dict = ReactionMapperResult(
            original_smiles="",
            selected_mapping="",
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )
        if not self._reaction_smiles_valid(reaction_smiles):
            return default_mapping_dict

        canonicalized_reaction_smiles = canonicalize_reaction_smiles(
            reaction_smiles, canonicalize_tautomer=False
        )
        reactants, products = self._split_reaction_components(
            canonicalized_reaction_smiles
        )

        reactant_mols = [Chem.MolFromSmiles(r) for r in reactants.split(".")]
        product_mols = [Chem.MolFromSmiles(p) for p in products.split(".")]

        largest_num_atoms = max(
            [len(mol.GetAtoms()) for mol in reactant_mols + product_mols]
        )

        reactant_encoding_dict = {}
        product_encoding_dict = {}

        for radius in range(min_radius, largest_num_atoms):
            for i, reactant_mol in enumerate(reactant_mols):
                if len(reactant_mol.GetAtoms()) < radius:
                    continue
                for atom in reactant_mol.GetAtoms():
                    reactant_encoding_dict[
                        str(i) + "_" + str(atom.GetIdx()) + "_" + str(radius)
                    ] = self._get_atoms_in_radius(reactant_mol, atom, radius)

            for i, product_mol in enumerate(product_mols):
                if len(product_mol.GetAtoms()) < radius:
                    continue
                for atom in product_mol.GetAtoms():
                    product_encoding_dict[
                        str(i) + "_" + str(atom.GetIdx()) + "_" + str(radius)
                    ] = self._get_atoms_in_radius(product_mol, atom, radius)

            # For the purpose of finding final radius and populating encoding dictionaries,
            # we don't care if radius < min_radius_to_anchor_new_mapping
            matches = self._get_num_matching_atom_envs(
                reactant_encoding_dict,
                product_encoding_dict,
                radius,
                min_radius_to_anchor_new_mapping=0,
            )
            if len(matches) == 1:
                final_radius = radius
                break
            if len(matches) == 0:
                final_radius = radius - 1
                break

        atom_map_num = 1
        for radius in range(final_radius, 0, -1):
            matches = self._get_num_matching_atom_envs(
                reactant_encoding_dict,
                product_encoding_dict,
                radius,
            )
            num_matches = len(matches)
            for _ in range(num_matches):
                atom_map_num, reactant_encoding_dict, product_encoding_dict = (
                    self._assign_atom_mapping(
                        [matches[0]],
                        reactant_mols,
                        product_mols,
                        reactant_encoding_dict,
                        product_encoding_dict,
                        atom_map_num,
                    )
                )
                matches = self._get_num_matching_atom_envs(
                    reactant_encoding_dict,
                    product_encoding_dict,
                    radius,
                )
                if len(matches) == 0:
                    break

        mapped_reactant_smiles = ".".join(
            [
                Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
                for mol in reactant_mols
            ]
        )
        mapped_product_smiles = ".".join(
            [
                Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
                for mol in product_mols
            ]
        )
        mapped_reaction_smiles = mapped_reactant_smiles + ">>" + mapped_product_smiles

        return ReactionMapperResult(
            original_smiles=reaction_smiles,
            selected_mapping=mapped_reaction_smiles,
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )

    def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
        """
        Map a list of reaction SMILES strings using the MCS mapper.

        Args:
            reaction_list (List[str]): List of reaction SMILES strings to map.

        Returns:
            List[ReactionMapperResult]: The mapping results in the same order as the
                input reactions.
        """

        mapped_reactions = []
        for reaction in reaction_list:
            mapped_reactions.append(self.map_reaction(reaction))
        return mapped_reactions

map_reaction(reaction_smiles, min_radius=1, min_radius_to_anchor_new_mapping=3)

Source code in agave_chem/mappers/mcs/mcs_mapper.py
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def map_reaction(
    self,
    reaction_smiles: str,
    min_radius: int = 1,
    min_radius_to_anchor_new_mapping: int = 3,
) -> ReactionMapperResult:
    """ """
    default_mapping_dict = ReactionMapperResult(
        original_smiles="",
        selected_mapping="",
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )
    if not self._reaction_smiles_valid(reaction_smiles):
        return default_mapping_dict

    canonicalized_reaction_smiles = canonicalize_reaction_smiles(
        reaction_smiles, canonicalize_tautomer=False
    )
    reactants, products = self._split_reaction_components(
        canonicalized_reaction_smiles
    )

    reactant_mols = [Chem.MolFromSmiles(r) for r in reactants.split(".")]
    product_mols = [Chem.MolFromSmiles(p) for p in products.split(".")]

    largest_num_atoms = max(
        [len(mol.GetAtoms()) for mol in reactant_mols + product_mols]
    )

    reactant_encoding_dict = {}
    product_encoding_dict = {}

    for radius in range(min_radius, largest_num_atoms):
        for i, reactant_mol in enumerate(reactant_mols):
            if len(reactant_mol.GetAtoms()) < radius:
                continue
            for atom in reactant_mol.GetAtoms():
                reactant_encoding_dict[
                    str(i) + "_" + str(atom.GetIdx()) + "_" + str(radius)
                ] = self._get_atoms_in_radius(reactant_mol, atom, radius)

        for i, product_mol in enumerate(product_mols):
            if len(product_mol.GetAtoms()) < radius:
                continue
            for atom in product_mol.GetAtoms():
                product_encoding_dict[
                    str(i) + "_" + str(atom.GetIdx()) + "_" + str(radius)
                ] = self._get_atoms_in_radius(product_mol, atom, radius)

        # For the purpose of finding final radius and populating encoding dictionaries,
        # we don't care if radius < min_radius_to_anchor_new_mapping
        matches = self._get_num_matching_atom_envs(
            reactant_encoding_dict,
            product_encoding_dict,
            radius,
            min_radius_to_anchor_new_mapping=0,
        )
        if len(matches) == 1:
            final_radius = radius
            break
        if len(matches) == 0:
            final_radius = radius - 1
            break

    atom_map_num = 1
    for radius in range(final_radius, 0, -1):
        matches = self._get_num_matching_atom_envs(
            reactant_encoding_dict,
            product_encoding_dict,
            radius,
        )
        num_matches = len(matches)
        for _ in range(num_matches):
            atom_map_num, reactant_encoding_dict, product_encoding_dict = (
                self._assign_atom_mapping(
                    [matches[0]],
                    reactant_mols,
                    product_mols,
                    reactant_encoding_dict,
                    product_encoding_dict,
                    atom_map_num,
                )
            )
            matches = self._get_num_matching_atom_envs(
                reactant_encoding_dict,
                product_encoding_dict,
                radius,
            )
            if len(matches) == 0:
                break

    mapped_reactant_smiles = ".".join(
        [
            Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
            for mol in reactant_mols
        ]
    )
    mapped_product_smiles = ".".join(
        [
            Chem.MolToSmiles(mol, isomericSmiles=True, canonical=False)
            for mol in product_mols
        ]
    )
    mapped_reaction_smiles = mapped_reactant_smiles + ">>" + mapped_product_smiles

    return ReactionMapperResult(
        original_smiles=reaction_smiles,
        selected_mapping=mapped_reaction_smiles,
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )

map_reactions(reaction_list)

Map a list of reaction SMILES strings using the MCS mapper.

Parameters:

Name Type Description Default
reaction_list List[str]

List of reaction SMILES strings to map.

required

Returns:

Type Description
List[ReactionMapperResult]

List[ReactionMapperResult]: The mapping results in the same order as the input reactions.

Source code in agave_chem/mappers/mcs/mcs_mapper.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
    """
    Map a list of reaction SMILES strings using the MCS mapper.

    Args:
        reaction_list (List[str]): List of reaction SMILES strings to map.

    Returns:
        List[ReactionMapperResult]: The mapping results in the same order as the
            input reactions.
    """

    mapped_reactions = []
    for reaction in reaction_list:
        mapped_reactions.append(self.map_reaction(reaction))
    return mapped_reactions

MappingScorer

Comprehensive scorer for atom-to-atom mappings.

This class computes all the metrics used to evaluate and compare different mapping solutions, following the approach in RDTool.

Metrics include: - Bond energy cost (formation/breaking) - Number of bond changes - Number of fragments affected - Stereochemistry changes - Ring opening/closing events - Overall similarity score

Source code in agave_chem/scoring.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
class MappingScorer:
    """
    Comprehensive scorer for atom-to-atom mappings.

    This class computes all the metrics used to evaluate and compare
    different mapping solutions, following the approach in RDTool.

    Metrics include:
    - Bond energy cost (formation/breaking)
    - Number of bond changes
    - Number of fragments affected
    - Stereochemistry changes
    - Ring opening/closing events
    - Overall similarity score
    """

    def __init__(
        self,
        energy_penalty_weight: float = 1.0,
        bond_change_weight: float = 10.0,
        fragment_weight: float = 20.0,
        stereo_weight: float = 15.0,
        ring_weight: float = 25.0,
    ):
        """
        Initialize the scorer with custom weights.

        Args:
            energy_penalty_weight: Weight for bond energy cost
            bond_change_weight: Weight for number of bond changes
            fragment_weight: Weight for fragment changes
            stereo_weight: Weight for stereo changes
            ring_weight: Weight for ring changes
        """
        self.weights = {
            "bond_energy_cost": energy_penalty_weight,
            "num_bond_changes": bond_change_weight,
            "num_fragments": fragment_weight,
            "stereo_changes": stereo_weight,
            "ring_changes": ring_weight,
        }

    def score_mapping(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
        bond_changes: Optional[List[BondChange]] = None,
    ) -> MappingScore:
        """
        Compute comprehensive score for a mapping.

        Args:
            reactants: List of reactant molecules
            products: List of product molecules
            mapping: Set of atom mappings
            bond_changes: Pre-computed bond changes (optional)

        Returns:
            MappingScore object with all metrics
        """
        if bond_changes is None:
            bond_changes = self.compute_bond_changes(reactants, products, mapping)

        # Count bond changes by type
        num_formed = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.FORMED
        )
        num_broken = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.BROKEN
        )
        num_order_changes = sum(
            1 for bc in bond_changes if bc.change_type == BondChangeType.ORDER_CHANGE
        )

        # Calculate total energy cost
        energy_cost = sum(bc.energy_cost for bc in bond_changes)

        # Calculate fragment changes
        num_fragments = self._count_fragment_changes(reactants, products, mapping)

        # Calculate stereo changes
        stereo_changes = self._count_stereo_changes(reactants, products, mapping)

        # Calculate ring changes
        ring_changes = self._count_ring_changes(
            reactants, products, mapping, bond_changes
        )

        # Calculate similarity score
        similarity = self._calculate_similarity(reactants, products, mapping)

        return MappingScore(
            bond_energy_cost=energy_cost,
            num_bond_changes=num_formed + num_broken + num_order_changes,
            num_bonds_formed=num_formed,
            num_bonds_broken=num_broken,
            num_fragments=num_fragments,
            stereo_changes=stereo_changes,
            similarity_score=similarity,
            ring_changes=ring_changes,
        )

    def compute_bond_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> List[BondChange]:
        """
        Compute all bond changes in the reaction.

        Args:
            reactants: List of reactant molecules
            products: List of product molecules
            mapping: Set of atom mappings

        Returns:
            List of BondChange objects
        """
        changes = []

        # Create mapping lookup: atom_map_num -> (mol_type, mol_idx, atom_idx)
        # First, assign temporary map numbers
        map_num_counter = 1
        reactant_to_map: Dict[Tuple[int, int], int] = {}
        product_to_map: Dict[Tuple[int, int], int] = {}

        for am in mapping:
            reactant_to_map[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
                map_num_counter
            )
            product_to_map[(am.product_mol_idx, am.product_atom_idx)] = map_num_counter
            map_num_counter += 1

        # Build bond sets for reactants (using map numbers)
        reactant_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
        for mol_idx, mol in enumerate(reactants):
            for bond in mol.GetBonds():
                atom1_key = (mol_idx, bond.GetBeginAtomIdx())
                atom2_key = (mol_idx, bond.GetEndAtomIdx())

                if atom1_key in reactant_to_map and atom2_key in reactant_to_map:
                    map1 = reactant_to_map[atom1_key]
                    map2 = reactant_to_map[atom2_key]
                    bond_order = bond.GetBondTypeAsDouble()

                    atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                    atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                    reactant_bonds[frozenset([map1, map2])] = (
                        bond_order,
                        atom1_sym,
                        atom2_sym,
                    )

        # Build bond sets for products (using map numbers)
        product_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
        for mol_idx, mol in enumerate(products):
            for bond in mol.GetBonds():
                atom1_key = (mol_idx, bond.GetBeginAtomIdx())
                atom2_key = (mol_idx, bond.GetEndAtomIdx())

                if atom1_key in product_to_map and atom2_key in product_to_map:
                    map1 = product_to_map[atom1_key]
                    map2 = product_to_map[atom2_key]
                    bond_order = bond.GetBondTypeAsDouble()

                    atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                    atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                    product_bonds[frozenset([map1, map2])] = (
                        bond_order,
                        atom1_sym,
                        atom2_sym,
                    )

        # Find broken bonds (in reactants but not products)
        for bond_key, (order, sym1, sym2) in reactant_bonds.items():
            if bond_key not in product_bonds:
                map1, map2 = sorted(bond_key)
                energy = get_bond_energy(sym1, sym2, order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.BROKEN,
                        old_order=order,
                        new_order=None,
                        energy_cost=energy,
                    )
                )

        # Find formed bonds (in products but not reactants)
        for bond_key, (order, sym1, sym2) in product_bonds.items():
            if bond_key not in reactant_bonds:
                map1, map2 = sorted(bond_key)
                energy = get_bond_energy(sym1, sym2, order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.FORMED,
                        old_order=None,
                        new_order=order,
                        energy_cost=energy,
                    )
                )

        # Find order changes (in both but different order)
        for bond_key in reactant_bonds.keys() & product_bonds.keys():
            r_order, r_sym1, r_sym2 = reactant_bonds[bond_key]
            p_order, _, _ = product_bonds[bond_key]

            if abs(r_order - p_order) > 0.1:
                map1, map2 = sorted(bond_key)
                # Energy change for order modification
                old_energy = get_bond_energy(r_sym1, r_sym2, r_order)
                new_energy = get_bond_energy(r_sym1, r_sym2, p_order)
                changes.append(
                    BondChange(
                        atom1_map=map1,
                        atom2_map=map2,
                        change_type=BondChangeType.ORDER_CHANGE,
                        old_order=r_order,
                        new_order=p_order,
                        energy_cost=abs(new_energy - old_energy),
                    )
                )

        return changes

    def _count_fragment_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> int:
        """Count the number of molecular fragments that change."""
        # Simple heuristic: count difference in number of molecules
        return abs(len(reactants) - len(products))

    def _count_stereo_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> int:
        """Count stereochemistry changes in the reaction."""
        changes = 0

        # Build lookup for mapping
        mapping_lookup: Dict[Tuple[int, int], Tuple[int, int]] = {}
        for am in mapping:
            mapping_lookup[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
                am.product_mol_idx,
                am.product_atom_idx,
            )

        # Check each mapped atom for stereo changes
        for am in mapping:
            r_mol = reactants[am.reactant_mol_idx]
            p_mol = products[am.product_mol_idx]

            r_atom = r_mol.GetAtomWithIdx(am.reactant_atom_idx)
            p_atom = p_mol.GetAtomWithIdx(am.product_atom_idx)

            # Check chiral tag
            r_chiral = r_atom.GetChiralTag()
            p_chiral = p_atom.GetChiralTag()

            if r_chiral != p_chiral:
                changes += 1

        return changes

    def _count_ring_changes(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
        bond_changes: List[BondChange],
    ) -> int:
        """Count ring opening and closing events."""
        changes = 0

        # Create map number to molecule info lookup
        map_to_reactant: Dict[int, Tuple[int, int]] = {}
        map_to_product: Dict[int, Tuple[int, int]] = {}

        map_num = 1
        for am in mapping:
            map_to_reactant[map_num] = (am.reactant_mol_idx, am.reactant_atom_idx)
            map_to_product[map_num] = (am.product_mol_idx, am.product_atom_idx)
            map_num += 1

        # Get ring info for all molecules
        reactant_rings: List[Set[int]] = []
        for mol in reactants:
            reactant_rings.extend(get_ring_info(mol))

        product_rings: List[Set[int]] = []
        for mol in products:
            product_rings.extend(get_ring_info(mol))

        # Check each bond change for ring involvement
        for bc in bond_changes:
            if bc.atom1_map in map_to_reactant and bc.atom2_map in map_to_reactant:
                r_info1 = map_to_reactant.get(bc.atom1_map)
                r_info2 = map_to_reactant.get(bc.atom2_map)

                if r_info1 and r_info2:
                    r_mol_idx1, r_atom_idx1 = r_info1
                    r_mol_idx2, r_atom_idx2 = r_info2

                    # Check if both atoms were in same ring
                    if r_mol_idx1 == r_mol_idx2:
                        mol = reactants[r_mol_idx1]
                        atom1 = mol.GetAtomWithIdx(r_atom_idx1)
                        atom2 = mol.GetAtomWithIdx(r_atom_idx2)

                        if atom1.IsInRing() and atom2.IsInRing():
                            if bc.change_type == BondChangeType.BROKEN:
                                changes += 1  # Ring opening
                            elif bc.change_type == BondChangeType.FORMED:
                                changes += 1  # Ring closing

        return changes

    def _calculate_similarity(
        self,
        reactants: List[Chem.Mol],
        products: List[Chem.Mol],
        mapping: FrozenSet[AtomMapping],
    ) -> float:
        """
        Calculate overall similarity based on the mapping.

        Returns fraction of atoms that are successfully mapped.
        """
        total_reactant_atoms = sum(mol.GetNumAtoms() for mol in reactants)
        total_product_atoms = sum(mol.GetNumAtoms() for mol in products)

        if total_reactant_atoms == 0 or total_product_atoms == 0:
            return 0.0

        mapped_atoms = len(mapping)

        # Similarity as fraction of atoms mapped
        reactant_coverage = mapped_atoms / total_reactant_atoms
        product_coverage = mapped_atoms / total_product_atoms

        return (reactant_coverage + product_coverage) / 2

__init__(energy_penalty_weight=1.0, bond_change_weight=10.0, fragment_weight=20.0, stereo_weight=15.0, ring_weight=25.0)

Initialize the scorer with custom weights.

Parameters:

Name Type Description Default
energy_penalty_weight float

Weight for bond energy cost

1.0
bond_change_weight float

Weight for number of bond changes

10.0
fragment_weight float

Weight for fragment changes

20.0
stereo_weight float

Weight for stereo changes

15.0
ring_weight float

Weight for ring changes

25.0
Source code in agave_chem/scoring.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def __init__(
    self,
    energy_penalty_weight: float = 1.0,
    bond_change_weight: float = 10.0,
    fragment_weight: float = 20.0,
    stereo_weight: float = 15.0,
    ring_weight: float = 25.0,
):
    """
    Initialize the scorer with custom weights.

    Args:
        energy_penalty_weight: Weight for bond energy cost
        bond_change_weight: Weight for number of bond changes
        fragment_weight: Weight for fragment changes
        stereo_weight: Weight for stereo changes
        ring_weight: Weight for ring changes
    """
    self.weights = {
        "bond_energy_cost": energy_penalty_weight,
        "num_bond_changes": bond_change_weight,
        "num_fragments": fragment_weight,
        "stereo_changes": stereo_weight,
        "ring_changes": ring_weight,
    }

compute_bond_changes(reactants, products, mapping)

Compute all bond changes in the reaction.

Parameters:

Name Type Description Default
reactants List[Mol]

List of reactant molecules

required
products List[Mol]

List of product molecules

required
mapping FrozenSet[AtomMapping]

Set of atom mappings

required

Returns:

Type Description
List[BondChange]

List of BondChange objects

Source code in agave_chem/scoring.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def compute_bond_changes(
    self,
    reactants: List[Chem.Mol],
    products: List[Chem.Mol],
    mapping: FrozenSet[AtomMapping],
) -> List[BondChange]:
    """
    Compute all bond changes in the reaction.

    Args:
        reactants: List of reactant molecules
        products: List of product molecules
        mapping: Set of atom mappings

    Returns:
        List of BondChange objects
    """
    changes = []

    # Create mapping lookup: atom_map_num -> (mol_type, mol_idx, atom_idx)
    # First, assign temporary map numbers
    map_num_counter = 1
    reactant_to_map: Dict[Tuple[int, int], int] = {}
    product_to_map: Dict[Tuple[int, int], int] = {}

    for am in mapping:
        reactant_to_map[(am.reactant_mol_idx, am.reactant_atom_idx)] = (
            map_num_counter
        )
        product_to_map[(am.product_mol_idx, am.product_atom_idx)] = map_num_counter
        map_num_counter += 1

    # Build bond sets for reactants (using map numbers)
    reactant_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
    for mol_idx, mol in enumerate(reactants):
        for bond in mol.GetBonds():
            atom1_key = (mol_idx, bond.GetBeginAtomIdx())
            atom2_key = (mol_idx, bond.GetEndAtomIdx())

            if atom1_key in reactant_to_map and atom2_key in reactant_to_map:
                map1 = reactant_to_map[atom1_key]
                map2 = reactant_to_map[atom2_key]
                bond_order = bond.GetBondTypeAsDouble()

                atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                reactant_bonds[frozenset([map1, map2])] = (
                    bond_order,
                    atom1_sym,
                    atom2_sym,
                )

    # Build bond sets for products (using map numbers)
    product_bonds: Dict[FrozenSet[int], Tuple[float, str, str]] = {}
    for mol_idx, mol in enumerate(products):
        for bond in mol.GetBonds():
            atom1_key = (mol_idx, bond.GetBeginAtomIdx())
            atom2_key = (mol_idx, bond.GetEndAtomIdx())

            if atom1_key in product_to_map and atom2_key in product_to_map:
                map1 = product_to_map[atom1_key]
                map2 = product_to_map[atom2_key]
                bond_order = bond.GetBondTypeAsDouble()

                atom1_sym = mol.GetAtomWithIdx(bond.GetBeginAtomIdx()).GetSymbol()
                atom2_sym = mol.GetAtomWithIdx(bond.GetEndAtomIdx()).GetSymbol()

                product_bonds[frozenset([map1, map2])] = (
                    bond_order,
                    atom1_sym,
                    atom2_sym,
                )

    # Find broken bonds (in reactants but not products)
    for bond_key, (order, sym1, sym2) in reactant_bonds.items():
        if bond_key not in product_bonds:
            map1, map2 = sorted(bond_key)
            energy = get_bond_energy(sym1, sym2, order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.BROKEN,
                    old_order=order,
                    new_order=None,
                    energy_cost=energy,
                )
            )

    # Find formed bonds (in products but not reactants)
    for bond_key, (order, sym1, sym2) in product_bonds.items():
        if bond_key not in reactant_bonds:
            map1, map2 = sorted(bond_key)
            energy = get_bond_energy(sym1, sym2, order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.FORMED,
                    old_order=None,
                    new_order=order,
                    energy_cost=energy,
                )
            )

    # Find order changes (in both but different order)
    for bond_key in reactant_bonds.keys() & product_bonds.keys():
        r_order, r_sym1, r_sym2 = reactant_bonds[bond_key]
        p_order, _, _ = product_bonds[bond_key]

        if abs(r_order - p_order) > 0.1:
            map1, map2 = sorted(bond_key)
            # Energy change for order modification
            old_energy = get_bond_energy(r_sym1, r_sym2, r_order)
            new_energy = get_bond_energy(r_sym1, r_sym2, p_order)
            changes.append(
                BondChange(
                    atom1_map=map1,
                    atom2_map=map2,
                    change_type=BondChangeType.ORDER_CHANGE,
                    old_order=r_order,
                    new_order=p_order,
                    energy_cost=abs(new_energy - old_energy),
                )
            )

    return changes

score_mapping(reactants, products, mapping, bond_changes=None)

Compute comprehensive score for a mapping.

Parameters:

Name Type Description Default
reactants List[Mol]

List of reactant molecules

required
products List[Mol]

List of product molecules

required
mapping FrozenSet[AtomMapping]

Set of atom mappings

required
bond_changes Optional[List[BondChange]]

Pre-computed bond changes (optional)

None

Returns:

Type Description
MappingScore

MappingScore object with all metrics

Source code in agave_chem/scoring.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def score_mapping(
    self,
    reactants: List[Chem.Mol],
    products: List[Chem.Mol],
    mapping: FrozenSet[AtomMapping],
    bond_changes: Optional[List[BondChange]] = None,
) -> MappingScore:
    """
    Compute comprehensive score for a mapping.

    Args:
        reactants: List of reactant molecules
        products: List of product molecules
        mapping: Set of atom mappings
        bond_changes: Pre-computed bond changes (optional)

    Returns:
        MappingScore object with all metrics
    """
    if bond_changes is None:
        bond_changes = self.compute_bond_changes(reactants, products, mapping)

    # Count bond changes by type
    num_formed = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.FORMED
    )
    num_broken = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.BROKEN
    )
    num_order_changes = sum(
        1 for bc in bond_changes if bc.change_type == BondChangeType.ORDER_CHANGE
    )

    # Calculate total energy cost
    energy_cost = sum(bc.energy_cost for bc in bond_changes)

    # Calculate fragment changes
    num_fragments = self._count_fragment_changes(reactants, products, mapping)

    # Calculate stereo changes
    stereo_changes = self._count_stereo_changes(reactants, products, mapping)

    # Calculate ring changes
    ring_changes = self._count_ring_changes(
        reactants, products, mapping, bond_changes
    )

    # Calculate similarity score
    similarity = self._calculate_similarity(reactants, products, mapping)

    return MappingScore(
        bond_energy_cost=energy_cost,
        num_bond_changes=num_formed + num_broken + num_order_changes,
        num_bonds_formed=num_formed,
        num_bonds_broken=num_broken,
        num_fragments=num_fragments,
        stereo_changes=stereo_changes,
        similarity_score=similarity,
        ring_changes=ring_changes,
    )

NeuralReactionMapper

Bases: ReactionMapper

Neural network-based reaction atom-mapping

Source code in agave_chem/mappers/neural/neural_mapper.py
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
class NeuralReactionMapper(ReactionMapper):
    """
    Neural network-based reaction atom-mapping
    """

    def __init__(
        self,
        mapper_name: str,
        mapper_weight: float = 3,
        checkpoint_path: Optional[str] = None,
        use_supervised: bool = True,
        supervised_config: SupervisedConfig | None = None,
        sequence_max_length: int = 256,
    ):
        """
        Initialize the NeuralReactionMapper instance.

        Args:
            mapper_name (str): The name of the mapper.
            mapper_weight (float): The weight of the mapper.
            checkpoint_path (Optional[str]): The path to the checkpoint file.
        """

        super().__init__("neural", mapper_name, mapper_weight)

        if not checkpoint_path:
            checkpoint_path = str(
                files("agave_chem.datafiles").joinpath("supervised_albert_model")
            )

        self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._sequence_max_length = sequence_max_length
        self._use_supervised = use_supervised
        self._supervised_config = supervised_config or SupervisedConfig()

        self._model = load_neural_albert_model(
            checkpoint_dir=checkpoint_path,
            device=self._device,
            use_supervised=use_supervised,
            max_length=sequence_max_length,
            supervised_config=self._supervised_config,
        )

        self._tokenizer = CustomTokenizer(smiles_token_to_id_dict)

    def _encode_atom(self, atom: Chem.Atom) -> List[int]:
        """
        Encode an RDKit Atom object into a list of integers.

        The encoding is as follows:
        - z: The atomic number of the atom.
        - chg: The formal charge of the atom.
        - arom: 1 if the atom is aromatic, 0 otherwise.
        - ring: 1 if the atom is in a ring, 0 otherwise.
        - h: The total number of hydrogen atoms bonded to the atom.
        - d: The degree of the atom.

        Args:
            atom (Chem.Atom): The RDKit Atom object to encode.

        Returns:
            List[int]: A list of integers encoding the atom.
        """
        z = atom.GetAtomicNum()
        chg = atom.GetFormalCharge()
        arom = 1 if atom.GetIsAromatic() else 0
        ring = 1 if atom.IsInRing() else 0
        h = atom.GetTotalNumHs()
        d = atom.GetDegree()
        return [z, chg, arom, ring, h, d]

    def get_attention_matrix_for_head(
        self,
        text: str,
        layer: int,
        head: int,
        max_length: int = 256,
        trim_padding: bool = True,
    ) -> Tuple[np.ndarray, List[str]]:
        """
        Returns the attention matrix for a given layer/head for a single input string.

        Args:
            text: input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)
            layer: 0-based layer index
            head: 0-based head index
            max_length: tokenization length (should match training, e.g. 256)
            trim_padding: if True, slices matrix down to non-pad tokens only

        Returns:
            attn: Tensor of shape (seq_len, seq_len) (trimmed if requested)
            tokens: list[str] tokens aligned to attn axes (trimmed if requested)
        """
        self._model.eval()

        enc = self._tokenizer(
            text,
            max_length=max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(self._device)
        attention_mask = enc["attention_mask"].to(self._device)

        token_type_ids = enc.get("token_type_ids", torch.zeros_like(enc["input_ids"]))
        token_type_ids = token_type_ids.to(self._device)

        with torch.no_grad():
            if self._use_supervised:
                # self._model is AlbertWithAttentionAlignment
                attn_probs = self._model.predict_attention_probs(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                )  # (B,S,S)
                attn = attn_probs[0].detach().cpu()  # (S,S)
            else:
                # self._model is AlbertForMaskedLM
                outputs = self._model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_attentions=True,
                    return_dict=True,
                )
                attentions = outputs.attentions  # tuple[num_layers] of (B,H,S,S)

                if layer < 0 or layer >= len(attentions):
                    raise ValueError(
                        f"layer must be in [0, {len(attentions) - 1}], got {layer}"
                    )

                num_heads = attentions[layer].shape[1]
                if head < 0 or head >= num_heads:
                    raise ValueError(
                        f"head must be in [0, {num_heads - 1}], got {head}"
                    )

                attn = attentions[layer][0, head].detach().cpu()  # (S,S)

        # Tokens for inspection/plotting
        token_ids = enc["input_ids"][0].tolist()
        tokens = self._tokenizer.convert_ids_to_tokens(token_ids)

        if trim_padding:
            real_len = int(enc["attention_mask"][0].sum().item())
            attn = attn[:real_len, :real_len]
            tokens = tokens[:real_len]

        # IMPORTANT: keep downstream behavior identical by returning log-attn
        return torch.log(attn)[1:-1, 1:-1].numpy(), tokens

    def get_reactants_products_dict(
        self,
        tokens: List[str],
    ) -> StringInfoDict:
        """
        Extracts reactants and products from a list of tokens in a reaction SMILES string.

        Args:
            tokens: A list of tokens in a reaction SMILES string.

        Returns:
            A tuple containing:
                reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
                products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
                atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices.
                non_atom_tokens: A list of token indices that correspond to non-atom tokens.
                reactants_start_index: The index of the first reactant token.
                reactants_end_index: The index of the last reactant token.
                products_start_index: The index of the first product token.
                products_end_index: The index of the last product token.
        """
        reactants_dict: Dict[int, str] = {}
        products_dict: Dict[int, str] = {}
        atom_tokens_dict: Dict[int, List[int]] = {}
        non_atom_tokens: List[int] = []

        found_reaction_symbol = False
        for i, token in enumerate(tokens[1:-1]):
            if token == ">>":
                found_reaction_symbol = True
                non_atom_tokens.append(i)
                continue
            if token_atom_identity_dict.get(token, 0) == 0:
                non_atom_tokens.append(i)
            else:
                if token_atom_identity_dict.get(token, 0) not in atom_tokens_dict:
                    atom_tokens_dict[token_atom_identity_dict.get(token, 0)] = [i]
                else:
                    atom_tokens_dict[token_atom_identity_dict.get(token, 0)].append(i)
            if found_reaction_symbol:
                products_dict[i] = token
            else:
                reactants_dict[i] = token

        string_info_dict: StringInfoDict = {
            "reactants_dict": reactants_dict,
            "products_dict": products_dict,
            "reactants_start_index": 0,
            "reactants_end_index": max(reactants_dict.keys()),
            "products_start_index": min(products_dict.keys()),
            "products_end_index": max(products_dict.keys()),
            "atom_tokens_dict": atom_tokens_dict,
            "non_atom_tokens": non_atom_tokens,
        }

        return string_info_dict

    def mask_attn_matrix(
        self,
        attn: np.ndarray,
        string_info_dict: StringInfoDict,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Masks the attention matrix to set the attention probability for certain tokens to 0.

        Args:
            attn: The attention matrix to be masked.
            reactants_start_index: The index of the first reactant token.
            reactants_end_index: The index of the last reactant token.
            products_start_index: The index of the first product token.
            products_end_index: The index of the last product token.
            non_atom_tokens: A list of indices of non-atom tokens.
            atom_tokens_dict: A dictionary mapping atom numbers to a list of token indices.

        Returns:
            The masked attention matrix.
        """
        attn[
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
        ] = -1e6  # Set attention logits for reactant tokens to other reactant tokens to very small value
        attn[
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
        ] = -1e6  # Set attention logits for product tokens to other product tokens to very small value
        for i in string_info_dict[
            "non_atom_tokens"
        ]:  # Set attention logits for reactant or product tokens to non-atom tokens to very small value
            attn[i] = -1e6
            attn[:, i] = -1e6
        for token_indices in string_info_dict[
            "atom_tokens_dict"
        ].values():  # Set attention logits for reactant and product tokens of different atom numbers to very small value
            idx = np.asarray(token_indices, dtype=np.int64)

            diff_atom_mask = np.ones(attn.shape[1], dtype=bool)
            diff_atom_mask[idx] = False
            attn[np.ix_(idx, diff_atom_mask)] = -1e6
            attn[np.ix_(diff_atom_mask, idx)] = -1e6

        row_max = np.max(attn, axis=1, keepdims=True)  # max per row
        exp_logits = np.exp(attn - row_max)
        probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

        probs[
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
            string_info_dict["reactants_start_index"] : string_info_dict[
                "products_start_index"
            ]
            - 1,
        ] = 0  # Set attention probability for reactant tokens to other reactant tokens to 0
        probs[
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
            string_info_dict["products_start_index"] : string_info_dict[
                "products_end_index"
            ]
            + 1,
        ] = 0  # Set attention probability for product tokens to other product tokens to 0
        for i in string_info_dict[
            "non_atom_tokens"
        ]:  # Set attention probability for reactant or product tokens to non-atom tokens to 0
            probs[i] = 0
            probs[:, i] = 0

        for token_indices in string_info_dict[
            "atom_tokens_dict"
        ].values():  # Set attention probability for reactant and product tokens of different atom numbers to 0
            idx = np.asarray(token_indices, dtype=np.int64)

            diff_atom_mask = np.ones(probs.shape[1], dtype=bool)
            diff_atom_mask[idx] = False
            probs[np.ix_(idx, diff_atom_mask)] = 0
            probs[np.ix_(diff_atom_mask, idx)] = 0

        return probs, exp_logits

    def average_attn_scores(
        self,
        out: np.ndarray,
        reactants_start_index: int,
        reactants_end_index: int,
        products_start_index: int,
    ) -> np.ndarray:
        """
        Compute the average attention scores between reactants and products.

        Args:
            out (np.ndarray): The attention matrix.
            reactants_start_index (int): The index of the first reactant token.
            reactants_end_index (int): The index of the last reactant token.
            products_start_index (int): The index of the first product token.

        Returns:
            np.ndarray: The average attention scores between reactants and products.
        """
        reactants_to_products_attn = out[
            products_start_index:,
            reactants_start_index : reactants_end_index + 1,
        ]  # reactants to products attention
        products_to_reactants_attn = out[
            reactants_start_index : reactants_end_index + 1, products_start_index:
        ].T  # products to reactants attention, transposed so indices align
        avg_attn = (reactants_to_products_attn + products_to_reactants_attn) / 2
        return avg_attn

    def remove_non_atom_rows_and_columns(
        self, attn: np.ndarray, string_info_dict: StringInfoDict
    ) -> np.ndarray:
        """
        Remove non-atom tokens from attention matrix.

        Args:
            attn (np.ndarray): The attention matrix.
            string_info_dict (StringInfoDict): A dictionary containing information about the tokens in the reaction SMILES string.

        Returns:
            np.ndarray: The attention matrix with non-atom tokens removed.
        """
        reactants_non_atom_tokens = [
            ele
            for ele in string_info_dict["non_atom_tokens"]
            if ele <= string_info_dict["reactants_end_index"]
        ]  # Get non-atom tokens in reactants
        products_non_atom_tokens = [
            ele - string_info_dict["products_start_index"]
            for ele in string_info_dict["non_atom_tokens"]
            if ele >= string_info_dict["products_start_index"]
        ]  # Get non-atom tokens in products with offset of products start index

        idx = np.asarray(reactants_non_atom_tokens, dtype=int)
        attn = np.delete(attn, idx, axis=1)

        idx = np.asarray(products_non_atom_tokens, dtype=int)
        attn = np.delete(attn, idx, axis=0)

        return attn

    def get_duplicate_indices(self, list_of_lists: list) -> Dict[int, List[int]]:
        result = {}
        offset = 0

        for sublist in list_of_lists:
            # Group flattened indices by value within this sublist
            value_to_indices = defaultdict(list)
            for i, val in enumerate(sublist):
                value_to_indices[val].append(offset + i)

            # For each item, map to OTHER items with same value in the same sublist
            for i, val in enumerate(sublist):
                flat_idx = offset + i
                others = [idx for idx in value_to_indices[val] if idx != flat_idx]
                if others:  # only include entries that actually have duplicates
                    result[flat_idx] = others

            offset += len(sublist)

        return result

    def assign_atom_maps(
        self,
        rxn_smiles: str,
        attn: np.ndarray,
        one_to_one_correspondence: bool = False,
        adjacent_atom_multiplier: float = 30,
        identical_adjacent_atom_multiplier: float = 10,
        reactants_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
        products_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
    ) -> Tuple[str, float]:
        """ """
        if not reactants_atom_idx_to_orig_mapping:
            reactants_atom_idx_to_orig_mapping = {}
        if not products_atom_idx_to_orig_mapping:
            products_atom_idx_to_orig_mapping = {}

        reactants_str, products_str = self._split_reaction_components(rxn_smiles)

        reactants_mols = [
            Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
        ]
        products_mols = [
            Chem.MolFromSmiles(product) for product in products_str.split(".")
        ]

        reactants_atom_dict = {}
        reactants_atom_dict_neighbors = {}
        reactant_atom_num = 0
        for mol in reactants_mols:
            mol_reactants_atom_dict = {}
            mol_idx_to_atom_num = {}
            for atom in mol.GetAtoms():
                mol_reactants_atom_dict[reactant_atom_num] = atom
                mol_idx_to_atom_num[atom.GetIdx()] = reactant_atom_num
                reactant_atom_num += 1
            for atom_reactant_atom_num, atom in mol_reactants_atom_dict.items():
                reactants_atom_dict_neighbors[atom_reactant_atom_num] = [
                    (
                        mol_idx_to_atom_num[neighbor.GetIdx()],
                        self._encode_atom(neighbor),
                    )
                    for neighbor in atom.GetNeighbors()
                ]
            reactants_atom_dict.update(mol_reactants_atom_dict)

        products_atom_dict = {}
        products_atom_dict_neighbors = {}
        product_atom_num = 0
        for mol in products_mols:
            mol_products_atom_dict = {}
            mol_idx_to_atom_num = {}
            for atom in mol.GetAtoms():
                mol_products_atom_dict[product_atom_num] = atom
                mol_idx_to_atom_num[atom.GetIdx()] = product_atom_num
                product_atom_num += 1
            for atom_product_atom_num, atom in mol_products_atom_dict.items():
                products_atom_dict_neighbors[atom_product_atom_num] = [
                    (
                        mol_idx_to_atom_num[neighbor.GetIdx()],
                        self._encode_atom(neighbor),
                    )
                    for neighbor in atom.GetNeighbors()
                ]
            products_atom_dict.update(mol_products_atom_dict)

        products_orig_mapping_to_idx = {
            value: key
            for key, value in products_atom_idx_to_orig_mapping.items()
            if value != 0
        }
        reactants_orig_mapping_to_idx = {
            value: key
            for key, value in reactants_atom_idx_to_orig_mapping.items()
            if value != 0
        }

        orig_attn = attn.copy()

        reactants_mols_canonical = [
            list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
            for mol in reactants_mols
        ]
        products_mols_canonical = [
            list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
            for mol in products_mols
        ]

        reactants_symmetric_atom_indices = self.get_duplicate_indices(
            reactants_mols_canonical
        )
        # products_symmetric_atom_indices = self.get_duplicate_indices(products_mols_canonical)

        reactants_identical_atoms_indices = []
        for k, v in reactants_symmetric_atom_indices.items():
            reactants_identical_atoms_indices.append(tuple(sorted([k] + v)))

        reactants_identical_atoms_indices = list(set(reactants_identical_atoms_indices))

        symmetric_atom_new_val_mapping = {}
        for symmetric_atoms in reactants_identical_atoms_indices:
            summed_vals = np.sum(
                orig_attn[
                    :,
                    symmetric_atoms,
                ],
                axis=1,
            )
            for val in symmetric_atoms:
                symmetric_atom_new_val_mapping[val] = summed_vals

        for k, v in symmetric_atom_new_val_mapping.items():
            orig_attn[:, k] = v

        assignment_probs = []
        if one_to_one_correspondence:
            for map_num in range(attn.shape[0]):
                if products_orig_mapping_to_idx.get(map_num + 1, 0):
                    row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
                    col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]
                    products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                    reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                    attn[row_highest_attn] = 0
                    attn[:, col_highest_attn] = 0
                    assignment_probs.append(1.0)

                else:
                    highest_attn_score = attn.max()
                    highest_attn_score_indices = np.where(attn == highest_attn_score)
                    row_highest_attn = highest_attn_score_indices[0][0]
                    col_highest_attn = highest_attn_score_indices[1][0]
                    products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                    reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                    attn[row_highest_attn] = 0
                    attn[:, col_highest_attn] = 0
                    assignment_probs.append(
                        orig_attn[row_highest_attn, col_highest_attn]
                    )

                # Update neighbors
                for (
                    product_atom_idx,
                    product_atom_env,
                ) in products_atom_dict_neighbors[row_highest_attn]:
                    for (
                        reactant_atom_idx,
                        reactant_atom_env,
                    ) in reactants_atom_dict_neighbors[col_highest_attn]:
                        if product_atom_env == reactant_atom_env:
                            attn[product_atom_idx, reactant_atom_idx] *= (
                                adjacent_atom_multiplier
                                * identical_adjacent_atom_multiplier
                            )
                        else:
                            attn[product_atom_idx, reactant_atom_idx] *= (
                                adjacent_atom_multiplier
                            )

        else:
            for map_num, row in enumerate(attn):
                # get partial mapping working for now
                # if products_orig_mapping_to_idx.get(map_num + 1, 0):
                #     row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
                #     col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]
                #     products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                #     reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                #     print(map_num + 1, row_highest_attn, col_highest_attn)
                # else:
                highest_attn_score = row.max()
                highest_attn_indices = int(np.where(row == highest_attn_score)[0][0])
                products_atom_dict[map_num].SetAtomMapNum(map_num + 1)
                reactants_atom_dict[highest_attn_indices].SetAtomMapNum(map_num + 1)
                assignment_probs.append(orig_attn[map_num, highest_attn_indices])

        mapped_reactants_str = ".".join(
            [Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols]
        )
        mapped_products_str = ".".join(
            [Chem.MolToSmiles(product, canonical=False) for product in products_mols]
        )
        mapped_rxn_smiles = mapped_reactants_str + ">>" + mapped_products_str

        confidence = float(np.prod(assignment_probs))

        return mapped_rxn_smiles, confidence

    def get_data_from_partially_mapped_smiles(self, rxn_smiles):
        reactants_str, products_str = self._split_reaction_components(rxn_smiles)
        reactants_mols = [
            Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
        ]
        products_mols = [
            Chem.MolFromSmiles(product) for product in products_str.split(".")
        ]

        reactants_atom_idx_to_orig_mapping = {}
        reactants_atom_dict = {}
        reactants_atom_dict_neighbors = {}
        reactant_atom_num = 0
        for mol in reactants_mols:
            for atom in mol.GetAtoms():
                reactants_atom_dict[reactant_atom_num] = atom
                reactants_atom_idx_to_orig_mapping[reactant_atom_num] = (
                    atom.GetAtomMapNum()
                )
                reactants_atom_dict_neighbors[reactant_atom_num] = [
                    neighbor.GetIdx() for neighbor in atom.GetNeighbors()
                ]
                atom.SetAtomMapNum(0)
                reactant_atom_num += 1

        products_atom_idx_to_orig_mapping = {}
        products_atom_dict = {}
        products_atom_dict_neighbors = {}
        product_atom_num = 0
        for mol in products_mols:
            for atom in mol.GetAtoms():
                products_atom_dict[product_atom_num] = atom
                products_atom_idx_to_orig_mapping[product_atom_num] = (
                    atom.GetAtomMapNum()
                )
                products_atom_dict_neighbors[product_atom_num] = [
                    neighbor.GetIdx() for neighbor in atom.GetNeighbors()
                ]
                atom.SetAtomMapNum(0)
                product_atom_num += 1

        unmapped_reactants_strings = [
            Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols
        ]

        unmapped_products_strings = [
            Chem.MolToSmiles(product, canonical=False) for product in products_mols
        ]

        unmapped_rxn = (
            ".".join(unmapped_reactants_strings)
            + ">>"
            + ".".join(unmapped_products_strings)
        )

        return (
            unmapped_rxn,
            reactants_atom_idx_to_orig_mapping,
            products_atom_idx_to_orig_mapping,
        )

    def map_reaction(
        self,
        rxn_smiles: str,
        layer: int = 9,
        head: int = 0,
        sequence_max_length: int = 256,
        adjacent_atom_multiplier: float = 10,
        identical_adjacent_atom_multiplier: float = 10,
        one_to_one_correspondence: bool = False,
        start_from_partial_map: bool = False,
    ) -> ReactionMapperResult:
        """
        Maps a reaction SMILES string using a pre-trained Albert model.

        Args:
            rxn_smiles (str): A reaction SMILES string.
            layer (int): 0-based layer index to use for attention.
            head (int): 0-based head index to use for attention.

        Returns:
            str: A mapped reaction SMILES string with atom map numbers assigned.
        """
        default_mapping_dict = ReactionMapperResult(
            original_smiles="",
            selected_mapping="",
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )
        if not self._reaction_smiles_valid(rxn_smiles):
            return default_mapping_dict

        reactants_atom_idx_to_orig_mapping = None
        products_atom_idx_to_orig_mapping = None
        if start_from_partial_map:
            (
                rxn_smiles,
                reactants_atom_idx_to_orig_mapping,
                products_atom_idx_to_orig_mapping,
            ) = self.get_data_from_partially_mapped_smiles(rxn_smiles)

        attn, tokens = self.get_attention_matrix_for_head(
            text=rxn_smiles,
            layer=layer,
            head=head,
            max_length=sequence_max_length,
            trim_padding=True,
        )

        if "[UNK]" in tokens:
            logger.warning("Unknown token in sequence")
            return default_mapping_dict

        if ">>" not in tokens:
            logger.warning("Sequence too long")

            return default_mapping_dict

        if len(tokens) >= sequence_max_length:
            logger.warning("Sequence too long")
            return default_mapping_dict

        string_info_dict = self.get_reactants_products_dict(tokens)
        attn_probs, _ = self.mask_attn_matrix(attn, string_info_dict)
        attn = self.average_attn_scores(
            attn_probs,
            string_info_dict["reactants_start_index"],
            string_info_dict["reactants_end_index"],
            string_info_dict["products_start_index"],
        )

        attn = self.remove_non_atom_rows_and_columns(attn, string_info_dict)

        mapped_rxn_smiles, confidence = self.assign_atom_maps(
            rxn_smiles,
            attn,
            one_to_one_correspondence=one_to_one_correspondence,
            adjacent_atom_multiplier=adjacent_atom_multiplier,
            identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
            reactants_atom_idx_to_orig_mapping=reactants_atom_idx_to_orig_mapping,
            products_atom_idx_to_orig_mapping=products_atom_idx_to_orig_mapping,
        )

        if not self._verify_validity_of_mapping(mapped_rxn_smiles):
            return default_mapping_dict

        return ReactionMapperResult(
            original_smiles=rxn_smiles,
            selected_mapping=mapped_rxn_smiles,
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=confidence,
            additional_info=[{}],
        )

    def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
        """ """
        mapped_reactions = []
        for reaction in reaction_list:
            mapped_reactions.append(self.map_reaction(reaction))
        return mapped_reactions

__init__(mapper_name, mapper_weight=3, checkpoint_path=None, use_supervised=True, supervised_config=None, sequence_max_length=256)

Initialize the NeuralReactionMapper instance.

Parameters:

Name Type Description Default
mapper_name str

The name of the mapper.

required
mapper_weight float

The weight of the mapper.

3
checkpoint_path Optional[str]

The path to the checkpoint file.

None
Source code in agave_chem/mappers/neural/neural_mapper.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
def __init__(
    self,
    mapper_name: str,
    mapper_weight: float = 3,
    checkpoint_path: Optional[str] = None,
    use_supervised: bool = True,
    supervised_config: SupervisedConfig | None = None,
    sequence_max_length: int = 256,
):
    """
    Initialize the NeuralReactionMapper instance.

    Args:
        mapper_name (str): The name of the mapper.
        mapper_weight (float): The weight of the mapper.
        checkpoint_path (Optional[str]): The path to the checkpoint file.
    """

    super().__init__("neural", mapper_name, mapper_weight)

    if not checkpoint_path:
        checkpoint_path = str(
            files("agave_chem.datafiles").joinpath("supervised_albert_model")
        )

    self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    self._sequence_max_length = sequence_max_length
    self._use_supervised = use_supervised
    self._supervised_config = supervised_config or SupervisedConfig()

    self._model = load_neural_albert_model(
        checkpoint_dir=checkpoint_path,
        device=self._device,
        use_supervised=use_supervised,
        max_length=sequence_max_length,
        supervised_config=self._supervised_config,
    )

    self._tokenizer = CustomTokenizer(smiles_token_to_id_dict)

assign_atom_maps(rxn_smiles, attn, one_to_one_correspondence=False, adjacent_atom_multiplier=30, identical_adjacent_atom_multiplier=10, reactants_atom_idx_to_orig_mapping=None, products_atom_idx_to_orig_mapping=None)

Source code in agave_chem/mappers/neural/neural_mapper.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
def assign_atom_maps(
    self,
    rxn_smiles: str,
    attn: np.ndarray,
    one_to_one_correspondence: bool = False,
    adjacent_atom_multiplier: float = 30,
    identical_adjacent_atom_multiplier: float = 10,
    reactants_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
    products_atom_idx_to_orig_mapping: Optional[Dict[int, int]] = None,
) -> Tuple[str, float]:
    """ """
    if not reactants_atom_idx_to_orig_mapping:
        reactants_atom_idx_to_orig_mapping = {}
    if not products_atom_idx_to_orig_mapping:
        products_atom_idx_to_orig_mapping = {}

    reactants_str, products_str = self._split_reaction_components(rxn_smiles)

    reactants_mols = [
        Chem.MolFromSmiles(reactant) for reactant in reactants_str.split(".")
    ]
    products_mols = [
        Chem.MolFromSmiles(product) for product in products_str.split(".")
    ]

    reactants_atom_dict = {}
    reactants_atom_dict_neighbors = {}
    reactant_atom_num = 0
    for mol in reactants_mols:
        mol_reactants_atom_dict = {}
        mol_idx_to_atom_num = {}
        for atom in mol.GetAtoms():
            mol_reactants_atom_dict[reactant_atom_num] = atom
            mol_idx_to_atom_num[atom.GetIdx()] = reactant_atom_num
            reactant_atom_num += 1
        for atom_reactant_atom_num, atom in mol_reactants_atom_dict.items():
            reactants_atom_dict_neighbors[atom_reactant_atom_num] = [
                (
                    mol_idx_to_atom_num[neighbor.GetIdx()],
                    self._encode_atom(neighbor),
                )
                for neighbor in atom.GetNeighbors()
            ]
        reactants_atom_dict.update(mol_reactants_atom_dict)

    products_atom_dict = {}
    products_atom_dict_neighbors = {}
    product_atom_num = 0
    for mol in products_mols:
        mol_products_atom_dict = {}
        mol_idx_to_atom_num = {}
        for atom in mol.GetAtoms():
            mol_products_atom_dict[product_atom_num] = atom
            mol_idx_to_atom_num[atom.GetIdx()] = product_atom_num
            product_atom_num += 1
        for atom_product_atom_num, atom in mol_products_atom_dict.items():
            products_atom_dict_neighbors[atom_product_atom_num] = [
                (
                    mol_idx_to_atom_num[neighbor.GetIdx()],
                    self._encode_atom(neighbor),
                )
                for neighbor in atom.GetNeighbors()
            ]
        products_atom_dict.update(mol_products_atom_dict)

    products_orig_mapping_to_idx = {
        value: key
        for key, value in products_atom_idx_to_orig_mapping.items()
        if value != 0
    }
    reactants_orig_mapping_to_idx = {
        value: key
        for key, value in reactants_atom_idx_to_orig_mapping.items()
        if value != 0
    }

    orig_attn = attn.copy()

    reactants_mols_canonical = [
        list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
        for mol in reactants_mols
    ]
    products_mols_canonical = [
        list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
        for mol in products_mols
    ]

    reactants_symmetric_atom_indices = self.get_duplicate_indices(
        reactants_mols_canonical
    )
    # products_symmetric_atom_indices = self.get_duplicate_indices(products_mols_canonical)

    reactants_identical_atoms_indices = []
    for k, v in reactants_symmetric_atom_indices.items():
        reactants_identical_atoms_indices.append(tuple(sorted([k] + v)))

    reactants_identical_atoms_indices = list(set(reactants_identical_atoms_indices))

    symmetric_atom_new_val_mapping = {}
    for symmetric_atoms in reactants_identical_atoms_indices:
        summed_vals = np.sum(
            orig_attn[
                :,
                symmetric_atoms,
            ],
            axis=1,
        )
        for val in symmetric_atoms:
            symmetric_atom_new_val_mapping[val] = summed_vals

    for k, v in symmetric_atom_new_val_mapping.items():
        orig_attn[:, k] = v

    assignment_probs = []
    if one_to_one_correspondence:
        for map_num in range(attn.shape[0]):
            if products_orig_mapping_to_idx.get(map_num + 1, 0):
                row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
                col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]
                products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                attn[row_highest_attn] = 0
                attn[:, col_highest_attn] = 0
                assignment_probs.append(1.0)

            else:
                highest_attn_score = attn.max()
                highest_attn_score_indices = np.where(attn == highest_attn_score)
                row_highest_attn = highest_attn_score_indices[0][0]
                col_highest_attn = highest_attn_score_indices[1][0]
                products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
                reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
                attn[row_highest_attn] = 0
                attn[:, col_highest_attn] = 0
                assignment_probs.append(
                    orig_attn[row_highest_attn, col_highest_attn]
                )

            # Update neighbors
            for (
                product_atom_idx,
                product_atom_env,
            ) in products_atom_dict_neighbors[row_highest_attn]:
                for (
                    reactant_atom_idx,
                    reactant_atom_env,
                ) in reactants_atom_dict_neighbors[col_highest_attn]:
                    if product_atom_env == reactant_atom_env:
                        attn[product_atom_idx, reactant_atom_idx] *= (
                            adjacent_atom_multiplier
                            * identical_adjacent_atom_multiplier
                        )
                    else:
                        attn[product_atom_idx, reactant_atom_idx] *= (
                            adjacent_atom_multiplier
                        )

    else:
        for map_num, row in enumerate(attn):
            # get partial mapping working for now
            # if products_orig_mapping_to_idx.get(map_num + 1, 0):
            #     row_highest_attn = products_orig_mapping_to_idx[map_num + 1]
            #     col_highest_attn = reactants_orig_mapping_to_idx[map_num + 1]
            #     products_atom_dict[row_highest_attn].SetAtomMapNum(map_num + 1)
            #     reactants_atom_dict[col_highest_attn].SetAtomMapNum(map_num + 1)
            #     print(map_num + 1, row_highest_attn, col_highest_attn)
            # else:
            highest_attn_score = row.max()
            highest_attn_indices = int(np.where(row == highest_attn_score)[0][0])
            products_atom_dict[map_num].SetAtomMapNum(map_num + 1)
            reactants_atom_dict[highest_attn_indices].SetAtomMapNum(map_num + 1)
            assignment_probs.append(orig_attn[map_num, highest_attn_indices])

    mapped_reactants_str = ".".join(
        [Chem.MolToSmiles(reactant, canonical=False) for reactant in reactants_mols]
    )
    mapped_products_str = ".".join(
        [Chem.MolToSmiles(product, canonical=False) for product in products_mols]
    )
    mapped_rxn_smiles = mapped_reactants_str + ">>" + mapped_products_str

    confidence = float(np.prod(assignment_probs))

    return mapped_rxn_smiles, confidence

average_attn_scores(out, reactants_start_index, reactants_end_index, products_start_index)

Compute the average attention scores between reactants and products.

Parameters:

Name Type Description Default
out ndarray

The attention matrix.

required
reactants_start_index int

The index of the first reactant token.

required
reactants_end_index int

The index of the last reactant token.

required
products_start_index int

The index of the first product token.

required

Returns:

Type Description
ndarray

np.ndarray: The average attention scores between reactants and products.

Source code in agave_chem/mappers/neural/neural_mapper.py
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
def average_attn_scores(
    self,
    out: np.ndarray,
    reactants_start_index: int,
    reactants_end_index: int,
    products_start_index: int,
) -> np.ndarray:
    """
    Compute the average attention scores between reactants and products.

    Args:
        out (np.ndarray): The attention matrix.
        reactants_start_index (int): The index of the first reactant token.
        reactants_end_index (int): The index of the last reactant token.
        products_start_index (int): The index of the first product token.

    Returns:
        np.ndarray: The average attention scores between reactants and products.
    """
    reactants_to_products_attn = out[
        products_start_index:,
        reactants_start_index : reactants_end_index + 1,
    ]  # reactants to products attention
    products_to_reactants_attn = out[
        reactants_start_index : reactants_end_index + 1, products_start_index:
    ].T  # products to reactants attention, transposed so indices align
    avg_attn = (reactants_to_products_attn + products_to_reactants_attn) / 2
    return avg_attn

get_attention_matrix_for_head(text, layer, head, max_length=256, trim_padding=True)

Returns the attention matrix for a given layer/head for a single input string.

Parameters:

Name Type Description Default
text str

input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)

required
layer int

0-based layer index

required
head int

0-based head index

required
max_length int

tokenization length (should match training, e.g. 256)

256
trim_padding bool

if True, slices matrix down to non-pad tokens only

True

Returns:

Name Type Description
attn ndarray

Tensor of shape (seq_len, seq_len) (trimmed if requested)

tokens List[str]

list[str] tokens aligned to attn axes (trimmed if requested)

Source code in agave_chem/mappers/neural/neural_mapper.py
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
def get_attention_matrix_for_head(
    self,
    text: str,
    layer: int,
    head: int,
    max_length: int = 256,
    trim_padding: bool = True,
) -> Tuple[np.ndarray, List[str]]:
    """
    Returns the attention matrix for a given layer/head for a single input string.

    Args:
        text: input reaction SMILES string (raw is fine; CustomTokenizer preprocesses)
        layer: 0-based layer index
        head: 0-based head index
        max_length: tokenization length (should match training, e.g. 256)
        trim_padding: if True, slices matrix down to non-pad tokens only

    Returns:
        attn: Tensor of shape (seq_len, seq_len) (trimmed if requested)
        tokens: list[str] tokens aligned to attn axes (trimmed if requested)
    """
    self._model.eval()

    enc = self._tokenizer(
        text,
        max_length=max_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"].to(self._device)
    attention_mask = enc["attention_mask"].to(self._device)

    token_type_ids = enc.get("token_type_ids", torch.zeros_like(enc["input_ids"]))
    token_type_ids = token_type_ids.to(self._device)

    with torch.no_grad():
        if self._use_supervised:
            # self._model is AlbertWithAttentionAlignment
            attn_probs = self._model.predict_attention_probs(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
            )  # (B,S,S)
            attn = attn_probs[0].detach().cpu()  # (S,S)
        else:
            # self._model is AlbertForMaskedLM
            outputs = self._model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True,
                return_dict=True,
            )
            attentions = outputs.attentions  # tuple[num_layers] of (B,H,S,S)

            if layer < 0 or layer >= len(attentions):
                raise ValueError(
                    f"layer must be in [0, {len(attentions) - 1}], got {layer}"
                )

            num_heads = attentions[layer].shape[1]
            if head < 0 or head >= num_heads:
                raise ValueError(
                    f"head must be in [0, {num_heads - 1}], got {head}"
                )

            attn = attentions[layer][0, head].detach().cpu()  # (S,S)

    # Tokens for inspection/plotting
    token_ids = enc["input_ids"][0].tolist()
    tokens = self._tokenizer.convert_ids_to_tokens(token_ids)

    if trim_padding:
        real_len = int(enc["attention_mask"][0].sum().item())
        attn = attn[:real_len, :real_len]
        tokens = tokens[:real_len]

    # IMPORTANT: keep downstream behavior identical by returning log-attn
    return torch.log(attn)[1:-1, 1:-1].numpy(), tokens

get_reactants_products_dict(tokens)

Extracts reactants and products from a list of tokens in a reaction SMILES string.

Parameters:

Name Type Description Default
tokens List[str]

A list of tokens in a reaction SMILES string.

required

Returns:

Type Description
StringInfoDict

A tuple containing: reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings. products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings. atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices. non_atom_tokens: A list of token indices that correspond to non-atom tokens. reactants_start_index: The index of the first reactant token. reactants_end_index: The index of the last reactant token. products_start_index: The index of the first product token. products_end_index: The index of the last product token.

Source code in agave_chem/mappers/neural/neural_mapper.py
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def get_reactants_products_dict(
    self,
    tokens: List[str],
) -> StringInfoDict:
    """
    Extracts reactants and products from a list of tokens in a reaction SMILES string.

    Args:
        tokens: A list of tokens in a reaction SMILES string.

    Returns:
        A tuple containing:
            reactants_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
            products_dict: A dictionary where the keys are token indices and the values are the corresponding token strings.
            atom_tokens_dict: A dictionary where the keys are atom identities and the values are lists of token indices.
            non_atom_tokens: A list of token indices that correspond to non-atom tokens.
            reactants_start_index: The index of the first reactant token.
            reactants_end_index: The index of the last reactant token.
            products_start_index: The index of the first product token.
            products_end_index: The index of the last product token.
    """
    reactants_dict: Dict[int, str] = {}
    products_dict: Dict[int, str] = {}
    atom_tokens_dict: Dict[int, List[int]] = {}
    non_atom_tokens: List[int] = []

    found_reaction_symbol = False
    for i, token in enumerate(tokens[1:-1]):
        if token == ">>":
            found_reaction_symbol = True
            non_atom_tokens.append(i)
            continue
        if token_atom_identity_dict.get(token, 0) == 0:
            non_atom_tokens.append(i)
        else:
            if token_atom_identity_dict.get(token, 0) not in atom_tokens_dict:
                atom_tokens_dict[token_atom_identity_dict.get(token, 0)] = [i]
            else:
                atom_tokens_dict[token_atom_identity_dict.get(token, 0)].append(i)
        if found_reaction_symbol:
            products_dict[i] = token
        else:
            reactants_dict[i] = token

    string_info_dict: StringInfoDict = {
        "reactants_dict": reactants_dict,
        "products_dict": products_dict,
        "reactants_start_index": 0,
        "reactants_end_index": max(reactants_dict.keys()),
        "products_start_index": min(products_dict.keys()),
        "products_end_index": max(products_dict.keys()),
        "atom_tokens_dict": atom_tokens_dict,
        "non_atom_tokens": non_atom_tokens,
    }

    return string_info_dict

map_reaction(rxn_smiles, layer=9, head=0, sequence_max_length=256, adjacent_atom_multiplier=10, identical_adjacent_atom_multiplier=10, one_to_one_correspondence=False, start_from_partial_map=False)

Maps a reaction SMILES string using a pre-trained Albert model.

Parameters:

Name Type Description Default
rxn_smiles str

A reaction SMILES string.

required
layer int

0-based layer index to use for attention.

9
head int

0-based head index to use for attention.

0

Returns:

Name Type Description
str ReactionMapperResult

A mapped reaction SMILES string with atom map numbers assigned.

Source code in agave_chem/mappers/neural/neural_mapper.py
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
def map_reaction(
    self,
    rxn_smiles: str,
    layer: int = 9,
    head: int = 0,
    sequence_max_length: int = 256,
    adjacent_atom_multiplier: float = 10,
    identical_adjacent_atom_multiplier: float = 10,
    one_to_one_correspondence: bool = False,
    start_from_partial_map: bool = False,
) -> ReactionMapperResult:
    """
    Maps a reaction SMILES string using a pre-trained Albert model.

    Args:
        rxn_smiles (str): A reaction SMILES string.
        layer (int): 0-based layer index to use for attention.
        head (int): 0-based head index to use for attention.

    Returns:
        str: A mapped reaction SMILES string with atom map numbers assigned.
    """
    default_mapping_dict = ReactionMapperResult(
        original_smiles="",
        selected_mapping="",
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )
    if not self._reaction_smiles_valid(rxn_smiles):
        return default_mapping_dict

    reactants_atom_idx_to_orig_mapping = None
    products_atom_idx_to_orig_mapping = None
    if start_from_partial_map:
        (
            rxn_smiles,
            reactants_atom_idx_to_orig_mapping,
            products_atom_idx_to_orig_mapping,
        ) = self.get_data_from_partially_mapped_smiles(rxn_smiles)

    attn, tokens = self.get_attention_matrix_for_head(
        text=rxn_smiles,
        layer=layer,
        head=head,
        max_length=sequence_max_length,
        trim_padding=True,
    )

    if "[UNK]" in tokens:
        logger.warning("Unknown token in sequence")
        return default_mapping_dict

    if ">>" not in tokens:
        logger.warning("Sequence too long")

        return default_mapping_dict

    if len(tokens) >= sequence_max_length:
        logger.warning("Sequence too long")
        return default_mapping_dict

    string_info_dict = self.get_reactants_products_dict(tokens)
    attn_probs, _ = self.mask_attn_matrix(attn, string_info_dict)
    attn = self.average_attn_scores(
        attn_probs,
        string_info_dict["reactants_start_index"],
        string_info_dict["reactants_end_index"],
        string_info_dict["products_start_index"],
    )

    attn = self.remove_non_atom_rows_and_columns(attn, string_info_dict)

    mapped_rxn_smiles, confidence = self.assign_atom_maps(
        rxn_smiles,
        attn,
        one_to_one_correspondence=one_to_one_correspondence,
        adjacent_atom_multiplier=adjacent_atom_multiplier,
        identical_adjacent_atom_multiplier=identical_adjacent_atom_multiplier,
        reactants_atom_idx_to_orig_mapping=reactants_atom_idx_to_orig_mapping,
        products_atom_idx_to_orig_mapping=products_atom_idx_to_orig_mapping,
    )

    if not self._verify_validity_of_mapping(mapped_rxn_smiles):
        return default_mapping_dict

    return ReactionMapperResult(
        original_smiles=rxn_smiles,
        selected_mapping=mapped_rxn_smiles,
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=confidence,
        additional_info=[{}],
    )

map_reactions(reaction_list)

Source code in agave_chem/mappers/neural/neural_mapper.py
1035
1036
1037
1038
1039
1040
def map_reactions(self, reaction_list: List[str]) -> List[ReactionMapperResult]:
    """ """
    mapped_reactions = []
    for reaction in reaction_list:
        mapped_reactions.append(self.map_reaction(reaction))
    return mapped_reactions

mask_attn_matrix(attn, string_info_dict)

Masks the attention matrix to set the attention probability for certain tokens to 0.

Parameters:

Name Type Description Default
attn ndarray

The attention matrix to be masked.

required
reactants_start_index

The index of the first reactant token.

required
reactants_end_index

The index of the last reactant token.

required
products_start_index

The index of the first product token.

required
products_end_index

The index of the last product token.

required
non_atom_tokens

A list of indices of non-atom tokens.

required
atom_tokens_dict

A dictionary mapping atom numbers to a list of token indices.

required

Returns:

Type Description
Tuple[ndarray, ndarray]

The masked attention matrix.

Source code in agave_chem/mappers/neural/neural_mapper.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
def mask_attn_matrix(
    self,
    attn: np.ndarray,
    string_info_dict: StringInfoDict,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Masks the attention matrix to set the attention probability for certain tokens to 0.

    Args:
        attn: The attention matrix to be masked.
        reactants_start_index: The index of the first reactant token.
        reactants_end_index: The index of the last reactant token.
        products_start_index: The index of the first product token.
        products_end_index: The index of the last product token.
        non_atom_tokens: A list of indices of non-atom tokens.
        atom_tokens_dict: A dictionary mapping atom numbers to a list of token indices.

    Returns:
        The masked attention matrix.
    """
    attn[
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
    ] = -1e6  # Set attention logits for reactant tokens to other reactant tokens to very small value
    attn[
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
    ] = -1e6  # Set attention logits for product tokens to other product tokens to very small value
    for i in string_info_dict[
        "non_atom_tokens"
    ]:  # Set attention logits for reactant or product tokens to non-atom tokens to very small value
        attn[i] = -1e6
        attn[:, i] = -1e6
    for token_indices in string_info_dict[
        "atom_tokens_dict"
    ].values():  # Set attention logits for reactant and product tokens of different atom numbers to very small value
        idx = np.asarray(token_indices, dtype=np.int64)

        diff_atom_mask = np.ones(attn.shape[1], dtype=bool)
        diff_atom_mask[idx] = False
        attn[np.ix_(idx, diff_atom_mask)] = -1e6
        attn[np.ix_(diff_atom_mask, idx)] = -1e6

    row_max = np.max(attn, axis=1, keepdims=True)  # max per row
    exp_logits = np.exp(attn - row_max)
    probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

    probs[
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
        string_info_dict["reactants_start_index"] : string_info_dict[
            "products_start_index"
        ]
        - 1,
    ] = 0  # Set attention probability for reactant tokens to other reactant tokens to 0
    probs[
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
        string_info_dict["products_start_index"] : string_info_dict[
            "products_end_index"
        ]
        + 1,
    ] = 0  # Set attention probability for product tokens to other product tokens to 0
    for i in string_info_dict[
        "non_atom_tokens"
    ]:  # Set attention probability for reactant or product tokens to non-atom tokens to 0
        probs[i] = 0
        probs[:, i] = 0

    for token_indices in string_info_dict[
        "atom_tokens_dict"
    ].values():  # Set attention probability for reactant and product tokens of different atom numbers to 0
        idx = np.asarray(token_indices, dtype=np.int64)

        diff_atom_mask = np.ones(probs.shape[1], dtype=bool)
        diff_atom_mask[idx] = False
        probs[np.ix_(idx, diff_atom_mask)] = 0
        probs[np.ix_(diff_atom_mask, idx)] = 0

    return probs, exp_logits

remove_non_atom_rows_and_columns(attn, string_info_dict)

Remove non-atom tokens from attention matrix.

Parameters:

Name Type Description Default
attn ndarray

The attention matrix.

required
string_info_dict StringInfoDict

A dictionary containing information about the tokens in the reaction SMILES string.

required

Returns:

Type Description
ndarray

np.ndarray: The attention matrix with non-atom tokens removed.

Source code in agave_chem/mappers/neural/neural_mapper.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
def remove_non_atom_rows_and_columns(
    self, attn: np.ndarray, string_info_dict: StringInfoDict
) -> np.ndarray:
    """
    Remove non-atom tokens from attention matrix.

    Args:
        attn (np.ndarray): The attention matrix.
        string_info_dict (StringInfoDict): A dictionary containing information about the tokens in the reaction SMILES string.

    Returns:
        np.ndarray: The attention matrix with non-atom tokens removed.
    """
    reactants_non_atom_tokens = [
        ele
        for ele in string_info_dict["non_atom_tokens"]
        if ele <= string_info_dict["reactants_end_index"]
    ]  # Get non-atom tokens in reactants
    products_non_atom_tokens = [
        ele - string_info_dict["products_start_index"]
        for ele in string_info_dict["non_atom_tokens"]
        if ele >= string_info_dict["products_start_index"]
    ]  # Get non-atom tokens in products with offset of products start index

    idx = np.asarray(reactants_non_atom_tokens, dtype=int)
    attn = np.delete(attn, idx, axis=1)

    idx = np.asarray(products_non_atom_tokens, dtype=int)
    attn = np.delete(attn, idx, axis=0)

    return attn

TemplateReactionMapper

Bases: ReactionMapper

Expert template reaction classification and atom-mapping

Source code in agave_chem/mappers/template/template_mapper.py
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
class TemplateReactionMapper(ReactionMapper):
    """
    Expert template reaction classification and atom-mapping
    """

    def __init__(
        self,
        mapper_name: str,
        mapper_weight: float = 3,
        custom_smirks_patterns: List[SmirksPattern] | None = None,
        use_default_smirks_patterns: bool = True,
        max_transforms: int = 1000,
        max_tautomers: int = 1000,
        use_mcs_mapping: bool = True,
    ):
        """
        Initialize the TemplateMapper instance.

        Args:
            custom_smirks_patterns (List[Dict]): A list of dictionaries containing
                custom SMIRKS patterns. Each dictionary should have a 'name' key,
                a 'smirks' key, and a 'superclass_id' key.
            use_default_smirks_patterns (bool): Whether to use the default SMIRKS
                patterns.
        """

        super().__init__("template", mapper_name, mapper_weight)

        if custom_smirks_patterns is not None:
            if not isinstance(custom_smirks_patterns, list):
                raise TypeError(
                    "Invalid input: custom_smirks_patterns must be a list of dictionaries."
                )
            for pattern in custom_smirks_patterns:
                if set(pattern.keys()) != set(["name", "smirks", "superclass_id"]):
                    raise TypeError(
                        "Invalid input: each dictionary in custom_smirks_patterns must have 'name', 'smirks', and 'superclass_id' keys."
                    )
                for key, value in pattern.items():
                    if key == "superclass_id":
                        if value is not None and not isinstance(value, int):
                            raise TypeError(
                                "Invalid input: 'superclass_id' value must be an integer or None."
                            )
                    else:
                        if not isinstance(value, str):
                            raise TypeError(
                                "Invalid input: 'name' and 'smirks' values must be strings."
                            )

        SMIRKS_PATTERNS_FILE = files("agave_chem.datafiles").joinpath(
            "smirks_patterns.json"
        )
        default_smirks_patterns = []
        with SMIRKS_PATTERNS_FILE.open("r") as f:
            default_smirks_patterns = json.load(f)

        self._smirks_patterns: List[SmirksPattern] = []
        if use_default_smirks_patterns and custom_smirks_patterns is None:
            self._smirks_patterns = default_smirks_patterns
        elif custom_smirks_patterns and not use_default_smirks_patterns:
            self._smirks_patterns = custom_smirks_patterns
        elif custom_smirks_patterns and use_default_smirks_patterns:
            self._smirks_patterns = custom_smirks_patterns + default_smirks_patterns
        else:
            raise TypeError(
                "Attempting to initialize AgaveChem with no SMIRKS patterns"
            )

        self._smirks_name_dictionary: Dict[str, SmirksNameDict] = {
            pattern["smirks"]: {
                "name": pattern["name"],
                "superclass_id": pattern["superclass_id"],
                "class_id": pattern["class_id"],
                "subclass_id": pattern["subclass_id"],
                "class_str": f"{pattern['superclass_id']}.{pattern['class_id']}.{pattern['subclass_id']}",
            }
            for pattern in self._smirks_patterns
        }
        self._initialized_smirks_patterns: List[InitializedSmirksPattern] = (
            initialize_template_data(self._smirks_patterns)
        )

        self._tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
        self._tautomer_enumerator.SetMaxTransforms(max_transforms)
        self._tautomer_enumerator.SetMaxTautomers(max_tautomers)

        self._mcs_mapper = None
        if use_mcs_mapping:
            self._mcs_mapper = MCSReactionMapper(
                mapper_name="mcs_for_template", mapper_weight=1
            )

    def _validate_smirks_patterns(self, smirks_patterns: List[SmirksPattern]) -> None:
        """Validates SMIRKS patterns."""
        pass

    def _prepare_reaction_data(
        self,
        reactants_str: str,
        products_str: str,
        unmapped_product_atom_islands: Optional[Dict[int, Set[int]]] = None,
    ) -> ReactionData:
        """Prepare reaction mapping inputs from reactant and product SMILES strings.

        Args:
            reactants_str (str): Reactants SMILES string.
            products_str (str): Products SMILES string.
            unmapped_product_atom_islands (Optional[List[str]]): Product atom-island SMILES
                strings that are intentionally left unmapped. Defaults to None.

        Returns:
            ReactionData: Mapping input data containing RDKit molecule objects for
            reactants/products, an RDChiral reactants object derived from
            `products_str`, a tautomer SMILES dictionary, a fragment count dictionary,
            and the normalized `unmapped_product_atom_islands` list.
        """

        if unmapped_product_atom_islands is None:
            unmapped_product_atom_islands = {}

        return ReactionData(
            products_mols=[
                Chem.MolFromSmiles(product_str)
                for product_str in products_str.split(".")
            ],
            reactants_mols=[
                Chem.MolFromSmiles(reactant_str)
                for reactant_str in reactants_str.split(".")
            ],
            rdc_products=rdc.rdchiralReactants(products_str),
            tautomers_reactants=self._enumerate_tautomer_smiles(reactants_str),
            fragment_count_reactants=self._get_fragment_count_dict(reactants_str),
            unmapped_product_atom_islands=unmapped_product_atom_islands,
        )

    def _enumerate_tautomer_smiles(self, smiles: str) -> Dict[str, List[str]]:
        """Enumerate tautomer SMILES strings for a given SMILES string.

        Args:
            smiles (str): A SMILES string representing one or more molecular fragments.

        Returns:
            Dict[str, List[str]]: A dictionary where keys are the fragments of the input
                SMILES string and values are lists of the enumerated tautomer SMILES strings
                for each fragment.

        """
        enumerated_smiles_dict: Dict[str, List[str]] = {}
        for fragment_str in smiles.split("."):
            mol = Chem.MolFromSmiles(fragment_str)

            if mol is None:
                enumerated_smiles_dict[fragment_str] = []
                continue

            enumerated_fragment_mols = list(self._tautomer_enumerator.Enumerate(mol))
            enumerated_fragment_smiles = [
                Chem.MolToSmiles(frag_mol) for frag_mol in enumerated_fragment_mols
            ]
            enumerated_fragment_smiles.append(fragment_str)
            canonicalized_enumerated_fragment_smiles = [
                canonicalize_smiles(frag_smiles, canonicalize_tautomer=False)
                for frag_smiles in enumerated_fragment_smiles
                if frag_smiles
            ]
            enumerated_smiles_dict[fragment_str] = list(
                set(canonicalized_enumerated_fragment_smiles)
            )
        return enumerated_smiles_dict

    def _get_fragment_count_dict(self, smiles: str) -> Dict[str, int]:
        """
        Build a fragment-count mapping from a dot-delimited SMILES string.

        Args:
            smiles (str): A SMILES string where disconnected fragments are delimited by `.`.

        Returns:
            Dict[str, int]: A mapping of each fragment string to the number of times it
            appears in `smiles`.
        """
        fragment_count_dict = {}
        for fragment_str in smiles.split("."):
            if fragment_str not in fragment_count_dict:
                fragment_count_dict[fragment_str] = 1
            else:
                fragment_count_dict[fragment_str] += 1

        return fragment_count_dict

    def _process_templates(
        self, reaction_smiles_data: ReactionData
    ) -> Dict[str, InitializedSmirksPattern]:
        """
        Apply template SMIRKS patterns to a reaction and collect mapped outcomes.

        Args:
            reaction_smiles_data (ReactionData): Reaction data containing reactants,
                products, and precomputed helper structures used for template
                application.

        Returns:
            Dict[str, List]: A mapping from mapped outcome SMILES to a list of
            applied SMIRKS patterns.
        """
        mapped_outcomes_smirks_dict: Dict[str, InitializedSmirksPattern] = {}

        atom_mapped_product = self._get_mapped_product(reaction_smiles_data)
        outcomes_and_applied_smirks = self._apply_templates(reaction_smiles_data)

        for outcome_and_applied_smirk in outcomes_and_applied_smirks:
            result = self._process_single_outcome(
                outcome_and_applied_smirk,
                reaction_smiles_data,
                atom_mapped_product,
            )
            mapped_outcomes_smirks_dict.update(result)

        return mapped_outcomes_smirks_dict

    def _get_mapped_product(self, reaction_smiles_data: ReactionData) -> str:
        """
        Generate a mapped product SMILES string from reaction reactants.

        Args:
            reaction_smiles_data (ReactionData): Reaction data containing
                `rdc_products`, whose `reactants` molecule is annotated with
                atom-map numbers derived from the reactant index-to-map mapping.

        Returns:
            str: A SMILES string for the reactants molecule with atom-map numbers
            applied.
        """

        rdc_products = reaction_smiles_data["rdc_products"]

        rdc_products_mol = rdc_products.reactants  # confusing rdchiral nomenclature - difference between retro and forward perspective
        for atom in rdc_products_mol.GetAtoms():
            atom.SetAtomMapNum(rdc_products.idx_to_mapnum(atom.GetIdx()))
        mapped_product = Chem.MolToSmiles(rdc_products_mol)

        return mapped_product

    def _apply_templates(
        self,
        reaction_smiles_data: ReactionData,
        num_smirks_applied: int = 0,
        apply_multiple_smirks: bool = True,
        num_smirks_to_apply: int = 1,
    ) -> List[AppliedSmirkData]:
        """ """
        product_mols = reaction_smiles_data["products_mols"]
        reactant_mols = reaction_smiles_data["reactants_mols"]
        rdc_products = reaction_smiles_data["rdc_products"]
        unmapped_product_atom_islands = reaction_smiles_data[
            "unmapped_product_atom_islands"
        ]

        # rdchiral uses 1-based indexing, but rdkit uses 0-based indexing
        # so we need another dictionary specifically for rdchiral
        unmapped_product_atom_islands_for_rdchiral = {
            key: {ele + 1 for ele in value}
            for key, value in unmapped_product_atom_islands.items()
        }

        outcomes_and_applied_smirks = []

        for template in self._initialized_smirks_patterns:
            products_smarts = template["products_smarts"]
            reactants_smarts = template["reactants_smarts"]
            rdc_rxn = template["rdc_rxn"]

            product_mol_has_substruct_match = all(
                any(
                    product_mol.HasSubstructMatch(products_smarts_fragment)
                    for product_mol in product_mols
                )
                for products_smarts_fragment in products_smarts
            )

            if not product_mol_has_substruct_match:
                continue

            reactant_mol_has_substruct_match = all(
                any(
                    reactant_mol.HasSubstructMatch(reactant_smarts_fragment)
                    for reactant_mol in reactant_mols
                )
                for reactant_smarts_fragment in reactants_smarts
            )

            if not reactant_mol_has_substruct_match:
                continue

            if not unmapped_product_atom_islands:
                template_applies_to_unmapped_product_atoms = True
            else:

                def _fragment_fits_some_island(
                    products_smarts_fragment: Chem.Mol,
                ) -> bool:
                    """
                    Check whether a SMARTS fragment has any overlap with an unmapped island.
                    We can't do a full subset check because the SMARTS for the template
                    may be overspecified, and include atoms that aren't actually changing in the
                    reaction, and thus *are* mapped with the MCS mapper.

                    Args:
                        products_smarts_fragment (Chem.Mol): SMARTS query fragment used to find
                            substructure matches in the product molecules.

                    Returns:
                        bool: True if any substructure match has any overlap with an unmapped island;
                        otherwise False.
                    """

                    for mol in product_mols:
                        matches = mol.GetSubstructMatches(products_smarts_fragment)
                        matches_set = [set(match) for match in matches]
                        for match_set in matches_set:
                            for island in unmapped_product_atom_islands.values():
                                if match_set & island:
                                    return True
                    return False

                template_applies_to_unmapped_product_atoms = all(
                    _fragment_fits_some_island(products_smarts_fragment)
                    for products_smarts_fragment in products_smarts
                )

            if not template_applies_to_unmapped_product_atoms:
                continue

            try:
                _, outcomes_dict = rdc.rdchiralRun(
                    rdc_rxn, rdc_products, return_mapped=True
                )

                if unmapped_product_atom_islands:

                    def _matching_island_ids(v: Tuple[str, List[int]]) -> List[int]:
                        """
                        Find island IDs that contain any atom map indices of a given outcome.
                        Can't do a full subset check for similar reasons as in
                        _fragment_fits_some_island.

                        Args:
                            v (Tuple[str, List[int]]): A tuple where the second element is a list
                                of atom map indices for a reaction outcome.

                        Returns:
                            List[int]: A list of island IDs where any atom map indices of the
                                outcome are contained.
                        """
                        return [
                            island_id
                            for island_id, island in unmapped_product_atom_islands_for_rdchiral.items()
                            if set(v[1]) & island
                        ]

                    for k, v in outcomes_dict.items():
                        matching_island_ids = _matching_island_ids(v)
                        if not matching_island_ids:
                            continue

                        for matching_island_id in matching_island_ids:
                            outcomes_and_applied_smirks.append(
                                AppliedSmirkData(
                                    outcome_unmapped_smiles=k,
                                    outcome_mapped_smiles=v[0],
                                    outcome_atom_map_indices=v[1],
                                    applied_smirk=template,
                                    outcome_to_island_id=matching_island_id,
                                    num_smirks_applied=num_smirks_applied + 1,
                                )
                            )

                else:
                    for k, v in outcomes_dict.items():
                        outcomes_and_applied_smirks.append(
                            AppliedSmirkData(
                                outcome_unmapped_smiles=k,
                                outcome_mapped_smiles=v[0],
                                outcome_atom_map_indices=v[1],
                                applied_smirk=template,
                                outcome_to_island_id=None,
                                num_smirks_applied=num_smirks_applied + 1,
                            )
                        )

            except Exception as e:
                logger.warning(f"Error applying templates: {e}")
                pass

        if not apply_multiple_smirks:
            return outcomes_and_applied_smirks

        ## TODO: Implement recursive application of SMIRKS for multiple applications
        # for outcome_and_applied_smirk in outcomes_and_applied_smirks:
        #     if outcome_and_applied_smirk.num_smirks_applied >= num_smirks_to_apply:
        #         continue
        #     recursively_applied_smirks_and_outcomes = self._apply_templates(
        #         reaction_smiles_data=outcome_and_applied_smirk,
        #         num_smirks_applied=outcome_and_applied_smirk.num_smirks_applied,
        #         apply_multiple_smirks=apply_multiple_smirks,
        #         num_smirks_to_apply=num_smirks_to_apply,
        #     )

        #     outcomes_and_applied_smirks.extend(recursively_applied_smirks_and_outcomes)

        return outcomes_and_applied_smirks

    def _remove_spectator_mappings(self, smiles: str) -> str:
        """Remove spectator atom-mapping numbers from a SMILES string.

        This clears atom map numbers greater than or equal to 900 (used for
        spectator fragments) by setting them to 0 for each valid SMILES
        fragment.

        Args:
            smiles (str): A SMILES string, potentially containing multiple
                fragments separated by `.`.

        Returns:
            str: A SMILES string with spectator atom mappings removed.
        """
        smiles_fragments = smiles.split(".")
        mol_fragments = []
        for smiles_fragment in smiles_fragments:
            mol_fragment = Chem.MolFromSmiles(smiles_fragment)
            if mol_fragment is None:
                continue
            for atom in mol_fragment.GetAtoms():
                if atom.GetAtomMapNum() >= 900:
                    atom.SetAtomMapNum(0)
            mol_fragments.append(mol_fragment)
        return ".".join(
            [Chem.MolToSmiles(mol_fragment) for mol_fragment in mol_fragments]
        )

    def _process_single_outcome(
        self,
        outcome_and_applied_smirk: AppliedSmirkData,
        reaction_smiles_data: ReactionData,
        atom_mapped_product: str,
    ) -> Dict[str, InitializedSmirksPattern]:
        """ """
        mapped_outcome = outcome_and_applied_smirk.get("outcome_mapped_smiles")
        applied_smirk = outcome_and_applied_smirk.get("applied_smirk")

        tautomers_reactants = reaction_smiles_data["tautomers_reactants"]
        fragment_count_dict = reaction_smiles_data["fragment_count_reactants"]

        mapped_outcome = self._remove_spectator_mappings(mapped_outcome)

        missing_fragments, found_fragments = self._find_missing_fragments(
            mapped_outcome, tautomers_reactants
        )

        if len(missing_fragments) != 0:
            fragment_mapped_dict = self._handle_missing_fragments(
                missing_fragments,
                found_fragments,
                tautomers_reactants,
            )

            if len(fragment_mapped_dict) != len(missing_fragments):
                return {}

            for k, v in fragment_mapped_dict.items():
                if k not in mapped_outcome:
                    return {}
                mapped_outcome = mapped_outcome.replace(k, v)

        unmapped_canonical_smiles_for_mapped_smiles = [
            canonicalize_smiles(ele, canonicalize_tautomer=False)
            for ele in mapped_outcome.split(".")
        ]

        spectators = []
        for ele, ele_count in fragment_count_dict.items():
            canonicalized_ele = canonicalize_smiles(ele, canonicalize_tautomer=False)
            num_occurrences_mapped = unmapped_canonical_smiles_for_mapped_smiles.count(
                canonicalized_ele
            )
            dif_num_occurrences = ele_count - num_occurrences_mapped
            if dif_num_occurrences > 0:
                spectators.extend([canonicalized_ele] * dif_num_occurrences)

        reactants = mapped_outcome.split(".")
        reactants_and_spectators = reactants + spectators

        finalized_reaction_smiles = (
            ".".join(reactants_and_spectators) + ">>" + atom_mapped_product
        )

        return {finalized_reaction_smiles: applied_smirk}

    def _find_missing_fragments(
        self, mapped_outcome: str, unmapped_reactants: Dict[str, List[str]]
    ) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]]]:
        """ """
        missing_fragments = []
        found_fragments = []
        reactant_fragments = list(unmapped_reactants.values())
        flattened_reactant_fragments = [
            item for sublist in reactant_fragments for item in sublist
        ]

        for mapped_fragment in mapped_outcome.split("."):
            unmapped_fragment = canonicalize_smiles(
                mapped_fragment, canonicalize_tautomer=False
            )
            if unmapped_fragment not in flattened_reactant_fragments:
                missing_fragments.append((unmapped_fragment, mapped_fragment))
            else:
                found_fragments.append((unmapped_fragment, mapped_fragment))

        return missing_fragments, found_fragments

    def _handle_missing_fragments(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> Dict[str, str]:
        """ """
        all_fragments_substructs = self._are_fragments_substructures(
            missing_fragments, found_fragments, unmapped_reactants
        )
        if not all_fragments_substructs:
            return {}

        fragment_mapped_dict = self._identify_and_map_fragments(
            missing_fragments,
            found_fragments,
            unmapped_reactants,
        )

        filtered_fragment_mapped_dict = {}
        for k, v in fragment_mapped_dict.items():
            if len(v) > 1:
                logger.warning(
                    "Multiple possible fragments identified for reaction SMARTS substructure"
                )
                return {}
            filtered_fragment_mapped_dict[k] = v[0]

        return filtered_fragment_mapped_dict

    def _are_fragments_substructures(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> bool:
        """ """
        unmapped_found_fragments = [ele[0] for ele in found_fragments]
        for fragment_str, _ in missing_fragments:
            if "*" not in fragment_str:
                continue

            query_mol = Chem.MolFromSmarts(fragment_str)
            if not query_mol:
                return False

            found_match = False
            for reactant_group in unmapped_reactants.values():
                for reactant_fragment_str in reactant_group:
                    if reactant_fragment_str in unmapped_found_fragments:
                        continue
                    reactant_mol = Chem.MolFromSmarts(reactant_fragment_str)
                    if not reactant_mol:
                        return False
                    reactant_mol.UpdatePropertyCache()
                    if reactant_mol.HasSubstructMatch(query_mol):
                        found_match = True
                        break

                if found_match:
                    break

            if not found_match:
                return False

        return True

    def _identify_and_map_fragments(
        self,
        missing_fragments: List[Tuple[str, str]],
        found_fragments: List[Tuple[str, str]],
        unmapped_reactants: Dict[str, List[str]],
    ) -> Dict[str, List[str]]:
        """ """
        unmapped_found_fragments = [ele[0] for ele in found_fragments]
        fragment_mapped_dict = {}
        for _, mapped_reactant_fragment in missing_fragments:
            fragment_found = False
            for _, tautomer_list in unmapped_reactants.items():
                for tautomer in tautomer_list:
                    if tautomer in unmapped_found_fragments:
                        continue

                    out = self._transfer_mapping(mapped_reactant_fragment, tautomer)

                    if not out:
                        continue

                    fragment_found = True

                    if mapped_reactant_fragment not in fragment_mapped_dict:
                        fragment_mapped_dict[mapped_reactant_fragment] = [out]
                    else:
                        existing_mapped_fragments = fragment_mapped_dict[
                            mapped_reactant_fragment
                        ]
                        existing_mapped_fragments.append(out)
                        fragment_mapped_dict[mapped_reactant_fragment] = sorted(
                            list(set(existing_mapped_fragments))
                        )

            if not fragment_found:
                return {}

        return fragment_mapped_dict

    def _transfer_mapping(
        self, mapped_substructure_smarts: str, full_molecule_smiles: str
    ) -> str | None:
        """ """
        pattern = Chem.MolFromSmarts(mapped_substructure_smarts)
        if not pattern:
            return None

        mol = Chem.MolFromSmiles(full_molecule_smiles)
        if not mol:
            return None

        match_indices = mol.GetSubstructMatches(pattern)

        symmetry_class = {
            k: v
            for k, v in enumerate(
                list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))
            )
        }

        symmetric = True
        for match_1 in match_indices:
            for match_2 in match_indices:
                for ele1, ele2 in zip(match_1, match_2):
                    if symmetry_class[ele1] != symmetry_class[ele2]:
                        symmetric = False

        if not match_indices:
            return None

        if len(match_indices) != 1 and not symmetric:
            return None

        match_indices = match_indices[0]

        for atom in mol.GetAtoms():
            if atom.GetAtomMapNum() != 0:
                atom.SetAtomMapNum(0)

        for pattern_atom in pattern.GetAtoms():
            map_num = pattern_atom.GetAtomMapNum()

            if map_num > 0 and map_num < 900:
                pattern_idx = pattern_atom.GetIdx()
                mol_idx = match_indices[pattern_idx]
                mol_atom = mol.GetAtomWithIdx(mol_idx)
                mol_atom.SetAtomMapNum(map_num)

        mapped_smiles_output = Chem.MolToSmiles(mol)
        return mapped_smiles_output

    def _get_unmapped_product_atom_islands(self, smiles: str) -> Dict[int, Set[int]]:
        """
        Find connected components ("islands") of unmapped atoms in a product SMILES.

        Args:
            smiles (str): Product SMILES string to analyze.

        Returns:
            Dict[int, Set[int]]: Mapping from island index (0..N-1) to a set of RDKit
            atom indices belonging to that connected component, considering only atoms
            with atom map number equal to 0.

        Raises:
            ValueError: If the SMILES cannot be parsed into an RDKit molecule.
        """
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Could not parse SMILES: {smiles}")

        unmapped = {
            atom.GetIdx() for atom in mol.GetAtoms() if atom.GetAtomMapNum() == 0
        }

        visited = set()
        islands: Dict[int, Set[int]] = {}

        for idx in unmapped:
            if idx in visited:
                continue

            island: Set[int] = set()
            queue = deque([idx])
            visited.add(idx)

            while queue:
                current = queue.popleft()
                island.add(current)

                for neighbor in mol.GetAtomWithIdx(current).GetNeighbors():
                    n_idx = neighbor.GetIdx()
                    if n_idx in unmapped and n_idx not in visited:
                        visited.add(n_idx)
                        queue.append(n_idx)

            islands_idx = len(islands)
            islands[islands_idx] = island

        return islands

    def _select_preferred_mapping(
        self, possible_outcomes: Dict[str, List[InitializedSmirksPattern]]
    ) -> str:
        """
        Select the preferred mapping from a list of mappings.

        Args:
            possible_outcomes (Dict[str, List[InitializedSmirksPattern]]): Dictionary of possible mappings.

        Returns:
            str: The selected preferred mapping.
        """
        selected_mapping = ""
        max_num_mapped_product_atoms = 0
        for canonicalized_mapping, possible_mappings in possible_outcomes.items():
            for possible_mapping in possible_mappings:
                mapping_num_fragments = len(possible_mapping["reactants_smarts"])
                mapping_num_atoms = mapping_num_fragments
                for mol in possible_mapping["reactants_smarts"]:
                    mapping_num_atoms += len(mol.GetAtoms())
                for mol in possible_mapping["products_smarts"]:
                    mapping_num_atoms += len(mol.GetAtoms())
                if max_num_mapped_product_atoms < mapping_num_atoms:
                    selected_mapping = canonicalized_mapping
                    max_num_mapped_product_atoms = mapping_num_atoms
        return selected_mapping

    def _post_process_mapped_outcomes(
        self,
        mapped_outcomes_smirks_dict: Dict[str, InitializedSmirksPattern],
        canonicalized_input_smiles: str,
        original_smiles: str,
    ) -> Optional[ReactionMapperResult]:
        """
        Filter, validate, and deduplicate mapped reaction outcomes.

        Args:
            mapped_outcomes_smirks_dict (Dict[str, InitializedSmirksPattern]): Map of
                candidate mapped reaction SMIRKS/SMILES strings to their originating
                initialized pattern metadata.
            canonicalized_input_smiles (str): Canonicalized representation of the
                input reaction (used to discard outcomes that do not match).
            original_smiles (str): Original (non-canonicalized) input reaction string
                to return in the result.

        Returns:
            Optional[ReactionMapperResult]: A single selected mapping and associated
                candidate metadata if exactly one valid, matching outcome remains;
                otherwise, returns None.
        """
        deduplicated_mapped_outcomes: Dict[str, List[InitializedSmirksPattern]] = {}
        for k, v in mapped_outcomes_smirks_dict.items():
            canonicalized_k = canonicalize_atom_mapping(
                canonicalize_reaction_smiles(
                    k, canonicalize_tautomer=False, remove_mapping=False
                )
            )
            if not k:
                continue
            if (
                canonicalize_reaction_smiles(
                    canonicalized_k, canonicalize_tautomer=False
                )
                != canonicalized_input_smiles
            ):
                continue
            if not self._verify_validity_of_mapping(canonicalized_k):
                continue

            applied_smirk_forward = (
                v["parent_smirks"].split(">>")[1]
                + ">>"
                + v["parent_smirks"].split(">>")[0]
            )

            v["template_name"] = self._smirks_name_dictionary[applied_smirk_forward][
                "name"
            ]

            if canonicalized_k not in deduplicated_mapped_outcomes:
                deduplicated_mapped_outcomes[canonicalized_k] = [v]
            else:
                deduplicated_mapped_outcomes[canonicalized_k].append(v)

        if len(deduplicated_mapped_outcomes) == 0:
            return None
        if len(deduplicated_mapped_outcomes) > 1:
            logger.warning("Multiple possible mappings")
            selected_mapping = self._select_preferred_mapping(
                deduplicated_mapped_outcomes
            )
            return ReactionMapperResult(
                original_smiles=original_smiles,
                selected_mapping=selected_mapping,
                possible_mappings=deduplicated_mapped_outcomes,
                mapping_type=self._mapper_type,
                mapping_score=None,
                additional_info=[],
            )

        return ReactionMapperResult(
            original_smiles=original_smiles,
            selected_mapping=list(deduplicated_mapped_outcomes.keys())[0],
            possible_mappings=deduplicated_mapped_outcomes,
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[],
        )

    def map_reaction(
        self, reaction_smiles: str, return_mcs_result: bool = False
    ) -> ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]:
        """
        Map a reaction SMILES string using template-based atom mapping.

        Args:
            reaction_smiles (str): Reaction SMILES to map.
            return_mcs_result (bool): Whether to return the MCS mapping result.

        Returns:
            ReactionMapperResult: A mapping result containing the selected mapping and
            related metadata. If the input is invalid or no unique valid mapping can be
            produced, an "empty" default result is returned.
            If return_mcs_result is True, a tuple of (ReactionMapperResult, ReactionMapperResult) is returned,
            where the second element is the MCS mapping result.
        """
        default_mapping_dict = ReactionMapperResult(
            original_smiles="",
            selected_mapping="",
            possible_mappings={},
            mapping_type=self._mapper_type,
            mapping_score=None,
            additional_info=[{}],
        )

        if not self._reaction_smiles_valid(reaction_smiles):
            if return_mcs_result:
                return default_mapping_dict, default_mapping_dict
            return default_mapping_dict

        canonicalized_reaction_smiles = canonicalize_reaction_smiles(
            reaction_smiles, canonicalize_tautomer=True
        )

        unmapped_product_atom_islands = {}
        if self._mcs_mapper is not None:
            mcs_result = self._mcs_mapper.map_reaction(canonicalized_reaction_smiles)

            if mcs_result["selected_mapping"] != "":
                unmapped_product_atom_islands = self._get_unmapped_product_atom_islands(
                    mcs_result["selected_mapping"].split(">>")[1]
                )

        reactants_str, products_str = self._split_reaction_components(
            canonicalized_reaction_smiles
        )

        reaction_data = self._prepare_reaction_data(
            reactants_str, products_str, unmapped_product_atom_islands
        )

        mapped_outcomes_smirks_dict = self._process_templates(
            reaction_data,
        )

        result = self._post_process_mapped_outcomes(
            mapped_outcomes_smirks_dict, canonicalized_reaction_smiles, reaction_smiles
        )
        if not result:
            if return_mcs_result:
                return default_mapping_dict, mcs_result
            return default_mapping_dict

        if return_mcs_result:
            return result, mcs_result

        return result

    def map_reactions(
        self, reaction_list: List[str], return_mcs_results: bool = False
    ) -> (
        List[ReactionMapperResult]
        | List[Tuple[ReactionMapperResult, ReactionMapperResult]]
    ):
        """
        Map a list of reaction SMILES strings using this mapper.

        Args:
            reaction_list (List[str]): Reaction SMILES strings to map.
            return_mcs_results (bool): Whether to return MCS mapping results.

        Returns:
            List[ReactionMapperResult]: Mapping results in the same order as the input.
            Tuple[ReactionMapperResult, ReactionMapperResult]: Mapping results and MCS results in the same order as the input.
        """

        mapped_reactions = []
        for reaction in reaction_list:
            mapped_reactions.append(self.map_reaction(reaction, return_mcs_results))
        return mapped_reactions

__init__(mapper_name, mapper_weight=3, custom_smirks_patterns=None, use_default_smirks_patterns=True, max_transforms=1000, max_tautomers=1000, use_mcs_mapping=True)

Initialize the TemplateMapper instance.

Parameters:

Name Type Description Default
custom_smirks_patterns List[Dict]

A list of dictionaries containing custom SMIRKS patterns. Each dictionary should have a 'name' key, a 'smirks' key, and a 'superclass_id' key.

None
use_default_smirks_patterns bool

Whether to use the default SMIRKS patterns.

True
Source code in agave_chem/mappers/template/template_mapper.py
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
def __init__(
    self,
    mapper_name: str,
    mapper_weight: float = 3,
    custom_smirks_patterns: List[SmirksPattern] | None = None,
    use_default_smirks_patterns: bool = True,
    max_transforms: int = 1000,
    max_tautomers: int = 1000,
    use_mcs_mapping: bool = True,
):
    """
    Initialize the TemplateMapper instance.

    Args:
        custom_smirks_patterns (List[Dict]): A list of dictionaries containing
            custom SMIRKS patterns. Each dictionary should have a 'name' key,
            a 'smirks' key, and a 'superclass_id' key.
        use_default_smirks_patterns (bool): Whether to use the default SMIRKS
            patterns.
    """

    super().__init__("template", mapper_name, mapper_weight)

    if custom_smirks_patterns is not None:
        if not isinstance(custom_smirks_patterns, list):
            raise TypeError(
                "Invalid input: custom_smirks_patterns must be a list of dictionaries."
            )
        for pattern in custom_smirks_patterns:
            if set(pattern.keys()) != set(["name", "smirks", "superclass_id"]):
                raise TypeError(
                    "Invalid input: each dictionary in custom_smirks_patterns must have 'name', 'smirks', and 'superclass_id' keys."
                )
            for key, value in pattern.items():
                if key == "superclass_id":
                    if value is not None and not isinstance(value, int):
                        raise TypeError(
                            "Invalid input: 'superclass_id' value must be an integer or None."
                        )
                else:
                    if not isinstance(value, str):
                        raise TypeError(
                            "Invalid input: 'name' and 'smirks' values must be strings."
                        )

    SMIRKS_PATTERNS_FILE = files("agave_chem.datafiles").joinpath(
        "smirks_patterns.json"
    )
    default_smirks_patterns = []
    with SMIRKS_PATTERNS_FILE.open("r") as f:
        default_smirks_patterns = json.load(f)

    self._smirks_patterns: List[SmirksPattern] = []
    if use_default_smirks_patterns and custom_smirks_patterns is None:
        self._smirks_patterns = default_smirks_patterns
    elif custom_smirks_patterns and not use_default_smirks_patterns:
        self._smirks_patterns = custom_smirks_patterns
    elif custom_smirks_patterns and use_default_smirks_patterns:
        self._smirks_patterns = custom_smirks_patterns + default_smirks_patterns
    else:
        raise TypeError(
            "Attempting to initialize AgaveChem with no SMIRKS patterns"
        )

    self._smirks_name_dictionary: Dict[str, SmirksNameDict] = {
        pattern["smirks"]: {
            "name": pattern["name"],
            "superclass_id": pattern["superclass_id"],
            "class_id": pattern["class_id"],
            "subclass_id": pattern["subclass_id"],
            "class_str": f"{pattern['superclass_id']}.{pattern['class_id']}.{pattern['subclass_id']}",
        }
        for pattern in self._smirks_patterns
    }
    self._initialized_smirks_patterns: List[InitializedSmirksPattern] = (
        initialize_template_data(self._smirks_patterns)
    )

    self._tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
    self._tautomer_enumerator.SetMaxTransforms(max_transforms)
    self._tautomer_enumerator.SetMaxTautomers(max_tautomers)

    self._mcs_mapper = None
    if use_mcs_mapping:
        self._mcs_mapper = MCSReactionMapper(
            mapper_name="mcs_for_template", mapper_weight=1
        )

map_reaction(reaction_smiles, return_mcs_result=False)

Map a reaction SMILES string using template-based atom mapping.

Parameters:

Name Type Description Default
reaction_smiles str

Reaction SMILES to map.

required
return_mcs_result bool

Whether to return the MCS mapping result.

False

Returns:

Name Type Description
ReactionMapperResult ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]

A mapping result containing the selected mapping and

ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]

related metadata. If the input is invalid or no unique valid mapping can be

ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]

produced, an "empty" default result is returned.

ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]

If return_mcs_result is True, a tuple of (ReactionMapperResult, ReactionMapperResult) is returned,

ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]

where the second element is the MCS mapping result.

Source code in agave_chem/mappers/template/template_mapper.py
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
def map_reaction(
    self, reaction_smiles: str, return_mcs_result: bool = False
) -> ReactionMapperResult | Tuple[ReactionMapperResult, ReactionMapperResult]:
    """
    Map a reaction SMILES string using template-based atom mapping.

    Args:
        reaction_smiles (str): Reaction SMILES to map.
        return_mcs_result (bool): Whether to return the MCS mapping result.

    Returns:
        ReactionMapperResult: A mapping result containing the selected mapping and
        related metadata. If the input is invalid or no unique valid mapping can be
        produced, an "empty" default result is returned.
        If return_mcs_result is True, a tuple of (ReactionMapperResult, ReactionMapperResult) is returned,
        where the second element is the MCS mapping result.
    """
    default_mapping_dict = ReactionMapperResult(
        original_smiles="",
        selected_mapping="",
        possible_mappings={},
        mapping_type=self._mapper_type,
        mapping_score=None,
        additional_info=[{}],
    )

    if not self._reaction_smiles_valid(reaction_smiles):
        if return_mcs_result:
            return default_mapping_dict, default_mapping_dict
        return default_mapping_dict

    canonicalized_reaction_smiles = canonicalize_reaction_smiles(
        reaction_smiles, canonicalize_tautomer=True
    )

    unmapped_product_atom_islands = {}
    if self._mcs_mapper is not None:
        mcs_result = self._mcs_mapper.map_reaction(canonicalized_reaction_smiles)

        if mcs_result["selected_mapping"] != "":
            unmapped_product_atom_islands = self._get_unmapped_product_atom_islands(
                mcs_result["selected_mapping"].split(">>")[1]
            )

    reactants_str, products_str = self._split_reaction_components(
        canonicalized_reaction_smiles
    )

    reaction_data = self._prepare_reaction_data(
        reactants_str, products_str, unmapped_product_atom_islands
    )

    mapped_outcomes_smirks_dict = self._process_templates(
        reaction_data,
    )

    result = self._post_process_mapped_outcomes(
        mapped_outcomes_smirks_dict, canonicalized_reaction_smiles, reaction_smiles
    )
    if not result:
        if return_mcs_result:
            return default_mapping_dict, mcs_result
        return default_mapping_dict

    if return_mcs_result:
        return result, mcs_result

    return result

map_reactions(reaction_list, return_mcs_results=False)

Map a list of reaction SMILES strings using this mapper.

Parameters:

Name Type Description Default
reaction_list List[str]

Reaction SMILES strings to map.

required
return_mcs_results bool

Whether to return MCS mapping results.

False

Returns:

Type Description
List[ReactionMapperResult] | List[Tuple[ReactionMapperResult, ReactionMapperResult]]

List[ReactionMapperResult]: Mapping results in the same order as the input.

List[ReactionMapperResult] | List[Tuple[ReactionMapperResult, ReactionMapperResult]]

Tuple[ReactionMapperResult, ReactionMapperResult]: Mapping results and MCS results in the same order as the input.

Source code in agave_chem/mappers/template/template_mapper.py
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
def map_reactions(
    self, reaction_list: List[str], return_mcs_results: bool = False
) -> (
    List[ReactionMapperResult]
    | List[Tuple[ReactionMapperResult, ReactionMapperResult]]
):
    """
    Map a list of reaction SMILES strings using this mapper.

    Args:
        reaction_list (List[str]): Reaction SMILES strings to map.
        return_mcs_results (bool): Whether to return MCS mapping results.

    Returns:
        List[ReactionMapperResult]: Mapping results in the same order as the input.
        Tuple[ReactionMapperResult, ReactionMapperResult]: Mapping results and MCS results in the same order as the input.
    """

    mapped_reactions = []
    for reaction in reaction_list:
        mapped_reactions.append(self.map_reaction(reaction, return_mcs_results))
    return mapped_reactions

map_reactions(reaction_list, mappers_list=[], mapping_selection_mode='weighted', batch_size=500)

Source code in agave_chem/main.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def map_reactions(
    reaction_list: List[str],
    mappers_list: List[ReactionMapper] = [],
    mapping_selection_mode: str = "weighted",
    batch_size: int = 500,
) -> List[AgaveChemMapperResult]:
    """ """
    if not mappers_list:
        mappers_list = [
            MCSReactionMapper("mcs_default"),
            TemplateReactionMapper("expert_default"),
        ]

    if isinstance(reaction_list, str):
        reaction_list = [reaction_list]

    if not isinstance(reaction_list, list):
        raise ValueError(
            "Invalid input: reaction_list must be a string or a non-empty list of strings."
        )
    if isinstance(reaction_list, list):
        if len(reaction_list) == 0:
            raise ValueError(
                "Invalid input: reaction_list must be a string or a non-empty list of strings."
            )
        for reaction in reaction_list:
            if not isinstance(reaction, str):
                raise ValueError(
                    "Invalid input: reaction_list must be a string or a non-empty list of strings."
                )
    if len(reaction_list) != len(set(reaction_list)):
        logger.warning("Removing duplicate reactions from reaction_list.")
        reaction_list = list(set(reaction_list))

    if not isinstance(mappers_list, list) or len(mappers_list) == 0:
        raise ValueError(
            "Invalid input: mappers_list must be a non-empty list of ReactionMapper instances."
        )

    seen_mappers = []
    for mapper in mappers_list:
        if not isinstance(mapper, ReactionMapper):
            raise ValueError(
                f"Invalid mapper: {mapper} is not an instance of ReactionMapper."
            )
        if mapper.mapper_name in seen_mappers:
            raise ValueError(f"Duplicate mapper name: {mapper.mapper_name}.")
        seen_mappers.append(mapper.mapper_name)

    if not isinstance(mapping_selection_mode, str) and not callable(
        mapping_selection_mode
    ):
        raise ValueError(
            "Invalid input: mapping_selection_mode must be a string or function."
        )

    if not isinstance(batch_size, int):
        raise TypeError("Invalid input: batch_size must be an integer.")
    if batch_size <= 0 or batch_size > 1000:
        raise ValueError("Invalid input: batch_size must be an integer between 1-1000.")

    mappers_out_dict = map_reactions_using_mappers(
        reaction_list, mappers_list, batch_size
    )

    if not mappers_out_dict:
        raise ValueError("Invalid input: batch_size must be an integer between 1-1000.")

    return mappers_out_dict