Skip to content

Commit

Permalink
Allow all onnxruntime_go tuning options to be set in a new session
Browse files Browse the repository at this point in the history
  • Loading branch information
RJKeevil committed Mar 25, 2024
1 parent 6835ee7 commit 38534cb
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 117 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v2
with:
go-version: '1.20.0'
go-version: '1.22.1'
- name: Checkout code
uses: actions/checkout@v4
- name: Install dependencies
Expand Down Expand Up @@ -73,4 +73,4 @@ jobs:
with:
artifacts: "libtokenizers.a, onnxruntime.so, hugot-cli-linux-amd64"
generateReleaseNotes: true
skipIfReleaseExists: true
skipIfReleaseExists: true
21 changes: 19 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ For the golang developer or ML engineer who wants to run transformer piplines on

## What is already there

Currently we have implementations for the following transfomer pipelines:
Currently, we have implementations for the following transfomer pipelines:

- [featureExtraction](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.FeatureExtractionPipeline)
- [textClassification](https://huggingface.co/docs/transformers/en/main_classes/pipelines#transformers.TextClassificationPipeline) (single label classification only)
Expand Down Expand Up @@ -167,6 +167,23 @@ Note that the --model parameter can be:
1. the full path to a model to load
2. the name of a huggingface model. Hugot will first try to look for the model at $HOME/hugot, or will try to download the model from huggingface.

## Performance Tuning

The library defaults to onnxruntime's default tuning settings. These are optimised for latency over throughput, and will attempt to parallelize single threaded calls to onnxruntime over multiple cores.

For maximum throughput, it is best to call a single shared hugot pipeline from multiple goroutines (1 per core), using channels to pass the input data. In this scenario, the following settings will greatly increase inference throughput.

```go
session, err := hugot.NewSession(
hugot.WithInterOpNumThreads(1),
hugot.WithIntraOpNumThreads(1),
hugot.WithCpuMemArena(false),
hugot.WithMemPattern(false),
)
```

InterOpNumThreads and IntraOpNumThreads constricts each goroutine's call to a single core, greatly reducing locking and cache penalties. Disabling CpuMemArena and MemPattern skips preallocation of some memory structures, increasing latency, but also throughput efficiency.

## Contributing

### Development environment
Expand Down Expand Up @@ -205,7 +222,7 @@ If you prefer to develop on bare metal, you will need to download the tokenizers

### Run the tests

The full test suite can be ran as follows. From the source folder:
The full test suite can be run as follows. From the source folder:

```bash
make clean run-tests
Expand Down
8 changes: 4 additions & 4 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ var runCommand = &cli.Command{
},
Action: func(ctx *cli.Context) error {

var onnxLibraryPathOpt hugot.SessionOption
var opts []hugot.WithOption

if modelsDir == "" {
userDir, err := os.UserHomeDir()
Expand All @@ -105,17 +105,17 @@ var runCommand = &cli.Command{
}

if sharedLibraryPath != "" {
onnxLibraryPathOpt = hugot.WithOnnxLibraryPath(sharedLibraryPath)
opts = append(opts, hugot.WithOnnxLibraryPath(sharedLibraryPath))
} else {
homeDir, err := os.UserHomeDir()
if err != nil {
if exists, err := util.FileSystem.Exists(ctx.Context, path.Join(homeDir, "lib", "hugot", "onnxruntime.so")); err != nil && exists {
onnxLibraryPathOpt = hugot.WithOnnxLibraryPath(path.Join(homeDir, "lib", "hugot", "onnxruntime.so"))
opts = append(opts, hugot.WithOnnxLibraryPath(path.Join(homeDir, "lib", "hugot", "onnxruntime.so")))
}
}
}

session, err := hugot.NewSession(onnxLibraryPathOpt)
session, err := hugot.NewSession(opts...)
if err != nil {
return err
}
Expand Down
14 changes: 8 additions & 6 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ require (
)

require (
cloud.google.com/go/storage v1.39.0 // indirect
github.com/aws/aws-sdk-go v1.51.2 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.3 // indirect
cloud.google.com/go/iam v1.1.7 // indirect
cloud.google.com/go/storage v1.39.1 // indirect
github.com/aws/aws-sdk-go v1.51.6 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/fatih/color v1.16.0 // indirect
github.com/go-errors/errors v1.5.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/mattn/go-colorable v0.1.13 // indirect
Expand All @@ -35,7 +35,9 @@ require (
golang.org/x/net v0.22.0 // indirect
golang.org/x/oauth2 v0.18.0 // indirect
golang.org/x/sys v0.18.0 // indirect
google.golang.org/api v0.169.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
google.golang.org/api v0.171.0 // indirect
google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
48 changes: 24 additions & 24 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
cloud.google.com/go v0.112.0 h1:tpFCD7hpHFlQ8yPwT3x+QeXqc2T6+n6T+hmABHfDUSM=
cloud.google.com/go v0.112.0/go.mod h1:3jEEVwZ/MHU4djK5t5RHuKOA/GbLddgTdVubX1qnPD4=
cloud.google.com/go/compute v1.24.0 h1:phWcR2eWzRJaL/kOiJwfFsPs4BaKq1j6vnpZrc1YlVg=
cloud.google.com/go/compute v1.24.0/go.mod h1:kw1/T+h/+tK2LJK0wiPPx1intgdAM3j/g3hFDlscY40=
cloud.google.com/go v0.112.1 h1:uJSeirPke5UNZHIb4SxfZklVSiWWVqW4oXlETwZziwM=
cloud.google.com/go v0.112.1/go.mod h1:+Vbu+Y1UU+I1rjmzeMOb/8RfkKJK2Gyxi1X6jJCZLo4=
cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU=
cloud.google.com/go/compute v1.25.1/go.mod h1:oopOIR53ly6viBYxaDhBfJwzUAxf1zE//uf3IB011ls=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA=
cloud.google.com/go/iam v1.1.6 h1:bEa06k05IO4f4uJonbB5iAgKTPpABy1ayxaIZV/GHVc=
cloud.google.com/go/iam v1.1.6/go.mod h1:O0zxdPeGBoFdWW3HWmBxJsk0pfvNM/p/qa82rWOGTwI=
cloud.google.com/go/storage v1.39.0 h1:brbjUa4hbDHhpQf48tjqMaXEV+f1OGoaTmQau9tmCsA=
cloud.google.com/go/storage v1.39.0/go.mod h1:OAEj/WZwUYjA3YHQ10/YcN9ttGuEpLwvaoyBXIPikEk=
github.com/aws/aws-sdk-go v1.51.2 h1:Ruwgz5aqIXin5Yfcgc+PCzoqW5tEGb9aDL/JWDsre7k=
github.com/aws/aws-sdk-go v1.51.2/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
cloud.google.com/go/iam v1.1.7 h1:z4VHOhwKLF/+UYXAJDFwGtNF0b6gjsW1Pk9Ml0U/IoM=
cloud.google.com/go/iam v1.1.7/go.mod h1:J4PMPg8TtyurAUvSmPj8FF3EDgY1SPRZxcUGrn7WXGA=
cloud.google.com/go/storage v1.39.1 h1:MvraqHKhogCOTXTlct/9C3K3+Uy2jBmFYb3/Sp6dVtY=
cloud.google.com/go/storage v1.39.1/go.mod h1:xK6xZmxZmo+fyP7+DEF6FhNc24/JAe95OLyOHCXFH1o=
github.com/aws/aws-sdk-go v1.51.6 h1:Ld36dn9r7P9IjU8WZSaswQ8Y/XUCRpewim5980DwYiU=
github.com/aws/aws-sdk-go v1.51.6/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk=
github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c h1:3TPq2BhzOquTGmbS53KeGcM1yalBUb/4zQM1wmaINrE=
github.com/bodaay/HuggingFaceModelDownloader v0.0.0-20240307153905-2f38356a6d6c/go.mod h1:p6JQ7mJjWx82F+SrFfj9RkoHlKEGXR4959uX/vkMbzE=
github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM=
github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand All @@ -39,8 +39,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs=
github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0=
github.com/googleapis/gax-go/v2 v2.12.2 h1:mhN09QQW1jEWeMF74zGR81R30z4VJzjZsfkUhuHF+DA=
github.com/googleapis/gax-go/v2 v2.12.2/go.mod h1:61M8vcyyXR2kqKFxKrfA22jaA8JGF7Dc8App1U3H6jc=
github.com/googleapis/gax-go/v2 v2.12.3 h1:5/zPPDvw8Q1SuXjrqrZslrqT7dL/uJT2CQii/cLCKqA=
github.com/googleapis/gax-go/v2 v2.12.3/go.mod h1:AKloxT6GtNbaLm8QTNSidHUVsHYcBHwWRvkNFJUQcS4=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
Expand Down Expand Up @@ -115,18 +115,18 @@ golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
google.golang.org/api v0.169.0 h1:QwWPy71FgMWqJN/l6jVlFHUa29a7dcUy02I8o799nPY=
google.golang.org/api v0.169.0/go.mod h1:gpNOiMA2tZ4mf5R9Iwf4rK/Dcz0fbdIgWYWVoxmsyLg=
google.golang.org/api v0.171.0 h1:w174hnBPqut76FzW5Qaupt7zY8Kql6fiVjgys4f58sU=
google.golang.org/api v0.171.0/go.mod h1:Hnq5AHm4OTMt2BUVjael2CWZFD6vksJdWCWiUAmjC9o=
google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM=
google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds=
google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9 h1:9+tzLLstTlPTRyJTh+ah5wIMsBW5c4tQwGTN3thOW9Y=
google.golang.org/genproto v0.0.0-20240213162025-012b6fc9bca9/go.mod h1:mqHbVIp48Muh7Ywss/AD6I5kNVKZMmAa/QEW58Gxp2s=
google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c h1:9g7erC9qu44ks7UK4gDNlnk4kOxZG707xKm4jVniy6o=
google.golang.org/genproto/googleapis/api v0.0.0-20240221002015-b0ce06bbee7c/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78 h1:Xs9lu+tLXxLIfuci70nG4cpwaRC+mRQPUL7LoIeDJC4=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240304161311-37d4d3c04a78/go.mod h1:UCOku4NytXMJuLQE5VuqA5lX3PcHCBo8pxNyvkf4xBs=
google.golang.org/grpc v1.62.0 h1:HQKZ/fa1bXkX1oFOvSjmZEUL8wLSaZTjCcLAlmZRtdk=
google.golang.org/grpc v1.62.0/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237 h1:PgNlNSx2Nq2/j4juYzQBG0/Zdr+WP4z5N01Vk4VYBCY=
google.golang.org/genproto v0.0.0-20240318140521-94a12d6c2237/go.mod h1:9sVD8c25Af3p0rGs7S7LLsxWKFiJt/65LdSyqXBkX/Y=
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:RFiFrvy37/mpSpdySBDrUdipW/dHwsRwh3J3+A9VgT4=
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk=
google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
137 changes: 72 additions & 65 deletions hugot.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"context"
"errors"
"fmt"
util "github.com/knights-analytics/hugot/utils"
"slices"

"github.com/knights-analytics/hugot/pipelines"
util "github.com/knights-analytics/hugot/utils"
ort "github.com/yalue/onnxruntime_go"
)

Expand Down Expand Up @@ -45,61 +45,16 @@ func (m pipelineMap[T]) GetStats() []string {
return stats
}

type SessionOption func() error

func WithOnnxLibraryPath(ortLibraryPath string) SessionOption {
return func() error {
if ortLibraryPath == "" {
return fmt.Errorf("path to the ort library cannot be empty")
}
ortPathExists, err := util.FileSystem.Exists(context.Background(), ortLibraryPath)
if err != nil {
return err
}
if !ortPathExists {
return fmt.Errorf("cannot find the ort library at: %s", ortLibraryPath)
}
ort.SetSharedLibraryPath(ortLibraryPath)
return nil
}
}

func (s *Session) setSessionOptions() error {
options, optionsError := ort.NewSessionOptions()
if optionsError != nil {
return optionsError
}
err1 := options.SetIntraOpNumThreads(1)
if err1 != nil {
return err1
}
err2 := options.SetInterOpNumThreads(1)
if err2 != nil {
return err2
}
err3 := options.SetCpuMemArena(true)
if err3 != nil {
return err3
}
s.ortOptions = options
return nil
}

// NewSession is the main entrypoint to hugot and is used to create a new hugot session object.
// ortLibraryPath should be the path to onnxruntime.so. If it's the empty string, hugot will try
// to load the library from the default location (/usr/lib/onnxruntime.so).
// A new session must be destroyed when it's not needed anymore to avoid memory leaks. See the Destroy method.
// Note moreover that there can be at most one hugot session active (i.e., the Session object is a singleton),
// otherwise NewSession will return an error.
func NewSession(options ...SessionOption) (*Session, error) {
func NewSession(options ...WithOption) (*Session, error) {

if ort.IsInitialized() {
return nil, errors.New("another session is currently active and only one session can be active at one time")
} else {
err := ort.InitializeEnvironment()
if err != nil {
return nil, err
}
return nil, errors.New("another session is currently active, and only one session can be active at one time")
}

session := &Session{
Expand All @@ -108,29 +63,81 @@ func NewSession(options ...SessionOption) (*Session, error) {
textClassificationPipelines: map[string]*pipelines.TextClassificationPipeline{},
}

telemetryErr := ort.DisableTelemetry()
if telemetryErr != nil {
destroyErr := session.Destroy()
return nil, errors.Join(telemetryErr, destroyErr)
// set session options and initialise
if initialised, err := session.initialiseORT(options...); err != nil {
if initialised {
destroyErr := session.Destroy()
return nil, errors.Join(err, destroyErr)
}
return nil, err
}

return session, nil
}

func (s *Session) initialiseORT(options ...WithOption) (bool, error) {

// Collect options into a struct, so they can be applied in the correct order later
o := &ortOptions{}
for _, option := range options {
option(o)
}

// set session options
optionsErr := session.setSessionOptions()
if optionsErr != nil {
destroyErr := session.Destroy()
return nil, errors.Join(optionsErr, destroyErr)
// Set pre-initialisation options
if o.libraryPath != "" {
ortPathExists, err := util.FileSystem.Exists(context.Background(), o.libraryPath)
if err != nil {
return false, err
}
if !ortPathExists {
return false, fmt.Errorf("cannot find the ort library at: %s", o.libraryPath)
}
ort.SetSharedLibraryPath(o.libraryPath)
}

// Start OnnxRuntime
if err := ort.InitializeEnvironment(); err != nil {
return false, err
}

for _, opt := range options {
if opt != nil {
optSetErr := opt()
if optSetErr != nil {
destroyErr := session.Destroy()
return nil, errors.Join(optSetErr, destroyErr)
}
if o.telemetry {
if err := ort.EnableTelemetry(); err != nil {
return true, err
}
} else {
if err := ort.DisableTelemetry(); err != nil {
return true, err
}
}
return session, nil

// Create session options for use in all pipelines
sessionOptions, optionsError := ort.NewSessionOptions()
if optionsError != nil {
return true, optionsError
}
if o.intraOpNumThreads != 0 {
if err := sessionOptions.SetIntraOpNumThreads(o.intraOpNumThreads); err != nil {
return true, err
}
}
if o.interOpNumThreads != 0 {
if err := sessionOptions.SetInterOpNumThreads(o.interOpNumThreads); err != nil {
return true, err
}
}
if !o.cpuMemArenaSet {
if err := sessionOptions.SetCpuMemArena(o.cpuMemArena); err != nil {
return true, err
}
}
if !o.memPatternSet {
if err := sessionOptions.SetMemPattern(o.memPattern); err != nil {
return true, err
}
}

s.ortOptions = sessionOptions
return true, nil
}

// NewTokenClassificationPipeline creates and returns a new token classification pipeline object.
Expand Down
Loading

0 comments on commit 38534cb

Please sign in to comment.