diff --git a/DDT/generate_box_imagenet_crop.py b/DDT/generate_box_imagenet_crop.py index 09478fe..915a815 100644 --- a/DDT/generate_box_imagenet_crop.py +++ b/DDT/generate_box_imagenet_crop.py @@ -22,7 +22,7 @@ parser = argparse.ArgumentParser(description='Parameters for DDT generate box') parser.add_argument('--input_size',default=448,dest='input_size') -parser.add_argument('--data',default="data/PSOL_imgs",help='path to imagenet dataset') +parser.add_argument('--data',default="data/DDT_imgs",help='path to imagenet dataset') parser.add_argument('--gpu',help='which gpu to use',default='0',dest='gpu') parser.add_argument('--output_path',default='data/DDT_crop/',dest='output_path') parser.add_argument('--batch_size',default=64,dest='batch_size') diff --git a/README_cn.md b/README_cn.md new file mode 100644 index 0000000..8a8d1f3 --- /dev/null +++ b/README_cn.md @@ -0,0 +1,113 @@ +# 字节跳动安全AI挑战赛低分辨率抖音号识别赛道总结 + +这次比赛最终成绩为0.66465, 排名第三, 虽然和第一0.70125还有一定差距, 但大体还是很满意的. (主要归功于比赛内容和以前的经历相关) + +![image-20220908193916466](imgs/image-20220908193916466.png) + +## 任务分析 + +赛题内容可参见[比赛官网](https://security.bytedance.com/fe/2022/ai-challenge#/challenge). 简单说, 就是对视频截图中的抖音号进行识别, 并且测试集的分辨率低于训练集. 没拿到数据前, 我以为这个比赛主要是一个domain adaptation问题, 如何去消除训练集和测试集的domain shift. 但其实比赛的重点在于如何定位抖音号位置, 一般来说OCR需要先进行目标检测, 再进行识别, 而比赛并没有提供标注框, 也不能使用预训练模型, 如何采取合适的方法裁剪出抖音号所在区域对于下一阶段的文本识别影响很大. + +![image-20220908195621821](imgs/image-20220908195621821.png) + +## 定位方案 + +### 直接裁剪 + +这个方案很粗暴. 浏览数据集我们可以发现, 抖音号绝大部分出现在图片的上边或下边, 所以我们可以直接裁剪出上下边并拼接在一起. 这里我选取的是上边1/10和下边1/10, 这个界限可以囊括绝大部分抖音号. + +![image-20220908204941516](imgs/image-20220908204941516.png) + +虽然抖音号仍然只占了裁剪后图片的一小部分, 但是相比原来无用信息已经少了很多了, 这个时候直接使用OCR模型训练已经可以取得不错的结果了. 考虑到这个背景无用信息较多, 所以当时采用的是attention机制SAR模型, 在测试集结果是0.18. + +当然, 我们希望可以更准确的定位抖音号位置. + +### 传统文本定位方法 + +既然不能使用预训练深度学习模型, 那就用传统方法呗. 我第一天就尝试了Maximally Stable Extremal Regions (MSER)等传统算法, 效果并不好, 文字也很难定位准确. 我没有在这个方向上做过多尝试, 当然我相信结合比赛数据调整传统算法/手写规则也可以在本任务中很好的做到抖音号定位. + +![image-20220908211320041](imgs/image-20220908211320041.png) + +### 更精细定位位置 + +#### Attention map定位左右 + +之前提到, 我们采用attention机制的SAR模型已经可以进行不错的文字识别了, 自然就想到我们可以将其attention可视化, 通过热图去做定位. 不过效果也不是很好, attention map无法精确的对准字符, 我自己猜想是SAR是对feature map做的attention, feature map和原图并不是完全对应的. + +![image-20220908223057062](imgs/image-20220908223057062.png) + +就比如在该任务中的attention map的峰值始终在图像中间部分, 导致我们无法通过attention map区分抖音号在上边还是下边. 不过, 基于attention map我们还是可以区分出抖音号是位于图片左边还是右边的. + +#### 人工构建数据集训练分类模型定位上下 + +对于区分在上边还是下边, 我们可以人为构建一个数据集训练一个是否存在抖音号的分类模型, 如图我们将原图10等分, 为了描述, 我们将其由上到下命名为1-10, 其中1和10组合为有抖音号的正样本, 随机选取2-9中的两个拼接构成负样本. + +![image-20220908213507774](imgs/image-20220908213507774.png) + +预测时, 将2-9中任选一个分别与1, 10拼接, 通过模型对其分别进行预测, 若1所在图片得分高, 则抖音号在图片上边, 若10所在图片得分高, 则抖音号在图片下边. + +#### 小结 + +这个方案最后没有做下去, 因为这个方案的构思让我想到了更好的解决办法, 即弱监督目标定位. + +### 弱监督目标定位 + +做attention map定位觉得这个过程很像弱监督目标定位, 而上述训练分类模型中我们引入的人造标注就是个很好的弱监督信号. 所以我基于上述的分类模型, 使用了以前用过的DDT方法完成弱监督目标定位. + +image-20220908225608295 + +整个效果还是不错的 + +![image-20220908231522492](imgs/image-20220908231522492.png) + +这里其实还有提升空间, 比如Rethinking the Route Towards Weakly Supervised Object Localization (PSOL) 中利用DDT生成的伪标签重新训练网络定位, 我也尝试了, 肉眼效果区别不大, 精力有限没有继续做实验比较结果. + +我这里尝试过对于弱监督目标定位后的图片用人工规则进一步裁剪, 但是最终文本识别效果下降, 原因不明. + +## 文本识别方案 + +文本识别算法可以看我之前的总结: https://zhuanlan.zhihu.com/p/540347287, 前面提到的SAR模型也有介绍. + +在经过弱监督目标定位后的图片上, 我主要尝试的是当前SOTA的SVTR模型 (SAR跑了一次效果一般就没有再调了). + +训练500个epoch实验结果如下: + +| 方法 | 测试集准确度 | +| ------------------------------------------------------------ | ------------ | +| SVTR-Tiny | 0.484 | +| SVTR-Tiny + 常规增强 | 0.568 | +| SVTR-Tiny + 常规增强 + Text-Image-Augmentation | 0.606 | +| SVTR-Tiny + 常规增强 + Text-Image-Augmentation + Pseudo Label | 0.618 | +| SVTR-Large + 常规增强 + Text-Image-Augmentation | 0.620 | +| SVTR-Large + 常规增强 + Text-Image-Augmentation + Pseudo Label | 0.626 | +| SVTR-Large (removing STN) + 常规增强 + Text-Image-Augmentation | 0.646 | + +常规增强: 裁剪, 模糊 (作用很大, 减少了训练集和测试集的domain shift), 色彩变换等. 开始没用增强是因为发现一使用增强模型无法训练, 后来排查发现是由于图像分辨率很低, 高斯模糊时5×5的高斯核太大, 会完全模糊图像, 故无法训练, 改为3×3后才正常. + +Text-Image-Augmentation: 主要是下述三种变化, 是针对文本识别的图像增强手段. + +![img](imgs/distort.gif)![img](imgs/stretch.gif)![img](imgs/perspective.gif) + +Pseudo Label: 将测试集置信度大于一定阈值的标签作为伪标签, 重新训练模型. + +removing STN: STN本身不应该对于性能有影响, 这里移去STN主要是发现默认尺寸不适合, 且考虑到该任务文本比较规则, 所以移去. + +由于最后几天服务器宕机了, 没有来得及将Pseudo Label和SVTR-Large (removing STN) 结合, 模型集成也只有0.646一个SVTR-Large (removing STN) 模型, 其他都是原来0.625左右的SVTR-Large模型, 最后投票结果即为最终成绩0.66465. + +## 可能的提升空间 + +1. 多个结合Pseudo Label的SVTR-Large (removing STN) 集成. 这个应该可以有1-2个点的提升. +2. SOTA的弱监督目标定位. DDT是比较老的算法了, 本来是打算用PSOL的, DDT是PSOL算法的第一步, 但是发现效果不错, 就没有继续用别的了. 这里提升未知, 不一定有效果. +3. Domain adaptation. 这个作用可能不是很大, 因为数据增强中的模糊也可以类似的效果, 整体看这题的domain shift并不严重, 最后在验证集上的准确率0.75和测试集0.646差距不大. +4. 调参. 这个影响应该比较大, SVTR模型可调的参数比较多的而且实验下来影响较大, 没有机会细调. +5. RotNet自监督预训练. 这个根据CVPR2021论文[What if we only use real datasets for scene text recognition? toward scene text recognition with fewer labels](http://openaccess.thecvf.com/content/CVPR2021/html/Baek_What_if_We_Only_Use_Real_Datasets_for_Scene_Text_CVPR_2021_paper.html)甚至好于MoCo. + + + +## 相关代码 + +比赛代码: https://github.com/eshoyuan/Tiktok_OCR + +文本识别参考代码: https://github.com/PaddlePaddle/PaddleOCR + +弱监督目标定位参考代码: https://github.com/tzzcl/PSOL diff --git a/imgs/distort.gif b/imgs/distort.gif new file mode 100644 index 0000000..e4245c5 Binary files /dev/null and b/imgs/distort.gif differ diff --git a/imgs/douyu-frame-example_0da1dc04.png b/imgs/douyu-frame-example_0da1dc04.png new file mode 100644 index 0000000..5cc3fa5 Binary files /dev/null and b/imgs/douyu-frame-example_0da1dc04.png differ diff --git a/imgs/image-20220908193916466.png b/imgs/image-20220908193916466.png new file mode 100644 index 0000000..f1e52a7 Binary files /dev/null and b/imgs/image-20220908193916466.png differ diff --git a/imgs/image-20220908195621821.png b/imgs/image-20220908195621821.png new file mode 100644 index 0000000..86c3734 Binary files /dev/null and b/imgs/image-20220908195621821.png differ diff --git a/imgs/image-20220908204913076.png b/imgs/image-20220908204913076.png new file mode 100644 index 0000000..0f3f6c1 Binary files /dev/null and b/imgs/image-20220908204913076.png differ diff --git a/imgs/image-20220908204941516.png b/imgs/image-20220908204941516.png new file mode 100644 index 0000000..1359875 Binary files /dev/null and b/imgs/image-20220908204941516.png differ diff --git a/imgs/image-20220908210111319.png b/imgs/image-20220908210111319.png new file mode 100644 index 0000000..6a92eff Binary files /dev/null and b/imgs/image-20220908210111319.png differ diff --git a/imgs/image-20220908211320041.png b/imgs/image-20220908211320041.png new file mode 100644 index 0000000..4a0aa41 Binary files /dev/null and b/imgs/image-20220908211320041.png differ diff --git a/imgs/image-20220908212615794.png b/imgs/image-20220908212615794.png new file mode 100644 index 0000000..4bc33f5 Binary files /dev/null and b/imgs/image-20220908212615794.png differ diff --git a/imgs/image-20220908213152888.png b/imgs/image-20220908213152888.png new file mode 100644 index 0000000..869938a Binary files /dev/null and b/imgs/image-20220908213152888.png differ diff --git a/imgs/image-20220908213432357.png b/imgs/image-20220908213432357.png new file mode 100644 index 0000000..521706f Binary files /dev/null and b/imgs/image-20220908213432357.png differ diff --git a/imgs/image-20220908213507774.png b/imgs/image-20220908213507774.png new file mode 100644 index 0000000..5b819d4 Binary files /dev/null and b/imgs/image-20220908213507774.png differ diff --git a/imgs/image-20220908223057062.png b/imgs/image-20220908223057062.png new file mode 100644 index 0000000..a6c5ded Binary files /dev/null and b/imgs/image-20220908223057062.png differ diff --git a/imgs/image-20220908225608295.png b/imgs/image-20220908225608295.png new file mode 100644 index 0000000..926445f Binary files /dev/null and b/imgs/image-20220908225608295.png differ diff --git a/imgs/image-20220908231522492.png b/imgs/image-20220908231522492.png new file mode 100644 index 0000000..3a04f96 Binary files /dev/null and b/imgs/image-20220908231522492.png differ diff --git a/imgs/img1.png b/imgs/img1.png new file mode 100644 index 0000000..5cc3fa5 Binary files /dev/null and b/imgs/img1.png differ diff --git a/imgs/perspective.gif b/imgs/perspective.gif new file mode 100644 index 0000000..df7b92f Binary files /dev/null and b/imgs/perspective.gif differ diff --git a/imgs/stretch.gif b/imgs/stretch.gif new file mode 100644 index 0000000..bf1e4b3 Binary files /dev/null and b/imgs/stretch.gif differ diff --git a/preprocess.sh b/preprocess.sh index 62935e7..aedc4ed 100644 --- a/preprocess.sh +++ b/preprocess.sh @@ -1,7 +1,7 @@ -python3 crop.py -python3 create_classification_dataset.py -python3 classification_vgg.py -python3 create_DDT_dataset.py -python3 DDT/generate_box_imagenet_crop.py -python3 tools/train.py -c svtr_large_train_stn.yml # Distributed Training: python -m paddle.distributed.launch --gpus '0,1,2,3' -c svtr_large_train_stn.yml -python3 tools/train.py -c svtr_large_train_stn.yml -o Global.pretrained_model=svtr_large_stn/best_accuracy \ No newline at end of file +python crop.py +python create_classification_dataset.py +python classification_vgg.py +python create_DDT_dataset.py +python DDT/generate_box_imagenet_crop.py +python tools/train.py -c svtr_large_train_stn.yml # Distributed Training: python -m paddle.distributed.launch --gpus '0,1,2,3' -c svtr_large_train_stn.yml +python tools/infer_rec.py -c svtr_large_train_stn.yml -o Global.pretrained_model=svtr_large_stn/best_accuracy \ No newline at end of file diff --git a/tools/export_center.py b/tools/export_center.py deleted file mode 100644 index 30b9c33..0000000 --- a/tools/export_center.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import sys -import pickle - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) - -from ppocr.data import build_dataloader -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import print_dict -import tools.program as program - - -def main(): - global_config = config['Global'] - # build dataloader - config['Eval']['dataset']['name'] = config['Train']['dataset']['name'] - config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][ - 'data_dir'] - config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][ - 'label_file_list'] - eval_dataloader = build_dataloader(config, 'Eval', device, logger) - - # build post process - post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - # for rec algorithm - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - config['Architecture']["Head"]['out_channels'] = char_num - - #set return_features = True - config['Architecture']["Head"]["return_feats"] = True - - model = build_model(config['Architecture']) - - best_model_dict = load_model(config, model) - if len(best_model_dict): - logger.info('metric in ckpt ***************') - for k, v in best_model_dict.items(): - logger.info('{}:{}'.format(k, v)) - - # get features from train data - char_center = program.get_center(model, eval_dataloader, post_process_class) - - #serialize to disk - with open("train_center.pkl", 'wb') as f: - pickle.dump(char_center, f) - return - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/export_model.py b/tools/export_model.py deleted file mode 100644 index 84a822e..0000000 --- a/tools/export_model.py +++ /dev/null @@ -1,187 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, ".."))) - -import argparse - -import paddle -from paddle.jit import to_static - -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.logging import get_logger -from tools.program import load_config, merge_config, ArgsParser - - -def export_single_model(model, - arch_config, - save_path, - logger, - input_shape=None, - quanter=None): - if arch_config["algorithm"] == "SRN": - max_text_length = arch_config["Head"]["max_text_length"] - other_shape = [ - paddle.static.InputSpec( - shape=[None, 1, 64, 256], dtype="float32"), [ - paddle.static.InputSpec( - shape=[None, 256, 1], - dtype="int64"), paddle.static.InputSpec( - shape=[None, max_text_length, 1], dtype="int64"), - paddle.static.InputSpec( - shape=[None, 8, max_text_length, max_text_length], - dtype="int64"), paddle.static.InputSpec( - shape=[None, 8, max_text_length, max_text_length], - dtype="int64") - ] - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SAR": - other_shape = [ - paddle.static.InputSpec( - shape=[None, 3, 48, 160], dtype="float32"), - [paddle.static.InputSpec( - shape=[None], dtype="float32")] - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "SVTR": - if arch_config["Head"]["name"] == 'MultiHead': - other_shape = [ - paddle.static.InputSpec( - shape=[None, 3, 48, -1], dtype="float32"), - ] - else: - other_shape = [ - paddle.static.InputSpec( - shape=[None] + input_shape, dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - elif arch_config["algorithm"] == "PREN": - other_shape = [ - paddle.static.InputSpec( - shape=[None, 3, 64, 512], dtype="float32"), - ] - model = to_static(model, input_spec=other_shape) - else: - infer_shape = [3, -1, -1] - if arch_config["model_type"] == "rec": - infer_shape = [3, 32, -1] # for rec model, H must be 32 - if "Transform" in arch_config and arch_config[ - "Transform"] is not None and arch_config["Transform"][ - "name"] == "TPS": - logger.info( - "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training" - ) - infer_shape[-1] = 100 - if arch_config["algorithm"] == "NRTR": - infer_shape = [1, 32, 100] - elif arch_config["model_type"] == "table": - infer_shape = [3, 488, 488] - model = to_static( - model, - input_spec=[ - paddle.static.InputSpec( - shape=[None] + infer_shape, dtype="float32") - ]) - - if quanter is None: - paddle.jit.save(model, save_path) - else: - quanter.save_quantized_model(model, save_path) - logger.info("inference model is saved to {}".format(save_path)) - return - - -def main(): - FLAGS = ArgsParser().parse_args() - config = load_config(FLAGS.config) - config = merge_config(config, FLAGS.opt) - logger = get_logger() - # build post process - - post_process_class = build_post_process(config["PostProcess"], - config["Global"]) - - # build model - # for rec algorithm - if hasattr(post_process_class, "character"): - char_num = len(getattr(post_process_class, "character")) - if config["Architecture"]["algorithm"] in ["Distillation", - ]: # distillation model - for key in config["Architecture"]["Models"]: - if config["Architecture"]["Models"][key]["Head"][ - "name"] == 'MultiHead': # multi head - out_channels_list = {} - if config['PostProcess'][ - 'name'] == 'DistillationSARLabelDecode': - char_num = char_num - 2 - out_channels_list['CTCLabelDecode'] = char_num - out_channels_list['SARLabelDecode'] = char_num + 2 - config['Architecture']['Models'][key]['Head'][ - 'out_channels_list'] = out_channels_list - else: - config["Architecture"]["Models"][key]["Head"][ - "out_channels"] = char_num - # just one final tensor needs to exported for inference - config["Architecture"]["Models"][key][ - "return_all_feats"] = False - elif config['Architecture']['Head'][ - 'name'] == 'MultiHead': # multi head - out_channels_list = {} - char_num = len(getattr(post_process_class, 'character')) - if config['PostProcess']['name'] == 'SARLabelDecode': - char_num = char_num - 2 - out_channels_list['CTCLabelDecode'] = char_num - out_channels_list['SARLabelDecode'] = char_num + 2 - config['Architecture']['Head'][ - 'out_channels_list'] = out_channels_list - else: # base rec model - config["Architecture"]["Head"]["out_channels"] = char_num - - model = build_model(config["Architecture"]) - load_model(config, model) - model.eval() - - save_path = config["Global"]["save_inference_dir"] - - arch_config = config["Architecture"] - - if arch_config["algorithm"] == "SVTR" and arch_config["Head"][ - "name"] != 'MultiHead': - input_shape = config["Eval"]["dataset"]["transforms"][-2][ - 'SVTRRecResizeImg']['image_shape'] - else: - input_shape = None - - if arch_config["algorithm"] in ["Distillation", ]: # distillation model - archs = list(arch_config["Models"].values()) - for idx, name in enumerate(model.model_name_list): - sub_model_save_path = os.path.join(save_path, name, "inference") - export_single_model(model.model_list[idx], archs[idx], - sub_model_save_path, logger) - else: - save_path = os.path.join(save_path, "inference") - export_single_model( - model, arch_config, save_path, logger, input_shape=input_shape) - - -if __name__ == "__main__": - main() diff --git a/tools/infer_cls.py b/tools/infer_cls.py deleted file mode 100644 index 7fd6b53..0000000 --- a/tools/infer_cls.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import get_image_file_list -import tools.program as program - - -def main(): - global_config = config['Global'] - - # build post process - post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - model = build_model(config['Architecture']) - - load_model(config, model) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] - elif op_name == "SSLRotateResize": - op[op_name]["mode"] = "test" - transforms.append(op) - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - - model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - preds = model(images) - post_result = post_process_class(preds) - for rec_result in post_result: - logger.info('\t result: {}'.format(rec_result)) - logger.info("success!") - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/infer_det.py b/tools/infer_det.py deleted file mode 100644 index 1aceced..0000000 --- a/tools/infer_det.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import cv2 -import json -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import get_image_file_list -import tools.program as program - - -def draw_det_res(dt_boxes, config, img, img_name, save_path): - if len(dt_boxes) > 0: - import cv2 - src_im = img - for box in dt_boxes: - box = box.astype(np.int32).reshape((-1, 1, 2)) - cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - if not os.path.exists(save_path): - os.makedirs(save_path) - save_path = os.path.join(save_path, os.path.basename(img_name)) - cv2.imwrite(save_path, src_im) - logger.info("The detected Image saved in {}".format(save_path)) - - -@paddle.no_grad() -def main(): - global_config = config['Global'] - - # build model - model = build_model(config['Architecture']) - - load_model(config, model) - # build post process - post_process_class = build_post_process(config['PostProcess']) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image', 'shape'] - transforms.append(op) - - ops = create_operators(transforms, global_config) - - save_res_path = config['Global']['save_res_path'] - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - - model.eval() - with open(save_res_path, "wb") as fout: - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - - images = np.expand_dims(batch[0], axis=0) - shape_list = np.expand_dims(batch[1], axis=0) - images = paddle.to_tensor(images) - preds = model(images) - post_result = post_process_class(preds, shape_list) - - src_img = cv2.imread(file) - - dt_boxes_json = [] - # parser boxes if post_result is dict - if isinstance(post_result, dict): - det_box_json = {} - for k in post_result.keys(): - boxes = post_result[k][0]['points'] - dt_boxes_list = [] - for box in boxes: - tmp_json = {"transcription": ""} - tmp_json['points'] = box.tolist() - dt_boxes_list.append(tmp_json) - det_box_json[k] = dt_boxes_list - save_det_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/det_results_{}/".format(k) - draw_det_res(boxes, config, src_img, file, save_det_path) - else: - boxes = post_result[0]['points'] - dt_boxes_json = [] - # write result - for box in boxes: - tmp_json = {"transcription": ""} - tmp_json['points'] = box.tolist() - dt_boxes_json.append(tmp_json) - save_det_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/det_results/" - draw_det_res(boxes, config, src_img, file, save_det_path) - otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" - fout.write(otstr.encode()) - - logger.info("success!") - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py deleted file mode 100644 index d3e6b28..0000000 --- a/tools/infer_e2e.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import cv2 -import json -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import get_image_file_list -import tools.program as program - - -def draw_e2e_res(dt_boxes, strs, config, img, img_name): - if len(dt_boxes) > 0: - src_im = img - for box, str in zip(dt_boxes, strs): - box = box.astype(np.int32).reshape((-1, 1, 2)) - cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2) - cv2.putText( - src_im, - str, - org=(int(box[0, 0, 0]), int(box[0, 0, 1])), - fontFace=cv2.FONT_HERSHEY_COMPLEX, - fontScale=0.7, - color=(0, 255, 0), - thickness=1) - save_det_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/e2e_results/" - if not os.path.exists(save_det_path): - os.makedirs(save_det_path) - save_path = os.path.join(save_det_path, os.path.basename(img_name)) - cv2.imwrite(save_path, src_im) - logger.info("The e2e Image saved in {}".format(save_path)) - - -def main(): - global_config = config['Global'] - - # build model - model = build_model(config['Architecture']) - - load_model(config, model) - - # build post process - post_process_class = build_post_process(config['PostProcess'], - global_config) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image', 'shape'] - transforms.append(op) - - ops = create_operators(transforms, global_config) - - save_res_path = config['Global']['save_res_path'] - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - - model.eval() - with open(save_res_path, "wb") as fout: - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - images = np.expand_dims(batch[0], axis=0) - shape_list = np.expand_dims(batch[1], axis=0) - images = paddle.to_tensor(images) - preds = model(images) - post_result = post_process_class(preds, shape_list) - points, strs = post_result['points'], post_result['texts'] - # write result - dt_boxes_json = [] - for poly, str in zip(points, strs): - tmp_json = {"transcription": str} - tmp_json['points'] = poly.tolist() - dt_boxes_json.append(tmp_json) - otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" - fout.write(otstr.encode()) - src_img = cv2.imread(file) - draw_e2e_res(points, strs, config, src_img, file) - logger.info("success!") - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/infer_kie.py b/tools/infer_kie.py deleted file mode 100644 index 187b27a..0000000 --- a/tools/infer_kie.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import paddle.nn.functional as F - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import cv2 -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.utils.save_load import load_model -import tools.program as program -import time - - -def read_class_list(filepath): - dict = {} - with open(filepath, "r") as f: - lines = f.readlines() - for line in lines: - key, value = line.split(" ") - dict[key] = value.rstrip() - return dict - - -def draw_kie_result(batch, node, idx_to_cls, count): - img = batch[6].copy() - boxes = batch[7] - h, w = img.shape[:2] - pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 - max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) - node_pred_label = max_idx.numpy().tolist() - node_pred_score = max_value.numpy().tolist() - - for i, box in enumerate(boxes): - if i >= len(node_pred_label): - break - new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]], - [box[0], box[3]]] - Pts = np.array([new_box], np.int32) - cv2.polylines( - img, [Pts.reshape((-1, 1, 2))], - True, - color=(255, 255, 0), - thickness=1) - x_min = int(min([point[0] for point in new_box])) - y_min = int(min([point[1] for point in new_box])) - - pred_label = str(node_pred_label[i]) - if pred_label in idx_to_cls: - pred_label = idx_to_cls[pred_label] - pred_score = '{:.2f}'.format(node_pred_score[i]) - text = pred_label + '(' + pred_score + ')' - cv2.putText(pred_img, text, (x_min * 2, y_min), - cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) - vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 - vis_img[:, :w] = img - vis_img[:, w:] = pred_img - save_kie_path = os.path.dirname(config['Global'][ - 'save_res_path']) + "/kie_results/" - if not os.path.exists(save_kie_path): - os.makedirs(save_kie_path) - save_path = os.path.join(save_kie_path, str(count) + ".png") - cv2.imwrite(save_path, vis_img) - logger.info("The Kie Image saved in {}".format(save_path)) - -def write_kie_result(fout, node, data): - """ - Write infer result to output file, sorted by the predict label of each line. - The format keeps the same as the input with additional score attribute. - """ - import json - label = data['label'] - annotations = json.loads(label) - max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) - node_pred_label = max_idx.numpy().tolist() - node_pred_score = max_value.numpy().tolist() - res = [] - for i, label in enumerate(node_pred_label): - pred_score = '{:.2f}'.format(node_pred_score[i]) - pred_res = { - 'label': label, - 'transcription': annotations[i]['transcription'], - 'score': pred_score, - 'points': annotations[i]['points'], - } - res.append(pred_res) - res.sort(key=lambda x: x['label']) - fout.writelines([json.dumps(res, ensure_ascii=False) + '\n']) - -def main(): - global_config = config['Global'] - - # build model - model = build_model(config['Architecture']) - load_model(config, model) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - transforms.append(op) - - data_dir = config['Eval']['dataset']['data_dir'] - - ops = create_operators(transforms, global_config) - - save_res_path = config['Global']['save_res_path'] - class_path = config['Global']['class_path'] - idx_to_cls = read_class_list(class_path) - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - - model.eval() - - warmup_times = 0 - count_t = [] - with open(save_res_path, "w") as fout: - with open(config['Global']['infer_img'], "rb") as f: - lines = f.readlines() - for index, data_line in enumerate(lines): - if index == 10: - warmup_t = time.time() - data_line = data_line.decode('utf-8') - substr = data_line.strip("\n").split("\t") - img_path, label = data_dir + "/" + substr[0], substr[1] - data = {'img_path': img_path, 'label': label} - with open(data['img_path'], 'rb') as f: - img = f.read() - data['image'] = img - st = time.time() - batch = transform(data, ops) - batch_pred = [0] * len(batch) - for i in range(len(batch)): - batch_pred[i] = paddle.to_tensor( - np.expand_dims( - batch[i], axis=0)) - st = time.time() - node, edge = model(batch_pred) - node = F.softmax(node, -1) - count_t.append(time.time() - st) - draw_kie_result(batch, node, idx_to_cls, index) - write_kie_result(fout, node, data) - fout.close() - logger.info("success!") - logger.info("It took {} s for predict {} images.".format( - np.sum(count_t), len(count_t))) - ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:]) - logger.info("The ips is {} images/s".format(ips)) - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/infer_rec.py b/tools/infer_rec.py index a08fa25..3ec18d0 100644 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -15,6 +15,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from cgi import print_arguments import numpy as np @@ -48,6 +49,7 @@ def main(): # build model if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) + if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: @@ -103,15 +105,16 @@ def main(): ops = create_operators(transforms, global_config) save_res_path = config['Global'].get('save_res_path', - "./output/rec/predicts_rec.txt") + "./usr/yyx/data/predicts_rec.txt") if not os.path.exists(os.path.dirname(save_res_path)): os.makedirs(os.path.dirname(save_res_path)) model.eval() - with open(save_res_path, "w") as fout: - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) + with open(save_res_path, "w",encoding = "utf-8") as fout: + from tqdm import tqdm + for file in tqdm(get_image_file_list(config['Global']['infer_img'])): + # logger.info("infer_img: {}".format(file)) with open(file, 'rb') as f: img = f.read() data = {'image': img} @@ -141,6 +144,10 @@ def main(): else: preds = model(images) post_result = post_process_class(preds) + # print("#"*20) + # print(preds) + # print(post_result) + # print("#"*20) info = None if isinstance(post_result, dict): rec_info = dict() @@ -153,11 +160,12 @@ def main(): info = json.dumps(rec_info, ensure_ascii=False) else: if len(post_result[0]) >= 2: - info = post_result[0][0] + "\t" + str(post_result[0][1]) + info = post_result[0][0] if info is not None: - logger.info("\t result: {}".format(info)) - fout.write(file + "\t" + info + "\n") + # logger.info("\t result: {}".format(info)) + # print(info) + fout.write(file + "," + info + "\n") logger.info("success!") diff --git a/tools/infer_rec_my.py b/tools/infer_rec_my.py deleted file mode 100644 index 3ec18d0..0000000 --- a/tools/infer_rec_my.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from cgi import print_arguments - -import numpy as np - -import os -import sys -import json - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import get_image_file_list -import tools.program as program - - -def main(): - global_config = config['Global'] - - # build post process - post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - if hasattr(post_process_class, 'character'): - char_num = len(getattr(post_process_class, 'character')) - - if config['Architecture']["algorithm"] in ["Distillation", - ]: # distillation model - for key in config['Architecture']["Models"]: - if config['Architecture']['Models'][key]['Head'][ - 'name'] == 'MultiHead': # for multi head - out_channels_list = {} - if config['PostProcess'][ - 'name'] == 'DistillationSARLabelDecode': - char_num = char_num - 2 - out_channels_list['CTCLabelDecode'] = char_num - out_channels_list['SARLabelDecode'] = char_num + 2 - config['Architecture']['Models'][key]['Head'][ - 'out_channels_list'] = out_channels_list - else: - config['Architecture']["Models"][key]["Head"][ - 'out_channels'] = char_num - elif config['Architecture']['Head'][ - 'name'] == 'MultiHead': # for multi head loss - out_channels_list = {} - if config['PostProcess']['name'] == 'SARLabelDecode': - char_num = char_num - 2 - out_channels_list['CTCLabelDecode'] = char_num - out_channels_list['SARLabelDecode'] = char_num + 2 - config['Architecture']['Head'][ - 'out_channels_list'] = out_channels_list - else: # base rec model - config['Architecture']["Head"]['out_channels'] = char_num - - model = build_model(config['Architecture']) - - load_model(config, model) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - elif op_name in ['RecResizeImg']: - op[op_name]['infer_mode'] = True - elif op_name == 'KeepKeys': - if config['Architecture']['algorithm'] == "SRN": - op[op_name]['keep_keys'] = [ - 'image', 'encoder_word_pos', 'gsrm_word_pos', - 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' - ] - elif config['Architecture']['algorithm'] == "SAR": - op[op_name]['keep_keys'] = ['image', 'valid_ratio'] - else: - op[op_name]['keep_keys'] = ['image'] - transforms.append(op) - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - - save_res_path = config['Global'].get('save_res_path', - "./usr/yyx/data/predicts_rec.txt") - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) - - model.eval() - - with open(save_res_path, "w",encoding = "utf-8") as fout: - from tqdm import tqdm - for file in tqdm(get_image_file_list(config['Global']['infer_img'])): - # logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - if config['Architecture']['algorithm'] == "SRN": - encoder_word_pos_list = np.expand_dims(batch[1], axis=0) - gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) - gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) - gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) - - others = [ - paddle.to_tensor(encoder_word_pos_list), - paddle.to_tensor(gsrm_word_pos_list), - paddle.to_tensor(gsrm_slf_attn_bias1_list), - paddle.to_tensor(gsrm_slf_attn_bias2_list) - ] - if config['Architecture']['algorithm'] == "SAR": - valid_ratio = np.expand_dims(batch[-1], axis=0) - img_metas = [paddle.to_tensor(valid_ratio)] - - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - if config['Architecture']['algorithm'] == "SRN": - preds = model(images, others) - elif config['Architecture']['algorithm'] == "SAR": - preds = model(images, img_metas) - else: - preds = model(images) - post_result = post_process_class(preds) - # print("#"*20) - # print(preds) - # print(post_result) - # print("#"*20) - info = None - if isinstance(post_result, dict): - rec_info = dict() - for key in post_result: - if len(post_result[key][0]) >= 2: - rec_info[key] = { - "label": post_result[key][0][0], - "score": float(post_result[key][0][1]), - } - info = json.dumps(rec_info, ensure_ascii=False) - else: - if len(post_result[0]) >= 2: - info = post_result[0][0] - - if info is not None: - # logger.info("\t result: {}".format(info)) - # print(info) - fout.write(file + "," + info + "\n") - logger.info("success!") - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main() diff --git a/tools/infer_table.py b/tools/infer_table.py deleted file mode 100644 index 66c2da4..0000000 --- a/tools/infer_table.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys -import json - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' - -import paddle -from paddle.jit import to_static - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.utility import get_image_file_list -import tools.program as program -import cv2 - - -def main(config, device, logger, vdl_writer): - global_config = config['Global'] - - # build post process - post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - if hasattr(post_process_class, 'character'): - config['Architecture']["Head"]['out_channels'] = len( - getattr(post_process_class, 'character')) - - model = build_model(config['Architecture']) - - load_model(config, model) - - # create data ops - transforms = [] - use_padding = False - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - continue - if op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] - if op_name == "ResizeTableImage": - use_padding = True - padding_max_len = op['ResizeTableImage']['max_len'] - transforms.append(op) - - global_config['infer_mode'] = True - ops = create_operators(transforms, global_config) - - model.eval() - for file in get_image_file_list(config['Global']['infer_img']): - logger.info("infer_img: {}".format(file)) - with open(file, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, ops) - images = np.expand_dims(batch[0], axis=0) - images = paddle.to_tensor(images) - preds = model(images) - post_result = post_process_class(preds) - res_html_code = post_result['res_html_code'] - res_loc = post_result['res_loc'] - img = cv2.imread(file) - imgh, imgw = img.shape[0:2] - res_loc_final = [] - for rno in range(len(res_loc[0])): - x0, y0, x1, y1 = res_loc[0][rno] - left = max(int(imgw * x0), 0) - top = max(int(imgh * y0), 0) - right = min(int(imgw * x1), imgw - 1) - bottom = min(int(imgh * y1), imgh - 1) - cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2) - res_loc_final.append([left, top, right, bottom]) - res_loc_str = json.dumps(res_loc_final) - logger.info("result: {}, {}".format(res_html_code, res_loc_final)) - logger.info("success!") - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - main(config, device, logger, vdl_writer) diff --git a/tools/infer_vqa_token_ser.py b/tools/infer_vqa_token_ser.py deleted file mode 100644 index 83ed72b..0000000 --- a/tools/infer_vqa_token_ser.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' -import cv2 -import json -import paddle - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.visual import draw_ser_results -from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps -import tools.program as program - - -def to_tensor(data): - import numbers - from collections import defaultdict - data_dict = defaultdict(list) - to_tensor_idxs = [] - for idx, v in enumerate(data): - if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): - if idx not in to_tensor_idxs: - to_tensor_idxs.append(idx) - data_dict[idx].append(v) - for idx in to_tensor_idxs: - data_dict[idx] = paddle.to_tensor(data_dict[idx]) - return list(data_dict.values()) - - -class SerPredictor(object): - def __init__(self, config): - global_config = config['Global'] - - # build post process - self.post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - self.model = build_model(config['Architecture']) - - load_model( - config, self.model, model_type=config['Architecture']["model_type"]) - - from paddleocr import PaddleOCR - - self.ocr_engine = PaddleOCR(use_angle_cls=False, show_log=False) - - # create data ops - transforms = [] - for op in config['Eval']['dataset']['transforms']: - op_name = list(op)[0] - if 'Label' in op_name: - op[op_name]['ocr_engine'] = self.ocr_engine - elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = [ - 'input_ids', 'labels', 'bbox', 'image', 'attention_mask', - 'token_type_ids', 'segment_offset_id', 'ocr_info', - 'entities' - ] - - transforms.append(op) - global_config['infer_mode'] = True - self.ops = create_operators(config['Eval']['dataset']['transforms'], - global_config) - self.model.eval() - - def __call__(self, img_path): - with open(img_path, 'rb') as f: - img = f.read() - data = {'image': img} - batch = transform(data, self.ops) - batch = to_tensor(batch) - preds = self.model(batch) - post_result = self.post_process_class( - preds, - attention_masks=batch[4], - segment_offset_ids=batch[6], - ocr_infos=batch[7]) - return post_result, batch - - -if __name__ == '__main__': - config, device, logger, vdl_writer = program.preprocess() - os.makedirs(config['Global']['save_res_path'], exist_ok=True) - - ser_engine = SerPredictor(config) - - infer_imgs = get_image_file_list(config['Global']['infer_img']) - with open( - os.path.join(config['Global']['save_res_path'], - "infer_results.txt"), - "w", - encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): - save_img_path = os.path.join( - config['Global']['save_res_path'], - os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") - logger.info("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) - - result, _ = ser_engine(img_path) - result = result[0] - fout.write(img_path + "\t" + json.dumps( - { - "ocr_info": result, - }, ensure_ascii=False) + "\n") - img_res = draw_ser_results(img_path, result) - cv2.imwrite(save_img_path, img_res) diff --git a/tools/infer_vqa_token_ser_re.py b/tools/infer_vqa_token_ser_re.py deleted file mode 100644 index 6210f7f..0000000 --- a/tools/infer_vqa_token_ser_re.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -import os -import sys - -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..'))) - -os.environ["FLAGS_allocator_strategy"] = 'auto_growth' -import cv2 -import json -import paddle -import paddle.distributed as dist - -from ppocr.data import create_operators, transform -from ppocr.modeling.architectures import build_model -from ppocr.postprocess import build_post_process -from ppocr.utils.save_load import load_model -from ppocr.utils.visual import draw_re_results -from ppocr.utils.logging import get_logger -from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict -from tools.program import ArgsParser, load_config, merge_config, check_gpu -from tools.infer_vqa_token_ser import SerPredictor - - -class ReArgsParser(ArgsParser): - def __init__(self): - super(ReArgsParser, self).__init__() - self.add_argument( - "-c_ser", "--config_ser", help="ser configuration file to use") - self.add_argument( - "-o_ser", - "--opt_ser", - nargs='+', - help="set ser configuration options ") - - def parse_args(self, argv=None): - args = super(ReArgsParser, self).parse_args(argv) - assert args.config_ser is not None, \ - "Please specify --config_ser=ser_configure_file_path." - args.opt_ser = self._parse_opt(args.opt_ser) - return args - - -def make_input(ser_inputs, ser_results): - entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} - - entities = ser_inputs[8][0] - ser_results = ser_results[0] - assert len(entities) == len(ser_results) - - # entities - start = [] - end = [] - label = [] - entity_idx_dict = {} - for i, (res, entity) in enumerate(zip(ser_results, entities)): - if res['pred'] == 'O': - continue - entity_idx_dict[len(start)] = i - start.append(entity['start']) - end.append(entity['end']) - label.append(entities_labels[res['pred']]) - entities = dict(start=start, end=end, label=label) - - # relations - head = [] - tail = [] - for i in range(len(entities["label"])): - for j in range(len(entities["label"])): - if entities["label"][i] == 1 and entities["label"][j] == 2: - head.append(i) - tail.append(j) - - relations = dict(head=head, tail=tail) - - batch_size = ser_inputs[0].shape[0] - entities_batch = [] - relations_batch = [] - entity_idx_dict_batch = [] - for b in range(batch_size): - entities_batch.append(entities) - relations_batch.append(relations) - entity_idx_dict_batch.append(entity_idx_dict) - - ser_inputs[8] = entities_batch - ser_inputs.append(relations_batch) - # remove ocr_info segment_offset_id and label in ser input - ser_inputs.pop(7) - ser_inputs.pop(6) - ser_inputs.pop(1) - return ser_inputs, entity_idx_dict_batch - - -class SerRePredictor(object): - def __init__(self, config, ser_config): - self.ser_engine = SerPredictor(ser_config) - - # init re model - global_config = config['Global'] - - # build post process - self.post_process_class = build_post_process(config['PostProcess'], - global_config) - - # build model - self.model = build_model(config['Architecture']) - - load_model( - config, self.model, model_type=config['Architecture']["model_type"]) - - self.model.eval() - - def __call__(self, img_path): - ser_results, ser_inputs = self.ser_engine(img_path) - paddle.save(ser_inputs, 'ser_inputs.npy') - paddle.save(ser_results, 'ser_results.npy') - re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) - preds = self.model(re_input) - post_result = self.post_process_class( - preds, - ser_results=ser_results, - entity_idx_dict_batch=entity_idx_dict_batch) - return post_result - - -def preprocess(): - FLAGS = ReArgsParser().parse_args() - config = load_config(FLAGS.config) - config = merge_config(config, FLAGS.opt) - - ser_config = load_config(FLAGS.config_ser) - ser_config = merge_config(ser_config, FLAGS.opt_ser) - - logger = get_logger() - - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - check_gpu(use_gpu) - - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' - device = paddle.set_device(device) - - logger.info('{} re config {}'.format('*' * 10, '*' * 10)) - print_dict(config, logger) - logger.info('\n') - logger.info('{} ser config {}'.format('*' * 10, '*' * 10)) - print_dict(ser_config, logger) - logger.info('train with paddle {} and device {}'.format(paddle.__version__, - device)) - return config, ser_config, device, logger - - -if __name__ == '__main__': - config, ser_config, device, logger = preprocess() - os.makedirs(config['Global']['save_res_path'], exist_ok=True) - - ser_re_engine = SerRePredictor(config, ser_config) - - infer_imgs = get_image_file_list(config['Global']['infer_img']) - with open( - os.path.join(config['Global']['save_res_path'], - "infer_results.txt"), - "w", - encoding='utf-8') as fout: - for idx, img_path in enumerate(infer_imgs): - save_img_path = os.path.join( - config['Global']['save_res_path'], - os.path.splitext(os.path.basename(img_path))[0] + "_ser.jpg") - logger.info("process: [{}/{}], save result to {}".format( - idx, len(infer_imgs), save_img_path)) - - result = ser_re_engine(img_path) - result = result[0] - fout.write(img_path + "\t" + json.dumps( - { - "ser_result": result, - }, ensure_ascii=False) + "\n") - img_res = draw_re_results(img_path, result) - cv2.imwrite(save_img_path, img_res) diff --git a/tools/test_hubserving.py b/tools/test_hubserving.py deleted file mode 100644 index ec17a94..0000000 --- a/tools/test_hubserving.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import sys -__dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) - -from ppocr.utils.logging import get_logger -logger = get_logger() - -import cv2 -import numpy as np -import time -from PIL import Image -from ppocr.utils.utility import get_image_file_list -from tools.infer.utility import draw_ocr, draw_boxes, str2bool -from ppstructure.utility import draw_structure_result -from ppstructure.predict_system import to_excel - -import requests -import json -import base64 - - -def cv2_to_base64(image): - return base64.b64encode(image).decode('utf8') - - -def draw_server_result(image_file, res): - img = cv2.imread(image_file) - image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) - if len(res) == 0: - return np.array(image) - keys = res[0].keys() - if 'text_region' not in keys: # for ocr_rec, draw function is invalid - logger.info("draw function is invalid for ocr_rec!") - return None - elif 'text' not in keys: # for ocr_det - logger.info("draw text boxes only!") - boxes = [] - for dno in range(len(res)): - boxes.append(res[dno]['text_region']) - boxes = np.array(boxes) - draw_img = draw_boxes(image, boxes) - return draw_img - else: # for ocr_system - logger.info("draw boxes and texts!") - boxes = [] - texts = [] - scores = [] - for dno in range(len(res)): - boxes.append(res[dno]['text_region']) - texts.append(res[dno]['text']) - scores.append(res[dno]['confidence']) - boxes = np.array(boxes) - scores = np.array(scores) - draw_img = draw_ocr( - image, boxes, texts, scores, draw_txt=True, drop_score=0.5) - return draw_img - - -def save_structure_res(res, save_folder, image_file): - img = cv2.imread(image_file) - excel_save_folder = os.path.join(save_folder, os.path.basename(image_file)) - os.makedirs(excel_save_folder, exist_ok=True) - # save res - with open( - os.path.join(excel_save_folder, 'res.txt'), 'w', - encoding='utf8') as f: - for region in res: - if region['type'] == 'Table': - excel_path = os.path.join(excel_save_folder, - '{}.xlsx'.format(region['bbox'])) - to_excel(region['res'], excel_path) - elif region['type'] == 'Figure': - x1, y1, x2, y2 = region['bbox'] - print(region['bbox']) - roi_img = img[y1:y2, x1:x2, :] - img_path = os.path.join(excel_save_folder, - '{}.jpg'.format(region['bbox'])) - cv2.imwrite(img_path, roi_img) - else: - for text_result in region['res']: - f.write('{}\n'.format(json.dumps(text_result))) - - -def main(args): - image_file_list = get_image_file_list(args.image_dir) - is_visualize = False - headers = {"Content-type": "application/json"} - cnt = 0 - total_time = 0 - for image_file in image_file_list: - img = open(image_file, 'rb').read() - if img is None: - logger.info("error in loading image:{}".format(image_file)) - continue - img_name = os.path.basename(image_file) - # seed http request - starttime = time.time() - data = {'images': [cv2_to_base64(img)]} - r = requests.post( - url=args.server_url, headers=headers, data=json.dumps(data)) - elapse = time.time() - starttime - total_time += elapse - logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) - res = r.json()["results"][0] - logger.info(res) - - if args.visualize: - draw_img = None - if 'structure_table' in args.server_url: - to_excel(res['html'], './{}.xlsx'.format(img_name)) - elif 'structure_system' in args.server_url: - save_structure_res(res['regions'], args.output, image_file) - else: - draw_img = draw_server_result(image_file, res) - if draw_img is not None: - if not os.path.exists(args.output): - os.makedirs(args.output) - cv2.imwrite( - os.path.join(args.output, os.path.basename(image_file)), - draw_img[:, :, ::-1]) - logger.info("The visualized image saved in {}".format( - os.path.join(args.output, os.path.basename(image_file)))) - cnt += 1 - if cnt % 100 == 0: - logger.info("{} processed".format(cnt)) - logger.info("avg time cost: {}".format(float(total_time) / cnt)) - - -def parse_args(): - import argparse - parser = argparse.ArgumentParser(description="args for hub serving") - parser.add_argument("--server_url", type=str, required=True) - parser.add_argument("--image_dir", type=str, required=True) - parser.add_argument("--visualize", type=str2bool, default=False) - parser.add_argument("--output", type=str, default='./hubserving_result') - args = parser.parse_args() - return args - - -if __name__ == '__main__': - args = parse_args() - main(args)