提升递归抽样函数的性能

huangapple go评论102阅读模式
英文:

Improve the performance of recursive sampling function

问题

以下是您要翻译的内容:

作为对我的先前问题的后续,我对改进现有的递归抽样函数性能感兴趣。

通过递归抽样,我指的是随机选择多达n个独特的未暴露ID,用于给定的已暴露ID,然后随机选择多达n个独特的未暴露ID,用于另一个已暴露ID的剩余未暴露ID。如果对于给定的已暴露ID没有剩余的未暴露ID,那么该已暴露ID将被排除在外。

原始函数如下:

  1. recursive_sample <- function(data, n) {
  2. groups <- unique(data[["exposed"]])
  3. out <- data.frame(exposed = character(), unexposed = character())
  4. for (group in groups) {
  5. chosen <- data %>%
  6. filter(exposed == group,
  7. !unexposed %in% out$unexposed) %>%
  8. group_by(unexposed) %>%
  9. slice(1) %>%
  10. ungroup() %>%
  11. sample_n(size = min(n, nrow(.)))
  12. out <- rbind(out, chosen)
  13. }
  14. out
  15. }

我能够创建一个更有效的版本,如下所示:

  1. recursive_sample2 <- function(data, n) {
  2. groups <- unique(data[["exposed"]])
  3. out <- tibble(exposed = integer(), unexposed = integer())
  4. for (group in groups) {
  5. chosen <- data %>%
  6. filter(exposed == group,
  7. !unexposed %in% out$unexposed) %>%
  8. filter(!duplicated(unexposed)) %>%
  9. sample_n(size = min(n, nrow(.)))
  10. out <- bind_rows(out, chosen)
  11. }
  12. out
  13. }

示例数据和性能基准测试:

  1. set.seed(123)
  2. df <- tibble(exposed = rep(1:100, each = 100),
  3. unexposed = sample(1:7000, 10000, replace = TRUE))
  4. microbenchmark(f1 = recursive_sample(df, 5),
  5. f2 = recursive_sample2(df, 5),
  6. times = 10)
  7. Unit: milliseconds
  8. expr min lq mean median uq max neval cld
  9. f1 1307.7198 1316.5276 1379.0533 1371.3952 1416.6360 1540.955 10 b
  10. f2 839.0086 865.2547 914.8327 901.2288 970.9518 1036.170 10 a

然而,对于我的实际数据集,我需要一个更高效(即更快)的函数。欢迎提供更高效版本的任何想法,无论是在data.table中,涉及并行化还是其他方法。

英文:

As a follow-up to my previous question, I'm interested in improving the performance of the existing recursive sampling function.

By recursive sampling I mean randomly choosing up to n unique unexposed IDs for a given exposed ID, and the randomly choosing up to n unique unexposed IDs from the remaining unexposed IDs for another exposed ID. If there are no remaining unexposed IDs for a given exposed ID, then the exposed ID is left out.

The original function is as follows:

  1. recursive_sample &lt;- function(data, n) {
  2. groups &lt;- unique(data[[&quot;exposed&quot;]])
  3. out &lt;- data.frame(exposed = character(), unexposed = character())
  4. for (group in groups) {
  5. chosen &lt;- data %&gt;%
  6. filter(exposed == group,
  7. !unexposed %in% out$unexposed) %&gt;%
  8. group_by(unexposed) %&gt;%
  9. slice(1) %&gt;%
  10. ungroup() %&gt;%
  11. sample_n(size = min(n, nrow(.)))
  12. out &lt;- rbind(out, chosen)
  13. }
  14. out
  15. }

I was able to create a more efficient one as follows:

  1. recursive_sample2 &lt;- function(data, n) {
  2. groups &lt;- unique(data[[&quot;exposed&quot;]])
  3. out &lt;- tibble(exposed = integer(), unexposed = integer())
  4. for (group in groups) {
  5. chosen &lt;- data %&gt;%
  6. filter(exposed == group,
  7. !unexposed %in% out$unexposed) %&gt;%
  8. filter(!duplicated(unexposed)) %&gt;%
  9. sample_n(size = min(n, nrow(.)))
  10. out &lt;- bind_rows(out, chosen)
  11. }
  12. out
  13. }

Sample data and bechmarking:

  1. set.seed(123)
  2. df &lt;- tibble(exposed = rep(1:100, each = 100),
  3. unexposed = sample(1:7000, 10000, replace = TRUE))
  4. microbenchmark(f1 = recursive_sample(df, 5),
  5. f2 = recursive_sample2(df, 5),
  6. times = 10)
  7. Unit: milliseconds
  8. expr min lq mean median uq max neval cld
  9. f1 1307.7198 1316.5276 1379.0533 1371.3952 1416.6360 1540.955 10 b
  10. f2 839.0086 865.2547 914.8327 901.2288 970.9518 1036.170 10 a

However, for my actual dataset, I would need an even more efficient (i.e., quicker) function. Any ideas for a more efficient version, whether in data.table, involving parallelisation or other approaches are welcome.

答案1

得分: 2

以下是翻译好的代码部分:

  1. # 更新,改进更多
  2. 更简洁的解决方案可能是使用 `Reduce` + `split`,其中我们首先对 `data` 的行进行洗牌,然后我们按组进行**迭代**抽样
  3. ```R
  4. ftic <- function(data, n) {
  5. Reduce(
  6. \(x, y) {
  7. rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
  8. },
  9. split(data[sample(1:nrow(data)), ], ~exposed)
  10. )
  11. }

下面是一个更具挑战性的性能测试,即具有1e6行的data,其中的方法包括:

  1. ftmfmnk <- function(data, n) {
  2. groups <- unique(data[["exposed"]])
  3. out <- tibble(exposed = integer(), unexposed = integer())
  4. for (group in groups) {
  5. chosen <- data %>%
  6. filter(
  7. exposed == group,
  8. !unexposed %in% out$unexposed
  9. ) %>%
  10. filter(!duplicated(unexposed)) %>%
  11. sample_n(size = min(n, nrow(.)))
  12. out <- bind_rows(out, chosen)
  13. }
  14. out
  15. }
  16. fminem <- function(data, n) {
  17. groups <- unique(data[["exposed"]])
  18. # working on vectors is faster
  19. id <- 1:nrow(data)
  20. i <- vector("integer")
  21. unexposed2 <- vector(class(data$unexposed))
  22. ex <- data$exposed
  23. ux <- data$unexposed
  24. for (group in groups) {
  25. f1 <- ex == group # first filter
  26. f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
  27. id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
  28. # and select necesary row ids
  29. is <- sample(id3, size = min(length(id3), n)) # sample row ids
  30. i <- c(i, is) # add to list
  31. unexposed2 <- ux[i] # resave unexposed2
  32. }
  33. out <- data[i, ] # only one data.frame subset
  34. out$id <- NULL
  35. out
  36. }
  37. ftic <- function(data, n) {
  38. Reduce(
  39. \(x, y) {
  40. rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
  41. },
  42. split(data[sample(1:nrow(data)), ], ~exposed)
  43. )
  44. }

以下是基准测试:

  1. set.seed(123)
  2. df <- tibble(
  3. exposed = rep(1:1000, each = 1000),
  4. unexposed = sample(1:70000, 1000000, replace = TRUE)
  5. )
  6. mbm <- microbenchmark(
  7. tmfmnk = ftmfmnk(df, 5),
  8. minem = fminem(df, 5),
  9. tic = ftic(df, 5),
  10. times = 10
  11. )
  12. boxplot(mbm)

我们将会看到:

  1. > mbm
  2. Unit: milliseconds
  3. expr min lq mean median uq max neval
  4. tmfmnk 36809.9563 44276.3545 43780.8407 44897.2661 46175.1031 46948.8906 10
  5. minem 5361.2796 5932.7752 5923.8811 6010.7775 6047.3716 6233.2919 10
  6. tic 504.5749 519.5997 641.7935 607.2825 729.4545 868.1283 10

先前的朴素方法

我在这里没有任何高级技巧,只是使用了for循环的动态规划方案,我相信一定有比我更高效的方法

  1. dp <- function(df, n) {
  2. d <- table(df)
  3. out <- list()
  4. rnm <- row.names(d)
  5. cnm <- colnames(d)
  6. for (i in 1:nrow(d)) {
  7. v <- which(d[i, ] > 0)
  8. l <- length(v)
  9. idx <- v[sample(l, min(l, n))]
  10. out[[i]] <- data.frame(exposed = rnm[i], unexposed = cnm[idx])
  11. d[, idx] <- 0
  12. }
  13. do.call(rbind, out)
  14. }

基准测试如下:

  1. set.seed(123)
  2. df <- tibble(
  3. exposed = rep(1:100, each = 100),
  4. unexposed = sample(1:7000, 10000, replace = TRUE)
  5. )
  6. mbm <- microbenchmark(
  7. f1 = recursive_sample(df, 5),
  8. f2 = recursive_sample2(df, 5),
  9. f3 = dp(df, 5),
  10. times = 10
  11. )
  12. boxplot(mbm)

结果如下:

  1. > mbm
  2. Unit: milliseconds
  3. expr min lq mean median uq max neval
  4. f1 1271.0135 1302.4310 1449.2193 1326.7630 1686.4329 1888.4549 10
  5. f2 507.9350 516.8854 617.0313 559.0422 706.4300 801.0124 10
  6. f3 212.8944 247.0066 278.1792 271.9010 309.7377 354.4320 10

此外,要检查结果 res <- dp(df, 5),我们可以使用以下代码:

  1. > table(res$exposed)
  2. 1 10 100 11 12 13 14 15 16 17 18 19 2 20 21 22 23 24 25 26
  3. 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
  4. 27 28 29 3 30 31 32 33 34
  5. <details>
  6. <summary>英文:</summary>
  7. # Update, with More Improvement
  8. A more concise solution might be using `Reduce` + `split`, where we shuffle the rows of `data` first and then we samples by groups **iteratively**

ftic <- function(data, n) {
Reduce(
(x, y) {
rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
},
split(data[sample(1:nrow(data)), ], ~exposed)
)
}

  1. and below is a ***tougher*** pressure test, i.e., **`data` of `1e6` rows**, where the approaches include:

ftmfmnk <- function(data, n) {
groups <- unique(data[["exposed"]])
out <- tibble(exposed = integer(), unexposed = integer())

  1. for (group in groups) {
  2. chosen &lt;- data %&gt;%
  3. filter(
  4. exposed == group,
  5. !unexposed %in% out$unexposed
  6. ) %&gt;%
  7. filter(!duplicated(unexposed)) %&gt;%
  8. sample_n(size = min(n, nrow(.)))
  9. out &lt;- bind_rows(out, chosen)
  10. }
  11. out

}

fminem <- function(data, n) {
groups <- unique(data[["exposed"]])
# working on vectors is faster
id <- 1:nrow(data)
i <- vector("integer")
unexposed2 <- vector(class(data$unexposed))
ex <- data$exposed
ux <- data$unexposed

  1. for (group in groups) {
  2. f1 &lt;- ex == group # first filter
  3. f2 &lt;- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
  4. id3 &lt;- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
  5. # and select necesary row ids
  6. is &lt;- sample(id3, size = min(length(id3), n)) # sample row ids
  7. i &lt;- c(i, is) # add to list
  8. unexposed2 &lt;- ux[i] # resave unexposed2
  9. }
  10. out &lt;- data[i, ] # only one data.frame subset
  11. out$id &lt;- NULL
  12. out

}

ftic <- function(data, n) {
Reduce(
(x, y) {
rbind(x, head(subset(y, !unexposed %in% x$unexposed), n))
},
split(data[sample(1:nrow(data)), ], ~exposed)
)
}

  1. The benchmarking is as below

set.seed(123)
df <- tibble(
exposed = rep(1:1000, each = 1000),
unexposed = sample(1:70000, 1000000, replace = TRUE)
)

mbm <- microbenchmark(
tmfmnk = ftmfmnk(df, 5),
minem = fminem(df, 5),
tic = ftic(df, 5),
times = 10
)

boxplot(mbm)

  1. and we will see that

> mbm
Unit: milliseconds
expr min lq mean median uq max neval
tmfmnk 36809.9563 44276.3545 43780.8407 44897.2661 46175.1031 46948.8906 10
minem 5361.2796 5932.7752 5923.8811 6010.7775 6047.3716 6233.2919 10
tic 504.5749 519.5997 641.7935 607.2825 729.4545 868.1283 10

  1. [![enter image description here][1]][1]
  2. -------------------------------------------
  3. # Previous Na&#239;ve Approach
  4. I don&#39;t have any advanced technique here, but just a dynamic programming scheme with `for` loops, and I believe there must be more performant approaches than mine

dp <- function(df, n) {
d <- table(df)
out <- list()
rnm <- row.names(d)
cnm <- colnames(d)
for (i in 1:nrow(d)) {
v <- which(d[i, ] > 0)
l <- length(v)
idx <- v[sample(l, min(l, n))]
out[[i]] <- data.frame(exposed = rnm[i], unexposed = cnm[idx])
d[, idx] <- 0
}
do.call(rbind, out)
}

  1. and the benchmarking

set.seed(123)
df <- tibble(
exposed = rep(1:100, each = 100),
unexposed = sample(1:7000, 10000, replace = TRUE)
)

mbm <- microbenchmark(
f1 = recursive_sample(df, 5),
f2 = recursive_sample2(df, 5),
f3 = dp(df, 5),
times = 10
)

boxplot(mbm)

  1. shows

> mbm
Unit: milliseconds
expr min lq mean median uq max neval
f1 1271.0135 1302.4310 1449.2193 1326.7630 1686.4329 1888.4549 10
f2 507.9350 516.8854 617.0313 559.0422 706.4300 801.0124 10
f3 212.8944 247.0066 278.1792 271.9010 309.7377 354.4320 10

  1. [![enter image description here][2]][2]
  2. Also, to check the result `res &lt;- dp(df, 5)`, we can use

> table(res$exposed)

1 10 100 11 12 13 14 15 16 17 18 19 2 20 21 22 23 24 25 26
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
27 28 29 3 30 31 32 33 34 35 36 37 38 39 4 40 41 42 43 44
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
45 46 47 48 49 5 50 51 52 53 54 55 56 57 58 59 6 60 61 62
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
63 64 65 66 67 68 69 7 70 71 72 73 74 75 76 77 78 79 8 80
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5
81 82 83 84 85 86 87 88 89 9 90 91 92 93 94 95 96 97 98 99
5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5

> anyDuplicated(res$unexposed)
1 0

  1. [1]: https://i.stack.imgur.com/HUvI0.png
  2. [2]: https://i.stack.imgur.com/60wE3.png
  3. </details>
  4. # 答案2
  5. **得分**: 2
  6. 以下是已翻译的代码部分:
  7. ```R
  8. Working on vectors is much faster:
  9. recursive_sample3 <- function(data, n) {
  10. groups <- unique(data[["exposed"]])
  11. # working on vectors is faster
  12. id <- 1:nrow(data)
  13. i <- vector('integer')
  14. unexposed2 <- vector(class(data$unexposed))
  15. ex <- data$exposed
  16. ux <- data$unexposed
  17. for (group in groups) {
  18. f1 <- ex == group # first filter
  19. f2 <- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
  20. id3 <- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
  21. # and select necesary row ids
  22. is <- sample(id3, size = min(length(id3), n)) # sample row ids
  23. i <- c(i, is) # add to list
  24. unexposed2 <- ux[i] # resave unexposed2
  25. }
  26. out <- data[i, ] # only one data.frame subset
  27. out$id <- NULL
  28. out
  29. }
  30. benchmarks:
  31. microbenchmark(f1 = recursive_sample(df, 5),
  32. f2 = recursive_sample2(df, 5),
  33. f3 = recursive_sample3(df, 5),
  34. times = 3)
  35. # Unit: milliseconds
  36. # expr min lq mean median uq max neval cld
  37. # f1 1399.8988 1407.1939 1422.008133 1414.4889 1433.06280 1451.6367 3 a
  38. # f2 667.0813 673.7229 678.106400 680.3645 683.61895 686.8734 3 b
  39. # f3 6.2399 6.2625 9.531267 6.2851 11.17695 16.0688 3 c
  40. Iterating on `recursive_sample3` & incorporating concerns of `sample`:
  41. f_minem <- function(data, n) {
  42. i <- vector('integer')
  43. unexposed2 <- vector(class(data$unexposed))
  44. ux <- data$unexposed
  45. exl <- split(1:nrow(data), data$exposed)
  46. for (ii in exl) {
  47. f2 <- !ux[ii] %in% unexposed2
  48. f12 <- ii[f2]
  49. dn <- !duplicated(ux[f12])
  50. id3 <- f12[dn]
  51. is <- id3[sample.int(min(length(id3), n))]
  52. i <- c(i, is)
  53. unexposed2 <- ux[i]
  54. }
  55. out <- data[i, ]
  56. out
  57. }
  58. benchmarks nr2:
  59. microbenchmark::microbenchmark(
  60. recursive_sample3 = recursive_sample3(df, 5L),
  61. recursive_sample4 = recursive_sample4(setDT(df), 5L),
  62. f_minem = f_minem(df, 5L),
  63. setup = {df <- copy(data)}
  64. , times = 10
  65. )
  66. # Unit: milliseconds
  67. # expr min lq mean median uq max neval cld
  68. # recursive_sample3 6.2102 6.2974 9.63296 6.43245 16.3367 17.0746 10 a
  69. # recursive_sample4 3.5145 3.6249 3.67077 3.67075 3.7513 3.7970 10 b
  70. # f_minem 2.1705 2.1920 2.27510 2.23215 2.3784 2.4585 10 b

希望这对你有所帮助!如果需要进一步的翻译,请告诉我。

英文:

Working on vectors is much faster:

  1. recursive_sample3 &lt;- function(data, n) {
  2. groups &lt;- unique(data[[&quot;exposed&quot;]])
  3. # working on vectors is faster
  4. id &lt;- 1:nrow(data)
  5. i &lt;- vector(&#39;integer&#39;)
  6. unexposed2 &lt;- vector(class(data$unexposed))
  7. ex &lt;- data$exposed
  8. ux &lt;- data$unexposed
  9. for (group in groups) {
  10. f1 &lt;- ex == group # first filter
  11. f2 &lt;- !ux[f1] %in% unexposed2 # 2nd filter (only on those that match 1st)
  12. id3 &lt;- id[f1][f2][!duplicated(ux[f1][f2])] # check duplicates only on needed
  13. # and select necesary row ids
  14. is &lt;- sample(id3, size = min(length(id3), n)) # sample row ids
  15. i &lt;- c(i, is) # add to list
  16. unexposed2 &lt;- ux[i] # resave unexposed2
  17. }
  18. out &lt;- data[i, ] # only one data.frame subset
  19. out$id &lt;- NULL
  20. out
  21. }

benchmarks:

  1. microbenchmark(f1 = recursive_sample(df, 5),
  2. f2 = recursive_sample2(df, 5),
  3. f3 = recursive_sample3(df, 5),
  4. times = 3)
  5. # Unit: milliseconds
  6. # expr min lq mean median uq max neval cld
  7. # f1 1399.8988 1407.1939 1422.008133 1414.4889 1433.06280 1451.6367 3 a
  8. # f2 667.0813 673.7229 678.106400 680.3645 683.61895 686.8734 3 b
  9. # f3 6.2399 6.2625 9.531267 6.2851 11.17695 16.0688 3 c

Iterating on recursive_sample3 & incorporating concerns of sample:

  1. f_minem &lt;- function(data, n) {
  2. i &lt;- vector(&#39;integer&#39;)
  3. unexposed2 &lt;- vector(class(data$unexposed))
  4. ux &lt;- data$unexposed
  5. exl &lt;- split(1:nrow(data), data$exposed)
  6. for (ii in exl) {
  7. f2 &lt;- !ux[ii] %in% unexposed2
  8. f12 &lt;- ii[f2]
  9. dn &lt;- !duplicated(ux[f12])
  10. id3 &lt;- f12[dn]
  11. is &lt;- id3[sample.int(min(length(id3), n))]
  12. i &lt;- c(i, is)
  13. unexposed2 &lt;- ux[i]
  14. }
  15. out &lt;- data[i, ]
  16. out
  17. }

benchmarks nr2:

  1. microbenchmark::microbenchmark(
  2. recursive_sample3 = recursive_sample3(df, 5L),
  3. recursive_sample4 = recursive_sample4(setDT(df), 5L),
  4. f_minem = f_minem(df, 5L),
  5. setup = {df &lt;- copy(data)}
  6. , times = 10
  7. )
  8. # Unit: milliseconds
  9. # expr min lq mean median uq max neval cld
  10. # recursive_sample3 6.2102 6.2974 9.63296 6.43245 16.3367 17.0746 10 a
  11. # recursive_sample4 3.5145 3.6249 3.67077 3.67075 3.7513 3.7970 10 b
  12. # f_minem 2.1705 2.1920 2.27510 2.23215 2.3784 2.4585 10 b

答案3

得分: 2

以下是您要翻译的代码部分:

  1. A `data.table` solution that keeps a running list of sampled values that are used in `setdiff` (or %!in% from `collapse`):
  2. library(data.table)
  3. library(collapse) # for %!in%
  4. recursive_sample4 <- function(data, n) {
  5. sampled <- vector("list", uniqueN(data$exposed))
  6. data[
  7. ,.(
  8. unexposed = {
  9. x <- setdiff(unexposed, unlist(sampled))
  10. sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
  11. }
  12. ), exposed
  13. ]
  14. }
  15. recursive_sample5 <- function(data, n) {
  16. sampled <- vector("list", uniqueN(data$exposed))
  17. data[
  18. ,.(
  19. unexposed = {
  20. x <- unexposed[unexposed %!in% unlist(sampled)]
  21. sampled[[.GRP]] <- x[sample.int(min(length(x), n))]
  22. }
  23. ), exposed
  24. ]
  25. }
  26. Timing (including `recursive_sample3` by @minem):
  27. data <- copy(df)
  28. microbenchmark::microbenchmark(
  29. recursive_sample2 = recursive_sample2(df, 5L),
  30. recursive_sample3 = recursive_sample3(df, 5L),
  31. recursive_sample4 = recursive_sample4(setDT(df), 5L),
  32. recursive_sample5 = recursive_sample5(setDT(df), 5L),
  33. setup = {df <- copy(data)}
  34. )
  35. #> Unit: milliseconds
  36. #> expr min lq mean median uq max neval
  37. #> recursive_sample2 416.5425 427.38700 452.520780 436.58280 459.79430 614.6392 100
  38. #> recursive_sample3 4.5211 5.16330 6.765060 5.79820 6.95425 14.0693 100
  39. #> recursive_sample4 3.2038 3.57650 4.676284 4.41120 4.90855 11.6975 100
  40. #> recursive_sample5 2.2327 2.58255 3.384131 3.27405 3.93265 8.7091 100
  41. Note that `recursive_sample3` can give erroneous results due to the behavior of `sample` when the first argument is of length 1:
  42. set.seed(123)
  43. df <- tibble(exposed = rep(1:100, each = 100),
  44. unexposed = sample(1:700, 10000, replace = TRUE))
  45. nrow(recursive_sample3(df, 10L))
  46. #> [1] 704
英文:

A data.table solution that keeps a running list of sampled values that are used in setdiff (or %!in% from collapse):

  1. library(data.table)
  2. library(collapse) # for %!in%
  3. recursive_sample4 &lt;- function(data, n) {
  4. sampled &lt;- vector(&quot;list&quot;, uniqueN(data$exposed))
  5. data[
  6. ,.(
  7. unexposed = {
  8. x &lt;- setdiff(unexposed, unlist(sampled))
  9. sampled[[.GRP]] &lt;- x[sample.int(min(length(x), n))]
  10. }
  11. ), exposed
  12. ]
  13. }
  14. recursive_sample5 &lt;- function(data, n) {
  15. sampled &lt;- vector(&quot;list&quot;, uniqueN(data$exposed))
  16. data[
  17. ,.(
  18. unexposed = {
  19. x &lt;- unexposed[unexposed %!in% unlist(sampled)]
  20. sampled[[.GRP]] &lt;- x[sample.int(min(length(x), n))]
  21. }
  22. ), exposed
  23. ]
  24. }

Timing (including recursive_sample3 by @minem):

  1. data &lt;- copy(df)
  2. microbenchmark::microbenchmark(
  3. recursive_sample2 = recursive_sample2(df, 5L),
  4. recursive_sample3 = recursive_sample3(df, 5L),
  5. recursive_sample4 = recursive_sample4(setDT(df), 5L),
  6. recursive_sample5 = recursive_sample5(setDT(df), 5L),
  7. setup = {df &lt;- copy(data)}
  8. )
  9. #&gt; Unit: milliseconds
  10. #&gt; expr min lq mean median uq max neval
  11. #&gt; recursive_sample2 416.5425 427.38700 452.520780 436.58280 459.79430 614.6392 100
  12. #&gt; recursive_sample3 4.5211 5.16330 6.765060 5.79820 6.95425 14.0693 100
  13. #&gt; recursive_sample4 3.2038 3.57650 4.676284 4.41120 4.90855 11.6975 100
  14. #&gt; recursive_sample5 2.2327 2.58255 3.384131 3.27405 3.93265 8.7091 100

Note that recursive_sample3 can give erroneous results due to the behavior of sample when the first argument is of length 1:

  1. set.seed(123)
  2. df &lt;- tibble(exposed = rep(1:100, each = 100),
  3. unexposed = sample(1:700, 10000, replace = TRUE))
  4. nrow(recursive_sample3(df, 10L))
  5. #&gt; [1] 704

huangapple
  • 本文由 发表于 2023年6月5日 19:00:20
  • 转载请务必保留本文链接:https://go.coder-hub.com/76405781.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定