1
+ extern crate curl;
2
+ extern crate flate2;
1
3
extern crate pkg_config;
2
4
extern crate semver;
5
+ extern crate tar;
3
6
4
7
use std:: error:: Error ;
5
8
use std:: fs:: File ;
9
+ use std:: io:: BufWriter ;
10
+ use std:: io:: Write ;
6
11
use std:: path:: { Path , PathBuf } ;
7
12
use std:: process;
8
13
use std:: process:: Command ;
9
14
use std:: { env, fs} ;
15
+
16
+ use curl:: easy:: Easy ;
17
+ use flate2:: read:: GzDecoder ;
10
18
use semver:: Version ;
19
+ use tar:: Archive ;
11
20
12
21
const LIBRARY : & ' static str = "tensorflow" ;
13
22
const REPOSITORY : & ' static str = "https://github.com/tensorflow/tensorflow.git" ;
14
23
const TARGET : & ' static str = "tensorflow:libtensorflow.so" ;
24
+ // `VERSION` and `TAG` are separate because the tag is not always `'v' + VERSION`.
25
+ const VERSION : & ' static str = "1.0.0" ;
15
26
const TAG : & ' static str = "v1.0.0" ;
16
27
const MIN_BAZEL : & ' static str = "0.3.2" ;
17
28
@@ -30,6 +41,84 @@ fn main() {
30
41
return ;
31
42
}
32
43
44
+ if env:: consts:: ARCH == "x86_64" && ( env:: consts:: OS == "linux" || env:: consts:: OS == "macos" ) {
45
+ install_prebuilt ( ) ;
46
+ } else {
47
+ build_from_src ( ) ;
48
+ }
49
+ }
50
+
51
+ fn remove_suffix ( value : & mut String , suffix : & str ) {
52
+ if value. ends_with ( suffix) {
53
+ let n = value. len ( ) ;
54
+ value. truncate ( n - suffix. len ( ) ) ;
55
+ }
56
+ }
57
+
58
+ fn extract < P : AsRef < Path > , P2 : AsRef < Path > > ( archive_path : P , extract_to : P2 ) {
59
+ let file = File :: open ( archive_path) . unwrap ( ) ;
60
+ let unzipped = GzDecoder :: new ( file) . unwrap ( ) ;
61
+ let mut a = Archive :: new ( unzipped) ;
62
+ a. unpack ( extract_to) . unwrap ( ) ;
63
+ }
64
+
65
+ // Downloads and unpacks a prebuilt binary. Only works for certain platforms.
66
+ fn install_prebuilt ( ) {
67
+ // Figure out the file names.
68
+ let os = match env:: consts:: OS {
69
+ "macos" => "darwin" ,
70
+ x => x,
71
+ } ;
72
+ let proc_type = if cfg ! ( feature = "tensorflow_gpu" ) { "gpu" } else { "cpu" } ;
73
+ let binary_url = format ! (
74
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-{}-{}-{}-{}.tar.gz" ,
75
+ proc_type, os, env:: consts:: ARCH , VERSION ) ;
76
+ log_var ! ( binary_url) ;
77
+ let short_file_name = binary_url. split ( "/" ) . last ( ) . unwrap ( ) ;
78
+ let mut base_name = short_file_name. to_string ( ) ;
79
+ remove_suffix ( & mut base_name, ".tar.gz" ) ;
80
+ log_var ! ( base_name) ;
81
+ let target_dir = PathBuf :: from ( & get ! ( "CARGO_MANIFEST_DIR" ) ) . join ( "target" ) ;
82
+ if !target_dir. exists ( ) {
83
+ fs:: create_dir ( & target_dir) . unwrap ( ) ;
84
+ }
85
+ let file_name = target_dir. join ( short_file_name) ;
86
+ log_var ! ( file_name) ;
87
+
88
+ // Download the tarball.
89
+ if !file_name. exists ( ) {
90
+ let f = File :: create ( & file_name) . unwrap ( ) ;
91
+ let mut writer = BufWriter :: new ( f) ;
92
+ let mut easy = Easy :: new ( ) ;
93
+ easy. url ( & binary_url) . unwrap ( ) ;
94
+ easy. write_function ( move |data| {
95
+ Ok ( writer. write ( data) . unwrap ( ) )
96
+ } ) . unwrap ( ) ;
97
+ easy. perform ( ) . unwrap ( ) ;
98
+
99
+ let response_code = easy. response_code ( ) . unwrap ( ) ;
100
+ if response_code != 200 {
101
+ panic ! ( "Unexpected response code {} for {}" , response_code, binary_url) ;
102
+ }
103
+ }
104
+
105
+ // Extract the tarball.
106
+ let unpacked_dir = target_dir. join ( base_name) ;
107
+ let lib_dir = unpacked_dir. join ( "lib" ) ;
108
+ if !lib_dir. join ( format ! ( "lib{}.so" , LIBRARY ) ) . exists ( ) {
109
+ extract ( file_name, & unpacked_dir) ;
110
+ }
111
+
112
+ //run("find", |command| command); // TODO: remove
113
+ run ( "ls" , |command| {
114
+ command. arg ( "-l" ) . arg ( lib_dir. to_str ( ) . unwrap ( ) )
115
+ } ) ; // TODO: remove
116
+
117
+ println ! ( "cargo:rustc-link-lib=dylib={}" , LIBRARY ) ;
118
+ println ! ( "cargo:rustc-link-search={}" , lib_dir. display( ) ) ;
119
+ }
120
+
121
+ fn build_from_src ( ) {
33
122
let output = PathBuf :: from ( & get ! ( "OUT_DIR" ) ) ;
34
123
log_var ! ( output) ;
35
124
let source = PathBuf :: from ( & get ! ( "CARGO_MANIFEST_DIR" ) ) . join ( format ! ( "target/source-{}" , TAG ) ) ;
@@ -71,7 +160,10 @@ fn main() {
71
160
let configure_hint_file = Path :: new ( & configure_hint_file_pb) ;
72
161
if !configure_hint_file. exists ( ) {
73
162
run ( "bash" ,
74
- |command| command. current_dir ( & source) . arg ( "-c" ) . arg ( "yes ''|./configure" ) ) ;
163
+ |command| command. current_dir ( & source)
164
+ . env ( "TF_NEED_CUDA" , if cfg ! ( feature = "tensorflow_gpu" ) { "1" } else { "0" } )
165
+ . arg ( "-c" )
166
+ . arg ( "yes ''|./configure" ) ) ;
75
167
File :: create ( configure_hint_file) . unwrap ( ) ;
76
168
}
77
169
run ( "bazel" , |command| {
0 commit comments