Skip to content

Commit

Permalink
creating separate validation method for pipelines and tests for it
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardopinosio committed Mar 11, 2024
1 parent 99aeef5 commit a0a9770
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 25 deletions.
81 changes: 81 additions & 0 deletions hugot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/knights-analytics/hugot"
"github.com/knights-analytics/hugot/pipelines"
util "github.com/knights-analytics/hugot/utils"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -88,6 +89,35 @@ func TestTextClassificationPipeline(t *testing.T) {
session.GetStats()
}

func TestTextClassificationPipelineValidation(t *testing.T) {
session, err := hugot.NewSession(onnxruntimeSharedLibrary)
check(err)
defer func(session *hugot.Session) {
check(session.Destroy())
}(session)
modelFolder := os.Getenv("TEST_MODELS_FOLDER")
if modelFolder == "" {
modelFolder = "./models"
}
modelPath := path.Join(modelFolder, "distilbert-base-uncased-finetuned-sst-2-english")
sentimentPipeline, err := session.NewTextClassificationPipeline(modelPath, "testPipeline", pipelines.WithAggregationFunction(util.SoftMax))
check(err)
sentimentPipeline.IdLabelMap = map[int]string{}
err = sentimentPipeline.Validate()
assert.Error(t, err)
if err != nil {
errInt := err.(interface{ Unwrap() []error })
assert.Equal(t, 3, len(errInt.Unwrap()))
}
sentimentPipeline.OutputDim = 0
err = sentimentPipeline.Validate()
assert.Error(t, err)
if err != nil {
errInt := err.(interface{ Unwrap() []error })
assert.Equal(t, 3, len(errInt.Unwrap()))
}
}

// Token classification

func TestTokenClassificationPipeline(t *testing.T) {
Expand Down Expand Up @@ -158,6 +188,37 @@ func TestTokenClassificationPipeline(t *testing.T) {
}
}

func TestTokenClassificationPipelineValidation(t *testing.T) {
session, err := hugot.NewSession(onnxruntimeSharedLibrary)
check(err)
defer func(session *hugot.Session) {
check(session.Destroy())
}(session)

modelFolder := os.Getenv("TEST_MODELS_FOLDER")
if modelFolder == "" {
modelFolder = "./models"
}
modelPath := path.Join(modelFolder, "distilbert-NER")
pipelineSimple, err2 := session.NewTokenClassificationPipeline(modelPath, "testPipelineSimple", pipelines.WithSimpleAggregation())
check(err2)

pipelineSimple.IdLabelMap = map[int]string{}
err = pipelineSimple.Validate()
assert.Error(t, err)
if err != nil {
errInt := err.(interface{ Unwrap() []error })
assert.Equal(t, 2, len(errInt.Unwrap()))
}
pipelineSimple.OutputDim = 0
err = pipelineSimple.Validate()
assert.Error(t, err)
if err != nil {
errInt := err.(interface{ Unwrap() []error })
assert.Equal(t, 2, len(errInt.Unwrap()))
}
}

// feature extraction

func TestFeatureExtractionPipeline(t *testing.T) {
Expand Down Expand Up @@ -252,6 +313,26 @@ func TestFeatureExtractionPipeline(t *testing.T) {
assert.Greater(t, pipeline.TokenizerTimings.TotalNS, zero, "TokenizerTimings.TotalNS should be greater than 0")
}

func TestFeatureExtractionPipelineValidation(t *testing.T) {
session, err := hugot.NewSession(onnxruntimeSharedLibrary)
check(err)
defer func(session *hugot.Session) {
check(session.Destroy())
}(session)

modelFolder := os.Getenv("TEST_MODELS_FOLDER")
if modelFolder == "" {
modelFolder = "./models"
}
modelPath := path.Join(modelFolder, "all-MiniLM-L6-v2")
pipeline, err := session.NewFeatureExtractionPipeline(modelPath, "testPipeline")
check(err)

pipeline.OutputDim = 0
err = pipeline.Validate()
assert.Error(t, err)
}

// utilities

// Returns an error if any element between a and b don't match.
Expand Down
15 changes: 12 additions & 3 deletions pipelines/featureExtraction.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,23 @@ func NewFeatureExtractionPipeline(modelPath string, name string) (*FeatureExtrac
// the dimension of the output is taken from the output meta. For the moment we assume that there is only one output
pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2])

// output dimension
if pipeline.OutputDim <= 0 {
return nil, errors.New("pipeline configuration invalid: outputDim parameter must be greater than zero")
err = pipeline.Validate()
if err != nil {
return nil, err
}

return pipeline, nil
}

func (p *FeatureExtractionPipeline) Validate() error {
var validationErrors []error

if p.OutputDim <= 0 {
validationErrors = append(validationErrors, errors.New("pipeline configuration invalid: outputDim parameter must be greater than zero"))
}
return errors.Join(validationErrors...)
}

// Postprocess Parse the results of the forward pass into the output. Token embeddings are mean pooled.
func (p *FeatureExtractionPipeline) Postprocess(batch PipelineBatch) (PipelineBatchOutput, error) {

Expand Down
1 change: 1 addition & 0 deletions pipelines/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ type Pipeline interface {
Destroy() error
GetStats() []string
GetOutputDim() int
Validate() error
Run([]string) (PipelineBatchOutput, error)
}

Expand Down
34 changes: 21 additions & 13 deletions pipelines/textClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,33 @@ func NewTextClassificationPipeline(modelPath string, name string, opts ...TextCl

// we only support single label classification for now
pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[1])
if len(pipeline.IdLabelMap) < 1 {
return nil, fmt.Errorf("only single label classification models are currently supported and more than one label is required")
}

// output dimension
if pipeline.OutputDim <= 0 {
return nil, fmt.Errorf("pipeline configuration invalid: outputDim parameter must be greater than zero")
// validate
validationErrors := pipeline.Validate()
if validationErrors != nil {
return nil, validationErrors
}

if len(pipeline.IdLabelMap) <= 0 {
return nil, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")
}
if len(pipeline.IdLabelMap) != pipeline.OutputDim {
return nil, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match model output dimension")
}
return pipeline, nil
}

// TODO: perhaps this can be unified with the other pipelines
func (p *TextClassificationPipeline) Validate() error {
var validationErrors []error

if len(p.IdLabelMap) < 1 {
validationErrors = append(validationErrors, fmt.Errorf("only single label classification models are currently supported and more than one label is required"))
}
if p.OutputDim <= 0 {
validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: outputDim parameter must be greater than zero"))
}
if len(p.IdLabelMap) <= 0 {
validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero"))
}
if len(p.IdLabelMap) != p.OutputDim {
validationErrors = append(validationErrors, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match model output dimension"))
}
return errors.Join(validationErrors...)
}

func (p *TextClassificationPipeline) Forward(batch PipelineBatch) (PipelineBatch, error) {
start := time.Now()
Expand Down
25 changes: 16 additions & 9 deletions pipelines/tokenClassification.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,19 +125,26 @@ func NewTokenClassificationPipeline(modelPath string, name string, opts ...Token
// the dimension of the output is taken from the output meta.
pipeline.OutputDim = int(pipeline.OutputsMeta[0].Dimensions[2])

// output dimension
if pipeline.OutputDim <= 0 {
return nil, fmt.Errorf("pipeline configuration invalid: outputDim parameter must be greater than zero")
err = pipeline.Validate()
if err != nil {
return nil, err
}
return pipeline, nil
}

// checks
if len(pipeline.IdLabelMap) <= 0 {
return nil, fmt.Errorf("pipeline configuration invalid: length of id2label map for token classification pipeline must be greater than zero")
func (p *TokenClassificationPipeline) Validate() error {
var validationErrors []error

if p.OutputDim <= 0 {
validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: outputDim parameter must be greater than zero"))
}
if len(pipeline.IdLabelMap) != pipeline.OutputDim {
return nil, fmt.Errorf("pipeline configuration invalid: length of id2label map does not match model output dimension")
if len(p.IdLabelMap) <= 0 {
validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map for token classification p must be greater than zero"))
}
return pipeline, nil
if len(p.IdLabelMap) != p.OutputDim {
validationErrors = append(validationErrors, fmt.Errorf("p configuration invalid: length of id2label map does not match model output dimension"))
}
return errors.Join(validationErrors...)
}

// Postprocess function for a token classification pipeline
Expand Down

0 comments on commit a0a9770

Please sign in to comment.