From 2f31df04dfb0f8f6fe8d6928b55183d75dfb191e Mon Sep 17 00:00:00 2001 From: Cong Wu <320971769@qq.com> Date: Thu, 21 Mar 2024 22:25:06 +0800 Subject: [PATCH] Add files via upload --- README.md | 75 ++- action_classification.py | 427 +++++++++++++ action_retrieval.py | 220 +++++++ data_gen/ntu_gendata.py | 194 ++++++ data_gen/preprocess.py | 88 +++ data_gen/rotation.py | 60 ++ dataset.py | 24 + feeder/__init__.py | 1 + feeder/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 191 bytes .../__pycache__/augmentations.cpython-310.pyc | Bin 0 -> 5493 bytes .../feeder_downstream.cpython-310.pyc | Bin 0 -> 3074 bytes .../feeder_pretraining.cpython-310.pyc | Bin 0 -> 3855 bytes feeder/augmentations.py | 181 ++++++ feeder/feeder_downstream.py | 101 ++++ feeder/feeder_pretraining.py | 126 ++++ graph/__init__.py | 3 + graph/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 214 bytes graph/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 297 bytes graph/__pycache__/ntu_rgb_d.cpython-310.pyc | Bin 0 -> 1690 bytes graph/__pycache__/ntu_rgb_d.cpython-36.pyc | Bin 0 -> 3758 bytes graph/__pycache__/tools.cpython-310.pyc | Bin 0 -> 2565 bytes graph/__pycache__/tools.cpython-36.pyc | Bin 0 -> 2563 bytes graph/__pycache__/ucla.cpython-310.pyc | Bin 0 -> 1654 bytes graph/__pycache__/ucla.cpython-36.pyc | Bin 0 -> 1963 bytes graph/ntu_rgb_d.py | 33 + graph/tools.py | 80 +++ log/n60_cb_recognition | 572 ++++++++++++++++++ log/n60_cb_retrieval | 226 +++++++ model/__init__.py | 0 model/__pycache__/__init__.cpython-310.pyc | Bin 0 -> 152 bytes model/__pycache__/__init__.cpython-36.pyc | Bin 0 -> 141 bytes model/__pycache__/ctrgcn.cpython-310.pyc | Bin 0 -> 9737 bytes model/__pycache__/ctrgcn.cpython-36.pyc | Bin 0 -> 11458 bytes model/ctrgcn.py | 307 ++++++++++ .../options_classification.cpython-310.pyc | Bin 0 -> 4151 bytes .../options_pretraining.cpython-310.pyc | Bin 0 -> 2576 bytes options/options_classification.py | 183 ++++++ options/options_pretraining.py | 96 +++ options/options_retrieval.py | 178 ++++++ pretraining.py | 293 +++++++++ requirements.txt | 96 +++ scd/__pycache__/builder.cpython-310.pyc | Bin 0 -> 4216 bytes scd/__pycache__/hi_encoder.cpython-310.pyc | Bin 0 -> 3163 bytes scd/builder.py | 149 +++++ scd/scd_encoder.py | 135 +++++ 45 files changed, 3846 insertions(+), 2 deletions(-) create mode 100644 action_classification.py create mode 100644 action_retrieval.py create mode 100644 data_gen/ntu_gendata.py create mode 100644 data_gen/preprocess.py create mode 100644 data_gen/rotation.py create mode 100644 dataset.py create mode 100644 feeder/__init__.py create mode 100644 feeder/__pycache__/__init__.cpython-310.pyc create mode 100644 feeder/__pycache__/augmentations.cpython-310.pyc create mode 100644 feeder/__pycache__/feeder_downstream.cpython-310.pyc create mode 100644 feeder/__pycache__/feeder_pretraining.cpython-310.pyc create mode 100644 feeder/augmentations.py create mode 100644 feeder/feeder_downstream.py create mode 100644 feeder/feeder_pretraining.py create mode 100644 graph/__init__.py create mode 100644 graph/__pycache__/__init__.cpython-310.pyc create mode 100644 graph/__pycache__/__init__.cpython-36.pyc create mode 100644 graph/__pycache__/ntu_rgb_d.cpython-310.pyc create mode 100644 graph/__pycache__/ntu_rgb_d.cpython-36.pyc create mode 100644 graph/__pycache__/tools.cpython-310.pyc create mode 100644 graph/__pycache__/tools.cpython-36.pyc create mode 100644 graph/__pycache__/ucla.cpython-310.pyc create mode 100644 graph/__pycache__/ucla.cpython-36.pyc create mode 100644 graph/ntu_rgb_d.py create mode 100644 graph/tools.py create mode 100644 log/n60_cb_recognition create mode 100644 log/n60_cb_retrieval create mode 100644 model/__init__.py create mode 100644 model/__pycache__/__init__.cpython-310.pyc create mode 100644 model/__pycache__/__init__.cpython-36.pyc create mode 100644 model/__pycache__/ctrgcn.cpython-310.pyc create mode 100644 model/__pycache__/ctrgcn.cpython-36.pyc create mode 100644 model/ctrgcn.py create mode 100644 options/__pycache__/options_classification.cpython-310.pyc create mode 100644 options/__pycache__/options_pretraining.cpython-310.pyc create mode 100644 options/options_classification.py create mode 100644 options/options_pretraining.py create mode 100644 options/options_retrieval.py create mode 100644 pretraining.py create mode 100644 requirements.txt create mode 100644 scd/__pycache__/builder.cpython-310.pyc create mode 100644 scd/__pycache__/hi_encoder.cpython-310.pyc create mode 100644 scd/builder.py create mode 100644 scd/scd_encoder.py diff --git a/README.md b/README.md index 4cc431c..f05331c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,74 @@ # SCD-Net -## An new SOTA method for Unsupervised Skeleton-based Action Understanding -(Code will be published soon!) +The Official implementation for 'SCD-Net: Spatiotemporal Clues Disentanglement Network for Self-supervised Skeleton-based Action Recognition' (AAAI 2024). + + +- [Prerequisite](#Prerequisite) +- [Data](#Data) +- [Training&Testing](#Training&Testing) +- [Log files](#Log) + + + + +# Prerequisite + +- Pytorch + +- We provided requirement file to install all packages, just by running + + +`pip install -r requirements.txt` + + + + +# Data + +## Generate the data + +**Download the raw data** + +- [NTU-RGB+D](https://rose1.ntu.edu.sg/dataset/actionRecognition/). +- [PKU-MMD](https://www.icst.pku.edu.cn/struct/Projects/PKUMMD.html). + +**Preprocess** + +- Preprocess data with `python ntu_gendata.py`. + + + + +# Training&Testing + +## Training + +- To train on NTU-RGB+D 60 under Cross-Subject evaluation, you can run + + + python ./pretraining.py --lr 0.01 --batch-size 64 --encoder-t 0.2 --encoder-k 8192 \ + --checkpoint-path ./checkpoints/pretrain/ \ + --schedule 351 --epochs 451 --pre-dataset ntu60 \ + --protocol cross_subject --skeleton-representation joint + +## Testing + + +- For action recognition on NTU-RGB+D 60 under Cross-Subject evaluation, you can run + + + python ./action_classification.py --lr 2 --batch-size 1024 \ + --pretrained ./checkpoints/pretrain/checkpoint.pth.tar \ + --finetune-dataset ntu60 --protocol cross_subject --finetune_skeleton_representation joint + +- For action retrieval on NTU-RGB+D 60 under Cross-Subject evaluation, you can run + + + python ./action_retrieval.py --knn-neighbours 1 \ + --pretrained ./checkpoints/pretrain/checkpoint.pth.tar \ + --finetune-dataset ntu60 --protocol cross_subject --finetune-skeleton-representation joint + + + +# Log files + +We also provided some the testing logs in ./log diff --git a/action_classification.py b/action_classification.py new file mode 100644 index 0000000..1a47967 --- /dev/null +++ b/action_classification.py @@ -0,0 +1,427 @@ +import argparse +import os +import random +import time +import warnings + +import torch +import torch.nn as nn + +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data + +from scd.scd_encoder import DownstreamEncoder + +from dataset import get_finetune_training_set, get_finetune_validation_set + +parser = argparse.ArgumentParser(description='Classification') +parser.add_argument('-j', '--workers', default=24, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=80, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=30., type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--schedule', default=[50, 70,], nargs='*', type=int, + help='learning rate schedule (when to drop lr by a ratio)') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') +parser.add_argument('--wd', '--weight-decay', default=0., type=float, + metavar='W', help='weight decay (default: 0.)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training.') +parser.add_argument('--pretrained', default='', type=str, + help='path to moco pretrained checkpoint') +parser.add_argument('--finetune-dataset', default='ntu60', type=str, + help='which dataset to use for finetuning') +parser.add_argument('--protocol', default='cross_view', type=str, + help='traiining protocol of ntu') +parser.add_argument('--finetune_skeleton_representation', default='joint', type=str, + help='which skeleton-representation to use for downstream training') + + +best_acc1 = 0 + + +# initilize weight +def weights_init(model): + with torch.no_grad(): + for child in list(model.children()): + print("init ", child) + for param in list(child.parameters()): + if param.dim() == 2: + nn.init.xavier_uniform_(param) + print('Weight initial finished!') + + +def load_moco_encoder_q(model,pretrained): + + if os.path.isfile(pretrained): + print("=> loading checkpoint '{}'".format(pretrained)) + checkpoint = torch.load(pretrained, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('encoder_q') and not k.startswith('encoder_q.fc'): + # remove prefix + state_dict[k[len("encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + print("message", msg) + # assert set(msg.missing_keys) == {"fc.weight", "fc.bias"} + + print("=> loaded pre-trained model '{}'".format(pretrained)) + else: + print("=> no checkpoint found at '{}'".format(pretrained)) + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + global best_acc1 + + # create model + + # training dataset + from options import options_classification as options + if args.finetune_dataset== 'ntu60' and args.protocol == 'cross_view': + opts = options.opts_ntu_60_cross_view() + elif args.finetune_dataset== 'ntu60' and args.protocol == 'cross_subject': + opts = options.opts_ntu_60_cross_subject() + elif args.finetune_dataset== 'ntu120' and args.protocol == 'cross_setup': + opts = options.opts_ntu_120_cross_setup() + elif args.finetune_dataset== 'ntu120' and args.protocol == 'cross_subject': + opts = options.opts_ntu_120_cross_subject() + elif args.finetune_dataset == 'pku_part1' and args.protocol == 'cross_subject': + opts = options.opts_pku_part1_cross_subject() + elif args.finetune_dataset == 'pku_part2' and args.protocol == 'cross_subject': + opts = options.opts_pku_part2_cross_subject() + + opts.train_feeder_args['input_representation'] = args.finetune_skeleton_representation + opts.test_feeder_args['input_representation'] = args.finetune_skeleton_representation + + model = DownstreamEncoder(**opts.encoder_args) + print(model) + print("options", opts.encoder_args, opts.train_feeder_args, opts.test_feeder_args) + if not args.pretrained: + weights_init(model) + + if args.pretrained: + # freeze all layers but the last fc + for name, param in model.named_parameters(): + if name not in ['fc.weight', 'fc.bias']: + param.requires_grad = False + else: + print('params', name) + # init the fc layer + model.fc.weight.data.normal_(mean=0.0, std=0.01) + model.fc.bias.data.zero_() + + # load from pre-trained model + load_moco_encoder_q(model, args.pretrained) + + model = model.cuda() + + total = sum([param.nelement() for param in model.parameters()]) + print("Number of parameter: %.2fM" % (total/1e6)) + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda() + + # optimize only the linear classifier + parameters = list(filter(lambda p: p.requires_grad, model.parameters())) + if args.pretrained: + assert len(parameters) == 2 # fc.weight, fc.bias + optimizer = torch.optim.SGD(parameters, args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + if True: + for parm in optimizer.param_groups: + print("optimize parameters lr", parm['lr']) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + best_acc1 = checkpoint['best_acc1'] + if args.gpu is not None: + # best_acc1 may be from a checkpoint from a different GPU + best_acc1 = best_acc1.to(args.gpu) + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # cudnn.benchmark = True + + ## Data loading code + + train_dataset = get_finetune_training_set(opts) + val_dataset = get_finetune_validation_set(opts) + + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler,drop_last=False) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True,drop_last=False) + + for epoch in range(args.start_epoch, args.epochs): + + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + train(train_loader, model, criterion, optimizer, epoch, args) + + # evaluate on validation set + acc1 = validate(val_loader, model, criterion, args) + + # remember best acc@1 and save checkpoint + is_best = acc1 > best_acc1 + if is_best: + print("found new best accuracy:= ", acc1) + best_acc1 = max(acc1, best_acc1) + + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'best_acc1': best_acc1, + 'optimizer': optimizer.state_dict(), + }, filename=args.finetune_skeleton_representation + '_model_best.pth.tar') + + # sanity check + if epoch == args.start_epoch: + sanity_check_encoder_q(model.state_dict(), args.pretrained) + print("Final best accuracy", best_acc1) + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, top1, top5], + prefix="Epoch: [{}]".format(epoch)) + + """ + Switch to eval mode: + Under the protocol of linear classification on frozen features/models, + it is not legitimate to change any part of the pre-trained model. + BatchNorm in train mode may revise running mean/std (even if it receives + no gradient), which are part of the model parameters too. + """ + model.eval() + + end = time.time() + for i, (q_input, target) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + q_input = q_input.float().cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(q_input) + + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), q_input.size(0)) + top1.update(acc1[0], q_input.size(0)) + top5.update(acc5[0], q_input.size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def validate(val_loader, model, criterion, args): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (q_input, target) in enumerate(val_loader): + + q_input = q_input.float().cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(q_input) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), q_input.size(0)) + top1.update(acc1[0], q_input.size(0)) + top5.update(acc5[0], q_input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + # TODO: this should also be done with the ProgressMeter + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' + .format(top1=top1, top5=top5)) + + return top1.avg + + +def save_checkpoint(state, filename='checkpoint.pth.tar'): + torch.save(state, './checkpoints/recognition/'+filename) + + +def sanity_check_encoder_q(state_dict, pretrained_weights): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading '{}' for sanity check".format(pretrained_weights)) + checkpoint = torch.load(pretrained_weights, map_location="cpu") + state_dict_pre = checkpoint['state_dict'] + + for k in list(state_dict.keys()): + # only ignore fc layer + if 'fc.weight' in k or 'fc.bias' in k: + continue + + # name in pretrained model + k_pre = 'encoder_q.' + k + + assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \ + '{} is changed in linear classifier training.'.format(k) + + print("=> sanity check passed.") + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries), flush=True) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() diff --git a/action_retrieval.py b/action_retrieval.py new file mode 100644 index 0000000..c88818a --- /dev/null +++ b/action_retrieval.py @@ -0,0 +1,220 @@ +import argparse +import os +import random + +import warnings +import torch +import torch.nn as nn + +import torch.backends.cudnn as cudnn + +import torch.optim + +import torch.utils.data + +import numpy as np +from sklearn.neighbors import KNeighborsClassifier +from sklearn import preprocessing +from sklearn.metrics import accuracy_score + +from scd.scd_encoder import DownstreamEncoder + + +from dataset import get_finetune_training_set, get_finetune_validation_set + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') +parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') + +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') + +parser.add_argument('--pretrained', default='', type=str, + help='path to moco pretrained checkpoint') +parser.add_argument('--finetune-dataset', default='ntu60', type=str, + help='which dataset to use for finetuning') +parser.add_argument('--protocol', default='cross_view', type=str, + help='traiining protocol of ntu') +parser.add_argument('--finetune-skeleton-representation', default='joint', type=str, + help='which skeleton-representation to use for downstream training') + +parser.add_argument('--knn-neighbours', default=None, type=int, + help='number of neighbours used for KNN.') + +best_acc1 = 0 + +# initilize weight +def weights_init(model): + with torch.no_grad(): + for child in list(model.children()): + print("init ", child) + for param in list(child.parameters()): + if param.dim() == 2: + nn.init.xavier_uniform_(param) + print('PC weight initial finished!') + + +def load_moco_encoder_q(model, pretrained): + + if os.path.isfile(pretrained): + print("=> loading checkpoint '{}'".format(pretrained)) + checkpoint = torch.load(pretrained, map_location="cpu") + + # rename moco pre-trained keys + state_dict = checkpoint['state_dict'] + for k in list(state_dict.keys()): + # retain only encoder_q up to before the embedding layer + if k.startswith('encoder_q') and not k.startswith('encoder_q.fc'): + # remove prefix + state_dict[k[len("encoder_q."):]] = state_dict[k] + # delete renamed or unused k + del state_dict[k] + + msg = model.load_state_dict(state_dict, strict=False) + print("message", msg) + + print("=> loaded pre-trained model '{}'".format(pretrained)) + else: + print("=> no checkpoint found at '{}'".format(pretrained)) + + +def knn(data_train, data_test, label_train, label_test, nn=9): + label_train = np.asarray(label_train) + label_test = np.asarray(label_test) + print("Number of KNN Neighbours = ", nn) + print("training feature and labels", data_train.shape, len(label_train)) + print("test feature and labels", data_test.shape, len(label_test)) + + Xtr_Norm = preprocessing.normalize(data_train) + Xte_Norm = preprocessing.normalize(data_test) + + knn = KNeighborsClassifier(n_neighbors=nn, + metric='cosine') + knn.fit(Xtr_Norm, label_train) + pred = knn.predict(Xte_Norm) + acc = accuracy_score(pred, label_test) + + return acc + + +def test_extract_hidden(model, data_train, data_eval): + model.eval() + for ith, (ith_data, label) in enumerate(data_train): + print(ith) + input_tensor = ith_data.cuda() + + en_hi = model(input_tensor, knn_eval=True) + + if ith == 0: + label_train = label + hidden_array_train = en_hi + + else: + label_train = torch.cat((label_train, label)) + hidden_array_train = torch.cat((hidden_array_train, en_hi)) + + model.eval() + for ith, (ith_data, label) in enumerate(data_eval): + print(ith) + input_tensor = ith_data.cuda() + + en_hi = model(input_tensor, knn_eval=True) + en_hi = en_hi + if ith == 0: + hidden_array_eval = en_hi + label_eval = label + else: + label_eval = torch.cat((label_eval, label)) + hidden_array_eval = torch.cat((hidden_array_eval, en_hi)) + + return hidden_array_train, hidden_array_eval, label_train, label_eval + + +def clustering_knn_acc(model, train_loader, eval_loader, knn_neighbours=1): + _train, _eval, label_train, label_eval = test_extract_hidden(model, train_loader, eval_loader) + + knn_acc_1 = knn(_train.cpu(), _eval.cpu(), label_train, label_eval, nn=knn_neighbours) + + return knn_acc_1 + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + global best_acc1 + + + # training dataset + from options import options_retrieval as options + if args.finetune_dataset == 'ntu60' and args.protocol == 'cross_view': + opts = options.opts_ntu_60_cross_view() + elif args.finetune_dataset == 'ntu60' and args.protocol == 'cross_subject': + opts = options.opts_ntu_60_cross_subject() + elif args.finetune_dataset == 'ntu120' and args.protocol == 'cross_setup': + opts = options.opts_ntu_120_cross_setup() + elif args.finetune_dataset == 'ntu120' and args.protocol == 'cross_subject': + opts = options.opts_ntu_120_cross_subject() + + opts.train_feeder_args['input_representation'] = args.finetune_skeleton_representation + opts.test_feeder_args['input_representation'] = args.finetune_skeleton_representation + + # create model + model = DownstreamEncoder(**opts.encoder_args) + print(model) + print("options", opts.encoder_args, opts.train_feeder_args, opts.test_feeder_args) + if not args.pretrained: + weights_init(model) + + if args.pretrained: + # freeze all layers + for name, param in model.named_parameters(): + param.requires_grad = False + + # load from pre-trained model + load_moco_encoder_q(model, args.pretrained) + + model = model.cuda() + + # cudnn.benchmark = True + + # Data loading code + train_dataset = get_finetune_training_set(opts) + val_dataset = get_finetune_validation_set(opts) + + train_sampler = None + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=False) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True,drop_last=False) + + # Extract frozen features of the pre-trained query encoder + # evaluate a KNN classifier on extracted features + acc1 = clustering_knn_acc(model, train_loader, val_loader, knn_neighbours=args.knn_neighbours) + + print("KNN retrieval acc = ", acc1) + +if __name__ == '__main__': + main() + diff --git a/data_gen/ntu_gendata.py b/data_gen/ntu_gendata.py new file mode 100644 index 0000000..a65bd25 --- /dev/null +++ b/data_gen/ntu_gendata.py @@ -0,0 +1,194 @@ +import argparse +import pickle +from tqdm import tqdm +import sys +from numpy.lib.format import open_memmap + +sys.path.extend(['../']) +from data_gen.preprocess import pre_normalization + +# ntu 60 +training_subjects = [ + 1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38 +] + +# ntu 120 +training_subjects = [ + 1, 2, 4, 5, 8, 9, 13, 14, 15, 16, 17, 18, 19, 25, 27, 28, 31, 34, 35, 38, + 45, 46, 47, 49, 50, 52, 53, 54, 55, 56, 57, 58, 59, 70, 74, 78,80, 81, 82, + 83, 84, 85, 86, 89, 91, 92, 93, 94, 95, 97, 98, 100, 103 +] +training_setups = [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32] + +training_cameras = [2, 3] +max_body_true = 2 +max_body_kinect = 4 +num_joint = 25 +max_frame = 300 + +import numpy as np +import os + + +def read_skeleton_filter(file): + with open(file, 'r') as f: + skeleton_sequence = {} + skeleton_sequence['numFrame'] = int(f.readline()) + skeleton_sequence['frameInfo'] = [] + # num_body = 0 + for t in range(skeleton_sequence['numFrame']): + frame_info = {} + frame_info['numBody'] = int(f.readline()) + frame_info['bodyInfo'] = [] + + for m in range(frame_info['numBody']): + body_info = {} + body_info_key = [ + 'bodyID', 'clipedEdges', 'handLeftConfidence', + 'handLeftState', 'handRightConfidence', 'handRightState', + 'isResticted', 'leanX', 'leanY', 'trackingState' + ] + body_info = { + k: float(v) + for k, v in zip(body_info_key, f.readline().split()) + } + body_info['numJoint'] = int(f.readline()) + body_info['jointInfo'] = [] + for v in range(body_info['numJoint']): + joint_info_key = [ + 'x', 'y', 'z', 'depthX', 'depthY', 'colorX', 'colorY', + 'orientationW', 'orientationX', 'orientationY', + 'orientationZ', 'trackingState' + ] + joint_info = { + k: float(v) + for k, v in zip(joint_info_key, f.readline().split()) + } + body_info['jointInfo'].append(joint_info) + frame_info['bodyInfo'].append(body_info) + skeleton_sequence['frameInfo'].append(frame_info) + + return skeleton_sequence + + +def get_nonzero_std(s): # tvc + index = s.sum(-1).sum(-1) != 0 # select valid frames + s = s[index] + if len(s) != 0: + s = s[:, :, 0].std() + s[:, :, 1].std() + s[:, :, 2].std() # three channels + else: + s = 0 + return s + + +def read_xyz(file, max_body=4, num_joint=25): # 取了前两个body + seq_info = read_skeleton_filter(file) + data = np.zeros((max_body, seq_info['numFrame'], num_joint, 3)) + for n, f in enumerate(seq_info['frameInfo']): + for m, b in enumerate(f['bodyInfo']): + for j, v in enumerate(b['jointInfo']): + if m < max_body and j < num_joint: + data[m, n, j, :] = [v['x'], v['y'], v['z']] + else: + pass + + # select two max energy body + energy = np.array([get_nonzero_std(x) for x in data]) + index = energy.argsort()[::-1][0:max_body_true] + data = data[index] + + data = data.transpose(3, 1, 2, 0) + return data + + +def gendata(data_path, out_path, ignored_sample_path=None, benchmark='xview', part='eval'): + if ignored_sample_path != None: + with open(ignored_sample_path, 'r') as f: + ignored_samples = [ + line.strip() + '.skeleton' for line in f.readlines() + ] + else: + ignored_samples = [] + sample_name = [] + sample_label = [] + for filename in os.listdir(data_path): + if filename in ignored_samples: + continue + action_class = int( + filename[filename.find('A') + 1:filename.find('A') + 4]) + subject_id = int( + filename[filename.find('P') + 1:filename.find('P') + 4]) + camera_id = int( + filename[filename.find('C') + 1:filename.find('C') + 4]) + setup_id = int( + filename[filename.find('S') + 1:filename.find('S') + 4]) + + if benchmark == 'xview': + istraining = (camera_id in training_cameras) + elif benchmark == 'xsub': + istraining = (subject_id in training_subjects) + elif benchmark == 'xsetup': + istraining = (setup_id in training_setups) + else: + raise ValueError() + + if part == 'train': + issample = istraining + elif part == 'val': + issample = not (istraining) + else: + raise ValueError() + + if issample: + sample_name.append(filename) + sample_label.append(action_class - 1) + + with open('{}/{}_label.pkl'.format(out_path, part), 'wb') as f: + pickle.dump((sample_name, list(sample_label)), f) + + fl = open_memmap( + '{}/{}_num_frame.npy'.format(out_path, part), + dtype='int', + mode='w+', + shape=(len(sample_label),)) + + fp = np.zeros((len(sample_label), 3, max_frame, num_joint, max_body_true), dtype=np.float32) + + for i, s in enumerate(tqdm(sample_name)): + data = read_xyz(os.path.join(data_path, s), max_body=max_body_kinect, num_joint=num_joint) + fp[i, :, 0:data.shape[1], :, :] = data + fl[i] = data.shape[1] # num_frame + + fp = pre_normalization(fp) + np.save('{}/{}_data_joint.npy'.format(out_path, part), fp) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='NTU-RGB-D Data Converter.') + parser.add_argument('--data_path', default='../data/nturgbd_raw/nturgb+d_skeletons/') + parser.add_argument('--ignored_sample_path', + default='../data/nturgbd_raw/samples_with_missing_skeletons.txt') + parser.add_argument('--out_folder', default='../data/NTU-RGB-D-60-AGCN/') + benchmark = ['xsub', 'xview'] + + #parser.add_argument('--data_path', default='../data/nturgbd_raw_120/nturgb+d_skeletons/') + #parser.add_argument('--ignored_sample_path', + # default='../data/nturgbd_raw_120/samples_with_missing_skeletons.txt') + #parser.add_argument('--out_folder', default='../data/NTU-RGB-D-120-AGCN/') + #benchmark = ['xsub','xsetup', ] + + part = ['train', 'val'] + arg = parser.parse_args() + + for b in benchmark: + for p in part: + out_path = os.path.join(arg.out_folder, b) + if not os.path.exists(out_path): + os.makedirs(out_path) + print(b, p) + gendata( + arg.data_path, + out_path, + arg.ignored_sample_path, + benchmark=b, + part=p) diff --git a/data_gen/preprocess.py b/data_gen/preprocess.py new file mode 100644 index 0000000..69702a9 --- /dev/null +++ b/data_gen/preprocess.py @@ -0,0 +1,88 @@ +import sys + +sys.path.extend(['../']) +from data_gen.rotation import * +from tqdm import tqdm + + +def pre_normalization(data, zaxis=[0, 1], xaxis=[8, 4]): + N, C, T, V, M = data.shape + s = np.transpose(data, [0, 4, 2, 3, 1]) # N, C, T, V, M to N, M, T, V, C + + print('pad the null frames with the previous frames') + for i_s, skeleton in enumerate(tqdm(s)): # pad + if skeleton.sum() == 0: + print(i_s, ' has no skeleton') + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + if person[0].sum() == 0: + index = (person.sum(-1).sum(-1) != 0) + tmp = person[index].copy() + person *= 0 + person[:len(tmp)] = tmp + for i_f, frame in enumerate(person): + if frame.sum() == 0: + if person[i_f:].sum() == 0: + rest = len(person) - i_f + num = int(np.ceil(rest / i_f)) + pad = np.concatenate([person[0:i_f] for _ in range(num)], 0)[:rest] + s[i_s, i_p, i_f:] = pad + break + + print('sub the center joint #1 (spine joint in ntu and neck joint in kinetics)') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + main_body_center = skeleton[0][:, 1:2, :].copy() + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + mask = (person.sum(-1) != 0).reshape(T, V, 1) + s[i_s, i_p] = (s[i_s, i_p] - main_body_center) * mask + + print('parallel the bone between hip(jpt 0) and spine(jpt 1) of the first person to the z axis') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_bottom = skeleton[0, 0, zaxis[0]] + joint_top = skeleton[0, 0, zaxis[1]] + axis = np.cross(joint_top - joint_bottom, [0, 0, 1]) + angle = angle_between(joint_top - joint_bottom, [0, 0, 1]) + matrix_z = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_z, joint) + + print( + 'parallel the bone between right shoulder(jpt 8) and left shoulder(jpt 4) of the first person to the x axis') + for i_s, skeleton in enumerate(tqdm(s)): + if skeleton.sum() == 0: + continue + joint_rshoulder = skeleton[0, 0, xaxis[0]] + joint_lshoulder = skeleton[0, 0, xaxis[1]] + axis = np.cross(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + angle = angle_between(joint_rshoulder - joint_lshoulder, [1, 0, 0]) + matrix_x = rotation_matrix(axis, angle) + for i_p, person in enumerate(skeleton): + if person.sum() == 0: + continue + for i_f, frame in enumerate(person): + if frame.sum() == 0: + continue + for i_j, joint in enumerate(frame): + s[i_s, i_p, i_f, i_j] = np.dot(matrix_x, joint) + + data = np.transpose(s, [0, 4, 2, 3, 1]) + return data + + +if __name__ == '__main__': + data = np.load('../data/ntu/xview/val_data.npy') + pre_normalization(data) + np.save('../data/ntu/xview/data_val_pre.npy', data) diff --git a/data_gen/rotation.py b/data_gen/rotation.py new file mode 100644 index 0000000..6e8aaa0 --- /dev/null +++ b/data_gen/rotation.py @@ -0,0 +1,60 @@ +import numpy as np +import math + + +def rotation_matrix(axis, theta): + """ + Return the rotation matrix associated with counterclockwise rotation about + the given axis by theta radians. + """ + if np.abs(axis).sum() < 1e-6 or np.abs(theta) < 1e-6: + return np.eye(3) + axis = np.asarray(axis) + axis = axis / math.sqrt(np.dot(axis, axis)) + a = math.cos(theta / 2.0) + b, c, d = -axis * math.sin(theta / 2.0) + aa, bb, cc, dd = a * a, b * b, c * c, d * d + bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d + return np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], + [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], + [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]) + + +def unit_vector(vector): + """ Returns the unit vector of the vector. """ + return vector / np.linalg.norm(vector) + + +def angle_between(v1, v2): + """ Returns the angle in radians between vectors 'v1' and 'v2':: + + >>> angle_between((1, 0, 0), (0, 1, 0)) + 1.5707963267948966 + >>> angle_between((1, 0, 0), (1, 0, 0)) + 0.0 + >>> angle_between((1, 0, 0), (-1, 0, 0)) + 3.141592653589793 + """ + if np.abs(v1).sum() < 1e-6 or np.abs(v2).sum() < 1e-6: + return 0 + v1_u = unit_vector(v1) + v2_u = unit_vector(v2) + return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0)) + + +def x_rotation(vector, theta): + """Rotates 3-D vector around x-axis""" + R = np.array([[1, 0, 0], [0, np.cos(theta), -np.sin(theta)], [0, np.sin(theta), np.cos(theta)]]) + return np.dot(R, vector) + + +def y_rotation(vector, theta): + """Rotates 3-D vector around y-axis""" + R = np.array([[np.cos(theta), 0, np.sin(theta)], [0, 1, 0], [-np.sin(theta), 0, np.cos(theta)]]) + return np.dot(R, vector) + + +def z_rotation(vector, theta): + """Rotates 3-D vector around z-axis""" + R = np.array([[np.cos(theta), -np.sin(theta), 0], [np.sin(theta), np.cos(theta), 0], [0, 0, 1]]) + return np.dot(R, vector) diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..f8d95e6 --- /dev/null +++ b/dataset.py @@ -0,0 +1,24 @@ + +def get_pretraining_set(opts): + + from feeder.feeder_pretraining import Feeder + training_data = Feeder(**opts.train_feeder_args) + + return training_data + + +def get_finetune_training_set(opts): + + from feeder.feeder_downstream import Feeder + + data = Feeder(**opts.train_feeder_args) + + return data + + +def get_finetune_validation_set(opts): + + from feeder.feeder_downstream import Feeder + data = Feeder(**opts.test_feeder_args) + + return data diff --git a/feeder/__init__.py b/feeder/__init__.py new file mode 100644 index 0000000..e5b3d4f --- /dev/null +++ b/feeder/__init__.py @@ -0,0 +1 @@ +from . import augmentations diff --git a/feeder/__pycache__/__init__.cpython-310.pyc b/feeder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c14c7f25363ae42a865b4f8ecf2f9b8f98105d GIT binary patch literal 191 zcmd1j<>g`kf~h@eDKbF%F^Gc(44TX@8G*u@ zjJJ3bOVe{x^GXs+GV}9_{WO_wF%*Hcu4E`;0a9S%myLdIUWtBMVsVLnUVc(ya&~Ef zera)PQL%n?JoA#{>@~UOR*!!s6H@7$U(2#SuZHgrPmvGu!K#?s0XG z9h)90VuOT43=Jzi5bnQ)75o7^?m1@nxs}M z8~A+saYt64G>m`H!^w|{hs*e-e}RyO6g{IaOhXubvuTN<#%|jD#%VfSx=jzI+4K9Q zW@*x<+$^Kc>XrMIW~E#Kf(I3g+i4I-6p+K=+23Hm-JIbJ@r+m3*9GurO;`J={RIswO-WIaB8#A ztE4|w=rz)5V63f9`WFhFrcqxk^d-`tD|DK2+J2ezFBkPRy=oNtD(PP>^feU6w1$Ss z%bKiBU7$9Ty>C#PVr}YWs(%KeO-7o2$O`m(9l!Jih|Cy^9N&(4C`M*(ykWdwe$O0P zBU_5x%G#joe;0Mqch=mjcM+Q5(McCc5iV8umIZ-lQAFw*w`$mkcvj#GBfD z6R>=1n5kFMn0H)=OVsytH_Te;06PhKt*#0Ndvr)CA*w9sTctf*#bcWgJ^c{M2${Tx zze15>B%~>=k9ao+;tk^f=hX%I?u(HH`r%D(ajh*~ezN(=$>@CPB%K3$aO zlbJF1bjA~dzBwD7C!#apIcU>7$A+N@yX;XLmbQAyZEf#G-92rm(cLim=Yt0i8YS(p zBWJ9uSI}O&;B$Z2(=J}fq^~_Hqd0@7AL1!c-84~|w%f_zPNSmBQ9KxCEg571^HN== zL6!s4KLV?~ouuHf!nLu+!;<#anH6dO^2Y-%l_RRYmP6Otl3*pV%_v z?9@Ym+QMilfx#Bb!WVVB4hK|)UKdMvewv8<6TE)xQ8ST==5Rb9Gf^3YGcvJn?pxBP z_$fWqz)_8Dkj#`mL4Zg`hq6QcOj&~KGT9ZEckgbrx6Wl-vaxk*tH1T)rLFcBp6EeQ zEJE}YQhf;;N4KkZ_Eb3N1?`Yei@HzR7oBHgH5^5h;xkfjLDOz~FNx5P@Y=A{J1Cl` zv#>u%6ujEhi0WPHwFpMNG;21Y<=4h4Anf)YDju;LgKX04LC+vc98Fa|5OVy=b@v%pFQ$?i7~vhH=zMR9{;%$#f|L;?qHrh7+r5 z=h_}J9_<2-!1_bRRBen(_D#uG*VhF?{j;p+Lp%@x zKQfqX29d3Rz{twP*rbC82$1<87ix?g`Pu{oFhdnMJZ|0qLw&j0|T+ zHH>NtJuUQnm3kI$<}J~N6NaKEapj?&4tl;j?>TJ;CN0b*X*NT;_ff(KR&MRvFdn!B zKDl_~F`q0pqm6XJhryAgfJ=66BS9$bGP-cUj4!4P3wlKDOfNKM)1-F(>ZfSS@9NSbk(WXt7oqj07%-r+K*imlm8}H$GN6yF{A)_4m z6xzsH!?BtY`6HyYiJWG2q9NZdfiE_GPUb>-;8JPTI1~pW_YV>F0G~3C;!D6fFK>Vd z>*Sy-QwNq;@GY;;GSL%`teTbf%Q?x4wC|atYF-^z34{o==Ak{V9U9|0vSJtO@$f5S zT(qGP=9#TD*eO=CQ~OJKHOFo`CS7rnE!)(3|1?G`-xDY;DtVUK$hntWn2Bb^oCJTK zbr~K9eK}7)z6?XoNpK{zBeOe$kgv21ZtY)>aP>e`xO!x%AD}s4LAR`0(2KfptDUGA zS1I(gWYj-e(N4M-U;xUnikM?lTBW^y6i0o^tNMTnDHf$#PN$b7iZQ31EK%(}?Pg(& z#EV}Anf5SPf0%{3+)m;w>JF1(s_k1*cw39Bx`IJMHAs43jP}x-!!W!XY6o{g=x>eA-KkG1%rK84k`+A0~edSToJ37izU ze5hAv>BQnoq=Z>>o@>xb!w(*gpL%7Xw+dw-C!Z3sXIb=(V@CVezCcp}BI+oUq3N}*#QMPzLp#oa|PgI}_6Hx&d*3Ur& zXGR6&Hh{AIxv5}-3LpYtKc#{vD@+AJ;m+o}vrxNurCq zvhg^g*YesvtbyF6%#^P&<^NCj*{BwgkD!jv-(OP-Md1Ka(l|nJ9y!G0RdtIh?~wR8 zM5E4(QXf;rJrchm@mmr!r1~9+Pe}Zp#2-lfk;JDE&DtXV2lV3Me4Tlf4zDA$48yn` z(%9(*h`LoZ!4NTF2v{;_$cA$k0b^;QPW>638tcs!_A2M&_;*JD{n*an6G$QSCq-|>EtZhlFZ$ozk`-xIA@_}we?!aq%lzc;CpA!6xG|(QC zPl>RF^OOij(!SrX^n(MdT@`eVwA^FRLsnZx&X1EK6k!^>mY$jbp zYJw@lNHjNP^ZM#S^T9%MeWCgFLi5X(*>}C;!_iL~;&}59pl81TC}=@ADHuRnSO8h! z0JI7hpk3GiVSeZq-VvpwAbqdzZ=gg4dH}r&y#{>>csSO_v+? z4yYOD59onKNj@V$ci%Z+_>OQ-)@tiwiyc!q+c5*sVhRs@{EESk@j;YV;x^)F5y(4d z=qRty`+$z3s#@YGs%j-(L43RqimG}^M^Od-ifd6+HA;L6aa00y6jie&j-qO=#8Fhu zm-r0gV|f%+uaEri_d;>3TdI@U996vYZFKI34CZ>y4zIFL5PPh>4RpP;cVKl_&`X$1SyO zpkw?_(^X9$Xu7WH*P4E*>6)hZwS7N{rFK%)i*s#7GS&^yx+xV;x~1++JLXI+U)!0C zy6P=7p^4ya1gKM8LGTU$Z4U-v`kbhbVDk4DOM^IH>V{dq6epWuXJ?o$4Kt~-rOuQ0 zmv1iLTIwW1E`8Q}klYxAz4*hWu8Fnrz(w*T&T=Kg!D70nea?GvFXtRDI=ckGt+@2$ z3#`l3hcBR)odz9Emr=WTvZjMhdD74u(W^{2{RT)uRV^odaQi-_2}oNQSXS)i^$0?j2e`%rnF~*-Fcsn@^8wu5$X-l$BF)q$V32NFHO=Bw z+vo%3OI-!?%9mP3tO-EdNeVCRroGNiB-_<29Hf!tF?gCxjqx)~e6oWX;-n#t-Mj;aKimP> z`kX?(SJzZIe!G?3?J5|qx$5*4eS5)VEQRQi@G=2^!@}hiqJaa zZ8^jN{~TsUiG9wi4DXy*VYpAc>M|Vb#Cr;QksQ%|Oqdqp`{+5g!R*|DB+W2|_tA^| zBiqPT1l?ikDsb#BTf=*)2nLj2Vb1`)Yu_mt6j1UbgK~zj#0~7)7ojvvDE7a(7f9j! zKmQJJ538^*6iDI0iQK|Hgd)-Rp<}_Q&*O=v=gAQ}YzQAtnc0!7ZZhJ`%r%7A^CP?gY8%L!RE<%!P=rCtP5Z(TKAi_Eyb`Q+sPgZDMaJ{j$BaqFU5Sdrd3{tJ>)&n0DPk zf`2il>>R&V!&#d&2{;O0NXK9j2?com={G+J%aH|Cb*Kt1=GWZ^l7m;!U zKx@jxrFMF;kUz8oQ^25bqWjX%5Db|71{?V<}70>#*}8?FGF8xdmw^_0g@Wz zr9`8voNvjzm&*a?U%<-jeEQk7t)I$Vo2u^df5V^it(s-Ht(y7-c-qMm)!Eigd5PwM8=1_xsmw`=x0Y$E7kArk z*VM@JrJZ5ki!$Q`+It8!AJ&+C|LZu~zf)TG3A$B=|2I%EEg$ZRu?AGqfL^lx0XXxw A7ytkO literal 0 HcmV?d00001 diff --git a/feeder/__pycache__/feeder_pretraining.cpython-310.pyc b/feeder/__pycache__/feeder_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26d99695686f63c5db58da9a89693ddca7c991db GIT binary patch literal 3855 zcmai0OK%*<5uTp+&c24^+7_KSSu+o16I!&SydVg+Y}xY3JcNXT3=~X8+q2{@cV|{T zvrb41Pf~y&h!I}`@yd9W9W?Bz-8KR zwFHn}wFD29VRRJv52$dbjUHON0Xyb zoa9m7Pm`=mGRu#F$peh+K0vIIhJnsmU{rM8;08BeumC4=`-K)*+~F?1Hm`6GUx!zD z4PTenaeh9iXEN()8*zg0%aHYyj9}>MPWyE(#h7QNMBzm>2aPbz==@97* zLqTcVY01hwb7Ui#jKZFXMzP8y>GhNGBoDLxSu9&6F5)pRz%1#5YFDJ=5EO`r@6pbc zVH)v}_LQcAmdZHE_9c5L*`Z{QCHqve1Ia#==7(t#OKU9pNiL0HoJb!kJdH(|_DV&v zjw8)19`?ldN%=MbI!C-sU z?jPU1{ma|G+Uuq~-uw8oXs@TP)bta=tXxF>q@Nt!8lOrx41woiNdC_7Gtz9=X0z8b zzN@>;Wgcr;K09lcxAWF)O_y?T_7F}+eLQZP}=wE@z5fZqXrQ*?oh#L2Y_MZ2WL zrvx{-c4COPr*u=%ElPN-w!t?B4mY-G=5q6Mom(wjvQSBMY`9y{4~v`=j$pH(+X0Hj z^V4yBb%$mZsV8G;QpigG{t(pC$?pP_Cc=}69}#LHM0fAtV%;VYTCE&jf}hj$2LQy( z=n9x-i#2r5*kn8UnLqWD*w6Nbo` zFnj=VMqakrSz{gv#ro!RC?W+;5$^#hH+_GOse{NSl?FM|F;vQ_P_Beq@m9i3`9l8; zwcy`4W){|hZEKh89KjwKgr6I^d12x2Nuj22I+y0TnY$P6B`Vtqk{kR=fv~So-vQNI zqy8Gy>KgSSsI@ifKB)C6bzK8%NrNsVD}@COmZ=G--Wv53)an}b1*o+(>fb@FmsHr@ z;QEO!{)q}=_4EUZKC0u1{sj{WeAiT(+668H*Z!=dMiVS?1$*YUHs_*xuXAr`g*{)f z#=j5lg^zu0s=Od1!b?90RS}pBTzj%#T;0;HuAbC@*a0^$hSV%jstZM!zsZuyw^~wJvBWF9Yb@~!TCuRCT2!yGq*^$ITX;paXHH|~Ay&G0JdYJB z`QY`BwBO%-qMz5#8%wc4DH^e=xmsTW1bShL)p?V?=qHg8!Xq<{@7*g}F)H|q}E^QGdJRM2< zHz_*nSelQ9{jszTj^jwkW`@SLkENq1J3*p^Q*JXd2ybrgA#v;E(G0HLS1Jly^xxf$&m0`7(Sas!zl;l zo~PX~6u*G>s@V_^2>XD*0f9w-FfZVO>O8gHnGfJ_;&$rN%=;s> zp?NC0$I>dptP c t d v m', d=4) + temporal_indicies = np.random.choice(16, 6, replace=False) + out = input_data.copy() + out[:, temporal_indicies] = 0 + out = rearrange(out, 'c t d v m -> c (t d) v m') + return out + + +def Shear(input_data): + Shear = np.array([ + [1, random.uniform(-1, 1), random.uniform(-1, 1)], + [random.uniform(-1, 1), 1, random.uniform(-1, 1)], + [random.uniform(-1, 1), random.uniform(-1, 1), 1] + ]) + # c t v m + output = np.dot(input_data.transpose([1, 2, 3, 0]), Shear.transpose()) + output = output.transpose(3, 0, 1, 2) + return output + + +def Flip(input_data): + order = [0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 16, 17, 18, 19, 12, 13, 14, 15, 20, 23, 24, 21, 22] + output = input_data[:, :, order, :] + return output + + +def Rotate(data): + def rotate(seq, axis, angle): + # x + if axis == 0: + R = np.array([[1, 0, 0], + [0, cos(angle), sin(angle)], + [0, -sin(angle), cos(angle)]]) + # y + if axis == 1: + R = np.array([[cos(angle), 0, -sin(angle)], + [0, 1, 0], + [sin(angle), 0, cos(angle)]]) + + # z + if axis == 2: + R = np.array([[cos(angle), sin(angle), 0], + [-sin(angle), cos(angle), 0], + [0, 0, 1]]) + R = R.T + output = np.dot(seq.transpose([1, 2, 3, 0]), R) + output = output.transpose(3, 0, 1, 2) + return output + + # c t v m + new_seq = data.copy() + total_axis = [0, 1, 2] + main_axis = random.randint(0, 2) + for axis in total_axis: + if axis == main_axis: + rotate_angle = random.uniform(0, 30) + rotate_angle = math.radians(rotate_angle) + new_seq = rotate(new_seq, axis, rotate_angle) + else: + rotate_angle = random.uniform(0, 1) + rotate_angle = math.radians(rotate_angle) + new_seq = rotate(new_seq, axis, rotate_angle) + + return new_seq + + +def temporal_cropresize(input_data,num_of_frames,l_ratio,output_size): + + + C, T, V, M =input_data.shape + + # Temporal crop + min_crop_length = 64 + + scale = np.random.rand(1)*(l_ratio[1]-l_ratio[0])+l_ratio[0] + temporal_crop_length = np.minimum(np.maximum(int(np.floor(num_of_frames*scale)),min_crop_length),num_of_frames) + + start = np.random.randint(0,num_of_frames-temporal_crop_length+1) + temporal_context = input_data[:,start:start+temporal_crop_length, :, :] + + # interpolate + temporal_context = torch.tensor(temporal_context,dtype=torch.float) + temporal_context=temporal_context.permute(0, 2, 3, 1).contiguous().view(C * V * M,temporal_crop_length) + temporal_context=temporal_context[None, :, :, None] + temporal_context= F.interpolate(temporal_context, size=(output_size, 1), mode='bilinear',align_corners=False) + temporal_context = temporal_context.squeeze(dim=3).squeeze(dim=0) + temporal_context=temporal_context.contiguous().view(C, V, M, output_size).permute(0, 3, 1, 2).contiguous().numpy() + + return temporal_context + +def crop_subsequence(input_data,num_of_frames,l_ratio,output_size): + + + C, T, V, M =input_data.shape + + if l_ratio[0] == 0.5: + # if training , sample a random crop + + min_crop_length = 64 + scale = np.random.rand(1)*(l_ratio[1]-l_ratio[0])+l_ratio[0] + temporal_crop_length = np.minimum(np.maximum(int(np.floor(num_of_frames*scale)),min_crop_length),num_of_frames) + + start = np.random.randint(0,num_of_frames-temporal_crop_length+1) + temporal_crop = input_data[:,start:start+temporal_crop_length, :, :] + + temporal_crop= torch.tensor(temporal_crop,dtype=torch.float) + temporal_crop=temporal_crop.permute(0, 2, 3, 1).contiguous().view(C * V * M,temporal_crop_length) + temporal_crop=temporal_crop[None, :, :, None] + temporal_crop= F.interpolate(temporal_crop, size=(output_size, 1), mode='bilinear',align_corners=False) + temporal_crop=temporal_crop.squeeze(dim=3).squeeze(dim=0) + temporal_crop=temporal_crop.contiguous().view(C, V, M, output_size).permute(0, 3, 1, 2).contiguous().numpy() + + return temporal_crop + + else: + # if testing , sample a center crop + + start = int((1-l_ratio[0]) * num_of_frames/2) + data =input_data[:,start:num_of_frames-start, :, :] + temporal_crop_length = data.shape[1] + + temporal_crop= torch.tensor(data,dtype=torch.float) + temporal_crop=temporal_crop.permute(0, 2, 3, 1).contiguous().view(C * V * M,temporal_crop_length) + temporal_crop=temporal_crop[None, :, :, None] + temporal_crop= F.interpolate(temporal_crop, size=(output_size, 1), mode='bilinear',align_corners=False) + temporal_crop=temporal_crop.squeeze(dim=3).squeeze(dim=0) + temporal_crop=temporal_crop.contiguous().view(C, V, M, output_size).permute(0, 3, 1, 2).contiguous().numpy() + + return temporal_crop diff --git a/feeder/feeder_downstream.py b/feeder/feeder_downstream.py new file mode 100644 index 0000000..588aaa5 --- /dev/null +++ b/feeder/feeder_downstream.py @@ -0,0 +1,101 @@ +# sys +import pickle + +# torch +import torch + +import numpy as np +np.set_printoptions(threshold=np.inf) + +try: + from feeder import augmentations +except: + import augmentations + + +class Feeder(torch.utils.data.Dataset): + """ + Arguments: + data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) + """ + + def __init__(self, + data_path, + label_path, + num_frame_path, + l_ratio, + input_size, + input_representation, + mmap=True): + + self.data_path = data_path + self.label_path = label_path + self.num_frame_path = num_frame_path + self.input_size = input_size + self.input_representation = input_representation + self.l_ratio = l_ratio + + self.load_data(mmap) + self.N, self.C, self.T, self.V, self.M = self.data.shape + self.S = self.V + self.B = self.V + self.Bone = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21), + (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 21), (22, 23), (23, 8), (24, 25), (25, 12)] + + print(self.data.shape, len(self.number_of_frames), len(self.label)) + print("l_ratio", self.l_ratio) + + def load_data(self, mmap): + # data: N C V T M + + # load data + if mmap: + self.data = np.load(self.data_path, mmap_mode='r') + else: + self.data = np.load(self.data_path) + + # load num of valid frame length + self.number_of_frames= np.load(self.num_frame_path) + + # load label + if '.pkl' in self.label_path: + with open(self.label_path, 'rb') as f: + self.sample_name, self.label = pickle.load(f) + elif '.npy' in self.label_path: + self.label = np.load(self.label_path).tolist() + + def __len__(self): + return self.N + + def __iter__(self): + return self + + def __getitem__(self, index): + + # get raw input + + # input: C, T, V, M + data_numpy = np.array(self.data[index]) + number_of_frames = self.number_of_frames[index] + label = self.label[index] + + # crop a sub-sequnce + data_numpy = augmentations.crop_subsequence(data_numpy, number_of_frames, self.l_ratio, self.input_size) + + if self.input_representation == "motion": + # motion + motion = np.zeros_like(data_numpy) + motion[:, :-1, :, :] = data_numpy[:, 1:, :, :] - data_numpy[:, :-1, :, :] + + data_numpy = motion + + elif self.input_representation == "bone": + # bone + bone = np.zeros_like(data_numpy) + for v1, v2 in self.Bone: + bone[:, :, v1 - 1, :] = data_numpy[:, :, v1 - 1, :] - data_numpy[:, :, v2 - 1, :] + + data_numpy = bone + + return data_numpy, label \ No newline at end of file diff --git a/feeder/feeder_pretraining.py b/feeder/feeder_pretraining.py new file mode 100644 index 0000000..de4be9e --- /dev/null +++ b/feeder/feeder_pretraining.py @@ -0,0 +1,126 @@ +import time +import torch + +import numpy as np +np.set_printoptions(threshold=np.inf) +import random + +try: + from feeder import augmentations +except: + import augmentations + + +class Feeder(torch.utils.data.Dataset): + """ + Arguments: + data_path: the path to '.npy' data, the shape of data should be (N, C, T, V, M) + """ + + def __init__(self, + data_path, + num_frame_path, + l_ratio, + input_size, + input_representation, + mmap=True): + + self.data_path = data_path + self.num_frame_path = num_frame_path + self.input_size = input_size + self.input_representation = input_representation + self.crop_resize = True + self.l_ratio = l_ratio + + self.load_data(mmap) + + self.N, self.C, self.T, self.V, self.M = self.data.shape + self.S = self.V + self.B = self.V + self.Bone = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), (8, 7), (9, 21), + (10, 9), (11, 10), (12, 11), (13, 1), (14, 13), (15, 14), (16, 15), (17, 1), + (18, 17), (19, 18), (20, 19), (21, 21), (22, 23), (23, 8), (24, 25), (25, 12)] + + print(self.data.shape, len(self.number_of_frames)) + print("l_ratio", self.l_ratio) + + def load_data(self, mmap): + # data: N C T V M + + # load data + if mmap: + self.data = np.load(self.data_path, mmap_mode='r') + else: + self.data = np.load(self.data_path) + + # load num of valid frame length + if self.num_frame_path != None: + self.number_of_frames = np.load(self.num_frame_path) + else: + self.number_of_frames = np.ones(self.data.shape[0], dtype=np.int32)*50 + + def __len__(self): + return self.N + + def __iter__(self): + return self + + def __getitem__(self, index): + + # get raw input + + # input: C, T, V, M + data_numpy = np.array(self.data[index]) + + number_of_frames = self.number_of_frames[index] + + # temporal crop-resize + data_numpy_v1 = augmentations.temporal_cropresize(data_numpy, number_of_frames, self.l_ratio, self.input_size) + data_numpy_v2 = augmentations.temporal_cropresize(data_numpy,number_of_frames, self.l_ratio, self.input_size) + + + if self.input_representation == "motion": + # motion + motion_v1 = np.zeros_like(data_numpy_v1) + motion_v1[:, :-1, :, :] = data_numpy_v1[:, 1:, :, :] - data_numpy_v1[:, :-1, :, :] + motion_v2 = np.zeros_like(data_numpy_v2) + motion_v2[:, :-1, :, :] = data_numpy_v2[:, 1:, :, :] - data_numpy_v2[:, :-1, :, :] + + data_numpy_v1 = motion_v1 + data_numpy_v2 = motion_v2 + + elif self.input_representation == "bone": + # bone + bone_v1 = np.zeros_like(data_numpy_v1) + for v1, v2 in self.Bone: + bone_v1[:, :, v1 - 1, :] = data_numpy_v1[:, :, v1 - 1, :] - data_numpy_v1[:, :, v2 - 1, :] + bone_v2 = np.zeros_like(data_numpy_v2) + for v1, v2 in self.Bone: + bone_v2[:, :, v1 - 1, :] = data_numpy_v2[:, :, v1 - 1, :] - data_numpy_v2[:, :, v2 - 1, :] + + data_numpy_v1 = bone_v1 + data_numpy_v2 = bone_v2 + + if random.random() < 0.5: + data_numpy_v1 = augmentations.Rotate(data_numpy_v1) + if random.random() < 0.5: + data_numpy_v1 = augmentations.Flip(data_numpy_v1) + if random.random() < 0.5: + data_numpy_v1 = augmentations.Shear(data_numpy_v1) + if random.random() < 0.5: + data_numpy_v1 = augmentations.spatial_masking(data_numpy_v1) + if random.random() < 0.5: + data_numpy_v1 = augmentations.temporal_masking(data_numpy_v1) + + if random.random() < 0.5: + data_numpy_v2 = augmentations.Rotate(data_numpy_v2) + if random.random() < 0.5: + data_numpy_v2 = augmentations.Flip(data_numpy_v2) + if random.random() < 0.5: + data_numpy_v2 = augmentations.Shear(data_numpy_v2) + if random.random() < 0.5: + data_numpy_v2 = augmentations.spatial_masking(data_numpy_v2) + if random.random() < 0.5: + data_numpy_v2 = augmentations.temporal_masking(data_numpy_v2) + + return data_numpy_v1, data_numpy_v2 \ No newline at end of file diff --git a/graph/__init__.py b/graph/__init__.py new file mode 100644 index 0000000..ede2cee --- /dev/null +++ b/graph/__init__.py @@ -0,0 +1,3 @@ +from . import tools +from . import ntu_rgb_d + diff --git a/graph/__pycache__/__init__.cpython-310.pyc b/graph/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33f85c232a6d841394daa5204c3beaeae3dc038 GIT binary patch literal 214 zcmd1j<>g`kg1)}A6kQSUb>k6b4|h!F8khAYRJQt*WA&pi4Fu>SAQ0V6Xs7%^jM9{C&K62EIt<>T#SpuB6iQjREA=>^*ZO4&{o>3X L$4%Jof&TA5I*3R8 literal 0 HcmV?d00001 diff --git a/graph/__pycache__/ntu_rgb_d.cpython-310.pyc b/graph/__pycache__/ntu_rgb_d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c691a26a476f7489c3dd12a30ae620d513d6c02 GIT binary patch literal 1690 zcmZuxOK;pZ5GJXoch?)obs9G<&=h^DkCk(2FGi3e4RZ4(Ko3F-L0-~p*;;8c;x@Li zZh_jzUV7-==`DYW*Piw-^wQ3dyYadW1&%%rsgE-shicZ>I}F<&za7Y5e8&Ew!D{nx z@CbYRGaAVxuUO{ZJQIo?c~`y^BmRCG_)ws+9CP1eezS0j-Z$vrnR2tFxt=MU~+0YyH(lh zVKbgQj41$$W{UygWIGXaigZrSL;y)Yzhm*7godue6dCj!4p#gj;lA@ngl{=ax-zL0RPtI?s(vR@)H5>MRY((qYgHA(&5HAl9#f7jJ>^&#-3yE2XW5N#f(d~&1Qn^@=OzM z1GA_2$^)w$NoM%b_3?rgjufNJ?eX0I|Guy5hr+b`#*2M%=jDDYq#xk2mEr3>NKw1T zHX9i?JKswp6W7ymUSn;WwN_OJc~#cOCM@e0 z1xOP$Q*A+9t8#R>-+*#|0$`i2%AK`aN3}4_2NI!F%@>!^x+p7&9)8r>)J|*)RE$e(=R8#c? zpn=4LFUkKE8B1iZt~Is>KF0~VAU3@i(~J2A_r36+vwJSRH%YjMg=~WhNTa$Kt2{TI zoJ5>fgnN1Zd|Fg)2I#nfHd+z*I)V>zg05Ukdc~~6`!SF&xn;cev^4=9I{vmqJuyBO z)}aa2iB`2Vfm48#u`-wdih8818rg=cioAgm4^gUcsIZCuLtt>fB*_RUBOgq=fdwHg_-W261 z|57MdP4(qt(U;ydy-l06MR#bM>U5Xhp?B#$`WgM4enI!>eY#J-qzCjXdPqC;0X?F} z&5Bo`YgS6ZN|>RlZOzzLQO&60qVQ&^?3Gx1fr?b3IV#fvU8V{x(@Jy3yFhEH9i`Nc zIr@-3q9^n*{hEG5pU`jVceL9qc=L4Kj-0pB%d~D=W!t)_%IYHPTA-qL$!eOV3pC%% zd6(%%D*sX{e~~_=&*=B`2YO1MH}hVF-moJoR?-sPw5=uET2@PHndPoL5p;_-UJBai zJKn0gLbU_uJE{L+eSWod&8u2!&0ABe^k!eSuJ?tvesn{vKNerDdN)+H^@h4ZZ_({< z9VxhVvoEChPCono3!!eQb-P3Isi^&NfbHFVp4r^os!5Yey4`j>z& zU%bH36er@V;tP4g@7K=XoM+jm;_pIZOwBRQ<{cwL+p_S(`&;|nj@mjtYV50MD?U(d zm2@LNXe1&3v}$ylQ3!ggse|7BR+Jq2y1D1mX7Bii!u@s_Cyj2W_h1_OVN70Vmwy8o zH~?dC36KH6h(QjJ2TTF%V9XvA0DP#$pa_@+lmHh1bAWk38E_GR^$sor76BE8v16bR zs;pzGOjyTMzI0y57tVxf$sZQ29(se4ZHlazcK{x~qVXO-dIIHXb+W1ppHRp6ir`*%!z?ElJ{A4ct(S>2} z&eQrjoQNZGR@YDL#MN_DynZKM{SmFN!->}>bTK)8u*!%0#Z?9+JdQ8VA(t|geJQ;R zO;etxn95R)^IE~=o*o|ni$pTY#9?%j3}YtA zGV<5Top4!WI-g9n++;c_Br{2|kKS3!DJ2(@xn$mQ%FMZ#WE0Vcqy=tWVhIZ*lgq6| z=E?=J$jE&a?pr$HQLP0#b{W)Y)S_T&WyHb1wMxzs&s^lAqvM{l}0zRmomFjC-eo0+i+Zn#4wOiib1m~?iwCuYZdd6JcBzFHC3H$ z@!TDMFN^|xY?h<0?gZ`dxq8)HC+eAuI#wpX!)w=_q1^|5+z8rgXt!4wN%B82nfS0{ za`Ar9V>g8ng$)(^+g^6Zx8|n4VNPym+h=>mW9qex$zsuFwjJy#c5c(}z}Y#g>P$v8 zuJ4CjHhq7cof#u72$>(hMOo&zBv<8}b5B;}^OB8*&C#UY2KE3CB4J#_>fAu#I&gJ~ z&kdgqZpi~jZzQtC5y59Vm*o4|7f!@+&=-G^Ck~$r^ak1z%LCLC?8u_f|4ci@Y`a%G{c(SqPk z5JtXlHJU6R<72)>WgDHK$7a}KkH})$1{kg$Ok*`VrEv$b4CbTYkN7Camce{nOf(ow z<9=cpOyim|8jstGWiTI=R)fJbu11!@G-~0{EsjR5vJB?q9$^_w_-_y@vR z&MW>&`IG*!c@Ba{Hl(z*A*HPiDLu}MBRkPq+u&F~j!lk}R_|EfW0uKst=Ttub(E-x zOwMMtVK^WXOil+;Q`vy<=F*hb)gcEmsiV^n(Qoo#y$N^-*kbU?XUF^7`VM-x0Xu-Z zfDZugFq|`^t`SW7J@ghvY1*&bv#dX6u9@e=*lTx_cDTp4Z%`(m{xiLHVw`R~43gUr zgbK<9pcC(Rzv0{k%=L5_CFjMEaqka9?*ZHyBi>xcbxN}6WNof0$|7ehye@N=nO7ON UO5B=DTcs-nS&(j~$cWMZ1=hdb#{d8T literal 0 HcmV?d00001 diff --git a/graph/__pycache__/tools.cpython-310.pyc b/graph/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa2e6c1b1b22d0bfd598edb8899072152cb19327 GIT binary patch literal 2565 zcmZuz$!-)!7_O>bXE6-MU=zr$ES+s)AP_}Xf)x^p9Gn9vhgPCdGu1QhS$c-*9*j|q z6vZ6FA&1=UBTtd{s8ghz{0g}w-(Nj8*ifUcrK`KT{%`qfCQD0ohVtj{J^4q-*grIQ zu{apq!d3i)N;1huEap;3=Lw62bft&Zk-iMjx-yhCw4SWX23lV>OYIQ zJNG2lj&`-z6C3P7O*>|PCilNXG|?&)qlTi zkMpwKONz3cPwpq(;cVKT6{#xP-JQ?ZZ>-;JcPBD!fA=tH_f<08ZkLnEsQ7HUYa0Eu z>=gZJXPlG-LqV*Z;Z1&quW*;Evlw;-s-xneLzl(p6Lc&-e1VV9Y=Mt6cY%)pKF$Ik z0W8FpqQxSiXbnX`N1g)z%!Cwbj%}(MP0>D7gDQ%!s?MNi!c_LttJp3*QP8W^cCp!W z7`(Ju7dP!@@1dh1q@hs02L08>L8?GcXg&aSml%;;x|;2A3S^G$Yu;i+ymoG~xp=T@ z!xB;q-3_KGD5na_5{zglGE=v7c1GDSHBP#l8n4(+rf5`>_tVJ3xSNcsSYxpXrZ+p& z$xf;aAD)i*9C0-4*cUpMzZO1GETam+QJuHISBSPrjIH5zm$Y_NNYTfP-N&F(wB!e3 z4rLDn)GQD*+jkHw*a8GDAaE=M%X3e2?Oh_^T=n6iFU~Vk{zCf~@vaAtxLC2$JRqSh z9R7zA`4!o;54A!S(XK9_GLD>-QD7O5#qE*d+lKF;ms|b;%~Y$WnFZrA z7P4S80Rd^D3F(SoEf}w3u1g9yDx{ZVFuuZ|gtfso3|(^cA>ci+$LF3MIg62B3V=fa z`#lHn1Uk6Op7E{7wVG4(nTmI>C}F=SR78=f73ru45x}~BmhU7=ngDLL9e&4%J9+H< zI4di~pVGNyy_hCtmW(Ra-^Y^HpA_a9unxB0!@U!1Q?SR9-@rh#eQ<|O^OBRyr2tC; zeUdrf!;BUyBya@VV%lke%aRZL1Im!#@e0<9hvr3nilaAIt7y3!%p!)m2ry*rOASFkjfsNX8g71j3;uC zbkn@MyLj|dtZq{PX@42CtB-K!dz)iU&`ydY>TM#>=zwa5D8-&A!XMagz-F+pJsg3| z{uF{BtdK0$w2NfH_kBDQ?1!ZDRfy&QdH#TKJau0>Ypj}EWAKhp8;}jev%-)*;25Cl zZbjU9lCZKisd8e?)UzVXA)vgQR%!}znJAU)T8z~P*thx^l?ldKo{eW?<3n4?sBfAT zbX5w0;cNX;N&ylx!Nq&N0f@@k~LEZ2?+JHEa9| zcmH#R_0~EjPJm1anbcuH9-#Arw-(NVuOKn&H^8O=G@KWp=i6Rb3hi0<3GJhW?{t7R zKpUbBOK0D$e5eB*_B_bCrfc?KIkX3hR~eG__S}`~O$?*D^`v@>rrxIN9jXWiBR1D@ zT?HQ3a9w|;AFogr9?wQ)Rwd#S*8B#m6ceTv37gW3B_xLI z!y&CI1Hq7YnR)|LI4y00q2mpfqRe8Wh{HV1`rG%xrsexp31pUMJ@S{$lv{YSpf4{| z>bOD6iOMWd!+1Dy&A%y_5x)BC<=oU{||=>X6Mo3`6hLHQCgNW04xR6r|e#DdgIdH_)rfsDyGi4(_3#}2z& zITz9+{UQ7c4*ZKbA#vIhcP{X}<1}rzwKbo1#`8WO@8tH@R{gi|&!7JB8T*?(cUkD) zMajQIC7I+Y>v1Wh^_ca9w55aAlCJd7+R~Rbw2rLH23l7(H%oEEer*3BmiE$- zytGQYbOz#(J#0TXeY0(~~{lt2WE#1<^SDfw{$6AW2Zv&z2=p@OtKZ)bHOs08Y zX_3vf^C(ucTwD2K8hAZ7TTH_w(;l_6MWO99$tIe2G*2`iYqt|-Gd2C4sco$K@14ET zY#Q&Kp7cj?wwF)hG%jXY81;(;jf(rTVU|#TZ>XaAXs?*f(tLM*svE<&2=n1QoJIwq zkk?T$zKy%Z+g!CU>I<5~<*bW=XCDIt4`0J001sQkV=ZmKBLI)Jfkyxe(N(lqAQY{o z2UK^sA^J00an!o^t2z#VSIq?(rYn&iELtU1x^lTZLVsM}aJ{{gSuX(n2WK9kvt?_a9&tjD!>u`;e6s zlvM>~_(e1nnXVf)!!()1+KNwO?c}3qjz&e^N4MPtHet zkvN)!<_DqSk66hmS|)sM^E$ZN=00!YR@X7xC#lgzf~Nf#5C-U&z!AbqEFtW%fRx3O zgli#Ga0LkLlYKy783^84Iwdci>ja&xcGrFDG9%%yrN8@tR&`(#8>?25x4Q-nL-yH! zxB|CgoHnGcP(|FUtEjXkXGP!{+I!+C)qJG+F?yL10nj;m2@%qMlqQejP$mYEn^-6( zSQs#<8~rx<4&m`S=K3TAx=1b?QGA9$0ds?H7`ou-3+e*{$BeA?$Snl$p@`i9a9Mh# z2ikczu#M!*Co0;Vyomafz*iTswyx!II)DaXTsO&%BPF#5FPjd(r^UUjXWd^EmEh0G z+>oBnqaumYiuCueS=-9Ynn^auDdR2s26M!Bku7qpeenRcVRIEC81 zK{|_y+Xn!O+}dlBze5=hRUGHNEp&1uxnIKkp(=VUB9$KIl#Enza-Q`Y&sdM-IO@k) z|8)K8E>>UXf;vdQ#MS$l`P%4jLp~{viovHu2ZS?1W+n1gVf{)d*#(3QT-ex^{0U?M z)C97|erY3X@Ue&Dk833MYmm+gxqd~6o|~7}4unL{b{H%q)Ct4`3AGur2#f=K?QX!e zBMGNxUJ2^Dp5#de8D;&r(v+9UNU7*lZ+G%MR+PoRQfj6Q-ilPAK0u|tX_6(=#ZJ~kEld89`}k!}-(+N! zPtgXVGE30V4i4h^X_c-iqgD{ou~L1Va*nZ=q1(h_mUzytW&_n}eK?)TMH+ujCz6qk SFMQ8;{ifeQ>(v|nh5rGMTpQ#7 literal 0 HcmV?d00001 diff --git a/graph/__pycache__/ucla.cpython-310.pyc b/graph/__pycache__/ucla.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d545dedef9359883b5154ec309b05e8a72f3eac GIT binary patch literal 1654 zcmZux&2HQ_5GJW#Yky)pPUEIU>-1kJvT}f;xdcWEw@6RE1n5CXAjl%UmaUaGB!$?< zI)~aX&^K^;%PaBP)4oD4?F{9OH*F|z^l?bda6W!iGZ;h!+FJC8E0>VJaIoAQ7(9Wl zUW4I;(~6|so2E>VBj?KHY)s!!ie9JkCVuJzcGG}++`l9u6u|}knOx9xWlVVRCru+B z+L_3Ddwj)?dUn(oJ<-2Q(!o>0Bi_3tym!IUA*ats+@Hf9$ljSV8YM9`o^F~-%{RbT zP@K7L4Bk)**@dofc1&jEg?~GiDOWpCiZM$Nk7&hrK22BOY}A0P(QH7*O8o@CxGVZ!xA^ z>-^}F>mBYRzTIIAC-Ec5i{a#*4r4g^QHNI%-|aAlTdc=$^5f2r;p9DwC6*yS0rTEz zpckvLr?CaVo<{#UJE!MtMmgCeGlq6?GBLG@>{+~Zq-6prF!bo=kfbF78Nf;`#asW+ z`%VR$i@f1=UdX)2w-oeKShh5LmjfBHIW|cF@V$ohJ%O&azzC8;MJWS;JKVjbse|I? z0U(!qAmPY(Pf;u;P}5wOd1Y%YFvy}I#=Z+(jlgJ-Z$>l)OEm+_v;)@BE?DSS-VE*^ zV4o8ZHrLr`Rzr1~7vv-xm;NoJ}yQr;5B1|KZAnSytAi&N7sRLcRHP@vPFbp)Gy1 zShQE*C0H~p8IRePy$t}h(GnTV$PttUBYMo_9Zj#GeC=va`)^!O5Zq-QT#(;s%=YBx z@Le3rZ{Z~|xLL!xuBp!KHI?^q@(VOyqWKD!MZe`$D}Ir(k+CDcLw}6M2D10k8Y_rz zU;kFa2kgihK@&7ry*VXqSD&eZ*-a2J(0g63jtIj7P%S@CZt8In2pJgxGyy|A; z_xKTJT%!7BE1$sx-bRJQg(RZGP@*rsWgK;?j0Y-!9}=sl#s#?@8eg1fQFG&UR3Jr` z597(a9t%s0S?y$!HL`@`eZ34M3Pnd02+IG=JR=sCr*=Z0GJagvG%ze$hd@?sz{U!SIr{6us{$?LM z%fz44SNuyLndB)8CnFDqVm+tlO3~pX&LMSDukVLGaW||;Px@y}1*$UQFW88OGxCp^2^RrSd|<0Y}KBdmDlWS*3N2bR@K&6 zIQNptO}TZ(OSb-&+V5En5ngR9)T}@!s-?g)4_UYUKrcy&0~m z#on4)mD}>>uYxn0t&bRI@7Td#KQMJwt=O3OE7tgGOtG`wW2f0{HMsH0JWq=;GSlW* zp-0f?()TTW1y(xjguSkv@)LUBh~LD4`-=U+H1;%v5zR-2Cw6Ax{nxE--dC-o!**9? ztzuuLs?4(}ZkGuSTD9{|mVmc9Iv#Xe!*&`s2S@L#Pt&9*+j)QRYzC&`htG6C!x+Ae zAs9OVm>$#hvat`S0PH>-99IFefEr*9PzPKA%mWqxi-0A9i_oDQYHjG-L~Q8mKZ;X+ zDo(g$>+D1Tr)g+#4U>~xI2va)Tr#|OnGp^DY#i|7Jg)!Qyh!&tZLl`y&ux;`sjW+e zjSnp8O(=(g30&Ahuv1$^FCN7BoW5e4K(Uab3WW@$Bi%C|I&wz(vO>(29@R_LcrS5F zrcw;zGKo_gZyQSRV8pf9gf)qjR4WlHIUz3vp(q_fr*sMF>v<>CV8ngNNIl& z9OFWPnM>n6MOr#hA6mVl9jEG}9GM`B5~{-}LWvis6vq9F_qcAuZcaDtwh<;CgnUsW zk7Ub>}M`Xl#LWPVv0v7h;gps>ZRzD_3_fxbh7hHA5b zsLqod+0M{6u(%Dl3AjaIs^7=yP<^L$uBi;}LVgcm6@63&JG6lEIFBx%=mNede15z! ziSAR0YzXU{y=UByhfHO0Uqz9rM#$(e1+PWX!7xrIE4l$6VDE!8TW@SVZqWi7t$--t z)j&f6#r-(RqR3h_E=7GrKP9)_j|b$2?JP7iaSR$gO&oVM75P_)knsE6dui4kR literal 0 HcmV?d00001 diff --git a/graph/ntu_rgb_d.py b/graph/ntu_rgb_d.py new file mode 100644 index 0000000..8c53132 --- /dev/null +++ b/graph/ntu_rgb_d.py @@ -0,0 +1,33 @@ +import sys +import numpy as np + +sys.path.extend(['../']) +from graph import tools + +num_node = 25 +self_link = [(i, i) for i in range(num_node)] +inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5), (7, 6), + (8, 7), (9, 21), (10, 9), (11, 10), (12, 11), (13, 1), + (14, 13), (15, 14), (16, 15), (17, 1), (18, 17), (19, 18), + (20, 19), (22, 23), (23, 8), (24, 25), (25, 12)] +inward = [(i - 1, j - 1) for (i, j) in inward_ori_index] +outward = [(j, i) for (i, j) in inward] +neighbor = inward + outward + +class Graph: + def __init__(self, labeling_mode='spatial'): + self.num_node = num_node + self.self_link = self_link + self.inward = inward + self.outward = outward + self.neighbor = neighbor + self.A = self.get_adjacency_matrix(labeling_mode) + + def get_adjacency_matrix(self, labeling_mode='spatial'): + if labeling_mode is None: + return self.A + if labeling_mode == 'spatial': + A = tools.get_spatial_graph(num_node, self_link, inward, outward) + else: + raise ValueError() + return A diff --git a/graph/tools.py b/graph/tools.py new file mode 100644 index 0000000..76e7cbe --- /dev/null +++ b/graph/tools.py @@ -0,0 +1,80 @@ +import numpy as np + +def get_sgp_mat(num_in, num_out, link): + A = np.zeros((num_in, num_out)) + for i, j in link: + A[i, j] = 1 + A_norm = A / np.sum(A, axis=0, keepdims=True) + return A_norm + +def edge2mat(link, num_node): + A = np.zeros((num_node, num_node)) + for i, j in link: + A[j, i] = 1 + return A + +def get_k_scale_graph(scale, A): + if scale == 1: + return A + An = np.zeros_like(A) + A_power = np.eye(A.shape[0]) + for k in range(scale): + A_power = A_power @ A + An += A_power + An[An > 0] = 1 + return An + +def normalize_digraph(A): + Dl = np.sum(A, 0) + h, w = A.shape + Dn = np.zeros((w, w)) + for i in range(w): + if Dl[i] > 0: + Dn[i, i] = Dl[i] ** (-1) + AD = np.dot(A, Dn) + return AD + + +def get_spatial_graph(num_node, self_link, inward, outward): + I = edge2mat(self_link, num_node) + In = normalize_digraph(edge2mat(inward, num_node)) + Out = normalize_digraph(edge2mat(outward, num_node)) + A = np.stack((I, In, Out)) + return A + +def normalize_adjacency_matrix(A): + node_degrees = A.sum(-1) + degs_inv_sqrt = np.power(node_degrees, -0.5) + norm_degs_matrix = np.eye(len(node_degrees)) * degs_inv_sqrt + return (norm_degs_matrix @ A @ norm_degs_matrix).astype(np.float32) + + +def k_adjacency(A, k, with_self=False, self_factor=1): + assert isinstance(A, np.ndarray) + I = np.eye(len(A), dtype=A.dtype) + if k == 0: + return I + Ak = np.minimum(np.linalg.matrix_power(A + I, k), 1) \ + - np.minimum(np.linalg.matrix_power(A + I, k - 1), 1) + if with_self: + Ak += (self_factor * I) + return Ak + +def get_multiscale_spatial_graph(num_node, self_link, inward, outward): + I = edge2mat(self_link, num_node) + A1 = edge2mat(inward, num_node) + A2 = edge2mat(outward, num_node) + A3 = k_adjacency(A1, 2) + A4 = k_adjacency(A2, 2) + A1 = normalize_digraph(A1) + A2 = normalize_digraph(A2) + A3 = normalize_digraph(A3) + A4 = normalize_digraph(A4) + A = np.stack((I, A1, A2, A3, A4)) + return A + + + +def get_uniform_graph(num_node, self_link, neighbor): + A = normalize_digraph(edge2mat(neighbor + self_link, num_node)) + return A \ No newline at end of file diff --git a/log/n60_cb_recognition b/log/n60_cb_recognition new file mode 100644 index 0000000..9a67cbd --- /dev/null +++ b/log/n60_cb_recognition @@ -0,0 +1,572 @@ +Epoch: [0][ 0/40] Time 35.689 (35.689) Data 32.362 (32.362) Loss 4.0975e+00 (4.0975e+00) Acc@1 1.46 ( 1.46) Acc@5 8.01 ( 8.01) +Epoch: [0][10/40] Time 1.351 ( 4.476) Data 0.000 ( 2.942) Loss 7.1505e-01 (1.5806e+00) Acc@1 77.15 ( 62.13) Acc@5 96.29 ( 84.71) +Epoch: [0][20/40] Time 1.351 ( 2.988) Data 0.000 ( 1.541) Loss 6.8699e-01 (1.1710e+00) Acc@1 79.98 ( 70.27) Acc@5 96.39 ( 90.35) +Epoch: [0][30/40] Time 1.351 ( 2.460) Data 0.000 ( 1.044) Loss 6.2460e-01 (1.0027e+00) Acc@1 83.11 ( 73.87) Acc@5 97.07 ( 92.46) +Test: [ 0/17] Time 9.444 ( 9.444) Loss 6.6350e-01 (6.6350e-01) Acc@1 80.57 ( 80.57) Acc@5 96.39 ( 96.39) +Test: [10/17] Time 1.351 ( 2.087) Loss 6.8147e-01 (6.6886e-01) Acc@1 82.62 ( 81.63) Acc@5 96.39 ( 96.66) + * Acc@1 81.907 Acc@5 96.864 +found new best accuracy:= tensor(81.9070, device='cuda:0') +Epoch: [1][ 0/40] Time 8.275 ( 8.275) Data 6.908 ( 6.908) Loss 4.7345e-01 (4.7345e-01) Acc@1 86.52 ( 86.52) Acc@5 98.05 ( 98.05) +Epoch: [1][10/40] Time 1.352 ( 1.981) Data 0.000 ( 0.628) Loss 4.7622e-01 (4.9925e-01) Acc@1 86.04 ( 85.10) Acc@5 97.46 ( 97.59) +Epoch: [1][20/40] Time 1.352 ( 1.682) Data 0.000 ( 0.329) Loss 4.5610e-01 (4.7294e-01) Acc@1 86.52 ( 85.74) Acc@5 97.66 ( 97.73) +Epoch: [1][30/40] Time 1.352 ( 1.575) Data 0.000 ( 0.223) Loss 3.7709e-01 (4.5355e-01) Acc@1 87.30 ( 86.14) Acc@5 98.54 ( 97.87) +Test: [ 0/17] Time 6.312 ( 6.312) Loss 5.2210e-01 (5.2210e-01) Acc@1 83.59 ( 83.59) Acc@5 97.17 ( 97.17) +Test: [10/17] Time 1.350 ( 1.802) Loss 5.4265e-01 (5.4488e-01) Acc@1 84.96 ( 84.27) Acc@5 96.78 ( 97.19) + * Acc@1 84.406 Acc@5 97.331 +found new best accuracy:= tensor(84.4059, device='cuda:0') +Epoch: [2][ 0/40] Time 8.661 ( 8.661) Data 7.282 ( 7.282) Loss 3.4213e-01 (3.4213e-01) Acc@1 88.67 ( 88.67) Acc@5 98.73 ( 98.73) +Epoch: [2][10/40] Time 1.351 ( 2.018) Data 0.000 ( 0.662) Loss 3.6840e-01 (3.4258e-01) Acc@1 87.99 ( 89.01) Acc@5 98.24 ( 98.49) +Epoch: [2][20/40] Time 1.351 ( 1.701) Data 0.000 ( 0.347) Loss 3.2181e-01 (3.4519e-01) Acc@1 90.23 ( 88.90) Acc@5 98.83 ( 98.53) +Epoch: [2][30/40] Time 1.351 ( 1.588) Data 0.000 ( 0.235) Loss 4.0143e-01 (3.5110e-01) Acc@1 86.91 ( 88.77) Acc@5 98.14 ( 98.49) +Test: [ 0/17] Time 6.294 ( 6.294) Loss 5.3062e-01 (5.3062e-01) Acc@1 83.89 ( 83.89) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.800) Loss 5.7867e-01 (5.6785e-01) Acc@1 83.79 ( 83.78) Acc@5 97.36 ( 97.21) + * Acc@1 83.842 Acc@5 97.349 +Epoch: [3][ 0/40] Time 9.246 ( 9.246) Data 7.878 ( 7.878) Loss 3.4537e-01 (3.4537e-01) Acc@1 88.77 ( 88.77) Acc@5 98.83 ( 98.83) +Epoch: [3][10/40] Time 1.351 ( 2.069) Data 0.000 ( 0.716) Loss 2.9641e-01 (3.3695e-01) Acc@1 91.50 ( 89.46) Acc@5 98.54 ( 98.68) +Epoch: [3][20/40] Time 1.351 ( 1.727) Data 0.000 ( 0.375) Loss 3.4254e-01 (3.1933e-01) Acc@1 89.94 ( 90.01) Acc@5 98.14 ( 98.73) +Epoch: [3][30/40] Time 1.353 ( 1.606) Data 0.000 ( 0.254) Loss 3.3116e-01 (3.1419e-01) Acc@1 88.38 ( 90.16) Acc@5 98.73 ( 98.75) +Test: [ 0/17] Time 6.418 ( 6.418) Loss 4.6881e-01 (4.6881e-01) Acc@1 85.45 ( 85.45) Acc@5 97.36 ( 97.36) +Test: [10/17] Time 1.351 ( 1.811) Loss 5.1866e-01 (5.1727e-01) Acc@1 85.74 ( 84.73) Acc@5 97.27 ( 97.28) + * Acc@1 84.934 Acc@5 97.459 +found new best accuracy:= tensor(84.9336, device='cuda:0') +Epoch: [4][ 0/40] Time 9.080 ( 9.080) Data 7.699 ( 7.699) Loss 2.8721e-01 (2.8721e-01) Acc@1 90.82 ( 90.82) Acc@5 98.54 ( 98.54) +Epoch: [4][10/40] Time 1.351 ( 2.054) Data 0.000 ( 0.700) Loss 2.4891e-01 (2.7707e-01) Acc@1 92.19 ( 91.10) Acc@5 99.22 ( 99.00) +Epoch: [4][20/40] Time 1.352 ( 1.719) Data 0.000 ( 0.367) Loss 2.7459e-01 (2.6765e-01) Acc@1 91.50 ( 91.56) Acc@5 98.83 ( 99.03) +Epoch: [4][30/40] Time 1.351 ( 1.601) Data 0.000 ( 0.248) Loss 2.4207e-01 (2.6700e-01) Acc@1 93.07 ( 91.72) Acc@5 99.22 ( 99.01) +Test: [ 0/17] Time 6.436 ( 6.436) Loss 4.6032e-01 (4.6032e-01) Acc@1 85.55 ( 85.55) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.813) Loss 5.2012e-01 (4.9625e-01) Acc@1 84.38 ( 85.09) Acc@5 97.27 ( 97.44) + * Acc@1 85.291 Acc@5 97.610 +found new best accuracy:= tensor(85.2914, device='cuda:0') +Epoch: [5][ 0/40] Time 9.937 ( 9.937) Data 8.560 ( 8.560) Loss 2.1351e-01 (2.1351e-01) Acc@1 93.85 ( 93.85) Acc@5 99.32 ( 99.32) +Epoch: [5][10/40] Time 1.351 ( 2.132) Data 0.000 ( 0.778) Loss 2.5409e-01 (2.4660e-01) Acc@1 92.38 ( 92.29) Acc@5 99.02 ( 99.09) +Epoch: [5][20/40] Time 1.351 ( 1.760) Data 0.000 ( 0.408) Loss 2.1825e-01 (2.3921e-01) Acc@1 92.48 ( 92.47) Acc@5 99.61 ( 99.20) +Epoch: [5][30/40] Time 1.351 ( 1.628) Data 0.000 ( 0.276) Loss 2.2156e-01 (2.4245e-01) Acc@1 92.58 ( 92.34) Acc@5 99.12 ( 99.17) +Test: [ 0/17] Time 6.366 ( 6.366) Loss 4.7040e-01 (4.7040e-01) Acc@1 85.16 ( 85.16) Acc@5 97.95 ( 97.95) +Test: [10/17] Time 1.351 ( 1.807) Loss 5.1714e-01 (5.0729e-01) Acc@1 85.16 ( 85.15) Acc@5 97.36 ( 97.51) + * Acc@1 85.316 Acc@5 97.628 +found new best accuracy:= tensor(85.3157, device='cuda:0') +Epoch: [6][ 0/40] Time 10.249 (10.249) Data 8.838 ( 8.838) Loss 1.9628e-01 (1.9628e-01) Acc@1 94.43 ( 94.43) Acc@5 99.71 ( 99.71) +Epoch: [6][10/40] Time 1.351 ( 2.160) Data 0.000 ( 0.804) Loss 2.0008e-01 (2.1030e-01) Acc@1 93.95 ( 93.63) Acc@5 99.32 ( 99.30) +Epoch: [6][20/40] Time 1.352 ( 1.775) Data 0.000 ( 0.421) Loss 2.1386e-01 (2.1734e-01) Acc@1 93.85 ( 93.43) Acc@5 99.22 ( 99.22) +Epoch: [6][30/40] Time 1.351 ( 1.638) Data 0.000 ( 0.285) Loss 1.8164e-01 (2.1621e-01) Acc@1 95.02 ( 93.53) Acc@5 99.51 ( 99.24) +Test: [ 0/17] Time 6.352 ( 6.352) Loss 4.6154e-01 (4.6154e-01) Acc@1 86.43 ( 86.43) Acc@5 97.56 ( 97.56) +Test: [10/17] Time 1.350 ( 1.805) Loss 4.9606e-01 (4.8977e-01) Acc@1 86.04 ( 85.96) Acc@5 97.17 ( 97.29) + * Acc@1 86.050 Acc@5 97.477 +found new best accuracy:= tensor(86.0496, device='cuda:0') +Epoch: [7][ 0/40] Time 8.999 ( 8.999) Data 7.622 ( 7.622) Loss 1.7763e-01 (1.7763e-01) Acc@1 95.31 ( 95.31) Acc@5 99.51 ( 99.51) +Epoch: [7][10/40] Time 1.351 ( 2.047) Data 0.000 ( 0.693) Loss 2.0463e-01 (2.0229e-01) Acc@1 94.24 ( 93.88) Acc@5 99.22 ( 99.41) +Epoch: [7][20/40] Time 1.351 ( 1.715) Data 0.000 ( 0.363) Loss 2.0656e-01 (2.0377e-01) Acc@1 92.97 ( 93.72) Acc@5 99.41 ( 99.44) +Epoch: [7][30/40] Time 1.351 ( 1.598) Data 0.000 ( 0.246) Loss 2.5290e-01 (2.0924e-01) Acc@1 92.19 ( 93.60) Acc@5 99.22 ( 99.41) +Test: [ 0/17] Time 6.367 ( 6.367) Loss 4.6642e-01 (4.6642e-01) Acc@1 85.84 ( 85.84) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.807) Loss 4.9533e-01 (4.9617e-01) Acc@1 86.23 ( 85.72) Acc@5 97.66 ( 97.55) + * Acc@1 85.886 Acc@5 97.732 +Epoch: [8][ 0/40] Time 8.609 ( 8.609) Data 7.219 ( 7.219) Loss 1.7401e-01 (1.7401e-01) Acc@1 95.21 ( 95.21) Acc@5 99.71 ( 99.71) +Epoch: [8][10/40] Time 1.352 ( 2.011) Data 0.000 ( 0.656) Loss 1.9994e-01 (1.8653e-01) Acc@1 94.14 ( 94.66) Acc@5 99.41 ( 99.52) +Epoch: [8][20/40] Time 1.351 ( 1.698) Data 0.000 ( 0.344) Loss 1.9359e-01 (1.8811e-01) Acc@1 94.82 ( 94.62) Acc@5 99.32 ( 99.51) +Epoch: [8][30/40] Time 1.351 ( 1.586) Data 0.000 ( 0.233) Loss 1.8080e-01 (1.8750e-01) Acc@1 94.14 ( 94.63) Acc@5 99.41 ( 99.48) +Test: [ 0/17] Time 6.465 ( 6.465) Loss 4.6426e-01 (4.6426e-01) Acc@1 86.13 ( 86.13) Acc@5 98.14 ( 98.14) +Test: [10/17] Time 1.351 ( 1.816) Loss 4.9818e-01 (4.9181e-01) Acc@1 86.04 ( 85.83) Acc@5 97.56 ( 97.42) + * Acc@1 85.953 Acc@5 97.580 +Epoch: [9][ 0/40] Time 11.221 (11.221) Data 9.825 ( 9.825) Loss 1.5254e-01 (1.5254e-01) Acc@1 96.00 ( 96.00) Acc@5 99.80 ( 99.80) +Epoch: [9][10/40] Time 1.351 ( 2.248) Data 0.000 ( 0.893) Loss 1.5869e-01 (1.6591e-01) Acc@1 96.00 ( 95.49) Acc@5 99.61 ( 99.52) +Epoch: [9][20/40] Time 1.351 ( 1.821) Data 0.000 ( 0.468) Loss 1.4354e-01 (1.6924e-01) Acc@1 96.09 ( 95.29) Acc@5 99.41 ( 99.53) +Epoch: [9][30/40] Time 1.351 ( 1.670) Data 0.000 ( 0.317) Loss 1.6189e-01 (1.7138e-01) Acc@1 95.61 ( 95.21) Acc@5 99.61 ( 99.54) +Test: [ 0/17] Time 6.412 ( 6.412) Loss 4.4322e-01 (4.4322e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.811) Loss 4.8196e-01 (4.8424e-01) Acc@1 86.72 ( 86.04) Acc@5 97.56 ( 97.38) + * Acc@1 86.292 Acc@5 97.574 +found new best accuracy:= tensor(86.2922, device='cuda:0') +Epoch: [10][ 0/40] Time 8.679 ( 8.679) Data 7.295 ( 7.295) Loss 1.2514e-01 (1.2514e-01) Acc@1 96.39 ( 96.39) Acc@5 99.90 ( 99.90) +Epoch: [10][10/40] Time 1.352 ( 2.018) Data 0.000 ( 0.663) Loss 1.8374e-01 (1.6020e-01) Acc@1 95.02 ( 95.39) Acc@5 99.61 ( 99.73) +Epoch: [10][20/40] Time 1.351 ( 1.701) Data 0.000 ( 0.348) Loss 1.8136e-01 (1.6220e-01) Acc@1 94.82 ( 95.44) Acc@5 99.61 ( 99.65) +Epoch: [10][30/40] Time 1.351 ( 1.588) Data 0.000 ( 0.235) Loss 1.5710e-01 (1.6173e-01) Acc@1 95.02 ( 95.46) Acc@5 99.80 ( 99.64) +Test: [ 0/17] Time 6.468 ( 6.468) Loss 4.7388e-01 (4.7388e-01) Acc@1 85.55 ( 85.55) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.816) Loss 4.9233e-01 (4.9981e-01) Acc@1 85.64 ( 85.56) Acc@5 97.46 ( 97.53) + * Acc@1 85.734 Acc@5 97.695 +Epoch: [11][ 0/40] Time 10.308 (10.308) Data 8.916 ( 8.916) Loss 1.4831e-01 (1.4831e-01) Acc@1 96.48 ( 96.48) Acc@5 99.61 ( 99.61) +Epoch: [11][10/40] Time 1.351 ( 2.166) Data 0.000 ( 0.811) Loss 1.7252e-01 (1.6085e-01) Acc@1 95.21 ( 95.58) Acc@5 99.51 ( 99.71) +Epoch: [11][20/40] Time 1.351 ( 1.778) Data 0.000 ( 0.425) Loss 1.8609e-01 (1.5844e-01) Acc@1 94.14 ( 95.67) Acc@5 99.51 ( 99.73) +Epoch: [11][30/40] Time 1.351 ( 1.641) Data 0.000 ( 0.288) Loss 1.4364e-01 (1.5660e-01) Acc@1 96.88 ( 95.76) Acc@5 99.61 ( 99.69) +Test: [ 0/17] Time 6.495 ( 6.495) Loss 4.6515e-01 (4.6515e-01) Acc@1 85.94 ( 85.94) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.818) Loss 4.9548e-01 (4.9606e-01) Acc@1 86.43 ( 85.73) Acc@5 97.56 ( 97.51) + * Acc@1 85.874 Acc@5 97.665 +Epoch: [12][ 0/40] Time 8.899 ( 8.899) Data 7.523 ( 7.523) Loss 1.3309e-01 (1.3309e-01) Acc@1 95.90 ( 95.90) Acc@5 100.00 (100.00) +Epoch: [12][10/40] Time 1.352 ( 2.133) Data 0.000 ( 0.779) Loss 1.4172e-01 (1.4629e-01) Acc@1 96.19 ( 95.89) Acc@5 99.71 ( 99.75) +Epoch: [12][20/40] Time 1.352 ( 1.761) Data 0.000 ( 0.408) Loss 1.2625e-01 (1.4751e-01) Acc@1 96.68 ( 95.92) Acc@5 100.00 ( 99.70) +Epoch: [12][30/40] Time 1.352 ( 1.629) Data 0.000 ( 0.277) Loss 1.4298e-01 (1.4550e-01) Acc@1 96.29 ( 96.10) Acc@5 99.90 ( 99.74) +Test: [ 0/17] Time 6.520 ( 6.520) Loss 4.5639e-01 (4.5639e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.821) Loss 4.9424e-01 (4.9610e-01) Acc@1 86.62 ( 85.99) Acc@5 97.66 ( 97.49) + * Acc@1 86.110 Acc@5 97.586 +Epoch: [13][ 0/40] Time 10.682 (10.682) Data 9.308 ( 9.308) Loss 1.3167e-01 (1.3167e-01) Acc@1 96.48 ( 96.48) Acc@5 99.80 ( 99.80) +Epoch: [13][10/40] Time 1.351 ( 2.199) Data 0.000 ( 0.846) Loss 1.2454e-01 (1.3549e-01) Acc@1 97.07 ( 96.43) Acc@5 99.71 ( 99.77) +Epoch: [13][20/40] Time 1.351 ( 1.796) Data 0.000 ( 0.443) Loss 1.0856e-01 (1.3258e-01) Acc@1 97.46 ( 96.52) Acc@5 99.90 ( 99.80) +Epoch: [13][30/40] Time 1.351 ( 1.652) Data 0.000 ( 0.300) Loss 1.1665e-01 (1.3460e-01) Acc@1 97.36 ( 96.48) Acc@5 99.90 ( 99.77) +Test: [ 0/17] Time 6.345 ( 6.345) Loss 4.5755e-01 (4.5755e-01) Acc@1 86.13 ( 86.13) Acc@5 97.95 ( 97.95) +Test: [10/17] Time 1.352 ( 1.806) Loss 5.0011e-01 (4.9670e-01) Acc@1 86.52 ( 86.03) Acc@5 97.46 ( 97.60) + * Acc@1 86.232 Acc@5 97.744 +Epoch: [14][ 0/40] Time 10.155 (10.155) Data 8.772 ( 8.772) Loss 1.3170e-01 (1.3170e-01) Acc@1 96.00 ( 96.00) Acc@5 99.61 ( 99.61) +Epoch: [14][10/40] Time 1.351 ( 2.152) Data 0.000 ( 0.798) Loss 1.2912e-01 (1.2418e-01) Acc@1 96.68 ( 96.90) Acc@5 99.71 ( 99.76) +Epoch: [14][20/40] Time 1.352 ( 1.771) Data 0.000 ( 0.418) Loss 1.1824e-01 (1.2530e-01) Acc@1 96.78 ( 96.90) Acc@5 100.00 ( 99.76) +Epoch: [14][30/40] Time 1.352 ( 1.636) Data 0.000 ( 0.283) Loss 1.3060e-01 (1.2589e-01) Acc@1 96.58 ( 96.82) Acc@5 99.71 ( 99.76) +Test: [ 0/17] Time 6.832 ( 6.832) Loss 4.5302e-01 (4.5302e-01) Acc@1 86.33 ( 86.33) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.849) Loss 4.8819e-01 (4.9498e-01) Acc@1 86.43 ( 85.93) Acc@5 97.36 ( 97.41) + * Acc@1 86.086 Acc@5 97.580 +Epoch: [15][ 0/40] Time 10.550 (10.550) Data 9.159 ( 9.159) Loss 1.1391e-01 (1.1391e-01) Acc@1 97.75 ( 97.75) Acc@5 99.90 ( 99.90) +Epoch: [15][10/40] Time 1.351 ( 2.188) Data 0.000 ( 0.833) Loss 1.2942e-01 (1.1580e-01) Acc@1 96.88 ( 97.21) Acc@5 99.90 ( 99.88) +Epoch: [15][20/40] Time 1.351 ( 1.789) Data 0.000 ( 0.436) Loss 1.1644e-01 (1.1702e-01) Acc@1 96.88 ( 97.17) Acc@5 99.90 ( 99.88) +Epoch: [15][30/40] Time 1.351 ( 1.648) Data 0.000 ( 0.296) Loss 1.2905e-01 (1.1784e-01) Acc@1 97.27 ( 97.17) Acc@5 99.71 ( 99.86) +Test: [ 0/17] Time 6.464 ( 6.464) Loss 4.6340e-01 (4.6340e-01) Acc@1 86.52 ( 86.52) Acc@5 97.95 ( 97.95) +Test: [10/17] Time 1.351 ( 1.815) Loss 4.8530e-01 (4.9179e-01) Acc@1 86.43 ( 86.04) Acc@5 97.46 ( 97.47) + * Acc@1 86.250 Acc@5 97.653 +Epoch: [16][ 0/40] Time 11.729 (11.729) Data 10.347 (10.347) Loss 8.7147e-02 (8.7147e-02) Acc@1 98.54 ( 98.54) Acc@5 99.90 ( 99.90) +Epoch: [16][10/40] Time 1.351 ( 2.294) Data 0.000 ( 0.941) Loss 1.2203e-01 (1.1019e-01) Acc@1 96.58 ( 97.33) Acc@5 99.80 ( 99.88) +Epoch: [16][20/40] Time 1.351 ( 1.845) Data 0.000 ( 0.493) Loss 1.0802e-01 (1.1416e-01) Acc@1 97.56 ( 97.13) Acc@5 99.80 ( 99.87) +Epoch: [16][30/40] Time 1.351 ( 1.686) Data 0.000 ( 0.334) Loss 1.0128e-01 (1.1475e-01) Acc@1 97.66 ( 97.14) Acc@5 100.00 ( 99.85) +Test: [ 0/17] Time 6.522 ( 6.522) Loss 4.5606e-01 (4.5606e-01) Acc@1 86.33 ( 86.33) Acc@5 97.95 ( 97.95) +Test: [10/17] Time 1.351 ( 1.821) Loss 4.8517e-01 (4.9495e-01) Acc@1 86.33 ( 85.82) Acc@5 97.56 ( 97.59) + * Acc@1 86.104 Acc@5 97.701 +Epoch: [17][ 0/40] Time 9.638 ( 9.638) Data 8.262 ( 8.262) Loss 1.0914e-01 (1.0914e-01) Acc@1 97.46 ( 97.46) Acc@5 99.90 ( 99.90) +Epoch: [17][10/40] Time 1.351 ( 2.105) Data 0.000 ( 0.751) Loss 9.2738e-02 (1.0557e-01) Acc@1 98.14 ( 97.59) Acc@5 100.00 ( 99.90) +Epoch: [17][20/40] Time 1.353 ( 1.746) Data 0.000 ( 0.394) Loss 1.1911e-01 (1.0766e-01) Acc@1 97.56 ( 97.51) Acc@5 99.71 ( 99.85) +Epoch: [17][30/40] Time 1.354 ( 1.620) Data 0.000 ( 0.267) Loss 8.9753e-02 (1.0751e-01) Acc@1 98.05 ( 97.56) Acc@5 100.00 ( 99.85) +Test: [ 0/17] Time 6.313 ( 6.313) Loss 4.6047e-01 (4.6047e-01) Acc@1 86.23 ( 86.23) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.802) Loss 4.9527e-01 (4.9735e-01) Acc@1 86.62 ( 85.96) Acc@5 97.36 ( 97.48) + * Acc@1 86.201 Acc@5 97.610 +Epoch: [18][ 0/40] Time 9.413 ( 9.413) Data 8.035 ( 8.035) Loss 8.0262e-02 (8.0262e-02) Acc@1 98.63 ( 98.63) Acc@5 99.90 ( 99.90) +Epoch: [18][10/40] Time 1.351 ( 2.084) Data 0.000 ( 0.731) Loss 9.1729e-02 (9.8773e-02) Acc@1 98.34 ( 97.98) Acc@5 99.80 ( 99.88) +Epoch: [18][20/40] Time 1.351 ( 1.735) Data 0.000 ( 0.383) Loss 1.1155e-01 (1.0553e-01) Acc@1 98.05 ( 97.70) Acc@5 99.80 ( 99.87) +Epoch: [18][30/40] Time 1.351 ( 1.611) Data 0.000 ( 0.259) Loss 1.1009e-01 (1.0364e-01) Acc@1 98.14 ( 97.75) Acc@5 99.71 ( 99.86) +Test: [ 0/17] Time 8.387 ( 8.387) Loss 4.6193e-01 (4.6193e-01) Acc@1 86.13 ( 86.13) Acc@5 97.85 ( 97.85) +Test: [10/17] Time 1.351 ( 1.991) Loss 5.0826e-01 (5.0131e-01) Acc@1 85.84 ( 86.01) Acc@5 97.56 ( 97.50) + * Acc@1 86.274 Acc@5 97.653 +Epoch: [19][ 0/40] Time 10.394 (10.394) Data 9.009 ( 9.009) Loss 1.0752e-01 (1.0752e-01) Acc@1 96.97 ( 96.97) Acc@5 99.90 ( 99.90) +Epoch: [19][10/40] Time 1.351 ( 2.174) Data 0.000 ( 0.819) Loss 9.8765e-02 (9.7896e-02) Acc@1 97.66 ( 97.86) Acc@5 99.90 ( 99.90) +Epoch: [19][20/40] Time 1.352 ( 1.782) Data 0.000 ( 0.429) Loss 8.7590e-02 (9.7002e-02) Acc@1 98.54 ( 97.98) Acc@5 100.00 ( 99.90) +Epoch: [19][30/40] Time 1.351 ( 1.643) Data 0.000 ( 0.291) Loss 9.6886e-02 (9.6541e-02) Acc@1 98.14 ( 97.94) Acc@5 99.80 ( 99.90) +Test: [ 0/17] Time 6.566 ( 6.566) Loss 4.7203e-01 (4.7203e-01) Acc@1 86.13 ( 86.13) Acc@5 97.56 ( 97.56) +Test: [10/17] Time 1.351 ( 1.825) Loss 5.0751e-01 (5.0678e-01) Acc@1 85.94 ( 85.85) Acc@5 97.46 ( 97.43) + * Acc@1 86.159 Acc@5 97.634 +Epoch: [20][ 0/40] Time 12.609 (12.609) Data 11.227 (11.227) Loss 9.2383e-02 (9.2383e-02) Acc@1 97.85 ( 97.85) Acc@5 100.00 (100.00) +Epoch: [20][10/40] Time 1.351 ( 2.375) Data 0.000 ( 1.021) Loss 9.3862e-02 (9.3753e-02) Acc@1 98.14 ( 98.19) Acc@5 99.71 ( 99.89) +Epoch: [20][20/40] Time 1.352 ( 1.888) Data 0.000 ( 0.535) Loss 9.2689e-02 (9.1840e-02) Acc@1 98.14 ( 98.25) Acc@5 99.80 ( 99.92) +Epoch: [20][30/40] Time 1.352 ( 1.715) Data 0.000 ( 0.362) Loss 1.0695e-01 (9.3090e-02) Acc@1 97.66 ( 98.14) Acc@5 99.90 ( 99.92) +Test: [ 0/17] Time 6.517 ( 6.517) Loss 4.5616e-01 (4.5616e-01) Acc@1 86.82 ( 86.82) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.820) Loss 4.9661e-01 (4.9907e-01) Acc@1 86.62 ( 86.16) Acc@5 97.36 ( 97.47) + * Acc@1 86.420 Acc@5 97.622 +found new best accuracy:= tensor(86.4196, device='cuda:0') +Epoch: [21][ 0/40] Time 8.614 ( 8.614) Data 7.240 ( 7.240) Loss 8.4987e-02 (8.4987e-02) Acc@1 98.34 ( 98.34) Acc@5 100.00 (100.00) +Epoch: [21][10/40] Time 1.351 ( 2.012) Data 0.000 ( 0.658) Loss 7.5419e-02 (8.6969e-02) Acc@1 98.93 ( 98.38) Acc@5 100.00 ( 99.92) +Epoch: [21][20/40] Time 1.352 ( 1.697) Data 0.000 ( 0.345) Loss 9.6718e-02 (8.7747e-02) Acc@1 98.14 ( 98.35) Acc@5 99.90 ( 99.92) +Epoch: [21][30/40] Time 1.351 ( 1.586) Data 0.000 ( 0.234) Loss 1.0456e-01 (8.9153e-02) Acc@1 97.85 ( 98.30) Acc@5 99.90 ( 99.93) +Test: [ 0/17] Time 6.431 ( 6.431) Loss 4.6804e-01 (4.6804e-01) Acc@1 85.84 ( 85.84) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.352 ( 1.813) Loss 5.0303e-01 (5.0722e-01) Acc@1 86.23 ( 85.94) Acc@5 97.36 ( 97.38) + * Acc@1 86.189 Acc@5 97.531 +Epoch: [22][ 0/40] Time 9.156 ( 9.156) Data 7.781 ( 7.781) Loss 7.6330e-02 (7.6330e-02) Acc@1 98.24 ( 98.24) Acc@5 100.00 (100.00) +Epoch: [22][10/40] Time 1.351 ( 2.061) Data 0.000 ( 0.708) Loss 7.2892e-02 (8.3032e-02) Acc@1 98.54 ( 98.30) Acc@5 100.00 ( 99.97) +Epoch: [22][20/40] Time 1.352 ( 1.723) Data 0.000 ( 0.371) Loss 8.0852e-02 (8.3633e-02) Acc@1 98.54 ( 98.35) Acc@5 100.00 ( 99.95) +Epoch: [22][30/40] Time 1.351 ( 1.603) Data 0.000 ( 0.251) Loss 7.7457e-02 (8.3986e-02) Acc@1 98.73 ( 98.32) Acc@5 100.00 ( 99.95) +Test: [ 0/17] Time 9.036 ( 9.036) Loss 4.6071e-01 (4.6071e-01) Acc@1 86.82 ( 86.82) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 2.050) Loss 5.0788e-01 (5.0517e-01) Acc@1 86.62 ( 86.22) Acc@5 97.27 ( 97.44) + * Acc@1 86.444 Acc@5 97.568 +found new best accuracy:= tensor(86.4439, device='cuda:0') +Epoch: [23][ 0/40] Time 10.294 (10.294) Data 8.911 ( 8.911) Loss 9.3961e-02 (9.3961e-02) Acc@1 98.44 ( 98.44) Acc@5 99.90 ( 99.90) +Epoch: [23][10/40] Time 1.351 ( 2.164) Data 0.000 ( 0.810) Loss 7.5571e-02 (8.2124e-02) Acc@1 98.83 ( 98.61) Acc@5 100.00 ( 99.95) +Epoch: [23][20/40] Time 1.351 ( 1.777) Data 0.000 ( 0.424) Loss 6.9884e-02 (8.2503e-02) Acc@1 99.12 ( 98.59) Acc@5 99.90 ( 99.94) +Epoch: [23][30/40] Time 1.351 ( 1.640) Data 0.000 ( 0.288) Loss 8.2688e-02 (8.1821e-02) Acc@1 98.24 ( 98.60) Acc@5 100.00 ( 99.95) +Test: [ 0/17] Time 6.506 ( 6.506) Loss 4.7467e-01 (4.7467e-01) Acc@1 86.04 ( 86.04) Acc@5 97.56 ( 97.56) +Test: [10/17] Time 1.353 ( 1.820) Loss 5.0348e-01 (5.0457e-01) Acc@1 86.43 ( 86.05) Acc@5 97.17 ( 97.44) + * Acc@1 86.262 Acc@5 97.634 +Epoch: [24][ 0/40] Time 14.269 (14.269) Data 12.885 (12.885) Loss 7.5534e-02 (7.5534e-02) Acc@1 98.73 ( 98.73) Acc@5 100.00 (100.00) +Epoch: [24][10/40] Time 1.352 ( 2.526) Data 0.000 ( 1.171) Loss 7.5895e-02 (7.7038e-02) Acc@1 98.73 ( 98.78) Acc@5 100.00 ( 99.96) +Epoch: [24][20/40] Time 1.352 ( 1.967) Data 0.000 ( 0.614) Loss 8.1626e-02 (7.7530e-02) Acc@1 98.44 ( 98.70) Acc@5 100.00 ( 99.97) +Epoch: [24][30/40] Time 1.351 ( 1.768) Data 0.000 ( 0.416) Loss 8.8632e-02 (7.7817e-02) Acc@1 98.44 ( 98.70) Acc@5 100.00 ( 99.97) +Test: [ 0/17] Time 6.504 ( 6.504) Loss 4.7577e-01 (4.7577e-01) Acc@1 85.74 ( 85.74) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.352 ( 1.820) Loss 5.0028e-01 (5.1029e-01) Acc@1 86.72 ( 86.18) Acc@5 97.27 ( 97.50) + * Acc@1 86.395 Acc@5 97.628 +Epoch: [25][ 0/40] Time 10.754 (10.754) Data 9.389 ( 9.389) Loss 7.4723e-02 (7.4723e-02) Acc@1 98.93 ( 98.93) Acc@5 100.00 (100.00) +Epoch: [25][10/40] Time 1.351 ( 2.206) Data 0.000 ( 0.854) Loss 8.5663e-02 (7.9952e-02) Acc@1 98.24 ( 98.49) Acc@5 100.00 ( 99.96) +Epoch: [25][20/40] Time 1.352 ( 1.799) Data 0.000 ( 0.447) Loss 7.9646e-02 (7.8899e-02) Acc@1 98.93 ( 98.60) Acc@5 99.90 ( 99.97) +Epoch: [25][30/40] Time 1.353 ( 1.655) Data 0.001 ( 0.303) Loss 7.6560e-02 (7.7527e-02) Acc@1 98.63 ( 98.64) Acc@5 99.80 ( 99.97) +Test: [ 0/17] Time 6.531 ( 6.531) Loss 4.7433e-01 (4.7433e-01) Acc@1 85.64 ( 85.64) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.822) Loss 5.0044e-01 (5.0410e-01) Acc@1 86.62 ( 86.17) Acc@5 97.07 ( 97.50) + * Acc@1 86.486 Acc@5 97.568 +found new best accuracy:= tensor(86.4863, device='cuda:0') +Epoch: [26][ 0/40] Time 10.466 (10.466) Data 9.075 ( 9.075) Loss 5.7828e-02 (5.7828e-02) Acc@1 99.61 ( 99.61) Acc@5 100.00 (100.00) +Epoch: [26][10/40] Time 1.351 ( 2.180) Data 0.000 ( 0.825) Loss 8.4694e-02 (7.7107e-02) Acc@1 98.63 ( 98.57) Acc@5 99.90 ( 99.97) +Epoch: [26][20/40] Time 1.351 ( 1.786) Data 0.000 ( 0.432) Loss 7.4485e-02 (7.5974e-02) Acc@1 98.54 ( 98.65) Acc@5 100.00 ( 99.96) +Epoch: [26][30/40] Time 1.351 ( 1.646) Data 0.000 ( 0.293) Loss 6.6370e-02 (7.4340e-02) Acc@1 99.41 ( 98.73) Acc@5 100.00 ( 99.97) +Test: [ 0/17] Time 10.185 (10.185) Loss 4.7806e-01 (4.7806e-01) Acc@1 85.74 ( 85.74) Acc@5 97.85 ( 97.85) +Test: [10/17] Time 1.350 ( 2.154) Loss 5.0855e-01 (5.1179e-01) Acc@1 87.01 ( 86.02) Acc@5 97.27 ( 97.45) + * Acc@1 86.304 Acc@5 97.622 +Epoch: [27][ 0/40] Time 11.963 (11.963) Data 10.595 (10.595) Loss 6.5065e-02 (6.5065e-02) Acc@1 99.22 ( 99.22) Acc@5 100.00 (100.00) +Epoch: [27][10/40] Time 1.351 ( 2.316) Data 0.000 ( 0.963) Loss 6.7011e-02 (7.0286e-02) Acc@1 98.73 ( 98.88) Acc@5 99.90 ( 99.97) +Epoch: [27][20/40] Time 1.351 ( 1.857) Data 0.000 ( 0.505) Loss 6.0937e-02 (6.9264e-02) Acc@1 99.41 ( 98.92) Acc@5 100.00 ( 99.97) +Epoch: [27][30/40] Time 1.352 ( 1.694) Data 0.000 ( 0.342) Loss 6.7789e-02 (7.1002e-02) Acc@1 99.61 ( 98.86) Acc@5 100.00 ( 99.97) +Test: [ 0/17] Time 6.487 ( 6.487) Loss 4.8722e-01 (4.8722e-01) Acc@1 86.13 ( 86.13) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.818) Loss 5.2172e-01 (5.1692e-01) Acc@1 85.94 ( 86.06) Acc@5 97.36 ( 97.40) + * Acc@1 86.316 Acc@5 97.562 +Epoch: [28][ 0/40] Time 13.501 (13.501) Data 12.129 (12.129) Loss 6.0516e-02 (6.0516e-02) Acc@1 99.61 ( 99.61) Acc@5 100.00 (100.00) +Epoch: [28][10/40] Time 1.352 ( 2.457) Data 0.000 ( 1.103) Loss 6.3679e-02 (6.6120e-02) Acc@1 99.12 ( 99.15) Acc@5 100.00 ( 99.99) +Epoch: [28][20/40] Time 1.352 ( 1.930) Data 0.000 ( 0.578) Loss 8.2098e-02 (6.6854e-02) Acc@1 98.54 ( 99.11) Acc@5 99.80 ( 99.98) +Epoch: [28][30/40] Time 1.352 ( 1.744) Data 0.000 ( 0.391) Loss 6.8748e-02 (6.6713e-02) Acc@1 98.54 ( 99.09) Acc@5 100.00 ( 99.98) +Test: [ 0/17] Time 6.495 ( 6.495) Loss 4.7674e-01 (4.7674e-01) Acc@1 86.52 ( 86.52) Acc@5 97.85 ( 97.85) +Test: [10/17] Time 1.351 ( 1.819) Loss 5.0813e-01 (5.1223e-01) Acc@1 86.72 ( 86.27) Acc@5 97.66 ( 97.50) + * Acc@1 86.456 Acc@5 97.647 +Epoch: [29][ 0/40] Time 8.969 ( 8.969) Data 7.592 ( 7.592) Loss 5.7755e-02 (5.7755e-02) Acc@1 99.02 ( 99.02) Acc@5 100.00 (100.00) +Epoch: [29][10/40] Time 1.352 ( 2.045) Data 0.000 ( 0.690) Loss 6.4470e-02 (6.2039e-02) Acc@1 99.02 ( 99.23) Acc@5 99.80 ( 99.96) +Epoch: [29][20/40] Time 1.353 ( 1.715) Data 0.000 ( 0.362) Loss 6.2173e-02 (6.3690e-02) Acc@1 99.12 ( 99.14) Acc@5 100.00 ( 99.98) +Epoch: [29][30/40] Time 1.352 ( 1.598) Data 0.000 ( 0.245) Loss 6.2651e-02 (6.4069e-02) Acc@1 99.32 ( 99.13) Acc@5 100.00 ( 99.98) +Test: [ 0/17] Time 7.886 ( 7.886) Loss 4.8419e-01 (4.8419e-01) Acc@1 85.45 ( 85.45) Acc@5 97.56 ( 97.56) +Test: [10/17] Time 1.352 ( 1.945) Loss 5.1116e-01 (5.1623e-01) Acc@1 86.23 ( 86.05) Acc@5 97.36 ( 97.44) + * Acc@1 86.347 Acc@5 97.598 +Epoch: [30][ 0/40] Time 9.340 ( 9.340) Data 7.961 ( 7.961) Loss 5.2260e-02 (5.2260e-02) Acc@1 99.71 ( 99.71) Acc@5 99.90 ( 99.90) +Epoch: [30][10/40] Time 1.352 ( 2.079) Data 0.000 ( 0.724) Loss 4.9381e-02 (6.0678e-02) Acc@1 99.22 ( 99.23) Acc@5 100.00 ( 99.98) +Epoch: [30][20/40] Time 1.352 ( 1.733) Data 0.000 ( 0.379) Loss 7.5032e-02 (6.3239e-02) Acc@1 98.54 ( 99.09) Acc@5 100.00 ( 99.99) +Epoch: [30][30/40] Time 1.351 ( 1.610) Data 0.000 ( 0.257) Loss 6.7343e-02 (6.2979e-02) Acc@1 98.83 ( 99.13) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 8.007 ( 8.007) Loss 4.8071e-01 (4.8071e-01) Acc@1 86.04 ( 86.04) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.352 ( 1.957) Loss 5.0567e-01 (5.1391e-01) Acc@1 86.72 ( 86.09) Acc@5 97.17 ( 97.41) + * Acc@1 86.371 Acc@5 97.574 +Epoch: [31][ 0/40] Time 13.006 (13.006) Data 11.617 (11.617) Loss 5.2463e-02 (5.2463e-02) Acc@1 99.61 ( 99.61) Acc@5 100.00 (100.00) +Epoch: [31][10/40] Time 1.353 ( 2.412) Data 0.000 ( 1.056) Loss 6.2020e-02 (5.5594e-02) Acc@1 99.41 ( 99.44) Acc@5 99.90 ( 99.98) +Epoch: [31][20/40] Time 1.352 ( 1.908) Data 0.000 ( 0.553) Loss 6.2775e-02 (5.7458e-02) Acc@1 98.73 ( 99.31) Acc@5 100.00 ( 99.99) +Epoch: [31][30/40] Time 1.351 ( 1.728) Data 0.000 ( 0.375) Loss 7.3734e-02 (5.9253e-02) Acc@1 98.44 ( 99.27) Acc@5 100.00 ( 99.98) +Test: [ 0/17] Time 6.394 ( 6.394) Loss 4.8360e-01 (4.8360e-01) Acc@1 86.23 ( 86.23) Acc@5 97.46 ( 97.46) +Test: [10/17] Time 1.352 ( 1.810) Loss 5.0900e-01 (5.1401e-01) Acc@1 86.52 ( 86.26) Acc@5 97.27 ( 97.43) + * Acc@1 86.426 Acc@5 97.604 +Epoch: [32][ 0/40] Time 12.835 (12.835) Data 11.483 (11.483) Loss 4.7411e-02 (4.7411e-02) Acc@1 99.80 ( 99.80) Acc@5 100.00 (100.00) +Epoch: [32][10/40] Time 1.351 ( 2.396) Data 0.000 ( 1.044) Loss 5.5085e-02 (5.4519e-02) Acc@1 99.41 ( 99.46) Acc@5 100.00 (100.00) +Epoch: [32][20/40] Time 1.351 ( 1.899) Data 0.000 ( 0.547) Loss 5.9401e-02 (5.6149e-02) Acc@1 98.73 ( 99.40) Acc@5 100.00 (100.00) +Epoch: [32][30/40] Time 1.351 ( 1.722) Data 0.000 ( 0.371) Loss 5.6052e-02 (5.7199e-02) Acc@1 99.51 ( 99.38) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 6.511 ( 6.511) Loss 4.8071e-01 (4.8071e-01) Acc@1 86.52 ( 86.52) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.820) Loss 5.0439e-01 (5.1567e-01) Acc@1 86.82 ( 86.08) Acc@5 97.46 ( 97.51) + * Acc@1 86.359 Acc@5 97.653 +Epoch: [33][ 0/40] Time 9.402 ( 9.402) Data 8.025 ( 8.025) Loss 4.8324e-02 (4.8324e-02) Acc@1 99.80 ( 99.80) Acc@5 100.00 (100.00) +Epoch: [33][10/40] Time 1.352 ( 2.083) Data 0.000 ( 0.730) Loss 5.3162e-02 (5.7132e-02) Acc@1 99.41 ( 99.34) Acc@5 100.00 ( 99.98) +Epoch: [33][20/40] Time 1.351 ( 1.735) Data 0.000 ( 0.382) Loss 6.3359e-02 (5.7581e-02) Acc@1 99.22 ( 99.27) Acc@5 100.00 ( 99.99) +Epoch: [33][30/40] Time 1.351 ( 1.611) Data 0.000 ( 0.259) Loss 4.8371e-02 (5.6778e-02) Acc@1 99.41 ( 99.31) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 7.985 ( 7.985) Loss 4.8574e-01 (4.8574e-01) Acc@1 86.13 ( 86.13) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.954) Loss 5.1638e-01 (5.2129e-01) Acc@1 85.94 ( 85.97) Acc@5 97.27 ( 97.43) + * Acc@1 86.238 Acc@5 97.592 +Epoch: [34][ 0/40] Time 8.733 ( 8.733) Data 7.342 ( 7.342) Loss 5.2660e-02 (5.2660e-02) Acc@1 99.61 ( 99.61) Acc@5 99.90 ( 99.90) +Epoch: [34][10/40] Time 1.352 ( 2.022) Data 0.000 ( 0.667) Loss 5.4550e-02 (5.8132e-02) Acc@1 99.41 ( 99.13) Acc@5 100.00 ( 99.98) +Epoch: [34][20/40] Time 1.352 ( 1.703) Data 0.000 ( 0.350) Loss 4.9662e-02 (5.5981e-02) Acc@1 99.51 ( 99.24) Acc@5 100.00 ( 99.99) +Epoch: [34][30/40] Time 1.351 ( 1.589) Data 0.000 ( 0.237) Loss 6.5467e-02 (5.5945e-02) Acc@1 99.32 ( 99.29) Acc@5 99.90 ( 99.99) +Test: [ 0/17] Time 8.216 ( 8.216) Loss 4.7950e-01 (4.7950e-01) Acc@1 86.52 ( 86.52) Acc@5 97.56 ( 97.56) +Test: [10/17] Time 1.350 ( 1.975) Loss 5.1184e-01 (5.1616e-01) Acc@1 86.72 ( 86.31) Acc@5 97.17 ( 97.49) + * Acc@1 86.480 Acc@5 97.665 +Epoch: [35][ 0/40] Time 14.549 (14.549) Data 13.164 (13.164) Loss 5.0899e-02 (5.0899e-02) Acc@1 99.51 ( 99.51) Acc@5 100.00 (100.00) +Epoch: [35][10/40] Time 1.351 ( 2.551) Data 0.000 ( 1.197) Loss 4.7211e-02 (4.9652e-02) Acc@1 99.90 ( 99.58) Acc@5 100.00 (100.00) +Epoch: [35][20/40] Time 1.351 ( 1.979) Data 0.000 ( 0.627) Loss 5.0148e-02 (5.1086e-02) Acc@1 99.51 ( 99.52) Acc@5 100.00 ( 99.99) +Epoch: [35][30/40] Time 1.351 ( 1.777) Data 0.000 ( 0.425) Loss 5.8193e-02 (5.2212e-02) Acc@1 99.61 ( 99.46) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 6.444 ( 6.444) Loss 4.7981e-01 (4.7981e-01) Acc@1 86.43 ( 86.43) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.814) Loss 5.2128e-01 (5.2098e-01) Acc@1 86.72 ( 86.16) Acc@5 97.17 ( 97.46) + * Acc@1 86.347 Acc@5 97.610 +Epoch: [36][ 0/40] Time 10.983 (10.983) Data 9.595 ( 9.595) Loss 5.1993e-02 (5.1993e-02) Acc@1 99.51 ( 99.51) Acc@5 100.00 (100.00) +Epoch: [36][10/40] Time 1.352 ( 2.228) Data 0.000 ( 0.872) Loss 4.7952e-02 (4.8969e-02) Acc@1 99.71 ( 99.61) Acc@5 100.00 ( 99.99) +Epoch: [36][20/40] Time 1.351 ( 1.810) Data 0.000 ( 0.457) Loss 5.1830e-02 (4.9777e-02) Acc@1 99.71 ( 99.60) Acc@5 100.00 (100.00) +Epoch: [36][30/40] Time 1.352 ( 1.662) Data 0.000 ( 0.310) Loss 5.4987e-02 (5.0629e-02) Acc@1 99.41 ( 99.55) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 6.521 ( 6.521) Loss 4.9052e-01 (4.9052e-01) Acc@1 86.52 ( 86.52) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.821) Loss 5.2048e-01 (5.2262e-01) Acc@1 86.43 ( 86.23) Acc@5 97.36 ( 97.48) + * Acc@1 86.426 Acc@5 97.622 +Epoch: [37][ 0/40] Time 10.251 (10.251) Data 8.866 ( 8.866) Loss 3.8700e-02 (3.8700e-02) Acc@1 99.71 ( 99.71) Acc@5 100.00 (100.00) +Epoch: [37][10/40] Time 1.351 ( 2.160) Data 0.000 ( 0.806) Loss 4.5669e-02 (4.5062e-02) Acc@1 99.51 ( 99.64) Acc@5 100.00 (100.00) +Epoch: [37][20/40] Time 1.351 ( 1.775) Data 0.000 ( 0.422) Loss 3.6044e-02 (4.7086e-02) Acc@1 99.71 ( 99.59) Acc@5 100.00 (100.00) +Epoch: [37][30/40] Time 1.351 ( 1.638) Data 0.000 ( 0.286) Loss 4.5578e-02 (4.8629e-02) Acc@1 99.61 ( 99.54) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 9.642 ( 9.642) Loss 4.9065e-01 (4.9065e-01) Acc@1 86.33 ( 86.33) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 2.104) Loss 5.1805e-01 (5.2722e-01) Acc@1 86.52 ( 86.19) Acc@5 97.07 ( 97.48) + * Acc@1 86.389 Acc@5 97.647 +Epoch: [38][ 0/40] Time 10.007 (10.007) Data 8.621 ( 8.621) Loss 4.9500e-02 (4.9500e-02) Acc@1 99.41 ( 99.41) Acc@5 100.00 (100.00) +Epoch: [38][10/40] Time 1.351 ( 2.138) Data 0.000 ( 0.784) Loss 4.4748e-02 (4.5641e-02) Acc@1 99.61 ( 99.60) Acc@5 100.00 (100.00) +Epoch: [38][20/40] Time 1.352 ( 1.764) Data 0.000 ( 0.411) Loss 4.9806e-02 (4.7257e-02) Acc@1 99.80 ( 99.56) Acc@5 100.00 (100.00) +Epoch: [38][30/40] Time 1.351 ( 1.631) Data 0.000 ( 0.278) Loss 4.4002e-02 (4.8157e-02) Acc@1 99.90 ( 99.56) Acc@5 100.00 ( 99.99) +Test: [ 0/17] Time 6.559 ( 6.559) Loss 4.8840e-01 (4.8840e-01) Acc@1 86.23 ( 86.23) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.824) Loss 5.2144e-01 (5.2551e-01) Acc@1 86.72 ( 86.25) Acc@5 97.07 ( 97.42) + * Acc@1 86.480 Acc@5 97.580 +Epoch: [39][ 0/40] Time 14.437 (14.437) Data 13.057 (13.057) Loss 4.7557e-02 (4.7557e-02) Acc@1 99.61 ( 99.61) Acc@5 100.00 (100.00) +Epoch: [39][10/40] Time 1.352 ( 2.541) Data 0.000 ( 1.187) Loss 4.7544e-02 (4.4068e-02) Acc@1 99.80 ( 99.67) Acc@5 99.90 ( 99.99) +Epoch: [39][20/40] Time 1.351 ( 1.975) Data 0.000 ( 0.622) Loss 4.8958e-02 (4.5828e-02) Acc@1 99.32 ( 99.62) Acc@5 100.00 (100.00) +Epoch: [39][30/40] Time 1.352 ( 1.774) Data 0.000 ( 0.421) Loss 4.7160e-02 (4.6677e-02) Acc@1 99.51 ( 99.61) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.580 ( 6.580) Loss 4.8909e-01 (4.8909e-01) Acc@1 85.94 ( 85.94) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.826) Loss 5.1834e-01 (5.2379e-01) Acc@1 86.91 ( 86.20) Acc@5 97.46 ( 97.50) + * Acc@1 86.480 Acc@5 97.659 +Epoch: [40][ 0/40] Time 9.911 ( 9.911) Data 8.539 ( 8.539) Loss 4.7164e-02 (4.7164e-02) Acc@1 99.61 ( 99.61) Acc@5 100.00 (100.00) +Epoch: [40][10/40] Time 1.351 ( 2.130) Data 0.000 ( 0.776) Loss 3.8234e-02 (4.4304e-02) Acc@1 99.80 ( 99.68) Acc@5 100.00 (100.00) +Epoch: [40][20/40] Time 1.351 ( 1.759) Data 0.000 ( 0.407) Loss 4.0974e-02 (4.4805e-02) Acc@1 99.80 ( 99.70) Acc@5 100.00 (100.00) +Epoch: [40][30/40] Time 1.351 ( 1.628) Data 0.000 ( 0.276) Loss 4.6440e-02 (4.5176e-02) Acc@1 99.71 ( 99.68) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.551 ( 6.551) Loss 4.9738e-01 (4.9738e-01) Acc@1 85.94 ( 85.94) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.824) Loss 5.1453e-01 (5.2668e-01) Acc@1 86.52 ( 86.17) Acc@5 97.27 ( 97.43) + * Acc@1 86.407 Acc@5 97.598 +Epoch: [41][ 0/40] Time 9.851 ( 9.851) Data 8.466 ( 8.466) Loss 3.7495e-02 (3.7495e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [41][10/40] Time 1.351 ( 2.125) Data 0.000 ( 0.770) Loss 3.7721e-02 (4.2016e-02) Acc@1 99.71 ( 99.71) Acc@5 100.00 (100.00) +Epoch: [41][20/40] Time 1.351 ( 1.757) Data 0.000 ( 0.403) Loss 4.6517e-02 (4.3299e-02) Acc@1 99.80 ( 99.70) Acc@5 100.00 (100.00) +Epoch: [41][30/40] Time 1.351 ( 1.626) Data 0.001 ( 0.273) Loss 4.5306e-02 (4.4237e-02) Acc@1 99.51 ( 99.68) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 9.699 ( 9.699) Loss 4.8524e-01 (4.8524e-01) Acc@1 86.33 ( 86.33) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 2.110) Loss 5.1944e-01 (5.2560e-01) Acc@1 86.62 ( 86.13) Acc@5 97.07 ( 97.42) + * Acc@1 86.310 Acc@5 97.628 +Epoch: [42][ 0/40] Time 10.336 (10.336) Data 8.951 ( 8.951) Loss 3.7365e-02 (3.7365e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [42][10/40] Time 1.351 ( 2.168) Data 0.000 ( 0.814) Loss 3.8243e-02 (4.1055e-02) Acc@1 99.71 ( 99.72) Acc@5 100.00 ( 99.99) +Epoch: [42][20/40] Time 1.351 ( 1.779) Data 0.000 ( 0.426) Loss 4.4096e-02 (4.1189e-02) Acc@1 99.90 ( 99.72) Acc@5 100.00 (100.00) +Epoch: [42][30/40] Time 1.351 ( 1.641) Data 0.000 ( 0.289) Loss 4.6960e-02 (4.2185e-02) Acc@1 99.90 ( 99.72) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.464 ( 6.464) Loss 4.9757e-01 (4.9757e-01) Acc@1 86.52 ( 86.52) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.816) Loss 5.1730e-01 (5.2998e-01) Acc@1 86.72 ( 86.29) Acc@5 97.36 ( 97.47) + * Acc@1 86.571 Acc@5 97.641 +found new best accuracy:= tensor(86.5712, device='cuda:0') +Epoch: [43][ 0/40] Time 15.035 (15.035) Data 13.653 (13.653) Loss 4.2313e-02 (4.2313e-02) Acc@1 99.32 ( 99.32) Acc@5 100.00 (100.00) +Epoch: [43][10/40] Time 1.351 ( 2.595) Data 0.000 ( 1.241) Loss 3.7204e-02 (4.2424e-02) Acc@1 99.90 ( 99.64) Acc@5 100.00 (100.00) +Epoch: [43][20/40] Time 1.351 ( 2.003) Data 0.000 ( 0.650) Loss 4.3874e-02 (4.1530e-02) Acc@1 99.61 ( 99.69) Acc@5 100.00 (100.00) +Epoch: [43][30/40] Time 1.351 ( 1.793) Data 0.000 ( 0.441) Loss 4.4127e-02 (4.1900e-02) Acc@1 99.51 ( 99.68) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.450 ( 6.450) Loss 4.9518e-01 (4.9518e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.815) Loss 5.1564e-01 (5.2840e-01) Acc@1 86.91 ( 86.27) Acc@5 97.27 ( 97.46) + * Acc@1 86.492 Acc@5 97.628 +Epoch: [44][ 0/40] Time 9.807 ( 9.807) Data 8.415 ( 8.415) Loss 3.8246e-02 (3.8246e-02) Acc@1 99.71 ( 99.71) Acc@5 100.00 (100.00) +Epoch: [44][10/40] Time 1.351 ( 2.124) Data 0.000 ( 0.768) Loss 4.2884e-02 (4.0515e-02) Acc@1 99.51 ( 99.72) Acc@5 100.00 (100.00) +Epoch: [44][20/40] Time 1.351 ( 1.756) Data 0.000 ( 0.403) Loss 4.2921e-02 (4.0521e-02) Acc@1 99.51 ( 99.73) Acc@5 100.00 (100.00) +Epoch: [44][30/40] Time 1.352 ( 1.626) Data 0.000 ( 0.273) Loss 4.2584e-02 (4.0450e-02) Acc@1 99.71 ( 99.74) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.485 ( 6.485) Loss 4.9641e-01 (4.9641e-01) Acc@1 86.43 ( 86.43) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.350 ( 1.817) Loss 5.1975e-01 (5.3550e-01) Acc@1 86.62 ( 86.13) Acc@5 97.07 ( 97.47) + * Acc@1 86.365 Acc@5 97.604 +Epoch: [45][ 0/40] Time 9.807 ( 9.807) Data 8.430 ( 8.430) Loss 4.0320e-02 (4.0320e-02) Acc@1 99.71 ( 99.71) Acc@5 99.90 ( 99.90) +Epoch: [45][10/40] Time 1.351 ( 2.120) Data 0.000 ( 0.766) Loss 4.3473e-02 (3.9954e-02) Acc@1 99.51 ( 99.70) Acc@5 100.00 ( 99.99) +Epoch: [45][20/40] Time 1.351 ( 1.754) Data 0.000 ( 0.402) Loss 4.6406e-02 (3.9719e-02) Acc@1 99.51 ( 99.71) Acc@5 100.00 (100.00) +Epoch: [45][30/40] Time 1.352 ( 1.624) Data 0.000 ( 0.272) Loss 3.9206e-02 (4.0039e-02) Acc@1 99.71 ( 99.70) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 11.067 (11.067) Loss 4.9305e-01 (4.9305e-01) Acc@1 86.43 ( 86.43) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 2.234) Loss 5.2160e-01 (5.2792e-01) Acc@1 86.33 ( 86.22) Acc@5 97.36 ( 97.54) + * Acc@1 86.383 Acc@5 97.653 +Epoch: [46][ 0/40] Time 10.382 (10.382) Data 9.005 ( 9.005) Loss 3.6970e-02 (3.6970e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [46][10/40] Time 1.352 ( 2.173) Data 0.000 ( 0.819) Loss 4.1443e-02 (3.9197e-02) Acc@1 99.51 ( 99.78) Acc@5 100.00 (100.00) +Epoch: [46][20/40] Time 1.352 ( 1.782) Data 0.000 ( 0.429) Loss 3.6320e-02 (3.9275e-02) Acc@1 100.00 ( 99.75) Acc@5 100.00 (100.00) +Epoch: [46][30/40] Time 1.351 ( 1.643) Data 0.000 ( 0.291) Loss 3.8798e-02 (3.9425e-02) Acc@1 99.80 ( 99.76) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.408 ( 6.408) Loss 4.9662e-01 (4.9662e-01) Acc@1 86.33 ( 86.33) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.810) Loss 5.2594e-01 (5.3643e-01) Acc@1 86.91 ( 86.22) Acc@5 97.07 ( 97.49) + * Acc@1 86.414 Acc@5 97.641 +Epoch: [47][ 0/40] Time 18.419 (18.419) Data 17.029 (17.029) Loss 4.2062e-02 (4.2062e-02) Acc@1 99.51 ( 99.51) Acc@5 100.00 (100.00) +Epoch: [47][10/40] Time 1.352 ( 2.904) Data 0.000 ( 1.548) Loss 3.5231e-02 (3.7321e-02) Acc@1 99.90 ( 99.75) Acc@5 100.00 (100.00) +Epoch: [47][20/40] Time 1.352 ( 2.165) Data 0.000 ( 0.811) Loss 3.6443e-02 (3.7523e-02) Acc@1 99.80 ( 99.73) Acc@5 99.90 (100.00) +Epoch: [47][30/40] Time 1.351 ( 1.903) Data 0.000 ( 0.550) Loss 4.2151e-02 (3.8208e-02) Acc@1 99.51 ( 99.71) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.546 ( 6.546) Loss 5.0181e-01 (5.0181e-01) Acc@1 86.33 ( 86.33) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.823) Loss 5.3043e-01 (5.3698e-01) Acc@1 86.52 ( 86.20) Acc@5 97.27 ( 97.51) + * Acc@1 86.420 Acc@5 97.653 +Epoch: [48][ 0/40] Time 8.935 ( 8.935) Data 7.580 ( 7.580) Loss 3.3293e-02 (3.3293e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [48][10/40] Time 1.352 ( 2.042) Data 0.000 ( 0.689) Loss 3.5823e-02 (3.5089e-02) Acc@1 99.90 ( 99.85) Acc@5 100.00 (100.00) +Epoch: [48][20/40] Time 1.351 ( 1.713) Data 0.000 ( 0.361) Loss 4.0357e-02 (3.6256e-02) Acc@1 99.71 ( 99.82) Acc@5 100.00 (100.00) +Epoch: [48][30/40] Time 1.352 ( 1.596) Data 0.000 ( 0.245) Loss 3.9522e-02 (3.7135e-02) Acc@1 99.51 ( 99.79) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 7.017 ( 7.017) Loss 4.9703e-01 (4.9703e-01) Acc@1 86.33 ( 86.33) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.866) Loss 5.3226e-01 (5.3760e-01) Acc@1 87.11 ( 86.24) Acc@5 97.17 ( 97.48) + * Acc@1 86.456 Acc@5 97.628 +Epoch: [49][ 0/40] Time 8.943 ( 8.943) Data 7.556 ( 7.556) Loss 3.2373e-02 (3.2373e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [49][10/40] Time 1.352 ( 2.136) Data 0.001 ( 0.780) Loss 3.1569e-02 (3.5076e-02) Acc@1 99.90 ( 99.77) Acc@5 100.00 (100.00) +Epoch: [49][20/40] Time 1.351 ( 1.762) Data 0.000 ( 0.409) Loss 3.7089e-02 (3.5775e-02) Acc@1 99.90 ( 99.79) Acc@5 100.00 (100.00) +Epoch: [49][30/40] Time 1.352 ( 1.630) Data 0.000 ( 0.277) Loss 3.6067e-02 (3.6395e-02) Acc@1 99.90 ( 99.79) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 9.446 ( 9.446) Loss 5.0239e-01 (5.0239e-01) Acc@1 86.62 ( 86.62) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 2.087) Loss 5.3299e-01 (5.3874e-01) Acc@1 86.43 ( 86.14) Acc@5 97.07 ( 97.47) + * Acc@1 86.365 Acc@5 97.610 +Epoch: [50][ 0/40] Time 10.221 (10.221) Data 8.835 ( 8.835) Loss 2.9383e-02 (2.9383e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [50][10/40] Time 1.352 ( 2.158) Data 0.000 ( 0.803) Loss 3.2696e-02 (3.2296e-02) Acc@1 100.00 ( 99.84) Acc@5 100.00 (100.00) +Epoch: [50][20/40] Time 1.352 ( 1.774) Data 0.000 ( 0.421) Loss 3.2766e-02 (3.2255e-02) Acc@1 99.90 ( 99.89) Acc@5 100.00 (100.00) +Epoch: [50][30/40] Time 1.352 ( 1.638) Data 0.000 ( 0.285) Loss 3.6840e-02 (3.2797e-02) Acc@1 99.90 ( 99.89) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.504 ( 6.504) Loss 4.9716e-01 (4.9716e-01) Acc@1 86.13 ( 86.13) Acc@5 97.66 ( 97.66) +Test: [10/17] Time 1.351 ( 1.820) Loss 5.3044e-01 (5.3526e-01) Acc@1 86.72 ( 86.16) Acc@5 97.07 ( 97.49) + * Acc@1 86.395 Acc@5 97.616 +Epoch: [51][ 0/40] Time 8.910 ( 8.910) Data 7.520 ( 7.520) Loss 3.4881e-02 (3.4881e-02) Acc@1 99.80 ( 99.80) Acc@5 100.00 (100.00) +Epoch: [51][10/40] Time 1.351 ( 2.039) Data 0.000 ( 0.684) Loss 2.8161e-02 (3.1784e-02) Acc@1 99.90 ( 99.92) Acc@5 100.00 (100.00) +Epoch: [51][20/40] Time 1.351 ( 1.712) Data 0.000 ( 0.358) Loss 3.3999e-02 (3.1961e-02) Acc@1 100.00 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [51][30/40] Time 1.351 ( 1.595) Data 0.000 ( 0.243) Loss 3.0054e-02 (3.2230e-02) Acc@1 99.90 ( 99.92) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.344 ( 6.344) Loss 4.9827e-01 (4.9827e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.805) Loss 5.2919e-01 (5.3532e-01) Acc@1 86.62 ( 86.23) Acc@5 97.07 ( 97.51) + * Acc@1 86.444 Acc@5 97.653 +Epoch: [52][ 0/40] Time 9.178 ( 9.178) Data 7.787 ( 7.787) Loss 3.2781e-02 (3.2781e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [52][10/40] Time 1.351 ( 2.063) Data 0.000 ( 0.708) Loss 2.9533e-02 (3.1903e-02) Acc@1 100.00 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [52][20/40] Time 1.351 ( 1.724) Data 0.000 ( 0.371) Loss 2.8467e-02 (3.1638e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [52][30/40] Time 1.351 ( 1.604) Data 0.000 ( 0.251) Loss 3.4305e-02 (3.2005e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.395 ( 6.395) Loss 4.9814e-01 (4.9814e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.809) Loss 5.2854e-01 (5.3550e-01) Acc@1 86.72 ( 86.25) Acc@5 97.07 ( 97.50) + * Acc@1 86.456 Acc@5 97.653 +Epoch: [53][ 0/40] Time 9.594 ( 9.594) Data 8.222 ( 8.222) Loss 2.9264e-02 (2.9264e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [53][10/40] Time 1.351 ( 2.101) Data 0.000 ( 0.748) Loss 3.3114e-02 (3.0093e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [53][20/40] Time 1.351 ( 1.744) Data 0.000 ( 0.392) Loss 2.9850e-02 (3.1354e-02) Acc@1 99.80 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [53][30/40] Time 1.351 ( 1.617) Data 0.000 ( 0.265) Loss 3.3346e-02 (3.1456e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.361 ( 6.361) Loss 4.9863e-01 (4.9863e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.806) Loss 5.2880e-01 (5.3568e-01) Acc@1 86.62 ( 86.19) Acc@5 97.07 ( 97.50) + * Acc@1 86.432 Acc@5 97.659 +Epoch: [54][ 0/40] Time 10.974 (10.974) Data 9.588 ( 9.588) Loss 3.0708e-02 (3.0708e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [54][10/40] Time 1.351 ( 2.227) Data 0.000 ( 0.872) Loss 3.1826e-02 (3.0996e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [54][20/40] Time 1.351 ( 1.810) Data 0.000 ( 0.457) Loss 3.1372e-02 (3.0839e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [54][30/40] Time 1.351 ( 1.662) Data 0.000 ( 0.309) Loss 3.3552e-02 (3.1233e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.398 ( 6.398) Loss 4.9949e-01 (4.9949e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.810) Loss 5.2905e-01 (5.3560e-01) Acc@1 86.52 ( 86.26) Acc@5 97.07 ( 97.50) + * Acc@1 86.474 Acc@5 97.653 +Epoch: [55][ 0/40] Time 9.008 ( 9.008) Data 7.626 ( 7.626) Loss 3.2593e-02 (3.2593e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [55][10/40] Time 1.351 ( 2.047) Data 0.000 ( 0.693) Loss 2.9940e-02 (3.1733e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [55][20/40] Time 1.351 ( 1.716) Data 0.000 ( 0.363) Loss 3.0961e-02 (3.1416e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [55][30/40] Time 1.351 ( 1.598) Data 0.000 ( 0.246) Loss 3.1477e-02 (3.1746e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 7.447 ( 7.447) Loss 4.9864e-01 (4.9864e-01) Acc@1 86.33 ( 86.33) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.905) Loss 5.2964e-01 (5.3608e-01) Acc@1 86.82 ( 86.32) Acc@5 97.07 ( 97.50) + * Acc@1 86.529 Acc@5 97.665 +Epoch: [56][ 0/40] Time 8.848 ( 8.848) Data 7.480 ( 7.480) Loss 2.8944e-02 (2.8944e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [56][10/40] Time 1.353 ( 2.034) Data 0.000 ( 0.680) Loss 3.0403e-02 (3.1094e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [56][20/40] Time 1.351 ( 1.709) Data 0.000 ( 0.356) Loss 3.3629e-02 (3.1174e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [56][30/40] Time 1.351 ( 1.594) Data 0.000 ( 0.241) Loss 3.5022e-02 (3.1686e-02) Acc@1 99.80 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.507 ( 6.507) Loss 4.9880e-01 (4.9880e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.820) Loss 5.2859e-01 (5.3538e-01) Acc@1 86.72 ( 86.26) Acc@5 97.17 ( 97.51) + * Acc@1 86.474 Acc@5 97.659 +Epoch: [57][ 0/40] Time 9.633 ( 9.633) Data 8.247 ( 8.247) Loss 3.1392e-02 (3.1392e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [57][10/40] Time 1.352 ( 2.105) Data 0.000 ( 0.750) Loss 3.2836e-02 (3.0844e-02) Acc@1 99.80 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [57][20/40] Time 1.351 ( 1.746) Data 0.000 ( 0.393) Loss 2.8042e-02 (3.1171e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [57][30/40] Time 1.351 ( 1.619) Data 0.000 ( 0.266) Loss 2.6365e-02 (3.1476e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.417 ( 6.417) Loss 4.9885e-01 (4.9885e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.811) Loss 5.3015e-01 (5.3592e-01) Acc@1 86.52 ( 86.27) Acc@5 97.07 ( 97.51) + * Acc@1 86.462 Acc@5 97.665 +Epoch: [58][ 0/40] Time 8.696 ( 8.696) Data 7.315 ( 7.315) Loss 2.7963e-02 (2.7963e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [58][10/40] Time 1.351 ( 2.019) Data 0.000 ( 0.665) Loss 3.5170e-02 (3.1982e-02) Acc@1 99.90 ( 99.89) Acc@5 100.00 (100.00) +Epoch: [58][20/40] Time 1.352 ( 1.701) Data 0.000 ( 0.348) Loss 3.3645e-02 (3.1805e-02) Acc@1 99.80 ( 99.92) Acc@5 100.00 (100.00) +Epoch: [58][30/40] Time 1.353 ( 1.589) Data 0.000 ( 0.236) Loss 3.0708e-02 (3.1493e-02) Acc@1 100.00 ( 99.93) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.598 ( 6.598) Loss 4.9911e-01 (4.9911e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.828) Loss 5.2968e-01 (5.3614e-01) Acc@1 86.72 ( 86.27) Acc@5 97.07 ( 97.49) + * Acc@1 86.474 Acc@5 97.634 +Epoch: [59][ 0/40] Time 9.296 ( 9.296) Data 7.916 ( 7.916) Loss 2.8221e-02 (2.8221e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [59][10/40] Time 1.352 ( 2.074) Data 0.000 ( 0.720) Loss 3.2042e-02 (3.0567e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [59][20/40] Time 1.351 ( 1.730) Data 0.000 ( 0.377) Loss 3.6304e-02 (3.0686e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [59][30/40] Time 1.351 ( 1.608) Data 0.000 ( 0.256) Loss 2.9484e-02 (3.1130e-02) Acc@1 100.00 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.581 ( 6.581) Loss 4.9937e-01 (4.9937e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.827) Loss 5.2928e-01 (5.3628e-01) Acc@1 86.72 ( 86.28) Acc@5 97.17 ( 97.51) + * Acc@1 86.480 Acc@5 97.671 +Epoch: [60][ 0/40] Time 9.899 ( 9.899) Data 8.520 ( 8.520) Loss 2.9403e-02 (2.9403e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [60][10/40] Time 1.352 ( 2.129) Data 0.000 ( 0.775) Loss 3.1012e-02 (3.0274e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [60][20/40] Time 1.354 ( 1.759) Data 0.000 ( 0.406) Loss 2.9814e-02 (3.0068e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [60][30/40] Time 1.351 ( 1.628) Data 0.000 ( 0.275) Loss 3.5166e-02 (3.0696e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.466 ( 6.466) Loss 4.9925e-01 (4.9925e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.816) Loss 5.2948e-01 (5.3677e-01) Acc@1 86.52 ( 86.27) Acc@5 97.07 ( 97.50) + * Acc@1 86.474 Acc@5 97.653 +Epoch: [61][ 0/40] Time 8.831 ( 8.831) Data 7.465 ( 7.465) Loss 2.7945e-02 (2.7945e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [61][10/40] Time 1.352 ( 2.032) Data 0.000 ( 0.679) Loss 3.0836e-02 (2.9632e-02) Acc@1 99.80 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [61][20/40] Time 1.352 ( 1.708) Data 0.000 ( 0.356) Loss 2.8180e-02 (3.0864e-02) Acc@1 100.00 ( 99.93) Acc@5 100.00 (100.00) +Epoch: [61][30/40] Time 1.351 ( 1.593) Data 0.000 ( 0.241) Loss 2.8866e-02 (3.1025e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.506 ( 6.506) Loss 5.0013e-01 (5.0013e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.819) Loss 5.3019e-01 (5.3687e-01) Acc@1 86.72 ( 86.27) Acc@5 97.07 ( 97.50) + * Acc@1 86.474 Acc@5 97.653 +Epoch: [62][ 0/40] Time 8.664 ( 8.664) Data 7.277 ( 7.277) Loss 3.2699e-02 (3.2699e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [62][10/40] Time 1.351 ( 2.017) Data 0.000 ( 0.662) Loss 2.8632e-02 (3.0899e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [62][20/40] Time 1.351 ( 1.700) Data 0.000 ( 0.347) Loss 3.1006e-02 (3.0882e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [62][30/40] Time 1.351 ( 1.588) Data 0.000 ( 0.235) Loss 3.0478e-02 (3.0904e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.454 ( 6.454) Loss 4.9959e-01 (4.9959e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.815) Loss 5.3001e-01 (5.3695e-01) Acc@1 86.72 ( 86.26) Acc@5 97.07 ( 97.49) + * Acc@1 86.456 Acc@5 97.647 +Epoch: [63][ 0/40] Time 9.065 ( 9.065) Data 7.673 ( 7.673) Loss 3.1245e-02 (3.1245e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [63][10/40] Time 1.351 ( 2.053) Data 0.000 ( 0.698) Loss 3.3649e-02 (3.2033e-02) Acc@1 100.00 ( 99.93) Acc@5 100.00 (100.00) +Epoch: [63][20/40] Time 1.352 ( 1.719) Data 0.000 ( 0.366) Loss 3.3206e-02 (3.1617e-02) Acc@1 99.90 ( 99.93) Acc@5 100.00 (100.00) +Epoch: [63][30/40] Time 1.351 ( 1.601) Data 0.000 ( 0.248) Loss 3.1896e-02 (3.1131e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.483 ( 6.483) Loss 4.9997e-01 (4.9997e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.817) Loss 5.3003e-01 (5.3637e-01) Acc@1 86.82 ( 86.30) Acc@5 97.07 ( 97.51) + * Acc@1 86.468 Acc@5 97.665 +Epoch: [64][ 0/40] Time 10.172 (10.172) Data 8.780 ( 8.780) Loss 3.3712e-02 (3.3712e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [64][10/40] Time 1.351 ( 2.153) Data 0.000 ( 0.798) Loss 2.9179e-02 (3.0601e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [64][20/40] Time 1.353 ( 1.772) Data 0.000 ( 0.418) Loss 3.1574e-02 (3.0677e-02) Acc@1 99.80 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [64][30/40] Time 1.351 ( 1.636) Data 0.000 ( 0.283) Loss 2.9972e-02 (3.0763e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.553 ( 6.553) Loss 5.0083e-01 (5.0083e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.824) Loss 5.2899e-01 (5.3660e-01) Acc@1 86.72 ( 86.26) Acc@5 97.17 ( 97.50) + * Acc@1 86.462 Acc@5 97.665 +Epoch: [65][ 0/40] Time 8.884 ( 8.884) Data 7.502 ( 7.502) Loss 2.9734e-02 (2.9734e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [65][10/40] Time 1.352 ( 2.037) Data 0.000 ( 0.682) Loss 2.8531e-02 (3.0759e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [65][20/40] Time 1.352 ( 1.710) Data 0.000 ( 0.357) Loss 3.0440e-02 (3.0689e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [65][30/40] Time 1.351 ( 1.595) Data 0.000 ( 0.242) Loss 3.0986e-02 (3.0855e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.563 ( 6.563) Loss 5.0090e-01 (5.0090e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.824) Loss 5.3018e-01 (5.3702e-01) Acc@1 86.52 ( 86.22) Acc@5 97.17 ( 97.53) + * Acc@1 86.414 Acc@5 97.677 +Epoch: [66][ 0/40] Time 8.724 ( 8.724) Data 7.354 ( 7.354) Loss 3.0802e-02 (3.0802e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [66][10/40] Time 1.352 ( 2.032) Data 0.000 ( 0.678) Loss 3.0812e-02 (3.0806e-02) Acc@1 100.00 ( 99.92) Acc@5 100.00 (100.00) +Epoch: [66][20/40] Time 1.352 ( 1.708) Data 0.000 ( 0.355) Loss 3.1649e-02 (3.0861e-02) Acc@1 99.90 ( 99.92) Acc@5 100.00 (100.00) +Epoch: [66][30/40] Time 1.351 ( 1.593) Data 0.000 ( 0.241) Loss 3.0635e-02 (3.0882e-02) Acc@1 100.00 ( 99.93) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.500 ( 6.500) Loss 5.0002e-01 (5.0002e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.819) Loss 5.2885e-01 (5.3698e-01) Acc@1 86.91 ( 86.26) Acc@5 97.17 ( 97.51) + * Acc@1 86.438 Acc@5 97.671 +Epoch: [67][ 0/40] Time 9.188 ( 9.188) Data 7.794 ( 7.794) Loss 2.6679e-02 (2.6679e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [67][10/40] Time 1.351 ( 2.064) Data 0.000 ( 0.709) Loss 3.0460e-02 (3.0430e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [67][20/40] Time 1.351 ( 1.725) Data 0.000 ( 0.371) Loss 2.9934e-02 (3.0648e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [67][30/40] Time 1.352 ( 1.605) Data 0.000 ( 0.252) Loss 3.3262e-02 (3.1244e-02) Acc@1 99.80 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.498 ( 6.498) Loss 5.0023e-01 (5.0023e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.352 ( 1.819) Loss 5.3067e-01 (5.3723e-01) Acc@1 86.62 ( 86.23) Acc@5 97.17 ( 97.54) + * Acc@1 86.450 Acc@5 97.689 +Epoch: [68][ 0/40] Time 9.433 ( 9.433) Data 8.048 ( 8.048) Loss 3.3265e-02 (3.3265e-02) Acc@1 99.80 ( 99.80) Acc@5 100.00 (100.00) +Epoch: [68][10/40] Time 1.351 ( 2.086) Data 0.000 ( 0.732) Loss 2.9544e-02 (3.1097e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [68][20/40] Time 1.352 ( 1.736) Data 0.000 ( 0.383) Loss 2.8613e-02 (3.0776e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [68][30/40] Time 1.352 ( 1.613) Data 0.000 ( 0.260) Loss 3.6062e-02 (3.0564e-02) Acc@1 99.80 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.546 ( 6.546) Loss 5.0005e-01 (5.0005e-01) Acc@1 86.23 ( 86.23) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.823) Loss 5.3133e-01 (5.3680e-01) Acc@1 86.82 ( 86.26) Acc@5 97.07 ( 97.51) + * Acc@1 86.432 Acc@5 97.665 +Epoch: [69][ 0/40] Time 10.232 (10.232) Data 8.863 ( 8.863) Loss 2.9488e-02 (2.9488e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [69][10/40] Time 1.352 ( 2.159) Data 0.000 ( 0.806) Loss 2.7870e-02 (2.9907e-02) Acc@1 100.00 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [69][20/40] Time 1.352 ( 1.774) Data 0.000 ( 0.422) Loss 3.0940e-02 (3.0478e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [69][30/40] Time 1.353 ( 1.638) Data 0.000 ( 0.286) Loss 2.8601e-02 (3.0579e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.423 ( 6.423) Loss 5.0018e-01 (5.0018e-01) Acc@1 86.04 ( 86.04) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.812) Loss 5.3075e-01 (5.3755e-01) Acc@1 86.91 ( 86.22) Acc@5 97.07 ( 97.51) + * Acc@1 86.450 Acc@5 97.659 +Epoch: [70][ 0/40] Time 8.779 ( 8.779) Data 7.400 ( 7.400) Loss 2.9655e-02 (2.9655e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [70][10/40] Time 1.351 ( 2.027) Data 0.000 ( 0.673) Loss 2.7534e-02 (3.0493e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [70][20/40] Time 1.351 ( 1.705) Data 0.000 ( 0.352) Loss 2.9834e-02 (3.1003e-02) Acc@1 99.90 ( 99.94) Acc@5 100.00 (100.00) +Epoch: [70][30/40] Time 1.351 ( 1.591) Data 0.000 ( 0.239) Loss 2.8321e-02 (3.0638e-02) Acc@1 100.00 ( 99.94) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.957 ( 6.957) Loss 5.0020e-01 (5.0020e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.860) Loss 5.3063e-01 (5.3737e-01) Acc@1 86.91 ( 86.27) Acc@5 97.07 ( 97.51) + * Acc@1 86.468 Acc@5 97.671 +Epoch: [71][ 0/40] Time 9.488 ( 9.488) Data 8.109 ( 8.109) Loss 2.8361e-02 (2.8361e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [71][10/40] Time 1.356 ( 2.092) Data 0.000 ( 0.737) Loss 3.3129e-02 (2.9744e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [71][20/40] Time 1.351 ( 1.739) Data 0.000 ( 0.386) Loss 3.4954e-02 (2.9980e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [71][30/40] Time 1.351 ( 1.614) Data 0.000 ( 0.262) Loss 3.2394e-02 (3.0643e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.483 ( 6.483) Loss 5.0031e-01 (5.0031e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.817) Loss 5.3062e-01 (5.3732e-01) Acc@1 86.91 ( 86.23) Acc@5 97.07 ( 97.51) + * Acc@1 86.444 Acc@5 97.665 +Epoch: [72][ 0/40] Time 10.296 (10.296) Data 8.907 ( 8.907) Loss 3.0989e-02 (3.0989e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [72][10/40] Time 1.352 ( 2.165) Data 0.000 ( 0.810) Loss 3.2618e-02 (3.0994e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [72][20/40] Time 1.352 ( 1.778) Data 0.000 ( 0.424) Loss 3.1270e-02 (3.0653e-02) Acc@1 100.00 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [72][30/40] Time 1.351 ( 1.640) Data 0.000 ( 0.287) Loss 3.2131e-02 (3.0594e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.487 ( 6.487) Loss 5.0031e-01 (5.0031e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.817) Loss 5.3061e-01 (5.3722e-01) Acc@1 86.82 ( 86.21) Acc@5 97.07 ( 97.51) + * Acc@1 86.432 Acc@5 97.665 +Epoch: [73][ 0/40] Time 10.463 (10.463) Data 9.072 ( 9.072) Loss 3.0279e-02 (3.0279e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [73][10/40] Time 1.351 ( 2.180) Data 0.000 ( 0.825) Loss 2.6930e-02 (3.0245e-02) Acc@1 100.00 ( 99.98) Acc@5 100.00 (100.00) +Epoch: [73][20/40] Time 1.351 ( 1.785) Data 0.000 ( 0.432) Loss 2.9132e-02 (3.0048e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [73][30/40] Time 1.351 ( 1.645) Data 0.000 ( 0.293) Loss 2.9238e-02 (3.0441e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.514 ( 6.514) Loss 5.0038e-01 (5.0038e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.820) Loss 5.3062e-01 (5.3727e-01) Acc@1 86.91 ( 86.23) Acc@5 97.07 ( 97.51) + * Acc@1 86.438 Acc@5 97.665 +Epoch: [74][ 0/40] Time 8.801 ( 8.801) Data 7.424 ( 7.424) Loss 3.0766e-02 (3.0766e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [74][10/40] Time 1.351 ( 2.029) Data 0.000 ( 0.675) Loss 3.0983e-02 (3.0309e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [74][20/40] Time 1.351 ( 1.706) Data 0.000 ( 0.354) Loss 2.7376e-02 (3.0411e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [74][30/40] Time 1.351 ( 1.592) Data 0.000 ( 0.240) Loss 3.2483e-02 (3.0302e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.497 ( 6.497) Loss 5.0047e-01 (5.0047e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.819) Loss 5.3071e-01 (5.3730e-01) Acc@1 86.91 ( 86.23) Acc@5 97.07 ( 97.51) + * Acc@1 86.444 Acc@5 97.665 +Epoch: [75][ 0/40] Time 8.740 ( 8.740) Data 7.351 ( 7.351) Loss 2.8989e-02 (2.8989e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [75][10/40] Time 1.351 ( 2.023) Data 0.000 ( 0.668) Loss 3.0388e-02 (3.0511e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [75][20/40] Time 1.351 ( 1.703) Data 0.000 ( 0.350) Loss 2.9898e-02 (3.0060e-02) Acc@1 99.90 ( 99.97) Acc@5 100.00 (100.00) +Epoch: [75][30/40] Time 1.351 ( 1.590) Data 0.000 ( 0.237) Loss 3.0593e-02 (3.0128e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.562 ( 6.562) Loss 5.0035e-01 (5.0035e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.824) Loss 5.3072e-01 (5.3728e-01) Acc@1 86.91 ( 86.22) Acc@5 97.07 ( 97.51) + * Acc@1 86.432 Acc@5 97.665 +Epoch: [76][ 0/40] Time 9.532 ( 9.532) Data 8.140 ( 8.140) Loss 3.1997e-02 (3.1997e-02) Acc@1 99.90 ( 99.90) Acc@5 100.00 (100.00) +Epoch: [76][10/40] Time 1.352 ( 2.095) Data 0.000 ( 0.740) Loss 3.1135e-02 (3.0363e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [76][20/40] Time 1.351 ( 1.741) Data 0.000 ( 0.388) Loss 2.8467e-02 (3.0590e-02) Acc@1 99.90 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [76][30/40] Time 1.351 ( 1.615) Data 0.000 ( 0.263) Loss 3.0943e-02 (3.0248e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.462 ( 6.462) Loss 5.0036e-01 (5.0036e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.815) Loss 5.3055e-01 (5.3728e-01) Acc@1 87.01 ( 86.23) Acc@5 97.07 ( 97.51) + * Acc@1 86.444 Acc@5 97.665 +Epoch: [77][ 0/40] Time 8.648 ( 8.648) Data 7.274 ( 7.274) Loss 3.1342e-02 (3.1342e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [77][10/40] Time 1.351 ( 2.015) Data 0.000 ( 0.661) Loss 2.8359e-02 (3.1518e-02) Acc@1 99.90 ( 99.95) Acc@5 100.00 (100.00) +Epoch: [77][20/40] Time 1.352 ( 1.699) Data 0.000 ( 0.346) Loss 2.5942e-02 (3.0884e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [77][30/40] Time 1.351 ( 1.587) Data 0.000 ( 0.235) Loss 3.1415e-02 (3.0260e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.537 ( 6.537) Loss 5.0023e-01 (5.0023e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.351 ( 1.822) Loss 5.3044e-01 (5.3722e-01) Acc@1 87.01 ( 86.22) Acc@5 97.07 ( 97.51) + * Acc@1 86.432 Acc@5 97.665 +Epoch: [78][ 0/40] Time 8.825 ( 8.825) Data 7.440 ( 7.440) Loss 2.4183e-02 (2.4183e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [78][10/40] Time 1.352 ( 2.031) Data 0.000 ( 0.676) Loss 3.1926e-02 (2.9425e-02) Acc@1 100.00 ( 99.99) Acc@5 100.00 (100.00) +Epoch: [78][20/40] Time 1.352 ( 1.707) Data 0.000 ( 0.354) Loss 3.1820e-02 (3.0009e-02) Acc@1 99.90 ( 99.98) Acc@5 100.00 (100.00) +Epoch: [78][30/40] Time 1.351 ( 1.593) Data 0.000 ( 0.240) Loss 3.0308e-02 (3.0179e-02) Acc@1 100.00 ( 99.97) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.499 ( 6.499) Loss 5.0037e-01 (5.0037e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.819) Loss 5.3065e-01 (5.3730e-01) Acc@1 87.01 ( 86.25) Acc@5 97.07 ( 97.51) + * Acc@1 86.450 Acc@5 97.665 +Epoch: [79][ 0/40] Time 10.329 (10.329) Data 8.953 ( 8.953) Loss 3.0395e-02 (3.0395e-02) Acc@1 100.00 (100.00) Acc@5 100.00 (100.00) +Epoch: [79][10/40] Time 1.351 ( 2.168) Data 0.000 ( 0.814) Loss 3.0754e-02 (3.0475e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [79][20/40] Time 1.351 ( 1.779) Data 0.000 ( 0.426) Loss 2.7916e-02 (3.0284e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Epoch: [79][30/40] Time 1.351 ( 1.641) Data 0.000 ( 0.289) Loss 2.9172e-02 (3.0278e-02) Acc@1 100.00 ( 99.96) Acc@5 100.00 (100.00) +Test: [ 0/17] Time 6.559 ( 6.559) Loss 5.0048e-01 (5.0048e-01) Acc@1 86.13 ( 86.13) Acc@5 97.75 ( 97.75) +Test: [10/17] Time 1.350 ( 1.824) Loss 5.3059e-01 (5.3730e-01) Acc@1 87.01 ( 86.25) Acc@5 97.07 ( 97.51) + * Acc@1 86.456 Acc@5 97.665 +Final best accuracy tensor(86.5712, device='cuda:0') diff --git a/log/n60_cb_retrieval b/log/n60_cb_retrieval new file mode 100644 index 0000000..2bc95a9 --- /dev/null +++ b/log/n60_cb_retrieval @@ -0,0 +1,226 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +65 +66 +67 +68 +69 +70 +71 +72 +73 +74 +75 +76 +77 +78 +79 +80 +81 +82 +83 +84 +85 +86 +87 +88 +89 +90 +91 +92 +93 +94 +95 +96 +97 +98 +99 +100 +101 +102 +103 +104 +105 +106 +107 +108 +109 +110 +111 +112 +113 +114 +115 +116 +117 +118 +119 +120 +121 +122 +123 +124 +125 +126 +127 +128 +129 +130 +131 +132 +133 +134 +135 +136 +137 +138 +139 +140 +141 +142 +143 +144 +145 +146 +147 +148 +149 +150 +151 +152 +153 +154 +155 +156 +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +32 +33 +34 +35 +36 +37 +38 +39 +40 +41 +42 +43 +44 +45 +46 +47 +48 +49 +50 +51 +52 +53 +54 +55 +56 +57 +58 +59 +60 +61 +62 +63 +64 +Number of KNN Neighbours = 1 +training feature and labels torch.Size([40091, 4096]) 40091 +test feature and labels torch.Size([16487, 4096]) 16487 +KNN retrieval acc = 0.7616303754473221 diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/model/__pycache__/__init__.cpython-310.pyc b/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fc5fb7b580f71aa7b7f3020190ad8b802b104a2 GIT binary patch literal 152 zcmd1j<>g`kf~h@eDIoeWh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o6vSKR2&LKP|Di zL_aS-DKR;_v_QYKIJKx)Ke^n%(9FR$kC6F&&fjOe>i~w2p)q77+?f49Dul(1xTbY1T$zd`mJOr0tq9CUsn1V`MIh3l@-Yu zsd@Uv*{M0HCHZ;riOD6IK&mJ;IX^uw6U5ff%}+_q(T|VM1oGqK^$IF)ao7OGN^?@} KKxP&LF#`ZRs3Wle literal 0 HcmV?d00001 diff --git a/model/__pycache__/ctrgcn.cpython-310.pyc b/model/__pycache__/ctrgcn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c71fdefb9978c1db1138794e6e48ad77ce3343 GIT binary patch literal 9737 zcmb7K>vJ4eR`2)pj7Fo;cq~8SBubpePS~|$CnVVr9A}*en@vPmi?GZl46W&wH1bT3 ze7nbSv?Ev|gS`~mEm(FbV1dy>0WK(tD&R9;fWKj@zX6|AK-E(D0AxeV?{~U;G_tKE zw5spiKKI$@p2zQ;R)2cBq~ZF`Z!G)YeN5B-l@F7bg@^OFqklo+Yrfvqn(43J)aBV| zn(}KkE%~*3PSerZ$LJP%Zqr5G^sQ@Jv*_EtgWr;0@Ll{)`9;5k-?BgDm+?FOisn!I zm1~+`37oacmf4&EWyY_9Qq3q;P-guaD7B0-3(6tC4oW?v)Id4x&w(vFW`=zLNU}v;L47^ zsmJx9IJ+vSfSSol2W^mwlE0eC^36|Ke0%-si`vN_~7%WdtrQf z*^A=SVSmYMUmC2R9z=nPPPgBA{LE8lo<7~~`@!j#u6U<0Ht3#i$7-b=o?3r5DQ9c8 zyIvF(F$VcB=f$gNzO3S;*zSi>?1eGNOC2w2IAmjZ zC2^OO7q(Y}NL4X!qKAob=?w0(!rru)W=oabTWRkW7n%h%;u$!4N&X(u>$XEW!A@B~6sSoXqJe-Pi)@s2oD zVN%19H&8V!p$?%)W-fW1UME~>WeibA*>nU&Ls!RG=k6z7kxaTG?{*vwnKab!Zgt(d zrA~sjPa1e2<2QQ<53#nUj|}YD9LgDvpRK(!y=kOp+OYR_Dtk(I&n-5L#8?WU8*Q%} z+`h-cQYbro039Q)UM1HbMP(n?_wDTmXeF1|ps+*ikh1CUAG&--Q;XBv0 zWt_inUMs?l;k4m~tm_cvrVX1@NXm-=jj`7yTgCRYY54N1fQpOT#2qc6h&6ZtZOa%L zI~r8j9Gbp<6b6oET{{9#aTI>>xVGuU=9)Fs=@oo^*%&&$i4w|z61|rV-`a)z?mJ*Apj-RmTq!*Y9+_xYG|CE;MPd z9w_xB2J>`jyjm?e&cxw6p7q5XBvWG<0p(%z1~aTE=q@Kyov_tj^+K!{C1seN@ss*- zZs}@@?K>ame-y zoQkqdM+{p-z1Bw1ZK=Vf-_#cxP-^s$_qb`OpF};;uih?ch2_3_$5Z|nI3rJ)rrgpu z>iag6b;6W#3M_6q3eZWT>xKy)bg3`pR_pS>>!#Wx@2p;8pXp4w+5u$0HL-;k(5i$% zJR9D!2Nx`=PRn`SA<{NF@ru}cEYc9?6&SMb>kwyS*%kpO?xKLFxNA?qi^-AK2i>@H zk><5E`JTJvt-NBxnvl0Ifx63gUTjnD^NV$!Fb73nfl$WU77jslzyS8Kw*U%`XmgPL zOGb3)RT^6TJ$^JX5`zM;hc23vp#mkjkd`mio@VC?BvwfFR}Ff>;(T3rLW zaz=&g+I9W9v79Qg8@tdwd&_tcXiRj6#=1Bvq32ZWtd(!*H#FaY;up3p(57zydVpT> zUB9@^+KgYi0rLPo1UB8~3esoCrs2K&Tv`9=ZF}4sdW+sp^i99w&s@W}Ry90w16q$a z`yTuA*RPJ-it#L4hXu6FZd>dRVXno8a&zCPE$k`&#^WdFp?~waHqJ+@{Xy5CUkc`X zJk|AX5Y6|O=TFY(z-FIct8^S1Ruo={!RI4)kAqISs6J*oPjKI3Q#=djo5Z)JYzfWCc;G!NvMa&5- z?u&VS&+I1db6u~uh*>Otf+K|MObu@Q>fh5W7+|py|qA8P~?=pne{-_egJ| z$c2MWqfu9%C!G!=F}p$7EQf<$2D5`GvEp~vgTx|xGeLP>tnF(cB<4nET^;90$6H?y zLSMbXxyJINeuizu*S)K6^!weje&VVi>i96xsV}rKs{5`$hlVR2A{C1@*B7#!pf3ifio`nRWV$;w6m@UszM7NqboeaZ^50Rp4K)($i8pNFbbK9N?=Uk9@D(R~oL$RC;Hk_C3+yI5sOmQazk*5 zJ!lQ~RIyoBfVsJyZ|oZRCE1zXj~H{(sY+0?F%z~2S3VoI<13##hjP#;sy-T#r3k!Gt+U@{7BPv% zK^Vd32O9`YNEAEb1c)RDy+&cel;|%cw%1)>^%CPMQZrZ2HjS%KOep)6EW~2%gfHdE+doI?#4d7Qf3m$$NtG+F?c7ffnQaN6Dvo)8SB&<20Y4Q+adMr6}$pQ1vDrGjKDVdhcImht!34!qicx zURWY9(XNNtMf|Y^jT)qS_b7;WsNA0CRp&YAF6#BSphv`wh?9^4fO6d`n@kQ=^^Mt_ zt7Bc-=Oqb(Xop3W>Zph+Wx-<}J_?!va+6KVJFxrRMDFiF&(cs3w(5B8J$W0S?ZG~Z z?E4!|l*?IlN|AP;oT*AkA;=^ z`0Xw9#@n3ro^0F++B)ffZ@wy8??e559~n8+=i}nq1I)XL20REcIGH)PV*Jpsh}J4x z*(iQ^SQH2M?Tk7L>R-f<42wSA84m8-sB3HY`N#-n8dILgS|5m<)mL#op8_@ifZduX zHcpCgF?XNmMkeErPXOxb=Qz(pEW}eC1k#gyxtD?W;QKBQIskI573KNr3ffY=U)lbT zzZ?D6KmGN&#-phL5NuBa&_tvc2ddR>+L9$|A}8N!C$8sv>v87_@`KT;2Si^Y%8v$1 zQIMI7U&P?lTStkFl>ADNqD#R8Z=i7x)+3+TeNIAkWNgWEXnN;B?C_5m6rD#=VMLJq5hqk& z>MX&-Wwf{-mF<$@AeVN+s0o^==w)LMQ;bdFq{Qqqg|s$fQ^>9Q{FZ|}xnkl(-!wyN zxG@5n7+E{S3ua7!2tOe}5a?m8FtlJJfFI-+nd&VQVMB@}KFcuTh!z5K;1k#!G^|MN z=l3Gi0c3Y{1R1PP<@79&Y?M>heG8jPPs?f#d@m{C+gQsFu5{Xgy1}+zV{x6u_#73Z z2syaiv%CmGqQ5H>so#=^jl0?Me?yZbC@p#vy=qkSqkwBynYHrCdA;)Dg@x8DxM?W_Pb}{8fW>#WiM{iS zb;c_r9A8EjavX)LWl%(yQWJErVf~br z&Vox!4eO(mZ2K7Fnj$m|8(*yRW>)6Y8uUWGu6hfjAnB6-V8D=uJ(P~bf|Q++=oCQ< zs{}g}Cdua`Ap#J##N6Y6C6)50*)E;P$drr1z!&%^9zaktP4@rd&hqlKBL=`!1@Ska z6%=63L|%SEs>bZuXGthV#2Obk)k7$t)4U-TX=Fl=OFOcIgm7vBa!t2V8~ghhq<)VD zca1dRIFq2hL#k|fZ{IJ0Rt6&ldd=&AUiE#u{KT{X`TZg8OfOI9)fq+F9eQbj$SZmb zMGozYda>Y_1>PffED3>cN(%gj6#S^mU zCou~89GxV!F9Z^pVq z4S=`RcC@$PGdJ`2T_SsN1W4KJ*!<84ZNdjc`7a6{l7*<8I)>y?8pP!7^voE6loAZJ zi8=O-rx-vSki(bJh#zG%vjTaokk7~re}-)3K)*zW@Ib}AAb~oIdQt!rdxKd1Z3XH4 zl|g@iEbf&~@Q(UZ_Wd&!f6n42iv4ju{ypYGqAw=;n~DDV-Bi9oUNdu4(?5e`%9HTR z^1aU3m`*vIJLZD=11|B0EdGdvd^jDk#%+I$!=_Zuav7=~@`cP4-l`r^q&Mc&Vf^45 zr@xfgf;81O2r2$_<;(Y}Fcf3mtfY_LAnrq&{SBduDT_=q;n#JGZo#!{m0G>FP`kG_TYF~i G@c#i3)cDK* literal 0 HcmV?d00001 diff --git a/model/__pycache__/ctrgcn.cpython-36.pyc b/model/__pycache__/ctrgcn.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b2a848df4447da2872323f373f0b915c60e267 GIT binary patch literal 11458 zcmb_ieQX?OTA#0-on5cj_GaTaZZ373G>wm5+Bj*Nav@EcTza6^6k<}e^lm%enYDMl zo^{@tO&#C*)T`TT_{gaU5O)Fu5(rv8j^oftRsZ16`wJw5_~SGZ2sngfA*2H#f$k6B z_j_h`cAdB%2dq8syz_p|JJ0*PzvuTnGiS!fi?{yr+poNPLDRmgjr=U+FX0IP4vD9E zdPl1zXT7G&wNW$WY}G6|=h}A7M$YKu*PWW9Yo6&@Z)mlGm-B3#i(cMya31ptUJ>V# zH|CXa9)C{r#=Y_znpgJic6rCFO`v4Ln?%WES~7`}DX)T(N?J07k^^29CDpW~f|7&Y zG)kt^k^?9?Eg<{#L_Q)-LzdhFbARm)xO_zj&!_)!u( zuHXoN7)f7iLmoT&wjSwy-7^kr$QzqSk%vUy#PwC}bLDL_GWtf}+%dL|o5x?$`UdV} zEynsO#gTWJ%$sp8+~~BTxLB{Z);GE;s@LQEiXXXAq~@*I3fy%+E;hQT3;ZAoV{^Uh z#TE;%eErXfgjdy84~R8t4O4Ok zWMN6pL@pN@ZL^JQPmiDT67x;d?(bfv;NTjzRmqd>|<1U^UdF&$8os+tpTtjzKSwL^{O*g1_dl7Wh z;zEjW=53j|s^DJg0FroO&26o>f|YtI2sOj1LrCUzb(ne1dgRe4oLlltv#1CiB-DTk z4)rkd+4^${OdPWQ$B~J&Hl71BHT$w^*|jyYibk@E^SR+1Wj4vYIk|ZwHkJcygofMk z@10$KIgq(MgoYudsp*rtdIaZ8&YV{!=Rd;*<=V#X>|=J+OmA$>RLuzT_8?KU5Ls>L zB3k7=%ga&vQfH&?eUv(~4(~?==A&YJjOEg&B(-_2O0XoOI~cu$(cjNT|Ei43m6Tk2 zyj>>VvCn>Y;fQE6;pKPqbgjo1Yt|L4i70*|rU&s)2fh3eOz`}?h5erET-@lz+*Z!#{_>$1gxEl#~oNI}kOwgNBCueu?zW@L3J>f!3d z#a7so-P7>nT+~&K)!1J0gRrY)9b?OLBUe??Db6)polYI=sjp+t5mb1udzZA7Y~4`S zl`{;dk!ZRD#G&3jEZ9@UI#+eIfSLxY2S3pPT!ifDYqUi*!?WMgnm}UDe4_xXgjI!I zvEG7O)N(+Qd|X=ciLP!Z6KsQU#vdV#N>9}emytvoU|riW+QyECjcN8xPoDu~F|BHc z0R1yCaW5;WCT!Q$#gPqdb@Rd4)eq2w&QN1n;7~6cl1y445 zJeKJPN{;a^c>to{X~tu%px#(@1Kcl+O8}MZN}c7D&M|q4&G+45I5whoR8{M2BQd$5H2Rlw`kw?uF-3lqoS47%J@P zp0xrM_Kapu)Y|s)qSlUA*rV3P_|OZzPSm<*i*g`2dStH| z>J_wMUdBM*ATnRqgL2iHyrsRRzhyL&t?B@3P2APEN;<$@QpDH*UbHOX z7{8<6(L5Wn$=|hbUA}{MI8S(x$X(_py{S6@0qnSnH+GkIkTyGb5;3j&jFq-8-OXjK zu|v_?ep?TZCWI6-nF4T}y@}}32?0GWQ4hoo;X=Q6b0r9wwL-x0L4=KkHJOC72Lgf$}ln_Yg zLdRWS_S}oFu_q2~YvBiZlBDwo&!bPqtS^4Z6x5I7rmqj~J`e_ljy%OM5RxFHXu?57 zd=mVN&5j?$<;!8{E1IFtDAiT9QqWsZS+O6+R&;yAk1a|))r=PeP@hLZY;LtS6d7i0 zyBiyR;He+w*qOp8YIIz9!QFhZ+wClPv7`L3xty=>X2uj(bElG~d5sCCV(Ak6@!eFcZ?XaGlI1&tK|<>nl&b@-Cg@TfpE z=uh=cFm8dho6|3WQG$TbV8Dk!U9N9DrSS}1Fc*{c)ML|Yt%H739I^oF26K-giOq%^ zVH?w=iq^&^#8n$hww6J3PZ#PvODEmra8(L1F7q!~rffD(zEQPv$}>%k9_ z#J~(ErdQEsD+*y~YSmR^|O!iwKruTunDrQ#fZH=dN!F72oYxsl^N z@Z*9qPpA&RhvXhEN!tcD87Y?2(_tY$$p8W@=XnqANj(P#S*zPY6|B;nrG3dE2P?}` z*eM4JLANo-kjh>5DC}EW94RG}NZQh#a)iXW9F&TM-QgX|n-@?!vBU7i6L8%p`?QS? z+@HUSrc$MQFD3lUol{(A0 z*M|GJR#G6q1FKLgt+%}FN)SchBF=FwPA8{@;1Y23){49?2xgiaf$Mkby!8q9H~E)WQ{}zoX;ImMcaMbW=0n zOl{}4shn2dhI+#r?&ooae?7zK8ohP9&psw1M}~Ol9YmROq^3z4fI34wc0Z@SB8qN% z`T?z7@CqcAq_^ALEJbWFyYK9n(()bT3r$4k`u6oifr}#o71wNQB53q(Jr*>gTc5g! zw72_TbgjF1I1%eGH$SEtC`EY2cLU|IfzM=>g@s-af+qP}zG||N_^YN_V}`x;*!JOM z_txi&Loiky*5=&K#;S|B6=H&FjU~n=LR6azIL^dI=iX42K=ONN7m}Q5xl{VA@kz_l zXP|=?L3w2T1~3fbj;6^bwGz9rcnwzeBug`Jj`Q`p*KNS>kzE_G@-maFO!kLrZt}0; zkg$?t!g?NbY{gPaD*I7o_}lB$=LMpn`QrT~LFA!R@5D}zo^R9$l#9OiBQZUDlcPpR z_98JtG2+#G*hw_24DXD1@=&XRCy#vxyaf7M?R@IPJIG5=`>yyZ(HNDJY(Gi@M8gm! zb|)WgNdNC+mSEUIMCZp7RQv*)Q@*j;3KANXaHPYm5coRAoUkPekkw0o?GY5-hiOAz zq@H2@zJB~)pe8Zj&=AI4%0>w}xXx1p`Q8sz20Br+SJ=6V1R~%ep8H_dW zC6d=?gCo%c_tsEUcRA{jwBILizoh-q!Cg`NAo6owNyd2y`G*-XmCbc*^I6>YxKNUF z+r{XSei8LkJ9NhvqZ9oiw_PxkmV=a_d?9+YU&MVtOWxLRE`TLJiZLGa5Xa3Ttiv}8 z|7Vx#->|faLw8J6mhV;u@vc!{^rJkukLgK@8~5*ndtQTs4&rjN*8*{`uPE0OtQnxn z%H4nY)^GjCx4&^wjMV(e1Y^F4-m+n8@a-dC)f=^(!~kn}hf;6Ej_bJ_QR|i;7Z4hz zXWwObe$|C{SER=m_Lf6G1*atpmXNP7&f#s!il6uwLhtsvcnB4DUk&2Ii>^Xg7&lJ% zNU^EFuQ4v(NZD>EPNWfR;Dtx0>qci_=w7_62so)hgaz$}ikL1z%{YI(+l^xM+9zYy%8hQbS@+kMeTllN*Vyw*OvH&Sr(zmJO!d}>@WH=D zw~&mIe9#d_m_$X%04L!@!U~If2^8i8Eh`>1}I=7TMgqh`6HIRq;mP7s38o06$PbIpf6% zNE}>VA%itKzlXE{&M{uMo&w+j5(#R2L5X~kL=(<2^J9Z@U^c?brDVB(flt{%f_H;G z&y9S%U@k9N#66bxd9+4_elEO`7;0P$9neh^$gO|8JF+Sa3^ z()$p5lCWNMlONAKvhd``3=5te+!s%L7X2mYiH&id`Cnl|N-CNXTTsf=5@~^(%`E+l z5!9trgjWrGnY~z>gn6aWYiudE6LEeCm-|!H6P#WZ)fxb@Y7$5pi7R2tluW#88d0=N zR2lYAvB>JK;0TW+0Rr&IoKeyad4-L49tmC~xL_q#|e%S_zOw5_v~L72jAS$ z7!l&fIL(6aT7HhhAlSHkkfV<<+H#yj2-olIa->~xK^SdFVm!j3z;OpVxr~}@C;t;B zxpyaHcWG;mIvQaDcqvqk5YvRXi~V8~FIT_9T}!~noB3@CRPoI{{PltdWAUcZ#8)-j z3{@5TrSxxX1+U1k`;BexU#$*UM_P*T#+`V>}#BdQ8NKP4!+#agHeA6 z?T_^tlhNwO;Sc9DwvX?Q{8#zl@fd=Wh^qSmj<`U!EFVH7*^TbNsULYce;Qw5@EeNyav*yd zL7mg;Hb*u(jPRVr@g$COIG)0B9>>#oA9MPN*nYkh_^yhLW<$RIkNll89<}bOl`sj?{wdqO#)PXG z=bp<%QOjMfLk_jlvK#t?pzZ!>eu>L*9{qvlh6C#30LdK$sAfte%2goyv{^+UqH4_I z!-6BmQKoi9KMIO$1NpbgdzZ06i06-9Hi1!~{)@Ox1iFaGA3~_G`Oh-{vbk+;TcC~n zatTC|(2rQ8C2vScaV+F>cpVOa1Uey}(a!H^FXzz9MnA?VK~cf7@qvf5&Pi*I<@EE| zmBbwj#1wl4!FJqR#z^adK-Nx9`WxsE-x^oEo7A&6W|ZHtq<10hosw56!!DH(;p1ny z@hn%)a`67sc9U5S{(f3kNptj;Pxkd~+zH=8kt~BICv$-Z?Z`az_5mU>Qr7dS@NzxFuH2*hmsOJu%45PpcQ1e;r9|-D>%-ib}?LsN;GM zI|-ja!Vg{Pb$g+@$(H;8PyH^F-($l4qTXThE|cG9@&`=*kjWo0`8tz7X0pLVbooz^ zL)bQpOV!HCk9sPoH+>gpKi2WV=2FeLaV6IAeh&wwmcpHeJ?u!#=%w#QqJbK-a_B3a^&JMxX!sUS$+;g*w2w-V@)MOam8r^SE3^Lt61iCU literal 0 HcmV?d00001 diff --git a/model/ctrgcn.py b/model/ctrgcn.py new file mode 100644 index 0000000..217f54a --- /dev/null +++ b/model/ctrgcn.py @@ -0,0 +1,307 @@ +import math +import pdb + +import numpy as np +import torch +import torch.nn as nn +from torch.autograd import Variable + +def import_class(name): + components = name.split('.') + mod = __import__(components[0]) + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + + +def conv_branch_init(conv, branches): + weight = conv.weight + n = weight.size(0) + k1 = weight.size(1) + k2 = weight.size(2) + nn.init.normal_(weight, 0, math.sqrt(2. / (n * k1 * k2 * branches))) + nn.init.constant_(conv.bias, 0) + + +def conv_init(conv): + if conv.weight is not None: + nn.init.kaiming_normal_(conv.weight, mode='fan_out') + if conv.bias is not None: + nn.init.constant_(conv.bias, 0) + + +def bn_init(bn, scale): + nn.init.constant_(bn.weight, scale) + nn.init.constant_(bn.bias, 0) + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + if hasattr(m, 'weight'): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if hasattr(m, 'bias') and m.bias is not None and isinstance(m.bias, torch.Tensor): + nn.init.constant_(m.bias, 0) + elif classname.find('BatchNorm') != -1: + if hasattr(m, 'weight') and m.weight is not None: + m.weight.data.normal_(1.0, 0.02) + if hasattr(m, 'bias') and m.bias is not None: + m.bias.data.fill_(0) + + +class TemporalConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1): + super(TemporalConv, self).__init__() + pad = (kernel_size + (kernel_size-1) * (dilation-1) - 1) // 2 + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=(kernel_size, 1), + padding=(pad, 0), + stride=(stride, 1), + dilation=(dilation, 1)) + + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class MultiScale_TemporalConv(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilations=[1,2,3,4], + residual=True, + residual_kernel_size=1): + + super().__init__() + assert out_channels % (len(dilations) + 2) == 0, '# out channels should be multiples of # branches' + + # Multiple branches of temporal convolution + self.num_branches = len(dilations) + 2 + branch_channels = out_channels // self.num_branches + if type(kernel_size) == list: + assert len(kernel_size) == len(dilations) + else: + kernel_size = [kernel_size]*len(dilations) + # Temporal Convolution branches + self.branches = nn.ModuleList([ + nn.Sequential( + nn.Conv2d( + in_channels, + branch_channels, + kernel_size=1, + padding=0), + nn.BatchNorm2d(branch_channels), + nn.ReLU(inplace=True), + TemporalConv( + branch_channels, + branch_channels, + kernel_size=ks, + stride=stride, + dilation=dilation), + ) + for ks, dilation in zip(kernel_size, dilations) + ]) + + # Additional Max & 1x1 branch + self.branches.append(nn.Sequential( + nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0), + nn.BatchNorm2d(branch_channels), + nn.ReLU(inplace=True), + nn.MaxPool2d(kernel_size=(3,1), stride=(stride,1), padding=(1,0)), + nn.BatchNorm2d(branch_channels) # 为什么还要加bn + )) + + self.branches.append(nn.Sequential( + nn.Conv2d(in_channels, branch_channels, kernel_size=1, padding=0, stride=(stride,1)), + nn.BatchNorm2d(branch_channels) + )) + + # Residual connection + if not residual: + self.residual = lambda x: 0 + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + else: + self.residual = TemporalConv(in_channels, out_channels, kernel_size=residual_kernel_size, stride=stride) + + # initialize + self.apply(weights_init) + + def forward(self, x): + # Input dim: (N,C,T,V) + res = self.residual(x) + branch_outs = [] + for tempconv in self.branches: + out = tempconv(x) + branch_outs.append(out) + + out = torch.cat(branch_outs, dim=1) + out += res + return out + + +class CTRGC(nn.Module): + def __init__(self, in_channels, out_channels, rel_reduction=8, mid_reduction=1): + super(CTRGC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + if in_channels == 3 or in_channels == 9: + self.rel_channels = 8 + self.mid_channels = 16 + else: + self.rel_channels = in_channels // rel_reduction + self.mid_channels = in_channels // mid_reduction + self.conv1 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1) + self.conv2 = nn.Conv2d(self.in_channels, self.rel_channels, kernel_size=1) + self.conv3 = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1) + self.conv4 = nn.Conv2d(self.rel_channels, self.out_channels, kernel_size=1) + self.tanh = nn.Tanh() + for m in self.modules(): + if isinstance(m, nn.Conv2d): + conv_init(m) + elif isinstance(m, nn.BatchNorm2d): + bn_init(m, 1) + + def forward(self, x, A=None, alpha=1): + x1, x2, x3 = self.conv1(x).mean(-2), self.conv2(x).mean(-2), self.conv3(x) + x1 = self.tanh(x1.unsqueeze(-1) - x2.unsqueeze(-2)) + x1 = self.conv4(x1) * alpha + (A.unsqueeze(0).unsqueeze(0) if A is not None else 0) # N,C,V,V + x1 = torch.einsum('ncuv,nctv->nctu', x1, x3) + return x1 + +class unit_tcn(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=9, stride=1): + super(unit_tcn, self).__init__() + pad = int((kernel_size - 1) / 2) + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(kernel_size, 1), padding=(pad, 0), + stride=(stride, 1)) + + self.bn = nn.BatchNorm2d(out_channels) + self.relu = nn.ReLU(inplace=True) + conv_init(self.conv) + bn_init(self.bn, 1) + + def forward(self, x): + x = self.bn(self.conv(x)) + return x + + +class unit_gcn(nn.Module): + def __init__(self, in_channels, out_channels, A, coff_embedding=4, adaptive=True, residual=True): + super(unit_gcn, self).__init__() + inter_channels = out_channels // coff_embedding + self.inter_c = inter_channels + self.out_c = out_channels + self.in_c = in_channels + self.adaptive = adaptive + self.num_subset = A.shape[0] + self.convs = nn.ModuleList() + for i in range(self.num_subset): + self.convs.append(CTRGC(in_channels, out_channels)) + + if residual: + if in_channels != out_channels: + self.down = nn.Sequential( + nn.Conv2d(in_channels, out_channels, 1), + nn.BatchNorm2d(out_channels) + ) + else: + self.down = lambda x: x + else: + self.down = lambda x: 0 + if self.adaptive: + self.PA = nn.Parameter(torch.from_numpy(A.astype(np.float32))) + else: + self.A = Variable(torch.from_numpy(A.astype(np.float32)), requires_grad=False) + self.alpha = nn.Parameter(torch.zeros(1)) + self.bn = nn.BatchNorm2d(out_channels) + self.soft = nn.Softmax(-2) + self.relu = nn.ReLU(inplace=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + conv_init(m) + elif isinstance(m, nn.BatchNorm2d): + bn_init(m, 1) + bn_init(self.bn, 1e-6) + + def forward(self, x): + y = None + if self.adaptive: + A = self.PA + else: + A = self.A.cuda(x.get_device()) + for i in range(self.num_subset): + z = self.convs[i](x, A[i], self.alpha) + y = z + y if y is not None else z + y = self.bn(y) + y += self.down(x) + y = self.relu(y) + + + return y + + +class TCN_GCN_unit(nn.Module): + def __init__(self, in_channels, out_channels, A, stride=1, residual=True, adaptive=True, kernel_size=5, dilations=[1,2]): + super(TCN_GCN_unit, self).__init__() + self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive) + self.tcn1 = MultiScale_TemporalConv(out_channels, out_channels, kernel_size=kernel_size, stride=stride, dilations=dilations, + residual=False) + self.relu = nn.ReLU(inplace=True) + if not residual: + self.residual = lambda x: 0 + + elif (in_channels == out_channels) and (stride == 1): + self.residual = lambda x: x + + else: + self.residual = unit_tcn(in_channels, out_channels, kernel_size=1, stride=stride) + + def forward(self, x): + y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x)) + return y + + +class Model(nn.Module): + def __init__(self, base_channel=64, + num_class=60, num_point=25, num_person=2, graph='graph.ntu_rgb_d.Graph', in_channels=3, adaptive=True): + super(Model, self).__init__() + + if graph is None: + raise ValueError() + else: + Graph = import_class(graph) + self.graph = Graph() + + A = self.graph.A # 3,25,25 + + self.num_class = num_class + self.num_point = num_point + self.data_bn = nn.BatchNorm1d(num_person * in_channels * num_point) + + self.l1 = TCN_GCN_unit(in_channels, base_channel, A, residual=False, adaptive=adaptive) + self.l2 = TCN_GCN_unit(base_channel, base_channel*4, A, adaptive=adaptive) + self.l3 = TCN_GCN_unit(base_channel*4, base_channel, A, adaptive=adaptive) + + bn_init(self.data_bn, 1) + + + def forward(self, x): + N, C, T, V, M = x.size() + + x = x.permute(0, 4, 3, 1, 2).contiguous().view(N, M * V * C, T) + x = self.data_bn(x) + x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) + x = self.l1(x) + x = self.l2(x) # N*M,C,T,V + x = self.l3(x) + + return x \ No newline at end of file diff --git a/options/__pycache__/options_classification.cpython-310.pyc b/options/__pycache__/options_classification.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6f3cdeea8c099489663c3432ca36f5d89331fe8 GIT binary patch literal 4151 zcmbVPPjlNu6qjU0{wF(snrYfjXDAIx>J(c^2;tI!K$vOa;6n))PiJ-%ubilnC9#s* z#KR4~0&vL}$Y)@f;dAJ~g#+UYmvEtQ;JsCB$%-tyD6=p7cHi!MZ{Kg+8Wh;y$yDaeG+i{(5{w&7(z<~tEJH9R*vRClN{W3T;C(M-y?fl zq|F@9Bd<*3bz6ib!w>rvXa~@)C=lEOy#*h^OYlub|In)o^;>}mzDUa}NL5ZND*~+^ zi*%kYT+G3|1T%|#W{J~ViTNBAz_#V);78Ck_!Y?<60~E0P7MU00~vZj`H<7LiGYt6&2>(8HlR{y-d zw^jf2=|NLF!=u%F1`#6|G4j$eZNFi==lw@h#4Y{Euo~{E)qj7AB-+?vde^{x|LE}V z&rb&Xcs6rEk?VDJ|D;vodlKHSj#h>;dq^ZYIE*RBTF z#B^+L_@D3~n;o-_r(`$W^PoTgXQodGV)beVh$Ks5+F#)>Wl<3oQHI&VKT2uXSKsvC zNxzPb;oeO+15ji`cnJwggmfL+v#_rc@=Z^-!Ws4+sOPT6$SRKE^}$Sbg}0|5s1hHU zm^v8VIY<+CZ`$Scj$RsV|3;=zUczUaogEb30>afQb(7fwBB{CeD3XwZ$S`EMF-bhpb7a*(w zc7y=YVwWyvKQJ9lY z)}u-5Fyvzz*0Lm{{WF2_6Zkem1trBQsIU(e)LK?(XMxieDE|Qum%|CUw}jKqjBsMx zuy<|tkTa|)lOf*;YyFu*o+=xKB#}X)kSEqISBQn7l7*D9+p78IU@752wi2zyl^vQU+(>(vL< Jwd%&w`hO8y5WoNc literal 0 HcmV?d00001 diff --git a/options/__pycache__/options_pretraining.cpython-310.pyc b/options/__pycache__/options_pretraining.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90098c97bd7d4522779a930053f1ba9bd551526e GIT binary patch literal 2576 zcmb_eOK%e~5cY01uQo{uAX*;c{ZN2(+X~7BL7_l7eX^>AC9B2?VJVh=OIZ4iSv3J=SQbzgPy)~dn*=lo$O4pOc|dtU zNk9cQ1!xLT3eW*o1XKi+W(OaW(jk|CJ3G+Hj)CSin35GuuDgtJk4m@2ql_0e={k2< zl*LiQ+2W!ztTh}dTi5KSr|h~TmF@XYo!YZ7utUj(v}-Tt<`?I0*fpPV`=Pt+mz$33 z+00Q+4K~`>Oh8M}+lRq$QZ*QpsG)#snt(N@2}-+WvLs7=vf#G_Bfy@nw90m6^-1~B%H8t4^5R_i z&dPGd-oy=TCBScre$x%#_0+5vY_+bAn2uuAh129X``gD4*goGD80gLtm~AV{YRbS- z>(xSUI@LJcpu$nEA5FMk5Gs8Xu%gI4R8vq#tc6YPDH-KrpQI=Nb5!g&T<3gms@ygr zOY%m2n}{PY?HG!R_VyWZMtuoYz}bv;*6ng`x?4ou2SSt%oB zm_}~5b!zaL&30E+ASj^lKqpG02}+wj3mZ7jQTjY|8l7Ln3=nJAV`dt^Wx-5BZ(Cq8 zBxmX`L-R{uvO_K%{nalSKHcCoH8hQF6cMPT zKuyH~GfE z_+^9Td$+AG%Wp2*py~gpVhMxRDxMo2)HpJp2BJ$yre#DlQpTA{`e(!@?Rz#xp>39jUa8h(^kILhlzm3*r3~YJRUijLOIMS>TSB z2u85>LV2)QGFBc^mDXjuPp;qf>vQ&IKVbu3y7*B3-8FsyZuA~cv-@~n_+3on=!ntX zgWk~*(I4u#Fs6>;9B`NRskG3ke*aUYRSe}knAUz}e)4i0Wi9~ABc#(AhNx1?h$4Q$ z5~_=ceg?&9gqVT?^%FQX_7L}8&?s}WQ@U>9zaSFdX45pG2*PRZTIqsW=(MSWg&!=2 BRU`la literal 0 HcmV?d00001 diff --git a/options/options_classification.py b/options/options_classification.py new file mode 100644 index 0000000..99199a9 --- /dev/null +++ b/options/options_classification.py @@ -0,0 +1,183 @@ +data_path = "./data" + +class opts_ntu_60_cross_view(): + + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512*4, + "num_head": 8, + "num_layer": 1, + "num_class": 60 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xview/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + +class opts_ntu_60_cross_subject(): + + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 60 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + + +class opts_ntu_120_cross_subject(): + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 120 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + +class opts_ntu_120_cross_setup(): + + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 120 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_pku_part1_cross_subject(): + + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 51 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v1/xsub/train_data_joint.npy", + "label_path": data_path + "/pku_v1/xsub/train_label.pkl", + "num_frame_path": data_path + "/pku_v1/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + "data_path": data_path + "/pku_v1/xsub/val_data_joint.npy", + "label_path": data_path + "/pku_v1/xsub/val_label.pkl", + "num_frame_path": data_path + "/pku_v1/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_pku_part2_cross_subject(): + def __init__(self): + + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 51 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v2/xsub/train_data_joint.npy", + "label_path": data_path + "/pku_v2/xsub/train_label.pkl", + "num_frame_path": data_path + "/pku_v2/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + "data_path": data_path + "/pku_v2/xsub/val_data_joint.npy", + "label_path": data_path + "/pku_v2/xsub/val_label.pkl", + "num_frame_path": data_path + "/pku_v2/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } diff --git a/options/options_pretraining.py b/options/options_pretraining.py new file mode 100644 index 0000000..cf525d6 --- /dev/null +++ b/options/options_pretraining.py @@ -0,0 +1,96 @@ +# Sequence based model arguments +encoder_arguments = { + "hidden_size": 512*4, + "num_head": 8, + "num_layer": 1, + "num_class": 128 +} + +data_path = "./data" + + +class opts_ntu_60_cross_view(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_data_joint.npy", + "num_frame_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + +class opts_ntu_60_cross_subject(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_data_joint.npy", + "num_frame_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + +class opts_ntu_120_cross_subject(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_data_joint.npy", + "num_frame_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + +class opts_ntu_120_cross_setup(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_data_joint.npy", + "num_frame_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + + +class opts_pku_part1_cross_subject(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v1/xsub/train_data_joint.npy", + "num_frame_path": data_path + "/pku_v1/xsub/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + +class opts_pku_part2_cross_subject(): + + def __init__(self): + + self.encoder_args = encoder_arguments + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v2/xsub/train_data_joint.npy", + "num_frame_path": data_path + "/pku_v2/xsub/train_num_frame.npy", + "l_ratio": [0.1, 1], + "input_size": 64 + } + diff --git a/options/options_retrieval.py b/options/options_retrieval.py new file mode 100644 index 0000000..ce52262 --- /dev/null +++ b/options/options_retrieval.py @@ -0,0 +1,178 @@ +data_path = "./data" + +class opts_ntu_60_cross_view(): + + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 60 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-60-AGCN/xview/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xview/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xview/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_ntu_60_cross_subject(): + + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 60 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-60-AGCN/xsub/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-60-AGCN/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_ntu_120_cross_subject(): + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 120 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-120-AGCN/xsub/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_ntu_120_cross_setup(): + + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 120 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_data_joint.npy", + "label_path": data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + 'data_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_data_joint.npy", + 'label_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_label.pkl", + 'num_frame_path': data_path + "/NTU-RGB-D-120-AGCN/xsetup/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_pku_part1_cross_subject(): + + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 51 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v1/xsub/train_data_joint.npy", + "label_path": data_path + "/pku_v1/xsub/train_label.pkl", + "num_frame_path": data_path + "/pku_v1/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + "data_path": data_path + "/pku_v1/xsub/val_data_joint.npy", + "label_path": data_path + "/pku_v1/xsub/val_label.pkl", + "num_frame_path": data_path + "/pku_v1/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + +class opts_pku_part2_cross_subject(): + def __init__(self): + # Sequence based model + self.encoder_args = { + "hidden_size": 512 * 4, + "num_head": 8, + "num_layer": 1, + "num_class": 51 + } + + # feeder + self.train_feeder_args = { + "data_path": data_path + "/pku_v2/xsub/train_data_joint.npy", + "label_path": data_path + "/pku_v2/xsub/train_label.pkl", + "num_frame_path": data_path + "/pku_v2/xsub/train_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } + + self.test_feeder_args = { + + "data_path": data_path + "/pku_v2/xsub/val_data_joint.npy", + "label_path": data_path + "/pku_v2/xsub/val_label.pkl", + "num_frame_path": data_path + "/pku_v2/xsub/val_num_frame.npy", + 'l_ratio': [0.95], + 'input_size': 64 + } diff --git a/pretraining.py b/pretraining.py new file mode 100644 index 0000000..a58b162 --- /dev/null +++ b/pretraining.py @@ -0,0 +1,293 @@ +import argparse +import math +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data + +import scd.builder +from torch.utils.tensorboard import SummaryWriter +from dataset import get_pretraining_set + + + +parser = argparse.ArgumentParser(description='Training') +parser.add_argument('-j', '--workers', default=24, type=int, metavar='N', + help='number of data loading workers (default: 32)') +parser.add_argument('--epochs', default=200, type=int, metavar='N', + help='number of total epochs to run') +parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') +parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, + metavar='LR', help='initial learning rate', dest='lr') +parser.add_argument('--schedule', default=[100, 160], nargs='*', type=int, + help='learning rate schedule (when to drop lr by 10x)') +parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum of SGD solver') +parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') +parser.add_argument('-p', '--print-freq', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') + +parser.add_argument('--checkpoint-path', default='./checkpoints/pretrain/', type=str) +parser.add_argument('--skeleton-representation', type=str, + help='input skeleton-representation for self supervised training (joint or motion or bone)') +parser.add_argument('--pre-dataset', default='ntu60', type=str, + help='which dataset to use for self supervised training (ntu60 or ntu120)') +parser.add_argument('--protocol', default='cross_subject', type=str, + help='training protocol cross_view/cross_subject/cross_setup') + +# specific configs: +parser.add_argument('--encoder-dim', default=128, type=int, + help='feature dimension (default: 128)') +parser.add_argument('--encoder-k', default=16384, type=int, + help='queue size; number of negative keys (default: 16384)') +parser.add_argument('--encoder-m', default=0.999, type=float, + help='momentum of updating key encoder (default: 0.999)') +parser.add_argument('--encoder-t', default=0.07, type=float, + help='softmax temperature (default: 0.07)') + +parser.add_argument('--cos', action='store_true', + help='use cosine lr schedule') + +parser.add_argument('--gpu', default=0) + + +def main(): + args = parser.parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + cudnn.deterministic = True + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + # pretraining dataset and protocol + from options import options_pretraining as options + if args.pre_dataset == 'ntu60' and args.protocol == 'cross_view': + opts = options.opts_ntu_60_cross_view() + elif args.pre_dataset == 'ntu60' and args.protocol == 'cross_subject': + opts = options.opts_ntu_60_cross_subject() + elif args.pre_dataset == 'ntu120' and args.protocol == 'cross_setup': + opts = options.opts_ntu_120_cross_setup() + elif args.pre_dataset == 'ntu120' and args.protocol == 'cross_subject': + opts = options.opts_ntu_120_cross_subject() + elif args.pre_dataset == 'pku_part1' and args.protocol == 'cross_subject': + opts = options.opts_pku_part1_cross_subject() + elif args.pre_dataset == 'pku_part2' and args.protocol == 'cross_subject': + opts = options.opts_pku_part2_cross_subject() + + opts.train_feeder_args['input_representation'] = args.skeleton_representation + + # create model + print("=> creating model") + + model = scd.builder.SCD_Net(opts.encoder_args, args.encoder_dim, args.encoder_k, args.encoder_m, args.encoder_t) + print("options",opts.train_feeder_args) + print(model) + + # single gpu training + model = model.cuda() + + criterion = nn.CrossEntropyLoss().cuda() + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + checkpoint = torch.load(args.resume) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + del checkpoint + torch.cuda.empty_cache() + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + # Data loading code + train_dataset = get_pretraining_set(opts) + + train_sampler = None + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + writer = SummaryWriter(args.checkpoint_path) + + for epoch in range(args.start_epoch, args.epochs): + + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + loss, acc1 = train(train_loader, model, criterion, optimizer, epoch, args) + writer.add_scalar('train_loss', loss.avg, global_step=epoch) + writer.add_scalar('acc', acc1.avg, global_step=epoch) + + if epoch % 50 == 0: + save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, is_best=False, filename=args.checkpoint_path+'/checkpoint.pth.tar') + + +def train(train_loader, model, criterion, optimizer, epoch, args): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + progress = ProgressMeter( + len(train_loader), + [batch_time, losses, top1,], + prefix="Epoch: [{}] Lr_rate [{}]".format(epoch, optimizer.param_groups[0]['lr'])) + + # switch to train mode + model.train() + + end = time.time() + for i, (q_input, k_input) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + q_input = q_input.float().cuda(non_blocking=True) + k_input = k_input.float().cuda(non_blocking=True) + + # compute output + output1, output2, output3, output4, target1, target2, target3, target4 = model(q_input, k_input) + + batch_size = output2.size(0) + + # interactive level loss + loss = criterion(output1, target1) + criterion(output2, target2) + criterion(output3, target3) \ + + criterion(output4, target4) + + losses.update(loss.item(), batch_size) + + # measure accuracy of model m1 and m2 individually + # acc1/acc5 are (K+1)-way contrast classifier accuracy + # measure accuracy and record loss + acc1, _ = accuracy(output2, target2, topk=(1, 5)) + top1.update(acc1[0], batch_size) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + return losses, top1 + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries), flush=True) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + if args.cos: # cosine lr schedule + lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + else: # stepwise lr schedule + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d811954 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,96 @@ +absl-py==1.4.0 +addict==2.4.0 +anykeystore==0.2 +brotlipy==0.7.0 +cachetools==5.3.0 +click==8.1.3 +clip==0.2.0 +colorama==0.4.6 +commonmark==0.9.1 +contourpy==1.0.6 +cryptacular==1.6.2 +cycler==0.11.0 +decord==0.6.0 +defusedxml==0.7.1 +einops==0.5.0 +filelock==3.8.0 +fonttools==4.38.0 +fvcore==0.1.5.post20221221 +google-auth==2.17.1 +google-auth-oauthlib==1.0.0 +greenlet==2.0.0 +grpcio==1.53.0 +huggingface-hub==0.10.1 +hupper==1.10.3 +imageio==2.22.3 +imgaug==0.4.0 +iopath==0.1.10 +kiwisolver==1.4.4 +Markdown==3.4.1 +MarkupSafe==2.1.1 +matplotlib==3.6.1 +mkl-fft==1.3.1 +mkl-service==2.4.0 +mmcv==1.5.1 +model-index==0.1.11 +networkx==2.8.8 +oauthlib==3.2.2 +opencv-contrib-python==4.6.0.66 +opencv-python==4.6.0.66 +openmim==0.3.2 +ordered-set==4.1.0 +packaging==21.3 +pandas==1.5.1 +PasteDeploy==3.0.1 +pbkdf2==1.3 +Pillow==9.2.0 +plaster==1.0 +plaster-pastedeploy==0.7 +portalocker==2.7.0 +protobuf==4.22.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +Pygments==2.13.0 +pyparsing==3.0.9 +pyramid==2.0 +pyramid-mailer==0.15.1 +python-dateutil==2.8.2 +python3-openid==3.2.0 +pytz==2022.6 +PyWavelets==1.4.1 +PyYAML==6.0 +repoze.sendmail==4.4.1 +requests-oauthlib==1.3.1 +rich==12.6.0 +roi-align==0.0.2 +rsa==4.9 +scikit-image==0.19.3 +scipy==1.9.3 +Shapely==1.8.5.post1 +spatial-correlation-sampler==0.4.0 +SQLAlchemy==1.4.42 +tabulate==0.9.0 +tensorboard==2.12.1 +tensorboard-data-server==0.7.0 +tensorboard-plugin-wit==1.8.1 +termcolor==2.3.0 +thop==0.1.1.post2209072238 +tifffile==2022.10.10 +timm==0.6.11 +torch==1.13.0 +torchaudio==0.13.0 +torchvision==0.14.0 +tqdm==4.64.1 +transaction==3.0.1 +translationstring==1.4 +velruse==1.1.1 +venusian==3.0.0 +WebOb==1.8.7 +Werkzeug==2.2.3 +WTForms==3.0.1 +wtforms-recaptcha==0.3.2 +yacs==0.1.8 +yapf==0.32.0 +zope.deprecation==4.4.0 +zope.interface==5.5.0 +zope.sqlalchemy==1.6 diff --git a/scd/__pycache__/builder.cpython-310.pyc b/scd/__pycache__/builder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00bc9a63cfa0c5fd0bfff03e6570ff3db3b10530 GIT binary patch literal 4216 zcmahMO^@5gb!PY}t)$)6%GzCf4ZEn5v`vt0iv%btxFBc_rKLV7 z?Qmqr0tFOZ2fozJsdov;#mCzKz{)8h1F(y*Vx9OU3B1=yfz_2(jx!V9A?YQGox10Dz;hvt_Chz-we zc;X=dNv3p0K7XAO$r3WvX5^T@OMbmMp+M8Ap3{{HOW4$y!DhQmT##nRn%W7QQA&e_eL=`sm@9Vq$Nba`0SnfuH zyPn@24R=Q|Ouy^j-@E+I0O<a1L(r! zU;lW^H`S^cz-?RbpqIo=@Eb`aO~^5uvKeS6T0&ec(E;j-0nm^pr@#g37+h6ayb_Xx zenEc^TCV*ZU693ubZqnlg%uRGg1l*IM&Mn*_b~(itK4A|b7Cb{r#!7BW~Ypv##v&y zt4BA$PuI96*QeElCEAS5AY|l5r`9uuJ{W5l~?GWUFe@x`lT23c^S@T)vJSni%x!dq8+jzr)95kxTt&l@~wd83}(cP?M~eq$m1;l6V;3PynwhvVP}PBiM@4TLjjIZ@Dt zsJj<9-QYpIqXJGv)gSl+0I-f6^-*&)d40+pdaa*G=h%Deb596Hx)az}QF& z`FXfUCump2b*hcSp(vn`4i@qt&kTkS+_Wr$qfsb=ILl zi+Chf5HGeyk&hYC>!pPVQW&PlKtMeKIY0(6q{Bhf7T*AlxB#GGr+OUpTFuh3Gbuxf zZ==xb2)+X#wUlLH_c_u9yA@p5?|E_Ty2>141Es$O05fC!CV+}-~b7^kE7fKL*ttvr30z_iyQ%o`N*6^)Tzy_oMie;w*38!P! zAhC3GNH>6Qks8?~33K(M70$4zjIS=7{iC^;Di;mVS>>Vf@VTEZMBA`BT>rRr&@csl zp?CuUo}OqRcnblVt)Yplh<^{kI|$G!Cx%LQ+jWfC)Zz4oY676UppbeqyIZ+7q~axQh3DEhLOM_ga&1WP6Fbq= z+J#)(PV6OZh3DEh0tEoIReQh=RDfggj}JYFM)cfNgVHT_08{R(fXe$MMC->*GtS?kdZ+XtrMH`g#Gr2)arO%F_55+xTn_OHDbt0Bgkx0ryRc4-yC||M zE3zGMeh7k7?+P7nXNUU|eIIkj{MrU&(TQ~+cR{s)gPqYk+C;;nm9)b;S^~$y*<&`* zavDNF(RxlBIc-SQ9en;a)M+v`RlO!tI7T%i)Y6HaYrs*d8O)r9MpDsMPQzg++J>%9 zJJr>Fe5$L{r#9E2`rB=&s?!38y-y0T3*_pjxW$}nB<7Or%7SYxaZ8G8V6V|BPFH3G zdjcGvl=BgE>eC~Yd<4Uxd`f3^X{2&$WKHUw)ufP=+bsXyb;T(rl?A7&I4jBOg0rDG zrDScvsU_>M>e_2$vYu3PteR}(*hW&zu^MQuC+nRmH@6Ak8{B%#rnMQFoKYjymyFca z$i_=XHq}V&B_msK`ez`Zt;C)|Hx9eQ-PDpZNj=$2wpuo%OMCp6MGpLFG#ttO1&Qs5 zgnr^;-|GxSz#Ts5!|w!i=^*CB7!;8oKtJI_c|;b(htti8Q?tZF&;S3^440?LiN!QO z8_5?o*Bi7$_-nvrCDge<#tU)P0f=)DSO;S}^52g9i#MZj0WiM{SXX5h@fwUb>RG18 zN-OpdTtjG`DmK7i7TsRrsgRNo*qc_2OiFiT08jZ)3` zWR|9`;3WG9zKa98KjNOif4;`H%2TOzlzn|_bu*;Nj%3P?Vt9us>jK^dyc;&{o;w`G zu3TX80t@AQG>j30o-UU|ER#bllT%pSEM#Ux!8r<&a}-1uD9UF xW$j(Fco@#R1JzPy&7=8|+NDxlvZY!QS4-_1*>~K;E;z;>*e39I$!xt|_%D&zAC>?B literal 0 HcmV?d00001 diff --git a/scd/__pycache__/hi_encoder.cpython-310.pyc b/scd/__pycache__/hi_encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6b20b3557ef8ca62afcd3f625ea7efea3430573 GIT binary patch literal 3163 zcmb7G%W~XA6jiq#qmjpBXDZ|kB`B!!5Mq-6Nq`VSUSyF>QN*xNQKeGU8jl=}#%XEn zWNN%}sBFlRy}k5?FX0QcW#wO>fN*YWUKJ-)&{SXDzPjpZ^*!fw_vYsv4cD)~w8g@L zru{*k*~>uZ8d|mqBQ(Jh%`5M$%RSD>=LzfTp58S)L-qB<>{_0sd?TrJZO>-fx0*18 z^;8p9#Mi4ky5}HL5jGVctNb4tojPoajC8NAA(^fd11+3zs9wxjK`J&kEYHq?B~Zky%$knK>< zn0RJx?N-oi!H4XqHq_Sn9!I~bJOdy42&^{SotRVpKsRkc{P?AP3XdS zZfqNQeP|5z9mM>F8Glh|SU0nEa1GydZ~rqxjoqeM7+I9Gi@8oLM1=p12T@_C{jT4M0#Q`SO9*Pk z{OIknyAR~~ewYMV=KK5FpI278X};19vV0}&JqW^w{f(7=7RhWS-1_+Z#q*a|2+1qA zHiH$QrUx^g!t%xwS;GXWWlJy)vzTEx%wYy!gtb|nS$`Rosejb2u|+m$Ou)H3UdoU@ zb$k)vXhm$Y=V8ZSV z&K%03oTj_9N_h-n`AVfrqkO9;AG&AV z4cS|d!n`N5gDd#Qkl;4H$fp@Ajr$Ewt^riq4&UVyozrszdIy;709gJ+-Jmz^1WX!_9D<9n3rjP>jda;Mry8m24V{I9WVO2zcns(cUAZqCa&+$5=4S-5%fs49gs z<*{6(=n*mvGDqpTmHV_rg_Vu&vAXY@3X<{^C7vd82F6orqclaT>bvwrL4d^JCYln6 zx^8jKEoQwi%sS+u#;m`M%G7J9nDh7G^x=xWS}CYZ#3ZE@ci@2KN-2Kg+gx!TMFRV= zkw7U(2h3bwH>hxHN=KgaZA&SGg=ZCP4d$@Y*6(ZkdP}~K@#F_EO-p`C{%2$^!W1Tu z6Ltzc403siGVF3)mMDYRC~3#`_yDvWWT0*gKb|3pS{&k$K+9;AKwRBn4esDKIQ5DJ z7%_LB$xo0ji6}DzPFY9g{st|h!xOQN>-86lu7bhqdAE$@mv~p6F~+>Bs#6A=VV+YQ zx!K!Fvm7Ux?wb-%p(JkO9HKts5b>y*ig;hZc{Ab(kW-NZ6a&F*MFQR`IZ!P*@HIwz zi34R#75EAF1#gG)GNru&!gRnat7MEJe8L=@I`sJyl->^MBgPj|ZUR%P&kRg;oJkl? zJWz-UB(YJ+qeJsP<>^Cn4+l;8IbOq#pnyngBflhbjiS~Re4hOB_xy#7< literal 0 HcmV?d00001 diff --git a/scd/builder.py b/scd/builder.py new file mode 100644 index 0000000..a84b9ed --- /dev/null +++ b/scd/builder.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn + +from .scd_encoder import PretrainingEncoder + + +# initilize weight +def weights_init(model): + with torch.no_grad(): + for child in list(model.children()): + for param in list(child.parameters()): + if param.dim() == 2: + nn.init.xavier_uniform_(param) + print('weights initialization finished!') + + +class SCD_Net(nn.Module): + def __init__(self, args_encoder, dim=3072, K=65536, m=0.999, T=0.07): + """ + args_encoder: model parameters encoder + dim: feature dimension (default: 128) + K: queue size; number of negative keys (default: 2048) + m: moco momentum of updating key encoder (default: 0.999) + T: softmax temperature (default: 0.07) + """ + super(SCD_Net, self).__init__() + + self.K = K + self.m = m + self.T = T + + print(" moco parameters", K, m, T) + + self.encoder_q = PretrainingEncoder(**args_encoder) + self.encoder_k = PretrainingEncoder(**args_encoder) + weights_init(self.encoder_q) + weights_init(self.encoder_k) + + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + # domain level queues + # temporal domain queue + self.register_buffer("t_queue", torch.randn(dim, K)) + self.t_queue = nn.functional.normalize(self.t_queue, dim=0) + self.register_buffer("t_queue_ptr", torch.zeros(1, dtype=torch.long)) + + # spatial domain queue + self.register_buffer("s_queue", torch.randn(dim, K)) + self.s_queue = nn.functional.normalize(self.s_queue, dim=0) + self.register_buffer("s_queue_ptr", torch.zeros(1, dtype=torch.long)) + + # instance level queue + self.register_buffer("i_queue", torch.randn(dim, K)) + self.i_queue = nn.functional.normalize(self.i_queue, dim=0) + self.register_buffer("i_queue_ptr", torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, t_keys, s_keys, i_keys): + N, C = t_keys.shape + + assert self.K % N == 0 # for simplicity + + t_ptr = int(self.t_queue_ptr) + # replace the keys at ptr (dequeue and enqueue) + self.t_queue[:, t_ptr:t_ptr + N] = t_keys.T + t_ptr = (t_ptr + N) % self.K # move pointer + self.t_queue_ptr[0] = t_ptr + + s_ptr = int(self.s_queue_ptr) + # replace the keys at ptr (dequeue and enqueue) + self.s_queue[:, s_ptr:s_ptr + N] = s_keys.T + s_ptr = (s_ptr + N) % self.K # move pointer + self.s_queue_ptr[0] = s_ptr + + i_ptr = int(self.i_queue_ptr) + # replace the keys at ptr (dequeue and enqueue) + self.i_queue[:, i_ptr:i_ptr + N] = i_keys.T + i_ptr = (i_ptr + N) % self.K # move pointer + self.i_queue_ptr[0] = i_ptr + + def forward(self, q_input, k_input): + """ + Input: + time-majored domain input sequence: qc_input and kc_input + space-majored domain input sequence: qp_input and kp_input + Output: + logits and targets + """ + + # compute temporal domain level, spatial domain level and instance level features + qt, qs, qi = self.encoder_q(q_input) # queries: NxC + + qt = nn.functional.normalize(qt, dim=1) + qs = nn.functional.normalize(qs, dim=1) + qi = nn.functional.normalize(qi, dim=1) + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + kt, ks, ki = self.encoder_k(k_input) # keys: NxC + + kt = nn.functional.normalize(kt, dim=1) + ks = nn.functional.normalize(ks, dim=1) + ki = nn.functional.normalize(ki, dim=1) + + # interactive loss + + l_pos_ti = torch.einsum('nc,nc->n', [qt, ki]).unsqueeze(1) + l_pos_si = torch.einsum('nc,nc->n', [qs, ki]).unsqueeze(1) + l_pos_it = torch.einsum('nc,nc->n', [qi, kt]).unsqueeze(1) + l_pos_is = torch.einsum('nc,nc->n', [qi, ks]).unsqueeze(1) + + l_neg_ti = torch.einsum('nc,ck->nk', [qt, self.i_queue.clone().detach()]) + l_neg_si = torch.einsum('nc,ck->nk', [qs, self.i_queue.clone().detach()]) + l_neg_it = torch.einsum('nc,ck->nk', [qi, self.t_queue.clone().detach()]) + l_neg_is = torch.einsum('nc,ck->nk', [qi, self.s_queue.clone().detach()]) + + logits_ti = torch.cat([l_pos_ti, l_neg_ti], dim=1) + logits_si = torch.cat([l_pos_si, l_neg_si], dim=1) + logits_it = torch.cat([l_pos_it, l_neg_it], dim=1) + logits_is = torch.cat([l_pos_is, l_neg_is], dim=1) + + logits_ti /= self.T + logits_si /= self.T + logits_it /= self.T + logits_is /= self.T + + labels_ti = torch.zeros(logits_ti.shape[0], dtype=torch.long).cuda() + labels_si = torch.zeros(logits_si.shape[0], dtype=torch.long).cuda() + labels_it = torch.zeros(logits_it.shape[0], dtype=torch.long).cuda() + labels_is = torch.zeros(logits_is.shape[0], dtype=torch.long).cuda() + + # dequeue and enqueue + self._dequeue_and_enqueue(kt, ks, ki) + + return logits_ti, logits_si, logits_it, logits_is, \ + labels_ti, labels_si, labels_it, labels_is, \ No newline at end of file diff --git a/scd/scd_encoder.py b/scd/scd_encoder.py new file mode 100644 index 0000000..1fe9d2b --- /dev/null +++ b/scd/scd_encoder.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer +from einops import rearrange + +from model.ctrgcn import Model + + +class Encoder(nn.Module): + def __init__(self, hidden_size, num_head, num_layer) -> None: + super().__init__() + self.d_model = hidden_size + + hidden_size = 64 + self.gcn_t = Model(hidden_size) + self.gcn_s = Model(hidden_size) + + self.channel_t = nn.Sequential( + nn.Linear(50*hidden_size, self.d_model), + nn.LayerNorm(self.d_model), + nn.ReLU(True), + nn.Linear(self.d_model, self.d_model), + ) + + self.channel_s = nn.Sequential( + nn.Linear(64 * hidden_size, self.d_model), + nn.LayerNorm(self.d_model), + nn.ReLU(True), + nn.Linear(self.d_model, self.d_model), + ) + + encoder_layer = TransformerEncoderLayer(self.d_model, num_head, self.d_model, batch_first=True) + self.t_encoder = TransformerEncoder(encoder_layer, num_layer) + self.s_encoder = TransformerEncoder(encoder_layer, num_layer) + + def forward(self, x): + + vt = self.gcn_t(x) + + vt = rearrange(vt, '(B M) C T V -> B T (M V C)', M=2) + vt = self.channel_t(vt) + + vs = self.gcn_s(x) + + vs = rearrange(vs, '(B M) C T V -> B (M V) (T C)', M=2) + vs = self.channel_s(vs) + + vt = self.t_encoder(vt) # B T C + + vs = self.s_encoder(vs) + + # implementation using amax for the TMP runs faster than using MaxPool1D + # not support pytorch < 1.7.0 + vt = vt.amax(dim=1) + vs = vs.amax(dim=1) + + return vt, vs + + +class PretrainingEncoder(nn.Module): + def __init__(self, hidden_size, num_head, num_layer, + num_class=60, + ): + super(PretrainingEncoder, self).__init__() + + self.d_model = hidden_size + + self.encoder = Encoder( + hidden_size, num_head, num_layer, + ) + + # temporal feature projector + self.t_proj = nn.Sequential( + nn.Linear(self.d_model, self.d_model), + nn.ReLU(True), + nn.Linear(self.d_model, num_class) + ) + + # spatial feature projector + self.s_proj = nn.Sequential( + nn.Linear(self.d_model, self.d_model), + nn.ReLU(True), + nn.Linear(self.d_model, num_class) + ) + + # instance level feature projector + self.i_proj = nn.Sequential( + nn.Linear(2 * self.d_model, self.d_model), + nn.ReLU(True), + nn.Linear(self.d_model, num_class) + ) + + def forward(self, x): + + vt, vs = self.encoder(x) + + # projection + zt = self.t_proj(vt) + zs = self.s_proj(vs) + + vi = torch.cat([vt, vs], dim=1) + + zi = self.i_proj(vi) + + return zt, zs, zi + + +class DownstreamEncoder(nn.Module): + """hierarchical encoder network + classifier""" + + def __init__(self, + hidden_size, num_head, num_layer, + num_class=60, + ): + super(DownstreamEncoder, self).__init__() + + self.d_model = hidden_size + + self.encoder = Encoder( + hidden_size, num_head, num_layer, + ) + + # linear classifier + self.fc = nn.Linear(2 * self.d_model, num_class) + + def forward(self, x, knn_eval=False): + + vt, vs = self.encoder(x) + + vi = torch.cat([vt, vs], dim=1) + + if knn_eval: # return last layer features during KNN evaluation (action retrieval) + return vi + else: + return self.fc(vi) \ No newline at end of file