Skip to content

Differential Binarization model #2095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 312 additions & 0 deletions keras_hub/src/models/diffbin/db_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
import keras


class Point:
def __init__(self, x, y):
self.x = x
self.y = y

def __add__(self, other):
return Point(self.x + other.x, self.y + other.y)

def __sub__(self, other):
return Point(self.x - other.x, self.y - other.y)

def __neg__(self):
return Point(-self.x, -self.y)

def cross(self, other):
return self.x * other.y - self.y * other.x

def to_tuple(self):
return (self.x, self.y)


def shrink_polygan(polygon, offset):
"""
Shrinks a polygon inward by moving each point toward the center.
"""
if len(polygon) < 3:
return polygon

if not isinstance(polygon[0], Point):
polygon = [Point(p[0], p[1]) for p in polygon]

cx = sum(p.x for p in polygon) / len(polygon)
cy = sum(p.y for p in polygon) / len(polygon)

shrunk = []
for p in polygon:
dx = p.x - cx
dy = p.y - cy
norm = max((dx**2 + dy**2) ** 0.5, 1e-6)
shrink_ratio = max(0, 1 - offset / norm)
shrunk.append(Point(cx + dx * shrink_ratio, cy + dy * shrink_ratio))

return shrunk


# Polygon Area
def Polygon(coords):
"""
Calculate the area of a polygon using the Shoelace formula.
"""
coords = keras.ops.convert_to_tensor(coords, dtype="float32")
x = coords[:, 0]
y = coords[:, 1]

x_next = keras.ops.roll(x, shift=-1, axis=0)
y_next = keras.ops.roll(y, shift=-1, axis=0)

area = 0.5 * keras.ops.abs(keras.ops.sum(x * y_next - x_next * y))
return area


# binary search smallest width
def binary_search_smallest_width(poly):
"""
The function aims maximum amount by which polygan can be shrunk by
taking polygan's smallest width
"""
if len(poly) < 3:
return 0

low, high = 0, 1

while high - low > 0.01:
mid = (high + low) / 2
mid_poly = shrink_polygan(poly, mid)
mid_poly = keras.ops.cast(
keras.ops.stack([[p.x, p.y] for p in mid_poly]), dtype="float32"
)
area = Polygon(mid_poly)

if area > 0.1:
low = mid
else:
high = mid

height = (low + high) / 2
height = (low + high) / 2
return int(height) if height >= 0.1 else 0


# project point to line
def project_point_to_line(x, u, v, axis=0):
"""
Projects a point x onto the line defined by points u and v
"""
x = keras.ops.convert_to_tensor(x, dtype="float32")
u = keras.ops.convert_to_tensor(u, dtype="float32")
v = keras.ops.convert_to_tensor(v, dtype="float32")

n = v - u
n = n / (
keras.ops.norm(n, axis=axis, keepdims=True) + keras.backend.epsilon()
)
p = u + n * keras.ops.sum((x - u) * n, axis=axis, keepdims=True)
return p


# project_point_to_segment
def project_point_to_segment(x, u, v, axis=0):
"""
Projects a point x onto the line segment defined by points u and v
"""
p = project_point_to_line(x, u, v, axis=axis)
outer = keras.ops.greater_equal(
keras.ops.sum((u - p) * (v - p), axis=axis, keepdims=True), 0
)
near_u = keras.ops.less_equal(
keras.ops.norm(u - p, axis=axis, keepdims=True),
keras.ops.norm(v - p, axis=axis, keepdims=True),
)
o = keras.ops.where(outer, keras.ops.where(near_u, u, v), p)
return o


# get line of height
def get_line_height(poly):
return binary_search_smallest_width(poly)


# cv2.fillpoly function with keras.ops
def fill_poly_keras(vertices, image_shape):
"""
Fill a polygon using the cv2.fillPoly function with keras.ops.
Ray-casting algorithm to determine if a point is inside a polygon.
"""
height, width = image_shape
x = keras.ops.arange(width)
y = keras.ops.arange(height)
xx, yy = keras.ops.meshgrid(x, y)
xx = keras.ops.cast(xx, "float32")
yy = keras.ops.cast(yy, "float32")

result = keras.ops.zeros((height, width), dtype="float32")

vertices = keras.ops.convert_to_tensor(vertices, dtype="float32")
num_vertices = vertices.shape[0]

for i in range(num_vertices):
x1, y1 = vertices[i]
x2, y2 = vertices[(i + 1) % num_vertices]

# Modified conditions to potentially include more boundary pixels
cond1 = (yy > keras.ops.minimum(y1, y2)) & (
yy <= keras.ops.maximum(y1, y2)
)
cond2 = xx < (x1 + (yy - y1) * (x2 - x1) / (y2 - y1))

result = keras.ops.where(
cond1 & cond2 & ((y1 > yy) != (y2 > yy)), 1 - result, result
)

result = keras.ops.cast(result, "int32")
return result


# get mask
def get_mask(w, h, polys, ignores):
"""
Generates a binary mask where:
- Ignored regions are set to 0
- Text regions are set to 1
"""
mask = keras.ops.ones((h, w), dtype="float32")

for poly, ignore in zip(polys, ignores):
poly = keras.ops.cast(keras.ops.convert_to_numpy(poly), dtype="int32")

if poly.shape[0] < 3:
print("Skipping invalid polygon:", poly)
continue

fill_value = 0.0 if ignore else 1.0
poly_mask = fill_poly_keras(poly, (h, w))

if ignore:
mask = keras.ops.where(
keras.ops.cast(poly_mask, "float32") == 1.0,
keras.ops.zeros_like(mask),
mask,
)
else:
mask = keras.ops.maximum(mask, poly_mask)
return mask


# get polygan coordinates projection
def get_coords_poly_projection(coords, poly):
"""
This projects set of points onto edges of a polygan and return closest
projected points
"""
start_points = keras.ops.array(poly, dtype="float32")
end_points = keras.ops.concatenate(
[
keras.ops.array(poly[1:], dtype="float32"),
keras.ops.array(poly[:1], dtype="float32"),
],
axis=0,
)
region_points = keras.ops.array(coords, dtype="float32")

projected_points = project_point_to_segment(
keras.ops.expand_dims(region_points, axis=1),
keras.ops.expand_dims(start_points, axis=0),
keras.ops.expand_dims(end_points, axis=0),
axis=2,
)

projection_distances = keras.ops.norm(
keras.ops.expand_dims(region_points, axis=1) - projected_points, axis=2
)

indices = keras.ops.expand_dims(
keras.ops.argmin(projection_distances, axis=1), axis=-1
)
best_projected_points = keras.ops.take_along_axis(
projected_points, indices[..., None], axis=1
)[:, 0, :]

return best_projected_points


# get polygan coordinates distance
def get_coords_poly_distance(coords, poly):
"""
This function calculates distance between set of points and polygan
"""
projection = get_coords_poly_projection(coords, poly)
return keras.ops.linalg.norm(projection - coords, axis=1)


# get normalized weight
def get_normalized_weight(heatmap, mask, background_weight=3.0):
"""
This function calculates normalized weight of heatmap
"""
pos = keras.ops.greater_equal(heatmap, 0.5)
neg = keras.ops.ones_like(pos, dtype="float32") - keras.ops.cast(
pos, dtype="float32"
)
pos = keras.ops.logical_and(pos, mask)
neg = keras.ops.logical_and(neg, mask)
npos = keras.ops.sum(pos)
nneg = keras.ops.sum(neg)
smooth = (
keras.ops.cast(npos, dtype="float32")
+ keras.ops.cast(nneg, dtype="float32")
+ 1
) * 0.05
wpos = (keras.ops.cast(nneg, dtype="float32") + smooth) / (
keras.ops.cast(npos, dtype="float32") + smooth
)
weight = keras.ops.zeros_like(heatmap)
# weight[keras.ops.cast(neg, dtype="bool")] = background_weight
neg = keras.ops.cast(neg, "bool")
weight = keras.ops.where(neg, background_weight, weight)
pos = keras.ops.cast(pos, "bool")
weight = keras.ops.where(pos, wpos, weight)
return weight


# Getting region coordinates
def get_region_coordinate(w, h, poly, heights, shrink):
"""
Extract coordinates of regions corresponding to text lines in an image.
"""
label_map = keras.ops.zeros((h, w), dtype="float32")
for line_id, (p, height) in enumerate(zip(poly, heights)):
if height > 0:
poly_points = [Point(row[0], row[1]) for row in p]
shrinked_poly = shrink_polygan(poly_points, height * shrink)
shrunk_poly_tuples = [point.to_tuple() for point in shrinked_poly]
shrunk_poly_tensor = keras.ops.convert_to_tensor(
shrunk_poly_tuples, dtype="float32"
)
filled_polygon = fill_poly_keras(shrunk_poly_tensor, (h, w))
label_map = keras.ops.maximum(label_map, filled_polygon)

label_map = keras.ops.convert_to_tensor(label_map)
sorted_tensor = keras.ops.sort(keras.ops.reshape(label_map, (-1,)))
diff = keras.ops.concatenate(
[
keras.ops.convert_to_tensor([True]),
(sorted_tensor[1:] != sorted_tensor[:-1]),
]
)
diff = keras.ops.reshape(diff, (-1,))
indices = keras.ops.convert_to_tensor(keras.ops.where(diff))
indices = keras.ops.reshape(indices, (-1,))
unique_labels = keras.ops.take(sorted_tensor, indices)
unique_labels = unique_labels[unique_labels != 0]
regions_coords = []
for label in unique_labels:
mask = keras.ops.equal(label_map, label)
y, x = keras.ops.nonzero(mask)
coords = keras.ops.stack([x, y], axis=-1)
regions_coords.append(keras.ops.convert_to_numpy(coords))

return regions_coords
172 changes: 172 additions & 0 deletions keras_hub/src/models/diffbin/db_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import keras
import numpy as np

from keras_hub.src.models.diffbin.db_utils import Point
from keras_hub.src.models.diffbin.db_utils import Polygon
from keras_hub.src.models.diffbin.db_utils import binary_search_smallest_width
from keras_hub.src.models.diffbin.db_utils import fill_poly_keras
from keras_hub.src.models.diffbin.db_utils import get_coords_poly_distance
from keras_hub.src.models.diffbin.db_utils import get_coords_poly_projection
from keras_hub.src.models.diffbin.db_utils import get_line_height
from keras_hub.src.models.diffbin.db_utils import get_mask
from keras_hub.src.models.diffbin.db_utils import get_normalized_weight
from keras_hub.src.models.diffbin.db_utils import get_region_coordinate
from keras_hub.src.models.diffbin.db_utils import project_point_to_line
from keras_hub.src.models.diffbin.db_utils import project_point_to_segment
from keras_hub.src.models.diffbin.db_utils import shrink_polygan
from keras_hub.src.tests.test_case import TestCase


class TestDBUtils(TestCase):
def test_point_operations(self):
p1 = Point(1, 2)
p2 = Point(3, 4)

p_add = p1 + p2
assert p_add.x == 4 and p_add.y == 6

p_sub = p2 - p1
assert p_sub.x == 2 and p_sub.y == 2

p_neg = -p1
assert p_neg.x == -1 and p_neg.y == -2

cross_product = p1.cross(p2)
assert cross_product == (1 * 4 - 2 * 3)

p_tuple = p1.to_tuple()
assert p_tuple == (1, 2)

def test_shrink_polygan(self):
polygon = [Point(0, 0), Point(2, 0), Point(2, 2), Point(0, 2)]
offset = 0.5
shrunk_polygon = shrink_polygan(polygon, offset)
assert len(shrunk_polygon) == 4
assert all(isinstance(p, Point) for p in shrunk_polygon)

polygon_np = np.array([[0, 0], [2, 0], [2, 2], [0, 2]])
shrunk_polygon_np = shrink_polygan(polygon_np, offset)
assert len(shrunk_polygon_np) == 4
assert all(isinstance(p, Point) for p in shrunk_polygon_np)

empty_polygon = []
assert shrink_polygan(empty_polygon, offset) == []

small_polygon = [Point(0, 0), Point(1, 1)]
assert shrink_polygan(small_polygon, offset) == small_polygon

def test_polygon_area(self):
coords = [[0, 0], [2, 0], [2, 2], [0, 2]]
area = Polygon(coords)
assert area == 4.0

coords_np = np.array([[0, 0], [2, 0], [2, 2], [0, 2]])
area_np = Polygon(coords_np)
assert area_np == 4.0

triangle = [[0, 0], [1, 0], [0, 1]]
area_triangle = Polygon(triangle)
assert area_triangle == 0.5

def test_binary_search_smallest_width(self):
# Example polygon, the exact shrunk amount depends on the algorithm
# We can only test if it returns a non-negative integer.
polygon = [[0, 0], [2, 0], [2, 2], [0, 2]]
width = binary_search_smallest_width(polygon)
assert isinstance(width, int)
assert width >= 0

small_polygon = [[0, 0], [1, 1]]
assert binary_search_smallest_width(small_polygon) == 0

def test_project_point_to_line(self):
x = [1, 1]
u = [0, 0]
v = [2, 0]
projection = project_point_to_line(x, u, v)
assert np.allclose(projection, [1, 0])

x_np = np.array([1, 1])
u_np = np.array([0, 0])
v_np = np.array([2, 0])
projection_np = project_point_to_line(x_np, u_np, v_np)
assert np.allclose(projection_np, [1, 0])

def test_project_point_to_segment(self):
x = [1, 1]
u = [0, 0]
v = [2, 0]
projection = project_point_to_segment(x, u, v)
assert np.allclose(projection, [1, 0])

x_off = [3, 1]
projection_off = project_point_to_segment(x_off, u, v)
assert np.allclose(projection_off, [2, 0])

def test_get_line_height(self):
polygon = [[0, 0], [2, 0], [2, 2], [0, 2]]
height = get_line_height(polygon)
assert isinstance(height, int)
assert height >= 0

def test_fill_poly_keras(self):
vertices = [[0, 0], [2, 0], [2, 2], [0, 2]]
image_shape = (3, 3)
mask = fill_poly_keras(vertices, image_shape)
assert mask.shape == image_shape
assert keras.ops.any(mask >= 0) and keras.ops.any(mask <= 1)

def test_get_mask(self):
w, h = 3, 3
polys = [[[0, 0], [2, 0], [2, 2], [0, 2]]]
ignores = [False]
mask = get_mask(w, h, polys, ignores)
assert mask.shape == (h, w)
assert keras.ops.any(mask >= 0) and keras.ops.any(mask <= 1)

def test_get_coords_poly_projection(self):
coords = [[1, 1], [3, 3]]
poly = [[0, 0], [2, 0], [2, 2], [0, 2]]
projection = get_coords_poly_projection(coords, poly)
assert projection.shape == (len(coords), 2)

def test_get_coords_poly_distance(self):
coords = [[1, 1], [3, 3]]
poly = [[0, 0], [2, 0], [2, 2], [0, 2]]
distances = get_coords_poly_distance(coords, poly)
assert distances.shape == (len(coords),)
assert keras.ops.all(distances >= 0)

def test_get_normalized_weight(self):
heatmap = np.array([[0.1, 0.6], [0.4, 0.8]])
mask = np.array([[1, 1], [1, 1]])
weight = get_normalized_weight(heatmap, mask)
assert weight.shape == heatmap.shape
assert np.all(weight >= 0)

mask_partial = np.array([[1, 0], [1, 1]])
weight_partial = get_normalized_weight(heatmap, mask_partial)
assert np.all(weight_partial >= 0)

def test_get_region_coordinate(self):
w, h = 10, 10
poly = [[[1, 1], [8, 1], [8, 3], [1, 3]]]
heights = [1]
shrink = 0.2
regions = get_region_coordinate(w, h, poly, heights, shrink)
assert isinstance(regions, list)
if regions:
assert isinstance(regions[0], np.ndarray)

poly_multiple = [
[[1, 1], [8, 1], [8, 3], [1, 3]],
[[2, 5], [7, 5], [7, 7], [2, 7]],
]
heights_multiple = [1, 0.8]
regions_multiple = get_region_coordinate(
w, h, poly_multiple, heights_multiple, shrink
)
assert isinstance(regions_multiple, list)
assert len(regions_multiple) <= len(poly_multiple)
for region in regions_multiple:
assert isinstance(region, np.ndarray)
220 changes: 220 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import keras
from keras import layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone


@keras_hub_export("keras_hub.models.DiffBinBackbone")
class DiffBinBackbone(Backbone):
"""Differentiable Binarization architecture for scene text detection.
This class implements the Differentiable Binarization architecture for
detecting text in natural images, described in
[Real-time Scene Text Detection with Differentiable Binarization](
https://arxiv.org/abs/1911.08947).
The backbone architecture in this class contains the feature pyramid
network and model heads.
Args:
image_encoder: A `keras_hub.models.ResNetBackbone` instance.
fpn_channels: int. The number of channels to output by the feature
pyramid network. Defaults to 256.
head_kernel_list: list of ints. The kernel sizes of probability map and
threshold map heads. Defaults to [3, 2, 2].
dtype: `None` or str or `keras.mixed_precision.DTypePolicy`. The dtype
to use for the model's computations and weights.
"""

def __init__(
self,
image_encoder,
fpn_channels=256,
head_kernel_list=[3, 2, 2],
dtype=None,
**kwargs,
):
# === Functional Model ===
inputs = image_encoder.input
x = image_encoder.pyramid_outputs
x = diffbin_fpn_model(x, out_channels=fpn_channels, dtype=dtype)

probability_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_prob",
)
threshold_maps = diffbin_head(
x,
in_channels=fpn_channels,
kernel_list=head_kernel_list,
name="head_thresh",
)

outputs = {
"probability_maps": probability_maps,
"threshold_maps": threshold_maps,
}

super().__init__(inputs=inputs, outputs=outputs, dtype=dtype, **kwargs)

# === Config ===
self.image_encoder = image_encoder
self.fpn_channels = fpn_channels
self.head_kernel_list = head_kernel_list

def get_config(self):
config = super().get_config()
config["fpn_channels"] = self.fpn_channels
config["head_kernel_list"] = self.head_kernel_list
config["image_encoder"] = keras.layers.serialize(self.image_encoder)
return config

@classmethod
def from_config(cls, config):
config["image_encoder"] = keras.layers.deserialize(
config["image_encoder"]
)
return cls(**config)


def diffbin_fpn_model(inputs, out_channels, dtype=None):
# lateral layers composing the FPN's bottom-up pathway using
# pointwise convolutions of ResNet's pyramid outputs
lateral_p2 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p2",
dtype=dtype,
)(inputs["P2"])
lateral_p3 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p3",
dtype=dtype,
)(inputs["P3"])
lateral_p4 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p4",
dtype=dtype,
)(inputs["P4"])
lateral_p5 = layers.Conv2D(
out_channels,
kernel_size=1,
use_bias=False,
name="neck_lateral_p5",
dtype=dtype,
)(inputs["P5"])
# top-down fusion pathway consisting of upsampling layers with
# skip connections
topdown_p5 = lateral_p5
topdown_p4 = layers.Add(name="neck_topdown_p4")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p5),
lateral_p4,
]
)
topdown_p3 = layers.Add(name="neck_topdown_p3")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p4),
lateral_p3,
]
)
topdown_p2 = layers.Add(name="neck_topdown_p2")(
[
layers.UpSampling2D(dtype=dtype)(topdown_p3),
lateral_p2,
]
)
# construct merged feature maps for each pyramid level
featuremap_p5 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p5",
dtype=dtype,
)(topdown_p5)
featuremap_p4 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p4",
dtype=dtype,
)(topdown_p4)
featuremap_p3 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p3",
dtype=dtype,
)(topdown_p3)
featuremap_p2 = layers.Conv2D(
out_channels // 4,
kernel_size=3,
padding="same",
use_bias=False,
name="neck_featuremap_p2",
dtype=dtype,
)(topdown_p2)
featuremap_p5 = layers.UpSampling2D((8, 8), dtype=dtype)(featuremap_p5)
featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4)
featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3)
featuremap = layers.Concatenate(axis=-1, dtype=dtype)(
[featuremap_p5, featuremap_p4, featuremap_p3, featuremap_p2]
)
return featuremap


def diffbin_head(inputs, in_channels, kernel_list, name):
x = layers.Conv2D(
in_channels // 4,
kernel_size=kernel_list[0],
padding="same",
use_bias=False,
name=f"{name}_conv0_weights",
)(inputs)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv0_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv0_relu")(x)
x = layers.Conv2DTranspose(
in_channels // 4,
kernel_size=kernel_list[1],
strides=2,
padding="valid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5,
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5,
),
name=f"{name}_conv1_weights",
)(x)
x = layers.BatchNormalization(
beta_initializer=keras.initializers.Constant(1e-4),
gamma_initializer=keras.initializers.Constant(1.0),
name=f"{name}_conv1_bn",
)(x)
x = layers.ReLU(name=f"{name}_conv1_relu")(x)
x = layers.Conv2DTranspose(
1,
kernel_size=kernel_list[2],
strides=2,
padding="valid",
activation="sigmoid",
bias_initializer=keras.initializers.RandomUniform(
minval=-1.0 / (in_channels // 4 * 1.0) ** 0.5,
maxval=1.0 / (in_channels // 4 * 1.0) ** 0.5,
),
name=f"{name}_conv2_weights",
)(x)
return x
8 changes: 8 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_image_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.image_converter import ImageConverter
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone


@keras_hub_export("keras_hub.layers.DiffBinImageConverter")
class DiffBinImageConverter(ImageConverter):
backbone_cls = DiffBinBackbone
14 changes: 14 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_image_converter import (
DiffBinImageConverter,
)
from keras_hub.src.models.image_segmenter_preprocessor import (
ImageSegmenterPreprocessor,
)


@keras_hub_export("keras_hub.models.DiffBinPreprocessor")
class DiffBinPreprocessor(ImageSegmenterPreprocessor):
backbone_cls = DiffBinBackbone
image_converter_cls = DiffBinImageConverter
17 changes: 17 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Differentiable Binarization preset configurations."""

backbone_presets = {
"diffbin_r50vd_icdar2015": {
"metadata": {
"description": (
"Differentiable Binarization using 50-layer"
"ResNetVD trained on the ICDAR2015 dataset."
),
"params": 25482722,
"official_name": "DifferentiableBinarization",
"path": "diffbin",
"model_card": "https://arxiv.org/abs/1911.08947",
},
"kaggle_handle": "", # TODO
}
}
124 changes: 124 additions & 0 deletions keras_hub/src/models/diffbin/diffbin_textdetector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import keras
from keras import layers

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.diffbin.diffbin_backbone import DiffBinBackbone
from keras_hub.src.models.diffbin.diffbin_preprocessor import (
DiffBinPreprocessor,
)
from keras_hub.src.models.diffbin.losses import DiffBinLoss
from keras_hub.src.models.image_text_detector_preprocessor import (
ImageTextDetectorPreprocessor,
)


@keras_hub_export("keras_hub.models.DiffBinImageTextDetector")
class DiffBinImageTextDetector(ImageTextDetectorPreprocessor):
"""Differentiable Binarization scene text detection task.
`DiffBinImageTextDetector` tasks wrap a `keras_hub.models.DiffBinBackbone`
and a `keras_hub.models.Preprocessor` to create a model that can be used
for detecting text in natural images.
The probability map output generated by `predict()` can be translated into
polygon representation using `model.postprocess_to_polygons()`.
Args:
backbone: A `keras_hub.models.DiffBinBackbone`
instance.
preprocessor: `None`, a `keras_hub.models.Preprocessor` instance,
a `keras.Layer` instance, or a callable. If `None` no preprocessing
will be applied to the inputs.
Examples:
```python
input_data = np.ones(shape=(8, 224, 224, 3))
image_encoder = keras_hub.models.ResNetBackbone.from_preset(
"resnet_vd_50_imagenet"
)
backbone = keras_hub.models.DiffBinBackbone(image_encoder)
detector = keras_hub.models.DiffBinImageTextDetector(
backbone=backbone
)
map_output = detector(input_data)
```
`map_output` now holds a 8x224x224x3 tensor, where the last dimension
corresponds to the model's probability map, threshold map and binary map
outputs. Use `postprocess_to_polygons()` to obtain a polygon
representation:
```python
detector.postprocess_to_polygons(map_output[...,0])
```
"""

backbone_cls = DiffBinBackbone
preprocessor_cls = DiffBinPreprocessor

def __init__(
self,
backbone,
preprocessor=None,
**kwargs,
):
# === Functional Model ===
inputs = backbone.input
x = backbone(inputs)
probability_maps = x["probability_maps"]
threshold_maps = x["threshold_maps"]
binary_maps = step_function(probability_maps, threshold_maps)
outputs = layers.Concatenate(axis=-1)(
[probability_maps, threshold_maps, binary_maps]
)

super().__init__(inputs=inputs, outputs=outputs, **kwargs)

# === Config ===
self.backbone = backbone
self.preprocessor = preprocessor

def compile(
self,
optimizer="auto",
loss="auto",
**kwargs,
):
"""Configures the `DiffBinImageTextDetector` task for training.
`DiffBinImageTextDetector` extends the default compilation signature
of `keras.Model.compile` with defaults for `optimizer` and `loss`. To
override these defaults, pass any value to these arguments during
compilation.
Args:
optimizer: `"auto"`, an optimizer name, or a `keras.Optimizer`
instance. Defaults to `"auto"`, which uses the default
optimizer for `DiffBinImageTextDetector`. See
`keras.Model.compile` and `keras.optimizers` for more info on
possible `optimizer` values.
loss: `"auto"`, a loss name, or a `keras.losses.Loss` instance.
Defaults to `"auto"`, in which case the default loss
computation of `DiffBinImageTextDetector` will be applied.
See `keras.Model.compile` and `keras.losses` for more info on
possible `loss` values.
**kwargs: See `keras.Model.compile` for a full list of arguments
supported by the compile method.
"""
if optimizer == "auto":
# parameters from https://arxiv.org/abs/1911.08947
optimizer = keras.optimizers.SGD(
learning_rate=0.007, weight_decay=0.0001, momentum=0.9
)
if loss == "auto":
loss = DiffBinLoss()
super().compile(
optimizer=optimizer,
loss=loss,
**kwargs,
)


def step_function(x, y, k=50.0):
return 1.0 / (1.0 + keras.ops.exp(-k * (x - y)))
168 changes: 168 additions & 0 deletions keras_hub/src/models/diffbin/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import keras
from keras import ops


class DiceLoss:
"""Computes the Dice loss for image segmentation tasks.
Dice loss evaluates the overlap between predicted and ground truth masks
and is particularly effective in handling class imbalance.
This class does not subclass `keras.losses.Loss`, as it expects an
additional `mask` argument for loss computation.
Args:
eps: float. A small constant to avoid zero division. Defaults to 1e-6.
"""

def __init__(self, eps=1e-6):
self.eps = eps

def __call__(self, y_true, y_pred, mask, weights=None):
if weights is not None:
mask = weights * mask
intersection = ops.sum((y_pred * y_true * mask))
union = ops.sum((y_pred * mask)) + ops.sum(y_true * mask) + self.eps
loss = 1 - 2.0 * intersection / union
return loss


class MaskL1Loss:
"""Computes the L1 loss of masked predictions.
This class does not subclass `keras.losses.Loss`, as it expects an
additional `mask` argument for loss computation.
"""

def __call__(self, y_true, y_pred, mask):
mask_sum = ops.sum(mask)
loss = ops.where(
mask_sum == 0.0,
0.0,
ops.sum(ops.absolute(y_pred - y_true) * mask) / mask_sum,
)
return loss


class BalanceCrossEntropyLoss:
"""Compute binary cross entropy, balancing negatives with positives.
This class uses hard negative mining, as described in
[Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/abs/1911.08947)
for balancing negatives with positives. Hence, for loss computation we only
consider a certain fraction of top negatives, relative to the number of
positives.
This class does not subclass `keras.losses.Loss`, as it expects an
additional `mask` argument for loss computation.
Args:
negative_ratio: float. The upper bound for the number of negatives we
consider for loss computation, relative to the number of positives.
Defaults to 3.0.
eps: float. A small constant to avoid zero division. Defaults to 1e-6.
"""

def __init__(self, negative_ratio=3.0, eps=1e-6):
self.negative_ratio = negative_ratio
self.eps = eps

def __call__(self, y_true, y_pred, mask, return_origin=False):
positive = ops.cast((y_true > 0.5) & ops.cast(mask, "bool"), "uint8")
negative = ops.cast((y_true < 0.5) & ops.cast(mask, "bool"), "uint8")
positive_count = ops.sum(ops.cast(positive, "int32"))
negative_count = ops.sum(ops.cast(negative, "int32"))
negative_count_max = ops.cast(
ops.cast(positive_count, "float32") * self.negative_ratio, "int32"
)

negative_count = ops.where(
negative_count > negative_count_max,
negative_count_max,
negative_count,
)
# Keras' losses reduce some axis. Since we don't want that here, we add
# a dummy dimension to y_true and y_pred
loss = keras.losses.BinaryCrossentropy(
from_logits=False,
label_smoothing=0.0,
axis=-1,
reduction=None,
)(y_true=y_true[..., None], y_pred=y_pred[..., None])

positive_loss = loss * ops.cast(positive, "float32")
negative_loss = loss * ops.cast(negative, "float32")

# Hard negative mining
# Compute the threshold for hard negatives, and zero-out
# negative losses below the threshold. Using this approach,
# we achieve efficient computation on GPUs

# compute negative_count relative to the element count of y_pred
negative_count_rel = ops.cast(negative_count, "float32") / ops.prod(
ops.cast(ops.shape(y_pred), "float32")
)
# compute the threshold value for negative losses and zero neg. loss
# values below this threshold
negative_loss_thresh = ops.quantile(
negative_loss, 1.0 - negative_count_rel
)
negative_loss = negative_loss * ops.cast(
negative_loss > negative_loss_thresh, "float32"
)

balance_loss = (ops.sum(positive_loss) + ops.sum(negative_loss)) / (
ops.cast(positive_count + negative_count, "float32") + self.eps
)

if return_origin:
return balance_loss, loss
return balance_loss


class DiffBinLoss(keras.losses.Loss):
"""Computes the loss for the Differentiable Binarization model.
Args:
eps: float. A small constant to avoid zero division. Defaults to 1e-6.
l1_scale: float. The scaling factor for the threshold map output's L1
loss contribution to the total loss. Defaults to 10.0.
bce_scale: float. The scaling factor for the probability map's balance
cross entropy loss contribution to the total loss. Defaults to 5.0.
"""

def __init__(self, eps=1e-6, l1_scale=10.0, bce_scale=5.0, **kwargs):
super().__init__(*kwargs)
self.dice_loss = DiceLoss(eps=eps)
self.l1_loss = MaskL1Loss()
self.bce_loss = BalanceCrossEntropyLoss()

self.l1_scale = l1_scale
self.bce_scale = bce_scale

def call(self, y_true, y_pred):
p_map_pred, t_map_pred, b_map_pred = ops.unstack(y_pred, 3, axis=-1)
shrink_map, shrink_mask, thresh_map, thresh_mask = ops.unstack(
y_true, 4, axis=-1
)

# we here implement L1BalanceCELoss from PyTorch's
# Differentiable Binarization implementation
Ls = self.bce_loss(
y_true=shrink_map,
y_pred=p_map_pred,
mask=shrink_mask,
return_origin=False,
)
Lt = self.l1_loss(
y_true=thresh_map,
y_pred=t_map_pred,
mask=thresh_mask,
)
dice_loss = self.dice_loss(
y_true=shrink_map,
y_pred=b_map_pred,
mask=shrink_mask,
)
loss = dice_loss + self.l1_scale * Lt + Ls * self.bce_scale
return loss
68 changes: 68 additions & 0 deletions keras_hub/src/models/image_text_detector_preprocessor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import cv2
import keras
import numpy as np

from keras_hub.src.models.preprocessor import Preprocessor
from keras_hub.src.utils.tensor_utils import preprocessing_function


class ImageTextDetectorPreprocessor(Preprocessor):
"""Base class for image text detector preprocessing layers."""

def __init__(
self,
image_converter=None,
target_size=(640, 640),
shrink_ratio=0.3,
**kwargs,
):
super().__init__(**kwargs)
self.image_converter = image_converter
self.target_size = target_size
self.shrink_ratio = shrink_ratio

@preprocessing_function
def call(self, x, y=None, sample_weight=None):
if y is None:
return self.image_converter(x)
else:
# Pass bounding boxes through image converter in the dictionary
# with keys format standardized by core Keras.
output = self.image_converter(
{
"images": x,
"bounding_boxes": y,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need any bounding boxes for this task.

}
)
x = output["images"]
y = output["bounding_boxes"]

# Text Region= Converts polygon/bounding box labels to a binary mask.
# Pixel within text region is 1, otherwise 0
x = self.image_converter(x)

# Intialize empty mask with zeros
mask = np.zeros(x.shape[:2], dtype=np.uint8)
for poly in y:
poly = np.array(poly, dtype=np.int32)
cv2.fillPoly(mask, [poly], 1)

# check if edge pixels are 1
top_edge = np.any(mask[0, :])
bottom_edge = np.any(mask[-1, :])
left_edge = np.any(mask[:, 0])
right_edge = np.any(mask[:, -1])

if not (top_edge or bottom_edge or left_edge or right_edge):
# Shrink the mask by a ratio
y = cv2.resize(
y,
(
int(mask.shape[1] * self.shrink_ratio),
int(mask.shape[0] * self.shrink_ratio),
),
)
else:
y = mask

return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)