Skip to content

Commit

Permalink
change configFilterChain from public sigleton to private owned by Con…
Browse files Browse the repository at this point in the history
…figClient
  • Loading branch information
robynron committed Dec 18, 2023
1 parent 8677a56 commit 0d49c7e
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 279 deletions.
84 changes: 19 additions & 65 deletions clients/config_client/config_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,9 @@ import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"

"github.com/alibabacloud-go/tea/tea"
dkms_api "github.com/aliyun/alibabacloud-dkms-gcs-go-sdk/openapi"
"github.com/nacos-group/nacos-sdk-go/v2/clients/cache"
"github.com/nacos-group/nacos-sdk-go/v2/clients/nacos_client"
"github.com/nacos-group/nacos-sdk-go/v2/common/constant"
Expand All @@ -52,15 +49,15 @@ type ConfigClient struct {
ctx context.Context
cancel context.CancelFunc
nacos_client.INacosClient
kmsClient *nacos_inner_encryption.KmsClient
localConfigs []vo.ConfigParam
mutex sync.Mutex
configProxy IConfigProxy
configCacheDir string
lastAllSyncTime time.Time
cacheMap cache.ConcurrentMap
uid string
listenExecute chan struct{}
configFilterChainManager filter.IConfigFilterChain
localConfigs []vo.ConfigParam
mutex sync.Mutex
configProxy IConfigProxy
configCacheDir string
lastAllSyncTime time.Time
cacheMap cache.ConcurrentMap
uid string
listenExecute chan struct{}
}

type cacheData struct {
Expand Down Expand Up @@ -94,7 +91,7 @@ func (cacheData *cacheData) executeListener() {
EncryptedDataKey: cacheData.encryptedDataKey,
UsageType: vo.ResponseType,
}
if err := filter.GetDefaultConfigFilterChainManager().DoFilters(param); err != nil {
if err := cacheData.configClient.configFilterChainManager.DoFilters(param); err != nil {
logger.Errorf("do filters failed ,dataId=%s,group=%s,tenant=%s,err:%+v ", cacheData.dataId,
cacheData.group, cacheData.tenant, err)
return
Expand Down Expand Up @@ -130,22 +127,16 @@ func NewConfigClient(nc nacos_client.INacosClient) (*ConfigClient, error) {
return nil, err
}

config.configFilterChainManager = filter.NewConfigFilterChainManager()

if clientConfig.OpenKMS {
filter.RegisterDefaultConfigEncryptionFilter()
nacos_inner_encryption.RegisterConfigEncryptionKmsPlugins()
var kmsClient *nacos_inner_encryption.KmsClient
switch clientConfig.KMSVersion {
case constant.KMSv1, constant.DEFAULT_KMS_VERSION:
kmsClient, err = initKmsV1Client(clientConfig)
case constant.KMSv3:
kmsClient, err = initKmsV3Client(clientConfig)
default:
err = fmt.Errorf("init kms client failed. unknown kms version:%s\n", clientConfig.KMSVersion)
}
kmsEncryptionHandler := nacos_inner_encryption.NewKmsHandler()
nacos_inner_encryption.RegisterConfigEncryptionKmsPlugins(kmsEncryptionHandler, clientConfig)
encryptionFilter := filter.NewDefaultConfigEncryptionFilter(kmsEncryptionHandler)
err := filter.RegisterConfigFilterToChain(config.configFilterChainManager, encryptionFilter)
if err != nil {
return nil, err
logger.Error(err)
}
config.kmsClient = kmsClient
}

uid, err := uuid.NewV4()
Expand All @@ -164,19 +155,6 @@ func initLogger(clientConfig constant.ClientConfig) error {
return logger.InitLogger(logger.BuildLoggerConfig(clientConfig))
}

func initKmsV1Client(clientConfig constant.ClientConfig) (*nacos_inner_encryption.KmsClient, error) {
return nacos_inner_encryption.InitDefaultKmsV1ClientWithAccessKey(clientConfig.RegionId, clientConfig.AccessKey, clientConfig.SecretKey)
}

func initKmsV3Client(clientConfig constant.ClientConfig) (*nacos_inner_encryption.KmsClient, error) {
return nacos_inner_encryption.InitDefaultKmsV3ClientWithConfig(&dkms_api.Config{
Protocol: tea.String("https"),
Endpoint: tea.String(clientConfig.KMSv3Config.Endpoint),
ClientKeyContent: tea.String(clientConfig.KMSv3Config.ClientKeyContent),
Password: tea.String(clientConfig.KMSv3Config.Password),
}, clientConfig.KMSv3Config.CaContent)
}

func (client *ConfigClient) GetConfig(param vo.ConfigParam) (content string, err error) {
content, encryptedDataKey, err := client.getConfigInner(param)
if err != nil {
Expand All @@ -186,37 +164,13 @@ func (client *ConfigClient) GetConfig(param vo.ConfigParam) (content string, err
deepCopyParam.EncryptedDataKey = encryptedDataKey
deepCopyParam.Content = content
deepCopyParam.UsageType = vo.ResponseType
if err = filter.GetDefaultConfigFilterChainManager().DoFilters(deepCopyParam); err != nil {
if err = client.configFilterChainManager.DoFilters(deepCopyParam); err != nil {
return "", err
}
content = deepCopyParam.Content
return content, nil
}

func (client *ConfigClient) decrypt(dataId, content string) (string, error) {
var plainContent string
var err error
if client.kmsClient != nil && strings.HasPrefix(dataId, "cipher-") {
plainContent, err = client.kmsClient.Decrypt(content)
if err != nil {
return "", fmt.Errorf("kms decrypt failed: %v", err)
}
}
return plainContent, nil
}

func (client *ConfigClient) encrypt(dataId, content, kmsKeyId string) (string, error) {
var cipherContent string
var err error
if client.kmsClient != nil && strings.HasPrefix(dataId, "cipher-") {
cipherContent, err = client.kmsClient.Encrypt(content, kmsKeyId)
if err != nil {
return "", fmt.Errorf("kms encrypt failed: %v", err)
}
}
return cipherContent, nil
}

func (client *ConfigClient) getConfigInner(param vo.ConfigParam) (content, encryptedDataKey string, err error) {
if len(param.DataId) <= 0 {
err = errors.New("[client.GetConfig] param.dataId can not be empty")
Expand Down Expand Up @@ -278,7 +232,7 @@ func (client *ConfigClient) PublishConfig(param vo.ConfigParam) (published bool,
}

param.UsageType = vo.RequestType
if err = filter.GetDefaultConfigFilterChainManager().DoFilters(&param); err != nil {
if err = client.configFilterChainManager.DoFilters(&param); err != nil {
return false, err
}

Expand Down
4 changes: 3 additions & 1 deletion common/encryption/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ const (
kmsAes256KeySpec = "AES_256"

kmsScheme = "https"
kmsAcceptFormat = "JSON"
kmsAcceptFormat = "XML"

kmsCipherAlgorithm = "AES/ECB/PKCS5Padding"

maskUnit8Width = 8
maskUnit32Width = 32

KmsHandlerName = "KmsHandler"
)

var (
Expand Down
118 changes: 81 additions & 37 deletions common/encryption/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ package encryption

import (
"fmt"
"github.com/alibabacloud-go/tea/tea"
dkms_api "github.com/aliyun/alibabacloud-dkms-gcs-go-sdk/openapi"
"github.com/nacos-group/nacos-sdk-go/v2/common/constant"
"github.com/nacos-group/nacos-sdk-go/v2/common/logger"
"github.com/pkg/errors"
"strings"
"sync"
)

var (
initDefaultHandlerOnce = &sync.Once{}
defaultHandler *DefaultHandler
)

type HandlerParam struct {
Expand All @@ -49,33 +47,52 @@ type Handler interface {
EncryptionHandler(*HandlerParam) error
DecryptionHandler(*HandlerParam) error
RegisterPlugin(Plugin) error
GetHandlerName() string
}

func NewKmsHandler() Handler {
return newKmsHandler()
}

func GetDefaultHandler() Handler {
if defaultHandler == nil {
initDefaultHandler()
func newKmsHandler() *KmsHandler {
kmsHandler := &KmsHandler{
encryptionPlugins: make(map[string]Plugin, 2),
}
return defaultHandler
logger.Debug("successfully create encryption KmsHandler")
return kmsHandler
}

func initDefaultHandler() {
initDefaultHandlerOnce.Do(func() {
defaultHandler = &DefaultHandler{
encryptionPlugins: make(map[string]Plugin, 2),
}
logger.Debug("successfully create encryption defaultHandler")
})
func RegisterConfigEncryptionKmsPlugins(encryptionHandler Handler, clientConfig constant.ClientConfig) {
innerKmsClient, err := innerNewKmsClient(clientConfig)
if err == nil && innerKmsClient == nil {
err = errors.New("create kms client failed.")
}
if err != nil {
logger.Error(err)
}
if err := encryptionHandler.RegisterPlugin(&KmsAes128Plugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil {
logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAes128AlgorithmName, encryptionHandler.GetHandlerName())
} else {
logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAes128AlgorithmName, encryptionHandler.GetHandlerName())
}
if err := encryptionHandler.RegisterPlugin(&KmsAes256Plugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil {
logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAes256AlgorithmName, encryptionHandler.GetHandlerName())
} else {
logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAes256AlgorithmName, encryptionHandler.GetHandlerName())
}
if err := encryptionHandler.RegisterPlugin(&KmsBasePlugin{kmsPlugin{kmsClient: innerKmsClient}}); err != nil {
logger.Errorf("failed to register encryption plugin[%s] to %s", KmsAlgorithmName, encryptionHandler.GetHandlerName())
} else {
logger.Debugf("successfully register encryption plugin[%s] to %s", KmsAlgorithmName, encryptionHandler.GetHandlerName())
}
}

type DefaultHandler struct {
type KmsHandler struct {
encryptionPlugins map[string]Plugin
}

func (d *DefaultHandler) EncryptionHandler(param *HandlerParam) error {
func (d *KmsHandler) EncryptionHandler(param *HandlerParam) error {
if err := d.encryptionParamCheck(*param); err != nil {
if err == DataIdParamCheckError || err == ContentParamCheckError {
return nil
}
return err
}
plugin, err := d.getPluginByDataIdPrefix(param.DataId)
Expand All @@ -90,11 +107,8 @@ func (d *DefaultHandler) EncryptionHandler(param *HandlerParam) error {
return plugin.Encrypt(param)
}

func (d *DefaultHandler) DecryptionHandler(param *HandlerParam) error {
func (d *KmsHandler) DecryptionHandler(param *HandlerParam) error {
if err := d.decryptionParamCheck(*param); err != nil {
if err == DataIdParamCheckError || err == ContentParamCheckError {
return nil
}
return err
}
plugin, err := d.getPluginByDataIdPrefix(param.DataId)
Expand All @@ -109,7 +123,7 @@ func (d *DefaultHandler) DecryptionHandler(param *HandlerParam) error {
return plugin.Decrypt(param)
}

func (d *DefaultHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error) {
func (d *KmsHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error) {
var (
matchedCount int
matchedPlugin Plugin
Expand All @@ -128,7 +142,7 @@ func (d *DefaultHandler) getPluginByDataIdPrefix(dataId string) (Plugin, error)
return matchedPlugin, nil
}

func (d *DefaultHandler) RegisterPlugin(plugin Plugin) error {
func (d *KmsHandler) RegisterPlugin(plugin Plugin) error {
if _, v := d.encryptionPlugins[plugin.AlgorithmName()]; v {
logger.Warnf("encryption algorithm [%s] has already registered to defaultHandler, will be update", plugin.AlgorithmName())
} else {
Expand All @@ -138,7 +152,11 @@ func (d *DefaultHandler) RegisterPlugin(plugin Plugin) error {
return nil
}

func (d *DefaultHandler) encryptionParamCheck(param HandlerParam) error {
func (d *KmsHandler) GetHandlerName() string {
return KmsHandlerName
}

func (d *KmsHandler) encryptionParamCheck(param HandlerParam) error {
if err := d.dataIdParamCheck(param.DataId); err != nil {
return DataIdParamCheckError
}
Expand All @@ -148,26 +166,52 @@ func (d *DefaultHandler) encryptionParamCheck(param HandlerParam) error {
return nil
}

func (d *DefaultHandler) decryptionParamCheck(param HandlerParam) error {
if err := d.dataIdParamCheck(param.DataId); err != nil {
return DataIdParamCheckError
}
if err := d.contentParamCheck(param.Content); err != nil {
return ContentParamCheckError
func (d *KmsHandler) decryptionParamCheck(param HandlerParam) error {
return d.encryptionParamCheck(param)
}

func (d *KmsHandler) keyIdParamCheck(keyId string) error {
if len(keyId) == 0 {
return fmt.Errorf("cipher dataId using kmsService need to set keyId, but keyId is nil")
}
return nil
}

func (d *DefaultHandler) dataIdParamCheck(dataId string) error {
func (d *KmsHandler) dataIdParamCheck(dataId string) error {
if !strings.Contains(dataId, CipherPrefix) {
return fmt.Errorf("dataId prefix should start with: %s", CipherPrefix)
}
return nil
}

func (d *DefaultHandler) contentParamCheck(content string) error {
func (d *KmsHandler) contentParamCheck(content string) error {
if len(content) == 0 {
return fmt.Errorf("content need to encrypt is nil")
}
return nil
}

func innerNewKmsClient(clientConfig constant.ClientConfig) (kmsClient *KmsClient, err error) {
switch clientConfig.KMSVersion {
case constant.KMSv1, constant.DEFAULT_KMS_VERSION:
kmsClient, err = newKmsV1Client(clientConfig)
case constant.KMSv3:
kmsClient, err = newKmsV3Client(clientConfig)
default:
err = fmt.Errorf("init kms client failed. unknown kms version:%s\n", clientConfig.KMSVersion)
}
return kmsClient, err
}

func newKmsV1Client(clientConfig constant.ClientConfig) (*KmsClient, error) {
return NewKmsV1ClientWithAccessKey(clientConfig.RegionId, clientConfig.AccessKey, clientConfig.SecretKey)
}

func newKmsV3Client(clientConfig constant.ClientConfig) (*KmsClient, error) {
return NewKmsV3ClientWithConfig(&dkms_api.Config{
Protocol: tea.String("https"),
Endpoint: tea.String(clientConfig.KMSv3Config.Endpoint),
ClientKeyContent: tea.String(clientConfig.KMSv3Config.ClientKeyContent),
Password: tea.String(clientConfig.KMSv3Config.Password),
}, clientConfig.KMSv3Config.CaContent)
}
Loading

0 comments on commit 0d49c7e

Please sign in to comment.