skypilot

Форк
0
98 строк · 3.0 Кб
1
envs:
2
  HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token
3
  ARTIFACT_BUCKET_NAME: YOUR_OWN_BUCKET_NAME # Change to your own bucket name
4
  WANDB_API_KEY: "" # Change to your own wandb api key
5
  MODEL_SIZE: 7
6
  USE_XFORMERS: 1
7

8
resources:
9
  accelerators: A100-80GB:8
10
  disk_size: 1024
11
  use_spot: true
12

13
num_nodes: 1
14

15
file_mounts:
16
  /artifacts:
17
    name: $ARTIFACT_BUCKET_NAME
18
    mode: MOUNT
19

20
workdir: .
21

22
setup: |
23
  # Download the ShareGPT dataset
24
  # Change to your OWN dataset if you want to train your own model
25
  mkdir -p $HOME/data
26
  wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
27

28
  # Setup the environment
29
  conda activate chatbot
30
  if [ $? -ne 0 ]; then
31
    conda create -n chatbot python=3.10 -y
32
    conda activate chatbot
33
  fi
34
  cd ./scripts
35
  # Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
36
  # has issues with checkpoint saving -- saving additional large files in the checkpoint folder
37
  pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
38
  if [ $USE_XFORMERS -eq 1 ]; then
39
    pip install -U xformers
40
  fi
41
  python hardcoded_questions.py
42
  python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json
43

44
  python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
45

46
run: |
47
  cd scripts
48
  conda activate chatbot
49
  if [ $USE_XFORMERS -eq 1 ]; then
50
    TRAIN_SCRIPT=train_xformers.py
51
  else
52
    TRAIN_SCRIPT=train.py
53
  fi
54

55
  PER_DEVICE_BATCH_SIZE=4
56
  SEQ_LEN=2048
57
  NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
58
  HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
59

60
  # Turn off wandb if no api key is provided
61
  if [ $WANDB_API_KEY == "" ]; then
62
    WANDB_MODE="offline"
63
  fi
64
  
65
  torchrun \
66
    --nnodes=$NUM_NODES \
67
    --nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
68
    --master_port=12375 \
69
    --master_addr=$HOST_ADDR \
70
    --node_rank=${SKYPILOT_NODE_RANK} \
71
    $TRAIN_SCRIPT \
72
    --model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
73
    --data_path $HOME/data/mydata.json \
74
    --bf16 True \
75
    --output_dir /artifacts/chatbot/${MODEL_SIZE}b \
76
    --num_train_epochs 3 \
77
    --per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
78
    --per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
79
    --gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
80
    --evaluation_strategy "no" \
81
    --save_strategy "steps" \
82
    --save_steps 600 \
83
    --save_total_limit 10 \
84
    --learning_rate 2e-5 \
85
    --weight_decay 0. \
86
    --warmup_ratio 0.03 \
87
    --lr_scheduler_type "cosine" \
88
    --logging_steps 1 \
89
    --fsdp "full_shard auto_wrap" \
90
    --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
91
    --tf32 True \
92
    --model_max_length ${SEQ_LEN} \
93
    --run_name $SKYPILOT_TASK_ID \
94
    --gradient_checkpointing True \
95
    --lazy_preprocess True
96

97
  returncode=$?
98
  exit $returncode
99

100

101

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.