Skip to content

Commit

Permalink
switch to latest tokenizer and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RJKeevil authored and riccardopinosio committed Jul 17, 2024
1 parent bd661db commit 59b2996
Show file tree
Hide file tree
Showing 14 changed files with 62 additions and 90 deletions.
9 changes: 5 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
ARG GO_VERSION=1.22.3
ARG RUST_VERSION=1.78
ARG GO_VERSION=1.22.5
ARG RUST_VERSION=1.79
ARG ONNXRUNTIME_VERSION=1.18.0
ARG BUILD_PLATFORM=linux/amd64

ARG CGO_LDFLAGS="-L./usr/lib/libtokenizers.a"
#--- rust build of tokenizer ---

FROM --platform=$BUILD_PLATFORM rust:$RUST_VERSION AS tokenizer

RUN git clone https://github.com/knights-analytics/tokenizers -b main && \
RUN git clone https://github.com/knights-analytics/tokenizers -b rebase && \
cd tokenizers && \
cargo build --release

Expand All @@ -16,6 +16,7 @@ RUN git clone https://github.com/knights-analytics/tokenizers -b main && \
FROM --platform=$BUILD_PLATFORM public.ecr.aws/amazonlinux/amazonlinux:2023 AS hugot-build
ARG GO_VERSION
ARG ONNXRUNTIME_VERSION
ARG CGO_LDFLAGS

RUN dnf -y install gcc jq bash tar xz gzip glibc-static libstdc++ wget zip git && \
ln -s /usr/lib64/libstdc++.so.6 /usr/lib64/libstdc++.so && \
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,13 @@ Hugot can be used in two ways: as a library in your go application, or as a comm

To use Hugot as a library in your application, you will need the following dependencies on your system:

- the tokenizers.a file obtained from building the [tokenizer](https://github.com/Knights-Analytics/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it.
- the tokenizers.a file obtained from building the [tokenizer](https://github.com/daulet/tokenizers) go library (which is itself a fork of https://github.com/daulet/tokenizers). This file should be at /usr/lib/tokenizers.a so that hugot can load it.
- the onnxruntime.go file obtained from the onnxruntime project. This is dynamically linked by hugot and used by the onnxruntime inference library [onnxruntime_go](https://github.com/yalue/onnxruntime_go). This file should be at /usr/lib/onnxruntime.so or /usr/lib64/onnxruntime.so

You can get the libtokenizers.a in two ways. Assuming you have rust installed, you can compile the tokenizers library and get the required libtokenizers.a:

```
git clone https://github.com/Knights-Analytics/tokenizers -b main && \
git clone https://github.com/daulet/tokenizers -b main && \
cd tokenizers && \
cargo build --release
mv target/release/libtokenizers.a /usr/lib/libtokenizers.a
Expand Down
2 changes: 1 addition & 1 deletion cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func TestFeatureExtractionCli(t *testing.T) {
}
baseArgs := os.Args[0:1]

testModel := path.Join("../models", "KnightsAnalytics_all-MiniLM-L6-v2")
testModel := path.Join("../models", "sentence-transformers_all-MiniLM-L6-v2")

testDataDir := path.Join(os.TempDir(), "hugoTestData")
err := os.MkdirAll(testDataDir, os.ModePerm)
Expand Down
24 changes: 0 additions & 24 deletions downloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ import (
"time"

hfd "github.com/bodaay/HuggingFaceModelDownloader/hfdownloader"

util "github.com/knights-analytics/hugot/utils"
)

// DownloadOptions is a struct of options that can be passed to DownloadModel
Expand Down Expand Up @@ -153,25 +151,3 @@ func checkURL(client *http.Client, url string, authToken string) (bool, bool, er

return tokenizerFound, onnxFound, nil
}

func downloadModelIfNotExists(session *Session, modelName string, destination string) string {
modelNameFS := modelName
if strings.Contains(modelNameFS, ":") {
modelNameFS = strings.Split(modelName, ":")[0]
}
modelNameFS = path.Join(destination, strings.Replace(modelNameFS, "/", "_", -1))

fullModelPath := path.Join(destination, modelNameFS)
exists, err := util.FileSystem.Exists(context.Background(), fullModelPath)
if err != nil {
panic(err)
}
if exists {
return fullModelPath
}
fullModelPath, err = session.DownloadModel(modelName, destination, NewDownloadOptions())
if err != nil {
panic(err)
}
return fullModelPath
}
14 changes: 8 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,23 @@ go 1.20

replace github.com/viant/afsc => github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df

replace github.com/daulet/tokenizers => github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267

require (
github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c
github.com/daulet/tokenizers v0.8.0
github.com/json-iterator/go v1.1.12
github.com/knights-analytics/tokenizers v0.12.1
github.com/mattn/go-isatty v0.0.20
github.com/stretchr/testify v1.9.0
github.com/urfave/cli/v2 v2.27.2
github.com/viant/afs v1.25.1
github.com/viant/afsc v1.9.2
github.com/yalue/onnxruntime_go v1.10.0
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10
golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7
)

require (
github.com/aws/aws-sdk-go v1.53.12 // indirect
github.com/aws/aws-sdk-go v1.54.19 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fatih/color v1.17.0 // indirect
Expand All @@ -32,8 +34,8 @@ require (
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
golang.org/x/crypto v0.23.0 // indirect
golang.org/x/oauth2 v0.20.0 // indirect
golang.org/x/sys v0.20.0 // indirect
golang.org/x/crypto v0.25.0 // indirect
golang.org/x/oauth2 v0.21.0 // indirect
golang.org/x/sys v0.22.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
28 changes: 13 additions & 15 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
github.com/aws/aws-sdk-go v1.53.12 h1:8f8K+YaTy2qwtGwVIo2Ftq22UCH96xQAX7Q0lyZKDiA=
github.com/aws/aws-sdk-go v1.53.12/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/aws/aws-sdk-go v1.54.19 h1:tyWV+07jagrNiCcGRzRhdtVjQs7Vy41NwsuOcl0IbVI=
github.com/aws/aws-sdk-go v1.54.19/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c h1:3TPq2BhzOquTGmbS53KeGcM1yalBUb/4zQM1wmaINrE=
github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c/go.mod h1:p6JQ7mJjWx82F+SrFfj9RkoHlKEGXR4959uX/vkMbzE=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
Expand All @@ -21,8 +21,8 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df h1:rVna1iJaI7gj5RonGys0dZ0iLy7upULdcbRQd9F2qg8=
github.com/knights-analytics/afsc v0.0.0-20240425201009-7e46526445df/go.mod h1:yZo80n1EB2eMwmmec7BekX6clpd7uY+joUpDRIBbeYs=
github.com/knights-analytics/tokenizers v0.12.1 h1:5bIxk3SQKXIHKxlzAOmqPXgFeKE+LCvbXS3hpTgOAX4=
github.com/knights-analytics/tokenizers v0.12.1/go.mod h1:TD+zVXlFlS4QyP6/RN8SPSAKkT2hpMmF64WdrdbBfts=
github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267 h1:M2jdyK5zl/AUe1ZBLUWqAAjSu6LwF9ZFegk+UBMjVjY=
github.com/knights-analytics/tokenizers v0.0.0-20240717085127-ca3ae0687267/go.mod h1:tGnMdZthXdcWY6DGD07IygpwJqiPvG85FQUnhs/wSCs=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
Expand Down Expand Up @@ -54,19 +54,17 @@ github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGC
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
github.com/yalue/onnxruntime_go v1.10.0 h1:om1yzOQYv/4GlsSP5HIZvS6G3WF3THv4x5rhO5AFERU=
github.com/yalue/onnxruntime_go v1.10.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10 h1:vpzMC/iZhYFAjJzHU0Cfuq+w1vLLsF2vLkDrPjzKYck=
golang.org/x/exp v0.0.0-20240529005216-23cca8864a10/go.mod h1:XtvwrStGgqGPLc4cjQfWqZHG1YFdYs6swckp8vpsjnc=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/oauth2 v0.20.0 h1:4mQdhULixXKP1rwYBW0vAijoXnkTG0BLCDRzfe1idMo=
golang.org/x/oauth2 v0.20.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30=
golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M=
golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7 h1:wDLEX9a7YQoKdKNQt88rtydkqDxeGaBUTnIYc3iG/mA=
golang.org/x/exp v0.0.0-20240716175740-e3f259677ff7/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs=
golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI=
golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
Expand Down
8 changes: 4 additions & 4 deletions hugot.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,10 +258,8 @@ func NewPipeline[T pipelines.Pipeline](s *Session, pipelineConfig pipelines.Pipe
config := any(pipelineConfig).(pipelines.PipelineConfig[*pipelines.ZeroShotClassificationPipeline])
pipelineInitialised, err := pipelines.NewZeroShotClassificationPipeline(config, s.ortOptions)
if err != nil {
fmt.Println("error in if statement")
return pipeline, err
}
fmt.Println("config name:", config.Name)
s.zeroShotClassifcationPipelines[config.Name] = pipelineInitialised
pipeline = any(pipelineInitialised).(T)
default:
Expand Down Expand Up @@ -310,6 +308,7 @@ func (s *Session) Destroy() error {
s.featureExtractionPipelines.Destroy(),
s.tokenClassificationPipelines.Destroy(),
s.textClassificationPipelines.Destroy(),
s.zeroShotClassifcationPipelines.Destroy(),
s.ortOptions.Destroy(),
ort.DestroyEnvironment(),
)
Expand All @@ -324,9 +323,10 @@ func (s *Session) Destroy() error {
// the average time per onnxruntime inference batch call
func (s *Session) GetStats() []string {
// slices.Concat() is not implemented in experimental x/exp/slices package
return append(append(
return append(append(append(
s.tokenClassificationPipelines.GetStats(),
s.textClassificationPipelines.GetStats()...),
s.featureExtractionPipelines.GetStats()...,
s.featureExtractionPipelines.GetStats()...),
s.zeroShotClassifcationPipelines.GetStats()...,
)
}
31 changes: 14 additions & 17 deletions hugot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func TestFeatureExtractionPipelineValidation(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models")
modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2"
config := FeatureExtractionConfig{
ModelPath: modelPath,
Name: "testPipeline",
Expand All @@ -75,7 +75,7 @@ func TestFeatureExtractionPipeline(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "sentence-transformers/all-MiniLM-L6-v2", "./models")
modelPath := "./models/sentence-transformers_all-MiniLM-L6-v2"

config := FeatureExtractionConfig{
ModelPath: modelPath,
Expand Down Expand Up @@ -187,12 +187,10 @@ func TestFeatureExtractionPipeline(t *testing.T) {
}
pipelineToken, err := NewPipeline(session, configSentence)
check(t, err)
out, err := pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"})
_, err = pipelineToken.RunPipeline([]string{"Onnxruntime is a great inference backend"})
if err != nil {
t.FailNow()
}
fmt.Println(out)
// TODO: assert the result here
}

// Text classification
Expand All @@ -211,7 +209,9 @@ func TestTextClassificationPipeline(t *testing.T) {
errDestroy := session.Destroy()
check(t, errDestroy)
}(session)
modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models")
modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english"
modelPathMulti := "./models/SamLowe_roberta-base-go_emotions-onnx"

config := TextClassificationConfig{
ModelPath: modelPath,
Name: "testPipelineSimple",
Expand All @@ -222,7 +222,6 @@ func TestTextClassificationPipeline(t *testing.T) {
sentimentPipeline, err := NewPipeline(session, config)
check(t, err)

modelPathMulti := downloadModelIfNotExists(session, "SamLowe/roberta-base-go_emotions-onnx", "./models")
configMulti := TextClassificationConfig{
ModelPath: modelPathMulti,
Name: "testPipelineSimpleMulti",
Expand Down Expand Up @@ -408,7 +407,7 @@ func TestTextClassificationPipelineValidation(t *testing.T) {
err := session.Destroy()
check(t, err)
}(session)
modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-base-uncased-finetuned-sst-2-english", "./models")
modelPath := "./models/KnightsAnalytics_distilbert-base-uncased-finetuned-sst-2-english"

config := TextClassificationConfig{
ModelPath: modelPath,
Expand Down Expand Up @@ -449,7 +448,7 @@ func TestZeroShotClassificationPipeline(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "protectai/deberta-v3-base-zeroshot-v1-onnx", "./models")
modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx"

config := ZeroShotClassificationConfig{
ModelPath: modelPath,
Expand Down Expand Up @@ -661,8 +660,6 @@ func TestZeroShotClassificationPipeline(t *testing.T) {
assert.Equal(t, len(expectedResult), len(testResult))
assert.Equal(t, tt.expected.ClassificationOutputs[ind].Sequence, batchResult.ClassificationOutputs[ind].Sequence)
for i := range testResult {
fmt.Println(testResult[i].Key, expectedResult[i].Key)

assert.True(t, almostEqual(testResult[i].Value, expectedResult[i].Value))
}
}
Expand All @@ -677,7 +674,7 @@ func TestZeroShotClassificationPipelineValidation(t *testing.T) {
err := session.Destroy()
check(t, err)
}(session)
modelPath := downloadModelIfNotExists(session, "protectai/deberta-v3-base-zeroshot-v1-onnx", "./models")
modelPath := "./models/protectai_deberta-v3-base-zeroshot-v1-onnx"

config := TextClassificationConfig{
ModelPath: modelPath,
Expand Down Expand Up @@ -717,7 +714,7 @@ func TestTokenClassificationPipeline(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models")
modelPath := "./models/KnightsAnalytics_distilbert-NER"
configSimple := TokenClassificationConfig{
ModelPath: modelPath,
Name: "testPipelineSimple",
Expand Down Expand Up @@ -794,7 +791,7 @@ func TestTokenClassificationPipelineValidation(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models")
modelPath := "./models/KnightsAnalytics_distilbert-NER"
configSimple := TokenClassificationConfig{
ModelPath: modelPath,
Name: "testPipelineSimple",
Expand Down Expand Up @@ -835,7 +832,7 @@ func TestNoSameNamePipeline(t *testing.T) {
check(t, err)
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/distilbert-NER", "./models")
modelPath := "./models/KnightsAnalytics_distilbert-NER"
configSimple := TokenClassificationConfig{
ModelPath: modelPath,
Name: "testPipelineSimple",
Expand Down Expand Up @@ -930,7 +927,7 @@ func TestCuda(t *testing.T) {
}
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models")
modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2"
config := FeatureExtractionConfig{
ModelPath: modelPath,
Name: "benchmarkEmbedding",
Expand Down Expand Up @@ -969,7 +966,7 @@ func runBenchmarkEmbedding(strings *[]string, cuda bool) {
}
}(session)

modelPath := downloadModelIfNotExists(session, "KnightsAnalytics/all-MiniLM-L6-v2", "./models")
modelPath := "./models/KnightsAnalytics_all-MiniLM-L6-v2"
config := FeatureExtractionConfig{
ModelPath: modelPath,
Name: "benchmarkEmbedding",
Expand Down
4 changes: 2 additions & 2 deletions pipelines/featureExtraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (

ort "github.com/yalue/onnxruntime_go"

"github.com/daulet/tokenizers"
util "github.com/knights-analytics/hugot/utils"
"github.com/knights-analytics/tokenizers"
)

// FeatureExtractionPipeline A feature extraction pipeline is a go version of
Expand Down Expand Up @@ -131,7 +131,7 @@ func (p *FeatureExtractionPipeline) GetMetadata() PipelineMetadata {
OutputsInfo: []OutputInfo{
{
Name: p.OutputName,
Dimensions: []int64(p.Output.Dimensions),
Dimensions: p.Output.Dimensions,
},
},
}
Expand Down
2 changes: 1 addition & 1 deletion pipelines/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
"os"
"strings"

"github.com/knights-analytics/tokenizers"
"github.com/daulet/tokenizers"
ort "github.com/yalue/onnxruntime_go"

util "github.com/knights-analytics/hugot/utils"
Expand Down
4 changes: 2 additions & 2 deletions pipelines/textClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (

util "github.com/knights-analytics/hugot/utils"

"github.com/daulet/tokenizers"
jsoniter "github.com/json-iterator/go"
"github.com/knights-analytics/tokenizers"
ort "github.com/yalue/onnxruntime_go"
)

Expand Down Expand Up @@ -162,7 +162,7 @@ func (p *TextClassificationPipeline) GetMetadata() PipelineMetadata {
OutputsInfo: []OutputInfo{
{
Name: p.OutputsMeta[0].Name,
Dimensions: []int64(p.OutputsMeta[0].Dimensions),
Dimensions: p.OutputsMeta[0].Dimensions,
},
},
}
Expand Down
Loading

0 comments on commit 59b2996

Please sign in to comment.