Skip to content

Commit bc5b8e0

Browse files
authored
Merge pull request #65 from adamcrume/binary
Install TensorFlow from a prebuilt binary when possible
2 parents e9c0927 + 2d7e6a5 commit bc5b8e0

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

.travis.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
language: rust
22
sudo: false
3+
dist: trusty # still in beta, but required for the prebuilt TF binaries
34

45
cache:
56
cargo: true
67
directories:
78
- $HOME/.cache/bazel
89

9-
rust: nightly
10+
rust: stable
1011

1112
install:
1213
- export CC="gcc-4.9" CXX="g++-4.9"
1314
- source travis-ci/install.sh
1415

1516
script:
17+
- export RUST_BACKTRACE=1
1618
- cargo test -vv -j 2 --features tensorflow_unstable
1719
- cargo run --example regression
1820
- cargo run --features tensorflow_unstable --example expressions
1921
- cargo doc -vv --features tensorflow_unstable
20-
- (cd tensorflow-sys && cargo test -vv -j 1)
22+
- # TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1)
2123
- (cd tensorflow-sys && cargo doc -vv)
2224

2325
addons:

tensorflow-sys/Cargo.toml

+6
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,11 @@ links = "tensorflow"
1717
libc = "0.2"
1818

1919
[build-dependencies]
20+
curl = "0.4"
21+
flate2 = "0.2"
2022
pkg-config = "0.3"
2123
semver = "0.5"
24+
tar = "0.4"
25+
26+
[features]
27+
tensorflow_gpu = []

tensorflow-sys/build.rs

+100-1
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1+
extern crate curl;
2+
extern crate flate2;
13
extern crate pkg_config;
24
extern crate semver;
5+
extern crate tar;
36

47
use std::error::Error;
58
use std::fs::File;
9+
use std::io::BufWriter;
10+
use std::io::Write;
611
use std::path::{Path, PathBuf};
712
use std::process;
813
use std::process::Command;
914
use std::{env, fs};
15+
16+
use curl::easy::Easy;
17+
use flate2::read::GzDecoder;
1018
use semver::Version;
19+
use tar::Archive;
1120

1221
const LIBRARY: &'static str = "tensorflow";
1322
const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git";
1423
const TARGET: &'static str = "tensorflow:libtensorflow.so";
24+
// `VERSION` and `TAG` are separate because the tag is not always `'v' + VERSION`.
25+
const VERSION: &'static str = "1.0.0";
1526
const TAG: &'static str = "v1.0.0";
1627
const MIN_BAZEL: &'static str = "0.3.2";
1728

@@ -30,6 +41,91 @@ fn main() {
3041
return;
3142
}
3243

44+
let force_src = match env::var("TF_RUST_BUILD_FROM_SRC") {
45+
Ok(s) => s == "true",
46+
Err(_) => false,
47+
};
48+
if !force_src && env::consts::ARCH == "x86_64" && (env::consts::OS == "linux" || env::consts::OS == "macos") {
49+
install_prebuilt();
50+
} else {
51+
build_from_src();
52+
}
53+
}
54+
55+
fn remove_suffix(value: &mut String, suffix: &str) {
56+
if value.ends_with(suffix) {
57+
let n = value.len();
58+
value.truncate(n - suffix.len());
59+
}
60+
}
61+
62+
fn extract<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
63+
let file = File::open(archive_path).unwrap();
64+
let unzipped = GzDecoder::new(file).unwrap();
65+
let mut a = Archive::new(unzipped);
66+
a.unpack(extract_to).unwrap();
67+
}
68+
69+
// Downloads and unpacks a prebuilt binary. Only works for certain platforms.
70+
fn install_prebuilt() {
71+
// Figure out the file names.
72+
let os = match env::consts::OS {
73+
"macos" => "darwin",
74+
x => x,
75+
};
76+
let proc_type = if cfg!(feature = "tensorflow_gpu") {"gpu"} else {"cpu"};
77+
let binary_url = format!(
78+
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}.tar.gz",
79+
proc_type, os, env::consts::ARCH, VERSION);
80+
log_var!(binary_url);
81+
let short_file_name = binary_url.split("/").last().unwrap();
82+
let mut base_name = short_file_name.to_string();
83+
remove_suffix(&mut base_name, ".tar.gz");
84+
log_var!(base_name);
85+
let download_dir = match env::var("TF_RUST_DOWNLOAD_DIR") {
86+
Ok(s) => PathBuf::from(s),
87+
Err(_) => PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join("target"),
88+
};
89+
if !download_dir.exists() {
90+
fs::create_dir(&download_dir).unwrap();
91+
}
92+
let file_name = download_dir.join(short_file_name);
93+
log_var!(file_name);
94+
95+
// Download the tarball.
96+
if !file_name.exists() {
97+
let f = File::create(&file_name).unwrap();
98+
let mut writer = BufWriter::new(f);
99+
let mut easy = Easy::new();
100+
easy.url(&binary_url).unwrap();
101+
easy.write_function(move |data| {
102+
Ok(writer.write(data).unwrap())
103+
}).unwrap();
104+
easy.perform().unwrap();
105+
106+
let response_code = easy.response_code().unwrap();
107+
if response_code != 200 {
108+
panic!("Unexpected response code {} for {}", response_code, binary_url);
109+
}
110+
}
111+
112+
// Extract the tarball.
113+
let unpacked_dir = download_dir.join(base_name);
114+
let lib_dir = unpacked_dir.join("lib");
115+
if !lib_dir.join(format!("lib{}.so", LIBRARY)).exists() {
116+
extract(file_name, &unpacked_dir);
117+
}
118+
119+
//run("find", |command| command); // TODO: remove
120+
run("ls", |command| {
121+
command.arg("-l").arg(lib_dir.to_str().unwrap())
122+
}); // TODO: remove
123+
124+
println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
125+
println!("cargo:rustc-link-search={}", lib_dir.display());
126+
}
127+
128+
fn build_from_src() {
33129
let output = PathBuf::from(&get!("OUT_DIR"));
34130
log_var!(output);
35131
let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join(format!("target/source-{}", TAG));
@@ -71,7 +167,10 @@ fn main() {
71167
let configure_hint_file = Path::new(&configure_hint_file_pb);
72168
if !configure_hint_file.exists() {
73169
run("bash",
74-
|command| command.current_dir(&source).arg("-c").arg("yes ''|./configure"));
170+
|command| command.current_dir(&source)
171+
.env("TF_NEED_CUDA", if cfg!(feature = "tensorflow_gpu") {"1"} else {"0"})
172+
.arg("-c")
173+
.arg("yes ''|./configure"));
75174
File::create(configure_hint_file).unwrap();
76175
}
77176
run("bazel", |command| {

0 commit comments

Comments
 (0)