14
14
package sql
15
15
16
16
import (
17
- "encoding/json"
18
17
"fmt"
19
18
"os"
20
19
"os/exec"
21
20
"path/filepath"
22
21
"regexp"
23
- "strconv"
24
22
"strings"
25
23
26
24
"github.com/aliyun/aliyun-oss-go-sdk/oss"
27
25
"sqlflow.org/goalisa"
26
+ "sqlflow.org/gomaxcompute"
28
27
"sqlflow.org/sqlflow/pkg/database"
29
28
"sqlflow.org/sqlflow/pkg/ir"
29
+ pb "sqlflow.org/sqlflow/pkg/proto"
30
30
"sqlflow.org/sqlflow/pkg/sql/codegen/pai"
31
31
)
32
32
33
- var tarball = "task .tar.gz"
33
+ var resourceName = "job .tar.gz"
34
34
var entryFile = "entry.py"
35
35
var reOSS = regexp .MustCompile (`oss://([^/]+).*host=([^&]+)` )
36
36
37
37
type alisaSubmitter struct {
38
38
* defaultSubmitter
39
39
}
40
40
41
- func (s * alisaSubmitter ) submitAlisaTask (code , resourceName string ) error {
41
+ func (s * alisaSubmitter ) submitAlisaTask (code , resourceURL string ) error {
42
42
_ , dsName , err := database .ParseURL (s .Session .DbConnStr )
43
43
if err != nil {
44
44
return err
@@ -48,8 +48,7 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
48
48
return e
49
49
}
50
50
51
- ossURL := fmt .Sprintf ("https://%s.%s" , os .Getenv ("SQLFLOW_OSS_BUCKET" ), os .Getenv ("SQLFLOW_OSS_ENDPOINT" ))
52
- cfg .Env ["RES_DOWNLOAD_URL" ] = fmt .Sprintf (`[{\"downloadUrl\":\"%s/%s\", \"resourceName\":\"%s\"}]` , ossURL , resourceName , tarball )
51
+ cfg .Env ["RES_DOWNLOAD_URL" ] = fmt .Sprintf (`[{\"downloadUrl\":\"%s\", \"resourceName\":\"%s\"}]` , resourceURL , resourceName )
53
52
cfg .Verbose = true
54
53
newDatasource := cfg .FormatDSN ()
55
54
@@ -61,90 +60,35 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
61
60
return e
62
61
}
63
62
64
- func (s * alisaSubmitter ) getModelPath (modelName string ) (string , error ) {
65
- _ , dsName , err := database .ParseURL (s .Session .DbConnStr )
66
- if err != nil {
67
- return "" , err
68
- }
69
- cfg , err := goalisa .ParseDSN (dsName )
70
- if err != nil {
71
- return "" , err
72
- }
73
- userID := s .Session .UserId
74
- if userID == "" {
75
- userID = "unkown"
76
- }
77
- return strings .Join ([]string {cfg .Project , userID , modelName }, "/" ), nil
78
- }
79
-
80
63
func (s * alisaSubmitter ) ExecuteTrain (ts * ir.TrainStmt ) (e error ) {
81
64
ts .TmpTrainTable , ts .TmpValidateTable , e = createTempTrainAndValTable (ts .Select , ts .ValidationSelect , s .Session .DbConnStr )
82
65
if e != nil {
83
66
return e
84
67
}
85
68
defer dropTmpTables ([]string {ts .TmpTrainTable , ts .TmpValidateTable }, s .Session .DbConnStr )
86
69
87
- cc , e := pai . GetClusterConfig (ts .Attributes )
70
+ ossModelPath , e := getModelPath (ts .Into , s . Session )
88
71
if e != nil {
89
72
return e
90
73
}
91
74
92
- modelPath , e := s .getModelPath (ts .Into )
75
+ // cleanup saved model on OSS before training
76
+ modelBucket , e := getModelBucket ()
93
77
if e != nil {
94
78
return e
95
79
}
96
-
97
- paiCmd , e := getPAIcmd (cc , ts .Into , modelPath , ts .TmpTrainTable , ts .TmpValidateTable , "" )
98
- if e != nil {
80
+ if e := modelBucket .DeleteObject (ossModelPath ); e != nil {
99
81
return e
100
82
}
101
83
102
- code , e := pai .TFTrainAndSave (ts , s .Session , modelPath , cc )
84
+ // Alisa resource should be prefix with @@, alisa source would replace it with the RES_DOWN_URL.resourceName in alisa env.
85
+ scriptPath := fmt .Sprintf ("file://@@%s" , resourceName )
86
+ code , paiCmd , requirements , e := pai .Train (ts , s .Session , scriptPath , ts .Into , ossModelPath , s .Cwd )
103
87
if e != nil {
104
88
return e
105
89
}
106
-
107
- if e := s .cleanUpModel (modelPath ); e != nil {
108
- return e
109
- }
110
-
111
- return s .submit (code , paiCmd )
112
- }
113
-
114
- func (s * alisaSubmitter ) submit (program , alisaCode string ) error {
115
- if e := s .achieveResource (program , tarball ); e != nil {
116
- return e
117
- }
118
-
119
- // upload Alisa resource file to OSS
120
- resourceName := randStringRunes (16 )
121
- bucket , err := getBucket (os .Getenv ("SQLFLOW_OSS_ENDPOINT" ),
122
- os .Getenv ("SQLFLOW_OSS_AK" ), os .Getenv ("SQLFLOW_OSS_SK" ), os .Getenv ("SQLFLOW_OSS_BUCKET" ))
123
- if err != nil {
124
- return err
125
- }
126
- if e := bucket .PutObjectFromFile (resourceName , filepath .Join (s .Cwd , tarball )); e != nil {
127
- return err
128
- }
129
- defer bucket .DeleteObject (resourceName )
130
-
131
- return s .submitAlisaTask (alisaCode , resourceName )
132
- }
133
-
134
- func (s * alisaSubmitter ) cleanUpModel (modelPath string ) error {
135
- ossCkptDir := os .Getenv ("SQLFLOW_OSS_CHECKPOINT_DIR" )
136
- sub := reOSS .FindStringSubmatch (ossCkptDir )
137
- if len (sub ) != 3 {
138
- return fmt .Errorf ("SQLFLOW_OSS_CHECKPOINT_DIR should be format: oss://bucket/?role_arn=xxx&host=xxx" )
139
- }
140
- bucket , e := getBucket (sub [2 ], os .Getenv ("SQLFLOW_OSS_AK" ), os .Getenv ("SQLFLOW_OSS_SK" ), sub [1 ])
141
- if e != nil {
142
- return e
143
- }
144
- if e := bucket .DeleteObject (modelPath ); e != nil {
145
- return e
146
- }
147
- return nil
90
+ // upload generated program to OSS and submit an Alisa task.
91
+ return s .uploadResourceAndSubmitAlisaTask (code , requirements , paiCmd )
148
92
}
149
93
150
94
func (s * alisaSubmitter ) ExecutePredict (ps * ir.PredictStmt ) error {
@@ -159,23 +103,37 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
159
103
return e
160
104
}
161
105
162
- cc , e := pai . GetClusterConfig (ps .Attributes )
106
+ ossModelPath , e := getModelPath (ps .Using , s . Session )
163
107
if e != nil {
164
108
return e
165
109
}
166
- modelPath , e := s .getModelPath (ps .Using )
110
+ isDeepModel , e := ossModelFileExists (ossModelPath )
111
+ if e != nil {
112
+ return e
113
+ }
114
+
115
+ scriptPath := fmt .Sprintf ("file://@@%s" , resourceName )
116
+ code , paiCmd , requirements , e := pai .Predict (ps , s .Session , scriptPath , ps .Using , ossModelPath , s .Cwd , isDeepModel )
167
117
if e != nil {
168
118
return e
169
119
}
170
- paiCmd , e := getPAIcmd (cc , ps .Using , modelPath , ps .TmpPredictTable , "" , ps .ResultTable )
120
+ return s .uploadResourceAndSubmitAlisaTask (code , requirements , paiCmd )
121
+ }
122
+
123
+ func (s * alisaSubmitter ) uploadResourceAndSubmitAlisaTask (entryCode , requirements , alisaExecCode string ) error {
124
+ // achieve and upload alisa Resource
125
+ ossObjectName := randStringRunes (16 )
126
+ alisaBucket , e := getAlisaBucket ()
171
127
if e != nil {
172
128
return e
173
129
}
174
- code , e := pai . TFLoadAndPredict ( ps , s . Session , modelPath )
130
+ resourceURL , e := tarAndUploadResource ( s . Cwd , entryCode , requirements , ossObjectName , alisaBucket )
175
131
if e != nil {
176
132
return e
177
133
}
178
- return s .submit (code , paiCmd )
134
+ defer alisaBucket .DeleteObject (ossObjectName )
135
+ // upload generated program to OSS and submit an Alisa task.
136
+ return s .submitAlisaTask (alisaExecCode , resourceURL )
179
137
}
180
138
181
139
func (s * alisaSubmitter ) ExecuteExplain (cl * ir.ExplainStmt ) error {
@@ -184,29 +142,6 @@ func (s *alisaSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
184
142
185
143
func (s * alisaSubmitter ) GetTrainStmtFromModel () bool { return false }
186
144
187
- func (s * alisaSubmitter ) achieveResource (entryCode , tarball string ) error {
188
- if err := writeFile (filepath .Join (s .Cwd , entryFile ), entryCode ); err != nil {
189
- return err
190
- }
191
-
192
- path , err := findPyModulePath ("sqlflow_submitter" )
193
- if err != nil {
194
- return err
195
- }
196
- cmd := exec .Command ("cp" , "-r" , path , "." )
197
- cmd .Dir = s .Cwd
198
- if _ , err := cmd .CombinedOutput (); err != nil {
199
- return fmt .Errorf ("failed %s, %v" , cmd , err )
200
- }
201
-
202
- cmd = exec .Command ("tar" , "czf" , tarball , "./sqlflow_submitter" , entryFile )
203
- cmd .Dir = s .Cwd
204
- if _ , err := cmd .CombinedOutput (); err != nil {
205
- return fmt .Errorf ("failed %s, %v" , cmd , err )
206
- }
207
- return nil
208
- }
209
-
210
145
func findPyModulePath (pyModuleName string ) (string , error ) {
211
146
cmd := exec .Command ("python" , "-c" , fmt .Sprintf (`import %s;print(%s.__path__[0])` , pyModuleName , pyModuleName ))
212
147
out , err := cmd .CombinedOutput ()
@@ -216,8 +151,38 @@ func findPyModulePath(pyModuleName string) (string, error) {
216
151
return strings .TrimSpace (string (out )), nil
217
152
}
218
153
219
- func getBucket (endpoint , ak , sk , bucketName string ) (* oss.Bucket , error ) {
220
- cli , err := oss .New (endpoint , ak , sk )
154
+ func getModelBucket () (* oss.Bucket , error ) {
155
+ ossCkptDir := os .Getenv ("SQLFLOW_OSS_CHECKPOINT_DIR" )
156
+ ak := os .Getenv ("SQLFLOW_OSS_AK" )
157
+ sk := os .Getenv ("SQLFLOW_OSS_SK" )
158
+ ep := os .Getenv ("SQLFLOW_OSS_MODEL_ENDPOINT" )
159
+ if ak == "" || sk == "" || ep == "" || ossCkptDir == "" {
160
+ return nil , fmt .Errorf ("should define SQLFLOW_OSS_MODEL_ENDPOINT, SQLFLOW_OSS_CHECKPOINT_DIR, SQLFLOW_OSS_AK, SQLFLOW_OSS_SK when using submitter alisa" )
161
+ }
162
+
163
+ sub := reOSS .FindStringSubmatch (ossCkptDir )
164
+ if len (sub ) != 3 {
165
+ return nil , fmt .Errorf ("SQLFLOW_OSS_CHECKPOINT_DIR should be format: oss://bucket/?role_arn=xxx&host=xxx" )
166
+ }
167
+ bucketName := sub [1 ]
168
+ cli , e := oss .New (ep , ak , sk )
169
+ if e != nil {
170
+ return nil , e
171
+ }
172
+ return cli .Bucket (bucketName )
173
+ }
174
+
175
+ func getAlisaBucket () (* oss.Bucket , error ) {
176
+ ep := os .Getenv ("SQLFLOW_OSS_ALISA_ENDPOINT" )
177
+ ak := os .Getenv ("SQLFLOW_OSS_AK" )
178
+ sk := os .Getenv ("SQLFLOW_OSS_SK" )
179
+ bucketName := os .Getenv ("SQLFLOW_OSS_ALISA_BUCKET" )
180
+
181
+ if ep == "" || ak == "" || sk == "" {
182
+ return nil , fmt .Errorf ("should define SQLFLOW_OSS_ALISA_ENDPOINT, SQLFLOW_OSS_ALISA_BUCKET, SQLFLOW_OSS_AK, SQLFLOW_OSS_SK when using submitter alisa" )
183
+ }
184
+
185
+ cli , err := oss .New (ep , ak , sk )
221
186
if err != nil {
222
187
return nil , err
223
188
}
@@ -234,48 +199,41 @@ func writeFile(filePath, program string) error {
234
199
return nil
235
200
}
236
201
237
- func odpsTables (table string ) (string , error ) {
238
- parts := strings .Split (table , "." )
239
- if len (parts ) != 2 {
240
- return "" , fmt .Errorf ("odps table: %s should be format db.table" , table )
202
+ func getModelPath (modelName string , session * pb.Session ) (string , error ) {
203
+ driverName , dsName , e := database .ParseURL (session .DbConnStr )
204
+ if e != nil {
205
+ return "" , e
206
+ }
207
+ userID := session .UserId
208
+ var projectName string
209
+ if driverName == "maxcompute" {
210
+ cfg , e := gomaxcompute .ParseDSN (dsName )
211
+ if e != nil {
212
+ return "" , e
213
+ }
214
+ projectName = cfg .Project
215
+ } else if driverName == "alisa" {
216
+ cfg , e := goalisa .ParseDSN (dsName )
217
+ if e != nil {
218
+ return "" , e
219
+ }
220
+ projectName = cfg .Project
241
221
}
242
- return fmt .Sprintf ("odps://%s/tables/%s" , parts [0 ], parts [1 ]), nil
222
+ if userID == "" {
223
+ userID = "unkown"
224
+ }
225
+ return strings .Join ([]string {projectName , userID , modelName }, "/" ), nil
243
226
}
244
227
245
- func getPAIcmd (cc * pai.ClusterConfig , modelName , ossModelPath , trainTable , valTable , resTable string ) (string , error ) {
246
- jobName := strings .Replace (strings .Join ([]string {"sqlflow" , modelName }, "_" ), "." , "_" , 0 )
247
- cfString , err := json .Marshal (cc )
248
- if err != nil {
249
- return "" , err
250
- }
251
- cfQuote := strconv .Quote (string (cfString ))
252
- ckpDir , err := pai .FormatCkptDir (ossModelPath )
253
- if err != nil {
254
- return "" , err
228
+ func tarAndUploadResource (cwd , entryCode , requirements , ossObjectName string , bucket * oss.Bucket ) (string , error ) {
229
+ tarball := "job.tar.gz"
230
+ if e := achieveResource (cwd , entryCode , requirements , tarball ); e != nil {
231
+ return "" , e
255
232
}
233
+ resourceURL := fmt .Sprintf ("https://%s.%s/%s" , bucket .BucketName , bucket .Client .Config .Endpoint , ossObjectName )
256
234
257
- // submit table should format as: odps://<project>/tables/<table>,odps://<project>/tables/<table>...
258
- submitTables , err := odpsTables (trainTable )
259
- if err != nil {
260
- return "" , err
261
- }
262
- if trainTable != valTable && valTable != "" {
263
- valTable , err := odpsTables (valTable )
264
- if err != nil {
265
- return "" , err
266
- }
267
- submitTables = fmt .Sprintf ("%s,%s" , submitTables , valTable )
268
- }
269
- outputTables := ""
270
- if resTable != "" {
271
- table , err := odpsTables (resTable )
272
- if err != nil {
273
- return "" , err
274
- }
275
- outputTables = fmt .Sprintf ("-Doutputs=%s" , table )
276
- }
277
- if cc .Worker .Count > 1 {
278
- return fmt .Sprintf ("pai -name tensorflow1120 -DjobName=%s -Dtags=dnn -Dscript=file://@@%s -DentryFile=entry.py -Dtables=%s %s -DcheckpointDir=\" %s\" -Dcluster=%s" , jobName , tarball , submitTables , outputTables , ckpDir , cfQuote ), nil
235
+ if e := bucket .PutObjectFromFile (ossObjectName , filepath .Join (cwd , tarball )); e != nil {
236
+ return "" , e
279
237
}
280
- return fmt . Sprintf ( "pai -name tensorflow1120 -DjobName=%s -Dtags=dnn -Dscript=file://@@%s -DentryFile=entry.py -Dtables=%s %s -DcheckpointDir= \" %s \" " , jobName , tarball , submitTables , outputTables , ckpDir ) , nil
238
+ return resourceURL , nil
281
239
}
0 commit comments