From 0d49c7e8e76f9a4ccb3313162d9994c083865ffa Mon Sep 17 00:00:00 2001 From: liurong Date: Mon, 18 Dec 2023 19:46:09 +0800 Subject: [PATCH] change configFilterChain from public sigleton to private owned by ConfigClient --- clients/config_client/config_client.go | 84 ++++----------- common/encryption/const.go | 4 +- common/encryption/handler.go | 118 +++++++++++++------- common/encryption/kms_client.go | 68 ++++-------- common/encryption/kms_plugins.go | 126 +++++++++++++--------- common/filter/config_encryption_filter.go | 42 ++++---- common/filter/config_filter.go | 34 +----- example/config-mse-kmsv3/main.go | 36 ++----- 8 files changed, 233 insertions(+), 279 deletions(-) diff --git a/clients/config_client/config_client.go b/clients/config_client/config_client.go index 3886f864..0df404db 100644 --- a/clients/config_client/config_client.go +++ b/clients/config_client/config_client.go @@ -20,12 +20,9 @@ import ( "context" "fmt" "os" - "strings" "sync" "time" - "github.com/alibabacloud-go/tea/tea" - dkms_api "github.com/aliyun/alibabacloud-dkms-gcs-go-sdk/openapi" "github.com/nacos-group/nacos-sdk-go/v2/clients/cache" "github.com/nacos-group/nacos-sdk-go/v2/clients/nacos_client" "github.com/nacos-group/nacos-sdk-go/v2/common/constant" @@ -52,15 +49,15 @@ type ConfigClient struct { ctx context.Context cancel context.CancelFunc nacos_client.INacosClient - kmsClient *nacos_inner_encryption.KmsClient - localConfigs []vo.ConfigParam - mutex sync.Mutex - configProxy IConfigProxy - configCacheDir string - lastAllSyncTime time.Time - cacheMap cache.ConcurrentMap - uid string - listenExecute chan struct{} + configFilterChainManager filter.IConfigFilterChain + localConfigs []vo.ConfigParam + mutex sync.Mutex + configProxy IConfigProxy + configCacheDir string + lastAllSyncTime time.Time + cacheMap cache.ConcurrentMap + uid string + listenExecute chan struct{} } type cacheData struct { @@ -94,7 +91,7 @@ func (cacheData *cacheData) executeListener() { EncryptedDataKey: cacheData.encryptedDataKey, UsageType: vo.ResponseType, } - if err := filter.GetDefaultConfigFilterChainManager().DoFilters(param); err != nil { + if err := cacheData.configClient.configFilterChainManager.DoFilters(param); err != nil { logger.Errorf("do filters failed ,dataId=%s,group=%s,tenant=%s,err:%+v ", cacheData.dataId, cacheData.group, cacheData.tenant, err) return @@ -130,22 +127,16 @@ func NewConfigClient(nc nacos_client.INacosClient) (*ConfigClient, error) { return nil, err } + config.configFilterChainManager = filter.NewConfigFilterChainManager() + if clientConfig.OpenKMS { - filter.RegisterDefaultConfigEncryptionFilter() - nacos_inner_encryption.RegisterConfigEncryptionKmsPlugins() - var kmsClient *nacos_inner_encryption.KmsClient - switch clientConfig.KMSVersion { - case constant.KMSv1, constant.DEFAULT_KMS_VERSION: - kmsClient, err = initKmsV1Client(clientConfig) - case constant.KMSv3: - kmsClient, err = initKmsV3Client(clientConfig) - default: - err = fmt.Errorf("init kms client failed. unknown kms version:%s\n", clientConfig.KMSVersion) - } + kmsEncryptionHandler := nacos_inner_encryption.NewKmsHandler() + nacos_inner_encryption.RegisterConfigEncryptionKmsPlugins(kmsEncryptionHandler, clientConfig) + encryptionFilter := filter.NewDefaultConfigEncryptionFilter(kmsEncryptionHandler) + err := filter.RegisterConfigFilterToChain(config.configFilterChainManager, encryptionFilter) if err != nil { - return nil, err + logger.Error(err) } - config.kmsClient = kmsClient } uid, err := uuid.NewV4() @@ -164,19 +155,6 @@ func initLogger(clientConfig constant.ClientConfig) error { return logger.InitLogger(logger.BuildLoggerConfig(clientConfig)) } -func initKmsV1Client(clientConfig constant.ClientConfig) (*nacos_inner_encryption.KmsClient, error) { - return nacos_inner_encryption.InitDefaultKmsV1ClientWithAccessKey(clientConfig.RegionId, clientConfig.AccessKey, clientConfig.SecretKey) -} - -func initKmsV3Client(clientConfig constant.ClientConfig) (*nacos_inner_encryption.KmsClient, error) { - return nacos_inner_encryption.InitDefaultKmsV3ClientWithConfig(&dkms_api.Config{ - Protocol: tea.String("https"), - Endpoint: tea.String(clientConfig.KMSv3Config.Endpoint), - ClientKeyContent: tea.String(clientConfig.KMSv3Config.ClientKeyContent), - Password: tea.String(clientConfig.KMSv3Config.Password), - }, clientConfig.KMSv3Config.CaContent) -} - func (client *ConfigClient) GetConfig(param vo.ConfigParam) (content string, err error) { content, encryptedDataKey, err := client.getConfigInner(param) if err != nil { @@ -186,37 +164,13 @@ func (client *ConfigClient) GetConfig(param vo.ConfigParam) (content string, err deepCopyParam.EncryptedDataKey = encryptedDataKey deepCopyParam.Content = content deepCopyParam.UsageType = vo.ResponseType - if err = filter.GetDefaultConfigFilterChainManager().DoFilters(deepCopyParam); err != nil { + if err = client.configFilterChainManager.DoFilters(deepCopyParam); err != nil { return "", err } content = deepCopyParam.Content return content, nil } -func (client *ConfigClient) decrypt(dataId, content string) (string, error) { - var plainContent string - var err error - if client.kmsClient != nil && strings.HasPrefix(dataId, "cipher-") { - plainContent, err = client.kmsClient.Decrypt(content) - if err != nil { - return "", fmt.Errorf("kms decrypt failed: %v", err) - } - } - return plainContent, nil -} - -func (client *ConfigClient) encrypt(dataId, content, kmsKeyId string) (string, error) { - var cipherContent string - var err error - if client.kmsClient != nil && strings.HasPrefix(dataId, "cipher-") { - cipherContent, err = client.kmsClient.Encrypt(content, kmsKeyId) - if err != nil { - return "", fmt.Errorf("kms encrypt failed: %v", err) - } - } - return cipherContent, nil -} - func (client *ConfigClient) getConfigInner(param vo.ConfigParam) (content, encryptedDataKey string, err error) { if len(param.DataId) <= 0 { err = errors.New("[client.GetConfig] param.dataId can not be empty") @@ -278,7 +232,7 @@ func (client *ConfigClient) PublishConfig(param vo.ConfigParam) (published bool, } param.UsageType = vo.RequestType - if err = filter.GetDefaultConfigFilterChainManager().DoFilters(¶m); err != nil { + if err = client.configFilterChainManager.DoFilters(¶m); err != nil { return false, err } diff --git a/common/encryption/const.go b/common/encryption/const.go index e3a769bc..6dfe9aee 100644 --- a/common/encryption/const.go +++ b/common/encryption/const.go @@ -29,12 +29,14 @@ const ( kmsAes256KeySpec = "AES_256" kmsScheme = "https" - kmsAcceptFormat = "JSON" + kmsAcceptFormat = "XML" kmsCipherAlgorithm = "AES/ECB/PKCS5Padding" maskUnit8Width = 8 maskUnit32Width = 32 + + KmsHandlerName = "KmsHandler" ) var ( diff --git a/common/encryption/handler.go b/common/encryption/handler.go index f13fb2df..11dae772 100644 --- a/common/encryption/handler.go +++ b/common/encryption/handler.go @@ -18,14 +18,12 @@ package encryption import ( "fmt" + "github.com/alibabacloud-go/tea/tea" + dkms_api "github.com/aliyun/alibabacloud-dkms-gcs-go-sdk/openapi" + "github.com/nacos-group/nacos-sdk-go/v2/common/constant" "github.com/nacos-group/nacos-sdk-go/v2/common/logger" + "github.com/pkg/errors" "strings" - "sync" -) - -var ( - initDefaultHandlerOnce = &sync.Once{} - defaultHandler *DefaultHandler ) type HandlerParam struct { @@ -49,33 +47,52 @@ type Handler interface { EncryptionHandler(*HandlerParam) error DecryptionHandler(*HandlerParam) error RegisterPlugin(Plugin) error + GetHandlerName() string +} + +func NewKmsHandler() Handler { + return newKmsHandler() } -func GetDefaultHandler() Handler { - if defaultHandler == nil { - initDefaultHandler() +func newKmsHandler() *KmsHandler { + kmsHandler := &KmsHandler{ + encryptionPlugins: make(map[string]Plugin, 2), } - return defaultHandler + logger.Debug("successfully create encryption KmsHandler") + return kmsHandler } -func initDefaultHandler() { - initDefaultHandlerOnce.Do(func() { - defaultHandler = &DefaultHandler{ - encryptionPlugins: make(map[string]Plugin, 2), - } - logger.Debug("successfully create encryption defaultHandler") - }) +func RegisterConfigEncryptionKmsPlugins(encryptionHandler Handler, clientConfig constant.ClientConfig) { + innerKmsClient, err := innerNewKmsClient(clientConfig) + if err == nil && innerKmsClient == nil { + err = errors.New("create kms client failed.") + } + if err != nil { + logger.Error(err) + } + if err := encryptionHandler.RegisterPlugin(&KmsAes128Plugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil { + logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAes128AlgorithmName, encryptionHandler.GetHandlerName()) + } else { + logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAes128AlgorithmName, encryptionHandler.GetHandlerName()) + } + if err := encryptionHandler.RegisterPlugin(&KmsAes256Plugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil { + logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAes256AlgorithmName, encryptionHandler.GetHandlerName()) + } else { + logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAes256AlgorithmName, encryptionHandler.GetHandlerName()) + } + if err := encryptionHandler.RegisterPlugin(&KmsBasePlugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil { + logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAlgorithmName, encryptionHandler.GetHandlerName()) + } else { + logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAlgorithmName, encryptionHandler.GetHandlerName()) + } } -type DefaultHandler struct { +type KmsHandler struct { encryptionPlugins map[string]Plugin } -func (d *DefaultHandler) EncryptionHandler(param *HandlerParam) error { +func (d *KmsHandler) EncryptionHandler(param *HandlerParam) error { if err := d.encryptionParamCheck(*param); err != nil { - if err == DataIdParamCheckError || err == ContentParamCheckError { - return nil - } return err } plugin, err := d.getPluginByDataIdPrefix(param.DataId) @@ -90,11 +107,8 @@ func (d *DefaultHandler) EncryptionHandler(param *HandlerParam) error { return plugin.Encrypt(param) } -func (d *DefaultHandler) DecryptionHandler(param *HandlerParam) error { +func (d *KmsHandler) DecryptionHandler(param *HandlerParam) error { if err := d.decryptionParamCheck(*param); err != nil { - if err == DataIdParamCheckError || err == ContentParamCheckError { - return nil - } return err } plugin, err := d.getPluginByDataIdPrefix(param.DataId) @@ -109,7 +123,7 @@ func (d *DefaultHandler) DecryptionHandler(param *HandlerParam) error { return plugin.Decrypt(param) } -func (d *DefaultHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error) { +func (d *KmsHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error) { var ( matchedCount int matchedPlugin Plugin @@ -128,7 +142,7 @@ func (d *DefaultHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error) return matchedPlugin, nil } -func (d *DefaultHandler) RegisterPlugin(plugin Plugin) error { +func (d *KmsHandler) RegisterPlugin(plugin Plugin) error { if _, v := d.encryptionPlugins[plugin.AlgorithmName()]; v { logger.Warnf("encryption algorithm [%s] has already registered to defaultHandler, will be update", plugin.AlgorithmName()) } else { @@ -138,7 +152,11 @@ func (d *DefaultHandler) RegisterPlugin(plugin Plugin) error { return nil } -func (d *DefaultHandler) encryptionParamCheck(param HandlerParam) error { +func (d *KmsHandler) GetHandlerName() string { + return KmsHandlerName +} + +func (d *KmsHandler) encryptionParamCheck(param HandlerParam) error { if err := d.dataIdParamCheck(param.DataId); err != nil { return DataIdParamCheckError } @@ -148,26 +166,52 @@ func (d *DefaultHandler) encryptionParamCheck(param HandlerParam) error { return nil } -func (d *DefaultHandler) decryptionParamCheck(param HandlerParam) error { - if err := d.dataIdParamCheck(param.DataId); err != nil { - return DataIdParamCheckError - } - if err := d.contentParamCheck(param.Content); err != nil { - return ContentParamCheckError +func (d *KmsHandler) decryptionParamCheck(param HandlerParam) error { + return d.encryptionParamCheck(param) +} + +func (d *KmsHandler) keyIdParamCheck(keyId string) error { + if len(keyId) == 0 { + return fmt.Errorf("cipher dataId using kmsService need to set keyId, but keyId is nil") } return nil } -func (d *DefaultHandler) dataIdParamCheck(dataId string) error { +func (d *KmsHandler) dataIdParamCheck(dataId string) error { if !strings.Contains(dataId, CipherPrefix) { return fmt.Errorf("dataId prefix should start with: %s", CipherPrefix) } return nil } -func (d *DefaultHandler) contentParamCheck(content string) error { +func (d *KmsHandler) contentParamCheck(content string) error { if len(content) == 0 { return fmt.Errorf("content need to encrypt is nil") } return nil } + +func innerNewKmsClient(clientConfig constant.ClientConfig) (kmsClient *KmsClient, err error) { + switch clientConfig.KMSVersion { + case constant.KMSv1, constant.DEFAULT_KMS_VERSION: + kmsClient, err = newKmsV1Client(clientConfig) + case constant.KMSv3: + kmsClient, err = newKmsV3Client(clientConfig) + default: + err = fmt.Errorf("init kms client failed. unknown kms version:%s\n", clientConfig.KMSVersion) + } + return kmsClient, err +} + +func newKmsV1Client(clientConfig constant.ClientConfig) (*KmsClient, error) { + return NewKmsV1ClientWithAccessKey(clientConfig.RegionId, clientConfig.AccessKey, clientConfig.SecretKey) +} + +func newKmsV3Client(clientConfig constant.ClientConfig) (*KmsClient, error) { + return NewKmsV3ClientWithConfig(&dkms_api.Config{ + Protocol: tea.String("https"), + Endpoint: tea.String(clientConfig.KMSv3Config.Endpoint), + ClientKeyContent: tea.String(clientConfig.KMSv3Config.ClientKeyContent), + Password: tea.String(clientConfig.KMSv3Config.Password), + }, clientConfig.KMSv3Config.CaContent) +} diff --git a/common/encryption/kms_client.go b/common/encryption/kms_client.go index 398699cd..219ff3a3 100644 --- a/common/encryption/kms_client.go +++ b/common/encryption/kms_client.go @@ -26,12 +26,6 @@ import ( "github.com/pkg/errors" "net/http" "strings" - "sync" -) - -var ( - initKmsClientOnce = &sync.Once{} - kmsClient *KmsClient ) type KmsClient struct { @@ -39,24 +33,18 @@ type KmsClient struct { kmsVersion constant.KMSVersion } -func InitDefaultKmsV1ClientWithAccessKey(regionId, ak, sk string) (*KmsClient, error) { +func NewKmsV1ClientWithAccessKey(regionId, ak, sk string) (*KmsClient, error) { var rErr error - if GetDefaultKmsClient() != nil { - return GetDefaultKmsClient(), rErr - } if rErr = checkKmsV1InitParam(regionId, ak, sk); rErr != nil { return nil, rErr } - initKmsClientOnce.Do(func() { - client, err := NewKmsV1ClientWithAccessKey(regionId, ak, sk) - if err != nil { - rErr = errors.Wrap(err, "init kms v1 client with ak/sk failed") - } else { - client.SetKmsVersion(constant.KMSv1) - kmsClient = client - } - }) - return GetDefaultKmsClient(), rErr + kmsClient, err := newKmsV1ClientWithAccessKey(regionId, ak, sk) + if err != nil { + rErr = errors.Wrap(err, "init kms v1 client with ak/sk failed") + } else { + kmsClient.setKmsVersion(constant.KMSv1) + } + return kmsClient, rErr } func checkKmsV1InitParam(regionId, ak, sk string) error { @@ -72,30 +60,24 @@ func checkKmsV1InitParam(regionId, ak, sk string) error { return nil } -func InitDefaultKmsV3ClientWithConfig(config *dkms_api.Config, caVerify string) (*KmsClient, error) { +func NewKmsV3ClientWithConfig(config *dkms_api.Config, caVerify string) (*KmsClient, error) { var rErr error - if GetDefaultKmsClient() != nil { - return GetDefaultKmsClient(), rErr - } if rErr = checkKmsV3InitParam(config, caVerify); rErr != nil { return nil, rErr } - initKmsClientOnce.Do(func() { - client, err := NewKmsV3ClientWithConfig(config) - if err != nil { - rErr = errors.Wrap(err, "init kms v3 client with config failed") + kmsClient, err := newKmsV3ClientWithConfig(config) + if err != nil { + rErr = errors.Wrap(err, "init kms v3 client with config failed") + } else { + if len(strings.TrimSpace(caVerify)) != 0 { + logger.Debugf("set kms client Ca with content: %s\n", caVerify[:len(caVerify)/maskUnit32Width]) + kmsClient.SetVerify(caVerify) } else { - if len(strings.TrimSpace(caVerify)) != 0 { - logger.Debugf("set kms client Ca with content: %s\n", caVerify[:len(caVerify)/maskUnit32Width]) - client.SetVerify(caVerify) - } else { - client.SetHTTPSInsecure(true) - } - client.SetKmsVersion(constant.KMSv3) - kmsClient = client + kmsClient.SetHTTPSInsecure(true) } - }) - return GetDefaultKmsClient(), rErr + kmsClient.setKmsVersion(constant.KMSv3) + } + return kmsClient, rErr } func checkKmsV3InitParam(config *dkms_api.Config, caVerify string) error { @@ -114,17 +96,13 @@ func checkKmsV3InitParam(config *dkms_api.Config, caVerify string) error { return nil } -func GetDefaultKmsClient() *KmsClient { - return kmsClient -} - -func NewKmsV1ClientWithAccessKey(regionId, ak, sk string) (*KmsClient, error) { +func newKmsV1ClientWithAccessKey(regionId, ak, sk string) (*KmsClient, error) { logger.Debugf("init kms client with region:[%s], ak:[%s]xxx, sk:[%s]xxx\n", regionId, ak[:len(ak)/maskUnit8Width], sk[:len(sk)/maskUnit8Width]) return newKmsClient(regionId, ak, sk, nil) } -func NewKmsV3ClientWithConfig(config *dkms_api.Config) (*KmsClient, error) { +func newKmsV3ClientWithConfig(config *dkms_api.Config) (*KmsClient, error) { logger.Debugf("init kms client with endpoint:[%s], clientKeyContent:[%s], password:[%s]\n", config.Endpoint, (*config.ClientKeyContent)[:len(*config.ClientKeyContent)/maskUnit8Width], (*config.Password)[:len(*config.Password)/maskUnit8Width]) @@ -145,7 +123,7 @@ func (kmsClient *KmsClient) GetKmsVersion() constant.KMSVersion { return kmsClient.kmsVersion } -func (kmsClient *KmsClient) SetKmsVersion(kmsVersion constant.KMSVersion) { +func (kmsClient *KmsClient) setKmsVersion(kmsVersion constant.KMSVersion) { logger.Debug("successfully set kms client version to " + kmsVersion) kmsClient.kmsVersion = kmsVersion } diff --git a/common/encryption/kms_plugins.go b/common/encryption/kms_plugins.go index 1d3b7b5a..4d38cca7 100644 --- a/common/encryption/kms_plugins.go +++ b/common/encryption/kms_plugins.go @@ -17,39 +17,20 @@ package encryption import ( + "fmt" + "github.com/nacos-group/nacos-sdk-go/v2/common/constant" inner_encoding "github.com/nacos-group/nacos-sdk-go/v2/common/encoding" - "github.com/nacos-group/nacos-sdk-go/v2/common/logger" "strings" ) -func RegisterConfigEncryptionKmsPlugins() { - if err := GetDefaultHandler().RegisterPlugin(&KmsAes128Plugin{}); err != nil { - logger.Errorf("failed to register encryption plugin[%s] to defaultHandler", KmsAes128AlgorithmName) - } else { - logger.Debugf("successfully register encryption plugin[%s] to defaultHandler", KmsAes128AlgorithmName) - } - if err := GetDefaultHandler().RegisterPlugin(&KmsAes256Plugin{}); err != nil { - logger.Errorf("failed to register encryption plugin[%s] to defaultHandler", KmsAes256AlgorithmName) - } else { - logger.Debugf("successfully register encryption plugin[%s] to defaultHandler", KmsAes256AlgorithmName) - } - if err := GetDefaultHandler().RegisterPlugin(&KmsBasePlugin{}); err != nil { - logger.Errorf("failed to register encryption plugin[%s] to defaultHandler", KmsAlgorithmName) - } else { - logger.Debugf("successfully register encryption plugin[%s] to defaultHandler", KmsAlgorithmName) - - } -} - type kmsPlugin struct { + kmsClient *KmsClient } func (k *kmsPlugin) Encrypt(param *HandlerParam) error { - if len(param.Content) == 0 { - return nil - } - if len(param.PlainDataKey) == 0 { - return EmptyPlainDataKeyError + err := k.encryptionParamCheck(*param) + if err != nil { + return err } secretKeyBase64Decoded, err := inner_encoding.DecodeBase64(inner_encoding.DecodeString2Utf8Bytes(param.PlainDataKey)) if err != nil { @@ -69,11 +50,9 @@ func (k *kmsPlugin) Encrypt(param *HandlerParam) error { } func (k *kmsPlugin) Decrypt(param *HandlerParam) error { - if len(param.Content) == 0 { - return nil - } - if len(param.PlainDataKey) == 0 { - return EmptyPlainDataKeyError + err := k.decryptionParamCheck(*param) + if err != nil { + return err } secretKeyBase64Decoded, err := inner_encoding.DecodeBase64(inner_encoding.DecodeString2Utf8Bytes(param.PlainDataKey)) if err != nil { @@ -100,13 +79,15 @@ func (k *kmsPlugin) GenerateSecretKey(param *HandlerParam) (string, error) { } func (k *kmsPlugin) EncryptSecretKey(param *HandlerParam) (string, error) { - if err := keyIdParamCheck(param.KeyId); err != nil { + var keyId string + var err error + if keyId, err = k.keyIdParamCheck(param.KeyId); err != nil { return "", err } if len(param.PlainDataKey) == 0 { - return "", nil + return "", EmptyPlainDataKeyError } - encryptedDataKey, err := GetDefaultKmsClient().Encrypt(param.PlainDataKey, param.KeyId) + encryptedDataKey, err := k.kmsClient.Encrypt(param.PlainDataKey, keyId) if err != nil { return "", err } @@ -119,9 +100,9 @@ func (k *kmsPlugin) EncryptSecretKey(param *HandlerParam) (string, error) { func (k *kmsPlugin) DecryptSecretKey(param *HandlerParam) (string, error) { if len(param.EncryptedDataKey) == 0 { - return "", nil + return "", EmptyEncryptedDataKeyError } - plainDataKey, err := GetDefaultKmsClient().Decrypt(param.EncryptedDataKey) + plainDataKey, err := k.kmsClient.Decrypt(param.EncryptedDataKey) if err != nil { return "", err } @@ -132,6 +113,51 @@ func (k *kmsPlugin) DecryptSecretKey(param *HandlerParam) (string, error) { return plainDataKey, nil } +func (k *kmsPlugin) encryptionParamCheck(param HandlerParam) error { + if err := k.plainDataKeyParamCheck(param.PlainDataKey); err != nil { + return KeyIdParamCheckError + } + if err := k.contentParamCheck(param.Content); err != nil { + return ContentParamCheckError + } + return nil +} + +func (k *kmsPlugin) decryptionParamCheck(param HandlerParam) error { + return k.encryptionParamCheck(param) +} + +func (k *kmsPlugin) plainDataKeyParamCheck(plainDataKey string) error { + if len(plainDataKey) == 0 { + return EmptyPlainDataKeyError + } + return nil +} + +func (k *kmsPlugin) dataIdParamCheck(dataId string) error { + if !strings.Contains(dataId, CipherPrefix) { + return fmt.Errorf("dataId prefix should start with: %s", CipherPrefix) + } + return nil +} + +func (k *kmsPlugin) keyIdParamCheck(keyId string) (string, error) { + if len(strings.TrimSpace(keyId)) == 0 { + if k.kmsClient.GetKmsVersion() == constant.KMSv1 { + return GetDefaultKMSv1KeyId(), nil + } + return "", KeyIdParamCheckError + } + return keyId, nil +} + +func (k *kmsPlugin) contentParamCheck(content string) error { + if len(content) == 0 { + return fmt.Errorf("content need to encrypt is nil") + } + return nil +} + type KmsAes128Plugin struct { kmsPlugin } @@ -149,10 +175,12 @@ func (k *KmsAes128Plugin) AlgorithmName() string { } func (k *KmsAes128Plugin) GenerateSecretKey(param *HandlerParam) (string, error) { - if err := keyIdParamCheck(param.KeyId); err != nil { + var keyId string + var err error + if keyId, err = k.keyIdParamCheck(param.KeyId); err != nil { return "", err } - plainSecretKey, encryptedSecretKey, err := GetDefaultKmsClient().GenerateDataKey(param.KeyId, kmsAes128KeySpec) + plainSecretKey, encryptedSecretKey, err := k.kmsClient.GenerateDataKey(keyId, kmsAes128KeySpec) if err != nil { return "", err } @@ -193,10 +221,12 @@ func (k *KmsAes256Plugin) AlgorithmName() string { } func (k *KmsAes256Plugin) GenerateSecretKey(param *HandlerParam) (string, error) { - if err := keyIdParamCheck(param.KeyId); err != nil { + var keyId string + var err error + if keyId, err = k.keyIdParamCheck(param.KeyId); err != nil { return "", err } - plainSecretKey, encryptedSecretKey, err := GetDefaultKmsClient().GenerateDataKey(param.KeyId, kmsAes256KeySpec) + plainSecretKey, encryptedSecretKey, err := k.kmsClient.GenerateDataKey(keyId, kmsAes256KeySpec) if err != nil { return "", err } @@ -220,16 +250,19 @@ func (k *KmsAes256Plugin) DecryptSecretKey(param *HandlerParam) (string, error) } type KmsBasePlugin struct { + kmsPlugin } func (k *KmsBasePlugin) Encrypt(param *HandlerParam) error { - if err := keyIdParamCheck(param.KeyId); err != nil { + var keyId string + var err error + if keyId, err = k.keyIdParamCheck(param.KeyId); err != nil { return err } if len(param.Content) == 0 { - return nil + return EmptyContentError } - encryptedContent, err := GetDefaultKmsClient().Encrypt(param.Content, param.KeyId) + encryptedContent, err := k.kmsClient.Encrypt(param.Content, keyId) if err != nil { return err } @@ -241,7 +274,7 @@ func (k *KmsBasePlugin) Decrypt(param *HandlerParam) error { if len(param.Content) == 0 { return nil } - plainContent, err := GetDefaultKmsClient().Decrypt(param.Content) + plainContent, err := k.kmsClient.Decrypt(param.Content) if err != nil { return err } @@ -264,10 +297,3 @@ func (k *KmsBasePlugin) EncryptSecretKey(param *HandlerParam) (string, error) { func (k *KmsBasePlugin) DecryptSecretKey(param *HandlerParam) (string, error) { return "", nil } - -func keyIdParamCheck(keyId string) error { - if len(strings.TrimSpace(keyId)) == 0 { - return KeyIdParamCheckError - } - return nil -} diff --git a/common/filter/config_encryption_filter.go b/common/filter/config_encryption_filter.go index 010ceb7b..283fc410 100644 --- a/common/filter/config_encryption_filter.go +++ b/common/filter/config_encryption_filter.go @@ -17,13 +17,10 @@ package filter import ( - "fmt" - "github.com/nacos-group/nacos-sdk-go/v2/common/constant" nacos_inner_encryption "github.com/nacos-group/nacos-sdk-go/v2/common/encryption" - "github.com/nacos-group/nacos-sdk-go/v2/common/logger" "github.com/nacos-group/nacos-sdk-go/v2/vo" + "github.com/pkg/errors" "strings" - "sync" ) const ( @@ -31,29 +28,22 @@ const ( ) var ( - initDefaultConfigEncryptionFilterOnce = &sync.Once{} - defaultConfigEncryptionFilter IConfigFilter + noNeedEncryptionError = errors.New("dataId doesn't need to encrypt/decrypt.") ) type DefaultConfigEncryptionFilter struct { + handler nacos_inner_encryption.Handler } -func GetDefaultConfigEncryptionFilter() IConfigFilter { - if defaultConfigEncryptionFilter == nil { - initDefaultConfigEncryptionFilterOnce.Do(func() { - defaultConfigEncryptionFilter = &DefaultConfigEncryptionFilter{} - logger.Debugf("successfully create ConfigFilter[%s]", defaultConfigEncryptionFilter.GetFilterName()) - }) - } - return defaultConfigEncryptionFilter +func NewDefaultConfigEncryptionFilter(handler nacos_inner_encryption.Handler) IConfigFilter { + return &DefaultConfigEncryptionFilter{handler} } func (d *DefaultConfigEncryptionFilter) DoFilter(param *vo.ConfigParam) error { - if !strings.HasPrefix(param.DataId, nacos_inner_encryption.CipherPrefix) { - return nil - } - if nacos_inner_encryption.GetDefaultKmsClient() == nil { - return fmt.Errorf("kms client hasn't inited, can't publish config dataId start with: %s", nacos_inner_encryption.CipherPrefix) + if err := d.paramCheck(*param); err != nil { + if errors.Is(err, noNeedEncryptionError) { + return nil + } } if param.UsageType == vo.RequestType { encryptionParam := &nacos_inner_encryption.HandlerParam{ @@ -61,10 +51,7 @@ func (d *DefaultConfigEncryptionFilter) DoFilter(param *vo.ConfigParam) error { Content: param.Content, KeyId: param.KmsKeyId, } - if len(encryptionParam.KeyId) == 0 && nacos_inner_encryption.GetDefaultKmsClient().GetKmsVersion() == constant.KMSv1 { - encryptionParam.KeyId = nacos_inner_encryption.GetDefaultKMSv1KeyId() - } - if err := nacos_inner_encryption.GetDefaultHandler().EncryptionHandler(encryptionParam); err != nil { + if err := d.handler.EncryptionHandler(encryptionParam); err != nil { return err } param.Content = encryptionParam.Content @@ -76,7 +63,7 @@ func (d *DefaultConfigEncryptionFilter) DoFilter(param *vo.ConfigParam) error { Content: param.Content, EncryptedDataKey: param.EncryptedDataKey, } - if err := nacos_inner_encryption.GetDefaultHandler().DecryptionHandler(decryptionParam); err != nil { + if err := d.handler.DecryptionHandler(decryptionParam); err != nil { return err } param.Content = decryptionParam.Content @@ -91,3 +78,10 @@ func (d *DefaultConfigEncryptionFilter) GetOrder() int { func (d *DefaultConfigEncryptionFilter) GetFilterName() string { return defaultConfigEncryptionFilterName } +func (d *DefaultConfigEncryptionFilter) paramCheck(param vo.ConfigParam) error { + if !strings.HasPrefix(param.DataId, nacos_inner_encryption.CipherPrefix) || + len(strings.TrimSpace(param.Content)) == 0 { + return noNeedEncryptionError + } + return nil +} diff --git a/common/filter/config_filter.go b/common/filter/config_filter.go index 97a3379d..fec71855 100644 --- a/common/filter/config_filter.go +++ b/common/filter/config_filter.go @@ -18,14 +18,7 @@ package filter import ( "fmt" - "github.com/nacos-group/nacos-sdk-go/v2/common/logger" "github.com/nacos-group/nacos-sdk-go/v2/vo" - "sync" -) - -var ( - initConfigFilterChainManagerOnce = &sync.Once{} - defaultConfigFilterChainManagerInstance IConfigFilterChain ) type IConfigFilterChain interface { @@ -41,28 +34,15 @@ type IConfigFilter interface { GetFilterName() string } -func RegisterDefaultConfigEncryptionFilter() { - err := RegisterConfigFilter(GetDefaultConfigFilterChainManager(), GetDefaultConfigEncryptionFilter()) - if err != nil { - logger.Errorf("failed to register configFilter[%s] to DefaultConfigFilterChainManager", - GetDefaultConfigEncryptionFilter().GetFilterName()) - return - } else { - logger.Debugf("successfully register ConfigFilter[%s] to DefaultConfigFilterChainManager", GetDefaultConfigEncryptionFilter().GetFilterName()) - } +func RegisterConfigFilterToChain(chain IConfigFilterChain, filter IConfigFilter) error { + return chain.AddFilter(filter) } -func GetDefaultConfigFilterChainManager() IConfigFilterChain { - if defaultConfigFilterChainManagerInstance == nil { - initConfigFilterChainManagerOnce.Do(func() { - defaultConfigFilterChainManagerInstance = newDefaultConfigFilterChainManager() - logger.Debug("successfully create DefaultConfigFilterChainManager") - }) - } - return defaultConfigFilterChainManagerInstance +func NewConfigFilterChainManager() IConfigFilterChain { + return newConfigFilterChainManager() } -func newDefaultConfigFilterChainManager() *DefaultConfigFilterChainManager { +func newConfigFilterChainManager() *DefaultConfigFilterChainManager { return &DefaultConfigFilterChainManager{ configFilterPriorityQueue: make([]IConfigFilter, 0, 2), } @@ -101,10 +81,6 @@ func (m *DefaultConfigFilterChainManager) DoFilterByName(param *vo.ConfigParam, return fmt.Errorf("cannot find the filter[%s]", name) } -func RegisterConfigFilter(chain IConfigFilterChain, filter IConfigFilter) error { - return chain.AddFilter(filter) -} - type configFilterPriorityQueue []IConfigFilter func (c *configFilterPriorityQueue) addFilter(filter IConfigFilter) error { diff --git a/example/config-mse-kmsv3/main.go b/example/config-mse-kmsv3/main.go index 55e5bf55..fa81ced1 100644 --- a/example/config-mse-kmsv3/main.go +++ b/example/config-mse-kmsv3/main.go @@ -21,7 +21,6 @@ import ( "github.com/nacos-group/nacos-sdk-go/v2/clients/config_client" "github.com/nacos-group/nacos-sdk-go/v2/clients/nacos_client" "github.com/nacos-group/nacos-sdk-go/v2/common/constant" - "github.com/nacos-group/nacos-sdk-go/v2/common/filter" "github.com/nacos-group/nacos-sdk-go/v2/common/http_agent" "github.com/nacos-group/nacos-sdk-go/v2/common/logger" "github.com/nacos-group/nacos-sdk-go/v2/vo" @@ -32,7 +31,7 @@ import ( ) var localServerConfigWithOptions = constant.NewServerConfig( - "mse-d12e6112-p.nacos-ans.mse.aliyuncs.com", + "mse-1a3d3840-p.nacos-ans.mse.aliyuncs.com", 8848, ) @@ -42,7 +41,7 @@ var localClientConfigWithOptions = constant.NewClientConfig( constant.WithNotLoadCacheAtStart(true), constant.WithAccessKey(getFileContent(path.Join(getWDR(), "ak"))), constant.WithSecretKey(getFileContent(path.Join(getWDR(), "sk"))), - constant.WithNamespaceId("791fd262-3735-40df-a605-e3236f8ff495"), + //constant.WithNamespaceId("791fd262-3735-40df-a605-e3236f8ff495"), constant.WithOpenKMS(true), constant.WithKMSVersion(constant.KMSv3), constant.WithKMSv3Config(&constant.KMSv3Config{ @@ -56,27 +55,26 @@ var localClientConfigWithOptions = constant.NewClientConfig( var localConfigList = []vo.ConfigParam{ { - DataId: "common-config", - Group: "default", - Content: "common", - KmsKeyId: "key-xxx", //可以识别 + DataId: "common-config", + Group: "default", + Content: "common普通&&", }, { DataId: "cipher-crypt", Group: "default", - Content: "cipher", + Content: "cipher加密&&", KmsKeyId: "key-xxx", //可以识别 }, { DataId: "cipher-kms-aes-128-crypt", Group: "default", - Content: "cipher-aes-128", + Content: "cipher-aes-128加密&&", KmsKeyId: "key-xxx", //可以识别 }, { DataId: "cipher-kms-aes-256-crypt", Group: "default", - Content: "cipher-aes-256", + Content: "cipher-aes-256加密&&", KmsKeyId: "key-xxx", //可以识别 }, } @@ -143,24 +141,6 @@ func usingKMSv3ClientAndStoredByNacos() { } } -func onlyUsingFilters() error { - createConfigClient() - for _, param := range localConfigList { - param.UsageType = vo.RequestType - fmt.Println("param = ", param) - if err := filter.GetDefaultConfigFilterChainManager().DoFilters(¶m); err != nil { - return err - } - fmt.Println("after encrypt param = ", param) - param.UsageType = vo.ResponseType - if err := filter.GetDefaultConfigFilterChainManager().DoFilters(¶m); err != nil { - return err - } - fmt.Println("after decrypt param = ", param) - } - return nil -} - func createConfigClient() *config_client.ConfigClient { nc := nacos_client.NacosClient{} _ = nc.SetServerConfig([]constant.ServerConfig{*localServerConfigWithOptions})