Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding iSIM as a trackable value during training #192

Merged
merged 9 commits into from
Mar 11, 2025
6 changes: 5 additions & 1 deletion configs/toml/staged_learning.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ summary_csv_prefix = "staged_learning" # prefix for the CSV file
use_checkpoint = false # if true read diversity filter from agent_file
purge_memories = false # if true purge all diversity filter memories after each stage


## Reinvent
prior_file = "priors/reinvent.prior"
agent_file = "priors/reinvent.prior"

##ISIM
tb_isim = false # if true track isim value of smilies across training epochs in tensorboard

## LibInvent
#prior_file = "priors/libinvent.prior"
#agent_file = "priors/libinvent.prior"
Expand Down Expand Up @@ -176,7 +180,7 @@ max_steps = 100

[stage.scoring] # the scoring components can be read from a score file
type = "geometric_mean" # aggregation function
filename = "stage2_scoring.toml" # file with scoring setup for this stage
filename = "configs/toml/stage2_scoring.toml" # file with scoring setup for this stage
filetype = "toml" # file format: TOML or JSON, no default, must be present


Expand Down
15 changes: 14 additions & 1 deletion reinvent/runmodes/RL/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from torch.utils.tensorboard import SummaryWriter
import numpy as np

#ISIM imports
from iSIM.comp import calculate_isim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, but we would also need to update the requirements file and the pyproject.toml file

from iSIM.utils import binary_fps

from .reports import RLTBReporter, RLCSVReporter, RLRemoteReporter, RLReportData
from reinvent.runmodes.RL.data_classes import ModelState
from reinvent.models.model_factory.sample_batch import SmilesState
Expand Down Expand Up @@ -52,6 +56,7 @@ def __init__(
inception: Inception = None,
responder_config: dict = None,
tb_logdir: str = None,
tb_isim: bool = False,
):
"""Setup of the common framework"""

Expand Down Expand Up @@ -94,13 +99,14 @@ def __init__(
self.reporters = []
self.tb_reporter = None
self._setup_reporters(tb_logdir)
self.tb_isim = tb_isim

self.start_time = 0

def optimize(self, converged: terminator_callable) -> bool:
"""Run the multistep optimization loop

Sample from the agent, score the SNILES, update the agent parameters.
Sample from the agent, score the SMILES, update the agent parameters.
Log some key characteristics of the current step.

:param converged: a callable that determines convergence
Expand Down Expand Up @@ -310,6 +316,12 @@ def report(
fract_duplicate_smiles = num_duplicate_smiles / len(mask_duplicates)

smilies = np.array(self.sampled.smilies)[mask_valid]

isim = None
if self.tb_isim:
fingerprints = binary_fps(smilies, fp_type='RDKIT', n_bits=None) #Use isim utilities to compute RDKIT binary fingerprints
isim = calculate_isim(fingerprints, n_ary ='JT') #Use isim calculator for average Tanimoto similarity

if self.prior.model_type == "Libinvent":
smilies = normalize(smilies, keep_all=True)
mask_idx = (np.argwhere(mask_valid).flatten(),)
Expand All @@ -318,6 +330,7 @@ def report(
step=step_no,
stage=self.stage_no,
smilies=smilies,
isim=isim, #Add isim to report_data
scaffolds=scaffolds,
sampled=self.sampled,
score_results=score_results,
Expand Down
1 change: 1 addition & 0 deletions reinvent/runmodes/RL/reports/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class RLReportData:
step: int
stage: int
smilies: list
isim: Optional[float]
scaffolds: list
sampled: SampleBatch
score_results: ScoreResults
Expand Down
3 changes: 3 additions & 0 deletions reinvent/runmodes/RL/reports/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def submit(self, data: RLReportData) -> None:
self.reporter.add_scalar(f"{name} (raw)", np.nanmean(_scores[mask_idx]), step)

self.reporter.add_scalar(f"Loss", data.loss, step)
#Add iSIM to board as scalar per step
if data.isim:
self.reporter.add_scalar(f"iSIM: Average similarity", data.isim, step)

# NOTE: for some reason this breaks on Windows because the necessary
# subdirectory cannot be created
Expand Down
7 changes: 4 additions & 3 deletions reinvent/runmodes/RL/run_staged_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def run_staged_learning(
)

parameters = config.parameters

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please, set up your IDE or editor to not add spurious whitespace. This distracts from actual code review. Thanks.

# NOTE: The model files are a dictionary with model attributes from
# Reinvent and a set of tensors, each with an attribute for the
# device (CPU or GPU) and if gradients are required
Expand Down Expand Up @@ -301,7 +301,7 @@ def run_staged_learning(
distance_threshold = parameters.distance_threshold

model_learning = getattr(RL, f"{model_type}Learning")

if callable(write_config):
write_config(config.model_dump())

Expand Down Expand Up @@ -329,7 +329,7 @@ def run_staged_learning(
else:
state = ModelState(agent, package.diversity_filter)
logger.debug(f"Using stage DF")

optimize = model_learning(
max_steps=package.max_steps,
stage_no=stage_no,
Expand All @@ -344,6 +344,7 @@ def run_staged_learning(
inception=inception,
responder_config=responder_config,
tb_logdir=logdir,
tb_isim=parameters.tb_isim,
)

if device.type == "cuda" and torch.cuda.is_available():
Expand Down
1 change: 1 addition & 0 deletions reinvent/runmodes/RL/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SectionParameters(GlobalConfig):
randomize_smiles: bool = True
unique_sequences: bool = False
temperature: float = 1.0
tb_isim: Optional[bool] = False #Add iSIM tracking as optional parameter


class SectionLearningStrategy(GlobalConfig):
Expand Down
4 changes: 2 additions & 2 deletions reinvent/utils/config_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,14 @@ def read_config(filename: Optional[Path], fmt: str) -> dict:
"""

pkg = FMT_CONVERT[fmt]

if isinstance(filename, (str, Path)):
with open(filename, "rb") as tf:
config = pkg.load(tf)
else:
config_str = "\n".join(sys.stdin.readlines())
config = pkg.loads(config_str)

return config


Expand Down
1 change: 1 addition & 0 deletions reinvent/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ReinventConfig(GlobalConfig):
use_cuda: Optional[bool] = Field(True, deprecated="use 'device' instead")
tb_logdir: Optional[str] = None
json_out_config: Optional[str] = None
tb_isim: Optional[bool] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed as global parameter because it is only relevant in the RL parameter section.

seed: Optional[int] = None
parameters: dict

Expand Down