Skip to content

Commit b50616b

Browse files
committed
Added test and clarified errors
Signed-off-by: nithinkdb <[email protected]>
1 parent 35ae3e3 commit b50616b

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

connection.go

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,18 @@ func (c *conn) HandleStagingDelete(ctx context.Context, presignedUrl string, hea
253253
return driver.ResultNoRows, nil
254254
}
255255

256-
func localPathIsAllowed(ctx context.Context, localFile string) bool {
257-
stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx)
256+
func localPathIsAllowed(stagingAllowedLocalPaths []string, localFile string) bool {
258257
for i := range stagingAllowedLocalPaths {
259-
path := stagingAllowedLocalPaths[i]
258+
// Convert both filepaths to absolute paths to avoid potential issues.
259+
//
260+
path, err := filepath.Abs(stagingAllowedLocalPaths[i])
261+
if err != nil {
262+
return false
263+
}
264+
localFile, err := filepath.Abs(localFile)
265+
if err != nil {
266+
return false
267+
}
260268
relativePath, err := filepath.Rel(path, localFile)
261269
if err != nil {
262270
return false
@@ -275,13 +283,16 @@ func (c *conn) ExecStagingOperation(
275283
var sqlRow []driver.Value
276284
colNames := row.Columns()
277285
sqlRow = make([]driver.Value, len(colNames))
278-
row.Next(sqlRow)
286+
err := row.Next(sqlRow)
287+
if err != nil {
288+
return nil, dbsqlerrint.NewDriverError(ctx, "Error fetching staging operation results", err)
289+
}
279290
var stringValues []string = make([]string, 4)
280291
for i := range stringValues {
281292
if s, ok := sqlRow[i].(string); ok {
282293
stringValues[i] = s
283294
} else {
284-
return nil, fmt.Errorf("local file operations are restricted to paths within the configured staging_allowed_local_path")
295+
return nil, dbsqlerrint.NewDriverError(ctx, "Received unexpected response from the server.", nil)
285296
}
286297
}
287298
operation := stringValues[0]
@@ -292,18 +303,19 @@ func (c *conn) ExecStagingOperation(
292303
return nil, err
293304
}
294305
localFile := stringValues[3]
306+
stagingAllowedLocalPaths := driverctx.StagingPathsFromContext(ctx)
295307
switch operation {
296308
case "PUT":
297-
if localPathIsAllowed(ctx, localFile) {
309+
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
298310
return c.HandleStagingPut(ctx, presignedUrl, headers, localFile)
299311
} else {
300-
return nil, fmt.Errorf("local file operations are restricted to paths within the configured staging_allowed_local_path")
312+
return nil, dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
301313
}
302314
case "GET":
303-
if localPathIsAllowed(ctx, localFile) {
315+
if localPathIsAllowed(stagingAllowedLocalPaths, localFile) {
304316
return c.HandleStagingGet(ctx, presignedUrl, headers, localFile)
305317
} else {
306-
return nil, dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured staging_allowed_local_path", nil)
318+
return nil, dbsqlerrint.NewDriverError(ctx, "local file operations are restricted to paths within the configured stagingAllowedLocalPath", nil)
307319
}
308320
case "DELETE":
309321
return c.HandleStagingDelete(ctx, presignedUrl, headers)

examples/staging/main.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,10 @@ func main() {
5454
return
5555
}
5656

57+
_, err1 = db.ExecContext(ctx, `GET '/Volumes/main/staging_test/e2etests/file1.csv' TO 'staging/newfile.csv'`)
58+
if err1 != nil {
59+
fmt.Println(err1.Error())
60+
return
61+
}
62+
5763
}

staging_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package dbsql
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestPathAllowed(t *testing.T) {
10+
t.Run("Should not allow paths that don't share directory", func(t *testing.T) {
11+
stagingAllowedLocalPath := []string{"/var/www/html"}
12+
localFile := "/var/www/html/../html1/not_allowed.html"
13+
assert.False(t, localPathIsAllowed(stagingAllowedLocalPath, localFile))
14+
})
15+
t.Run("Should allow multiple specified allowed local staging paths", func(t *testing.T) {
16+
stagingAllowedLocalPath := []string{"/foo", "/var/www/html"}
17+
localFile := "/var/www/html/allowed.html"
18+
assert.True(t, localPathIsAllowed(stagingAllowedLocalPath, localFile))
19+
})
20+
}

0 commit comments

Comments
 (0)