From 0936e40eb3c6f92b590b01738c3067e1f1b5a76e Mon Sep 17 00:00:00 2001 From: crisostomi Date: Tue, 25 Jun 2024 11:35:53 +0200 Subject: [PATCH] add typing --- src/scripts/__init__.py | 0 src/scripts/combine_checkpoints.py | 7 ++- tests/test_combine_checkpoints.py | 99 ++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) create mode 100644 src/scripts/__init__.py create mode 100644 tests/test_combine_checkpoints.py diff --git a/src/scripts/__init__.py b/src/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/scripts/combine_checkpoints.py b/src/scripts/combine_checkpoints.py index ac8368c..2497e36 100644 --- a/src/scripts/combine_checkpoints.py +++ b/src/scripts/combine_checkpoints.py @@ -5,11 +5,16 @@ # LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Union + import torch def combine_checkpoints( - generator_checkpoint: str, detector_checkpoint: str, output_checkpoint: str + generator_checkpoint: Union[str, Path], + detector_checkpoint: Union[str, Path], + output_checkpoint: Union[str, Path], ): """Combine split generator and detector checkpoints into a single checkpoint that can be further trained.""" gen_ckpt = torch.load(generator_checkpoint) diff --git a/tests/test_combine_checkpoints.py b/tests/test_combine_checkpoints.py new file mode 100644 index 0000000..343e6b8 --- /dev/null +++ b/tests/test_combine_checkpoints.py @@ -0,0 +1,99 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import urllib +from pathlib import Path + +import pytest +import torch +import torchaudio + +from audioseal import AudioSeal +from audioseal.builder import ( + AudioSealDetectorConfig, + AudioSealWMConfig, + create_detector, + create_generator, +) +from audioseal.models import AudioSealDetector, AudioSealWM +from scripts.combine_checkpoints import combine_checkpoints + + +@pytest.fixture +def ckpts_dir() -> Path: + path = Path("TMP") + path.mkdir(exist_ok=True, parents=True) + + return path + + +@pytest.fixture +def generator_ckpt_path(ckpts_dir: Path) -> Path: + + checkpoint, config = AudioSeal.parse_model( + "audioseal_wm_16bits", + AudioSealWMConfig, + nbits=16, + ) + + model = create_generator(config) + model.load_state_dict(checkpoint) + + checkpoint = {"xp.cfg": config, "model": model.state_dict()} + path = ckpts_dir / "generator_checkpoint.pth" + + torch.save(checkpoint, path) + + return path + + +@pytest.fixture +def detector_ckpt_path(ckpts_dir: Path) -> Path: + + checkpoint, config = AudioSeal.parse_model( + "audioseal_detector_16bits", + AudioSealDetectorConfig, + nbits=16, + ) + + model = create_detector(config) + model.load_state_dict(checkpoint) + + checkpoint = {"xp.cfg": config, "model": model.state_dict()} + path = ckpts_dir / "detector_checkpoint.pth" + + torch.save(checkpoint, path) + + return path + + +def test_combine_checkpoints( + generator_ckpt_path: Path, detector_ckpt_path: Path, ckpts_dir: Path +): + + combined_ckpt_path = ckpts_dir / "combined.pth" + + combine_checkpoints(generator_ckpt_path, detector_ckpt_path, combined_ckpt_path) + + assert combined_ckpt_path.exists() + + generator = torch.load(generator_ckpt_path) + detector = torch.load(detector_ckpt_path) + + combined = torch.load(combined_ckpt_path) + + for key in generator["model"]: + assert f"generator.{key}" in combined["model"] + + for key in detector["model"]: + assert f"detector.{key}" in combined["model"] + + # clean up + combined_ckpt_path.unlink() + generator_ckpt_path.unlink() + detector_ckpt_path.unlink() + ckpts_dir.rmdir()