summary refs log tree commit diff
path: root/src/system/mod.rs
diff options
context:
space:
mode:
authorAshelyn Dawn <ashe@ashen.earth>2024-10-01 15:37:09 -0600
committerAshelyn Dawn <ashe@ashen.earth>2024-10-01 15:37:09 -0600
commita6a120ae8b8ed08b0801d76e80a5f7a0b8cde44b (patch)
tree37242703c61137d1198bc373558828b9faa0732b /src/system/mod.rs
parentf87e9727e69e981e4acff31a779e29a35637b964 (diff)
Refactor gateway and message aggregation
Diffstat (limited to 'src/system/mod.rs')
-rw-r--r--src/system/mod.rs500
1 files changed, 500 insertions, 0 deletions
diff --git a/src/system/mod.rs b/src/system/mod.rs
new file mode 100644
index 0000000..ce23466
--- /dev/null
+++ b/src/system/mod.rs
@@ -0,0 +1,500 @@
+use std::{collections::HashMap, str::FromStr, time::Duration};
+
+use futures::future::join_all;
+use tokio::{
+    sync::mpsc::{channel, Sender},
+    time::sleep,
+};
+use twilight_http::Client;
+use twilight_model::http::attachment::Attachment;
+use twilight_model::util::Timestamp;
+use twilight_model::{
+    channel::{
+        message::{AllowedMentions, MentionType, MessageType},
+        Message,
+    },
+    id::{marker::UserMarker, Id},
+};
+
+use crate::config::{AutoproxyConfig, AutoproxyLatchScope, Member};
+
+mod aggregator;
+mod gateway;
+mod types;
+use aggregator::MessageAggregator;
+use gateway::Gateway;
+pub use types::*;
+
+pub struct Manager {
+    pub name: String,
+    pub config: crate::config::System,
+    pub clients: HashMap<MemberId, Client>,
+    pub gateways: HashMap<MemberId, Gateway>,
+    pub latch_state: Option<(MemberId, Timestamp)>,
+    pub last_presence: HashMap<MemberId, Status>,
+    pub system_sender: Option<Sender<SystemEvent>>,
+}
+
+impl Manager {
+    pub fn new(system_name: String, system_config: crate::config::System) -> Self {
+        Self {
+            name: system_name,
+            config: system_config,
+            clients: HashMap::new(),
+            gateways: HashMap::new(),
+            latch_state: None,
+            last_presence: HashMap::new(),
+            system_sender: None,
+        }
+    }
+
+    pub fn find_member_by_name<'a>(
+        &'a self,
+        name: &String,
+    ) -> Option<(MemberId, &'a crate::config::Member)> {
+        self.config
+            .members
+            .iter()
+            .enumerate()
+            .find(|(_member_id, member)| member.name == *name)
+    }
+
+    pub fn find_member_by_id<'a>(&'a self, id: MemberId) -> Option<&'a Member> {
+        self.config
+            .members
+            .iter()
+            .enumerate()
+            .find(|(member_id, _)| *member_id == id)
+            .map_or(None, |(_member_id, member)| Some(member))
+    }
+
+    pub async fn start_clients(&mut self) {
+        println!("Starting clients for system {}", self.name);
+
+        let reference_user_id: Id<UserMarker> =
+            Id::from_str(self.config.reference_user_id.as_str())
+                .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str());
+
+        let (system_sender, mut system_receiver) = channel::<SystemEvent>(100);
+        self.system_sender = Some(system_sender.clone());
+        let mut aggregator = MessageAggregator::new();
+        aggregator.set_handler(system_sender.clone());
+
+        for (member_id, member) in self.config.members.iter().enumerate() {
+            // Create outgoing client
+            let client = twilight_http::Client::new(member.discord_token.clone());
+            self.clients.insert(member_id, client);
+
+            // Create gateway listener
+            let mut listener = Gateway::new(member_id, &member, reference_user_id);
+
+            listener.set_message_handler(aggregator.get_sender());
+            listener.set_system_handler(system_sender.clone());
+
+            // Start gateway listener
+            listener.start_listening();
+            self.gateways.insert(member_id, listener);
+        }
+
+        aggregator.start();
+
+        let mut num_connected = 0;
+
+        loop {
+            match system_receiver.recv().await {
+                Some(event) => match event {
+                    SystemEvent::GatewayConnected(member_id) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+
+                        num_connected += 1;
+                        println!(
+                            "Gateway client {} ({}) connected",
+                            num_connected, member.name
+                        );
+
+                        if num_connected == self.config.members.len() {
+                            let system_sender = system_sender.clone();
+                            tokio::spawn(async move {
+                                println!("All gateways connected");
+                                sleep(Duration::from_secs(5)).await;
+                                let _ = system_sender.send(SystemEvent::AllGatewaysConnected).await;
+                            });
+                        }
+                    }
+                    SystemEvent::GatewayClosed(member_id) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+
+                        println!("Gateway client {} closed", member.name);
+
+                        num_connected -= 1;
+                    }
+                    SystemEvent::NewMessage((event_time, message)) => {
+                        self.handle_message(message, event_time).await;
+                    }
+                    SystemEvent::GatewayError(member_id, message) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+                        println!("Gateway client {} ran into error {}", member.name, message);
+                        return;
+                    }
+                    SystemEvent::AutoproxyTimeout(time_scheduled) => {
+                        if let Some((_member, current_last_message)) = self.latch_state.clone() {
+                            if current_last_message == time_scheduled {
+                                println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs());
+                                self.latch_state = None;
+                                self.update_status_of_system().await;
+                            }
+                        }
+                    }
+                    SystemEvent::AllGatewaysConnected => {
+                        println!(
+                            "Attempting to set startup status for system {}",
+                            self.name.clone()
+                        );
+                        self.update_status_of_system().await;
+                    }
+                    _ => (),
+                },
+                None => return,
+            }
+        }
+    }
+
+    async fn handle_message(&mut self, message: Message, timestamp: Timestamp) {
+        // TODO: Commands
+        if message.content.eq("!panic") {
+            panic!("Exiting due to user command");
+        }
+
+        // Escape sequence
+        if message.content.starts_with(r"\") {
+            if message.content == r"\\" {
+                let client = if let Some((current_member, _)) = self.latch_state.clone() {
+                    self.clients
+                        .get(&current_member)
+                        .expect(format!("No client for member {}", current_member).as_str())
+                } else {
+                    self.clients.iter().next().expect("No clients!").1
+                };
+
+                client
+                    .delete_message(message.channel_id, message.id)
+                    .await
+                    .expect("Could not delete message");
+                self.latch_state = None
+            } else if message.content.starts_with(r"\\") {
+                self.latch_state = None;
+            }
+
+            return;
+        }
+
+        // TODO: Non-latching prefixes maybe?
+
+        // Check for prefix
+        println!("Checking prefix");
+        let match_prefix =
+            self.config
+                .members
+                .iter()
+                .enumerate()
+                .find_map(|(member_id, member)| {
+                    Some((member_id, member.matches_proxy_prefix(&message)?))
+                });
+        if let Some((member_id, matched_content)) = match_prefix {
+            self.proxy_message(&message, member_id, matched_content)
+                .await;
+            println!("Updating proxy state to member id {}", member_id);
+            self.update_autoproxy_state_after_message(member_id, timestamp);
+            self.update_status_of_system().await;
+            return;
+        }
+
+        // Check for autoproxy
+        println!("Checking autoproxy");
+        if let Some(autoproxy_config) = &self.config.autoproxy {
+            match autoproxy_config {
+                AutoproxyConfig::Member { name } => {
+                    let (member_id, _member) = self
+                        .find_member_by_name(&name)
+                        .expect("Invalid autoproxy member name");
+                    self.proxy_message(&message, member_id, message.content.as_str())
+                        .await;
+                }
+                // TODO: Do something with the latch scope
+                // TODO: Do something with presence setting
+                AutoproxyConfig::Latch {
+                    scope,
+                    timeout_seconds,
+                    presence_indicator,
+                } => {
+                    println!("Currently in latch mode");
+                    if let Some((member, last_timestamp)) = self.latch_state.clone() {
+                        println!("We have a latch state");
+                        let time_since_last = timestamp.as_secs() - last_timestamp.as_secs();
+                        println!("Time since last (seconds) {}", time_since_last);
+                        if time_since_last <= (*timeout_seconds).into() {
+                            println!("Proxying");
+                            self.proxy_message(&message, member, message.content.as_str())
+                                .await;
+                            self.latch_state = Some((member, timestamp));
+                            self.update_autoproxy_state_after_message(member, timestamp);
+                            self.update_status_of_system().await;
+                        }
+                    }
+                }
+            }
+        } else {
+            println!("No autoproxy config?");
+        }
+    }
+
+    async fn proxy_message(&self, message: &Message, member: MemberId, content: &str) {
+        let client = self.clients.get(&member).expect("No client for member");
+
+        if let Err(err) = self.duplicate_message(message, client, content).await {
+            match err {
+                MessageDuplicateError::MessageCreate(err) => {
+                    if err.to_string().contains("Cannot send an empty message") {
+                        client
+                            .delete_message(message.channel_id, message.id)
+                            .await
+                            .expect("Could not delete message");
+                    }
+                }
+                _ => println!("Error: {:?}", err),
+            }
+        } else {
+            client
+                .delete_message(message.channel_id, message.id)
+                .await
+                .expect("Could not delete message");
+        }
+    }
+
+    async fn duplicate_message(
+        &self,
+        message: &Message,
+        client: &Client,
+        content: &str,
+    ) -> Result<Message, MessageDuplicateError> {
+        let mut create_message = client.create_message(message.channel_id).content(content)?;
+
+        let mut allowed_mentions = AllowedMentions {
+            parse: Vec::new(),
+            replied_user: false,
+            roles: message.mention_roles.clone(),
+            users: message.mentions.iter().map(|user| user.id).collect(),
+        };
+
+        if message.mention_everyone {
+            allowed_mentions.parse.push(MentionType::Everyone);
+        }
+
+        if message.kind == MessageType::Reply {
+            if let Some(ref_message) = message.referenced_message.as_ref() {
+                create_message = create_message.reply(ref_message.id);
+
+                let pings_referenced_author = message
+                    .mentions
+                    .iter()
+                    .any(|user| user.id == ref_message.author.id);
+
+                if pings_referenced_author {
+                    allowed_mentions.replied_user = true;
+                } else {
+                    allowed_mentions.replied_user = false;
+                }
+            } else {
+                panic!("Cannot proxy message: Was reply but no referenced message");
+            }
+        }
+
+        let attachments = join_all(message.attachments.iter().map(|attachment| async {
+            let filename = attachment.filename.clone();
+            let description_opt = attachment.description.clone();
+            let bytes = reqwest::get(attachment.proxy_url.clone())
+                .await?
+                .bytes()
+                .await?;
+            let mut new_attachment =
+                Attachment::from_bytes(filename, bytes.try_into().unwrap(), attachment.id.into());
+
+            if let Some(description) = description_opt {
+                new_attachment.description(description);
+            }
+
+            Ok(new_attachment)
+        }))
+        .await
+        .iter()
+        .filter_map(
+            |result: &Result<Attachment, MessageDuplicateError>| match result {
+                Ok(attachment) => Some(attachment.clone()),
+                Err(_) => None,
+            },
+        )
+        .collect::<Vec<_>>();
+
+        if attachments.len() > 0 {
+            create_message = create_message.attachments(attachments.as_slice())?;
+        }
+
+        if let Some(flags) = message.flags {
+            create_message = create_message.flags(flags);
+        }
+
+        create_message = create_message.allowed_mentions(Some(&allowed_mentions));
+        let new_message = create_message.await?.model().await?;
+
+        Ok(new_message)
+    }
+
+    fn update_autoproxy_state_after_message(&mut self, member: MemberId, timestamp: Timestamp) {
+        match &self.config.autoproxy {
+            None => (),
+            Some(AutoproxyConfig::Member { name: _ }) => (),
+            Some(AutoproxyConfig::Latch {
+                scope,
+                timeout_seconds,
+                presence_indicator: _,
+            }) => {
+                self.latch_state = Some((member, timestamp));
+
+                if let Some(channel) = self.system_sender.clone() {
+                    let last_message = timestamp.clone();
+                    let timeout_seconds = timeout_seconds.clone();
+
+                    tokio::spawn(async move {
+                        sleep(Duration::from_secs(timeout_seconds.into())).await;
+                        channel
+                            .send(SystemEvent::AutoproxyTimeout(last_message))
+                            .await
+                            .expect("Channel has closed");
+                    });
+                }
+            }
+        }
+    }
+
+    async fn update_status_of_system(&mut self) {
+        let member_states: Vec<(MemberId, Status)> = self
+            .config
+            .members
+            .iter()
+            .enumerate()
+            .map(|(member_id, member)| {
+                (
+                    member_id,
+                    match &self.config.autoproxy {
+                        None => Status::Invisible,
+                        Some(AutoproxyConfig::Member { name }) => {
+                            if member.name == *name {
+                                Status::Online
+                            } else {
+                                Status::Invisible
+                            }
+                        }
+                        Some(AutoproxyConfig::Latch {
+                            scope,
+                            timeout_seconds: _,
+                            presence_indicator,
+                        }) => {
+                            if let AutoproxyLatchScope::Server = scope {
+                                Status::Invisible
+                            } else if !presence_indicator {
+                                Status::Invisible
+                            } else {
+                                match &self.latch_state {
+                                    Some((latch_member, _last_timestamp)) => {
+                                        if member_id == *latch_member {
+                                            Status::Online
+                                        } else {
+                                            Status::Invisible
+                                        }
+                                    }
+                                    None => Status::Invisible,
+                                }
+                            }
+                        }
+                    },
+                )
+            })
+            .collect();
+
+        for (member, status) in member_states {
+            self.update_status_of_member(member, status).await;
+        }
+    }
+
+    async fn update_status_of_member(&mut self, member: MemberId, status: Status) {
+        let last_status = *self.last_presence.get(&member).unwrap_or(&Status::Offline);
+
+        if status == last_status {
+            return;
+        }
+
+        if let Some(gateway) = self.gateways.get(&member) {
+            gateway.set_status(status).await;
+
+            self.last_presence.insert(member, status);
+        } else {
+            let full_member = self
+                .find_member_by_id(member)
+                .expect("Cannot look up member");
+            println!(
+                "Could not look up gateway for member ID {} ({})",
+                member, full_member.name
+            );
+        }
+    }
+}
+
+impl crate::config::Member {
+    pub fn matches_proxy_prefix<'a>(&self, message: &'a Message) -> Option<&'a str> {
+        match self.message_pattern.captures(message.content.as_str()) {
+            None => None,
+            Some(captures) => match captures.name("content") {
+                None => None,
+                Some(matched_content) => Some(matched_content.as_str()),
+            },
+        }
+    }
+}
+
+#[derive(Debug)]
+enum MessageDuplicateError {
+    MessageValidation(twilight_validate::message::MessageValidationError),
+    AttachmentRequest(reqwest::Error),
+    MessageCreate(twilight_http::error::Error),
+    ResponseDeserialization(twilight_http::response::DeserializeBodyError),
+}
+
+impl From<twilight_validate::message::MessageValidationError> for MessageDuplicateError {
+    fn from(value: twilight_validate::message::MessageValidationError) -> Self {
+        MessageDuplicateError::MessageValidation(value)
+    }
+}
+
+impl From<reqwest::Error> for MessageDuplicateError {
+    fn from(value: reqwest::Error) -> Self {
+        MessageDuplicateError::AttachmentRequest(value)
+    }
+}
+
+impl From<twilight_http::error::Error> for MessageDuplicateError {
+    fn from(value: twilight_http::error::Error) -> Self {
+        MessageDuplicateError::MessageCreate(value)
+    }
+}
+
+impl From<twilight_http::response::DeserializeBodyError> for MessageDuplicateError {
+    fn from(value: twilight_http::response::DeserializeBodyError) -> Self {
+        MessageDuplicateError::ResponseDeserialization(value)
+    }
+}