Skip to content

Commit

Permalink
SSDTrafficLightsClassifier Wrapped in Class (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Morozov authored Jan 25, 2019
1 parent 206641d commit 342eade
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 29 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import rospy
from unittest import TestCase
from light_classification.tl_classifier import TLClassifier, OpenCVTrafficLightsClassifier
from light_classification.tl_classifier import TLClassifier, OpenCVTrafficLightsClassifier, SSDTrafficLightsClassifier
from roslaunch.parent import ROSLaunchParent


Expand All @@ -24,16 +24,18 @@ def setUp(self):
TLClassifier.INSTANCE = None
TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS = {}
TLClassifier.register_subclass("opencv")(OpenCVTrafficLightsClassifier)
TLClassifier.register_subclass("ssd")(SSDTrafficLightsClassifier)

def tearDown(self):
TLClassifier.INSTANCE = None
TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS = {}
TLClassifier.register_subclass("opencv")(OpenCVTrafficLightsClassifier)
TLClassifier.register_subclass("ssd")(SSDTrafficLightsClassifier)

def test_get_instance_of(self):
instance = TLClassifier.get_instance_of("opencv")
self.assertIsInstance(instance, OpenCVTrafficLightsClassifier)
self.assertEqual(1, len(TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS))
self.assertEqual(2, len(TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS))

def test_classify(self):

Expand All @@ -45,6 +47,9 @@ def __init__(self):
def _classify(self, image):
pass

def get_state_count_threshold(self, last_state):
pass

rospy.init_node('test_tl_classifier', anonymous=True)
mock_instance = TLClassifier.get_instance_of('mock')
for i in range(100):
Expand Down
81 changes: 81 additions & 0 deletions ros/src/tl_detector/light_classification/tl_classifier.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import rospy
import rospkg
import os
import cv2
import sys
import threading
import numpy as np
import tensorflow as tf
from styx_msgs.msg import TrafficLight
from abc import ABCMeta, abstractmethod

Expand Down Expand Up @@ -107,6 +110,16 @@ def classify(self, image):

return tl_state

@abstractmethod
def get_state_count_threshold(self, last_state):
"""
Returns state count threshold value based on the last state.
:param last_state: last traffic lights state
:return: threshold value
:rtype: int
"""
raise NotImplementedError()

@abstractmethod
def __init__(self, cls_name):
"""
Expand All @@ -128,6 +141,9 @@ class OpenCVTrafficLightsClassifier(TLClassifier):
Detects and classifies traffic lights on images with Computer Vision techniques.
"""

def get_state_count_threshold(self, last_state):
return 3

def _classify(self, image):
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

Expand Down Expand Up @@ -158,3 +174,68 @@ def _classify(self, image):

def __init__(self):
super(OpenCVTrafficLightsClassifier, self).__init__(self.__class__.__name__)


@TLClassifier.register_subclass("ssd")
class SSDTrafficLightsClassifier(TLClassifier):

def get_state_count_threshold(self, last_state):
if last_state == TrafficLight.RED:
# High threshold for accelerating
return 3

# Low threshold for stopping
return 1

@staticmethod
def load_graph(graph_file):
"""Loads a frozen inference graph"""
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(graph_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

return graph

def _classify(self, image):
image_np = np.expand_dims(np.asarray(image, dtype=np.uint8), 0)
# Actual detection
(boxes, scores, classes) = self.sess.run([self.detection_boxes, self.detection_scores, self.detection_classes],
feed_dict={self.image_tensor: image_np})

# Remove unnecessary dimensions
scores = np.squeeze(scores)
classes = np.squeeze(classes)

for i, clazz in enumerate(classes):
rospy.logdebug('class = %s, score = %s', self.labels_dict[classes[i]], str(scores[i]))
# if red or yellow light with confidence more than 10%
if (clazz == 2 or clazz == 3) and scores[i] > 0.1:
return TrafficLight.RED

return TrafficLight.UNKNOWN

def __init__(self):
super(SSDTrafficLightsClassifier, self).__init__(self.__class__.__name__)

# Model path
package_root_path = rospkg.RosPack().get_path('tl_detector')
model_path = os.path.join(package_root_path, 'models/ssd.pb')

# Labels dictionary
self.labels_dict = {1: 'Green', 2: 'Red', 3: 'Yellow', 4: 'Unknown'}

# Load frozen graph of trained model
self.detection_graph = self.load_graph(model_path)

# Get tensors
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
self.detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
self.detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
self.detection_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')

# Create session
self.sess = tf.Session(graph=self.detection_graph)
Empty file.
2 changes: 1 addition & 1 deletion ros/src/tl_detector/sim_traffic_light_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ stop_line_positions:
- [161.76, 2303.82]
- [351.84, 1574.65]
is_site: False
classifier: "opencv" # opencv, dl_ssd, simple_cnn, yolo
classifier: "opencv" # opencv, ssd, simple_cnn, yolo
2 changes: 1 addition & 1 deletion ros/src/tl_detector/site_traffic_light_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ camera_info:
stop_line_positions:
- [8.0, 16.2]
is_site: True
classifier: "dl_ssd" # opencv, dl_ssd, simple_cnn, yolo
classifier: "ssd" # opencv, ssd, simple_cnn, yolo
39 changes: 14 additions & 25 deletions ros/src/tl_detector/tl_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
from scipy.spatial import KDTree
from light_classification.tl_classifier import TLClassifier
import tf
import cv2
import yaml

STATE_COUNT_THRESHOLD = 3


class TLDetector(object):
def __init__(self):
Expand Down Expand Up @@ -45,7 +42,6 @@ def __init__(self):
rospy.Subscriber('/current_pose', PoseStamped, self.pose_cb)
rospy.Subscriber('/base_waypoints', Lane, self.waypoints_cb)


# /vehicle/traffic_lights provides you with the location of the traffic light in 3D map space and
# helps you acquire an accurate ground truth data source for the traffic light
# classifier by sending the current color state of all traffic lights in the
Expand Down Expand Up @@ -92,24 +88,21 @@ def image_cb(self, msg):
Callback function for /image_color topic subscriber.
Identifies red lights in the incoming camera image and publishes the index
of the waypoint closest to the red light's stop line to /traffic_waypoint.
:param msg: image from car-mounted camera
:type msg: Image
"""
self.has_image = True
self.camera_image_msg = msg
light_wp, state = self.process_traffic_lights()

'''
Publish upcoming red lights at camera frequency.
Each predicted state has to occur `STATE_COUNT_THRESHOLD` number
of times till we start using it. Otherwise the previous stable state is
used.
'''
# Publish upcoming red lights at camera frequency.
# Each predicted state has to occur `self.light_classifier.get_state_count_threshold(self.last_state)` number
# of times till we start using it. Otherwise the previous stable state is used.

if self.state != state:
self.state_count = 0
self.state = state
elif self.state_count >= STATE_COUNT_THRESHOLD:
elif self.state_count >= self.light_classifier.get_state_count_threshold(self.last_state):
self.last_state = self.state
light_wp = light_wp if state == TrafficLight.RED else -1
self.last_wp = light_wp
Expand All @@ -130,17 +123,14 @@ def get_closest_waypoint(self, x, y):
return self.waypoint_tree.query((x, y), 1)[1]

def get_light_state(self, light):
"""Determines the current color of the traffic light
Args:
light (TrafficLight): light to classify
Returns:
int: ID of traffic light color (specified in styx_msgs/TrafficLight)
"""
Determines the current color of the traffic light.
:param light: light to classify
:type light: TrafficLight
:return: ID of traffic light color (specified in styx_msgs/TrafficLight)
:rtype: int
"""
if not self.has_image:
self.prev_light_loc = None
return False

cv_image = self.bridge.imgmsg_to_cv2(self.camera_image_msg, "bgr8")
Expand All @@ -150,10 +140,9 @@ def get_light_state(self, light):
def process_traffic_lights(self):
"""
Finds closest visible traffic light, if one exists, and determines its location and color.
Returns:
int: index of waypoint closes to the upcoming stop line for a traffic light (-1 if none exists)
int: ID of traffic light color (specified in styx_msgs/TrafficLight)
:return - index of waypoint closes to the upcoming stop line for a traffic light (-1 if none exists)
- ID of traffic light color (specified in styx_msgs/TrafficLight)
:rtype tuple(int, int)
"""
closest_light = None
line_wp_idx = None
Expand Down

0 comments on commit 342eade

Please sign in to comment.