use std::{ collections::HashMap, sync::{ Arc, RwLock, }, }; use ansi_term::Colour::{self, *}; use futures::{Stream, StreamExt}; use tokio::{ io::AsyncWriteExt, net::TcpListener, sync::mpsc::{channel, Receiver, Sender}, }; use cave::{ config::LoadConfig, feed::Post, firehose::FirehoseFactory, }; mod config; fn html_to_text(html: &str) -> String { let mut result = String::with_capacity(html.len()); let mut in_tag = false; let mut entity = None; for c in html.chars() { if c == '<' { // tag open in_tag = true; } else if in_tag && c == '>' { // tag close in_tag = false; } else if in_tag { // ignore } else if c == '&' { entity = Some(String::with_capacity(5)); } else if entity.is_some() && c == ';' { let r = match entity.take().unwrap().as_str() { "amp" => "&", "lt" => "<", "gt" => ">", "quot" => "\"", "apos" => "\'", _ => "", }; result.push_str(r); } else if let Some(entity) = entity.as_mut() { entity.push(c); } else { result.push(c); } } result } fn language_colour(language: &str) -> Colour { let x = language.bytes().fold(0, |x, b| x ^ b); let b = language.as_bytes(); let y = if b.len() >= 1 { (b[0] & 0x1F) << 2 } else { 127 }; let z = if b.len() >= 2 { (b[1] & 0x1F) << 2 } else { 127 }; match x % 6 { 0 => RGB(127 + y, 0, 0), 1 => RGB(0, 127 + z, 0), 2 => RGB(0, 0, 127 + y), 3 => RGB(127 + y, 127 + z, 0), 4 => RGB(127 + y, 0, 127 + z), 5 => RGB(0, 127 + y, 127 + z), 6..=u8::MAX => unreachable!(), } } fn format_message(post: Post) -> Option { let time_str; let time = if let Some(time) = post.timestamp() { time_str = format!("{}", time.format("%H:%M:%S")); &time_str } else { &post.created_at }; let language = &post.language?; let display_name = &post.account.display_name; let username = &post.account.username; let host = post.account.host()?; let text = html_to_text(&post.content); Some(format!( "[{}] {} {} <@{}@{}>\r\n{}\r\n\r\n", Black.on(language_colour(language)).paint(language), Red.paint(time), Yellow.bold().paint(display_name), Yellow.underline().paint(username), Yellow.underline().paint(host), text, )) } #[derive(Clone)] struct State { next_id: Arc>, consumers: Arc>>>>>, } impl State { pub fn new() -> Self { State { next_id: Arc::new(RwLock::new(0)), consumers: Arc::new(RwLock::new(HashMap::new())), } } pub fn get_next_id(&self) -> usize { let mut next_id = self.next_id.write().unwrap(); let result = *next_id; *next_id += 1; result } pub fn pipe(&self) -> Pipe { let (tx, rx) = channel(128); let id = self.get_next_id(); let mut consumers = self.consumers.write().unwrap(); consumers.insert(id, tx); Pipe { id, rx, consumers: self.consumers.clone(), } } pub async fn broadcast(&self, msg: Arc>) { let txs = { let consumers = self.consumers.read().unwrap(); consumers.values() .cloned() .collect::>() }; for tx in txs { let _ = tx.try_send(msg.clone()); } } } struct Pipe { id: usize, pub rx: Receiver>>, consumers: Arc>>>>>, } impl Drop for Pipe { fn drop(&mut self) { tracing::info!("Consumer disconnected"); let mut consumers = self.consumers.write().unwrap(); consumers.remove(&self.id); } } async fn publisher(state: State, firehose: impl Stream, Vec)>) { firehose.for_each(move |(event_type, data)| { let state = state.clone(); async move { if event_type != b"update" { // Only process new posts, no updates return; } let post = serde_json::from_slice(&data) .ok(); let msg = post.and_then(format_message); if let Some(msg) = msg { state.broadcast(Arc::new(msg.into_bytes())).await; } cave::systemd::watchdog(); } }).await; } #[tokio::main] async fn main() { cave::init::exit_on_panic(); cave::init::init_logger(5557); let config = config::Config::load(); let state = State::new(); let firehose_factory = FirehoseFactory::new(config.redis, config.redis_password_file); let firehose = firehose_factory.produce() .await .expect("firehose"); tokio::spawn( publisher(state.clone(), firehose) ); let listener = TcpListener::bind( format!("[::]:{}", config.listen_port) ).await.expect("TcpListener::bind"); cave::systemd::ready(); while let Ok((mut socket, addr)) = listener.accept().await { tracing::info!("Accepted connection from {:?}", addr); let mut pipe = state.pipe(); tokio::spawn(async move { while let Some(msg) = pipe.rx.recv().await { match socket.write_all(&msg[..]).await { Ok(_) => {} Err(_) => break, } } }); } }