From 2e02cc060db181b7e6d2a145d76c9df2d7143568 Mon Sep 17 00:00:00 2001 From: Olli Saarikivi Date: Wed, 18 Oct 2023 23:03:28 +0000 Subject: [PATCH] More bindings and better test parametrization --- python/mscclpp/__init__.py | 3 +++ python/mscclpp/core_py.cpp | 4 ++++ python/test/test_mscclpp.py | 30 ++++++++++++++++++++++-------- 3 files changed, 29 insertions(+), 8 deletions(-) diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 5165e95cb..f14cbaa4b 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -19,6 +19,9 @@ Transport, TransportFlags, version, + get_ib_device_count, + get_ib_device_name, + get_ib_transport_by_device_name, ) __version__ = version() diff --git a/python/mscclpp/core_py.cpp b/python/mscclpp/core_py.cpp index 086c241ef..c12b2d959 100644 --- a/python/mscclpp/core_py.cpp +++ b/python/mscclpp/core_py.cpp @@ -166,6 +166,10 @@ void register_core(nb::module_& m) { .def("connect", &Communicator::connect, nb::arg("remoteRank"), nb::arg("tag"), nb::arg("localConfig")) .def("remote_rank_of", &Communicator::remoteRankOf) .def("tag_of", &Communicator::tagOf); + + m.def("get_ib_device_count", &getIBDeviceCount); + m.def("get_ib_device_name", &getIBDeviceName, nb::arg("ib_transport")); + m.def("get_ib_transport_by_device_name", &getIBTransportByDeviceName, nb::arg("ib_device_name")); } NB_MODULE(_mscclpp, m) { diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 3af1580a4..e5538553c 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -9,7 +9,7 @@ import netifaces as ni import pytest -from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport +from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport, get_ib_device_count from ._cpp import _ext from .mscclpp_group import MscclppGroup from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group @@ -17,6 +17,19 @@ ethernet_interface_name = "eth0" +skipif_ib = pytest.mark.skipif(get_ib_device_count() == 0, reason="no IB device") + +def parametrize_transport(*transports: list): + def decorator(func): + params = [] + for transport in transports: + if transport == "IB": + params.append(pytest.param(transport, marks=skipif_ib)) + else: + params.append(transport) + return pytest.mark.parametrize("transport", params)(func) + return decorator + def all_ranks_on_the_same_node(mpi_group: MpiGroup): if (ethernet_interface_name in ni.interfaces()) is False: @@ -81,13 +94,13 @@ def create_and_connect(mpi_group: MpiGroup, transport: str): @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@parametrize_transport("IB", "NVLink") def test_group_with_connections(mpi_group: MpiGroup, transport: str): create_and_connect(mpi_group, transport) @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@parametrize_transport("IB", "NVLink") @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int): group, connections = create_and_connect(mpi_group, transport) @@ -122,7 +135,7 @@ def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int) @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@parametrize_transport("IB", "NVLink") @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20, 27]]) @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, nelem: int, device: str): @@ -174,6 +187,7 @@ def test_connection_write_and_signal(mpi_group: MpiGroup, transport: Transport, @parametrize_mpi_groups(2, 4, 8, 16) +@skipif_ib def test_h2h_semaphores(mpi_group: MpiGroup): group, connections = create_and_connect(mpi_group, "IB") @@ -262,7 +276,7 @@ def __call__(self): @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("transport", ["NVLink", "IB"]) +@parametrize_transport("NVLink", "IB") def test_h2d_semaphores(mpi_group: MpiGroup, transport: str): def signal(semaphores): for rank in semaphores: @@ -295,7 +309,7 @@ def test_d2d_semaphores(mpi_group: MpiGroup): @parametrize_mpi_groups(2, 4, 8, 16) -@pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) +@pytest.mark.parametrize("nelem", [2**i for i in [10]]) @pytest.mark.parametrize("use_packet", [False, True]) def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool): group, connections = create_and_connect(mpi_group, "NVLink") @@ -344,7 +358,7 @@ def test_fifo( @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) -@pytest.mark.parametrize("transport", ["IB", "NVLink"]) +@parametrize_transport("IB", "NVLink") def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): group, connections = create_and_connect(mpi_group, transport) @@ -393,7 +407,7 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) -@pytest.mark.parametrize("transport", ["NVLink", "IB"]) +@parametrize_transport("NVLink", "IB") @pytest.mark.parametrize("use_packet", [False, True]) def test_simple_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool): group, connections = create_and_connect(mpi_group, transport)