Model Training Patterns - Distribution Strategy

With machine learning problems, as the size of models and data increases, the computation and memory demands increases proportionately.

To tackle this, we can use one of the following strategies:

  1. Data parallelism (synchronous/asynchronous) - different workers train on different subsets of the training data
  2. Model parallelism - different works carry out computation for different parts of the model

Data parallelism

Data parallelism is useful when the computation involved per weight is high.

Synchronous training

  1. For each step of SGD, a mini-batch of data is split across different workers
  2. Each worker does a forward pass on it's subset and computes gradients for each parameter of the model
  3. These gradients are collected and aggregated to produce a single gradient update for each parameter of the model
  4. The most recent updated copy of the model is stored in a central server where the gradient update is done for all parameters of the model
  5. After the parameters are updated in the central server, the model is sent back to the workers for the next step of SGD on the next mini-batch of data
# synchronous
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
  # model parameters are created as mirrored variables instead of regular variables
  # variables are mirrored across all workers
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, input_shape=(5,)),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(1)
  ])
  model.compile(loss='mse', optimizer='sgd')
  model.fit(train_dataset, epochs=2)
  model.evaluate(train_dataset)

Asynchronous training

It's almost the same as synchronous training. The difference being the workers are not put on hold when the parameters are being updated in the central server. Instead, this uses a parameter server architecture where the model is updated as soon as a gradient update is received from one of the workers.

# asynchronous
async_strategy = tf.distribute.ParameterServerStrategy()
with async_strategy.scope():
  model = tf.keras.Sequential([
    tf.keras.layers.Dense(32, input_shape=(5,)),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dense(1)
  ])
  model.compile(loss='mse', optimizer='sgd')
  model.fit(train_dataset, epochs=2)
  model.evaluate(train_dataset)

To improve data parallelism, when the model is large, we can increase the mini-batch size. This will result in a decrease in total number of training iterations, thus, leading to faster training process.

However, increasing mini-batch sizes adversely affects the rate of convergence of SGD. To avoid this, if we are considering to increase batch size, then learning rate should scale linearly as a function of batch size to maintain low validation error while decreasing the time of distributed training.

Model Parallelism

Model parallelism is useful when the NN is too large to fit in a single device's memory.

  1. In this case, the model is partitioned across multiple devices.
  2. Each device operates over the same mini-batch of data but carries out computation only for the portion of the model assigned to it.

ASICs (Application-specific integrated circuits)

ASICs such as TPUs are faster than GPUs and help speed up the training process.

To run distributed training jobs on TPUs:

  1. Setup TPUClusterResolver - points to the location of the TPUs
  2. Create a TPUStrategy

We don't need to specify any arguments for tpu_address when working with Google Colab.

cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address)
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
tpu_strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)

Making the most of data parallelism

In order to make the most of data parallelism, we need to reduce the I/O. Since, distributed strategies speed up training, it becomes extremely important to keep the data flowing and wasting time keeping the devices idle. To achieve this, we can use tf.data.Dataset.prefetch(buffer_size). This overlaps the preprocessing and model execution of a training step. So, while the model is executing training step N, the input pipeline is preparing data for training step N + 1.

20