Skip to content

Commit c6b645d

Browse files
committed
additional similarity metrics included
1 parent d06570a commit c6b645d

File tree

1 file changed

+130
-2
lines changed

1 file changed

+130
-2
lines changed

src/rxn_insight/utils.py

+130-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@
2121
from rdkit.Chem.rdchem import Atom, BondType, Mol
2222
from rdkit.Chem.Scaffolds.MurckoScaffold import GetScaffoldForMol
2323
from rxnmapper import RXNMapper
24-
from scipy.spatial.distance import jaccard
24+
from scipy.spatial.distance import braycurtis, canberra, chebyshev, \
25+
cityblock, correlation, cosine, euclidean, minkowski, \
26+
dice, hamming, jaccard, kulczynski1, rogerstanimoto, \
27+
russellrao, sokalmichener, sokalsneath, yule
2528

2629
pd.options.mode.chained_assignment = None
2730

@@ -640,11 +643,136 @@ def make_rdkit_fp(rxn: str, fp: str = "MACCS", concatenate: bool = True) -> str:
640643
return bfp
641644

642645

643-
def get_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
646+
def get_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any], metric: str = "jaccard") -> float:
647+
"""Calculate the similarity between two fingerprints. Available metrics are `jaccard`, `dice`,
648+
`kulczynski1`, `rogerstanimoto`, `russellrao`, `sokalmichener`, `sokalsneath`, `yule`,
649+
`braycurtis`, `canberra`, `chebyshev`, `manhattan`, `correlation`, `cosine`, `euclidean`, `minkowski`.
650+
651+
:param v1: Reference fingerprint in np.ndarray format
652+
:param v2: Fingerprint to compare in np.ndarray format
653+
:param metric: Metric to calculate the similarity with.
654+
:return: similarity value as float
655+
"""
656+
metric = metric.lower()
657+
658+
if metric == "jaccard":
659+
similarity = calculate_jaccard_similarity(v1, v2)
660+
elif metric == "dice":
661+
similarity = calculate_dice_similarity(v1, v2)
662+
elif metric == "kulczynski1":
663+
similarity = calculate_kulczynksi1_similarity(v1, v2)
664+
elif metric == "rogerstanimoto":
665+
similarity = calculate_rogerstanimoto_similarity(v1, v2)
666+
elif metric == "russellrao":
667+
similarity = calculate_russellrao_similarity(v1, v2)
668+
elif metric == "sokalmichener":
669+
similarity = calculate_sokalmichener_similarity(v1, v2)
670+
elif metric == "sokalsneath":
671+
similarity = calculate_sokalsneath_similarity(v1, v2)
672+
elif metric == "yule":
673+
similarity = calculate_yule_similarity(v1, v2)
674+
elif metric == "braycurtis":
675+
similarity = calculate_braycurtis_similarity(v1, v2)
676+
elif metric == "canberra":
677+
similarity = calculate_canberra_similarity(v1, v2)
678+
elif metric == "chebyshev":
679+
similarity = calculate_chebyshev_similarity(v1, v2)
680+
elif metric == "manhattan":
681+
similarity = calculate_manhattan_similarity(v1, v2)
682+
elif metric == "correlation":
683+
similarity = calculate_correlation_similarity(v1, v2)
684+
elif metric == "cosine":
685+
similarity = calculate_cosine_similarity(v1, v2)
686+
elif metric == "euclidean":
687+
similarity = calculate_euclidean_similarity(v1, v2)
688+
elif metric == "minkowski":
689+
similarity = calculate_minkowski_similarity(v1, v2)
690+
else:
691+
raise ValueError(f"Unknown metric '{metric}'. Please choose a valid metric.")
692+
693+
return similarity
694+
695+
696+
def calculate_jaccard_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
644697
similarity: float = 1 - jaccard(v1, v2)
645698
return similarity
646699

647700

701+
def calculate_dice_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
702+
similarity: float = 1 - dice(v1, v2)
703+
return similarity
704+
705+
706+
def calculate_kulczynksi1_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
707+
similarity: float = 1 - kulczynski1(v1, v2)
708+
return similarity
709+
710+
711+
def calculate_rogerstanimoto_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
712+
similarity: float = 1 - rogerstanimoto(v1, v2)
713+
return similarity
714+
715+
716+
def calculate_russellrao_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
717+
similarity: float = 1 - russellrao(v1, v2)
718+
return similarity
719+
720+
721+
def calculate_sokalmichener_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
722+
similarity: float = 1 - sokalmichener(v1, v2)
723+
return similarity
724+
725+
726+
def calculate_sokalsneath_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
727+
similarity: float = 1 - sokalsneath(v1, v2)
728+
return similarity
729+
730+
731+
def calculate_yule_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
732+
similarity: float = 1 - kulczynski1(v1, v2)
733+
return similarity
734+
735+
736+
def calculate_braycurtis_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
737+
similarity: float = 1 - braycurtis(v1, v2)
738+
return similarity
739+
740+
741+
def calculate_canberra_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
742+
similarity: float = 1 - canberra(v1, v2)
743+
return similarity
744+
745+
746+
def calculate_chebyshev_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
747+
similarity: float = 1 - chebyshev(v1, v2)
748+
return similarity
749+
750+
751+
def calculate_manhattan_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
752+
similarity: float = 1 - cityblock(v1, v2)
753+
return similarity
754+
755+
756+
def calculate_correlation_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
757+
similarity: float = 1 - correlation(v1, v2)
758+
return similarity
759+
760+
761+
def calculate_cosine_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
762+
similarity: float = 1 - cosine(v1, v2)
763+
return similarity
764+
765+
766+
def calculate_euclidean_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
767+
similarity: float = 1 - euclidean(v1, v2)
768+
return similarity
769+
770+
771+
def calculate_minkowski_similarity(v1: npt.NDArray[Any], v2: npt.NDArray[Any]) -> float:
772+
similarity: float = 1 - minkowski(v1, v2)
773+
return similarity
774+
775+
648776
def get_solvent_ranking(df: pd.DataFrame) -> pd.DataFrame:
649777
solvent_dict: dict[str, list[str]] = {"NAME": [], "COUNT": []}
650778
solvents = df["SOLVENT"].tolist()

0 commit comments

Comments
 (0)