Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 33 additions & 4 deletions shim_et/xplat/executorch/build/runtime_wrapper.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,43 @@ def _patch_build_mode_flags(kwargs):

return kwargs

def _has_pytorch_dep(dep_list):
"""Check if a dependency list contains PyTorch/ATen dependencies."""
if not dep_list:
return False
for dep in dep_list:
if type(dep) == "string":
if "torch" in dep or "libtorch" in dep or "caffe2" in dep:
return True
return False

def _patch_test_compiler_flags(kwargs):
if "compiler_flags" not in kwargs:
kwargs["compiler_flags"] = []

# Required globally by all c++ tests.
kwargs["compiler_flags"] += [
"-std=c++17",
]
# Determine C++ standard based on whether this is an aten test.
# Aten tests require at least C++20 to compile against PyTorch, while
# non-aten tests are pinned to C++17 for embedded.
name = kwargs.get("name", "")
external_deps = kwargs.get("external_deps", [])
deps = kwargs.get("deps", [])
xplat_deps = kwargs.get("xplat_deps", [])
fbcode_deps = kwargs.get("fbcode_deps", [])
is_aten_test = (
"_aten" in name or
"aten_" in name or
"libtorch" in external_deps or
"gtest_aten" in external_deps or
"gmock_aten" in external_deps or
_has_pytorch_dep(deps) or
_has_pytorch_dep(xplat_deps) or
_has_pytorch_dep(fbcode_deps)
)

if not is_aten_test:
kwargs["compiler_flags"] += [
"-std=c++17",
]

# Relaxing some constraints for tests
kwargs["compiler_flags"] += [
Expand Down
3 changes: 1 addition & 2 deletions third-party/gtest_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_aten_mode_opt
COMPILER_FLAGS = [
"-std=c++17",
]
COMPILER_FLAGS_ATEN = [
"-std=c++17",]
COMPILER_FLAGS_ATEN = []

# define_gtest_targets
def define_gtest_targets():
Expand Down
Loading