Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions doctr/datasets/vocabs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# Arabic & Persian
"arabic_diacritics": "ًٌٍَُِّْ",
"arabic_digits": "٠١٢٣٤٥٦٧٨٩",
"arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي",
"arabic_letters": "- ء آ أ ؤ إ ئ ا ٪ ب ت ث ج ح خ د ذ ر ز س ش ص ض ط ظ ع غ ف ق ك ٰیٕ٪ ل م ن ه ة و ي پ چ ڢ ڤ گ ﻻ ﻷ ﻹ ﻵ ﺀ ﺁ ﺃ ﺅ ﺇ ﺉ ﺍ ﺏ ﺕ ﺙ ﺝ ﺡ ﺥ ﺩ ﺫ ﺭ ﺯ ﺱ ﺵ ﺹ ﺽ ﻁ ﻅ ﻉ ﻍ ﻑ ﻕ ﻙ ﻝ ﻡ ﻥ ﻩ ﻩ ﻭ ﻱ ﺑ ﺗ ﺛ ﺟ ﺣ ﺧ ﺳ ﺷ ﺻ ﺿ ﻃ ﻇ ﻋ ﻏ ﻓ ﻗ ﻛ ﻟ ﻣ ﻧ ﻫ ﻳ ﺒ ﺘ ﺜ ﺠ ﺤ ﺨ ﺴ ﺸ ﺼ ﺾ ﻄ ﻈ ﻌ ﻐ ﻔ ﻘ ﻜ ﻠ ﻤ ﻨ ﻬ ﻴ ﺎ ﺐ ﺖ ﺚ ﺞ ﺢ ﺦ ﺪ ﺬ ﺮ ﺰ ﺲ ﺶ ﺺ ﺾ ﻂ ﻆ ﻊ ﻎ ﻒ ﻖ ﻚ ﻞ ﻢ ﻦ ﻪ ﺔ ﺓﺋ ﺓﺋ ى ﻼوفرّٕ ﺊ ﻯ ﻀ ﻯ ﻼ ﺋ ﺊﺓى ﻀال ص ح x ـ ـوx ﻰ ﻮ ﻲ ً ٌ ؟ ؛ « » — ! # $ % & ' ( ) * + , - . / : ; < = > ? @ [ ] ^ _ { | } ~",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest only to extend chars to the existing arabic_letters if some are missing additional arabic specific punctuations to add to the arabic_punctuation because in the arabic entry western punctuation is already included :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aditional it should not include whitespaces - our models can't work well with whitespaces so please remove if we want to make it more readable then:

"arabic_letters": "".join(["د", "غ" ...])

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Thanks for your feedback!

Just to clarify: in Arabic, letters change shape depending on their position in the word (beginning, middle, or end).
The characters I included cover all these contextual forms, which makes them more suitable for training the model accurately.

Also, the whitespaces between characters are not meant for natural spacing but are used intentionally to differentiate between the different forms of each letter during training.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mh.. Understood
Could we split this into vowels, consonants, diacritics ?

At the end each char needs to be unique and whitespace/s are not allowed as mentioned to avoid that something visual is merged we can use

"".join(["A", "B", ...])

punctuation should be removed because it's later on added to the arabic entry :)

If I merge both I get this:

['ء', 'آ', 'أ', 'ؤ', 'إ', 'ئ', 'ا', 'ب', 'ة', 'ت', 'ث', 'ج', 'ح', 'خ', 'د', 'ذ', 'ر', 'ز', 'س', 'ش', 'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ـ', 'ف', 'ق', 'ك', 'ل', 'م', 'ن', 'ه', 'و', 'ى', 'ي', 'ﻰ', 'ﻚ', 'ﻟ', 'ﺱ', 'ﻦ', 'ٰ', 'ﺞ', 'ﻛ', 'ﺩ', 'ﺀ', 'ﺨ', 'ﻋ', 'x', 'ﺺ', 'ﻫ', 'ﻱ', 'ﺲ', 'ﻝ', 'ﺕ', 'ڢ', 'ﻳ', 'ڤ', 'ﺬ', '؛', 'ﺶ', 'ﺟ', 'ﺔ', 'گ', 'ﻙ', 'ﺦ', 'ﺁ', 'ﺋ', 'ﻞ', 'ﺷ', 'ﺚ', 'ﺃ', 'ﻈ', 'ﻨ', 'ﺴ', 'ﻹ', 'ﺉ', 'ﻊ', 'ﺪ', 'ﻉ', 'ﺝ', 'ﺳ', 'ﻷ', 'ﻓ', 'ﺍ', 'ﺊ', 'ﻖ', 'ﻠ', 'ً', 'ﻍ', 'ﻣ', 'ﻇ', 'ﺾ', 'ٌ', 'چ', 'ﺿ', 'ﻧ', 'ﺡ', 'ﻗ', 'ﺙ', 'ﺼ', 'ﺑ', 'ﻅ', 'ﺓ', 'ﻯ', 'ﻭ', 'ﺒ', 'ﻤ', 'ﻔ', 'پ', 'ﺯ', 'ﻩ', 'ﻑ', 'ﻜ', 'ﺖ', 'ﺛ', 'ﺧ', 'ﺫ', 'ﺠ', 'ﻡ', 'ﻵ', 'ﻌ', 'ﺰ', 'ﻴ', 'ﻘ', 'ﻄ', 'ﻒ', '٪', 'ﺮ', 'ﺇ', 'ﺘ', 'ﺽ', 'ﻢ', 'ﻐ', 'ﻏ', 'ﻃ', 'ی', 'ﺵ', 'ﺸ', 'ﻲ', 'ﻮ', 'ﺻ', 'ﻆ', 'ﻁ', 'ﺏ', 'ﺎ', 'ﻕ', 'ﺹ', 'ﻻ', 'ﻂ', 'ﺣ', 'ﻼ', 'ﺭ', 'ﻪ', '؟', 'ﺐ', 'ﺤ', 'ﻬ', 'ٕ', 'ّ', 'ﻀ', 'ﺗ', 'ﻥ', 'ﻎ', 'ﺥ', 'ﺅ', 'ﺜ', 'ﺢ']

"arabic_punctuation": "؟؛«»—",
"persian_letters": "پچڢڤگ",
# Bangla
Expand Down Expand Up @@ -786,7 +786,8 @@
VOCABS["multilingual"] = "".join(
dict.fromkeys(
# latin_based
VOCABS["english"]
VOCABS["arabic"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this for the moment, we will add this if we have a multilingual dataset including arabic 👍

+VOCABS["english"]
+ VOCABS["albanian"]
+ VOCABS["afrikaans"]
+ VOCABS["azerbaijani"]
Expand Down
30 changes: 27 additions & 3 deletions doctr/models/detection/differentiable_binarization/base.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general a really good idea to add a sanity check 👍

But we need to rethink the implementation a bit, your current code fits only for the db_ models, but such a check should be more generic and contolable so I would suggest the following:

Here we can add an boolean argument sanity_check or something like that which defaults to False if True it should do the following before formatting and appending the data:

- Check that the coordinates are in the image ranges
- Check that the coordinates are absolute so not in range 0-1

This logic can be added as a private method to the class and called before polygon formatting

Afterwards a test needs to be added here:

def test_detection_dataset(mock_image_folder, mock_detection_label):

and
def test_detection_dataset(mock_image_folder, mock_detection_label):

If these parts are done we can add an extra arg to the detection training scripts

parser.add_argument("--check-dataset", dest="check_dataset", action="store_true", help="Check the dataset for possible issues")

and corresponding update the DetectionDataset instances:

        val_set = DetectionDataset(
            img_folder=os.path.join(args.val_path, "images"),
            label_path=os.path.join(args.val_path, "labels.json"),
            sanity_check=args.check_dataset,
            ....

Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
bin_thresh: threshold used to binzarized p_map at inference time

"""
class InvalidCoordinatesError(Exception):
def __init__(self, image_name, class_name, min_val, max_val):
self.image_name = image_name
self.class_name = class_name
self.min_val = min_val
self.max_val = max_val
message = (
f"Invalid box coordinates in {image_name}, class '{class_name}': "
f"values should be between 0 & 1, but found range [{min_val:.4f}, {max_val:.4f}]."
)
super().__init__(message)

def __init__(
self,
Expand Down Expand Up @@ -270,11 +281,24 @@
target: list[dict[str, np.ndarray]],
output_shape: tuple[int, int, int],
channels_last: bool = True,
image_names: list[str] = None, # Add optional parameter for image names
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
if any(t.dtype != np.float32 for tgt in target for t in tgt.values()):
raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.")
if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()):
raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.")

Check notice on line 288 in doctr/models/detection/differentiable_binarization/base.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/differentiable_binarization/base.py#L288

Trailing whitespace
# Enhanced error checking with image identification
for idx, tgt in enumerate(target):
for class_name, t in tgt.items():
if np.any((t[:, :4] > 1) | (t[:, :4] < 0)):
image_id = f"image #{idx}" if image_names is None else image_names[idx]
# Find the actual values that are out of range for better debugging
min_val = t[:, :4].min()
max_val = t[:, :4].max()
raise ValueError(
f"Invalid box coordinates in {image_id}, class '{class_name}': "
f"values should be between 0 & 1, but found range [{min_val:.4f}, {max_val:.4f}]. "
f"Please normalize your coordinates by dividing x by image width and y by image height."
)

input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32

Expand Down Expand Up @@ -362,4 +386,4 @@
thresh_target = thresh_target.astype(input_dtype)
thresh_mask = thresh_mask.astype(bool)

return seg_target, seg_mask, thresh_target, thresh_mask
return seg_target, seg_mask, thresh_target, thresh_mask
27 changes: 21 additions & 6 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from collections.abc import Callable
from typing import Any

from typing import Any, List, Optional
import numpy as np
import torch
from torch import nn
Expand Down Expand Up @@ -185,6 +184,7 @@
target: list[np.ndarray] | None = None,
return_model_output: bool = False,
return_preds: bool = False,
image_names: Optional[List[str]] = None, # Added parameter for image names
) -> dict[str, torch.Tensor]:
# Extract feature maps at different stages
feats = self.feat_extractor(x)
Expand Down Expand Up @@ -218,7 +218,7 @@

if target is not None:
thresh_map = self.thresh_head(feat_concat)
loss = self.compute_loss(logits, thresh_map, target)
loss = self.compute_loss(logits, thresh_map, target, image_names=image_names) # Pass image_names to compute_loss
out["loss"] = loss

return out
Expand All @@ -231,6 +231,7 @@
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
image_names: Optional[List[str]] = None, # Paramètre pour les noms d'images
) -> torch.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -242,6 +243,7 @@
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss
image_names: list of image filenames for error reporting

Returns:
A loss tensor
Expand All @@ -252,13 +254,26 @@
prob_map = torch.sigmoid(out_map)
thresh_map = torch.sigmoid(thresh_map)

targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
try:
targets = self.build_target(target, out_map.shape[1:], False, image_names=image_names)
except ValueError as e:
# Re-raise with more context about which images caused the problem
if "Invalid box coordinates" in str(e) and image_names:
batch_info = ", ".join(image_names)
raise ValueError(f"{str(e)} Images in batch: {batch_info}")
else:
raise

seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)

# Initialize all loss components
focal_loss = torch.tensor(0.0, device=out_map.device)
dice_loss = torch.tensor(0.0, device=out_map.device)
l1_loss = torch.tensor(0.0, device=out_map.device)

Check notice on line 276 in doctr/models/detection/differentiable_binarization/pytorch.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

doctr/models/detection/differentiable_binarization/pytorch.py#L276

Trailing whitespace
if torch.any(seg_mask):
# Focal loss
focal_scale = 10.0
Expand All @@ -269,7 +284,7 @@
# Unreduced version
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3))
focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / (seg_mask.sum((0, 1, 2, 3)) + eps)

# Compute dice loss for each class or for approx binary_map
if len(self.class_names) > 1:
Expand Down Expand Up @@ -429,4 +444,4 @@
"thresh_head.6.bias",
],
**kwargs,
)
)
Loading