-
Notifications
You must be signed in to change notification settings - Fork 277
Differential Binarization model #2095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
mehtamansi29
wants to merge
4
commits into
keras-team:master
Choose a base branch
from
mehtamansi29:diffbin
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,103
−0
Draft
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
ed97271
ImageText detector preprocessor for Differential Binarization model
mehtamansi29 d97f362
db_utils functions and testfile
mehtamansi29 de3aaae
Diffbin utils function and test file
mehtamansi29 9a3cf2a
diffbin utils function and testfile
mehtamansi29 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
import keras | ||
|
||
|
||
class Point: | ||
def __init__(self, x, y): | ||
self.x = x | ||
self.y = y | ||
|
||
def __add__(self, other): | ||
return Point(self.x + other.x, self.y + other.y) | ||
|
||
def __sub__(self, other): | ||
return Point(self.x - other.x, self.y - other.y) | ||
|
||
def __neg__(self): | ||
return Point(-self.x, -self.y) | ||
|
||
def cross(self, other): | ||
return self.x * other.y - self.y * other.x | ||
|
||
def to_tuple(self): | ||
return (self.x, self.y) | ||
|
||
|
||
def shrink_polygan(polygon, offset): | ||
""" | ||
Shrinks a polygon inward by moving each point toward the center. | ||
""" | ||
if len(polygon) < 3: | ||
return polygon | ||
|
||
if not isinstance(polygon[0], Point): | ||
polygon = [Point(p[0], p[1]) for p in polygon] | ||
|
||
cx = sum(p.x for p in polygon) / len(polygon) | ||
cy = sum(p.y for p in polygon) / len(polygon) | ||
|
||
shrunk = [] | ||
for p in polygon: | ||
dx = p.x - cx | ||
dy = p.y - cy | ||
norm = max((dx**2 + dy**2) ** 0.5, 1e-6) | ||
shrink_ratio = max(0, 1 - offset / norm) | ||
shrunk.append(Point(cx + dx * shrink_ratio, cy + dy * shrink_ratio)) | ||
|
||
return shrunk | ||
|
||
|
||
# Polygon Area | ||
def Polygon(coords): | ||
""" | ||
Calculate the area of a polygon using the Shoelace formula. | ||
""" | ||
coords = keras.ops.convert_to_tensor(coords, dtype="float32") | ||
x = coords[:, 0] | ||
y = coords[:, 1] | ||
|
||
x_next = keras.ops.roll(x, shift=-1, axis=0) | ||
y_next = keras.ops.roll(y, shift=-1, axis=0) | ||
|
||
area = 0.5 * keras.ops.abs(keras.ops.sum(x * y_next - x_next * y)) | ||
return area | ||
|
||
|
||
# binary search smallest width | ||
def binary_search_smallest_width(poly): | ||
""" | ||
The function aims maximum amount by which polygan can be shrunk by | ||
taking polygan's smallest width | ||
""" | ||
if len(poly) < 3: | ||
return 0 | ||
|
||
low, high = 0, 1 | ||
|
||
while high - low > 0.01: | ||
mid = (high + low) / 2 | ||
mid_poly = shrink_polygan(poly, mid) | ||
mid_poly = keras.ops.cast( | ||
keras.ops.stack([[p.x, p.y] for p in mid_poly]), dtype="float32" | ||
) | ||
area = Polygon(mid_poly) | ||
|
||
if area > 0.1: | ||
low = mid | ||
else: | ||
high = mid | ||
|
||
height = (low + high) / 2 | ||
height = (low + high) / 2 | ||
return int(height) if height >= 0.1 else 0 | ||
|
||
|
||
# project point to line | ||
def project_point_to_line(x, u, v, axis=0): | ||
""" | ||
Projects a point x onto the line defined by points u and v | ||
""" | ||
x = keras.ops.convert_to_tensor(x, dtype="float32") | ||
u = keras.ops.convert_to_tensor(u, dtype="float32") | ||
v = keras.ops.convert_to_tensor(v, dtype="float32") | ||
|
||
n = v - u | ||
n = n / ( | ||
keras.ops.norm(n, axis=axis, keepdims=True) + keras.backend.epsilon() | ||
) | ||
p = u + n * keras.ops.sum((x - u) * n, axis=axis, keepdims=True) | ||
return p | ||
|
||
|
||
# project_point_to_segment | ||
def project_point_to_segment(x, u, v, axis=0): | ||
""" | ||
Projects a point x onto the line segment defined by points u and v | ||
""" | ||
p = project_point_to_line(x, u, v, axis=axis) | ||
outer = keras.ops.greater_equal( | ||
keras.ops.sum((u - p) * (v - p), axis=axis, keepdims=True), 0 | ||
) | ||
near_u = keras.ops.less_equal( | ||
keras.ops.norm(u - p, axis=axis, keepdims=True), | ||
keras.ops.norm(v - p, axis=axis, keepdims=True), | ||
) | ||
o = keras.ops.where(outer, keras.ops.where(near_u, u, v), p) | ||
return o | ||
|
||
|
||
# get line of height | ||
def get_line_height(poly): | ||
return binary_search_smallest_width(poly) | ||
|
||
|
||
# cv2.fillpoly function with keras.ops | ||
def fill_poly_keras(vertices, image_shape): | ||
""" | ||
Fill a polygon using the cv2.fillPoly function with keras.ops. | ||
Ray-casting algorithm to determine if a point is inside a polygon. | ||
""" | ||
height, width = image_shape | ||
x = keras.ops.arange(width) | ||
y = keras.ops.arange(height) | ||
xx, yy = keras.ops.meshgrid(x, y) | ||
xx = keras.ops.cast(xx, "float32") | ||
yy = keras.ops.cast(yy, "float32") | ||
|
||
result = keras.ops.zeros((height, width), dtype="float32") | ||
|
||
vertices = keras.ops.convert_to_tensor(vertices, dtype="float32") | ||
num_vertices = vertices.shape[0] | ||
|
||
for i in range(num_vertices): | ||
x1, y1 = vertices[i] | ||
x2, y2 = vertices[(i + 1) % num_vertices] | ||
|
||
# Modified conditions to potentially include more boundary pixels | ||
cond1 = (yy > keras.ops.minimum(y1, y2)) & ( | ||
yy <= keras.ops.maximum(y1, y2) | ||
) | ||
cond2 = xx < (x1 + (yy - y1) * (x2 - x1) / (y2 - y1)) | ||
|
||
result = keras.ops.where( | ||
cond1 & cond2 & ((y1 > yy) != (y2 > yy)), 1 - result, result | ||
) | ||
|
||
result = keras.ops.cast(result, "int32") | ||
return result | ||
|
||
|
||
# get mask | ||
def get_mask(w, h, polys, ignores): | ||
""" | ||
Generates a binary mask where: | ||
- Ignored regions are set to 0 | ||
- Text regions are set to 1 | ||
""" | ||
mask = keras.ops.ones((h, w), dtype="float32") | ||
|
||
for poly, ignore in zip(polys, ignores): | ||
poly = keras.ops.cast(keras.ops.convert_to_numpy(poly), dtype="int32") | ||
|
||
if poly.shape[0] < 3: | ||
print("Skipping invalid polygon:", poly) | ||
continue | ||
|
||
fill_value = 0.0 if ignore else 1.0 | ||
poly_mask = fill_poly_keras(poly, (h, w)) | ||
|
||
if ignore: | ||
mask = keras.ops.where( | ||
keras.ops.cast(poly_mask, "float32") == 1.0, | ||
keras.ops.zeros_like(mask), | ||
mask, | ||
) | ||
else: | ||
mask = keras.ops.maximum(mask, poly_mask) | ||
return mask | ||
|
||
|
||
# get polygan coordinates projection | ||
def get_coords_poly_projection(coords, poly): | ||
""" | ||
This projects set of points onto edges of a polygan and return closest | ||
projected points | ||
""" | ||
start_points = keras.ops.array(poly, dtype="float32") | ||
end_points = keras.ops.concatenate( | ||
[ | ||
keras.ops.array(poly[1:], dtype="float32"), | ||
keras.ops.array(poly[:1], dtype="float32"), | ||
], | ||
axis=0, | ||
) | ||
region_points = keras.ops.array(coords, dtype="float32") | ||
|
||
projected_points = project_point_to_segment( | ||
keras.ops.expand_dims(region_points, axis=1), | ||
keras.ops.expand_dims(start_points, axis=0), | ||
keras.ops.expand_dims(end_points, axis=0), | ||
axis=2, | ||
) | ||
|
||
projection_distances = keras.ops.norm( | ||
keras.ops.expand_dims(region_points, axis=1) - projected_points, axis=2 | ||
) | ||
|
||
indices = keras.ops.expand_dims( | ||
keras.ops.argmin(projection_distances, axis=1), axis=-1 | ||
) | ||
best_projected_points = keras.ops.take_along_axis( | ||
projected_points, indices[..., None], axis=1 | ||
)[:, 0, :] | ||
|
||
return best_projected_points | ||
|
||
|
||
# get polygan coordinates distance | ||
def get_coords_poly_distance(coords, poly): | ||
""" | ||
This function calculates distance between set of points and polygan | ||
""" | ||
projection = get_coords_poly_projection(coords, poly) | ||
return keras.ops.linalg.norm(projection - coords, axis=1) | ||
|
||
|
||
# get normalized weight | ||
def get_normalized_weight(heatmap, mask, background_weight=3.0): | ||
""" | ||
This function calculates normalized weight of heatmap | ||
""" | ||
pos = keras.ops.greater_equal(heatmap, 0.5) | ||
neg = keras.ops.ones_like(pos, dtype="float32") - keras.ops.cast( | ||
pos, dtype="float32" | ||
) | ||
pos = keras.ops.logical_and(pos, mask) | ||
neg = keras.ops.logical_and(neg, mask) | ||
npos = keras.ops.sum(pos) | ||
nneg = keras.ops.sum(neg) | ||
smooth = ( | ||
keras.ops.cast(npos, dtype="float32") | ||
+ keras.ops.cast(nneg, dtype="float32") | ||
+ 1 | ||
) * 0.05 | ||
wpos = (keras.ops.cast(nneg, dtype="float32") + smooth) / ( | ||
keras.ops.cast(npos, dtype="float32") + smooth | ||
) | ||
weight = keras.ops.zeros_like(heatmap) | ||
# weight[keras.ops.cast(neg, dtype="bool")] = background_weight | ||
neg = keras.ops.cast(neg, "bool") | ||
weight = keras.ops.where(neg, background_weight, weight) | ||
pos = keras.ops.cast(pos, "bool") | ||
weight = keras.ops.where(pos, wpos, weight) | ||
return weight | ||
|
||
|
||
# Getting region coordinates | ||
def get_region_coordinate(w, h, poly, heights, shrink): | ||
""" | ||
Extract coordinates of regions corresponding to text lines in an image. | ||
""" | ||
label_map = keras.ops.zeros((h, w), dtype="float32") | ||
for line_id, (p, height) in enumerate(zip(poly, heights)): | ||
if height > 0: | ||
poly_points = [Point(row[0], row[1]) for row in p] | ||
shrinked_poly = shrink_polygan(poly_points, height * shrink) | ||
shrunk_poly_tuples = [point.to_tuple() for point in shrinked_poly] | ||
shrunk_poly_tensor = keras.ops.convert_to_tensor( | ||
shrunk_poly_tuples, dtype="float32" | ||
) | ||
filled_polygon = fill_poly_keras(shrunk_poly_tensor, (h, w)) | ||
label_map = keras.ops.maximum(label_map, filled_polygon) | ||
|
||
label_map = keras.ops.convert_to_tensor(label_map) | ||
sorted_tensor = keras.ops.sort(keras.ops.reshape(label_map, (-1,))) | ||
diff = keras.ops.concatenate( | ||
[ | ||
keras.ops.convert_to_tensor([True]), | ||
(sorted_tensor[1:] != sorted_tensor[:-1]), | ||
] | ||
) | ||
diff = keras.ops.reshape(diff, (-1,)) | ||
indices = keras.ops.convert_to_tensor(keras.ops.where(diff)) | ||
indices = keras.ops.reshape(indices, (-1,)) | ||
unique_labels = keras.ops.take(sorted_tensor, indices) | ||
unique_labels = unique_labels[unique_labels != 0] | ||
regions_coords = [] | ||
for label in unique_labels: | ||
mask = keras.ops.equal(label_map, label) | ||
y, x = keras.ops.nonzero(mask) | ||
coords = keras.ops.stack([x, y], axis=-1) | ||
regions_coords.append(keras.ops.convert_to_numpy(coords)) | ||
|
||
return regions_coords |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import keras | ||
import numpy as np | ||
|
||
from keras_hub.src.models.diffbin.db_utils import Point | ||
from keras_hub.src.models.diffbin.db_utils import Polygon | ||
from keras_hub.src.models.diffbin.db_utils import binary_search_smallest_width | ||
from keras_hub.src.models.diffbin.db_utils import fill_poly_keras | ||
from keras_hub.src.models.diffbin.db_utils import get_coords_poly_distance | ||
from keras_hub.src.models.diffbin.db_utils import get_coords_poly_projection | ||
from keras_hub.src.models.diffbin.db_utils import get_line_height | ||
from keras_hub.src.models.diffbin.db_utils import get_mask | ||
from keras_hub.src.models.diffbin.db_utils import get_normalized_weight | ||
from keras_hub.src.models.diffbin.db_utils import get_region_coordinate | ||
from keras_hub.src.models.diffbin.db_utils import project_point_to_line | ||
from keras_hub.src.models.diffbin.db_utils import project_point_to_segment | ||
from keras_hub.src.models.diffbin.db_utils import shrink_polygan | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class TestDBUtils(TestCase): | ||
def test_point_operations(self): | ||
p1 = Point(1, 2) | ||
p2 = Point(3, 4) | ||
|
||
p_add = p1 + p2 | ||
assert p_add.x == 4 and p_add.y == 6 | ||
|
||
p_sub = p2 - p1 | ||
assert p_sub.x == 2 and p_sub.y == 2 | ||
|
||
p_neg = -p1 | ||
assert p_neg.x == -1 and p_neg.y == -2 | ||
|
||
cross_product = p1.cross(p2) | ||
assert cross_product == (1 * 4 - 2 * 3) | ||
|
||
p_tuple = p1.to_tuple() | ||
assert p_tuple == (1, 2) | ||
|
||
def test_shrink_polygan(self): | ||
polygon = [Point(0, 0), Point(2, 0), Point(2, 2), Point(0, 2)] | ||
offset = 0.5 | ||
shrunk_polygon = shrink_polygan(polygon, offset) | ||
assert len(shrunk_polygon) == 4 | ||
assert all(isinstance(p, Point) for p in shrunk_polygon) | ||
|
||
polygon_np = np.array([[0, 0], [2, 0], [2, 2], [0, 2]]) | ||
shrunk_polygon_np = shrink_polygan(polygon_np, offset) | ||
assert len(shrunk_polygon_np) == 4 | ||
assert all(isinstance(p, Point) for p in shrunk_polygon_np) | ||
|
||
empty_polygon = [] | ||
assert shrink_polygan(empty_polygon, offset) == [] | ||
|
||
small_polygon = [Point(0, 0), Point(1, 1)] | ||
assert shrink_polygan(small_polygon, offset) == small_polygon | ||
|
||
def test_polygon_area(self): | ||
coords = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
area = Polygon(coords) | ||
assert area == 4.0 | ||
|
||
coords_np = np.array([[0, 0], [2, 0], [2, 2], [0, 2]]) | ||
area_np = Polygon(coords_np) | ||
assert area_np == 4.0 | ||
|
||
triangle = [[0, 0], [1, 0], [0, 1]] | ||
area_triangle = Polygon(triangle) | ||
assert area_triangle == 0.5 | ||
|
||
def test_binary_search_smallest_width(self): | ||
# Example polygon, the exact shrunk amount depends on the algorithm | ||
# We can only test if it returns a non-negative integer. | ||
polygon = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
width = binary_search_smallest_width(polygon) | ||
assert isinstance(width, int) | ||
assert width >= 0 | ||
|
||
small_polygon = [[0, 0], [1, 1]] | ||
assert binary_search_smallest_width(small_polygon) == 0 | ||
|
||
def test_project_point_to_line(self): | ||
x = [1, 1] | ||
u = [0, 0] | ||
v = [2, 0] | ||
projection = project_point_to_line(x, u, v) | ||
assert np.allclose(projection, [1, 0]) | ||
|
||
x_np = np.array([1, 1]) | ||
u_np = np.array([0, 0]) | ||
v_np = np.array([2, 0]) | ||
projection_np = project_point_to_line(x_np, u_np, v_np) | ||
assert np.allclose(projection_np, [1, 0]) | ||
|
||
def test_project_point_to_segment(self): | ||
x = [1, 1] | ||
u = [0, 0] | ||
v = [2, 0] | ||
projection = project_point_to_segment(x, u, v) | ||
assert np.allclose(projection, [1, 0]) | ||
|
||
x_off = [3, 1] | ||
projection_off = project_point_to_segment(x_off, u, v) | ||
assert np.allclose(projection_off, [2, 0]) | ||
|
||
def test_get_line_height(self): | ||
polygon = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
height = get_line_height(polygon) | ||
assert isinstance(height, int) | ||
assert height >= 0 | ||
|
||
def test_fill_poly_keras(self): | ||
vertices = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
image_shape = (3, 3) | ||
mask = fill_poly_keras(vertices, image_shape) | ||
assert mask.shape == image_shape | ||
assert keras.ops.any(mask >= 0) and keras.ops.any(mask <= 1) | ||
|
||
def test_get_mask(self): | ||
w, h = 3, 3 | ||
polys = [[[0, 0], [2, 0], [2, 2], [0, 2]]] | ||
ignores = [False] | ||
mask = get_mask(w, h, polys, ignores) | ||
assert mask.shape == (h, w) | ||
assert keras.ops.any(mask >= 0) and keras.ops.any(mask <= 1) | ||
|
||
def test_get_coords_poly_projection(self): | ||
coords = [[1, 1], [3, 3]] | ||
poly = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
projection = get_coords_poly_projection(coords, poly) | ||
assert projection.shape == (len(coords), 2) | ||
|
||
def test_get_coords_poly_distance(self): | ||
coords = [[1, 1], [3, 3]] | ||
poly = [[0, 0], [2, 0], [2, 2], [0, 2]] | ||
distances = get_coords_poly_distance(coords, poly) | ||
assert distances.shape == (len(coords),) | ||
assert keras.ops.all(distances >= 0) | ||
|
||
def test_get_normalized_weight(self): | ||
heatmap = np.array([[0.1, 0.6], [0.4, 0.8]]) | ||
mask = np.array([[1, 1], [1, 1]]) | ||
weight = get_normalized_weight(heatmap, mask) | ||
assert weight.shape == heatmap.shape | ||
assert np.all(weight >= 0) | ||
|
||
mask_partial = np.array([[1, 0], [1, 1]]) | ||
weight_partial = get_normalized_weight(heatmap, mask_partial) | ||
assert np.all(weight_partial >= 0) | ||
|
||
def test_get_region_coordinate(self): | ||
w, h = 10, 10 | ||
poly = [[[1, 1], [8, 1], [8, 3], [1, 3]]] | ||
heights = [1] | ||
shrink = 0.2 | ||
regions = get_region_coordinate(w, h, poly, heights, shrink) | ||
assert isinstance(regions, list) | ||
if regions: | ||
assert isinstance(regions[0], np.ndarray) | ||
|
||
poly_multiple = [ | ||
[[1, 1], [8, 1], [8, 3], [1, 3]], | ||
[[2, 5], [7, 5], [7, 7], [2, 7]], | ||
] | ||
heights_multiple = [1, 0.8] | ||
regions_multiple = get_region_coordinate( | ||
w, h, poly_multiple, heights_multiple, shrink | ||
) | ||
assert isinstance(regions_multiple, list) | ||
assert len(regions_multiple) <= len(poly_multiple) | ||
for region in regions_multiple: | ||
assert isinstance(region, np.ndarray) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import keras | ||
from keras import layers | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.backbone import Backbone | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DiffBinBackbone") | ||
class DiffBinBackbone(Backbone): | ||
"""Differentiable Binarization architecture for scene text detection. | ||
This class implements the Differentiable Binarization architecture for | ||
detecting text in natural images, described in | ||
[Real-time Scene Text Detection with Differentiable Binarization]( | ||
https://arxiv.org/abs/1911.08947). | ||
The backbone architecture in this class contains the feature pyramid | ||
network and model heads. | ||
Args: | ||
image_encoder: A `keras_hub.models.ResNetBackbone` instance. | ||
fpn_channels: int. The number of channels to output by the feature | ||
pyramid network. Defaults to 256. | ||
head_kernel_list: list of ints. The kernel sizes of probability map and | ||
threshold map heads. Defaults to [3, 2, 2]. | ||
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype | ||
to use for the model's computations and weights. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
image_encoder, | ||
fpn_channels=256, | ||
head_kernel_list=[3, 2, 2], | ||
dtype=None, | ||
**kwargs, | ||
): | ||
# === Functional Model === | ||
inputs = image_encoder.input | ||
x = image_encoder.pyramid_outputs | ||
x = diffbin_fpn_model(x, out_channels=fpn_channels, dtype=dtype) | ||
|
||
probability_maps = diffbin_head( | ||
x, | ||
in_channels=fpn_channels, | ||
kernel_list=head_kernel_list, | ||
name="head_prob", | ||
) | ||
threshold_maps = diffbin_head( | ||
x, | ||
in_channels=fpn_channels, | ||
kernel_list=head_kernel_list, | ||
name="head_thresh", | ||
) | ||
|
||
outputs = { | ||
"probability_maps": probability_maps, | ||
"threshold_maps": threshold_maps, | ||
} | ||
|
||
super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs) | ||
|
||
# === Config === | ||
self.image_encoder = image_encoder | ||
self.fpn_channels = fpn_channels | ||
self.head_kernel_list = head_kernel_list | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config["fpn_channels"] = self.fpn_channels | ||
config["head_kernel_list"] = self.head_kernel_list | ||
config["image_encoder"] = keras.layers.serialize(self.image_encoder) | ||
return config | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
config["image_encoder"] = keras.layers.deserialize( | ||
config["image_encoder"] | ||
) | ||
return cls(**config) | ||
|
||
|
||
def diffbin_fpn_model(inputs, out_channels, dtype=None): | ||
# lateral layers composing the FPN's bottom-up pathway using | ||
# pointwise convolutions of ResNet's pyramid outputs | ||
lateral_p2 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p2", | ||
dtype=dtype, | ||
)(inputs["P2"]) | ||
lateral_p3 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p3", | ||
dtype=dtype, | ||
)(inputs["P3"]) | ||
lateral_p4 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p4", | ||
dtype=dtype, | ||
)(inputs["P4"]) | ||
lateral_p5 = layers.Conv2D( | ||
out_channels, | ||
kernel_size=1, | ||
use_bias=False, | ||
name="neck_lateral_p5", | ||
dtype=dtype, | ||
)(inputs["P5"]) | ||
# top-down fusion pathway consisting of upsampling layers with | ||
# skip connections | ||
topdown_p5 = lateral_p5 | ||
topdown_p4 = layers.Add(name="neck_topdown_p4")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p5), | ||
lateral_p4, | ||
] | ||
) | ||
topdown_p3 = layers.Add(name="neck_topdown_p3")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p4), | ||
lateral_p3, | ||
] | ||
) | ||
topdown_p2 = layers.Add(name="neck_topdown_p2")( | ||
[ | ||
layers.UpSampling2D(dtype=dtype)(topdown_p3), | ||
lateral_p2, | ||
] | ||
) | ||
# construct merged feature maps for each pyramid level | ||
featuremap_p5 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p5", | ||
dtype=dtype, | ||
)(topdown_p5) | ||
featuremap_p4 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p4", | ||
dtype=dtype, | ||
)(topdown_p4) | ||
featuremap_p3 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p3", | ||
dtype=dtype, | ||
)(topdown_p3) | ||
featuremap_p2 = layers.Conv2D( | ||
out_channels // 4, | ||
kernel_size=3, | ||
padding="same", | ||
use_bias=False, | ||
name="neck_featuremap_p2", | ||
dtype=dtype, | ||
)(topdown_p2) | ||
featuremap_p5 = layers.UpSampling2D((8, 8), dtype=dtype)(featuremap_p5) | ||
featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4) | ||
featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3) | ||
featuremap = layers.Concatenate(axis=-1, dtype=dtype)( | ||
[featuremap_p5, featuremap_p4, featuremap_p3, featuremap_p2] | ||
) | ||
return featuremap | ||
|
||
|
||
def diffbin_head(inputs, in_channels, kernel_list, name): | ||
x = layers.Conv2D( | ||
in_channels // 4, | ||
kernel_size=kernel_list[0], | ||
padding="same", | ||
use_bias=False, | ||
name=f"{name}_conv0_weights", | ||
)(inputs) | ||
x = layers.BatchNormalization( | ||
beta_initializer=keras.initializers.Constant(1e-4), | ||
gamma_initializer=keras.initializers.Constant(1.0), | ||
name=f"{name}_conv0_bn", | ||
)(x) | ||
x = layers.ReLU(name=f"{name}_conv0_relu")(x) | ||
x = layers.Conv2DTranspose( | ||
in_channels // 4, | ||
kernel_size=kernel_list[1], | ||
strides=2, | ||
padding="valid", | ||
bias_initializer=keras.initializers.RandomUniform( | ||
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
), | ||
name=f"{name}_conv1_weights", | ||
)(x) | ||
x = layers.BatchNormalization( | ||
beta_initializer=keras.initializers.Constant(1e-4), | ||
gamma_initializer=keras.initializers.Constant(1.0), | ||
name=f"{name}_conv1_bn", | ||
)(x) | ||
x = layers.ReLU(name=f"{name}_conv1_relu")(x) | ||
x = layers.Conv2DTranspose( | ||
1, | ||
kernel_size=kernel_list[2], | ||
strides=2, | ||
padding="valid", | ||
activation="sigmoid", | ||
bias_initializer=keras.initializers.RandomUniform( | ||
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5, | ||
), | ||
name=f"{name}_conv2_weights", | ||
)(x) | ||
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.DiffBinImageConverter") | ||
class DiffBinImageConverter(ImageConverter): | ||
backbone_cls = DiffBinBackbone |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
from keras_hub.src.models.diffbin.diffbin_image_converter import ( | ||
DiffBinImageConverter, | ||
) | ||
from keras_hub.src.models.image_segmenter_preprocessor import ( | ||
ImageSegmenterPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DiffBinPreprocessor") | ||
class DiffBinPreprocessor(ImageSegmenterPreprocessor): | ||
backbone_cls = DiffBinBackbone | ||
image_converter_cls = DiffBinImageConverter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""Differentiable Binarization preset configurations.""" | ||
|
||
backbone_presets = { | ||
"diffbin_r50vd_icdar2015": { | ||
"metadata": { | ||
"description": ( | ||
"Differentiable Binarization using 50-layer" | ||
"ResNetVD trained on the ICDAR2015 dataset." | ||
), | ||
"params": 25482722, | ||
"official_name": "DifferentiableBinarization", | ||
"path": "diffbin", | ||
"model_card": "https://arxiv.org/abs/1911.08947", | ||
}, | ||
"kaggle_handle": "", # TODO | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import keras | ||
from keras import layers | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone | ||
from keras_hub.src.models.diffbin.diffbin_preprocessor import ( | ||
DiffBinPreprocessor, | ||
) | ||
from keras_hub.src.models.diffbin.losses import DiffBinLoss | ||
from keras_hub.src.models.image_text_detector_preprocessor import ( | ||
ImageTextDetectorPreprocessor, | ||
) | ||
|
||
|
||
@keras_hub_export("keras_hub.models.DiffBinImageTextDetector") | ||
class DiffBinImageTextDetector(ImageTextDetectorPreprocessor): | ||
"""Differentiable Binarization scene text detection task. | ||
`DiffBinImageTextDetector` tasks wrap a `keras_hub.models.DiffBinBackbone` | ||
and a `keras_hub.models.Preprocessor` to create a model that can be used | ||
for detecting text in natural images. | ||
The probability map output generated by `predict()` can be translated into | ||
polygon representation using `model.postprocess_to_polygons()`. | ||
Args: | ||
backbone: A `keras_hub.models.DiffBinBackbone` | ||
instance. | ||
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance, | ||
a `keras.Layer` instance, or a callable. If `None` no preprocessing | ||
will be applied to the inputs. | ||
Examples: | ||
```python | ||
input_data = np.ones(shape=(8, 224, 224, 3)) | ||
image_encoder = keras_hub.models.ResNetBackbone.from_preset( | ||
"resnet_vd_50_imagenet" | ||
) | ||
backbone = keras_hub.models.DiffBinBackbone(image_encoder) | ||
detector = keras_hub.models.DiffBinImageTextDetector( | ||
backbone=backbone | ||
) | ||
map_output = detector(input_data) | ||
``` | ||
`map_output` now holds a 8x224x224x3 tensor, where the last dimension | ||
corresponds to the model's probability map, threshold map and binary map | ||
outputs. Use `postprocess_to_polygons()` to obtain a polygon | ||
representation: | ||
```python | ||
detector.postprocess_to_polygons(map_output[...,0]) | ||
``` | ||
""" | ||
|
||
backbone_cls = DiffBinBackbone | ||
preprocessor_cls = DiffBinPreprocessor | ||
|
||
def __init__( | ||
self, | ||
backbone, | ||
preprocessor=None, | ||
**kwargs, | ||
): | ||
# === Functional Model === | ||
inputs = backbone.input | ||
x = backbone(inputs) | ||
probability_maps = x["probability_maps"] | ||
threshold_maps = x["threshold_maps"] | ||
binary_maps = step_function(probability_maps, threshold_maps) | ||
outputs = layers.Concatenate(axis=-1)( | ||
[probability_maps, threshold_maps, binary_maps] | ||
) | ||
|
||
super().__init__(inputs=inputs, outputs=outputs, **kwargs) | ||
|
||
# === Config === | ||
self.backbone = backbone | ||
self.preprocessor = preprocessor | ||
|
||
def compile( | ||
self, | ||
optimizer="auto", | ||
loss="auto", | ||
**kwargs, | ||
): | ||
"""Configures the `DiffBinImageTextDetector` task for training. | ||
`DiffBinImageTextDetector` extends the default compilation signature | ||
of `keras.Model.compile` with defaults for `optimizer` and `loss`. To | ||
override these defaults, pass any value to these arguments during | ||
compilation. | ||
Args: | ||
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer` | ||
instance. Defaults to `"auto"`, which uses the default | ||
optimizer for `DiffBinImageTextDetector`. See | ||
`keras.Model.compile` and `keras.optimizers` for more info on | ||
possible `optimizer` values. | ||
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance. | ||
Defaults to `"auto"`, in which case the default loss | ||
computation of `DiffBinImageTextDetector` will be applied. | ||
See `keras.Model.compile` and `keras.losses` for more info on | ||
possible `loss` values. | ||
**kwargs: See `keras.Model.compile` for a full list of arguments | ||
supported by the compile method. | ||
""" | ||
if optimizer == "auto": | ||
# parameters from https://arxiv.org/abs/1911.08947 | ||
optimizer = keras.optimizers.SGD( | ||
learning_rate=0.007, weight_decay=0.0001, momentum=0.9 | ||
) | ||
if loss == "auto": | ||
loss = DiffBinLoss() | ||
super().compile( | ||
optimizer=optimizer, | ||
loss=loss, | ||
**kwargs, | ||
) | ||
|
||
|
||
def step_function(x, y, k=50.0): | ||
return 1.0 / (1.0 + keras.ops.exp(-k * (x - y))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import keras | ||
from keras import ops | ||
|
||
|
||
class DiceLoss: | ||
"""Computes the Dice loss for image segmentation tasks. | ||
Dice loss evaluates the overlap between predicted and ground truth masks | ||
and is particularly effective in handling class imbalance. | ||
This class does not subclass `keras.losses.Loss`, as it expects an | ||
additional `mask` argument for loss computation. | ||
Args: | ||
eps: float. A small constant to avoid zero division. Defaults to 1e-6. | ||
""" | ||
|
||
def __init__(self, eps=1e-6): | ||
self.eps = eps | ||
|
||
def __call__(self, y_true, y_pred, mask, weights=None): | ||
if weights is not None: | ||
mask = weights * mask | ||
intersection = ops.sum((y_pred * y_true * mask)) | ||
union = ops.sum((y_pred * mask)) + ops.sum(y_true * mask) + self.eps | ||
loss = 1 - 2.0 * intersection / union | ||
return loss | ||
|
||
|
||
class MaskL1Loss: | ||
"""Computes the L1 loss of masked predictions. | ||
This class does not subclass `keras.losses.Loss`, as it expects an | ||
additional `mask` argument for loss computation. | ||
""" | ||
|
||
def __call__(self, y_true, y_pred, mask): | ||
mask_sum = ops.sum(mask) | ||
loss = ops.where( | ||
mask_sum == 0.0, | ||
0.0, | ||
ops.sum(ops.absolute(y_pred - y_true) * mask) / mask_sum, | ||
) | ||
return loss | ||
|
||
|
||
class BalanceCrossEntropyLoss: | ||
"""Compute binary cross entropy, balancing negatives with positives. | ||
This class uses hard negative mining, as described in | ||
[Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947) | ||
for balancing negatives with positives. Hence, for loss computation we only | ||
consider a certain fraction of top negatives, relative to the number of | ||
positives. | ||
This class does not subclass `keras.losses.Loss`, as it expects an | ||
additional `mask` argument for loss computation. | ||
Args: | ||
negative_ratio: float. The upper bound for the number of negatives we | ||
consider for loss computation, relative to the number of positives. | ||
Defaults to 3.0. | ||
eps: float. A small constant to avoid zero division. Defaults to 1e-6. | ||
""" | ||
|
||
def __init__(self, negative_ratio=3.0, eps=1e-6): | ||
self.negative_ratio = negative_ratio | ||
self.eps = eps | ||
|
||
def __call__(self, y_true, y_pred, mask, return_origin=False): | ||
positive = ops.cast((y_true > 0.5) & ops.cast(mask, "bool"), "uint8") | ||
negative = ops.cast((y_true < 0.5) & ops.cast(mask, "bool"), "uint8") | ||
positive_count = ops.sum(ops.cast(positive, "int32")) | ||
negative_count = ops.sum(ops.cast(negative, "int32")) | ||
negative_count_max = ops.cast( | ||
ops.cast(positive_count, "float32") * self.negative_ratio, "int32" | ||
) | ||
|
||
negative_count = ops.where( | ||
negative_count > negative_count_max, | ||
negative_count_max, | ||
negative_count, | ||
) | ||
# Keras' losses reduce some axis. Since we don't want that here, we add | ||
# a dummy dimension to y_true and y_pred | ||
loss = keras.losses.BinaryCrossentropy( | ||
from_logits=False, | ||
label_smoothing=0.0, | ||
axis=-1, | ||
reduction=None, | ||
)(y_true=y_true[..., None], y_pred=y_pred[..., None]) | ||
|
||
positive_loss = loss * ops.cast(positive, "float32") | ||
negative_loss = loss * ops.cast(negative, "float32") | ||
|
||
# Hard negative mining | ||
# Compute the threshold for hard negatives, and zero-out | ||
# negative losses below the threshold. Using this approach, | ||
# we achieve efficient computation on GPUs | ||
|
||
# compute negative_count relative to the element count of y_pred | ||
negative_count_rel = ops.cast(negative_count, "float32") / ops.prod( | ||
ops.cast(ops.shape(y_pred), "float32") | ||
) | ||
# compute the threshold value for negative losses and zero neg. loss | ||
# values below this threshold | ||
negative_loss_thresh = ops.quantile( | ||
negative_loss, 1.0 - negative_count_rel | ||
) | ||
negative_loss = negative_loss * ops.cast( | ||
negative_loss > negative_loss_thresh, "float32" | ||
) | ||
|
||
balance_loss = (ops.sum(positive_loss) + ops.sum(negative_loss)) / ( | ||
ops.cast(positive_count + negative_count, "float32") + self.eps | ||
) | ||
|
||
if return_origin: | ||
return balance_loss, loss | ||
return balance_loss | ||
|
||
|
||
class DiffBinLoss(keras.losses.Loss): | ||
"""Computes the loss for the Differentiable Binarization model. | ||
Args: | ||
eps: float. A small constant to avoid zero division. Defaults to 1e-6. | ||
l1_scale: float. The scaling factor for the threshold map output's L1 | ||
loss contribution to the total loss. Defaults to 10.0. | ||
bce_scale: float. The scaling factor for the probability map's balance | ||
cross entropy loss contribution to the total loss. Defaults to 5.0. | ||
""" | ||
|
||
def __init__(self, eps=1e-6, l1_scale=10.0, bce_scale=5.0, **kwargs): | ||
super().__init__(*kwargs) | ||
self.dice_loss = DiceLoss(eps=eps) | ||
self.l1_loss = MaskL1Loss() | ||
self.bce_loss = BalanceCrossEntropyLoss() | ||
|
||
self.l1_scale = l1_scale | ||
self.bce_scale = bce_scale | ||
|
||
def call(self, y_true, y_pred): | ||
p_map_pred, t_map_pred, b_map_pred = ops.unstack(y_pred, 3, axis=-1) | ||
shrink_map, shrink_mask, thresh_map, thresh_mask = ops.unstack( | ||
y_true, 4, axis=-1 | ||
) | ||
|
||
# we here implement L1BalanceCELoss from PyTorch's | ||
# Differentiable Binarization implementation | ||
Ls = self.bce_loss( | ||
y_true=shrink_map, | ||
y_pred=p_map_pred, | ||
mask=shrink_mask, | ||
return_origin=False, | ||
) | ||
Lt = self.l1_loss( | ||
y_true=thresh_map, | ||
y_pred=t_map_pred, | ||
mask=thresh_mask, | ||
) | ||
dice_loss = self.dice_loss( | ||
y_true=shrink_map, | ||
y_pred=b_map_pred, | ||
mask=shrink_mask, | ||
) | ||
loss = dice_loss + self.l1_scale * Lt + Ls * self.bce_scale | ||
return loss |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import cv2 | ||
import keras | ||
import numpy as np | ||
|
||
from keras_hub.src.models.preprocessor import Preprocessor | ||
from keras_hub.src.utils.tensor_utils import preprocessing_function | ||
|
||
|
||
class ImageTextDetectorPreprocessor(Preprocessor): | ||
"""Base class for image text detector preprocessing layers.""" | ||
|
||
def __init__( | ||
self, | ||
image_converter=None, | ||
target_size=(640, 640), | ||
shrink_ratio=0.3, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
self.image_converter = image_converter | ||
self.target_size = target_size | ||
self.shrink_ratio = shrink_ratio | ||
|
||
@preprocessing_function | ||
def call(self, x, y=None, sample_weight=None): | ||
if y is None: | ||
return self.image_converter(x) | ||
else: | ||
# Pass bounding boxes through image converter in the dictionary | ||
# with keys format standardized by core Keras. | ||
output = self.image_converter( | ||
{ | ||
"images": x, | ||
"bounding_boxes": y, | ||
} | ||
) | ||
x = output["images"] | ||
y = output["bounding_boxes"] | ||
|
||
# Text Region= Converts polygon/bounding box labels to a binary mask. | ||
# Pixel within text region is 1, otherwise 0 | ||
x = self.image_converter(x) | ||
|
||
# Intialize empty mask with zeros | ||
mask = np.zeros(x.shape[:2], dtype=np.uint8) | ||
for poly in y: | ||
poly = np.array(poly, dtype=np.int32) | ||
cv2.fillPoly(mask, [poly], 1) | ||
|
||
# check if edge pixels are 1 | ||
top_edge = np.any(mask[0, :]) | ||
bottom_edge = np.any(mask[-1, :]) | ||
left_edge = np.any(mask[:, 0]) | ||
right_edge = np.any(mask[:, -1]) | ||
|
||
if not (top_edge or bottom_edge or left_edge or right_edge): | ||
# Shrink the mask by a ratio | ||
y = cv2.resize( | ||
y, | ||
( | ||
int(mask.shape[1] * self.shrink_ratio), | ||
int(mask.shape[0] * self.shrink_ratio), | ||
), | ||
) | ||
else: | ||
y = mask | ||
|
||
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need any bounding boxes for this task.