diff --git a/Dockerfile b/Dockerfile index 162acec..2abacf8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,13 +1,13 @@ -ARG GO_VERSION=1.22.3 -ARG RUST_VERSION=1.78 +ARG GO_VERSION=1.22.5 +ARG RUST_VERSION=1.79 ARG ONNXRUNTIME_VERSION=1.18.0 ARG BUILD_PLATFORM=linux/amd64 - +ARG CGO_LDFLAGS="-L./usr/lib/libtokenizers.a" #--- rust build of tokenizer --- FROM --platform=$BUILD_PLATFORM rust:$RUST_VERSION AS tokenizer -RUN git clone https://github.com/knights-analytics/tokenizers -b main && \ +RUN git clone https://github.com/knights-analytics/tokenizers -b rebase && \ cd tokenizers && \ cargo build --release @@ -16,6 +16,7 @@ RUN git clone https://github.com/knights-analytics/tokenizers -b main && \ FROM --platform=$BUILD_PLATFORM public.ecr.aws/amazonlinux/amazonlinux:2023 AS hugot-build ARG GO_VERSION ARG ONNXRUNTIME_VERSION +ARG CGO_LDFLAGS RUN dnf -y install gcc jq bash tar xz gzip glibc-static libstdc++ wget zip git && \ ln -s /usr/lib64/libstdc++.so.6 /usr/lib64/libstdc++.so && \ diff --git a/README.md b/README.md index 7556711..e0cc1d7 100644 --- a/README.md +++ b/README.md @@ -83,13 +83,13 @@ Hugot can be used in two ways: as a library in your go application, or as a comm To use Hugot as a library in your application, you will need the following dependencies on your system: -- the tokenizers.a file obtained from building the [tokenizer](https://github.com/Knights-Analytics/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it. +- the tokenizers.a file obtained from building the [tokenizer](https://github.com/daulet/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it. - the onnxruntime.go file obtained from the onnxruntime project. This is dynamically linked by hugot and used by the onnxruntime inference library [onnxruntime_go](https://github.com/yalue/onnxruntime_go). This file should be at /usr/lib/onnxruntime.so or /usr/lib64/onnxruntime.so You can get the libtokenizers.a in two ways. Assuming you have rust installed, you can compile the tokenizers library and get the required libtokenizers.a: ``` -git clone https://github.com/Knights-Analytics/tokenizers -b main && \ +git clone https://github.com/daulet/tokenizers -b main && \ cd tokenizers && \ cargo build --release mv target/release/libtokenizers.a /usr/lib/libtokenizers.a diff --git a/cmd/main_test.go b/cmd/main_test.go index 829d859..690cf0b 100644 --- a/cmd/main_test.go +++ b/cmd/main_test.go @@ -87,7 +87,7 @@ func TestFeatureExtractionCli(t *testing.T) { } baseArgs := os.Args[0:1] - testModel := path.Join("../models", "KnightsAnalytics_all-MiniLM-L6-v2") + testModel := path.Join("../models", "sentence-transformers_all-MiniLM-L6-v2") testDataDir := path.Join(os.TempDir(), "hugoTestData") err := os.MkdirAll(testDataDir, os.ModePerm) diff --git a/downloader.go b/downloader.go index 8edcdca..1cfbd6b 100644 --- a/downloader.go +++ b/downloader.go @@ -14,8 +14,6 @@ import ( "time" hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader" - - util "github.com/knights-analytics/hugot/utils" ) // DownloadOptions is a struct of options that can be passed to DownloadModel @@ -153,25 +151,3 @@ func checkURL(client *http.Client, url string, authToken string) (bool, bool, er return tokenizerFound, onnxFound, nil } - -func downloadModelIfNotExists(session *Session, modelName string, destination string) string { - modelNameFS := modelName - if strings.Contains(modelNameFS, ":") { - modelNameFS = strings.Split(modelName, ":")[0] - } - modelNameFS = path.Join(destination, strings.Replace(modelNameFS, "/", "_", -1)) - - fullModelPath := path.Join(destination, modelNameFS) - exists, err := util.FileSystem.Exists(context.Background(), fullModelPath) - if err != nil { - panic(err) - } - if exists { - return fullModelPath - } - fullModelPath, err = session.DownloadModel(modelName, destination, NewDownloadOptions()) - if err != nil { - panic(err) - } - return fullModelPath -} diff --git a/go.mod b/go.mod index 67cfd2c..bf876e9 100644 --- a/go.mod +++ b/go.mod @@ -4,21 +4,23 @@ go 1.20 replace github.com/viant/afsc => github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df +replace github.com/daulet/tokenizers => github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267 + require ( github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c + github.com/daulet/tokenizers v0.8.0 github.com/json-iterator/go v1.1.12 - github.com/knights-analytics/tokenizers v0.12.1 github.com/mattn/go-isatty v0.0.20 github.com/stretchr/testify v1.9.0 github.com/urfave/cli/v2 v2.27.2 github.com/viant/afs v1.25.1 github.com/viant/afsc v1.9.2 github.com/yalue/onnxruntime_go v1.10.0 - golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 + golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 ) require ( - github.com/aws/aws-sdk-go v1.53.12 // indirect + github.com/aws/aws-sdk-go v1.54.19 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.17.0 // indirect @@ -32,8 +34,8 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/oauth2 v0.20.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/crypto v0.25.0 // indirect + golang.org/x/oauth2 v0.21.0 // indirect + golang.org/x/sys v0.22.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 4b8b1d4..826cfcf 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/aws/aws-sdk-go v1.53.12 h1:8f8K+YaTy2qwtGwVIo2Ftq22UCH96xQAX7Q0lyZKDiA= -github.com/aws/aws-sdk-go v1.53.12/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go v1.54.19 h1:tyWV+07jagrNiCcGRzRhdtVjQs7Vy41NwsuOcl0IbVI= +github.com/aws/aws-sdk-go v1.54.19/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c h1:3TPq2BhzOquTGmbS53KeGcM1yalBUb/4zQM1wmaINrE= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c/go.mod h1:p6JQ7mJjWx82F+SrFfj9RkoHlKEGXR4959uX/vkMbzE= github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= @@ -21,8 +21,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df h1:rVna1iJaI7gj5RonGys0dZ0iLy7upULdcbRQd9F2qg8= github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df/go.mod h1:yZo80n1EB2eMwmmec7BekX6clpd7uY+joUpDRIBbeYs= -github.com/knights-analytics/tokenizers v0.12.1 h1:5bIxk3SQKXIHKxlzAOmqPXgFeKE+LCvbXS3hpTgOAX4= -github.com/knights-analytics/tokenizers v0.12.1/go.mod h1:TD+zVXlFlS4QyP6/RN8SPSAKkT2hpMmF64WdrdbBfts= +github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267 h1:M2jdyK5zl/AUe1ZBLUWqAAjSu6LwF9ZFegk+UBMjVjY= +github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= @@ -54,19 +54,17 @@ github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGC github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yalue/onnxruntime_go v1.10.0 h1:om1yzOQYv/4GlsSP5HIZvS6G3WF3THv4x5rhO5AFERU= github.com/yalue/onnxruntime_go v1.10.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 h1:vpzMC/iZhYFAjJzHU0Cfuq+w1vLLsF2vLkDrPjzKYck= -golang.org/x/exp v0.0.0-20240529005216-23cca8864a10/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc= -golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= -golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo= -golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= +golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 h1:wDLEX9a7YQoKdKNQt88rtydkqDxeGaBUTnIYc3iG/mA= +golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= +golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= +golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= +golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/hugot.go b/hugot.go index c100b93..1d25bed 100644 --- a/hugot.go +++ b/hugot.go @@ -258,10 +258,8 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe config := any(pipelineConfig).(pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline]) pipelineInitialised, err := pipelines.NewZeroShotClassificationPipeline(config, s.ortOptions) if err != nil { - fmt.Println("error in if statement") return pipeline, err } - fmt.Println("config name:", config.Name) s.zeroShotClassifcationPipelines[config.Name] = pipelineInitialised pipeline = any(pipelineInitialised).(T) default: @@ -310,6 +308,7 @@ func (s *Session) Destroy() error { s.featureExtractionPipelines.Destroy(), s.tokenClassificationPipelines.Destroy(), s.textClassificationPipelines.Destroy(), + s.zeroShotClassifcationPipelines.Destroy(), s.ortOptions.Destroy(), ort.DestroyEnvironment(), ) @@ -324,9 +323,10 @@ func (s *Session) Destroy() error { // the average time per onnxruntime inference batch call func (s *Session) GetStats() []string { // slices.Concat() is not implemented in experimental x/exp/slices package - return append(append( + return append(append(append( s.tokenClassificationPipelines.GetStats(), s.textClassificationPipelines.GetStats()...), - s.featureExtractionPipelines.GetStats()..., + s.featureExtractionPipelines.GetStats()...), + s.zeroShotClassifcationPipelines.GetStats()..., ) } diff --git a/hugot_test.go b/hugot_test.go index 314accf..75a4770 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -49,7 +49,7 @@ func TestFeatureExtractionPipelineValidation(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, Name: "testPipeline", @@ -75,7 +75,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "sentence-transformers/all-MiniLM-L6-v2", "./models") + modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, @@ -187,12 +187,10 @@ func TestFeatureExtractionPipeline(t *testing.T) { } pipelineToken, err := NewPipeline(session, configSentence) check(t, err) - out, err := pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"}) + _, err = pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"}) if err != nil { t.FailNow() } - fmt.Println(out) - // TODO: assert the result here } // Text classification @@ -211,7 +209,9 @@ func TestTextClassificationPipeline(t *testing.T) { errDestroy := session.Destroy() check(t, errDestroy) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english" + modelPathMulti := "./models/SamLowe_roberta-base-go_emotions-onnx" + config := TextClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -222,7 +222,6 @@ func TestTextClassificationPipeline(t *testing.T) { sentimentPipeline, err := NewPipeline(session, config) check(t, err) - modelPathMulti := downloadModelIfNotExists(session, "SamLowe/roberta-base-go_emotions-onnx", "./models") configMulti := TextClassificationConfig{ ModelPath: modelPathMulti, Name: "testPipelineSimpleMulti", @@ -408,7 +407,7 @@ func TestTextClassificationPipelineValidation(t *testing.T) { err := session.Destroy() check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english" config := TextClassificationConfig{ ModelPath: modelPath, @@ -449,7 +448,7 @@ func TestZeroShotClassificationPipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "protectai/deberta-v3-base-zeroshot-v1-onnx", "./models") + modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx" config := ZeroShotClassificationConfig{ ModelPath: modelPath, @@ -661,8 +660,6 @@ func TestZeroShotClassificationPipeline(t *testing.T) { assert.Equal(t, len(expectedResult), len(testResult)) assert.Equal(t, tt.expected.ClassificationOutputs[ind].Sequence, batchResult.ClassificationOutputs[ind].Sequence) for i := range testResult { - fmt.Println(testResult[i].Key, expectedResult[i].Key) - assert.True(t, almostEqual(testResult[i].Value, expectedResult[i].Value)) } } @@ -677,7 +674,7 @@ func TestZeroShotClassificationPipelineValidation(t *testing.T) { err := session.Destroy() check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "protectai/deberta-v3-base-zeroshot-v1-onnx", "./models") + modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx" config := TextClassificationConfig{ ModelPath: modelPath, @@ -717,7 +714,7 @@ func TestTokenClassificationPipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -794,7 +791,7 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -835,7 +832,7 @@ func TestNoSameNamePipeline(t *testing.T) { check(t, err) }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models") + modelPath := "./models/KnightsAnalytics_distilbert-NER" configSimple := TokenClassificationConfig{ ModelPath: modelPath, Name: "testPipelineSimple", @@ -930,7 +927,7 @@ func TestCuda(t *testing.T) { } }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, Name: "benchmarkEmbedding", @@ -969,7 +966,7 @@ func runBenchmarkEmbedding(strings *[]string, cuda bool) { } }(session) - modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models") + modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2" config := FeatureExtractionConfig{ ModelPath: modelPath, Name: "benchmarkEmbedding", diff --git a/pipelines/featureExtraction.go b/pipelines/featureExtraction.go index 54f86fe..bd9a868 100644 --- a/pipelines/featureExtraction.go +++ b/pipelines/featureExtraction.go @@ -10,8 +10,8 @@ import ( ort "github.com/yalue/onnxruntime_go" + "github.com/daulet/tokenizers" util "github.com/knights-analytics/hugot/utils" - "github.com/knights-analytics/tokenizers" ) // FeatureExtractionPipeline A feature extraction pipeline is a go version of @@ -131,7 +131,7 @@ func (p *FeatureExtractionPipeline) GetMetadata() PipelineMetadata { OutputsInfo: []OutputInfo{ { Name: p.OutputName, - Dimensions: []int64(p.Output.Dimensions), + Dimensions: p.Output.Dimensions, }, }, } diff --git a/pipelines/pipeline.go b/pipelines/pipeline.go index 0cc2bb3..cc7cd40 100644 --- a/pipelines/pipeline.go +++ b/pipelines/pipeline.go @@ -8,7 +8,7 @@ import ( "os" "strings" - "github.com/knights-analytics/tokenizers" + "github.com/daulet/tokenizers" ort "github.com/yalue/onnxruntime_go" util "github.com/knights-analytics/hugot/utils" diff --git a/pipelines/textClassification.go b/pipelines/textClassification.go index 92fc8e8..5435d41 100644 --- a/pipelines/textClassification.go +++ b/pipelines/textClassification.go @@ -9,8 +9,8 @@ import ( util "github.com/knights-analytics/hugot/utils" + "github.com/daulet/tokenizers" jsoniter "github.com/json-iterator/go" - "github.com/knights-analytics/tokenizers" ort "github.com/yalue/onnxruntime_go" ) @@ -162,7 +162,7 @@ func (p *TextClassificationPipeline) GetMetadata() PipelineMetadata { OutputsInfo: []OutputInfo{ { Name: p.OutputsMeta[0].Name, - Dimensions: []int64(p.OutputsMeta[0].Dimensions), + Dimensions: p.OutputsMeta[0].Dimensions, }, }, } diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go index 6c0386e..7128e9d 100644 --- a/pipelines/tokenClassification.go +++ b/pipelines/tokenClassification.go @@ -13,8 +13,8 @@ import ( util "github.com/knights-analytics/hugot/utils" + "github.com/daulet/tokenizers" jsoniter "github.com/json-iterator/go" - "github.com/knights-analytics/tokenizers" ) // TokenClassificationPipeline is a go version of huggingface tokenClassificationPipeline. @@ -166,7 +166,7 @@ func (p *TokenClassificationPipeline) GetMetadata() PipelineMetadata { OutputsInfo: []OutputInfo{ { Name: p.OutputsMeta[0].Name, - Dimensions: []int64(p.OutputsMeta[0].Dimensions), + Dimensions: p.OutputsMeta[0].Dimensions, }, }, } diff --git a/pipelines/zeroShotClassification.go b/pipelines/zeroShotClassification.go index 0bb8e77..9743013 100644 --- a/pipelines/zeroShotClassification.go +++ b/pipelines/zeroShotClassification.go @@ -15,8 +15,8 @@ import ( util "github.com/knights-analytics/hugot/utils" ort "github.com/yalue/onnxruntime_go" + "github.com/daulet/tokenizers" jsoniter "github.com/json-iterator/go" - "github.com/knights-analytics/tokenizers" ) /** @@ -207,17 +207,14 @@ func NewZeroShotClassificationPipeline(config PipelineConfig[*ZeroShotClassifica } } - // TODO: figure out logging - // if pipeline.entailmentID == -1 { - // fmt.Println("Failed to determine `entailment` label id from the id2label mapping in the model config. Setting to -1. Define a descriptive id2labelmapping in the model config to ensure correct outputs") - // } - configPath1 := util.PathJoinSafe(pipeline.ModelPath, "special_tokens_map.json") file, err := os.Open(configPath1) if err != nil { return nil, fmt.Errorf("cannot read special_tokens_map.json at %s", pipeline.ModelPath) } - defer file.Close() + defer func() { + err = file.Close() + }() byteValue, _ := io.ReadAll(file) var result map[string]interface{} @@ -269,7 +266,7 @@ func NewZeroShotClassificationPipeline(config PipelineConfig[*ZeroShotClassifica pipeline.PipelineTimings = &timings{} pipeline.TokenizerTimings = &timings{} - return pipeline, nil + return pipeline, err } func (p *ZeroShotClassificationPipeline) Preprocess(batch *PipelineBatch, inputs []string) error { @@ -481,7 +478,7 @@ func (p *ZeroShotClassificationPipeline) GetMetadata() PipelineMetadata { OutputsInfo: []OutputInfo{ { Name: p.OutputsMeta[0].Name, - Dimensions: []int64(p.OutputsMeta[0].Dimensions), + Dimensions: p.OutputsMeta[0].Dimensions, }, }, } diff --git a/testData/downloadModels.go b/testData/downloadModels.go index a54fa21..3dc39f8 100644 --- a/testData/downloadModels.go +++ b/testData/downloadModels.go @@ -31,7 +31,8 @@ func main() { } downloadOptions := hugot.NewDownloadOptions() for _, modelName := range []string{ - "KnightsAnalytics/all-MiniLM-L6-v2", + "sentence-transformers/all-MiniLM-L6-v2", + "protectai/deberta-v3-base-zeroshot-v1-onnx", "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "KnightsAnalytics/distilbert-NER", "SamLowe/roberta-base-go_emotions-onnx"} {