Skip to content

Commit 5267a41

Browse files
brandon-b-millerleofangrwgk
authored
Add more ObjectCode constructors (#652)
* add ctors * updates * add CU_JIT_INPUT_LIBRARY * lib -> library everywhere * fix * Add missing `@staticmethod` --------- Co-authored-by: Leo Fang <[email protected]> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]> Co-authored-by: Ralf W. Grosse-Kunstleve <[email protected]>
1 parent 3c588e8 commit 5267a41

File tree

2 files changed

+67
-1
lines changed

2 files changed

+67
-1
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,15 @@ def _lazy_init():
7777
"fatbin": _nvjitlink.InputType.FATBIN,
7878
"ltoir": _nvjitlink.InputType.LTOIR,
7979
"object": _nvjitlink.InputType.OBJECT,
80+
"library": _nvjitlink.InputType.LIBRARY,
8081
}
8182
else:
8283
_driver_input_types = {
8384
"ptx": _driver.CUjitInputType.CU_JIT_INPUT_PTX,
8485
"cubin": _driver.CUjitInputType.CU_JIT_INPUT_CUBIN,
8586
"fatbin": _driver.CUjitInputType.CU_JIT_INPUT_FATBINARY,
8687
"object": _driver.CUjitInputType.CU_JIT_INPUT_OBJECT,
88+
"library": _driver.CUjitInputType.CU_JIT_INPUT_LIBRARY,
8789
}
8890
_inited = True
8991

cuda_core/cuda/core/experimental/_module.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class ObjectCode:
276276
"""
277277

278278
__slots__ = ("_handle", "_backend_version", "_code_type", "_module", "_loader", "_sym_map")
279-
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin")
279+
_supported_code_type = ("cubin", "ptx", "ltoir", "fatbin", "object", "library")
280280

281281
def __new__(self, *args, **kwargs):
282282
raise RuntimeError(
@@ -342,6 +342,70 @@ def from_ptx(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None
342342
"""
343343
return ObjectCode._init(module, "ptx", symbol_mapping=symbol_mapping)
344344

345+
@staticmethod
346+
def from_ltoir(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
347+
"""Create an :class:`ObjectCode` instance from an existing LTOIR.
348+
349+
Parameters
350+
----------
351+
module : Union[bytes, str]
352+
Either a bytes object containing the in-memory ltoir code to load, or
353+
a file path string pointing to the on-disk ltoir file to load.
354+
symbol_mapping : Optional[dict]
355+
A dictionary specifying how the unmangled symbol names (as keys)
356+
should be mapped to the mangled names before trying to retrieve
357+
them (default to no mappings).
358+
"""
359+
return ObjectCode._init(module, "ltoir", symbol_mapping=symbol_mapping)
360+
361+
@staticmethod
362+
def from_fatbin(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
363+
"""Create an :class:`ObjectCode` instance from an existing fatbin.
364+
365+
Parameters
366+
----------
367+
module : Union[bytes, str]
368+
Either a bytes object containing the in-memory fatbin to load, or
369+
a file path string pointing to the on-disk fatbin to load.
370+
symbol_mapping : Optional[dict]
371+
A dictionary specifying how the unmangled symbol names (as keys)
372+
should be mapped to the mangled names before trying to retrieve
373+
them (default to no mappings).
374+
"""
375+
return ObjectCode._init(module, "fatbin", symbol_mapping=symbol_mapping)
376+
377+
@staticmethod
378+
def from_object(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
379+
"""Create an :class:`ObjectCode` instance from an existing object code.
380+
381+
Parameters
382+
----------
383+
module : Union[bytes, str]
384+
Either a bytes object containing the in-memory object code to load, or
385+
a file path string pointing to the on-disk object code to load.
386+
symbol_mapping : Optional[dict]
387+
A dictionary specifying how the unmangled symbol names (as keys)
388+
should be mapped to the mangled names before trying to retrieve
389+
them (default to no mappings).
390+
"""
391+
return ObjectCode._init(module, "object", symbol_mapping=symbol_mapping)
392+
393+
@staticmethod
394+
def from_library(module: Union[bytes, str], *, symbol_mapping: Optional[dict] = None) -> "ObjectCode":
395+
"""Create an :class:`ObjectCode` instance from an existing library.
396+
397+
Parameters
398+
----------
399+
module : Union[bytes, str]
400+
Either a bytes object containing the in-memory library to load, or
401+
a file path string pointing to the on-disk library to load.
402+
symbol_mapping : Optional[dict]
403+
A dictionary specifying how the unmangled symbol names (as keys)
404+
should be mapped to the mangled names before trying to retrieve
405+
them (default to no mappings).
406+
"""
407+
return ObjectCode._init(module, "library", symbol_mapping=symbol_mapping)
408+
345409
# TODO: do we want to unload in a finalizer? Probably not..
346410

347411
def _lazy_load_module(self, *args, **kwargs):

0 commit comments

Comments
 (0)