如何在截断正态分布的范围内生成随机数

huangapple go评论79阅读模式
英文:

How to generate random number in range of truncated normal distribution

问题

我需要在截断正态分布范围内生成一个值,例如,在Python中,你可以使用scipy.stats.truncnorm()来实现。

def get_truncated_normal(mean=.0, sd=1., low=.0, upp=10.):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

具体描述可以参考这里

是否有任何包可以在Go中实现类似的功能,还是我应该自己编写以下函数?

我尝试了如下所示的方法,按照文档所说,但生成的数字不在所需的范围内:

func GenerateTruncatedNormal(mean, sd uint64) float64 {
    return rand.NormFloat64() * (float64)(sd + mean)
}

调用GenerateTruncatedNormal(10, 5)得到的结果可能是16.61、-14.54,甚至是32.8,但我希望由于mean = 10,有一小部分机会得到15,因为15是我们可以得到的最大值(10 + 5 = 15)。这里出了什么问题?😅

英文:

I need to generate a value in range of truncated normal distribution, for example, in python you could use scipy.stats.truncnorm() to make

def get_truncated_normal(mean=.0, sd=1., low=.0, upp=10.):
    return truncnorm((low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)

how it is described here

Is there any package to make something in go or should I write the following function myself?

I tried following, how the doc says, but it makes number not in the needed range:

func GenerateTruncatedNormal(mean, sd uint64) float64 {
	return rand.NormFloat64() * (float64)(sd + mean)
}
GenerateTruncatedNormal(10, 5)

makes 16.61, -14.54, or even 32.8, but I expect a small chance of getting 15 due to mean = 10 -> 10 + 5 = 15 is maximum value which we can get. What is wrong here? 😅

答案1

得分: 1

一种实现这个目标的方法是:

  • 从正态分布中生成一个数x,其均值和标准差符合所需参数。
  • 如果x不在范围[low..high]内,则将其丢弃并重新生成。

这样做尊重了正态分布的概率密度函数,有效地去除了左右尾部。

如果[low..high]区间非常窄,那么计算时间会稍微长一些,因为会有更多的生成数被丢弃。然而,在实践中它仍然非常快速收敛。

我通过绘制上述代码的结果并将其与scipy的truncnorm的结果进行比较,发现它们产生了相等的图表。

英文:

One way of achieving this consists of

  • generating a number x from the Normal distribution, with the desired parameters for mean and standard deviation,
  • if it's outside the range [low..high], then throw it away and try again.

This respects the Probability Density Function of the Normal distribution, effectively cutting out the left and right tails.

func TruncatedNormal(mean, stdDev, low, high float64) float64 {
	if low >= high {
		panic("high must be greater than low")
	}

	for {
		x := rand.NormFloat64()*stdDev + mean
		if low <= x && x < high {
			return x
		}
		// fmt.Println("missed!", x)
	}
}

Playground

If the [low..high] interval is very narrow, then it will take a bit more computation time, as more generated numbers will be thrown away. However, it still converges very fast in practice.

I checked the code above by plotting its results against and compare them to the results of scipy's truncnorm, and they do produce equivalent charts.

huangapple
  • 本文由 发表于 2021年6月16日 20:31:57
  • 转载请务必保留本文链接:https://go.coder-hub.com/68002731.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定