diff --git a/pkg/lifecycle/binary.go b/pkg/lifecycle/binary.go index 1087f5d449..4c83fed420 100644 --- a/pkg/lifecycle/binary.go +++ b/pkg/lifecycle/binary.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "net/http" + "net/url" "os" "os/exec" "path/filepath" @@ -104,7 +105,7 @@ func RunBinary(ctx context.Context, execPath string, args []string) (*Command, e // DownloadBinary downloads a file from the given URL into the specified path // this also marks it executable and returns its full path. -func DownloadBinary(url, destDir, destFile string, logger *zap.Logger) (string, error) { +func DownloadBinary(sourceURL, destDir, destFile string, logger *zap.Logger) (string, error) { if err := os.MkdirAll(destDir, 0755); err != nil { return "", fmt.Errorf("could not create directory %s (%w)", destDir, err) } @@ -132,32 +133,52 @@ func DownloadBinary(url, destDir, destFile string, logger *zap.Logger) (string, } }() - logger.Info("downloading binary", zap.String("url", url)) + logger.Info("downloading binary", zap.String("url", sourceURL)) - req, err := http.NewRequest("GET", url, nil) + u, err := url.Parse(sourceURL) if err != nil { - return "", fmt.Errorf("could not create request (%w)", err) + return "", fmt.Errorf("could not parse URL %s (%w)", sourceURL, err) } - client := &http.Client{} - resp, err := client.Do(req) - if err != nil { - return "", fmt.Errorf("HTTP GET %s failed (%w)", url, err) - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("HTTP GET %s failed with error %d", url, resp.StatusCode) - } + switch u.Scheme { + case "http", "https": + req, err := http.NewRequest("GET", sourceURL, nil) + if err != nil { + return "", fmt.Errorf("could not create request (%w)", err) + } + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("HTTP GET %s failed (%w)", sourceURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("HTTP GET %s failed with error %d", sourceURL, resp.StatusCode) + } - if _, err = io.Copy(tmpFile, resp.Body); err != nil { - return "", fmt.Errorf("could not copy from %s to %s (%w)", url, tmpName, err) + if _, err = io.Copy(tmpFile, resp.Body); err != nil { + return "", fmt.Errorf("could not copy from %s to %s (%w)", sourceURL, tmpName, err) + } + + case "file": + data, err := os.ReadFile(u.Path) + if err != nil { + return "", fmt.Errorf("could not read file %s (%w)", u.Path, err) + } + + if _, err = tmpFile.Write(data); err != nil { + return "", fmt.Errorf("could not write to %s (%w)", tmpName, err) + } + + default: + return "", fmt.Errorf("unsupported file scheme %s", u.Scheme) } if err := os.Chmod(tmpName, 0755); err != nil { return "", fmt.Errorf("could not chmod file %s (%w)", tmpName, err) } - tmpFile.Close() if err := os.Rename(tmpName, destPath); err != nil { return "", fmt.Errorf("could not move %s to %s (%w)", tmpName, destPath, err) } diff --git a/pkg/lifecycle/binary_test.go b/pkg/lifecycle/binary_test.go index 28d8bdc958..5132e01b78 100644 --- a/pkg/lifecycle/binary_test.go +++ b/pkg/lifecycle/binary_test.go @@ -18,6 +18,8 @@ import ( "context" "net/http" "net/http/httptest" + "os" + "path" "strconv" "testing" "time" @@ -91,10 +93,10 @@ func TestDownloadBinary(t *testing.T) { defer server.Close() logger := zaptest.NewLogger(t) - destDir := t.TempDir() - destFile := "test-binary" t.Run("successful download", func(t *testing.T) { + destDir := t.TempDir() + destFile := "test-binary" url := server.URL + "/binary" path, err := DownloadBinary(url, destDir, destFile, logger) require.NoError(t, err) @@ -102,6 +104,8 @@ func TestDownloadBinary(t *testing.T) { }) t.Run("file already exists", func(t *testing.T) { + destDir := t.TempDir() + destFile := "test-binary" url := server.URL + "/binary" path, err := DownloadBinary(url, destDir, destFile, logger) require.NoError(t, err) @@ -112,6 +116,35 @@ func TestDownloadBinary(t *testing.T) { require.NoError(t, err) assert.FileExists(t, path) }) + + t.Run("file on local", func(t *testing.T) { + sourceDir := t.TempDir() + sourceFile := "test-binary" + sourcePath := path.Join(sourceDir, sourceFile) + err := os.WriteFile(sourcePath, []byte("test binary content"), 0755) + require.NoError(t, err) + + destDir := t.TempDir() + destFile := "test-binary" + url := "file://" + sourcePath + + path, err := DownloadBinary(url, destDir, destFile, logger) + require.NoError(t, err) + assert.FileExists(t, path) + content, err := os.ReadFile(path) + require.NoError(t, err) + assert.Equal(t, "test binary content", string(content)) + }) + + t.Run("not valid source url given", func(t *testing.T) { + destDir := t.TempDir() + destFile := "test-binary" + url := "ftp://invalid-url" + + path, err := DownloadBinary(url, destDir, destFile, logger) + require.Error(t, err) + assert.Empty(t, path) + }) } func httpTestServer() *httptest.Server {