use std::collections::HashSet; use std::future::Future; use std::sync::Arc; use std::time::{Duration, Instant}; use cave::{ block_list::BlockList, db::Database, feed::{Feed, EncodablePost, Post, StreamError}, posts_cache::PostsCache, store::Store, }; use futures::{StreamExt, future, stream::iter}; use reqwest::StatusCode; use crate::scheduler::{Host, InstanceHost}; use crate::webfinger; #[derive(Clone)] pub struct RobotsTxt { robot: Arc>, } impl RobotsTxt { pub async fn fetch( client: &reqwest::Client, host: &Host, ) -> Self { let url = format!("https://{host}/robots.txt"); metrics::increment_gauge!("hunter_requests", 1.0, "type" => "robotstxt"); let robot = async { let body = client.get(url) .send().await .ok()? .text().await .ok()?; texting_robots::Robot::new( env!("CARGO_PKG_NAME"), body.as_bytes(), ).ok() }.await; metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "robotstxt"); RobotsTxt { robot: Arc::new(robot), } } pub fn allowed(&self, url: &str) -> bool { if let Some(ref robot) = self.robot.as_ref() { robot.allowed(url) } else { true } } pub fn delay(&self) -> Option { if let Some(ref robot) = self.robot.as_ref() { robot.delay.map(|delay| Duration::from_secs(delay as u64)) } else { None } } } #[derive(Debug)] pub enum Message { WorkerDone, Fetched { host: Host, new_post_ratio: Option, mean_interval: Option, }, IntroduceHost(InstanceHost), } pub async fn run( message_tx: tokio::sync::mpsc::UnboundedSender, store: Store, db: Database, posts_cache: PostsCache, block_list: BlockList, client: reqwest::Client, host: InstanceHost, ) { // Fetch /robots.txt let robots_txt = RobotsTxt::fetch(&client, &host.host).await; let robots_delay = robots_txt.delay(); // Fetch posts and open stream let (timeline_result, stream_result) = future::join( fetch_timeline( message_tx.clone(), store.clone(), &posts_cache, block_list.clone(), &client, robots_txt.clone(), &host.host ), open_stream( message_tx.clone(), store.clone(), db.clone(), &posts_cache, block_list.clone(), &client, robots_txt, host.host.clone() ), ).await; if let Err(e) = &timeline_result { tracing::error!("{}", e); } if let Err(e) = &stream_result { tracing::error!("{}", e); } // If there is a web server responding, its Webfinger endpoint may point to another domain match (&host.known_user, &timeline_result) { (Some(known_user), Err(timeline_err)) if timeline_err.is_status() => { metrics::increment_gauge!("hunter_requests", 1.0, "type" => "webfinger"); let webfinger_result = webfinger::get_hosts_from_webfinger(&client, known_user, &host.host).await; metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "webfinger"); if let Ok(hosts) = webfinger_result { for host in hosts { message_tx.send(Message::IntroduceHost(InstanceHost { host, known_user: None, })).unwrap(); } } } _ => {} } // Next worker can start message_tx.send(Message::WorkerDone).unwrap(); // Process stream let (mut new_post_ratio, mut mean_interval) = timeline_result.unwrap_or((None, None)); if let Ok((stats_key, stream)) = stream_result { tracing::info!("Processing {stats_key} for {}", &host.host); metrics::increment_gauge!("hunter_requests", 1.0, "type" => stats_key); let start_time = Instant::now(); let post_count = stream.await; let end_time = Instant::now(); metrics::decrement_gauge!("hunter_requests", 1.0, "type" => stats_key); tracing::warn!("Ended {stats_key} for {}. {} posts in {:?}", &host.host, post_count, end_time - start_time); if post_count > 0 { if let Some(ref mut new_post_ratio) = new_post_ratio { *new_post_ratio += post_count as f64 / 100.; } let stream_avg_interval = Duration::from_secs_f64( (end_time - start_time).as_secs_f64() / (post_count as f64) ); if mean_interval.map_or(true, |mean_interval| stream_avg_interval < mean_interval) { mean_interval = Some(stream_avg_interval); } } } // Ready for reenqueue if let Some(mean_interval) = &mut mean_interval { if let Some(robots_delay) = robots_delay { *mean_interval = (*mean_interval).max(robots_delay); } } message_tx.send(Message::Fetched { host: host.host, new_post_ratio, mean_interval, }).unwrap(); } async fn fetch_timeline( message_tx: tokio::sync::mpsc::UnboundedSender, mut store: Store, posts_cache: &PostsCache, block_list: BlockList, client: &reqwest::Client, robots_txt: RobotsTxt, host: &Host, ) -> Result<(Option, Option), reqwest::Error> { let url = format!("https://{host}/api/v1/timelines/public?limit=40"); if ! robots_txt.allowed(&url) { tracing::warn!("Timeline of {} forbidden by robots.txt", host); return Ok((None, None)); } // free as early as possible drop(robots_txt); metrics::increment_gauge!("hunter_requests", 1.0, "type" => "timeline"); let t1 = Instant::now(); let result = Feed::fetch(client, &url).await; let t2 = Instant::now(); metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "timeline"); metrics::histogram!("hunter_fetch_seconds", t2 - t1, "result" => if result.is_ok() { "ok" } else { "error" }); let feed = result?; let mean_interval = feed.mean_post_interval(); let (new_post_ratio, introduce_hosts) = process_posts(&mut store, posts_cache, block_list, host, feed.posts.into_iter()).await; for introduce_host in introduce_hosts { message_tx.send(Message::IntroduceHost(introduce_host)).unwrap(); } // successfully fetched, save for future run store.save_host(host).await.unwrap(); Ok((new_post_ratio, mean_interval)) } fn scan_for_hosts(introduce_hosts: &mut Vec, post: &Post) { // introduce instances from accounts if let Ok(host) = (&post.account).try_into() { introduce_hosts.push(host); } // introduce instances from mentions for mention in &post.mentions { if let Ok(host) = mention.try_into() { introduce_hosts.push(host); } } } async fn process_posts( store: &mut Store, posts_cache: &PostsCache, block_list: BlockList, host: &Host, posts: impl Iterator, ) -> (Option, Vec) { // introduce new hosts, validate posts let mut introduce_hosts = Vec::new(); let mut new_posts = 0; let mut posts_len = 0; for post in posts { posts_len += 1; // potentially save a round-trip to redis with an in-process cache if ! posts_cache.insert(post.url.clone()) { let t1 = Instant::now(); scan_for_hosts(&mut introduce_hosts, &post); if let Some(reblog) = &post.reblog { scan_for_hosts(&mut introduce_hosts, reblog); } if let Some(account_host) = post.account.host() { if block_list.is_blocked(&account_host).await { tracing::warn!("ignore post from blocked host {account_host}"); } else if store.save_post(post).await == Ok(true) { // send away to redis new_posts += 1; } } else { tracing::warn!("drop repost ({:?} on {})", post.account.host(), host); } let t2 = Instant::now(); metrics::histogram!("hunter_post_process_seconds", t2 - t1); } } tracing::trace!("{}: {}/{} new posts", host, new_posts, posts_len); metrics::counter!("hunter_posts", new_posts, "type" => "new"); metrics::counter!("hunter_posts", posts_len, "type" => "total"); let new_post_ratio = if posts_len > 0 { let ratio = (new_posts as f64) / (posts_len as f64); metrics::histogram!("hunter_new_post_ratio", ratio); Some(ratio) } else { None }; // dedup introduce_hosts let mut seen_hosts = HashSet::with_capacity(introduce_hosts.len()); let introduce_hosts = iter( introduce_hosts.into_iter() .filter_map(|introduce_host| { if ! seen_hosts.contains(&introduce_host.host) { seen_hosts.insert(introduce_host.host.clone()); Some(introduce_host) } else { None } }) ) .filter(|introduce_host| { let block_list = block_list.clone(); let host = introduce_host.host.to_string(); async move { ! block_list.is_blocked(&host).await } }) .collect().await; (new_post_ratio, introduce_hosts) } async fn open_stream( message_tx: tokio::sync::mpsc::UnboundedSender, store: Store, db: Database, posts_cache: &PostsCache, block_list: BlockList, client: &reqwest::Client, robots_txt: RobotsTxt, host: Host, ) -> Result<(&'static str, impl Future), String> { let url = format!("https://{host}/api/v1/streaming/public"); if ! robots_txt.allowed(&url) { return Err(format!("Streaming of {host} forbidden by robots.txt")); } // free as early as possible drop(robots_txt); let posts_cache = posts_cache.clone(); metrics::increment_gauge!("hunter_requests", 1.0, "type" => "stream_open"); let mut stream = Feed::stream(client, &url).await; metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "stream_open"); let mut stats_key = "stream"; let mut prev_token: Option = None; let mut token_tries = 0; while let Err(StreamError::HttpStatus(StatusCode::UNAUTHORIZED)) = &stream { if let Some(invalid_token) = prev_token { // If we tried with a token before but it's Unauthorized, delete it. tracing::warn!("Deleting invalid token for host {}: {}", host, invalid_token); let _ = db.delete_token(&host, &invalid_token).await; } if token_tries > 3 { break; } let token = db.get_token(&host).await .expect("db.get_token()"); if let Some(token) = &token { let url = format!("https://{}/api/v1/streaming/public?access_token={}", host, urlencoding::encode(token)); metrics::increment_gauge!("hunter_requests", 1.0, "type" => "stream_open_token"); stream = Feed::stream(client, &url).await; metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "stream_open_token"); stats_key = "stream_token"; } else { tracing::info!("No working token for {}", host); break; } prev_token = token; token_tries += 1; } if let Err(e) = &stream { tracing::error!("Error opening stream to {}: {}", host, e); } let stream = stream.map_err(|e| { format!("Stream error for {host}: {e}") })?; Ok((stats_key, stream.fold(0, move |post_count, post| { let message_tx = message_tx.clone(); let mut store = store.clone(); let posts_cache = posts_cache.clone(); let block_list = block_list.clone(); let host = host.clone(); async move { let (_, introduce_hosts) = process_posts(&mut store, &posts_cache, block_list, &host, [post].into_iter()).await; for introduce_host in introduce_hosts { message_tx.send(Message::IntroduceHost(introduce_host)).unwrap(); } post_count + 1 } }))) }