2024-01-02 rust borrow checker
In some TCP connection code some values need to be parsed, but if there is not enough input more input is read. Basically I want this:
fn decode_header<'a>(reader: &'a mut [u8]) -> Result<(u64, &'a [u8]), MoreInputExpected> {
todo!();
}
let mut b = MyBuffer::new(stream);
let (version, hostname) = b.step(decode_header).await?;
Some helpers, such as MoreInputExpected
and MyBuffer
are needed, which are defined in the complete code listing at the bottom. The MyBuffer::step()
function is the central piece here. I would expect it to be similar to this:
impl MyBuffer {
async fn step<'b,F,R>(&'b mut self, decode: F) -> Result<R,ConnectionError>
where
F: for<'c> Fn(&'c mut [u8]) -> Result<R, MoreInputExpected>,
{
loop {
if self.length == self.buffer.len() {
return Err(ConnectionError::MoreInputExpected);
}
let read = async_read(&mut self.stream, &mut self.buffer[self.length..]).await?;
if read == 0 {
return Err(ConnectionError::Closed);
}
self.length += read;
if let Ok(retval) = decode(&mut self.buffer[self.start..self.length]) {
return Ok(retval);
}
}
}
}
But this yield the following violations of the borrow checker:
error[E0308]: mismatched types
--> <source>:77:31
|
77 | let (version, hostname) = b.step(decode_header).await?;
| ^^^^^^^^^^^^^^^^^^^^^ one type is more general than the other
|
= note: expected enum `Result<(_, &'c _), _>`
found enum `Result<(_, &_), _>`
note: the lifetime requirement is introduced here
--> <source>:54:40
|
54 | F: for<'c> Fn(&'c mut [u8]) -> Result<R, MoreInputExpected>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The for<'c>
is needed to avoid borrowing twice the buffer. Does anybody have an idea how to solve this using stable Rust?
Complete code listing
use std::net::TcpStream;
#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)]
pub struct MoreInputExpected {
}
impl std::error::Error for MoreInputExpected {
}
impl std::fmt::Display for MoreInputExpected {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "MoreInputExpected")
}
}
#[derive(Debug)]
enum ConnectionError {
IO(std::io::Error),
MoreInputExpected,
Closed,
IncompatiableVersion,
}
impl From<std::io::Error> for ConnectionError {
fn from(value: std::io::Error) -> Self {
Self::IO(value)
}
}
impl From<MoreInputExpected> for ConnectionError {
fn from(_: MoreInputExpected) -> Self {
Self::MoreInputExpected
}
}
struct MyBuffer {
buffer: [u8; 512],
start: usize,
length: usize,
stream: TcpStream,
}
async fn async_read(stream: &mut TcpStream, buffer: &mut [u8]) -> Result<usize,ConnectionError> {
todo!();
}
impl MyBuffer {
fn new(stream: TcpStream) -> Self { Self { buffer: [0u8; 512], start: 0, length: 0, stream } }
}
impl MyBuffer {
async fn step<'b,F,R>(&'b mut self, decode: F) -> Result<R,ConnectionError>
where
F: for<'c> Fn(&'c mut [u8]) -> Result<R, MoreInputExpected>,
{
loop {
if self.length == self.buffer.len() {
return Err(ConnectionError::MoreInputExpected);
}
let read = async_read(&mut self.stream, &mut self.buffer[self.length..]).await?;
if read == 0 {
return Err(ConnectionError::Closed);
}
self.length += read;
if let Ok(retval) = decode(&mut self.buffer[self.start..self.length]) {
return Ok(retval);
}
}
}
}
fn decode_header<'a>(reader: &'a mut [u8]) -> Result<(u64, &'a [u8]), MoreInputExpected> {
todo!();
}
pub async fn connection(stream: TcpStream) -> Result<(),ConnectionError> {
let mut b = MyBuffer::new(stream);
let (version, hostname) = b.step(decode_header).await?;
Ok(())
}