Skip to content

[WIP] Hack my way to get OPT running #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def _add_network_size_args(parser):
group.add_argument('--onnx-safe', type=bool, required=False,
help='Use workarounds for known problems with '
'Torch ONNX exporter')
group.add_argument('--relu', action='store_true')
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
Expand Down
8 changes: 7 additions & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,13 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True

if args.deepspeed:
load_optimizer_states = False if args.no_load_optim else True
loaded_dir, state_dict = model[0].load_checkpoint(load_dir, load_optimizer_states=load_optimizer_states)
loaded_dir, state_dict = model[0].load_checkpoint(
load_dir,
tag=".",
load_optimizer_states=load_optimizer_states,
load_lr_scheduler_states=load_optimizer_states,
load_module_only=not load_optimizer_states
)
if loaded_dir is None:
print_rank_0('WARNING: could not find the metadata file {} '.format(
load_dir))
Expand Down
2 changes: 2 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(self, init_method, output_layer_init_method):
self.activation_func = openai_gelu
elif args.onnx_safe:
self.activation_func = erf_gelu
elif args.relu:
self.activation_func = F.relu

# Project back to h.
self.dense_4h_to_h = mpu.RowParallelLinear(
Expand Down
2 changes: 1 addition & 1 deletion megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __init__(self, tokenizer_name_or_path, vocab_extra_ids):
if vocab_extra_ids > 0:
# TODO @thomasw21 we might need to concatenate to a pre-existing list?
hf_tokenizer_kwargs["additional_special_tokens"] = [f"<extra_id_{_id}>" for _id in range(vocab_extra_ids)]
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, **hf_tokenizer_kwargs, use_fast=False)
self.encoder = self.tokenizer.get_vocab()
self.decoder = {v: k for k, v in self.encoder.items()}

Expand Down
15 changes: 14 additions & 1 deletion tasks/eval_harness/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,20 @@

def main():
task_list = ALL_TASKS if args.task_list == 'all' else args.task_list.split(',')
tasks.get_task_dict(task_list)
task_and_exceptions = []
for task in task_list:
print("--------")
print(f"Downloading dataset for task: {task}")
try:
tasks.get_task_dict([task])
except Exception as e:
task_and_exceptions.append((task, e))

for task, exception in task_and_exceptions:
print("=======================================")
print(task)
print(exception)


if __name__ == '__main__':
main()
Expand Down
67 changes: 15 additions & 52 deletions tasks/eval_harness/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
import os
import sys
import datetime

from megatron.checkpointing import load_checkpoint

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir,os.path.pardir)))

Expand Down Expand Up @@ -82,7 +85,7 @@ def loglikelihood(self, requests):
# end of text as context
context_enc = [self.EOT_TOKEN_ID]
else:
context_enc = self.tokenizer_encode(context)
context_enc = [self.EOT_TOKEN_ID, *self.tokenizer_encode(context)]

continuation_enc = self.tokenizer_encode(continuation)

Expand Down Expand Up @@ -194,6 +197,8 @@ def create_model_inputs(self, tokens):
prefix_indices=None,
loss_on_targets_only=False)

position_ids = position_ids + 2

return (tokens, position_ids, attention_mask), (tokens, loss_mask)

def _model_call(self, inps):
Expand All @@ -217,9 +222,10 @@ def _model_call(self, inps):


if output is not None:
output = torch.cat(output, 0)[:len(inps)]
else:
output = None
if self.args.offloadearly:
output = torch.cat([F.log_softmax(o, dim=-1).cpu() for o in output[:len(inps)]], 0)
else:
output = torch.cat(output, 0)[:len(inps)]

# hack #2 for adaptive_seq_len to work as total_loss gets appended to and shapes aren't the same
if args.adaptive_seq_len:
Expand Down Expand Up @@ -288,58 +294,17 @@ def override_args(args, override_args, skip_keys, skip_if_specified_keys):
# We then use the megatron deepspeed converter to load the deepspeed checkpoints as if they we're megatron checkpoints.
def load_ds_checkpoint_and_setup_megatron(args):
_print_args = megatron.arguments._print_args
megatron.arguments._print_args = lambda *_args, **kwarg: None

if not os.path.exists(args.load):
raise ValueError(f"checkpoint path {args.load} doesn't exit")

ds_checkpoint = DeepSpeedCheckpoint(args.load,
tp_degree=args.tensor_model_parallel_size,
pp_degree=args.pipeline_model_parallel_size)


cp_args = ds_checkpoint.get_args()
# Merge the current args with the checkpoint args.
skip_keys = [
'abort_on_unmet_fused_kernel_constraints',
'batch_size',
'data_parallel_size',
'deepspeed',
'deepspeed_config',
'device_count',
'global_batch_size',
'inference',
'iteration',
'load',
'local_rank',
'micro_batch_size',
'pipeline_model_parallel_size',
'rampup_batch_size',
'rank',
'tensor_model_parallel_size',
'tensorboard_dir',
'world_size',
]

skip_if_specified = ['merge_file', 'vocab_file']

if args.eval_fp32:
cp_args.fp16 = False
cp_args.bf16 = False
cp_args.params_dtype = torch.float32

override_args(args, cp_args, skip_keys, skip_if_specified)

# stop megatron from reparsing the arguments.
megatron.global_vars._parse_args = lambda *_args, **kwarg: args
megatron.global_vars._GLOBAL_ARGS = args

initialize_megatron()
torch.distributed.barrier()

# Initializing megatron will update eg. tokenizer size. Override again.
override_args(args, cp_args, skip_keys, skip_if_specified)

# print final arguments.
_print_args(args)
if args.deepspeed:
Expand All @@ -354,15 +319,12 @@ def load_ds_checkpoint_and_setup_megatron(args):
import deepspeed
deepspeed.runtime.state_dict_factory.MegatronSDLoader.sanity_check = lambda self, ckpt_file_name: None


cp_path = args.load
args.load = None
model, _, _ = setup_model_and_optimizer(model_provider)
zero_enabled = model[0]._config.zero_enabled
model[0]._config.zero_enabled = False
load_checkpoint(model, optimizer=None, lr_scheduler=None)
model[0]._config.zero_enabled = zero_enabled
model = model[0]
zero_enabled = model._config.zero_enabled
model._config.zero_enabled = False
_, _ = model.load_checkpoint(cp_path, tag = '.', load_optimizer_states=False, load_lr_scheduler_states=False, load_module_only=True)
model._config.zero_enabled = zero_enabled
else:
model = get_model(model_provider)[0]
# Initialize megatron model using the parsed state dict.
Expand Down Expand Up @@ -390,6 +352,7 @@ def tasks_args(parser):
group.add_argument('--intermed_results', default = False, action='store_true', help='Whether to print & write intermediate results for each task')
group.add_argument('--bootstrap_iters', type=int, default=100000, help='How many iterations to use for stderr estimation')
group.add_argument('--micro_bs_multiplier', type=int, default=1, help='Increase the global batch size to remove bubble when pipeline parallel')
group.add_argument('--offloadearly', default = False, action='store_true', help='Offloads logits to CPU earlier to allow using a higher micro_bs_multiplier - Speeds up eval by up to 1.5x for 176B')
return parser

from megatron.global_vars import _parse_args
Expand Down
Loading