Skip to content

Commit b8b77d5

Browse files
committed
update
1 parent 49180c4 commit b8b77d5

File tree

2 files changed

+88
-27
lines changed

2 files changed

+88
-27
lines changed

README.md

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,35 +17,58 @@ https://huggingface.co/junnyu/roformer_chinese_base
1717
## 使用
1818
```python
1919
import torch
20-
from roformer import RoFormerModel, RoFormerTokenizer
20+
from roformer import RoFormerModel, RoFormerTokenizer, TFRoFormerModel
2121
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
22-
model = RoFormerModel.from_pretrained("junnyu/roformer_chinese_base")
22+
pt_model = RoFormerModel.from_pretrained("junnyu/roformer_chinese_base")
23+
tf_model = TFRoFormerModel.from_pretrained("junnyu/roformer_chinese_base",
24+
from_pt=True)
2325
text = "这里基本保留了唐宋遗留下来的坊巷格局和大量明清古建筑,其中各级文保单位29处,被誉为“里坊制度的活化石”“明清建筑博物馆”!"
24-
inputs = tokenizer(text, return_tensors="pt")
26+
pt_inputs = tokenizer(text, return_tensors="pt")
27+
tf_inputs = tokenizer(text, return_tensors="tf")
2528
with torch.no_grad():
26-
outputs = model(**inputs).last_hidden_state
27-
print(outputs.shape)
29+
pt_outputs = pt_model(**pt_inputs).last_hidden_state
30+
print(pt_outputs.shape)
31+
tf_outputs = tf_model(**tf_inputs, training=False).last_hidden_state
32+
print(tf_outputs.shape)
2833
```
2934
## MLM测试
3035
```python
3136
import torch
32-
from roformer import RoFormerForMaskedLM, RoFormerTokenizer
37+
import tensorflow as tf
38+
from roformer import RoFormerForMaskedLM, RoFormerTokenizer, TFRoFormerForMaskedLM
3339
text = "今天[MASK]很好,我[MASK]去公园玩。"
3440
tokenizer = RoFormerTokenizer.from_pretrained("junnyu/roformer_chinese_base")
35-
model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
36-
inputs = tokenizer(text, return_tensors="pt")
41+
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_base")
42+
tf_model = TFRoFormerForMaskedLM.from_pretrained(
43+
"junnyu/roformer_chinese_base", from_pt=True)
44+
pt_inputs = tokenizer(text, return_tensors="pt")
45+
tf_inputs = tokenizer(text, return_tensors="tf")
46+
# pytorch
3747
with torch.no_grad():
38-
outputs = model(**inputs).logits[0]
39-
outputs_sentence = ""
48+
pt_outputs = pt_model(**pt_inputs).logits[0]
49+
pt_outputs_sentence = "pytorch: "
4050
for i, id in enumerate(tokenizer.encode(text)):
4151
if id == tokenizer.mask_token_id:
42-
tokens = tokenizer.convert_ids_to_tokens(outputs[i].topk(k=5)[1])
43-
outputs_sentence += "[" + "||".join(tokens) + "]"
52+
tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
53+
pt_outputs_sentence += "[" + "||".join(tokens) + "]"
4454
else:
45-
outputs_sentence += "".join(
55+
pt_outputs_sentence += "".join(
4656
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
47-
print(outputs_sentence)
48-
# 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
57+
print(pt_outputs_sentence)
58+
# tf
59+
tf_outputs = tf_model(**tf_inputs, training=False).logits[0]
60+
tf_outputs_sentence = "tf: "
61+
for i, id in enumerate(tokenizer.encode(text)):
62+
if id == tokenizer.mask_token_id:
63+
tokens = tokenizer.convert_ids_to_tokens(
64+
tf.math.top_k(tf_outputs[i], k=5)[1])
65+
tf_outputs_sentence += "[" + "||".join(tokens) + "]"
66+
else:
67+
tf_outputs_sentence += "".join(
68+
tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True))
69+
print(tf_outputs_sentence)
70+
# pytorch: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
71+
# tf: 今天[天气||天||心情||阳光||空气]很好,我[想||要||打算||准备||喜欢]去公园玩。
4972
```
5073

5174
## 手动权重转换

src/roformer/__init__.py

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1616
# See the License for the specific language governing permissions and
1717
# limitations under the License.
18+
from typing import TYPE_CHECKING
1819

1920
from transformers.file_utils import (
2021
_BaseLazyModule,
@@ -61,21 +62,58 @@
6162
"TFRoFormerModel",
6263
"TFRoFormerPreTrainedModel",
6364
]
64-
import importlib
65-
import os
66-
import sys
65+
if TYPE_CHECKING:
66+
from .configuration_roformer import ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, RoFormerConfig
67+
from .tokenization_roformer import CustomBasicTokenizer, RoFormerTokenizer
6768

69+
if is_torch_available():
70+
from .modeling_roformer import (
71+
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
72+
RoFormerForMaskedLM,
73+
RoFormerForMultipleChoice,
74+
RoFormerForNextSentencePrediction,
75+
RoFormerForPreTraining,
76+
RoFormerForQuestionAnswering,
77+
RoFormerForSequenceClassification,
78+
RoFormerForTokenClassification,
79+
RoFormerLayer,
80+
RoFormerLMHeadModel,
81+
RoFormerModel,
82+
RoFormerPreTrainedModel,
83+
load_tf_weights_in_roformer,
84+
)
6885

69-
class _LazyModule(_BaseLazyModule):
70-
"""
71-
Module class that surfaces all objects but only performs associated imports when the objects are requested.
72-
"""
86+
if is_tf_available():
87+
from .modeling_tf_roformer import (
88+
TF_ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
89+
TFRoFormerEmbeddings,
90+
TFRoFormerForMaskedLM,
91+
TFRoFormerForMultipleChoice,
92+
TFRoFormerForNextSentencePrediction,
93+
TFRoFormerForPreTraining,
94+
TFRoFormerForQuestionAnswering,
95+
TFRoFormerForSequenceClassification,
96+
TFRoFormerForTokenClassification,
97+
TFRoFormerLMHeadModel,
98+
TFRoFormerMainLayer,
99+
TFRoFormerModel,
100+
TFRoFormerPreTrainedModel,
101+
)
73102

74-
__file__ = globals()["__file__"]
75-
__path__ = [os.path.dirname(__file__)]
103+
else:
104+
import importlib
105+
import os
106+
import sys
76107

77-
def _get_module(self, module_name: str):
78-
return importlib.import_module("." + module_name, self.__name__)
108+
class _LazyModule(_BaseLazyModule):
109+
"""
110+
Module class that surfaces all objects but only performs associated imports when the objects are requested.
111+
"""
79112

113+
__file__ = globals()["__file__"]
114+
__path__ = [os.path.dirname(__file__)]
80115

81-
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
116+
def _get_module(self, module_name: str):
117+
return importlib.import_module("." + module_name, self.__name__)
118+
119+
sys.modules[__name__] = _LazyModule(__name__, _import_structure)

0 commit comments

Comments
 (0)