Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 35 additions & 14 deletions sidecar/tasks/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sei-protocol/seictl/sidecar/engine"
"github.com/sei-protocol/seilog"
Expand All @@ -38,6 +39,26 @@ func DefaultS3ClientFactory(ctx context.Context, region string) (S3GetObjectAPI,
return s3.NewFromConfig(cfg), nil
}

// S3TransferClient abstracts the transfer manager GetObject call for testing.
// The transfer manager's GetObject downloads byte ranges in parallel and
// reassembles them into a sequential io.Reader, giving us parallel throughput
// with a streaming API.
type S3TransferClient interface {
GetObject(ctx context.Context, input *transfermanager.GetObjectInput, opts ...func(*transfermanager.Options)) (*transfermanager.GetObjectOutput, error)
}

// S3TransferClientFactory builds an S3TransferClient for a given region.
type S3TransferClientFactory func(ctx context.Context, region string) (S3TransferClient, error)

// DefaultS3TransferClientFactory creates a transfermanager.Client backed by a real S3 client.
func DefaultS3TransferClientFactory(ctx context.Context, region string) (S3TransferClient, error) {
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, fmt.Errorf("loading AWS config: %w", err)
}
return transfermanager.New(s3.NewFromConfig(cfg)), nil
}

// SnapshotConfig holds S3 coordinates for snapshot download.
type SnapshotConfig struct {
Bucket string
Expand All @@ -48,18 +69,18 @@ type SnapshotConfig struct {

// SnapshotRestorer downloads and extracts a snapshot archive from S3.
type SnapshotRestorer struct {
homeDir string
s3ClientFactory S3ClientFactory
homeDir string
s3TransferClientFactory S3TransferClientFactory
}

// NewSnapshotRestorer creates a restorer targeting the given home directory.
func NewSnapshotRestorer(homeDir string, factory S3ClientFactory) *SnapshotRestorer {
func NewSnapshotRestorer(homeDir string, factory S3TransferClientFactory) *SnapshotRestorer {
if factory == nil {
factory = DefaultS3ClientFactory
factory = DefaultS3TransferClientFactory
}
return &SnapshotRestorer{
homeDir: homeDir,
s3ClientFactory: factory,
homeDir: homeDir,
s3TransferClientFactory: factory,
}
}

Expand All @@ -77,31 +98,32 @@ func (r *SnapshotRestorer) Handler() engine.TaskHandler {

// Restore downloads and extracts the snapshot, skipping if the marker file exists.
// It reads latest.txt from the prefix to resolve the current snapshot object key.
// The transfer manager downloads byte ranges in parallel behind the scenes,
// presenting the data as a sequential io.Reader for streaming extraction.
func (r *SnapshotRestorer) Restore(ctx context.Context, cfg SnapshotConfig) error {
if markerExists(r.homeDir, snapshotMarkerFile) {
restoreLog.Debug("already completed, skipping")
return nil
}

s3Client, err := r.s3ClientFactory(ctx, cfg.Region)
tmClient, err := r.s3TransferClientFactory(ctx, cfg.Region)
if err != nil {
return fmt.Errorf("building S3 client: %w", err)
return fmt.Errorf("building S3 transfer client: %w", err)
}

snapshotKey, err := resolveSnapshotKey(ctx, s3Client, cfg)
snapshotKey, err := resolveSnapshotKey(ctx, tmClient, cfg)
if err != nil {
return err
}

restoreLog.Info("downloading snapshot", "bucket", cfg.Bucket, "key", snapshotKey)
output, err := s3Client.GetObject(ctx, &s3.GetObjectInput{
output, err := tmClient.GetObject(ctx, &transfermanager.GetObjectInput{
Bucket: aws.String(cfg.Bucket),
Key: aws.String(snapshotKey),
})
if err != nil {
return fmt.Errorf("s3 GetObject %s: %w", snapshotKey, err)
}
defer func() { _ = output.Body.Close() }()

snapshotDir := filepath.Join(r.homeDir, "data", "snapshots")
if err := os.MkdirAll(snapshotDir, 0o755); err != nil {
Expand All @@ -118,15 +140,14 @@ func (r *SnapshotRestorer) Restore(ctx context.Context, cfg SnapshotConfig) erro
}

// resolveSnapshotKey reads <prefix>latest.txt to find the current snapshot object key.
func resolveSnapshotKey(ctx context.Context, s3Client S3GetObjectAPI, cfg SnapshotConfig) (string, error) {
out, err := s3Client.GetObject(ctx, &s3.GetObjectInput{
func resolveSnapshotKey(ctx context.Context, tmClient S3TransferClient, cfg SnapshotConfig) (string, error) {
out, err := tmClient.GetObject(ctx, &transfermanager.GetObjectInput{
Bucket: aws.String(cfg.Bucket),
Key: aws.String(cfg.Prefix + "latest.txt"),
})
if err != nil {
return "", fmt.Errorf("reading latest.txt: %w", err)
}
defer func() { _ = out.Body.Close() }()

data, err := io.ReadAll(out.Body)
if err != nil {
Expand Down
31 changes: 15 additions & 16 deletions sidecar/tasks/snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,37 @@ import (
"compress/gzip"
"context"
"fmt"
"io"
"os"
"path/filepath"
"testing"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/feature/s3/transfermanager"
)

// mockS3Client implements S3GetObjectAPI for testing.
// mockTransferClient implements S3TransferClient for testing.
// responses maps S3 object keys to their body bytes.
// If a key is not in responses, errDefault is returned.
type mockS3Client struct {
type mockTransferClient struct {
responses map[string][]byte
errDefault error
}

func (m *mockS3Client) GetObject(_ context.Context, in *s3.GetObjectInput, _ ...func(*s3.Options)) (*s3.GetObjectOutput, error) {
func (m *mockTransferClient) GetObject(_ context.Context, in *transfermanager.GetObjectInput, _ ...func(*transfermanager.Options)) (*transfermanager.GetObjectOutput, error) {
key := ""
if in.Key != nil {
key = *in.Key
}
if body, ok := m.responses[key]; ok {
return &s3.GetObjectOutput{Body: io.NopCloser(bytes.NewReader(body))}, nil
return &transfermanager.GetObjectOutput{Body: bytes.NewReader(body)}, nil
}
if m.errDefault != nil {
return nil, m.errDefault
}
return nil, fmt.Errorf("unexpected key: %s", key)
}

func mockS3Factory(client S3GetObjectAPI) S3ClientFactory {
return func(_ context.Context, _ string) (S3GetObjectAPI, error) {
func mockTransferFactory(client S3TransferClient) S3TransferClientFactory {
return func(_ context.Context, _ string) (S3TransferClient, error) {
return client, nil
}
}
Expand Down Expand Up @@ -76,13 +75,13 @@ func TestSnapshotRestoreExtractsArchive(t *testing.T) {
"config.toml": "[p2p]\npersistent_peers = \"\"",
})

client := &mockS3Client{
client := &mockTransferClient{
responses: map[string][]byte{
"snapshots/latest.txt": []byte("100000000"),
"snapshots/snapshot_100000000_testchain_us-east-1.tar.gz": archive,
},
}
restorer := NewSnapshotRestorer(homeDir, mockS3Factory(client))
restorer := NewSnapshotRestorer(homeDir, mockTransferFactory(client))
err := restorer.Restore(context.Background(), SnapshotConfig{
Bucket: "test-bucket",
Prefix: "snapshots/",
Expand Down Expand Up @@ -119,7 +118,7 @@ func TestSnapshotRestoreSkipsWhenMarkerExists(t *testing.T) {
}

// S3 client that would fail if called — proves we skip the download.
restorer := NewSnapshotRestorer(homeDir, mockS3Factory(&mockS3Client{
restorer := NewSnapshotRestorer(homeDir, mockTransferFactory(&mockTransferClient{
errDefault: fmt.Errorf("should not be called"),
}))

Expand All @@ -137,7 +136,7 @@ func TestSnapshotRestoreSkipsWhenMarkerExists(t *testing.T) {
func TestSnapshotRestoreNoMarkerOnLatestTxtError(t *testing.T) {
homeDir := t.TempDir()

restorer := NewSnapshotRestorer(homeDir, mockS3Factory(&mockS3Client{
restorer := NewSnapshotRestorer(homeDir, mockTransferFactory(&mockTransferClient{
errDefault: fmt.Errorf("access denied"),
}))

Expand All @@ -159,13 +158,13 @@ func TestSnapshotRestoreNoMarkerOnLatestTxtError(t *testing.T) {
func TestSnapshotRestoreNoMarkerOnDownloadError(t *testing.T) {
homeDir := t.TempDir()

client := &mockS3Client{
client := &mockTransferClient{
responses: map[string][]byte{
"snapshots/latest.txt": []byte("100000000"),
},
errDefault: fmt.Errorf("access denied"),
}
restorer := NewSnapshotRestorer(homeDir, mockS3Factory(client))
restorer := NewSnapshotRestorer(homeDir, mockTransferFactory(client))

err := restorer.Restore(context.Background(), SnapshotConfig{
Bucket: "test-bucket",
Expand All @@ -188,13 +187,13 @@ func TestSnapshotRestoreRejectsPathTraversal(t *testing.T) {
"../../etc/passwd": "malicious",
})

client := &mockS3Client{
client := &mockTransferClient{
responses: map[string][]byte{
"snapshots/latest.txt": []byte("100000000"),
"snapshots/snapshot_100000000_testchain_us-east-1.tar.gz": archive,
},
}
restorer := NewSnapshotRestorer(homeDir, mockS3Factory(client))
restorer := NewSnapshotRestorer(homeDir, mockTransferFactory(client))
err := restorer.Restore(context.Background(), SnapshotConfig{
Bucket: "test-bucket",
Prefix: "snapshots/",
Expand Down
Loading