diff options
Diffstat (limited to 'src/system/aggregator.rs')
-rw-r--r-- | src/system/aggregator.rs | 60 |
1 files changed, 40 insertions, 20 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs index 6873272..00ba8e8 100644 --- a/src/system/aggregator.rs +++ b/src/system/aggregator.rs @@ -1,47 +1,67 @@ use lru::LruCache; +use std::sync::Arc; +use tokio::sync::RwLock; use std::num::NonZeroUsize; use tokio::sync::mpsc::{channel, Receiver, Sender}; use twilight_model::channel::Message as TwiMessage; use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent}; -pub struct MessageAggregator { +pub struct AggregatorState { rx: Receiver<MessageEvent>, tx: Sender<MessageEvent>, message_cache: lru::LruCache<MessageId, TwiMessage>, system_emitter: Option<Sender<SystemEvent>>, } +pub struct MessageAggregator { + state: Arc<RwLock<AggregatorState>>, +} + impl MessageAggregator { pub fn new() -> Self { let (tx, rx) = channel::<MessageEvent>(100); Self { - tx, - rx, - message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()), - system_emitter: None, + state: Arc::new(RwLock::new( AggregatorState { + tx, + rx, + message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()), + system_emitter: None, + + })) } } - pub fn get_sender(&self) -> Sender<MessageEvent> { - self.tx.clone() + pub async fn get_sender(&self) -> Sender<MessageEvent> { + self.state.read().await.tx.clone() } - pub fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () { - self.system_emitter = Some(emitter); + pub async fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () { + self.state.write().await.system_emitter = Some(emitter); } - pub fn start(mut self) -> () { + pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> { + self.state.write().await.message_cache.get(&message_id).map(|m| m.clone()) + } + + pub fn start(&self) -> () { + let state = self.state.clone(); + tokio::spawn(async move { loop { - match self.rx.recv().await { + let system_emitter = { state.read().await.system_emitter.clone().expect("No system emitter") }; + let self_emitter = { state.read().await.tx.clone() }; + let next_event = { state.write().await.rx.recv().await }; + + + match next_event { None => (), Some((timestamp, message)) => { - let system_emitter = &self.system_emitter.clone().expect("No system emitter"); match message { GatewayMessage::Partial(current_partial, member_id) => { - match self.message_cache.get(¤t_partial.id) { + let cache_content = { state.write().await.message_cache.get(¤t_partial.id).map(|m| m.clone()) }; + match cache_content { Some(original_message) => { let mut updated_message = original_message.clone(); @@ -49,11 +69,11 @@ impl MessageAggregator { updated_message.edited_timestamp = Some(edited_time); } - if let Some(content) = current_partial.content { - updated_message.content = content + if let Some(content) = ¤t_partial.content { + updated_message.content = content.clone() } - self.tx.send((timestamp, GatewayMessage::Complete(updated_message))).await; + self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await; }, None => { system_emitter.send( @@ -63,9 +83,9 @@ impl MessageAggregator { }; }, GatewayMessage::Complete(message) => { - let previous_message = self.message_cache.get(&message.id); + let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) }; - if let Some(previous_message) = previous_message.cloned() { + if let Some(previous_message) = previous_message { let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp); let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp); @@ -77,9 +97,9 @@ impl MessageAggregator { // If not, fall through to update stored message } - self.message_cache.put(message.id, message.clone()); + { state.write().await.message_cache.put(message.id, message.clone()); }; - self.system_emitter.as_ref().expect("Aggregator has no system emitter") + system_emitter .send(SystemEvent::NewMessage(timestamp, message)) .await; }, |