Skip to content

Commit 1c54e2e

Browse files
committed
numa support
Signed-off-by: Mauricio J. Serrano <[email protected]>
1 parent 30e55c9 commit 1c54e2e

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

aiu_fms_testing_utils/scripts/inference.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,12 @@
257257
default=0,
258258
help="Timeout to use for messaging in minutes. Default set by PyTorch dist.init_process_group",
259259
)
260+
parser.add_argument(
261+
"--numa",
262+
action="store_true",
263+
help="NUMA aware task distribution (requires distributed option)",
264+
)
265+
260266
args = parser.parse_args()
261267

262268
attention_map = {
@@ -327,6 +333,19 @@
327333
dist.init_process_group()
328334
# Fix until PT 2.3
329335
torch._C._distributed_c10d._register_process_group("default", dist.group.WORLD)
336+
if args.numa:
337+
try:
338+
from numa import info
339+
numa_num_nodes = info.get_num_configured_nodes()
340+
numa_world_size = dist.get_world_size()
341+
numa_size_per_node = numa_world_size // numa_num_nodes
342+
from numa import schedule
343+
numa_rank = dist.get_rank()
344+
numa_node = dist.get_rank() // numa_size_per_node
345+
schedule.run_on_nodes(numa_node)
346+
dprint(f"NUMA: process {numa_rank} set to node {numa_node}")
347+
except:
348+
dprint(f"NUMA not available in this machine, please install libnuma libraries")
330349
aiu_setup.aiu_dist_setup(dist.get_rank(), dist.get_world_size())
331350

332351
if args.device_type == "cuda":

0 commit comments

Comments
 (0)