transformers

Форк
0
/
split_model_tests.py 
65 строк · 2.2 Кб
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
"""
16
This script is used to get the list of folders under `tests/models` and split the list into `NUM_SLICES` splits.
17
The main use case is a GitHub Actions workflow file calling this script to get the (nested) list of folders allowing it
18
to split the list of jobs to run into multiple slices each containing a smaller number of jobs. This way, we can bypass
19
the maximum of 256 jobs in a matrix.
20

21
See the `setup` and `run_tests_gpu` jobs defined in the workflow file `.github/workflows/self-scheduled.yml` for more
22
details.
23

24
Usage:
25

26
This script is required to be run under `tests` folder of `transformers` root directory.
27

28
Assume we are under `transformers` root directory:
29
```bash
30
cd tests
31
python ../utils/split_model_tests.py --num_splits 64
32
```
33
"""
34

35
import argparse
36
import os
37

38

39
if __name__ == "__main__":
40
    parser = argparse.ArgumentParser()
41
    parser.add_argument(
42
        "--num_splits",
43
        type=int,
44
        default=1,
45
        help="the number of splits into which the (flat) list of folders will be split.",
46
    )
47
    args = parser.parse_args()
48

49
    tests = os.getcwd()
50
    model_tests = os.listdir(os.path.join(tests, "models"))
51
    d1 = sorted(filter(os.path.isdir, os.listdir(tests)))
52
    d2 = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))
53
    d1.remove("models")
54
    d = d2 + d1
55

56
    num_jobs = len(d)
57
    num_jobs_per_splits = num_jobs // args.num_splits
58

59
    model_splits = []
60
    end = 0
61
    for idx in range(args.num_splits):
62
        start = end
63
        end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)
64
        model_splits.append(d[start:end])
65
    print(model_splits)
66

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.