Skip to content

Commit b696c9d

Browse files
committed
test: refactor tests
TODO: Canonical output values
1 parent f17ae7f commit b696c9d

File tree

1 file changed

+23
-87
lines changed

1 file changed

+23
-87
lines changed

fastembed_test.go

+23-87
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,15 @@ package fastembed
22

33
import (
44
"fmt"
5-
"reflect"
65
"testing"
76
)
87

9-
func TestNewFlagEmbedding(t *testing.T) {
10-
// Test with default options
11-
_, err := NewFlagEmbedding(&InitOptions{})
12-
if err != nil {
13-
t.Fatalf("Expected no error, got %v", err)
14-
}
15-
}
16-
17-
func TestEmbed(t *testing.T) {
8+
// TODO: Added canonical tests for all models
9+
func TestEmbedBGESmallEN(t *testing.T) {
1810
// Test with a single input
19-
fe, err := NewFlagEmbedding(&InitOptions{})
11+
fe, err := NewFlagEmbedding(&InitOptions{
12+
Model: BGESmallEN,
13+
})
2014
defer fe.Destroy()
2115
if err != nil {
2216
t.Fatalf("Expected no error, got %v", err)
@@ -32,100 +26,42 @@ func TestEmbed(t *testing.T) {
3226
}
3327
}
3428

35-
func TestQueryEmbed(t *testing.T) {
29+
func TestEmbedBGEBaseEN(t *testing.T) {
3630
// Test with a single input
37-
fe, err := NewFlagEmbedding(&InitOptions{})
31+
fe, err := NewFlagEmbedding(&InitOptions{
32+
Model: BGEBaseEN,
33+
})
34+
defer fe.Destroy()
3835
if err != nil {
3936
t.Fatalf("Expected no error, got %v", err)
4037
}
41-
input := "Hello, world!"
42-
result, err := fe.QueryEmbed(input)
38+
input := []string{"Is the world doing okay?"}
39+
result, err := fe.Embed(input, 1)
4340
if err != nil {
4441
t.Fatalf("Expected no error, got %v", err)
4542
}
46-
if len(result) == 0 {
47-
t.Errorf("Expected non-empty result")
43+
fmt.Println(result[0][0])
44+
if len(result) != len(input) {
45+
t.Errorf("Expected result length %v, got %v", len(input), len(result))
4846
}
4947
}
5048

51-
func TestPassageEmbed(t *testing.T) {
49+
func TestEmbedAllMiniLML6V2(t *testing.T) {
5250
// Test with a single input
53-
fe, err := NewFlagEmbedding(&InitOptions{})
51+
fe, err := NewFlagEmbedding(&InitOptions{
52+
Model: AllMiniLML6V2,
53+
})
54+
defer fe.Destroy()
5455
if err != nil {
5556
t.Fatalf("Expected no error, got %v", err)
5657
}
57-
input := []string{"Hello, world!"}
58-
result, err := fe.PassageEmbed(input, 1)
58+
input := []string{"Is the world doing okay?"}
59+
result, err := fe.Embed(input, 1)
5960
if err != nil {
6061
t.Fatalf("Expected no error, got %v", err)
6162
}
63+
fmt.Println(result[0][0])
6264
if len(result) != len(input) {
6365
t.Errorf("Expected result length %v, got %v", len(input), len(result))
6466
}
6567
}
66-
67-
func TestEncodingToInt32(t *testing.T) {
68-
inputA := []int{1, 2, 3}
69-
inputB := []int{4, 5, 6}
70-
inputC := []int{7, 8, 9}
71-
outputA, outputB, outputC := encodingToInt32(inputA, inputB, inputC)
72-
expectedA := []int64{1, 2, 3}
73-
expectedB := []int64{4, 5, 6}
74-
expectedC := []int64{7, 8, 9}
75-
if !reflect.DeepEqual(outputA, expectedA) {
76-
t.Errorf("Expected %v, got %v", expectedA, outputA)
77-
}
78-
if !reflect.DeepEqual(outputB, expectedB) {
79-
t.Errorf("Expected %v, got %v", expectedB, outputB)
80-
}
81-
if !reflect.DeepEqual(outputC, expectedC) {
82-
t.Errorf("Expected %v, got %v", expectedC, outputC)
83-
}
84-
}
85-
86-
// // Define the canonical vector values as a map
87-
// var canonicalVectorValues = map[string][]float64{
88-
// "BAAI/bge-small-en": {-0.0232, -0.0255, 0.0174, -0.0639, -0.0006},
89-
// "BAAI/bge-base-en": {0.0115, 0.0372, 0.0295, 0.0121, 0.0346},
90-
// "sentence-transformers/all-MiniLM-L6-v2": {0.0259, 0.0058, 0.0114, 0.0380, -0.0233},
91-
// "intfloat/multilingual-e5-large": {0.0098, 0.0045, 0.0066, -0.0354, 0.0070},
92-
// }
93-
94-
// // Define the test for default embedding
95-
// func TestDefaultEmbedding(t *testing.T) {
96-
// for _, modelDesc := range Embedding.ListSupportedModels() {
97-
// dim := modelDesc["dim"]
98-
// model := DefaultEmbedding(modelDesc["model"])
99-
100-
// docs := []string{"hello world", "flag embedding"}
101-
// embeddings := model.Embed(docs)
102-
// if len(embeddings) != 2 || len(embeddings[0]) != dim {
103-
// t.Errorf("Expected embeddings shape (2, %v), got (%v, %v)", dim, len(embeddings), len(embeddings[0]))
104-
// }
105-
106-
// canonicalVector := canonicalVectorValues[modelDesc["model"]]
107-
// for i, val := range embeddings[0][:len(canonicalVector)] {
108-
// if math.Abs(val-canonicalVector[i]) > 1e-3 {
109-
// t.Errorf("Expected %v, got %v", canonicalVector[i], val)
110-
// }
111-
// }
112-
// }
113-
// }
114-
115-
// // Define the test for batch embedding
116-
// func TestBatchEmbedding(t *testing.T) {
117-
// model := DefaultEmbedding()
118-
119-
// docs := make([]string, 200)
120-
// for i := range docs {
121-
// if i%2 == 0 {
122-
// docs[i] = "hello world"
123-
// } else {
124-
// docs[i] = "flag embedding"
125-
// }
126-
// }
127-
// embeddings := model.Embed(docs, 10)
128-
// if len(embeddings) != 200 || len(embeddings[0]) != 384 {
129-
// t.Errorf("Expected embeddings shape (200, 384), got (%v, %v)", len(embeddings), len(embeddings[0]))
130-
// }
131-
// }

0 commit comments

Comments
 (0)