-
Notifications
You must be signed in to change notification settings - Fork 2
logger api #50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
logger api #50
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import jax | ||
import wandb | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
|
||
class BaseLogger: | ||
"""Base logger interface that defines common methods for logging metrics and parameters.""" | ||
|
||
def __init__(self, log_dir: Union[str, Path]=None): | ||
"""Initializes the BaseLogger with a specified log directory. | ||
|
||
Args: | ||
log_dir (str): Directory to store the logs. If None, uses default log directory. | ||
""" | ||
self.log_dir = log_dir | ||
|
||
def log_metric(self, key, value, step=None): | ||
"""Logs a metric value. | ||
|
||
Args: | ||
key (str): The name of the metric. | ||
value (float): The value of the metric. | ||
step (int, optional): The step number at which the metric is logged. | ||
""" | ||
raise NotImplementedError | ||
Comment on lines
+17
to
+25
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Question: do you want every child class to have |
||
|
||
def log_params(self, params): | ||
"""Logs parameters. | ||
|
||
Args: | ||
params (dict): A dictionary containing parameter names and their values. | ||
""" | ||
raise NotImplementedError | ||
|
||
def log_progress(self, step, state_info): | ||
"""Logs the state of a process using the log_metric method. | ||
|
||
Args: | ||
state_info (dict): A dictionary containing state information. | ||
step (int, optional): The step number at which the state is logged. | ||
""" | ||
for key, value in state_info.items(): | ||
if isinstance(value, jax.Array): | ||
value = float(value) # we need floats for logging | ||
self.log_metric(key, value, step) | ||
|
||
|
||
class TensorBoardLogger(BaseLogger): | ||
"""Logger that implements logging functionality using TensorBoard. | ||
|
||
Inherits from BaseLogger and implements its methods for TensorBoard specific logging. | ||
""" | ||
|
||
def __init__(self, log_dir=None): | ||
"""Initializes the TensorBoardLogger with a specified log directory. | ||
|
||
Args: | ||
log_dir (str): Directory to store TensorBoard logs. If None, uses default log directory. | ||
""" | ||
super().__init__(log_dir) | ||
self.writer = SummaryWriter(log_dir) | ||
|
||
def log_metric(self, key, value, step=None): | ||
"""Logs a metric to TensorBoard. | ||
|
||
Args: | ||
key (str): The name of the metric. | ||
value (float): The value of the metric. | ||
step (int, optional): The step number at which the metric is logged. | ||
""" | ||
self.writer.add_scalar(key, value, step) | ||
|
||
def log_params(self, params): | ||
"""Logs parameters to TensorBoard. | ||
|
||
Args: | ||
params (dict): A dictionary of parameters to log. | ||
""" | ||
self.writer.add_hparams(params) | ||
|
||
|
||
class WandbLogger(BaseLogger): | ||
"""Logger that implements logging functionality using Weights & Biases (wandb). | ||
|
||
Inherits from BaseLogger and implements its methods for wandb specific logging. | ||
""" | ||
|
||
def __init__(self, log_dir=None, project_name=None): | ||
"""Initializes the WandbLogger with a specified log directory and project name. | ||
|
||
Args: | ||
log_dir (str): Directory to store local wandb logs. If None, uses default wandb directory. | ||
project_name (str): Name of the wandb project. If None, a default project is used. | ||
""" | ||
super().__init__(log_dir) | ||
wandb.init(dir=log_dir, project=project_name) | ||
|
||
def log_metric(self, key, value, step=None): | ||
"""Logs a metric to wandb. | ||
|
||
Args: | ||
key (str): The name of the metric. | ||
value (float): The value of the metric. | ||
step (int, optional): The step number at which the metric is logged. | ||
""" | ||
wandb.log({key: value}, step=step) | ||
|
||
def log_params(self, params): | ||
"""Logs parameters to wandb. | ||
|
||
Args: | ||
params (dict): A dictionary of parameters to log. | ||
""" | ||
wandb.config.update(params) | ||
|
||
|
||
class LoggerFactory: | ||
"""Factory class to create logger instances based on specified logger type. | ||
|
||
Supports creation of different types of loggers like TensorBoardLogger and WandbLogger. | ||
""" | ||
|
||
@staticmethod | ||
def get_logger(logger_type, log_dir=None): | ||
"""Creates and returns a logger instance based on the specified logger type. | ||
|
||
Args: | ||
logger_type (str): The type of logger to create ('tensorboard' or 'wandb'). | ||
log_dir (str, optional): Directory to store the logs. Specific to the logger type. | ||
|
||
Returns: | ||
BaseLogger: An instance of the requested logger type. | ||
|
||
Raises: | ||
ValueError: If an unsupported logger type is specified. | ||
""" | ||
if logger_type == "tensorboard": | ||
return TensorBoardLogger(log_dir) | ||
elif logger_type == "wandb": | ||
return WandbLogger(log_dir) | ||
else: | ||
raise ValueError("Unsupported logger type") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import functools | ||
import os | ||
from datetime import datetime | ||
|
||
import jax | ||
from brax import envs | ||
from brax.training.agents.ppo import networks as ppo_networks | ||
from brax.training.agents.ppo import train as ppo | ||
|
||
from ambersim.logger.logger import LoggerFactory | ||
from ambersim.rl.pendulum.swingup import PendulumSwingupEnv | ||
|
||
""" | ||
A pendulum swingup example that uses a custom logger to log training | ||
progress in real time. | ||
""" | ||
|
||
if __name__ == "__main__": | ||
# Initialize the environment | ||
envs.register_environment("pendulum_swingup", PendulumSwingupEnv) | ||
env = envs.get_environment("pendulum_swingup") | ||
|
||
# Define the training function | ||
network_factory = functools.partial( | ||
ppo_networks.make_ppo_networks, | ||
policy_hidden_layer_sizes=(64,) * 3, | ||
) | ||
train_fn = functools.partial( | ||
ppo.train, | ||
num_timesteps=100_000, | ||
num_evals=50, | ||
reward_scaling=0.1, | ||
episode_length=200, | ||
normalize_observations=True, | ||
action_repeat=1, | ||
unroll_length=10, | ||
num_minibatches=32, | ||
num_updates_per_batch=8, | ||
discounting=0.97, | ||
learning_rate=3e-4, | ||
entropy_cost=0, | ||
num_envs=1024, | ||
batch_size=512, | ||
network_factory=network_factory, | ||
seed=0, | ||
) | ||
|
||
# Save the log in the current directory | ||
log_dir = os.path.join(os.path.abspath(os.getcwd()), "logs") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: in most of this repo, we use |
||
if not os.path.exists(log_dir): | ||
os.makedirs(log_dir) | ||
|
||
print(f"Setting up Tensorboard logging in {log_dir}") | ||
logger = LoggerFactory.get_logger("tensorboard", log_dir) | ||
|
||
# Define a callback to log progress | ||
times = [datetime.now()] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: at time of review, I don't see a callback here. I also haven't run the code myself since this isn't ready for review yet, so I could just be misunderstanding how the logging works in this example. |
||
|
||
# Do the training | ||
print("Training...") | ||
make_inference_fn, params, _ = train_fn( | ||
environment=env, | ||
progress_fn=logger.log_progress, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you add type hints in the function signatures, we can remove the redundant types in the docstring.