diff --git a/concurrency/dir/dir.go b/concurrency/dir/dir.go new file mode 100644 index 0000000..3af9e85 --- /dev/null +++ b/concurrency/dir/dir.go @@ -0,0 +1,90 @@ +/* +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package dir + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/dapr/kit/logger" +) + +type Options struct { + Log logger.Logger + Target string +} + +// Dir atomically writes files to a given directory. +type Dir struct { + log logger.Logger + + base string + target string + targetDir string + + prev *string +} + +func New(opts Options) *Dir { + return &Dir{ + log: opts.Log, + base: filepath.Dir(opts.Target), + target: opts.Target, + targetDir: filepath.Base(opts.Target), + } +} + +func (d *Dir) Write(files map[string][]byte) error { + newDir := filepath.Join(d.base, fmt.Sprintf("%d-%s", time.Now().UTC().UnixNano(), d.targetDir)) + + if err := os.MkdirAll(d.base, os.ModePerm); err != nil { + return err + } + + if err := os.MkdirAll(newDir, os.ModePerm); err != nil { + return err + } + + for file, b := range files { + path := filepath.Join(newDir, file) + if err := os.WriteFile(path, b, os.ModePerm); err != nil { + return err + } + d.log.Infof("Written file %s", file) + } + + if err := os.Symlink(newDir, d.target+".new"); err != nil { + return err + } + + d.log.Infof("Syslink %s to %s.new", newDir, d.target) + + if err := os.Rename(d.target+".new", d.target); err != nil { + return err + } + + d.log.Infof("Atomic write to %s", d.target) + + if d.prev != nil { + if err := os.RemoveAll(*d.prev); err != nil { + return err + } + } + + d.prev = &newDir + + return nil +} diff --git a/crypto/spiffe/spiffe.go b/crypto/spiffe/spiffe.go index 4904272..51109aa 100644 --- a/crypto/spiffe/spiffe.go +++ b/crypto/spiffe/spiffe.go @@ -28,6 +28,9 @@ import ( "github.com/spiffe/go-spiffe/v2/svid/x509svid" "k8s.io/utils/clock" + "github.com/dapr/kit/concurrency/dir" + "github.com/dapr/kit/crypto/pem" + "github.com/dapr/kit/crypto/spiffe/trustanchors" "github.com/dapr/kit/logger" ) @@ -38,6 +41,14 @@ type ( type Options struct { Log logger.Logger RequestSVIDFn RequestSVIDFn + + // WriteIdentityToFile is used to write the identity private key and + // certificate chain to file. The certificate chain and private key will be + // written to the `tls.cert` and `tls.key` files respectively in the given + // directory. + WriteIdentityToFile *string + + TrustAnchors trustanchors.Interface } // SPIFFE is a readable/writeable store of a SPIFFE X.509 SVID. @@ -46,6 +57,9 @@ type SPIFFE struct { currentSVID *x509svid.SVID requestSVIDFn RequestSVIDFn + dir *dir.Dir + trustAnchors trustanchors.Interface + log logger.Logger lock sync.RWMutex clock clock.Clock @@ -54,8 +68,18 @@ type SPIFFE struct { } func New(opts Options) *SPIFFE { + var sdir *dir.Dir + if opts.WriteIdentityToFile != nil { + sdir = dir.New(dir.Options{ + Log: opts.Log, + Target: *opts.WriteIdentityToFile, + }) + } + return &SPIFFE{ requestSVIDFn: opts.RequestSVIDFn, + dir: sdir, + trustAnchors: opts.TrustAnchors, log: opts.Log, clock: clock.RealClock{}, readyCh: make(chan struct{}), @@ -165,6 +189,31 @@ func (s *SPIFFE) fetchIdentityCertificate(ctx context.Context) (*x509svid.SVID, return nil, fmt.Errorf("error parsing spiffe id from newly signed certificate: %w", err) } + if s.dir != nil { + pkPEM, err := pem.EncodePrivateKey(key) + if err != nil { + return nil, err + } + + certPEM, err := pem.EncodeX509Chain(workloadcert) + if err != nil { + return nil, err + } + + td, err := s.trustAnchors.CurrentTrustAnchors(ctx) + if err != nil { + return nil, err + } + + if err := s.dir.Write(map[string][]byte{ + "key.pem": pkPEM, + "cert.pem": certPEM, + "ca.pem": td, + }); err != nil { + return nil, err + } + } + return &x509svid.SVID{ ID: spiffeID, Certificates: workloadcert,