Skip to content
11 changes: 8 additions & 3 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,16 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None:
params = _convert_params(params)
params = _flatten_dict(params)

import mlflow.utils.validation
from mlflow.entities import Param

# Truncate parameter values to 250 characters.
# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0
params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()]
try: # Check maximum param value length is available and use it
param_length_limit = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
except Exception: # Fallback (in case of MAX_PARAM_VAL_LENGTH not available)
param_length_limit = 250 # Historical default value

# Use mlflow default limit or truncate parameter values to 250 characters if limit is not available
params_list = [Param(key=k, value=str(v)[:param_length_limit]) for k, v in params.items()]

# Log in chunks of 100 parameters (the maximum allowed by MLflow).
for idx in range(0, len(params_list), 100):
Expand Down
15 changes: 9 additions & 6 deletions tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,14 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path):
)
param.assert_called_with(key="test", value="test_param")

long_params = {"test": "test_param" * 50}
logger.log_hyperparams(long_params)

logger.experiment.log_batch.assert_called_with(
run_id=logger.run_id, params=[param(key="test", value="test_param" * 50)]
)
param.assert_called_with(key="test", value="test_param" * 50)

metrics = {"some_metric": 10}
logger.log_metrics(metrics)

Expand Down Expand Up @@ -317,12 +325,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path):

@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock())
def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path):
"""Test that long parameter values are truncated to 250 characters."""

def _check_value_length(value, *args, **kwargs):
assert len(value) <= 250

mlflow_mock.entities.Param.side_effect = _check_value_length
"""Test that long parameter values are handled correctly."""

logger = MLFlowLogger("test", save_dir=str(tmp_path))

Expand Down
Loading