Skip to content

Commit afe2ef3

Browse files
committed
make sure to take warehouse ID from environment
1 parent d0b86de commit afe2ef3

File tree

4 files changed

+36
-30
lines changed

4 files changed

+36
-30
lines changed

experimental/apps-mcp/cmd/apps_mcp.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package mcp
22

33
import (
4+
"errors"
5+
"os"
6+
47
"github.com/databricks/cli/cmd/root"
58
mcplib "github.com/databricks/cli/experimental/apps-mcp/lib"
69
"github.com/databricks/cli/experimental/apps-mcp/lib/server"
@@ -38,6 +41,13 @@ The server communicates via stdio using the Model Context Protocol.`,
3841
RunE: func(cmd *cobra.Command, args []string) error {
3942
ctx := cmd.Context()
4043

44+
if warehouseID == "" {
45+
warehouseID = os.Getenv("DATABRICKS_WAREHOUSE_ID")
46+
if warehouseID == "" {
47+
return errors.New("DATABRICKS_WAREHOUSE_ID environment variable is required")
48+
}
49+
}
50+
4151
w := cmdctx.WorkspaceClient(ctx)
4252

4353
// Build MCP config from flags

experimental/apps-mcp/lib/providers/databricks/databricks.go

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package databricks
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"net/url"
78
"strconv"
@@ -47,7 +48,7 @@ func applyPagination[T any](items []T, limit, offset int) ([]T, int, int) {
4748
// Allows alphanumeric, underscore, hyphen, and dot (for qualified names).
4849
func validateIdentifier(id string) error {
4950
if id == "" {
50-
return fmt.Errorf("identifier cannot be empty")
51+
return errors.New("identifier cannot be empty")
5152
}
5253

5354
// Allow alphanumeric, underscore, hyphen, and dot for qualified names
@@ -209,7 +210,7 @@ func (r *ListCatalogsResultResponse) Display() string {
209210
lines = append(lines, "")
210211

211212
for _, catalog := range r.Catalogs {
212-
lines = append(lines, fmt.Sprintf("• %s", catalog))
213+
lines = append(lines, "• "+catalog)
213214
}
214215

215216
return strings.Join(lines, "\n")
@@ -227,7 +228,7 @@ func (r *ListSchemasResultResponse) Display() string {
227228
lines = append(lines, "")
228229

229230
for _, schema := range r.Schemas {
230-
lines = append(lines, fmt.Sprintf("• %s", schema))
231+
lines = append(lines, "• "+schema)
231232
}
232233

233234
return strings.Join(lines, "\n")
@@ -247,10 +248,10 @@ func (r *ListTablesResultResponse) Display() string {
247248
for _, table := range r.Tables {
248249
info := fmt.Sprintf("• %s (%s)", table.FullName, table.TableType)
249250
if table.Owner != nil {
250-
info += fmt.Sprintf(" - Owner: %s", *table.Owner)
251+
info += " - Owner: " + *table.Owner
251252
}
252253
if table.Comment != nil {
253-
info += fmt.Sprintf(" - %s", *table.Comment)
254+
info += " - " + *table.Comment
254255
}
255256
lines = append(lines, info)
256257
}
@@ -261,23 +262,23 @@ func (r *ListTablesResultResponse) Display() string {
261262
func (r *TableDetailsResponse) Display() string {
262263
var lines []string
263264

264-
lines = append(lines, fmt.Sprintf("Table: %s", r.FullName))
265-
lines = append(lines, fmt.Sprintf("Table Type: %s", r.TableType))
265+
lines = append(lines, "Table: "+r.FullName)
266+
lines = append(lines, "Table Type: "+r.TableType)
266267

267268
if r.Owner != nil {
268-
lines = append(lines, fmt.Sprintf("Owner: %s", *r.Owner))
269+
lines = append(lines, "Owner: "+*r.Owner)
269270
}
270271
if r.Comment != nil {
271-
lines = append(lines, fmt.Sprintf("Comment: %s", *r.Comment))
272+
lines = append(lines, "Comment: "+*r.Comment)
272273
}
273274
if r.RowCount != nil {
274275
lines = append(lines, fmt.Sprintf("Row Count: %d", *r.RowCount))
275276
}
276277
if r.StorageLocation != nil {
277-
lines = append(lines, fmt.Sprintf("Storage Location: %s", *r.StorageLocation))
278+
lines = append(lines, "Storage Location: "+*r.StorageLocation)
278279
}
279280
if r.DataSourceFormat != nil {
280-
lines = append(lines, fmt.Sprintf("Data Source Format: %s", *r.DataSourceFormat))
281+
lines = append(lines, "Data Source Format: "+*r.DataSourceFormat)
281282
}
282283

283284
if len(r.Columns) > 0 {
@@ -289,7 +290,7 @@ func (r *TableDetailsResponse) Display() string {
289290
}
290291
colInfo := fmt.Sprintf(" - %s: %s (%s)", col.Name, col.DataType, nullableStr)
291292
if col.Comment != nil {
292-
colInfo += fmt.Sprintf(" - %s", *col.Comment)
293+
colInfo += " - " + *col.Comment
293294
}
294295
lines = append(lines, colInfo)
295296
}
@@ -328,7 +329,7 @@ func (r *ExecuteSqlResultResponse) Display() string {
328329
for k := range r.Rows[0] {
329330
columns = append(columns, k)
330331
}
331-
lines = append(lines, fmt.Sprintf("Columns: %s", strings.Join(columns, ", ")))
332+
lines = append(lines, "Columns: "+strings.Join(columns, ", "))
332333
lines = append(lines, "")
333334
lines = append(lines, "Results:")
334335
}
@@ -376,7 +377,7 @@ func NewDatabricksRestClient(ctx context.Context, cfg *mcp.Config) (*DatabricksR
376377

377378
warehouseID := cfg.WarehouseID
378379
if warehouseID == "" {
379-
return nil, fmt.Errorf("DATABRICKS_WAREHOUSE_ID not configured")
380+
return nil, errors.New("DATABRICKS_WAREHOUSE_ID not configured")
380381
}
381382

382383
return &DatabricksRestClient{
@@ -494,7 +495,7 @@ func (c *DatabricksRestClient) processStatementResult(ctx context.Context, resul
494495

495496
if result.Manifest == nil || result.Manifest.Schema == nil {
496497
log.Debugf(ctx, "No schema in response")
497-
return nil, fmt.Errorf("no schema in response")
498+
return nil, errors.New("no schema in response")
498499
}
499500

500501
schema := result.Manifest.Schema
@@ -657,7 +658,7 @@ func (c *DatabricksRestClient) ListTables(ctx context.Context, request *ListTabl
657658
func (c *DatabricksRestClient) listTablesViaInformationSchema(ctx context.Context, request *ListTablesRequest) (*ListTablesResultResponse, error) {
658659
// Validate invalid combination
659660
if request.CatalogName == nil && request.SchemaName != nil {
660-
return nil, fmt.Errorf("schema_name requires catalog_name to be specified")
661+
return nil, errors.New("schema_name requires catalog_name to be specified")
661662
}
662663

663664
// Validate identifiers for SQL safety
@@ -741,25 +742,25 @@ func (c *DatabricksRestClient) listTablesViaInformationSchema(ctx context.Contex
741742
for _, row := range rows {
742743
catalogVal, ok := row["table_catalog"]
743744
if !ok {
744-
return nil, fmt.Errorf("missing table_catalog in row")
745+
return nil, errors.New("missing table_catalog in row")
745746
}
746747
catalog := fmt.Sprintf("%v", catalogVal)
747748

748749
schemaVal, ok := row["table_schema"]
749750
if !ok {
750-
return nil, fmt.Errorf("missing table_schema in row")
751+
return nil, errors.New("missing table_schema in row")
751752
}
752753
schema := fmt.Sprintf("%v", schemaVal)
753754

754755
nameVal, ok := row["table_name"]
755756
if !ok {
756-
return nil, fmt.Errorf("missing table_name in row")
757+
return nil, errors.New("missing table_name in row")
757758
}
758759
name := fmt.Sprintf("%v", nameVal)
759760

760761
tableTypeVal, ok := row["table_type"]
761762
if !ok {
762-
return nil, fmt.Errorf("missing table_type in row")
763+
return nil, errors.New("missing table_type in row")
763764
}
764765
tableType := fmt.Sprintf("%v", tableTypeVal)
765766

@@ -786,7 +787,7 @@ func (c *DatabricksRestClient) listTablesViaInformationSchema(ctx context.Contex
786787
}
787788

788789
func (c *DatabricksRestClient) listTablesImpl(ctx context.Context, catalogName, schemaName string, excludeInaccessible bool) ([]TableInfoResponse, error) {
789-
tables := []TableInfoResponse{}
790+
var tables []TableInfoResponse
790791

791792
w := cmdctx.WorkspaceClient(ctx)
792793
clientCfg, err := config.HTTPClientConfigFromConfig(w.Config)
@@ -880,7 +881,7 @@ func (c *DatabricksRestClient) DescribeTable(ctx context.Context, request *Descr
880881
}
881882

882883
var rowCount *int64
883-
countQuery := fmt.Sprintf("SELECT COUNT(*) as count FROM %s", tableName)
884+
countQuery := "SELECT COUNT(*) as count FROM " + tableName
884885
if rows, err := c.executeSqlImpl(ctx, countQuery); err == nil && len(rows) > 0 {
885886
if countVal, ok := rows[0]["count"]; ok {
886887
switch v := countVal.(type) {

experimental/apps-mcp/lib/providers/databricks/deployment.go

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package databricks
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
7-
"os"
86
"os/exec"
97
"time"
108

@@ -103,11 +101,8 @@ func DeployApp(ctx context.Context, cfg *mcp.Config, appInfo *apps.App) error {
103101
return nil
104102
}
105103

106-
func ResourcesFromEnv() (*apps.AppResource, error) {
107-
warehouseID := os.Getenv("DATABRICKS_WAREHOUSE_ID")
108-
if warehouseID == "" {
109-
return nil, errors.New("DATABRICKS_WAREHOUSE_ID environment variable is required for app deployment. Set this to your Databricks SQL warehouse ID")
110-
}
104+
func ResourcesFromEnv(cfg *mcp.Config) (*apps.AppResource, error) {
105+
warehouseID := cfg.WarehouseID
111106

112107
return &apps.AppResource{
113108
Name: "base",

experimental/apps-mcp/lib/providers/deployment/provider.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ func (p *Provider) getOrCreateApp(ctx context.Context, name, description string,
289289

290290
log.Infof(ctx, "App not found, creating new app: name=%s", name)
291291

292-
resources, err := databricks.ResourcesFromEnv()
292+
resources, err := databricks.ResourcesFromEnv(p.config)
293293
if err != nil {
294294
return nil, err
295295
}

0 commit comments

Comments
 (0)