Skip to content

Commit 963d34c

Browse files
committed
Merge branch 'infer' of https://github.com/horcham/mindocr into new_pipe
2 parents d1c2610 + 83dfa9e commit 963d34c

File tree

22 files changed

+119
-70
lines changed

22 files changed

+119
-70
lines changed

configs/layout/yolov8/yolov8n.yaml

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ predict:
159159
max_device_memory: 8GB
160160
amp_level: O0
161161
mode: 0
162-
ckpt_load_path: /root/.mindspore/models/dbnet_resnet50-c3a4aa24.ckpt
162+
ckpt_load_path: /root/.mindspore/models/yolov8n-4b9e8004.ckpt
163163
dataset_sink_mode: False
164164
dataset:
165165
type: PublayNetDataset
@@ -169,17 +169,12 @@ predict:
169169
transform_pipeline:
170170
- func_name: letterbox
171171
scaleup: False
172-
- func_name: label_norm
173-
xyxy2xywh_: True
174-
- func_name: label_pad
175-
padding_size: 160
176-
padding_value: -1
177172
- func_name: image_norm
178173
scale: 255.
179174
- func_name: image_transpose
180175
bgr2rgb: True
181176
hwc2chw: True
182-
batch_size: &refine_batch_size 13
177+
batch_size: *refine_batch_size
183178
stride: 64
184179
output_columns: ['image', 'labels', 'image_ids', 'hw_ori', 'hw_scale', 'pad']
185180
net_input_column_index: [ 0 ] # input indices for network forward func in output_columns

deploy/py_infer/src/core/model/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def warmup(self):
106106
height, width = hw_list[0]
107107
warmup_shape = [(*other_shape, height, width)] # Only single input
108108

109-
dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
110-
self.model.infer(dummy_tensor)
109+
# dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)]
110+
# self.model.infer(dummy_tensor)
111111

112112
def __del__(self):
113113
if hasattr(self, "model") and self.model:

deploy/py_infer/src/data_process/postprocess/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_device_status():
4444
def _get_status():
4545
nonlocal status
4646
try:
47+
ms.set_context(max_device_memory="0.01GB")
4748
status = ms.Tensor([0])[0:].asnumpy()[0]
4849
except RuntimeError:
4950
status = 1

deploy/py_infer/src/parallel/module/detection/det_post_node.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@ def concat_crops(self, crops: list):
2828
Returns:
2929
numpy.ndarray: A horizontally concatenated image array.
3030
"""
31-
crops_sorted = sorted(crops, key=lambda points: (points[0][1], points[0][0]))
32-
max_height = max(crop.shape[0] for crop in crops_sorted)
31+
max_height = max(crop.shape[0] for crop in crops)
3332
resized_crops = []
34-
for crop in crops_sorted:
33+
for crop in crops:
3534
h, w, c = crop.shape
3635
new_h = max_height
3736
new_w = int((w / h) * new_h)
@@ -48,6 +47,8 @@ def process(self, input_data):
4847

4948
data = input_data.data
5049
boxes = self.text_detector.postprocess(data["pred"], data["shape_list"])
50+
if self.is_concat:
51+
boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0]))
5152

5253
infer_res_list = []
5354
for box in boxes:
@@ -65,7 +66,7 @@ def process(self, input_data):
6566
sub_image = cv_utils.crop_box_from_image(image, np.array(box))
6667
sub_image_list.append(sub_image)
6768
if self.is_concat:
68-
sub_image_list = [self.concat_crops(sub_image_list)]
69+
sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)]
6970
input_data.sub_image_list = sub_image_list
7071

7172
input_data.data = None

mindocr/infer/classification/classification.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def __init__(self, args):
2929
self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline)
3030

3131
def __call__(self, img):
32-
print(img)
3332
data = {"image": img}
3433
data = run_transforms(data, self.transforms[1:])
3534
return data

mindocr/infer/classification/cls_post_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def process(self, input_data):
4949
scores = np.array(output["scores"]).tolist()
5050

5151
batch = input_data.sub_image_size
52-
if self.task_type.value in (TaskType.DET_CLS_REC.value, TaskType.Layout_DET_CLS_REC.value):
52+
if self.task_type.value in (TaskType.DET_CLS_REC.value, TaskType.LAYOUT_DET_CLS_REC.value):
5353
sub_images = input_data.sub_image_list
5454
for i in range(batch):
5555
angle, score = angles[i], scores[i]

mindocr/infer/common/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .collect_node import CollectNode
1+
from .collect_node2 import CollectNode
22
from .decode_node import DecodeNode
33
from .handout_node import HandoutNode

mindocr/infer/detection/det_post_node.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99

1010
from pipeline.framework.module_base import ModuleBase
1111
from pipeline.tasks import TaskType
12-
from .detection import DetPostProcess
12+
from .detection import DetPostprocess
1313
from tools.infer.text.utils import crop_text_region
1414
from pipeline.data_process.utils.cv_utils import crop_box_from_image
1515

1616
class DetPostNode(ModuleBase):
1717
def __init__(self, args, msg_queue, tqdm_info):
1818
super(DetPostNode, self).__init__(args, msg_queue, tqdm_info)
19-
self.det_postprocess = DetPostProcess(args)
19+
self.det_postprocess = DetPostprocess(args)
2020
self.task_type = self.args.task_type
2121
self.is_concat = self.args.is_concat
2222

@@ -51,17 +51,11 @@ def process(self, input_data):
5151
return
5252

5353
pred = input_data.data["det_infer_res"]
54-
# print("pred:", len(pred))
5554
pred = pred[0]
5655
data_dict = {"shape_list": input_data.data["det_pre_res"]["shape_list"]}
5756
boxes = self.det_postprocess(pred, data_dict)
5857

59-
60-
6158
boxes = boxes['polys'][0]
62-
63-
# TODO ZHQ 对齐 tools/infer/text/postprocess.py?
64-
# print(boxes)
6559

6660
if self.is_concat:
6761
boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0]))
@@ -72,11 +66,9 @@ def process(self, input_data):
7266

7367
input_data.infer_result = infer_res_list
7468

75-
# ZHQ TODO
76-
77-
# input_data.sub_image_total = len(infer_res_list)
78-
# input_data.sub_image_size = len(infer_res_list)
7969
if self.task_type.value in (TaskType.DET.value, TaskType.DET_REC.value, TaskType.DET_CLS_REC.value):
70+
if len(input_data.frame) == 0:
71+
return
8072
image = input_data.frame[0] # bs=1 for det
8173
else:
8274
image = input_data.data["layout_images"][0]
@@ -88,9 +80,6 @@ def process(self, input_data):
8880
sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)]
8981
input_data.sub_image_list = sub_image_list
9082

91-
# if not (self.args.crop_save_dir or self.args.vis_det_save_dir or self.args.vis_pipeline_save_dir):
92-
# input_data.frame = None
93-
9483
if not infer_res_list:
9584
input_data.skip = True
9685

mindocr/infer/detection/det_pre_node.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,14 @@ def process(self, input_data):
3535
return
3636
image = input_data.data["layout_images"][0] # bs = 1 for det
3737
data = self.det_preprocesser({"image": image})
38-
# print(data)
3938

4039
if len(data["image"].shape) == 3:
4140
data["image"] = np.expand_dims(data["image"], 0)
4241
data["shape_list"] = np.expand_dims(data["shape_list"], 0)
43-
# print(data["image"].shape)
44-
# time.sleep(1000)
45-
if self.task_type.value == TaskType.DET.value and not (self.args.crop_save_dir or self.args.vis_det_save_dir):
46-
input_data.frame = None
42+
# if self.task_type.value == TaskType.DET.value and not (self.args.crop_save_dir or self.args.vis_det_save_dir):
43+
# input_data.frame = None
4744

48-
if self.task_type.value in (TaskType.LAYOUT_DET.value, TaskType.LAYOUT_DET_REC, TaskType.LAYOUT_DET_CLS_REC):
45+
if self.task_type.value in (TaskType.LAYOUT_DET.value, TaskType.LAYOUT_DET_REC.value, TaskType.LAYOUT_DET_CLS_REC.value):
4946
input_data.data["det_pre_res"] = data
5047
else:
5148
input_data.data = {"det_pre_res": data}

mindocr/infer/detection/detection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ def __init__(self, args):
3636
break
3737
self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline)
3838

39-
def __call__(self, img):
40-
data = {"image": img}
39+
def __call__(self, data):
4140
data = run_transforms(data, self.transforms[1:])
4241
return data
4342

mindocr/infer/layout/layout_pre_node.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ def process(self, input_data):
4747
"target_size": [800, 800],
4848
}
4949
data = self.layout_preprocesser(data)
50-
# print(data)
5150

5251
if len(data["image"].shape) == 3:
5352
data["image"] = np.expand_dims(data["image"], 0)

mindocr/infer/node_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
__all__ = ["MODEL_DICT_v2",
1919
"DET_DESC_v2", "CLS_DESC_v2", "REC_DESC_v2",
20-
"DET_REC_DESC_v2", "DET_CLS_REC_DESC_v2"]
20+
"DET_REC_DESC_v2", "DET_CLS_REC_DESC_v2",
21+
"LAYOUT_DESC_v2", "LAYOUT_DET_REC_DESC_v2", "LAYOUT_DET_CLS_REC_DESC_v2"]
2122

2223
DET_DESC_v2 = [
2324
(("HandoutNode", "0", 1), ("DecodeNode", "0", 1)),

mindocr/infer/recognition/rec_pre_node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def process(self, input_data):
3131
image = input_data.frame[0]
3232
data = [self.rec_preprocesser(image)["image"]]
3333
input_data.sub_image_size = 1
34-
input_data.data["rec_pre_res"] = data
34+
input_data.data = {"rec_pre_res": data}
3535
self.send_to_next_module(input_data)
3636
else:
3737
sub_image_list = input_data.sub_image_list

mindocr/losses/det_loss.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import os
23
from math import pi
34
from typing import Tuple, Union
45

@@ -10,6 +11,8 @@
1011
__all__ = ["DBLoss", "PSEDiceLoss", "EASTLoss", "FCELoss"]
1112
_logger = logging.getLogger(__name__)
1213

14+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
15+
1316

1417
class DBLoss(nn.LossBase):
1518
"""
@@ -165,7 +168,13 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor:
165168
neg_loss = (loss * negative).view(loss.shape[0], -1)
166169

167170
neg_vals, _ = ops.sort(neg_loss)
168-
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)
171+
172+
if OFFLINE_MODE is None:
173+
neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1)
174+
else:
175+
neg_index = ops.stack(
176+
(ops.arange(loss.shape[0], dtype=neg_count.dtype), neg_vals.shape[1] - neg_count), axis=1
177+
)
169178
min_neg_score = ops.expand_dims(ops.gather_nd(neg_vals, neg_index), axis=1)
170179

171180
neg_loss_mask = (neg_loss >= min_neg_score).astype(ms.float32) # filter values less than top k

mindocr/losses/rec_loss.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import numpy as np
24

35
import mindspore as ms
@@ -6,6 +8,8 @@
68

79
__all__ = ["CTCLoss", "AttentionLoss", "VisionLANLoss"]
810

11+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
12+
913

1014
class CTCLoss(LossBase):
1115
"""
@@ -147,14 +151,21 @@ class AttentionLoss(LossBase):
147151
def __init__(self, reduction: str = "mean", ignore_index: int = 0) -> None:
148152
super().__init__()
149153
# ignore <GO> symbol, assume it is placed at 0th index
150-
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
154+
if OFFLINE_MODE is None:
155+
self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index)
156+
else:
157+
self.reduction = reduction
158+
self.ignore_index = ignore_index
151159

152160
def construct(self, logits: Tensor, labels: Tensor) -> Tensor:
153161
labels = labels[:, 1:] # without <GO> symbol
154162
num_classes = logits.shape[-1]
155163
logits = ops.reshape(logits, (-1, num_classes))
156164
labels = ops.reshape(labels, (-1,))
157-
return self.criterion(logits, labels)
165+
if OFFLINE_MODE is None:
166+
return self.criterion(logits, labels)
167+
else:
168+
return ops.cross_entropy(logits, labels, reduction=self.reduction, ignore_index=self.ignore_index)
158169

159170

160171
class SARLoss(LossBase):

mindocr/models/necks/fpn.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from typing import List, Tuple
23

34
from mindspore import Tensor, nn, ops
@@ -7,14 +8,20 @@
78
from ..utils.attention_cells import SEModule
89
from .asf import AdaptiveScaleFusion
910

11+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
1012

11-
def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
12-
if scale == 1 or shape == x.shape[2:]:
13-
return x
1413

15-
if scale:
16-
shape = (x.shape[2] * scale, x.shape[3] * scale)
17-
return ops.ResizeNearestNeighbor(shape)(x)
14+
if OFFLINE_MODE is None:
15+
def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None):
16+
if scale == 1 or shape == x.shape[2:]:
17+
return x
18+
19+
if scale:
20+
shape = (x.shape[2] * scale, x.shape[3] * scale)
21+
return ops.ResizeNearestNeighbor(shape)(x)
22+
else:
23+
def _resize_nn(x: Tensor, shape: Tensor):
24+
return ops.ResizeNearestNeighborV2()(x, shape)
1825

1926

2027
class FPN(nn.Cell):
@@ -64,11 +71,18 @@ def construct(self, features: List[Tensor]) -> Tensor:
6471
for i, uc_op in enumerate(self.unify_channels):
6572
features[i] = uc_op(features[i])
6673

67-
for i in range(2, -1, -1):
68-
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])
74+
if OFFLINE_MODE is None:
75+
for i in range(2, -1, -1):
76+
features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:])
77+
78+
for i, out in enumerate(self.out):
79+
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
80+
else:
81+
for i in range(2, -1, -1):
82+
features[i] += _resize_nn(features[i + 1], shape=ops.dyn_shape(features[i])[2:])
6983

70-
for i, out in enumerate(self.out):
71-
features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:])
84+
for i, out in enumerate(self.out):
85+
features[i] = _resize_nn(out(features[i]), shape=ops.dyn_shape(features[0])[2:])
7286

7387
return self.fuse(features[::-1]) # matching the reverse order of the original work
7488

mindocr/models/transforms/tps_spatial_transformer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import itertools
2+
import os
23
from typing import Optional, Tuple
34

45
import numpy as np
@@ -8,6 +9,8 @@
89
import mindspore.ops as ops
910
from mindspore import Tensor
1011

12+
OFFLINE_MODE = os.getenv("OFFLINE_MODE", None)
13+
1114

1215
def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor:
1316
out_type = input.dtype
@@ -112,15 +115,22 @@ def __init__(
112115
self.target_coordinate_repr = Tensor(target_coordinate_repr, dtype=ms.float32)
113116
self.target_control_points = Tensor(target_control_points, dtype=ms.float32)
114117

118+
if OFFLINE_MODE is not None:
119+
self.matmul = ops.BatchMatMul()
120+
115121
def construct(
116122
self, input: Tensor, source_control_points: Tensor
117123
) -> Tuple[Tensor, Tensor]:
118124
batch_size = ops.shape(source_control_points)[0]
119125

120126
padding_matrix = ops.tile(self.padding_matrix, (batch_size, 1, 1))
121127
Y = ops.concat([source_control_points, padding_matrix], axis=1)
122-
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
123-
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
128+
if OFFLINE_MODE is None:
129+
mapping_matrix = ops.matmul(self.inverse_kernel, Y)
130+
source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix)
131+
else:
132+
mapping_matrix = self.matmul(self.inverse_kernel[None, ...], Y)
133+
source_coordinate = self.matmul(self.target_coordinate_repr[None, ...], mapping_matrix)
124134
grid = ops.reshape(
125135
source_coordinate,
126136
(-1, self.target_height, self.target_width, 2),

0 commit comments

Comments
 (0)