summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/system/aggregator.rs29
-rw-r--r--src/system/bot/client.rs16
-rw-r--r--src/system/bot/gateway.rs2
-rw-r--r--src/system/bot/mod.rs8
-rw-r--r--src/system/message_parser.rs6
-rw-r--r--src/system/mod.rs47
-rw-r--r--src/system/types.rs4
7 files changed, 74 insertions, 38 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs
index 00ba8e8..8fbdfdd 100644
--- a/src/system/aggregator.rs
+++ b/src/system/aggregator.rs
@@ -5,12 +5,12 @@ use std::num::NonZeroUsize;
 use tokio::sync::mpsc::{channel, Receiver, Sender};
 use twilight_model::channel::Message as TwiMessage;
 
-use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent};
+use super::{MemberId, Message as GatewayMessage, MessageEvent, MessageId, SystemEvent};
 
 pub struct AggregatorState {
     rx: Receiver<MessageEvent>,
     tx: Sender<MessageEvent>,
-    message_cache: lru::LruCache<MessageId, TwiMessage>,
+    message_cache: lru::LruCache<MessageId, (TwiMessage, MemberId)>,
     system_emitter: Option<Sender<SystemEvent>>,
 }
 
@@ -19,14 +19,14 @@ pub struct MessageAggregator {
 }
 
 impl MessageAggregator {
-    pub fn new() -> Self {
-        let (tx, rx) = channel::<MessageEvent>(100);
+    pub fn new(system_size: usize) -> Self {
+        let (tx, rx) = channel::<MessageEvent>(system_size * 2);
 
         Self {
             state: Arc::new(RwLock::new( AggregatorState {
                 tx,
                 rx,
-                message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
+                message_cache: LruCache::new(NonZeroUsize::new(system_size * 2).unwrap()),
                 system_emitter: None,
 
             }))
@@ -41,9 +41,10 @@ impl MessageAggregator {
         self.state.write().await.system_emitter = Some(emitter);
     }
 
-    pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> {
-        self.state.write().await.message_cache.get(&message_id).map(|m| m.clone())
-    }
+    // We probably don't actully need this since we've got a separate sent-cache by channel
+    // pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> {
+    //     self.state.write().await.message_cache.get(&message_id).map(|m| m.clone())
+    // }
 
     pub fn start(&self) -> () {
         let state = self.state.clone();
@@ -62,7 +63,7 @@ impl MessageAggregator {
                             GatewayMessage::Partial(current_partial, member_id) => {
                                 let cache_content = { state.write().await.message_cache.get(&current_partial.id).map(|m| m.clone()) };
                                 match cache_content {
-                                    Some(original_message) => {
+                                    Some((original_message, member_id)) => {
 
                                         let mut updated_message = original_message.clone();
                                         if let Some(edited_time) = current_partial.edited_timestamp {
@@ -73,7 +74,7 @@ impl MessageAggregator {
                                             updated_message.content = content.clone()
                                         }
 
-                                        self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await;
+                                        self_emitter.send((timestamp, GatewayMessage::Complete(updated_message, member_id))).await;
                                     },
                                     None => {
                                         system_emitter.send(
@@ -82,10 +83,10 @@ impl MessageAggregator {
                                     },
                                 };
                             },
-                            GatewayMessage::Complete(message) => {
+                            GatewayMessage::Complete(message, member_id) => {
                                 let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) };
 
-                                if let Some(previous_message) = previous_message {
+                                if let Some((previous_message, _last_seen_by)) = previous_message {
                                     let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp);
                                     let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp);
 
@@ -97,10 +98,10 @@ impl MessageAggregator {
                                     // If not, fall through to update stored message
                                 }
 
-                                { state.write().await.message_cache.put(message.id, message.clone()); };
+                                { state.write().await.message_cache.put(message.id, (message.clone(), member_id)); };
 
                                 system_emitter
-                                    .send(SystemEvent::NewMessage(timestamp, message))
+                                    .send(SystemEvent::NewMessage(timestamp, message, member_id))
                                     .await;
                             },
                         };
diff --git a/src/system/bot/client.rs b/src/system/bot/client.rs
index 006ce8f..61d7515 100644
--- a/src/system/bot/client.rs
+++ b/src/system/bot/client.rs
@@ -23,18 +23,22 @@ impl Client {
         }
     }
 
-    pub async fn refetch_message(&self, message_id: MessageId, channel_id: ChannelId) {
+    pub async fn fetch_message(&self, message_id: MessageId, channel_id: ChannelId) -> FullMessage {
         let client = self.client.lock().await;
-        let bot_conf = self.bot_conf.read().await;
-        let message_channel = bot_conf.message_handler.as_ref().expect("No message handler");
 
-        let message = client
+        client
             .message(channel_id, message_id)
             .await
             .expect("Could not load message")
             .model()
             .await
-            .expect("Could not deserialize message");
+            .expect("Could not deserialize message")
+    }
+
+    pub async fn resend_message(&self, message_id: MessageId, channel_id: ChannelId) {
+        let bot_conf = self.bot_conf.read().await;
+        let message = self.fetch_message(message_id, channel_id).await;
+        let message_channel = bot_conf.message_handler.as_ref().expect("No message handler");
 
         let timestamp = if message.edited_timestamp.is_some() {
             message.edited_timestamp.unwrap()
@@ -43,7 +47,7 @@ impl Client {
         };
 
         message_channel
-            .send((timestamp, Message::Complete(message)))
+            .send((timestamp, Message::Complete(message, bot_conf.member_id)))
             .await;
     }
 
diff --git a/src/system/bot/gateway.rs b/src/system/bot/gateway.rs
index 5a45083..5a343d1 100644
--- a/src/system/bot/gateway.rs
+++ b/src/system/bot/gateway.rs
@@ -94,7 +94,7 @@ impl Gateway {
                             }
 
                             message_channel
-                                .send((message.timestamp, Message::Complete(message)))
+                                .send((message.timestamp, Message::Complete(message, bot_conf.member_id)))
                                 .await;
                         }
 
diff --git a/src/system/bot/mod.rs b/src/system/bot/mod.rs
index 3c4585f..2f38075 100644
--- a/src/system/bot/mod.rs
+++ b/src/system/bot/mod.rs
@@ -66,8 +66,12 @@ impl Bot {
         self.gateway.start_listening()
     }
 
-    pub async fn refetch_message(&self, message_id: MessageId, channel_id: ChannelId) {
-        self.client.refetch_message(message_id, channel_id).await;
+    pub async fn fetch_message(&self, message_id: MessageId, channel_id: ChannelId) -> TwiMessage {
+        self.client.fetch_message(message_id, channel_id).await
+    }
+
+    pub async fn resend_message(&self, message_id: MessageId, channel_id: ChannelId) {
+        self.client.resend_message(message_id, channel_id).await;
     }
 
     pub async fn delete_message(&self, channel_id: ChannelId, message_id: MessageId) -> Result<(), TwiError> {
diff --git a/src/system/message_parser.rs b/src/system/message_parser.rs
index b404064..01bfaee 100644
--- a/src/system/message_parser.rs
+++ b/src/system/message_parser.rs
@@ -36,7 +36,7 @@ static CORRECTION_REGEX: LazyLock<Regex> = LazyLock::new(|| {
 });
 
 impl MessageParser {
-    pub fn parse(message: &FullMessage, secondary_message: Option<&FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> ParsedMessage {
+    pub fn parse(message: &FullMessage, secondary_message: Option<FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> ParsedMessage {
         if message.content == r"\\" {
             return ParsedMessage::LatchClear(if let Some((member_id, _)) = latch_state {
                 member_id
@@ -73,13 +73,13 @@ impl MessageParser {
         ParsedMessage::UnproxiedMessage
     }
 
-    fn parse_command(message: &FullMessage, secondary_message: Option<&FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> Command {
+    fn parse_command(message: &FullMessage, secondary_message: Option<FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> Command {
         
         // If unable to parse
         Command::UnknownCommand
     }
 
-    fn check_correction(message: &FullMessage, secondary_message: Option<&FullMessage>) -> Option<ParsedMessage> {
+    fn check_correction(message: &FullMessage, secondary_message: Option<FullMessage>) -> Option<ParsedMessage> {
         None
     }
 
diff --git a/src/system/mod.rs b/src/system/mod.rs
index 3c6064b..bd27f4a 100644
--- a/src/system/mod.rs
+++ b/src/system/mod.rs
@@ -1,11 +1,12 @@
-use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
+use std::{collections::HashMap, num::NonZeroUsize, str::FromStr, time::Duration};
 
+use lru::LruCache;
 use tokio::{
-    sync::{mpsc::{channel, Sender}, RwLock},
+    sync::mpsc::{channel, Sender},
     time::sleep,
 };
 use twilight_http::request::channel::reaction::RequestReactionType;
-use twilight_model::{channel::message::ReactionType, id::{marker::UserMarker, Id}};
+use twilight_model::{channel::message::{MessageReference, MessageType, ReactionType}, id::{marker::UserMarker, Id}};
 use twilight_model::util::Timestamp;
 
 use crate::config::{AutoproxyConfig, AutoproxyLatchScope, Member};
@@ -30,6 +31,7 @@ pub struct Manager {
     pub latch_state: Option<(MemberId, Timestamp)>,
     pub system_sender: Option<Sender<SystemEvent>>,
     pub aggregator: MessageAggregator,
+    pub send_cache: LruCache<ChannelId, TwiMessage>,
     pub reference_user_id: UserId,
 }
 
@@ -38,12 +40,13 @@ impl Manager {
         Self {
             reference_user_id: Id::from_str(&system_config.reference_user_id.as_str())
                 .expect(format!("Invalid user id for system {}", &system_name).as_str()),
+            aggregator: MessageAggregator::new(system_config.members.len()),
             name: system_name,
             config: system_config,
             bots: HashMap::new(),
             latch_state: None,
             system_sender: None,
-            aggregator: MessageAggregator::new(),
+            send_cache: LruCache::new(NonZeroUsize::new(15).unwrap()),
         }
     }
 
@@ -101,13 +104,13 @@ impl Manager {
                     self.start_bot(member_id).await;
                 }
 
-                Some(SystemEvent::NewMessage(event_time, message)) => {
-                    self.handle_message(message, event_time).await;
+                Some(SystemEvent::NewMessage(event_time, message, member_id)) => {
+                    self.handle_message(message, event_time, member_id).await;
                 }
 
                 Some(SystemEvent::RefetchMessage(member_id, message_id, channel_id)) => {
                     let bot = self.bots.get(&member_id).unwrap();
-                    bot.refetch_message(message_id, channel_id).await;
+                    bot.resend_message(message_id, channel_id).await;
                 }
 
                 Some(SystemEvent::AutoproxyTimeout(time_scheduled)) => {
@@ -160,8 +163,28 @@ impl Manager {
         });
     }
 
-    async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp) {
-        let parsed_message = MessageParser::parse(&message, None, &self.config, self.latch_state);
+    async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp, seen_by: MemberId) {
+        // let bot = self.bots.get(&seen_by).expect("No client for member");
+        let last_in_channel = self.send_cache.get(&message.channel_id);
+        let replied_message = if let MessageType::Reply = message.kind {
+            message.referenced_message.clone()
+        } else {
+            None
+        };
+
+        if let None = last_in_channel {
+            println!("ERROR: Could not look up last sent message in channel {}", message.channel_id);
+        }
+
+        let ref_message = if replied_message.is_some() {
+            replied_message.map(|m| *m)
+        } else if last_in_channel.is_some() {
+            last_in_channel.map(|m| m.clone())
+        } else {
+            None
+        };
+
+        let parsed_message = MessageParser::parse(&message, ref_message, &self.config, self.latch_state);
 
         match parsed_message {
             message_parser::ParsedMessage::UnproxiedMessage => (),
@@ -195,7 +218,7 @@ impl Manager {
         }
     }
 
-    async fn proxy_message(&self, message: &TwiMessage, member: MemberId, content: &str) -> Result<(), ()> {
+    async fn proxy_message(&mut self, message: &TwiMessage, member: MemberId, content: &str) -> Result<(), ()> {
         let bot = self.bots.get(&member).expect("No client for member");
 
         let duplicate_result = bot.duplicate_message(message, content).await;
@@ -216,6 +239,10 @@ impl Manager {
             return Err(())
         }
 
+        // Sent successfully, add to send cache
+        let sent_message = duplicate_result.unwrap();
+        self.send_cache.put(sent_message.channel_id, sent_message);
+
         Ok(())
     }
 
diff --git a/src/system/types.rs b/src/system/types.rs
index 483fbbd..89eb189 100644
--- a/src/system/types.rs
+++ b/src/system/types.rs
@@ -14,7 +14,7 @@ pub type Status = twilight_model::gateway::presence::Status;
 
 #[derive(Clone)]
 pub enum Message {
-    Complete(FullMessage),
+    Complete(FullMessage, MemberId),
     Partial(PartialMessage, MemberId),
 }
 
@@ -31,7 +31,7 @@ pub enum SystemEvent {
     UpdateClientStatus(MemberId),
 
     // User event handling
-    NewMessage(Timestamp, FullMessage),
+    NewMessage(Timestamp, FullMessage, MemberId),
     EditedMessage(MessageEvent),
     NewReaction(ReactionEvent),