Rust可变迭代器遍历具有平面表示的矩阵的列

huangapple go评论64阅读模式
英文:

Rust Mutable Iterator over Cols of Matrix with flat Representation

问题

我有一个矩阵,它存储在一个Vec<T>中,采用行优先表示法。
在我的代码中,我经常会写类似以下的内容:

for i in 0..width {
    for j in 0..height {
        let element = matrix.get_mut(i, j);
        // 做一些操作
    }
}

这当然可以工作,但我认为如果我可以使用迭代器会更加方便。想象一个返回类型为 fn cols_mut() -> impl ExactSizeIterator<Item = impl ExactSizeIterator<Item = &mut T>> 的函数。然后我可以这样做:

matrix.cols_mut().for_each(|col| {
    col.fold(0, |sum, val| {
        *val += sum;
        *val
    });
});

编写一个不可变版本相当简单:

fn cols(&self) -> impl ExactSizeIterator<Item = impl ExactSizeIterator<Item = &T>> {
    (0..self.width()).map(move |i| {
        (0..self.height()).map(move |j| {
            self.get(i, j).unwrap()
        })
    })
}

但是,当我尝试使其可变时,我无法弄清楚如何满足借用检查器的要求。

英文:

I have a matrix that is stored in an Vec&lt;T&gt; with a row-major representation.
In my code I often end up writing something like

for i in 0..width {
    for j in 0..height {
        let element = matrix.get_mut(i, j);
        // do something
    }
}

That does of course work, but I think the ergonomics would be nicer if I could use iterators. Imagine a function with the return type of fn cols_mut() -&gt; impl ExactSizeIterator&lt;Item = impl ExactSizeIterator&lt;Item = &amp;mut T&gt;&gt;.
Then I could do something like

matrix.cols_mut.for_each(|col| {
    col.fold(0, |sum, val| {
        *val += sum;
        *val
    });
});

Writing a non mutable version is pretty straightforward

fn cols(&amp;self) -&gt; impl ExactSizeIterator&lt;Item = impl ExactSizeIterator&lt;Item = &amp;T&gt;&gt; {
    (0..self.width()).map(move |i| {
        (0..self.height()).map(move |j| {
            self.get(i, j).unwrap()
        })
    })
}

but I just can't figure out how to satisfy the borrow checker when making this mutable.

答案1

得分: 1

如果您确实经常以这种方式进行处理,我建议您切换到列主表示法,因为对矩阵进行处理需要大量的缓存未命中,并且阻止了自动矢量化。

至于您的问题,可以使用不安全的方式和指针算术来实现。

// 这是为了显示我们有行主矩阵。
fn calc_idx(row: usize, col: usize, total_rows: usize, total_cols: usize) -> usize {
    assert!((row < total_rows) & (col < total_cols));
    row * total_cols + col
}

fn make_column_iterators<'matrix, T: Sized + 'matrix>(
    matrix: &'matrix mut [T],
    rows: usize,
    cols: usize,
) -> impl ExactSizeIterator<Item = impl ExactSizeIterator<Item = &'matrix mut T> + 'matrix> {
    assert_eq!(matrix.len(), rows * cols);
    let ptr = matrix.as_mut_ptr();
    (0..cols).map(move |col_idx| unsafe {
        // 安全性:我们在范围内,因为 `col_idx < cols`。
        let ptr = ptr.add(col_idx);
        (0..rows)
            // 安全性:索引 (i < rows, j < cols)
            // 在范围内。
            .map(move |row_idx| ptr.add(row_idx * cols))
            // 函数生命周期参数防止别名
            // 矩阵和单个引用。
            // 我们只引用每个元素一次。
            .map(|p| &mut *p)
    })
}
英文:

If you really often do your processing that way, I recommend you to switch to column-major representation because processing matrix against it representation lead to a lot of cache misses and prevents autovectorization.

As for your question, it is possible with unsafe and pointer arithmetics.

// This is to show that we have row-major matrix.
fn calc_idx(row: usize, col: usize, total_rows: usize, total_cols: usize) -&gt; usize {
    assert!((row &lt; total_rows) &amp; (col &lt; total_cols));
    row * total_cols + col
}

fn make_column_iterators&lt;&#39;matrix, T: Sized + &#39;matrix&gt;(
    matrix: &amp;&#39;matrix mut [T],
    rows: usize,
    cols: usize,
) -&gt; impl ExactSizeIterator&lt;Item = impl ExactSizeIterator&lt;Item = &amp;&#39;matrix mut T&gt; + &#39;matrix&gt; {
    assert_eq!(matrix.len(), rows * cols);
    let ptr = matrix.as_mut_ptr();
    (0..cols).map(move |col_idx| unsafe {
        // SAFETY: We are in range because `col_idx &lt; cols`.
        let ptr = ptr.add(col_idx);
        (0..rows)
            // SAFETY: Indices (i &lt; rows, j &lt; cols)
            // in range.
            .map(move |row_idx| ptr.add(row_idx*cols))
            // Function lifetime arguments prevent aliasing
            // matrix and individual references.
            // We take reference to each element exactly once.
            .map(|p| &amp;mut *p)
    })
}

huangapple
  • 本文由 发表于 2023年5月29日 21:17:24
  • 转载请务必保留本文链接:https://go.coder-hub.com/76357728.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定