Skip to content

Commit 4f4f1b0

Browse files
committed
First commit
0 parents  commit 4f4f1b0

27 files changed

+2194
-0
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
logging/
2+
error.log
3+
4+
__pycache__/

README.md

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# StegNet Paper: Mega-Image-Steganography-Capacity-with-Deep-Convolutional-Network
2+
---
3+
4+
## How to create ImageNet Dataset used by StegNet
5+
6+
[[Read the LMDB Creator Doc][./lmdb_creator/README.md]]
7+
8+
9+
## How to run the StegNet Model
10+
11+
Step 1. Setup Environmental Variables:
12+
13+
```bash
14+
export ILSVRC2012_MDB_PATH="<Your Path to Created 'ILSVRC2012_image_train.mdb' Directory>"
15+
```
16+
17+
Step 2. Run the code
18+
19+
```bash
20+
python ./main.py
21+
```
22+
23+
The command line arguments can be tweeked:
24+
```
25+
-h, --help
26+
--train_max_epoch TRAIN_MAX_EPOCH
27+
--batch_size BATCH_SIZE
28+
--restart # Restart from scratch
29+
--global_mode {train,inference}
30+
```
31+

dataset_tools/__init__.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pathlib
2+
import io
3+
import math
4+
import multiprocessing as mp
5+
6+
import numpy as np
7+
from PIL import Image
8+
9+
from . import ilsvrc2012
10+
11+
_dispatcher = {
12+
'ILSVRC2012': ilsvrc2012.DatasetILSVRC2012
13+
}
14+
15+
16+
def get_dataset_by_name(name):
17+
return _dispatcher[name]

dataset_tools/dataset.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
class Dataset(object):
2+
def __init__(self, train_ratio=0.8, seed=42):
3+
raise NotImplementedError
4+
5+
def get_name(self):
6+
raise NotImplementedError
7+
8+
def get_shape(self):
9+
raise NotImplementedError
10+
11+
def get_whole_size(self):
12+
raise NotImplementedError
13+
14+
def get_train_size(self):
15+
raise NotImplementedError
16+
17+
def get_valid_size(self):
18+
raise NotImplementedError
19+
20+
def fetch_train_data(self, batch_size):
21+
raise NotImplementedError
22+
23+
def fetch_valid_data(self, batch_size):
24+
raise NotImplementedError

dataset_tools/ilsvrc2012.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import io
2+
import os
3+
4+
import numpy as np
5+
from PIL import Image
6+
7+
import lmdb
8+
import msgpack
9+
10+
from .dataset import Dataset
11+
12+
ILSVRC2012_MDB_PATH = os.environ['ILSVRC2012_MDB_PATH']
13+
14+
class DatasetILSVRC2012(Dataset):
15+
def __init__(self, train_ratio=0.8, seed=42):
16+
self.seed = seed
17+
self.mdb_path = ILSVRC2012_MDB_PATH
18+
self.env = lmdb.open(self.mdb_path, readonly=True)
19+
self.whole_size = self.env.stat()['entries']
20+
self.train_size = int(self.whole_size * train_ratio)
21+
self.valid_size = self.whole_size - self.train_size
22+
self.inrows, self.incols, self.incnls = 64, 64, 3
23+
24+
def get_name(self):
25+
return 'ILSVRC2012'
26+
27+
def get_shape(self):
28+
return self.inrows, self.incols, self.incnls
29+
30+
def get_whole_size(self):
31+
return self.whole_size
32+
33+
def get_train_size(self):
34+
return self.train_size
35+
36+
def get_valid_size(self):
37+
return self.valid_size
38+
39+
def _fetch_data_in_range(self, batch_size, lower_bound, upper_bound):
40+
# Image is normalized to [-1, 1]
41+
np.random.seed(self.seed)
42+
rand_range = np.arange(lower_bound, upper_bound)
43+
with self.env.begin() as txn:
44+
while True:
45+
image_v = np.zeros(shape=(batch_size, self.inrows, self.incols, self.incnls))
46+
image_idx = np.random.choice(rand_range, size=batch_size)
47+
for index in range(batch_size):
48+
image_rawd = txn.get('{:08d}'.format(image_idx[index]).encode())
49+
image_info = msgpack.unpackb(image_rawd, encoding='utf-8')
50+
with Image.open(io.BytesIO(image_info['image'])) as im:
51+
im = im.resize((self.inrows, self.incols), Image.ANTIALIAS)
52+
image_data = np.array(im)
53+
image_v[index, :, :, :] = image_data
54+
image_v = image_v / 255. * 2 - 1
55+
yield image_v
56+
57+
def fetch_train_data(self, batch_size):
58+
return self._fetch_data_in_range(batch_size, 0, self.train_size)
59+
60+
def fetch_valid_data(self, batch_size):
61+
return self._fetch_data_in_range(batch_size, self.train_size, self.whole_size)

generators/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
'''
2+
Generators
3+
'''
4+
5+
from .dataset_generator import dataset_generator

generators/dataset_generator.py

+32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
'''
2+
Input Image Data Generators
3+
'''
4+
5+
import contextlib
6+
import queue as Q
7+
import time
8+
9+
import params
10+
import utils
11+
12+
13+
def dataset_generator(queue, dataset, mode, role, batch_size):
14+
'''
15+
Generate image data from dataset
16+
'''
17+
if mode == 'train':
18+
dgen = dataset.fetch_train_data
19+
else:
20+
dgen = dataset.fetch_valid_data
21+
22+
# queue-name: one of ['covr/train', 'hide/train', 'covr/valid', 'hide/valid']
23+
qname = '{}/{}'.format(role, mode)
24+
25+
for image in dgen(batch_size):
26+
if params.SHOULD_FINISH.value:
27+
break
28+
with contextlib.suppress(Q.Full):
29+
queue[qname].put(image, timeout=params.QUEUE_TIMEOUT)
30+
# Setup queue to allow exit without flushing all the data to the pipe
31+
queue[qname].cancel_join_thread()
32+
utils.eprint('dataset_generator(%s/%s): exit' % (role, mode))

lmdb_creator/README.md

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Create lmdb files of ImageNet 2012 Contest
2+
3+
Step 1. Download and Extract ILSVRC2012 Dataset
4+
5+
Step 2. Make sure that you have `ILSVRC2012_devkit_t12` and `ILSVRC2012_img_train` directory
6+
7+
Step 3. Set up environment variables:
8+
9+
```bash
10+
export IMAGE_DIR="<Your ILSVRC2012_img_train Directory Path>"
11+
export DK_DIR="<Your ILSVRC2012_devkit_t12 Directory Path>"
12+
export MDB_OUT_DIR="<Your Expected Directory for Generating LMDB File>" # Note: Reserve 60GB at least
13+
```
14+
15+
Step 4. Run the python script
16+
17+
```bash
18+
python ./images2lmdb.py
19+
```

lmdb_creator/images2lmdb.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
'''
2+
Generate ILSVRC2012 Dataset LMDB file
3+
'''
4+
import io
5+
import os
6+
import pathlib
7+
import struct
8+
import sys
9+
import time
10+
11+
import PIL.Image
12+
import lmdb
13+
import msgpack
14+
15+
import scipy.io
16+
import cytoolz as tz
17+
import numpy as np
18+
19+
# Prepare
20+
IMAGE_DIR = os.environ['IMAGE_DIR']
21+
DK_DIR = os.environ['DK_DIR']
22+
MDB_OUT_DIR = os.environ['MDB_OUT_DIR']
23+
24+
seed = 42
25+
np.random.seed(seed)
26+
27+
lmdb_map_size = 50*1024*1024*1024
28+
lmdb_txn_size = 500
29+
30+
# Setup PATHs
31+
META_PATH = os.path.join(DK_DIR, 'data', 'meta.mat')
32+
META_MP_PATH = os.path.join(MDB_OUT_DIR, 'meta.msgpack')
33+
LMDB_PATH = os.path.join(MDB_OUT_DIR, 'ILSVRC2012_image_train.mdb')
34+
35+
# Generate meta.msgpack
36+
meta = scipy.io.loadmat(META_PATH, squeeze_me=True)
37+
synsets = meta['synsets']
38+
39+
meta_info = [{
40+
'ILSVRC2012_ID': int(s['ILSVRC2012_ID']),
41+
'WNID': str(s['WNID']),
42+
'words': str(s['words']),
43+
'gloss': str(s['gloss']),
44+
'wordnet_height': int(s['wordnet_height']),
45+
'num_train_images': int(s['num_train_images'])
46+
} for s in synsets]
47+
48+
meta_info_packed = msgpack.packb(meta_info, use_bin_type=True)
49+
50+
with open(META_MP_PATH, 'wb') as f:
51+
f.write(meta_info_packed)
52+
53+
# Generate LMDB
54+
def make_context():
55+
return {
56+
'image_id': 0,
57+
'clock_beg': time.time(),
58+
'clock_end': time.time(),
59+
}
60+
61+
62+
def process_image_one(txn, image_id, wordnet_id, label, image_abspath):
63+
'''
64+
txn: lmdb transaction object
65+
image_id: int
66+
The image id, increasing index
67+
wordnet_id: str
68+
The wordnet id, i.e. n07711569
69+
image_abspath: str
70+
The image's absolute path
71+
'''
72+
with PIL.Image.open(image_abspath) as im, io.BytesIO() as bio:
73+
if im.mode != 'RGB':
74+
im = im.convert('RGB')
75+
rows, cols = im.size
76+
cnls = 3
77+
im.resize((256, 256))
78+
im.save(bio, format='webp')
79+
image_bytes = bio.getvalue()
80+
81+
filename = os.path.basename(image_abspath).rstrip('.JPEG')
82+
83+
info = {
84+
'wordnet_id': wordnet_id,
85+
'filename': filename,
86+
'image': image_bytes,
87+
'rows': rows,
88+
'cols': cols,
89+
'cnls': cnls,
90+
'label': label,
91+
}
92+
key = '{:08d}'.format(image_id).encode()
93+
txn.put(key, msgpack.packb(info, use_bin_type=True))
94+
95+
96+
def imagenet_walk(wnid_meta_map, image_Dir):
97+
def get_category_image_abspaths(Path):
98+
return [str(f.absolute()) for f in Path.iterdir() if f.is_file()]
99+
100+
def process_category_one(count, category_Path):
101+
wordnet_id = category_Path.name
102+
metainfo = wnid_meta_map[wordnet_id]
103+
words = metainfo['words']
104+
gloss = metainfo['gloss']
105+
label = metainfo['ILSVRC2012_ID']
106+
107+
print('Process count=%d, label=%d, wordnet_id=%s' % (count, label, wordnet_id))
108+
print(' %s: %s' % (words, gloss))
109+
for image_abspath in get_category_image_abspaths(category_Path):
110+
yield {
111+
'label': label,
112+
'wordnet_id': wordnet_id,
113+
'image_abspath': image_abspath
114+
}
115+
116+
categories = [d for d in image_Dir.iterdir() if d.is_dir()]
117+
118+
image_files = [
119+
image_info
120+
for count, category_Path in enumerate(categories)
121+
for image_info in process_category_one(count, category_Path)
122+
]
123+
return image_files
124+
125+
126+
def process_images(ctx, lmdb_env, image_infos, image_total):
127+
image_id = ctx['image_id']
128+
129+
with lmdb_env.begin(write=True) as txn:
130+
for image_info in image_infos:
131+
wordnet_id = image_info['wordnet_id']
132+
label = image_info['label']
133+
image_abspath = image_info['image_abspath']
134+
process_image_one(txn, image_id, wordnet_id, label, image_abspath)
135+
image_id = image_id + 1
136+
137+
clock_beg = ctx['clock_beg']
138+
clock_end = time.time()
139+
140+
elapse = clock_end - clock_beg
141+
elapse_h = int(elapse) // 60 // 60
142+
elapse_m = int(elapse) // 60 % 60
143+
elapse_s = int(elapse) % 60
144+
145+
estmt = (image_total - image_id) / image_id * elapse
146+
estmt_h = int(estmt) // 60 // 60
147+
estmt_m = int(estmt) // 60 % 60
148+
estmt_s = int(estmt) % 60
149+
150+
labels = [image_info['label'] for image_info in image_infos]
151+
print('ImageId: {:8d}/{:8d}, time: {:2d}h/{:2d}m/{:2d}s, remain: {:2d}h/{:2d}m/{:2d}s, Sample: {} ...'.format(
152+
image_id, image_total,
153+
elapse_h, elapse_m, elapse_s,
154+
estmt_h, estmt_m, estmt_s,
155+
str(labels)[:80]))
156+
157+
ctx['image_id'] = image_id
158+
ctx['clock_end'] = clock_end
159+
160+
161+
wnid_meta_map = { m['WNID']: m for m in meta_info }
162+
163+
image_train_env = lmdb.open(LMDB_PATH, map_size=lmdb_map_size)
164+
165+
image_infos = imagenet_walk(wnid_meta_map, pathlib.Path(IMAGE_DIR))
166+
image_total = len(image_infos)
167+
np.random.shuffle(image_infos)
168+
169+
ctx = make_context()
170+
for image_infos_partial in tz.partition_all(lmdb_txn_size, image_infos):
171+
process_images(ctx, image_train_env, image_infos_partial, image_total)

0 commit comments

Comments
 (0)