Skip to content

Commit

Permalink
enhance demo program; autocrlf
Browse files Browse the repository at this point in the history
  • Loading branch information
yelantf committed Sep 24, 2020
1 parent bcb6b16 commit 8b43521
Show file tree
Hide file tree
Showing 8 changed files with 1,683 additions and 1,592 deletions.
2 changes: 1 addition & 1 deletion demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ To run this demo, make sure that you install all requirements following [INSTALL

1. Download the object detection model manually: **yolov3-spp.weights** ([Google Drive](https://drive.google.com/open?id=1T13mXnPLu8JRelwh60BRR21f2TlGWBAM)). Place it into `data/models/detector_models`.
2. Download the person tracking model manually: **jde.uncertainty.pt** ([Google Drive](https://drive.google.com/open?id=1IJSp_t5SRlQarFClrRolQzSJ4K5xZIqm)). Place it into `data/models/detector_models`.
3. Please download our action models. Place them into ```data/models/aia_models```. All models are available in the [Model Zoo](../README.md#model-zoo).
3. Please download our action models. Place them into ```data/models/aia_models```. All models are available in [MODEL_ZOO.md](../MODEL_ZOO.md).
4. We also provide a practical model ([Google Drive](https://drive.google.com/open?id=1CudK8w0d2_5r73_tnyAY1Fnwd78hce3M)) trained on 15 common action categories in AVA. This
model achieves about 70 mAP on these categories. Please use [`resnet101_8x8f_denseserial.yaml`](../config_files/resnet101_8x8f_denseserial.yaml)
and eable `--common-cate` to apply this model.
Expand Down
812 changes: 447 additions & 365 deletions demo/action_predictor.py

Large diffs are not rendered by default.

379 changes: 194 additions & 185 deletions demo/demo.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,194 @@
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# Written by Xinzhi MU ([email protected])
# -----------------------------------------------------

import argparse
from time import sleep
from itertools import count
from tqdm import tqdm

import torch

from visualizer import AVAVisualizer
from action_predictor import AVAPredictorWorker

#pytorch issuse #973
import resource

rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (rlimit[1], rlimit[1]))

def main():
parser = argparse.ArgumentParser(description='Action Detection Demo')
parser.add_argument(
"--webcam",
dest="webcam",
help="Use webcam as input",
action="store_true",
)
parser.add_argument(
"--video-path",
default="input.mp4",
help="The path to the input video",
type=str,
)
parser.add_argument(
"--output-path",
default="output.mp4",
help="The path to the video output",
type=str,
)
parser.add_argument(
"--cpu",
dest="cpu",
help="Use cpu",
action="store_true",
)
parser.add_argument(
"--cfg-path",
default="../config_files/resnet101_8x8f_denseserial.yaml",
help="The path to the cfg file",
type=str,
)
parser.add_argument(
"--weight-path",
default="../data/models/aia_models/resnet101_8x8f_denseserial.pth",
help="The path to the model weights",
type=str,
)
parser.add_argument(
"--visual-threshold",
default=0.5,
help="The threshold of visualizer",
type=float,
)
parser.add_argument(
"--start",
default=0,
help="Start reading video at which millisecond",
type=int,
)
parser.add_argument(
"--duration",
default=-1,
help="The duration of detection",
type=int,
)
parser.add_argument(
"--detect-rate",
default=4,
help="Rate(fps) to update action labels",
type=int
)
parser.add_argument(
"--common-cate",
default=False,
help="Using common category model",
action="store_true"
)
parser.add_argument(
"--hide-time",
default=False,
help="Not show the timestamp at the corner",
action="store_true"
)
parser.add_argument(
"--tracker-box-thres",
default=0.1,
help="The box threshold for tracker",
type=float,
)
parser.add_argument(
"--tracker-nms-thres",
default=0.4,
help="The nms threshold for tracker",
type=float,
)

args = parser.parse_args()

args.input_path = 0 if args.webcam else args.video_path
args.device = torch.device("cpu" if args.cpu else "cuda")
args.realtime = True if args.webcam else False

# Configuration for Tracker. Currently Multi-gpu is not supported
args.gpus = "0"
args.gpus = [int(i) for i in args.gpus.split(',')] if torch.cuda.device_count() >= 1 else [-1]
args.min_box_area = 0
args.tracking = True
args.detector = "tracker"
args.debug = False

if args.webcam:
print('Starting webcam demo, press Ctrl + C to terminate...')
else:
print('Starting video demo, video path: {}'.format(args.video_path))

# Initialise Visualizer
video_writer = AVAVisualizer(
args.input_path,
args.output_path,
args.realtime,
args.start,
args.duration,
(not args.hide_time),
confidence_threshold = args.visual_threshold,
common_cate = args.common_cate,
)

torch.multiprocessing.set_start_method('forkserver', force=True)
torch.multiprocessing.set_sharing_strategy('file_system')

ava_predictor_worker = AVAPredictorWorker(args)

try:
for i in tqdm(count()):
with torch.no_grad():
(orig_img, boxes, scores, ids) = ava_predictor_worker.read_track()

if orig_img is None:
if not args.realtime:
ava_predictor_worker.compute_prediction()
break

if args.realtime:
result = ava_predictor_worker.read()
flag = video_writer.realtime_write_frame(result, orig_img, boxes, scores, ids)
if not flag:
break
else:
video_writer.send_track((boxes, ids))
except KeyboardInterrupt:
print("Keyboard Interrupted")

if not args.realtime:
video_writer.send_track("DONE")
while True:
result = ava_predictor_worker.read()
if result is None:
sleep(0.1)
continue
if result == "done":
break

video_writer.send(result)

video_writer.send("DONE")
print("Wait for writer process to finish...")
video_writer.progress_bar(i)

video_writer.close()
ava_predictor_worker.terminate()

if __name__ == "__main__":
main()
# -----------------------------------------------------
# Copyright (c) Shanghai Jiao Tong University. All rights reserved.
# Written by Xinzhi MU ([email protected])
# -----------------------------------------------------

import argparse
from time import sleep
from itertools import count
from tqdm import tqdm

import torch

from visualizer import AVAVisualizer
from action_predictor import AVAPredictorWorker

#pytorch issuse #973
import resource

rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (rlimit[1], rlimit[1]))

def main():
parser = argparse.ArgumentParser(description='Action Detection Demo')
parser.add_argument(
"--webcam",
dest="webcam",
help="Use webcam as input",
action="store_true",
)
parser.add_argument(
"--video-path",
default="input.mp4",
help="The path to the input video",
type=str,
)
parser.add_argument(
"--output-path",
default="output.mp4",
help="The path to the video output",
type=str,
)
parser.add_argument(
"--cpu",
dest="cpu",
help="Use cpu",
action="store_true",
)
parser.add_argument(
"--cfg-path",
default="../config_files/resnet101_8x8f_denseserial.yaml",
help="The path to the cfg file",
type=str,
)
parser.add_argument(
"--weight-path",
default="../data/models/aia_models/resnet101_8x8f_denseserial.pth",
help="The path to the model weights",
type=str,
)
parser.add_argument(
"--visual-threshold",
default=0.5,
help="The threshold of visualizer",
type=float,
)
parser.add_argument(
"--start",
default=0,
help="Start reading video at which millisecond",
type=int,
)
parser.add_argument(
"--duration",
default=-1,
help="The duration of detection",
type=int,
)
parser.add_argument(
"--detect-rate",
default=4,
help="Rate(fps) to update action labels",
type=int
)
parser.add_argument(
"--common-cate",
default=False,
help="Using common category model",
action="store_true"
)
parser.add_argument(
"--hide-time",
default=False,
help="Not show the timestamp at the corner",
action="store_true"
)
parser.add_argument(
"--tracker-box-thres",
default=0.1,
help="The box threshold for tracker",
type=float,
)
parser.add_argument(
"--tracker-nms-thres",
default=0.4,
help="The nms threshold for tracker",
type=float,
)

args = parser.parse_args()

args.input_path = 0 if args.webcam else args.video_path
args.device = torch.device("cpu" if args.cpu else "cuda")
args.realtime = True if args.webcam else False

# Configuration for Tracker. Currently Multi-gpu is not supported
args.gpus = "0"
args.gpus = [int(i) for i in args.gpus.split(',')] if torch.cuda.device_count() >= 1 else [-1]
args.min_box_area = 0
args.tracking = True
args.detector = "tracker"
args.debug = False

if args.webcam:
print('Starting webcam demo, press Ctrl + C to terminate...')
else:
print('Starting video demo, video path: {}'.format(args.video_path))

# Initialise Visualizer
video_writer = AVAVisualizer(
args.input_path,
args.output_path,
args.realtime,
args.start,
args.duration,
(not args.hide_time),
confidence_threshold = args.visual_threshold,
common_cate = args.common_cate,
)

torch.multiprocessing.set_start_method('forkserver', force=True)
torch.multiprocessing.set_sharing_strategy('file_system')

ava_predictor_worker = AVAPredictorWorker(args)
pred_done_flag = False

print("Showing tracking progress bar (in fps). Other processes are running in the background.")
try:
for i in tqdm(count(), desc="Tracker Progress", unit=" frame"):
with torch.no_grad():
(orig_img, boxes, scores, ids) = ava_predictor_worker.read_track()

if orig_img is None:
if not args.realtime:
ava_predictor_worker.compute_prediction()
break

if args.realtime:
result = ava_predictor_worker.read()
flag = video_writer.realtime_write_frame(result, orig_img, boxes, scores, ids)
if not flag:
break
else:
video_writer.send_track((boxes, ids))
while not pred_done_flag:
result = ava_predictor_worker.read()
if result is None:
break
elif result == "done":
pred_done_flag = True
else:
video_writer.send(result)
except KeyboardInterrupt:
print("Keyboard Interrupted")

if not args.realtime:
video_writer.send_track("DONE")
while not pred_done_flag:
result = ava_predictor_worker.read()
if result is None:
sleep(0.1)
elif result == "done":
pred_done_flag = True
else:
video_writer.send(result)

video_writer.send("DONE")
tqdm.write("Showing video writer progress (in fps).")
video_writer.progress_bar(i)

video_writer.close()
ava_predictor_worker.terminate()

if __name__ == "__main__":
main()
Loading

0 comments on commit 8b43521

Please sign in to comment.