use async_trait::async_trait;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct TcpTransport {
writer: Arc<Mutex<tokio::io::WriteHalf<TcpStream>>>,
}
#[async_trait]
impl Transport for TcpTransport {
async fn send(&self, data: Vec<u8>) -> Result<(), anyhow::Error> {
let mut writer = self.writer.lock().await;
writer.write_all(&data).await?;
Ok(())
}
async fn disconnect(&self) {
// TCP disconnect handled by drop
}
}
pub struct TcpTransportFactory {
address: String,
}
impl TcpTransportFactory {
pub fn new(address: impl Into<String>) -> Self {
Self {
address: address.into(),
}
}
}
#[async_trait]
impl TransportFactory for TcpTransportFactory {
async fn create_transport(
&self,
) -> Result<(Arc<dyn Transport>, async_channel::Receiver<TransportEvent>), anyhow::Error>
{
let stream = TcpStream::connect(&self.address).await?;
let (reader, writer) = tokio::io::split(stream);
let (event_tx, event_rx) = async_channel::bounded(100);
let transport = Arc::new(TcpTransport {
writer: Arc::new(Mutex::new(writer)),
});
// Spawn read task
tokio::task::spawn(async move {
let mut reader = reader;
let mut buf = vec![0u8; 4096];
event_tx.send(TransportEvent::Connected).await.ok();
loop {
match reader.read(&mut buf).await {
Ok(0) => break,
Ok(n) => {
let data = bytes::Bytes::copy_from_slice(&buf[..n]);
if event_tx.send(TransportEvent::DataReceived(data)).await.is_err() {
break;
}
}
Err(_) => break,
}
}
event_tx.send(TransportEvent::Disconnected).await.ok();
});
Ok((transport, event_rx))
}
}