From f3c85c0ebe3e54dc91836fe998c2d7c8e3373a52 Mon Sep 17 00:00:00 2001 From: Scott Winkler Date: Mon, 13 Nov 2023 01:16:50 -0800 Subject: [PATCH] fix: provider muxing for plugin-framework (#2130) * provider config refactor * gen docs * update provider * update docs * update provider * framework initial work * update rm * update resource monitor * update docs * go fmt * update provider config * add to framework repo --- docs/index.md | 2 +- framework/planmodifiers/stringplanmodifier.go | 35 + framework/provider/helpers.go | 84 ++ framework/provider/provider.go | 815 ++++++++++++++++++ framework/provider/provider_helpers.go | 186 ++++ .../provider/resource_monitor_resource.go | 807 +++++++++++++++++ go.mod | 2 + go.sum | 4 + main.go | 60 +- pkg/provider/provider.go | 39 +- pkg/provider/provider_helpers.go | 49 ++ pkg/sdk/client.go | 30 + 12 files changed, 2074 insertions(+), 39 deletions(-) create mode 100644 framework/planmodifiers/stringplanmodifier.go create mode 100644 framework/provider/helpers.go create mode 100644 framework/provider/provider.go create mode 100644 framework/provider/provider_helpers.go create mode 100644 framework/provider/resource_monitor_resource.go diff --git a/docs/index.md b/docs/index.md index 2063a87819..f1ceeb044b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -73,8 +73,8 @@ provider "snowflake" { - `oauth_endpoint` (String, Sensitive, Deprecated) Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_ENDPOINT` environment variable. - `oauth_redirect_url` (String, Sensitive, Deprecated) Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_REDIRECT_URL` environment variable. - `oauth_refresh_token` (String, Sensitive, Deprecated) Token for use with OAuth. Setup and generation of the token is left to other tools. Should be used in conjunction with `oauth_client_id`, `oauth_client_secret`, `oauth_endpoint`, `oauth_redirect_url`. Cannot be used with `browser_auth`, `private_key_path`, `oauth_access_token` or `password`. Can also be sourced from `SNOWFLAKE_OAUTH_REFRESH_TOKEN` environment variable. +- `ocsp_fail_open` (Boolean) True represents OCSP fail open mode. False represents OCSP fail closed mode. Fail open true by default. Can also be sourced from the `SNOWFLAKE_OCSP_FAIL_OPEN` environment variable. - `okta_url` (String) The URL of the Okta server. e.g. https://example.okta.com. Can also be sourced from the `SNOWFLAKE_OKTA_URL` environment variable. -- `oscp_fail_open` (Boolean) True represents OCSP fail open mode. False represents OCSP fail closed mode. Fail open true by default. Can also be sourced from the `SNOWFLAKE_OCSP_FAIL_OPEN` environment variable. - `params` (Map of String) Sets other connection (i.e. session) parameters. [Parameters](https://docs.snowflake.com/en/sql-reference/parameters) - `passcode` (String) Specifies the passcode provided by Duo when using multi-factor authentication (MFA) for login. Can also be sourced from the `SNOWFLAKE_PASSCODE` environment variable. - `passcode_in_password` (Boolean) False by default. Set to true if the MFA passcode is embedded in the login password. Appends the MFA passcode to the end of the password. Can also be sourced from the `SNOWFLAKE_PASSCODE_IN_PASSWORD` environment variable. diff --git a/framework/planmodifiers/stringplanmodifier.go b/framework/planmodifiers/stringplanmodifier.go new file mode 100644 index 0000000000..6af3cc23a1 --- /dev/null +++ b/framework/planmodifiers/stringplanmodifier.go @@ -0,0 +1,35 @@ +package stringplanmodifiers + +import ( + "context" + + "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" +) + +// useStateForUnknownModifier implements the plan modifier. +type suppressDiffIfModifier struct { + f func(old, new string) bool +} + +// Description returns a human-readable description of the plan modifier. +func (m suppressDiffIfModifier) Description(_ context.Context) string { + return "Suppresses diff if values based on function." +} + +// MarkdownDescription returns a markdown description of the plan modifier. +func (m suppressDiffIfModifier) MarkdownDescription(_ context.Context) string { + return "Suppresses diff if values based on function." +} + +// PlanModifyBool implements the plan modification logic. +func (m suppressDiffIfModifier) PlanModifyString(_ context.Context, req planmodifier.StringRequest, resp *planmodifier.StringResponse) { + if m.f(req.StateValue.ValueString(), req.PlanValue.ValueString()) { + resp.PlanValue = req.StateValue + } +} + +func SuppressDiffIf(f func(old, new string) bool) planmodifier.String { + return suppressDiffIfModifier{ + f: f, + } +} diff --git a/framework/provider/helpers.go b/framework/provider/helpers.go new file mode 100644 index 0000000000..b2b8f154f9 --- /dev/null +++ b/framework/provider/helpers.go @@ -0,0 +1,84 @@ +package provider + +import ( + "os" + "path/filepath" + "regexp" + "strings" + "sync" + + "github.com/gookit/color" +) + +type tfOperation string + +const ( + CreateOperation tfOperation = "CREATE" + ReadOperation tfOperation = "READ" + UpdateOperation tfOperation = "UPDATE" + DeleteOperation tfOperation = "DELETE" +) + +func formatSQLPreview(operation tfOperation, resourceName string, id string, commands []string) string { + var c color.Color + switch operation { + case CreateOperation: + c = color.HiGreen + case ReadOperation: + c = color.HiBlue + case UpdateOperation: + c = color.HiYellow + case DeleteOperation: + c = color.HiRed + } + var sb strings.Builder + sb.WriteString(c.Sprintf("\n[ %s %s %s ]", operation, resourceName, id)) + for _, command := range commands { + sb.WriteString(c.Sprintf("\n - %s", command)) + } + sb.WriteString("\n") + return sb.String() +} + +type sensitiveAttributes struct { + m map[string]bool +} + +var ( + sa *sensitiveAttributes + lock = sync.Mutex{} +) + +func isSensitive(s string) bool { + if sa == nil { + lock.Lock() + defer lock.Unlock() + if sa == nil { + sa = &sensitiveAttributes{ + m: make(map[string]bool), + } + dir, err := os.UserHomeDir() + if err != nil { + return false + } + // sensitive path is ~/.snowflake/sensitive. + f := filepath.Join(dir, ".snowflake", "sensitive") + dat, err := os.ReadFile(f) + if err != nil { + return false + } + lines := strings.Split(string(dat), "\n") + r := regexp.MustCompile("(data[.])?snowflake_(.*)[.](.+)[.](.+)") + for _, line := range lines { + strippedLine := strings.TrimSpace(line) + if r.MatchString(strippedLine) { + sa.m[strippedLine] = true + } + } + } + } + if _, ok := sa.m[s]; ok { + return true + } + return false +} diff --git a/framework/provider/provider.go b/framework/provider/provider.go new file mode 100644 index 0000000000..235a6f3dd1 --- /dev/null +++ b/framework/provider/provider.go @@ -0,0 +1,815 @@ +package provider + +import ( + "context" + "net" + "net/url" + "os" + "time" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-framework-validators/boolvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/listvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/datasource" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/provider" + "github.com/hashicorp/terraform-plugin-framework/provider/schema" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/snowflakedb/gosnowflake" +) + +// Ensure SnowflakeProvider satisfies various provider interfaces. +var _ provider.Provider = new(SnowflakeProvider) + +// SnowflakeProvider defines the provider implementation. +type SnowflakeProvider struct { + // version is set to the provider version on release, "dev" when the + // provider is built and ran locally, and "test" when running acceptance + // testing. + version string +} + +// SnowflakeProviderModel describes the provider data model. +type snowflakeProviderModelV0 struct { + Account types.String `tfsdk:"account"` + User types.String `tfsdk:"user"` + Password types.String `tfsdk:"password"` + Warehouse types.String `tfsdk:"warehouse"` + Role types.String `tfsdk:"role"` + Region types.String `tfsdk:"region"` + ValidateDefaultParameters types.Bool `tfsdk:"validate_default_parameters"` + Params types.Map `tfsdk:"params"` + ClientIP types.String `tfsdk:"client_ip"` + Protocol types.String `tfsdk:"protocol"` + Host types.String `tfsdk:"host"` + Port types.Int64 `tfsdk:"port"` + Authenticator types.String `tfsdk:"authenticator"` + Passcode types.String `tfsdk:"passcode"` + PasscodeInPassword types.Bool `tfsdk:"passcode_in_password"` + OktaURL types.String `tfsdk:"okta_url"` + LoginTimeout types.Int64 `tfsdk:"login_timeout"` + RequestTimeout types.Int64 `tfsdk:"request_timeout"` + JWTExpireTimeout types.Int64 `tfsdk:"jwt_expire_timeout"` + ClientTimeout types.Int64 `tfsdk:"client_timeout"` + JWTClientTimeout types.Int64 `tfsdk:"jwt_client_timeout"` + ExternalBrowserTimeout types.Int64 `tfsdk:"external_browser_timeout"` + InsecureMode types.Bool `tfsdk:"insecure_mode"` + OCSPFailOpen types.Bool `tfsdk:"ocsp_fail_open"` + Token types.String `tfsdk:"token"` + TokenAccessor types.List `tfsdk:"token_accessor"` + KeepSessionAlive types.Bool `tfsdk:"keep_session_alive"` + PrivateKey types.String `tfsdk:"private_key"` + PrivateKeyPassphrase types.String `tfsdk:"private_key_passphrase"` + DisableTelemetry types.Bool `tfsdk:"disable_telemetry"` + ClientRequestMFAToken types.Bool `tfsdk:"client_request_mfa_token"` + ClientStoreTemporaryCredential types.Bool `tfsdk:"client_store_temporary_credential"` + DisableQueryContextCache types.Bool `tfsdk:"disable_query_context_cache"` + Profile types.String `tfsdk:"profile"` + // Deprecated Attributes + Username types.String `tfsdk:"username"` + OauthAccessToken types.String `tfsdk:"oauth_access_token"` + OauthRefreshToken types.String `tfsdk:"oauth_refresh_token"` + OauthClientID types.String `tfsdk:"oauth_client_id"` + OauthClientSecret types.String `tfsdk:"oauth_client_secret"` + OauthEndpoint types.String `tfsdk:"oauth_endpoint"` + OauthRedirectURL types.String `tfsdk:"oauth_redirect_url"` + BrowserAuth types.Bool `tfsdk:"browser_auth"` + PrivateKeyPath types.String `tfsdk:"private_key_path"` + SessionParams types.Map `tfsdk:"session_params"` +} + +type RefreshTokenAccesor struct { + TokenEndpoint types.String `tfsdk:"token_endpoint"` + RefreshToken types.String `tfsdk:"refresh_token"` + ClientID types.String `tfsdk:"client_id"` + ClientSecret types.String `tfsdk:"client_secret"` + RedirectURI types.String `tfsdk:"redirect_uri"` +} + +func (p *SnowflakeProvider) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) { + resp.TypeName = "snowflake" + resp.Version = p.version +} + +func (p *SnowflakeProvider) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) { + resp.Schema = schema.Schema{ + Attributes: map[string]schema.Attribute{ + "account": schema.StringAttribute{ + Description: "Specifies your Snowflake account identifier assigned, by Snowflake. For information about account identifiers, see the [Snowflake documentation](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html). Can also be sourced from the `SNOWFLAKE_ACCOUNT` environment variable. Required unless using `profile`.", + Optional: true, + }, + "user": schema.StringAttribute{ + Description: "Username. Can also be sourced from the `SNOWFLAKE_USER` environment variable. Required unless using `profile`.", + Optional: true, + }, + "username": schema.StringAttribute{ + Description: "Username for username+password authentication. Can also be sourced from the `SNOWFLAKE_USER` environment variable. Required unless using `profile`.", + Optional: true, + DeprecationMessage: "Use `user` instead", + }, + "password": schema.StringAttribute{ + Description: "Password for username+password auth. Cannot be used with `browser_auth` or `private_key_path`. Can also be sourced from the `SNOWFLAKE_PASSWORD` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("oauth_access_token"), path.MatchRoot("oauth_refresh_token")), + }, + }, + "warehouse": schema.StringAttribute{ + Description: "Specifies the virtual warehouse to use by default for queries, loading, etc. in the client session. Can also be sourced from the `SNOWFLAKE_WAREHOUSE` environment variable.", + Optional: true, + }, + "role": schema.StringAttribute{ + Description: "Specifies the role to use by default for accessing Snowflake objects in the client session. Can also be sourced from the `SNOWFLAKE_ROLE` environment variable. .", + Optional: true, + }, + "validate_default_parameters": schema.BoolAttribute{ + Description: "True by default. If false, disables the validation checks for Database, Schema, Warehouse and Role at the time a connection is established. Can also be sourced from the `SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS` environment variable.", + Optional: true, + }, + "params": schema.MapAttribute{ + Description: "Sets other connection (i.e. session) parameters. [Parameters](https://docs.snowflake.com/en/sql-reference/parameters)", + Optional: true, + ElementType: types.StringType, + }, + "client_ip": schema.StringAttribute{ + Description: "IP address for network checks. Can also be sourced from the `SNOWFLAKE_CLIENT_IP` environment variable.", + Optional: true, + }, + "protocol": schema.StringAttribute{ + Description: "Either http or https, defaults to https. Can also be sourced from the `SNOWFLAKE_PROTOCOL` environment variable.", + Optional: true, + Validators: []validator.String{ + stringvalidator.OneOf("http", "https"), + }, + }, + "host": schema.StringAttribute{ + Description: "Supports passing in a custom host value to the snowflake go driver for use with privatelink. Can also be sourced from the `SNOWFLAKE_HOST` environment variable. ", + Optional: true, + }, + "port": schema.Int64Attribute{ + Description: "Support custom port values to snowflake go driver for use with privatelink. Can also be sourced from the `SNOWFLAKE_PORT` environment variable. ", + Optional: true, + }, + "authenticator": schema.StringAttribute{ + Description: "Specifies the [authentication type](https://pkg.go.dev/github.com/snowflakedb/gosnowflake#AuthType) to use when connecting to Snowflake. Valid values include: Snowflake, OAuth, ExternalBrowser, Okta, JWT, TokenAccessor, UsernamePasswordMFA. Can also be sourced from the `SNOWFLAKE_AUTHENTICATOR` environment variable.", + Optional: true, + Validators: []validator.String{ + stringvalidator.OneOf("Snowflake", "OAuth", "ExternalBrowser", "Okta", "JWT", "TokenAccessor", "UsernamePasswordMFA"), + }, + }, + "passcode": schema.StringAttribute{ + Description: "Specifies the passcode provided by Duo when using multi-factor authentication (MFA) for login. Can also be sourced from the `SNOWFLAKE_PASSCODE` environment variable. ", + Optional: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("passcode_in_password")), + }, + }, + "passcode_in_password": schema.BoolAttribute{ + Description: "False by default. Set to true if the MFA passcode is embedded in the login password. Appends the MFA passcode to the end of the password. Can also be sourced from the `SNOWFLAKE_PASSCODE_IN_PASSWORD` environment variable. ", + Optional: true, + Validators: []validator.Bool{ + boolvalidator.ConflictsWith(path.MatchRoot("passcode")), + }, + }, + "okta_url": schema.StringAttribute{ + Description: "The URL of the Okta server. e.g. https://example.okta.com. Can also be sourced from the `SNOWFLAKE_OKTA_URL` environment variable.", + Optional: true, + }, + "login_timeout": schema.Int64Attribute{ + Description: "Login retry timeout EXCLUDING network roundtrip and read out http response. Can also be sourced from the `SNOWFLAKE_LOGIN_TIMEOUT` environment variable.", + Optional: true, + }, + "request_timeout": schema.Int64Attribute{ + Description: "request retry timeout EXCLUDING network roundtrip and read out http response. Can also be sourced from the `SNOWFLAKE_REQUEST_TIMEOUT` environment variable.", + Optional: true, + }, + "jwt_expire_timeout": schema.Int64Attribute{ + Description: "JWT expire after timeout in seconds. Can also be sourced from the `SNOWFLAKE_JWT_EXPIRE_TIMEOUT` environment variable.", + Optional: true, + }, + "client_timeout": schema.Int64Attribute{ + Description: "The timeout in seconds for the client to complete the authentication. Default is 900 seconds. Can also be sourced from the `SNOWFLAKE_CLIENT_TIMEOUT` environment variable.", + Optional: true, + }, + "jwt_client_timeout": schema.Int64Attribute{ + Description: "The timeout in seconds for the JWT client to complete the authentication. Default is 10 seconds. Can also be sourced from the `SNOWFLAKE_JWT_CLIENT_TIMEOUT` environment variable.", + Optional: true, + }, + "external_browser_timeout": schema.Int64Attribute{ + Description: "The timeout in seconds for the external browser to complete the authentication. Default is 120 seconds. Can also be sourced from the `SNOWFLAKE_EXTERNAL_BROWSER_TIMEOUT` environment variable.", + Optional: true, + }, + "insecure_mode": schema.BoolAttribute{ + Description: "If true, bypass the Online Certificate Status Protocol (OCSP) certificate revocation check. IMPORTANT: Change the default value for testing or emergency situations only. Can also be sourced from the `SNOWFLAKE_INSECURE_MODE` environment variable.", + Optional: true, + }, + "ocsp_fail_open": schema.BoolAttribute{ + Description: "True represents OCSP fail open mode. False represents OCSP fail closed mode. Fail open true by default. Can also be sourced from the `SNOWFLAKE_OCSP_FAIL_OPEN` environment variable.", + Optional: true, + }, + "token": schema.StringAttribute{ + Description: "Token to use for OAuth and other forms of token based auth. Can also be sourced from the `SNOWFLAKE_TOKEN` environment variable.", + Sensitive: true, + Optional: true, + }, + "keep_session_alive": schema.BoolAttribute{ + Optional: true, + Description: "Enables the session to persist even after the connection is closed. Can also be sourced from the `SNOWFLAKE_KEEP_SESSION_ALIVE` environment variable.", + }, + "private_key": schema.StringAttribute{ + Description: "Private Key for username+private-key auth. Cannot be used with `browser_auth` or `password`. Can also be sourced from `SNOWFLAKE_PRIVATE_KEY` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("password"), path.MatchRoot("private_key_path"), path.MatchRoot("oauth_access_token"), path.MatchRoot("oauth_refresh_token")), + }, + }, + "private_key_passphrase": schema.StringAttribute{ + Description: "Supports the encryption ciphers aes-128-cbc, aes-128-gcm, aes-192-cbc, aes-192-gcm, aes-256-cbc, aes-256-gcm, and des-ede3-cbc. Can also be sourced from `SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("password"), path.MatchRoot("private_key_path"), path.MatchRoot("oauth_access_token"), path.MatchRoot("oauth_refresh_token")), + }, + }, + "disable_telemetry": schema.BoolAttribute{ + Description: "Indicates whether to disable telemetry. Can also be sourced from the `SNOWFLAKE_DISABLE_TELEMETRY` environment variable.", + Optional: true, + }, + "client_request_mfa_token": schema.BoolAttribute{ + Description: "When true the MFA token is cached in the credential manager. True by default in Windows/OSX. False for Linux. Can also be sourced from the `SNOWFLAKE_CLIENT_REQUEST_MFA_TOKEN` environment variable.", + Optional: true, + }, + "client_store_temporary_credential": schema.BoolAttribute{ + Description: "When true the ID token is cached in the credential manager. True by default in Windows/OSX. False for Linux. Can also be sourced from the `SNOWFLAKE_CLIENT_STORE_TEMPORARY_CREDENTIAL` environment variable.", + Optional: true, + }, + "disable_query_context_cache": schema.BoolAttribute{ + Description: "Should HTAP query context cache be disabled. Can also be sourced from the `SNOWFLAKE_DISABLE_QUERY_CONTEXT_CACHE` environment variable.", + Optional: true, + }, + "profile": schema.StringAttribute{ + Description: "Sets the profile to read from ~/.snowflake/config file. Can also be sourced from the `SNOWFLAKE_PROFILE` environment variable.", + Optional: true, + }, + /* + Feature not yet released as of latest gosnowflake release + https://github.com/snowflakedb/gosnowflake/blob/master/dsn.go#L103 + "include_retry_reason": schema.BoolAttribute { + Description: "Should retried request contain retry reason. Can also be sourced from the `SNOWFLAKE_INCLUDE_RETRY_REASON` environment variable.", + Optional: true, + }, + */ + // Deprecated Attributes + "region": schema.StringAttribute{ + Description: "Snowflake region, such as \"eu-central-1\", with this parameter. However, since this parameter is deprecated, it is best to specify the region as part of the account parameter. For details, see the description of the account parameter. [Snowflake region](https://docs.snowflake.com/en/user-guide/intro-regions.html) to use. Required if using the [legacy format for the `account` identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier.html#format-2-legacy-account-locator-in-a-region) in the form of `.`. Can also be sourced from the `SNOWFLAKE_REGION` environment variable. ", + Optional: true, + DeprecationMessage: "Specify the region as part of the account parameter", + }, + "session_params": schema.MapAttribute{ + Description: "Sets session parameters. [Parameters](https://docs.snowflake.com/en/sql-reference/parameters)", + Optional: true, + ElementType: types.StringType, + DeprecationMessage: "Use `params` instead", + }, + "oauth_access_token": schema.StringAttribute{ + Description: "Token for use with OAuth. Generating the token is left to other tools. Cannot be used with `browser_auth`, `private_key_path`, `oauth_refresh_token` or `password`. Can also be sourced from `SNOWFLAKE_OAUTH_ACCESS_TOKEN` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_refresh_token")), + }, + DeprecationMessage: "Use `token` instead", + }, + "oauth_refresh_token": schema.StringAttribute{ + Description: "Token for use with OAuth. Setup and generation of the token is left to other tools. Should be used in conjunction with `oauth_client_id`, `oauth_client_secret`, `oauth_endpoint`, `oauth_redirect_url`. Cannot be used with `browser_auth`, `private_key_path`, `oauth_access_token` or `password`. Can also be sourced from `SNOWFLAKE_OAUTH_REFRESH_TOKEN` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_access_token")), + stringvalidator.AlsoRequires(path.MatchRoot("oauth_client_id"), path.MatchRoot("oauth_client_secret"), path.MatchRoot("oauth_endpoint"), path.MatchRoot("oauth_redirect_url")), + }, + DeprecationMessage: "Use `token_accessor.0.refresh_token` instead", + }, + "oauth_client_id": schema.StringAttribute{ + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_CLIENT_ID` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_access_token")), + stringvalidator.AlsoRequires(path.MatchRoot("oauth_refresh_token"), path.MatchRoot("oauth_client_secret"), path.MatchRoot("oauth_endpoint"), path.MatchRoot("oauth_redirect_url")), + }, + DeprecationMessage: "Use `token_accessor.0.client_id` instead", + }, + "oauth_client_secret": schema.StringAttribute{ + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_CLIENT_SECRET` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_access_token")), + stringvalidator.AlsoRequires(path.MatchRoot("oauth_refresh_token"), path.MatchRoot("oauth_client_id"), path.MatchRoot("oauth_endpoint"), path.MatchRoot("oauth_redirect_url")), + }, + DeprecationMessage: "Use `token_accessor.0.client_secret` instead", + }, + "oauth_endpoint": schema.StringAttribute{ + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_ENDPOINT` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_access_token")), + stringvalidator.AlsoRequires(path.MatchRoot("oauth_refresh_token"), path.MatchRoot("oauth_client_id"), path.MatchRoot("oauth_client_secret"), path.MatchRoot("oauth_redirect_url")), + }, + DeprecationMessage: "Use `token_accessor.0.token_endpoint` instead", + }, + "oauth_redirect_url": schema.StringAttribute{ + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_OAUTH_REDIRECT_URL` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("private_key_path"), path.MatchRoot("private_key"), path.MatchRoot("private_key_passphrase"), path.MatchRoot("password"), path.MatchRoot("oauth_access_token")), + stringvalidator.AlsoRequires(path.MatchRoot("oauth_refresh_token"), path.MatchRoot("oauth_client_id"), path.MatchRoot("oauth_client_secret"), path.MatchRoot("oauth_endpoint")), + }, + DeprecationMessage: "Use `token_accessor.0.redirect_uri` instead", + }, + "browser_auth": schema.BoolAttribute{ + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_USE_BROWSER_AUTH` environment variable.", + Optional: true, + Sensitive: false, + DeprecationMessage: "Use `authenticator` instead", + }, + "private_key_path": schema.StringAttribute{ + Description: "Path to a private key for using keypair authentication. Cannot be used with `browser_auth`, `oauth_access_token` or `password`. Can also be sourced from `SNOWFLAKE_PRIVATE_KEY_PATH` environment variable.", + Optional: true, + Sensitive: true, + Validators: []validator.String{ + stringvalidator.ConflictsWith(path.MatchRoot("browser_auth"), path.MatchRoot("oauth_access_token"), path.MatchRoot("password")), + }, + DeprecationMessage: "use the [file Function](https://developer.hashicorp.com/terraform/language/functions/file) instead", + }, + }, + Blocks: map[string]schema.Block{ + "token_accessor": schema.ListNestedBlock{ + Validators: []validator.List{ + listvalidator.SizeAtMost(1), + }, + NestedObject: schema.NestedBlockObject{ + Attributes: map[string]schema.Attribute{ + "token_endpoint": schema.StringAttribute{ + Description: "The token endpoint for the OAuth provider e.g. https://{yourDomain}/oauth/token when using a refresh token to renew access token. Can also be sourced from the `SNOWFLAKE_TOKEN_ACCESSOR_TOKEN_ENDPOINT` environment variable.", + Required: true, + Sensitive: true, + }, + "refresh_token": schema.StringAttribute{ + Description: "The refresh token for the OAuth provider when using a refresh token to renew access token. Can also be sourced from the `SNOWFLAKE_TOKEN_ACCESSOR_REFRESH_TOKEN` environment variable.", + Required: true, + Sensitive: true, + }, + "client_id": schema.StringAttribute{ + Description: "The client ID for the OAuth provider when using a refresh token to renew access token. Can also be sourced from the `SNOWFLAKE_TOKEN_ACCESSOR_CLIENT_ID` environment variable.", + Required: true, + Sensitive: true, + }, + "client_secret": schema.StringAttribute{ + Description: "The client secret for the OAuth provider when using a refresh token to renew access token. Can also be sourced from the `SNOWFLAKE_TOKEN_ACCESSOR_CLIENT_SECRET` environment variable.", + Required: true, + Sensitive: true, + }, + "redirect_uri": schema.StringAttribute{ + Description: "The redirect URI for the OAuth provider when using a refresh token to renew access token. Can also be sourced from the `SNOWFLAKE_TOKEN_ACCESSOR_REDIRECT_URI` environment variable.", + Required: true, + Sensitive: true, + }, + }, + }, + }, + }, + } +} + +func (p *SnowflakeProvider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) { + var data snowflakeProviderModelV0 + + // Read configuration data into model + resp.Diagnostics.Append(req.Config.Get(ctx, &data)...) + + config := &gosnowflake.Config{ + Application: "terraform-provider-snowflake", + } + + account := os.Getenv("SNOWFLAKE_ACCOUNT") + if data.Account.ValueString() != "" { + account = data.Account.ValueString() + } + if account != "" { + config.Account = account + } + + user := os.Getenv("SNOWFLAKE_USER") + if user == "" { + user = os.Getenv("SNOWFLAKE_USERNAME") + } + if data.Username.ValueString() != "" { + user = data.Username.ValueString() + } + if data.User.ValueString() != "" { + user = data.User.ValueString() + } + if user != "" { + config.User = user + } + + password := os.Getenv("SNOWFLAKE_PASSWORD") + if data.Password.ValueString() != "" { + password = data.Password.ValueString() + } + if password != "" { + config.Password = password + } + + warehouse := os.Getenv("SNOWFLAKE_WAREHOUSE") + if data.Warehouse.ValueString() != "" { + warehouse = data.Warehouse.ValueString() + } + if warehouse != "" { + config.Warehouse = warehouse + } + + role := os.Getenv("SNOWFLAKE_ROLE") + if data.Role.ValueString() != "" { + role = data.Role.ValueString() + } + if role != "" { + config.Role = role + } + + validateDefaultParameters := getBoolEnv("SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS", true) + if !data.ValidateDefaultParameters.IsNull() && !data.ValidateDefaultParameters.IsUnknown() { + validateDefaultParameters = data.ValidateDefaultParameters.ValueBool() + } + if validateDefaultParameters { + config.ValidateDefaultParameters = gosnowflake.ConfigBoolTrue + } else { + config.ValidateDefaultParameters = gosnowflake.ConfigBoolFalse + } + + clientIP := os.Getenv("SNOWFLAKE_CLIENT_IP") + if data.ClientIP.ValueString() != "" { + clientIP = data.ClientIP.ValueString() + } + if clientIP != "" { + config.ClientIP = net.ParseIP(clientIP) + } + + protocol := os.Getenv("SNOWFLAKE_PROTOCOL") + if data.Protocol.ValueString() != "" { + protocol = data.Protocol.ValueString() + } + if protocol != "" { + config.Protocol = protocol + } + + host := os.Getenv("SNOWFLAKE_HOST") + if data.Host.ValueString() != "" { + host = data.Host.ValueString() + } + if host != "" { + config.Host = host + } + + port := getInt64Env("SNOWFLAKE_PORT", -1) + if !data.Port.IsNull() && !data.Port.IsUnknown() { + port = data.Port.ValueInt64() + } + if port > 0 { + config.Port = int(port) + } + + browserAuth := getBoolEnv("SNOWFLAKE_USE_BROWSER_AUTH", false) + if !data.BrowserAuth.IsNull() && !data.BrowserAuth.IsUnknown() { + browserAuth = data.BrowserAuth.ValueBool() + } + if browserAuth { + config.Authenticator = gosnowflake.AuthTypeExternalBrowser + } + + authenticator := os.Getenv("SNOWFLAKE_AUTHENTICATOR") + if data.Authenticator.ValueString() != "" { + authenticator = data.Authenticator.ValueString() + } + if authenticator != "" { + config.Authenticator = toAuthenticatorType(authenticator) + } + + passcode := os.Getenv("SNOWFLAKE_PASSCODE") + if data.Passcode.ValueString() != "" { + passcode = data.Passcode.ValueString() + } + if passcode != "" { + config.Passcode = passcode + } + + passcodeInPassword := getBoolEnv("SNOWFLAKE_PASSCODE_IN_PASSWORD", false) + if !data.PasscodeInPassword.IsNull() && !data.PasscodeInPassword.IsUnknown() { + passcodeInPassword = data.PasscodeInPassword.ValueBool() + } + config.PasscodeInPassword = passcodeInPassword + + oktaURL := os.Getenv("SNOWFLAKE_OKTA_URL") + if data.OktaURL.ValueString() != "" { + oktaURL = data.OktaURL.ValueString() + } + if oktaURL != "" { + parsedOktaURL, err := url.Parse(oktaURL) + if err != nil { + resp.Diagnostics.AddError("Error parsing Okta URL", err.Error()) + } + config.OktaURL = parsedOktaURL + } + + loginTimeout := getInt64Env("SNOWFLAKE_LOGIN_TIMEOUT", -1) + if !data.LoginTimeout.IsNull() && !data.LoginTimeout.IsUnknown() { + loginTimeout = data.LoginTimeout.ValueInt64() + } + if loginTimeout > 0 { + config.LoginTimeout = time.Second * time.Duration(loginTimeout) + } + + requestTimeout := getInt64Env("SNOWFLAKE_REQUEST_TIMEOUT", -1) + if !data.RequestTimeout.IsNull() && !data.RequestTimeout.IsUnknown() { + requestTimeout = data.RequestTimeout.ValueInt64() + } + if requestTimeout > 0 { + config.RequestTimeout = time.Second * time.Duration(requestTimeout) + } + + jwtExpireTimeout := getInt64Env("SNOWFLAKE_JWT_EXPIRE_TIMEOUT", -1) + if !data.JWTExpireTimeout.IsNull() && !data.JWTExpireTimeout.IsUnknown() { + jwtExpireTimeout = data.JWTClientTimeout.ValueInt64() + } + if jwtExpireTimeout > 0 { + config.JWTClientTimeout = time.Second * time.Duration(jwtExpireTimeout) + } + + clientTimeout := getInt64Env("SNOWFLAKE_CLIENT_TIMEOUT", -1) + if !data.ClientTimeout.IsNull() && !data.ClientTimeout.IsUnknown() { + clientTimeout = data.ClientTimeout.ValueInt64() + } + if clientTimeout > 0 { + config.ClientTimeout = time.Second * time.Duration(clientTimeout) + } + + jwtClientTimeout := getInt64Env("SNOWFLAKE_JWT_CLIENT_TIMEOUT", -1) + if !data.JWTClientTimeout.IsNull() && !data.JWTClientTimeout.IsUnknown() { + jwtClientTimeout = data.JWTClientTimeout.ValueInt64() + } + if jwtClientTimeout > 0 { + config.JWTClientTimeout = time.Second * time.Duration(jwtClientTimeout) + } + + externalBrowserTimeout := getInt64Env("SNOWFLAKE_EXTERNAL_BROWSER_TIMEOUT", -1) + if !data.ExternalBrowserTimeout.IsNull() && !data.ExternalBrowserTimeout.IsUnknown() { + externalBrowserTimeout = data.ExternalBrowserTimeout.ValueInt64() + } + if externalBrowserTimeout > 0 { + config.ExternalBrowserTimeout = time.Second * time.Duration(externalBrowserTimeout) + } + + insecureMode := getBoolEnv("SNOWFLAKE_INSECURE_MODE", false) + if !data.InsecureMode.IsNull() && !data.InsecureMode.IsUnknown() { + insecureMode = data.InsecureMode.ValueBool() + } + config.InsecureMode = insecureMode + + ocspFailOpen := getBoolEnv("SNOWFLAKE_OCSP_FAIL_OPEN", true) + if !data.OCSPFailOpen.IsNull() && !data.OCSPFailOpen.IsUnknown() { + ocspFailOpen = data.OCSPFailOpen.ValueBool() + } + if ocspFailOpen { + config.OCSPFailOpen = gosnowflake.OCSPFailOpenTrue + } else { + config.OCSPFailOpen = gosnowflake.OCSPFailOpenFalse + } + + token := os.Getenv("SNOWFLAKE_TOKEN") + if data.Token.ValueString() != "" { + token = data.Token.ValueString() + } + if token != "" { + config.Token = token + } + + keepSessionAlive := getBoolEnv("SNOWFLAKE_KEEP_SESSION_ALIVE", false) + if !data.KeepSessionAlive.IsNull() && !data.KeepSessionAlive.IsUnknown() { + keepSessionAlive = data.KeepSessionAlive.ValueBool() + } + config.KeepSessionAlive = keepSessionAlive + + privateKey := os.Getenv("SNOWFLAKE_PRIVATE_KEY") + if data.PrivateKey.ValueString() != "" { + privateKey = data.PrivateKey.ValueString() + } + privateKeyPath := os.Getenv("SNOWFLAKE_PRIVATE_KEY_PATH") + if data.PrivateKeyPath.ValueString() != "" { + privateKeyPath = data.PrivateKeyPath.ValueString() + } + privateKeyPassphrase := os.Getenv("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE") + if data.PrivateKeyPassphrase.ValueString() != "" { + privateKeyPassphrase = data.PrivateKeyPassphrase.ValueString() + } + if privateKey != "" || privateKeyPath != "" { + if v, err := getPrivateKey(privateKeyPath, privateKey, privateKeyPassphrase); err != nil && v != nil { + config.PrivateKey = v + } + } + disableTelemetry := getBoolEnv("SNOWFLAKE_DISABLE_TELEMETRY", false) + if !data.DisableTelemetry.IsNull() && !data.DisableTelemetry.IsUnknown() { + disableTelemetry = data.DisableTelemetry.ValueBool() + } + config.DisableTelemetry = disableTelemetry + + clientRequestMFAToken := getBoolEnv("SNOWFLAKE_CLIENT_REQUEST_MFA_TOKEN", true) + if !data.ClientRequestMFAToken.IsNull() && !data.ClientRequestMFAToken.IsUnknown() { + clientRequestMFAToken = data.ClientRequestMFAToken.ValueBool() + } + if clientRequestMFAToken { + config.ClientRequestMfaToken = gosnowflake.ConfigBoolTrue + } else { + config.ClientRequestMfaToken = gosnowflake.ConfigBoolFalse + } + + clientStoreTemporaryCredential := getBoolEnv("SNOWFLAKE_CLIENT_STORE_TEMPORARY_CREDENTIAL", true) + if !data.ClientStoreTemporaryCredential.IsNull() && !data.ClientStoreTemporaryCredential.IsUnknown() { + clientStoreTemporaryCredential = data.ClientStoreTemporaryCredential.ValueBool() + } + if clientStoreTemporaryCredential { + config.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolTrue + } else { + config.ClientStoreTemporaryCredential = gosnowflake.ConfigBoolFalse + } + + disableQueryContextCache := getBoolEnv("SNOWFLAKE_DISABLE_QUERY_CONTEXT_CACHE", false) + if !data.DisableQueryContextCache.IsNull() && !data.DisableQueryContextCache.IsUnknown() { + disableQueryContextCache = data.DisableQueryContextCache.ValueBool() + } + config.DisableQueryContextCache = disableQueryContextCache + + tokenEndpoint := os.Getenv("SNOWFLAKE_TOKEN_ACCESSOR_TOKEN_ENDPOINT") + if tokenEndpoint == "" { + tokenEndpoint = os.Getenv("SNOWFLAKE_OAUTH_ENDPOINT") + } + if data.OauthEndpoint.ValueString() != "" { + tokenEndpoint = data.OauthEndpoint.ValueString() + } + refreshToken := os.Getenv("SNOWFLAKE_TOKEN_ACCESSOR_REFRESH_TOKEN") + if refreshToken == "" { + refreshToken = os.Getenv("SNOWFLAKE_OAUTH_REFRESH_TOKEN") + } + if data.OauthRefreshToken.ValueString() != "" { + refreshToken = data.OauthRefreshToken.ValueString() + } + clientID := os.Getenv("SNOWFLAKE_TOKEN_ACCESSOR_CLIENT_ID") + if clientID == "" { + clientID = os.Getenv("SNOWFLAKE_OAUTH_CLIENT_ID") + } + if data.OauthClientID.ValueString() != "" { + clientID = data.OauthClientID.ValueString() + } + clientSecret := os.Getenv("SNOWFLAKE_TOKEN_ACCESSOR_CLIENT_SECRET") + if clientSecret == "" { + clientSecret = os.Getenv("SNOWFLAKE_OAUTH_CLIENT_SECRET") + } + if data.OauthClientSecret.ValueString() != "" { + clientSecret = data.OauthClientSecret.ValueString() + } + redirectURI := os.Getenv("SNOWFLAKE_TOKEN_ACCESSOR_REDIRECT_URI") + if redirectURI == "" { + redirectURI = os.Getenv("SNOWFLAKE_OAUTH_REDIRECT_URL") + } + if data.OauthRedirectURL.ValueString() != "" { + redirectURI = data.OauthRedirectURL.ValueString() + } + var tokenAccesors []RefreshTokenAccesor + data.TokenAccessor.ElementsAs(ctx, &tokenAccesors, false) + if len(tokenAccesors) > 0 { + tokenAccessor := tokenAccesors[0] + if tokenAccessor.TokenEndpoint.ValueString() != "" { + tokenEndpoint = tokenAccessor.TokenEndpoint.ValueString() + } + if tokenAccessor.RefreshToken.ValueString() != "" { + refreshToken = tokenAccessor.RefreshToken.ValueString() + } + if tokenAccessor.ClientID.ValueString() != "" { + clientID = tokenAccessor.ClientID.ValueString() + } + if tokenAccessor.ClientSecret.ValueString() != "" { + clientSecret = tokenAccessor.ClientSecret.ValueString() + } + if tokenAccessor.RedirectURI.ValueString() != "" { + redirectURI = tokenAccessor.RedirectURI.ValueString() + } + } + + if tokenEndpoint != "" && refreshToken != "" && clientID != "" && clientSecret != "" && redirectURI != "" { + accessToken, err := GetAccessTokenWithRefreshToken(tokenEndpoint, clientID, clientSecret, refreshToken, redirectURI) + if err != nil { + resp.Diagnostics.AddError("Error retrieving access token from refresh token", err.Error()) + } + config.Token = accessToken + config.Authenticator = gosnowflake.AuthTypeOAuth + } + + region := os.Getenv("SNOWFLAKE_REGION") + if data.Region.ValueString() != "" { + region = data.Region.ValueString() + } + if region != "" { + config.Region = region + } + + if !data.SessionParams.IsNull() && !data.SessionParams.IsUnknown() { + var m map[string]interface{} + params := make(map[string]*string, 0) + data.SessionParams.ElementsAs(ctx, m, false) + for k, v := range m { + s := v.(string) + params[k] = &s + } + config.Params = params + } + + if !data.Params.IsNull() && !data.Params.IsUnknown() { + var m map[string]interface{} + params := make(map[string]*string, 0) + data.Params.ElementsAs(ctx, m, false) + for k, v := range m { + s := v.(string) + params[k] = &s + } + config.Params = params + } + + profile := os.Getenv("SNOWFLAKE_PROFILE") + if data.Profile.ValueString() != "" { + profile = data.Profile.ValueString() + } + + if profile != "" { + if profile == "default" { + defaultConfig := sdk.DefaultConfig() + if defaultConfig.Account == "" || defaultConfig.User == "" { + resp.Diagnostics.AddError("Error retrieving default profile config", "default profile not found in config file") + } + config = sdk.MergeConfig(config, defaultConfig) + } else { + profileConfig, err := sdk.ProfileConfig(profile) + if err != nil { + resp.Diagnostics.AddError("Error retrieving profile config", err.Error()) + } + if profileConfig == nil { + resp.Diagnostics.AddError("Error retrieving profile config", "profile with name: "+profile+" not found in config file") + } + // merge any credentials found in profile with config + config = sdk.MergeConfig(config, profileConfig) + } + } + + client, err := sdk.NewClient(config) + if err != nil { + resp.Diagnostics.AddError("Error creating Snowflake client", err.Error()) + } + + if resp.Diagnostics.HasError() { + return + } + providerData := &ProviderData{ + client: client, + } + resp.DataSourceData = providerData + resp.ResourceData = providerData +} + +type ProviderData struct { + client *sdk.Client +} + +func (p *SnowflakeProvider) Resources(ctx context.Context) []func() resource.Resource { + return []func() resource.Resource{ + // NewResourceMonitorResource, + } +} + +func (p *SnowflakeProvider) DataSources(ctx context.Context) []func() datasource.DataSource { + return []func() datasource.DataSource{} +} + +func New(version string) func() provider.Provider { + return func() provider.Provider { + return &SnowflakeProvider{ + version: version, + } + } +} diff --git a/framework/provider/provider_helpers.go b/framework/provider/provider_helpers.go new file mode 100644 index 0000000000..8d61b79f49 --- /dev/null +++ b/framework/provider/provider_helpers.go @@ -0,0 +1,186 @@ +package provider + +import ( + "crypto/rsa" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strconv" + "strings" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/mitchellh/go-homedir" + "github.com/snowflakedb/gosnowflake" + "github.com/youmark/pkcs8" + "golang.org/x/crypto/ssh" +) + +func mergeSchemas(schemaCollections ...map[string]*schema.Resource) map[string]*schema.Resource { + out := map[string]*schema.Resource{} + for _, schemaCollection := range schemaCollections { + for name, s := range schemaCollection { + out[name] = s + } + } + return out +} + +func getPrivateKey(privateKeyPath, privateKeyString, privateKeyPassphrase string) (*rsa.PrivateKey, error) { + privateKeyBytes := []byte(privateKeyString) + var err error + if len(privateKeyBytes) == 0 && privateKeyPath != "" { + privateKeyBytes, err = readFile(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("private Key file could not be read err = %w", err) + } + } + return parsePrivateKey(privateKeyBytes, []byte(privateKeyPassphrase)) +} + +func toAuthenticatorType(authenticator string) gosnowflake.AuthType { + switch authenticator { + case "Snowflake": + return gosnowflake.AuthTypeSnowflake + case "OAuth": + return gosnowflake.AuthTypeOAuth + case "ExternalBrowser": + return gosnowflake.AuthTypeExternalBrowser + case "Okta": + return gosnowflake.AuthTypeOkta + case "JWT": + return gosnowflake.AuthTypeJwt + case "TokenAccessor": + return gosnowflake.AuthTypeTokenAccessor + case "UsernamePasswordMFA": + return gosnowflake.AuthTypeUsernamePasswordMFA + default: + return gosnowflake.AuthTypeSnowflake + } +} + +func getInt64Env(key string, defaultValue int64) int64 { + s := os.Getenv(key) + if s == "" { + return defaultValue + } + i, err := strconv.Atoi(s) + if err != nil { + return defaultValue + } + return int64(i) +} + +func getBoolEnv(key string, defaultValue bool) bool { + s := strings.ToLower(os.Getenv(key)) + if s == "" { + return defaultValue + } + switch s { + case "true", "1": + return true + case "false", "0": + return false + default: + return defaultValue + } +} + +func readFile(privateKeyPath string) ([]byte, error) { + expandedPrivateKeyPath, err := homedir.Expand(privateKeyPath) + if err != nil { + return nil, fmt.Errorf("invalid Path to private key err = %w", err) + } + + privateKeyBytes, err := os.ReadFile(expandedPrivateKeyPath) + if err != nil { + return nil, fmt.Errorf("could not read private key err = %w", err) + } + + if len(privateKeyBytes) == 0 { + return nil, errors.New("private key is empty") + } + + return privateKeyBytes, nil +} + +func parsePrivateKey(privateKeyBytes []byte, passhrase []byte) (*rsa.PrivateKey, error) { + privateKeyBlock, _ := pem.Decode(privateKeyBytes) + if privateKeyBlock == nil { + return nil, fmt.Errorf("could not parse private key, key is not in PEM format") + } + + if privateKeyBlock.Type == "ENCRYPTED PRIVATE KEY" { + if len(passhrase) == 0 { + return nil, fmt.Errorf("private key requires a passphrase, but private_key_passphrase was not supplied") + } + privateKey, err := pkcs8.ParsePKCS8PrivateKeyRSA(privateKeyBlock.Bytes, passhrase) + if err != nil { + return nil, fmt.Errorf("could not parse encrypted private key with passphrase, only ciphers aes-128-cbc, aes-128-gcm, aes-192-cbc, aes-192-gcm, aes-256-cbc, aes-256-gcm, and des-ede3-cbc are supported err = %w", err) + } + return privateKey, nil + } + + privateKey, err := ssh.ParseRawPrivateKey(privateKeyBytes) + if err != nil { + return nil, fmt.Errorf("could not parse private key err = %w", err) + } + + rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("privateKey not of type RSA") + } + return rsaPrivateKey, nil +} + +type GetRefreshTokenResponseBody struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` +} + +func GetAccessTokenWithRefreshToken( + tokenEndPoint string, + clientID string, + clientSecret string, + refreshToken string, + redirectURI string, +) (string, error) { + client := &http.Client{} + + data := url.Values{} + data.Set("grant_type", "refresh_token") + data.Set("refresh_token", refreshToken) + data.Set("redirect_uri", redirectURI) + body := strings.NewReader(data.Encode()) + + request, err := http.NewRequest("POST", tokenEndPoint, body) + if err != nil { + return "", fmt.Errorf("request to the endpoint could not be completed %w", err) + } + request.SetBasicAuth(clientID, clientSecret) + request.Header.Set("Content-Type", "application/x-www-form-urlencoded;charset=UTF-8") + + response, err := client.Do(request) + if err != nil { + return "", fmt.Errorf("response status returned an err = %w", err) + } + if response.StatusCode != 200 { + return "", fmt.Errorf("response status code: %s: %s err = %w", strconv.Itoa(response.StatusCode), http.StatusText(response.StatusCode), err) + } + defer response.Body.Close() + dat, err := io.ReadAll(response.Body) + if err != nil { + return "", fmt.Errorf("response body was not able to be parsed err = %w", err) + } + var result GetRefreshTokenResponseBody + err = json.Unmarshal(dat, &result) + if err != nil { + return "", fmt.Errorf("error parsing JSON from Snowflake err = %w", err) + } + return result.AccessToken, nil +} diff --git a/framework/provider/resource_monitor_resource.go b/framework/provider/resource_monitor_resource.go new file mode 100644 index 0000000000..921387a376 --- /dev/null +++ b/framework/provider/resource_monitor_resource.go @@ -0,0 +1,807 @@ +package provider + +import ( + "context" + "fmt" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-framework-validators/int64validator" + "github.com/hashicorp/terraform-plugin-framework-validators/setvalidator" + "github.com/hashicorp/terraform-plugin-framework-validators/stringvalidator" + "github.com/hashicorp/terraform-plugin-framework/attr" + "github.com/hashicorp/terraform-plugin-framework/diag" + "github.com/hashicorp/terraform-plugin-framework/path" + "github.com/hashicorp/terraform-plugin-framework/resource" + "github.com/hashicorp/terraform-plugin-framework/resource/schema" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/booldefault" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/boolplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/float64default" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/float64planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" + "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier" + "github.com/hashicorp/terraform-plugin-framework/schema/validator" + "github.com/hashicorp/terraform-plugin-framework/types" + "github.com/hashicorp/terraform-plugin-log/tflog" +) + +var ( + _ resource.Resource = &ResourceMonitorResource{} + _ resource.ResourceWithImportState = &ResourceMonitorResource{} +) + +func NewResourceMonitorResource() resource.Resource { + return &ResourceMonitorResource{} +} + +type ResourceMonitorResource struct { + client *sdk.Client +} + +type resourceMonitorModelV0 struct { + Name types.String `tfsdk:"name"` + NotifyUsers types.Set `tfsdk:"notify_users"` + CreditQuota types.Float64 `tfsdk:"credit_quota"` + Frequency types.String `tfsdk:"frequency"` + StartTimestamp types.String `tfsdk:"start_timestamp"` + EndTimestamp types.String `tfsdk:"end_timestamp"` + SuspendTrigger types.Int64 `tfsdk:"suspend_trigger"` + SuspendTriggers types.Set `tfsdk:"suspend_triggers"` + SuspendImmediateTrigger types.Int64 `tfsdk:"suspend_immediate_trigger"` + SuspendImmediateTriggers types.Set `tfsdk:"suspend_immediate_triggers"` + NotifyTriggers types.Set `tfsdk:"notify_triggers"` + SetForAccount types.Bool `tfsdk:"set_for_account"` + Warehouses types.Set `tfsdk:"warehouses"` + Id types.String `tfsdk:"id"` +} + +func resourceMonitorSchemaV0() schema.Schema { + return schema.Schema{ + Version: 0, + Attributes: map[string]schema.Attribute{ + "name": schema.StringAttribute{ + Description: "Identifier for the resource monitor; must be unique for your account.", + Required: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "notify_users": schema.SetAttribute{ + Description: "Specifies the list of users to receive email notifications on resource monitors.", + Optional: true, + ElementType: types.StringType, + }, + "credit_quota": schema.Int64Attribute{ + Description: "The number of credits allocated monthly to the resource monitor.", + Optional: true, + Computed: true, + }, + "frequency": schema.StringAttribute{ + Description: "The frequency interval at which the credit usage resets to 0. If you set a frequency for a resource monitor, you must also set START_TIMESTAMP.", + Optional: true, + Computed: true, + Validators: []validator.String{ + stringvalidator.OneOfCaseInsensitive([]string{"MONTHLY", "DAILY", "WEEKLY", "YEARLY", "NEVER"}...), + }, + }, + "start_timestamp": schema.StringAttribute{ + Description: "The date and time when the resource monitor starts monitoring credit usage for the assigned warehouses.", + Optional: true, + Computed: true, + }, + "end_timestamp": schema.StringAttribute{ + Description: "The date and time when the resource monitor suspends the assigned warehouses.", + Optional: true, + Computed: true, + }, + "suspend_trigger": schema.Int64Attribute{ + Description: "The number that represents the percentage threshold at which to suspend all warehouses.", + Optional: true, + Validators: []validator.Int64{ + int64validator.ConflictsWith(path.MatchRoot("suspend_triggers")), + }, + }, + "suspend_triggers": schema.SetAttribute{ + Description: "A list of percentage thresholds at which to suspend all warehouses.", + Optional: true, + ElementType: types.Int64Type, + Validators: []validator.Set{ + setvalidator.ConflictsWith(path.MatchRoot("suspend_trigger")), + }, + DeprecationMessage: "Use suspend_trigger instead", + }, + "suspend_immediate_trigger": schema.Int64Attribute{ + Description: "The number that represents the percentage threshold at which to immediately suspend all warehouses.", Optional: true, + Validators: []validator.Int64{ + int64validator.ConflictsWith(path.MatchRoot("suspend_immediate_triggers")), + }, + }, + "suspend_immediate_triggers": schema.SetAttribute{ + Description: "A list of percentage thresholds at which to suspend all warehouses.", + Optional: true, + ElementType: types.Int64Type, + Validators: []validator.Set{ + setvalidator.ConflictsWith(path.MatchRoot("suspend_immediate_trigger")), + }, + DeprecationMessage: "Use suspend_immediate_trigger instead", + }, + "notify_triggers": schema.SetAttribute{ + Description: "A list of percentage thresholds at which to send an alert to subscribed users.", + Optional: true, + ElementType: types.Int64Type, + }, + "set_for_account": schema.BoolAttribute{ + Description: "Specifies whether the resource monitor should be applied globally to your Snowflake account (defaults to false).", + Optional: true, + Default: booldefault.StaticBool(false), + // todo: create a snowflake_resource_monitor_association resource + // DeprecationMessage: "Use snowflake_resource_monitor_association instead", + }, + "warehouses": schema.SetAttribute{ + Description: "A list of warehouses to apply the resource monitor to.", + Optional: true, + ElementType: types.StringType, + // todo: add the `resource_monitor` attribute to the `snowflake_warehouse` resource + // DeprecationMessage: "Set the `resource_monitor` attribute on the `snowflake_warehouse` resource instead", + }, + }, + } +} + +func upgradeResourceMonitorStateV0toV1(ctx context.Context, req resource.UpgradeStateRequest, resp *resource.UpgradeStateResponse) { + var resourceMonitorDataV0 resourceMonitorModelV0 + resp.Diagnostics.Append(req.State.Get(ctx, &resourceMonitorDataV0)...) + if resp.Diagnostics.HasError() { + return + } + + name := resourceMonitorDataV0.Name + notifyUsers := resourceMonitorDataV0.NotifyUsers + creditQuota := resourceMonitorDataV0.CreditQuota + frequency := resourceMonitorDataV0.Frequency + startTimestamp := resourceMonitorDataV0.StartTimestamp + endTimestamp := resourceMonitorDataV0.EndTimestamp + suspendTrigger := resourceMonitorDataV0.SuspendTrigger + suspendTriggers := resourceMonitorDataV0.SuspendTriggers + if !suspendTriggers.IsNull() { + suspendTriggersElements := make([]types.Int64, 0, len(suspendTriggers.Elements())) + suspendTriggers.ElementsAs(ctx, &suspendTriggersElements, false) + if len(suspendTriggersElements) > 0 { + suspendTrigger = suspendTriggersElements[0] + } + } + suspendImmediateTrigger := resourceMonitorDataV0.SuspendImmediateTrigger + suspendImmediateTriggers := resourceMonitorDataV0.SuspendImmediateTriggers + if !suspendImmediateTriggers.IsNull() { + suspendImmediateTriggersElements := make([]types.Int64, 0, len(suspendImmediateTriggers.Elements())) + suspendImmediateTriggers.ElementsAs(ctx, &suspendImmediateTriggersElements, false) + if len(suspendImmediateTriggersElements) > 0 { + suspendImmediateTrigger = suspendImmediateTriggersElements[0] + } + } + notifyTriggers := resourceMonitorDataV0.NotifyTriggers + + trigggers := make([]resourceMonitorTriggerModel, 0) + if !suspendTrigger.IsNull() { + trigggers = append(trigggers, resourceMonitorTriggerModel{ + Threshold: suspendTrigger, + TriggerAction: types.StringValue("SUSPEND"), + }) + } + if !suspendImmediateTrigger.IsNull() { + trigggers = append(trigggers, resourceMonitorTriggerModel{ + Threshold: suspendImmediateTrigger, + TriggerAction: types.StringValue("SUSPEND_IMMEDIATE"), + }) + } + if !notifyTriggers.IsNull() { + notifyTriggersElements := make([]types.Int64, 0, len(notifyTriggers.Elements())) + notifyTriggers.ElementsAs(ctx, ¬ifyTriggersElements, false) + for _, v := range notifyTriggersElements { + trigggers = append(trigggers, resourceMonitorTriggerModel{ + Threshold: v, + TriggerAction: types.StringValue("NOTIFY"), + }) + } + } + triggersObjectType := types.ObjectType{}.WithAttributeTypes(map[string]attr.Type{ + "threshold": types.Int64Type, + "trigger_action": types.StringType, + }) + triggersSet, diags := types.SetValueFrom(ctx, triggersObjectType, trigggers) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + + resourceMonitorV1 := &resourceMonitorModelV1{ + Name: name, + NotifyUsers: notifyUsers, + CreditQuota: creditQuota, + Frequency: frequency, + StartTimestamp: startTimestamp, + EndTimestamp: endTimestamp, + Triggers: triggersSet, + Id: resourceMonitorDataV0.Id, + } + resp.Diagnostics.Append(resp.State.Set(ctx, resourceMonitorV1)...) +} + +type resourceMonitorModelV1 struct { + OrReplace types.Bool `tfsdk:"or_replace"` + Name types.String `tfsdk:"name"` + CreditQuota types.Float64 `tfsdk:"credit_quota"` + UsedCredits types.Float64 `tfsdk:"used_credits"` + RemainingCredits types.Float64 `tfsdk:"remaining_credits"` + Frequency types.String `tfsdk:"frequency"` + StartTimestamp types.String `tfsdk:"start_timestamp"` + EndTimestamp types.String `tfsdk:"end_timestamp"` + Level types.String `tfsdk:"level"` + NotifyUsers types.Set `tfsdk:"notify_users"` + Triggers types.Set `tfsdk:"triggers"` + Id types.String `tfsdk:"id"` +} + +func resourceMonitorSchemaV1() schema.Schema { + return schema.Schema{ + Description: "Snowflake resource monitor resource", + Version: 1, + Attributes: map[string]schema.Attribute{ + "id": schema.StringAttribute{ + Computed: true, + MarkdownDescription: "ID of the database", + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + }, + }, + "or_replace": schema.BoolAttribute{ + Description: "Specifies whether to replace the resource monitor if it exists and has the same name as the one being created", + Optional: true, + Computed: true, + Sensitive: isSensitive("snowflake_resource_monitor.*.or_replace"), + Default: booldefault.StaticBool(false), + PlanModifiers: []planmodifier.Bool{ + boolplanmodifier.UseStateForUnknown(), + }, + }, + "name": schema.StringAttribute{ + Description: "Specifies the object identifier for the database", + Required: true, + Sensitive: isSensitive("snowflake_resource_monitor.*.name"), + PlanModifiers: []planmodifier.String{ + stringplanmodifier.RequiresReplace(), + }, + }, + "credit_quota": schema.Float64Attribute{ + Description: "The number of credits allocated to the resource monitor per frequency interval.", + Optional: true, + Computed: true, + Default: float64default.StaticFloat64(0), + Sensitive: isSensitive("snowflake_resource_monitor.*.credit_quota"), + }, + "used_credits": schema.Float64Attribute{ + Description: "The number of credits used by the resource monitor.", + Computed: true, + PlanModifiers: []planmodifier.Float64{ + float64planmodifier.UseStateForUnknown(), + }, + }, + "remaining_credits": schema.Float64Attribute{ + Description: "The number of credits remaining for the resource monitor.", + Computed: true, + PlanModifiers: []planmodifier.Float64{ + float64planmodifier.UseStateForUnknown(), + }, + }, + "level": schema.StringAttribute{ + Description: "resource monitor level", + Computed: true, + PlanModifiers: []planmodifier.String{ + stringplanmodifier.UseStateForUnknown(), + }, + }, + "frequency": schema.StringAttribute{ + Description: "Specifies the maximum number of days to extend the Fail-safe storage retention period for the database", + Optional: true, + Computed: true, + Sensitive: isSensitive("snowflake_resource_monitor.*.frequency"), + Validators: []validator.String{ + stringvalidator.OneOf([]string{"MONTHLY", "DAILY", "WEEKLY", "YEARLY", "NEVER"}...), + stringvalidator.AlsoRequires(path.MatchRoot("start_timestamp")), + }, + }, + "start_timestamp": schema.StringAttribute{ + Description: "Specifies the start time of the resource monitor", + Optional: true, + Computed: true, + Validators: []validator.String{ + stringvalidator.AlsoRequires(path.MatchRoot("frequency")), + }, + }, + "end_timestamp": schema.StringAttribute{ + Description: "Specifies the end time of the resource monitor", + Optional: true, + }, + "notify_users": schema.SetAttribute{ + Description: "Specifies the list of users to receive email notifications on resource monitors", + Optional: true, + ElementType: types.StringType, + }, + "triggers": schema.SetNestedAttribute{ + Description: "Specifies the list of triggers to receive email notifications on resource monitors", + Optional: true, + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "threshold": schema.Int64Attribute{ + Description: "Specifies the percentage of credits used to trigger an email notification", + Required: true, + Validators: []validator.Int64{ + int64validator.AtLeast(0), + }, + }, + "trigger_action": schema.StringAttribute{ + Description: "Specifies the action to take when the trigger is activated", + Required: true, + Validators: []validator.String{ + stringvalidator.OneOf([]string{"SUSPEND", "SUSPEND_IMMEDIATE", "NOTIFY"}...), + }, + }, + }, + }, + }, + }, + } +} + +type resourceMonitorTriggerModel struct { + Threshold types.Int64 `tfsdk:"threshold"` + TriggerAction types.String `tfsdk:"trigger_action"` +} + +func (old *resourceMonitorModelV1) Equals(new *resourceMonitorModelV1, ctx context.Context) bool { + if old == nil || new == nil { + return false + } + if !old.Id.Equal(new.Id) { + return false + } + if !old.OrReplace.Equal(new.OrReplace) { + return false + } + if !old.Name.Equal(new.Name) { + return false + } + if !old.CreditQuota.Equal(new.CreditQuota) { + return false + } + if !old.Frequency.Equal(new.Frequency) { + return false + } + if !old.StartTimestamp.Equal(new.StartTimestamp) { + return false + } + if !old.EndTimestamp.Equal(new.EndTimestamp) { + return false + } + if !old.Triggers.Equal(new.Triggers) { + return false + } + if !old.NotifyUsers.Equal(new.NotifyUsers) { + return false + } + + return true +} + +func (r *ResourceMonitorResource) Metadata(ctx context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) { + resp.TypeName = req.ProviderTypeName + "_resource_monitor" +} + +func (r *ResourceMonitorResource) Schema(ctx context.Context, req resource.SchemaRequest, resp *resource.SchemaResponse) { + resp.Schema = resourceMonitorSchemaV1() +} + +func (r *ResourceMonitorResource) UpgradeState(ctx context.Context) map[int64]resource.StateUpgrader { + schemaV0 := resourceMonitorSchemaV0() + return map[int64]resource.StateUpgrader{ + // State upgrade implementation from 0 to 1 + 0: { + PriorSchema: &schemaV0, + StateUpgrader: upgradeResourceMonitorStateV0toV1, + }, + } +} + +func (r *ResourceMonitorResource) Configure(ctx context.Context, req resource.ConfigureRequest, resp *resource.ConfigureResponse) { + // Prevent panic if the provider has not been configured. + if req.ProviderData == nil { + return + } + + providerData, ok := req.ProviderData.(*ProviderData) + + if !ok { + resp.Diagnostics.AddError( + "Unexpected Resource Configure Type", + fmt.Sprintf("Expected *sdk.Client, got: %T. Please report this issue to the provider developers.", req.ProviderData), + ) + + return + } + + r.client = providerData.client +} + +func (r *ResourceMonitorResource) ModifyPlan(ctx context.Context, req resource.ModifyPlanRequest, resp *resource.ModifyPlanResponse) { + // we aren't really modifying the plan, just logging what the plan intends to do + resp.Plan = req.Plan + var plan, state *resourceMonitorModelV1 + resp.Diagnostics.Append(req.Plan.Get(ctx, &plan)...) + resp.Diagnostics.Append(req.State.Get(ctx, &state)...) + if resp.Diagnostics.HasError() { + return + } + resourceName := "snowflake_database" + // DELETE + if req.Plan.Raw.IsNull() { + _, readLogs, _ := r.read(ctx, state, true) + _, deleteLogs, _ := r.delete(ctx, state, true) + deleteLogs = append(deleteLogs, readLogs...) + tflog.Debug(ctx, formatSQLPreview(DeleteOperation, resourceName, state.Id.ValueString(), deleteLogs)) + return + } + + // CREATE + if plan.Id.IsUnknown() { + _, createLogs, _ := r.create(ctx, plan, true) + plan.Id = types.StringValue(sdk.NewAccountObjectIdentifier(plan.Name.ValueString()).FullyQualifiedName()) + _, readLogs, _ := r.read(ctx, plan, true) + createLogs = append(createLogs, readLogs...) + tflog.Debug(ctx, formatSQLPreview(CreateOperation, resourceName, "", createLogs)) + return + } + + if plan.Equals(state, ctx) { + // READ + _, logs, _ := r.read(ctx, state, true) + tflog.Debug(ctx, formatSQLPreview(ReadOperation, resourceName, state.Id.ValueString(), logs)) + return + } else { + // UPDATE + _, updateLogs, _ := r.update(ctx, plan, state, true) + _, readLogs, _ := r.read(ctx, plan, true) + updateLogs = append(updateLogs, readLogs...) + tflog.Debug(ctx, formatSQLPreview(UpdateOperation, resourceName, state.Id.ValueString(), updateLogs)) + } +} + +func (r *ResourceMonitorResource) Create(ctx context.Context, req resource.CreateRequest, resp *resource.CreateResponse) { + var data *resourceMonitorModelV1 + resp.Diagnostics.Append(req.Plan.Get(ctx, &data)...) + data, _, diags := r.create(ctx, data, false) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + resp.Diagnostics.Append(resp.State.Set(ctx, &data)...) +} + +func (r *ResourceMonitorResource) create(ctx context.Context, data *resourceMonitorModelV1, dryRun bool) (*resourceMonitorModelV1, []string, diag.Diagnostics) { + diags := diag.Diagnostics{} + client := r.client + if dryRun { + client = sdk.NewDryRunClient() + } + + name := data.Name.ValueString() + + id := sdk.NewAccountObjectIdentifier(name) + + opts := &sdk.CreateResourceMonitorOptions{ + OrReplace: data.OrReplace.ValueBoolPointer(), + } + + with := &sdk.ResourceMonitorWith{} + setWith := false + if !data.CreditQuota.IsNull() && !data.CreditQuota.IsUnknown() && data.CreditQuota.ValueFloat64() > 0 { + setWith = true + with.CreditQuota = sdk.Int(int(data.CreditQuota.ValueFloat64())) + } + if !data.Frequency.IsNull() && data.Frequency.ValueString() != "" { + setWith = true + frequency, err := sdk.FrequencyFromString(data.Frequency.ValueString()) + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to create resource monitor, got error: %s", err)) + } + with.Frequency = frequency + } + if !data.StartTimestamp.IsNull() && data.StartTimestamp.ValueString() != "" { + setWith = true + with.StartTimestamp = data.StartTimestamp.ValueStringPointer() + } + + if !data.EndTimestamp.IsNull() && data.EndTimestamp.ValueString() != "" { + setWith = true + with.EndTimestamp = data.EndTimestamp.ValueStringPointer() + } + + if !data.NotifyUsers.IsNull() && len(data.NotifyUsers.Elements()) > 0 { + setWith = true + elements := make([]types.String, 0, len(data.NotifyUsers.Elements())) + var notifiedUsers []sdk.NotifiedUser + for _, e := range elements { + notifiedUsers = append(notifiedUsers, sdk.NotifiedUser{Name: e.ValueString()}) + } + with.NotifyUsers = &sdk.NotifyUsers{ + Users: notifiedUsers, + } + } + + if !data.Triggers.IsNull() && len(data.Triggers.Elements()) > 0 { + setWith = true + elements := make([]resourceMonitorTriggerModel, 0, len(data.Triggers.Elements())) + data.Triggers.ElementsAs(ctx, &elements, false) + var triggers []sdk.TriggerDefinition + for _, e := range elements { + triggers = append(triggers, sdk.TriggerDefinition{ + Threshold: int(e.Threshold.ValueInt64()), + TriggerAction: sdk.TriggerAction(e.TriggerAction.ValueString()), + }) + } + with.Triggers = triggers + } + + if setWith { + opts.With = with + } + err := client.ResourceMonitors.Create(ctx, id, opts) + + if dryRun { + return data, client.TraceLogs(), diags + } + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to create resource monitor, got error: %s", err)) + } + + data.Id = types.StringValue(id.FullyQualifiedName()) + r.read(ctx, data, false) + return data, nil, diags +} + +func (r *ResourceMonitorResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { + var data *resourceMonitorModelV1 + resp.Diagnostics.Append(req.State.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + data, _, diags := r.read(ctx, data, false) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + diags.Append(resp.State.Set(ctx, &data)...) +} + +func (r *ResourceMonitorResource) read(ctx context.Context, data *resourceMonitorModelV1, dryRun bool) (*resourceMonitorModelV1, []string, diag.Diagnostics) { + diags := diag.Diagnostics{} + client := r.client + if dryRun { + client = sdk.NewDryRunClient() + } + + id := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(data.Id.ValueString()) + resourceMonitor, err := client.ResourceMonitors.ShowByID(ctx, id) + if dryRun { + return data, client.TraceLogs(), diags + } + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to read database, got error: %s", err)) + return data, nil, diags + } + + data.CreditQuota = types.Float64Value(resourceMonitor.CreditQuota) + data.Frequency = types.StringValue(string(resourceMonitor.Frequency)) + switch resourceMonitor.Level { + case sdk.ResourceMonitorLevelAccount: + data.Level = types.StringValue("ACCOUNT") + case sdk.ResourceMonitorLevelWarehouse: + data.Level = types.StringValue("WAREHOUSE") + case sdk.ResourceMonitorLevelNull: + data.Level = types.StringValue("NULL") + } + data.UsedCredits = types.Float64Value(resourceMonitor.UsedCredits) + data.RemainingCredits = types.Float64Value(resourceMonitor.RemainingCredits) + + if resourceMonitor.StartTime != "" { + if data.StartTimestamp.ValueString() != "IMMEDIATELY" { + data.StartTimestamp = types.StringValue(resourceMonitor.StartTime) + } + } else { + data.StartTimestamp = types.StringNull() + } + if resourceMonitor.EndTime != "" { + data.EndTimestamp = types.StringValue(resourceMonitor.EndTime) + } + if len(resourceMonitor.NotifyUsers) == 0 { + data.NotifyUsers = types.SetNull(types.StringType) + } else { + var notifyUsers []types.String + for _, e := range resourceMonitor.NotifyUsers { + notifyUsers = append(notifyUsers, types.StringValue(e)) + } + var diag diag.Diagnostics + data.NotifyUsers, diag = types.SetValueFrom(ctx, types.StringType, notifyUsers) + diags = append(diags, diag...) + } + + triggersObjectType := types.ObjectType{}.WithAttributeTypes(map[string]attr.Type{ + "threshold": types.Int64Type, + "trigger_action": types.StringType, + }) + if len(resourceMonitor.NotifyTriggers) == 0 && resourceMonitor.SuspendAt == nil && resourceMonitor.SuspendImmediateAt == nil { + data.Triggers = types.SetNull(triggersObjectType) + } else { + var triggers []resourceMonitorTriggerModel + for _, e := range resourceMonitor.NotifyTriggers { + triggers = append(triggers, resourceMonitorTriggerModel{ + Threshold: types.Int64Value(int64(e)), + TriggerAction: types.StringValue(string(sdk.TriggerActionNotify)), + }) + } + if resourceMonitor.SuspendAt != nil { + triggers = append(triggers, resourceMonitorTriggerModel{ + Threshold: types.Int64Value(int64(*resourceMonitor.SuspendAt)), + TriggerAction: types.StringValue(string(sdk.TriggerActionSuspend)), + }) + } + + var diag diag.Diagnostics + data.Triggers, diag = types.SetValueFrom(ctx, triggersObjectType, triggers) + diags = append(diags, diag...) + } + + data.Id = types.StringValue(id.FullyQualifiedName()) + return data, nil, diags +} + +func (r *ResourceMonitorResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { + var plan, state resourceMonitorModelV1 + resp.Diagnostics.Append(req.Plan.Get(ctx, &plan)...) + resp.Diagnostics.Append(req.State.Get(ctx, &state)...) + if resp.Diagnostics.HasError() { + return + } + + data, _, diags := r.update(ctx, &plan, &state, false) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } + diags.Append(resp.State.Set(ctx, &data)...) +} + +func (r *ResourceMonitorResource) update(ctx context.Context, plan *resourceMonitorModelV1, state *resourceMonitorModelV1, dryRun bool) (*resourceMonitorModelV1, []string, diag.Diagnostics) { + diags := diag.Diagnostics{} + client := r.client + if dryRun { + client = sdk.NewDryRunClient() + } + id := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(state.Id.ValueString()) + opts := &sdk.AlterResourceMonitorOptions{} + runUpdate := false + if !plan.CreditQuota.Equal(state.CreditQuota) { + runUpdate = true + if opts.Set == nil { + opts.Set = &sdk.ResourceMonitorSet{} + } + opts.Set.CreditQuota = sdk.Int(int(plan.CreditQuota.ValueFloat64())) + } + if !plan.Frequency.Equal(state.Frequency) { + runUpdate = true + if opts.Set == nil { + opts.Set = &sdk.ResourceMonitorSet{} + } + frequency, err := sdk.FrequencyFromString(plan.Frequency.ValueString()) + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to update resource monitor, got error: %s", err)) + return plan, nil, diags + } + opts.Set.Frequency = frequency + opts.Set.StartTimestamp = plan.StartTimestamp.ValueStringPointer() + } + if !plan.StartTimestamp.Equal(state.StartTimestamp) { + runUpdate = true + if opts.Set == nil { + opts.Set = &sdk.ResourceMonitorSet{} + } + frequency, err := sdk.FrequencyFromString(plan.Frequency.ValueString()) + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to update resource monitor, got error: %s", err)) + return plan, nil, diags + } + opts.Set.Frequency = frequency + opts.Set.StartTimestamp = plan.StartTimestamp.ValueStringPointer() + } + if !plan.EndTimestamp.Equal(state.EndTimestamp) && plan.EndTimestamp.ValueString() != "" { + runUpdate = true + if opts.Set == nil { + opts.Set = &sdk.ResourceMonitorSet{} + } + opts.Set.EndTimestamp = plan.EndTimestamp.ValueStringPointer() + } + if !plan.NotifyUsers.Equal(state.NotifyUsers) { + runUpdate = true + var notifiedUsers []sdk.NotifiedUser + elements := make([]types.String, 0, len(plan.NotifyUsers.Elements())) + plan.NotifyUsers.ElementsAs(ctx, &elements, false) + for _, e := range elements { + notifiedUsers = append(notifiedUsers, sdk.NotifiedUser{Name: e.ValueString()}) + } + opts.NotifyUsers = &sdk.NotifyUsers{ + Users: notifiedUsers, + } + } + + if !plan.Triggers.Equal(state.Triggers) { + runUpdate = true + var triggers []sdk.TriggerDefinition + elements := make([]resourceMonitorTriggerModel, 0, len(plan.Triggers.Elements())) + plan.Triggers.ElementsAs(ctx, &elements, false) + for _, e := range elements { + triggers = append(triggers, sdk.TriggerDefinition{ + Threshold: int(e.Threshold.ValueInt64()), + TriggerAction: sdk.TriggerAction(e.TriggerAction.ValueString()), + }) + } + opts.Triggers = triggers + } + + if runUpdate { + err := client.ResourceMonitors.Alter(ctx, id, opts) + if dryRun { + return plan, client.TraceLogs(), diags + } + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to update resource monitor, got error: %s", err)) + return plan, nil, diags + } + } + data, _, readDiags := r.read(ctx, plan, false) + diags.Append(readDiags...) + return data, nil, diags +} + +func (r *ResourceMonitorResource) Delete(ctx context.Context, req resource.DeleteRequest, resp *resource.DeleteResponse) { + var data *resourceMonitorModelV1 + resp.Diagnostics.Append(req.State.Get(ctx, &data)...) + if resp.Diagnostics.HasError() { + return + } + _, _, diags := r.delete(ctx, data, false) + resp.Diagnostics.Append(diags...) + if resp.Diagnostics.HasError() { + return + } +} + +func (r *ResourceMonitorResource) delete(ctx context.Context, data *resourceMonitorModelV1, dryRun bool) (*resourceMonitorModelV1, []string, diag.Diagnostics) { + client := r.client + if dryRun { + client = sdk.NewDryRunClient() + } + + diags := diag.Diagnostics{} + id := sdk.NewAccountObjectIdentifierFromFullyQualifiedName(data.Id.ValueString()) + err := client.ResourceMonitors.Drop(ctx, id) + if dryRun { + return data, client.TraceLogs(), diags + } + if err != nil { + diags.AddError("Client Error", fmt.Sprintf("Unable to delete database, got error: %s", err)) + return data, nil, diags + } + return data, nil, diags +} + +func (r *ResourceMonitorResource) ImportState(ctx context.Context, req resource.ImportStateRequest, resp *resource.ImportStateResponse) { + resource.ImportStatePassthroughID(ctx, path.Root("id"), req, resp) +} diff --git a/go.mod b/go.mod index f5c52822c6..15cef39951 100644 --- a/go.mod +++ b/go.mod @@ -71,6 +71,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v23.5.26+incompatible // indirect github.com/google/go-cmp v0.5.9 // indirect + github.com/gookit/color v1.5.4 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-checkpoint v0.5.0 // indirect @@ -118,6 +119,7 @@ require ( github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect github.com/vmihailenco/msgpack/v5 v5.4.0 // indirect github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect + github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 // indirect github.com/zclconf/go-cty v1.14.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.13.0 // indirect diff --git a/go.sum b/go.sum index 8f217b4c9b..e882f361cd 100644 --- a/go.sum +++ b/go.sum @@ -140,6 +140,8 @@ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= +github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= @@ -312,6 +314,8 @@ github.com/vmihailenco/msgpack/v5 v5.4.0/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21 github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= +github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778 h1:QldyIu/L63oPpyvQmHgvgickp1Yw510KJOqX7H24mg8= +github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a h1:fZHgsYlfvtyqToslyjUt3VOPF4J7aK/3MPcK7xp3PDk= github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a/go.mod h1:ul22v+Nro/R083muKhosV54bj5niojjWZvU8xrevuH4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= diff --git a/main.go b/main.go index 18cc74af8c..d0118612ae 100644 --- a/main.go +++ b/main.go @@ -1,21 +1,63 @@ package main import ( + "context" "flag" + "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider" - "github.com/hashicorp/terraform-plugin-sdk/v2/plugin" + oldprovider "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider" + + "github.com/hashicorp/terraform-plugin-go/tfprotov6" + "github.com/hashicorp/terraform-plugin-go/tfprotov6/tf6server" + "github.com/hashicorp/terraform-plugin-mux/tf5to6server" + "github.com/hashicorp/terraform-plugin-mux/tf6muxserver" ) -const ProviderAddr = "registry.terraform.io/Snowflake-Labs/snowflake" +var version string = "dev" // goreleaser can pass other information to the main package, such as the specific commit +// https://goreleaser.com/cookbooks/using-main.version/ func main() { - debug := flag.Bool("debug", false, "set to true to run the provider with support for debuggers like delve") + ctx := context.Background() + + var debug bool + + flag.BoolVar(&debug, "debug", false, "set to true to run the provider with support for debuggers like delve") flag.Parse() - plugin.Serve(&plugin.ServeOpts{ - Debug: *debug, - ProviderAddr: ProviderAddr, - ProviderFunc: provider.Provider, - }) + upgradedSdkServer, err := tf5to6server.UpgradeServer( + ctx, + oldprovider.Provider().GRPCProvider, + ) + if err != nil { + log.Fatal(err) + } + + providers := []func() tfprotov6.ProviderServer{ + // disabled until ready to start using + // providerserver.NewProtocol6(provider.New(version)()), + func() tfprotov6.ProviderServer { + return upgradedSdkServer + }, + } + + muxServer, err := tf6muxserver.NewMuxServer(ctx, providers...) + if err != nil { + log.Fatal(err) + } + + var serveOpts []tf6server.ServeOpt + + if debug { + serveOpts = append(serveOpts, tf6server.WithManagedDebug()) + } + + err = tf6server.Serve( + "registry.terraform.io/Snowflake-Labs/snowflake", + muxServer.ProviderServer, + serveOpts..., + ) + + if err != nil { + log.Fatal(err) + } } diff --git a/pkg/provider/provider.go b/pkg/provider/provider.go index 3453616873..e2270cf62c 100644 --- a/pkg/provider/provider.go +++ b/pkg/provider/provider.go @@ -61,7 +61,7 @@ func Provider() *schema.Provider { }, "validate_default_parameters": { Type: schema.TypeBool, - Description: "If true, disables the validation checks for Database, Schema, Warehouse and Role at the time a connection is established. Can also be sourced from the `SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS` environment variable.", + Description: "True by default. If false, disables the validation checks for Database, Schema, Warehouse and Role at the time a connection is established. Can also be sourced from the `SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS` environment variable.", Optional: true, DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_VALIDATE_DEFAULT_PARAMETERS", nil), }, @@ -180,7 +180,7 @@ func Provider() *schema.Provider { Optional: true, DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_INSECURE_MODE", nil), }, - "oscp_fail_open": { + "ocsp_fail_open": { Type: schema.TypeBool, Description: "True represents OCSP fail open mode. False represents OCSP fail closed mode. Fail open true by default. Can also be sourced from the `SNOWFLAKE_OCSP_FAIL_OPEN` environment variable.", Optional: true, @@ -372,13 +372,12 @@ func Provider() *schema.Provider { Deprecated: "Use `token_accessor.0.redirect_uri` instead", }, "browser_auth": { - Type: schema.TypeBool, - Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_USE_BROWSER_AUTH` environment variable.", - Optional: true, - Sensitive: false, - DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_USE_BROWSER_AUTH", nil), - Deprecated: "Use `authenticator` instead", - ConflictsWith: []string{"password", "private_key_path", "private_key", "private_key_passphrase", "oauth_access_token", "oauth_refresh_token"}, + Type: schema.TypeBool, + Description: "Required when `oauth_refresh_token` is used. Can also be sourced from `SNOWFLAKE_USE_BROWSER_AUTH` environment variable.", + Optional: true, + Sensitive: false, + DefaultFunc: schema.EnvDefaultFunc("SNOWFLAKE_USE_BROWSER_AUTH", nil), + Deprecated: "Use `authenticator` instead", }, "private_key_path": { Type: schema.TypeString, @@ -610,25 +609,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { } if v, ok := s.GetOk("authenticator"); ok && v.(string) != "" { - authenticator := v.(string) - switch authenticator { - case "Snowflake": - config.Authenticator = gosnowflake.AuthTypeSnowflake - case "OAuth": - config.Authenticator = gosnowflake.AuthTypeOAuth - case "ExternalBrowser": - config.Authenticator = gosnowflake.AuthTypeExternalBrowser - case "Okta": - config.Authenticator = gosnowflake.AuthTypeOkta - case "JWT": - config.Authenticator = gosnowflake.AuthTypeJwt - case "TokenAccessor": - config.Authenticator = gosnowflake.AuthTypeTokenAccessor - case "UsernamePasswordMFA": - config.Authenticator = gosnowflake.AuthTypeUsernamePasswordMFA - default: - return nil, fmt.Errorf("invalid authenticator %s", authenticator) - } + config.Authenticator = toAuthenticatorType(v.(string)) } if v, ok := s.GetOk("passcode"); ok && v.(string) != "" { @@ -674,7 +655,7 @@ func ConfigureProvider(s *schema.ResourceData) (interface{}, error) { config.InsecureMode = v.(bool) } - if v, ok := s.GetOk("oscp_fail_open"); ok && v.(bool) { + if v, ok := s.GetOk("ocsp_fail_open"); ok && v.(bool) { config.OCSPFailOpen = gosnowflake.OCSPFailOpenTrue } diff --git a/pkg/provider/provider_helpers.go b/pkg/provider/provider_helpers.go index a7a18c6485..dfa0494c44 100644 --- a/pkg/provider/provider_helpers.go +++ b/pkg/provider/provider_helpers.go @@ -15,6 +15,7 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/mitchellh/go-homedir" + "github.com/snowflakedb/gosnowflake" "github.com/youmark/pkcs8" "golang.org/x/crypto/ssh" ) @@ -44,6 +45,54 @@ func getPrivateKey(privateKeyPath, privateKeyString, privateKeyPassphrase string return parsePrivateKey(privateKeyBytes, []byte(privateKeyPassphrase)) } +func toAuthenticatorType(authenticator string) gosnowflake.AuthType { + switch authenticator { + case "Snowflake": + return gosnowflake.AuthTypeSnowflake + case "OAuth": + return gosnowflake.AuthTypeOAuth + case "ExternalBrowser": + return gosnowflake.AuthTypeExternalBrowser + case "Okta": + return gosnowflake.AuthTypeOkta + case "JWT": + return gosnowflake.AuthTypeJwt + case "TokenAccessor": + return gosnowflake.AuthTypeTokenAccessor + case "UsernamePasswordMFA": + return gosnowflake.AuthTypeUsernamePasswordMFA + default: + return gosnowflake.AuthTypeSnowflake + } +} + +func getInt64Env(key string, defaultValue int64) int64 { + s := os.Getenv(key) + if s == "" { + return defaultValue + } + i, err := strconv.Atoi(s) + if err != nil { + return defaultValue + } + return int64(i) +} + +func getBoolEnv(key string, defaultValue bool) bool { + s := strings.ToLower(os.Getenv(key)) + if s == "" { + return defaultValue + } + switch s { + case "true", "1": + return true + case "false", "0": + return false + default: + return defaultValue + } +} + func readFile(privateKeyPath string) ([]byte, error) { expandedPrivateKeyPath, err := homedir.Expand(privateKeyPath) if err != nil { diff --git a/pkg/sdk/client.go b/pkg/sdk/client.go index eb056c0ec3..777479e95d 100644 --- a/pkg/sdk/client.go +++ b/pkg/sdk/client.go @@ -18,6 +18,8 @@ type Client struct { db *sqlx.DB sessionID string accountLocator string + dryRun bool + traceLogs []string // System-Defined Functions ContextFunctions ContextFunctions @@ -72,6 +74,15 @@ func NewDefaultClient() (*Client, error) { return NewClient(nil) } +func NewDryRunClient() *Client { + client := &Client{ + dryRun: true, + traceLogs: []string{}, + } + client.initialize() + return client +} + func NewClient(cfg *gosnowflake.Config) (*Client, error) { var err error if cfg == nil { @@ -175,6 +186,10 @@ func (c *Client) initialize() { c.Warehouses = &warehouses{client: c} } +func (c *Client) TraceLogs() []string { + return c.traceLogs +} + func (c *Client) Ping() error { return c.db.Ping() } @@ -194,6 +209,11 @@ const ( // Exec executes a query that does not return rows. func (c *Client) exec(ctx context.Context, sql string) (sql.Result, error) { + if c.dryRun { + c.traceLogs = append(c.traceLogs, sql) + log.Printf("[DEBUG] sql-conn-exec-dry: %v\n", sql) + return nil, nil + } ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) result, err := c.db.ExecContext(ctx, sql) return result, decodeDriverError(err) @@ -201,12 +221,22 @@ func (c *Client) exec(ctx context.Context, sql string) (sql.Result, error) { // query runs a query and returns the rows. dest is expected to be a slice of structs. func (c *Client) query(ctx context.Context, dest interface{}, sql string) error { + if c.dryRun { + c.traceLogs = append(c.traceLogs, sql) + log.Printf("[DEBUG] sql-conn-query-dry: %v\n", sql) + return nil + } ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) return decodeDriverError(c.db.SelectContext(ctx, dest, sql)) } // queryOne runs a query and returns one row. dest is expected to be a pointer to a struct. func (c *Client) queryOne(ctx context.Context, dest interface{}, sql string) error { + if c.dryRun { + c.traceLogs = append(c.traceLogs, sql) + log.Printf("[DEBUG] sql-conn-query-one-dry: %v\n", sql) + return nil + } ctx = context.WithValue(ctx, snowflakeAccountLocatorContextKey, c.accountLocator) return decodeDriverError(c.db.GetContext(ctx, dest, sql)) }