Skip to content

I added support for Accelerated PyTorch training on Mac #425

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions ctgan/synthesizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,13 @@ def __setstate__(self, state):
state['random_states'] = (current_numpy_state, current_torch_state)

self.__dict__ = state
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Prioritize CUDA if available, then MPSCUDA, finally CPU
if torch.cuda.is_available():
device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
self.set_device(device)

def save(self, path):
Expand All @@ -118,11 +124,33 @@ def save(self, path):
@classmethod
def load(cls, path):
"""Load the model stored in the passed `path`."""
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Prioritize CUDA if available, then MPS, finally CPU
if torch.cuda.is_available():
device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')
model = torch.load(path)
model.set_device(device)
return model

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU')."""
self._device = device
if device.type == 'cuda':
# For CUDA, move the generator to the appropriate device
if self._generator is not None:
self._generator.to(self._device)
elif device.type == 'mps':
# For MPS, move module parameters and buffers to the MPS device
if self._generator is not None:
self._generator.to(self._device)
for parameter in self._generator.parameters():
parameter.data = parameter.data.to(self._device)
for buffer in self._generator.buffers():
buffer.data = buffer.data.to(self._device)

def set_random_state(self, random_state):
"""Set the random state.

Expand All @@ -148,4 +176,4 @@ def set_random_state(self, random_state):
raise TypeError(
f'`random_state` {random_state} expected to be an int or a tuple of '
'(`np.random.RandomState`, `torch.Generator`)'
)
)
13 changes: 11 additions & 2 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class CTGAN(BaseSynthesizer):
Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
mps (bool):
Whether to attempt to use mps for GPU computation.
If this is False or MPS is not available, CPU will be used.
Defaults to ``False``.
"""

def __init__(
Expand All @@ -160,6 +164,7 @@ def __init__(
epochs=300,
pac=10,
cuda=True,
mps=False,
):
assert batch_size % 2 == 0

Expand All @@ -179,12 +184,16 @@ def __init__(
self._epochs = epochs
self.pac = pac

if not cuda or not torch.cuda.is_available():
if not cuda and not mps:
device = 'cpu'
elif mps and torch.backends.mps.is_available():
device = 'mps'
elif cuda and torch.cuda.is_available():
device = 'cuda'
elif isinstance(cuda, str):
device = cuda
else:
device = 'cuda'
device = 'cpu'

self._device = torch.device(device)

Expand Down
104 changes: 104 additions & 0 deletions tests/integration/synthesizer/test_ctgan_apple_mps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Integration tests for ctgan.

These tests only ensure that the software does not crash and that
the API works as expected in terms of input and output data formats,
but correctness of the data values and the internal behavior of the
model are not checked.
"""

import tempfile as tf

import numpy as np
import pandas as pd
import pytest
import torch
import os

from ctgan.synthesizers.ctgan import CTGAN

@pytest.fixture
def random_state():
return 42

@pytest.fixture
def train_data():
size = 100
# Explicitly specify categorical columns during DataFrame creation
df = pd.DataFrame({
'continuous': np.random.normal(size=size),
'categorical': np.random.choice(['a', 'b', 'c'], size=size),
'binary': np.random.choice([0, 1], size=size).astype(int)
})
return df

@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
def test_ctgan_fit_sample_apple_mps_hardware(tmpdir, train_data, random_state):
"""Test the CTGAN can fit and sample."""
# Specify discrete columns explicitly
discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns
ctgan = CTGAN(cuda=False, epochs=1)
ctgan.set_random_state(random_state)
ctgan.fit(train_data, discrete_columns=discrete_columns)
sampled = ctgan.sample(1000)
assert sampled.shape == (1000, train_data.shape[1])

# Save and load
path = os.path.join(tmpdir, 'test_ctgan.pkl')
ctgan.save(path)
ctgan = CTGAN.load(path)

sampled = ctgan.sample(1000)
assert sampled.shape == (1000, train_data.shape[1])



@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
def test_mps_training_apple_mps_hardware(tmpdir, train_data, random_state):
"""Test CTGAN training on MPS device."""
ctgan = CTGAN(cuda=False, mps=True, epochs=1)
ctgan.set_random_state(random_state)
discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns

# Check device of model components before training
assert ctgan._device.type == 'mps'
# assert next(ctgan._generator.parameters()).device.type == 'mps'

ctgan.fit(train_data, discrete_columns=discrete_columns)

# Check device of model components after training
assert next(ctgan._generator.parameters()).device.type == 'mps'

sampled = ctgan.sample(100)
assert sampled.shape == (100, train_data.shape[1])


@pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS not available")
def test_save_load_apple_mps_hardware(tmpdir, train_data, random_state):
"""Test the CTGAN saves and loads correctly."""
ctgan = CTGAN(cuda=False, epochs=1)
ctgan.set_random_state(random_state)
discrete_columns = ['categorical', 'binary'] # Explicitly specify discrete columns

ctgan.fit(train_data, discrete_columns=discrete_columns)

# Save and load
path = os.path.join(tmpdir, 'test_ctgan.pkl')
ctgan.save(path)
ctgan = CTGAN.load(path)

# Check device type after loading
if torch.backends.mps.is_available():
assert ctgan._device.type == 'mps'
assert next(ctgan._generator.parameters()).device.type == 'mps'
elif torch.cuda.is_available():
assert ctgan._device.type == 'cuda'
assert next(ctgan._generator.parameters()).device.type == 'cuda'
else:
assert ctgan._device.type == 'cpu'
assert next(ctgan._generator.parameters()).device.type == 'cpu'

sampled = ctgan.sample(1000)
assert sampled.shape == (1000, train_data.shape[1])