Skip to content

Commit

Permalink
test: add more tests and sentinel errors (#66)
Browse files Browse the repository at this point in the history
* test: add more tests and sentinel errors

* missing godoc

* missing godoc

* Update types/errors.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* linter

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Marko <[email protected]>
  • Loading branch information
3 people authored Feb 25, 2025
1 parent 0f31609 commit 15978b7
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 6 deletions.
48 changes: 44 additions & 4 deletions test/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"bytes"
"context"
"crypto/sha512"
"fmt"
"regexp"
"slices"
"sync"
"time"

"github.com/rollkit/go-execution/types"
)

var validChainIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]*`)

// DummyExecutor is a dummy implementation of the DummyExecutor interface for testing
type DummyExecutor struct {
mu sync.RWMutex // Add mutex for thread safety
mu sync.RWMutex
stateRoot types.Hash
pendingRoots map[uint64]types.Hash
maxBytes uint64
Expand All @@ -36,6 +38,22 @@ func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, in
e.mu.Lock()
defer e.mu.Unlock()

if initialHeight == 0 {
return types.Hash{}, 0, types.ErrZeroInitialHeight
}
if chainID == "" {
return types.Hash{}, 0, types.ErrEmptyChainID
}
if !validChainIDRegex.MatchString(chainID) {
return types.Hash{}, 0, types.ErrInvalidChainID
}
if genesisTime.After(time.Now()) {
return types.Hash{}, 0, types.ErrFutureGenesisTime
}
if len(chainID) > 32 {
return types.Hash{}, 0, types.ErrChainIDTooLong
}

hash := sha512.New()
hash.Write(e.stateRoot)
e.stateRoot = hash.Sum(nil)
Expand All @@ -48,7 +66,7 @@ func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) {
defer e.mu.RUnlock()

txs := make([]types.Tx, len(e.injectedTxs))
copy(txs, e.injectedTxs) // Create a copy to avoid external modifications
copy(txs, e.injectedTxs)
return txs, nil
}

Expand All @@ -65,6 +83,28 @@ func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHei
e.mu.Lock()
defer e.mu.Unlock()

if bytes.Equal(prevStateRoot, types.Hash{}) {
return types.Hash{}, 0, types.ErrEmptyStateRoot
}

// Don't really allow future block times, but allow up to 5 minutes in the future
// for testing purposes.
if timestamp.After(time.Now().Add(5 * time.Minute)) {
return types.Hash{}, 0, types.ErrFutureBlockTime
}
if blockHeight == 0 {
return types.Hash{}, 0, types.ErrInvalidBlockHeight
}

for _, tx := range txs {
if len(tx) == 0 {
return types.Hash{}, 0, types.ErrEmptyTx
}
if uint64(len(tx)) > e.maxBytes {
return types.Hash{}, 0, types.ErrTxTooLarge
}
}

hash := sha512.New()
hash.Write(prevStateRoot)
for _, tx := range txs {
Expand All @@ -86,7 +126,7 @@ func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error
delete(e.pendingRoots, blockHeight)
return nil
}
return fmt.Errorf("cannot set finalized block at height %d", blockHeight)
return types.ErrBlockNotFound
}

func (e *DummyExecutor) removeExecutedTxs(txs []types.Tx) {
Expand Down
174 changes: 172 additions & 2 deletions test/dummy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package test

import (
"context"
"fmt"
"strings"
"sync"
"testing"
"time"

Expand All @@ -25,7 +28,8 @@ func TestDummySuite(t *testing.T) {
suite.Run(t, new(DummyTestSuite))
}

func TestTxRemoval(t *testing.T) {
func (s *DummyTestSuite) TestTxRemoval() {
t := s.T()
exec := NewDummyExecutor()
tx1 := types.Tx([]byte{1, 2, 3})
tx2 := types.Tx([]byte{3, 2, 1})
Expand All @@ -47,7 +51,8 @@ func TestTxRemoval(t *testing.T) {
require.Contains(t, txs, tx1)
require.Contains(t, txs, tx2)

state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), nil)
dummyStateRoot := []byte("dummy-state-root")
state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), dummyStateRoot)
require.NoError(t, err)
require.NotEmpty(t, state)

Expand All @@ -58,3 +63,168 @@ func TestTxRemoval(t *testing.T) {
require.NotContains(t, txs, tx1)
require.Contains(t, txs, tx2)
}

func (s *DummyTestSuite) TestExecuteTxsComprehensive() {
t := s.T()
tests := []struct {
name string
txs []types.Tx
blockHeight uint64
timestamp time.Time
prevStateRoot types.Hash
expectedErr error
}{
{
name: "valid multiple transactions",
txs: []types.Tx{[]byte("tx1"), []byte("tx2"), []byte("tx3")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: nil,
},
{
name: "empty state root",
txs: []types.Tx{[]byte("tx1")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{},
expectedErr: types.ErrEmptyStateRoot,
},
{
name: "future timestamp",
txs: []types.Tx{[]byte("tx1")},
blockHeight: 1,
timestamp: time.Now().Add(24 * time.Hour),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: types.ErrFutureBlockTime,
},
{
name: "empty transaction",
txs: []types.Tx{[]byte("")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: types.ErrEmptyTx,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stateRoot, maxBytes, err := s.Exec.ExecuteTxs(context.Background(), tt.txs, tt.blockHeight, tt.timestamp, tt.prevStateRoot)
if tt.expectedErr != nil {
require.ErrorIs(t, err, tt.expectedErr)
return
}
require.NoError(t, err)
require.NotEqual(t, types.Hash{}, stateRoot)
require.Greater(t, maxBytes, uint64(0))
})
}
}

func (s *DummyTestSuite) TestInitChain() {
t := s.T()
tests := []struct {
name string
genesisTime time.Time
initialHeight uint64
chainID string
expectedErr error
}{
{
name: "valid case",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "test-chain",
expectedErr: nil,
},
{
name: "very large initial height",
genesisTime: time.Now().UTC(),
initialHeight: 1000000,
chainID: "test-chain",
expectedErr: nil,
},
{
name: "zero height",
genesisTime: time.Now().UTC(),
initialHeight: 0,
chainID: "test-chain",
expectedErr: types.ErrZeroInitialHeight,
},
{
name: "empty chain ID",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "",
expectedErr: types.ErrEmptyChainID,
},
{
name: "future genesis time",
genesisTime: time.Now().Add(1 * time.Hour),
initialHeight: 1,
chainID: "test-chain",
expectedErr: types.ErrFutureGenesisTime,
},
{
name: "invalid chain ID characters",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "@invalid",
expectedErr: types.ErrInvalidChainID,
},
{
name: "invalid chain ID length",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: strings.Repeat("a", 50),
expectedErr: types.ErrChainIDTooLong,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stateRoot, maxBytes, err := s.Exec.InitChain(context.Background(), tt.genesisTime, tt.initialHeight, tt.chainID)
if tt.expectedErr != nil {
require.ErrorIs(t, err, tt.expectedErr)
return
}
require.NoError(t, err)
require.NotEqual(t, types.Hash{}, stateRoot)
require.Greater(t, maxBytes, uint64(0))
})
}
}

func (s *DummyTestSuite) TestGetTxsWithConcurrency() {
t := s.T()
const numGoroutines = 10
const txsPerGoroutine = 100

var wg sync.WaitGroup
wg.Add(numGoroutines)

// Inject transactions concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < txsPerGoroutine; j++ {
tx := types.Tx([]byte(fmt.Sprintf("tx-%d-%d", id, j)))
s.TxInjector.InjectTx(tx)
}
}(i)
}
wg.Wait()

// Verify all transactions are retrievable
txs, err := s.Exec.GetTxs(context.Background())
require.NoError(t, err)
require.Len(t, txs, numGoroutines*txsPerGoroutine)

// Verify transaction uniqueness
txMap := make(map[string]struct{})
for _, tx := range txs {
txMap[string(tx)] = struct{}{}
}
require.Len(t, txMap, numGoroutines*txsPerGoroutine)
}
56 changes: 56 additions & 0 deletions types/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package types

import "errors"

var (
// Chain initialization errors

// ErrZeroInitialHeight is returned when the initial height is zero
ErrZeroInitialHeight = errors.New("initial height cannot be zero")
// ErrEmptyChainID is returned when the chain ID is empty
ErrEmptyChainID = errors.New("chain ID cannot be empty")
// ErrInvalidChainID is returned when the chain ID contains invalid characters
ErrInvalidChainID = errors.New("chain ID contains invalid characters")
// ErrChainIDTooLong is returned when the chain ID exceeds maximum length
ErrChainIDTooLong = errors.New("chain ID exceeds maximum length")
// ErrFutureGenesisTime is returned when the genesis time is in the future
ErrFutureGenesisTime = errors.New("genesis time cannot be in the future")

// Transaction execution errors

// ErrEmptyStateRoot is returned when the previous state root is empty
ErrEmptyStateRoot = errors.New("previous state root cannot be empty")
// ErrFutureBlockTime is returned when the block timestamp is in the future
ErrFutureBlockTime = errors.New("block timestamp cannot be in the future")
// ErrInvalidBlockHeight is returned when the block height is invalid
ErrInvalidBlockHeight = errors.New("invalid block height")
// ErrTxTooLarge is returned when the transaction size exceeds maximum allowed
ErrTxTooLarge = errors.New("transaction size exceeds maximum allowed")
// ErrEmptyTx is returned when the transaction is empty
ErrEmptyTx = errors.New("transaction cannot be empty")

// Block finalization errors

// ErrBlockNotFound is returned when the block is not found
ErrBlockNotFound = errors.New("block not found")
// ErrBlockAlreadyExists is returned when the block already exists
ErrBlockAlreadyExists = errors.New("block already exists")
// ErrNonSequentialBlock is returned when the block height is not sequential
ErrNonSequentialBlock = errors.New("non-sequential block height")

// Transaction pool errors

// ErrTxAlreadyExists is returned when the transaction already exists in pool
ErrTxAlreadyExists = errors.New("transaction already exists in pool")
// ErrTxPoolFull is returned when the transaction pool is full
ErrTxPoolFull = errors.New("transaction pool is full")
// ErrInvalidTxFormat is returned when the transaction format is invalid
ErrInvalidTxFormat = errors.New("invalid transaction format")

// Context errors

// ErrContextCanceled is returned when the context is canceled
ErrContextCanceled = errors.New("context canceled")
// ErrContextTimeout is returned when the context deadline is exceeded
ErrContextTimeout = errors.New("context deadline exceeded")
)

0 comments on commit 15978b7

Please sign in to comment.