Skip to content

Commit 3d44bd5

Browse files
author
Abdul Dakkak
committed
start work on pascal dataset
1 parent 457cffb commit 3d44bd5

File tree

3 files changed

+271
-19
lines changed

3 files changed

+271
-19
lines changed

vision/coco.go

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package vision
22

33
import (
4-
"bytes"
54
context "context"
65
"path"
76
"path/filepath"
@@ -16,7 +15,6 @@ import (
1615
"github.com/rai-project/dlframework"
1716
"github.com/rai-project/dlframework/framework/feature"
1817
"github.com/rai-project/downloadmanager"
19-
"github.com/rai-project/image"
2018
"github.com/rai-project/image/types"
2119
protobuf "github.com/ubccr/terf/protobuf"
2220
)
@@ -157,24 +155,10 @@ func (d *CocoValidationTFRecord) Next(ctx context.Context) (dldataset.LabeledDat
157155
return nil, err
158156
}
159157

160-
return nextCocoFromRecord(rec), nil
158+
return NewCocoLabeledImageFromRecord(rec), nil
161159
}
162160

163-
func getImageRecord(data []byte, format string) (*types.RGBImage, error) {
164-
img, err := image.Read(bytes.NewBuffer(data), image.Context(nil))
165-
if err != nil {
166-
return nil, err
167-
}
168-
169-
rgbImage, ok := img.(*types.RGBImage)
170-
if !ok {
171-
return nil, errors.Errorf("expecting an rgb image")
172-
}
173-
174-
return rgbImage, nil
175-
}
176-
177-
func nextCocoFromRecord(rec *protobuf.Example) *CocoLabeledImage {
161+
func NewCocoLabeledImageFromRecord(rec *protobuf.Example) *CocoLabeledImage {
178162
height := tfrecord.FeatureInt64(rec, "image/height")
179163
width := tfrecord.FeatureInt64(rec, "image/width")
180164
fileName := tfrecord.FeatureString(rec, "image/filename")
@@ -203,6 +187,8 @@ func nextCocoFromRecord(rec *protobuf.Example) *CocoLabeledImage {
203187
feature.BoundingBoxYmin(bboxYmin[ii]),
204188
feature.BoundingBoxYmax(bboxYmax[ii]),
205189
feature.BoundingBoxLabel(class[ii]),
190+
feature.AppendMetadata("isCrowd", isCrowd[ii]),
191+
feature.AppendMetadata("area", area[ii]),
206192
)
207193
}
208194

vision/pascal.go

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
package vision
2+
3+
import (
4+
context "context"
5+
"path"
6+
"path/filepath"
7+
"strings"
8+
9+
"github.com/Unknwon/com"
10+
"github.com/pkg/errors"
11+
"github.com/rai-project/config"
12+
"github.com/rai-project/dldataset"
13+
"github.com/rai-project/dldataset/reader"
14+
"github.com/rai-project/dldataset/reader/tfrecord"
15+
"github.com/rai-project/dlframework"
16+
"github.com/rai-project/dlframework/framework/feature"
17+
"github.com/rai-project/downloadmanager"
18+
"github.com/rai-project/image/types"
19+
protobuf "github.com/ubccr/terf/protobuf"
20+
)
21+
22+
// PascalLabeledImage ...
23+
type PascalLabeledImage struct {
24+
width int64
25+
height int64
26+
fileName string
27+
sourceID string
28+
sha256 string
29+
difficult []int64
30+
truncated []int64
31+
pose []byte
32+
features []*dlframework.Feature
33+
data *types.RGBImage
34+
}
35+
36+
// PascalValidationTFRecord ...
37+
type PascalValidationTFRecord struct {
38+
base
39+
name string
40+
baseURL string
41+
recordFileName string
42+
md5sum string
43+
recordReader *reader.TFRecordReader
44+
}
45+
46+
var (
47+
Pascal2007ValidationTFRecord *PascalValidationTFRecord
48+
Pascal2012ValidationTFRecord *PascalValidationTFRecord
49+
)
50+
51+
func NewPascalLabeledImageFromRecord(rec *protobuf.Example) *PascalLabeledImage {
52+
height := tfrecord.FeatureInt64(rec, "image/height")
53+
width := tfrecord.FeatureInt64(rec, "image/width")
54+
fileName := tfrecord.FeatureString(rec, "image/filename")
55+
sourceID := tfrecord.FeatureString(rec, "image/source_id")
56+
sha256 := tfrecord.FeatureString(rec, "image/key/sha256")
57+
imgFormat := tfrecord.FeatureString(rec, "image/format")
58+
img, err := getImageRecord(tfrecord.FeatureBytes(rec, "image/encoded"), imgFormat)
59+
if err != nil {
60+
panic(err)
61+
}
62+
bboxXmin := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/xmin")
63+
bboxXmax := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/xmax")
64+
bboxYmin := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/ymin")
65+
bboxYmax := tfrecord.FeatureFloat32Slice(rec, "image/object/bbox/ymax")
66+
classText := tfrecord.FeatureStringSlice(rec, "image/object/class/text")
67+
classesLabels := tfrecord.FeatureInt64Slice(rec, "image/object/class/label")
68+
difficult := tfrecord.FeatureInt64Slice(rec, "image/object/difficult")
69+
truncated := tfrecord.FeatureInt64Slice(rec, "image/object/truncated")
70+
pose := tfrecord.FeatureBytes(rec, "image/object/view")
71+
72+
numBBoxes := len(bboxXmax)
73+
features := make([]*dlframework.Feature, numBBoxes)
74+
for ii := 0; ii < numBBoxes; ii++ {
75+
features[ii] = feature.New(
76+
feature.BoundingBoxType(),
77+
feature.BoundingBoxXmin(bboxXmin[ii]),
78+
feature.BoundingBoxXmax(bboxXmax[ii]),
79+
feature.BoundingBoxYmin(bboxYmin[ii]),
80+
feature.BoundingBoxYmax(bboxYmax[ii]),
81+
feature.BoundingBoxIndex(int32(classesLabels[ii])),
82+
feature.BoundingBoxLabel(classText[ii]),
83+
feature.AppendMetadata("difficult", difficult[ii]),
84+
feature.AppendMetadata("truncated", truncated[ii]),
85+
feature.AppendMetadata("pose", pose[ii]),
86+
)
87+
}
88+
89+
return &PascalLabeledImage{
90+
width: width,
91+
height: height,
92+
fileName: fileName,
93+
sourceID: sourceID,
94+
sha256: sha256,
95+
difficult: difficult,
96+
truncated: truncated,
97+
pose: pose,
98+
features: features,
99+
data: img,
100+
}
101+
}
102+
103+
// Label ...
104+
func (l *PascalLabeledImage) Label() string {
105+
return "<undefined>"
106+
}
107+
108+
// Data ...
109+
func (l *PascalLabeledImage) Data() (interface{}, error) {
110+
return l.data, nil
111+
}
112+
113+
// Feature ...
114+
func (d *PascalLabeledImage) Feature() *dlframework.Feature {
115+
return d.features[0]
116+
}
117+
118+
// Features ...
119+
func (d *PascalLabeledImage) Features() dlframework.Features {
120+
return d.features
121+
}
122+
123+
// Name ...
124+
func (d *PascalValidationTFRecord) Name() string {
125+
return d.name
126+
}
127+
128+
// CanonicalName ...
129+
func (d *PascalValidationTFRecord) CanonicalName() string {
130+
category := strings.ToLower(d.Category())
131+
name := strings.ToLower(d.Name())
132+
key := path.Join(category, name)
133+
return key
134+
}
135+
136+
func (d *PascalValidationTFRecord) workingDir() string {
137+
category := strings.ToLower(d.Category())
138+
name := strings.ToLower(d.Name())
139+
return filepath.Join(d.baseWorkingDir, category, name)
140+
}
141+
142+
// Download ...
143+
func (d *PascalValidationTFRecord) Download(ctx context.Context) error {
144+
workingDir := d.workingDir()
145+
fileName := d.recordFileName
146+
downloadedFileName := filepath.Join(workingDir, fileName)
147+
if com.IsFile(downloadedFileName) {
148+
return nil
149+
}
150+
downloadedFileName, err := downloadmanager.DownloadFile(
151+
urlJoin(d.baseURL, fileName),
152+
downloadedFileName,
153+
downloadmanager.Context(ctx),
154+
)
155+
if err != nil {
156+
return errors.Wrapf(err, "failed to download %v", fileName)
157+
}
158+
return nil
159+
}
160+
161+
// New ...
162+
func (d *PascalValidationTFRecord) New(ctx context.Context) (dldataset.Dataset, error) {
163+
return nil, nil
164+
}
165+
166+
// Get ...
167+
func (d *PascalValidationTFRecord) Get(ctx context.Context, name string) (dldataset.LabeledData, error) {
168+
return nil, errors.New("get is not implemented for " + d.CanonicalName())
169+
}
170+
171+
// List ...
172+
func (d *PascalValidationTFRecord) List(ctx context.Context) ([]string, error) {
173+
return nil, errors.New("list is not implemented for " + d.CanonicalName())
174+
}
175+
176+
func (d *PascalValidationTFRecord) loadRecord(ctx context.Context) error {
177+
workingDir := d.workingDir()
178+
recordFileName := filepath.Join(workingDir, d.recordFileName)
179+
if !com.IsFile(recordFileName) {
180+
return errors.Errorf("unable to find the record file in %v make sure to download the dataset first", recordFileName)
181+
}
182+
183+
recordIOReader, err := reader.NewTFRecordReader(recordFileName)
184+
if err != nil {
185+
return errors.Wrapf(err, "failed to load record from %v", recordFileName)
186+
}
187+
d.recordReader = recordIOReader
188+
return nil
189+
}
190+
191+
// Load ...
192+
func (d *PascalValidationTFRecord) Load(ctx context.Context) error {
193+
return d.loadRecord(ctx)
194+
}
195+
196+
// Next ...
197+
func (d *PascalValidationTFRecord) Next(ctx context.Context) (dldataset.LabeledData, error) {
198+
rec, err := d.recordReader.NextRecord(ctx)
199+
if err != nil {
200+
return nil, err
201+
}
202+
203+
return NewPascalLabeledImageFromRecord(rec), nil
204+
}
205+
206+
// Close ...
207+
func (d *PascalValidationTFRecord) Close() error {
208+
if d.recordReader != nil {
209+
d.recordReader.Close()
210+
}
211+
return nil
212+
}
213+
214+
func init() {
215+
config.AfterInit(func() {
216+
217+
const baseURLPrefix = "https://s3.amazonaws.com/store.carml.org/datasets"
218+
219+
baseWorkingDir := filepath.Join(dldataset.Config.WorkingDirectory, "dldataset")
220+
Pascal2007ValidationTFRecord = &PascalValidationTFRecord{
221+
base: base{
222+
ctx: context.Background(),
223+
baseWorkingDir: baseWorkingDir,
224+
},
225+
name: "Pascal2007",
226+
baseURL: baseURLPrefix + "/pascal2007",
227+
recordFileName: "Pascal_val.record-00000-of-00001",
228+
md5sum: "b1f63512f72d3c84792a1f53ec40062a",
229+
}
230+
231+
Pascal2012ValidationTFRecord = &PascalValidationTFRecord{
232+
base: base{
233+
ctx: context.Background(),
234+
baseWorkingDir: baseWorkingDir,
235+
},
236+
name: "Pascal2012",
237+
baseURL: baseURLPrefix + "/pascal2012",
238+
recordFileName: "Pascal_val.record-00000-of-00001",
239+
md5sum: "b8a0cfed5ad569d4572b4ad8645acb5b",
240+
}
241+
242+
dldataset.Register(Pascal2007ValidationTFRecord)
243+
dldataset.Register(Pascal2012ValidationTFRecord)
244+
})
245+
}

vision/utils.go

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,31 @@
11
package vision
22

3-
import "strings"
3+
import (
4+
"bytes"
5+
"strings"
6+
7+
"github.com/pkg/errors"
8+
"github.com/rai-project/image"
9+
"github.com/rai-project/image/types"
10+
)
411

512
func urlJoin(base string, n string) string {
613
if strings.HasPrefix(base, "/") {
714
base = strings.TrimPrefix(base, "/")
815
}
916
return strings.Join([]string{base, n}, "/")
1017
}
18+
19+
func getImageRecord(data []byte, format string) (*types.RGBImage, error) {
20+
img, err := image.Read(bytes.NewBuffer(data), image.Context(nil))
21+
if err != nil {
22+
return nil, err
23+
}
24+
25+
rgbImage, ok := img.(*types.RGBImage)
26+
if !ok {
27+
return nil, errors.Errorf("expecting an rgb image")
28+
}
29+
30+
return rgbImage, nil
31+
}

0 commit comments

Comments
 (0)