英文:
How to unroll the training loop so that Jax can train multiple steps in GPU/TPU
问题
Here's the translated code portion you requested:
当使用强大的硬件,尤其是TPU时,通常最好进行多步训练。例如,在TensorFlow中,这是可能的。
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# 计算每个副本的批次大小,并在每个TPU工作器上分发`tf.data.Dataset`。
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
"""一个训练步骤的函数。"""
def step_fn(inputs):
"""在每个TPU设备上运行的计算。"""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# 将`steps_per_epoch`转换为`tf.Tensor`,以防`tf.function`在值更改时重新跟踪。
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
If you need further assistance or have additional questions, please feel free to ask.
英文:
When using powerful hardware, especially TPU, it is often preferable to train multiple steps. For example, in TensorFlow, this is possible.
with strategy.scope():
model = create_model()
optimizer_inner = AdamW(weight_decay=1e-6)
optimizer_middle = SWA(optimizer_inner)
optimizer = Lookahead(optimizer_middle)
training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
'training_accuracy', dtype=tf.float32)
# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size
train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync
train_dataset = strategy.experimental_distribute_datasets_from_function(
lambda _: get_dataset(per_replica_batch_size, is_training=True))
@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
"""The step function for one training step."""
def step_fn(inputs):
"""The computation to run on each TPU device."""
images, labels = inputs
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, logits, from_logits=True)
loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
training_loss.update_state(loss * strategy.num_replicas_in_sync)
training_accuracy.update_state(labels, logits)
for _ in tf.range(steps):
strategy.run(step_fn, args=(next(iterator),))
train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.
for epoch in range(10):
print('Epoch: {}/10'.format(epoch))
train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))
In Jax or Flax, however, I haven't seen a complete working example of doing so. I guess it would be something like
@jax.jit
def train_for_n_steps(train_state, batches):
for batch in batches:
train_state = train_step_fn(train_state, batch)
return train_state
However, in my case when I am trying to test the complete example, I am not sure how one can create multiple batches. Here is a working example using GPU without training multiple steps. The relevant code should probably be here:
for step,batch in enumerate(train_ds.as_numpy_iterator()):
# Run optimization steps over training batches and compute batch metrics
state = train_step(state, batch) # get updated train state (which contains the updated parameters)
state = compute_metrics(state=state, batch=batch) # aggregate batch metrics
if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
for metric,value in state.metrics.compute().items(): # compute metrics
metrics_history[f'train_{metric}'].append(value) # record metrics
state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch
# Compute metrics on the test set after each training epoch
test_state = state
for test_batch in test_ds.as_numpy_iterator():
test_state = compute_metrics(state=test_state, batch=test_batch)
for metric,value in test_state.metrics.compute().items():
metrics_history[f'test_{metric}'].append(value)
print(f"train epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['train_loss'][-1]}, "
f"accuracy: {metrics_history['train_accuracy'][-1] * 100}")
print(f"test epoch: {(step+1) // num_steps_per_epoch}, "
f"loss: {metrics_history['test_loss'][-1]}, "
f"accuracy: {metrics_history['test_accuracy'][-1] * 100}")
My goal is to unroll 5 loops when training.
Any suggestions are welcomed.
答案1
得分: 1
你可以使用 more_itertools.chunked
来实现类似以下的代码:
for step, five_batches in chunked(train_ds.as_numpy_iterator()):
state = five_steps(state, five_batches)
然后进行展开:
@jax.jit
def five_steps(state, batches):
for batch in batches:
state = train_step(state, batch)
return state
这种方法之所以有效是因为 batches
的长度不依赖于数据,所以追踪时循环将执行 5 次。
这可能会使 jit 编译时间比您期望的长,所以更优选但更复杂的方法是将批次打包成 [N x batch_size x ...]
张量,然后使用 scan 来在输入上循环执行更新函数。
英文:
You could use more_itertools.chunked
to get something like this:
for step, five_batches in chunked(train_ds.as_numpy_iterator()):
state = five_steps(state, five_batches):
Then do the unrolling
@jax.jit
def five_steps(state, batches):
for batch in batches:
state = train_step(state, batch)
return state
The reason this works is that batches
has a length that isn't data dependent, so the loop will just get executed 5 times during tracing.
This will likely make jitting take much longer than you want, so the perferred but more difficult way is to pack the batches into [N x batch_size x ...]
tensors, then use scan to loop your update function over the inputs.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论