transformers
95 строк · 3.7 Кб
1# coding=utf-8
2# Copyright 2023 The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16Utility that checks the supports of 3rd party libraries are listed in the documentation file. Currently, this includes:
17- flash attention support
18- SDPA support
19
20Use from the root of the repo with (as used in `make repo-consistency`):
21
22```bash
23python utils/check_support_list.py
24```
25
26It has no auto-fix mode.
27"""
28import os29from glob import glob30
31
32# All paths are set with the intent you should run this script from the root of the repo with the command
33# python utils/check_doctest_list.py
34REPO_PATH = "."35
36
37def check_flash_support_list():38with open(os.path.join(REPO_PATH, "docs/source/en/perf_infer_gpu_one.md"), "r") as f:39doctext = f.read()40
41doctext = doctext.split("FlashAttention-2 is currently supported for the following architectures:")[1]42doctext = doctext.split("You can request to add FlashAttention-2 support")[0]43
44patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))45patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))46patterns_flax = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_flax_*.py"))47patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))48archs_supporting_fa2 = []49for filename in patterns:50with open(filename, "r") as f:51text = f.read()52
53if "_supports_flash_attn_2 = True" in text:54model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")55archs_supporting_fa2.append(model_name)56
57for arch in archs_supporting_fa2:58if arch not in doctext:59raise ValueError(60f"{arch} should be in listed in the flash attention documentation but is not. Please update the documentation."61)62
63
64def check_sdpa_support_list():65with open(os.path.join(REPO_PATH, "docs/source/en/perf_infer_gpu_one.md"), "r") as f:66doctext = f.read()67
68doctext = doctext.split(69"For now, Transformers supports SDPA inference and training for the following architectures:"70)[1]71doctext = doctext.split("Note that FlashAttention can only be used for models using the")[0]72
73patterns = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_*.py"))74patterns_tf = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_tf_*.py"))75patterns_flax = glob(os.path.join(REPO_PATH, "src/transformers/models/**/modeling_flax_*.py"))76patterns = list(set(patterns) - set(patterns_tf) - set(patterns_flax))77archs_supporting_sdpa = []78for filename in patterns:79with open(filename, "r") as f:80text = f.read()81
82if "_supports_sdpa = True" in text:83model_name = os.path.basename(filename).replace(".py", "").replace("modeling_", "")84archs_supporting_sdpa.append(model_name)85
86for arch in archs_supporting_sdpa:87if arch not in doctext:88raise ValueError(89f"{arch} should be in listed in the SDPA documentation but is not. Please update the documentation."90)91
92
93if __name__ == "__main__":94check_flash_support_list()95check_sdpa_support_list()96