transformers
63 строки · 2.5 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import glob18import os19
20from get_test_info import get_tester_classes21
22
23if __name__ == "__main__":24failures = []25
26pattern = os.path.join("tests", "models", "**", "test_modeling_*.py")27test_files = glob.glob(pattern)28# TODO: deal with TF/Flax too29test_files = [30x for x in test_files if not (x.startswith("test_modeling_tf_") or x.startswith("test_modeling_flax_"))31]32
33for test_file in test_files:34tester_classes = get_tester_classes(test_file)35for tester_class in tester_classes:36# A few tester classes don't have `parent` parameter in `__init__`.37# TODO: deal this better38try:39tester = tester_class(parent=None)40except Exception:41continue42if hasattr(tester, "get_config"):43config = tester.get_config()44for k, v in config.to_dict().items():45if isinstance(v, int):46target = None47if k in ["vocab_size"]:48target = 10049elif k in ["max_position_embeddings"]:50target = 12851elif k in ["hidden_size", "d_model"]:52target = 4053elif k == ["num_layers", "num_hidden_layers", "num_encoder_layers", "num_decoder_layers"]:54target = 555if target is not None and v > target:56failures.append(57f"{tester_class.__name__} will produce a `config` of type `{config.__class__.__name__}`"58f' with config["{k}"] = {v} which is too large for testing! Set its value to be smaller'59f" than {target}."60)61
62if len(failures) > 0:63raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))64