Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ort trainer backend #69

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions examples/cnndailymail_text_summarization/azureml/submit_ortds.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from azureml.core import Experiment, Workspace, ScriptRunConfig
from azureml.core import Experiment, Workspace, ScriptRunConfig, Datastore
from azureml.core.compute import AmlCompute
from azureml.core.runconfig import MpiConfiguration

# put your AML workspace config.json in this directory!
ws = Workspace.from_config()
ws_details = ws.get_details()
ds = ws.get_default_datastore()
ds = Datastore(ws, 'ws2_ds')

gpu_compute_target = AmlCompute(workspace=ws, name='sriovdedicated1')
gpu_compute_target = AmlCompute(workspace=ws, name='LoRA-ND')
print(gpu_compute_target.status.serialize())

from azureml.core import Dataset
from azureml.data import OutputFileDatasetConfig


# create input/output datasets
def get_input_dataset(datastore, path_on_datastore, dataset_name):
dataset = Dataset.File.from_files(path=[(datastore, path_on_datastore)])
Expand All @@ -25,29 +24,30 @@ def get_output_dataset(datastore, path_on_datastore, dataset_name):
def get_args(outputSuffix="deepspeed_ort_amp_nopadding_v100_8"):
all_params_default = [
'--data_path', get_input_dataset(ds, f'datasets/cnn_dm/preprocessed/bart/', "data_path"),
'--config_path', 'config-ortds.yaml',
'--config_path', 'config-ort.yaml',
]

return all_params_default

from azureml.core import Environment

# Creates the environment inside a Docker container.
pytorch_env = Environment(name='myEnv')
pytorch_env = Environment(name='pymarlin-ort-ds')
pytorch_env.docker.enabled = True
# docker file in this directory built for your convenience
pytorch_env.docker.base_image = "pymarlin/base-gpu:cuda11.1.cudnn8.ds.ort"

pytorch_env.docker.base_image = "pymarlin/pymarlin.cuda11.1"
pytorch_env.python.user_managed_dependencies = True
pytorch_env.python.interpreter_path = '/opt/miniconda/bin/python'

mpi = MpiConfiguration()
#NCv3_24rs - 4 16GB V100 GPU's per node
mpi.process_count_per_node = 4
mpi.node_count = 2
# NDv2, 8 GPU's per node
mpi.process_count_per_node = 8
mpi.node_count = 1

# ds.upload_files(['local path to preprocessed data'], 'datasets/cnn_dm/preprocessed/bart')

script = "train_ortds.py"
script = "train.py"
codepath = '..'

config = ScriptRunConfig(source_directory=codepath,
Expand All @@ -57,14 +57,13 @@ def get_args(outputSuffix="deepspeed_ort_amp_nopadding_v100_8"):
environment=pytorch_env,
distributed_job_config=mpi)

experiment_name = 'pymarlin_summarization_bart_ortds'
experiment_name = 'summarization_bart_ort_backend'
experiment = Experiment(ws, name=experiment_name)

run = experiment.submit(config)

run.tag('nodes', f'{mpi.node_count}')
run.tag('process_count_per_node', f'{mpi.process_count_per_node}')
run.tag('notes', '2 node with ort+ds')

print("Submitted run")
print(f"\n{run.get_portal_url()}")
33 changes: 33 additions & 0 deletions examples/cnndailymail_text_summarization/config-ort.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
data_path: 'D:/data/cnn_cln'

trainer:
max_train_steps_per_epoch : null # Maximum train steps per epoch.
max_val_steps_per_epoch : null # Maximum validation steps per epoch.
train_batch_size: 32 # Training global batch size.
val_batch_size: 32 # Validation batch size per GPU.
epochs: 3 # Total epochs to run.
gpu_batch_size_limit : 4 # Max limit for GPU batch size during training.
disable_tqdm : False
writers: ["stdout", "aml", "tensorboard"]
backend: 'ddp-amp-ort'
module:
max_length_encoder : 1024
max_length_decoder : 128
wrt:
tb_log_dir : 'logs'
stat:
log_steps : 50
chkp:
checkpoint : True
delete_existing_checkpoints: False
save_dir: 'outputs' #aml output path. does not require mounting
load_dir: null
load_filename: null

# add more from BartForConditionalGeneration.generate?
generate:
max_length: 128
do_sample : False
num_beams : 5
# support everything in a yaml. ignore (print warning) everything that's not present.
# Do not add the requirement to define anything in the parser other than yamls
1 change: 1 addition & 0 deletions examples/cnndailymail_text_summarization/config-ortds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ trainer:
train_batch_size: 32 # Training global batch size.
val_batch_size: 32 # Validation batch size per GPU.
epochs: 3 # Total epochs to run.
ort: True
gpu_batch_size_limit : 4 # Max limit for GPU batch size during training.
disable_tqdm : True
writers: ["stdout", "aml", "tensorboard"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,6 @@ def get_core_model(model, deepspeed_flag=False, ort_flag=False):
if deepspeed_flag:
module = module.module
if ort_flag:
module = module._original_module
module = module._module_metadata.original_module

return module
4 changes: 2 additions & 2 deletions examples/cnndailymail_text_summarization/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(
generate_kwargs = {}
):
super().__init__()
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
self.tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-base")
self.model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
self.tokenizer = BartTokenizerFast.from_pretrained("facebook/bart-large")
self.max_lr = max_lr
self.max_length_encoder = max_length_encoder
self.max_length_decoder = max_length_decoder
Expand Down
52 changes: 52 additions & 0 deletions pymarlin/core/ort_trainer_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from .trainer_backend import *
import sys
from pymarlin.utils.logger import getlogger
import torch.nn as nn

class ORTTrainerBackend(AbstractTrainerBackendDecorator):
def __init__(self, trainer_backend):
super().__init__(trainer_backend)
self.logger = getlogger(__file__,log_level='DEBUG')

# TODO: add these under TrainerBackendDecoratorPassThrough, which ORT, Opacus can inherit from
# so that DDP backend can get/set from wrapped SingleProcess*
def __getattribute__(self, name):
# self.logger.debug(f'__getattribute__(name={name})')
if name in ('trainer_backend','init','__init__','logger', '_core_model', 'core_model') :
return super().__getattribute__(name)
else:
return self.trainer_backend.__getattribute__(name)

def __setattr__(self, name, value):
# self.logger.debug(f'__setattr_(name={name},value={value})')
if name in ('trainer_backend','init','__init__','logger', '_core_model', 'core_model') :
super().__setattr__(name, value)
else:
self.trainer_backend.__setattr__(name, value)

@property
def core_model(self):
return self._core_model

@core_model.setter
def core_model(self, model):
self._core_model = model

def init(self, args: TrainerBackendArguments):
super().init(args)
try:
from torch_ort import ORTModule
except:
self.logger.error("could not import ORTModule")
sys.exit(1)

assert(hasattr(self.trainer_backend.model, 'model'), 'self.trainer_backend.model.model does not exist')
assert(isinstance(self.trainer_backend.model.model, nn.Module), "expected module_inteface.model of type torch.nn.Module")

# get the reference and save it before ORTModule wrap
self.core_model = self.trainer_backend.model.model
module = self.trainer_backend.model # TODO: should we change trainer_backend.model to module?
module.get_core_model = lambda: self.core_model

self.logger.info("Wrapping trainer_backend.model.model")
self.trainer_backend.model.model = ORTModule(self.trainer_backend.model.model)
35 changes: 3 additions & 32 deletions pymarlin/core/trainer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Alternatively a user can provide a custom `TrainerBackend`.
"""
from tqdm.auto import tqdm
from abc import ABC, abstractmethod
from abc import ABC, abstractmethod, abstractproperty
import dataclasses
from typing import Iterable, Optional, Union
import warnings
Expand All @@ -31,32 +31,12 @@
SequentialDistributedSampler,
)


try:
from apex import amp
except ImportError:
amp = None
from functools import wraps

def build_trainer_backend(trainer_backend_name, *args, **kwargs):
"""Factory for trainer_backends

Args:
trainer_backend_name (str): TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apex
args (sequence): TrainerBackend positional arguments
kwargs (dict): TrainerBackend keyword arguments
"""
factory_dict = {
"sp": SingleProcess,
"sp-amp": SingleProcessAmp,
"sp-amp-apex": SingleProcessApexAmp,
"ddp": DDPTrainerBackendFactory(SingleProcess),
"ddp-amp": DDPTrainerBackendFactory(SingleProcessAmp),
"ddp-amp-apex": DDPTrainerBackendFactory(SingleProcessApexAmp),
}
return factory_dict[trainer_backend_name](*args, **kwargs)


@dataclasses.dataclass
class TrainerBackendArguments:
"""
Expand Down Expand Up @@ -106,13 +86,11 @@ def get_batches_completed(self):
def get_global_steps_completed(self):
pass

@property
@abstractmethod
@abstractproperty
def train_sampler(self):
return RandomSampler

@property
@abstractmethod
@abstractproperty
def val_sampler(self):
return SequentialSampler

Expand Down Expand Up @@ -712,10 +690,3 @@ def train_sampler(self):
@property
def val_sampler(self):
return SequentialDistributedSampler

def DDPTrainerBackendFactory(trainer_backend_cls): # pylint: disable=invalid-name
def create(*args, gather_frequency: Optional[int] = None, **kwargs):
# pull out args to DDPTrainerBackend if needed here.
return DDPTrainerBackend(trainer_backend_cls(*args, **kwargs), gather_frequency=gather_frequency)

return create
37 changes: 37 additions & 0 deletions pymarlin/core/trainer_backend_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from .trainer_backend import *
from .ort_trainer_backend import ORTTrainerBackend

def build_trainer_backend(trainer_backend_name, *args, **kwargs):
"""Factory for trainer_backends

Args:
trainer_backend_name (str): TrainerBackend Name. Possible choices are currently: sp, sp-amp, sp-amp-apex, ddp, ddp-amp, ddp-amp-apex
args (sequence): TrainerBackend positional arguments
kwargs (dict): TrainerBackend keyword arguments
"""
factory_dict = {
"sp": SingleProcess,
"sp-amp": SingleProcessAmp,
"sp-amp-apex": SingleProcessApexAmp,
"ddp": DDPTrainerBackendFactory(SingleProcess),
"ddp-amp-ort": DDPORTTrainerBackendFactory(SingleProcessAmp),
"ddp-amp": DDPTrainerBackendFactory(SingleProcessAmp),
"ddp-amp-apex": DDPTrainerBackendFactory(SingleProcessApexAmp),
}
return factory_dict[trainer_backend_name](*args, **kwargs)

def DDPTrainerBackendFactory(trainer_backend_cls): # pylint: disable=invalid-name
def create(*args, gather_frequency: Optional[int] = None, **kwargs):
# pull out args to DDPTrainerBackend if needed here.
return DDPTrainerBackend(trainer_backend_cls(*args, **kwargs), gather_frequency=gather_frequency)

return create

# testing TODO: refactor factory logic to do hierachael decoration (sp->ort->ddp/deepspeed)
def DDPORTTrainerBackendFactory(trainer_backend_cls):
def create(*args, gather_frequency: Optional[int] = None, **kwargs):
return DDPTrainerBackend(
ORTTrainerBackend(trainer_backend_cls(*args, **kwargs)),
gather_frequency=gather_frequency)

return create