caveman/smokestack/src/main.rs

226 lines
5.7 KiB
Rust

use std::{
collections::HashMap,
sync::{
Arc,
RwLock,
},
};
use ansi_term::Colour::{self, Black, RGB, Red, Yellow};
use futures::{Stream, StreamExt};
use tokio::{
io::AsyncWriteExt,
net::TcpListener,
sync::mpsc::{channel, Receiver, Sender},
};
use cave::{
config::LoadConfig,
feed::Post,
firehose::FirehoseFactory,
};
mod config;
fn html_to_text(html: &str) -> String {
let mut result = String::with_capacity(html.len());
let mut in_tag = false;
let mut entity = None;
for c in html.chars() {
if c == '<' {
// tag open
in_tag = true;
} else if in_tag && c == '>' {
// tag close
in_tag = false;
} else if in_tag {
// ignore
} else if c == '&' {
entity = Some(String::with_capacity(5));
} else if entity.is_some() && c == ';' {
let r = match entity.take().unwrap().as_str() {
"amp" => "&",
"lt" => "<",
"gt" => ">",
"quot" => "\"",
"apos" => "\'",
_ => "",
};
result.push_str(r);
} else if let Some(entity) = entity.as_mut() {
entity.push(c);
} else {
result.push(c);
}
}
result
}
fn language_colour(language: &str) -> Colour {
let x = language.bytes().fold(0, |x, b| x ^ b);
let b = language.as_bytes();
let y = if !b.is_empty() {
(b[0] & 0x1F) << 2
} else {
127
};
let z = if b.len() >= 2 {
(b[1] & 0x1F) << 2
} else {
127
};
match x % 6 {
0 => RGB(127 + y, 0, 0),
1 => RGB(0, 127 + z, 0),
2 => RGB(0, 0, 127 + y),
3 => RGB(127 + y, 127 + z, 0),
4 => RGB(127 + y, 0, 127 + z),
5 => RGB(0, 127 + y, 127 + z),
6..=u8::MAX => unreachable!(),
}
}
fn format_message(post: Post) -> Option<String> {
let time_str;
let time = if let Some(time) = post.timestamp() {
time_str = format!("{}", time.format("%H:%M:%S"));
&time_str
} else {
&post.created_at
};
let language = &post.language?;
let display_name = &post.account.display_name;
let username = &post.account.username;
let host = post.account.host()?;
let text = html_to_text(&post.content);
Some(format!(
"[{}] {} {} <@{}@{}>\r\n{}\r\n\r\n",
Black.on(language_colour(language)).paint(language),
Red.paint(time),
Yellow.bold().paint(display_name),
Yellow.underline().paint(username),
Yellow.underline().paint(host),
text,
))
}
#[derive(Clone)]
struct State {
next_id: Arc<RwLock<usize>>,
consumers: Arc<RwLock<HashMap<usize, Sender<Arc<Vec<u8>>>>>>,
}
impl State {
pub fn new() -> Self {
State {
next_id: Arc::new(RwLock::new(0)),
consumers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn get_next_id(&self) -> usize {
let mut next_id = self.next_id.write().unwrap();
let result = *next_id;
*next_id += 1;
result
}
pub fn pipe(&self) -> Pipe {
let (tx, rx) = channel(128);
let id = self.get_next_id();
let mut consumers = self.consumers.write().unwrap();
consumers.insert(id, tx);
Pipe {
id, rx,
consumers: self.consumers.clone(),
}
}
pub async fn broadcast(&self, msg: Arc<Vec<u8>>) {
let txs = {
let consumers = self.consumers.read().unwrap();
consumers.values()
.cloned()
.collect::<Vec<_>>()
};
for tx in txs {
let _ = tx.try_send(msg.clone());
}
}
}
struct Pipe {
id: usize,
pub rx: Receiver<Arc<Vec<u8>>>,
consumers: Arc<RwLock<HashMap<usize, Sender<Arc<Vec<u8>>>>>>,
}
impl Drop for Pipe {
fn drop(&mut self) {
tracing::info!("Consumer disconnected");
let mut consumers = self.consumers.write().unwrap();
consumers.remove(&self.id);
}
}
async fn publisher(state: State, firehose: impl Stream<Item = (Vec<u8>, Vec<u8>)>) {
firehose.for_each(move |(event_type, data)| {
let state = state.clone();
async move {
if event_type != b"update" {
// Only process new posts, no updates
return;
}
let post = serde_json::from_slice(&data)
.ok();
let msg = post.and_then(format_message);
if let Some(msg) = msg {
state.broadcast(Arc::new(msg.into_bytes())).await;
}
cave::systemd::watchdog();
}
}).await;
}
#[tokio::main]
async fn main() {
cave::init::exit_on_panic();
cave::init::init_logger(5557);
let config = config::Config::load();
let state = State::new();
let firehose_factory = FirehoseFactory::new(config.redis, config.redis_password_file);
let firehose = firehose_factory.produce()
.await
.expect("firehose");
tokio::spawn(
publisher(state.clone(), firehose)
);
let listener = TcpListener::bind(
format!("[::]:{}", config.listen_port)
).await.expect("TcpListener::bind");
cave::systemd::ready();
while let Ok((mut socket, addr)) = listener.accept().await {
tracing::info!("Accepted connection from {:?}", addr);
let mut pipe = state.pipe();
tokio::spawn(async move {
while let Some(msg) = pipe.rx.recv().await {
match socket.write_all(&msg[..]).await {
Ok(()) => {}
Err(_) => break,
}
}
});
}
}