如何使用jit编译和vmap自动向量化对JAX函数进行向量化

huangapple go评论60阅读模式
英文:

How to vectorize JAX functions using jit compilation and vmap auto-vectorization

问题

You can use "jit" and "vmap" in JAX to vectorize and speed up the computation in the provided code. To fix the issue you mentioned, you can modify the code as follows:

import jax
import jax.numpy as jnp
from jax import vmap, jit
import numpy as np

@jit
def distance(X, Y):
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(i, X, Y):
    if i:
        return distance(X[:, i], Y[:, i])
    else:
        return 0.0

# data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
# indices
idxs = ((7, 8), (1, 7, 9), (), (1), ())

# call the regular function
print(jax.vmap(compute_metrics, in_axes=(0, None, None))(idxs, X, Y))

In this modified code, I added a jit decorator to both the distance and compute_metrics functions, and I also modified the compute_metrics function to return 0.0 when i is an empty tuple to avoid the IndexError. Finally, I used jax.vmap instead of vmap to apply vectorization.

英文:

How can I use jit and vmap in JAX to vectorize and speed up the following computation:

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    results = []
    # Iterate over idxs
    for i in idxs:
        if i:
            results.append(distance(X[:, i], Y[:, i]))
    return results

#data
X = np.random.rand(600, 10)
Y = np.random.rand(600, 10)
#indices
idxs = ((7,8), (1,7,9), (), (1), ())

# call the regular function
print(compute_metrics(idxs, X, Y)) # works
# call the function with vmap
print(vmap(compute_metrics, in_axes=(None, 0, 0))(idxs, X, Y)) # doesn't work

I followed the JAX website and tutorials but I can't find out how to make this work. The non vmap version works. However, I get an IndexError for the the vmap version (last line above) that looks like this:

jax._src.traceback_util.UnfilteredStackTrace: IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

IndexError: Too many indices for array: 2 non-None/Ellipsis indices for dim 1.

Any idea how I could get this to work? Also idxs might change and be any arbitrary valid combination of indices e.g.

idxs = ((1,3,4,5), (3,9), (3,2,5), (), (5,8))

As explained above, I tried the above version with and without vmap and couldn't get the latter, vmap, version to work.

答案1

得分: 1

I don't think vmap is going to work with a tuple of scalars. What you need to do is put the indices into an array and then use vmap over it.

我认为vmap不能处理标量元组。您需要将索引放入数组中,然后在数组上使用vmap。

英文:

I don't think vmap going to work with tuple of scalars. What you need is to put indices into array and vmap over it.

I am not sure if this solution satisfies you because we have to get rid of empty indices pairs ().

idxs_pairs = jnp.array([[7,8],[7,9]]) # put the indices pairs into array

@jit
def distance(X, Y):
    """Compute distance between two matrices X and Y.

    Args:
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.mean(jnp.abs(X - Y))

@jit
def compute_metrics(idxs, X, Y):
    return distance(X[:,idxs], Y[:,idxs])

vmap(compute_metrics, in_axes=(0, None, None))(idxs_pairs, X, Y)

You can also jit everything:

jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

Update 19/05/2023:

The question is how to make it more general - to have variable number of indices. The problem here is that JAX needs static shapes of input and output, therefore we need some tricks how to deal with this. The most obvious trick in such cases is to use jnp.where function to handle this conditional behavior. The other choice is jax.lax.cond. Therefore as before, we put indices into an array but this time we set -1 as a special flag indicating this is empty fill in the matrix (this is like zero-padding but with -1 instead of 0s). Because arrays have static shape, the number of columns in idxs_pairs should be the max number of pairs.

For example:

# 7, 8, -1 -> we only use indices: 7, 8 
# 7, 9, -1 -> we only use indices: 7, 9 
# 7, 5, 6 -> we use indices: 7, 5, 6 
# 1, -1, -1 -> we use only index: 1 
idxs_pairs = jnp.array([[7, 8, -1], [7, 9, -1], [7, 5, 6], [1, -1, -1]]) # put the indices pairs into array

We now redefine our new functions:

def distance_vectors(idx, X, Y):
    """Compute distance between two vectors of matrices X and Y.

    Args:
        idx (jax.numpy.ndarray): scalar indicating index of column
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return jnp.abs(X[:,idx] - Y[:,idx])

def compute_metrics(idxs, X, Y):
  distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
  distances = distances.T * jnp.where(idxs >= 0, 1, 0)
  n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
  output = 1/n_of_actual_indices *  1/X.shape[0] * jnp.sum(distances)
  return output

output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

I am not sure this is the most optimal way of doing it - it depends if XLA compiler can catch that we set distance of zero for -1 indices, but I am not an XLA expert. I will later provide another solution based on jax.lax.cond which can be faster, so we can benchmark.

Update: 22/05/2023
In case of jax.lax.cond the implementation can look like this:

def distance_vectors(idx, X, Y):
    """Compute distance between two vectors of matrices X and Y.

    Args:
        idx (jax.numpy.ndarray): scalar indicating index of column
        X (jax.numpy.ndarray): matrix of shape (n, m)
        Y (jax.numpy.ndarray): matrix of shape (n, m)

    Returns:
        float: distance
    """
    return lax.cond(idx >= 0, lambda: jnp.abs(X[:,idx] - Y[:,idx]), lambda: jnp.zeros_like(X[:,idx]))

def compute_metrics(idxs, X, Y):
  distances = vmap(distance_vectors, in_axes=(0, None, None))(idxs, X, Y)
  n_of_actual_indices = jnp.sum(jnp.where(idxs >= 0, 1, 0))
  output = 1/n_of_actual_indices *  1/X.shape[0] * jnp.sum(distances)
  return output

output = jit(vmap(compute_metrics, in_axes=(0, None, None)))(idxs_pairs, X, Y)

I tested it and execution times are the same as for jnp.where case.

huangapple
  • 本文由 发表于 2023年5月13日 09:06:51
  • 转载请务必保留本文链接:https://go.coder-hub.com/76240674.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定