Skip to content

Commit 6e2d859

Browse files
author
Abdul Dakkak
committed
return features
1 parent 3f943cf commit 6e2d859

File tree

8 files changed

+268
-0
lines changed

8 files changed

+268
-0
lines changed

dldataset.go

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
type LabeledData interface {
1313
Label() string
1414
Feature() *dlframework.Feature
15+
Features() dlframework.Features
1516
Data() (interface{}, error)
1617
}
1718

reader/tfrecord/features.go

+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package tfrecord
2+
3+
import (
4+
"github.com/ubccr/terf"
5+
protobuf "github.com/ubccr/terf/protobuf"
6+
)
7+
8+
func FeatureBool(rec *protobuf.Example, key string) bool {
9+
return FeatureInt(rec, key) == 1
10+
}
11+
12+
func FeatureInt64(rec *protobuf.Example, key string) int64 {
13+
f, ok := rec.Features.Feature[key]
14+
if !ok {
15+
return 0
16+
}
17+
18+
val, ok := f.Kind.(*protobuf.Feature_Int64List)
19+
if !ok {
20+
return 0
21+
}
22+
23+
return val.Int64List.Value[0]
24+
}
25+
26+
func FeatureInt(rec *protobuf.Example, key string) int {
27+
return int(FeatureInt64(rec, key))
28+
}
29+
30+
func FeatureInt32(rec *protobuf.Example, key string) int32 {
31+
return int32(FeatureInt64(rec, key))
32+
}
33+
34+
func FeatureFloat64(rec *protobuf.Example, key string) float64 {
35+
return float64(FeatureFloat32(rec, key))
36+
}
37+
38+
func FeatureFloat32(rec *protobuf.Example, key string) float32 {
39+
f, ok := rec.Features.Feature[key]
40+
if !ok {
41+
return 0
42+
}
43+
44+
val, ok := f.Kind.(*protobuf.Feature_FloatList)
45+
if !ok {
46+
return 0
47+
}
48+
49+
return val.FloatList.Value[0]
50+
}
51+
52+
func FeatureBytes(rec *protobuf.Example, key string) []byte {
53+
return terf.ExampleFeatureBytes(rec, key)
54+
}
55+
56+
func FeatureString(rec *protobuf.Example, key string) string {
57+
return string(FeatureBytes(rec, key))
58+
}
59+
60+
func FeatureBytesSlice(rec *protobuf.Example, key string) [][]byte {
61+
// TODO: return error if key is not found?
62+
f, ok := rec.Features.Feature[key]
63+
if !ok {
64+
return nil
65+
}
66+
67+
val, ok := f.Kind.(*protobuf.Feature_BytesList)
68+
if !ok {
69+
return nil
70+
}
71+
return val.BytesList.Value
72+
}
73+
74+
func FeatureStringSlice(rec *protobuf.Example, key string) []string {
75+
slice := FeatureBytesSlice(rec, key)
76+
if slice == nil {
77+
return nil
78+
}
79+
80+
res := make([]string, len(slice))
81+
for ii, val := range slice {
82+
res[ii] = string(val)
83+
}
84+
85+
return res
86+
}
87+
88+
func FeatureInt64Slice(rec *protobuf.Example, key string) []int64 {
89+
90+
f, ok := rec.Features.Feature[key]
91+
if !ok {
92+
return nil
93+
}
94+
95+
val, ok := f.Kind.(*protobuf.Feature_Int64List)
96+
if !ok {
97+
return nil
98+
}
99+
100+
return val.Int64List.Value
101+
}
102+
103+
func FeatureIntSlice(rec *protobuf.Example, key string) []int {
104+
slice := FeatureInt64Slice(rec, key)
105+
if slice == nil {
106+
return nil
107+
}
108+
109+
res := make([]int, len(slice))
110+
for ii, val := range slice {
111+
res[ii] = int(val)
112+
}
113+
114+
return res
115+
}
116+
117+
func FeatureInt32Slice(rec *protobuf.Example, key string) []int32 {
118+
slice := FeatureInt64Slice(rec, key)
119+
if slice == nil {
120+
return nil
121+
}
122+
123+
res := make([]int32, len(slice))
124+
for ii, val := range slice {
125+
res[ii] = int32(val)
126+
}
127+
128+
return res
129+
}
130+
131+
func FeatureFloat64Slice(rec *protobuf.Example, key string) []float64 {
132+
slice := FeatureFloat32Slice(rec, key)
133+
if slice == nil {
134+
return nil
135+
}
136+
137+
res := make([]float64, len(slice))
138+
for ii, val := range slice {
139+
res[ii] = float64(val)
140+
}
141+
142+
return res
143+
}
144+
145+
func FeatureFloat32Slice(rec *protobuf.Example, key string) []float32 {
146+
f, ok := rec.Features.Feature[key]
147+
if !ok {
148+
return nil
149+
}
150+
151+
val, ok := f.Kind.(*protobuf.Feature_FloatList)
152+
if !ok {
153+
return nil
154+
}
155+
156+
return val.FloatList.Value
157+
}

vision/cifar10.go

+5
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ func (l CIFAR10LabeledImage) Feature() *dlframework.Feature {
6363
)
6464
}
6565

66+
// Features ...
67+
func (l CIFAR10LabeledImage) Features() dlframework.Features {
68+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
69+
}
70+
6671
// Data ...
6772
func (l CIFAR10LabeledImage) Data() (interface{}, error) {
6873
return l.data, nil

vision/cifar100.go

+5
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ func (l CIFAR100LabeledImage) Feature() *dlframework.Feature {
7676
)
7777
}
7878

79+
// Features ...
80+
func (l CIFAR100LabeledImage) Features() dlframework.Features {
81+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
82+
}
83+
7984
// Data ...
8085
func (l CIFAR100LabeledImage) Data() (interface{}, error) {
8186
return l.data, nil

vision/coco.go

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package vision
2+
3+
import (
4+
"bytes"
5+
6+
"github.com/pkg/errors"
7+
"github.com/rai-project/dldataset/reader/tfrecord"
8+
"github.com/rai-project/dlframework"
9+
"github.com/rai-project/dlframework/framework/feature"
10+
"github.com/rai-project/image"
11+
"github.com/rai-project/image/types"
12+
protobuf "github.com/ubccr/terf/protobuf"
13+
)
14+
15+
// CIFAR100LabeledImage ...
16+
type CocoLabeledImage struct {
17+
width int64
18+
height int64
19+
fileName string
20+
sourceID string
21+
sha256 string
22+
area []float32
23+
isCrowd []int64
24+
features []*dlframework.Feature
25+
data *types.RGBImage
26+
}
27+
28+
func getImageRecord(data []byte, format string) (*types.RGBImage, error) {
29+
img, err := image.Read(bytes.NewBuffer(data), image.Context(nil))
30+
if err != nil {
31+
return nil, err
32+
}
33+
34+
rgbImage, ok := img.(*types.RGBImage)
35+
if !ok {
36+
return nil, errors.Errorf("expecting an rgb image")
37+
}
38+
39+
return rgbImage, nil
40+
}
41+
42+
func nextCocoFromRecord(rec *protobuf.Example) *CocoLabeledImage {
43+
height := tfrecord.FeatureInt64(rec, "image/height")
44+
width := tfrecord.FeatureInt64(rec, "image/width")
45+
fileName := tfrecord.FeatureString(rec, "image/filename")
46+
sourceID := tfrecord.FeatureString(rec, "image/source_id")
47+
sha256 := tfrecord.FeatureString(rec, "image/key/sha256")
48+
imgFormat := tfrecord.FeatureString(rec, "image/format")
49+
img, err := getImageRecord(tfrecord.FeatureBytes(rec, "image/encoded"), imgFormat)
50+
if err != nil {
51+
panic(err)
52+
}
53+
bboxXmin := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/xmin")
54+
bboxXmax := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/xmax")
55+
bboxYmin := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/ymin")
56+
bboxYmax := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/ymax")
57+
class := tfrecord.FeatureStringSlice(rec, "image/object/class/text")
58+
isCrowd := tfrecord.FeatureInt64Slice(rec, "image/object/is_crowd")
59+
area := tfrecord.FeatureFloat32Slice(rec, "image/object/area")
60+
61+
numBBoxes := len(bboxXmax)
62+
features := make([]*dlframework.Feature, numBBoxes)
63+
for ii := 0; ii < numBBoxes; ii++ {
64+
features[ii] = feature.New(
65+
feature.BoundingBoxType(),
66+
feature.BoundingBoxXmin(bboxXmin[ii]),
67+
feature.BoundingBoxXmax(bboxXmax[ii]),
68+
feature.BoundingBoxYmin(bboxYmin[ii]),
69+
feature.BoundingBoxYmax(bboxYmax[ii]),
70+
feature.BoundingBoxLabel(class[ii]),
71+
)
72+
}
73+
74+
return &CocoLabeledImage{
75+
width: width,
76+
height: height,
77+
fileName: fileName,
78+
sourceID: sourceID,
79+
sha256: sha256,
80+
area: area,
81+
isCrowd: isCrowd,
82+
features: features,
83+
data: img,
84+
}
85+
}

vision/ilsvrc2012_image.go

+5
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,8 @@ func (d *ILSVRC2012ValidationLabeledImage) Feature() *dlframework.Feature {
2828
feature.ClassificationLabel(d.Label()),
2929
)
3030
}
31+
32+
// Features ...
33+
func (l ILSVRC2012ValidationLabeledImage) Features() dlframework.Features {
34+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
35+
}

vision/ilsvrc2012_validation.go

+5
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ func (d *iLSVRC2012ValidationRecordIOLabeledData) Feature() *dlframework.Feature
6060
)
6161
}
6262

63+
// Features ...
64+
func (l iLSVRC2012ValidationRecordIOLabeledData) Features() dlframework.Features {
65+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
66+
}
67+
6368
func (d *iLSVRC2012ValidationRecordIOLabeledData) Data() (interface{}, error) {
6469
return d.Image, nil
6570
}

vision/mnist.go

+5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ func (l MNISTLabeledImage) Feature() *dlframework.Feature {
4444
)
4545
}
4646

47+
// Features ...
48+
func (l MNISTLabeledImage) Features() dlframework.Features {
49+
return dlframework.Features([]*dlframework.Feature{l.Feature()})
50+
}
51+
4752
// Data ...
4853
func (l MNISTLabeledImage) Data() (interface{}, error) {
4954
return l.data, nil

0 commit comments

Comments
 (0)