Skip to content

Commit

Permalink
Add cuquantum_configure rule in Bazel WORKSPACE
Browse files Browse the repository at this point in the history
  • Loading branch information
jaeyoo committed Mar 30, 2023
1 parent 2212128 commit 2d9eed9
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 21 deletions.
21 changes: 2 additions & 19 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,6 @@ cc_library(
],
)

# TODO(jaeyoo) : add bzl rule to configure cuquantum if there exists `CUQUANTUM_DIR` environ.
CUQUANTUM_DIR = "/usr/local/google/home/jaeyoo/workspace/cuquantum-linux-x86_64-22.11.0.13-archive"
load("//third_party/cuquantum:cuquantum_configure.bzl", "cuquantum_configure")

new_local_repository(
name = "cuquantum_libs",
path = CUQUANTUM_DIR,
build_file_content = """
cc_library(
name = "custatevec_headers",
srcs = ["include/custatevec.h"],
visibility = ["//visibility:public"],
)
cc_library(
name = "custatevec",
srcs = ["lib/libcustatevec.so"],
visibility = ["//visibility:public"],
)
""",
)
cuquantum_configure(name = "local_config_cuquantum")
4 changes: 2 additions & 2 deletions lib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ cuda_library(
"vectorspace_cuda.h",
],
deps = [
"@cuquantum_libs//:custatevec",
"@cuquantum_libs//:custatevec_headers",
"@local_config_cuquantum//:cuquantum_headers",
"@local_config_cuquantum//:libcuquantum",
],
)

Expand Down
Empty file added third_party/BUILD
Empty file.
Empty file added third_party/cuquantum/BUILD
Empty file.
21 changes: 21 additions & 0 deletions third_party/cuquantum/BUILD.tpl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package(default_visibility = ["//visibility:public"])

cc_library(
name = "cuquantum_headers",
linkstatic = 1,
srcs = [":cuquantum_header_include"],
visibility = ["//visibility:public"],
)

cc_library(
name = "libcuquantum",
srcs = [
":libcustatevec.so",
":libcutensornet.so",
],
visibility = ["//visibility:public"],
)

%{CUQUANTUM_HEADER_GENRULE}
%{CUSTATEVEC_SHARED_LIBRARY_GENRULE}
%{CUTENSORNET_SHARED_LIBRARY_GENRULE}
210 changes: 210 additions & 0 deletions third_party/cuquantum/cuquantum_configure.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
"""Setup cuQuantum as external dependency"""
_CUQUANTUM_ROOT = "CUQUANTUM_ROOT"


def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl
repository_ctx.template(
out,
Label("//third_party/cuquantum:%s.tpl" % tpl),
substitutions,
)


def _fail(msg):
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))


def _execute(
repository_ctx,
cmdline,
error_msg = None,
error_details = None,
empty_stdout_fine = False):
"""Executes an arbitrary shell command.
Args:
repository_ctx: the repository_ctx object
cmdline: list of strings, the command to execute
error_msg: string, a summary of the error if the command fails
error_details: string, details about the error or steps to fix it
empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
it's an error
Return:
the result of repository_ctx.execute(cmdline)
"""
result = repository_ctx.execute(cmdline)
if result.stderr or not (empty_stdout_fine or result.stdout):
_fail("\n".join([
error_msg.strip() if error_msg else "Repository command failed",
result.stderr.strip(),
error_details if error_details else "",
]))
return result


def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory.
Finds all files inside a directory, traversing subfolders and following
symlinks. The returned string contains the full path of all files
separated by line breaks.
"""
find_result = _execute(
repository_ctx,
["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine = True,
)
result = find_result.stdout
return result

def _genrule(genrule_name, command, outs):
"""Returns a string with a genrule.
Genrule executes the given command and produces the given outputs.
Args:
genrule_name: A unique name for genrule target.
command: The command to run.
outs: A list of files generated by this rule.
Returns:
A genrule target.
"""
return (
"genrule(\n" +
' name = "' +
genrule_name + '",\n' +
" outs = [\n" +
outs +
"\n ],\n" +
' cmd = """\n' +
command +
'\n """,\n' +
")\n"
)

def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/")
if path[-1] == "/":
path = path[:-1]
return path


def _symlink_genrule_for_dir(
repository_ctx,
src_dir,
dest_dir,
genrule_name,
src_files = [],
dest_files = [],
is_empty_genrule = False):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
we assume files are in src_files and dest_files.
Args:
repository_ctx: the repository_ctx object.
src_dir: source directory.
dest_dir: directory to create symlink in.
genrule_name: genrule name.
src_files: list of source files instead of src_dir.
dest_files: list of corresonding destination files.
is_empty_genrule: True if CUQUANTUM_ROOT is not set.
Returns:
genrule target that creates the symlinks.
"""
if is_empty_genrule:
genrule = _genrule(
genrule_name,
"echo 'this genrule is empty because CUQUANTUM_ROOT is not set.' && touch %s.h" % genrule_name,
"'%s.h'" % genrule_name,
)
return genrule

if src_dir != None:
src_dir = _norm_path(src_dir)
dest_dir = _norm_path(dest_dir)
files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))

dest_files = files.replace(src_dir, "").splitlines()
src_files = files.splitlines()
command = []
outs = []

for i in range(len(dest_files)):
if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]

# Copy the headers to create a sandboxable setup.
cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')

genrule = _genrule(
genrule_name,
" && ".join(command),
"\n".join(outs),
)
return genrule


def _cuquantum_pip_imple(repository_ctx):
cuquantum_root = repository_ctx.os.environ[_CUQUANTUM_ROOT]

is_empty_genrule = cuquantum_root == ""

cuquantum_header_path = "%s/include" % cuquantum_root

cuquantum_header_rule = _symlink_genrule_for_dir(
repository_ctx,
cuquantum_header_path,
"include",
"cuquantum_header_include",
is_empty_genrule=is_empty_genrule,
)
custatevec_shared_library_path = "%s/lib/libcustatevec.so" % (cuquantum_root)

custatevec_shared_library_rule = _symlink_genrule_for_dir(
repository_ctx,
None,
"",
"libcustatevec.so",
[custatevec_shared_library_path],
["libcustatevec.so"],
is_empty_genrule=is_empty_genrule,
)

cutensornet_shared_library_path = "%s/lib/libcutensornet.so" % (cuquantum_root)

cutensornet_shared_library_rule = _symlink_genrule_for_dir(
repository_ctx,
None,
"",
"libcutensornet.so",
[cutensornet_shared_library_path],
["libcutensornet.so"],
is_empty_genrule=is_empty_genrule,
)

_tpl(repository_ctx, "BUILD", {
"%{CUQUANTUM_HEADER_GENRULE}": cuquantum_header_rule,
"%{CUSTATEVEC_SHARED_LIBRARY_GENRULE}": custatevec_shared_library_rule,
"%{CUTENSORNET_SHARED_LIBRARY_GENRULE}": cutensornet_shared_library_rule,
})



cuquantum_configure = repository_rule(
implementation = _cuquantum_pip_imple,
environ = [
_CUQUANTUM_ROOT,
],
)

0 comments on commit 2d9eed9

Please sign in to comment.