Skip to content

Commit

Permalink
add typing
Browse files Browse the repository at this point in the history
  • Loading branch information
crisostomi committed Jun 25, 2024
1 parent 8c0562e commit 0936e40
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 1 deletion.
Empty file added src/scripts/__init__.py
Empty file.
7 changes: 6 additions & 1 deletion src/scripts/combine_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,16 @@
# LICENSE file in the root directory of this source tree.


from pathlib import Path
from typing import Union

import torch


def combine_checkpoints(
generator_checkpoint: str, detector_checkpoint: str, output_checkpoint: str
generator_checkpoint: Union[str, Path],
detector_checkpoint: Union[str, Path],
output_checkpoint: Union[str, Path],
):
"""Combine split generator and detector checkpoints into a single checkpoint that can be further trained."""
gen_ckpt = torch.load(generator_checkpoint)
Expand Down
99 changes: 99 additions & 0 deletions tests/test_combine_checkpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import urllib
from pathlib import Path

import pytest
import torch
import torchaudio

from audioseal import AudioSeal
from audioseal.builder import (
AudioSealDetectorConfig,
AudioSealWMConfig,
create_detector,
create_generator,
)
from audioseal.models import AudioSealDetector, AudioSealWM
from scripts.combine_checkpoints import combine_checkpoints


@pytest.fixture
def ckpts_dir() -> Path:
path = Path("TMP")
path.mkdir(exist_ok=True, parents=True)

return path


@pytest.fixture
def generator_ckpt_path(ckpts_dir: Path) -> Path:

checkpoint, config = AudioSeal.parse_model(
"audioseal_wm_16bits",
AudioSealWMConfig,
nbits=16,
)

model = create_generator(config)
model.load_state_dict(checkpoint)

checkpoint = {"xp.cfg": config, "model": model.state_dict()}
path = ckpts_dir / "generator_checkpoint.pth"

torch.save(checkpoint, path)

return path


@pytest.fixture
def detector_ckpt_path(ckpts_dir: Path) -> Path:

checkpoint, config = AudioSeal.parse_model(
"audioseal_detector_16bits",
AudioSealDetectorConfig,
nbits=16,
)

model = create_detector(config)
model.load_state_dict(checkpoint)

checkpoint = {"xp.cfg": config, "model": model.state_dict()}
path = ckpts_dir / "detector_checkpoint.pth"

torch.save(checkpoint, path)

return path


def test_combine_checkpoints(
generator_ckpt_path: Path, detector_ckpt_path: Path, ckpts_dir: Path
):

combined_ckpt_path = ckpts_dir / "combined.pth"

combine_checkpoints(generator_ckpt_path, detector_ckpt_path, combined_ckpt_path)

assert combined_ckpt_path.exists()

generator = torch.load(generator_ckpt_path)
detector = torch.load(detector_ckpt_path)

combined = torch.load(combined_ckpt_path)

for key in generator["model"]:
assert f"generator.{key}" in combined["model"]

for key in detector["model"]:
assert f"detector.{key}" in combined["model"]

# clean up
combined_ckpt_path.unlink()
generator_ckpt_path.unlink()
detector_ckpt_path.unlink()
ckpts_dir.rmdir()

0 comments on commit 0936e40

Please sign in to comment.