如何从一个 TensorFlow 张量的概率分布中抽样 5 个索引及其对应的概率?

huangapple go评论84阅读模式
英文:

How to sample 5 index and their probabilities from a tensorflow tensor of probability distribution?

问题

Sure, here's the translation of the code-related part:

我有一个概率分布(经过Softmax处理后),其中每行的值相加等于1。

probs = tf.constant([
    [0.0, 0.1, 0.2, 0.3, 0.4],
    [0.5, 0.3, 0.2, 0.0, 0.0]])

我想要使用tensorflow操作从中抽样k个索引及其相应的概率值。

3个索引的期望输出:

index: [
  [4, 3, 4],
  [0, 1, 0]
]

probs: [
  [0.4, 0.3, 0.4],
  [0.5, 0.3, 0.5]
]

我该如何实现这个?

英文:

I have a probability distribution (after applying Softmax) where values in each row sums up to 1

probs = tf.constant([
    [0.0, 0.1, 0.2, 0.3, 0.4],
    [0.5, 0.3, 0.2, 0.0, 0.0]])

I want to sample k index from it and their respective probability values using tensorflow operations.

The expected output for 3 index:

index: [
  [4, 3, 4],
  [0, 1, 0]
       ]

probs: [
  [0.4, 0.3, 0.4],
  [0.5, 0.3, 0.5]
       ]

How can I achieve this?

答案1

得分: 1

使用tf.random.uniform生成随机索引的张量:

k = 3 # 设置要抽样的元素数量
# 获取概率张量的维度
row, cols = tf.unstack(tf.shape(probs))
# 为每行生成0到cols - 1之间的k个值
idx = tf.random.uniform((row, k), 0, cols, dtype=tf.int32)

使用tf.gather_nd根据索引获取概率张量的值:

tf.gather(probs, idx, batch_dims=1)
英文:

Generating a random tensor of indexes with tf.random.uniform :

k = 3 # setting up the number of element to sample
# getting the dimension of the prob tensor
row, cols = tf.unstack(tf.shape(probs))
# Generating k values for each row between 0 and cols - 1
idx = tf.random.uniform((row, k), 0, cols, dtype=tf.int32)
>>> idx
<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[0, 0, 4],
       [3, 1, 0]], dtype=int32)>

Using tf.gather_nd to index the probabilities tensor with the indexes:

>>> tf.gather(probs, idx, batch_dims=1)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0. , 0. , 0.4],
       [0. , 0.3, 0.5]], dtype=float32)>

huangapple
  • 本文由 发表于 2023年6月6日 14:45:59
  • 转载请务必保留本文链接:https://go.coder-hub.com/76412067.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定