diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..8c624ff --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,271 @@ +# This code is licensed under the terms of the MIT license https://opensource.org/license/mit +# Copyright (c) 2021 Marat Reymers + +run: + # Timeout for analysis, e.g. 30s, 5m. + # Default: 1m + timeout: 3m + + +# This file contains only configs which differ from defaults. +# All possible options can be found here https://github.com/golangci/golangci-lint/blob/master/.golangci.reference.yml +linters-settings: + cyclop: + # The maximal code complexity to report. + # Default: 10 + max-complexity: 30 + # The maximal average package complexity. + # If it's higher than 0.0 (float) the check is enabled + # Default: 0.0 + package-average: 10.0 + + errcheck: + # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. + # Such cases aren't reported by default. + # Default: false + check-type-assertions: true + + exhaustive: + # Program elements to check for exhaustiveness. + # Default: [ switch ] + check: + - switch + - map + + exhaustruct: + # List of regular expressions to exclude struct packages and names from check. + # Default: [] + exclude: + # std libs + - "^net/http.Client$" + - "^net/http.Cookie$" + - "^net/http.Request$" + - "^net/http.Response$" + - "^net/http.Server$" + - "^net/http.Transport$" + - "^net/url.URL$" + - "^os/exec.Cmd$" + - "^reflect.StructField$" + # public libs + - "^github.com/Shopify/sarama.Config$" + - "^github.com/Shopify/sarama.ProducerMessage$" + - "^github.com/mitchellh/mapstructure.DecoderConfig$" + - "^github.com/prometheus/client_golang/.+Opts$" + - "^github.com/spf13/cobra.Command$" + - "^github.com/spf13/cobra.CompletionOptions$" + - "^github.com/stretchr/testify/mock.Mock$" + - "^github.com/testcontainers/testcontainers-go.+Request$" + - "^github.com/testcontainers/testcontainers-go.FromDockerfile$" + - "^golang.org/x/tools/go/analysis.Analyzer$" + - "^google.golang.org/protobuf/.+Options$" + - "^gopkg.in/yaml.v3.Node$" + + funlen: + # Checks the number of lines in a function. + # If lower than 0, disable the check. + # Default: 60 + lines: 100 + # Checks the number of statements in a function. + # If lower than 0, disable the check. + # Default: 40 + statements: 50 + # Ignore comments when counting lines. + # Default false + ignore-comments: true + + gocognit: + # Minimal code complexity to report. + # Default: 30 (but we recommend 10-20) + min-complexity: 20 + + gocritic: + # Settings passed to gocritic. + # The settings key is the name of a supported gocritic checker. + # The list of supported checkers can be find in https://go-critic.github.io/overview. + settings: + captLocal: + # Whether to restrict checker to params only. + # Default: true + paramsOnly: false + underef: + # Whether to skip (*x).method() calls where x is a pointer receiver. + # Default: true + skipRecvDeref: false + + gomnd: + # List of function patterns to exclude from analysis. + # Values always ignored: `time.Date`, + # `strconv.FormatInt`, `strconv.FormatUint`, `strconv.FormatFloat`, + # `strconv.ParseInt`, `strconv.ParseUint`, `strconv.ParseFloat`. + # Default: [] + ignored-functions: + - flag.Arg + - flag.Duration.* + - flag.Float.* + - flag.Int.* + - flag.Uint.* + - os.Chmod + - os.Mkdir.* + - os.OpenFile + - os.WriteFile + - prometheus.ExponentialBuckets.* + - prometheus.LinearBuckets + + gomodguard: + blocked: + # List of blocked modules. + # Default: [] + modules: + - github.com/golang/protobuf: + recommendations: + - google.golang.org/protobuf + reason: "see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules" + - github.com/satori/go.uuid: + recommendations: + - github.com/google/uuid + reason: "satori's package is not maintained" + - github.com/gofrs/uuid: + recommendations: + - github.com/google/uuid + reason: "gofrs' package is not go module" + + govet: + # Enable all analyzers. + # Default: false + enable-all: true + # Disable analyzers by name. + # Run `go tool vet help` to see all analyzers. + # Default: [] + disable: + - fieldalignment # too strict + # Settings per analyzer. + settings: + shadow: + # Whether to be strict about shadowing; can be noisy. + # Default: false + strict: true + + nakedret: + # Make an issue if func has more lines of code than this setting, and it has naked returns. + # Default: 30 + max-func-lines: 0 + + nolintlint: + # Exclude following linters from requiring an explanation. + # Default: [] + allow-no-explanation: [ funlen, gocognit, lll ] + # Enable to require an explanation of nonzero length after each nolint directive. + # Default: false + require-explanation: true + # Enable to require nolint directives to mention the specific linter being suppressed. + # Default: false + require-specific: true + + rowserrcheck: + # database/sql is always checked + # Default: [] + packages: + - github.com/jmoiron/sqlx + + tenv: + # The option `all` will run against whole test files (`_test.go`) regardless of method/function signatures. + # Otherwise, only methods that take `*testing.T`, `*testing.B`, and `testing.TB` as arguments are checked. + # Default: false + all: true + + +linters: + disable-all: true + enable: + ## enabled by default + - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases + - gosimple # specializes in simplifying a code + - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string + - ineffassign # detects when assignments to existing variables are not used + - staticcheck # is a go vet on steroids, applying a ton of static analysis checks + - typecheck # like the front-end of a Go compiler, parses and type-checks Go code + - unused # checks for unused constants, variables, functions and types + ## disabled by default + - asasalint # checks for pass []any as any in variadic func(...any) + - asciicheck # checks that your code does not contain non-ASCII identifiers + - bidichk # checks for dangerous unicode character sequences + - bodyclose # checks whether HTTP response body is closed successfully + - cyclop # checks function and package cyclomatic complexity + - dupl # tool for code clone detection + - durationcheck # checks for two durations multiplied together + - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error + - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13 + - execinquery # checks query string in Query function which reads your Go src files and warning it finds + - exhaustive # checks exhaustiveness of enum switch statements + - exportloopref # checks for pointers to enclosing loop variables + - forbidigo # forbids identifiers + - funlen # tool for detection of long functions + - gocheckcompilerdirectives # validates go compiler directive comments (//go:) + - gochecknoglobals # checks that no global variables exist + - gochecknoinits # checks that no init functions are present in Go code + - gochecksumtype # checks exhaustiveness on Go "sum types" + - gocognit # computes and checks the cognitive complexity of functions + - goconst # finds repeated strings that could be replaced by a constant + - gocritic # provides diagnostics that check for bugs, performance and style issues + - gocyclo # computes and checks the cyclomatic complexity of functions + - godot # checks if comments end in a period + - goimports # in addition to fixing imports, goimports also formats your code in the same style as gofmt + - gomnd # detects magic numbers + - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod + - gomodguard # allow and block lists linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations + - goprintffuncname # checks that printf-like functions are named with f at the end + - gosec # inspects source code for security problems + - lll # reports long lines + - loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap) + - makezero # finds slice declarations with non-zero initial length + - mirror # reports wrong mirror patterns of bytes/strings usage + - musttag # enforces field tags in (un)marshaled structs + - nakedret # finds naked returns in functions greater than a specified function length + - nestif # reports deeply nested if statements + - nilerr # finds the code that returns nil even if it checks that the error is not nil + - nilnil # checks that there is no simultaneous return of nil error and an invalid value + - noctx # finds sending http request without context.Context + - nolintlint # reports ill-formed or insufficient nolint directives + - nonamedreturns # reports all named returns + - nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL + - perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative + - predeclared # finds code that shadows one of Go's predeclared identifiers + - promlinter # checks Prometheus metrics naming via promlint + - protogetter # reports direct reads from proto message fields when getters should be used + - reassign # checks that package variables are not reassigned + - revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint + - rowserrcheck # checks whether Err of rows is checked successfully + - sloglint # ensure consistent code style when using log/slog + - sqlclosecheck # checks that sql.Rows and sql.Stmt are closed + - stylecheck # is a replacement for golint + - tenv # detects using os.Setenv instead of t.Setenv since Go1.17 + - testableexamples # checks if examples are testable (have an expected output) + - testifylint # checks usage of github.com/stretchr/testify + - testpackage # makes you use a separate _test package + - tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes + - unconvert # removes unnecessary type conversions + - unparam # reports unused function parameters + - usestdlibvars # detects the possibility to use variables/constants from the Go standard library + - wastedassign # finds wasted assignment statements + - whitespace # detects leading and trailing whitespace + +issues: + # Maximum count of issues with the same text. + # Set to 0 to disable. + # Default: 3 + max-same-issues: 50 + + exclude-rules: + - source: "(noinspection|TODO)" + linters: [ godot ] + - source: "//noinspection" + linters: [ gocritic ] + - path: "_test\\.go" + linters: + - bodyclose + - dupl + - funlen + - goconst + - gosec + - noctx + - wrapcheck \ No newline at end of file diff --git a/fastembed.go b/fastembed.go index 11e0146..936831d 100644 --- a/fastembed.go +++ b/fastembed.go @@ -20,7 +20,7 @@ import ( ort "github.com/yalue/onnxruntime_go" ) -// Enum-type representing the available embedding models +// Enum-type representing the available embedding models. type EmbeddingModel string const ( @@ -36,7 +36,7 @@ const ( // MLE5Large EmbeddingModel = "fast-multilingual-e5-large" ) -// Struct to interface with a FastEmbed model +// Struct to interface with a FastEmbed model. type FlagEmbedding struct { tokenizer *tokenizer.Tokenizer model EmbeddingModel @@ -55,7 +55,7 @@ type FlagEmbedding struct { // not setting this flag and the user setting it to false. We want the default value to be true. // As Go assigns a default(empty) value of "false" to bools, we can't distinguish // if the user set it to false or not set at all. -// A pointer to bool will be nil if not set explicitly +// A pointer to bool will be nil if not set explicitly. type InitOptions struct { Model EmbeddingModel ExecutionProviders []string @@ -64,14 +64,14 @@ type InitOptions struct { ShowDownloadProgress *bool } -// Struct to represent FastEmbed model information +// Struct to represent FastEmbed model information. type ModelInfo struct { Model EmbeddingModel Dim int Description string } -// Function to initialize a FastEmbed model +// Function to initialize a FastEmbed model. func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) { if options == nil { options = &InitOptions{} @@ -120,17 +120,15 @@ func NewFlagEmbedding(options *InitOptions) (*FlagEmbedding, error) { maxLength: options.MaxLength, modelPath: modelPath, }, nil - } // Function to cleanup the internal onnxruntime environment when it is no longer needed. -func (f *FlagEmbedding) Destroy() { - ort.DestroyEnvironment() +func (f *FlagEmbedding) Destroy() error { + return ort.DestroyEnvironment() } -// Private function to embed a batch of input strings +// Private function to embed a batch of input strings. func (f *FlagEmbedding) onnxEmbed(input []string) ([]([]float32), error) { - inputs := make([]tokenizer.EncodeInput, len(input)) for index, v := range input { sequence := tokenizer.NewInputSequence(v) @@ -212,7 +210,7 @@ func (f *FlagEmbedding) onnxEmbed(input []string) ([]([]float32), error) { // The batchSize parameter controls the number of inputs to embed in a single batch // The batches are processed in parallel // Returns the first error encountered if any -// Default batch size is 256 +// Default batch size is 256. func (f *FlagEmbedding) Embed(input []string, batchSize int) ([]([]float32), error) { if batchSize <= 0 { batchSize = 256 @@ -220,7 +218,7 @@ func (f *FlagEmbedding) Embed(input []string, batchSize int) ([]([]float32), err embeddings := make([]([]float32), len(input)) var wg sync.WaitGroup errorCh := make(chan error, len(input)) - //var resultsMutex sync.Mutex + // var resultsMutex sync.Mutex for i := 0; i < len(input); i += batchSize { wg.Add(1) @@ -236,9 +234,8 @@ func (f *FlagEmbedding) Embed(input []string, batchSize int) ([]([]float32), err } // resultsMutex.Lock() // defer resultsMutex.Unlock() - //Removed the mutex as the slice positions being accessed are unique for each goroutine and there is no overlap + // Removed the mutex as the slice positions being accessed are unique for each goroutine and there is no overlap copy(embeddings[i:end], batchOut) - }(i) } wg.Wait() @@ -252,7 +249,7 @@ func (f *FlagEmbedding) Embed(input []string, batchSize int) ([]([]float32), err } // Function to embed a single input string prefixed with "query: " -// Recommended for generating query embeddings for semantic search +// Recommended for generating query embeddings for semantic search. func (f *FlagEmbedding) QueryEmbed(input string) ([]float32, error) { query := "query: " + input data, err := f.onnxEmbed([]string{query}) @@ -262,7 +259,7 @@ func (f *FlagEmbedding) QueryEmbed(input string) ([]float32, error) { return data[0], nil } -// Function to embed string prefixed with "passage: " +// Function to embed string prefixed with "passage: ". func (f *FlagEmbedding) PassageEmbed(input []string, batchSize int) ([]([]float32), error) { processedInput := make([]string, len(input)) for i, v := range input { @@ -271,7 +268,7 @@ func (f *FlagEmbedding) PassageEmbed(input []string, batchSize int) ([]([]float3 return f.Embed(processedInput, batchSize) } -// Function to list the supported FastEmbed models +// Function to list the supported FastEmbed models. func ListSupportedModels() []ModelInfo { return []ModelInfo{ { @@ -378,7 +375,6 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error specialTokens := make([]tokenizer.AddedToken, 0) for _, v := range tokensMap { - switch t := v.(type) { case map[string]interface{}: { @@ -399,14 +395,13 @@ func loadTokenizer(modelPath string, maxLength int) (*tokenizer.Tokenizer, error default: panic(fmt.Sprintf("unknown type for special_tokens_map.json%T", t)) } - } tknzer.AddSpecialTokens(specialTokens) return tknzer, nil } -// Private function to get model information from the model name +// Private function to get model information from the model name. func getModelInfo(model EmbeddingModel) (ModelInfo, error) { for _, m := range ListSupportedModels() { if m.Model == model { @@ -417,7 +412,7 @@ func getModelInfo(model EmbeddingModel) (ModelInfo, error) { } // Private function to retrieve the model from the cache or download it -// Returns the path to the model +// Returns the path to the model. func retrieveModel(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) { if _, err := os.Stat(filepath.Join(cacheDir, string(model))); !errors.Is(err, fs.ErrNotExist) { return filepath.Join(cacheDir, string(model)), nil @@ -425,7 +420,7 @@ func retrieveModel(model EmbeddingModel, cacheDir string, showDownloadProgress b return downloadFromGcs(model, cacheDir, showDownloadProgress) } -// Private function to download the model from Google Cloud Storage +// Private function to download the model from Google Cloud Storage. func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress bool) (string, error) { // The MLE5Large model URL doesn't follow the same naming convention as the other models // So, we tranform "fast-multilingual-e5-large" -> "intfloat-multilingual-e5-large" in the download URL @@ -466,7 +461,7 @@ func downloadFromGcs(model EmbeddingModel, cacheDir string, showDownloadProgress return filepath.Join(cacheDir, string(model)), nil } -// Private function to untar the downloaded model from a .tar.gz file +// Private function to untar the downloaded model from a .tar.gz file. func untar(tarball io.Reader, target string) error { archive, err := gzip.NewReader(tarball) if err != nil { @@ -527,7 +522,7 @@ func normalize(v []float32) []float32 { return normalized } -// Private function to return the normalized embeddings from a flattened array with the given dimensions +// Private function to return the normalized embeddings from a flattened array with the given dimensions. func getEmbeddings(data []float32, dimensions []int64) []([]float32) { x, y, z := dimensions[0], dimensions[1], dimensions[2] embeddings := make([][]float32, x) @@ -541,18 +536,18 @@ func getEmbeddings(data []float32, dimensions []int64) []([]float32) { } // Private function to convert multiple int32 slices to int64 slices as required by the onnxruntime API -// With a linear time complexity -func encodingToInt32(inputA, inputB, inputC []int) (outputA, outputB, outputC []int64) { +// With a linear time complexity. +func encodingToInt32(inputA, inputB, inputC []int) ([]int64, []int64, []int64) { if len(inputA) != len(inputB) || len(inputB) != len(inputC) { panic("input lengths do not match") } - outputA = make([]int64, len(inputA)) - outputB = make([]int64, len(inputB)) - outputC = make([]int64, len(inputC)) + outputA := make([]int64, len(inputA)) + outputB := make([]int64, len(inputB)) + outputC := make([]int64, len(inputC)) for i := range inputA { outputA[i] = int64(inputA[i]) outputB[i] = int64(inputB[i]) outputC[i] = int64(inputC[i]) } - return + return outputA, outputB, outputC } diff --git a/fastembed_test.go b/fastembed_test.go index 9e59efb..da4576b 100644 --- a/fastembed_test.go +++ b/fastembed_test.go @@ -1,44 +1,45 @@ -package fastembed +package fastembed_test import ( "math" "testing" + + fastembed "github.com/anush008/fastembed-go" ) -func TestCanonicalValues(T *testing.T) { - canonicalValues := map[EmbeddingModel]([]float32){ - AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328}, - BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061}, - BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451}, - BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045}, - BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434}, - BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762}, +func TestCanonicalValues(t *testing.T) { + canonicalValues := map[fastembed.EmbeddingModel]([]float32){ + fastembed.AllMiniLML6V2: []float32{0.02591, 0.00573, 0.01147, 0.03796, -0.02328}, + fastembed.BGESmallEN: []float32{-0.02313, -0.02552, 0.017357, -0.06393, -0.00061}, + fastembed.BGEBaseEN: []float32{0.01140, 0.03722, 0.02941, 0.01230, 0.03451}, + fastembed.BGEBaseENV15: []float32{0.01129394, 0.05493144, 0.02615099, 0.00328772, 0.02996045}, + fastembed.BGESmallENV15: []float32{0.01522374, -0.02271799, 0.00860278, -0.07424029, 0.00386434}, + fastembed.BGESmallZH: []float32{-0.01023294, 0.07634465, 0.0691722, -0.04458365, -0.03160762}, } for model, expected := range canonicalValues { - fe, err := NewFlagEmbedding(&InitOptions{ + fe, err := fastembed.NewFlagEmbedding(&fastembed.InitOptions{ Model: model, }) defer fe.Destroy() if err != nil { - T.Fatalf("Expected no error, got %v", err) + t.Fatalf("Expected no error, got %v", err) } input := []string{"hello world"} result, err := fe.Embed(input, 1) if err != nil { - T.Fatalf("Expected no error, got %v", err) + t.Fatalf("Expected no error, got %v", err) } if len(result) != len(input) { - T.Errorf("Expected result length %v, got %v", len(input), len(result)) + t.Errorf("Expected result length %v, got %v", len(input), len(result)) } epsilon := float64(1e-4) for i, v := range expected { - if math.Abs(float64(result[0][i]-v)) > float64(epsilon) { - T.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i]) + if math.Abs(float64(result[0][i]-v)) > epsilon { + t.Errorf("Element %d mismatch for %s: expected %.6f, got %.6f", i, model, v, result[0][i]) } } } - }