英文:
cmdstanR: extracting draws from stan model fit
问题
You can extract only the draw variable values from the fit object using the following code:
draws <- fit$draws()[, , "lambda"]
This code will give you a matrix containing the draw values for the "lambda" variable from all chains and iterations in your fit object.
英文:
I am trying to extract draws from a stan model. The stan file is as follows:
data {
int<lower=0> N;
int<lower=0, upper=1> obs_data[N];
}
parameters {
real<lower=0, upper=1> lambda;
}
model {
target += uniform_lpdf(lambda | 0,1);
for (n in 1:N) {
target += bernoulli_logit_lpmf(obs_data[n] | lambda);
}
}
I am using cmdstanR to compile and sample from the model.
dl <- list(N = 10, obs_data = c(1,0,1,1,1,0,0,1,1,1))
mod <- cmdstan_model("model.stan") // file pasted above
fit <- mod$sample(data, data = dl, num_chains = 4, num_cores = 4)
The resulting fit objects is as follows:
> fit$draws()
, , variable = lambda
chain
iteration 1 2 3 4
1 0.419819000 0.85642500 0.319154000 0.73338700
2 0.807612000 0.78189500 0.737518000 0.73338700
3 0.609196000 0.65826000 0.601450000 0.37992200
4 0.390631000 0.84544000 0.601450000 0.17992400
From the fit object, which is a R6 object, I want to get only the draw variable values. how do I do that?
Based on a followup comment, I am adding additional information on the internal structure of the R6 object for greater clarity:
> str(fit$draws())
'draws_array' num [1:1000, 1:4, 1:2] 0.42 0.808 0.609 0.391 0.391 ...
- attr(*, "dimnames")=List of 3
..$ iteration: chr [1:1000] "1" "2" "3" "4" ...
..$ chain : chr [1:4] "1" "2" "3" "4"
..$ variable : chr [1:2] "lambda" "lp__"
答案1
得分: 2
感谢来自@StéphaneLaurent和@RomanLuštrik的提示/建议,我能够得到我想要的结果。我想要绘制来自stan的采样结果的值。
根据@RomanLuštrik的建议,
str(fit$draws())提供了不同索引的映射;第一个是“迭代”,然后是“链”,最后是“变量”。
因为我想要提取“lambda”参数(“变量”)的采样结果;我可以通过以下方式访问它
fit$draws()[,,1].
要从第1和第3链中提取参数“lambda”的前25个样本,我需要使用
fit$draws()[1:25,c(1,3),1]
英文:
Thanks for the prompts / suggestions from @StéphaneLaurent and @RomanLuštrik, I was able to get to what I wanted. I was looking to draws values of the sampling result from stan.
Based on the suggestions from @RomanLuštrik,
str(fit$draws()), gives the map to the different indices; the first being the "iteration", then the "chain" and finally the "variable".
Since I was looking to extract draws for the "lambda" parameter (the "variable"); I could access it through
fit$draws()[,,1].
To draw the samples from the 1 and 3rd chains, and the first 25 samples for the parameter "lambda", I need to use
fit$draws()[1:25,c(1,3),1]
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论