How to unroll the training loop so that Jax can train multiple steps in GPU/TPU
Here's the translated code portion you requested:
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# 计算每个副本的批次大小,并在每个TPU工作器上分发`tf.data.Dataset`。
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
def train_multiple_steps(iterator, steps):
def step_fn(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# 将`steps_per_epoch`转换为`tf.Tensor`,以防`tf.function`在值更改时重新跟踪。
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
If you need further assistance or have additional questions, please feel free to ask.
When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
def train_multiple_steps(iterator, steps):
"""The step function for one training step."""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
In Jax or Flax, however, I haven't seen a complete working example of doing so. I guess it would be something like
def train_for_n_steps(train_state, batches):
for batch in batches:
train_state = train_step_fn(train_state, batch)
return train_state
However, in my case when I am trying to test the complete example, I am not sure how one can create multiple batches. Here is a working example using GPU without training multiple steps. The relevant code should probably be here:
for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric,value in test_state.metrics.compute().items():
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")
My goal is to unroll 5 loops when training.
Any suggestions are welcomed.
得分: 1
你可以使用 more_itertools.chunked
for step, five_batches in chunked(train_ds.as_numpy_iterator()):
state = five_steps(state, five_batches)
def five_steps(state, batches):
for batch in batches:
state = train_step(state, batch)
return state
这种方法之所以有效是因为 batches
的长度不依赖于数据,所以追踪时循环将执行 5 次。
这可能会使 jit 编译时间比您期望的长,所以更优选但更复杂的方法是将批次打包成 [N x batch_size x ...]
张量,然后使用 scan 来在输入上循环执行更新函数。
You could use more_itertools.chunked
to get something like this:
for step, five_batches in chunked(train_ds.as_numpy_iterator()):
state = five_steps(state, five_batches):
Then do the unrolling
def five_steps(state, batches):
for batch in batches:
state = train_step(state, batch)
return state
The reason this works is that batches
has a length that isn't data dependent, so the loop will just get executed 5 times during tracing.
This will likely make jitting take much longer than you want, so the perferred but more difficult way is to pack the batches into [N x batch_size x ...]
tensors, then use scan to loop your update function over the inputs.