Skip to content

Commit 17361fa

Browse files
committed
Install TensorFlow from a prebuilt binary when possible
1 parent 3e078d3 commit 17361fa

File tree

3 files changed

+101
-1
lines changed

3 files changed

+101
-1
lines changed

.travis.yml

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
language: rust
22
sudo: false
3+
dist: trusty # still in beta, but required for the prebuilt TF binaries
34

45
cache:
56
cargo: true
@@ -13,6 +14,7 @@ install:
1314
- source travis-ci/install.sh
1415

1516
script:
17+
- export RUST_BACKTRACE=1
1618
- cargo test -vv -j 2 --features tensorflow_unstable
1719
- cargo run --example regression
1820
- cargo run --features tensorflow_unstable --example expressions

tensorflow-sys/Cargo.toml

+6
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,11 @@ links = "tensorflow"
1717
libc = "0.2"
1818

1919
[build-dependencies]
20+
curl = "0.4"
21+
flate2 = "0.2"
2022
pkg-config = "0.3"
2123
semver = "0.5"
24+
tar = "0.4"
25+
26+
[features]
27+
tensorflow_gpu = []

tensorflow-sys/build.rs

+93-1
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1+
extern crate curl;
2+
extern crate flate2;
13
extern crate pkg_config;
24
extern crate semver;
5+
extern crate tar;
36

47
use std::error::Error;
58
use std::fs::File;
9+
use std::io::BufWriter;
10+
use std::io::Write;
611
use std::path::{Path, PathBuf};
712
use std::process;
813
use std::process::Command;
914
use std::{env, fs};
15+
16+
use curl::easy::Easy;
17+
use flate2::read::GzDecoder;
1018
use semver::Version;
19+
use tar::Archive;
1120

1221
const LIBRARY: &'static str = "tensorflow";
1322
const REPOSITORY: &'static str = "https://github.com/tensorflow/tensorflow.git";
1423
const TARGET: &'static str = "tensorflow:libtensorflow.so";
24+
// `VERSION` and `TAG` are separate because the tag is not always `'v' + VERSION`.
25+
const VERSION: &'static str = "1.0.0";
1526
const TAG: &'static str = "v1.0.0";
1627
const MIN_BAZEL: &'static str = "0.3.2";
1728

@@ -30,6 +41,84 @@ fn main() {
3041
return;
3142
}
3243

44+
if env::consts::ARCH == "x86_64" && (env::consts::OS == "linux" || env::consts::OS == "macos") {
45+
install_prebuilt();
46+
} else {
47+
build_from_src();
48+
}
49+
}
50+
51+
fn remove_suffix(value: &mut String, suffix: &str) {
52+
if value.ends_with(suffix) {
53+
let n = value.len();
54+
value.truncate(n - suffix.len());
55+
}
56+
}
57+
58+
fn extract<P: AsRef<Path>, P2: AsRef<Path>>(archive_path: P, extract_to: P2) {
59+
let file = File::open(archive_path).unwrap();
60+
let unzipped = GzDecoder::new(file).unwrap();
61+
let mut a = Archive::new(unzipped);
62+
a.unpack(extract_to).unwrap();
63+
}
64+
65+
// Downloads and unpacks a prebuilt binary. Only works for certain platforms.
66+
fn install_prebuilt() {
67+
// Figure out the file names.
68+
let os = match env::consts::OS {
69+
"macos" => "darwin",
70+
x => x,
71+
};
72+
let proc_type = if cfg!(feature = "tensorflow_gpu") {"gpu"} else {"cpu"};
73+
let binary_url = format!(
74+
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}.tar.gz",
75+
proc_type, os, env::consts::ARCH, VERSION);
76+
log_var!(binary_url);
77+
let short_file_name = binary_url.split("/").last().unwrap();
78+
let mut base_name = short_file_name.to_string();
79+
remove_suffix(&mut base_name, ".tar.gz");
80+
log_var!(base_name);
81+
let target_dir = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join("target");
82+
if !target_dir.exists() {
83+
fs::create_dir(&target_dir).unwrap();
84+
}
85+
let file_name = target_dir.join(short_file_name);
86+
log_var!(file_name);
87+
88+
// Download the tarball.
89+
if !file_name.exists() {
90+
let f = File::create(&file_name).unwrap();
91+
let mut writer = BufWriter::new(f);
92+
let mut easy = Easy::new();
93+
easy.url(&binary_url).unwrap();
94+
easy.write_function(move |data| {
95+
Ok(writer.write(data).unwrap())
96+
}).unwrap();
97+
easy.perform().unwrap();
98+
99+
let response_code = easy.response_code().unwrap();
100+
if response_code != 200 {
101+
panic!("Unexpected response code {} for {}", response_code, binary_url);
102+
}
103+
}
104+
105+
// Extract the tarball.
106+
let unpacked_dir = target_dir.join(base_name);
107+
let lib_dir = unpacked_dir.join("lib");
108+
if !lib_dir.join(format!("lib{}.so", LIBRARY)).exists() {
109+
extract(file_name, &unpacked_dir);
110+
}
111+
112+
//run("find", |command| command); // TODO: remove
113+
run("ls", |command| {
114+
command.arg("-l").arg(lib_dir.to_str().unwrap())
115+
}); // TODO: remove
116+
117+
println!("cargo:rustc-link-lib=dylib={}", LIBRARY);
118+
println!("cargo:rustc-link-search={}", lib_dir.display());
119+
}
120+
121+
fn build_from_src() {
33122
let output = PathBuf::from(&get!("OUT_DIR"));
34123
log_var!(output);
35124
let source = PathBuf::from(&get!("CARGO_MANIFEST_DIR")).join(format!("target/source-{}", TAG));
@@ -71,7 +160,10 @@ fn main() {
71160
let configure_hint_file = Path::new(&configure_hint_file_pb);
72161
if !configure_hint_file.exists() {
73162
run("bash",
74-
|command| command.current_dir(&source).arg("-c").arg("yes ''|./configure"));
163+
|command| command.current_dir(&source)
164+
.env("TF_NEED_CUDA", if cfg!(feature = "tensorflow_gpu") {"1"} else {"0"})
165+
.arg("-c")
166+
.arg("yes ''|./configure"));
75167
File::create(configure_hint_file).unwrap();
76168
}
77169
run("bazel", |command| {

0 commit comments

Comments
 (0)