-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
205 changed files
with
35,239 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.pyc | ||
data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
import os | ||
import sys | ||
import numpy as np | ||
import skimage | ||
import paddle | ||
import signal | ||
import random | ||
|
||
__dir__ = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) | ||
|
||
import copy | ||
from paddle.io import Dataset, DataLoader, BatchSampler, DistributedBatchSampler | ||
import paddle.distributed as dist | ||
|
||
from ppocr.data.imaug import transform, create_operators | ||
from ppocr.data.simple_dataset import SimpleDataSet | ||
from ppocr.data.lmdb_dataset import LMDBDataSet | ||
from ppocr.data.pgnet_dataset import PGDataSet | ||
from ppocr.data.pubtab_dataset import PubTabDataSet | ||
|
||
__all__ = ['build_dataloader', 'transform', 'create_operators'] | ||
|
||
|
||
def term_mp(sig_num, frame): | ||
""" kill all child processes | ||
""" | ||
pid = os.getpid() | ||
pgid = os.getpgid(os.getpid()) | ||
print("main proc {} exit, kill process group " "{}".format(pid, pgid)) | ||
os.killpg(pgid, signal.SIGKILL) | ||
|
||
|
||
def build_dataloader(config, mode, device, logger, seed=None): | ||
config = copy.deepcopy(config) | ||
|
||
support_dict = [ | ||
'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet' | ||
] | ||
module_name = config[mode]['dataset']['name'] | ||
assert module_name in support_dict, Exception( | ||
'DataSet only support {}'.format(support_dict)) | ||
assert mode in ['Train', 'Eval', 'Test' | ||
], "Mode should be Train, Eval or Test." | ||
|
||
dataset = eval(module_name)(config, mode, logger, seed) | ||
loader_config = config[mode]['loader'] | ||
batch_size = loader_config['batch_size_per_card'] | ||
drop_last = loader_config['drop_last'] | ||
shuffle = loader_config['shuffle'] | ||
num_workers = loader_config['num_workers'] | ||
if 'use_shared_memory' in loader_config.keys(): | ||
use_shared_memory = loader_config['use_shared_memory'] | ||
else: | ||
use_shared_memory = True | ||
|
||
if mode == "Train": | ||
# Distribute data to multiple cards | ||
batch_sampler = DistributedBatchSampler( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
drop_last=drop_last) | ||
else: | ||
# Distribute data to single card | ||
batch_sampler = BatchSampler( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
drop_last=drop_last) | ||
|
||
if 'collate_fn' in loader_config: | ||
from . import collate_fn | ||
collate_fn = getattr(collate_fn, loader_config['collate_fn'])() | ||
else: | ||
collate_fn = None | ||
data_loader = DataLoader( | ||
dataset=dataset, | ||
batch_sampler=batch_sampler, | ||
places=device, | ||
num_workers=num_workers, | ||
return_list=True, | ||
use_shared_memory=use_shared_memory, | ||
collate_fn=collate_fn) | ||
|
||
# support exit using ctrl+c | ||
signal.signal(signal.SIGINT, term_mp) | ||
signal.signal(signal.SIGTERM, term_mp) | ||
|
||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import numbers | ||
import numpy as np | ||
from collections import defaultdict | ||
|
||
|
||
class DictCollator(object): | ||
""" | ||
data batch | ||
""" | ||
|
||
def __call__(self, batch): | ||
# todo:support batch operators | ||
data_dict = defaultdict(list) | ||
to_tensor_keys = [] | ||
for sample in batch: | ||
for k, v in sample.items(): | ||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): | ||
if k not in to_tensor_keys: | ||
to_tensor_keys.append(k) | ||
data_dict[k].append(v) | ||
for k in to_tensor_keys: | ||
data_dict[k] = paddle.to_tensor(data_dict[k]) | ||
return data_dict | ||
|
||
|
||
class ListCollator(object): | ||
""" | ||
data batch | ||
""" | ||
|
||
def __call__(self, batch): | ||
# todo:support batch operators | ||
data_dict = defaultdict(list) | ||
to_tensor_idxs = [] | ||
for sample in batch: | ||
for idx, v in enumerate(sample): | ||
if isinstance(v, (np.ndarray, paddle.Tensor, numbers.Number)): | ||
if idx not in to_tensor_idxs: | ||
to_tensor_idxs.append(idx) | ||
data_dict[idx].append(v) | ||
for idx in to_tensor_idxs: | ||
data_dict[idx] = paddle.to_tensor(data_dict[idx]) | ||
return list(data_dict.values()) | ||
|
||
|
||
class SSLRotateCollate(object): | ||
""" | ||
bach: [ | ||
[(4*3xH*W), (4,)] | ||
[(4*3xH*W), (4,)] | ||
... | ||
] | ||
""" | ||
|
||
def __call__(self, batch): | ||
output = [np.concatenate(d, axis=0) for d in zip(*batch)] | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from paddle.vision.transforms import ColorJitter as pp_ColorJitter | ||
|
||
__all__ = ['ColorJitter'] | ||
|
||
class ColorJitter(object): | ||
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs): | ||
self.aug = pp_ColorJitter(brightness, contrast, saturation, hue) | ||
|
||
def __call__(self, data): | ||
image = data['image'] | ||
image = self.aug(image) | ||
data['image'] = image | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
from __future__ import unicode_literals | ||
|
||
from .iaa_augment import IaaAugment | ||
from .make_border_map import MakeBorderMap | ||
from .make_shrink_map import MakeShrinkMap | ||
from .random_crop_data import EastRandomCropData, RandomCropImgMask | ||
from .make_pse_gt import MakePseGt | ||
|
||
from .rec_img_aug import RecAug, RecConAug, RecResizeImg, ClsResizeImg, \ | ||
SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg, PRENResizeImg, \ | ||
SVTRRecResizeImg | ||
from .ssl_img_aug import SSLRotateResize | ||
from .randaugment import RandAugment | ||
from .copy_paste import CopyPaste | ||
from .ColorJitter import ColorJitter | ||
from .operators import * | ||
from .label_ops import * | ||
|
||
from .east_process import * | ||
from .sast_process import * | ||
from .pg_process import * | ||
from .gen_table_mask import * | ||
|
||
from .vqa import * | ||
|
||
from .fce_aug import * | ||
from .fce_targets import FCENetTargets | ||
|
||
|
||
def transform(data, ops=None): | ||
""" transform """ | ||
if ops is None: | ||
ops = [] | ||
for op in ops: | ||
data = op(data) | ||
if data is None: | ||
return None | ||
return data | ||
|
||
|
||
def create_operators(op_param_list, global_config=None): | ||
""" | ||
create operators based on the config | ||
Args: | ||
params(list): a dict list, used to create some operators | ||
""" | ||
assert isinstance(op_param_list, list), ('operator config should be a list') | ||
ops = [] | ||
for operator in op_param_list: | ||
assert isinstance(operator, | ||
dict) and len(operator) == 1, "yaml format error" | ||
op_name = list(operator)[0] | ||
param = {} if operator[op_name] is None else operator[op_name] | ||
if global_config is not None: | ||
param.update(global_config) | ||
op = eval(op_name)(**param) | ||
ops.append(op) | ||
return ops |
Oops, something went wrong.