Skip to content

Commit

Permalink
reduce memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed May 3, 2024
1 parent 673237b commit cb5d5af
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 10 deletions.
6 changes: 3 additions & 3 deletions examples/config_train_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nanotron.logging import human_format

model_config = LlamaConfig(
# Config for a tiny model model with 1.62M parameters
bos_token_id=1,
eos_token_id=2,
hidden_act="silu",
Expand Down Expand Up @@ -82,8 +81,9 @@
)

# Tokens per batch = micro_batch_size * dp * sequence_length * batch_accumulation_per_replica
# 128 * 4 * 512 * 4 = 1,048,576. A global batch-size of 1M tokens.
tokens = TokensArgs(sequence_length=512, train_steps=200, micro_batch_size=128, batch_accumulation_per_replica=4)
# 16 * 4 * 512 * 32 = 1,048,576. -> A global batch-size of 1M tokens.
# train 200 steps to observe the loss
tokens = TokensArgs(sequence_length=512, train_steps=200, micro_batch_size=16, batch_accumulation_per_replica=32)

checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions examples/config_train_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ tokenizer:
tokenizer_name_or_path: gpt2
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 4
batch_accumulation_per_replica: 32
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 128
micro_batch_size: 16
sequence_length: 512
train_steps: 200
val_check_interval: -1
45 changes: 40 additions & 5 deletions tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,19 @@ def extract_loss(line):


def test_train_llama():
# create CONFIG_FILE
# create config file
cmd = f"python {CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
# Read and print output in real-time
while True:
line = process.stdout.readline()
if process.poll() is not None and line == b"":
break
if line:
print(line.decode("utf-8"), end="")

process.wait() # Wait for the process to finish
assert process.returncode == 0

# run training
# set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations
Expand Down Expand Up @@ -81,9 +91,18 @@ def test_train_llama():

# also run the tiny llama example. Only want to assert it can be ran.
def test_tiny_llama():
# create CONFIG_FILE
# create config file
cmd = f"python {TINY_LLLAMA_CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
# Read and print output in real-time
while True:
line = process.stdout.readline()
if process.poll() is not None and line == b"":
break
if line:
print(line.decode("utf-8"), end="")
process.wait() # Wait for the process to finish
assert process.returncode == 0

# run training
# set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations
Expand All @@ -104,8 +123,24 @@ def test_tiny_llama():


if __name__ == "__main__":
# create config file
cmd = f"python {CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
try:
# Read and print output in real-time
while True:
line = process.stdout.readline()
if process.poll() is not None and line == b"":
break
if line:
print(line.decode("utf-8"), end="")
process.wait() # Wait for the process to finish
assert process.returncode == 0
except AssertionError:
print("Command failed with exit code:", process.returncode)
exit()
else:
print("Config created successfully.")

# run training
# set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations
Expand Down

0 comments on commit cb5d5af

Please sign in to comment.