英文:
How to implement iter() for a type that wraps an NdArray?
问题
以下是代码部分的翻译:
作为创建基于Rust的张量库的一部分,我已经实现了一个2D的`Tensor`类型:
pub struct Tensor(Rc<RefCell<TensorData>>);
pub struct TensorData {
pub data: Array2<f64>,
pub grad: Array2<f64>,
// 其他字段...
}
impl TensorData {
fn new(data: Array2<f64>) -> TensorData {
let shape = data.raw_dim();
TensorData {
data,
grad: Array2::zeros(shape),
// 其他字段...
}
}
}
impl Tensor {
pub fn new(array: Array2<f64>) -> Tensor {
Tensor(Rc::new(RefCell::new(TensorData::new(array))))
}
pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
Ref::map((*self.0).borrow(), |mi| &mi.data)
}
}
现在我想要能够迭代张量的行(例如,用于实现随机梯度下降)。以下是我目前的实现:
impl Tensor {
// 其他方法...
pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
self.data().outer_iter().map(|el| {
let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
reshaped_and_cloned_el
}).map(|el| Tensor::new(el))
}
}
这个问题的问题是它在视觉上不友好,并且不会编译,因为迭代器是一个临时值,一旦超出作用域就会被丢弃。
error[E0515]: 无法返回引用临时值
--> src/tensor/mod.rs:348:9
|
348 | self.data().outer_iter().map(|el| {
| ^----------
| |
| _________temporary value created here
| |
349 | | let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | | reshaped_and_cloned_el
351 | | }).map(|el| Tensor::new(el))
| |____________________________________^ returns a value referencing data owned by the current function
|
= help: use `.collect()` to allocate the iterator
有没有一个不会有这些问题的iter()
的替代实现?
英文:
As part of making a Rust-based tensor library, I have implemented a 2D Tensor
type as:
use ndarray::prelude::*;
pub struct Tensor(Rc<RefCell<TensorData>>);
pub struct TensorData {
pub data: Array2<f64>,
pub grad: Array2<f64>,
// other fields...
}
impl TensorData {
fn new(data: Array2<f64>) -> TensorData {
let shape = data.raw_dim();
TensorData {
data,
grad: Array2::zeros(shape),
// other fields...
}
}
}
impl Tensor {
pub fn new(array: Array2<f64>) -> Tensor {
Tensor(Rc::new(RefCell::new(TensorData::new(array))))
}
pub fn data(&self) -> impl Deref<Target = Array2<f64>> + '_ {
Ref::map((*self.0).borrow(), |mi| &mi.data)
}
}
Now I want to be able to iterate over the rows of a tensor (e.g. for implementing stochastic gradient descent). This is what I have so far:
impl Tensor {
// other methods...
pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
self.data().outer_iter().map(|el| {
let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(|el| el.clone());
reshaped_and_cloned_el
}).map(|el| Tensor::new(el))
}
}
The issue with this is that it is (1) unpleasant visually and (2) does not compile as the iterator is a temporary value that is dropped as soon as it goes out of scope:
error[E0515]: cannot return value referencing temporary value
--> src/tensor/mod.rs:348:9
|
348 | self.data().outer_iter().map(|el| {
| ^----------
| |
| _________temporary value created here
| |
349 | | let reshaped_and_cloned_el = el.into_shape((el.shape()[0], 1)).unwrap().mapv(...
350 | | reshaped_and_cloned_el
351 | | }).map(|el| Tensor::new(el))
| |____________________________________^ returns a value referencing data owned by the current function
|
= help: use `.collect()` to allocate the iterator
What would be an alternative implementation of iter()
that would not have these issues?
答案1
得分: 1
pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
let data = self.data();
(0..data.shape()[0]).map(move |i| {
let el = data.index_axis(Axis(0), i);
let reshaped_and_cloned_el = el
.into_shape((el.shape()[0], 1))
.unwrap()
.mapv(|el| el.clone());
Tensor::new(reshaped_and_cloned_el)
})
}
英文:
Unfortunately, you cannot use the axis iterators, because this will make a self-referential struct (the data guard and the iterator borrowing from it). But you can index to access the axis:
pub fn iter(&self) -> impl Iterator<Item = Tensor> + '_ {
let data = self.data();
(0..data.shape()[0]).map(move |i| {
let el = data.index_axis(Axis(0), i);
let reshaped_and_cloned_el = el
.into_shape((el.shape()[0], 1))
.unwrap()
.mapv(|el| el.clone());
Tensor::new(reshaped_and_cloned_el)
})
}
通过集体智慧和协作来改善编程学习和解决问题的方式。致力于成为全球开发者共同参与的知识库,让每个人都能够通过互相帮助和分享经验来进步。
评论