如何关闭自定义的读/写流在Rust中

huangapple go评论95阅读模式
英文:

How to close a custom read/write stream in rust

问题

这是一个我已经实现的可读写流。我在多个线程中多次使用了 try_clone 来克隆它。请问如何安全地关闭这个 VirtualStream,以便正常响应 HTTP 请求。如果有更好的方法来实现 VirtualStream,那将更好。感谢您的关注。

use std::io;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::Arc;
use std::thread;

use std::borrow::BorrowMut;
use std::sync::mpsc::{channel, Receiver, RecvError, Sender};

pub struct VirtualStream {
    tx1: Sender<Vec<u8>>,
    rp1: Arc<Receiver<Vec<u8>>>,
    tx2: Sender<Vec<u8>>,
    rp2: Arc<Receiver<Vec<u8>>>,
}

unsafe impl Sync for VirtualStream {}

unsafe impl Send for VirtualStream {}

impl VirtualStream {
    pub fn new() -> Self {
        let (tx1, rp1): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = channel();
        let (tx2, rp2): (Sender<Vec<u8>>, Receiver<Vec<u8>>) = channel();

        VirtualStream {
            rp1: Arc::new(rp1),
            tx1,
            rp2: Arc::new(rp2),
            tx2,
        }
    }

    pub fn produce(&self, buf: &[u8]) {
        // 提供数据
        let _ = self.tx1.send(buf.to_vec());
    }

    pub fn accept(&self) -> Result<Vec<u8>, RecvError> {
        // 消耗数据
        self.rp2.recv().map_err(|e| e.into())
    }
    pub fn shutdown(&mut self) -> std::io::Result<()> {
        // 关闭流
        println!("关闭流....");
        Ok(())
    }
    pub fn try_clone(&self) -> Option<Self> {
        let tx1 = self.tx1.clone();
        let tx2 = self.tx2.clone();
        let rp1 = self.rp1.clone();
        let rp2 = self.rp2.clone();
        let cloned = VirtualStream { rp1, tx1, rp2, tx2 };
        Some(cloned)
    }
}

impl Read for VirtualStream {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self.rp1.recv() {
            Ok(data) => {
                let len = data.len().min(buf.len());
                buf[0..len].copy_from_slice(&data[0..len]);
                println!("读取成功");
                Ok(len)
            }
            Err(_) => {
                println!("读取错误");
                Err(io::ErrorKind::WouldBlock.into())
            }
        }
    }
}

impl Write for VirtualStream {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        let writer = self.tx2.borrow_mut();
        writer.send(buf.to_vec()).unwrap();
        Ok(buf.len())
    }

    fn flush(&mut self) -> io::Result<()> {
        // 无缓存
        Ok(())
    }
}

fn main() {
    let listener = TcpListener::bind("127.0.0.1:4000").unwrap();
    println!("将 127.0.0.1:4000 转发到 127.0.0.1:8000");

    for stream in listener.incoming() {
        if let Ok(mut client) = stream {
            thread::spawn(move || {
                let mut client2 = client.try_clone().unwrap();
                let mut client3 = client.try_clone().unwrap();
                let mut v_stream0 = VirtualStream::new();
                let mut virtual_stream2 = v_stream0.try_clone().unwrap();
                let mut virtual_stream3 = v_stream0.try_clone().unwrap();

                thread::spawn(move || loop {
                    println!("监听 v_stream0 数据");
                    match v_stream0.accept() {
                        Ok(data) => {
                            let string_result = std::str::from_utf8(&data).unwrap();
                            println!("HTTP 请求数据包:{:?}", string_result);
                            v_stream0.produce("HTTP/1.1 200 OK\n\nhello world123".as_bytes());
                            v_stream0.shutdown().expect("关闭失败");
                            println!("关闭调用后");
                        }
                        Err(err) => {
                            // 关闭流,然后关闭客户端
                            print!("检测到关闭...{:?}", err);
                            client
                                .shutdown(std::net::Shutdown::Write)
                                .expect("关闭流失败");
                        }
                    }
                });
                thread::spawn(move || {
                    let _ = io::copy(&mut virtual_stream2, &mut client2);
                    println!("复制虚拟流到客户端完成或出错...1");
                });
                let _ = io::copy(&mut client3, &mut virtual_stream3);
                println!("复制客户端到虚拟流完成或出错...2");
            });
        }
    }
}

使用 telnet 进行测试:

☁  ~  telnet 127.0.0.1 4000
Trying 127.0.0.1...
Connected to localhost.
Escape character is '^]'.
GET / HTTP/1.1
HTTP/1.1 200 OK

hello world123

使用 curl 进行测试:

☁  examples [master] ⚡  curl -vs http://127.0.0.1:4000/
*   Trying 127.0.0.1...
* TCP_NODELAY set
* Connected to 127.0.0.1 (127.0.0.1) port 4000 (#0)
> GET / HTTP/1.1
> Host: 127.0.0.1:4000
> User-Agent: curl/7.64.1
> Accept: */*
>
< HTTP/1.1 200 OK
* no chunk, no close, no size. Assume close to signal end
<
英文:

Here is a readable and writable stream that I have implemented. I have used try_clone clones multiple times in multiple threads. How can I safely close this VirtualStream so that I can respond to HTTP requests normally. If there were a better way to implement VirtualStream, it would be even better. Thanks for your attention.

use std::io;
use std::io::{Read, Write};
use std::net::TcpListener;
use std::sync::Arc;
use std::thread;

use std::borrow::BorrowMut;
use std::sync::mpsc::{channel, Receiver, RecvError, Sender};

pub struct VirtualStream {
    tx1: Sender&lt;Vec&lt;u8&gt;&gt;,
    rp1: Arc&lt;Receiver&lt;Vec&lt;u8&gt;&gt;&gt;,
    tx2: Sender&lt;Vec&lt;u8&gt;&gt;,
    rp2: Arc&lt;Receiver&lt;Vec&lt;u8&gt;&gt;&gt;,
}

unsafe impl Sync for VirtualStream {}

unsafe impl Send for VirtualStream {}

impl VirtualStream {
    pub fn new() -&gt; Self {
        let (tx1, rp1): (Sender&lt;Vec&lt;u8&gt;&gt;, Receiver&lt;Vec&lt;u8&gt;&gt;) = channel();
        let (tx2, rp2): (Sender&lt;Vec&lt;u8&gt;&gt;, Receiver&lt;Vec&lt;u8&gt;&gt;) = channel();

        VirtualStream {
            rp1: Arc::new(rp1),
            tx1,
            rp2: Arc::new(rp2),
            tx2,
        }
    }

    pub fn produce(&amp;self, buf: &amp;[u8]) {
        // provide data
        let _ = self.tx1.send(buf.to_vec());
    }

    pub fn accept(&amp;self) -&gt; Result&lt;Vec&lt;u8&gt;, RecvError&gt; {
        // consume  data
        self.rp2.recv().map_err(|e| e.into())
    }
    pub fn shutdown(&amp;mut self) -&gt; std::io::Result&lt;()&gt; {
        // close stream
        println!(&quot;shuwdown....&quot;);
        Ok(())
    }
    pub fn try_clone(&amp;self) -&gt; Option&lt;Self&gt; {
        let tx1 = self.tx1.clone();
        let tx2 = self.tx2.clone();
        let rp1 = self.rp1.clone();
        let rp2 = self.rp2.clone();
        let cloned = VirtualStream { rp1, tx1, rp2, tx2 };
        Some(cloned)
    }
}

impl Read for VirtualStream {
    fn read(&amp;mut self, buf: &amp;mut [u8]) -&gt; io::Result&lt;usize&gt; {
        match self.rp1.recv() {
            Ok(data) =&gt; {
                let len = data.len().min(buf.len());
                buf[0..len].copy_from_slice(&amp;data[0..len]);
                println!(&quot;read ok&quot;);
                Ok(len)
            }
            Err(_) =&gt; {
                println!(&quot;read error&quot;);
                Err(io::ErrorKind::WouldBlock.into())
            }
        }
    }
}

impl Write for VirtualStream {
    fn write(&amp;mut self, buf: &amp;[u8]) -&gt; io::Result&lt;usize&gt; {
        let writer = self.tx2.borrow_mut();
        writer.send(buf.to_vec()).unwrap();
        Ok(buf.len())
    }

    fn flush(&amp;mut self) -&gt; io::Result&lt;()&gt; {
        // no cache
        Ok(())
    }
}

fn main() {
    let listener = TcpListener::bind(&quot;127.0.0.1:4000&quot;).unwrap();
    println!(&quot;Forward 127.0.0.1:4000 to 127.0.0.1:8000&quot;);

    for stream in listener.incoming() {
        if let Ok(mut client) = stream {
            thread::spawn(move || {
                let mut client2 = client.try_clone().unwrap();
                let mut client3 = client.try_clone().unwrap();
                let mut v_stream0 = VirtualStream::new();
                // let virtual_stream1 = v_stream0.try_clone().unwrap();
                let mut virtual_stream2 = v_stream0.try_clone().unwrap();
                let mut virtual_stream3 = v_stream0.try_clone().unwrap();

                thread::spawn(move || loop {
                    println!(&quot;listen v_stream0 data&quot;);
                    match v_stream0.accept() {
                        Ok(data) =&gt; {
                            let string_result = std::str::from_utf8(&amp;data).unwrap();
                            println!(&quot;http request packet:{:?}&quot;, string_result);
                            v_stream0.produce(&quot;HTTP/1.1 200 OK\n\nhello world123&quot;.as_bytes());
                            v_stream0.shutdown().expect(&quot;Failed to shutdown&quot;);
                            println!(&quot;after shutdown called&quot;);
                        }
                        Err(err) =&gt; {
                            // closed stream, then shutdown client
                            print!(&quot;close detect...{:?}&quot;, err);
                            client
                                .shutdown(std::net::Shutdown::Write)
                                .expect(&quot;falied close stream&quot;);
                        }
                    }
                });
                thread::spawn(move || {
                    let _ = io::copy(&amp;mut virtual_stream2, &amp;mut client2);
                    println!(&quot;copy virtual stream to client complete or error...1&quot;);
                });
                let _ = io::copy(&amp;mut client3, &amp;mut virtual_stream3);
                println!(&quot;copy client to virtual stream complete or error...2&quot;);
            });
        }
    }
}

testing with telnet:

☁  ~  telnet 127.0.0.1 4000
Trying 127.0.0.1...
Connected to localhost.
Escape character is &#39;^]&#39;.
GET / HTTP/1.1
HTTP/1.1 200 OK

hello world123

testing with curl:

☁  examples [master] ⚡  curl -vs http://127.0.0.1:4000/
*   Trying 127.0.0.1...
* TCP_NODELAY set
* Connected to 127.0.0.1 (127.0.0.1) port 4000 (#0)
&gt; GET / HTTP/1.1
&gt; Host: 127.0.0.1:4000
&gt; User-Agent: curl/7.64.1
&gt; Accept: */*
&gt;
&lt; HTTP/1.1 200 OK
* no chunk, no close, no size. Assume close to signal end
&lt;

答案1

得分: 2

首先,您的unsafe impl块对于SendSync都是不安全的。SenderReceiver都不是Sync,而您的unsafe impl允许对它们进行不同步的访问。

通常情况下,除非您绝对确定它是安全的,否则不应手动实现这些特性。如果您不确定“soundness”是什么意思,那么您不应该使用unsafe关键字。

看起来您只是想创建一个ReadWrite对,其中写入Write端的内容来自于Read端。您可以直接在SenderReceiver之上实现这个功能:

use std::{
    io::{Read, Write},
    sync::mpsc::{channel, Receiver, Sender},
    time::Duration,
};

#[derive(Debug, Clone)]
pub struct ChannelWrite(Sender<Vec<u8>>);

impl Write for ChannelWrite {
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        self.0
            .send(buf.into())
            .map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;

        Ok(buf.len())
    }

    fn flush(&mut self) -> std::io::Result<()> {
        Ok(())
    }
}

#[derive(Debug)]
pub struct ChannelRead {
    rx: Receiver<Vec<u8>>,
    current: Vec<u8>,
    current_pos: usize,
}

impl Read for ChannelRead {
    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
        if buf.len() == 0 {
            return Ok(0);
        }

        loop {
            let remaining = self.current.len() - self.current_pos;
            if remaining > 0 {
                let to_fill = std::cmp::min(remaining, buf.len());
                buf[..to_fill]
                    .copy_from_slice(&self.current[self.current_pos..(self.current_pos + to_fill)]);
                self.current_pos += to_fill;
                return Ok(to_fill);
            }

            match self.rx.recv() {
                Ok(b) => {
                    self.current = b;
                    self.current_pos = 0;
                }
                Err(_) => return Ok(0),
            };
        }
    }
}

pub fn channel_pipe() -> (ChannelWrite, ChannelRead) {
    let (tx, rx) = channel();

    (
        ChannelWrite(tx),
        ChannelRead {
            rx,
            current: vec![],
            current_pos: 0,
        },
    )
}

很少需要同时具有读和写能力。通常,您会有一种单向的字节流,这正是这里实现的功能。

请注意,ChannelWrite 派生了 Clone,因此您可以简单地克隆写入者并将它们分发给多个线程。

ChannelRead 无法克隆,因为 Receiver 无法克隆。

要使用它,您可以像这样做:let (mut write, mut read) = channel_pipe(); 并将写入者分发到需要的地方,然后可以从read端读取。

请注意:

  • 发送错误会转换为带有 BrokenPipe 类型的 std::io::Error。如果在读取器被关闭后尝试写入,您将收到此错误。
  • 接收错误只会在所有发送方被关闭时发生,这只是表示没有更多数据可以读取,因此读取将返回Ok(0)以表示这一点。因此,读取器永远不会产生任何I/O错误。

因此,关闭这种“虚拟流”的方式是通过关闭所有写入者(通过它们隐式地超出作用域,或者使用 drop() 显式关闭)。一旦这样做,读取器将通过返回 Ok(0) 来表示流的结束(一旦读取了所有排队/缓冲的数据,当然是这样)。

以下是一个将读取器提供给线程的简单示例:

fn main() {
    let (mut write, mut read) = channel_pipe();

    let th = std::thread::spawn(move || {
        let mut buf = [0; 1024];
        loop {
            match read.read(&mut buf).unwrap() {
                0 => break,
                v => println!("read: {:?}", from_utf8(&buf[..v]).unwrap()),
            };
        }
    });

    for i in 0..10 {
        write.write_all(format!("hello, {}", i).as_bytes()).unwrap();
        std::thread::sleep(Duration::from_millis(500));
    }

    drop(write);

    th.join().unwrap();
}

(Playground)

英文:

First, your unsafe impl blocks for both Send and Sync are unsound. Neither Sender nor Receiver are Sync, and your unsafe impls allow unsynchronized access to both of them.

Generally you should not implement these traits manually unless you are absolutely certain that it is safe to do so. If you aren't sure what "soundness" means then you should not use the unsafe keyword.

It looks like you're just trying to create a Read and Write pair where what is written to the Write end is produced from the Read end. You can implement this directly on top of Sender and Receiver:

use std::{
io::{Read, Write},
sync::mpsc::{channel, Receiver, Sender}, time::Duration,
};
#[derive(Debug, Clone)]
pub struct ChannelWrite(Sender&lt;Vec&lt;u8&gt;&gt;);
impl Write for ChannelWrite {
fn write(&amp;mut self, buf: &amp;[u8]) -&gt; std::io::Result&lt;usize&gt; {
self.0
.send(buf.into())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
Ok(buf.len())
}
fn flush(&amp;mut self) -&gt; std::io::Result&lt;()&gt; {
Ok(())
}
}
#[derive(Debug)]
pub struct ChannelRead {
rx: Receiver&lt;Vec&lt;u8&gt;&gt;,
current: Vec&lt;u8&gt;,
current_pos: usize,
}
impl Read for ChannelRead {
fn read(&amp;mut self, buf: &amp;mut [u8]) -&gt; std::io::Result&lt;usize&gt; {
if buf.len() == 0 {
return Ok(0);
}
loop {
let remaining = self.current.len() - self.current_pos;
if remaining &gt; 0 {
let to_fill = std::cmp::min(remaining, buf.len());
buf[..to_fill]
.copy_from_slice(&amp;self.current[self.current_pos..(self.current_pos + to_fill)]);
self.current_pos += to_fill;
return Ok(to_fill);
}
match self.rx.recv() {
Ok(b) =&gt; {
self.current = b;
self.current_pos = 0;
}
Err(_) =&gt; return Ok(0),
};
}
}
}
pub fn channel_pipe() -&gt; (ChannelWrite, ChannelRead) {
let (tx, rx) = channel();
(
ChannelWrite(tx),
ChannelRead {
rx,
current: vec![],
current_pos: 0,
},
)
}

It's very rare that you actually need both read and write capabilities at once. Usually you have a unidirectional stream of bytes, which is what is implemented here.

Note that ChannelWrite derives Clone so you can simply clone the writer and give them to multiple threads.

ChannelRead cannot be cloned because Receiver cannot.

To use this, you would do something like let (mut write, mut read) = channel_pipe(); and distribute the writer(s) around where needed, then you can read from the read side.

Note that:

  • Send errors are converted to an std::io::Error with the kind BrokenPipe. You will get this error if you try to write after the reader is dropped.
  • Receive errors can only happen when all senders are dropped, which simply indicates that there is no more data to be read, so reading will return Ok(0) to signal this. The reader will therefore never produce any I/O errors whatsoever.

So, the way you "close" this kind of "virtual stream" is simply to drop all of the writers (implicitly by having them go out of scope, or explicitly with drop()). Once you do so, the reader will signal end-of-stream by returning Ok(0) (once all queued/buffered data has been read, of course).

Here is a simple example where the reader is given to a thread:

fn main() {
let (mut write, mut read) = channel_pipe();
let th = std::thread::spawn(move || {
let mut buf = [0; 1024];
loop {
match read.read(&amp;mut buf).unwrap() {
0 =&gt; break,
v =&gt; println!(&quot;read: {:?}&quot;, from_utf8(&amp;buf[..v]).unwrap()),
};
}
});
for i in 0..10 {
write.write_all(format!(&quot;hello, {i}&quot;).as_bytes()).unwrap();
std::thread::sleep(Duration::from_millis(500));
}
drop(write);
th.join().unwrap();
}

(Playground)

huangapple
  • 本文由 发表于 2023年7月3日 09:28:26
  • 转载请务必保留本文链接:https://go.coder-hub.com/76601380.html
匿名

发表评论

匿名网友

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen:

确定