如何为包装NdArray的类型实现iter()?

huangapple go评论61阅读模式
英文:

How to implement iter() for a type that wraps an NdArray?

问题

以下是代码部分的翻译:

作为创建基于Rust的张量库的一部分,我已经实现了一个2D的`Tensor`类型:

pub struct Tensor(Rc<RefCell<TensorData>>);

pub struct TensorData {
    pub data: Array2<f64>,
    pub grad: Array2<f64>,
    // 其他字段...
}

impl TensorData {
    fn new(data: Array2<f64>) -> TensorData {
        let shape = data.raw_dim();
        TensorData {
            data,
            grad: Array2::zeros(shape),
            // 其他字段...
        }
    }
}

impl Tensor {
    pub fn new(array: Array2<f64>) -> Tensor {
        Tensor(Rc::new(RefCell::new(TensorData::new(array))))
    }

    pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
        Ref::map((*self.0).borrow(), |mi| &mi.data)
    }
}

现在我想要能够迭代张量的行(例如,用于实现随机梯度下降)。以下是我目前的实现:

impl Tensor {
    // 其他方法...

    pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
        self.data().outer_iter().map(|el| {
            let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
            reshaped_and_cloned_el
        }).map(|el| Tensor::new(el))
    }
}

这个问题的问题是它在视觉上不友好,并且不会编译,因为迭代器是一个临时值,一旦超出作用域就会被丢弃。

error[E0515]: 无法返回引用临时值
   --> src/tensor/mod.rs:348:9
    |
348 |           self.data().outer_iter().map(|el| {
    |           ^----------
    |           |
    |  _________temporary value created here
    | |
349 | |             let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | |             reshaped_and_cloned_el
351 | |         }).map(|el| Tensor::new(el))
    | |____________________________________^ returns a value referencing data owned by the current function
    |
    = help: use `.collect()` to allocate the iterator

有没有一个不会有这些问题的iter()的替代实现?

英文:

As part of making a Rust-based tensor library, I have implemented a 2D Tensor type as:

use ndarray::prelude::*;

pub struct Tensor(Rc&lt;RefCell&lt;TensorData&gt;&gt;);

pub struct TensorData {
    pub data: Array2&lt;f64&gt;,
    pub grad: Array2&lt;f64&gt;,
    // other fields...
}

impl TensorData {
    fn new(data: Array2&lt;f64&gt;) -&gt; TensorData {
        let shape = data.raw_dim();
        TensorData {
            data,
            grad: Array2::zeros(shape),
            // other fields...
        }
    }
}

impl Tensor {
    pub fn new(array: Array2&lt;f64&gt;) -&gt; Tensor {
        Tensor(Rc::new(RefCell::new(TensorData::new(array))))
    }

    pub fn data(&amp;self) -&gt; impl Deref&lt;Target = Array2&lt;f64&gt;&gt; + &#39;_ {
        Ref::map((*self.0).borrow(), |mi| &amp;mi.data)
    }
}

Now I want to be able to iterate over the rows of a tensor (e.g. for implementing stochastic gradient descent). This is what I have so far:

impl Tensor {
	// other methods...
	
    pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
        self.data().outer_iter().map(|el| {
            let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
            reshaped_and_cloned_el
        }).map(|el| Tensor::new(el))
    }
}

The issue with this is that it is (1) unpleasant visually and (2) does not compile as the iterator is a temporary value that is dropped as soon as it goes out of scope:

error[E0515]: cannot return value referencing temporary value
   --&gt; src/tensor/mod.rs:348:9
    |
348 |           self.data().outer_iter().map(|el| {
    |           ^----------
    |           |
    |  _________temporary value created here
    | |
349 | |             let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | |             reshaped_and_cloned_el
351 | |         }).map(|el| Tensor::new(el))
    | |____________________________________^ returns a value referencing data owned by the current function
    |
    = help: use `.collect()` to allocate the iterator

What would be an alternative implementation of iter() that would not have these issues?

答案1

得分: 1

pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
    let data = self.data();
    (0..data.shape()[0]).map(move |i| {
        let el = data.index_axis(Axis(0), i);
        let reshaped_and_cloned_el = el
            .into_shape((el.shape()[0], 1))
            .unwrap()
            .mapv(|el| el.clone());
        Tensor::new(reshaped_and_cloned_el)
    })
}
英文:

Unfortunately, you cannot use the axis iterators, because this will make a self-referential struct (the data guard and the iterator borrowing from it). But you can index to access the axis:

pub fn iter(&amp;self) -&gt; impl Iterator&lt;Item = Tensor&gt; + &#39;_ {
    let data = self.data();
    (0..data.shape()[0]).map(move |i| {
        let el = data.index_axis(Axis(0), i);
        let reshaped_and_cloned_el = el
            .into_shape((el.shape()[0], 1))
            .unwrap()
            .mapv(|el| el.clone());
        Tensor::new(reshaped_and_cloned_el)
    })
}

huangapple
  • 本文由 发表于 2023年7月20日 14:58:11
  • 转载请务必保留本文链接:https://go.coder-hub.com/76727378.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定