Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions ambersim/logger/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import jax
import wandb
from torch.utils.tensorboard import SummaryWriter


class BaseLogger:
"""Base logger interface that defines common methods for logging metrics and parameters."""

def __init__(self, log_dir: Union[str, Path]=None):
"""Initializes the BaseLogger with a specified log directory.

Args:
log_dir (str): Directory to store the logs. If None, uses default log directory.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add type hints in the function signatures, we can remove the redundant types in the docstring.

"""
self.log_dir = log_dir

def log_metric(self, key, value, step=None):
"""Logs a metric value.

Args:
key (str): The name of the metric.
value (float): The value of the metric.
step (int, optional): The step number at which the metric is logged.
"""
raise NotImplementedError
Comment on lines +17 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do you want every child class to have log_metric and log_params? If so, I would make BaseLogger abstract and decorate these abstractmethod instead. If, however, there are some situations where child classes shouldn't necessarily implement these functions, then it's fine to leave it as-is. I don't know what vision we have for the logging API.


def log_params(self, params):
"""Logs parameters.

Args:
params (dict): A dictionary containing parameter names and their values.
"""
raise NotImplementedError

def log_progress(self, step, state_info):
"""Logs the state of a process using the log_metric method.

Args:
state_info (dict): A dictionary containing state information.
step (int, optional): The step number at which the state is logged.
"""
for key, value in state_info.items():
if isinstance(value, jax.Array):
value = float(value) # we need floats for logging
self.log_metric(key, value, step)


class TensorBoardLogger(BaseLogger):
"""Logger that implements logging functionality using TensorBoard.

Inherits from BaseLogger and implements its methods for TensorBoard specific logging.
"""

def __init__(self, log_dir=None):
"""Initializes the TensorBoardLogger with a specified log directory.

Args:
log_dir (str): Directory to store TensorBoard logs. If None, uses default log directory.
"""
super().__init__(log_dir)
self.writer = SummaryWriter(log_dir)

def log_metric(self, key, value, step=None):
"""Logs a metric to TensorBoard.

Args:
key (str): The name of the metric.
value (float): The value of the metric.
step (int, optional): The step number at which the metric is logged.
"""
self.writer.add_scalar(key, value, step)

def log_params(self, params):
"""Logs parameters to TensorBoard.

Args:
params (dict): A dictionary of parameters to log.
"""
self.writer.add_hparams(params)


class WandbLogger(BaseLogger):
"""Logger that implements logging functionality using Weights & Biases (wandb).

Inherits from BaseLogger and implements its methods for wandb specific logging.
"""

def __init__(self, log_dir=None, project_name=None):
"""Initializes the WandbLogger with a specified log directory and project name.

Args:
log_dir (str): Directory to store local wandb logs. If None, uses default wandb directory.
project_name (str): Name of the wandb project. If None, a default project is used.
"""
super().__init__(log_dir)
wandb.init(dir=log_dir, project=project_name)

def log_metric(self, key, value, step=None):
"""Logs a metric to wandb.

Args:
key (str): The name of the metric.
value (float): The value of the metric.
step (int, optional): The step number at which the metric is logged.
"""
wandb.log({key: value}, step=step)

def log_params(self, params):
"""Logs parameters to wandb.

Args:
params (dict): A dictionary of parameters to log.
"""
wandb.config.update(params)


class LoggerFactory:
"""Factory class to create logger instances based on specified logger type.

Supports creation of different types of loggers like TensorBoardLogger and WandbLogger.
"""

@staticmethod
def get_logger(logger_type, log_dir=None):
"""Creates and returns a logger instance based on the specified logger type.

Args:
logger_type (str): The type of logger to create ('tensorboard' or 'wandb').
log_dir (str, optional): Directory to store the logs. Specific to the logger type.

Returns:
BaseLogger: An instance of the requested logger type.

Raises:
ValueError: If an unsupported logger type is specified.
"""
if logger_type == "tensorboard":
return TensorBoardLogger(log_dir)
elif logger_type == "wandb":
return WandbLogger(log_dir)
else:
raise ValueError("Unsupported logger type")
64 changes: 64 additions & 0 deletions examples/rl/pendulum/ex_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import functools
import os
from datetime import datetime

import jax
from brax import envs
from brax.training.agents.ppo import networks as ppo_networks
from brax.training.agents.ppo import train as ppo

from ambersim.logger.logger import LoggerFactory
from ambersim.rl.pendulum.swingup import PendulumSwingupEnv

"""
A pendulum swingup example that uses a custom logger to log training
progress in real time.
"""

if __name__ == "__main__":
# Initialize the environment
envs.register_environment("pendulum_swingup", PendulumSwingupEnv)
env = envs.get_environment("pendulum_swingup")

# Define the training function
network_factory = functools.partial(
ppo_networks.make_ppo_networks,
policy_hidden_layer_sizes=(64,) * 3,
)
train_fn = functools.partial(
ppo.train,
num_timesteps=100_000,
num_evals=50,
reward_scaling=0.1,
episode_length=200,
normalize_observations=True,
action_repeat=1,
unroll_length=10,
num_minibatches=32,
num_updates_per_batch=8,
discounting=0.97,
learning_rate=3e-4,
entropy_cost=0,
num_envs=1024,
batch_size=512,
network_factory=network_factory,
seed=0,
)

# Save the log in the current directory
log_dir = os.path.join(os.path.abspath(os.getcwd()), "logs")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: in most of this repo, we use pathlib.Path. For consistency, it would be nice to change the os-based code to instead use Path. If you don't want to do it, it's also totally fine.

if not os.path.exists(log_dir):
os.makedirs(log_dir)

print(f"Setting up Tensorboard logging in {log_dir}")
logger = LoggerFactory.get_logger("tensorboard", log_dir)

# Define a callback to log progress
times = [datetime.now()]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: at time of review, I don't see a callback here. I also haven't run the code myself since this isn't ready for review yet, so I could just be misunderstanding how the logging works in this example.


# Do the training
print("Training...")
make_inference_fn, params, _ = train_fn(
environment=env,
progress_fn=logger.log_progress,
)