skypilot

Форк
0
53 строки · 1.9 Кб
1
# Example script that uses Keras and tf.distribute.MultiWorkerMirroredStrategy to train a model on multiple workers.
2
#
3
# Usage (two separate processes):
4
#   TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:23456"]},"task":{"type":"worker","index":0}}' python train.py
5
#   TF_CONFIG='{"cluster":{"worker":["localhost:12345","localhost:23456"]},"task":{"type":"worker","index":1}}' python train.py
6

7
import json
8
import os
9

10
import numpy as np
11
import tensorflow as tf
12

13

14
def mnist_dataset(batch_size):
15
    (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
16
    # The `x` arrays are in uint8 and have values in the [0, 255] range.
17
    # You need to convert them to float32 with values in the [0, 1] range.
18
    x_train = x_train / np.float32(255)
19
    y_train = y_train.astype(np.int64)
20
    train_dataset = tf.data.Dataset.from_tensor_slices(
21
        (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
22
    return train_dataset
23

24

25
def build_and_compile_cnn_model():
26
    model = tf.keras.Sequential([
27
        tf.keras.layers.InputLayer(input_shape=(28, 28)),
28
        tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
29
        tf.keras.layers.Conv2D(32, 3, activation='relu'),
30
        tf.keras.layers.Flatten(),
31
        tf.keras.layers.Dense(128, activation='relu'),
32
        tf.keras.layers.Dense(10)
33
    ])
34
    model.compile(
35
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
36
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
37
        metrics=['accuracy'])
38
    return model
39

40

41
per_worker_batch_size = 64
42
tf_config = json.loads(os.environ['TF_CONFIG'])
43
num_workers = len(tf_config['cluster']['worker'])
44

45
strategy = tf.distribute.MultiWorkerMirroredStrategy()
46

47
global_batch_size = per_worker_batch_size * num_workers
48
multi_worker_dataset = mnist_dataset(global_batch_size)
49

50
with strategy.scope():
51
    multi_worker_model = build_and_compile_cnn_model()
52

53
multi_worker_model.fit(multi_worker_dataset, epochs=10, steps_per_epoch=70)
54

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.