skypilot
53 строки · 1.9 Кб
1# Example script that uses Keras and tf.distribute.MultiWorkerMirroredStrategy to train a model on multiple workers.
2#
3# Usage (two separate processes):
4# TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:23456"]},"task":{"type":"worker","index":0}}' python train.py
5# TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:23456"]},"task":{"type":"worker","index":1}}' python train.py
6
7import json8import os9
10import numpy as np11import tensorflow as tf12
13
14def mnist_dataset(batch_size):15(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()16# The `x` arrays are in uint8 and have values in the [0, 255] range.17# You need to convert them to float32 with values in the [0, 1] range.18x_train = x_train / np.float32(255)19y_train = y_train.astype(np.int64)20train_dataset = tf.data.Dataset.from_tensor_slices(21(x_train, y_train)).shuffle(60000).repeat().batch(batch_size)22return train_dataset23
24
25def build_and_compile_cnn_model():26model = tf.keras.Sequential([27tf.keras.layers.InputLayer(input_shape=(28, 28)),28tf.keras.layers.Reshape(target_shape=(28, 28, 1)),29tf.keras.layers.Conv2D(32, 3, activation='relu'),30tf.keras.layers.Flatten(),31tf.keras.layers.Dense(128, activation='relu'),32tf.keras.layers.Dense(10)33])34model.compile(35loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),36optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),37metrics=['accuracy'])38return model39
40
41per_worker_batch_size = 6442tf_config = json.loads(os.environ['TF_CONFIG'])43num_workers = len(tf_config['cluster']['worker'])44
45strategy = tf.distribute.MultiWorkerMirroredStrategy()46
47global_batch_size = per_worker_batch_size * num_workers48multi_worker_dataset = mnist_dataset(global_batch_size)49
50with strategy.scope():51multi_worker_model = build_and_compile_cnn_model()52
53multi_worker_model.fit(multi_worker_dataset, epochs=10, steps_per_epoch=70)54