如何为包装NdArray的类型实现iter()?

huangapple go评论101阅读模式
英文:

How to implement iter() for a type that wraps an NdArray?

问题

以下是代码部分的翻译:

  1. 作为创建基于Rust的张量库的一部分,我已经实现了一个2D`Tensor`类型:
  2. pub struct Tensor(Rc<RefCell<TensorData>>);
  3. pub struct TensorData {
  4. pub data: Array2<f64>,
  5. pub grad: Array2<f64>,
  6. // 其他字段...
  7. }
  8. impl TensorData {
  9. fn new(data: Array2<f64>) -> TensorData {
  10. let shape = data.raw_dim();
  11. TensorData {
  12. data,
  13. grad: Array2::zeros(shape),
  14. // 其他字段...
  15. }
  16. }
  17. }
  18. impl Tensor {
  19. pub fn new(array: Array2<f64>) -> Tensor {
  20. Tensor(Rc::new(RefCell::new(TensorData::new(array))))
  21. }
  22. pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
  23. Ref::map((*self.0).borrow(), |mi| &mi.data)
  24. }
  25. }

现在我想要能够迭代张量的行(例如,用于实现随机梯度下降)。以下是我目前的实现:

  1. impl Tensor {
  2. // 其他方法...
  3. pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
  4. self.data().outer_iter().map(|el| {
  5. let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
  6. reshaped_and_cloned_el
  7. }).map(|el| Tensor::new(el))
  8. }
  9. }

这个问题的问题是它在视觉上不友好,并且不会编译,因为迭代器是一个临时值,一旦超出作用域就会被丢弃。

  1. error[E0515]: 无法返回引用临时值
  2. --> src/tensor/mod.rs:348:9
  3. |
  4. 348 | self.data().outer_iter().map(|el| {
  5. | ^----------
  6. | |
  7. | _________temporary value created here
  8. | |
  9. 349 | | let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
  10. 350 | | reshaped_and_cloned_el
  11. 351 | | }).map(|el| Tensor::new(el))
  12. | |____________________________________^ returns a value referencing data owned by the current function
  13. |
  14. = help: use `.collect()` to allocate the iterator

有没有一个不会有这些问题的iter()的替代实现?

英文:

As part of making a Rust-based tensor library, I have implemented a 2D Tensor type as:

  1. use ndarray::prelude::*;
  2. pub struct Tensor(Rc&lt;RefCell&lt;TensorData&gt;&gt;);
  3. pub struct TensorData {
  4. pub data: Array2&lt;f64&gt;,
  5. pub grad: Array2&lt;f64&gt;,
  6. // other fields...
  7. }
  8. impl TensorData {
  9. fn new(data: Array2&lt;f64&gt;) -&gt; TensorData {
  10. let shape = data.raw_dim();
  11. TensorData {
  12. data,
  13. grad: Array2::zeros(shape),
  14. // other fields...
  15. }
  16. }
  17. }
  18. impl Tensor {
  19. pub fn new(array: Array2&lt;f64&gt;) -&gt; Tensor {
  20. Tensor(Rc::new(RefCell::new(TensorData::new(array))))
  21. }
  22. pub fn data(&amp;self) -&gt; impl Deref&lt;Target = Array2&lt;f64&gt;&gt; + &#39;_ {
  23. Ref::map((*self.0).borrow(), |mi| &amp;mi.data)
  24. }
  25. }

Now I want to be able to iterate over the rows of a tensor (e.g. for implementing stochastic gradient descent). This is what I have so far:

  1. impl Tensor {
  2. // other methods...
  3. pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
  4. self.data().outer_iter().map(|el| {
  5. let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
  6. reshaped_and_cloned_el
  7. }).map(|el| Tensor::new(el))
  8. }
  9. }

The issue with this is that it is (1) unpleasant visually and (2) does not compile as the iterator is a temporary value that is dropped as soon as it goes out of scope:

  1. error[E0515]: cannot return value referencing temporary value
  2. --&gt; src/tensor/mod.rs:348:9
  3. |
  4. 348 | self.data().outer_iter().map(|el| {
  5. | ^----------
  6. | |
  7. | _________temporary value created here
  8. | |
  9. 349 | | let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
  10. 350 | | reshaped_and_cloned_el
  11. 351 | | }).map(|el| Tensor::new(el))
  12. | |____________________________________^ returns a value referencing data owned by the current function
  13. |
  14. = help: use `.collect()` to allocate the iterator

What would be an alternative implementation of iter() that would not have these issues?

答案1

得分: 1

  1. pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
  2. let data = self.data();
  3. (0..data.shape()[0]).map(move |i| {
  4. let el = data.index_axis(Axis(0), i);
  5. let reshaped_and_cloned_el = el
  6. .into_shape((el.shape()[0], 1))
  7. .unwrap()
  8. .mapv(|el| el.clone());
  9. Tensor::new(reshaped_and_cloned_el)
  10. })
  11. }
英文:

Unfortunately, you cannot use the axis iterators, because this will make a self-referential struct (the data guard and the iterator borrowing from it). But you can index to access the axis:

  1. pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
  2. let data = self.data();
  3. (0..data.shape()[0]).map(move |i| {
  4. let el = data.index_axis(Axis(0), i);
  5. let reshaped_and_cloned_el = el
  6. .into_shape((el.shape()[0], 1))
  7. .unwrap()
  8. .mapv(|el| el.clone());
  9. Tensor::new(reshaped_and_cloned_el)
  10. })
  11. }

huangapple
  • 本文由 发表于 2023年7月20日 14:58:11
  • 转载请务必保留本文链接:https://go.coder-hub.com/76727378.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定