Skip to content

Commit 271f48a

Browse files
author
Abdul Dakkak
committed
start work on recordio
1 parent f59de1d commit 271f48a

8 files changed

+346
-50
lines changed

vision/cifar10.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,10 @@ func (d *CIFAR10) Download(ctx context.Context) error {
8282
}
8383
workingDir := d.workingDir()
8484
downloadedFileName := filepath.Join(workingDir, d.fileName)
85-
downloadedFileName, err := downloadmanager.DownloadFile(d.url, downloadedFileName, downloadmanager.Context(ctx))
85+
downloadedFileName, err := downloadmanager.DownloadFile(d.url, downloadedFileName, downloadmanager.Context(ctx), downloadmanager.MD5Sum(d.md5sum))
8686
if err != nil {
8787
return err
8888
}
89-
ok, err := utils.MD5Sum.CheckFile(downloadedFileName, d.md5sum)
90-
if err != nil {
91-
return errors.Wrapf(err, "unable to perform md5sum on %s", downloadedFileName)
92-
}
93-
if !ok {
94-
return errors.Wrapf(err, "the md5 sum for %s did not match expected %s", downloadedFileName, d.md5sum)
95-
}
9689
if err := downloadmanager.Unarchive(workingDir, downloadedFileName); err != nil {
9790
return err
9891
}

vision/cifar100.go

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,17 +96,10 @@ func (d *CIFAR100) Download(ctx context.Context) error {
9696
}
9797
workingDir := d.workingDir()
9898
downloadedFileName := filepath.Join(workingDir, d.fileName)
99-
downloadedFileName, err := downloadmanager.DownloadFile(d.url, downloadedFileName, downloadmanager.Context(ctx))
99+
downloadedFileName, err := downloadmanager.DownloadFile(d.url, downloadedFileName, downloadmanager.Context(ctx), downloadmanager.MD5Sum(d.md5sum))
100100
if err != nil {
101101
return err
102102
}
103-
ok, err := utils.MD5Sum.CheckFile(downloadedFileName, d.md5sum)
104-
if err != nil {
105-
return errors.Wrapf(err, "unable to perform md5sum on %s", downloadedFileName)
106-
}
107-
if !ok {
108-
return errors.Wrapf(err, "the md5 sum for %s did not match expected %s", downloadedFileName, d.md5sum)
109-
}
110103
if err := downloadmanager.Unarchive(workingDir, downloadedFileName); err != nil {
111104
return err
112105
}

vision/ilsvrc2012_image.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package vision
2+
3+
import "github.com/rai-project/image/types"
4+
5+
// ILSVRC2012ValidationLabeledImage ...
6+
type ILSVRC2012ValidationLabeledImage struct {
7+
label string
8+
data *types.RGBImage
9+
}
10+
11+
// Label ...
12+
func (l ILSVRC2012ValidationLabeledImage) Label() string {
13+
return l.label
14+
}
15+
16+
// Data ...
17+
func (l ILSVRC2012ValidationLabeledImage) Data() (interface{}, error) {
18+
return l.data, nil
19+
}

vision/ilsvrc2012_validation.go

Lines changed: 264 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,268 @@
11
package vision
22

33
import (
4-
_ "github.com/PaddlePaddle/recordio"
4+
"bytes"
5+
"io/ioutil"
6+
"os"
7+
"path"
8+
"path/filepath"
9+
"strings"
10+
11+
"github.com/PaddlePaddle/recordio"
12+
"github.com/Unknwon/com"
13+
"github.com/pkg/errors"
14+
"github.com/rai-project/config"
15+
"github.com/rai-project/dldataset"
16+
"github.com/rai-project/downloadmanager"
17+
"github.com/spf13/cast"
18+
context "golang.org/x/net/context"
19+
"golang.org/x/sync/errgroup"
20+
)
21+
22+
var (
23+
iLSVRC2012ValidationRecordIO *ILSVRC2012ValidationRecordIO
24+
iLSVRC2012Validation224RecordIO *ILSVRC2012ValidationRecordIO
25+
iLSVRC2012Validation227RecordIO *ILSVRC2012ValidationRecordIO
526
)
27+
28+
// ILSVRC2012ValidationFolder ...
29+
type ILSVRC2012ValidationRecordIO struct {
30+
base
31+
imageSize int
32+
baseURL string
33+
listFileName string
34+
indexFileName string
35+
recordFileName string
36+
index *recordio.Index
37+
recordScanner *recordio.RangeScanner
38+
recordFile *os.File
39+
fileOffsetMapping map[string]recordIoOffset
40+
data map[string]ILSVRC2012ValidationLabeledImage
41+
}
42+
43+
type recordIoOffset struct {
44+
start int
45+
end int
46+
}
47+
48+
func (d *ILSVRC2012ValidationRecordIO) New(ctx context.Context) (dldataset.Dataset, error) {
49+
return nil, nil
50+
}
51+
func (d *ILSVRC2012ValidationRecordIO) Name() string {
52+
if d.imageSize == 0 {
53+
return "ilsvrc2012_validation"
54+
}
55+
return "ilsvrc2012_validation_" + cast.ToString(d.imageSize)
56+
}
57+
58+
func (d *ILSVRC2012ValidationRecordIO) CanonicalName() string {
59+
category := strings.ToLower(d.Category())
60+
name := strings.ToLower(d.Name())
61+
key := path.Join(category, name)
62+
return key
63+
}
64+
65+
func (d *ILSVRC2012ValidationRecordIO) workingDir() string {
66+
category := strings.ToLower(d.Category())
67+
name := strings.ToLower(d.Name())
68+
return filepath.Join(d.baseWorkingDir, category, name)
69+
}
70+
71+
func (d *ILSVRC2012ValidationRecordIO) Download(ctx context.Context) error {
72+
grp, ctx := errgroup.WithContext(ctx)
73+
files := []string{d.listFileName, d.indexFileName, d.recordFileName}
74+
workingDir := d.workingDir()
75+
for ii := range files {
76+
fileName := files[ii]
77+
grp.Go(func() error {
78+
downloadedFileName := filepath.Join(workingDir, fileName)
79+
if com.IsFile(downloadedFileName) {
80+
return nil
81+
}
82+
downloadedFileName, err := downloadmanager.DownloadFile(
83+
urlJoin(d.baseURL, fileName),
84+
downloadedFileName,
85+
downloadmanager.Context(ctx),
86+
)
87+
if err != nil {
88+
return errors.Wrapf(err, "failed to download %v", fileName)
89+
}
90+
return nil
91+
})
92+
}
93+
err := grp.Wait()
94+
if err != nil {
95+
return err
96+
}
97+
_, err = d.populate(ctx)
98+
if err != nil {
99+
return err
100+
}
101+
return nil
102+
}
103+
104+
func keysFileOffset(s map[string]recordIoOffset) []string {
105+
keys := make([]string, len(s))
106+
107+
ii := 0
108+
for k := range s {
109+
keys[ii] = k
110+
ii++
111+
}
112+
return keys
113+
}
114+
115+
func (d *ILSVRC2012ValidationRecordIO) populate(ctx context.Context) ([]string, error) {
116+
117+
workingDir := d.workingDir()
118+
listFileName := filepath.Join(workingDir, d.listFileName)
119+
if !com.IsFile(listFileName) {
120+
return nil, errors.Errorf("unable to find the list file in %v make sure to download the dataset first", listFileName)
121+
}
122+
123+
bts, err := ioutil.ReadFile(listFileName)
124+
if err != nil {
125+
return nil, errors.Wrapf(err, "failed to read %v", listFileName)
126+
}
127+
128+
fileContent := strings.TrimSpace(string(bts))
129+
lines := strings.Split(fileContent, "\n")
130+
files := make([]string, len(lines))
131+
d.fileOffsetMapping = make(map[string]recordIoOffset)
132+
for ii, line := range lines {
133+
fields := strings.Fields(line)
134+
fileName := fields[len(fields)-1]
135+
d.fileOffsetMapping[fileName] = recordIoOffset{
136+
start: cast.ToInt(fields[0]),
137+
end: cast.ToInt(fields[1]),
138+
}
139+
files[ii] = fileName
140+
}
141+
142+
return files, nil
143+
}
144+
145+
func (d *ILSVRC2012ValidationRecordIO) List(ctx context.Context) ([]string, error) {
146+
147+
if len(d.fileOffsetMapping) != 0 {
148+
return d.populate(ctx)
149+
}
150+
151+
return keysFileOffset(d.fileOffsetMapping), nil
152+
}
153+
154+
func (d *ILSVRC2012ValidationRecordIO) loadIndex(ctx context.Context) error {
155+
workingDir := d.workingDir()
156+
indexFileName := filepath.Join(workingDir, d.indexFileName)
157+
if !com.IsFile(indexFileName) {
158+
return errors.Errorf("unable to find the index file in %v make sure to download the dataset first", indexFileName)
159+
}
160+
161+
bts, err := ioutil.ReadFile(indexFileName)
162+
if err != nil {
163+
return errors.Wrapf(err, "failed to read %v", indexFileName)
164+
}
165+
166+
idx, err := recordio.LoadIndex(bytes.NewReader(bts))
167+
if err != nil {
168+
return errors.Wrapf(err, "failed to load index from %v", indexFileName)
169+
}
170+
d.index = idx
171+
return nil
172+
}
173+
174+
func (d *ILSVRC2012ValidationRecordIO) loadRecord(ctx context.Context, offset recordIoOffset) error {
175+
workingDir := d.workingDir()
176+
recordFileName := filepath.Join(workingDir, d.recordFileName)
177+
if !com.IsFile(recordFileName) {
178+
return errors.Errorf("unable to find the record file in %v make sure to download the dataset first", recordFileName)
179+
}
180+
181+
if d.recordFile == nil {
182+
f, err := os.Open(recordFileName)
183+
if err != nil {
184+
return errors.Wrapf(err, "failed to open %v", recordFileName)
185+
}
186+
d.recordFile = f
187+
}
188+
189+
rng := recordio.NewRangeScanner(d.recordFile, d.index, offset.start, offset.end)
190+
if rng == nil {
191+
return errors.Errorf("failed to load record from %v", recordFileName)
192+
}
193+
d.recordScanner = rng
194+
return nil
195+
}
196+
197+
func (d *ILSVRC2012ValidationRecordIO) Get(ctx context.Context, name string) (dldataset.LabeledData, error) {
198+
fileOffsetMapping := d.fileOffsetMapping
199+
offset, ok := fileOffsetMapping[name]
200+
if !ok {
201+
return nil, errors.Errorf("file %v not found", name)
202+
}
203+
if d.index == nil {
204+
if err := d.loadIndex(ctx); err != nil {
205+
return nil, err
206+
}
207+
}
208+
if d.recordScanner == nil {
209+
if err := d.loadRecord(ctx, offset); err != nil {
210+
return nil, err
211+
}
212+
}
213+
d.recordScanner.Record()
214+
return nil, nil
215+
}
216+
217+
func (d *ILSVRC2012ValidationRecordIO) Close() error {
218+
if d.recordFile != nil {
219+
d.recordFile.Close()
220+
}
221+
return nil
222+
}
223+
224+
func init() {
225+
config.AfterInit(func() {
226+
227+
const fileListPath = "/vision/support/ilsvrc2012_validation_file_list.txt"
228+
const baseURL = "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-"
229+
230+
iLSVRC2012ValidationRecordIO = &ILSVRC2012ValidationRecordIO{
231+
base: base{
232+
ctx: context.Background(),
233+
baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"),
234+
},
235+
baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/ilsvrc2012_validation_recordio",
236+
listFileName: "imagenet1k-val.lst",
237+
indexFileName: "imagenet1k-val.idx",
238+
recordFileName: "imagenet1k-val.rec",
239+
}
240+
241+
iLSVRC2012Validation224RecordIO = &ILSVRC2012ValidationRecordIO{
242+
base: base{
243+
ctx: context.Background(),
244+
baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"),
245+
},
246+
imageSize: 224,
247+
baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-224",
248+
listFileName: "imagenet1k-val-224.lst",
249+
indexFileName: "imagenet1k-val-224.idx",
250+
recordFileName: "imagenet1k-val-224.rec",
251+
}
252+
253+
iLSVRC2012Validation227RecordIO = &ILSVRC2012ValidationRecordIO{
254+
base: base{
255+
ctx: context.Background(),
256+
baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"),
257+
},
258+
imageSize: 227,
259+
baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-227",
260+
listFileName: "imagenet1k-val-227.lst",
261+
indexFileName: "imagenet1k-val-227.idx",
262+
recordFileName: "imagenet1k-val-227.rec",
263+
}
264+
dldataset.Register(iLSVRC2012ValidationRecordIO)
265+
dldataset.Register(iLSVRC2012Validation224RecordIO)
266+
dldataset.Register(iLSVRC2012Validation227RecordIO)
267+
})
268+
}

0 commit comments

Comments
 (0)