Skip to content

Commit

Permalink
add measure bandwidth
Browse files Browse the repository at this point in the history
  • Loading branch information
NouamaneTazi committed Nov 27, 2024
1 parent 8aa249e commit e3b886c
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 34 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ def __post_init__(self):

# Some final sanity checks across separate arguments sections:
if self.profiler is not None and self.profiler.profiler_export_path is not None:
assert self.tokens.train_steps < 10
assert self.tokens.train_steps < 11

if self.optimizer is not None and self.optimizer.learning_rate_scheduler.lr_decay_steps is None:
self.optimizer.learning_rate_scheduler.lr_decay_steps = (
Expand Down
82 changes: 58 additions & 24 deletions src/nanotron/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def get_profiler(config: Config):
if config.profiler is not None:
if config.profiler.profiler_export_path is not None:
on_trace_ready = tensorboard_trace_handler(
config.profiler.profiler_export_path / datetime.now().strftime("%Y%m%d-%H%M%S")
config.profiler.profiler_export_path / datetime.now().strftime("%Y%m%d-%H%M%S-" + config.general.run)
)
else:
on_trace_ready = None
Expand Down Expand Up @@ -605,7 +605,10 @@ def create_table_log(
LogItem("mTFLOPs", model_tflops, ".2f"),
LogItem("hTFLOPs", hardware_tflops, ".2f"),
LogItem("tok/s/gpu", tokens_per_sec / parallel_context.world_pg.size(), ".2f"),
LogItem("Bandwidth (GB/s)", bandwidth, ".2f"),
LogItem("AllReduce (GB/s)", bandwidth["all_reduce"], ".2f"),
LogItem("ReduceScatter (GB/s)", bandwidth["reduce_scatter"], ".2f"),
LogItem("AR Intra-node (GB/s)", bandwidth["all_reduce_intranode"], ".2f"),
LogItem("RS Intra-node (GB/s)", bandwidth["reduce_scatter_intranode"], ".2f"),
LogItem("Mem Alloc (GB)", torch.cuda.max_memory_allocated() / 1024**3, ".2f"),
LogItem("Mem Res (GB)", torch.cuda.max_memory_reserved() / 1024**3, ".2f"),
]
Expand All @@ -625,27 +628,57 @@ def create_table_output(table_log, column_widths):


def write_to_csv(csv_filename, table_log, model_tflops, slurm_job_id):
if not os.path.exists(csv_filename):
os.makedirs(os.path.dirname(csv_filename), exist_ok=True)
with open(csv_filename, mode="w") as fo:
writer = csv.writer(fo)
writer.writerow([item.tag for item in table_log])
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# elif model_tflops > 0:
# # replace line with same job_id
# with open(csv_filename, mode="r") as fi:
# lines = fi.readlines()
# with open(csv_filename, mode="w") as fo:
# writer = csv.writer(fo)
# for line in lines:
# if line.startswith(slurm_job_id):
# writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
# else:
# fo.write(line)
else:
with open(csv_filename, mode="a") as fo:
writer = csv.writer(fo)
writer.writerow([f"{item.scalar_value:{item.log_format}}" for item in table_log])
"""Write benchmark results to a CSV file with file locking."""
try:
# Check if csv_filename is valid
if not csv_filename:
logger.warning("No benchmark CSV path specified - skipping CSV output")
return

# Create output directory if needed
csv_dir = os.path.dirname(csv_filename)
if csv_dir: # Only try to create directory if path has a directory component
os.makedirs(csv_dir, exist_ok=True)

# Format row data
header = [item.tag for item in table_log]
row = [f"{item.scalar_value:{item.log_format}}" for item in table_log]

# Use file locking to handle concurrent writes
lock_file = f"{csv_filename}.lock"
max_attempts = 10
attempt = 0

while attempt < max_attempts:
try:
# Try to create lock file
with open(lock_file, "x") as _:
try:
# We got the lock, do the write
write_mode = "w" if not os.path.exists(csv_filename) else "a"
with open(csv_filename, mode=write_mode, newline="") as f:
writer = csv.writer(f)
if write_mode == "w":
writer.writerow(header)
writer.writerow(row)
f.flush()
os.fsync(f.fileno())
break
finally:
# Always remove lock file
os.remove(lock_file)
except FileExistsError:
# Another process has the lock, wait and retry
attempt += 1
time.sleep(0.1) # Wait 100ms before retrying

if attempt == max_attempts:
logger.error(f"Failed to acquire lock for {csv_filename} after {max_attempts} attempts")

except OSError as e:
logger.error(f"Failed to write benchmark results to {csv_filename}: {str(e)}")
# Don't raise the error - just log it and continue
return


def log_throughput(
Expand All @@ -654,7 +687,7 @@ def log_throughput(
model_tflops=0,
hardware_tflops=0,
tokens_per_sec=0,
bandwidth=0,
bandwidth={"all_reduce": 0, "reduce_scatter": 0, "all_reduce_intranode": 0, "reduce_scatter_intranode": 0},
):
slurm_job_id = os.environ.get("SLURM_JOB_ID", "N/A")

Expand All @@ -673,6 +706,7 @@ def log_throughput(

if dist.get_rank(parallel_context.world_pg) == 0:
write_to_csv(config.general.benchmark_csv_path, table_log, model_tflops, slurm_job_id)
dist.barrier(group=parallel_context.world_pg)


def compute_remain_train_steps_of_a_data_stage_from_ckp(
Expand Down
13 changes: 8 additions & 5 deletions src/nanotron/parallel/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Literal, Tuple, Annotated
from typing import Literal, Tuple

import numpy as np
import torch
Expand All @@ -21,6 +21,7 @@ def __init__(
"""Initialize parallel context."""
num_gpus_per_model = tensor_parallel_size * pipeline_parallel_size * expert_parallel_size
world_size = int(os.environ["WORLD_SIZE"])
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "8")) if world_size > 8 else world_size

assert (
world_size % data_parallel_size == 0
Expand All @@ -42,6 +43,8 @@ def __init__(
self.pipeline_parallel_size = pipeline_parallel_size
self.data_parallel_size = data_parallel_size
self.expert_parallel_size = expert_parallel_size
self.world_size = world_size
self.local_world_size = local_world_size

self._groups = {}

Expand All @@ -52,7 +55,6 @@ def __init__(
if not dist.is_initialized():
dist.initialize_torch_distributed()

world_size = int(os.getenv("WORLD_SIZE", "1"))
ranks = list(range(world_size))
process_group = dist.new_group(
ranks=ranks,
Expand All @@ -65,8 +67,7 @@ def __init__(
def _init_parallel_groups(self):
"""Initialize 3D parallelism's all process groups."""
dist.barrier()
world_size = int(os.environ["WORLD_SIZE"])
ranks = np.arange(0, world_size).reshape(
ranks = np.arange(0, self.world_size).reshape(
(
self.expert_parallel_size,
self.pipeline_parallel_size,
Expand All @@ -75,6 +76,8 @@ def _init_parallel_groups(self):
)
)
self.world_ranks_to_pg = {}
self.local_pg = self.create_new_group(ranks.reshape((-1, self.local_world_size)))
assert int(os.environ.get("LOCAL_RANK")) == dist.get_rank(self.local_pg), "Local rank mismatch"

# Relevant process groups containing the current rank
self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3)).reshape((-1, self.tensor_parallel_size)))
Expand Down Expand Up @@ -152,4 +155,4 @@ def get_global_rank(
:return: numpy.int64, The global rank.
"""
return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank]
return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, tp_rank]
5 changes: 3 additions & 2 deletions src/nanotron/parallel/tensor_parallel/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,12 @@ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.mode is TensorParallelLinearMode.ALL_REDUCE:
out = differentiable_all_reduce_sum(out, group=self.pg)
elif self.mode is TensorParallelLinearMode.REDUCE_SCATTER:
out = differentiable_reduce_scatter_sum(out, group=self.pg)
# assert that first dim of out is sequence_length
out = differentiable_reduce_scatter_sum(out, group=self.pg) # this should scatter s
else:
raise ValueError(f"Got unexpected mode: {self.mode}.")

return out
return out # [*input_ids.shape, embedding_dim]

def extra_repr(self) -> str:
return f"tp_rank={dist.get_rank(self.pg)}, {super().extra_repr()}, unsharded_num_embeddings={self.original_num_embeddings}"
122 changes: 120 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def __init__(
set_ranks_logging_level(parallel_context=self.parallel_context, logging_config=self.config.logging)

# Log benchmark info
if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
log_throughput(self.config, self.parallel_context)
# if os.environ.get("NANOTRON_BENCHMARK", "0") == "1":
# log_throughput(self.config, self.parallel_context)

########################################
## Setting up our model, optimizers, schedulers, etc.
Expand Down Expand Up @@ -260,6 +260,9 @@ def __init__(
def pre_init(self):
self.init_checkpoint_path = parse_ckpt_path(config=self.config, parallel_context=self.parallel_context)

# Calculate cluster bandwidth
self.BANDWIDTHS = measure_bandwidth(self.parallel_context)

def post_init(self):
# S3 Mover and save initial state
if self.config.s3_upload is not None:
Expand Down Expand Up @@ -663,6 +666,7 @@ def train_step_logs(
model_tflops,
hardware_tflops,
tokens_per_sec,
bandwidth=self.BANDWIDTHS,
)
log_rank("Throughput logging complete", logger=logger, level=logging.INFO)
if "SLURM_JOB_ID" in os.environ:
Expand Down Expand Up @@ -1050,3 +1054,117 @@ def mark_unsharded_params_as_tied_across_expert(
tie_parameters(
root_module=model, ties=shared_weights, parallel_context=parallel_context, reduce_op=reduce_op
)


def measure_bandwidth(parallel_context: ParallelContext):
"""Measure inter-GPU and intra-node bandwidth using NCCL all_reduce and reduce_scatter."""
import time

import torch
import torch.distributed as dist

# Size of data to transfer (256MB in elements)
size = 256 * 1024 * 1024 # Number of elements
tensor = torch.ones(size).cuda()
element_size = tensor.element_size() # Size of each element in bytes
data_size_bytes = size * element_size # Total data size in bytes

# For reduce_scatter, we need a tensor list where each GPU gets a chunk
world_size = dist.get_world_size(parallel_context.world_pg)
local_world_size = dist.get_world_size(parallel_context.local_pg)
chunk_size = size // world_size
local_chunk_size = size // local_world_size
tensor_rs = torch.ones(size).cuda()
output_rs = torch.empty(chunk_size, dtype=tensor_rs.dtype, device=tensor_rs.device)
output_rs_local = torch.empty(local_chunk_size, dtype=tensor_rs.dtype, device=tensor_rs.device)

# Get process groups for inter and intra node communication
inter_node_group = parallel_context.world_pg
intra_node_group = parallel_context.local_pg

dist.barrier(group=inter_node_group)
dist.barrier(group=intra_node_group)

# Warmup both operations for both inter and intra node
for _ in range(5):
# Inter-node warmup
dist.all_reduce(tensor.clone(), group=inter_node_group)
dist.reduce_scatter(output_rs, list(tensor_rs.split(chunk_size)), group=inter_node_group)

# Intra-node warmup
dist.all_reduce(tensor.clone(), group=intra_node_group)
dist.reduce_scatter(output_rs_local, list(tensor_rs.split(local_chunk_size)), group=intra_node_group)

torch.cuda.synchronize()
dist.barrier(group=inter_node_group)

# Measure inter-node all_reduce bandwidth
tic = time.time()
iters = 10

for _ in range(iters):
dist.all_reduce(tensor, group=inter_node_group)

torch.cuda.synchronize()
dist.barrier(group=inter_node_group)
toc = time.time()

# Calculate inter-node all_reduce bandwidth in GB/s
# Include algorithm factor in bandwidth calculation
ar_algo_factor = 2 * (world_size - 1) / world_size
ar_bandwidth = ar_algo_factor * data_size_bytes * iters / (toc - tic) / 1e9

dist.barrier(group=inter_node_group)

# Measure inter-node reduce_scatter bandwidth
tic = time.time()

for _ in range(iters):
dist.reduce_scatter(output_rs, list(tensor_rs.split(chunk_size)), group=inter_node_group)

torch.cuda.synchronize()
dist.barrier(group=inter_node_group)
toc = time.time()

# Calculate inter-node reduce_scatter bandwidth in GB/s
rs_algo_factor = (world_size - 1) / world_size
rs_bandwidth = rs_algo_factor * data_size_bytes * iters / (toc - tic) / 1e9

dist.barrier(group=intra_node_group)

# Measure intra-node all_reduce bandwidth
tic = time.time()

for _ in range(iters):
dist.all_reduce(tensor, group=intra_node_group)

torch.cuda.synchronize()
dist.barrier(group=intra_node_group)
toc = time.time()

# Calculate intra-node all_reduce bandwidth in GB/s
ar_local_algo_factor = 2 * (local_world_size - 1) / local_world_size
ar_bandwidth_intra = ar_local_algo_factor * data_size_bytes * iters / (toc - tic) / 1e9

dist.barrier(group=intra_node_group)

# Measure intra-node reduce_scatter bandwidth
tic = time.time()

for _ in range(iters):
dist.reduce_scatter(output_rs_local, list(tensor_rs.split(local_chunk_size)), group=intra_node_group)

torch.cuda.synchronize()
dist.barrier(group=intra_node_group)
toc = time.time()

# Calculate intra-node reduce_scatter bandwidth in GB/s
rs_local_algo_factor = (local_world_size - 1) / local_world_size
rs_bandwidth_intra = rs_local_algo_factor * data_size_bytes * iters / (toc - tic) / 1e9

return {
"all_reduce": ar_bandwidth,
"reduce_scatter": rs_bandwidth,
"all_reduce_intranode": ar_bandwidth_intra,
"reduce_scatter_intranode": rs_bandwidth_intra,
}

0 comments on commit e3b886c

Please sign in to comment.