From 38534cb9f5de89d36d5497ade9bd250c688a315a Mon Sep 17 00:00:00 2001 From: Rob Keevil Date: Mon, 25 Mar 2024 09:45:09 +0100 Subject: [PATCH] Allow all onnxruntime_go tuning options to be set in a new session --- .github/workflows/release.yaml | 4 +- README.md | 21 ++++- cmd/main.go | 8 +- go.mod | 14 +-- go.sum | 48 +++++----- hugot.go | 137 +++++++++++++++------------- hugot_test.go | 26 +++--- options.go | 64 +++++++++++++ pipelines/tokenClassification.go | 2 +- scripts/run-unit-tests-container.sh | 3 +- 10 files changed, 210 insertions(+), 117 deletions(-) create mode 100644 options.go diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 0c90f2b..10b2bee 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -17,7 +17,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: '1.20.0' + go-version: '1.22.1' - name: Checkout code uses: actions/checkout@v4 - name: Install dependencies @@ -73,4 +73,4 @@ jobs: with: artifacts: "libtokenizers.a, onnxruntime.so, hugot-cli-linux-amd64" generateReleaseNotes: true - skipIfReleaseExists: true \ No newline at end of file + skipIfReleaseExists: true diff --git a/README.md b/README.md index ac9b103..3af780c 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ For the golang developer or ML engineer who wants to run transformer piplines on ## What is already there -Currently we have implementations for the following transfomer pipelines: +Currently, we have implementations for the following transfomer pipelines: - [featureExtraction](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.FeatureExtractionPipeline) - [textClassification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextClassificationPipeline) (single label classification only) @@ -167,6 +167,23 @@ Note that the --model parameter can be: 1. the full path to a model to load 2. the name of a huggingface model. Hugot will first try to look for the model at $HOME/hugot, or will try to download the model from huggingface. +## Performance Tuning + +The library defaults to onnxruntime's default tuning settings. These are optimised for latency over throughput, and will attempt to parallelize single threaded calls to onnxruntime over multiple cores. + +For maximum throughput, it is best to call a single shared hugot pipeline from multiple goroutines (1 per core), using channels to pass the input data. In this scenario, the following settings will greatly increase inference throughput. + +```go +session, err := hugot.NewSession( + hugot.WithInterOpNumThreads(1), + hugot.WithIntraOpNumThreads(1), + hugot.WithCpuMemArena(false), + hugot.WithMemPattern(false), +) +``` + +InterOpNumThreads and IntraOpNumThreads constricts each goroutine's call to a single core, greatly reducing locking and cache penalties. Disabling CpuMemArena and MemPattern skips preallocation of some memory structures, increasing latency, but also throughput efficiency. + ## Contributing ### Development environment @@ -205,7 +222,7 @@ If you prefer to develop on bare metal, you will need to download the tokenizers ### Run the tests -The full test suite can be ran as follows. From the source folder: +The full test suite can be run as follows. From the source folder: ```bash make clean run-tests diff --git a/cmd/main.go b/cmd/main.go index 830c3c2..0fa841e 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -94,7 +94,7 @@ var runCommand = &cli.Command{ }, Action: func(ctx *cli.Context) error { - var onnxLibraryPathOpt hugot.SessionOption + var opts []hugot.WithOption if modelsDir == "" { userDir, err := os.UserHomeDir() @@ -105,17 +105,17 @@ var runCommand = &cli.Command{ } if sharedLibraryPath != "" { - onnxLibraryPathOpt = hugot.WithOnnxLibraryPath(sharedLibraryPath) + opts = append(opts, hugot.WithOnnxLibraryPath(sharedLibraryPath)) } else { homeDir, err := os.UserHomeDir() if err != nil { if exists, err := util.FileSystem.Exists(ctx.Context, path.Join(homeDir, "lib", "hugot", "onnxruntime.so")); err != nil && exists { - onnxLibraryPathOpt = hugot.WithOnnxLibraryPath(path.Join(homeDir, "lib", "hugot", "onnxruntime.so")) + opts = append(opts, hugot.WithOnnxLibraryPath(path.Join(homeDir, "lib", "hugot", "onnxruntime.so"))) } } } - session, err := hugot.NewSession(onnxLibraryPathOpt) + session, err := hugot.NewSession(opts...) if err != nil { return err } diff --git a/go.mod b/go.mod index 08ad91f..3a68505 100644 --- a/go.mod +++ b/go.mod @@ -15,13 +15,13 @@ require ( ) require ( - cloud.google.com/go/storage v1.39.0 // indirect - github.com/aws/aws-sdk-go v1.51.2 // indirect - github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect + cloud.google.com/go/iam v1.1.7 // indirect + cloud.google.com/go/storage v1.39.1 // indirect + github.com/aws/aws-sdk-go v1.51.6 // indirect + github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/fatih/color v1.16.0 // indirect github.com/go-errors/errors v1.5.1 // indirect - github.com/golang/protobuf v1.5.4 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect @@ -35,7 +35,9 @@ require ( golang.org/x/net v0.22.0 // indirect golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect - google.golang.org/api v0.169.0 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/api v0.171.0 // indirect + google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 6e23b86..166da03 100644 --- a/go.sum +++ b/go.sum @@ -1,19 +1,19 @@ -cloud.google.com/go v0.112.0 h1:tpFCD7hpHFlQ8yPwT3x+QeXqc2T6+n6T+hmABHfDUSM= -cloud.google.com/go v0.112.0/go.mod h1:3jEEVwZ/MHU4djK5t5RHuKOA/GbLddgTdVubX1qnPD4= -cloud.google.com/go/compute v1.24.0 h1:phWcR2eWzRJaL/kOiJwfFsPs4BaKq1j6vnpZrc1YlVg= -cloud.google.com/go/compute v1.24.0/go.mod h1:kw1/T+h/+tK2LJK0wiPPx1intgdAM3j/g3hFDlscY40= +cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM= +cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4= +cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU= +cloud.google.com/go/compute v1.25.1/go.mod h1:oopOIR53ly6viBYxaDhBfJwzUAxf1zE//uf3IB011ls= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= -cloud.google.com/go/iam v1.1.6 h1:bEa06k05IO4f4uJonbB5iAgKTPpABy1ayxaIZV/GHVc= -cloud.google.com/go/iam v1.1.6/go.mod h1:O0zxdPeGBoFdWW3HWmBxJsk0pfvNM/p/qa82rWOGTwI= -cloud.google.com/go/storage v1.39.0 h1:brbjUa4hbDHhpQf48tjqMaXEV+f1OGoaTmQau9tmCsA= -cloud.google.com/go/storage v1.39.0/go.mod h1:OAEj/WZwUYjA3YHQ10/YcN9ttGuEpLwvaoyBXIPikEk= -github.com/aws/aws-sdk-go v1.51.2 h1:Ruwgz5aqIXin5Yfcgc+PCzoqW5tEGb9aDL/JWDsre7k= -github.com/aws/aws-sdk-go v1.51.2/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM= +cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA= +cloud.google.com/go/storage v1.39.1 h1:MvraqHKhogCOTXTlct/9C3K3+Uy2jBmFYb3/Sp6dVtY= +cloud.google.com/go/storage v1.39.1/go.mod h1:xK6xZmxZmo+fyP7+DEF6FhNc24/JAe95OLyOHCXFH1o= +github.com/aws/aws-sdk-go v1.51.6 h1:Ld36dn9r7P9IjU8WZSaswQ8Y/XUCRpewim5980DwYiU= +github.com/aws/aws-sdk-go v1.51.6/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c h1:3TPq2BhzOquTGmbS53KeGcM1yalBUb/4zQM1wmaINrE= github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c/go.mod h1:p6JQ7mJjWx82F+SrFfj9RkoHlKEGXR4959uX/vkMbzE= -github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4= +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -39,8 +39,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= -github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA= -github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc= +github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA= +github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= @@ -115,18 +115,18 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -google.golang.org/api v0.169.0 h1:QwWPy71FgMWqJN/l6jVlFHUa29a7dcUy02I8o799nPY= -google.golang.org/api v0.169.0/go.mod h1:gpNOiMA2tZ4mf5R9Iwf4rK/Dcz0fbdIgWYWVoxmsyLg= +google.golang.org/api v0.171.0 h1:w174hnBPqut76FzW5Qaupt7zY8Kql6fiVjgys4f58sU= +google.golang.org/api v0.171.0/go.mod h1:Hnq5AHm4OTMt2BUVjael2CWZFD6vksJdWCWiUAmjC9o= google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 h1:9+tzLLstTlPTRyJTh+ah5wIMsBW5c4tQwGTN3thOW9Y= -google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:mqHbVIp48Muh7Ywss/AD6I5kNVKZMmAa/QEW58Gxp2s= -google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c h1:9g7erC9qu44ks7UK4gDNlnk4kOxZG707xKm4jVniy6o= -google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78 h1:Xs9lu+tLXxLIfuci70nG4cpwaRC+mRQPUL7LoIeDJC4= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs= -google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk= -google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= +google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237 h1:PgNlNSx2Nq2/j4juYzQBG0/Zdr+WP4z5N01Vk4VYBCY= +google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237/go.mod h1:9sVD8c25Af3p0rGs7S7LLsxWKFiJt/65LdSyqXBkX/Y= +google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:RFiFrvy37/mpSpdySBDrUdipW/dHwsRwh3J3+A9VgT4= +google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY= +google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk= +google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/hugot.go b/hugot.go index aac93c3..aabb3eb 100644 --- a/hugot.go +++ b/hugot.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" + util "github.com/knights-analytics/hugot/utils" "slices" "github.com/knights-analytics/hugot/pipelines" - util "github.com/knights-analytics/hugot/utils" ort "github.com/yalue/onnxruntime_go" ) @@ -45,61 +45,16 @@ func (m pipelineMap[T]) GetStats() []string { return stats } -type SessionOption func() error - -func WithOnnxLibraryPath(ortLibraryPath string) SessionOption { - return func() error { - if ortLibraryPath == "" { - return fmt.Errorf("path to the ort library cannot be empty") - } - ortPathExists, err := util.FileSystem.Exists(context.Background(), ortLibraryPath) - if err != nil { - return err - } - if !ortPathExists { - return fmt.Errorf("cannot find the ort library at: %s", ortLibraryPath) - } - ort.SetSharedLibraryPath(ortLibraryPath) - return nil - } -} - -func (s *Session) setSessionOptions() error { - options, optionsError := ort.NewSessionOptions() - if optionsError != nil { - return optionsError - } - err1 := options.SetIntraOpNumThreads(1) - if err1 != nil { - return err1 - } - err2 := options.SetInterOpNumThreads(1) - if err2 != nil { - return err2 - } - err3 := options.SetCpuMemArena(true) - if err3 != nil { - return err3 - } - s.ortOptions = options - return nil -} - // NewSession is the main entrypoint to hugot and is used to create a new hugot session object. // ortLibraryPath should be the path to onnxruntime.so. If it's the empty string, hugot will try // to load the library from the default location (/usr/lib/onnxruntime.so). // A new session must be destroyed when it's not needed anymore to avoid memory leaks. See the Destroy method. // Note moreover that there can be at most one hugot session active (i.e., the Session object is a singleton), // otherwise NewSession will return an error. -func NewSession(options ...SessionOption) (*Session, error) { +func NewSession(options ...WithOption) (*Session, error) { if ort.IsInitialized() { - return nil, errors.New("another session is currently active and only one session can be active at one time") - } else { - err := ort.InitializeEnvironment() - if err != nil { - return nil, err - } + return nil, errors.New("another session is currently active, and only one session can be active at one time") } session := &Session{ @@ -108,29 +63,81 @@ func NewSession(options ...SessionOption) (*Session, error) { textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{}, } - telemetryErr := ort.DisableTelemetry() - if telemetryErr != nil { - destroyErr := session.Destroy() - return nil, errors.Join(telemetryErr, destroyErr) + // set session options and initialise + if initialised, err := session.initialiseORT(options...); err != nil { + if initialised { + destroyErr := session.Destroy() + return nil, errors.Join(err, destroyErr) + } + return nil, err + } + + return session, nil +} + +func (s *Session) initialiseORT(options ...WithOption) (bool, error) { + + // Collect options into a struct, so they can be applied in the correct order later + o := &ortOptions{} + for _, option := range options { + option(o) } - // set session options - optionsErr := session.setSessionOptions() - if optionsErr != nil { - destroyErr := session.Destroy() - return nil, errors.Join(optionsErr, destroyErr) + // Set pre-initialisation options + if o.libraryPath != "" { + ortPathExists, err := util.FileSystem.Exists(context.Background(), o.libraryPath) + if err != nil { + return false, err + } + if !ortPathExists { + return false, fmt.Errorf("cannot find the ort library at: %s", o.libraryPath) + } + ort.SetSharedLibraryPath(o.libraryPath) + } + + // Start OnnxRuntime + if err := ort.InitializeEnvironment(); err != nil { + return false, err } - for _, opt := range options { - if opt != nil { - optSetErr := opt() - if optSetErr != nil { - destroyErr := session.Destroy() - return nil, errors.Join(optSetErr, destroyErr) - } + if o.telemetry { + if err := ort.EnableTelemetry(); err != nil { + return true, err + } + } else { + if err := ort.DisableTelemetry(); err != nil { + return true, err } } - return session, nil + + // Create session options for use in all pipelines + sessionOptions, optionsError := ort.NewSessionOptions() + if optionsError != nil { + return true, optionsError + } + if o.intraOpNumThreads != 0 { + if err := sessionOptions.SetIntraOpNumThreads(o.intraOpNumThreads); err != nil { + return true, err + } + } + if o.interOpNumThreads != 0 { + if err := sessionOptions.SetInterOpNumThreads(o.interOpNumThreads); err != nil { + return true, err + } + } + if !o.cpuMemArenaSet { + if err := sessionOptions.SetCpuMemArena(o.cpuMemArena); err != nil { + return true, err + } + } + if !o.memPatternSet { + if err := sessionOptions.SetMemPattern(o.memPattern); err != nil { + return true, err + } + } + + s.ortOptions = sessionOptions + return true, nil } // NewTokenClassificationPipeline creates and returns a new token classification pipeline object. diff --git a/hugot_test.go b/hugot_test.go index 64c7e1d..1c94e53 100644 --- a/hugot_test.go +++ b/hugot_test.go @@ -22,7 +22,7 @@ var tokenExpectedByte []byte var resultsByte []byte // use the system library for the tests -var onnxruntimeSharedLibrary = "/usr/lib64/onnxruntime.so" +const onnxRuntimeSharedLibrary = "/usr/lib64/onnxruntime.so" // test download validation @@ -37,7 +37,14 @@ func TestDownloadValidation(t *testing.T) { // Text classification func TestTextClassificationPipeline(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession( + WithOnnxLibraryPath(onnxRuntimeSharedLibrary), + WithTelemetry(), + WithCpuMemArena(true), + WithMemPattern(true), + WithIntraOpNumThreads(1), + WithInterOpNumThreads(1), + ) check(t, err) defer func(session *Session) { err := session.Destroy() @@ -90,13 +97,8 @@ func TestTextClassificationPipeline(t *testing.T) { session.GetStats() } -func TestNewSessionErrors(t *testing.T) { - _, err := NewSession(WithOnnxLibraryPath("")) - assert.Error(t, err) -} - func TestTextClassificationPipelineValidation(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) check(t, err) defer func(session *Session) { err := session.Destroy() @@ -124,7 +126,7 @@ func TestTextClassificationPipelineValidation(t *testing.T) { // Token classification func TestTokenClassificationPipeline(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) check(t, err) defer func(session *Session) { err := session.Destroy() @@ -185,7 +187,7 @@ func TestTokenClassificationPipeline(t *testing.T) { } func TestTokenClassificationPipelineValidation(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) check(t, err) defer func(session *Session) { err := session.Destroy() @@ -215,7 +217,7 @@ func TestTokenClassificationPipelineValidation(t *testing.T) { // feature extraction func TestFeatureExtractionPipeline(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) check(t, err) defer func(session *Session) { err := session.Destroy() @@ -289,7 +291,7 @@ func TestFeatureExtractionPipeline(t *testing.T) { } func TestFeatureExtractionPipelineValidation(t *testing.T) { - session, err := NewSession(WithOnnxLibraryPath(onnxruntimeSharedLibrary)) + session, err := NewSession(WithOnnxLibraryPath(onnxRuntimeSharedLibrary)) check(t, err) defer func(session *Session) { err := session.Destroy() diff --git a/options.go b/options.go new file mode 100644 index 0000000..4d337fb --- /dev/null +++ b/options.go @@ -0,0 +1,64 @@ +package hugot + +type ortOptions struct { + libraryPath string + telemetry bool + intraOpNumThreads int + interOpNumThreads int + cpuMemArena bool + cpuMemArenaSet bool + memPattern bool + memPatternSet bool +} + +// WithOption is the interface for all option functions +type WithOption func(o *ortOptions) + +// WithOnnxLibraryPath Use this function to set the path to the "onnxruntime.so" or "onnxruntime.dll" function. +// By default, it will be set to "onnxruntime.so" on non-Windows systems, and "onnxruntime.dll" on Windows. +func WithOnnxLibraryPath(ortLibraryPath string) WithOption { + return func(o *ortOptions) { + o.libraryPath = ortLibraryPath + } +} + +// WithTelemetry Enables telemetry events for the onnxruntime environment. Default is off. +func WithTelemetry() WithOption { + return func(o *ortOptions) { + o.telemetry = true + } +} + +// WithIntraOpNumThreads Sets the number of threads used to parallelize execution within onnxruntime +// graph nodes. If unspecified, onnxruntime uses the number of physical CPU cores. +func WithIntraOpNumThreads(numThreads int) WithOption { + return func(o *ortOptions) { + o.intraOpNumThreads = numThreads + } +} + +// WithInterOpNumThreads Sets the number of threads used to parallelize execution across separate +// onnxruntime graph nodes. If unspecified, onnxruntime uses the number of physical CPU cores. +func WithInterOpNumThreads(numThreads int) WithOption { + return func(o *ortOptions) { + o.interOpNumThreads = numThreads + } +} + +// WithCpuMemArena Enable/Disable the usage of the memory arena on CPU. +// Arena may pre-allocate memory for future usage. Default is true. +func WithCpuMemArena(enable bool) WithOption { + return func(o *ortOptions) { + o.cpuMemArena = enable + o.cpuMemArenaSet = true + } +} + +// WithMemPattern Enable/Disable the memory pattern optimization. +// If this is enabled memory is preallocated if all shapes are known. Default is true. +func WithMemPattern(enable bool) WithOption { + return func(o *ortOptions) { + o.memPattern = enable + o.memPatternSet = true + } +} diff --git a/pipelines/tokenClassification.go b/pipelines/tokenClassification.go index 70febfc..2745efb 100644 --- a/pipelines/tokenClassification.go +++ b/pipelines/tokenClassification.go @@ -289,7 +289,7 @@ func (p *TokenClassificationPipeline) getTag(entityName string) (string, string) bi = "I" tag = entityName[2:] } else { - // defaulting to I if string is not in B- I- format + // defaulting to "I" if string is not in B- I- format bi = "I" tag = entityName } diff --git a/scripts/run-unit-tests-container.sh b/scripts/run-unit-tests-container.sh index 88ecd96..f5651f5 100755 --- a/scripts/run-unit-tests-container.sh +++ b/scripts/run-unit-tests-container.sh @@ -5,6 +5,7 @@ set -e cd /build && \ mkdir -p /test/unit && \ go run ./testData/downloadModels.go && \ -gotestsum --junitfile=/test/unit/unit.xml --jsonfile=/test/unit/unit.json -- -coverprofile=/test/unit/cover.out -race -covermode=atomic ./... +gotestsum --junitfile=/test/unit/unit.xml --jsonfile=/test/unit/unit.json -- -coverprofile=/test/unit/cover.out.pre -race -covermode=atomic ./... +cat /test/unit/cover.out.pre | grep -v "_downloadModels.go" > /test/unit/cover.out echo Done.