transformers
454 строки · 17.3 Кб
1# Copyright 2020 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import shutil
17import sys
18import tempfile
19import unittest
20from contextlib import contextmanager
21from pathlib import Path
22
23
24git_repo_path = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
25sys.path.append(os.path.join(git_repo_path, "utils"))
26
27import check_copies # noqa: E402
28from check_copies import convert_to_localized_md, find_code_in_transformers, is_copy_consistent # noqa: E402
29
30
31# This is the reference code that will be used in the tests.
32# If BertLMPredictionHead is changed in modeling_bert.py, this code needs to be manually updated.
33REFERENCE_CODE = """ def __init__(self, config):
34super().__init__()
35self.transform = BertPredictionHeadTransform(config)
36
37# The output weights are the same as the input embeddings, but there is
38# an output-only bias for each token.
39self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
40
41self.bias = nn.Parameter(torch.zeros(config.vocab_size))
42
43# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
44self.decoder.bias = self.bias
45
46def forward(self, hidden_states):
47hidden_states = self.transform(hidden_states)
48hidden_states = self.decoder(hidden_states)
49return hidden_states
50"""
51
52MOCK_BERT_CODE = """from ...modeling_utils import PreTrainedModel
53
54def bert_function(x):
55return x
56
57
58class BertAttention(nn.Module):
59def __init__(self, config):
60super().__init__()
61
62
63class BertModel(BertPreTrainedModel):
64def __init__(self, config):
65super().__init__()
66self.bert = BertEncoder(config)
67
68@add_docstring(BERT_DOCSTRING)
69def forward(self, x):
70return self.bert(x)
71"""
72
73MOCK_BERT_COPY_CODE = """from ...modeling_utils import PreTrainedModel
74
75# Copied from transformers.models.bert.modeling_bert.bert_function
76def bert_copy_function(x):
77return x
78
79
80# Copied from transformers.models.bert.modeling_bert.BertAttention
81class BertCopyAttention(nn.Module):
82def __init__(self, config):
83super().__init__()
84
85
86# Copied from transformers.models.bert.modeling_bert.BertModel with Bert->BertCopy all-casing
87class BertCopyModel(BertCopyPreTrainedModel):
88def __init__(self, config):
89super().__init__()
90self.bertcopy = BertCopyEncoder(config)
91
92@add_docstring(BERTCOPY_DOCSTRING)
93def forward(self, x):
94return self.bertcopy(x)
95"""
96
97
98MOCK_DUMMY_BERT_CODE_MATCH = """
99class BertDummyModel:
100attr_1 = 1
101attr_2 = 2
102
103def __init__(self, a=1, b=2):
104self.a = a
105self.b = b
106
107# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
108def forward(self, c):
109return 1
110
111def existing_common(self, c):
112return 4
113
114def existing_diff_to_be_ignored(self, c):
115return 9
116"""
117
118
119MOCK_DUMMY_ROBERTA_CODE_MATCH = """
120# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
121class RobertaBertDummyModel:
122
123attr_1 = 1
124attr_2 = 2
125
126def __init__(self, a=1, b=2):
127self.a = a
128self.b = b
129
130# Ignore copy
131def only_in_roberta_to_be_ignored(self, c):
132return 3
133
134# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
135def forward(self, c):
136return 1
137
138def existing_common(self, c):
139return 4
140
141# Ignore copy
142def existing_diff_to_be_ignored(self, c):
143return 6
144"""
145
146
147MOCK_DUMMY_BERT_CODE_NO_MATCH = """
148class BertDummyModel:
149attr_1 = 1
150attr_2 = 2
151
152def __init__(self, a=1, b=2):
153self.a = a
154self.b = b
155
156# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
157def forward(self, c):
158return 1
159
160def only_in_bert(self, c):
161return 7
162
163def existing_common(self, c):
164return 4
165
166def existing_diff_not_ignored(self, c):
167return 8
168
169def existing_diff_to_be_ignored(self, c):
170return 9
171"""
172
173
174MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
175# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
176class RobertaBertDummyModel:
177
178attr_1 = 1
179attr_2 = 3
180
181def __init__(self, a=1, b=2):
182self.a = a
183self.b = b
184
185# Ignore copy
186def only_in_roberta_to_be_ignored(self, c):
187return 3
188
189# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
190def forward(self, c):
191return 1
192
193def only_in_roberta_not_ignored(self, c):
194return 2
195
196def existing_common(self, c):
197return 4
198
199def existing_diff_not_ignored(self, c):
200return 5
201
202# Ignore copy
203def existing_diff_to_be_ignored(self, c):
204return 6
205"""
206
207
208EXPECTED_REPLACED_CODE = """
209# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
210class RobertaBertDummyModel:
211attr_1 = 1
212attr_2 = 2
213
214def __init__(self, a=1, b=2):
215self.a = a
216self.b = b
217
218# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
219def forward(self, c):
220return 1
221
222def only_in_bert(self, c):
223return 7
224
225def existing_common(self, c):
226return 4
227
228def existing_diff_not_ignored(self, c):
229return 8
230
231# Ignore copy
232def existing_diff_to_be_ignored(self, c):
233return 6
234
235# Ignore copy
236def only_in_roberta_to_be_ignored(self, c):
237return 3
238"""
239
240
241def replace_in_file(filename, old, new):
242with open(filename, "r", encoding="utf-8") as f:
243content = f.read()
244
245content = content.replace(old, new)
246
247with open(filename, "w", encoding="utf-8", newline="\n") as f:
248f.write(content)
249
250
251def create_tmp_repo(tmp_dir):
252"""
253Creates a mock repository in a temporary folder for testing.
254"""
255tmp_dir = Path(tmp_dir)
256if tmp_dir.exists():
257shutil.rmtree(tmp_dir)
258tmp_dir.mkdir(exist_ok=True)
259
260model_dir = tmp_dir / "src" / "transformers" / "models"
261model_dir.mkdir(parents=True, exist_ok=True)
262
263models = {
264"bert": MOCK_BERT_CODE,
265"bertcopy": MOCK_BERT_COPY_CODE,
266"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
267"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
268"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
269"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
270}
271for model, code in models.items():
272model_subdir = model_dir / model
273model_subdir.mkdir(exist_ok=True)
274with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
275f.write(code)
276
277
278@contextmanager
279def patch_transformer_repo_path(new_folder):
280"""
281Temporarily patches the variables defines in `check_copies` to use a different location for the repo.
282"""
283old_repo_path = check_copies.REPO_PATH
284old_doc_path = check_copies.PATH_TO_DOCS
285old_transformer_path = check_copies.TRANSFORMERS_PATH
286repo_path = Path(new_folder).resolve()
287check_copies.REPO_PATH = str(repo_path)
288check_copies.PATH_TO_DOCS = str(repo_path / "docs" / "source" / "en")
289check_copies.TRANSFORMERS_PATH = str(repo_path / "src" / "transformers")
290try:
291yield
292finally:
293check_copies.REPO_PATH = old_repo_path
294check_copies.PATH_TO_DOCS = old_doc_path
295check_copies.TRANSFORMERS_PATH = old_transformer_path
296
297
298class CopyCheckTester(unittest.TestCase):
299def test_find_code_in_transformers(self):
300with tempfile.TemporaryDirectory() as tmp_folder:
301create_tmp_repo(tmp_folder)
302with patch_transformer_repo_path(tmp_folder):
303code = find_code_in_transformers("models.bert.modeling_bert.BertAttention")
304
305reference_code = (
306"class BertAttention(nn.Module):\n def __init__(self, config):\n super().__init__()\n"
307)
308self.assertEqual(code, reference_code)
309
310def test_is_copy_consistent(self):
311path_to_check = ["src", "transformers", "models", "bertcopy", "modeling_bertcopy.py"]
312with tempfile.TemporaryDirectory() as tmp_folder:
313# Base check
314create_tmp_repo(tmp_folder)
315with patch_transformer_repo_path(tmp_folder):
316file_to_check = os.path.join(tmp_folder, *path_to_check)
317diffs = is_copy_consistent(file_to_check)
318self.assertEqual(diffs, [])
319
320# Base check with an inconsistency
321create_tmp_repo(tmp_folder)
322with patch_transformer_repo_path(tmp_folder):
323file_to_check = os.path.join(tmp_folder, *path_to_check)
324
325replace_in_file(file_to_check, "self.bertcopy(x)", "self.bert(x)")
326diffs = is_copy_consistent(file_to_check)
327self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
328
329_ = is_copy_consistent(file_to_check, overwrite=True)
330
331with open(file_to_check, "r", encoding="utf-8") as f:
332self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
333
334def test_is_copy_consistent_with_ignored_match(self):
335path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
336with tempfile.TemporaryDirectory() as tmp_folder:
337# Base check
338create_tmp_repo(tmp_folder)
339with patch_transformer_repo_path(tmp_folder):
340file_to_check = os.path.join(tmp_folder, *path_to_check)
341diffs = is_copy_consistent(file_to_check)
342self.assertEqual(diffs, [])
343
344def test_is_copy_consistent_with_ignored_no_match(self):
345path_to_check = [
346"src",
347"transformers",
348"models",
349"dummy_roberta_no_match",
350"modeling_dummy_roberta_no_match.py",
351]
352with tempfile.TemporaryDirectory() as tmp_folder:
353# Base check with an inconsistency
354create_tmp_repo(tmp_folder)
355with patch_transformer_repo_path(tmp_folder):
356file_to_check = os.path.join(tmp_folder, *path_to_check)
357
358diffs = is_copy_consistent(file_to_check)
359# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
360# (which has a leading `\n`.)
361self.assertEqual(
362diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
363)
364
365_ = is_copy_consistent(file_to_check, overwrite=True)
366
367with open(file_to_check, "r", encoding="utf-8") as f:
368self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
369
370def test_convert_to_localized_md(self):
371localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
372
373md_list = (
374"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
375" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
376" Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
377" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut.\n1."
378" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (from HuggingFace),"
379" released together with the paper [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
380" lighter](https://arxiv.org/abs/1910.01108) by Victor Sanh, Lysandre Debut and Thomas Wolf. The same"
381" method has been applied to compress GPT2 into"
382" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
383" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
384" Multilingual BERT into"
385" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
386" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)**"
387" (from Google Research/Stanford University) released with the paper [ELECTRA: Pre-training text encoders"
388" as discriminators rather than generators](https://arxiv.org/abs/2003.10555) by Kevin Clark, Minh-Thang"
389" Luong, Quoc V. Le, Christopher D. Manning."
390)
391localized_md_list = (
392"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
393" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
394" Language Representations](https://arxiv.org/abs/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
395" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
396)
397converted_md_list_sample = (
398"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
399" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
400" Language Representations](https://arxiv.org/abs/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
401" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n1."
402" **[DistilBERT](https://huggingface.co/transformers/model_doc/distilbert.html)** (来自 HuggingFace) 伴随论文"
403" [DistilBERT, a distilled version of BERT: smaller, faster, cheaper and"
404" lighter](https://arxiv.org/abs/1910.01108) 由 Victor Sanh, Lysandre Debut and Thomas Wolf 发布。 The same"
405" method has been applied to compress GPT2 into"
406" [DistilGPT2](https://github.com/huggingface/transformers/tree/main/examples/distillation), RoBERTa into"
407" [DistilRoBERTa](https://github.com/huggingface/transformers/tree/main/examples/distillation),"
408" Multilingual BERT into"
409" [DistilmBERT](https://github.com/huggingface/transformers/tree/main/examples/distillation) and a German"
410" version of DistilBERT.\n1. **[ELECTRA](https://huggingface.co/transformers/model_doc/electra.html)** (来自"
411" Google Research/Stanford University) 伴随论文 [ELECTRA: Pre-training text encoders as discriminators rather"
412" than generators](https://arxiv.org/abs/2003.10555) 由 Kevin Clark, Minh-Thang Luong, Quoc V. Le,"
413" Christopher D. Manning 发布。\n"
414)
415
416num_models_equal, converted_md_list = convert_to_localized_md(
417md_list, localized_md_list, localized_readme["format_model_list"]
418)
419
420self.assertFalse(num_models_equal)
421self.assertEqual(converted_md_list, converted_md_list_sample)
422
423num_models_equal, converted_md_list = convert_to_localized_md(
424md_list, converted_md_list, localized_readme["format_model_list"]
425)
426
427# Check whether the number of models is equal to README.md after conversion.
428self.assertTrue(num_models_equal)
429
430link_changed_md_list = (
431"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (from Google Research and the"
432" Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for"
433" Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong"
434" Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut."
435)
436link_unchanged_md_list = (
437"1. **[ALBERT](https://huggingface.co/transformers/main/model_doc/albert.html)** (来自 Google Research and"
438" the Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
439" Language Representations](https://arxiv.org/abs/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
440" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
441)
442converted_md_list_sample = (
443"1. **[ALBERT](https://huggingface.co/transformers/model_doc/albert.html)** (来自 Google Research and the"
444" Toyota Technological Institute at Chicago) 伴随论文 [ALBERT: A Lite BERT for Self-supervised Learning of"
445" Language Representations](https://arxiv.org/abs/1909.11942), 由 Zhenzhong Lan, Mingda Chen, Sebastian"
446" Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut 发布。\n"
447)
448
449num_models_equal, converted_md_list = convert_to_localized_md(
450link_changed_md_list, link_unchanged_md_list, localized_readme["format_model_list"]
451)
452
453# Check if the model link is synchronized.
454self.assertEqual(converted_md_list, converted_md_list_sample)
455