diff options
author | Ashelyn Rose <git@ashen.earth> | 2025-03-01 18:29:08 -0700 |
---|---|---|
committer | Ashelyn Rose <git@ashen.earth> | 2025-03-01 18:29:08 -0700 |
commit | df8e78aded7ce2c8653e81edeaaa026e7c44c713 (patch) | |
tree | 1d2cb57d8703327710ba0dd669321340bbb5adb4 | |
parent | 38be251b5f4ed4dabe21753451e23c233b62b6dd (diff) |
Edit command
-rw-r--r-- | src/system/mod.rs | 30 | ||||
-rw-r--r-- | src/system/plugin.rs | 23 | ||||
-rw-r--r-- | src/system/plugin/autoproxy.rs | 10 | ||||
-rw-r--r-- | src/system/plugin/edit.rs | 123 | ||||
-rw-r--r-- | src/system/plugin/prefixes.rs | 12 | ||||
-rw-r--r-- | src/system/types.rs | 36 |
6 files changed, 200 insertions, 34 deletions
diff --git a/src/system/mod.rs b/src/system/mod.rs index 4eea408..cb3f040 100644 --- a/src/system/mod.rs +++ b/src/system/mod.rs @@ -10,12 +10,13 @@ use twilight_gateway::{Intents, Shard, ShardId}; use twilight_http::Client; pub use types::SystemThreadCommand; use crate::SystemUiEvent; -use std::{num::NonZeroU64, sync::Arc}; +use std::{collections::HashMap, num::NonZeroU64, sync::Arc}; use tokio::sync::Mutex; use plugin::get_plugins; use std::sync::mpsc::Sender as ThreadSender; use types::{Member, Response, System}; +use std::iter::once; pub struct Manager; impl Manager { @@ -36,6 +37,7 @@ impl Manager { ))), client: Arc::new(Mutex::new(Client::new(member.discord_token.clone()))) }).collect(), + message_cache: Arc::new(Mutex::new(HashMap::new())), }; let mut message_receiver = aggregator::MessageAggregator::start(&system); @@ -63,16 +65,22 @@ impl Manager { Some(MemberEvent::Message(message, seen_by)) => { if let Some(command_string) = message.content.strip_prefix(&system.command_prefix) { let mut words = command_string.split_whitespace(); - if let Some(command) = words.next() { - if let Some(plugin) = by_command.get(command) { - logger.log_line(None, format!("Handling command: {command}")).await; - let args : Vec<_> = words.collect(); + if let Some(first_word) = words.next() { + if let Some((command, plugin)) = by_command + .get(first_word) + .map(|command| Some(command)) + .unwrap_or_else(|| by_command.get(first_word.get(0..1).unwrap())) { + logger.log_line(None, format!("Handling command: {command:?}")).await; + let args : Vec<_> = match command { + plugin::PluginCommand::Long(_) => words.collect(), + plugin::PluginCommand::Short(_) => once(first_word).chain(words).collect(), + }; - plugin.handle_command(&logger, &system, &message, args).await; - continue 'member_event; - } else { - logger.log_line(None, format!("Unknown command: {command}")).await; - } + plugin.handle_command(&logger, &system, &message, *command, args).await; + continue 'member_event; + } else { + logger.log_line(None, format!("Unknown command: {first_word}")).await; + } } } @@ -97,6 +105,8 @@ impl Manager { if let Err(err) = {member.client.lock().await.delete_message(message.channel_id, message.id).await.map(|_| ()).map_err(|err| err.to_string()).clone() } { logger.log_err(Some(member.discord_token), format!("Could not proxy message: {err}")).await; {let _ = member.client.lock().await.delete_message(new_message.channel_id, new_message.id).await;} + } else { + system.cache_most_recent_message(new_message.channel_id, new_message.clone(), member.clone()).await; } for plugin in &all_plugins { diff --git a/src/system/plugin.rs b/src/system/plugin.rs index 4419955..379606d 100644 --- a/src/system/plugin.rs +++ b/src/system/plugin.rs @@ -1,5 +1,5 @@ mod autoproxy; -// mod edit; +mod edit; // mod ghost; mod prefixes; // mod reproxy; @@ -8,6 +8,7 @@ use std::{collections::HashMap, sync::Arc}; use async_trait::async_trait; +use edit::Edit; use twilight_model::{channel::{Channel, Message}, id::{marker::ChannelMarker, Id}}; use super::log::Logger; @@ -16,28 +17,38 @@ use crate::system::types::{System, Response}; pub use prefixes::ProxyPrefixes; pub use autoproxy::Autoproxy; +#[derive(Copy, Clone, Debug)] +pub enum PluginCommand { + Long(&'static str), + Short(&'static str), +} + #[async_trait] pub trait SeancePlugin<'system> { - fn get_commands(&self) -> Vec<&'static str>; + fn get_commands(&self) -> Vec<PluginCommand>; - async fn handle_command<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, args: Vec<&'message str>); + async fn handle_command<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, command: PluginCommand, args: Vec<&'message str>); async fn handle_message<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, response: &'message mut Response); async fn post_response<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, channel: Id<ChannelMarker>, response: &'message Response); } -pub fn get_plugins<'system>() -> (Vec<Arc<Box<dyn SeancePlugin<'system>>>>, HashMap<&'static str, Arc<Box<dyn SeancePlugin<'system>>>>) { +pub fn get_plugins<'system>() -> (Vec<Arc<Box<dyn SeancePlugin<'system>>>>, HashMap<&'static str, (PluginCommand, Arc<Box<dyn SeancePlugin<'system>>>)>) { let all_plugins : Vec<Arc<Box<dyn SeancePlugin<'system>>>> = vec![ - Arc::new(Box::new(ProxyPrefixes)), + Arc::new(Box::new(ProxyPrefixes::new())), Arc::new(Box::new(Autoproxy::new())), + Arc::new(Box::new(Edit::new())), ]; let by_commands = all_plugins.iter() .map(|plugin| { let commands = plugin.get_commands(); commands.into_iter().map(|command| { - (command, plugin.clone()) + match command { + PluginCommand::Long(command_word) => (command_word, (command, plugin.clone())), + PluginCommand::Short(command_char) => (command_char, (command, plugin.clone())), + } }) }) .flatten() diff --git a/src/system/plugin/autoproxy.rs b/src/system/plugin/autoproxy.rs index 58f9668..24511c4 100644 --- a/src/system/plugin/autoproxy.rs +++ b/src/system/plugin/autoproxy.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use std::sync::Arc; use tokio::sync::Mutex; use twilight_model::{channel::{Channel, Message}, id::{marker::ChannelMarker, Id}, util::Timestamp}; -use crate::system::types::{System, Member, Response}; +use crate::system::{plugin::PluginCommand, types::{Member, Response, System}}; use super::SeancePlugin; use tokio::time::sleep; use std::time::Duration; @@ -23,18 +23,18 @@ enum InnerState { impl Autoproxy { pub fn new() -> Self { Self { - current_state: Arc::new(Mutex::new(InnerState::Off)) + current_state: Arc::new(Mutex::new(InnerState::LatchInactive)) } } } #[async_trait] impl<'system> SeancePlugin<'system> for Autoproxy { - fn get_commands(&self) -> Vec<&'static str> { - vec!["auto"] + fn get_commands(&self) -> Vec<PluginCommand> { + vec![PluginCommand::Long("auto")] } - async fn handle_command<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, args: Vec<&'message str>) { + async fn handle_command<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, _command: PluginCommand, args: Vec<&'message str>) { let mut args = args.iter().map(|r| *r); let first_word = args.next(); diff --git a/src/system/plugin/edit.rs b/src/system/plugin/edit.rs new file mode 100644 index 0000000..bbd5801 --- /dev/null +++ b/src/system/plugin/edit.rs @@ -0,0 +1,123 @@ +use async_trait::async_trait; +use regex::RegexBuilder; +use twilight_model::channel::message::MessageReference; +use twilight_model::id::{marker::ChannelMarker, Id}; +use twilight_model::channel::Message; +use crate::system::{log::Logger, types::{Response, System}}; + +use super::{PluginCommand, SeancePlugin}; + +pub struct Edit; + +impl Edit { + pub fn new() -> Self { + Self + } +} + +#[async_trait] +impl<'system> SeancePlugin<'system> for Edit { + fn get_commands(&self) -> Vec<PluginCommand> { + vec![ + PluginCommand::Long("edit"), + PluginCommand::Short("s"), + ] + } + + async fn handle_command<'message>(&self, logger: &'system Logger, system: &'system System, message: &'message Message, command: PluginCommand, args: Vec<&'message str>) { + let most_recent_message = system.get_most_recent_message(message.channel_id).await; + + if let Some((edit_target, update_most_recent)) = match message.kind { + twilight_model::channel::message::MessageType::Regular => most_recent_message.map(|(message, _)| (message, true)), + twilight_model::channel::message::MessageType::Reply => async {Some(( + message.referenced_message.as_ref()?.as_ref().clone(), + message.referenced_message.as_ref()?.id == most_recent_message?.0.id + ))}.await, + _ => todo!(), + } { + if let Some(authoring_member) = system.get_member_by_id(edit_target.author.id).await { + if let Some(edit_contents) = match command { + PluginCommand::Long("edit") => Some(args.join(" ")), + PluginCommand::Short("s") => async { + let replacement_command = args.join(" "); + let separator = replacement_command.chars().nth(1).unwrap(); + let parts: Vec<&str> = replacement_command.split(separator).collect(); + + if parts.len() != 3 && parts.len() != 4 { + None? + } + + let pattern = parts.get(1).unwrap(); + let replacement = parts.get(2).unwrap(); + let flags = parts.get(3).unwrap_or(&""); + + let mut global = false; + let mut regex = RegexBuilder::new(pattern); + + for flag in flags.chars() {match flag { + 'i' => {regex.case_insensitive(true);}, + 'm' => {regex.multi_line(true);}, + 'g' => {global = true;}, + 'x' => {regex.ignore_whitespace(true);}, + 'R' => {regex.crlf(true);}, + 's' => {regex.dot_matches_new_line(true);}, + 'U' => {regex.swap_greed(true);}, + other => { + logger.log_err(Some(authoring_member.discord_token.clone()), format!("Unknown regex flag {other}")).await; + None? + }, + }}; + + let valid_regex = regex.build(); + let original_content = edit_target.content; + + // If the regex parses, replace with that + let result = if let Ok(regex) = valid_regex { + if global { + regex.replace_all(original_content.as_str(), *replacement).to_string() + } else { + regex.replace(original_content.as_str(), *replacement).to_string() + } + + // Else attempt replace as string + } else { + original_content.replace(pattern, replacement) + }; + Some(result) + }.await, + _ => unreachable!("Unknown command"), + } { + let result = {authoring_member.client.lock().await.update_message(edit_target.channel_id, edit_target.id) + .content(Some(edit_contents.as_str())) + .expect("Invalid edit contents") + .await + .expect("Could not edit") + .model() + .await + .expect("Could not parse response")}; + + if update_most_recent { + system.cache_most_recent_message(result.channel_id, result, authoring_member.clone()).await; + } + + let _ = {authoring_member.client.lock().await.delete_message(message.channel_id, message.id).await}; + } else { + logger.log_err(None, format!("Could not determine edit contents")).await; + } + } else { + logger.log_err(None, format!("Cannot edit message not sent by system member")).await; + } + + } else { + logger.log_err(None, format!("Cannot find edit target")).await; + } + } + + async fn handle_message<'message>(&self, _logger: &'system Logger, _system: &'system System, _message: &'message Message, _response: &'message mut Response) { + // noop + } + + async fn post_response<'message>(&self, _logger: &'system Logger, _system: &'system System, _message: &'message Message, _channel: Id<ChannelMarker>, _response: &'message Response) { + // noop + } +} diff --git a/src/system/plugin/prefixes.rs b/src/system/plugin/prefixes.rs index 42d8631..55a48d3 100644 --- a/src/system/plugin/prefixes.rs +++ b/src/system/plugin/prefixes.rs @@ -3,17 +3,23 @@ use twilight_model::id::{marker::ChannelMarker, Id}; use twilight_model::channel::Message; use crate::system::{log::Logger, types::{Response, System}}; -use super::SeancePlugin; +use super::{PluginCommand, SeancePlugin}; pub struct ProxyPrefixes; +impl ProxyPrefixes { + pub fn new() -> Self { + Self + } +} + #[async_trait] impl<'system> SeancePlugin<'system> for ProxyPrefixes { - fn get_commands(&self) -> Vec<&'static str> { + fn get_commands(&self) -> Vec<PluginCommand> { vec![] } - async fn handle_command<'message>(&self, _logger: &'system Logger, _system: &'system System, _message: &'message Message, args: Vec<&'message str>) { + async fn handle_command<'message>(&self, _logger: &'system Logger, _system: &'system System, _message: &'message Message, _command: PluginCommand, _args: Vec<&'message str>) { unreachable!("Prefix plugin has no commands") } diff --git a/src/system/types.rs b/src/system/types.rs index c4da066..5aaf59c 100644 --- a/src/system/types.rs +++ b/src/system/types.rs @@ -2,8 +2,9 @@ use regex::Regex; use twilight_http::Client; use twilight_gateway::Shard; use twilight_mention::ParseMention; -use twilight_model::id::{marker::UserMarker, Id}; -use std::sync::Arc; +use twilight_model::channel::Message; +use twilight_model::id::{marker::{ChannelMarker, UserMarker}, Id}; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::Mutex; #[derive(Clone)] @@ -20,7 +21,8 @@ pub struct Member { pub struct System { pub followed_user: Id<UserMarker>, pub command_prefix: String, - pub members: Vec<Member> + pub members: Vec<Member>, + pub message_cache: Arc<Mutex<HashMap<Id<ChannelMarker>, (Member, Message)>>> } #[derive(Clone)] @@ -40,16 +42,30 @@ impl System { pub async fn resolve_mention<'system>(&'system self, maybe_mention: Option<&str>) -> Option<&'system Member> { if let Some(mention) = maybe_mention { if let Ok(mention) = Id::<UserMarker>::parse(mention) { - for member in &self.members { - let is_member = {member.user_id.lock().await.map(|id| id == mention).unwrap_or(false)}; - - if is_member { - return Some(&member) - } - } + return self.get_member_by_id(mention).await; } } None } + + pub async fn get_member_by_id<'system>(&'system self, search_id: Id<UserMarker>) -> Option<&'system Member> { + for member in &self.members { + let is_member = {member.user_id.lock().await.map(|id| id == search_id).unwrap_or(false)}; + + if is_member { + return Some(&member) + } + } + + return None + } + + pub async fn cache_most_recent_message(&self, channel: Id<ChannelMarker>, message: Message, member: Member) { + self.message_cache.lock().await.insert(channel, (member, message)); + } + + pub async fn get_most_recent_message(&self, channel: Id<ChannelMarker>) -> Option<(Message, Member)> { + self.message_cache.lock().await.get(&channel).map(|(member, message)| (message.clone(), member.clone())) + } } |