Skip to content

Commit

Permalink
fix dataaug test (PaddlePaddle#5532)
Browse files Browse the repository at this point in the history
* fix dataaug test

* decrease length
  • Loading branch information
lugimzzz authored Apr 4, 2023
1 parent 95475e3 commit 2ee2948
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/dataaug.md
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ SentenceBackTranslate 参数介绍:
“beam_search”策略中的beam值。 默认为 4。
use_faster (bool):
是否使用FasterGeneration进行加速。默认为True
是否使用FasterGeneration进行加速。默认为False
decode_strategy (str):
生成中的解码策略。 目前支持三种解码策略:“greedy_search”、“sampling”和“beam_search”。 默认为“beam_search”。
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/dataaug/sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class SentenceBackTranslate:
num_beams (int): The number of beams in the "beam_search"
strategy. Default to 4.
use_faster: (bool): Whether to use faster entry of model
for FasterGeneration. Default to True.
for FasterGeneration. Default to False.
decode_strategy (str, optional): The decoding strategy in generation.
Currently, there are three decoding strategies supported:
"greedy_search", "sampling" and "beam_search". Default to
Expand All @@ -209,7 +209,7 @@ def __init__(
max_length=128,
batch_size=1,
num_beams=4,
use_faster=True,
use_faster=False,
decode_strategy="beam_search",
from_model_name=None,
to_model_name=None,
Expand Down
6 changes: 4 additions & 2 deletions tests/dataaug/test_char_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_char_substitute(self, create_n):
for t in self.types:
if t == "mlm":
aug = CharSubstitute(
"mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab"
"mlm", create_n=create_n, model_name="__internal_testing__/tiny-random-ernie", vocab="test_vocab"
)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
Expand All @@ -79,7 +79,9 @@ def test_char_substitute(self, create_n):
def test_char_insert(self, create_n):
for t in self.types:
if t == "mlm":
aug = CharInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab")
aug = CharInsert(
"mlm", create_n=create_n, model_name="__internal_testing__/tiny-random-ernie", vocab="test_vocab"
)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
continue
Expand Down
14 changes: 9 additions & 5 deletions tests/dataaug/test_sentence_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,21 @@

class TestSentAug(unittest.TestCase):
def setUp(self):
self.sequences = ["人类语言是抽象的信息符号,其中蕴含着丰富的语义信息,人类可以很轻松地理解其中的含义。", "而计算机只能处理数值化的信息,无法直接理解人类语言,所以需要将人类语言进行数值化转换。"]
self.sequences = ["人类语言是抽象的信息符号。", "而计算机只能处理数值化的信息。"]
self.max_length = 3

def test_sent_generate(self):
aug = SentenceGenerate(model_name="__internal_testing__/tiny-random-roformer-sim")
aug = SentenceGenerate(model_name="__internal_testing__/tiny-random-roformer-sim", max_length=self.max_length)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
self.assertEqual(aug.create_n, len(augmented[0]))
self.assertEqual(aug.create_n, len(augmented[1]))

def test_sent_summarize(self):
model = AutoModelForConditionalGeneration.from_pretrained("__internal_testing__/tiny-random-pegasus")
tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-pegasus")
model = AutoModelForConditionalGeneration.from_pretrained(
"__internal_testing__/tiny-random-mbart", max_length=self.max_length
)
tokenizer = AutoTokenizer.from_pretrained("__internal_testing__/tiny-random-mbart")
model_path = os.path.join(TemporaryDirectory().name, "model")
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
Expand All @@ -52,14 +55,15 @@ def test_sent_backtranslate(self):
aug = SentenceBackTranslate(
from_model_name="__internal_testing__/tiny-random-mbart",
to_model_name="__internal_testing__/tiny-random-mbart",
max_length=self.max_length,
)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
self.assertEqual(1, len(augmented[0]))
self.assertEqual(1, len(augmented[1]))

def test_sent_continue(self):
aug = SentenceContinue(model_name="__internal_testing__/tiny-random-gpt")
aug = SentenceContinue(model_name="__internal_testing__/tiny-random-gpt", max_length=self.max_length)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
self.assertEqual(aug.create_n, len(augmented[0]))
Expand Down
6 changes: 4 additions & 2 deletions tests/dataaug/test_word_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_word_substitute(self, create_n):
for t in self.types:
if t == "mlm":
aug = WordSubstitute(
"mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab"
"mlm", create_n=create_n, model_name="__internal_testing__/tiny-random-ernie", vocab="test_vocab"
)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
Expand All @@ -72,7 +72,9 @@ def test_word_substitute(self, create_n):
def test_word_insert(self, create_n):
for t in self.types:
if t == "mlm":
aug = WordInsert("mlm", create_n=create_n, model_name="__internal_testing__/ernie", vocab="test_vocab")
aug = WordInsert(
"mlm", create_n=create_n, model_name="__internal_testing__/tiny-random-ernie", vocab="test_vocab"
)
augmented = aug.augment(self.sequences)
self.assertEqual(len(self.sequences), len(augmented))
continue
Expand Down

0 comments on commit 2ee2948

Please sign in to comment.