英文:
How to vectorize JAX functions using jit compilation and vmap auto-vectorization
问题
You can use "jit" and "vmap" in JAX to vectorize and speed up the computation in the provided code. To fix the issue you mentioned, you can modify the code as follows:
import jax
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np
@jit
def distance(X, Y):
return jnp.mean(jnp.abs(X - Y))
@jit
def compute_metrics(i, X, Y):
if i:
return distance(X[:, i], Y[:, i])
else:
return 0.0
# data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
# indices
idxs = ((7, 8), (1, 7, 9), (), (1), ())
# call the regular function
print(jax.vmap(compute_metrics, in_axes=(0, None, None))(idxs, X, Y))
In this modified code, I added a jit
decorator to both the distance
and compute_metrics
functions, and I also modified the compute_metrics
function to return 0.0 when i
is an empty tuple to avoid the IndexError. Finally, I used jax.vmap
instead of vmap
to apply vectorization.
英文:
How can I use jit and vmap in JAX to vectorize and speed up the following computation:
@jit
def distance(X, Y):
"""Compute distance between two matrices X and Y.
Args:
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.mean(jnp.abs(X - Y))
@jit
def compute_metrics(idxs, X, Y):
results = []
# Iterate over idxs
for i in idxs:
if i:
results.append(distance(X[:, i], Y[:, i]))
return results
#data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
#indices
idxs = ((7,8), (1,7,9), (), (1), ())
# call the regular function
print(compute_metrics(idxs, X, Y)) # works
# call the function with vmap
print(vmap(compute_metrics, in_axes=(None, 0, 0))(idxs, X, Y)) # doesn't work
I followed the JAX website and tutorials but I can't find out how to make this work. The non vmap version works. However, I get an IndexError for the the vmap version (last line above) that looks like this:
jax._src.traceback_util.UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.
Any idea how I could get this to work? Also idxs might change and be any arbitrary valid combination of indices e.g.
idxs = ((1,3,4,5), (3,9), (3,2,5), (), (5,8))
As explained above, I tried the above version with and without vmap and couldn't get the latter, vmap, version to work.
答案1
得分: 1
I don't think vmap is going to work with a tuple of scalars. What you need to do is put the indices into an array and then use vmap over it.
我认为vmap不能处理标量元组。您需要将索引放入数组中,然后在数组上使用vmap。
英文:
I don't think vmap going to work with tuple of scalars. What you need is to put indices into array and vmap over it.
I am not sure if this solution satisfies you because we have to get rid of empty indices pairs ().
idxs_pairs = jnp.array([[7,8],[7,9]]) # put the indices pairs into array
@jit
def distance(X, Y):
"""Compute distance between two matrices X and Y.
Args:
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.mean(jnp.abs(X - Y))
@jit
def compute_metrics(idxs, X, Y):
return distance(X[:,idxs], Y[:,idxs])
vmap(compute_metrics, in_axes=(0, None, None))(idxs_pairs, X, Y)
You can also jit everything:
jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
Update 19/05/2023:
The question is how to make it more general - to have variable number of indices. The problem here is that JAX needs static shapes of input and output, therefore we need some tricks how to deal with this. The most obvious trick in such cases is to use jnp.where function to handle this conditional behavior. The other choice is jax.lax.cond. Therefore as before, we put indices into an array but this time we set -1 as a special flag indicating this is empty fill in the matrix (this is like zero-padding but with -1 instead of 0s). Because arrays have static shape, the number of columns in idxs_pairs should be the max number of pairs.
For example:
# 7, 8, -1 -> we only use indices: 7, 8
# 7, 9, -1 -> we only use indices: 7, 9
# 7, 5, 6 -> we use indices: 7, 5, 6
# 1, -1, -1 -> we use only index: 1
idxs_pairs = jnp.array([[7, 8, -1], [7, 9, -1], [7, 5, 6], [1, -1, -1]]) # put the indices pairs into array
We now redefine our new functions:
def distance_vectors(idx, X, Y):
"""Compute distance between two vectors of matrices X and Y.
Args:
idx (jax.numpy.ndarray): scalar indicating index of column
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return jnp.abs(X[:,idx] - Y[:,idx])
def compute_metrics(idxs, X, Y):
distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
distances = distances.T * jnp.where(idxs >= 0, 1, 0)
n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
output = 1/n_of_actual_indices * 1/X.shape[0] * jnp.sum(distances)
return output
output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
I am not sure this is the most optimal way of doing it - it depends if XLA compiler can catch that we set distance of zero for -1 indices, but I am not an XLA expert. I will later provide another solution based on jax.lax.cond which can be faster, so we can benchmark.
Update: 22/05/2023
In case of jax.lax.cond the implementation can look like this:
def distance_vectors(idx, X, Y):
"""Compute distance between two vectors of matrices X and Y.
Args:
idx (jax.numpy.ndarray): scalar indicating index of column
X (jax.numpy.ndarray): matrix of shape (n, m)
Y (jax.numpy.ndarray): matrix of shape (n, m)
Returns:
float: distance
"""
return lax.cond(idx >= 0, lambda: jnp.abs(X[:,idx] - Y[:,idx]), lambda: jnp.zeros_like(X[:,idx]))
def compute_metrics(idxs, X, Y):
distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
output = 1/n_of_actual_indices * 1/X.shape[0] * jnp.sum(distances)
return output
output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)
I tested it and execution times are the same as for jnp.where case.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论