transformers
65 строк · 2.2 Кб
1# Copyright 2024 The HuggingFace Team. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16This script is used to get the list of folders under `tests/models` and split the list into `NUM_SLICES` splits.
17The main use case is a GitHub Actions workflow file calling this script to get the (nested) list of folders allowing it
18to split the list of jobs to run into multiple slices each containing a smaller number of jobs. This way, we can bypass
19the maximum of 256 jobs in a matrix.
20
21See the `setup` and `run_tests_gpu` jobs defined in the workflow file `.github/workflows/self-scheduled.yml` for more
22details.
23
24Usage:
25
26This script is required to be run under `tests` folder of `transformers` root directory.
27
28Assume we are under `transformers` root directory:
29```bash
30cd tests
31python ../utils/split_model_tests.py --num_splits 64
32```
33"""
34
35import argparse36import os37
38
39if __name__ == "__main__":40parser = argparse.ArgumentParser()41parser.add_argument(42"--num_splits",43type=int,44default=1,45help="the number of splits into which the (flat) list of folders will be split.",46)47args = parser.parse_args()48
49tests = os.getcwd()50model_tests = os.listdir(os.path.join(tests, "models"))51d1 = sorted(filter(os.path.isdir, os.listdir(tests)))52d2 = sorted(filter(os.path.isdir, [f"models/{x}" for x in model_tests]))53d1.remove("models")54d = d2 + d155
56num_jobs = len(d)57num_jobs_per_splits = num_jobs // args.num_splits58
59model_splits = []60end = 061for idx in range(args.num_splits):62start = end63end = start + num_jobs_per_splits + (1 if idx < num_jobs % args.num_splits else 0)64model_splits.append(d[start:end])65print(model_splits)66