Skip to content

Commit a659a06

Browse files
committed
a
1 parent d5b2414 commit a659a06

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

extract.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
parser.add_argument('--num_decoding_thread', type=int, default=4, help='Num parallel thread for video decoding')
2424
parser.add_argument('--l2_normalize', type=int, default=1, help='l2 normalize feature')
2525
parser.add_argument('--resnext101_model_path', type=str, default='model/resnext101.pth', help='Resnext model path')
26-
parser.add_argument('--device', type=str, default='cuda', help='cuda or cpu')
26+
parser.add_argument('--device', type=str, default='cpu', help='cuda or cpu')
2727
args = parser.parse_args()
2828

2929
if args.device not in ["cpu", "cuda"] :
@@ -52,6 +52,7 @@
5252
def path_leaf(path: str):
5353
"""
5454
Returns the name of a file given its path
55+
Thanks https://github.com/Tikquuss/eulascript/blob/master/utils.py
5556
https://stackoverflow.com/a/8384788/11814682
5657
"""
5758
head, tail = ntpath.split(path)
@@ -89,7 +90,7 @@ def get_number_of_frames(filename : str, output_dir : str = None):
8990
input_file_length = get_length(input_file)
9091
input_file_stream_number = get_number_of_frames(input_file)
9192

92-
file_name, extension = os.path.splitext(input_file)
93+
file_name, extension = os.path.splitext(path_leaf(input_file))
9394

9495
df.append([file_name, input_file_stream_number, input_file_length])
9596

@@ -100,12 +101,12 @@ def get_number_of_frames(filename : str, output_dir : str = None):
100101
if len(video.shape) == 4:
101102
video = preprocess(video)
102103
n_chunk = len(video)
103-
features = th.cuda.FloatTensor(n_chunk, 2048).fill_(0)
104+
features = th.FloatTensor(n_chunk, 2048).to(args.device).fill_(0)
104105
n_iter = int(math.ceil(n_chunk / float(args.batch_size)))
105106
for i in range(n_iter):
106107
min_ind = i * args.batch_size
107108
max_ind = (i + 1) * args.batch_size
108-
video_batch = video[min_ind:max_ind].cuda()
109+
video_batch = video[min_ind:max_ind].to(args.device)
109110
batch_features = model(video_batch)
110111
if args.l2_normalize:
111112
batch_features = F.normalize(batch_features, dim=1)
@@ -123,4 +124,4 @@ def get_number_of_frames(filename : str, output_dir : str = None):
123124
os.path.splitext(path_leaf(args.csv))[0]+"_frame_duration.csv"
124125
),
125126
index = False,
126-
header = ["file_name", "video_frame_nbr", "video_length"])
127+
header = ["file_name", "video_frames", "duration"])

model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def get_model(args):
1919
print('Loading 2D-ResNet-152 ...')
2020
model = models.resnet152(pretrained=True)
2121
model = nn.Sequential(*list(model.children())[:-2], GlobalAvgPool())
22-
model = model.cuda()
22+
model = model.to(args.device)
2323
else:
2424
print('Loading 3D-ResneXt-101 ...')
2525
model = resnext.resnet101(

0 commit comments

Comments
 (0)