summary refs log tree commit diff
path: root/src/system/aggregator.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/system/aggregator.rs')
-rw-r--r--src/system/aggregator.rs60
1 files changed, 40 insertions, 20 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs
index 6873272..00ba8e8 100644
--- a/src/system/aggregator.rs
+++ b/src/system/aggregator.rs
@@ -1,47 +1,67 @@
 use lru::LruCache;
+use std::sync::Arc;
+use tokio::sync::RwLock;
 use std::num::NonZeroUsize;
 use tokio::sync::mpsc::{channel, Receiver, Sender};
 use twilight_model::channel::Message as TwiMessage;
 
 use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent};
 
-pub struct MessageAggregator {
+pub struct AggregatorState {
     rx: Receiver<MessageEvent>,
     tx: Sender<MessageEvent>,
     message_cache: lru::LruCache<MessageId, TwiMessage>,
     system_emitter: Option<Sender<SystemEvent>>,
 }
 
+pub struct MessageAggregator {
+    state: Arc<RwLock<AggregatorState>>,
+}
+
 impl MessageAggregator {
     pub fn new() -> Self {
         let (tx, rx) = channel::<MessageEvent>(100);
 
         Self {
-            tx,
-            rx,
-            message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
-            system_emitter: None,
+            state: Arc::new(RwLock::new( AggregatorState {
+                tx,
+                rx,
+                message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
+                system_emitter: None,
+
+            }))
         }
     }
 
-    pub fn get_sender(&self) -> Sender<MessageEvent> {
-        self.tx.clone()
+    pub async fn get_sender(&self) -> Sender<MessageEvent> {
+        self.state.read().await.tx.clone()
     }
 
-    pub fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () {
-        self.system_emitter = Some(emitter);
+    pub async fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () {
+        self.state.write().await.system_emitter = Some(emitter);
     }
 
-    pub fn start(mut self) -> () {
+    pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> {
+        self.state.write().await.message_cache.get(&message_id).map(|m| m.clone())
+    }
+
+    pub fn start(&self) -> () {
+        let state = self.state.clone();
+
         tokio::spawn(async move {
             loop {
-                match self.rx.recv().await {
+                let system_emitter = { state.read().await.system_emitter.clone().expect("No system emitter") };
+                let self_emitter = { state.read().await.tx.clone() };
+                let next_event = { state.write().await.rx.recv().await };
+
+
+                match next_event {
                     None => (),
                     Some((timestamp, message)) => {
-                        let system_emitter = &self.system_emitter.clone().expect("No system emitter");
                         match message {
                             GatewayMessage::Partial(current_partial, member_id) => {
-                                match self.message_cache.get(&current_partial.id) {
+                                let cache_content = { state.write().await.message_cache.get(&current_partial.id).map(|m| m.clone()) };
+                                match cache_content {
                                     Some(original_message) => {
 
                                         let mut updated_message = original_message.clone();
@@ -49,11 +69,11 @@ impl MessageAggregator {
                                             updated_message.edited_timestamp = Some(edited_time);
                                         }
 
-                                        if let Some(content) = current_partial.content {
-                                            updated_message.content = content
+                                        if let Some(content) = &current_partial.content {
+                                            updated_message.content = content.clone()
                                         }
 
-                                        self.tx.send((timestamp, GatewayMessage::Complete(updated_message))).await;
+                                        self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await;
                                     },
                                     None => {
                                         system_emitter.send(
@@ -63,9 +83,9 @@ impl MessageAggregator {
                                 };
                             },
                             GatewayMessage::Complete(message) => {
-                                let previous_message = self.message_cache.get(&message.id);
+                                let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) };
 
-                                if let Some(previous_message) = previous_message.cloned() {
+                                if let Some(previous_message) = previous_message {
                                     let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp);
                                     let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp);
 
@@ -77,9 +97,9 @@ impl MessageAggregator {
                                     // If not, fall through to update stored message
                                 }
 
-                                self.message_cache.put(message.id, message.clone());
+                                { state.write().await.message_cache.put(message.id, message.clone()); };
 
-                                self.system_emitter.as_ref().expect("Aggregator has no system emitter")
+                                system_emitter
                                     .send(SystemEvent::NewMessage(timestamp, message))
                                     .await;
                             },