forked from knights-analytics/hugot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhugot_training.go
114 lines (96 loc) · 2.67 KB
/
hugot_training.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package hugot
import (
"fmt"
"github.com/knights-analytics/hugot/datasets"
"github.com/knights-analytics/hugot/options"
"github.com/knights-analytics/hugot/pipelineBackends"
"github.com/knights-analytics/hugot/pipelines"
)
type TrainingSession struct {
runtime string
pipeline pipelineBackends.Pipeline
config TrainingConfig
}
type TrainingOption func(eo *TrainingSession) error
type TrainingConfig struct {
ModelPath string
OnnxFilename string
Cuda bool
Epochs int
XlaTrainingOptions *XLATrainingOptions
Dataset datasets.Dataset
Verbose bool
}
func newTrainingSession[T pipelineBackends.Pipeline](runtime string, config TrainingConfig) (*TrainingSession, error) {
session := &TrainingSession{
config: config,
runtime: runtime,
}
var trainingPipeline T
var model *pipelineBackends.Model
var err error
opts := options.Defaults()
opts.Runtime = runtime
switch runtime {
case "XLA":
opts.XLAOptions.Cuda = config.Cuda
default:
return nil, fmt.Errorf("runtime %s is not supported", runtime)
}
if config.Epochs <= 0 {
config.Epochs = 1
}
model, err = pipelineBackends.LoadModel(config.ModelPath, config.OnnxFilename, opts)
if err != nil {
return nil, err
}
switch any(trainingPipeline).(type) {
case *pipelines.FeatureExtractionPipeline:
pipelineConfig := FeatureExtractionConfig{}
pipeline := any(trainingPipeline).(*pipelines.FeatureExtractionPipeline)
pipeline, _, err = InitializePipeline(pipeline, pipelineConfig, opts, model)
if err != nil {
return nil, err
}
session.pipeline = pipeline
// hook the dataset up with the pipeline for tokenization
if d, ok := session.config.Dataset.(*datasets.SemanticSimilarityDataset); !ok {
return nil, fmt.Errorf("expected SemanticSimilarityDataset, got %T", d)
} else {
if e := d.SetTokenizationPipeline(pipeline); e != nil {
return nil, e
}
}
default:
return nil, fmt.Errorf("training for pipeline type is not supported")
}
if session.config.Verbose {
session.config.Dataset.SetVerbose(true)
}
return session, nil
}
func (s *TrainingSession) Train() error {
switch s.runtime {
case "XLA":
return TrainXLA(s)
default:
return fmt.Errorf("training runtime %s is not supported", s.runtime)
}
}
func (s *TrainingSession) Save(path string) error {
model := s.pipeline.GetModel()
if model != nil {
if s.runtime == "XLA" {
xlaModel := model.XLAModel
if xlaModel != nil {
return xlaModel.Save(path)
} else {
return fmt.Errorf("xla model is nil")
}
} else {
return fmt.Errorf("XLA runtime is required for saving a training model")
}
} else {
return fmt.Errorf("pipeline model is nil")
}
}