Skip to content

Differential Binarization model #2095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 88 commits into
base: master
Choose a base branch
from

Conversation

mehtamansi29
Copy link
Collaborator

@mehtamansi29 mehtamansi29 commented Feb 12, 2025

@sachinprasadhs sachinprasadhs added the WIP Pull requests which are work in progress and not ready yet for review. label Apr 11, 2025
Copy link
Collaborator

@sachinprasadhs sachinprasadhs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took high level pass and left some comments.
Also,
Make al the file names in follow the same format like other files, for db_utils and losses.py

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the Differential Binarization (DiffBin) model, including the model's backbone, text detector task, loss function, preprocessor, and associated utilities and tests. Additionally, there's a widespread cleanup of import statements across the API, removing redundant aliases, which improves code readability.

The implementation of the new DiffBin model is a great start, but there are several areas that need attention to improve correctness, maintainability, and completeness.

Comment on lines 32 to 48
def call(self, y_true, y_pred):
prob_map_true = y_true[..., 0:1] # Channel 0
binary_map_true = y_true[..., 1:2] # Channel 1
thresh_map_true = y_true[..., 2:3] # Channel 2
dilated_mask = y_true[..., 3:4] # Channel 3

prob_map_pred = y_pred[..., 0:1] # Channel 0 - probability maps
thresh_map_pred = y_pred[..., 1:2] # Channel 1 - threshold maps
binary_map_pred = y_pred[..., 2:3]

ls = self.hard_negative_mining_bce(prob_map_true, prob_map_pred)
lb = self.hard_negative_mining_bce(thresh_map_true, thresh_map_pred)
lt = self.threshold_map_loss(
binary_map_true, binary_map_pred, dilated_mask
)
total_loss = ls + (self.alpha * lb) + (self.beta * lt)
return total_loss

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of the loss function is confusing and seems to deviate from the original "Differentiable Binarization" paper, which could lead to incorrect behavior or difficulty in maintenance.

Specifically:

  • Variable Naming: The variable names lb and lt seem to be swapped compared to their purpose. lb is calculated on threshold maps, while lt is calculated on binary maps. The paper uses L_b for binary map loss and L_t for threshold map loss.
  • Hyperparameter Roles: The roles of alpha and beta are also swapped compared to the paper's formula L = L_s + alpha * L_b + beta * L_t. In the code, alpha weights the threshold map loss (lb) and beta weights the binary map loss (lt).
  • Loss Functions: The paper suggests L1 loss for the threshold map and Dice loss for the binary map. This implementation uses BCE for the threshold map and L1 for the binary map.

While deviations from the paper can be valid, the current implementation is hard to follow due to the inconsistencies in naming. I recommend refactoring for clarity.

Comment on lines +46 to +52
`map_output` now holds a 8x224x224x3 tensor, where the last dimension
corresponds to the model's probability map, threshold map and binary map
outputs. Use `postprocess_to_polygons()` to obtain a polygon
representation:
```python
detector.postprocess_to_polygons(map_output[...,0])
```

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The example in the docstring demonstrates a call to detector.postprocess_to_polygons(map_output[...,0]). However, the postprocess_to_polygons method is not implemented in the DiffBinTextDetector class. This method should be implemented, or the docstring should be updated to reflect the available functionality.

Comment on lines 31 to 36
self,
image_encoder,
fpn_channels=256,
head_kernel_list=[3, 2, 2],
image_shape=(640, 640, 3),
dtype=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using mutable default arguments like lists or tuples is a common pitfall in Python and can lead to unexpected behavior. It's a best practice to default to None and then assign the mutable object inside the function body.

Consider adding the following logic at the beginning of __init__:

if head_kernel_list is None:
    head_kernel_list = [3, 2, 2]
if image_shape is None:
    image_shape = (640, 640, 3)
Suggested change
self,
image_encoder,
fpn_channels=256,
head_kernel_list=[3, 2, 2],
image_shape=(640, 640, 3),
dtype=None,
head_kernel_list=None,
image_shape=None,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can keep image_shape=(None, None, 3) for consistency

Comment on lines 120 to 126
# lateral_p5 = layers.Conv2D(
# out_channels,
# kernel_size=1,
# use_bias=False,
# name="neck_lateral_p5",
# dtype=dtype,
# )(inputs["P5"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of commented-out code should be removed to improve code clarity and maintainability.

)(topdown_p2)
featuremap_p4 = layers.UpSampling2D((4, 4), dtype=dtype)(featuremap_p4)
featuremap_p3 = layers.UpSampling2D((2, 2), dtype=dtype)(featuremap_p3)
featuremap_p2 = layers.UpSampling2D((1, 1), dtype=dtype)(featuremap_p2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This upsampling layer with a size of (1, 1) is a no-op and can be removed to improve code clarity and avoid an unnecessary operation.

Comment on lines 19 to 20
loss = self.loss_fn(y_true, y_pred)
self.assertGreaterEqual(loss.numpy(), 0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The assertion self.assertGreaterEqual(loss.numpy(), 0.0) is very weak for this test case. When the predicted values are identical to the true values, the loss should be zero (or very close to it). A more precise assertion would make this test more meaningful and robust.

Suggested change
loss = self.loss_fn(y_true, y_pred)
self.assertGreaterEqual(loss.numpy(), 0.0)
self.assertAllClose(loss.numpy(), 0.0, atol=1e-6)

Comment on lines 55 to 56
if y is None:
return self.image_converter(x)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a redundant call to self.image_converter(x). The variable x has already been updated with the result of self.image_converter(x) on line 54. This line can be simplified to return x.

            return x

@sachinprasadhs
Copy link
Collaborator

Rebase the code to the latest master and resolve the gemini suggested comments

@sachinprasadhs sachinprasadhs moved this to In Progress in KerasHub Jul 16, 2025
@mehtamansi29 mehtamansi29 marked this pull request as ready for review July 17, 2025 04:21
@mehtamansi29 mehtamansi29 changed the title [WIP] Differential Binarization model Differential Binarization model Jul 21, 2025
mehtamansi29 and others added 23 commits July 22, 2025 11:44
The inputs to `generate` are `"prompts"`, not `"text"`.

Fixes keras-team#1685
* routine HF sync

* code reformat
Bumps the python group with 2 updates: torch and torchvision.


Updates `torch` from 2.6.0+cu126 to 2.7.0+cu126

Updates `torchvision` from 0.21.0+cu126 to 0.22.0+cu126

---
updated-dependencies:
- dependency-name: torch
  dependency-version: 2.7.0+cu126
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
- dependency-name: torchvision
  dependency-version: 0.22.0+cu126
  dependency-type: direct:production
  update-type: version-update:semver-minor
  dependency-group: python
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* Modify TransformerEncoder masking documentation

* Added space before parenthesis
* Fix Mistral conversion script

This commit addresses several issues in the Mistral checkpoint conversion script:

- Adds `dropout` to the model initialization to match the Hugging Face model.
- Replaces `requests.get` with `hf_hub_download` for more reliable tokenizer downloads.
- Adds support for both `tokenizer.model` and `tokenizer.json` to handle different Mistral versions.
- Fixes a `TypeError` in the `save_to_preset` function call.

* address format issues

* adopted to latest hub style

* address format issues

---------

Co-authored-by: laxmareddyp <laxmareddyp@laxma-n2-highmem-256gbram.us-central1-f.c.gtech-rmi-dev.internal>
Updates the requirements on [tensorflow-cpu](https://github.com/tensorflow/tensorflow), [tensorflow](https://github.com/tensorflow/tensorflow), [tensorflow-text](https://github.com/tensorflow/text), torch, torchvision and [tensorflow[and-cuda]](https://github.com/tensorflow/tensorflow) to permit the latest version.

Updates `tensorflow-cpu` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0)

Updates `tensorflow` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.1...v2.19.0)

Updates `tensorflow-text` to 2.19.0
- [Release notes](https://github.com/tensorflow/text/releases)
- [Commits](tensorflow/text@v2.18.0...v2.19.0)

Updates `torch` from 2.7.0+cu126 to 2.7.1+cu126

Updates `torchvision` from 0.22.0+cu126 to 0.22.1+cu126

Updates `tensorflow[and-cuda]` to 2.19.0
- [Release notes](https://github.com/tensorflow/tensorflow/releases)
- [Changelog](https://github.com/tensorflow/tensorflow/blob/master/RELEASE.md)
- [Commits](tensorflow/tensorflow@v2.18.0...v2.19.0)

---
updated-dependencies:
- dependency-name: tensorflow-cpu
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: tensorflow
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: tensorflow-text
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
- dependency-name: torch
  dependency-version: 2.7.1+cu126
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: python
- dependency-name: torchvision
  dependency-version: 0.22.1+cu126
  dependency-type: direct:production
  update-type: version-update:semver-patch
  dependency-group: python
- dependency-name: tensorflow[and-cuda]
  dependency-version: 2.19.0
  dependency-type: direct:production
  dependency-group: python
...

Signed-off-by: dependabot[bot] <[email protected]>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
* init

* update

* bug fixes

* add qwen causal lm test

* fix qwen3 tests
* support flash-attn at torch backend

* fix

* fix

* fix

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* fix conflit

* format
* init: Add initial project structure and files

* bug: Small bug related to weight loading in the conversion script

* finalizing: Add TIMM preprocessing layer

* incorporate reviews: Consolidate stage configurations and improve API consistency

* bug: Unexpected argument error in JAX with Keras 3.5

* small addition for the D-FINE to come: No changes to the existing HGNetV2

* D-FINE JIT compile: Remove non-essential conditional statement

* refactor: Address reviews and fix some nits
* Register qwen3 presets

* fix format
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Pull requests which are work in progress and not ready yet for review.
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.