Skip to content

[WIP] Add support for intra-session mutable state. #609

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: development
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion compiler_gym/envs/llvm/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ cc_binary(
name = "compiler_gym-llvm-service-prelinked",
srcs = ["RunService.cc"],
deps = [
":BenchmarkFactory",
":LlvmServiceContext",
":LlvmSession",
"//compiler_gym/service/runtime:cc_runtime",
],
Expand Down Expand Up @@ -207,6 +207,18 @@ cc_library(
],
)

cc_library(
name = "LlvmServiceContext",
srcs = ["LlvmServiceContext.cc"],
hdrs = ["LlvmServiceContext.h"],
deps = [
":BenchmarkFactory",
"//compiler_gym/service:CompilerGymServiceContext",
"//compiler_gym/util:GrpcStatusMacros",
"@llvm//10.0.0",
],
)

cc_library(
name = "LlvmSession",
srcs = ["LlvmSession.cc"],
Expand All @@ -224,6 +236,7 @@ cc_library(
":Benchmark",
":BenchmarkFactory",
":Cost",
":LlvmServiceContext",
":Observation",
":ObservationSpaces",
"//compiler_gym/service:CompilationSession",
Expand Down
44 changes: 14 additions & 30 deletions compiler_gym/envs/llvm/service/BenchmarkFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,22 @@ constexpr size_t kMaxLoadedBenchmarksCount = 128;
class BenchmarkFactory {
public:
/**
* Return the global benchmark factory singleton.
* Construct a benchmark factory.
*
* @param workingDirectory The working directory.
* @param rand An optional random number generator. This is used for cache
* evictions.
* @param maxLoadedBenchmarksCount The maximum number of benchmarks to cache.
* @return The benchmark factory singleton instance.
* @param workingDirectory A filesystem directory to use for storing temporary
* files.
* @param rand is a random seed used to control the selection of random
* benchmarks.
* @param maxLoadedBenchmarksCount is the maximum combined size of the bitcodes
* that may be cached in memory. Once this size is reached, benchmarks are
* offloaded so that they must be re-read from disk.
*/
static BenchmarkFactory& getSingleton(
const boost::filesystem::path& workingDirectory,
std::optional<std::mt19937_64> rand = std::nullopt,
size_t maxLoadedBenchmarksCount = kMaxLoadedBenchmarksCount) {
static BenchmarkFactory instance(workingDirectory, rand, maxLoadedBenchmarksCount);
return instance;
}
BenchmarkFactory(const boost::filesystem::path& workingDirectory,
std::optional<std::mt19937_64> rand = std::nullopt,
size_t maxLoadedBenchmarksCount = kMaxLoadedBenchmarksCount);

BenchmarkFactory(const BenchmarkFactory&) = delete;
BenchmarkFactory& operator=(const BenchmarkFactory&) = delete;

~BenchmarkFactory();

Expand All @@ -86,23 +87,6 @@ class BenchmarkFactory {
const std::string& uri, const boost::filesystem::path& path,
std::optional<compiler_gym::BenchmarkDynamicConfig> dynamicConfig = std::nullopt);

/**
* Construct a benchmark factory.
*
* @param workingDirectory A filesystem directory to use for storing temporary
* files.
* @param rand is a random seed used to control the selection of random
* benchmarks.
* @param maxLoadedBenchmarksCount is the maximum combined size of the bitcodes
* that may be cached in memory. Once this size is reached, benchmarks are
* offloaded so that they must be re-read from disk.
*/
BenchmarkFactory(const boost::filesystem::path& workingDirectory,
std::optional<std::mt19937_64> rand, size_t maxLoadedBenchmarksCount);

BenchmarkFactory(const BenchmarkFactory&) = delete;
BenchmarkFactory& operator=(const BenchmarkFactory&) = delete;

/**
* A mapping from URI to benchmarks which have been loaded into memory.
*/
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/ComputeObservation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ int main(int argc, char** argv) {
benchmarkMessage.set_uri("user");
benchmarkMessage.mutable_program()->set_uri(fmt::format("file:///{}", argv[2]));

auto& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory);
BenchmarkFactory benchmarkFactory{workingDirectory};
std::unique_ptr<::llvm_service::Benchmark> benchmark;
{
const auto status = benchmarkFactory.getBenchmark(benchmarkMessage, &benchmark);
Expand Down
73 changes: 73 additions & 0 deletions compiler_gym/envs/llvm/service/LlvmServiceContext.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"

#include "compiler_gym/util/GrpcStatusMacros.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
#include "llvm/Support/TargetSelect.h"

using grpc::Status;

namespace {

void initLlvm() {
llvm::InitializeNativeTarget();

// Initialize passes.
llvm::PassRegistry& Registry = *llvm::PassRegistry::getPassRegistry();
llvm::initializeCore(Registry);
llvm::initializeCoroutines(Registry);
llvm::initializeScalarOpts(Registry);
llvm::initializeObjCARCOpts(Registry);
llvm::initializeVectorization(Registry);
llvm::initializeIPO(Registry);
llvm::initializeAnalysis(Registry);
llvm::initializeTransformUtils(Registry);
llvm::initializeInstCombine(Registry);
llvm::initializeAggressiveInstCombine(Registry);
llvm::initializeInstrumentation(Registry);
llvm::initializeTarget(Registry);
llvm::initializeExpandMemCmpPassPass(Registry);
llvm::initializeScalarizeMaskedMemIntrinPass(Registry);
llvm::initializeCodeGenPreparePass(Registry);
llvm::initializeAtomicExpandPass(Registry);
llvm::initializeRewriteSymbolsLegacyPassPass(Registry);
llvm::initializeWinEHPreparePass(Registry);
llvm::initializeDwarfEHPreparePass(Registry);
llvm::initializeSafeStackLegacyPassPass(Registry);
llvm::initializeSjLjEHPreparePass(Registry);
llvm::initializePreISelIntrinsicLoweringLegacyPassPass(Registry);
llvm::initializeGlobalMergePass(Registry);
llvm::initializeIndirectBrExpandPassPass(Registry);
llvm::initializeInterleavedAccessPass(Registry);
llvm::initializeEntryExitInstrumenterPass(Registry);
llvm::initializePostInlineEntryExitInstrumenterPass(Registry);
llvm::initializeUnreachableBlockElimLegacyPassPass(Registry);
llvm::initializeExpandReductionsPass(Registry);
llvm::initializeWasmEHPreparePass(Registry);
llvm::initializeWriteBitcodePassPass(Registry);
}

} // anonymous namespace

namespace compiler_gym::llvm_service {

LlvmServiceContext::LlvmServiceContext(const boost::filesystem::path& workingDirectory)
: CompilerGymServiceContext(workingDirectory), benchmarkFactory_(workingDirectory) {}

Status LlvmServiceContext::init() {
RETURN_IF_ERROR(CompilerGymServiceContext::init());
initLlvm();
return Status::OK;
}

Status LlvmServiceContext::shutdown() {
Status status = CompilerGymServiceContext::shutdown();
benchmarkFactory().close();
return status;
}

} // namespace compiler_gym::llvm_service
29 changes: 29 additions & 0 deletions compiler_gym/envs/llvm/service/LlvmServiceContext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#pragma once

#include <grpcpp/grpcpp.h>

#include "boost/filesystem.hpp"
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
#include "compiler_gym/service/CompilerGymServiceContext.h"

namespace compiler_gym::llvm_service {

class LlvmServiceContext final : public CompilerGymServiceContext {
public:
LlvmServiceContext(const boost::filesystem::path& workingDirectory);

[[nodiscard]] virtual grpc::Status init() final override;

[[nodiscard]] virtual grpc::Status shutdown() final override;

BenchmarkFactory& benchmarkFactory() { return benchmarkFactory_; }

private:
BenchmarkFactory benchmarkFactory_;
};

} // namespace compiler_gym::llvm_service
20 changes: 12 additions & 8 deletions compiler_gym/envs/llvm/service/LlvmSession.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "compiler_gym/envs/llvm/service/Benchmark.h"
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
#include "compiler_gym/envs/llvm/service/Cost.h"
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"
#include "compiler_gym/envs/llvm/service/Observation.h"
#include "compiler_gym/envs/llvm/service/ObservationSpaces.h"
#include "compiler_gym/envs/llvm/service/passes/10.0.0/ActionHeaders.h"
Expand Down Expand Up @@ -76,18 +77,19 @@ std::vector<ObservationSpace> LlvmSession::getObservationSpaces() const {
return getLlvmObservationSpaceList();
}

LlvmSession::LlvmSession(const boost::filesystem::path& workingDirectory)
: CompilationSession(workingDirectory),
LlvmSession::LlvmSession(CompilerGymServiceContext* const context)
: CompilationSession(context),
observationSpaceNames_(util::createPascalCaseToEnumLookupTable<LlvmObservationSpace>()) {
// TODO: Move CPUInfo initialize to context setup!
cpuinfo_initialize();
}

Status LlvmSession::init(const ActionSpace& actionSpace, const BenchmarkProto& benchmark) {
BenchmarkFactory& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory());
LlvmServiceContext* const ctx = static_cast<LlvmServiceContext*>(context());

// Get the benchmark or return an error.
std::unique_ptr<Benchmark> llvmBenchmark;
RETURN_IF_ERROR(benchmarkFactory.getBenchmark(benchmark, &llvmBenchmark));
RETURN_IF_ERROR(ctx->benchmarkFactory().getBenchmark(benchmark, &llvmBenchmark));

// Verify the benchmark now to catch errors early.
RETURN_IF_ERROR(llvmBenchmark->verify_module());
Expand All @@ -101,7 +103,8 @@ Status LlvmSession::init(const ActionSpace& actionSpace, const BenchmarkProto& b
Status LlvmSession::init(CompilationSession* other) {
// TODO: Static cast?
auto llvmOther = static_cast<LlvmSession*>(other);
return init(llvmOther->actionSpace(), llvmOther->benchmark().clone(workingDirectory()));
return init(llvmOther->actionSpace(),
llvmOther->benchmark().clone(context()->workingDirectory()));
}

Status LlvmSession::init(const LlvmActionSpace& actionSpace, std::unique_ptr<Benchmark> benchmark) {
Expand Down Expand Up @@ -156,7 +159,8 @@ Status LlvmSession::computeObservation(const ObservationSpace& observationSpace,
}
const LlvmObservationSpace observationSpaceEnum = it->second;

return setObservation(observationSpaceEnum, workingDirectory(), benchmark(), observation);
return setObservation(observationSpaceEnum, context()->workingDirectory(), benchmark(),
observation);
}

Status LlvmSession::handleSessionParameter(const std::string& key, const std::string& value,
Expand Down Expand Up @@ -256,8 +260,8 @@ bool LlvmSession::runPass(llvm::FunctionPass* pass) {

Status LlvmSession::runOptWithArgs(const std::vector<std::string>& optArgs) {
// Create temporary files for `opt` to read from and write to.
const auto before_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
const auto after_path = fs::unique_path(workingDirectory() / "module-%%%%%%%%.bc");
const auto before_path = fs::unique_path(context()->workingDirectory() / "module-%%%%%%%%.bc");
const auto after_path = fs::unique_path(context()->workingDirectory() / "module-%%%%%%%%.bc");
RETURN_IF_ERROR(writeBitcodeFile(benchmark().module(), before_path));

// Build a command line invocation: `opt input.bc -o output.bc <optArgs...>`.
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/LlvmSession.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ namespace compiler_gym::llvm_service {
*/
class LlvmSession final : public CompilationSession {
public:
LlvmSession(const boost::filesystem::path& workingDirectory);
LlvmSession(CompilerGymServiceContext* const context);

std::string getCompilerVersion() const final override;

Expand Down
61 changes: 2 additions & 59 deletions compiler_gym/envs/llvm/service/RunService.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,15 @@
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "compiler_gym/envs/llvm/service/BenchmarkFactory.h"
#include "compiler_gym/envs/llvm/service/LlvmServiceContext.h"
#include "compiler_gym/envs/llvm/service/LlvmSession.h"
#include "compiler_gym/service/runtime/Runtime.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/TargetSelect.h"

const char* usage = R"(LLVM CompilerGym service)";

using namespace compiler_gym::runtime;
using namespace compiler_gym::llvm_service;

namespace {

void initLlvm() {
llvm::InitializeNativeTarget();

// Initialize passes.
llvm::PassRegistry& Registry = *llvm::PassRegistry::getPassRegistry();
llvm::initializeCore(Registry);
llvm::initializeCoroutines(Registry);
llvm::initializeScalarOpts(Registry);
llvm::initializeObjCARCOpts(Registry);
llvm::initializeVectorization(Registry);
llvm::initializeIPO(Registry);
llvm::initializeAnalysis(Registry);
llvm::initializeTransformUtils(Registry);
llvm::initializeInstCombine(Registry);
llvm::initializeAggressiveInstCombine(Registry);
llvm::initializeInstrumentation(Registry);
llvm::initializeTarget(Registry);
llvm::initializeExpandMemCmpPassPass(Registry);
llvm::initializeScalarizeMaskedMemIntrinPass(Registry);
llvm::initializeCodeGenPreparePass(Registry);
llvm::initializeAtomicExpandPass(Registry);
llvm::initializeRewriteSymbolsLegacyPassPass(Registry);
llvm::initializeWinEHPreparePass(Registry);
llvm::initializeDwarfEHPreparePass(Registry);
llvm::initializeSafeStackLegacyPassPass(Registry);
llvm::initializeSjLjEHPreparePass(Registry);
llvm::initializePreISelIntrinsicLoweringLegacyPassPass(Registry);
llvm::initializeGlobalMergePass(Registry);
llvm::initializeIndirectBrExpandPassPass(Registry);
llvm::initializeInterleavedAccessPass(Registry);
llvm::initializeEntryExitInstrumenterPass(Registry);
llvm::initializePostInlineEntryExitInstrumenterPass(Registry);
llvm::initializeUnreachableBlockElimLegacyPassPass(Registry);
llvm::initializeExpandReductionsPass(Registry);
llvm::initializeWasmEHPreparePass(Registry);
llvm::initializeWriteBitcodePassPass(Registry);
}

} // anonymous namespace

int main(int argc, char** argv) {
initLlvm();
const auto ret = createAndRunCompilerGymService<LlvmSession>(argc, argv, usage);

// NOTE(github.com/facebookresearch/CompilerGym/issues/582): We need to make
// sure that BenchmarkFactory::close() is called on the global singleton
// instance, so that the temporary scratch directories are tidied up.
//
// TODO(github.com/facebookresearch/CompilerGym/issues/591): Once the runtime
// has been refactored to support intra-session mutable state, this singleton
// can be replaced by a member variable that is closed on
// CompilerGymServiceContext::shutdown().
BenchmarkFactory::getSingleton(FLAGS_working_dir).close();

return ret;
return createAndRunCompilerGymService<LlvmSession, LlvmServiceContext>(argc, argv, usage);
}
2 changes: 1 addition & 1 deletion compiler_gym/envs/llvm/service/StripOptNoneAttribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ int main(int argc, char** argv) {
google::InitGoogleLogging(argv[0]);

const fs::path workingDirectory{"."};
auto& benchmarkFactory = BenchmarkFactory::getSingleton(workingDirectory);
BenchmarkFactory benchmarkFactory(workingDirectory);

for (int i = 1; i < argc; ++i) {
stripOptNoneAttributesOrDie(argv[i], benchmarkFactory);
Expand Down
13 changes: 13 additions & 0 deletions compiler_gym/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ cc_library(
srcs = ["CompilationSession.cc"],
hdrs = ["CompilationSession.h"],
visibility = ["//visibility:public"],
deps = [
":CompilerGymServiceContext",
"//compiler_gym/service/proto:compiler_gym_service_cc",
"@boost//:filesystem",
"@com_github_grpc_grpc//:grpc++",
],
)

cc_library(
name = "CompilerGymServiceContext",
srcs = ["CompilerGymServiceContext.cc"],
hdrs = ["CompilerGymServiceContext.h"],
visibility = ["//visibility:public"],
deps = [
"//compiler_gym/service/proto:compiler_gym_service_cc",
"@boost//:filesystem",
Expand Down
4 changes: 2 additions & 2 deletions compiler_gym/service/CompilationSession.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Status CompilationSession::handleSessionParameter(const std::string& key, const
return Status::OK;
}

CompilationSession::CompilationSession(const boost::filesystem::path& workingDirectory)
: workingDirectory_(workingDirectory) {}
CompilationSession::CompilationSession(CompilerGymServiceContext* const context)
: context_(context) {}

} // namespace compiler_gym
Loading