Skip to content

Commit

Permalink
Code for YOLOv3-tiny Classifier (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
Karthikeya108 authored and Sergey Morozov committed Jan 28, 2019
1 parent fd5b617 commit 61bf056
Show file tree
Hide file tree
Showing 95 changed files with 42,575 additions and 176 deletions.
6 changes: 4 additions & 2 deletions PROJECT_README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ polygon detection is performed and checked if the the number of sides is more th
* This is not robust enough, the thresholds need to be adjusted always.
* Doesnt work properly on real world data as there is lot of noise.

### Real World (Test Lot) --- YOLOv3-tiny (You Only Look Once)
We used this approach for real world.
TODO:write about it

### Real World (Test Lot) --- SSD (Single Shot Detection)
We need to solve both object detection - where in the image is the object,
and object classification --- given detections on an image, classify traffic lights.
Expand All @@ -59,8 +63,6 @@ For example, SSD (Single Shot Multibox Detection) and YOLO (You Only Look Once).
We attempted transfer learning using the pre-trained SSD_inception_v2 model trained on COCO dataset,
and retrain it on our own dataset for NUM_EPOCHS, achieving a final loss of FINAL_LOSS.

@Segey, pls elaborate how you collated the dataset.

Here is a sample of the dataset.
![Udacity Test Site training images](report/udacity_visualization.png)

Expand Down
2 changes: 1 addition & 1 deletion ros/launch/site.launch
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
<include file="$(find waypoint_updater)/launch/waypoint_updater.launch"/>

<!--Traffic Light Locations and Camera Config -->
<param name="traffic_light_config" textfile="$(find tl_detector)/site_traffic_light_config.yaml" />
<param name="traffic_light_config" textfile="$(find tl_detector)/config/site_traffic_light_config.yaml" />

<!--Traffic Light Detector Node -->
<include file="$(find tl_detector)/launch/tl_detector_site.launch"/>
Expand Down
2 changes: 1 addition & 1 deletion ros/launch/styx.launch
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
<include file="$(find tl_detector)/launch/tl_detector.launch"/>

<!--Traffic Light Locations and Camera Config -->
<param name="traffic_light_config" textfile="$(find tl_detector)/sim_traffic_light_config.yaml" />
<param name="traffic_light_config" textfile="$(find tl_detector)/config/sim_traffic_light_config.yaml" />
</launch>
8 changes: 4 additions & 4 deletions ros/src/styx/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

conf = AttrDict({
'subscribers': [
{'topic':'/vehicle/steering_cmd', 'type': 'steer_cmd', 'name': 'steering'},
{'topic':'/vehicle/throttle_cmd', 'type': 'throttle_cmd', 'name': 'throttle'},
{'topic':'/vehicle/brake_cmd', 'type': 'brake_cmd', 'name': 'brake'},
{'topic':'/final_waypoints', 'type': 'path_draw', 'name': 'path'},
{'topic': '/vehicle/steering_cmd', 'type': 'steer_cmd', 'name': 'steering'},
{'topic': '/vehicle/throttle_cmd', 'type': 'throttle_cmd', 'name': 'throttle'},
{'topic': '/vehicle/brake_cmd', 'type': 'brake_cmd', 'name': 'brake'},
{'topic': '/final_waypoints', 'type': 'path_draw', 'name': 'path'},
],
'publishers': [
{'topic': '/current_pose', 'type': 'pose', 'name': 'current_pose'},
Expand Down
Empty file added ros/src/tl_detector/__init__.py
Empty file.
Binary file added ros/src/tl_detector/config/FiraMono-Medium.otf
Binary file not shown.
45 changes: 45 additions & 0 deletions ros/src/tl_detector/config/SIL Open Font License.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Copyright (c) 2014, Mozilla Foundation https://mozilla.org/ with Reserved Font Name Fira Mono.

Copyright (c) 2014, Telefonica S.A.

This Font Software is licensed under the SIL Open Font License, Version 1.1.
This license is copied below, and is also available with a FAQ at: http://scripts.sil.org/OFL

-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
-----------------------------------------------------------

PREAMBLE
The goals of the Open Font License (OFL) are to stimulate worldwide development of collaborative font projects, to support the font creation efforts of academic and linguistic communities, and to provide a free and open framework in which fonts may be shared and improved in partnership with others.

The OFL allows the licensed fonts to be used, studied, modified and redistributed freely as long as they are not sold by themselves. The fonts, including any derivative works, can be bundled, embedded, redistributed and/or sold with any software provided that any reserved names are not used by derivative works. The fonts and derivatives, however, cannot be released under any other type of license. The requirement for fonts to remain under this license does not apply to any document created using the fonts or their derivatives.

DEFINITIONS
"Font Software" refers to the set of files released by the Copyright Holder(s) under this license and clearly marked as such. This may include source files, build scripts and documentation.

"Reserved Font Name" refers to any names specified as such after the copyright statement(s).

"Original Version" refers to the collection of Font Software components as distributed by the Copyright Holder(s).

"Modified Version" refers to any derivative made by adding to, deleting, or substituting -- in part or in whole -- any of the components of the Original Version, by changing formats or by porting the Font Software to a new environment.

"Author" refers to any designer, engineer, programmer, technical writer or other person who contributed to the Font Software.

PERMISSION & CONDITIONS
Permission is hereby granted, free of charge, to any person obtaining a copy of the Font Software, to use, study, copy, merge, embed, modify, redistribute, and sell modified and unmodified copies of the Font Software, subject to the following conditions:

1) Neither the Font Software nor any of its individual components, in Original or Modified Versions, may be sold by itself.

2) Original or Modified Versions of the Font Software may be bundled, redistributed and/or sold with any software, provided that each copy contains the above copyright notice and this license. These can be included either as stand-alone text files, human-readable headers or in the appropriate machine-readable metadata fields within text or binary files as long as those fields can be easily viewed by the user.

3) No Modified Version of the Font Software may use the Reserved Font Name(s) unless explicit written permission is granted by the corresponding Copyright Holder. This restriction only applies to the primary font name as presented to the users.

4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font Software shall not be used to promote, endorse or advertise any Modified Version, except to acknowledge the contribution(s) of the Copyright Holder(s) and the Author(s) or with their explicit written permission.

5) The Font Software, modified or unmodified, in part or in whole, must be distributed entirely under this license, and must not be distributed under any other license. The requirement for fonts to remain under this license does not apply to any document created using the Font Software.

TERMINATION
This license becomes null and void if any of the above conditions are not met.

DISCLAIMER
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM OTHER DEALINGS IN THE FONT SOFTWARE.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ stop_line_positions:
- [161.76, 2303.82]
- [351.84, 1574.65]
is_site: False
classifier: "opencv" # opencv, ssd, simple_cnn, yolo
classifier: "opencv" # opencv, ssd, simple_cnn, yolo-tiny
is_debug: false
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ camera_info:
stop_line_positions:
- [8.0, 16.2]
is_site: True
classifier: "ssd" # opencv, ssd, simple_cnn, yolo
classifier: "yolo-tiny" # opencv, ssd, simple_cnn, yolo-tiny
is_debug: false
1 change: 1 addition & 0 deletions ros/src/tl_detector/config/tiny_yolo_anchors.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
10,14, 23,27, 37,58, 81,82, 135,169, 344,319
3 changes: 3 additions & 0 deletions ros/src/tl_detector/config/traffic_lights_classes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
red
yellow
green
4 changes: 4 additions & 0 deletions ros/src/tl_detector/light_classification/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from tl_classifier import TLClassifier
from opencv_tl_classifier import OpenCVTLClassifier
from ssd_tl_classifier import SSDTLClassifier
from light_classification.yolo.yolo_tiny_tl_classifier import YOLOTinyTLClassifier
52 changes: 52 additions & 0 deletions ros/src/tl_detector/light_classification/opencv_tl_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import cv2
import rospy

import numpy as np

from styx_msgs.msg import TrafficLight
from light_classification.tl_classifier import TLClassifier


@TLClassifier.register_subclass('opencv')
class OpenCVTLClassifier(TLClassifier):
"""
Detects and classifies traffic lights on images with Computer Vision techniques.
"""

def get_state_count_threshold(self, last_state):
return 3

def _classify(self, image):
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

lower_red1 = np.array([0, 100, 100])
upper_red1 = np.array([10, 255, 255])
lower_red2 = np.array([160, 100, 100])
upper_red2 = np.array([179, 255, 255])

mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
red_img = cv2.addWeighted(mask1, 1.0, mask2, 1.0, 0)

im, contours, hierarchy = cv2.findContours(red_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
red_count = 0
for x, contour in enumerate(contours):
contourarea = cv2.contourArea(contour) # get area of contour
if 18 < contourarea < 900: # Discard contours with a too large area as this may just be noise
arclength = cv2.arcLength(contour, True)
approxcontour = cv2.approxPolyDP(contour, 0.01 * arclength, True)
# Check for Square
if len(approxcontour) > 5:
red_count += 1
rospy.logdebug("Red count: %d", red_count)

tl_id = TrafficLight.RED if red_count > 0 else TrafficLight.UNKNOWN

if self.is_debug:
# TODO: create a debug image
return tl_id, None

return tl_id, None

def __init__(self, is_debug):
super(OpenCVTLClassifier, self).__init__(self.__class__.__name__, is_debug)
74 changes: 74 additions & 0 deletions ros/src/tl_detector/light_classification/ssd_tl_classifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os
import rospy
import rospkg

import tensorflow as tf
import numpy as np

from styx_msgs.msg import TrafficLight
from light_classification.tl_classifier import TLClassifier


@TLClassifier.register_subclass("ssd")
class SSDTLClassifier(TLClassifier):

def get_state_count_threshold(self, last_state):
if last_state == TrafficLight.RED:
# High threshold for accelerating
return 3

# Low threshold for stopping
return 1

@staticmethod
def load_graph(graph_file):
"""Loads a frozen inference graph"""
graph = tf.Graph()
with graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(graph_file, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')

return graph

def _classify(self, image):
image_np = np.expand_dims(np.asarray(image, dtype=np.uint8), 0)
# Actual detection
(boxes, scores, classes) = self.sess.run([self.detection_boxes, self.detection_scores, self.detection_classes],
feed_dict={self.image_tensor: image_np})

# Remove unnecessary dimensions
scores = np.squeeze(scores)
classes = np.squeeze(classes)

for i, clazz in enumerate(classes):
rospy.logdebug('class = %s, score = %s', self.labels_dict[classes[i]], str(scores[i]))
# if red or yellow light with confidence more than 10%
if (clazz == 2 or clazz == 3) and scores[i] > 0.1:
return TrafficLight.RED, None

return TrafficLight.UNKNOWN, None

def __init__(self, is_debug):
super(SSDTLClassifier, self).__init__(self.__class__.__name__, is_debug)

# Model path
package_root_path = rospkg.RosPack().get_path('tl_detector')
model_path = os.path.join(package_root_path, 'models/ssd.pb')

# Labels dictionary
self.labels_dict = {1: 'Green', 2: 'Red', 3: 'Yellow', 4: 'Unknown'}

# Load frozen graph of trained model
self.detection_graph = self.load_graph(model_path)

# Get tensors
self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
self.detection_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
self.detection_scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
self.detection_classes = self.detection_graph.get_tensor_by_name('detection_classes:0')

# Create session
self.sess = tf.Session(graph=self.detection_graph)
33 changes: 21 additions & 12 deletions ros/src/tl_detector/light_classification/test_tl_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import rospy
from unittest import TestCase
from light_classification.tl_classifier import TLClassifier, OpenCVTrafficLightsClassifier, SSDTrafficLightsClassifier
from light_classification.tl_classifier import TLClassifier
from light_classification.ssd_tl_classifier import SSDTLClassifier
from light_classification.yolo.yolo_tiny_tl_classifier import YOLOTinyTLClassifier
from light_classification.opencv_tl_classifier import OpenCVTLClassifier
from roslaunch.parent import ROSLaunchParent


Expand All @@ -23,35 +26,41 @@ def tearDownClass(cls):
def setUp(self):
TLClassifier.INSTANCE = None
TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS = {}
TLClassifier.register_subclass("opencv")(OpenCVTrafficLightsClassifier)
TLClassifier.register_subclass("ssd")(SSDTrafficLightsClassifier)
TLClassifier.register_subclass("opencv")(OpenCVTLClassifier)
TLClassifier.register_subclass("ssd")(SSDTLClassifier)
TLClassifier.register_subclass("yolo-tiny")(YOLOTinyTLClassifier)

def tearDown(self):
TLClassifier.INSTANCE = None
TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS = {}
TLClassifier.register_subclass("opencv")(OpenCVTrafficLightsClassifier)
TLClassifier.register_subclass("ssd")(SSDTrafficLightsClassifier)
TLClassifier.register_subclass("opencv")(OpenCVTLClassifier)
TLClassifier.register_subclass("ssd")(SSDTLClassifier)
TLClassifier.register_subclass("yolo-tiny")(YOLOTinyTLClassifier)

def test_get_instance_of(self):
instance = TLClassifier.get_instance_of("opencv")
self.assertIsInstance(instance, OpenCVTrafficLightsClassifier)
self.assertEqual(2, len(TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS))
self.assertIsInstance(instance, OpenCVTLClassifier)
self.assertEqual(3, len(TLClassifier.KNOWN_TRAFFIC_LIGHT_CLASSIFIERS))

def test_classify(self):

@TLClassifier.register_subclass('mock')
class MockTLClassifier(TLClassifier):
def __init__(self):
super(MockTLClassifier, self).__init__(self.__class__.__name__)
def __init__(self, is_debug):
super(MockTLClassifier, self).__init__(self.__class__.__name__, is_debug)

def _classify(self, image):
pass
return None, None

def get_state_count_threshold(self, last_state):
pass

rospy.init_node('test_tl_classifier', anonymous=True)
mock_instance = TLClassifier.get_instance_of('mock')
for i in range(100):
for i in range(20):
mock_instance.classify(None)
self.assertEqual(20, len(mock_instance._start_time_circular_buffer))

for i in range(200):
mock_instance.classify(None)
self.assertEqual(100, mock_instance._counter)
self.assertEqual(100, len(mock_instance._start_time_circular_buffer))
Loading

0 comments on commit 61bf056

Please sign in to comment.