diff options
author | Ashelyn Rose <git@ashen.earth> | 2024-10-05 18:02:12 -0600 |
---|---|---|
committer | Ashelyn Rose <git@ashen.earth> | 2024-10-05 18:02:12 -0600 |
commit | ca4b0e32be531053f19ce2895b994130b247af4a (patch) | |
tree | bdfef248c6917109b6a768bed9733b0999ae0e2f /src/system | |
parent | 4fa4907c3da23249ddec2bcb50e48f708152059e (diff) |
Add sent-message cache for looking up command targets
Diffstat (limited to 'src/system')
-rw-r--r-- | src/system/aggregator.rs | 29 | ||||
-rw-r--r-- | src/system/bot/client.rs | 16 | ||||
-rw-r--r-- | src/system/bot/gateway.rs | 2 | ||||
-rw-r--r-- | src/system/bot/mod.rs | 8 | ||||
-rw-r--r-- | src/system/message_parser.rs | 6 | ||||
-rw-r--r-- | src/system/mod.rs | 47 | ||||
-rw-r--r-- | src/system/types.rs | 4 |
7 files changed, 74 insertions, 38 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs index 00ba8e8..8fbdfdd 100644 --- a/src/system/aggregator.rs +++ b/src/system/aggregator.rs @@ -5,12 +5,12 @@ use std::num::NonZeroUsize; use tokio::sync::mpsc::{channel, Receiver, Sender}; use twilight_model::channel::Message as TwiMessage; -use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent}; +use super::{MemberId, Message as GatewayMessage, MessageEvent, MessageId, SystemEvent}; pub struct AggregatorState { rx: Receiver<MessageEvent>, tx: Sender<MessageEvent>, - message_cache: lru::LruCache<MessageId, TwiMessage>, + message_cache: lru::LruCache<MessageId, (TwiMessage, MemberId)>, system_emitter: Option<Sender<SystemEvent>>, } @@ -19,14 +19,14 @@ pub struct MessageAggregator { } impl MessageAggregator { - pub fn new() -> Self { - let (tx, rx) = channel::<MessageEvent>(100); + pub fn new(system_size: usize) -> Self { + let (tx, rx) = channel::<MessageEvent>(system_size * 2); Self { state: Arc::new(RwLock::new( AggregatorState { tx, rx, - message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()), + message_cache: LruCache::new(NonZeroUsize::new(system_size * 2).unwrap()), system_emitter: None, })) @@ -41,9 +41,10 @@ impl MessageAggregator { self.state.write().await.system_emitter = Some(emitter); } - pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> { - self.state.write().await.message_cache.get(&message_id).map(|m| m.clone()) - } + // We probably don't actully need this since we've got a separate sent-cache by channel + // pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> { + // self.state.write().await.message_cache.get(&message_id).map(|m| m.clone()) + // } pub fn start(&self) -> () { let state = self.state.clone(); @@ -62,7 +63,7 @@ impl MessageAggregator { GatewayMessage::Partial(current_partial, member_id) => { let cache_content = { state.write().await.message_cache.get(¤t_partial.id).map(|m| m.clone()) }; match cache_content { - Some(original_message) => { + Some((original_message, member_id)) => { let mut updated_message = original_message.clone(); if let Some(edited_time) = current_partial.edited_timestamp { @@ -73,7 +74,7 @@ impl MessageAggregator { updated_message.content = content.clone() } - self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await; + self_emitter.send((timestamp, GatewayMessage::Complete(updated_message, member_id))).await; }, None => { system_emitter.send( @@ -82,10 +83,10 @@ impl MessageAggregator { }, }; }, - GatewayMessage::Complete(message) => { + GatewayMessage::Complete(message, member_id) => { let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) }; - if let Some(previous_message) = previous_message { + if let Some((previous_message, _last_seen_by)) = previous_message { let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp); let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp); @@ -97,10 +98,10 @@ impl MessageAggregator { // If not, fall through to update stored message } - { state.write().await.message_cache.put(message.id, message.clone()); }; + { state.write().await.message_cache.put(message.id, (message.clone(), member_id)); }; system_emitter - .send(SystemEvent::NewMessage(timestamp, message)) + .send(SystemEvent::NewMessage(timestamp, message, member_id)) .await; }, }; diff --git a/src/system/bot/client.rs b/src/system/bot/client.rs index 006ce8f..61d7515 100644 --- a/src/system/bot/client.rs +++ b/src/system/bot/client.rs @@ -23,18 +23,22 @@ impl Client { } } - pub async fn refetch_message(&self, message_id: MessageId, channel_id: ChannelId) { + pub async fn fetch_message(&self, message_id: MessageId, channel_id: ChannelId) -> FullMessage { let client = self.client.lock().await; - let bot_conf = self.bot_conf.read().await; - let message_channel = bot_conf.message_handler.as_ref().expect("No message handler"); - let message = client + client .message(channel_id, message_id) .await .expect("Could not load message") .model() .await - .expect("Could not deserialize message"); + .expect("Could not deserialize message") + } + + pub async fn resend_message(&self, message_id: MessageId, channel_id: ChannelId) { + let bot_conf = self.bot_conf.read().await; + let message = self.fetch_message(message_id, channel_id).await; + let message_channel = bot_conf.message_handler.as_ref().expect("No message handler"); let timestamp = if message.edited_timestamp.is_some() { message.edited_timestamp.unwrap() @@ -43,7 +47,7 @@ impl Client { }; message_channel - .send((timestamp, Message::Complete(message))) + .send((timestamp, Message::Complete(message, bot_conf.member_id))) .await; } diff --git a/src/system/bot/gateway.rs b/src/system/bot/gateway.rs index 5a45083..5a343d1 100644 --- a/src/system/bot/gateway.rs +++ b/src/system/bot/gateway.rs @@ -94,7 +94,7 @@ impl Gateway { } message_channel - .send((message.timestamp, Message::Complete(message))) + .send((message.timestamp, Message::Complete(message, bot_conf.member_id))) .await; } diff --git a/src/system/bot/mod.rs b/src/system/bot/mod.rs index 3c4585f..2f38075 100644 --- a/src/system/bot/mod.rs +++ b/src/system/bot/mod.rs @@ -66,8 +66,12 @@ impl Bot { self.gateway.start_listening() } - pub async fn refetch_message(&self, message_id: MessageId, channel_id: ChannelId) { - self.client.refetch_message(message_id, channel_id).await; + pub async fn fetch_message(&self, message_id: MessageId, channel_id: ChannelId) -> TwiMessage { + self.client.fetch_message(message_id, channel_id).await + } + + pub async fn resend_message(&self, message_id: MessageId, channel_id: ChannelId) { + self.client.resend_message(message_id, channel_id).await; } pub async fn delete_message(&self, channel_id: ChannelId, message_id: MessageId) -> Result<(), TwiError> { diff --git a/src/system/message_parser.rs b/src/system/message_parser.rs index b404064..01bfaee 100644 --- a/src/system/message_parser.rs +++ b/src/system/message_parser.rs @@ -36,7 +36,7 @@ static CORRECTION_REGEX: LazyLock<Regex> = LazyLock::new(|| { }); impl MessageParser { - pub fn parse(message: &FullMessage, secondary_message: Option<&FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> ParsedMessage { + pub fn parse(message: &FullMessage, secondary_message: Option<FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> ParsedMessage { if message.content == r"\\" { return ParsedMessage::LatchClear(if let Some((member_id, _)) = latch_state { member_id @@ -73,13 +73,13 @@ impl MessageParser { ParsedMessage::UnproxiedMessage } - fn parse_command(message: &FullMessage, secondary_message: Option<&FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> Command { + fn parse_command(message: &FullMessage, secondary_message: Option<FullMessage>, system_config: &System, latch_state: Option<(MemberId, Timestamp)>) -> Command { // If unable to parse Command::UnknownCommand } - fn check_correction(message: &FullMessage, secondary_message: Option<&FullMessage>) -> Option<ParsedMessage> { + fn check_correction(message: &FullMessage, secondary_message: Option<FullMessage>) -> Option<ParsedMessage> { None } diff --git a/src/system/mod.rs b/src/system/mod.rs index 3c6064b..bd27f4a 100644 --- a/src/system/mod.rs +++ b/src/system/mod.rs @@ -1,11 +1,12 @@ -use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; +use std::{collections::HashMap, num::NonZeroUsize, str::FromStr, time::Duration}; +use lru::LruCache; use tokio::{ - sync::{mpsc::{channel, Sender}, RwLock}, + sync::mpsc::{channel, Sender}, time::sleep, }; use twilight_http::request::channel::reaction::RequestReactionType; -use twilight_model::{channel::message::ReactionType, id::{marker::UserMarker, Id}}; +use twilight_model::{channel::message::{MessageReference, MessageType, ReactionType}, id::{marker::UserMarker, Id}}; use twilight_model::util::Timestamp; use crate::config::{AutoproxyConfig, AutoproxyLatchScope, Member}; @@ -30,6 +31,7 @@ pub struct Manager { pub latch_state: Option<(MemberId, Timestamp)>, pub system_sender: Option<Sender<SystemEvent>>, pub aggregator: MessageAggregator, + pub send_cache: LruCache<ChannelId, TwiMessage>, pub reference_user_id: UserId, } @@ -38,12 +40,13 @@ impl Manager { Self { reference_user_id: Id::from_str(&system_config.reference_user_id.as_str()) .expect(format!("Invalid user id for system {}", &system_name).as_str()), + aggregator: MessageAggregator::new(system_config.members.len()), name: system_name, config: system_config, bots: HashMap::new(), latch_state: None, system_sender: None, - aggregator: MessageAggregator::new(), + send_cache: LruCache::new(NonZeroUsize::new(15).unwrap()), } } @@ -101,13 +104,13 @@ impl Manager { self.start_bot(member_id).await; } - Some(SystemEvent::NewMessage(event_time, message)) => { - self.handle_message(message, event_time).await; + Some(SystemEvent::NewMessage(event_time, message, member_id)) => { + self.handle_message(message, event_time, member_id).await; } Some(SystemEvent::RefetchMessage(member_id, message_id, channel_id)) => { let bot = self.bots.get(&member_id).unwrap(); - bot.refetch_message(message_id, channel_id).await; + bot.resend_message(message_id, channel_id).await; } Some(SystemEvent::AutoproxyTimeout(time_scheduled)) => { @@ -160,8 +163,28 @@ impl Manager { }); } - async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp) { - let parsed_message = MessageParser::parse(&message, None, &self.config, self.latch_state); + async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp, seen_by: MemberId) { + // let bot = self.bots.get(&seen_by).expect("No client for member"); + let last_in_channel = self.send_cache.get(&message.channel_id); + let replied_message = if let MessageType::Reply = message.kind { + message.referenced_message.clone() + } else { + None + }; + + if let None = last_in_channel { + println!("ERROR: Could not look up last sent message in channel {}", message.channel_id); + } + + let ref_message = if replied_message.is_some() { + replied_message.map(|m| *m) + } else if last_in_channel.is_some() { + last_in_channel.map(|m| m.clone()) + } else { + None + }; + + let parsed_message = MessageParser::parse(&message, ref_message, &self.config, self.latch_state); match parsed_message { message_parser::ParsedMessage::UnproxiedMessage => (), @@ -195,7 +218,7 @@ impl Manager { } } - async fn proxy_message(&self, message: &TwiMessage, member: MemberId, content: &str) -> Result<(), ()> { + async fn proxy_message(&mut self, message: &TwiMessage, member: MemberId, content: &str) -> Result<(), ()> { let bot = self.bots.get(&member).expect("No client for member"); let duplicate_result = bot.duplicate_message(message, content).await; @@ -216,6 +239,10 @@ impl Manager { return Err(()) } + // Sent successfully, add to send cache + let sent_message = duplicate_result.unwrap(); + self.send_cache.put(sent_message.channel_id, sent_message); + Ok(()) } diff --git a/src/system/types.rs b/src/system/types.rs index 483fbbd..89eb189 100644 --- a/src/system/types.rs +++ b/src/system/types.rs @@ -14,7 +14,7 @@ pub type Status = twilight_model::gateway::presence::Status; #[derive(Clone)] pub enum Message { - Complete(FullMessage), + Complete(FullMessage, MemberId), Partial(PartialMessage, MemberId), } @@ -31,7 +31,7 @@ pub enum SystemEvent { UpdateClientStatus(MemberId), // User event handling - NewMessage(Timestamp, FullMessage), + NewMessage(Timestamp, FullMessage, MemberId), EditedMessage(MessageEvent), NewReaction(ReactionEvent), |