Skip to content

Commit fc4f9c8

Browse files
author
Abdul Dakkak
committed
updates
1 parent cfc5fc9 commit fc4f9c8

File tree

3 files changed

+34
-44
lines changed

3 files changed

+34
-44
lines changed

vision/cifar10.go

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package vision
22

33
import (
44
"bufio"
5-
"bytes"
65
"encoding/binary"
6+
"image"
77
"io"
88
"os"
99
"path"
@@ -16,7 +16,7 @@ import (
1616
"github.com/rai-project/config"
1717
"github.com/rai-project/dldataset"
1818
"github.com/rai-project/downloadmanager"
19-
"github.com/rai-project/image"
19+
"github.com/rai-project/image/types"
2020
context "golang.org/x/net/context"
2121
)
2222

@@ -41,15 +41,15 @@ type CIFAR10 struct {
4141

4242
type CIFAR10LabeledImage struct {
4343
label string
44-
data image.RGBImage
44+
data *types.RGBImage
4545
}
4646

4747
func (l CIFAR10LabeledImage) Label() string {
4848
return l.label
4949
}
5050

51-
func (l CIFAR10LabeledImage) Data() (io.Reader, error) {
52-
return bytes.NewBuffer(l.data), nil
51+
func (l CIFAR10LabeledImage) Data() (interface{}, error) {
52+
return l.data, nil
5353
}
5454

5555
func (*CIFAR10) Name() string {
@@ -235,26 +235,17 @@ func (d *CIFAR10) readEntry(ctx context.Context, reader io.Reader) (*CIFAR10Labe
235235

236236
pixelByteSize := int64(d.pixelByteSize)
237237
pixelBytesReader := io.LimitReader(reader, pixelByteSize)
238-
pixelBytes := make([]byte, pixelByteSize)
239-
err = binary.Read(pixelBytesReader, binary.LittleEndian, &pixelBytes)
238+
239+
img := types.NewRGBImage(image.Rect(0, 0, d.imageDimensions[0], d.imageDimensions[1]))
240+
241+
err = binary.Read(pixelBytesReader, binary.LittleEndian, img.Pix)
240242
if err == io.EOF {
241243
return nil, err
242244
}
243245
if err != nil {
244246
return nil, errors.New("unable to read label")
245247
}
246248

247-
img := image.NewRGBImage(image.Rect(0, 0, d.imageDimensions[0], d.imageDimensions[1]))
248-
var idx int
249-
for y := 0; y < d.imageDimensions[0]; y++ {
250-
for x := 0; x < d.imageDimensions[1]; x++ {
251-
for c := 0; c < d.imageDimensions[2]; c++ {
252-
img.Pix[ii] = float32(pixelBytes[idx])
253-
idx++
254-
}
255-
}
256-
}
257-
258249
return &CIFAR10LabeledImage{
259250
label: d.labels[labelIdx],
260251
data: img,

vision/cifar100.go

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package vision
22

33
import (
44
"bufio"
5-
"bytes"
65
"encoding/binary"
6+
"image"
77
"io"
88
"os"
99
"path"
@@ -16,6 +16,7 @@ import (
1616
"github.com/rai-project/config"
1717
"github.com/rai-project/dldataset"
1818
"github.com/rai-project/downloadmanager"
19+
"github.com/rai-project/image/types"
1920
context "golang.org/x/net/context"
2021
)
2122

@@ -44,7 +45,7 @@ type CIFAR100 struct {
4445
type CIFAR100LabeledImage struct {
4546
coarseLabel string
4647
fineLabel string
47-
data image.RGBImage
48+
data *types.RGBImage
4849
}
4950

5051
func (l CIFAR100LabeledImage) CoarseLabel() string {
@@ -59,8 +60,8 @@ func (l CIFAR100LabeledImage) Label() string {
5960
return l.FineLabel()
6061
}
6162

62-
func (l CIFAR100LabeledImage) Data() (io.Reader, error) {
63-
return bytes.NewBuffer(l.data), nil
63+
func (l CIFAR100LabeledImage) Data() (interface{}, error) {
64+
return l.data, nil
6465
}
6566

6667
func (*CIFAR100) Name() string {
@@ -270,26 +271,17 @@ func (d *CIFAR100) readEntry(ctx context.Context, reader io.Reader) (*CIFAR100La
270271

271272
pixelByteSize := int64(d.pixelByteSize)
272273
pixelBytesReader := io.LimitReader(reader, pixelByteSize)
273-
pixelBytes := make([]byte, pixelByteSize)
274-
err = binary.Read(pixelBytesReader, binary.LittleEndian, &pixelBytes)
274+
275+
img := types.NewRGBImage(image.Rect(0, 0, d.imageDimensions[0], d.imageDimensions[1]))
276+
277+
err = binary.Read(pixelBytesReader, binary.LittleEndian, img.Pix)
275278
if err == io.EOF {
276279
return nil, err
277280
}
278281
if err != nil {
279282
return nil, errors.Wrap(err, "unable to read label")
280283
}
281284

282-
img := image.NewRGBImage(image.Rect(0, 0, d.imageDimensions[0], d.imageDimensions[1]))
283-
var idx int
284-
for y := 0; y < d.imageDimensions[0]; y++ {
285-
for x := 0; x < d.imageDimensions[1]; x++ {
286-
for c := 0; c < d.imageDimensions[2]; c++ {
287-
img.Pix[ii] = float32(pixelBytes[idx])
288-
idx++
289-
}
290-
}
291-
}
292-
293285
return &CIFAR100LabeledImage{
294286
coarseLabel: d.coarseLabels[coarseLabelIdx],
295287
fineLabel: d.fineLabels[fineLabelIdx],

vision/mnist.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package vision
22

33
import (
4-
"bytes"
5-
"io"
4+
"image"
65
"path"
76
"strconv"
87
"strings"
@@ -12,6 +11,7 @@ import (
1211
"github.com/pkg/errors"
1312
"github.com/rai-project/config"
1413
"github.com/rai-project/dldataset"
14+
"github.com/rai-project/image/types"
1515
mnistLoader "github.com/unixpickle/mnist"
1616
)
1717

@@ -25,15 +25,15 @@ var mnist *MNIST
2525

2626
type MNISTLabeledImage struct {
2727
label string
28-
data []byte
28+
data *types.RGBImage
2929
}
3030

3131
func (l MNISTLabeledImage) Label() string {
3232
return l.label
3333
}
3434

35-
func (l MNISTLabeledImage) Data() (io.Reader, error) {
36-
return bytes.NewBuffer(l.data), nil
35+
func (l MNISTLabeledImage) Data() (interface{}, error) {
36+
return l.data, nil
3737
}
3838

3939
func (*MNIST) Name() string {
@@ -86,17 +86,24 @@ func (d *MNIST) Get(ctx context.Context, name string) (dldataset.LabeledData, er
8686
}
8787

8888
elem := dataset.Samples[idx]
89-
data := make([]byte, len(elem.Intensities))
89+
90+
img := types.NewRGBImage(image.Rect(0, 0, dataset.Width, dataset.Height))
91+
data := img.Pix
92+
9093
for ii, intensity := range elem.Intensities {
9194
if intensity == 1 {
92-
data[ii] = byte(1)
95+
data[3*ii+0] = byte(1)
96+
data[3*ii+1] = byte(1)
97+
data[3*ii+2] = byte(1)
9398
} else {
94-
data[ii] = byte(0)
99+
data[3*ii+0] = byte(0)
100+
data[3*ii+1] = byte(0)
101+
data[3*ii+2] = byte(0)
95102
}
96103
}
97104

98105
return &MNISTLabeledImage{
99-
data: data,
106+
data: img,
100107
label: strconv.Itoa(elem.Label),
101108
}, nil
102109
}

0 commit comments

Comments
 (0)