@@ -2,21 +2,15 @@ package fastembed
2
2
3
3
import (
4
4
"fmt"
5
- "reflect"
6
5
"testing"
7
6
)
8
7
9
- func TestNewFlagEmbedding (t * testing.T ) {
10
- // Test with default options
11
- _ , err := NewFlagEmbedding (& InitOptions {})
12
- if err != nil {
13
- t .Fatalf ("Expected no error, got %v" , err )
14
- }
15
- }
16
-
17
- func TestEmbed (t * testing.T ) {
8
+ // TODO: Added canonical tests for all models
9
+ func TestEmbedBGESmallEN (t * testing.T ) {
18
10
// Test with a single input
19
- fe , err := NewFlagEmbedding (& InitOptions {})
11
+ fe , err := NewFlagEmbedding (& InitOptions {
12
+ Model : BGESmallEN ,
13
+ })
20
14
defer fe .Destroy ()
21
15
if err != nil {
22
16
t .Fatalf ("Expected no error, got %v" , err )
@@ -32,100 +26,42 @@ func TestEmbed(t *testing.T) {
32
26
}
33
27
}
34
28
35
- func TestQueryEmbed (t * testing.T ) {
29
+ func TestEmbedBGEBaseEN (t * testing.T ) {
36
30
// Test with a single input
37
- fe , err := NewFlagEmbedding (& InitOptions {})
31
+ fe , err := NewFlagEmbedding (& InitOptions {
32
+ Model : BGEBaseEN ,
33
+ })
34
+ defer fe .Destroy ()
38
35
if err != nil {
39
36
t .Fatalf ("Expected no error, got %v" , err )
40
37
}
41
- input := "Hello, world!"
42
- result , err := fe .QueryEmbed (input )
38
+ input := [] string { "Is the world doing okay?" }
39
+ result , err := fe .Embed (input , 1 )
43
40
if err != nil {
44
41
t .Fatalf ("Expected no error, got %v" , err )
45
42
}
46
- if len (result ) == 0 {
47
- t .Errorf ("Expected non-empty result" )
43
+ fmt .Println (result [0 ][0 ])
44
+ if len (result ) != len (input ) {
45
+ t .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
48
46
}
49
47
}
50
48
51
- func TestPassageEmbed (t * testing.T ) {
49
+ func TestEmbedAllMiniLML6V2 (t * testing.T ) {
52
50
// Test with a single input
53
- fe , err := NewFlagEmbedding (& InitOptions {})
51
+ fe , err := NewFlagEmbedding (& InitOptions {
52
+ Model : AllMiniLML6V2 ,
53
+ })
54
+ defer fe .Destroy ()
54
55
if err != nil {
55
56
t .Fatalf ("Expected no error, got %v" , err )
56
57
}
57
- input := []string {"Hello, world! " }
58
- result , err := fe .PassageEmbed (input , 1 )
58
+ input := []string {"Is the world doing okay? " }
59
+ result , err := fe .Embed (input , 1 )
59
60
if err != nil {
60
61
t .Fatalf ("Expected no error, got %v" , err )
61
62
}
63
+ fmt .Println (result [0 ][0 ])
62
64
if len (result ) != len (input ) {
63
65
t .Errorf ("Expected result length %v, got %v" , len (input ), len (result ))
64
66
}
65
67
}
66
-
67
- func TestEncodingToInt32 (t * testing.T ) {
68
- inputA := []int {1 , 2 , 3 }
69
- inputB := []int {4 , 5 , 6 }
70
- inputC := []int {7 , 8 , 9 }
71
- outputA , outputB , outputC := encodingToInt32 (inputA , inputB , inputC )
72
- expectedA := []int64 {1 , 2 , 3 }
73
- expectedB := []int64 {4 , 5 , 6 }
74
- expectedC := []int64 {7 , 8 , 9 }
75
- if ! reflect .DeepEqual (outputA , expectedA ) {
76
- t .Errorf ("Expected %v, got %v" , expectedA , outputA )
77
- }
78
- if ! reflect .DeepEqual (outputB , expectedB ) {
79
- t .Errorf ("Expected %v, got %v" , expectedB , outputB )
80
- }
81
- if ! reflect .DeepEqual (outputC , expectedC ) {
82
- t .Errorf ("Expected %v, got %v" , expectedC , outputC )
83
- }
84
- }
85
-
86
- // // Define the canonical vector values as a map
87
- // var canonicalVectorValues = map[string][]float64{
88
- // "BAAI/bge-small-en": {-0.0232, -0.0255, 0.0174, -0.0639, -0.0006},
89
- // "BAAI/bge-base-en": {0.0115, 0.0372, 0.0295, 0.0121, 0.0346},
90
- // "sentence-transformers/all-MiniLM-L6-v2": {0.0259, 0.0058, 0.0114, 0.0380, -0.0233},
91
- // "intfloat/multilingual-e5-large": {0.0098, 0.0045, 0.0066, -0.0354, 0.0070},
92
- // }
93
-
94
- // // Define the test for default embedding
95
- // func TestDefaultEmbedding(t *testing.T) {
96
- // for _, modelDesc := range Embedding.ListSupportedModels() {
97
- // dim := modelDesc["dim"]
98
- // model := DefaultEmbedding(modelDesc["model"])
99
-
100
- // docs := []string{"hello world", "flag embedding"}
101
- // embeddings := model.Embed(docs)
102
- // if len(embeddings) != 2 || len(embeddings[0]) != dim {
103
- // t.Errorf("Expected embeddings shape (2, %v), got (%v, %v)", dim, len(embeddings), len(embeddings[0]))
104
- // }
105
-
106
- // canonicalVector := canonicalVectorValues[modelDesc["model"]]
107
- // for i, val := range embeddings[0][:len(canonicalVector)] {
108
- // if math.Abs(val-canonicalVector[i]) > 1e-3 {
109
- // t.Errorf("Expected %v, got %v", canonicalVector[i], val)
110
- // }
111
- // }
112
- // }
113
- // }
114
-
115
- // // Define the test for batch embedding
116
- // func TestBatchEmbedding(t *testing.T) {
117
- // model := DefaultEmbedding()
118
-
119
- // docs := make([]string, 200)
120
- // for i := range docs {
121
- // if i%2 == 0 {
122
- // docs[i] = "hello world"
123
- // } else {
124
- // docs[i] = "flag embedding"
125
- // }
126
- // }
127
- // embeddings := model.Embed(docs, 10)
128
- // if len(embeddings) != 200 || len(embeddings[0]) != 384 {
129
- // t.Errorf("Expected embeddings shape (200, 384), got (%v, %v)", len(embeddings), len(embeddings[0]))
130
- // }
131
- // }
0 commit comments