caveman/hunter/src/worker.rs

368 lines
12 KiB
Rust

use std::collections::HashSet;
use std::future::Future;
use std::sync::Arc;
use std::time::{Duration, Instant};
use cave::{
block_list::BlockList,
db::Database,
feed::{Feed, EncodablePost, Post, StreamError},
posts_cache::PostsCache,
store::Store,
};
use futures::{StreamExt, future, stream::iter};
use reqwest::StatusCode;
use crate::scheduler::{Host, InstanceHost};
use crate::webfinger;
#[derive(Clone)]
pub struct RobotsTxt {
robot: Arc<Option<texting_robots::Robot>>,
}
impl RobotsTxt {
pub async fn fetch(
client: &reqwest::Client,
host: &Host,
) -> Self {
let url = format!("https://{host}/robots.txt");
metrics::increment_gauge!("hunter_requests", 1.0, "type" => "robotstxt");
let robot = async {
let body = client.get(url)
.send().await
.ok()?
.text().await
.ok()?;
texting_robots::Robot::new(
env!("CARGO_PKG_NAME"),
body.as_bytes(),
).ok()
}.await;
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "robotstxt");
RobotsTxt {
robot: Arc::new(robot),
}
}
pub fn allowed(&self, url: &str) -> bool {
if let Some(ref robot) = self.robot.as_ref() {
robot.allowed(url)
} else {
true
}
}
pub fn delay(&self) -> Option<Duration> {
if let Some(ref robot) = self.robot.as_ref() {
robot.delay.map(|delay| Duration::from_secs(delay as u64))
} else {
None
}
}
}
#[derive(Debug)]
pub enum Message {
WorkerDone,
Fetched {
host: Host,
new_post_ratio: Option<f64>,
mean_interval: Option<Duration>,
},
IntroduceHost(InstanceHost),
}
pub async fn run(
message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
store: Store,
db: Database,
posts_cache: PostsCache,
block_list: BlockList,
client: reqwest::Client,
host: InstanceHost,
) {
// Fetch /robots.txt
let robots_txt = RobotsTxt::fetch(&client, &host.host).await;
let robots_delay = robots_txt.delay();
// Fetch posts and open stream
let (timeline_result, stream_result) = future::join(
fetch_timeline(
message_tx.clone(), store.clone(),
&posts_cache, block_list.clone(),
&client, robots_txt.clone(), &host.host
),
open_stream(
message_tx.clone(), store.clone(), db.clone(),
&posts_cache, block_list.clone(),
&client, robots_txt, host.host.clone()
),
).await;
if let Err(e) = &timeline_result {
tracing::error!("{}", e);
}
if let Err(e) = &stream_result {
tracing::error!("{}", e);
}
// If there is a web server responding, its Webfinger endpoint may point to another domain
match (&host.known_user, &timeline_result) {
(Some(known_user), Err(timeline_err)) if timeline_err.is_status() => {
metrics::increment_gauge!("hunter_requests", 1.0, "type" => "webfinger");
let webfinger_result = webfinger::get_hosts_from_webfinger(&client, known_user, &host.host).await;
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "webfinger");
if let Ok(hosts) = webfinger_result {
for host in hosts {
message_tx.send(Message::IntroduceHost(InstanceHost {
host,
known_user: None,
})).unwrap();
}
}
}
_ => {}
}
// Next worker can start
message_tx.send(Message::WorkerDone).unwrap();
// Process stream
let (mut new_post_ratio, mut mean_interval) = timeline_result.unwrap_or((None, None));
if let Ok((stats_key, stream)) = stream_result {
tracing::info!("Processing {stats_key} for {}", &host.host);
metrics::increment_gauge!("hunter_requests", 1.0, "type" => stats_key);
let start_time = Instant::now();
let post_count = stream.await;
let end_time = Instant::now();
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => stats_key);
tracing::warn!("Ended {stats_key} for {}. {} posts in {:?}", &host.host, post_count, end_time - start_time);
if post_count > 0 {
if let Some(ref mut new_post_ratio) = new_post_ratio {
*new_post_ratio += post_count as f64 / 100.;
}
let stream_avg_interval = Duration::from_secs_f64(
(end_time - start_time).as_secs_f64() / (post_count as f64)
);
if mean_interval.map_or(true, |mean_interval| stream_avg_interval < mean_interval) {
mean_interval = Some(stream_avg_interval);
}
}
}
// Ready for reenqueue
if let Some(mean_interval) = &mut mean_interval {
if let Some(robots_delay) = robots_delay {
*mean_interval = (*mean_interval).max(robots_delay);
}
}
message_tx.send(Message::Fetched {
host: host.host,
new_post_ratio,
mean_interval,
}).unwrap();
}
async fn fetch_timeline(
message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
mut store: Store,
posts_cache: &PostsCache,
block_list: BlockList,
client: &reqwest::Client,
robots_txt: RobotsTxt,
host: &Host,
) -> Result<(Option<f64>, Option<Duration>), reqwest::Error> {
let url = format!("https://{host}/api/v1/timelines/public?limit=40");
if ! robots_txt.allowed(&url) {
tracing::warn!("Timeline of {} forbidden by robots.txt", host);
return Ok((None, None));
}
// free as early as possible
drop(robots_txt);
metrics::increment_gauge!("hunter_requests", 1.0, "type" => "timeline");
let t1 = Instant::now();
let result = Feed::fetch(client, &url).await;
let t2 = Instant::now();
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "timeline");
metrics::histogram!("hunter_fetch_seconds", t2 - t1, "result" => if result.is_ok() { "ok" } else { "error" });
let feed = result?;
let mean_interval = feed.mean_post_interval();
let (new_post_ratio, introduce_hosts) = process_posts(&mut store, posts_cache, block_list, host, feed.posts.into_iter()).await;
for introduce_host in introduce_hosts {
message_tx.send(Message::IntroduceHost(introduce_host)).unwrap();
}
// successfully fetched, save for future run
store.save_host(host).await.unwrap();
Ok((new_post_ratio, mean_interval))
}
fn scan_for_hosts(introduce_hosts: &mut Vec<InstanceHost>, post: &Post) {
// introduce instances from accounts
if let Ok(host) = (&post.account).try_into() {
introduce_hosts.push(host);
}
// introduce instances from mentions
for mention in &post.mentions {
if let Ok(host) = mention.try_into() {
introduce_hosts.push(host);
}
}
}
async fn process_posts(
store: &mut Store,
posts_cache: &PostsCache,
block_list: BlockList,
host: &Host,
posts: impl Iterator<Item = EncodablePost>,
) -> (Option<f64>, Vec<InstanceHost>) {
// introduce new hosts, validate posts
let mut introduce_hosts = Vec::new();
let mut new_posts = 0;
let mut posts_len = 0;
for post in posts {
posts_len += 1;
// potentially save a round-trip to redis with an in-process cache
if ! posts_cache.insert(post.url.clone()) {
let t1 = Instant::now();
scan_for_hosts(&mut introduce_hosts, &post);
if let Some(reblog) = &post.reblog {
scan_for_hosts(&mut introduce_hosts, reblog);
}
if let Some(account_host) = post.account.host() {
if block_list.is_blocked(&account_host).await {
tracing::warn!("ignore post from blocked host {account_host}");
} else if store.save_post(post).await == Ok(true) {
// send away to redis
new_posts += 1;
}
} else {
tracing::warn!("drop repost ({:?} on {})", post.account.host(), host);
}
let t2 = Instant::now();
metrics::histogram!("hunter_post_process_seconds", t2 - t1);
}
}
tracing::trace!("{}: {}/{} new posts", host, new_posts, posts_len);
metrics::counter!("hunter_posts", new_posts, "type" => "new");
metrics::counter!("hunter_posts", posts_len, "type" => "total");
let new_post_ratio = if posts_len > 0 {
let ratio = (new_posts as f64) / (posts_len as f64);
metrics::histogram!("hunter_new_post_ratio", ratio);
Some(ratio)
} else {
None
};
// dedup introduce_hosts
let mut seen_hosts = HashSet::with_capacity(introduce_hosts.len());
let introduce_hosts = iter(
introduce_hosts.into_iter()
.filter_map(|introduce_host| {
if ! seen_hosts.contains(&introduce_host.host) {
seen_hosts.insert(introduce_host.host.clone());
Some(introduce_host)
} else {
None
}
})
)
.filter(|introduce_host| {
let block_list = block_list.clone();
let host = introduce_host.host.to_string();
async move {
! block_list.is_blocked(&host).await
}
})
.collect().await;
(new_post_ratio, introduce_hosts)
}
async fn open_stream(
message_tx: tokio::sync::mpsc::UnboundedSender<Message>,
store: Store,
db: Database,
posts_cache: &PostsCache,
block_list: BlockList,
client: &reqwest::Client,
robots_txt: RobotsTxt,
host: Host,
) -> Result<(&'static str, impl Future<Output = usize>), String> {
let url = format!("https://{host}/api/v1/streaming/public");
if ! robots_txt.allowed(&url) {
return Err(format!("Streaming of {host} forbidden by robots.txt"));
}
// free as early as possible
drop(robots_txt);
let posts_cache = posts_cache.clone();
metrics::increment_gauge!("hunter_requests", 1.0, "type" => "stream_open");
let mut stream = Feed::stream(client, &url).await;
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "stream_open");
let mut stats_key = "stream";
let mut prev_token: Option<String> = None;
let mut token_tries = 0;
while let Err(StreamError::HttpStatus(StatusCode::UNAUTHORIZED)) = &stream {
if let Some(invalid_token) = prev_token {
// If we tried with a token before but it's Unauthorized, delete it.
tracing::warn!("Deleting invalid token for host {}: {}", host, invalid_token);
let _ = db.delete_token(&host, &invalid_token).await;
}
if token_tries > 3 {
break;
}
let token = db.get_token(&host).await
.expect("db.get_token()");
if let Some(token) = &token {
let url = format!("https://{}/api/v1/streaming/public?access_token={}", host, urlencoding::encode(token));
metrics::increment_gauge!("hunter_requests", 1.0, "type" => "stream_open_token");
stream = Feed::stream(client, &url).await;
metrics::decrement_gauge!("hunter_requests", 1.0, "type" => "stream_open_token");
stats_key = "stream_token";
} else {
tracing::info!("No working token for {}", host);
break;
}
prev_token = token;
token_tries += 1;
}
if let Err(e) = &stream {
tracing::error!("Error opening stream to {}: {}", host, e);
}
let stream = stream.map_err(|e| {
format!("Stream error for {host}: {e}")
})?;
Ok((stats_key, stream.fold(0, move |post_count, post| {
let message_tx = message_tx.clone();
let mut store = store.clone();
let posts_cache = posts_cache.clone();
let block_list = block_list.clone();
let host = host.clone();
async move {
let (_, introduce_hosts) =
process_posts(&mut store, &posts_cache, block_list, &host, [post].into_iter()).await;
for introduce_host in introduce_hosts {
message_tx.send(Message::IntroduceHost(introduce_host)).unwrap();
}
post_count + 1
}
})))
}