#![allow(unused)]
use core::fmt;
use managed::{ManagedMap, ManagedSlice};
use crate::config::{REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE};
use crate::storage::Assembler;
use crate::time::{Duration, Instant};
#[cfg(feature = "alloc")]
type Buffer = alloc::vec::Vec<u8>;
#[cfg(not(feature = "alloc"))]
type Buffer = [u8; REASSEMBLY_BUFFER_SIZE];
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct AssemblerError;
impl fmt::Display for AssemblerError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "AssemblerError")
}
}
#[cfg(feature = "std")]
impl std::error::Error for AssemblerError {}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct AssemblerFullError;
impl fmt::Display for AssemblerFullError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "AssemblerFullError")
}
}
#[cfg(feature = "std")]
impl std::error::Error for AssemblerFullError {}
#[derive(Debug)]
pub struct PacketAssembler<K> {
key: Option<K>,
buffer: Buffer,
assembler: Assembler,
total_size: Option<usize>,
expires_at: Instant,
}
impl<K> PacketAssembler<K> {
pub const fn new() -> Self {
Self {
key: None,
#[cfg(feature = "alloc")]
buffer: Buffer::new(),
#[cfg(not(feature = "alloc"))]
buffer: [0u8; REASSEMBLY_BUFFER_SIZE],
assembler: Assembler::new(),
total_size: None,
expires_at: Instant::ZERO,
}
}
pub(crate) fn reset(&mut self) {
self.key = None;
self.assembler.clear();
self.total_size = None;
self.expires_at = Instant::ZERO;
}
pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> {
if let Some(old_size) = self.total_size {
if old_size != size {
return Err(AssemblerError);
}
}
#[cfg(not(feature = "alloc"))]
if self.buffer.len() < size {
return Err(AssemblerError);
}
#[cfg(feature = "alloc")]
if self.buffer.len() < size {
self.buffer.resize(size, 0);
}
self.total_size = Some(size);
Ok(())
}
pub(crate) fn expires_at(&self) -> Instant {
self.expires_at
}
pub(crate) fn add_with(
&mut self,
offset: usize,
f: impl Fn(&mut [u8]) -> Result<usize, AssemblerError>,
) -> Result<(), AssemblerError> {
if self.buffer.len() < offset {
return Err(AssemblerError);
}
let len = f(&mut self.buffer[offset..])?;
assert!(offset + len <= self.buffer.len());
net_debug!(
"frag assembler: receiving {} octets at offset {}",
len,
offset
);
self.assembler.add(offset, len);
Ok(())
}
pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> {
#[cfg(not(feature = "alloc"))]
if self.buffer.len() < offset + data.len() {
return Err(AssemblerError);
}
#[cfg(feature = "alloc")]
if self.buffer.len() < offset + data.len() {
self.buffer.resize(offset + data.len(), 0);
}
let len = data.len();
self.buffer[offset..][..len].copy_from_slice(data);
net_debug!(
"frag assembler: receiving {} octets at offset {}",
len,
offset
);
self.assembler.add(offset, data.len());
Ok(())
}
pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> {
if !self.is_complete() {
return None;
}
let total_size = self.total_size.unwrap();
self.reset();
Some(&self.buffer[..total_size])
}
pub(crate) fn is_complete(&self) -> bool {
self.total_size == Some(self.assembler.peek_front())
}
fn is_free(&self) -> bool {
self.key.is_none()
}
}
#[derive(Debug)]
pub struct PacketAssemblerSet<K: Eq + Copy> {
assemblers: [PacketAssembler<K>; REASSEMBLY_BUFFER_COUNT],
}
impl<K: Eq + Copy> PacketAssemblerSet<K> {
const NEW_PA: PacketAssembler<K> = PacketAssembler::new();
pub fn new() -> Self {
Self {
assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT],
}
}
pub(crate) fn get(
&mut self,
key: &K,
expires_at: Instant,
) -> Result<&mut PacketAssembler<K>, AssemblerFullError> {
let mut empty_slot = None;
for slot in &mut self.assemblers {
if slot.key.as_ref() == Some(key) {
return Ok(slot);
}
if slot.is_free() {
empty_slot = Some(slot)
}
}
let slot = empty_slot.ok_or(AssemblerFullError)?;
slot.key = Some(*key);
slot.expires_at = expires_at;
Ok(slot)
}
pub fn remove_expired(&mut self, timestamp: Instant) {
for frag in &mut self.assemblers {
if !frag.is_free() && frag.expires_at < timestamp {
frag.reset();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
struct Key {
id: usize,
}
#[test]
fn packet_assembler_overlap() {
let mut p_assembler = PacketAssembler::<Key>::new();
p_assembler.set_total_size(5).unwrap();
let data = b"Rust";
p_assembler.add(&data[..], 0);
p_assembler.add(&data[..], 1);
assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..]))
}
#[test]
fn packet_assembler_assemble() {
let mut p_assembler = PacketAssembler::<Key>::new();
let data = b"Hello World!";
p_assembler.set_total_size(data.len()).unwrap();
p_assembler.add(b"Hello ", 0).unwrap();
assert_eq!(p_assembler.assemble(), None);
p_assembler.add(b"World!", b"Hello ".len()).unwrap();
assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
}
#[test]
fn packet_assembler_out_of_order_assemble() {
let mut p_assembler = PacketAssembler::<Key>::new();
let data = b"Hello World!";
p_assembler.set_total_size(data.len()).unwrap();
p_assembler.add(b"World!", b"Hello ".len()).unwrap();
assert_eq!(p_assembler.assemble(), None);
p_assembler.add(b"Hello ", 0).unwrap();
assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..]));
}
#[test]
fn packet_assembler_set() {
let key = Key { id: 1 };
let mut set = PacketAssemblerSet::new();
assert!(set.get(&key, Instant::ZERO).is_ok());
}
#[test]
fn packet_assembler_set_full() {
let mut set = PacketAssemblerSet::new();
for i in 0..REASSEMBLY_BUFFER_COUNT {
set.get(&Key { id: i }, Instant::ZERO).unwrap();
}
assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err());
}
#[test]
fn packet_assembler_set_assembling_many() {
let mut set = PacketAssemblerSet::new();
let key = Key { id: 0 };
let assr = set.get(&key, Instant::ZERO).unwrap();
assert_eq!(assr.assemble(), None);
assr.set_total_size(0).unwrap();
assr.assemble().unwrap();
let assr = set.get(&key, Instant::ZERO).unwrap();
assert_eq!(assr.assemble(), None);
assr.set_total_size(0).unwrap();
assr.assemble().unwrap();
let key = Key { id: 1 };
let assr = set.get(&key, Instant::ZERO).unwrap();
assr.set_total_size(0).unwrap();
assr.assemble().unwrap();
let key = Key { id: 2 };
let assr = set.get(&key, Instant::ZERO).unwrap();
assr.set_total_size(0).unwrap();
assr.assemble().unwrap();
let key = Key { id: 2 };
let assr = set.get(&key, Instant::ZERO).unwrap();
assr.set_total_size(2).unwrap();
assr.add(&[0x00], 0).unwrap();
assert_eq!(assr.assemble(), None);
let assr = set.get(&key, Instant::ZERO).unwrap();
assr.add(&[0x01], 1).unwrap();
assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..]));
}
}