|
1 | 1 | package vision
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - _ "github.com/PaddlePaddle/recordio" |
| 4 | + "bytes" |
| 5 | + "io/ioutil" |
| 6 | + "os" |
| 7 | + "path" |
| 8 | + "path/filepath" |
| 9 | + "strings" |
| 10 | + |
| 11 | + "github.com/PaddlePaddle/recordio" |
| 12 | + "github.com/Unknwon/com" |
| 13 | + "github.com/pkg/errors" |
| 14 | + "github.com/rai-project/config" |
| 15 | + "github.com/rai-project/dldataset" |
| 16 | + "github.com/rai-project/downloadmanager" |
| 17 | + "github.com/spf13/cast" |
| 18 | + context "golang.org/x/net/context" |
| 19 | + "golang.org/x/sync/errgroup" |
| 20 | +) |
| 21 | + |
| 22 | +var ( |
| 23 | + iLSVRC2012ValidationRecordIO *ILSVRC2012ValidationRecordIO |
| 24 | + iLSVRC2012Validation224RecordIO *ILSVRC2012ValidationRecordIO |
| 25 | + iLSVRC2012Validation227RecordIO *ILSVRC2012ValidationRecordIO |
5 | 26 | )
|
| 27 | + |
| 28 | +// ILSVRC2012ValidationFolder ... |
| 29 | +type ILSVRC2012ValidationRecordIO struct { |
| 30 | + base |
| 31 | + imageSize int |
| 32 | + baseURL string |
| 33 | + listFileName string |
| 34 | + indexFileName string |
| 35 | + recordFileName string |
| 36 | + index *recordio.Index |
| 37 | + recordScanner *recordio.RangeScanner |
| 38 | + recordFile *os.File |
| 39 | + fileOffsetMapping map[string]recordIoOffset |
| 40 | + data map[string]ILSVRC2012ValidationLabeledImage |
| 41 | +} |
| 42 | + |
| 43 | +type recordIoOffset struct { |
| 44 | + start int |
| 45 | + end int |
| 46 | +} |
| 47 | + |
| 48 | +func (d *ILSVRC2012ValidationRecordIO) New(ctx context.Context) (dldataset.Dataset, error) { |
| 49 | + return nil, nil |
| 50 | +} |
| 51 | +func (d *ILSVRC2012ValidationRecordIO) Name() string { |
| 52 | + if d.imageSize == 0 { |
| 53 | + return "ilsvrc2012_validation" |
| 54 | + } |
| 55 | + return "ilsvrc2012_validation_" + cast.ToString(d.imageSize) |
| 56 | +} |
| 57 | + |
| 58 | +func (d *ILSVRC2012ValidationRecordIO) CanonicalName() string { |
| 59 | + category := strings.ToLower(d.Category()) |
| 60 | + name := strings.ToLower(d.Name()) |
| 61 | + key := path.Join(category, name) |
| 62 | + return key |
| 63 | +} |
| 64 | + |
| 65 | +func (d *ILSVRC2012ValidationRecordIO) workingDir() string { |
| 66 | + category := strings.ToLower(d.Category()) |
| 67 | + name := strings.ToLower(d.Name()) |
| 68 | + return filepath.Join(d.baseWorkingDir, category, name) |
| 69 | +} |
| 70 | + |
| 71 | +func (d *ILSVRC2012ValidationRecordIO) Download(ctx context.Context) error { |
| 72 | + grp, ctx := errgroup.WithContext(ctx) |
| 73 | + files := []string{d.listFileName, d.indexFileName, d.recordFileName} |
| 74 | + workingDir := d.workingDir() |
| 75 | + for ii := range files { |
| 76 | + fileName := files[ii] |
| 77 | + grp.Go(func() error { |
| 78 | + downloadedFileName := filepath.Join(workingDir, fileName) |
| 79 | + if com.IsFile(downloadedFileName) { |
| 80 | + return nil |
| 81 | + } |
| 82 | + downloadedFileName, err := downloadmanager.DownloadFile( |
| 83 | + urlJoin(d.baseURL, fileName), |
| 84 | + downloadedFileName, |
| 85 | + downloadmanager.Context(ctx), |
| 86 | + ) |
| 87 | + if err != nil { |
| 88 | + return errors.Wrapf(err, "failed to download %v", fileName) |
| 89 | + } |
| 90 | + return nil |
| 91 | + }) |
| 92 | + } |
| 93 | + err := grp.Wait() |
| 94 | + if err != nil { |
| 95 | + return err |
| 96 | + } |
| 97 | + _, err = d.populate(ctx) |
| 98 | + if err != nil { |
| 99 | + return err |
| 100 | + } |
| 101 | + return nil |
| 102 | +} |
| 103 | + |
| 104 | +func keysFileOffset(s map[string]recordIoOffset) []string { |
| 105 | + keys := make([]string, len(s)) |
| 106 | + |
| 107 | + ii := 0 |
| 108 | + for k := range s { |
| 109 | + keys[ii] = k |
| 110 | + ii++ |
| 111 | + } |
| 112 | + return keys |
| 113 | +} |
| 114 | + |
| 115 | +func (d *ILSVRC2012ValidationRecordIO) populate(ctx context.Context) ([]string, error) { |
| 116 | + |
| 117 | + workingDir := d.workingDir() |
| 118 | + listFileName := filepath.Join(workingDir, d.listFileName) |
| 119 | + if !com.IsFile(listFileName) { |
| 120 | + return nil, errors.Errorf("unable to find the list file in %v make sure to download the dataset first", listFileName) |
| 121 | + } |
| 122 | + |
| 123 | + bts, err := ioutil.ReadFile(listFileName) |
| 124 | + if err != nil { |
| 125 | + return nil, errors.Wrapf(err, "failed to read %v", listFileName) |
| 126 | + } |
| 127 | + |
| 128 | + fileContent := strings.TrimSpace(string(bts)) |
| 129 | + lines := strings.Split(fileContent, "\n") |
| 130 | + files := make([]string, len(lines)) |
| 131 | + d.fileOffsetMapping = make(map[string]recordIoOffset) |
| 132 | + for ii, line := range lines { |
| 133 | + fields := strings.Fields(line) |
| 134 | + fileName := fields[len(fields)-1] |
| 135 | + d.fileOffsetMapping[fileName] = recordIoOffset{ |
| 136 | + start: cast.ToInt(fields[0]), |
| 137 | + end: cast.ToInt(fields[1]), |
| 138 | + } |
| 139 | + files[ii] = fileName |
| 140 | + } |
| 141 | + |
| 142 | + return files, nil |
| 143 | +} |
| 144 | + |
| 145 | +func (d *ILSVRC2012ValidationRecordIO) List(ctx context.Context) ([]string, error) { |
| 146 | + |
| 147 | + if len(d.fileOffsetMapping) != 0 { |
| 148 | + return d.populate(ctx) |
| 149 | + } |
| 150 | + |
| 151 | + return keysFileOffset(d.fileOffsetMapping), nil |
| 152 | +} |
| 153 | + |
| 154 | +func (d *ILSVRC2012ValidationRecordIO) loadIndex(ctx context.Context) error { |
| 155 | + workingDir := d.workingDir() |
| 156 | + indexFileName := filepath.Join(workingDir, d.indexFileName) |
| 157 | + if !com.IsFile(indexFileName) { |
| 158 | + return errors.Errorf("unable to find the index file in %v make sure to download the dataset first", indexFileName) |
| 159 | + } |
| 160 | + |
| 161 | + bts, err := ioutil.ReadFile(indexFileName) |
| 162 | + if err != nil { |
| 163 | + return errors.Wrapf(err, "failed to read %v", indexFileName) |
| 164 | + } |
| 165 | + |
| 166 | + idx, err := recordio.LoadIndex(bytes.NewReader(bts)) |
| 167 | + if err != nil { |
| 168 | + return errors.Wrapf(err, "failed to load index from %v", indexFileName) |
| 169 | + } |
| 170 | + d.index = idx |
| 171 | + return nil |
| 172 | +} |
| 173 | + |
| 174 | +func (d *ILSVRC2012ValidationRecordIO) loadRecord(ctx context.Context, offset recordIoOffset) error { |
| 175 | + workingDir := d.workingDir() |
| 176 | + recordFileName := filepath.Join(workingDir, d.recordFileName) |
| 177 | + if !com.IsFile(recordFileName) { |
| 178 | + return errors.Errorf("unable to find the record file in %v make sure to download the dataset first", recordFileName) |
| 179 | + } |
| 180 | + |
| 181 | + if d.recordFile == nil { |
| 182 | + f, err := os.Open(recordFileName) |
| 183 | + if err != nil { |
| 184 | + return errors.Wrapf(err, "failed to open %v", recordFileName) |
| 185 | + } |
| 186 | + d.recordFile = f |
| 187 | + } |
| 188 | + |
| 189 | + rng := recordio.NewRangeScanner(d.recordFile, d.index, offset.start, offset.end) |
| 190 | + if rng == nil { |
| 191 | + return errors.Errorf("failed to load record from %v", recordFileName) |
| 192 | + } |
| 193 | + d.recordScanner = rng |
| 194 | + return nil |
| 195 | +} |
| 196 | + |
| 197 | +func (d *ILSVRC2012ValidationRecordIO) Get(ctx context.Context, name string) (dldataset.LabeledData, error) { |
| 198 | + fileOffsetMapping := d.fileOffsetMapping |
| 199 | + offset, ok := fileOffsetMapping[name] |
| 200 | + if !ok { |
| 201 | + return nil, errors.Errorf("file %v not found", name) |
| 202 | + } |
| 203 | + if d.index == nil { |
| 204 | + if err := d.loadIndex(ctx); err != nil { |
| 205 | + return nil, err |
| 206 | + } |
| 207 | + } |
| 208 | + if d.recordScanner == nil { |
| 209 | + if err := d.loadRecord(ctx, offset); err != nil { |
| 210 | + return nil, err |
| 211 | + } |
| 212 | + } |
| 213 | + d.recordScanner.Record() |
| 214 | + return nil, nil |
| 215 | +} |
| 216 | + |
| 217 | +func (d *ILSVRC2012ValidationRecordIO) Close() error { |
| 218 | + if d.recordFile != nil { |
| 219 | + d.recordFile.Close() |
| 220 | + } |
| 221 | + return nil |
| 222 | +} |
| 223 | + |
| 224 | +func init() { |
| 225 | + config.AfterInit(func() { |
| 226 | + |
| 227 | + const fileListPath = "/vision/support/ilsvrc2012_validation_file_list.txt" |
| 228 | + const baseURL = "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-" |
| 229 | + |
| 230 | + iLSVRC2012ValidationRecordIO = &ILSVRC2012ValidationRecordIO{ |
| 231 | + base: base{ |
| 232 | + ctx: context.Background(), |
| 233 | + baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"), |
| 234 | + }, |
| 235 | + baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/ilsvrc2012_validation_recordio", |
| 236 | + listFileName: "imagenet1k-val.lst", |
| 237 | + indexFileName: "imagenet1k-val.idx", |
| 238 | + recordFileName: "imagenet1k-val.rec", |
| 239 | + } |
| 240 | + |
| 241 | + iLSVRC2012Validation224RecordIO = &ILSVRC2012ValidationRecordIO{ |
| 242 | + base: base{ |
| 243 | + ctx: context.Background(), |
| 244 | + baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"), |
| 245 | + }, |
| 246 | + imageSize: 224, |
| 247 | + baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-224", |
| 248 | + listFileName: "imagenet1k-val-224.lst", |
| 249 | + indexFileName: "imagenet1k-val-224.idx", |
| 250 | + recordFileName: "imagenet1k-val-224.rec", |
| 251 | + } |
| 252 | + |
| 253 | + iLSVRC2012Validation227RecordIO = &ILSVRC2012ValidationRecordIO{ |
| 254 | + base: base{ |
| 255 | + ctx: context.Background(), |
| 256 | + baseWorkingDir: filepath.Join(dldataset.Config.WorkingDirectory, "dldataset"), |
| 257 | + }, |
| 258 | + imageSize: 227, |
| 259 | + baseURL: "http://store.carml.org.s3.amazonaws.com/datasets/imagenet1k-val-227", |
| 260 | + listFileName: "imagenet1k-val-227.lst", |
| 261 | + indexFileName: "imagenet1k-val-227.idx", |
| 262 | + recordFileName: "imagenet1k-val-227.rec", |
| 263 | + } |
| 264 | + dldataset.Register(iLSVRC2012ValidationRecordIO) |
| 265 | + dldataset.Register(iLSVRC2012Validation224RecordIO) |
| 266 | + dldataset.Register(iLSVRC2012Validation227RecordIO) |
| 267 | + }) |
| 268 | +} |
0 commit comments