Skip to content

Commit 010143a

Browse files
author
abdul dakkak
authored
Merge pull request #1 from rai-project/feature/tfrecord
implements tfrecord reader
2 parents c28e4f8 + b09ac62 commit 010143a

10 files changed

+271
-10
lines changed

reader/.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11

22
image.png
3+
validation.tfrecord
4+
*tar.gz
5+
cifar-10-batches-py

reader/_fixtures/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
test.png
30.8 MB
Binary file not shown.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Read CIFAR-10 data from pickled numpy arrays and writes TFRecords.
16+
17+
Generates tf.train.Example protos and writes them to TFRecord files from the
18+
python version of the CIFAR-10 dataset downloaded from
19+
https://www.cs.toronto.edu/~kriz/cifar.html.
20+
"""
21+
22+
from __future__ import absolute_import
23+
from __future__ import division
24+
from __future__ import print_function
25+
26+
import argparse
27+
import os
28+
import sys
29+
30+
import tarfile
31+
from six.moves import cPickle as pickle
32+
from six.moves import xrange # pylint: disable=redefined-builtin
33+
import tensorflow as tf
34+
35+
CIFAR_FILENAME = 'cifar-10-python.tar.gz'
36+
CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
37+
CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
38+
39+
40+
def download_and_extract(data_dir):
41+
# download CIFAR-10 if not already downloaded.
42+
tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir,
43+
CIFAR_DOWNLOAD_URL)
44+
tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),
45+
'r:gz').extractall(data_dir)
46+
47+
48+
def _int64_feature(value):
49+
if not isinstance(value, list):
50+
value = [value]
51+
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
52+
53+
54+
def _bytes_feature(value):
55+
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
56+
57+
def _get_file_names():
58+
"""Returns the file names expected to exist in the input_dir."""
59+
file_names = {}
60+
# file_names['train'] = ['data_batch_%d' % i for i in xrange(1, 5)]
61+
file_names['validation'] = ['data_batch_5']
62+
# file_names['eval'] = ['test_batch']
63+
return file_names
64+
65+
66+
def read_pickle_from_file(filename):
67+
with tf.gfile.Open(filename, 'rb') as f:
68+
if sys.version_info >= (3, 0):
69+
data_dict = pickle.load(f, encoding='bytes')
70+
else:
71+
data_dict = pickle.load(f)
72+
return data_dict
73+
74+
75+
def convert_to_tfrecord(input_files, output_file):
76+
"""Converts a file to TFRecords."""
77+
print('Generating %s' % output_file)
78+
with tf.python_io.TFRecordWriter(output_file) as record_writer:
79+
for input_file in input_files:
80+
data_dict = read_pickle_from_file(input_file)
81+
data = data_dict[b'data']
82+
labels = data_dict[b'labels']
83+
num_entries_in_batch = len(labels)
84+
for i in range(num_entries_in_batch):
85+
example = tf.train.Example(features=tf.train.Features(
86+
feature={
87+
'image/id': _int64_feature(i),
88+
'image/encoded': _bytes_feature(data[i].tobytes()),
89+
'image/format': _bytes_feature(b'cifar'),
90+
'image/label': _int64_feature(labels[i]),
91+
'image/height': _int64_feature(32),
92+
'image/width': _int64_feature(32),
93+
}))
94+
record_writer.write(example.SerializeToString())
95+
96+
97+
def main(data_dir):
98+
print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
99+
download_and_extract(data_dir)
100+
file_names = _get_file_names()
101+
input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
102+
for mode, files in file_names.items():
103+
input_files = [os.path.join(input_dir, f) for f in files]
104+
output_file = os.path.join(data_dir, mode + '.tfrecords')
105+
try:
106+
os.remove(output_file)
107+
except OSError:
108+
pass
109+
# Convert to tf.train.Example and write the to TFRecords.
110+
convert_to_tfrecord(input_files, output_file)
111+
print('Done!')
112+
113+
114+
if __name__ == '__main__':
115+
parser = argparse.ArgumentParser()
116+
parser.add_argument(
117+
'--data-dir',
118+
type=str,
119+
default='',
120+
help='Directory to download and extract CIFAR-10 to.')
121+
122+
args = parser.parse_args()
123+
main(args.data_dir)

reader/mxnet_recordio.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@ type RecordIOReader struct {
2222
r io.ReadCloser
2323
}
2424

25-
type Record struct {
26-
ID uint64
27-
LabelIndex float32
28-
Image *types.RGBImage
29-
}
30-
3125
func NewRecordIOReader(path string) (*RecordIOReader, error) {
3226
r, err := os.Open(path)
3327
if err != nil {
@@ -38,7 +32,7 @@ func NewRecordIOReader(path string) (*RecordIOReader, error) {
3832
}, nil
3933
}
4034

41-
func (r *RecordIOReader) Next(ctx context.Context) (*Record, error) {
35+
func (r *RecordIOReader) Next(ctx context.Context) (*ImageRecord, error) {
4236
f := r.r
4337

4438
var magic uint32
@@ -113,7 +107,7 @@ func (r *RecordIOReader) Next(ctx context.Context) (*Record, error) {
113107
return nil, errors.Errorf("expecting an rgb image")
114108
}
115109

116-
return &Record{
110+
return &ImageRecord{
117111
ID: imageId1,
118112
LabelIndex: label,
119113
Image: rgbImage,

reader/record.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package reader
2+
3+
import "github.com/rai-project/image/types"
4+
5+
type ImageRecord struct {
6+
ID uint64
7+
LabelIndex float32
8+
Image *types.RGBImage
9+
}

reader/tfrecord.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package reader
2+
3+
import (
4+
"bytes"
5+
context "context"
6+
goimage "image"
7+
"io"
8+
"os"
9+
"strings"
10+
11+
"github.com/pkg/errors"
12+
"github.com/rai-project/image"
13+
"github.com/rai-project/image/types"
14+
"github.com/ubccr/terf"
15+
)
16+
17+
type TFRecordReader struct {
18+
r io.ReadCloser
19+
*terf.Reader
20+
}
21+
22+
func NewTFRecordReader(path string) (*TFRecordReader, error) {
23+
r, err := os.Open(path)
24+
if err != nil {
25+
return nil, errors.Wrapf(err, "cannot open %v", path)
26+
}
27+
return &TFRecordReader{
28+
r: r,
29+
Reader: terf.NewReader(r),
30+
}, nil
31+
}
32+
33+
func (r *TFRecordReader) Next(ctx context.Context) (*ImageRecord, error) {
34+
nxt, err := r.Reader.Next()
35+
if err != nil {
36+
return nil, err
37+
}
38+
39+
imgRecord := new(terf.Image)
40+
err = imgRecord.UnmarshalExample(nxt)
41+
if err != nil {
42+
return nil, errors.Wrap(err, "unable to unmarshal image")
43+
}
44+
45+
if strings.ToLower(imgRecord.Format) == "cifar" {
46+
img := types.NewRGBImage(goimage.Rect(0, 0, imgRecord.Width, imgRecord.Height))
47+
imgPix := img.Pix
48+
inputPix := imgRecord.Raw
49+
channels := 3
50+
51+
for h := 0; h < imgRecord.Height; h++ {
52+
for w := 0; w < imgRecord.Width; w++ {
53+
for c := 0; c < channels; c++ {
54+
imgPix[channels*(h*imgRecord.Width+w)+c] =
55+
// inputPix format = The first 1024 entries contain the red channel values,
56+
// the next 1024 the green, and the final 1024 the blue.
57+
inputPix[c*imgRecord.Height*imgRecord.Width+h*imgRecord.Width+w]
58+
}
59+
}
60+
}
61+
return &ImageRecord{
62+
ID: uint64(imgRecord.ID),
63+
LabelIndex: float32(imgRecord.LabelID),
64+
Image: img,
65+
}, nil
66+
}
67+
68+
img, err := image.Read(bytes.NewBuffer(imgRecord.Raw), image.Context(nil))
69+
if err != nil {
70+
return nil, err
71+
}
72+
73+
rgbImage, ok := img.(*types.RGBImage)
74+
if !ok {
75+
return nil, errors.Errorf("expecting an rgb image")
76+
}
77+
78+
return &ImageRecord{
79+
ID: uint64(imgRecord.ID),
80+
LabelIndex: float32(imgRecord.LabelID),
81+
Image: rgbImage,
82+
}, nil
83+
}
84+
85+
func (r *TFRecordReader) Close() error {
86+
return r.r.Close()
87+
}

reader/tfrecord_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package reader
2+
3+
import (
4+
context "context"
5+
"image/png"
6+
"os"
7+
"path/filepath"
8+
"testing"
9+
10+
"github.com/GeertJohan/go-sourcepath"
11+
"github.com/stretchr/testify/assert"
12+
)
13+
14+
var (
15+
fixturesPath = filepath.Join(sourcepath.MustAbsoluteDir(), "_fixtures")
16+
)
17+
18+
func TestTFRecord(t *testing.T) {
19+
reader, err := NewTFRecordReader(filepath.Join(fixturesPath, "cifar10_validation.tfrecord"))
20+
assert.NoError(t, err)
21+
assert.NotEmpty(t, reader)
22+
23+
defer reader.Close()
24+
25+
rec, err := reader.Next(context.Background())
26+
assert.NoError(t, err)
27+
assert.NotEmpty(t, rec)
28+
29+
assert.Equal(t, rec.ID, uint64(0))
30+
31+
rec, err = reader.Next(context.Background())
32+
assert.NoError(t, err)
33+
assert.NotEmpty(t, rec)
34+
35+
assert.Equal(t, rec.ID, uint64(1))
36+
assert.NotEmpty(t, rec.Image.Pix)
37+
38+
out, _ := os.Create("_fixtures/test.png")
39+
defer out.Close()
40+
41+
err = png.Encode(out, rec.Image.ToRGBAImage())
42+
assert.NoError(t, err)
43+
}

reader/validation.tfrecords

Whitespace-only changes.

vision/ilsvrc2012_validation.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
context "context"
10+
1011
"github.com/Unknwon/com"
1112
"github.com/pkg/errors"
1213
"github.com/rai-project/config"
@@ -37,7 +38,7 @@ type ILSVRC2012ValidationRecordIO struct {
3738
}
3839

3940
type iLSVRC2012ValidationRecordIOLabeledData struct {
40-
*reader.Record
41+
*reader.ImageRecord
4142
}
4243

4344
type recordIoOffset struct {
@@ -189,7 +190,7 @@ func (d *ILSVRC2012ValidationRecordIO) Next(ctx context.Context) (dldataset.Labe
189190
}
190191

191192
return &iLSVRC2012ValidationRecordIOLabeledData{
192-
Record: rec,
193+
ImageRecord: rec,
193194
}, nil
194195
}
195196

0 commit comments

Comments
 (0)