Skip to content

Commit

Permalink
Rectify code of the LayoutLM series models and adjust it to amp_level…
Browse files Browse the repository at this point in the history
… mode.
  • Loading branch information
Bourn3z committed Apr 17, 2024
1 parent c698646 commit 3d716b6
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 120 deletions.
3 changes: 1 addition & 2 deletions configs/kie/layoutlmv3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Apart from the dataset setting, please also check the following important args:
system:
mode:
distribute: False # `True` for distributed training, `False` for standalone training
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # Validate while training
drop_overflow_update: False
Expand All @@ -157,7 +157,6 @@ model:
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:
...
train:
Expand Down
3 changes: 1 addition & 2 deletions configs/kie/layoutlmv3/ser_layoutlmv3_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: "O0"
amp_level: "O3"
seed: 42
log_interval: 10
val_start_epoch: 50
Expand All @@ -17,7 +17,6 @@ model:
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/layoutxlm/re_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -16,11 +16,9 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: True
use_float16: True
head:
name: RelationExtractionHead
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/layoutxlm/ser_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -15,12 +15,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: True
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: True
use_float16: True
pretrained:

postprocess:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ Apart from the dataset setting, please also check the following important args:
system:
mode:
distribute: False # `True` for distributed training, `False` for standalone training
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # Validate while training
drop_overflow_update: False
Expand All @@ -171,12 +171,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:
...
train:
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ eval:
system:
mode:
distribute: False # 分布式训练为True,单卡训练为False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
val_while_train: True # 边训练边验证
drop_overflow_update: False
Expand All @@ -168,12 +168,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:
...
train:
Expand Down
6 changes: 2 additions & 4 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: "O0"
amp_level: "O3"
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -16,11 +16,9 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head:
name: RelationExtractionHead
use_visual_backbone: False
use_float16: True
pretrained:

postprocess:
Expand Down Expand Up @@ -90,11 +88,11 @@ train:
"bbox",
"attention_mask",
"token_type_ids",
"image",
"question",
"question_label",
"answer",
"answer_label",
"image",
"relation_label",
]
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
Expand Down
4 changes: 1 addition & 3 deletions configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
system:
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
distribute: False
amp_level: 'O0'
amp_level: 'O3'
seed: 42
log_interval: 10
val_while_train: True
Expand All @@ -15,12 +15,10 @@ model:
pretrained: True
num_classes: &num_classes 7
use_visual_backbone: False
use_float16: True
head :
name: TokenClassificationHead
num_classes: 7
use_visual_backbone: False
use_float16: True
pretrained:

postprocess:
Expand Down
3 changes: 1 addition & 2 deletions mindocr/models/backbones/layoutlmv3/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@

@dataclass
class LayoutLMv3PretrainedConfig:
def __init__(self, use_float16=False):
def __init__(self):
pretrained_config = {
"use_float16": use_float16,
"fast_qkv": False,
"vocab_size": 250002,
"hidden_size": 768,
Expand Down
21 changes: 12 additions & 9 deletions mindocr/models/backbones/layoutlmv3/layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def construct(

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
attention_scores = attention_scores + attention_mask.astype(self.dense_dtype)
attention_scores = attention_scores + attention_mask.astype(attention_scores.dtype)

# Normalize the attention scores to probabilities.
# Use the trick of the CogView paper to stablize training
Expand Down Expand Up @@ -227,11 +227,8 @@ def __init__(self, config):
self.has_relative_attention_bias = config.has_relative_attention_bias
self.has_spatial_attention_bias = config.has_spatial_attention_bias
self.patch_size = config.patch_size
self.use_float16 = config.use_float16
self.dense_dtype = mstype.float32
if self.use_float16 is True:
self.dense_dtype = mstype.float16
self.min = finfo(self.dense_dtype)
self.float32_min = finfo(mstype.float32)
self.float16_min = finfo(mstype.float16)
self.out_channels = 1
self.use_visual_backbone = True

Expand Down Expand Up @@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # fp16 compatibility
extended_attention_mask = extended_attention_mask.astype(dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * self.min

if dtype == mstype.float32:
minimum = self.float32_min
elif dtype == mstype.float16:
minimum = self.float16_min

extended_attention_mask = (1.0 - extended_attention_mask) * minimum
return extended_attention_mask

def get_head_mask(self, head_mask, num_hidden_layers: int, is_attention_chunked: bool = False):
Expand Down Expand Up @@ -518,7 +521,7 @@ def construct(


@register_backbone
def layoutlmv3(use_float16: bool = True, **kwargs):
pretrained_config = LayoutLMv3PretrainedConfig(use_float16)
def layoutlmv3(**kwargs):
pretrained_config = LayoutLMv3PretrainedConfig()
model = LayoutLMv3Model(pretrained_config)
return model
4 changes: 1 addition & 3 deletions mindocr/models/backbones/layoutxlm/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@

@dataclass
class LayoutXLMPretrainedConfig:
def __init__(self, use_visual_backbone=True, use_float16=False):
def __init__(self, use_visual_backbone=True):
pretrained_config = {
"use_visual_backbone": use_visual_backbone,
"use_float16": use_float16,
"attention_probs_dropout_prob": 0.1,
"use_visual_backbone": use_visual_backbone,
"use_float16": use_float16,
"bos_token_id": 0,
"coordinate_size": 128,
"eos_token_id": 2,
Expand Down
12 changes: 3 additions & 9 deletions mindocr/models/backbones/layoutxlm/layoutxlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,12 @@ def __init__(self, config):
self.has_visual_segment_embedding = config.has_visual_segment_embedding
self.embeddings = LayoutXLMEmbeddings(config)
self.use_visual_backbone = config.use_visual_backbone
self.use_float16 = config.use_float16
self.dense_dtype = mstype.float32
if self.use_float16 is True:
self.dense_dtype = mstype.float16

if self.use_visual_backbone is True:
set_context(jit_syntax_level=0)
self.visual = VisualBackbone(config)
self.visual.freeze()
self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size).to_float(
self.dense_dtype
)
self.visual_proj = nn.Dense(config.image_feature_pool_shape[-1], config.hidden_size)
if self.has_visual_segment_embedding:
self.visual_segment_embedding = Parameter(nn.Embedding(1, config.hidden_size).embedding_table[0])
self.visual_LayerNorm = nn.LayerNorm((config.hidden_size,), epsilon=config.layer_norm_eps)
Expand Down Expand Up @@ -302,8 +296,8 @@ def construct(


@register_backbone
def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, use_float16: bool = False, **kwargs):
pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone, use_float16)
def layoutxlm(pretrained: bool = True, use_visual_backbone: bool = True, **kwargs):
pretrained_config = LayoutXLMPretrainedConfig(use_visual_backbone)
model = LayoutXLMModel(pretrained_config)
if pretrained:
if use_visual_backbone is True:
Expand Down
2 changes: 1 addition & 1 deletion mindocr/models/backbones/layoutxlm/visual_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def construct(self, x):
top_block_in_feature = bottom_up_features[self.top_block.in_feature]
else:
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
results.extend(self.top_block(top_block_in_feature.astype(ms.float16)))
results.extend(self.top_block(top_block_in_feature))

assert len(self._out_features) == len(results)

Expand Down
Loading

0 comments on commit 3d716b6

Please sign in to comment.