skypilot

Форк
0
/
bert_qa_train_eval.yaml 
87 строк · 2.3 Кб
1
name: bert_qa
2

3
---
4

5
name: train
6

7
resources:
8
    accelerators: V100:1
9

10
# Assume your working directory is under `~/transformers`.
11
# To make this example work, please run the following command:
12
# git clone https://github.com/huggingface/transformers.git ~/transformers -b v4.30.1
13
workdir: ~/transformers
14

15
file_mounts:
16
    /checkpoint:
17
        name: test-bert-train-eval # NOTE: Fill in your bucket name
18
        mode: MOUNT
19

20
setup: |
21
    pip install -e .
22
    cd examples/pytorch/question-answering/
23
    pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
24
    pip install wandb
25

26
run: |
27
    cd examples/pytorch/question-answering/
28
    python run_qa.py \
29
    --model_name_or_path bert-base-uncased \
30
    --dataset_name squad \
31
    --do_train \
32
    --per_device_train_batch_size 12 \
33
    --learning_rate 3e-5 \
34
    --num_train_epochs 1 \
35
    --max_seq_length 384 \
36
    --doc_stride 128 \
37
    --report_to wandb \
38
    --run_name $SKYPILOT_TASK_ID \
39
    --output_dir /checkpoint/bert_qa/$SKYPILOT_TASK_ID \
40
    --save_total_limit 10 \
41
    --save_steps 1000
42
    echo Model saved to /checkpoint/bert_qa/$SKYPILOT_TASK_ID
43

44
envs:
45
    WANDB_API_KEY: # NOTE: Fill in your wandb key
46

47
---
48

49
name: eval
50

51
resources:
52
    accelerators: T4:1
53

54
workdir: ~/transformers
55

56
file_mounts:
57
    /checkpoint:
58
        name: test-bert-train-eval # NOTE: Fill in your bucket name
59
        mode: MOUNT
60

61
setup: |
62
    pip install -e .
63
    cd examples/pytorch/question-answering/
64
    pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
65
    pip install wandb
66

67
run: |
68
    FIRST_TASK_UNIQUE_ID=$(echo "$SKYPILOT_TASK_IDS" | sed -n 1p)
69
    echo Load model from /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID
70
    cd examples/pytorch/question-answering/
71
    python run_qa.py \
72
    --model_name_or_path /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID \
73
    --dataset_name squad \
74
    --do_eval \
75
    --per_device_train_batch_size 12 \
76
    --learning_rate 3e-5 \
77
    --num_train_epochs 50 \
78
    --max_seq_length 384 \
79
    --doc_stride 128 \
80
    --report_to wandb \
81
    --run_name $SKYPILOT_TASK_ID \
82
    --output_dir /checkpoint/bert_qa/$FIRST_TASK_UNIQUE_ID \
83
    --save_total_limit 10 \
84
    --save_steps 1000
85

86
envs:
87
    WANDB_API_KEY: # NOTE: Fill in your wandb key
88

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.