diff --git a/test/dummy.go b/test/dummy.go index b9b5699..f290826 100644 --- a/test/dummy.go +++ b/test/dummy.go @@ -4,7 +4,7 @@ import ( "bytes" "context" "crypto/sha512" - "fmt" + "regexp" "slices" "sync" "time" @@ -12,9 +12,11 @@ import ( "github.com/rollkit/go-execution/types" ) +var validChainIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]*`) + // DummyExecutor is a dummy implementation of the DummyExecutor interface for testing type DummyExecutor struct { - mu sync.RWMutex // Add mutex for thread safety + mu sync.RWMutex stateRoot types.Hash pendingRoots map[uint64]types.Hash maxBytes uint64 @@ -36,6 +38,22 @@ func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, in e.mu.Lock() defer e.mu.Unlock() + if initialHeight == 0 { + return types.Hash{}, 0, types.ErrZeroInitialHeight + } + if chainID == "" { + return types.Hash{}, 0, types.ErrEmptyChainID + } + if !validChainIDRegex.MatchString(chainID) { + return types.Hash{}, 0, types.ErrInvalidChainID + } + if genesisTime.After(time.Now()) { + return types.Hash{}, 0, types.ErrFutureGenesisTime + } + if len(chainID) > 32 { + return types.Hash{}, 0, types.ErrChainIDTooLong + } + hash := sha512.New() hash.Write(e.stateRoot) e.stateRoot = hash.Sum(nil) @@ -48,7 +66,7 @@ func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) { defer e.mu.RUnlock() txs := make([]types.Tx, len(e.injectedTxs)) - copy(txs, e.injectedTxs) // Create a copy to avoid external modifications + copy(txs, e.injectedTxs) return txs, nil } @@ -65,6 +83,28 @@ func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHei e.mu.Lock() defer e.mu.Unlock() + if bytes.Equal(prevStateRoot, types.Hash{}) { + return types.Hash{}, 0, types.ErrEmptyStateRoot + } + + // Don't really allow future block times, but allow up to 5 minutes in the future + // for testing purposes. + if timestamp.After(time.Now().Add(5 * time.Minute)) { + return types.Hash{}, 0, types.ErrFutureBlockTime + } + if blockHeight == 0 { + return types.Hash{}, 0, types.ErrInvalidBlockHeight + } + + for _, tx := range txs { + if len(tx) == 0 { + return types.Hash{}, 0, types.ErrEmptyTx + } + if uint64(len(tx)) > e.maxBytes { + return types.Hash{}, 0, types.ErrTxTooLarge + } + } + hash := sha512.New() hash.Write(prevStateRoot) for _, tx := range txs { @@ -86,7 +126,7 @@ func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error delete(e.pendingRoots, blockHeight) return nil } - return fmt.Errorf("cannot set finalized block at height %d", blockHeight) + return types.ErrBlockNotFound } func (e *DummyExecutor) removeExecutedTxs(txs []types.Tx) { diff --git a/test/dummy_test.go b/test/dummy_test.go index 4c44618..2f7b9e6 100644 --- a/test/dummy_test.go +++ b/test/dummy_test.go @@ -2,6 +2,9 @@ package test import ( "context" + "fmt" + "strings" + "sync" "testing" "time" @@ -25,7 +28,8 @@ func TestDummySuite(t *testing.T) { suite.Run(t, new(DummyTestSuite)) } -func TestTxRemoval(t *testing.T) { +func (s *DummyTestSuite) TestTxRemoval() { + t := s.T() exec := NewDummyExecutor() tx1 := types.Tx([]byte{1, 2, 3}) tx2 := types.Tx([]byte{3, 2, 1}) @@ -47,7 +51,8 @@ func TestTxRemoval(t *testing.T) { require.Contains(t, txs, tx1) require.Contains(t, txs, tx2) - state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), nil) + dummyStateRoot := []byte("dummy-state-root") + state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), dummyStateRoot) require.NoError(t, err) require.NotEmpty(t, state) @@ -58,3 +63,168 @@ func TestTxRemoval(t *testing.T) { require.NotContains(t, txs, tx1) require.Contains(t, txs, tx2) } + +func (s *DummyTestSuite) TestExecuteTxsComprehensive() { + t := s.T() + tests := []struct { + name string + txs []types.Tx + blockHeight uint64 + timestamp time.Time + prevStateRoot types.Hash + expectedErr error + }{ + { + name: "valid multiple transactions", + txs: []types.Tx{[]byte("tx1"), []byte("tx2"), []byte("tx3")}, + blockHeight: 1, + timestamp: time.Now().UTC(), + prevStateRoot: types.Hash{1, 2, 3}, + expectedErr: nil, + }, + { + name: "empty state root", + txs: []types.Tx{[]byte("tx1")}, + blockHeight: 1, + timestamp: time.Now().UTC(), + prevStateRoot: types.Hash{}, + expectedErr: types.ErrEmptyStateRoot, + }, + { + name: "future timestamp", + txs: []types.Tx{[]byte("tx1")}, + blockHeight: 1, + timestamp: time.Now().Add(24 * time.Hour), + prevStateRoot: types.Hash{1, 2, 3}, + expectedErr: types.ErrFutureBlockTime, + }, + { + name: "empty transaction", + txs: []types.Tx{[]byte("")}, + blockHeight: 1, + timestamp: time.Now().UTC(), + prevStateRoot: types.Hash{1, 2, 3}, + expectedErr: types.ErrEmptyTx, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stateRoot, maxBytes, err := s.Exec.ExecuteTxs(context.Background(), tt.txs, tt.blockHeight, tt.timestamp, tt.prevStateRoot) + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + return + } + require.NoError(t, err) + require.NotEqual(t, types.Hash{}, stateRoot) + require.Greater(t, maxBytes, uint64(0)) + }) + } +} + +func (s *DummyTestSuite) TestInitChain() { + t := s.T() + tests := []struct { + name string + genesisTime time.Time + initialHeight uint64 + chainID string + expectedErr error + }{ + { + name: "valid case", + genesisTime: time.Now().UTC(), + initialHeight: 1, + chainID: "test-chain", + expectedErr: nil, + }, + { + name: "very large initial height", + genesisTime: time.Now().UTC(), + initialHeight: 1000000, + chainID: "test-chain", + expectedErr: nil, + }, + { + name: "zero height", + genesisTime: time.Now().UTC(), + initialHeight: 0, + chainID: "test-chain", + expectedErr: types.ErrZeroInitialHeight, + }, + { + name: "empty chain ID", + genesisTime: time.Now().UTC(), + initialHeight: 1, + chainID: "", + expectedErr: types.ErrEmptyChainID, + }, + { + name: "future genesis time", + genesisTime: time.Now().Add(1 * time.Hour), + initialHeight: 1, + chainID: "test-chain", + expectedErr: types.ErrFutureGenesisTime, + }, + { + name: "invalid chain ID characters", + genesisTime: time.Now().UTC(), + initialHeight: 1, + chainID: "@invalid", + expectedErr: types.ErrInvalidChainID, + }, + { + name: "invalid chain ID length", + genesisTime: time.Now().UTC(), + initialHeight: 1, + chainID: strings.Repeat("a", 50), + expectedErr: types.ErrChainIDTooLong, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + stateRoot, maxBytes, err := s.Exec.InitChain(context.Background(), tt.genesisTime, tt.initialHeight, tt.chainID) + if tt.expectedErr != nil { + require.ErrorIs(t, err, tt.expectedErr) + return + } + require.NoError(t, err) + require.NotEqual(t, types.Hash{}, stateRoot) + require.Greater(t, maxBytes, uint64(0)) + }) + } +} + +func (s *DummyTestSuite) TestGetTxsWithConcurrency() { + t := s.T() + const numGoroutines = 10 + const txsPerGoroutine = 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Inject transactions concurrently + for i := 0; i < numGoroutines; i++ { + go func(id int) { + defer wg.Done() + for j := 0; j < txsPerGoroutine; j++ { + tx := types.Tx([]byte(fmt.Sprintf("tx-%d-%d", id, j))) + s.TxInjector.InjectTx(tx) + } + }(i) + } + wg.Wait() + + // Verify all transactions are retrievable + txs, err := s.Exec.GetTxs(context.Background()) + require.NoError(t, err) + require.Len(t, txs, numGoroutines*txsPerGoroutine) + + // Verify transaction uniqueness + txMap := make(map[string]struct{}) + for _, tx := range txs { + txMap[string(tx)] = struct{}{} + } + require.Len(t, txMap, numGoroutines*txsPerGoroutine) +} diff --git a/types/errors.go b/types/errors.go new file mode 100644 index 0000000..5f6c5c3 --- /dev/null +++ b/types/errors.go @@ -0,0 +1,56 @@ +package types + +import "errors" + +var ( + // Chain initialization errors + + // ErrZeroInitialHeight is returned when the initial height is zero + ErrZeroInitialHeight = errors.New("initial height cannot be zero") + // ErrEmptyChainID is returned when the chain ID is empty + ErrEmptyChainID = errors.New("chain ID cannot be empty") + // ErrInvalidChainID is returned when the chain ID contains invalid characters + ErrInvalidChainID = errors.New("chain ID contains invalid characters") + // ErrChainIDTooLong is returned when the chain ID exceeds maximum length + ErrChainIDTooLong = errors.New("chain ID exceeds maximum length") + // ErrFutureGenesisTime is returned when the genesis time is in the future + ErrFutureGenesisTime = errors.New("genesis time cannot be in the future") + + // Transaction execution errors + + // ErrEmptyStateRoot is returned when the previous state root is empty + ErrEmptyStateRoot = errors.New("previous state root cannot be empty") + // ErrFutureBlockTime is returned when the block timestamp is in the future + ErrFutureBlockTime = errors.New("block timestamp cannot be in the future") + // ErrInvalidBlockHeight is returned when the block height is invalid + ErrInvalidBlockHeight = errors.New("invalid block height") + // ErrTxTooLarge is returned when the transaction size exceeds maximum allowed + ErrTxTooLarge = errors.New("transaction size exceeds maximum allowed") + // ErrEmptyTx is returned when the transaction is empty + ErrEmptyTx = errors.New("transaction cannot be empty") + + // Block finalization errors + + // ErrBlockNotFound is returned when the block is not found + ErrBlockNotFound = errors.New("block not found") + // ErrBlockAlreadyExists is returned when the block already exists + ErrBlockAlreadyExists = errors.New("block already exists") + // ErrNonSequentialBlock is returned when the block height is not sequential + ErrNonSequentialBlock = errors.New("non-sequential block height") + + // Transaction pool errors + + // ErrTxAlreadyExists is returned when the transaction already exists in pool + ErrTxAlreadyExists = errors.New("transaction already exists in pool") + // ErrTxPoolFull is returned when the transaction pool is full + ErrTxPoolFull = errors.New("transaction pool is full") + // ErrInvalidTxFormat is returned when the transaction format is invalid + ErrInvalidTxFormat = errors.New("invalid transaction format") + + // Context errors + + // ErrContextCanceled is returned when the context is canceled + ErrContextCanceled = errors.New("context canceled") + // ErrContextTimeout is returned when the context deadline is exceeded + ErrContextTimeout = errors.New("context deadline exceeded") +)