如何展开训练循环,以便Jax可以在GPU/TPU上进行多步训练。

huangapple go评论49阅读模式
英文:

How to unroll the training loop so that Jax can train multiple steps in GPU/TPU

问题

Here's the translated code portion you requested:

当使用强大的硬件尤其是TPU时通常最好进行多步训练例如在TensorFlow中这是可能的

with strategy.scope():
  model = create_model()
  optimizer_inner = AdamW(weight_decay=1e-6)
  optimizer_middle = SWA(optimizer_inner)
  optimizer = Lookahead(optimizer_middle)
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# 计算每个副本的批次大小,并在每个TPU工作器上分发`tf.data.Dataset`。
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
  """一个训练步骤的函数。"""

  def step_fn(inputs):
    """在每个TPU设备上运行的计算。"""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

train_iterator = iter(train_dataset)
# 将`steps_per_epoch`转换为`tf.Tensor`,以防`tf.function`在值更改时重新跟踪。
for epoch in range(10):
  print('Epoch: {}/10'.format(epoch))

  train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

If you need further assistance or have additional questions, please feel free to ask.

英文:

When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.

with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
"""The step function for one training step."""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

In Jax or Flax, however, I haven't seen a complete working example of doing so. I guess it would be something like

@jax.jit
def train_for_n_steps(train_state, batches):
for batch in batches:
train_state = train_step_fn(train_state, batch)
return train_state

However, in my case when I am trying to test the complete example, I am not sure how one can create multiple batches. Here is a working example using GPU without training multiple steps. The relevant code should probably be here:

for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric,value in test_state.metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")

My goal is to unroll 5 loops when training.

Any suggestions are welcomed.

答案1

得分: 1

你可以使用 more_itertools.chunked 来实现类似以下的代码:

for step, five_batches in chunked(train_ds.as_numpy_iterator()):
    state = five_steps(state, five_batches)

然后进行展开:

@jax.jit
def five_steps(state, batches):
    for batch in batches:
        state = train_step(state, batch)
    return state

这种方法之所以有效是因为 batches 的长度不依赖于数据,所以追踪时循环将执行 5 次。

这可能会使 jit 编译时间比您期望的长,所以更优选但更复杂的方法是将批次打包成 [N x batch_size x ...] 张量,然后使用 scan 来在输入上循环执行更新函数。

英文:

You could use more_itertools.chunked to get something like this:

for step, five_batches in chunked(train_ds.as_numpy_iterator()):
    state = five_steps(state, five_batches):

Then do the unrolling

@jax.jit
def five_steps(state, batches):
    for batch in batches:
        state = train_step(state, batch)
    return state

The reason this works is that batches has a length that isn't data dependent, so the loop will just get executed 5 times during tracing.

This will likely make jitting take much longer than you want, so the perferred but more difficult way is to pack the batches into [N x batch_size x ...] tensors, then use scan to loop your update function over the inputs.

huangapple
  • 本文由 发表于 2023年5月7日 10:47:46
  • 转载请务必保留本文链接:https://go.coder-hub.com/76191996.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定