英文:
MellowMax operator returning +INF
问题
MellowMax是一种softmax操作符,可在Deep Q学习的上下文中代替Max。使用Mellow Max已被证明可以消除目标网络的需求。论文链接: https://arxiv.org/abs/1612.05628
为了估算目标Q值,您需要对下一个状态的Q值执行Mellow Max操作。Mellow Max函数如下所示:
其中x是Q值的张量,w是温度参数。
我的实现如下:
def mellow_max(q_values):
q_values = tf.cast(q_values, tf.float64)
powers = tf.multiply(q_values, DEEP_MELLOW_TEMPERATURE_VALUE)
summation_values = tf.math.exp(powers)
summation = tf.math.reduce_sum(summation_values, axis=1)
val_for_log = tf.multiply(summation,(1/NUM_ACTIONS))
numerator = tf.math.log(val_for_log)
mellow_val = tf.math.divide(numerator, DEEP_MELLOW_TEMPERATURE_VALUE).numpy()
return mellow_val
我的问题是,当使用温度值'w'为1000时,这个函数的第三行返回+inf的值。我使用温度值'w'为1,000,因为在上述论文中应用于Atari Breakout测试时,这被证明是最佳的。
如果您有任何建议,以防止第三行干扰计算,我会非常感激。也许,当'w'趋向于1000时,获取该函数的极限可能有效。您有关于如何在tensorflow中实现这一点的建议吗?
英文:
MellowMax is a softmax operator that can be used instead of Max in the context of Deep Q Learning. Using Mellow Max has been shown to remove the need for a target network. Link to paper: https://arxiv.org/abs/1612.05628
To estimate a target Q Value, you perform mellow max on the Q Values of the next state. The mellow max function looks like this:
where x is the tensor of Q values and w is a temperature parameter.
My implementation is:
def mellow_max(q_values):
q_values = tf.cast(q_values, tf.float64)
powers = tf.multiply(q_values, DEEP_MELLOW_TEMPERATURE_VALUE)
summation_values = tf.math.exp(powers)
summation = tf.math.reduce_sum(summation_values, axis=1)
val_for_log = tf.multiply(summation,(1/NUM_ACTIONS))
numerator = tf.math.log(val_for_log)
mellow_val = tf.math.divide(numerator, DEEP_MELLOW_TEMPERATURE_VALUE).numpy()
return mellow_val
My issue is that the third line in this function returns values of +inf when using a temperature value 'w' of 1000. I'm using a temperature value 'w' of 1,000 as that's what was shown to be optimal in the above paper when applying to the Atari Breakout testbed.
Any suggestions would be appreciated on how I can prevent that third line from interfering with the calculation. Maybe, getting the limit of the function as 'w' goes to 1,000 would work. Any suggestions on how I could do that in tensorflow?
答案1
得分: 0
你不能像这样计算Mellowmax。因为当w*x_i很大时,指数函数会很快溢出/下溢。因此,你需要采取一些更聪明的方法,例如:
这里的logsumexp部分只有非常负的值,因此解决了溢出问题。
我们可以注意到有一个logsumexp项。我们知道当W很大时,LSE会变成log(K),其中K是x_i中出现的最大值的数量。你可以使用这个来手动验证一下你的结果。
如果你想使用非常小的w << 1,你必须小心下溢。在这种情况下,你可以使用类似的技巧。但首先计算均值,然后围绕均值而不是最大值进行logsumexp。
我错了,这里没有下溢风险。
以下是我的示例:
import torch
def mellowmax(a: torch.Tensor, w: float):
m = torch.max(a)
N = torch.Tensor([len(a),])
# 由于a - m都是负数,我们可以直接计算lse
lse = torch.exp((a - m)*w).sum().log_()
return m + (lse - N.log_())/w
N = 10
a = torch.randn((N,), dtype=torch.float)*N
for n in range(-4,5):
w = 10**n
mwm = mellowmax(a, w)
print(mwm, a.max(), a.mean())
结果是:
tensor([2.1293]) tensor(17.7385) tensor(2.1235)
tensor([2.1791]) tensor(17.7385) tensor(2.1235)
tensor([2.6696]) tensor(17.7385) tensor(2.1235)
tensor([6.6293]) tensor(17.7385) tensor(2.1235)
tensor([15.4587]) tensor(17.7385) tensor(2.1235)
tensor([17.5083]) tensor(17.7385) tensor(2.1235)
tensor([17.7155]) tensor(17.7385) tensor(2.1235)
tensor([17.7362]) tensor(17.7385) tensor(2.1235)
tensor([17.7383]) tensor(17.7385) tensor(2.1235)
我们可以看到,首先Mellowmax非常接近均值,然后随着w的增加,它变得非常接近最大值。
请注意,有意义的w通常在10左右。所以你的w=100可能是其他问题的结果。尽管如此,它取决于你的x,用朴素的方法计算Mellowmax仍然可能经常导致溢出。
英文:
you cannot compute mellowmax like this. Because the exp function will go overflow/underflow quickly when the w*x_i is large. Thus you have to do some smarter thing, for example:
Here the logsumexp part only have very negative value thus it solves the overflow issue.
We can notice there is a logsumexp term. we know LSE will become log(K), when W is very large. The K is the number of max value presented in the x_i. You can use this to manually verify your result a bit.
<s>If you wish to use very small w <<1, you have to take care the underflow. In this case, you use similar technique. But first calculate the mean value, then do the logsumexp around the mean value instead of max value.</s>
I was wrong, there is no underflow risk here.
Here is my example:
import torch
def mellowmax(a: torch.Tensor, w: float):
m = torch.max(a)
N = torch.Tensor([len(a),])
# since the a - m are all negative, we can directly compute lse
lse = torch.exp((a - m)*w).sum().log_()
return m + (lse - N.log_())/w
N = 10
a = torch.randn((N,), dtype=torch.float)*N
for n in range(-4,5):
w = 10**n
mwm = mellowmax(a, w)
print(mwm, a.max(), a.mean())
result is:
tensor([2.1293]) tensor(17.7385) tensor(2.1235)
tensor([2.1791]) tensor(17.7385) tensor(2.1235)
tensor([2.6696]) tensor(17.7385) tensor(2.1235)
tensor([6.6293]) tensor(17.7385) tensor(2.1235)
tensor([15.4587]) tensor(17.7385) tensor(2.1235)
tensor([17.5083]) tensor(17.7385) tensor(2.1235)
tensor([17.7155]) tensor(17.7385) tensor(2.1235)
tensor([17.7362]) tensor(17.7385) tensor(2.1235)
tensor([17.7383]) tensor(17.7385) tensor(2.1235)
we can see first the mellowmax is very close to mean, then it becomes very close to max as w increasing.
Please noticed the meaningful w usually within 10. So your w=100 could be result of other issue.
nevertheless, it depends on your x, the naive way to compute mellowmax can still results in overflow quite often.
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论