Skip to content

Commit 67fad86

Browse files
Yancey0623wangkuiyi
authored andcommitted
Refactor PAI/Alisa submitter (#1765)
* refactor pai/alisa submitter * update * fix unit test * fix ut
1 parent 4e6ac3a commit 67fad86

File tree

8 files changed

+339
-528
lines changed

8 files changed

+339
-528
lines changed

cmd/repl/repl.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ func getDatabaseName(datasource string) string {
331331
re = regexp.MustCompile(`[^/].*/api[?].*curr_project=(\w*).*`)
332332
case "mysql":
333333
case "hive":
334+
case "alisa": // TODO(yaney1989): using go drivers to parse the database
334335
default:
335336
log.Fatalf("unknown database '%s' in data source'%s'", driver, datasource)
336337
}
@@ -348,7 +349,7 @@ func getDataSource(dataSource, db string) string {
348349
}
349350
pieces := strings.Split(other, "?")
350351
switch driver {
351-
case "maxcompute":
352+
case "maxcompute", "alisa":
352353
var v url.Values = url.Values{}
353354
if len(pieces) == 2 {
354355
v, e = url.ParseQuery(pieces[1])

pkg/sql/alisa_submitter.go

Lines changed: 96 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -14,31 +14,31 @@
1414
package sql
1515

1616
import (
17-
"encoding/json"
1817
"fmt"
1918
"os"
2019
"os/exec"
2120
"path/filepath"
2221
"regexp"
23-
"strconv"
2422
"strings"
2523

2624
"github.com/aliyun/aliyun-oss-go-sdk/oss"
2725
"sqlflow.org/goalisa"
26+
"sqlflow.org/gomaxcompute"
2827
"sqlflow.org/sqlflow/pkg/database"
2928
"sqlflow.org/sqlflow/pkg/ir"
29+
pb "sqlflow.org/sqlflow/pkg/proto"
3030
"sqlflow.org/sqlflow/pkg/sql/codegen/pai"
3131
)
3232

33-
var tarball = "task.tar.gz"
33+
var resourceName = "job.tar.gz"
3434
var entryFile = "entry.py"
3535
var reOSS = regexp.MustCompile(`oss://([^/]+).*host=([^&]+)`)
3636

3737
type alisaSubmitter struct {
3838
*defaultSubmitter
3939
}
4040

41-
func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
41+
func (s *alisaSubmitter) submitAlisaTask(code, resourceURL string) error {
4242
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
4343
if err != nil {
4444
return err
@@ -48,8 +48,7 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
4848
return e
4949
}
5050

51-
ossURL := fmt.Sprintf("https://%s.%s", os.Getenv("SQLFLOW_OSS_BUCKET"), os.Getenv("SQLFLOW_OSS_ENDPOINT"))
52-
cfg.Env["RES_DOWNLOAD_URL"] = fmt.Sprintf(`[{\"downloadUrl\":\"%s/%s\", \"resourceName\":\"%s\"}]`, ossURL, resourceName, tarball)
51+
cfg.Env["RES_DOWNLOAD_URL"] = fmt.Sprintf(`[{\"downloadUrl\":\"%s\", \"resourceName\":\"%s\"}]`, resourceURL, resourceName)
5352
cfg.Verbose = true
5453
newDatasource := cfg.FormatDSN()
5554

@@ -61,90 +60,35 @@ func (s *alisaSubmitter) submitAlisaTask(code, resourceName string) error {
6160
return e
6261
}
6362

64-
func (s *alisaSubmitter) getModelPath(modelName string) (string, error) {
65-
_, dsName, err := database.ParseURL(s.Session.DbConnStr)
66-
if err != nil {
67-
return "", err
68-
}
69-
cfg, err := goalisa.ParseDSN(dsName)
70-
if err != nil {
71-
return "", err
72-
}
73-
userID := s.Session.UserId
74-
if userID == "" {
75-
userID = "unkown"
76-
}
77-
return strings.Join([]string{cfg.Project, userID, modelName}, "/"), nil
78-
}
79-
8063
func (s *alisaSubmitter) ExecuteTrain(ts *ir.TrainStmt) (e error) {
8164
ts.TmpTrainTable, ts.TmpValidateTable, e = createTempTrainAndValTable(ts.Select, ts.ValidationSelect, s.Session.DbConnStr)
8265
if e != nil {
8366
return e
8467
}
8568
defer dropTmpTables([]string{ts.TmpTrainTable, ts.TmpValidateTable}, s.Session.DbConnStr)
8669

87-
cc, e := pai.GetClusterConfig(ts.Attributes)
70+
ossModelPath, e := getModelPath(ts.Into, s.Session)
8871
if e != nil {
8972
return e
9073
}
9174

92-
modelPath, e := s.getModelPath(ts.Into)
75+
// cleanup saved model on OSS before training
76+
modelBucket, e := getModelBucket()
9377
if e != nil {
9478
return e
9579
}
96-
97-
paiCmd, e := getPAIcmd(cc, ts.Into, modelPath, ts.TmpTrainTable, ts.TmpValidateTable, "")
98-
if e != nil {
80+
if e := modelBucket.DeleteObject(ossModelPath); e != nil {
9981
return e
10082
}
10183

102-
code, e := pai.TFTrainAndSave(ts, s.Session, modelPath, cc)
84+
// Alisa resource should be prefix with @@, alisa source would replace it with the RES_DOWN_URL.resourceName in alisa env.
85+
scriptPath := fmt.Sprintf("file://@@%s", resourceName)
86+
code, paiCmd, requirements, e := pai.Train(ts, s.Session, scriptPath, ts.Into, ossModelPath, s.Cwd)
10387
if e != nil {
10488
return e
10589
}
106-
107-
if e := s.cleanUpModel(modelPath); e != nil {
108-
return e
109-
}
110-
111-
return s.submit(code, paiCmd)
112-
}
113-
114-
func (s *alisaSubmitter) submit(program, alisaCode string) error {
115-
if e := s.achieveResource(program, tarball); e != nil {
116-
return e
117-
}
118-
119-
// upload Alisa resource file to OSS
120-
resourceName := randStringRunes(16)
121-
bucket, err := getBucket(os.Getenv("SQLFLOW_OSS_ENDPOINT"),
122-
os.Getenv("SQLFLOW_OSS_AK"), os.Getenv("SQLFLOW_OSS_SK"), os.Getenv("SQLFLOW_OSS_BUCKET"))
123-
if err != nil {
124-
return err
125-
}
126-
if e := bucket.PutObjectFromFile(resourceName, filepath.Join(s.Cwd, tarball)); e != nil {
127-
return err
128-
}
129-
defer bucket.DeleteObject(resourceName)
130-
131-
return s.submitAlisaTask(alisaCode, resourceName)
132-
}
133-
134-
func (s *alisaSubmitter) cleanUpModel(modelPath string) error {
135-
ossCkptDir := os.Getenv("SQLFLOW_OSS_CHECKPOINT_DIR")
136-
sub := reOSS.FindStringSubmatch(ossCkptDir)
137-
if len(sub) != 3 {
138-
return fmt.Errorf("SQLFLOW_OSS_CHECKPOINT_DIR should be format: oss://bucket/?role_arn=xxx&host=xxx")
139-
}
140-
bucket, e := getBucket(sub[2], os.Getenv("SQLFLOW_OSS_AK"), os.Getenv("SQLFLOW_OSS_SK"), sub[1])
141-
if e != nil {
142-
return e
143-
}
144-
if e := bucket.DeleteObject(modelPath); e != nil {
145-
return e
146-
}
147-
return nil
90+
// upload generated program to OSS and submit an Alisa task.
91+
return s.uploadResourceAndSubmitAlisaTask(code, requirements, paiCmd)
14892
}
14993

15094
func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
@@ -159,23 +103,37 @@ func (s *alisaSubmitter) ExecutePredict(ps *ir.PredictStmt) error {
159103
return e
160104
}
161105

162-
cc, e := pai.GetClusterConfig(ps.Attributes)
106+
ossModelPath, e := getModelPath(ps.Using, s.Session)
163107
if e != nil {
164108
return e
165109
}
166-
modelPath, e := s.getModelPath(ps.Using)
110+
isDeepModel, e := ossModelFileExists(ossModelPath)
111+
if e != nil {
112+
return e
113+
}
114+
115+
scriptPath := fmt.Sprintf("file://@@%s", resourceName)
116+
code, paiCmd, requirements, e := pai.Predict(ps, s.Session, scriptPath, ps.Using, ossModelPath, s.Cwd, isDeepModel)
167117
if e != nil {
168118
return e
169119
}
170-
paiCmd, e := getPAIcmd(cc, ps.Using, modelPath, ps.TmpPredictTable, "", ps.ResultTable)
120+
return s.uploadResourceAndSubmitAlisaTask(code, requirements, paiCmd)
121+
}
122+
123+
func (s *alisaSubmitter) uploadResourceAndSubmitAlisaTask(entryCode, requirements, alisaExecCode string) error {
124+
// achieve and upload alisa Resource
125+
ossObjectName := randStringRunes(16)
126+
alisaBucket, e := getAlisaBucket()
171127
if e != nil {
172128
return e
173129
}
174-
code, e := pai.TFLoadAndPredict(ps, s.Session, modelPath)
130+
resourceURL, e := tarAndUploadResource(s.Cwd, entryCode, requirements, ossObjectName, alisaBucket)
175131
if e != nil {
176132
return e
177133
}
178-
return s.submit(code, paiCmd)
134+
defer alisaBucket.DeleteObject(ossObjectName)
135+
// upload generated program to OSS and submit an Alisa task.
136+
return s.submitAlisaTask(alisaExecCode, resourceURL)
179137
}
180138

181139
func (s *alisaSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
@@ -184,29 +142,6 @@ func (s *alisaSubmitter) ExecuteExplain(cl *ir.ExplainStmt) error {
184142

185143
func (s *alisaSubmitter) GetTrainStmtFromModel() bool { return false }
186144

187-
func (s *alisaSubmitter) achieveResource(entryCode, tarball string) error {
188-
if err := writeFile(filepath.Join(s.Cwd, entryFile), entryCode); err != nil {
189-
return err
190-
}
191-
192-
path, err := findPyModulePath("sqlflow_submitter")
193-
if err != nil {
194-
return err
195-
}
196-
cmd := exec.Command("cp", "-r", path, ".")
197-
cmd.Dir = s.Cwd
198-
if _, err := cmd.CombinedOutput(); err != nil {
199-
return fmt.Errorf("failed %s, %v", cmd, err)
200-
}
201-
202-
cmd = exec.Command("tar", "czf", tarball, "./sqlflow_submitter", entryFile)
203-
cmd.Dir = s.Cwd
204-
if _, err := cmd.CombinedOutput(); err != nil {
205-
return fmt.Errorf("failed %s, %v", cmd, err)
206-
}
207-
return nil
208-
}
209-
210145
func findPyModulePath(pyModuleName string) (string, error) {
211146
cmd := exec.Command("python", "-c", fmt.Sprintf(`import %s;print(%s.__path__[0])`, pyModuleName, pyModuleName))
212147
out, err := cmd.CombinedOutput()
@@ -216,8 +151,38 @@ func findPyModulePath(pyModuleName string) (string, error) {
216151
return strings.TrimSpace(string(out)), nil
217152
}
218153

219-
func getBucket(endpoint, ak, sk, bucketName string) (*oss.Bucket, error) {
220-
cli, err := oss.New(endpoint, ak, sk)
154+
func getModelBucket() (*oss.Bucket, error) {
155+
ossCkptDir := os.Getenv("SQLFLOW_OSS_CHECKPOINT_DIR")
156+
ak := os.Getenv("SQLFLOW_OSS_AK")
157+
sk := os.Getenv("SQLFLOW_OSS_SK")
158+
ep := os.Getenv("SQLFLOW_OSS_MODEL_ENDPOINT")
159+
if ak == "" || sk == "" || ep == "" || ossCkptDir == "" {
160+
return nil, fmt.Errorf("should define SQLFLOW_OSS_MODEL_ENDPOINT, SQLFLOW_OSS_CHECKPOINT_DIR, SQLFLOW_OSS_AK, SQLFLOW_OSS_SK when using submitter alisa")
161+
}
162+
163+
sub := reOSS.FindStringSubmatch(ossCkptDir)
164+
if len(sub) != 3 {
165+
return nil, fmt.Errorf("SQLFLOW_OSS_CHECKPOINT_DIR should be format: oss://bucket/?role_arn=xxx&host=xxx")
166+
}
167+
bucketName := sub[1]
168+
cli, e := oss.New(ep, ak, sk)
169+
if e != nil {
170+
return nil, e
171+
}
172+
return cli.Bucket(bucketName)
173+
}
174+
175+
func getAlisaBucket() (*oss.Bucket, error) {
176+
ep := os.Getenv("SQLFLOW_OSS_ALISA_ENDPOINT")
177+
ak := os.Getenv("SQLFLOW_OSS_AK")
178+
sk := os.Getenv("SQLFLOW_OSS_SK")
179+
bucketName := os.Getenv("SQLFLOW_OSS_ALISA_BUCKET")
180+
181+
if ep == "" || ak == "" || sk == "" {
182+
return nil, fmt.Errorf("should define SQLFLOW_OSS_ALISA_ENDPOINT, SQLFLOW_OSS_ALISA_BUCKET, SQLFLOW_OSS_AK, SQLFLOW_OSS_SK when using submitter alisa")
183+
}
184+
185+
cli, err := oss.New(ep, ak, sk)
221186
if err != nil {
222187
return nil, err
223188
}
@@ -234,48 +199,41 @@ func writeFile(filePath, program string) error {
234199
return nil
235200
}
236201

237-
func odpsTables(table string) (string, error) {
238-
parts := strings.Split(table, ".")
239-
if len(parts) != 2 {
240-
return "", fmt.Errorf("odps table: %s should be format db.table", table)
202+
func getModelPath(modelName string, session *pb.Session) (string, error) {
203+
driverName, dsName, e := database.ParseURL(session.DbConnStr)
204+
if e != nil {
205+
return "", e
206+
}
207+
userID := session.UserId
208+
var projectName string
209+
if driverName == "maxcompute" {
210+
cfg, e := gomaxcompute.ParseDSN(dsName)
211+
if e != nil {
212+
return "", e
213+
}
214+
projectName = cfg.Project
215+
} else if driverName == "alisa" {
216+
cfg, e := goalisa.ParseDSN(dsName)
217+
if e != nil {
218+
return "", e
219+
}
220+
projectName = cfg.Project
241221
}
242-
return fmt.Sprintf("odps://%s/tables/%s", parts[0], parts[1]), nil
222+
if userID == "" {
223+
userID = "unkown"
224+
}
225+
return strings.Join([]string{projectName, userID, modelName}, "/"), nil
243226
}
244227

245-
func getPAIcmd(cc *pai.ClusterConfig, modelName, ossModelPath, trainTable, valTable, resTable string) (string, error) {
246-
jobName := strings.Replace(strings.Join([]string{"sqlflow", modelName}, "_"), ".", "_", 0)
247-
cfString, err := json.Marshal(cc)
248-
if err != nil {
249-
return "", err
250-
}
251-
cfQuote := strconv.Quote(string(cfString))
252-
ckpDir, err := pai.FormatCkptDir(ossModelPath)
253-
if err != nil {
254-
return "", err
228+
func tarAndUploadResource(cwd, entryCode, requirements, ossObjectName string, bucket *oss.Bucket) (string, error) {
229+
tarball := "job.tar.gz"
230+
if e := achieveResource(cwd, entryCode, requirements, tarball); e != nil {
231+
return "", e
255232
}
233+
resourceURL := fmt.Sprintf("https://%s.%s/%s", bucket.BucketName, bucket.Client.Config.Endpoint, ossObjectName)
256234

257-
// submit table should format as: odps://<project>/tables/<table>,odps://<project>/tables/<table>...
258-
submitTables, err := odpsTables(trainTable)
259-
if err != nil {
260-
return "", err
261-
}
262-
if trainTable != valTable && valTable != "" {
263-
valTable, err := odpsTables(valTable)
264-
if err != nil {
265-
return "", err
266-
}
267-
submitTables = fmt.Sprintf("%s,%s", submitTables, valTable)
268-
}
269-
outputTables := ""
270-
if resTable != "" {
271-
table, err := odpsTables(resTable)
272-
if err != nil {
273-
return "", err
274-
}
275-
outputTables = fmt.Sprintf("-Doutputs=%s", table)
276-
}
277-
if cc.Worker.Count > 1 {
278-
return fmt.Sprintf("pai -name tensorflow1120 -DjobName=%s -Dtags=dnn -Dscript=file://@@%s -DentryFile=entry.py -Dtables=%s %s -DcheckpointDir=\"%s\" -Dcluster=%s", jobName, tarball, submitTables, outputTables, ckpDir, cfQuote), nil
235+
if e := bucket.PutObjectFromFile(ossObjectName, filepath.Join(cwd, tarball)); e != nil {
236+
return "", e
279237
}
280-
return fmt.Sprintf("pai -name tensorflow1120 -DjobName=%s -Dtags=dnn -Dscript=file://@@%s -DentryFile=entry.py -Dtables=%s %s -DcheckpointDir=\"%s\"", jobName, tarball, submitTables, outputTables, ckpDir), nil
238+
return resourceURL, nil
281239
}

pkg/sql/alisa_submitter_test.go

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
package sql
1515

1616
import (
17-
"fmt"
18-
"os"
1917
"testing"
2018

2119
"github.com/stretchr/testify/assert"
22-
"sqlflow.org/sqlflow/pkg/sql/codegen/pai"
2320
)
2421

2522
func TestAlisaSubmitter(t *testing.T) {
@@ -33,27 +30,3 @@ func TestFindPyModulePath(t *testing.T) {
3330
_, err := findPyModulePath("sqlflow_submitter")
3431
a.NoError(err)
3532
}
36-
37-
func TestGetPAICmd(t *testing.T) {
38-
a := assert.New(t)
39-
cc := &pai.ClusterConfig{
40-
Worker: pai.WorkerConfig{
41-
Count: 1,
42-
CPU: 2,
43-
GPU: 0,
44-
},
45-
PS: pai.PSConfig{
46-
Count: 2,
47-
CPU: 4,
48-
GPU: 0,
49-
},
50-
}
51-
os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "oss://bucket/?role_arn=xxx&host=xxx")
52-
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR")
53-
paiCmd, err := getPAIcmd(cc, "my_model", "project/12345/my_model", "testdb.test", "", "testdb.result")
54-
a.NoError(err)
55-
ckpDir, err := pai.FormatCkptDir("project/12345/my_model")
56-
a.NoError(err)
57-
expected := fmt.Sprintf("pai -name tensorflow1120 -DjobName=sqlflow_my_model -Dtags=dnn -Dscript=file://@@task.tar.gz -DentryFile=entry.py -Dtables=odps://testdb/tables/test -Doutputs=odps://testdb/tables/result -DcheckpointDir=\"%s\"", ckpDir)
58-
a.Equal(expected, paiCmd)
59-
}

0 commit comments

Comments
 (0)