|
2 | 2 | from typing import ClassVar |
3 | 3 |
|
4 | 4 | import datasets |
5 | | -import nltk |
6 | | -from nltk.corpus import wordnet |
7 | 5 | from pydantic import model_validator |
8 | 6 |
|
9 | 7 | from shared.base import BaseDataset, ChatEntry |
10 | 8 |
|
11 | | -nltk.download("wordnet") |
12 | | - |
13 | 9 |
|
14 | 10 | class SN13Dataset(BaseDataset): |
15 | 11 | _url: ClassVar[str] = "arrmlet/x_dataset_218" |
@@ -41,51 +37,10 @@ def sample(self) -> ChatEntry: |
41 | 37 | if self.exception is not None: |
42 | 38 | raise self.exception |
43 | 39 | # Randomly select a sample from the dataset. |
44 | | - sample_idx = random.randint(0, len(self.dataset) - 1) |
45 | | - message = self.dataset[sample_idx]["text"] |
46 | | - role = ["user"] |
47 | | - |
48 | | - # Augment the messages by modifying words and introducing errors. |
49 | | - messages = [self._augment_message(role, message)] |
50 | | - |
51 | | - return ChatEntry(roles=role, messages=messages, organic=False, source=self._url) |
52 | | - |
53 | | - def _augment_message(self, role: str, message: str) -> str: |
54 | | - if role == "assistant": |
55 | | - return message |
56 | | - |
57 | | - words = message.split() |
58 | | - num_words_to_modify = random.randint(1, max(1, int(len(words) * self._chance_word_synonym))) |
59 | | - words_to_modify = random.sample(range(len(words)), num_words_to_modify) |
60 | | - |
61 | | - for idx in words_to_modify: |
62 | | - synonym = self._get_synonym(words[idx]) |
63 | | - if synonym: |
64 | | - words[idx] = synonym |
65 | | - |
66 | | - message = " ".join(words) |
67 | | - message = self._introduce_typos(message) |
68 | | - return message |
69 | | - |
70 | | - def _get_synonym(self, word: str) -> str: |
71 | | - synonyms = wordnet.synsets(word) |
72 | | - if synonyms: |
73 | | - # Choose a synonym that is not the word itself. |
74 | | - synonym_words = [lemma.name() for lemma in synonyms[0].lemmas() if lemma.name() != word] |
75 | | - if synonym_words: |
76 | | - return random.choice(synonym_words) |
77 | | - return word |
78 | | - |
79 | | - def _introduce_typos(self, message: str) -> str: |
80 | | - message = list(message) |
81 | | - num_errors = random.randint(0, max(1, int(len(message) * self._chance_char_typo))) |
82 | | - for _ in range(num_errors): |
83 | | - error_type = random.choice(["remove", "add_space"]) |
84 | | - error_position = random.randint(0, len(message) - 1) |
85 | | - |
86 | | - if error_type == "remove": |
87 | | - message.pop(error_position) |
88 | | - elif error_type == "add_space": |
89 | | - message.insert(error_position, " ") |
| 40 | + messages = [] |
| 41 | + for _ in range(4): |
| 42 | + sample_idx = random.randint(0, len(self.dataset) - 1) |
| 43 | + if message := self.dataset[sample_idx]["text"]: |
| 44 | + messages.append({"role": random.choice(["user", "assistant"]), "content": message}) |
90 | 45 |
|
91 | | - return "".join(message) |
| 46 | + return ChatEntry(messages=messages, organic=False, source=self._url) |
0 commit comments