skypilot
1name: tpuvm_mnist2
3resources:4accelerators: tpu-v2-85
6# The setup command. Will be run under the working directory.
7setup: |8git clone https://github.com/google/flax.git --branch v0.6.11
9
10conda activate flax
11if [ $? -eq 0 ]; then12echo 'conda env exists'
13else
14conda create -n flax python=3.10 -y15conda activate flax
16# Make sure to install TPU related packages in a conda env to avoid package conflicts.17pip install "jax[tpu]==0.4.23" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html18pip install --upgrade clu19pip install -e flax20pip install tensorflow tensorflow-datasets21fi
22
23
24# The command to run. Will be run under the working directory.
25run: |26conda activate flax
27cd flax/examples/mnist
28python3 main.py --workdir=/tmp/mnist \
29--config=configs/default.py \
30--config.learning_rate=0.05 \
31--config.num_epochs=10
32