Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions aiu_fms_testing_utils/scripts/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,12 @@
default=0,
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group",
)
parser.add_argument(
"--numa",
action="store_true",
help="NUMA aware task distribution (requires distributed option)",
)

args = parser.parse_args()

attention_map = {
Expand Down Expand Up @@ -327,6 +333,21 @@
dist.init_process_group()
# Fix until PT 2.3
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
if args.numa:
try:
from numa import info
numa_num_nodes = info.get_num_configured_nodes()
numa_world_size = dist.get_world_size()
numa_size_per_node = numa_world_size // numa_num_nodes
from numa import schedule
numa_rank = dist.get_rank()
numa_node = dist.get_rank() // numa_size_per_node
schedule.run_on_nodes(numa_node)
from numa import memory
memory.set_local_alloc()
dprint(f"NUMA: process {numa_rank} set to node {numa_node}")
except:
dprint(f"NUMA not available in this machine, please install libnuma libraries")
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())

if args.device_type == "cuda":
Expand Down