Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BN-Inception #2

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
save/*
logs/*
dump/*

ckpts/*

*.pyc
3 changes: 3 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .inceptionresnetv2.pytorch_load import inceptionresnetv2
from .inceptionv4.pytorch_load import inceptionv4
from .bninception.pytorch_load import BNInception, InceptionV3
Empty file added bninception/__init__.py
Empty file.
558 changes: 558 additions & 0 deletions bninception/bn_inception.yaml

Large diffs are not rendered by default.

5,607 changes: 5,607 additions & 0 deletions bninception/caffe_pb2.py

Large diffs are not rendered by default.

821 changes: 821 additions & 0 deletions bninception/inceptionv3.yaml

Large diffs are not rendered by default.

84 changes: 84 additions & 0 deletions bninception/layer_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from torch import nn


LAYER_BUILDER_DICT=dict()


def parse_expr(expr):
parts = expr.split('<=')
return parts[0].split(','), parts[1], parts[2].split(',')


def get_basic_layer(info, channels=None, conv_bias=False):
id = info['id']
attr = info['attrs'] if 'attrs' in info else list()

out, op, in_vars = parse_expr(info['expr'])
assert(len(out) == 1)
assert(len(in_vars) == 1)
mod, out_channel, = LAYER_BUILDER_DICT[op](attr, channels, conv_bias)

return id, out[0], mod, out_channel, in_vars[0]


def build_conv(attr, channels=None, conv_bias=False):
out_channels = attr['num_output']
ks = attr['kernel_size'] if 'kernel_size' in attr else (attr['kernel_h'], attr['kernel_w'])
if 'pad' in attr or 'pad_w' in attr and 'pad_h' in attr:
padding = attr['pad'] if 'pad' in attr else (attr['pad_h'], attr['pad_w'])
else:
padding = 0
if 'stride' in attr or 'stride_w' in attr and 'stride_h' in attr:
stride = attr['stride'] if 'stride' in attr else (attr['stride_h'], attr['stride_w'])
else:
stride = 1

conv = nn.Conv2d(channels, out_channels, ks, stride, padding, bias=conv_bias)

return conv, out_channels


def build_pooling(attr, channels=None, conv_bias=False):
method = attr['mode']
pad = attr['pad'] if 'pad' in attr else 0
if method == 'max':
pool = nn.MaxPool2d(attr['kernel_size'], attr['stride'], pad,
ceil_mode=True) # all Caffe pooling use ceil model
elif method == 'ave':
pool = nn.AvgPool2d(attr['kernel_size'], attr['stride'], pad,
ceil_mode=True) # all Caffe pooling use ceil model
else:
raise ValueError("Unknown pooling method: {}".format(method))

return pool, channels


def build_relu(attr, channels=None, conv_bias=False):
return nn.ReLU(inplace=True), channels


def build_bn(attr, channels=None, conv_bias=False):
return nn.BatchNorm2d(channels, momentum=0.1), channels


def build_linear(attr, channels=None, conv_bias=False):
return nn.Linear(channels, attr['num_output']), channels


def build_dropout(attr, channels=None, conv_bias=False):
return nn.Dropout(p=attr['dropout_ratio']), channels


LAYER_BUILDER_DICT['Convolution'] = build_conv

LAYER_BUILDER_DICT['Pooling'] = build_pooling

LAYER_BUILDER_DICT['ReLU'] = build_relu

LAYER_BUILDER_DICT['Dropout'] = build_dropout

LAYER_BUILDER_DICT['BN'] = build_bn

LAYER_BUILDER_DICT['InnerProduct'] = build_linear

159 changes: 159 additions & 0 deletions bninception/parse_caffe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#!/usr/bin/env python

import argparse

parser = argparse.ArgumentParser(description="Convert a Caffe model and its learned parameters to torch")
parser.add_argument('model', help='network spec, usually a ProtoBuf text message')
parser.add_argument('weights', help='network parameters, usually in a name like *.caffemodel ')
parser.add_argument('--model_yaml', help="translated model spec yaml file")
parser.add_argument('--dump_weights', help="translated model parameters to be used by torch")
parser.add_argument('--model_version', help="the version of Caffe's model spec, usually 2", default=2)

args = parser.parse_args()

from . import caffe_pb2
from google.protobuf import text_format
from pprint import pprint
import yaml
import numpy as np
import torch


class CaffeVendor(object):
def __init__(self, net_name, weight_name, version=2):
print("loading model spec...")
self._net_pb = caffe_pb2.NetParameter()
text_format.Merge(open(net_name).read(), self._net_pb)
self._weight_dict = {}
self._init_dict = []

if weight_name is not None:
print("loading weights...")
self._weight_pb = caffe_pb2.NetParameter()
self._weight_pb.ParseFromString(open(weight_name, 'rb').read())
for l in self._weight_pb.layer:
self._weight_dict[l.name] = l

print("parsing...")
self._parse_net(version)

def _parse_net(self, version):
self._name = str(self._net_pb.name)
self._layers = self._net_pb.layer if version == 2 else self._net_pb.layers
self._parsed_layers = [self._layer2dict(x, version) for x in self._layers]

self._net_dict = {
'name': self._name,
'inputs': [],
'layers': [],
}

self._weight_array_dict = {}

for info, blob, is_data in self._parsed_layers:
if not is_data and info is not None:
self._net_dict['layers'].append(info)

self._weight_array_dict.update(blob)

@staticmethod
def _parse_blob(blob):
flat_data = np.array(blob.data)
shaped_data = flat_data.reshape(list(blob.shape.dim))
return shaped_data

def _layer2dict(self, layer, version):
attr_dict = {}
params = []
weight_params = []
fillers = []

for field, value in layer.ListFields():
if field.name == 'top':
tops = [v.replace('-', '_').replace('/', '_') for v in value]
elif field.name == 'name':
layer_name = str(value).replace('-', '_').replace('/', '_')
elif field.name == 'bottom':
bottoms = [v.replace('-', '_').replace('/', '_') for v in value]
elif field.name == 'include':
if value[0].phase == 1 and op == 'Data':
print('found 1 testing data layer')
return None, dict(), dict(), False
elif field.name == 'type':
if version == 2:
op = value
else:
raise NotImplemented
elif field.name == 'loss_weight':
pass
elif field.name == 'param':
pass
else:
# other params
try:
for f, v in value.ListFields():
if 'filler' in f.name:
pass
elif f.name == 'pool':
attr_dict['mode'] = 'max' if v == 0 else 'ave'
else:
attr_dict[f.name] = v

except:
print(field.name, value)
raise

expr_temp = '{top}<={op}<={input}'

if layer.name in self._weight_dict:
blobs = [self._parse_blob(x) for x in self._weight_dict[layer.name].blobs]
else:
blobs = []

blob_dict = dict()
if len(blobs) > 0:
blob_dict['{}.weight'.format(layer_name)] = torch.from_numpy(blobs[0])
blob_dict['{}.bias'.format(layer_name)] = torch.from_numpy(blobs[1])
if op == 'BN':
blob_dict['{}.running_mean'.format(layer_name)] = torch.from_numpy(blobs[2])
blob_dict['{}.running_var'.format(layer_name)] = torch.from_numpy(blobs[3])

expr = expr_temp.format(top=','.join(tops), input=','.join(bottoms), op=op)

out_dict = {
'id': layer_name,
'expr': expr,
}

if len(attr_dict) > 0:
out_dict['attrs'] = attr_dict

return out_dict, blob_dict, False

@property
def text_form(self):
return str(self._net_pb)

@property
def info(self):
return {
'name': self._name,
'layers': [x.name for x in self._layers]
}

@property
def yaml(self):
return yaml.dump(self._net_dict)

def dump_weights(self, filename):
# print self._weight_array_dict.keys()
torch.save(self._weight_array_dict, open(filename, 'wb'))

# build output
cv = CaffeVendor(args.model, args.weights, int(args.model_version))

if args.model_yaml is not None:
open(args.model_yaml, 'w').write(cv.yaml)

if args.dump_weights is not None:
cv.dump_weights(args.dump_weights)
67 changes: 67 additions & 0 deletions bninception/pytorch_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch import nn
from .layer_factory import get_basic_layer, parse_expr
import torch.utils.model_zoo as model_zoo
import yaml


class BNInception(nn.Module):
def __init__(self, model_path='model_zoo/bninception/bn_inception.yaml', num_classes=101,
weight_url='https://yjxiong.blob.core.windows.net/models/bn_inception-9f5701afb96c8044.pth'):
super(BNInception, self).__init__()

manifest = yaml.load(open(model_path))

layers = manifest['layers']

self._channel_dict = dict()

self._op_list = list()
for l in layers:
out_var, op, in_var = parse_expr(l['expr'])
if op != 'Concat':
id, out_name, module, out_channel, in_name = get_basic_layer(l,
3 if len(self._channel_dict) == 0 else self._channel_dict[in_var[0]],
conv_bias=True)

self._channel_dict[out_name] = out_channel
setattr(self, id, module)
self._op_list.append((id, op, out_name, in_name))
else:
self._op_list.append((id, op, out_var[0], in_var))
channel = sum([self._channel_dict[x] for x in in_var])
self._channel_dict[out_var[0]] = channel

self.load_state_dict(torch.utils.model_zoo.load_url(weight_url))

def forward(self, input):
data_dict = dict()
data_dict[self._op_list[0][-1]] = input

def get_hook(name):

def hook(m, grad_in, grad_out):
print(name, grad_out[0].data.abs().mean())

return hook
for op in self._op_list:
if op[1] != 'Concat' and op[1] != 'InnerProduct':
data_dict[op[2]] = getattr(self, op[0])(data_dict[op[-1]])
# getattr(self, op[0]).register_backward_hook(get_hook(op[0]))
elif op[1] == 'InnerProduct':
x = data_dict[op[-1]]
data_dict[op[2]] = getattr(self, op[0])(x.view(x.size(0), -1))
else:
try:
data_dict[op[2]] = torch.cat(tuple(data_dict[x] for x in op[-1]), 1)
except:
for x in op[-1]:
print(x,data_dict[x].size())
raise
return data_dict[self._op_list[-1][2]]


class InceptionV3(BNInception):
def __init__(self, model_path='model_zoo/bninception/inceptionv3.yaml', num_classes=101,
weight_url='https://yjxiong.blob.core.windows.net/models/inceptionv3-cuhk-0e09b300b493bc74c.pth'):
super(InceptionV3, self).__init__(model_path=model_path, weight_url=weight_url, num_classes=num_classes)