skypilot
98 строк · 3.0 Кб
1envs:
2HF_TOKEN: <your-huggingface-token> # Change to your own huggingface token
3ARTIFACT_BUCKET_NAME: YOUR_OWN_BUCKET_NAME # Change to your own bucket name
4WANDB_API_KEY: "" # Change to your own wandb api key
5MODEL_SIZE: 7
6USE_XFORMERS: 1
7
8resources:
9accelerators: A100-80GB:8
10disk_size: 1024
11use_spot: true
12
13num_nodes: 1
14
15file_mounts:
16/artifacts:
17name: $ARTIFACT_BUCKET_NAME
18mode: MOUNT
19
20workdir: .
21
22setup: |
23# Download the ShareGPT dataset
24# Change to your OWN dataset if you want to train your own model
25mkdir -p $HOME/data
26wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
27
28# Setup the environment
29conda activate chatbot
30if [ $? -ne 0 ]; then
31conda create -n chatbot python=3.10 -y
32conda activate chatbot
33fi
34cd ./scripts
35# Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
36# has issues with checkpoint saving -- saving additional large files in the checkpoint folder
37pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
38if [ $USE_XFORMERS -eq 1 ]; then
39pip install -U xformers
40fi
41python hardcoded_questions.py
42python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json
43
44python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
45
46run: |
47cd scripts
48conda activate chatbot
49if [ $USE_XFORMERS -eq 1 ]; then
50TRAIN_SCRIPT=train_xformers.py
51else
52TRAIN_SCRIPT=train.py
53fi
54
55PER_DEVICE_BATCH_SIZE=4
56SEQ_LEN=2048
57NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
58HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
59
60# Turn off wandb if no api key is provided
61if [ $WANDB_API_KEY == "" ]; then
62WANDB_MODE="offline"
63fi
64
65torchrun \
66--nnodes=$NUM_NODES \
67--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
68--master_port=12375 \
69--master_addr=$HOST_ADDR \
70--node_rank=${SKYPILOT_NODE_RANK} \
71$TRAIN_SCRIPT \
72--model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
73--data_path $HOME/data/mydata.json \
74--bf16 True \
75--output_dir /artifacts/chatbot/${MODEL_SIZE}b \
76--num_train_epochs 3 \
77--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
78--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
79--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
80--evaluation_strategy "no" \
81--save_strategy "steps" \
82--save_steps 600 \
83--save_total_limit 10 \
84--learning_rate 2e-5 \
85--weight_decay 0. \
86--warmup_ratio 0.03 \
87--lr_scheduler_type "cosine" \
88--logging_steps 1 \
89--fsdp "full_shard auto_wrap" \
90--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
91--tf32 True \
92--model_max_length ${SEQ_LEN} \
93--run_name $SKYPILOT_TASK_ID \
94--gradient_checkpointing True \
95--lazy_preprocess True
96
97returncode=$?
98exit $returncode
99
100
101