Error using JAX, Array slice indices must have static start/stop/step.

huangapple go评论60阅读模式
英文:

Error using JAX, Array slice indices must have static start/stop/step

问题

I'll be happy to help you with your code. If I understand correctly, you want to create a 2D Gaussian patch for each value in the darkField array. The size of the patch should ideally be calculated as 2 * np.ceil(3 * sigma) + 1, where sigma is the corresponding value from the darkField array. You have fixed the size value to 10 in your example to avoid errors.

Once the Gaussian patch is normalized to 1, you want to multiply it by the corresponding value from the intensityRefracted2DF array to obtain the generated blur. Finally, you want to add this blur patch to the intensityRefracted3 array.

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax.scipy.signal import convolve2d
from functools import partial

@partial(jax.jit, static_argnums=(1,))
def gaussian_shape(sigma, size):
    x = jnp.arange(0, size) - jnp.floor(size / 2)
    exponent = jnp.exp(-(x ** 2) / (2 * sigma ** 2))
    exponent = jnp.outer(exponent, exponent)
    exponent /= jnp.sum(exponent)
    return exponent

@partial(jax.jit)
def apply_dark_field(i, j, intensityRefracted2DF, intensityRefracted3, darkField):
    currDF_ij = darkField[i, j]
    patch = gaussian_shape(currDF_ij, 10)
    size2 = patch.shape[0] // 2
    patch = patch * intensityRefracted2DF[i, j]

    intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])
    return intensityRefracted3

@jax.jit
def darkFieldLoop(intensityRefracted2DF, intensityRefracted3, darkField):
    currDF = jnp.zeros_like(intensityRefracted3)
    currDF = jnp.where(intensityRefracted2DF != 0, darkField, 0)

    i = jnp.nonzero(currDF, size=currDF.shape[0])
    indices_i = i[0]
    indices_j = i[1]
    intensityRefracted3 = jnp.zeros_like(intensityRefracted3)

    intensityRefracted3 = jax.vmap(apply_dark_field, in_axes=(0, 0, None, None, None))(indices_i, indices_j, intensityRefracted2DF, intensityRefracted3, darkField)

    return intensityRefracted3

intensityRefracted2DF = np.random.rand(10, 10)
intensityRefracted3 = np.zeros((10, 10))
darkField = np.random.rand(10, 10)

a = darkFieldLoop(intensityRefracted2DF, intensityRefracted3, darkField)

for i in range(a.shape[0]):
    plt.imshow(a[i])
    plt.show()

Regarding the error message you mentioned:

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

It seems like an issue with dynamic slicing within JIT-compiled functions. You might want to use lax.dynamic_slice or lax.dynamic_update_slice to handle dynamically sized arrays within JIT-compiled functions. This will require modifying your code to use these JAX functions for slicing and updating arrays dynamically.

英文:

I'll be happy to help you with your code. If I understand correctly, you want to create a 2D Gaussian patch for each value in the darkField array. The size of the patch should ideally be calculated as 2 * np.ceil(3 * sigma) + 1, where sigma is the corresponding value from the darkField array. You have fixed the size value to 10 in your example to avoid errors.

Once the Gaussian patch is normalized to 1, you want to multiply it by the corresponding value from the intensityRefracted2DF array to obtain the generated blur. Finally, you want to add this blur patch to the intensityRefracted3 array.

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from jax.scipy.signal import convolve2d
from functools import partial

@partial(jax.jit,static_argnums=(1,))
def gaussian_shape(sigma, size):
    """
    Generate a Gaussian shape.

    Args:
        sigma (float or 2D numpy array): Standard deviation(s) of the Gaussian shape.
        size (int): Size of the Gaussian shape.

    Returns:
        exponent (2D numpy array): Gaussian shape.
    """

    x = jnp.arange(0, size) - jnp.floor(size / 2)
    exponent = jnp.exp(-(x ** 2) / (2 * sigma ** 2))
    exponent = jnp.outer(exponent, exponent)
    exponent /= jnp.sum(exponent)
    return exponent

@partial(jax.jit)
def apply_dark_field(i, j, intensityRefracted2DF, intensityRefracted3, darkField):
    currDF_ij=darkField[i,j]
    patch = gaussian_shape(currDF_ij,10)
    size2 = patch.shape[0] // 2
    patch = patch * intensityRefracted2DF[i, j]


    intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])
    # intensityRefracted3 = jax.ops.index_add(intensityRefracted3, (i, j), intensityRefracted2DF[i, j] * (darkField[i, j] == 0))
    return intensityRefracted3

@jax.jit
def darkFieldLoop(intensityRefracted2DF, intensityRefracted3, darkField):
    currDF = jnp.zeros_like(intensityRefracted3)
    currDF = jnp.where(intensityRefracted2DF!=0,darkField,0)

    i = jnp.nonzero(currDF,size=currDF.shape[0])
    indices_i=i[0]
    indices_j=i[1]
    intensityRefracted3 = jnp.zeros_like(intensityRefracted3)

    intensityRefracted3 = jax.vmap(apply_dark_field, in_axes=(0, 0, None, None, None))(indices_i, indices_j, intensityRefracted2DF, intensityRefracted3, darkField)

    return intensityRefracted3

intensityRefracted2DF = np.random.rand(10,10)
intensityRefracted3 = np.zeros((10, 10))
darkField = np.random.rand(10, 10)

a=darkFieldLoop(intensityRefracted2DF,intensityRefracted3,darkField)

for i in range(a.shape[0]):
    plt.imshow(a[i])
    plt.show()

And there is the error message :

IndexError: Array slice indices must have static start/stop/step to   be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=3/0)>, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

I've also try to put i,j into static_argnums using partial

@partial(jax.jit, static_argnums=(0,1))
def apply_dark_field(i, j, intensityRefracted2DF, intensityRefracted3, darkField):

and there is the error:

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 0) of type <class 'jax._src.interpreters.batching.BatchTracer'> for function apply_dark_field is non-hashable.

答案1

得分: 0

问题出在JAX数组不能具有动态形状,因此无法在索引表达式中使用动态的起始和结束索引。

ij标记为静态的解决方案可以行得通,但是您正在对这些值进行vmapping,按定义它们不能是静态的。

在这里最好的解决方案可能是使用lax.dynamic_slicelax.dynamic_update_slice,这些操作恰好设计用于您的情况(其中索引是动态的,但切片大小是静态的)。

您可以将这行代码替换为:

intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])

用这个:

start_indices = (i - size2, j - size2)
update = jax.lax.dynamic_slice(intensityRefracted3, start_indices, patch.shape)
update += patch * intensityRefracted2DF[i, j]
intensityRefracted3 = jax.lax.dynamic_update_slice(
    intensityRefracted3, update,  start_indices)

它应该能正确地处理动态的ij。不过请注意,如果指定的任何索引超出范围,dynamic_slicedynamic_update_slice将将它们剪切到有效范围内。

英文:

The issue comes from the fact that JAX arrays cannot have a dynamic shape, and so dynamic start & end indices cannot be used in indexing expressions.

Your solution of marking i and j as static would work, except that you are vmapping across these values, so by definition they cannot be static.

The best solution here is probably to use lax.dynamic_slice and lax.dynamic_update_slice, which are operations designed exactly for the case that you have (where indices are dynamic, but slice sizes are static).

You can replace this line:

intensityRefracted3 = intensityRefracted3.at[i - size2:i + size2 + 1, j - size2:j + size2 + 1].add(patch * intensityRefracted2DF[i, j])

with this:

start_indices = (i - size2, j - size2)
update = jax.lax.dynamic_slice(intensityRefracted3, start_indices, patch.shape)
update += patch * intensityRefracted2DF[i, j]
intensityRefracted3 = jax.lax.dynamic_update_slice(
    intensityRefracted3, update,  start_indices)

and it should work correctly with dynamic i and j. Though you should be careful, because if any of the specified indices are out-of-bounds, dynamic_slice and dynamic_update_slice will clip them into the valid range.

huangapple
  • 本文由 发表于 2023年6月29日 14:14:12
  • 转载请务必保留本文链接:https://go.coder-hub.com/76578461.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定