use std::{ collections::HashMap, sync::{ Arc, RwLock, }, }; use futures::{Stream, StreamExt}; use tokio::{ io::AsyncWriteExt, net::TcpListener, sync::mpsc::{channel, Receiver, Sender}, }; use cave::{ config::LoadConfig, feed::Post, firehose::FirehoseFactory, }; mod config; fn html_to_text(html: &str) -> String { let mut result = String::with_capacity(html.len()); let mut in_tag = false; let mut entity = None; for c in html.chars() { if c == '<' { // tag open in_tag = true; } else if in_tag && c == '>' { // tag close in_tag = false; } else if in_tag { // ignore } else if c == '&' { entity = Some(String::with_capacity(5)); } else if entity.is_some() && c == ';' { let r = match entity.take().unwrap().as_str() { "amp" => "&", "lt" => "<", "gt" => ">", "quot" => "\"", "apos" => "\'", _ => "", }; result.push_str(r); } else if let Some(entity) = entity.as_mut() { entity.push(c); } else { result.push(c); } } result } fn format_message(post: Post) -> Option { let language = &post.language?; let time = &post.created_at; let display_name = &post.account.display_name; let username = &post.account.username; let host = post.account.host()?; let text = html_to_text(&post.content); Some(format!( "[{}] {} {} <@{}@{}>\r\n{}\r\n\r\n", language, time, display_name, username, host, text, )) } #[derive(Clone)] struct State { next_id: Arc>, consumers: Arc>>>>>, } impl State { pub fn new() -> Self { State { next_id: Arc::new(RwLock::new(0)), consumers: Arc::new(RwLock::new(HashMap::new())), } } pub fn get_next_id(&self) -> usize { let mut next_id = self.next_id.write().unwrap(); let result = *next_id; *next_id += 1; result } pub fn pipe(&self) -> Pipe { let (tx, rx) = channel(1); let id = self.get_next_id(); let mut consumers = self.consumers.write().unwrap(); consumers.insert(id, tx); Pipe { id, rx, consumers: self.consumers.clone(), } } pub async fn broadcast(&self, msg: Arc>) { let txs = { let consumers = self.consumers.read().unwrap(); consumers.values() .cloned() .collect::>() }; for tx in txs { let _ = tx.send(msg.clone()).await; } } } struct Pipe { id: usize, pub rx: Receiver>>, consumers: Arc>>>>>, } impl Drop for Pipe { fn drop(&mut self) { log::trace!("drop pipe"); let mut consumers = self.consumers.write().unwrap(); consumers.remove(&self.id); } } async fn publisher(state: State, firehose: impl Stream>) { firehose.for_each(move |data| { let state = state.clone(); async move { let post = serde_json::from_slice(&data) .ok(); let msg = post.and_then(format_message); if let Some(msg) = msg { state.broadcast(Arc::new(msg.into_bytes())).await; } } }).await; } #[tokio::main] async fn main() { cave::init::exit_on_panic(); cave::init::init_logger(); let config = config::Config::load(); let state = State::new(); let firehose_factory = FirehoseFactory::new(config.redis); let firehose = firehose_factory.produce() .await .expect("firehose") .filter_map(|item| async move { item.ok() }); tokio::spawn( publisher(state.clone(), firehose) ); let listener = TcpListener::bind( format!("[::]:{}", config.listen_port) ).await.expect("TcpListener::bind"); cave::systemd::ready(); while let Ok((mut socket, addr)) = listener.accept().await { log::info!("Accepted connection from {:?}", addr); let mut pipe = state.pipe(); tokio::spawn(async move { log::trace!("while..."); while let Some(msg) = pipe.rx.recv().await { match socket.write_all(&msg[..]).await { Ok(_) => {} Err(_) => break, } } }); } }