summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorAshelyn Rose <git@ashen.earth>2024-07-09 12:35:58 -0600
committerAshelyn Rose <git@ashen.earth>2024-07-09 12:35:58 -0600
commit8903b0afca8a7e1b18b33c92f41051bd5ceb9b3d (patch)
tree1a9bee833c838a3922ce00c63c0c95eb370e9cd6 /src
parent0b90860d730b2fbd7ebe2b9c39084edd006f515d (diff)
handle attachments and replies
Diffstat (limited to 'src')
-rw-r--r--src/listener.rs39
-rw-r--r--src/system.rs97
2 files changed, 101 insertions, 35 deletions
diff --git a/src/listener.rs b/src/listener.rs
index 9d0e062..bc9868d 100644
--- a/src/listener.rs
+++ b/src/listener.rs
@@ -1,5 +1,6 @@
 use tokio::sync::mpsc::Sender;
-use twilight_model::{channel::ChannelMention, id::{Id, marker::{ChannelMarker, MessageMarker, UserMarker}}, user::User};
+use twilight_http::Client;
+use twilight_model::{channel::Message, id::{Id, marker::{ChannelMarker, MessageMarker, UserMarker}}, user::User};
 use twilight_gateway::{error::ReceiveMessageError, Intents, Shard, ShardId};
 use twilight_model::util::Timestamp;
 
@@ -20,6 +21,7 @@ impl Listener {
         let intents = Intents::GUILD_MEMBERS | Intents::GUILD_PRESENCES | Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT;
 
         let mut shard = Shard::new(ShardId::ONE, self.config.discord_token.clone(), intents);
+        let mut client = Client::new(self.config.discord_token.clone());
         
         loop {
             match shard.next_event().await {
@@ -29,40 +31,44 @@ impl Listener {
                             println!("Bot started for {}#{}", client.user.name, client.user.discriminator);
                         },
 
-                        twilight_gateway::Event::MessageCreate(message) => {
+                        twilight_gateway::Event::MessageCreate(message_create) => {
+                            let message = message_create.0;
+
                             if message.author.id != self.reference_user_id {
                                 continue
                             }
 
                             if let Err(_) = channel.send(ClientEvent::Message {
                                 event_time: message.timestamp,
-                                message_id: message.id,
-                                channel_id: message.channel_id,
-                                author: message.author.clone(),
-                                content: message.content.clone()
+                                message
                             }).await {
                                 println!("Client listener error: System context has already closed");
                                 return
                             }
                         },
 
-                        twilight_gateway::Event::MessageUpdate(message) => {
-                            if message.author.is_none() || message.author.as_ref().unwrap().id != self.reference_user_id {
+                        twilight_gateway::Event::MessageUpdate(message_update) => {
+                            if message_update.author.is_none() || message_update.author.as_ref().unwrap().id != self.reference_user_id {
                                 continue
                             }
 
-                            if message.edited_timestamp.is_none() {
+                            if message_update.edited_timestamp.is_none() {
                                 println!("Message update but no edit timestamp");
                                 continue;
                             }
 
+                            if message_update.content.is_none() {
+                                println!("Message update but no content");
+                                continue;
+                            }
+
+                            let message = client.message(message_update.channel_id, message_update.id)
+                                .await.expect("Could not load message")
+                                .model().await.expect("Could not deserialize message");
 
                             if let Err(_) = channel.send(ClientEvent::Message {
-                                event_time: message.edited_timestamp.unwrap(),
-                                message_id: message.id,
-                                channel_id: message.channel_id,
-                                author: message.author.unwrap(),
-                                content: message.content.unwrap()
+                                event_time: message_update.edited_timestamp.unwrap(),
+                                message,
                             }).await {
                                 println!("Client listener error: System context has already closed");
                                 return
@@ -91,10 +97,7 @@ impl Listener {
 pub enum ClientEvent {
     Message {
         event_time: Timestamp,
-        message_id: Id<MessageMarker>,
-        channel_id: Id<ChannelMarker>,
-        author: User,
-        content: String,
+        message: Message,
     },
     Error(ReceiveMessageError)
 }
diff --git a/src/system.rs b/src/system.rs
index 3fe4195..cf59f0d 100644
--- a/src/system.rs
+++ b/src/system.rs
@@ -1,9 +1,11 @@
 use std::{collections::HashMap, num::NonZeroUsize, str::FromStr};
 
 use tokio::sync::mpsc::channel;
+use futures::future::join_all;
 use twilight_http::Client;
-use twilight_model::id::{Id, marker::{ChannelMarker, MessageMarker, UserMarker}};
+use twilight_model::{channel::{message::MessageType, Message}, id::{Id, marker::{ChannelMarker, MessageMarker, UserMarker}}};
 use twilight_model::util::Timestamp;
+use twilight_model::http::attachment::Attachment;
 
 use crate::{config::{AutoproxyConfig, Member, MemberName}, listener::{Listener, ClientEvent}};
 
@@ -49,9 +51,9 @@ impl System {
         loop {
             match rx.recv().await {
                 Some(event) => match event {
-                    ClientEvent::Message { event_time, message_id, channel_id, content, author: _ } => {
-                        if self.is_new_message(message_id, event_time) {
-                            self.handle_message(message_id, channel_id, content, event_time).await;
+                    ClientEvent::Message { event_time, message } => {
+                        if self.is_new_message(message.id, event_time) {
+                            self.handle_message(message, event_time).await;
                         }
                     },
                     ClientEvent::Error(_err) => {
@@ -78,7 +80,7 @@ impl System {
         }
     }
 
-    async fn handle_message(&mut self, message_id: Id<MessageMarker>, channel_id: Id<ChannelMarker>, content: String, timestamp: Timestamp) {
+    async fn handle_message(&mut self, message: Message, timestamp: Timestamp) {
         // Check for command
         // TODO: Commands
         // TODO: Escaping
@@ -86,9 +88,9 @@ impl System {
         // TODO: Non-latching prefixes maybe?
         
         // Check for prefix
-        let match_prefix = self.config.members.iter().find_map(|member| Some((member, member.matches_proxy_prefix(&content)?)));
+        let match_prefix = self.config.members.iter().find_map(|member| Some((member, member.matches_proxy_prefix(&message)?)));
         if let Some((member, matched_content)) = match_prefix {
-            self.proxy_message(message_id, channel_id, member, matched_content).await;
+            self.proxy_message(&message, member, matched_content).await;
             self.update_autoproxy_state_after_message(member.clone(), timestamp);
             return
         }
@@ -99,7 +101,7 @@ impl System {
             match autoproxy_config {
                 AutoproxyConfig::Member {name} => {
                     let member = self.config.members.iter().find(|member| member.name == *name).expect("Invalid autoproxy member name");
-                    self.proxy_message(message_id, channel_id, member, content.as_str()).await;
+                    self.proxy_message(&message, member, message.content.as_str()).await;
                 },
                 // TODO: Do something with the latch scope
                 // TODO: Do something with presence setting
@@ -107,7 +109,7 @@ impl System {
                     if let Some((member, last_timestamp)) = &self.latch_state {
                         let time_since_last = timestamp.as_secs() - last_timestamp.as_secs();
                         if time_since_last <= (*timeout_seconds).into() {
-                            self.proxy_message(message_id, channel_id, &member, content.as_str()).await;
+                            self.proxy_message(&message, &member, message.content.as_str()).await;
                             self.latch_state = Some((member.clone(), timestamp));
                         }
                     }
@@ -116,13 +118,43 @@ impl System {
         }
     }
 
-    async fn proxy_message(&self, message_id: Id<MessageMarker>, channel_id: Id<ChannelMarker>, member: &Member, content: &str) {
+    async fn proxy_message(&self, message: &Message, member: &Member, content: &str) {
         let client = self.clients.get(&member.name).expect("No client for member");
 
-        if let Ok(_) = client.create_message(channel_id)
-            .content(content).expect("Cannot set content").await {
-                client.delete_message(channel_id, message_id).await.expect("Could not delete message");
-            }
+        if let Ok(_) = self.duplicate_message(message, client, content).await {
+            client.delete_message(message.channel_id, message.id).await.expect("Could not delete message");
+        }
+    }
+
+    async fn duplicate_message(&self, message: &Message, client: &Client, content: &str) -> Result<Message, MessageDuplicateError> {
+        let mut create_message = client.create_message(message.channel_id)
+            .content(content)?;
+
+        if message.kind == MessageType::Reply {
+            create_message = create_message.reply(
+                message.referenced_message.as_ref().expect("Message was reply but no referenced message").id
+            );
+        }
+
+        let attachments = join_all(message.attachments.iter().map(|attachment| async {
+            let filename = attachment.filename.clone();
+            let description = attachment.description.clone();
+            let bytes = reqwest::get(attachment.proxy_url.clone()).await?.bytes().await?;
+
+            // TODO: keep description
+            Ok(Attachment::from_bytes(filename, bytes.try_into().unwrap(), attachment.id.into()))
+        })).await.iter().filter_map(|result: &Result<Attachment, MessageDuplicateError>| match result {
+            Ok(attachment) => Some(attachment.clone()),
+            Err(_) => None,
+        }).collect::<Vec<_>>();
+
+        if attachments.len() > 0 {
+            create_message = create_message.attachments(attachments.as_slice())?;
+        }
+
+        let new_message = create_message.await?.model().await?;
+
+        Ok(new_message)
     }
 
     fn update_autoproxy_state_after_message(&mut self, member: Member, timestamp: Timestamp) {
@@ -139,13 +171,13 @@ impl System {
 
 
 impl crate::config::Member {
-    pub fn matches_proxy_prefix<'a>(&self, content: &'a String) -> Option<&'a str> {
-        match self.message_pattern.captures(content.as_str()) {
+    pub fn matches_proxy_prefix<'a>(&self, message: &'a Message) -> Option<&'a str> {
+        match self.message_pattern.captures(message.content.as_str()) {
             None => None,
             Some(captures) => {
                 let full_match = captures.get(0).unwrap();
 
-                if full_match.len() != content.len() {
+                if full_match.len() != message.content.len() {
                     return None
                 }
 
@@ -157,3 +189,34 @@ impl crate::config::Member {
         }
     }
 }
+
+enum MessageDuplicateError {
+    MessageValidation(twilight_validate::message::MessageValidationError),
+    AttachmentRequest(reqwest::Error),
+    MessageCreate(twilight_http::error::Error),
+    ResponseDeserialization(twilight_http::response::DeserializeBodyError)
+}
+
+impl From<twilight_validate::message::MessageValidationError> for MessageDuplicateError {
+    fn from(value: twilight_validate::message::MessageValidationError) -> Self {
+        MessageDuplicateError::MessageValidation(value)
+    }
+}
+
+impl From<reqwest::Error> for MessageDuplicateError {
+    fn from(value: reqwest::Error) -> Self {
+        MessageDuplicateError::AttachmentRequest(value)
+    }
+}
+
+impl From<twilight_http::error::Error> for MessageDuplicateError {
+    fn from(value: twilight_http::error::Error) -> Self {
+        MessageDuplicateError::MessageCreate(value)
+    }
+}
+
+impl From<twilight_http::response::DeserializeBodyError> for MessageDuplicateError {
+    fn from(value: twilight_http::response::DeserializeBodyError) -> Self {
+        MessageDuplicateError::ResponseDeserialization(value)
+    }
+}