use std::collections::{HashMap, LinkedList};
use std::io::{Read, Write};
use std::net::SocketAddrV4;
use crate::buffer::RecvQueue;
use crate::connection::Connection;
use crate::seq::Seq;
use crate::util::remove_from_list;
use crate::util::time::Duration;
use crate::{
AcceptError, AcceptedTcpState, CloseError, ConnectError, Dependencies, ListenError, Payload,
PollState, PopPacketError, PushPacketError, RecvError, RstCloseError, SendError, Shutdown,
ShutdownError, TcpConfig, TcpError, TcpFlags, TcpHeader, TcpState, TcpStateEnum, TcpStateTrait,
TimerRegisteredBy,
};
#[derive(Debug)]
pub struct InitState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) config: TcpConfig,
}
#[derive(Debug)]
pub struct ListenState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) config: TcpConfig,
pub(crate) max_backlog: u32,
pub(crate) send_buffer: LinkedList<TcpHeader>,
pub(crate) children: slotmap::DenseSlotMap<ChildTcpKey, ChildEntry<X>>,
pub(crate) conn_map: HashMap<RemoteLocalPair, ChildTcpKey>,
pub(crate) accept_queue: LinkedList<ChildTcpKey>,
pub(crate) to_send: LinkedList<ChildTcpKey>,
}
#[derive(Debug)]
pub struct SynSentState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct SynReceivedState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct EstablishedState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct FinWaitOneState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct FinWaitTwoState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct ClosingState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct TimeWaitState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct CloseWaitState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct LastAckState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) connection: Connection<X::Instant>,
}
#[derive(Debug)]
pub struct RstState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) send_buffer: LinkedList<TcpHeader>,
pub(crate) was_connected: bool,
}
#[derive(Debug)]
pub struct ClosedState<X: Dependencies> {
pub(crate) common: Common<X>,
pub(crate) recv_buffer: RecvQueue,
pub(crate) was_connected: bool,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
struct ChildNotFound;
#[derive(Debug)]
pub(crate) struct Common<X: Dependencies> {
pub(crate) deps: X,
pub(crate) child_key: Option<ChildTcpKey>,
pub(crate) error: Option<TcpError>,
}
impl<X: Dependencies> Common<X> {
pub fn register_timer(
&self,
time: X::Instant,
f: impl FnOnce(TcpStateEnum<X>) -> TcpStateEnum<X> + Send + Sync + 'static,
) {
let child_key = self.child_key;
let timer_cb_inner = move |mut parent_state, state_type| {
match state_type {
TimerRegisteredBy::Parent => f(parent_state),
TimerRegisteredBy::Child => {
let TcpStateEnum::Listen(parent_listen_state) = &mut parent_state else {
return parent_state;
};
let child_key = child_key.expect(
"The timer was supposedly registered by a child state, but there was no \
key to identify the child",
);
let rv = parent_listen_state.with_child(child_key, |state| (f(state), ()));
#[allow(clippy::single_match)]
match rv {
Ok(()) => {}
Err(ChildNotFound) => {}
}
parent_state
}
}
};
let timer_cb = move |parent_state: &mut TcpState<X>, state_type| {
parent_state.with_state(|state| (timer_cb_inner(state, state_type), ()))
};
self.deps.register_timer(time, timer_cb);
}
pub fn current_time(&self) -> X::Instant {
self.deps.current_time()
}
pub fn set_error_if_unset(&mut self, new_error: TcpError) -> bool {
if self.error.is_none() {
self.error = Some(new_error);
return true;
}
false
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct RemoteLocalPair {
remote: SocketAddrV4,
local: SocketAddrV4,
}
impl RemoteLocalPair {
pub fn new(remote: SocketAddrV4, local: SocketAddrV4) -> Self {
Self { remote, local }
}
}
slotmap::new_key_type! { pub(crate) struct ChildTcpKey; }
#[derive(Debug)]
pub(crate) struct ChildEntry<X: Dependencies> {
state: Option<TcpStateEnum<X>>,
conn_addrs: RemoteLocalPair,
}
impl<X: Dependencies> InitState<X> {
pub fn new(deps: X, config: TcpConfig) -> Self {
let common = Common {
deps,
child_key: None,
error: None,
};
InitState { common, config }
}
}
impl<X: Dependencies> TcpStateTrait<X> for InitState<X> {
fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = ClosedState::new(self.common, None, false);
(new_state.into(), Ok(()))
}
fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
let new_state = ClosedState::new(self.common, None, false);
(new_state.into(), Ok(()))
}
fn shutdown(self, _how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
(self.into(), Err(ShutdownError::NotConnected))
}
fn listen<T, E>(
self,
backlog: u32,
associate_fn: impl FnOnce() -> Result<T, E>,
) -> (TcpStateEnum<X>, Result<T, ListenError<E>>) {
let rv = match associate_fn() {
Ok(x) => x,
Err(e) => return (self.into(), Err(ListenError::FailedAssociation(e))),
};
let max_backlog = backlog.saturating_add(1);
let new_state = ListenState::new(self.common, self.config, max_backlog);
(new_state.into(), Ok(rv))
}
fn connect<T, E>(
self,
remote_addr: SocketAddrV4,
associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
let assoc_result = associate_fn();
let (local_addr, assoc_result) = match assoc_result {
Ok((local_addr, assoc_result)) => (local_addr, assoc_result),
Err(e) => return (self.into(), Err(ConnectError::FailedAssociation(e))),
};
assert!(!local_addr.ip().is_unspecified());
let connection = Connection::new(local_addr, remote_addr, Seq::new(0), self.config);
let new_state = SynSentState::new(self.common, connection);
(new_state.into(), Ok(assoc_result))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::NotConnected))
}
fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
(self.into(), Err(RecvError::NotConnected))
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::empty();
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
false
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
None
}
}
impl<X: Dependencies> ListenState<X> {
fn new(common: Common<X>, config: TcpConfig, max_backlog: u32) -> Self {
ListenState {
common,
config,
max_backlog,
send_buffer: LinkedList::new(),
children: slotmap::DenseSlotMap::with_key(),
conn_map: HashMap::new(),
accept_queue: LinkedList::new(),
to_send: LinkedList::new(),
}
}
fn register_child(&mut self, header: &TcpHeader, payload: Payload) -> ChildTcpKey {
let conn_addrs = RemoteLocalPair::new(header.src(), header.dst());
let key = self.children.insert_with_key(|key| {
let common = Common {
deps: self.common.deps.fork(),
child_key: Some(key),
error: None,
};
assert!(header.flags.contains(TcpFlags::SYN));
assert!(!header.flags.contains(TcpFlags::RST));
let mut connection =
Connection::new(header.dst(), header.src(), Seq::new(0), self.config);
connection.push_packet(header, payload).unwrap();
let new_tcp = SynReceivedState::new(common, connection);
ChildEntry {
state: Some(new_tcp.into()),
conn_addrs,
}
});
assert!(self.conn_map.insert(conn_addrs, key).is_none());
self.sync_child(key).unwrap();
key
}
fn sync_child(&mut self, key: ChildTcpKey) -> Result<(), ChildNotFound> {
let is_closed;
{
let entry = self.children.get_mut(key).ok_or(ChildNotFound)?;
let child = &mut entry.state;
let conn_addrs = &entry.conn_addrs;
if child.as_ref().unwrap().wants_to_send() {
if !self.to_send.contains(&key) {
self.to_send.push_back(key);
}
} else {
remove_from_list(&mut self.to_send, &key);
}
if matches!(
child.as_ref().unwrap(),
TcpStateEnum::Established(_) | TcpStateEnum::CloseWait(_)
) {
if !self.accept_queue.contains(&key) {
self.accept_queue.push_back(key);
}
} else {
remove_from_list(&mut self.accept_queue, &key);
}
assert!(self.conn_map.contains_key(conn_addrs));
debug_assert_eq!(self.conn_map.get(conn_addrs).unwrap(), &key);
is_closed = child.as_ref().unwrap().poll().contains(PollState::CLOSED);
}
if is_closed {
self.remove_child(key).unwrap();
}
Ok(())
}
fn remove_child(&mut self, key: ChildTcpKey) -> Option<TcpStateEnum<X>> {
let entry = self.children.remove(key)?;
let child = entry.state.unwrap();
let conn_addrs = entry.conn_addrs;
remove_from_list(&mut self.accept_queue, &key);
remove_from_list(&mut self.to_send, &key);
assert_eq!(self.conn_map.remove(&conn_addrs), Some(key));
Some(child)
}
fn child(&self, key: ChildTcpKey) -> Option<&TcpStateEnum<X>> {
self.children.get(key)?.state.as_ref()
}
fn with_child<T>(
&mut self,
key: ChildTcpKey,
f: impl FnOnce(TcpStateEnum<X>) -> (TcpStateEnum<X>, T),
) -> Result<T, ChildNotFound> {
let rv;
{
let child = &mut self.children.get_mut(key).ok_or(ChildNotFound)?.state;
let mut state = child.take().unwrap();
(state, rv) = f(state);
*child = Some(state);
}
self.sync_child(key).unwrap();
Ok(rv)
}
}
impl<X: Dependencies> TcpStateTrait<X> for ListenState<X> {
fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let (new_state, rv) = self.rst_close();
assert!(rv.is_ok());
(new_state, Ok(()))
}
fn rst_close(mut self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
let child_keys = Vec::from_iter(self.children.keys());
for key in child_keys {
self.with_child(key, |child| child.rst_close())
.unwrap()
.unwrap();
while let Ok(Ok((header, payload))) = self.with_child(key, |child| child.pop_packet()) {
assert!(payload.is_empty());
self.send_buffer.push_back(header);
}
}
assert!(self.children.is_empty());
let rst_packets: LinkedList<_> = self
.send_buffer
.into_iter()
.filter(|header| header.flags.contains(TcpFlags::RST))
.collect();
let new_state = if rst_packets.is_empty() {
ClosedState::new(self.common, None, false).into()
} else {
RstState::new(self.common, rst_packets, false).into()
};
(new_state, Ok(()))
}
fn shutdown(self, _how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
(self.into(), Err(ShutdownError::NotConnected))
}
fn listen<T, E>(
mut self,
backlog: u32,
associate_fn: impl FnOnce() -> Result<T, E>,
) -> (TcpStateEnum<X>, Result<T, ListenError<E>>) {
let rv = match associate_fn() {
Ok(x) => x,
Err(e) => return (self.into(), Err(ListenError::FailedAssociation(e))),
};
let max_backlog = backlog.saturating_add(1);
self.max_backlog = max_backlog;
(self.into(), Ok(rv))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::IsListening))
}
fn accept(mut self) -> (TcpStateEnum<X>, Result<AcceptedTcpState<X>, AcceptError>) {
let Some(child_key) = self.accept_queue.pop_front() else {
return (self.into(), Err(AcceptError::NothingToAccept));
};
let child = self.remove_child(child_key).unwrap();
let accepted_state = match child.try_into() {
Ok(x) => x,
Err(child) => {
panic!("Unexpected child TCP state in accept queue: {:?}", child);
}
};
(self.into(), Ok(accepted_state))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::NotConnected))
}
fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
(self.into(), Err(RecvError::NotConnected))
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
let max_backlog = self.max_backlog.try_into().unwrap();
let syn_queue_len = self
.children
.len()
.checked_sub(self.accept_queue.len())
.unwrap();
let accept_queue_full = self.accept_queue.len() >= max_backlog;
let syn_queue_full = syn_queue_len >= max_backlog;
if header.flags.contains(TcpFlags::SYN) && (accept_queue_full || syn_queue_full) {
return (self.into(), Ok(0));
}
let conn_addrs = RemoteLocalPair::new(header.src(), header.dst());
if let Some(child_key) = self.conn_map.get(&conn_addrs) {
if matches!(self.child(*child_key), Some(TcpStateEnum::SynReceived(_)))
&& header.flags.contains(TcpFlags::ACK)
&& accept_queue_full
{
return (self.into(), Ok(0));
}
let rv = self
.with_child(*child_key, |state| state.push_packet(header, payload))
.unwrap();
return (self.into(), rv);
}
if !header.flags.contains(TcpFlags::SYN) {
return (self.into(), Ok(0));
}
self.register_child(header, payload);
(self.into(), Ok(0))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
if let Some(header) = self.send_buffer.pop_front() {
return (self.into(), Ok((header, Payload::default())));
}
if let Some(child_key) = self.to_send.pop_front() {
let rv = self
.with_child(child_key, |state| state.pop_packet())
.unwrap();
let (header, payload) = rv.unwrap();
debug_assert!(payload.is_empty());
return (self.into(), Ok((header, payload)));
}
(self.into(), Err(PopPacketError::NoPacket))
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::LISTENING;
if !self.accept_queue.is_empty() {
poll_state.insert(PollState::READY_TO_ACCEPT);
}
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
!self.send_buffer.is_empty() || !self.to_send.is_empty()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
None
}
}
impl<X: Dependencies> SynSentState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
let state = SynSentState { common, connection };
let timeout = state.common.current_time() + X::Duration::from_secs(60);
state.common.register_timer(timeout, |state| {
if let TcpStateEnum::SynSent(mut state) = state {
state.common.error = Some(TcpError::TimedOut);
let (state, rv) = state.rst_close();
assert!(rv.is_ok());
state
} else {
state
}
});
state
}
}
impl<X: Dependencies> TcpStateTrait<X> for SynSentState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
debug_assert!(!self.connection.recv_buf_has_data());
self.common
.set_error_if_unset(TcpError::ClosedWhileConnecting);
let new_state = ClosedState::new(self.common, None, true);
(new_state.into(), Ok(()))
}
fn rst_close(mut self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
debug_assert!(!self.connection.recv_buf_has_data());
self.common
.set_error_if_unset(TcpError::ClosedWhileConnecting);
let new_state = ClosedState::new(self.common, None, true);
(new_state.into(), Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
debug_assert!(!self.connection.recv_buf_has_data());
self.common
.set_error_if_unset(TcpError::ClosedWhileConnecting);
let new_state = ClosedState::new(self.common, None, true);
return (new_state.into(), Ok(()));
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::InProgress))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::NotConnected))
}
fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
(self.into(), Err(RecvError::NotConnected))
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.received_syn() && self.connection.syn_was_acked() {
let new_state = EstablishedState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
if self.connection.received_syn() {
let new_state = SynReceivedState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTING;
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> SynReceivedState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
let state = SynReceivedState { common, connection };
let timeout = state.common.current_time() + X::Duration::from_secs(60);
state.common.register_timer(timeout, |state| {
if let TcpStateEnum::SynReceived(mut state) = state {
state.common.error = Some(TcpError::TimedOut);
let (state, rv) = state.rst_close();
assert!(rv.is_ok());
return state;
}
state
});
state
}
}
impl<X: Dependencies> TcpStateTrait<X> for SynReceivedState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_fin();
self.common
.set_error_if_unset(TcpError::ClosedWhileConnecting);
self.connection.send_rst_if_recv_payload();
FinWaitOneState::new(self.common, self.connection).into()
};
(new_state, Ok(()))
}
fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
let new_state = reset_connection(self.common, self.connection);
(new_state.into(), Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
self.connection.send_fin();
self.common
.set_error_if_unset(TcpError::ClosedWhileConnecting);
let new_state = FinWaitOneState::new(self.common, self.connection);
return (new_state.into(), Ok(()));
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::InProgress))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::NotConnected))
}
fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
(self.into(), Err(RecvError::NotConnected))
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.syn_was_acked() {
let new_state = EstablishedState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTING;
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> EstablishedState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
EstablishedState { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for EstablishedState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_fin();
self.connection.send_rst_if_recv_payload();
FinWaitOneState::new(self.common, self.connection).into()
};
(new_state, Ok(()))
}
fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
let new_state = reset_connection(self.common, self.connection);
(new_state.into(), Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
self.connection.send_fin();
let new_state = FinWaitOneState::new(self.common, self.connection);
return (new_state.into(), Ok(()));
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(
mut self,
reader: impl Read,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, SendError>) {
let rv = self.connection.send(reader, len);
(self.into(), rv)
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.received_fin() {
let new_state = CloseWaitState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
if self.connection.send_buf_has_space() {
poll_state.insert(PollState::WRITABLE);
}
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> FinWaitOneState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
FinWaitOneState { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for FinWaitOneState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_rst_if_recv_payload();
self.into()
};
(new_state, Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.received_fin() && self.connection.fin_was_acked() {
let new_state = TimeWaitState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
if self.connection.received_fin() {
let new_state = ClosingState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
if self.connection.fin_was_acked() {
let new_state = FinWaitTwoState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> FinWaitTwoState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
FinWaitTwoState { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for FinWaitTwoState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_rst_if_recv_payload();
self.into()
};
(new_state, Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.received_fin() {
let new_state = TimeWaitState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> ClosingState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
ClosingState { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for ClosingState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_rst_if_recv_payload();
self.into()
};
(new_state, Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
if matches!(rv, Err(RecvError::Empty)) {
return (self.into(), Err(RecvError::StreamClosed));
}
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.fin_was_acked() {
let new_state = TimeWaitState::new(self.common, self.connection);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
poll_state.insert(PollState::RECV_CLOSED);
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> TimeWaitState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
let state = TimeWaitState { common, connection };
let timeout = X::Duration::from_secs(60);
let timeout = state.common.current_time() + timeout;
state.common.register_timer(timeout, |state| {
if let TcpStateEnum::TimeWait(state) = state {
let recv_buffer = state.connection.into_recv_buffer();
let new_state =
ClosedState::new(state.common, recv_buffer, true);
new_state.into()
} else {
state
}
});
state
}
}
impl<X: Dependencies> TcpStateTrait<X> for TimeWaitState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
self.connection.send_rst_if_recv_payload();
(self.into(), Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
if matches!(rv, Err(RecvError::Empty)) {
return (self.into(), Err(RecvError::StreamClosed));
}
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
poll_state.insert(PollState::RECV_CLOSED);
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> CloseWaitState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
Self { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for CloseWaitState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_fin();
self.connection.send_rst_if_recv_payload();
LastAckState::new(self.common, self.connection).into()
};
(new_state, Ok(()))
}
fn rst_close(self) -> (TcpStateEnum<X>, Result<(), RstCloseError>) {
let new_state = reset_connection(self.common, self.connection);
(new_state.into(), Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
self.connection.send_fin();
let new_state = LastAckState::new(self.common, self.connection);
return (new_state.into(), Ok(()));
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(
mut self,
reader: impl Read,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, SendError>) {
let rv = self.connection.send(reader, len);
(self.into(), rv)
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
if matches!(rv, Err(RecvError::Empty)) {
return (self.into(), Err(RecvError::StreamClosed));
}
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
if self.connection.send_buf_has_space() {
poll_state.insert(PollState::WRITABLE);
}
poll_state.insert(PollState::RECV_CLOSED);
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> LastAckState<X> {
fn new(common: Common<X>, connection: Connection<X::Instant>) -> Self {
Self { common, connection }
}
}
impl<X: Dependencies> TcpStateTrait<X> for LastAckState<X> {
fn close(mut self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
let new_state = if self.connection.recv_buf_has_data() {
reset_connection(self.common, self.connection).into()
} else {
self.connection.send_rst_if_recv_payload();
self.into()
};
(new_state, Ok(()))
}
fn shutdown(mut self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if how == Shutdown::Read || how == Shutdown::Both {
self.connection.send_rst_if_recv_payload()
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
(self.into(), Err(ConnectError::AlreadyConnected))
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
let rv = self.connection.recv(writer, len);
if matches!(rv, Err(RecvError::Empty)) {
return (self.into(), Err(RecvError::StreamClosed));
}
(self.into(), rv)
}
fn push_packet(
mut self,
header: &TcpHeader,
payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
if !self.connection.packet_addrs_match(header) {
return (self.into(), Ok(0));
}
let pushed_len = match self.connection.push_packet(header, payload) {
Ok(v) => v,
Err(e) => return (self.into(), Err(e)),
};
if self.connection.is_reset() {
if header.flags.contains(TcpFlags::RST) {
self.common.set_error_if_unset(TcpError::ResetReceived);
}
let new_state = connection_was_reset(self.common, self.connection);
return (new_state, Ok(pushed_len));
}
if self.connection.fin_was_acked() {
let recv_buffer = self.connection.into_recv_buffer();
let new_state =
ClosedState::new(self.common, recv_buffer, true);
return (new_state.into(), Ok(pushed_len));
}
(self.into(), Ok(pushed_len))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let rv = self.connection.pop_packet(self.common.current_time());
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CONNECTED;
poll_state.insert(PollState::RECV_CLOSED);
if self.connection.recv_buf_has_data() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
self.connection.wants_to_send()
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
Some((self.connection.local_addr, self.connection.remote_addr))
}
}
impl<X: Dependencies> RstState<X> {
fn new(common: Common<X>, rst_packets: LinkedList<TcpHeader>, was_connected: bool) -> Self {
debug_assert!(rst_packets.iter().all(|x| x.flags.contains(TcpFlags::RST)));
assert!(!rst_packets.is_empty());
Self {
common,
send_buffer: rst_packets,
was_connected,
}
}
}
impl<X: Dependencies> TcpStateTrait<X> for RstState<X> {
fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
(self.into(), Ok(()))
}
fn shutdown(self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if !self.was_connected {
return (self.into(), Err(ShutdownError::NotConnected));
}
if how == Shutdown::Read || how == Shutdown::Both {
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
if self.was_connected {
(self.into(), Err(ConnectError::AlreadyConnected))
} else {
(self.into(), Err(ConnectError::InvalidState))
}
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
if self.was_connected {
(self.into(), Err(SendError::StreamClosed))
} else {
(self.into(), Err(SendError::NotConnected))
}
}
fn recv(self, _writer: impl Write, _len: usize) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
if self.was_connected {
(self.into(), Err(RecvError::StreamClosed))
} else {
(self.into(), Err(RecvError::NotConnected))
}
}
fn push_packet(
self,
_header: &TcpHeader,
_payload: Payload,
) -> (TcpStateEnum<X>, Result<u32, PushPacketError>) {
(self.into(), Ok(0))
}
fn pop_packet(
mut self,
) -> (
TcpStateEnum<X>,
Result<(TcpHeader, Payload), PopPacketError>,
) {
let header = self.send_buffer.pop_front().unwrap();
let packet = (header, Payload::default());
assert!(packet.0.flags.contains(TcpFlags::RST));
if self.send_buffer.is_empty() {
let new_state = ClosedState::new(
self.common,
None,
self.was_connected,
);
return (new_state.into(), Ok(packet));
}
(self.into(), Ok(packet))
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::RECV_CLOSED | PollState::SEND_CLOSED;
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
if self.was_connected {
poll_state.insert(PollState::CONNECTED);
}
poll_state
}
fn wants_to_send(&self) -> bool {
assert!(!self.send_buffer.is_empty());
true
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
None
}
}
impl<X: Dependencies> ClosedState<X> {
fn new(common: Common<X>, recv_buffer: Option<RecvQueue>, was_connected: bool) -> Self {
let recv_buffer = recv_buffer.unwrap_or_else(|| RecvQueue::new(Seq::new(0)));
if !was_connected {
assert!(recv_buffer.is_empty());
}
Self {
common,
recv_buffer,
was_connected,
}
}
}
impl<X: Dependencies> TcpStateTrait<X> for ClosedState<X> {
fn close(self) -> (TcpStateEnum<X>, Result<(), CloseError>) {
(self.into(), Ok(()))
}
fn shutdown(self, how: Shutdown) -> (TcpStateEnum<X>, Result<(), ShutdownError>) {
if !self.was_connected {
return (self.into(), Err(ShutdownError::NotConnected));
}
if how == Shutdown::Read || how == Shutdown::Both {
}
if how == Shutdown::Write || how == Shutdown::Both {
}
(self.into(), Ok(()))
}
fn connect<T, E>(
self,
_remote_addr: SocketAddrV4,
_associate_fn: impl FnOnce() -> Result<(SocketAddrV4, T), E>,
) -> (TcpStateEnum<X>, Result<T, ConnectError<E>>) {
if self.was_connected {
(self.into(), Err(ConnectError::AlreadyConnected))
} else {
(self.into(), Err(ConnectError::InvalidState))
}
}
fn send(self, _reader: impl Read, _len: usize) -> (TcpStateEnum<X>, Result<usize, SendError>) {
if !self.was_connected {
return (self.into(), Err(SendError::NotConnected));
}
(self.into(), Err(SendError::StreamClosed))
}
fn recv(
mut self,
writer: impl Write,
len: usize,
) -> (TcpStateEnum<X>, Result<usize, RecvError>) {
if !self.was_connected {
return (self.into(), Err(RecvError::NotConnected));
}
if self.recv_buffer.is_empty() {
return (self.into(), Err(RecvError::StreamClosed));
}
let rv = self.recv_buffer.read(writer, len).map_err(RecvError::Io);
(self.into(), rv)
}
fn clear_error(&mut self) -> Option<TcpError> {
self.common.error.take()
}
fn poll(&self) -> PollState {
let mut poll_state = PollState::CLOSED;
poll_state.insert(PollState::RECV_CLOSED);
if !self.recv_buffer.is_empty() {
poll_state.insert(PollState::READABLE);
}
poll_state.insert(PollState::SEND_CLOSED);
assert!(!poll_state.contains(PollState::WRITABLE));
if self.was_connected {
poll_state.insert(PollState::CONNECTED);
}
if self.common.error.is_some() {
poll_state.insert(PollState::ERROR);
}
poll_state
}
fn wants_to_send(&self) -> bool {
false
}
fn local_remote_addrs(&self) -> Option<(SocketAddrV4, SocketAddrV4)> {
None
}
}
fn reset_connection<X: Dependencies>(
common: Common<X>,
mut connection: Connection<X::Instant>,
) -> RstState<X> {
connection.send_rst();
let new_state = connection_was_reset(common, connection);
let TcpStateEnum::Rst(new_state) = new_state else {
panic!("We called `send_rst()` above but aren't now in the \"rst\" state: {new_state:?}");
};
new_state
}
fn connection_was_reset<X: Dependencies>(
mut common: Common<X>,
mut connection: Connection<X::Instant>,
) -> TcpStateEnum<X> {
assert!(connection.is_reset());
let now = common.current_time();
if let Ok((header, payload)) = connection.pop_packet(now) {
assert!(payload.is_empty());
debug_assert!(connection.pop_packet(now).is_err());
common.set_error_if_unset(TcpError::ResetSent);
let rst_packets = [header].into_iter().collect();
RstState::new(common, rst_packets, true).into()
} else {
ClosedState::new(common, None, true).into()
}
}