65 lines
1.8 KiB
Rust
65 lines
1.8 KiB
Rust
use std::collections::HashMap;
|
|
use std::hash::Hash;
|
|
use std::sync::{Arc, Mutex};
|
|
use tokio::sync::oneshot;
|
|
|
|
#[axum::async_trait]
|
|
pub trait Muxable {
|
|
type Key: Clone + Eq + Hash;
|
|
type Result: Clone;
|
|
|
|
async fn request(&self, key: Self::Key) -> Self::Result;
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct RequestMux<T: Muxable + Clone> {
|
|
muxable: Arc<T>,
|
|
pending: Arc<Mutex<HashMap<T::Key, Vec<oneshot::Sender<T::Result>>>>>,
|
|
}
|
|
|
|
// unsafe impl<T: Muxable + Clone> Send for RequestMux<T> {}
|
|
// unsafe impl<T: Muxable + Clone> Sync for RequestMux<T> {}
|
|
|
|
impl<T: Muxable + Clone> RequestMux<T> {
|
|
pub fn new(muxable: T) -> Self {
|
|
RequestMux {
|
|
muxable: Arc::new(muxable),
|
|
pending: Arc::new(Mutex::new(HashMap::new())),
|
|
}
|
|
}
|
|
|
|
pub async fn request(&self, key: T::Key) -> T::Result {
|
|
let rx = if let Some(txs) = self.pending.lock().unwrap().get_mut(&key) {
|
|
let (tx, rx) = oneshot::channel();
|
|
// get in line
|
|
txs.push(tx);
|
|
// drop the reference
|
|
drop(txs);
|
|
Some(rx)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if let Some(rx) = rx {
|
|
rx.await.unwrap()
|
|
} else {
|
|
// create queue
|
|
self.pending.lock().unwrap().insert(key.clone(), vec![]);
|
|
// run async request
|
|
let result = self.muxable.request(key.clone()).await;
|
|
// remove and obtain queue
|
|
let txs = self.pending.lock().unwrap().remove(&key)
|
|
.unwrap_or_else(|| {
|
|
log::warn!("RequestMux txs has vanished");
|
|
vec![]
|
|
});
|
|
// notify secondary requesters
|
|
for tx in txs.into_iter() {
|
|
let _ = tx.send(result.clone());
|
|
}
|
|
// return result
|
|
result
|
|
}
|
|
}
|
|
}
|