Skip to content

Commit 483b867

Browse files
authored
A new format of RoleArn for PAI to access OSS (#2639)
* pai access oss by checkpoint * rename SQLFLOW_OSS_CHECKPOINT_DIR to SQLFLOW_OSS_CHECKPOINT_CONFIG
1 parent 7f5dd8a commit 483b867

File tree

5 files changed

+65
-27
lines changed

5 files changed

+65
-27
lines changed

go/cmd/sqlflowserver/e2e_alisa_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
// below environment variables to run them:
2727
// SQLFLOW_submitter=alisa
2828
// SQLFLOW_TEST_DATASOURCE="xxx"
29-
// SQLFLOW_OSS_CHECKPOINT_DIR="xxx"
29+
// SQLFLOW_OSS_CHECKPOINT_CONFIG="xxx"
3030
// SQLFLOW_OSS_ALISA_ENDPOINT="xxx"
3131
// SQLFLOW_OSS_AK="xxx"
3232
// SQLFLOW_OSS_SK="xxx"

go/cmd/sqlflowserver/e2e_pai_maxcompute_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ USING e2etest_keras_dnn;`, caseTestTable, caseDB)
444444
// SQLFLOW_TEST_DB_MAXCOMPUTE_ENDPOINT="xxx"
445445
// SQLFLOW_TEST_DB_MAXCOMPUTE_AK="xxx"
446446
// SQLFLOW_TEST_DB_MAXCOMPUTE_SK="xxx"
447-
// SQLFLOW_OSS_CHECKPOINT_DIR="xxx"
447+
// SQLFLOW_OSS_CHECKPOINT_CONFIG="xxx"
448448
// SQLFLOW_OSS_ENDPOINT="xxx"
449449
// SQLFLOW_OSS_AK="xxx"
450450
// SQLFLOW_OSS_SK="xxx"

go/codegen/pai/codegen_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ func TestTrainCodegen(t *testing.T) {
118118
a := assert.New(t)
119119
trainStmt := ir.MockTrainStmt(false)
120120

121-
os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "{\"project\": \"oss://bucket/?role_arn=xxx&host=xxx\"}")
122-
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR")
121+
os.Setenv("SQLFLOW_OSS_CHECKPOINT_CONFIG", "{\"host\": \"h.com\", \"arn\": \"acs:ram::9527:role\"}")
122+
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
123123

124124
sess := mockSession()
125125
ossModelPath := "iris/sqlflow/my_dnn_model"
@@ -135,16 +135,16 @@ func TestTrainCodegen(t *testing.T) {
135135
a.True(hasExportedLocal(tfCode))
136136
a.False(hasUnknownParameters(paiTFCode, knownTrainParams))
137137

138-
expectedPAICmd := fmt.Sprintf("pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 -DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=%s -DentryFile=entry.py -Dtables=odps://iris/tables/train,odps://iris/tables/test -DhyperParameters=\"%s\" -DcheckpointDir='oss://sqlflow-models/iris/sqlflow/my_dnn_model/?role_arn=xxx&host=xxx' -DgpuRequired='0'", scriptPath, paramsPath)
138+
expectedPAICmd := fmt.Sprintf("pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 -DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=%s -DentryFile=entry.py -Dtables=odps://iris/tables/train,odps://iris/tables/test -DhyperParameters=\"%s\" -DcheckpointDir='oss://sqlflow-models/iris/sqlflow/my_dnn_model/?role_arn=acs:ram::9527:role/pai2oss_project&host=h.com' -DgpuRequired='0'", scriptPath, paramsPath)
139139
a.Equal(expectedPAICmd, paiCmd)
140140
}
141141

142142
func TestPredictCodegen(t *testing.T) {
143143
a := assert.New(t)
144144
ir := ir.MockPredStmt(ir.MockTrainStmt(false))
145145

146-
os.Setenv("SQLFLOW_OSS_CHECKPOINT_DIR", "{\"project\": \"oss://bucket/?role_arn=xxx&host=xxx\"}")
147-
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_DIR")
146+
os.Setenv("SQLFLOW_OSS_CHECKPOINT_CONFIG", "{\"host\": \"h.com\", \"arn\": \"acs:ram::9527:role\"}")
147+
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
148148
sess := mockSession()
149149
ossModelPath := "iris/sqlflow/my_dnn_model"
150150
scriptPath := "file:///tmp/task.tar.gz"
@@ -157,6 +157,6 @@ func TestPredictCodegen(t *testing.T) {
157157

158158
a.True(hasExportedLocal(tfCode))
159159
a.False(hasUnknownParameters(tfCode, knownPredictParams))
160-
expectedPAICmd := fmt.Sprintf("pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 -DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=%s -DentryFile=entry.py -Dtables=odps://iris/tables/predict -Doutputs=odps://iris/tables/predict -DhyperParameters=\"%s\" -DcheckpointDir='oss://sqlflow-models/iris/sqlflow/my_dnn_model/?role_arn=xxx&host=xxx' -DgpuRequired='0'", scriptPath, paramsPath)
160+
expectedPAICmd := fmt.Sprintf("pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 -DjobName=sqlflow_my_dnn_model -Dtags=dnn -Dscript=%s -DentryFile=entry.py -Dtables=odps://iris/tables/predict -Doutputs=odps://iris/tables/predict -DhyperParameters=\"%s\" -DcheckpointDir='oss://sqlflow-models/iris/sqlflow/my_dnn_model/?role_arn=acs:ram::9527:role/pai2oss_project&host=h.com' -DgpuRequired='0'", scriptPath, paramsPath)
161161
a.Equal(expectedPAICmd, paiCmd)
162162
}

go/codegen/pai/tensorflow.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -207,27 +207,11 @@ func getTFPAICmd(cc *ClusterConfig, tarball, paramsFile, modelName, ossModelPath
207207
cmd := fmt.Sprintf("pai -name tensorflow1150 -project algo_public_dev -DmaxHungTimeBeforeGCInSeconds=0 -DjobName=%s -Dtags=dnn -Dscript=%s -DentryFile=entry.py -Dtables=%s %s -DhyperParameters=\"%s\"",
208208
jobName, tarball, submitTables, outputTables, paramsFile)
209209

210-
// format the oss checkpoint path with ARN authorization.
211-
ossCheckpointConfigs := os.Getenv("SQLFLOW_OSS_CHECKPOINT_DIR")
212-
if ossCheckpointConfigs == "" {
213-
return "", fmt.Errorf("need to configure SQLFLOW_OSS_CHECKPOINT_DIR when submitting to PAI")
214-
}
215-
ossJSONConfigs := make(map[string]string)
216-
if err := json.Unmarshal([]byte(ossCheckpointConfigs), &ossJSONConfigs); err != nil {
210+
chkpoint, err := getCheckpointDir(ossModelPath, project)
211+
if err != nil {
217212
return "", err
218213
}
219-
currProjectOSS, ok := ossJSONConfigs[project]
220-
if !ok {
221-
return "", fmt.Errorf("project %s not configured in SQLFLOW_OSS_CHECKPOINT_DIR", project)
222-
}
223-
arnSplited := strings.Split(currProjectOSS, "?")
224-
if len(arnSplited) != 2 {
225-
return "", fmt.Errorf("need to configure SQLFLOW_OSS_CHECKPOINT_DIR when submitting to PAI")
226-
}
227-
arn := arnSplited[1]
228-
ossURI := OSSModelURL(ossModelPath)
229-
ossCheckpointPath := fmt.Sprintf("%s/?%s", ossURI, arn)
230-
cmd = fmt.Sprintf("%s -DcheckpointDir='%s'", cmd, ossCheckpointPath)
214+
cmd = fmt.Sprintf("%s -DcheckpointDir='%s'", cmd, chkpoint)
231215

232216
if cc.Worker.Count > 1 {
233217
cmd = fmt.Sprintf("%s -Dcluster=%s", cmd, cfQuote)
@@ -236,3 +220,23 @@ func getTFPAICmd(cc *ClusterConfig, tarball, paramsFile, modelName, ossModelPath
236220
}
237221
return cmd, nil
238222
}
223+
224+
type roleArn struct {
225+
Host string `json:"host"`
226+
Arn string `json:"arn"`
227+
}
228+
229+
func getCheckpointDir(ossModelPath, project string) (string, error) {
230+
ckpJSONStr := os.Getenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
231+
if ckpJSONStr == "" {
232+
return "", fmt.Errorf("need to configure SQLFLOW_OSS_CHECKPOINT_CONFIG when submitting to PAI")
233+
}
234+
ra := roleArn{}
235+
if err := json.Unmarshal([]byte(ckpJSONStr), &ra); err != nil {
236+
return "", err
237+
}
238+
ossURL := OSSModelURL(ossModelPath)
239+
roleName := fmt.Sprintf("pai2oss_%s", project)
240+
// format the oss checkpoint path with ARN authorization.
241+
return fmt.Sprintf("%s/?role_arn=%s/%s&host=%s", ossURL, ra.Arn, roleName, ra.Host), nil
242+
}

go/codegen/pai/tensorflow_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright 2020 The SQLFlow Authors. All rights reserved.
2+
// Licensed under the Apache License, Version 2.0 (the "License");
3+
// you may not use this file except in compliance with the License.
4+
// You may obtain a copy of the License at
5+
//
6+
// http://www.apache.org/licenses/LICENSE-2.0
7+
//
8+
// Unless required by applicable law or agreed to in writing, software
9+
// distributed under the License is distributed on an "AS IS" BASIS,
10+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
// See the License for the specific language governing permissions and
12+
// limitations under the License.
13+
14+
package pai
15+
16+
import (
17+
"fmt"
18+
"os"
19+
"testing"
20+
21+
"github.com/stretchr/testify/assert"
22+
)
23+
24+
func TestGetCheckpointDir(t *testing.T) {
25+
a := assert.New(t)
26+
os.Setenv("SQLFLOW_OSS_CHECKPOINT_CONFIG", "{\"host\": \"h.com\", \"arn\": \"acs:ram::9527:role\"}")
27+
defer os.Unsetenv("SQLFLOW_OSS_CHECKPOINT_CONFIG")
28+
ossModelPath, project := "p/t/m", "pr0j"
29+
30+
ckpoint, err := getCheckpointDir(ossModelPath, project)
31+
a.NoError(err)
32+
expectedCkp := fmt.Sprintf("oss://%s/p/t/m/?role_arn=acs:ram::9527:role/pai2oss_pr0j&host=h.com", BucketName)
33+
a.Equal(expectedCkp, ckpoint)
34+
}

0 commit comments

Comments
 (0)