23
23
parser .add_argument ('--num_decoding_thread' , type = int , default = 4 , help = 'Num parallel thread for video decoding' )
24
24
parser .add_argument ('--l2_normalize' , type = int , default = 1 , help = 'l2 normalize feature' )
25
25
parser .add_argument ('--resnext101_model_path' , type = str , default = 'model/resnext101.pth' , help = 'Resnext model path' )
26
- parser .add_argument ('--device' , type = str , default = 'cuda ' , help = 'cuda or cpu' )
26
+ parser .add_argument ('--device' , type = str , default = 'cpu ' , help = 'cuda or cpu' )
27
27
args = parser .parse_args ()
28
28
29
29
if args .device not in ["cpu" , "cuda" ] :
52
52
def path_leaf (path : str ):
53
53
"""
54
54
Returns the name of a file given its path
55
+ Thanks https://github.com/Tikquuss/eulascript/blob/master/utils.py
55
56
https://stackoverflow.com/a/8384788/11814682
56
57
"""
57
58
head , tail = ntpath .split (path )
@@ -89,7 +90,7 @@ def get_number_of_frames(filename : str, output_dir : str = None):
89
90
input_file_length = get_length (input_file )
90
91
input_file_stream_number = get_number_of_frames (input_file )
91
92
92
- file_name , extension = os .path .splitext (input_file )
93
+ file_name , extension = os .path .splitext (path_leaf ( input_file ))
93
94
94
95
df .append ([file_name , input_file_stream_number , input_file_length ])
95
96
@@ -100,12 +101,12 @@ def get_number_of_frames(filename : str, output_dir : str = None):
100
101
if len (video .shape ) == 4 :
101
102
video = preprocess (video )
102
103
n_chunk = len (video )
103
- features = th .cuda . FloatTensor (n_chunk , 2048 ).fill_ (0 )
104
+ features = th .FloatTensor (n_chunk , 2048 ). to ( args . device ).fill_ (0 )
104
105
n_iter = int (math .ceil (n_chunk / float (args .batch_size )))
105
106
for i in range (n_iter ):
106
107
min_ind = i * args .batch_size
107
108
max_ind = (i + 1 ) * args .batch_size
108
- video_batch = video [min_ind :max_ind ].cuda ( )
109
+ video_batch = video [min_ind :max_ind ].to ( args . device )
109
110
batch_features = model (video_batch )
110
111
if args .l2_normalize :
111
112
batch_features = F .normalize (batch_features , dim = 1 )
@@ -123,4 +124,4 @@ def get_number_of_frames(filename : str, output_dir : str = None):
123
124
os .path .splitext (path_leaf (args .csv ))[0 ]+ "_frame_duration.csv"
124
125
),
125
126
index = False ,
126
- header = ["file_name" , "video_frame_nbr " , "video_length " ])
127
+ header = ["file_name" , "video_frames " , "duration " ])
0 commit comments