-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
06f5400
commit f611c35
Showing
16 changed files
with
514 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# posenetv2-pythontf | ||
This is a Python and Tensorflow implementation of Posenet v2 released by Google in TensorflowJS. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
MOBILENET FULL PRECISION STRIDE 16 MODEL | ||
# MODEL LINK | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/model-stride16.json | ||
|
||
MOBILENET FULL PRECISION STRIDE 8 MODEL | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/model-stride8.json | ||
|
||
# MODEL WEIGHT | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/group1-shard1of4.bin | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/group1-shard2of4.bin | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/group1-shard43f4.bin | ||
https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/100/group1-shard4of4.bin | ||
|
||
#MORE LINKS WILL BE UPDATED SOON |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import tensorflow as tf | ||
import cv2 | ||
import time | ||
import argparse | ||
import os | ||
|
||
import posenet | ||
|
||
|
||
MODEL_DIR = './models' | ||
DEBUG_OUTPUT = False | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--model', type=str, default='model-resnet_v2') | ||
parser.add_argument('--output_stride', type=int, default=16) | ||
parser.add_argument('--scale_factor', type=float, default=1.0) | ||
parser.add_argument('--notxt', action='store_true') | ||
parser.add_argument('--image_dir', type=str, default='./images') | ||
parser.add_argument('--output_dir', type=str, default='./output') | ||
args = parser.parse_args() | ||
|
||
def load_model(model_name, sess, model_dir=MODEL_DIR): | ||
model_path = os.path.join(model_dir, '%s.pb' % model_name) | ||
if not os.path.exists(model_path): | ||
print('Cannot find model file %s' % model_path) | ||
|
||
with tf.gfile.GFile(model_path, 'rb') as f: | ||
graph_def = tf.GraphDef() | ||
graph_def.ParseFromString(f.read()) | ||
sess.graph.as_default() | ||
tf.import_graph_def(graph_def, name='') | ||
|
||
if DEBUG_OUTPUT: | ||
graph_nodes = [n for n in graph_def.node] | ||
names = [] | ||
for t in graph_nodes: | ||
names.append(t.name) | ||
print('Loaded graph node:', t.name) | ||
#For Mobilenet Version | ||
offsets = sess.graph.get_tensor_by_name('MobilenetV1/offset_2/BiasAdd:0') | ||
displacement_fwd = sess.graph.get_tensor_by_name('MobilenetV1/displacement_fwd_2/BiasAdd:0') | ||
displacement_bwd = sess.graph.get_tensor_by_name('MobilenetV1/displacement_bwd_2/BiasAdd:0') | ||
heatmaps = sess.graph.get_tensor_by_name('MobilenetV1/heatmap_2/BiasAdd:0') | ||
# For Resnet50 Version | ||
# offsets = sess.graph.get_tensor_by_name('float_short_offsets:0') | ||
# displacement_fwd = sess.graph.get_tensor_by_name('resnet_v1_50/displacement_fwd_2/BiasAdd:0') | ||
# displacement_bwd = sess.graph.get_tensor_by_name('resnet_v1_50/displacement_bwd_2/BiasAdd:0') | ||
# heatmaps = sess.graph.get_tensor_by_name('float_heatmaps:0') | ||
|
||
return [heatmaps, offsets, displacement_fwd, displacement_bwd] | ||
|
||
|
||
def main(): | ||
|
||
with tf.Session() as sess: | ||
model_outputs = load_model(args.model, sess) | ||
output_stride = args.output_stride #16 #Change it according to the model | ||
|
||
if args.output_dir: | ||
if not os.path.exists(args.output_dir): | ||
os.makedirs(args.output_dir) | ||
|
||
filenames = [ | ||
f.path for f in os.scandir(args.image_dir) if f.is_file() and f.path.endswith(('.png', '.jpg'))] | ||
|
||
start = time.time() | ||
for f in filenames: | ||
input_image, draw_image, output_scale = posenet.read_imgfile( | ||
f, scale_factor=args.scale_factor, output_stride=output_stride) | ||
|
||
heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = sess.run( | ||
model_outputs, | ||
feed_dict={'sub_2:0': input_image} | ||
) | ||
|
||
pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multiple_poses( | ||
heatmaps_result.squeeze(axis=0), | ||
offsets_result.squeeze(axis=0), | ||
displacement_fwd_result.squeeze(axis=0), | ||
displacement_bwd_result.squeeze(axis=0), | ||
output_stride=output_stride, | ||
max_pose_detections=10, | ||
min_pose_score=0.25) | ||
|
||
keypoint_coords *= output_scale | ||
|
||
if args.output_dir: | ||
draw_image = posenet.draw_skel_and_kp( | ||
draw_image, pose_scores, keypoint_scores, keypoint_coords, | ||
min_pose_score=0.25, min_part_score=0.25) | ||
|
||
cv2.imwrite(os.path.join(args.output_dir, os.path.relpath(f, args.image_dir)), draw_image) | ||
|
||
if not args.notxt: | ||
print() | ||
print("Results for image: %s" % f) | ||
for pi in range(len(pose_scores)): | ||
if pose_scores[pi] == 0.: | ||
break | ||
print('Pose #%d, score = %f' % (pi, pose_scores[pi])) | ||
for ki, (s, c) in enumerate(zip(keypoint_scores[pi, :], keypoint_coords[pi, :, :])): | ||
print('Keypoint %s, score = %f, coord = %s' % (posenet.PART_NAMES[ki], s, c)) | ||
|
||
print('Average FPS:', len(filenames) / (time.time() - start)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from posenet.constants import * | ||
from posenet.decode_multi import decode_multiple_poses | ||
from posenet.utils import * |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
|
||
PART_NAMES = [ | ||
"nose", "leftEye", "rightEye", "leftEar", "rightEar", "leftShoulder", | ||
"rightShoulder", "leftElbow", "rightElbow", "leftWrist", "rightWrist", | ||
"leftHip", "rightHip", "leftKnee", "rightKnee", "leftAnkle", "rightAnkle" | ||
] | ||
|
||
NUM_KEYPOINTS = len(PART_NAMES) | ||
|
||
PART_IDS = {pn: pid for pid, pn in enumerate(PART_NAMES)} | ||
|
||
CONNECTED_PART_NAMES = [ | ||
("leftHip", "leftShoulder"), ("leftElbow", "leftShoulder"), | ||
("leftElbow", "leftWrist"), ("leftHip", "leftKnee"), | ||
("leftKnee", "leftAnkle"), ("rightHip", "rightShoulder"), | ||
("rightElbow", "rightShoulder"), ("rightElbow", "rightWrist"), | ||
("rightHip", "rightKnee"), ("rightKnee", "rightAnkle"), | ||
("leftShoulder", "rightShoulder"), ("leftHip", "rightHip") | ||
] | ||
|
||
CONNECTED_PART_INDICES = [(PART_IDS[a], PART_IDS[b]) for a, b in CONNECTED_PART_NAMES] | ||
|
||
LOCAL_MAXIMUM_RADIUS = 1 | ||
|
||
POSE_CHAIN = [ | ||
("nose", "leftEye"), ("leftEye", "leftEar"), ("nose", "rightEye"), | ||
("rightEye", "rightEar"), ("nose", "leftShoulder"), | ||
("leftShoulder", "leftElbow"), ("leftElbow", "leftWrist"), | ||
("leftShoulder", "leftHip"), ("leftHip", "leftKnee"), | ||
("leftKnee", "leftAnkle"), ("nose", "rightShoulder"), | ||
("rightShoulder", "rightElbow"), ("rightElbow", "rightWrist"), | ||
("rightShoulder", "rightHip"), ("rightHip", "rightKnee"), | ||
("rightKnee", "rightAnkle") | ||
] | ||
|
||
PARENT_CHILD_TUPLES = [(PART_IDS[parent], PART_IDS[child]) for parent, child in POSE_CHAIN] | ||
|
||
PART_CHANNELS = [ | ||
'left_face', | ||
'right_face', | ||
'right_upper_leg_front', | ||
'right_lower_leg_back', | ||
'right_upper_leg_back', | ||
'left_lower_leg_front', | ||
'left_upper_leg_front', | ||
'left_upper_leg_back', | ||
'left_lower_leg_back', | ||
'right_feet', | ||
'right_lower_leg_front', | ||
'left_feet', | ||
'torso_front', | ||
'torso_back', | ||
'right_upper_arm_front', | ||
'right_upper_arm_back', | ||
'right_lower_arm_back', | ||
'left_lower_arm_front', | ||
'left_upper_arm_front', | ||
'left_upper_arm_back', | ||
'left_lower_arm_back', | ||
'right_hand', | ||
'right_lower_arm_front', | ||
'left_hand' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
|
||
from posenet.constants import * | ||
|
||
|
||
def traverse_to_targ_keypoint( | ||
edge_id, source_keypoint, target_keypoint_id, scores, offsets, output_stride, displacements | ||
): | ||
height = scores.shape[0] | ||
width = scores.shape[1] | ||
|
||
source_keypoint_indices = np.clip( | ||
np.round(source_keypoint / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32) | ||
|
||
displaced_point = source_keypoint + displacements[ | ||
source_keypoint_indices[0], source_keypoint_indices[1], edge_id] | ||
|
||
displaced_point_indices = np.clip( | ||
np.round(displaced_point / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32) | ||
|
||
score = scores[displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id] | ||
|
||
image_coord = displaced_point_indices * output_stride + offsets[ | ||
displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id] | ||
|
||
return score, image_coord | ||
|
||
|
||
def decode_pose( | ||
root_score, root_id, root_image_coord, | ||
scores, | ||
offsets, | ||
output_stride, | ||
displacements_fwd, | ||
displacements_bwd | ||
): | ||
num_parts = scores.shape[2] | ||
num_edges = len(PARENT_CHILD_TUPLES) | ||
|
||
instance_keypoint_scores = np.zeros(num_parts) | ||
instance_keypoint_coords = np.zeros((num_parts, 2)) | ||
instance_keypoint_scores[root_id] = root_score | ||
instance_keypoint_coords[root_id] = root_image_coord | ||
|
||
for edge in reversed(range(num_edges)): | ||
target_keypoint_id, source_keypoint_id = PARENT_CHILD_TUPLES[edge] | ||
if (instance_keypoint_scores[source_keypoint_id] > 0.0 and | ||
instance_keypoint_scores[target_keypoint_id] == 0.0): | ||
score, coords = traverse_to_targ_keypoint( | ||
edge, | ||
instance_keypoint_coords[source_keypoint_id], | ||
target_keypoint_id, | ||
scores, offsets, output_stride, displacements_bwd) | ||
instance_keypoint_scores[target_keypoint_id] = score | ||
instance_keypoint_coords[target_keypoint_id] = coords | ||
|
||
for edge in range(num_edges): | ||
source_keypoint_id, target_keypoint_id = PARENT_CHILD_TUPLES[edge] | ||
if (instance_keypoint_scores[source_keypoint_id] > 0.0 and | ||
instance_keypoint_scores[target_keypoint_id] == 0.0): | ||
score, coords = traverse_to_targ_keypoint( | ||
edge, | ||
instance_keypoint_coords[source_keypoint_id], | ||
target_keypoint_id, | ||
scores, offsets, output_stride, displacements_fwd) | ||
instance_keypoint_scores[target_keypoint_id] = score | ||
instance_keypoint_coords[target_keypoint_id] = coords | ||
|
||
return instance_keypoint_scores, instance_keypoint_coords |
Oops, something went wrong.