diff --git a/flops.py b/flops.py new file mode 100644 index 0000000..5511600 --- /dev/null +++ b/flops.py @@ -0,0 +1,32 @@ +import torch +# from ptflops import get_model_complexity_info +# from model import NetworkCIFAR as Network +from thop import profile +from model_search import Network as Network +from model import NetworkCIFAR as Network_CIFAR +import genotypes + +# torch.cuda.set_device(0) +init_channels = 36 +CIFAR_CLASSES = 10 +layers = 20 +auxiliary = True +input = torch.randn(1, 3, 32, 32) +# input = input.cuda() +genotype = eval("genotypes.%s" % 'SWD_NAS') + +# genotype = eval("Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_5x5', 2)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 1)], reduce_concat=range(2, 6))") + +criterion = torch.nn.CrossEntropyLoss() +criterion = criterion.cuda() + +with torch.cuda.device(0): + model = Network_CIFAR(init_channels, CIFAR_CLASSES, layers, auxiliary, genotype) + # net = Network(init_channels, CIFAR_CLASSES, layers, criterion) + # model = model.cuda() + model.drop_path_prob = 0.2 + flops, params = profile(model, inputs=(input,)) + # macs, params = get_model_complexity_info(net, (3, 32, 32), as_strings=True, + # print_per_layer_stat=True, verbose=True) +print(f"FLOPs: {flops / 1e6} M FLOPs") +print(f"Number of Parameters: {params}") \ No newline at end of file diff --git a/genotypes.py b/genotypes.py new file mode 100644 index 0000000..2529101 --- /dev/null +++ b/genotypes.py @@ -0,0 +1,95 @@ +from collections import namedtuple + +Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') +Normal_Genotype = namedtuple('Normal_Genotype', 'normal normal_concat') +Reduce_Genotype = namedtuple('Reduce_Genotype', 'reduce reduce_concat') + +PARAMS = {'conv_3x1_1x3':864, 'conv_7x1_1x7':2016, 'sep_conv_7x7': 1464, 'conv 3x3':1296, 'sep_conv_5x5': 888, 'sep_conv_3x3':504, 'dil_conv_5x5': 444, 'conv 1x1':144, 'dil_conv_3x3':252, 'skip_connect':0, 'none':0, 'max_pool_3x3':0, 'avg_pool_3x3':0} + +PRIMITIVES = [ + 'sep_conv_3x3', + 'sep_conv_5x5', + 'dil_conv_3x3', + 'dil_conv_5x5', + 'skip_connect', + 'avg_pool_3x3', + 'max_pool_3x3', + 'none' +] + +NASNet = Genotype( + normal = [ + ('sep_conv_5x5', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 0), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 0), + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ], + normal_concat = [2, 3, 4, 5, 6], + reduce = [ + ('sep_conv_5x5', 1), + ('sep_conv_7x7', 0), + ('max_pool_3x3', 1), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('sep_conv_5x5', 0), + ('skip_connect', 3), + ('avg_pool_3x3', 2), + ('sep_conv_3x3', 2), + ('max_pool_3x3', 1), + ], + reduce_concat = [4, 5, 6], +) + +AmoebaNet = Genotype( + normal = [ + ('avg_pool_3x3', 0), + ('max_pool_3x3', 1), + ('sep_conv_3x3', 0), + ('sep_conv_5x5', 2), + ('sep_conv_3x3', 0), + ('avg_pool_3x3', 3), + ('sep_conv_3x3', 1), + ('skip_connect', 1), + ('skip_connect', 0), + ('avg_pool_3x3', 1), + ], + normal_concat = [4, 5, 6], + reduce = [ + ('avg_pool_3x3', 0), + ('sep_conv_3x3', 1), + ('max_pool_3x3', 0), + ('sep_conv_7x7', 2), + ('sep_conv_7x7', 0), + ('avg_pool_3x3', 1), + ('max_pool_3x3', 0), + ('max_pool_3x3', 1), + ('conv_7x1_1x7', 0), + ('sep_conv_3x3', 5), + ], + reduce_concat = [3, 4, 6] +) + +DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5]) +DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) + +DARTS = DARTS_V1 + +Cell_0 = Normal_Genotype(normal=[('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('dil_conv_5x5', 0), ('skip_connect', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 3), ('dil_conv_5x5', 0), ('sep_conv_3x3', 3)], normal_concat=range(2, 6)) +Cell_1 = Reduce_Genotype(reduce=[('max_pool_3x3', 0), ('sep_conv_5x5', 1), ('skip_connect', 0), ('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('skip_connect', 1), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], reduce_concat=range(2, 6)) +Cell_2 = Normal_Genotype(normal=[('skip_connect', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_3x3', 0), ('sep_conv_5x5', 1), ('sep_conv_3x3', 1), ('dil_conv_5x5', 4)], normal_concat=range(2, 6)) +Cell_3 = Reduce_Genotype(reduce=[('skip_connect', 1), ('skip_connect', 0), ('dil_conv_3x3', 1), ('dil_conv_3x3', 0), ('avg_pool_3x3', 0), ('max_pool_3x3', 1), ('avg_pool_3x3', 0), ('max_pool_3x3', 1)], reduce_concat=range(2, 6)) +Cell_4 = Normal_Genotype(normal=[('skip_connect', 0), ('dil_conv_5x5', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_5x5', 0), ('max_pool_3x3', 2), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1)], normal_concat=range(2, 6)) + +Cell_0 = Normal_Genotype(normal=[('skip_connect', 0), ('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 1), ('sep_conv_5x5', 0), ('avg_pool_3x3', 1)], normal_concat=range(2, 6)) +Cell_1 = Reduce_Genotype(reduce=[('dil_conv_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_3x3', 1), ('sep_conv_5x5', 3), ('sep_conv_5x5', 0), ('skip_connect', 1)], reduce_concat=range(2, 6)) +Cell_2 = Normal_Genotype(normal=[('skip_connect', 0), ('sep_conv_3x3', 1), ('avg_pool_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('avg_pool_3x3', 0), ('skip_connect', 1)], normal_concat=range(2, 6)) +Cell_3 = Reduce_Genotype(reduce=[('avg_pool_3x3', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('avg_pool_3x3', 1), ('dil_conv_3x3', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 1)], reduce_concat=range(2, 6)) +Cell_4 = Normal_Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('skip_connect', 0), ('skip_connect', 1), ('sep_conv_5x5', 0), ('dil_conv_5x5', 1), ('dil_conv_5x5', 0), ('max_pool_3x3', 1)], normal_concat=range(2, 6)) + +SWD_NAS = [Cell_0, Cell_1, Cell_2, Cell_3, Cell_4] \ No newline at end of file diff --git a/inference_time.py b/inference_time.py new file mode 100644 index 0000000..756ad42 --- /dev/null +++ b/inference_time.py @@ -0,0 +1,27 @@ +import torch +from model_search import Network as Network + +iteration = 50 +model = Network(16, 10, 5, None) +# model = model.cuda() + +input = torch.randn(1, 3, 32, 32) +# input = input.cuda() + +starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) +for _ in range(50): + _ = model(input) + +times = torch.zeros(iteration) +with torch.no_grad(): + for iter in range(iteration): + starter.record() + _ = model(input) + ender.record() + # Waits for everything to finish running + torch.cuda.synchronize() + times[iter] = starter.elapsed_time(ender) + +mean_time = times.mean().item() +print("Inference time: {:.6f}, FPS: {} ".format(mean_time, 1000 / mean_time)) + diff --git a/model.py b/model.py new file mode 100644 index 0000000..fbea2bb --- /dev/null +++ b/model.py @@ -0,0 +1,241 @@ +import torch +import torch.nn as nn +from operations import * +from torch.autograd import Variable +from utils import drop_path +import random + + +class Cell(nn.Module): + + def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): + super(Cell, self).__init__() + print(C_prev_prev, C_prev, C) + + if reduction_prev: + self.preprocess0 = FactorizedReduce(C_prev_prev, C) + else: + self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) + self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) + + if reduction: + op_names, indices = zip(*genotype.reduce) + concat = genotype.reduce_concat + else: + op_names, indices = zip(*genotype.normal) + concat = genotype.normal_concat + self._compile(C, op_names, indices, concat, reduction) + + def _compile(self, C, op_names, indices, concat, reduction): + assert len(op_names) == len(indices) + self._steps = len(op_names) // 2 + self._concat = concat + self.multiplier = len(concat) + + self._ops = nn.ModuleList() + for name, index in zip(op_names, indices): + stride = 2 if reduction and index < 2 else 1 + op = OPS[name](C, stride, True) + self._ops += [op] + self._indices = indices + + def forward(self, s0, s1, drop_prob): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + for i in range(self._steps): + h1 = states[self._indices[2*i]] + h2 = states[self._indices[2*i+1]] + op1 = self._ops[2*i] + op2 = self._ops[2*i+1] + h1 = op1(h1) + h2 = op2(h2) + if self.training and drop_prob > 0.: + if not isinstance(op1, Identity): + h1 = drop_path(h1, drop_prob) + if not isinstance(op2, Identity): + h2 = drop_path(h2, drop_prob) + s = h1 + h2 + states += [s] + return torch.cat([states[i] for i in self._concat], dim=1) + + +class AuxiliaryHeadCIFAR(nn.Module): + + def __init__(self, C, num_classes): + """assuming input size 8x8""" + super(AuxiliaryHeadCIFAR, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x + +class AuxiliaryHeadImageNet(nn.Module): + + def __init__(self, C, num_classes): + """assuming input size 14x14""" + super(AuxiliaryHeadImageNet, self).__init__() + self.features = nn.Sequential( + nn.ReLU(inplace=True), + nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), + nn.Conv2d(C, 128, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 768, 2, bias=False), + # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. + # Commenting it out for consistency with the experiments in the paper. + # nn.BatchNorm2d(768), + nn.ReLU(inplace=True) + ) + self.classifier = nn.Linear(768, num_classes) + + def forward(self, x): + x = self.features(x) + x = self.classifier(x.view(x.size(0),-1)) + return x + +class NetworkCIFAR(nn.Module): + + def __init__(self, C, num_classes, layers, auxiliary, genotype): + super(NetworkCIFAR, self).__init__() + self._layers = layers + self._auxiliary = auxiliary + + stem_multiplier = 3 + C_curr = stem_multiplier*C + self.stem = nn.Sequential( + nn.Conv2d(3, C_curr, 3, padding=1, bias=False), + nn.BatchNorm2d(C_curr) + ) + + C_prev_prev, C_prev, C_curr = C_curr, C_curr, C + self.cells = nn.ModuleList() + reduction_prev = False + for i in range(layers): + if i in [layers//3, 2*layers//3]: + C_curr *= 2 + reduction = True + else: + reduction = False + + # cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + + if i < layers // 3: + cell = Cell(genotype[0], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif i == layers // 3: + cell = Cell(genotype[1], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif layers // 3 < i < 2 * layers // 3: + cell = Cell(genotype[2], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif i == 2 * layers // 3: + cell = Cell(genotype[3], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif 2 * layers // 3 < i: + cell = Cell(genotype[4], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr + if i == 2*layers//3: + C_to_auxiliary = C_prev + + if auxiliary: + self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + + def forward(self, input): + logits_aux = None + s0 = s1 = self.stem(input) + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1, self.drop_path_prob) + if i == 2*self._layers//3: + if self._auxiliary and self.training: + logits_aux = self.auxiliary_head(s1) + out = self.global_pooling(s1) + logits = self.classifier(out.view(out.size(0),-1)) + return logits, logits_aux + +class NetworkImageNet(nn.Module): + + def __init__(self, C, num_classes, layers, auxiliary, genotype): + super(NetworkImageNet, self).__init__() + self._layers = layers + self._auxiliary = auxiliary + + self.stem0 = nn.Sequential( + nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C // 2), + nn.ReLU(inplace=True), + nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C), + ) + + self.stem1 = nn.Sequential( + nn.ReLU(inplace=True), + nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), + nn.BatchNorm2d(C), + ) + + C_prev_prev, C_prev, C_curr = C, C, C + + self.cells = nn.ModuleList() + reduction_prev = True + for i in range(layers): + if i in [layers // 3, 2 * layers // 3]: + C_curr *= 2 + reduction = True + else: + reduction = False + # cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + if i < layers//6: + cell = Cell(genotype[0], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif layers//6 <= i < layers//3: + cell = Cell(genotype[0], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif i == layers//3: + cell = Cell(genotype[1], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif layers//3 < i < layers//2: + cell = Cell(genotype[2], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif layers // 2 <= i < 2*layers//3: + cell = Cell(genotype[2], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif i == 2*layers//3: + cell = Cell(genotype[3], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif 2*layers//3 < i < 17: + cell = Cell(genotype[4], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + elif 17 <= i: + cell = Cell(genotype[4], C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr + if i == 2 * layers // 3: + C_to_auxiliary = C_prev + + if auxiliary: + self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) + self.global_pooling = nn.AvgPool2d(7) + self.classifier = nn.Linear(C_prev, num_classes) + + def forward(self, input): + logits_aux = None + s0 = self.stem0(input) + s1 = self.stem1(s0) + for i, cell in enumerate(self.cells): + s0, s1 = s1, cell(s0, s1, self.drop_path_prob) + if i == 2 * self._layers // 3: + if self._auxiliary and self.training: + logits_aux = self.auxiliary_head(s1) + out = self.global_pooling(s1) + logits = self.classifier(out.view(out.size(0), -1)) + return logits, logits_aux \ No newline at end of file diff --git a/model_search.py b/model_search.py new file mode 100644 index 0000000..59b8763 --- /dev/null +++ b/model_search.py @@ -0,0 +1,227 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from operations import * +from torch.autograd import Variable +from genotypes import PRIMITIVES +from genotypes import Genotype + + +class AttentionModule(nn.Module): + def __init__(self, channel, ratio=16): + super(AttentionModule, self).__init__() + self.channel = channel + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // ratio, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // ratio, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + attention_out = x * y.expand_as(x) + + op_attention = [] + op_channel = c // 8 # Number of channels per operation + for i in range(8): + temp = y[:, i * op_channel:op_channel * (i + 1), :, :] # The attention weights of i-th operation + op_i_atten = torch.sum(temp) # Attention weights summation + op_attention.append(op_i_atten.item()) + + return attention_out, op_attention + +class ChannelAttention(nn.Module): + def __init__(self, channel=16, reduction=2): + super().__init__() + self.maxpool = nn.AdaptiveMaxPool2d(1) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.se = nn.Sequential( + nn.Conv2d(channel, channel//reduction, 1, bias=False), + nn.ReLU(), + nn.Conv2d(channel//reduction, channel, 1, bias=False) + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + max_result = self.maxpool(x) + avg_result = self.avgpool(x) + max_out = self.se(max_result) + avg_out = self.se(avg_result) + output = max_out + avg_out + output = self.sigmoid(output) + return output + +class MixedOp(nn.Module): + + def __init__(self, C, stride): + super(MixedOp, self).__init__() + self._ops = nn.ModuleList() + self.stride = stride + + self.mp = nn.MaxPool2d(2, 2) + self.k = 16 + self.ca = ChannelAttention(C) + + if self.stride == 2: + self.auxiliary_op = FactorizedReduce(C, C, affine=False) + else: + self.auxiliary_op = Identity() + + self.channel = C // self.k + for primitive in PRIMITIVES: + op = OPS[primitive](C // self.k, stride, False) + if 'pool' in primitive: + op = nn.Sequential(op, nn.BatchNorm2d(C // self.k, affine=False)) + self._ops.append(op) + + self.attention = AttentionModule(C * 8 // self.k, ratio=8) + + def forward(self, x): + + dim_2 = x.shape[1] + num_list = self.ca(x) + auxiliary_op = self.auxiliary_op(x) + x = x * num_list + slist = torch.sum(num_list, dim=0, keepdim=True) + values, max_num_index = slist.topk(dim_2 // self.k, dim=1, largest=True, sorted=True) + max_num_index = max_num_index.squeeze() + num_dict = max_num_index + xtemp = torch.index_select(x, 1, max_num_index) + + + out = 0 + temp = [] + for op in self._ops: + temp.append(op(xtemp)) + temp = torch.cat(temp[:], dim=1) # Concatenate feature maps in channel dimension + + attention_out, op_attention = self.attention(temp) # Calculate attention weights + + for i in range(8): # Integrate all feature maps by element-wise addition + out += attention_out[:, i * self.channel:self.channel * (i + 1):, :, :] + + # concat feature maps + if out.shape[2] == x.shape[2]: + x[:, num_dict, :, :] = out[:, :, :, :] + else: + x = self.mp(x) + x[:, num_dict, :, :] = out[:, :, :, :] + + x += auxiliary_op + return x, op_attention + + +class Cell(nn.Module): + + def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): + super(Cell, self).__init__() + self.reduction = reduction + self.reduction_prev = reduction_prev + if reduction_prev: + self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=False) + else: + self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=False) + self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=False) + self._steps = steps + self._multiplier = multiplier + + self._ops = nn.ModuleList() + self._bns = nn.ModuleList() + for i in range(self._steps): + for j in range(2 + i): + stride = 2 if reduction and j < 2 else 1 + op = MixedOp(C, stride) + self._ops.append(op) + + def forward(self, s0, s1): + s0 = self.preprocess0(s0) + s1 = self.preprocess1(s1) + + states = [s0, s1] + offset = 0 + op_Attention = [] + for i in range(self._steps): + + s = 0 + for j, h in enumerate(states): + temp, op_attention = self._ops[offset + j](h) + s += temp + + op_Attention.append(op_attention) # 14*8 attention weight matrix + + offset += len(states) + states.append(s) + + if self.reduction != True and self.reduction_prev != True: + states.append(s1) + return torch.cat(states[-self._multiplier:], dim=1), op_Attention # self._multiplier=4 + + +class Network(nn.Module): + + def __init__(self, C, num_classes, layers, criterion, steps=4, multiplier=4, stem_multiplier=3): + super(Network, self).__init__() + self._C = C + self._num_classes = num_classes + self._layers = layers + self._criterion = criterion + self._steps = steps + self._multiplier = multiplier + 2 + + C_curr = stem_multiplier * C # 48 + self.stem = nn.Sequential( + nn.Conv2d(3, C_curr, 3, padding=1, bias=False), + nn.BatchNorm2d(C_curr) + ) + + C_prev_prev, C_prev, C_curr = C_curr, C_curr, C + self.cells = nn.ModuleList() + reduction_prev = False + for i in range(layers): + if i in [layers // 3, 2 * layers // 3]: + C_curr *= 2 + reduction = True + else: + reduction = False + cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) + reduction_prev = reduction + self.cells += [cell] + C_prev_prev, C_prev = C_prev, multiplier * C_curr # 16 16*4 + + + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, num_classes) + + def new(self): + model_new = Network(self._C, self._num_classes, self._layers, self._criterion).cuda() + for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): + x.data.copy_(y.data) + return model_new + + def forward(self, input): + s0 = s1 = self.stem(input) + op_Attention_normal_all = [] + op_Attention_reduce_all = [] + for i, cell in enumerate(self.cells): + + if cell.reduction: + s2, op_Attention_reduce = cell(s0, s1) + op_Attention_reduce_all.append(op_Attention_reduce) # Different cell topologies are various + else: + s2, op_Attention_normal = cell(s0, s1) + op_Attention_normal_all.append(op_Attention_normal) # Different cell topologies are various + + s0, s1 = s1, s2 + out = self.global_pooling(s1) + logits = self.classifier(out.view(out.size(0), -1)) + return logits, op_Attention_normal_all, op_Attention_reduce_all + + def _loss(self, input, target): + logits = self(input) + return self._criterion(logits, target) + + diff --git a/models/cifar10.pt b/models/cifar10.pt new file mode 100644 index 0000000..d501c54 Binary files /dev/null and b/models/cifar10.pt differ diff --git a/models/cifar100.pt b/models/cifar100.pt new file mode 100644 index 0000000..37e635a Binary files /dev/null and b/models/cifar100.pt differ diff --git a/operations.py b/operations.py new file mode 100644 index 0000000..b0c62c5 --- /dev/null +++ b/operations.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn + +OPS = { + 'none' : lambda C, stride, affine: Zero(stride), + 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), + 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), + 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), + 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), + 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), + 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), + 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), + 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), + 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), + nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), + nn.BatchNorm2d(C, affine=affine) + ), +} + +class ReLUConvBN(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), + nn.BatchNorm2d(C_out, affine=affine) + ) + + def forward(self, x): + return self.op(x) + +class DilConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): + super(DilConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine), + ) + + def forward(self, x): + return self.op(x) + + +class SepConv(nn.Module): + + def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): + super(SepConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_in, affine=affine), + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out, affine=affine), + ) + + def forward(self, x): + return self.op(x) + + +class Identity(nn.Module): + + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x + + +class Zero(nn.Module): + + def __init__(self, stride): + super(Zero, self).__init__() + self.stride = stride + + def forward(self, x): + if self.stride == 1: + return x.mul(0.) + return x[:,:,::self.stride,::self.stride].mul(0.) + + +class FactorizedReduce(nn.Module): + + def __init__(self, C_in, C_out, affine=True): + super(FactorizedReduce, self).__init__() + assert C_out % 2 == 0 + self.relu = nn.ReLU(inplace=False) + self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) + self.bn = nn.BatchNorm2d(C_out, affine=affine) + + def forward(self, x): + x = self.relu(x) + out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) + out = self.bn(out) + return out + diff --git a/test.py b/test.py new file mode 100644 index 0000000..6f84325 --- /dev/null +++ b/test.py @@ -0,0 +1,104 @@ +import os +import sys +import glob +import numpy as np +import torch +import utils +import logging +import argparse +import torch.nn as nn +import genotypes +import torch.utils +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn + +from torch.autograd import Variable +from model import NetworkCIFAR as Network + + +parser = argparse.ArgumentParser("cifar") +parser.add_argument('--data', type=str, default='../datasets/cifar-10', help='location of the data corpus') +parser.add_argument('--batch_size', type=int, default=96, help='batch size') +parser.add_argument('--report_freq', type=float, default=50, help='report frequency') +parser.add_argument('--gpu', type=int, default=0, help='gpu device id') +parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') +parser.add_argument('--layers', type=int, default=20, help='total number of layers') +parser.add_argument('--model_path', type=str, default='./models/cifar10.pt', help='path of pretrained model') +parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower') +parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') +parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') +parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--arch', type=str, default='SWD_NAS', help='which architecture to use') +args = parser.parse_args() + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + +CIFAR_CLASSES = 10 + + +def main(): + if not torch.cuda.is_available(): + logging.info('no gpu device available') + sys.exit(1) + + np.random.seed(args.seed) + torch.cuda.set_device(args.gpu) + cudnn.benchmark = True + torch.manual_seed(args.seed) + cudnn.enabled=True + torch.cuda.manual_seed(args.seed) + logging.info('gpu device = %d' % args.gpu) + logging.info("args = %s", args) + + genotype = eval("genotypes.%s" % args.arch) + model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype) + model = model.cuda() + utils.load(model, args.model_path) + + logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) + + criterion = nn.CrossEntropyLoss() + criterion = criterion.cuda() + + _, test_transform = utils._data_transforms_cifar10(args) + test_data = dset.CIFAR10(root=args.data, train=False, download=False, transform=test_transform) + + test_queue = torch.utils.data.DataLoader( + test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) + + model.drop_path_prob = args.drop_path_prob + test_acc, test_obj = infer(test_queue, model, criterion) + logging.info('test_acc %f', test_acc) + + +def infer(test_queue, model, criterion): + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + model.eval() + + for step, (input, target) in enumerate(test_queue): + input = input.cuda() + target = target.cuda(non_blocking=True) + + logits, _ = model(input) + loss = criterion(logits, target) + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) + + return top1.avg, objs.avg + + +if __name__ == '__main__': + main() + diff --git a/train.py b/train.py new file mode 100644 index 0000000..1728a82 --- /dev/null +++ b/train.py @@ -0,0 +1,199 @@ +import os +import sys +import time +import glob +import numpy as np +from numpy import random +import torch +import utils +import logging +import argparse +import torch.nn as nn +import genotypes +import torch.utils +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn + +from torch.autograd import Variable +from model import NetworkCIFAR as Network + + +parser = argparse.ArgumentParser("cifar") +parser.add_argument('--data', type=str, default='../dataset/cifar10', help='location of the data corpus') +parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10 or cifar100') +parser.add_argument('--batch_size', type=int, default=96, help='batch size') +parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') +parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') +parser.add_argument('--report_freq', type=float, default=50, help='report frequency') +parser.add_argument('--gpu', type=int, default=0, help='gpu device id') +parser.add_argument('--epochs', type=int, default=600, help='num of training epochs') +parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') +parser.add_argument('--layers', type=int, default=20, help='total number of layers') +parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') +parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower') +parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') +parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') +parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') +parser.add_argument('--save', type=str, default='EXP', help='experiment name') +parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--arch', type=str, default='SWD_NAS', help='which architecture to use') +parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') +args = parser.parse_args() + +args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) +utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') +fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + +if args.dataset == 'cifar10': + CIFAR_CLASSES = 10 +elif args.dataset == 'cifar100': + CIFAR_CLASSES = 100 + + +def main(): + if not torch.cuda.is_available(): + logging.info('no gpu device available') + sys.exit(1) + + torch.cuda.set_device(args.gpu) + # fix seeds + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + os.environ['PYTHONHASHSEED']=str(args.seed) + cudnn.enabled = True + cudnn.benchmark = False + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.deterministic = True + logging.info('gpu device = %d' % args.gpu) + logging.info("args = %s", args) + + genotype = eval("genotypes.%s" % args.arch) + logging.info('genotype = %s', genotype) + model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype) + model = model.cuda() + + logging.info(model) + logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) + + criterion = nn.CrossEntropyLoss() + criterion = criterion.cuda() + optimizer = torch.optim.SGD( + model.parameters(), + args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay + ) + + + if args.dataset == 'cifar10': + dataset_class = dset.CIFAR10 + train_transform, valid_transform = utils._data_transforms_cifar10(args) + elif args.dataset == 'cifar100': + dataset_class = dset.CIFAR100 + train_transform, valid_transform = utils._data_transforms(args) + + train_data = dataset_class(root=args.data, train=True, download=True, transform=train_transform) + valid_data = dataset_class(root=args.data, train=False, download=True, transform=valid_transform) + + train_queue = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=2) + + valid_queue = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min) + + best_valid_acc = 0.0 + for epoch in range(args.epochs): + scheduler.step() + logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) + model.drop_path_prob = args.drop_path_prob * epoch / args.epochs + + train_acc, train_obj = train(train_queue, model, criterion, optimizer) + logging.info('train_acc %f', train_acc) + + with torch.no_grad(): + valid_acc, valid_obj = infer(valid_queue, model, criterion) + logging.info('valid_acc %f', valid_acc) + + if valid_acc > best_valid_acc: + best_valid_acc = valid_acc + utils.save(model, os.path.join(args.save, 'best_weights.pt')) + logging.info('best_valid_acc %f', best_valid_acc) + + utils.save(model, os.path.join(args.save, 'weights.pt')) + + +def train(train_queue, model, criterion, optimizer): + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + model.train() + + for step, (input, target) in enumerate(train_queue): + #input = Variable(input).cuda() + #target = Variable(target).cuda(async=True) + input = input.cuda() + target = target.cuda(non_blocking=True) + + optimizer.zero_grad() + logits, logits_aux = model(input) + loss = criterion(logits, target) + if args.auxiliary: + loss_aux = criterion(logits_aux, target) + loss += args.auxiliary_weight*loss_aux + loss.backward() + nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) + optimizer.step() + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) + + return top1.avg, objs.avg + + +def infer(valid_queue, model, criterion): + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + model.eval() + + for step, (input, target) in enumerate(valid_queue): + input = input.cuda() + target = target.cuda(non_blocking=True) + + logits, _ = model(input) + loss = criterion(logits, target) + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) + + return top1.avg, objs.avg + + +if __name__ == '__main__': + main() + diff --git a/train_imagenet.py b/train_imagenet.py new file mode 100644 index 0000000..3b4a3ab --- /dev/null +++ b/train_imagenet.py @@ -0,0 +1,252 @@ +import os +import sys +import time +import glob +import numpy as np +import torch +import utils +import logging +import argparse +import torch.nn as nn +import genotypes +import torch.utils +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn +import torchvision.transforms as transforms + +from torch.autograd import Variable +from model import NetworkImageNet as Network + +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' + +parser = argparse.ArgumentParser("imagenet") +parser.add_argument('--data', type=str, default='../datasets/', help='location of the data corpus') +parser.add_argument('--dataset', type=str, default='imagenet', help='imagenet') +parser.add_argument('--batch_size', type=int, default=256, help='batch size') +parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay') +parser.add_argument('--report_freq', type=float, default=100, help='report frequency') +parser.add_argument('--gpu', type=int, default=0, help='gpu device id') +parser.add_argument('--epochs', type=int, default=250, help='num of training epochs') +parser.add_argument('--init_channels', type=int, default=48, help='num of init channels') +parser.add_argument('--layers', type=int, default=20, help='total number of layers') +parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') +parser.add_argument('--auxiliary', action='store_true', default=True, help='use auxiliary tower') +parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') +parser.add_argument('--cutout', action='store_true', default=True, help='use cutout') +parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') +parser.add_argument('--save', type=str, default='EXP', help='experiment name') +parser.add_argument('--seed', type=int, default=1, help='random seed') +parser.add_argument('--arch', type=str, default='SWD_NAS', help='which architecture to use') +parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') +parser.add_argument('--workers', type=int, default=20, help='number of workers to load dataset') +parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') +args = parser.parse_args() + +args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) +utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') +fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + +CLASSES = 1000 + + +class CrossEntropyLabelSmooth(nn.Module): + + def __init__(self, num_classes, epsilon): + super(CrossEntropyLabelSmooth, self).__init__() + self.num_classes = num_classes + self.epsilon = epsilon + self.logsoftmax = nn.LogSoftmax(dim=1) + + def forward(self, inputs, targets): + log_probs = self.logsoftmax(inputs) + targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) + targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes + loss = (-targets * log_probs).mean(0).sum() + return loss + +def main(): + if not torch.cuda.is_available(): + logging.info('no gpu device available') + sys.exit(1) + + torch.cuda.set_device(args.gpu) + # fix seeds + np.random.seed(args.seed) + torch.manual_seed(args.seed) + os.environ['PYTHONHASHSEED']=str(args.seed) + cudnn.enabled = True + cudnn.benchmark = True + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.deterministic = True + logging.info('gpu device = %d' % args.gpu) + logging.info("args = %s", args) + + genotype = eval("genotypes.%s" % args.arch) + logging.info('genotype = %s', genotype) + model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) + model = model.cuda() + + logging.info(model) + logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) + print('hello') + criterion = nn.CrossEntropyLoss() + criterion = criterion.cuda() + criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) + criterion_smooth = criterion_smooth.cuda() + + optimizer = torch.optim.SGD( + model.parameters(), + args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay + ) + + data_dir = os.path.join(args.data, 'imagenet2012') + traindir = os.path.join(data_dir, 'train') + validdir = os.path.join(data_dir, 'val') + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + train_data = dset.ImageFolder( + traindir, + transforms.Compose([ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter( + brightness=0.4, + contrast=0.4, + saturation=0.4, + hue=0.2), + transforms.ToTensor(), + normalize, + ])) + valid_data = dset.ImageFolder( + validdir, + transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + normalize, + ])) + + train_queue = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers) + + valid_queue = torch.utils.data.DataLoader( + valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) + + best_valid_acc = 0.0 + for epoch in range(args.epochs): + scheduler.step() + logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) + model.drop_path_prob = args.drop_path_prob * epoch / args.epochs + + train_acc, train_obj = train(train_queue, model, criterion, optimizer) + logging.info('train_acc %f', train_acc) + + with torch.no_grad(): + valid_acc, valid_acc_r5, valid_obj = infer(valid_queue, model, criterion) + logging.info('valid_acc %f, valid_acc_r5 %f', valid_acc, valid_acc_r5) + + if valid_acc > best_valid_acc: + best_valid_acc = valid_acc + best_valid_acc_r5 = valid_acc_r5 + utils.save(model, os.path.join(args.save, 'best_weights.pt')) + logging.info('best_valid_acc %f, best_valid_acc_r5 %f', best_valid_acc, best_valid_acc_r5) + + utils.save(model, os.path.join(args.save, 'weights.pt')) + + +def train(train_queue, model, criterion, optimizer): + objs = utils.AverageMeter() + top1 = utils.AverageMeter() + top5 = utils.AverageMeter() + batch_time = utils.AverageMeter() + model.train() + + for step, (input, target) in enumerate(train_queue): + #input = Variable(input).cuda() + #target = Variable(target).cuda(async=True) + input = input.cuda() + target = target.cuda(non_blocking=True) + + b_start = time.time() + + optimizer.zero_grad() + logits, logits_aux = model(input) + loss = criterion(logits, target) + if args.auxiliary: + loss_aux = criterion(logits_aux, target) + loss += args.auxiliary_weight*loss_aux + loss.backward() + nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) + optimizer.step() + + batch_time.update(time.time()-b_start) + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.data.item(), n) + top1.update(prec1.data.item(), n) + top5.update(prec5.data.item(), n) + + if step % args.report_freq == 0: + end_time = time.time() + if step == 0: + duration = 0 + start_time = time.time() + else: + duration = end_time - start_time + start_time = time.time() + logging.info('TRAIN Step: %03d Objs: %e R1: %f R5: %f Duration: %ds BTime: %.3fs', + step, objs.avg, top1.avg, top5.avg, duration, batch_time.avg) + + return top1.avg, objs.avg + + +def infer(valid_queue, model, criterion): + objs = utils.AverageMeter() + top1 = utils.AverageMeter() + top5 = utils.AverageMeter() + model.eval() + + for step, (input, target) in enumerate(valid_queue): + input = input.cuda() + target = target.cuda(non_blocking=True) + + logits, _ = model(input) + loss = criterion(logits, target) + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + end_time = time.time() + if step == 0: + duration = 0 + start_time = time.time() + else: + duration = end_time - start_time + start_time = time.time() + logging.info('VALID Step: %03d Objs: %e R1: %f R5: %f Duration: %ds', step, objs.avg, top1.avg, top5.avg, + duration) + + return top1.avg, top5.avg, objs.avg + + +if __name__ == '__main__': + main() + diff --git a/train_search.py b/train_search.py new file mode 100644 index 0000000..cbb270a --- /dev/null +++ b/train_search.py @@ -0,0 +1,253 @@ +import os +import sys +import time +import glob +import numpy as np +from numpy import random +import torch +import utils +import logging +import argparse +import torch.nn as nn +import torch.utils +import torch.nn.functional as F +import torchvision.datasets as dset +import torch.backends.cudnn as cudnn +from torch.autograd import Variable +from model_search import Network +from genotypes import PRIMITIVES +from genotypes import Genotype, Normal_Genotype, Reduce_Genotype + + +parser = argparse.ArgumentParser("cifar") +parser.add_argument('--data', type=str, default='../datasets/cifar-10', help='location of the data corpus') +parser.add_argument('--batch_size', type=int, default=64, help='batch size') +parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') +parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') +parser.add_argument('--momentum', type=float, default=0.9, help='momentum') +parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') +parser.add_argument('--report_freq', type=float, default=50, help='report frequency') +parser.add_argument('--gpu', type=int, default=1, help='gpu device id') +parser.add_argument('--epochs', type=int, default=30, help='num of training epochs') +parser.add_argument('--init_channels', type=int, default=16, help='num of init channels') +parser.add_argument('--layers', type=int, default=5, help='total number of layers') +parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') +parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') +parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') +parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability') +parser.add_argument('--save', type=str, default='EXP', help='experiment name') +parser.add_argument('--seed', type=int, default=0, help='random seed') +parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') +parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') # 设置训练的数据的量 +parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss') +parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') +parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') +args = parser.parse_args() + +args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) +utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') +fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) +fh.setFormatter(logging.Formatter(log_format)) +logging.getLogger().addHandler(fh) + + +CIFAR_CLASSES = 10 + + +def main(): + if not torch.cuda.is_available(): + logging.info('no gpu device available') + sys.exit(1) + + np.random.seed(args.seed) + torch.cuda.set_device(args.gpu) + cudnn.benchmark = True + torch.manual_seed(args.seed) + cudnn.enabled=True + torch.cuda.manual_seed(args.seed) + logging.info('gpu device = %d' % args.gpu) + logging.info("args = %s", args) + + criterion = nn.CrossEntropyLoss() + criterion = criterion.cuda() + model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion) + logging.info(model) + model = model.cuda() + logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) + + optimizer = torch.optim.SGD( + model.parameters(), + args.learning_rate, + momentum=args.momentum, + weight_decay=args.weight_decay) + + train_transform, valid_transform = utils._data_transforms_cifar10(args) + train_data = dset.CIFAR10(root=args.data, train=True, download=False, transform=train_transform) + + num_train = len(train_data) + indices = list(range(num_train)) + split = int(np.floor(args.train_portion * num_train)) + + train_queue = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), + pin_memory=True, num_workers=2) + + valid_queue = torch.utils.data.DataLoader( + train_data, batch_size=args.batch_size, + sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]), + pin_memory=True, num_workers=2) + + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, float(args.epochs), eta_min=args.learning_rate_min) + + + for epoch in range(args.epochs): + scheduler.step() + lr = scheduler.get_lr()[0] + logging.info('epoch %d lr %e', epoch, lr) + + # training + train_acc, train_obj = train(train_queue, valid_queue, model, criterion, optimizer, lr) + logging.info('train_acc %f', train_acc) + + # validation + valid_acc, valid_obj = infer(valid_queue, model, criterion) + logging.info('valid_acc %f', valid_acc) + + utils.save(model, os.path.join(args.save, 'weights.pt')) + + +def train(train_queue, valid_queue, model, criterion, optimizer, lr): + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + op_Attention_normal_all = [] + op_Attention_reduce_all = [] + + for step, (input, target) in enumerate(train_queue): + model.train() + n = input.size(0) + + input = input.cuda() + target = target.cuda(non_blocking=True) + + optimizer.zero_grad() + logits, op_Attention_normal, op_Attention_reduce = model(input) + loss = criterion(logits, target) + + op_Attention_normal_all = np.sum([op_Attention_normal_all, op_Attention_normal], axis=0) + op_Attention_reduce_all = np.sum([op_Attention_reduce_all, op_Attention_reduce], axis=0) + + loss.backward() + nn.utils.clip_grad_norm(model.parameters(), args.grad_clip) + optimizer.step() + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) + + return top1.avg, objs.avg + + +def infer(valid_queue, model, criterion): + objs = utils.AvgrageMeter() + top1 = utils.AvgrageMeter() + top5 = utils.AvgrageMeter() + model.eval() + + op_Attention_normal_all = [] + op_Attention_reduce_all = [] + for step, (input, target) in enumerate(valid_queue): + input = input.cuda() + target = target.cuda(non_blocking=True) + with torch.no_grad(): + logits, op_Attention_normal, op_Attention_reduce = model(input) + loss = criterion(logits, target) + + op_Attention_normal_all = np.sum([op_Attention_normal_all, op_Attention_normal], axis=0) + op_Attention_reduce_all = np.sum([op_Attention_reduce_all, op_Attention_reduce], axis=0) + + prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) + n = input.size(0) + objs.update(loss.item(), n) + top1.update(prec1.item(), n) + top5.update(prec5.item(), n) + + if step % args.report_freq == 0: + logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) + + # logging.info('op_Attention_normal_all %s', op_Attention_normal_all) + # logging.info('op_Attention_reduce_all %s', op_Attention_reduce_all) + normal_genotype_0 = Normal_Genotype(normal=_parse(op_Attention_normal_all[0]), normal_concat=range(2, 6)) + normal_genotype_1 = Normal_Genotype(normal=_parse(op_Attention_normal_all[1]), normal_concat=range(2, 6)) + normal_genotype_2 = Normal_Genotype(normal=_parse(op_Attention_normal_all[2]), normal_concat=range(2, 6)) + reduce_genotype_0 = Reduce_Genotype(reduce=_parse(op_Attention_reduce_all[0]), reduce_concat=range(2, 6)) + reduce_genotype_1 = Reduce_Genotype(reduce=_parse(op_Attention_reduce_all[1]), reduce_concat=range(2, 6)) + logging.info('Cell_0 = %s', normal_genotype_0) + logging.info('Cell_1 = %s', reduce_genotype_0) + logging.info('Cell_2 = %s', normal_genotype_1) + logging.info('Cell_3 = %s', reduce_genotype_1) + logging.info('Cell_4 = %s', normal_genotype_2) + + return top1.avg, objs.avg + + +def _parse(Atten_weights): + gene = [] + start =0 + n = 2 + Atten_weights = torch.Tensor(Atten_weights) + weights = F.softmax(Atten_weights, dim=-1).data.cpu().numpy() + # logging.info(Atten_weights) + for i in range (4): # nodes + end = start + n + A = weights[start:end].copy() + edges = sorted(range(i + 2), key=lambda x: -max(A[x][k] for k in range(len(A[x])) if k != PRIMITIVES.index('none')))[:2] + for j in edges: # edges + k_best = None + for k in range(len(A[j])): # operations + if k != PRIMITIVES.index('none'): + if k_best is None or A[j][k] > A[j][k_best]: + k_best = k + gene.append((PRIMITIVES[k_best], j)) + start = end + n += 1 + return gene + +def _parse_norm(Atten_weights): + gene = [] + start =0 + n = 2 + Atten_weights = torch.Tensor(Atten_weights) + weights = torch.zeros_like(Atten_weights) + for i in len(Atten_weights): + weights[i] = Atten_weights[i]/torch.sum(Atten_weights[i]) + # weights = F.softmax(Atten_weights, dim=-1).data.cpu().numpy() + for i in range (4): # nodes + end = start + n + A = weights[start:end].copy() + edges = sorted(range(i + 2), key=lambda x: -max(A[x][k] for k in range(len(A[x])) if k != PRIMITIVES.index('none')))[:2] # 试一下修改选择边的依据为sum + for j in edges: # edges + k_best = None + for k in range(len(A[j])): # operations + if k != PRIMITIVES.index('none'): + if k_best is None or A[j][k] > A[j][k_best]: + k_best = k + gene.append((PRIMITIVES[k_best], j)) + start = end + n += 1 + return gene + + +if __name__ == '__main__': + main() + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d8b60bd --- /dev/null +++ b/utils.py @@ -0,0 +1,144 @@ +import os +import numpy as np +import torch +import shutil +import torchvision.transforms as transforms +from torch.autograd import Variable + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def accuracy(output, target, topk=(1,)): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].contiguous().view(-1).float().sum(0) + res.append(correct_k.mul_(100.0/batch_size)) + return res + + +class Cutout(object): + def __init__(self, length): + self.length = length + + def __call__(self, img): + h, w = img.size(1), img.size(2) + mask = np.ones((h, w), np.float32) + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img *= mask + return img + + +def _data_transforms_cifar10(args): + CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] + CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] + + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ]) + return train_transform, valid_transform + +def _data_transforms(args): + if args.dataset == 'cifar10': + DATA_MEAN = [0.49139968, 0.48215827, 0.44653124] + DATA_STD = [0.24703233, 0.24348505, 0.26158768] + elif args.dataset == 'cifar100': + DATA_MEAN = [0.5071, 0.4867, 0.4408] + DATA_STD = [0.2675, 0.2565, 0.2761] + else: + raise ValueError('No Defined Dataset!') + train_transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(DATA_MEAN, DATA_STD), + ]) + if args.cutout: + train_transform.transforms.append(Cutout(args.cutout_length)) + + valid_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(DATA_MEAN, DATA_STD), + ]) + return train_transform, valid_transform + +def count_parameters_in_MB(model): + return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 + + +def save_checkpoint(state, is_best, save): + filename = os.path.join(save, 'checkpoint.pth.tar') + torch.save(state, filename) + if is_best: + best_filename = os.path.join(save, 'model_best.pth.tar') + shutil.copyfile(filename, best_filename) + + +def save(model, model_path): + torch.save(model.state_dict(), model_path) + + +def load(model, model_path): + model.load_state_dict(torch.load(model_path)) + + +def drop_path(x, drop_prob): + if drop_prob > 0.: + keep_prob = 1.-drop_prob + mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) + x.div_(keep_prob) + x.mul_(mask) + return x + + +def create_exp_dir(path, scripts_to_save=None): + if not os.path.exists(path): + os.mkdir(path) + print('Experiment dir : {}'.format(path)) + + if scripts_to_save is not None: + os.mkdir(os.path.join(path, 'scripts')) + for script in scripts_to_save: + dst_file = os.path.join(path, 'scripts', os.path.basename(script)) + shutil.copyfile(script, dst_file) +