Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: add more tests and sentinel errors #66

Merged
merged 6 commits into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions test/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,19 @@ import (
"bytes"
"context"
"crypto/sha512"
"fmt"
"regexp"
"slices"
"sync"
"time"

"github.com/rollkit/go-execution/types"
)

var validChainIDRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9-]*`)

// DummyExecutor is a dummy implementation of the DummyExecutor interface for testing
type DummyExecutor struct {
mu sync.RWMutex // Add mutex for thread safety
mu sync.RWMutex
stateRoot types.Hash
pendingRoots map[uint64]types.Hash
maxBytes uint64
Expand All @@ -36,6 +38,22 @@ func (e *DummyExecutor) InitChain(ctx context.Context, genesisTime time.Time, in
e.mu.Lock()
defer e.mu.Unlock()

if initialHeight == 0 {
return types.Hash{}, 0, types.ErrZeroInitialHeight
}
if chainID == "" {
return types.Hash{}, 0, types.ErrEmptyChainID
}
if !validChainIDRegex.MatchString(chainID) {
return types.Hash{}, 0, types.ErrInvalidChainID
}
if genesisTime.After(time.Now()) {
return types.Hash{}, 0, types.ErrFutureGenesisTime
}
if len(chainID) > 32 {
return types.Hash{}, 0, types.ErrChainIDTooLong
}

hash := sha512.New()
hash.Write(e.stateRoot)
e.stateRoot = hash.Sum(nil)
Expand All @@ -48,7 +66,7 @@ func (e *DummyExecutor) GetTxs(context.Context) ([]types.Tx, error) {
defer e.mu.RUnlock()

txs := make([]types.Tx, len(e.injectedTxs))
copy(txs, e.injectedTxs) // Create a copy to avoid external modifications
copy(txs, e.injectedTxs)
return txs, nil
}

Expand All @@ -65,6 +83,28 @@ func (e *DummyExecutor) ExecuteTxs(ctx context.Context, txs []types.Tx, blockHei
e.mu.Lock()
defer e.mu.Unlock()

if bytes.Equal(prevStateRoot, types.Hash{}) {
return types.Hash{}, 0, types.ErrEmptyStateRoot
}

// Don't really allow future block times, but allow up to 5 minutes in the future
// for testing purposes.
if timestamp.After(time.Now().Add(5 * time.Minute)) {
return types.Hash{}, 0, types.ErrFutureBlockTime
}
if blockHeight == 0 {
return types.Hash{}, 0, types.ErrInvalidBlockHeight
}

for _, tx := range txs {
if len(tx) == 0 {
return types.Hash{}, 0, types.ErrEmptyTx
}
if uint64(len(tx)) > e.maxBytes {
return types.Hash{}, 0, types.ErrTxTooLarge
}
}

hash := sha512.New()
hash.Write(prevStateRoot)
for _, tx := range txs {
Expand All @@ -86,7 +126,7 @@ func (e *DummyExecutor) SetFinal(ctx context.Context, blockHeight uint64) error
delete(e.pendingRoots, blockHeight)
return nil
}
return fmt.Errorf("cannot set finalized block at height %d", blockHeight)
return types.ErrBlockNotFound
}

func (e *DummyExecutor) removeExecutedTxs(txs []types.Tx) {
Expand Down
174 changes: 172 additions & 2 deletions test/dummy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package test

import (
"context"
"fmt"
"strings"
"sync"
"testing"
"time"

Expand All @@ -25,7 +28,8 @@ func TestDummySuite(t *testing.T) {
suite.Run(t, new(DummyTestSuite))
}

func TestTxRemoval(t *testing.T) {
func (s *DummyTestSuite) TestTxRemoval() {
t := s.T()
exec := NewDummyExecutor()
tx1 := types.Tx([]byte{1, 2, 3})
tx2 := types.Tx([]byte{3, 2, 1})
Expand All @@ -47,7 +51,8 @@ func TestTxRemoval(t *testing.T) {
require.Contains(t, txs, tx1)
require.Contains(t, txs, tx2)

state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), nil)
dummyStateRoot := []byte("dummy-state-root")
state, _, err := exec.ExecuteTxs(context.Background(), []types.Tx{tx1}, 1, time.Now(), dummyStateRoot)
require.NoError(t, err)
require.NotEmpty(t, state)

Expand All @@ -58,3 +63,168 @@ func TestTxRemoval(t *testing.T) {
require.NotContains(t, txs, tx1)
require.Contains(t, txs, tx2)
}

func (s *DummyTestSuite) TestExecuteTxsComprehensive() {
t := s.T()
tests := []struct {
name string
txs []types.Tx
blockHeight uint64
timestamp time.Time
prevStateRoot types.Hash
expectedErr error
}{
{
name: "valid multiple transactions",
txs: []types.Tx{[]byte("tx1"), []byte("tx2"), []byte("tx3")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: nil,
},
{
name: "empty state root",
txs: []types.Tx{[]byte("tx1")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{},
expectedErr: types.ErrEmptyStateRoot,
},
{
name: "future timestamp",
txs: []types.Tx{[]byte("tx1")},
blockHeight: 1,
timestamp: time.Now().Add(24 * time.Hour),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: types.ErrFutureBlockTime,
},
{
name: "empty transaction",
txs: []types.Tx{[]byte("")},
blockHeight: 1,
timestamp: time.Now().UTC(),
prevStateRoot: types.Hash{1, 2, 3},
expectedErr: types.ErrEmptyTx,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stateRoot, maxBytes, err := s.Exec.ExecuteTxs(context.Background(), tt.txs, tt.blockHeight, tt.timestamp, tt.prevStateRoot)
if tt.expectedErr != nil {
require.ErrorIs(t, err, tt.expectedErr)
return
}
require.NoError(t, err)
require.NotEqual(t, types.Hash{}, stateRoot)
require.Greater(t, maxBytes, uint64(0))
})
}
}

func (s *DummyTestSuite) TestInitChain() {
t := s.T()
tests := []struct {
name string
genesisTime time.Time
initialHeight uint64
chainID string
expectedErr error
}{
{
name: "valid case",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "test-chain",
expectedErr: nil,
},
{
name: "very large initial height",
genesisTime: time.Now().UTC(),
initialHeight: 1000000,
chainID: "test-chain",
expectedErr: nil,
},
{
name: "zero height",
genesisTime: time.Now().UTC(),
initialHeight: 0,
chainID: "test-chain",
expectedErr: types.ErrZeroInitialHeight,
},
{
name: "empty chain ID",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "",
expectedErr: types.ErrEmptyChainID,
},
{
name: "future genesis time",
genesisTime: time.Now().Add(1 * time.Hour),
initialHeight: 1,
chainID: "test-chain",
expectedErr: types.ErrFutureGenesisTime,
},
{
name: "invalid chain ID characters",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: "@invalid",
expectedErr: types.ErrInvalidChainID,
},
{
name: "invalid chain ID length",
genesisTime: time.Now().UTC(),
initialHeight: 1,
chainID: strings.Repeat("a", 50),
expectedErr: types.ErrChainIDTooLong,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stateRoot, maxBytes, err := s.Exec.InitChain(context.Background(), tt.genesisTime, tt.initialHeight, tt.chainID)
if tt.expectedErr != nil {
require.ErrorIs(t, err, tt.expectedErr)
return
}
require.NoError(t, err)
require.NotEqual(t, types.Hash{}, stateRoot)
require.Greater(t, maxBytes, uint64(0))
})
}
}

func (s *DummyTestSuite) TestGetTxsWithConcurrency() {
t := s.T()
const numGoroutines = 10
const txsPerGoroutine = 100

var wg sync.WaitGroup
wg.Add(numGoroutines)

// Inject transactions concurrently
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < txsPerGoroutine; j++ {
tx := types.Tx([]byte(fmt.Sprintf("tx-%d-%d", id, j)))
s.TxInjector.InjectTx(tx)
}
}(i)
}
wg.Wait()

// Verify all transactions are retrievable
txs, err := s.Exec.GetTxs(context.Background())
require.NoError(t, err)
require.Len(t, txs, numGoroutines*txsPerGoroutine)

// Verify transaction uniqueness
txMap := make(map[string]struct{})
for _, tx := range txs {
txMap[string(tx)] = struct{}{}
}
require.Len(t, txMap, numGoroutines*txsPerGoroutine)
}
56 changes: 56 additions & 0 deletions types/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package types

import "errors"

var (
// Chain initialization errors

// ErrZeroInitialHeight is returned when the initial height is zero
ErrZeroInitialHeight = errors.New("initial height cannot be zero")
// ErrEmptyChainID is returned when the chain ID is empty
ErrEmptyChainID = errors.New("chain ID cannot be empty")
// ErrInvalidChainID is returned when the chain ID contains invalid characters
ErrInvalidChainID = errors.New("chain ID contains invalid characters")
// ErrChainIDTooLong is returned when the chain ID exceeds maximum length
ErrChainIDTooLong = errors.New("chain ID exceeds maximum length")
// ErrFutureGenesisTime is returned when the genesis time is in the future
ErrFutureGenesisTime = errors.New("genesis time cannot be in the future")

// Transaction execution errors

// ErrEmptyStateRoot is returned when the previous state root is empty
ErrEmptyStateRoot = errors.New("previous state root cannot be empty")
// ErrFutureBlockTime is returned when the block timestamp is in the future
ErrFutureBlockTime = errors.New("block timestamp cannot be in the future")
// ErrInvalidBlockHeight is returned when the block height is invalid
ErrInvalidBlockHeight = errors.New("invalid block height")
// ErrTxTooLarge is returned when the transaction size exceeds maximum allowed
ErrTxTooLarge = errors.New("transaction size exceeds maximum allowed")
// ErrEmptyTx is returned when the transaction is empty
ErrEmptyTx = errors.New("transaction cannot be empty")

// Block finalization errors

// ErrBlockNotFound is returned when the block is not found
ErrBlockNotFound = errors.New("block not found")
// ErrBlockAlreadyExists is returned when the block already exists
ErrBlockAlreadyExists = errors.New("block already exists")
// ErrNonSequentialBlock is returned when the block height is not sequential
ErrNonSequentialBlock = errors.New("non-sequential block height")

// Transaction pool errors

// ErrTxAlreadyExists is returned when the transaction already exists in pool
ErrTxAlreadyExists = errors.New("transaction already exists in pool")
// ErrTxPoolFull is returned when the transaction pool is full
ErrTxPoolFull = errors.New("transaction pool is full")
// ErrInvalidTxFormat is returned when the transaction format is invalid
ErrInvalidTxFormat = errors.New("invalid transaction format")

// Context errors

// ErrContextCanceled is returned when the context is canceled
ErrContextCanceled = errors.New("context canceled")
// ErrContextTimeout is returned when the context deadline is exceeded
ErrContextTimeout = errors.New("context deadline exceeded")
)