Skip to content

Commit 3153759

Browse files
author
abdul dakkak
authored
Merge pull request #2 from rai-project/feature/dataset/vision/coco
Feature/dataset/vision/coco
2 parents 010143a + 0e7c2fa commit 3153759

30 files changed

+2916
-1
lines changed

dldataset.go

+4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,15 @@ import (
44
"io"
55

66
context "context"
7+
8+
"github.com/rai-project/dlframework"
79
)
810

911
// LabeledData ...
1012
type LabeledData interface {
1113
Label() string
14+
Feature() *dlframework.Feature
15+
Features() dlframework.Features
1216
Data() (interface{}, error)
1317
}
1418

reader/mxnet_recordio.go

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ func NewRecordIOReader(path string) (*RecordIOReader, error) {
3232
}, nil
3333
}
3434

35+
// Next ...
3536
func (r *RecordIOReader) Next(ctx context.Context) (*ImageRecord, error) {
3637
f := r.r
3738

reader/record.go

+8
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,16 @@ package reader
22

33
import "github.com/rai-project/image/types"
44

5+
// ImageRecord ...
56
type ImageRecord struct {
67
ID uint64
78
LabelIndex float32
89
Image *types.RGBImage
910
}
11+
12+
// ImageSegmentationRecord ...
13+
type ImageSegmentationRecord struct {
14+
ID uint64
15+
LabelIndex float32
16+
Image *types.RGBImage
17+
}

reader/tfrecord.go

+14
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ import (
1212
"github.com/rai-project/image"
1313
"github.com/rai-project/image/types"
1414
"github.com/ubccr/terf"
15+
protobuf "github.com/ubccr/terf/protobuf"
1516
)
1617

18+
// TFRecordReader ...
1719
type TFRecordReader struct {
1820
r io.ReadCloser
1921
*terf.Reader
2022
}
2123

24+
// NewTFRecordReader ...
2225
func NewTFRecordReader(path string) (*TFRecordReader, error) {
2326
r, err := os.Open(path)
2427
if err != nil {
@@ -30,6 +33,16 @@ func NewTFRecordReader(path string) (*TFRecordReader, error) {
3033
}, nil
3134
}
3235

36+
// NextRecord ...
37+
func (r *TFRecordReader) NextRecord(ctx context.Context) (*protobuf.Example, error) {
38+
nxt, err := r.Reader.Next()
39+
if err != nil {
40+
return nil, err
41+
}
42+
return nxt, nil
43+
}
44+
45+
// Next ...
3346
func (r *TFRecordReader) Next(ctx context.Context) (*ImageRecord, error) {
3447
nxt, err := r.Reader.Next()
3548
if err != nil {
@@ -82,6 +95,7 @@ func (r *TFRecordReader) Next(ctx context.Context) (*ImageRecord, error) {
8295
}, nil
8396
}
8497

98+
// Close ...
8599
func (r *TFRecordReader) Close() error {
86100
return r.r.Close()
87101
}

reader/tfrecord/features.go

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package tfrecord
2+
3+
import (
4+
"github.com/ubccr/terf"
5+
protobuf "github.com/ubccr/terf/protobuf"
6+
)
7+
8+
// FeatureBool ...
9+
func FeatureBool(rec *protobuf.Example, key string) bool {
10+
return FeatureInt(rec, key) == 1
11+
}
12+
13+
// FeatureInt64 ...
14+
func FeatureInt64(rec *protobuf.Example, key string) int64 {
15+
f, ok := rec.Features.Feature[key]
16+
if !ok {
17+
return 0
18+
}
19+
20+
val, ok := f.Kind.(*protobuf.Feature_Int64List)
21+
if !ok {
22+
return 0
23+
}
24+
25+
return val.Int64List.Value[0]
26+
}
27+
28+
// FeatureInt ...
29+
func FeatureInt(rec *protobuf.Example, key string) int {
30+
return int(FeatureInt64(rec, key))
31+
}
32+
33+
// FeatureInt32 ...
34+
func FeatureInt32(rec *protobuf.Example, key string) int32 {
35+
return int32(FeatureInt64(rec, key))
36+
}
37+
38+
// FeatureFloat64 ...
39+
func FeatureFloat64(rec *protobuf.Example, key string) float64 {
40+
return float64(FeatureFloat32(rec, key))
41+
}
42+
43+
// FeatureFloat32 ...
44+
func FeatureFloat32(rec *protobuf.Example, key string) float32 {
45+
f, ok := rec.Features.Feature[key]
46+
if !ok {
47+
return 0
48+
}
49+
50+
val, ok := f.Kind.(*protobuf.Feature_FloatList)
51+
if !ok {
52+
return 0
53+
}
54+
55+
return val.FloatList.Value[0]
56+
}
57+
58+
// FeatureBytes ...
59+
func FeatureBytes(rec *protobuf.Example, key string) []byte {
60+
return terf.ExampleFeatureBytes(rec, key)
61+
}
62+
63+
// FeatureString ...
64+
func FeatureString(rec *protobuf.Example, key string) string {
65+
return string(FeatureBytes(rec, key))
66+
}
67+
68+
// FeatureBytesSlice ...
69+
func FeatureBytesSlice(rec *protobuf.Example, key string) [][]byte {
70+
// TODO: return error if key is not found?
71+
f, ok := rec.Features.Feature[key]
72+
if !ok {
73+
return nil
74+
}
75+
76+
val, ok := f.Kind.(*protobuf.Feature_BytesList)
77+
if !ok {
78+
return nil
79+
}
80+
return val.BytesList.Value
81+
}
82+
83+
// FeatureStringSlice ...
84+
func FeatureStringSlice(rec *protobuf.Example, key string) []string {
85+
slice := FeatureBytesSlice(rec, key)
86+
if slice == nil {
87+
return nil
88+
}
89+
90+
res := make([]string, len(slice))
91+
for ii, val := range slice {
92+
res[ii] = string(val)
93+
}
94+
95+
return res
96+
}
97+
98+
// FeatureInt64Slice ...
99+
func FeatureInt64Slice(rec *protobuf.Example, key string) []int64 {
100+
101+
f, ok := rec.Features.Feature[key]
102+
if !ok {
103+
return nil
104+
}
105+
106+
val, ok := f.Kind.(*protobuf.Feature_Int64List)
107+
if !ok {
108+
return nil
109+
}
110+
111+
return val.Int64List.Value
112+
}
113+
114+
// FeatureIntSlice ...
115+
func FeatureIntSlice(rec *protobuf.Example, key string) []int {
116+
slice := FeatureInt64Slice(rec, key)
117+
if slice == nil {
118+
return nil
119+
}
120+
121+
res := make([]int, len(slice))
122+
for ii, val := range slice {
123+
res[ii] = int(val)
124+
}
125+
126+
return res
127+
}
128+
129+
// FeatureInt32Slice ...
130+
func FeatureInt32Slice(rec *protobuf.Example, key string) []int32 {
131+
slice := FeatureInt64Slice(rec, key)
132+
if slice == nil {
133+
return nil
134+
}
135+
136+
res := make([]int32, len(slice))
137+
for ii, val := range slice {
138+
res[ii] = int32(val)
139+
}
140+
141+
return res
142+
}
143+
144+
// FeatureFloat64Slice ...
145+
func FeatureFloat64Slice(rec *protobuf.Example, key string) []float64 {
146+
slice := FeatureFloat32Slice(rec, key)
147+
if slice == nil {
148+
return nil
149+
}
150+
151+
res := make([]float64, len(slice))
152+
for ii, val := range slice {
153+
res[ii] = float64(val)
154+
}
155+
156+
return res
157+
}
158+
159+
// FeatureFloat32Slice ...
160+
func FeatureFloat32Slice(rec *protobuf.Example, key string) []float32 {
161+
f, ok := rec.Features.Feature[key]
162+
if !ok {
163+
return nil
164+
}
165+
166+
val, ok := f.Kind.(*protobuf.Feature_FloatList)
167+
if !ok {
168+
return nil
169+
}
170+
171+
return val.FloatList.Value
172+
}

vision/cifar10.go

+16
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@ import (
1111
"strconv"
1212
"strings"
1313

14+
"github.com/rai-project/dlframework"
15+
"github.com/rai-project/dlframework/framework/feature"
16+
1417
context "context"
18+
1519
"github.com/Unknwon/com"
1620
"github.com/pkg/errors"
1721
"github.com/rai-project/config"
@@ -52,6 +56,18 @@ func (l CIFAR10LabeledImage) Label() string {
5256
return l.label
5357
}
5458

59+
// Feature ...
60+
func (l CIFAR10LabeledImage) Feature() *dlframework.Feature {
61+
return feature.New(
62+
feature.ClassificationLabel(l.label),
63+
)
64+
}
65+
66+
// Features ...
67+
func (l CIFAR10LabeledImage) Features() dlframework.Features {
68+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
69+
}
70+
5571
// Data ...
5672
func (l CIFAR10LabeledImage) Data() (interface{}, error) {
5773
return l.data, nil

vision/cifar100.go

+15
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ import (
1212
"strings"
1313

1414
context "context"
15+
1516
"github.com/Unknwon/com"
1617
"github.com/pkg/errors"
1718
"github.com/rai-project/config"
1819
"github.com/rai-project/dldataset"
20+
"github.com/rai-project/dlframework"
21+
"github.com/rai-project/dlframework/framework/feature"
1922
"github.com/rai-project/downloadmanager"
2023
"github.com/rai-project/image/types"
2124
"github.com/rai-project/utils"
@@ -66,6 +69,18 @@ func (l CIFAR100LabeledImage) Label() string {
6669
return l.FineLabel()
6770
}
6871

72+
// Feature ...
73+
func (l CIFAR100LabeledImage) Feature() *dlframework.Feature {
74+
return feature.New(
75+
feature.ClassificationLabel(l.fineLabel),
76+
)
77+
}
78+
79+
// Features ...
80+
func (l CIFAR100LabeledImage) Features() dlframework.Features {
81+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
82+
}
83+
6984
// Data ...
7085
func (l CIFAR100LabeledImage) Data() (interface{}, error) {
7186
return l.data, nil

vision/cityscape.go

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
package vision

0 commit comments

Comments
 (0)