Skip to content

Commit c86827c

Browse files
author
Abdul Dakkak
committed
Add helper functions
1 parent 54207de commit c86827c

File tree

3 files changed

+69
-13
lines changed

3 files changed

+69
-13
lines changed

vision/coco.go

+32-13
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package vision
22

33
import (
44
context "context"
5+
"fmt"
56
"path"
67
"path/filepath"
78
"strings"
89

10+
"github.com/rai-project/dldataset/vision/support/object_detection"
11+
912
"github.com/Unknwon/com"
1013
"github.com/pkg/errors"
1114
"github.com/rai-project/config"
@@ -35,11 +38,13 @@ type CocoLabeledImage struct {
3538
// CocoValidationTFRecord ...
3639
type CocoValidationTFRecord struct {
3740
base
38-
name string
39-
baseURL string
40-
recordFileName string
41-
md5sum string
42-
recordReader *reader.TFRecordReader
41+
name string
42+
baseURL string
43+
recordFileName string
44+
md5sum string
45+
labelMap object_detection.StringIntLabelMap
46+
completeLabelMap object_detection.StringIntLabelMap
47+
recordReader *reader.TFRecordReader
4348
}
4449

4550
var (
@@ -210,27 +215,41 @@ func init() {
210215

211216
const baseURLPrefix = "https://s3.amazonaws.com/store.carml.org/datasets"
212217

218+
labelMap, err := object_detection.Get("mscoco_label_map.pbtxt")
219+
if err != nil {
220+
panic(fmt.Sprintf("failed to get mscoco_label_map.pbtxt due to %v", err))
221+
}
222+
223+
completeLabelMap, err := object_detection.Get("mscoco_complete_label_map.pbtxt")
224+
if err != nil {
225+
panic(fmt.Sprintf("failed to get mscoco_complete_label_map.pbtxt due to %v", err))
226+
}
227+
213228
baseWorkingDir := filepath.Join(dldataset.Config.WorkingDirectory, "dldataset")
214229
coco2014ValidationTFRecord = &CocoValidationTFRecord{
215230
base: base{
216231
ctx: context.Background(),
217232
baseWorkingDir: baseWorkingDir,
218233
},
219-
name: "coco2014",
220-
baseURL: baseURLPrefix + "/coco2014",
221-
recordFileName: "coco_val.record-00000-of-00001",
222-
md5sum: "b1f63512f72d3c84792a1f53ec40062a",
234+
name: "coco2014",
235+
baseURL: baseURLPrefix + "/coco2014",
236+
labelMap: labelMap,
237+
completeLabelMap: completeLabelMap,
238+
recordFileName: "coco_val.record-00000-of-00001",
239+
md5sum: "b1f63512f72d3c84792a1f53ec40062a",
223240
}
224241

225242
coco2017ValidationTFRecord = &CocoValidationTFRecord{
226243
base: base{
227244
ctx: context.Background(),
228245
baseWorkingDir: baseWorkingDir,
229246
},
230-
name: "coco2017",
231-
baseURL: baseURLPrefix + "/coco2017",
232-
recordFileName: "coco_val.record-00000-of-00001",
233-
md5sum: "b8a0cfed5ad569d4572b4ad8645acb5b",
247+
name: "coco2017",
248+
baseURL: baseURLPrefix + "/coco2017",
249+
labelMap: labelMap,
250+
completeLabelMap: completeLabelMap,
251+
recordFileName: "coco_val.record-00000-of-00001",
252+
md5sum: "b8a0cfed5ad569d4572b4ad8645acb5b",
234253
}
235254

236255
dldataset.Register(coco2014ValidationTFRecord)

vision/pascal.go

+10
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package vision
22

33
import (
44
context "context"
5+
"fmt"
56
"path"
67
"path/filepath"
78
"strings"
@@ -12,6 +13,7 @@ import (
1213
"github.com/rai-project/dldataset"
1314
"github.com/rai-project/dldataset/reader"
1415
"github.com/rai-project/dldataset/reader/tfrecord"
16+
"github.com/rai-project/dldataset/vision/support/object_detection"
1517
"github.com/rai-project/dlframework"
1618
"github.com/rai-project/dlframework/framework/feature"
1719
"github.com/rai-project/downloadmanager"
@@ -40,6 +42,7 @@ type PascalValidationTFRecord struct {
4042
baseURL string
4143
recordFileName string
4244
md5sum string
45+
labelMap object_detection.StringIntLabelMap
4346
recordReader *reader.TFRecordReader
4447
}
4548

@@ -216,13 +219,19 @@ func init() {
216219

217220
const baseURLPrefix = "https://s3.amazonaws.com/store.carml.org/datasets"
218221

222+
labelMap, err := object_detection.Get("pascal_label_map.pbtxt")
223+
if err != nil {
224+
panic(fmt.Sprintf("failed to get pascal_label_map.pbtxt due to %v", err))
225+
}
226+
219227
baseWorkingDir := filepath.Join(dldataset.Config.WorkingDirectory, "dldataset")
220228
Pascal2007ValidationTFRecord = &PascalValidationTFRecord{
221229
base: base{
222230
ctx: context.Background(),
223231
baseWorkingDir: baseWorkingDir,
224232
},
225233
name: "Pascal2007",
234+
labelMap: labelMap,
226235
baseURL: baseURLPrefix + "/pascal2007",
227236
recordFileName: "validation.tfrecord",
228237
md5sum: "e646ecf0bf838fa39d34e58d87c3e914",
@@ -234,6 +243,7 @@ func init() {
234243
baseWorkingDir: baseWorkingDir,
235244
},
236245
name: "Pascal2012",
246+
labelMap: labelMap,
237247
baseURL: baseURLPrefix + "/pascal2012",
238248
recordFileName: "validation.tfrecord",
239249
md5sum: "9a59d26492103b8635ba0c916d68535a",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package object_detection
2+
3+
func (s *StringIntLabelMap) Get(idx int) *StringIntLabelMapItem {
4+
items := s.GetItem()
5+
if idx < 0 {
6+
return nil
7+
}
8+
if idx >= len(items) {
9+
return nil
10+
}
11+
return items[idx]
12+
}
13+
14+
func (s *StringIntLabelMap) GetName(idx int) string {
15+
item := s.Get(idx)
16+
return item.GetName()
17+
}
18+
19+
func (s *StringIntLabelMap) GetId(idx int) int32 {
20+
item := s.Get(idx)
21+
return item.GetId()
22+
}
23+
24+
func (s *StringIntLabelMap) GetDisplayName(idx int) string {
25+
item := s.Get(idx)
26+
return item.GetDisplayName()
27+
}

0 commit comments

Comments
 (0)