Skip to content

Commit df46075

Browse files
erkamarkphelps
andauthored
fix(go): improve wasm memory allocations and deallocations (#790)
Signed-off-by: Mark Phelps <[email protected]> Signed-off-by: Roman Dmytrenko <[email protected]> Co-authored-by: Mark Phelps <[email protected]>
1 parent a531bdd commit df46075

File tree

2 files changed

+71
-51
lines changed

2 files changed

+71
-51
lines changed

Diff for: flipt-client-go/benchmark_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func init() {
7373

7474
func generateLargeContext(size int) map[string]string {
7575
context := make(map[string]string)
76-
for i := 0; i < size; i++ {
76+
for i := range size {
7777
context[fmt.Sprintf("key%d", i)] = fmt.Sprintf("value%d", i)
7878
}
7979
return context

Diff for: flipt-client-go/evaluation.go

+70-50
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@ import (
1919
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
2020
)
2121

22-
var (
23-
//go:embed ext/flipt_engine_wasm.wasm
24-
wasm []byte
25-
)
22+
//go:embed ext/flipt_engine_wasm.wasm
23+
var wasm []byte
2624

2725
const (
2826
statusSuccess = "success"
@@ -65,11 +63,6 @@ type EvaluationClient struct {
6563
errChan chan error
6664
snapshotChan chan snapshot
6765

68-
// cached WASM functions
69-
allocFunc api.Function
70-
deallocFunc api.Function
71-
snapshotFunc api.Function
72-
7366
closeOnce sync.Once
7467
}
7568

@@ -178,13 +171,6 @@ func NewEvaluationClient(ctx context.Context, opts ...ClientOption) (_ *Evaluati
178171
return
179172
}
180173

181-
var (
182-
// cache common WASM functions
183-
allocFunc = mod.ExportedFunction(fAllocate)
184-
deallocFunc = mod.ExportedFunction(fDeallocate)
185-
snapshotFunc = mod.ExportedFunction(fSnapshot)
186-
)
187-
188174
ctx, cancel := context.WithCancel(ctx)
189175

190176
client := &EvaluationClient{
@@ -202,11 +188,6 @@ func NewEvaluationClient(ctx context.Context, opts ...ClientOption) (_ *Evaluati
202188
cancel: cancel,
203189
errChan: make(chan error, 1),
204190
snapshotChan: make(chan snapshot, 1),
205-
206-
// cache WASM functions
207-
allocFunc: allocFunc,
208-
deallocFunc: deallocFunc,
209-
snapshotFunc: snapshotFunc,
210191
}
211192

212193
for _, opt := range opts {
@@ -250,7 +231,10 @@ func NewEvaluationClient(ctx context.Context, opts ...ClientOption) (_ *Evaluati
250231
client.httpClient.Timeout = client.requestTimeout
251232
}
252233

253-
var initializeEngine = mod.ExportedFunction(fInitializeEngine)
234+
var (
235+
initializeEngine = mod.ExportedFunction(fInitializeEngine)
236+
allocFunc = mod.ExportedFunction(fAllocate)
237+
)
254238

255239
// allocate namespace
256240
nsPtr, err := allocFunc.Call(ctx, uint64(len(client.namespace)))
@@ -463,11 +447,11 @@ func (e *EvaluationClient) EvaluateBatch(ctx context.Context, requests []*Evalua
463447
// ListFlags lists all flags.
464448
func (e *EvaluationClient) ListFlags(ctx context.Context) ([]Flag, error) {
465449
e.mu.RLock()
450+
defer e.mu.RUnlock()
451+
466452
if e.err != nil && e.errorStrategy == ErrorStrategyFail {
467-
e.mu.RUnlock()
468453
return nil, e.err
469454
}
470-
e.mu.RUnlock()
471455

472456
if e.engine == 0 {
473457
return nil, errors.New("engine not initialized")
@@ -484,27 +468,35 @@ func (e *EvaluationClient) ListFlags(ctx context.Context) ([]Flag, error) {
484468
}
485469

486470
ptr, length := decodePtr(res[0])
487-
defer e.deallocFunc.Call(ctx, uint64(ptr), uint64(length))
471+
deallocFunc := e.mod.ExportedFunction(fDeallocate)
488472

489473
b, ok := e.mod.Memory().Read(ptr, length)
490474
if !ok {
475+
deallocFunc.Call(ctx, uint64(ptr), uint64(length))
491476
return nil, fmt.Errorf("failed to read result from memory")
492477
}
493478

494-
var result *ListFlagsResult
495-
if err := json.Unmarshal(b, &result); err != nil {
479+
// make a copy of the result before deallocating
480+
result := make([]byte, len(b))
481+
copy(result, b)
482+
483+
// clean up memory
484+
deallocFunc.Call(ctx, uint64(ptr), uint64(length))
485+
486+
var listResult *ListFlagsResult
487+
if err := json.Unmarshal(result, &listResult); err != nil {
496488
return nil, fmt.Errorf("failed to unmarshal flags: %w", err)
497489
}
498490

499-
if result == nil {
491+
if listResult == nil {
500492
return nil, errors.New("failed to unmarshal flags: nil")
501493
}
502494

503-
if result.Status != statusSuccess {
504-
return nil, errors.New(result.ErrorMessage)
495+
if listResult.Status != statusSuccess {
496+
return nil, errors.New(listResult.ErrorMessage)
505497
}
506498

507-
return *result.Result, nil
499+
return *listResult.Result, nil
508500
}
509501

510502
// Close cleans up the allocated resources.
@@ -544,47 +536,57 @@ type snapshot struct {
544536
}
545537

546538
func (e *EvaluationClient) handleUpdates(ctx context.Context) error {
539+
var (
540+
allocFunc = e.mod.ExportedFunction(fAllocate)
541+
deallocFunc = e.mod.ExportedFunction(fDeallocate)
542+
snapshotFunc = e.mod.ExportedFunction(fSnapshot)
543+
)
544+
547545
for {
548546
select {
549547
case <-ctx.Done():
550548
close(e.snapshotChan)
551549
return nil
552550
case s, ok := <-e.snapshotChan:
553551
if !ok {
554-
// we are likely shutting down
555552
return nil
556553
}
557554

558555
e.mu.Lock()
559556
e.etag = s.etag
560557
e.mu.Unlock()
561558

562-
// skip update if no changes (304 response) or error
563559
if len(s.payload) == 0 {
564560
continue
565561
}
566562

563+
e.mu.Lock()
567564
// allocate memory for the new payload
568-
pmPtr, err := e.allocFunc.Call(ctx, uint64(len(s.payload)))
565+
pmPtr, err := allocFunc.Call(ctx, uint64(len(s.payload)))
569566
if err != nil {
567+
e.mu.Unlock()
570568
return fmt.Errorf("failed to allocate memory for payload: %w", err)
571569
}
572570

573571
// write the new payload to memory
574572
if !e.mod.Memory().Write(uint32(pmPtr[0]), s.payload) {
575-
e.deallocFunc.Call(ctx, uint64(pmPtr[0]), uint64(len(s.payload)))
573+
e.mu.Unlock()
574+
deallocFunc.Call(ctx, uint64(pmPtr[0]), uint64(len(s.payload)))
576575
return fmt.Errorf("failed to write payload to memory")
577576
}
578577

579-
// update the engine with the new snapshot
580-
_, err = e.snapshotFunc.Call(ctx, uint64(e.engine), pmPtr[0], uint64(len(s.payload)))
578+
// update the engine with the new snapshot while holding the lock
579+
res, err := snapshotFunc.Call(ctx, uint64(e.engine), pmPtr[0], uint64(len(s.payload)))
580+
581+
ptr, length := decodePtr(res[0])
582+
// always deallocate the memory after we're done with it
583+
deallocFunc.Call(ctx, uint64(pmPtr[0]), uint64(len(s.payload)))
584+
deallocFunc.Call(ctx, uint64(ptr), uint64(length))
585+
586+
e.mu.Unlock()
581587
if err != nil {
582-
e.deallocFunc.Call(ctx, uint64(pmPtr[0]), uint64(len(s.payload)))
583588
return fmt.Errorf("failed to update engine: %w", err)
584589
}
585-
586-
// clean up the memory we allocated for the payload
587-
e.deallocFunc.Call(ctx, uint64(pmPtr[0]), uint64(len(s.payload)))
588590
}
589591
}
590592
}
@@ -665,7 +667,7 @@ func (e *EvaluationClient) fetch(ctx context.Context, etag string) (snapshot, er
665667
}
666668

667669
if resp.StatusCode == http.StatusNotModified {
668-
return snapshot{}, nil
670+
return snapshot{etag: etag}, nil
669671
}
670672

671673
if resp.StatusCode != http.StatusOK {
@@ -734,13 +736,13 @@ func (e *EvaluationClient) startStreaming(ctx context.Context) {
734736
case <-ctx.Done():
735737
return
736738
default:
737-
// Create a channel to receive the read result
739+
// create a channel to receive the read result
738740
readChan := make(chan struct {
739741
line []byte
740742
err error
741743
})
742744

743-
// Start a goroutine to perform the blocking read
745+
// start a goroutine to perform the blocking read
744746
go func() {
745747
line, err := reader.ReadBytes('\n')
746748
readChan <- struct {
@@ -749,7 +751,7 @@ func (e *EvaluationClient) startStreaming(ctx context.Context) {
749751
}{line, err}
750752
}()
751753

752-
// Wait for either the read to complete or context cancellation
754+
// wait for either the read to complete or context cancellation
753755
select {
754756
case <-ctx.Done():
755757
return
@@ -784,42 +786,60 @@ func (e *EvaluationClient) evaluateWASM(ctx context.Context, funcName string, re
784786
return nil, errors.New("engine not initialized")
785787
}
786788

789+
var (
790+
allocFunc = e.mod.ExportedFunction(fAllocate)
791+
deallocFunc = e.mod.ExportedFunction(fDeallocate)
792+
)
793+
787794
reqBytes, err := json.Marshal(request)
788795
if err != nil {
789796
return nil, fmt.Errorf("failed to marshal request: %w", err)
790797
}
791798

792-
reqPtr, err := e.allocFunc.Call(ctx, uint64(len(reqBytes)))
799+
e.mu.Lock()
800+
reqPtr, err := allocFunc.Call(ctx, uint64(len(reqBytes)))
793801
if err != nil {
802+
e.mu.Unlock()
794803
return nil, fmt.Errorf("failed to allocate memory for request: %w", err)
795804
}
796-
defer e.deallocFunc.Call(ctx, reqPtr[0], uint64(len(reqBytes)))
797805

798806
if !e.mod.Memory().Write(uint32(reqPtr[0]), reqBytes) {
807+
deallocFunc.Call(ctx, reqPtr[0], uint64(len(reqBytes)))
808+
e.mu.Unlock()
799809
return nil, fmt.Errorf("failed to write request to memory")
800810
}
801811

802812
evalFunc := e.mod.ExportedFunction(funcName)
803813
res, err := evalFunc.Call(ctx, uint64(e.engine), reqPtr[0], uint64(len(reqBytes)))
804814
if err != nil {
815+
deallocFunc.Call(ctx, reqPtr[0], uint64(len(reqBytes)))
816+
e.mu.Unlock()
805817
return nil, fmt.Errorf("failed to call %s: %w", funcName, err)
806818
}
807819

820+
// clean up request memory
821+
deallocFunc.Call(ctx, reqPtr[0], uint64(len(reqBytes)))
822+
808823
if len(res) < 1 {
824+
e.mu.Unlock()
809825
return nil, fmt.Errorf("failed to call %s: no result returned", funcName)
810826
}
811827

812828
ptr, length := decodePtr(res[0])
813-
defer e.deallocFunc.Call(ctx, uint64(ptr), uint64(length))
814-
815829
b, ok := e.mod.Memory().Read(ptr, length)
816830
if !ok {
831+
deallocFunc.Call(ctx, uint64(ptr), uint64(length))
832+
e.mu.Unlock()
817833
return nil, fmt.Errorf("failed to read result from memory")
818834
}
819835

820-
// Make a copy of the result before deallocating
836+
// make a copy of the result before deallocating
821837
result := make([]byte, len(b))
822838
copy(result, b)
823839

840+
// clean up result memory
841+
deallocFunc.Call(ctx, uint64(ptr), uint64(length))
842+
e.mu.Unlock()
843+
824844
return result, nil
825845
}

0 commit comments

Comments
 (0)