-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubmission.go
78 lines (68 loc) · 1.5 KB
/
submission.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
// Copyright 2018 The go-trackml Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package trackml
import (
"compress/gzip"
"encoding/csv"
"os"
"strconv"
"github.com/pkg/errors"
)
// Submission creates a CSV file ready for submission to Kaggle
type Submission struct {
f *os.File
gw *gzip.Writer
csv *csv.Writer
}
func NewSubmission() (*Submission, error) {
f, err := os.Create("submission.csv.gz")
if err != nil {
return nil, err
}
sub := &Submission{f: f}
sub.gw = gzip.NewWriter(f)
sub.csv = csv.NewWriter(sub.gw)
err = sub.csv.Write([]string{"event_id", "hit_id", "track_id"})
if err != nil {
return nil, err
}
sub.csv.Flush()
return sub, sub.csv.Error()
}
func (sub *Submission) Close() error {
sub.csv.Flush()
err1 := sub.csv.Error()
err2 := sub.gw.Close()
err3 := sub.f.Close()
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
if err3 != nil {
return err3
}
return nil
}
func (sub *Submission) Append(evt Event, trkIDs []int) error {
defer sub.csv.Flush()
if len(evt.Hits) != len(trkIDs) {
return errors.Errorf("length mismatch")
}
var (
rec [3]string
evtid = strconv.Itoa(evt.ID)
)
for i, tid := range trkIDs {
rec[0] = evtid
rec[1] = strconv.Itoa(evt.Hits[i].HitID)
rec[2] = strconv.Itoa(tid)
err := sub.csv.Write(rec[:])
if err != nil {
return errors.Wrapf(err, "could not write row %d of event %v", i, evt.ID)
}
}
return sub.csv.Error()
}