summary refs log tree commit diff
path: root/src/system
diff options
context:
space:
mode:
authorAshelyn Dawn <ashe@ashen.earth>2024-10-01 15:37:09 -0600
committerAshelyn Dawn <ashe@ashen.earth>2024-10-01 15:37:09 -0600
commita6a120ae8b8ed08b0801d76e80a5f7a0b8cde44b (patch)
tree37242703c61137d1198bc373558828b9faa0732b /src/system
parentf87e9727e69e981e4acff31a779e29a35637b964 (diff)
Refactor gateway and message aggregation
Diffstat (limited to 'src/system')
-rw-r--r--src/system/aggregator.rs61
-rw-r--r--src/system/gateway.rs167
-rw-r--r--src/system/mod.rs500
-rw-r--r--src/system/types.rs33
4 files changed, 761 insertions, 0 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs
new file mode 100644
index 0000000..0177249
--- /dev/null
+++ b/src/system/aggregator.rs
@@ -0,0 +1,61 @@
+use lru::LruCache;
+use std::num::NonZeroUsize;
+use tokio::sync::mpsc::{channel, Receiver, Sender};
+
+use super::{MessageEvent, MessageId, SystemEvent};
+
+pub struct MessageAggregator {
+    rx: Receiver<MessageEvent>,
+    tx: Sender<MessageEvent>,
+    message_cache: lru::LruCache<MessageId, MessageEvent>,
+    emitter: Option<Sender<SystemEvent>>,
+}
+
+impl MessageAggregator {
+    pub fn new() -> Self {
+        let (tx, rx) = channel::<MessageEvent>(100);
+
+        Self {
+            tx,
+            rx,
+            message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
+            emitter: None,
+        }
+    }
+
+    pub fn get_sender(&self) -> Sender<MessageEvent> {
+        self.tx.clone()
+    }
+
+    pub fn set_handler(&mut self, emitter: Sender<SystemEvent>) -> () {
+        self.emitter = Some(emitter);
+    }
+
+    pub fn start(mut self) -> () {
+        tokio::spawn(async move {
+            loop {
+                match self.rx.recv().await {
+                    None => return,
+                    Some((timestamp, message)) => {
+                        let last_seen_timestamp = self.message_cache.get(&message.id);
+                        let current_timestamp = timestamp;
+
+                        if last_seen_timestamp.is_none()
+                            || last_seen_timestamp.unwrap().0.as_micros()
+                                < current_timestamp.as_micros()
+                        {
+                            self.message_cache
+                                .put(message.id, (timestamp, message.clone()));
+
+                            if let Some(emitter) = &self.emitter {
+                                emitter
+                                    .send(SystemEvent::NewMessage((timestamp, message)))
+                                    .await;
+                            }
+                        };
+                    }
+                }
+            }
+        });
+    }
+}
diff --git a/src/system/gateway.rs b/src/system/gateway.rs
new file mode 100644
index 0000000..17bfe3d
--- /dev/null
+++ b/src/system/gateway.rs
@@ -0,0 +1,167 @@
+use std::sync::Arc;
+use tokio::sync::mpsc::Sender;
+use tokio::sync::Mutex;
+use twilight_gateway::{Intents, Shard, ShardId};
+use twilight_http::Client;
+use twilight_model::gateway::{
+    payload::outgoing::{update_presence::UpdatePresencePayload, UpdatePresence},
+    OpCode,
+};
+
+use super::{MemberId, MessageEvent, Status, SystemEvent, UserId};
+
+pub struct Gateway {
+    member_id: MemberId,
+    discord_token: String,
+    reference_user_id: UserId,
+    message_handler: Option<Arc<Mutex<Sender<MessageEvent>>>>,
+    system_handler: Option<Arc<Mutex<Sender<SystemEvent>>>>,
+    shard: Arc<Mutex<Shard>>,
+}
+
+impl Gateway {
+    pub fn new(
+        member_id: MemberId,
+        config: &crate::config::Member,
+        reference_user_id: UserId,
+    ) -> Self {
+        let intents = Intents::GUILD_MEMBERS
+            | Intents::GUILD_PRESENCES
+            | Intents::GUILD_MESSAGES
+            | Intents::MESSAGE_CONTENT;
+
+        Self {
+            member_id,
+            discord_token: config.discord_token.clone(),
+            reference_user_id,
+            message_handler: None,
+            system_handler: None,
+            shard: Arc::new(Mutex::new(Shard::new(
+                ShardId::ONE,
+                config.discord_token.clone(),
+                intents,
+            ))),
+        }
+    }
+
+    pub fn set_message_handler(&mut self, handler: Sender<MessageEvent>) {
+        self.message_handler = Some(Arc::new(Mutex::new(handler)));
+    }
+
+    pub fn set_system_handler(&mut self, handler: Sender<SystemEvent>) {
+        self.system_handler = Some(Arc::new(Mutex::new(handler)));
+    }
+
+    pub async fn set_status(&self, status: Status) {
+        let mut shard = self.shard.lock().await;
+
+        shard
+            .command(&UpdatePresence {
+                d: UpdatePresencePayload {
+                    activities: Vec::new(),
+                    afk: false,
+                    since: None,
+                    status,
+                },
+                op: OpCode::PresenceUpdate,
+            })
+            .await
+            .expect("Could not send command to gateway");
+    }
+
+    pub fn start_listening(&self) {
+        let message_channel = self.message_handler.clone();
+        let system_channel = self.system_handler.clone();
+        let shard = self.shard.clone();
+        let member_id = self.member_id.clone();
+        let reference_user_id = self.reference_user_id.clone();
+        let client = Client::new(self.discord_token.clone());
+
+        tokio::spawn(async move {
+            loop {
+                let next_event = { shard.lock().await.next_event().await };
+
+                match next_event {
+                    Err(source) => {
+                        if let Some(channel) = &system_channel {
+                            let channel = channel.lock().await;
+
+                            channel
+                                .send(SystemEvent::GatewayError(member_id, source.to_string()))
+                                .await;
+
+                            if source.is_fatal() {
+                                channel.send(SystemEvent::GatewayClosed(member_id)).await;
+                                break;
+                            }
+                        }
+                        todo!("Handle this")
+                    }
+                    Ok(event) => match event {
+                        twilight_gateway::Event::Ready(_) => {
+                            if let Some(channel) = &system_channel {
+                                channel
+                                    .lock()
+                                    .await
+                                    .send(SystemEvent::GatewayConnected(member_id))
+                                    .await;
+                            }
+                        }
+
+                        twilight_gateway::Event::MessageCreate(message_create) => {
+                            let message = message_create.0;
+
+                            if message.author.id != reference_user_id {
+                                continue;
+                            }
+
+                            if let Some(channel) = &message_channel {
+                                channel
+                                    .lock()
+                                    .await
+                                    .send((message.timestamp, message))
+                                    .await;
+                            }
+                        }
+
+                        twilight_gateway::Event::MessageUpdate(message_update) => {
+                            if message_update.author.is_none()
+                                || message_update.author.as_ref().unwrap().id != reference_user_id
+                            {
+                                continue;
+                            }
+
+                            if message_update.edited_timestamp.is_none() {
+                                println!("Message update but no edit timestamp");
+                                continue;
+                            }
+
+                            if message_update.content.is_none() {
+                                println!("Message update but no content");
+                                continue;
+                            }
+
+                            let message = client
+                                .message(message_update.channel_id, message_update.id)
+                                .await
+                                .expect("Could not load message")
+                                .model()
+                                .await
+                                .expect("Could not deserialize message");
+
+                            if let Some(channel) = &message_channel {
+                                channel
+                                    .lock()
+                                    .await
+                                    .send((message_update.edited_timestamp.unwrap(), message))
+                                    .await;
+                            }
+                        }
+
+                        _ => (),
+                    },
+                }
+            }
+        });
+    }
+}
diff --git a/src/system/mod.rs b/src/system/mod.rs
new file mode 100644
index 0000000..ce23466
--- /dev/null
+++ b/src/system/mod.rs
@@ -0,0 +1,500 @@
+use std::{collections::HashMap, str::FromStr, time::Duration};
+
+use futures::future::join_all;
+use tokio::{
+    sync::mpsc::{channel, Sender},
+    time::sleep,
+};
+use twilight_http::Client;
+use twilight_model::http::attachment::Attachment;
+use twilight_model::util::Timestamp;
+use twilight_model::{
+    channel::{
+        message::{AllowedMentions, MentionType, MessageType},
+        Message,
+    },
+    id::{marker::UserMarker, Id},
+};
+
+use crate::config::{AutoproxyConfig, AutoproxyLatchScope, Member};
+
+mod aggregator;
+mod gateway;
+mod types;
+use aggregator::MessageAggregator;
+use gateway::Gateway;
+pub use types::*;
+
+pub struct Manager {
+    pub name: String,
+    pub config: crate::config::System,
+    pub clients: HashMap<MemberId, Client>,
+    pub gateways: HashMap<MemberId, Gateway>,
+    pub latch_state: Option<(MemberId, Timestamp)>,
+    pub last_presence: HashMap<MemberId, Status>,
+    pub system_sender: Option<Sender<SystemEvent>>,
+}
+
+impl Manager {
+    pub fn new(system_name: String, system_config: crate::config::System) -> Self {
+        Self {
+            name: system_name,
+            config: system_config,
+            clients: HashMap::new(),
+            gateways: HashMap::new(),
+            latch_state: None,
+            last_presence: HashMap::new(),
+            system_sender: None,
+        }
+    }
+
+    pub fn find_member_by_name<'a>(
+        &'a self,
+        name: &String,
+    ) -> Option<(MemberId, &'a crate::config::Member)> {
+        self.config
+            .members
+            .iter()
+            .enumerate()
+            .find(|(_member_id, member)| member.name == *name)
+    }
+
+    pub fn find_member_by_id<'a>(&'a self, id: MemberId) -> Option<&'a Member> {
+        self.config
+            .members
+            .iter()
+            .enumerate()
+            .find(|(member_id, _)| *member_id == id)
+            .map_or(None, |(_member_id, member)| Some(member))
+    }
+
+    pub async fn start_clients(&mut self) {
+        println!("Starting clients for system {}", self.name);
+
+        let reference_user_id: Id<UserMarker> =
+            Id::from_str(self.config.reference_user_id.as_str())
+                .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str());
+
+        let (system_sender, mut system_receiver) = channel::<SystemEvent>(100);
+        self.system_sender = Some(system_sender.clone());
+        let mut aggregator = MessageAggregator::new();
+        aggregator.set_handler(system_sender.clone());
+
+        for (member_id, member) in self.config.members.iter().enumerate() {
+            // Create outgoing client
+            let client = twilight_http::Client::new(member.discord_token.clone());
+            self.clients.insert(member_id, client);
+
+            // Create gateway listener
+            let mut listener = Gateway::new(member_id, &member, reference_user_id);
+
+            listener.set_message_handler(aggregator.get_sender());
+            listener.set_system_handler(system_sender.clone());
+
+            // Start gateway listener
+            listener.start_listening();
+            self.gateways.insert(member_id, listener);
+        }
+
+        aggregator.start();
+
+        let mut num_connected = 0;
+
+        loop {
+            match system_receiver.recv().await {
+                Some(event) => match event {
+                    SystemEvent::GatewayConnected(member_id) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+
+                        num_connected += 1;
+                        println!(
+                            "Gateway client {} ({}) connected",
+                            num_connected, member.name
+                        );
+
+                        if num_connected == self.config.members.len() {
+                            let system_sender = system_sender.clone();
+                            tokio::spawn(async move {
+                                println!("All gateways connected");
+                                sleep(Duration::from_secs(5)).await;
+                                let _ = system_sender.send(SystemEvent::AllGatewaysConnected).await;
+                            });
+                        }
+                    }
+                    SystemEvent::GatewayClosed(member_id) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+
+                        println!("Gateway client {} closed", member.name);
+
+                        num_connected -= 1;
+                    }
+                    SystemEvent::NewMessage((event_time, message)) => {
+                        self.handle_message(message, event_time).await;
+                    }
+                    SystemEvent::GatewayError(member_id, message) => {
+                        let member = self
+                            .find_member_by_id(member_id)
+                            .expect("Could not find member");
+                        println!("Gateway client {} ran into error {}", member.name, message);
+                        return;
+                    }
+                    SystemEvent::AutoproxyTimeout(time_scheduled) => {
+                        if let Some((_member, current_last_message)) = self.latch_state.clone() {
+                            if current_last_message == time_scheduled {
+                                println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs());
+                                self.latch_state = None;
+                                self.update_status_of_system().await;
+                            }
+                        }
+                    }
+                    SystemEvent::AllGatewaysConnected => {
+                        println!(
+                            "Attempting to set startup status for system {}",
+                            self.name.clone()
+                        );
+                        self.update_status_of_system().await;
+                    }
+                    _ => (),
+                },
+                None => return,
+            }
+        }
+    }
+
+    async fn handle_message(&mut self, message: Message, timestamp: Timestamp) {
+        // TODO: Commands
+        if message.content.eq("!panic") {
+            panic!("Exiting due to user command");
+        }
+
+        // Escape sequence
+        if message.content.starts_with(r"\") {
+            if message.content == r"\\" {
+                let client = if let Some((current_member, _)) = self.latch_state.clone() {
+                    self.clients
+                        .get(&current_member)
+                        .expect(format!("No client for member {}", current_member).as_str())
+                } else {
+                    self.clients.iter().next().expect("No clients!").1
+                };
+
+                client
+                    .delete_message(message.channel_id, message.id)
+                    .await
+                    .expect("Could not delete message");
+                self.latch_state = None
+            } else if message.content.starts_with(r"\\") {
+                self.latch_state = None;
+            }
+
+            return;
+        }
+
+        // TODO: Non-latching prefixes maybe?
+
+        // Check for prefix
+        println!("Checking prefix");
+        let match_prefix =
+            self.config
+                .members
+                .iter()
+                .enumerate()
+                .find_map(|(member_id, member)| {
+                    Some((member_id, member.matches_proxy_prefix(&message)?))
+                });
+        if let Some((member_id, matched_content)) = match_prefix {
+            self.proxy_message(&message, member_id, matched_content)
+                .await;
+            println!("Updating proxy state to member id {}", member_id);
+            self.update_autoproxy_state_after_message(member_id, timestamp);
+            self.update_status_of_system().await;
+            return;
+        }
+
+        // Check for autoproxy
+        println!("Checking autoproxy");
+        if let Some(autoproxy_config) = &self.config.autoproxy {
+            match autoproxy_config {
+                AutoproxyConfig::Member { name } => {
+                    let (member_id, _member) = self
+                        .find_member_by_name(&name)
+                        .expect("Invalid autoproxy member name");
+                    self.proxy_message(&message, member_id, message.content.as_str())
+                        .await;
+                }
+                // TODO: Do something with the latch scope
+                // TODO: Do something with presence setting
+                AutoproxyConfig::Latch {
+                    scope,
+                    timeout_seconds,
+                    presence_indicator,
+                } => {
+                    println!("Currently in latch mode");
+                    if let Some((member, last_timestamp)) = self.latch_state.clone() {
+                        println!("We have a latch state");
+                        let time_since_last = timestamp.as_secs() - last_timestamp.as_secs();
+                        println!("Time since last (seconds) {}", time_since_last);
+                        if time_since_last <= (*timeout_seconds).into() {
+                            println!("Proxying");
+                            self.proxy_message(&message, member, message.content.as_str())
+                                .await;
+                            self.latch_state = Some((member, timestamp));
+                            self.update_autoproxy_state_after_message(member, timestamp);
+                            self.update_status_of_system().await;
+                        }
+                    }
+                }
+            }
+        } else {
+            println!("No autoproxy config?");
+        }
+    }
+
+    async fn proxy_message(&self, message: &Message, member: MemberId, content: &str) {
+        let client = self.clients.get(&member).expect("No client for member");
+
+        if let Err(err) = self.duplicate_message(message, client, content).await {
+            match err {
+                MessageDuplicateError::MessageCreate(err) => {
+                    if err.to_string().contains("Cannot send an empty message") {
+                        client
+                            .delete_message(message.channel_id, message.id)
+                            .await
+                            .expect("Could not delete message");
+                    }
+                }
+                _ => println!("Error: {:?}", err),
+            }
+        } else {
+            client
+                .delete_message(message.channel_id, message.id)
+                .await
+                .expect("Could not delete message");
+        }
+    }
+
+    async fn duplicate_message(
+        &self,
+        message: &Message,
+        client: &Client,
+        content: &str,
+    ) -> Result<Message, MessageDuplicateError> {
+        let mut create_message = client.create_message(message.channel_id).content(content)?;
+
+        let mut allowed_mentions = AllowedMentions {
+            parse: Vec::new(),
+            replied_user: false,
+            roles: message.mention_roles.clone(),
+            users: message.mentions.iter().map(|user| user.id).collect(),
+        };
+
+        if message.mention_everyone {
+            allowed_mentions.parse.push(MentionType::Everyone);
+        }
+
+        if message.kind == MessageType::Reply {
+            if let Some(ref_message) = message.referenced_message.as_ref() {
+                create_message = create_message.reply(ref_message.id);
+
+                let pings_referenced_author = message
+                    .mentions
+                    .iter()
+                    .any(|user| user.id == ref_message.author.id);
+
+                if pings_referenced_author {
+                    allowed_mentions.replied_user = true;
+                } else {
+                    allowed_mentions.replied_user = false;
+                }
+            } else {
+                panic!("Cannot proxy message: Was reply but no referenced message");
+            }
+        }
+
+        let attachments = join_all(message.attachments.iter().map(|attachment| async {
+            let filename = attachment.filename.clone();
+            let description_opt = attachment.description.clone();
+            let bytes = reqwest::get(attachment.proxy_url.clone())
+                .await?
+                .bytes()
+                .await?;
+            let mut new_attachment =
+                Attachment::from_bytes(filename, bytes.try_into().unwrap(), attachment.id.into());
+
+            if let Some(description) = description_opt {
+                new_attachment.description(description);
+            }
+
+            Ok(new_attachment)
+        }))
+        .await
+        .iter()
+        .filter_map(
+            |result: &Result<Attachment, MessageDuplicateError>| match result {
+                Ok(attachment) => Some(attachment.clone()),
+                Err(_) => None,
+            },
+        )
+        .collect::<Vec<_>>();
+
+        if attachments.len() > 0 {
+            create_message = create_message.attachments(attachments.as_slice())?;
+        }
+
+        if let Some(flags) = message.flags {
+            create_message = create_message.flags(flags);
+        }
+
+        create_message = create_message.allowed_mentions(Some(&allowed_mentions));
+        let new_message = create_message.await?.model().await?;
+
+        Ok(new_message)
+    }
+
+    fn update_autoproxy_state_after_message(&mut self, member: MemberId, timestamp: Timestamp) {
+        match &self.config.autoproxy {
+            None => (),
+            Some(AutoproxyConfig::Member { name: _ }) => (),
+            Some(AutoproxyConfig::Latch {
+                scope,
+                timeout_seconds,
+                presence_indicator: _,
+            }) => {
+                self.latch_state = Some((member, timestamp));
+
+                if let Some(channel) = self.system_sender.clone() {
+                    let last_message = timestamp.clone();
+                    let timeout_seconds = timeout_seconds.clone();
+
+                    tokio::spawn(async move {
+                        sleep(Duration::from_secs(timeout_seconds.into())).await;
+                        channel
+                            .send(SystemEvent::AutoproxyTimeout(last_message))
+                            .await
+                            .expect("Channel has closed");
+                    });
+                }
+            }
+        }
+    }
+
+    async fn update_status_of_system(&mut self) {
+        let member_states: Vec<(MemberId, Status)> = self
+            .config
+            .members
+            .iter()
+            .enumerate()
+            .map(|(member_id, member)| {
+                (
+                    member_id,
+                    match &self.config.autoproxy {
+                        None => Status::Invisible,
+                        Some(AutoproxyConfig::Member { name }) => {
+                            if member.name == *name {
+                                Status::Online
+                            } else {
+                                Status::Invisible
+                            }
+                        }
+                        Some(AutoproxyConfig::Latch {
+                            scope,
+                            timeout_seconds: _,
+                            presence_indicator,
+                        }) => {
+                            if let AutoproxyLatchScope::Server = scope {
+                                Status::Invisible
+                            } else if !presence_indicator {
+                                Status::Invisible
+                            } else {
+                                match &self.latch_state {
+                                    Some((latch_member, _last_timestamp)) => {
+                                        if member_id == *latch_member {
+                                            Status::Online
+                                        } else {
+                                            Status::Invisible
+                                        }
+                                    }
+                                    None => Status::Invisible,
+                                }
+                            }
+                        }
+                    },
+                )
+            })
+            .collect();
+
+        for (member, status) in member_states {
+            self.update_status_of_member(member, status).await;
+        }
+    }
+
+    async fn update_status_of_member(&mut self, member: MemberId, status: Status) {
+        let last_status = *self.last_presence.get(&member).unwrap_or(&Status::Offline);
+
+        if status == last_status {
+            return;
+        }
+
+        if let Some(gateway) = self.gateways.get(&member) {
+            gateway.set_status(status).await;
+
+            self.last_presence.insert(member, status);
+        } else {
+            let full_member = self
+                .find_member_by_id(member)
+                .expect("Cannot look up member");
+            println!(
+                "Could not look up gateway for member ID {} ({})",
+                member, full_member.name
+            );
+        }
+    }
+}
+
+impl crate::config::Member {
+    pub fn matches_proxy_prefix<'a>(&self, message: &'a Message) -> Option<&'a str> {
+        match self.message_pattern.captures(message.content.as_str()) {
+            None => None,
+            Some(captures) => match captures.name("content") {
+                None => None,
+                Some(matched_content) => Some(matched_content.as_str()),
+            },
+        }
+    }
+}
+
+#[derive(Debug)]
+enum MessageDuplicateError {
+    MessageValidation(twilight_validate::message::MessageValidationError),
+    AttachmentRequest(reqwest::Error),
+    MessageCreate(twilight_http::error::Error),
+    ResponseDeserialization(twilight_http::response::DeserializeBodyError),
+}
+
+impl From<twilight_validate::message::MessageValidationError> for MessageDuplicateError {
+    fn from(value: twilight_validate::message::MessageValidationError) -> Self {
+        MessageDuplicateError::MessageValidation(value)
+    }
+}
+
+impl From<reqwest::Error> for MessageDuplicateError {
+    fn from(value: reqwest::Error) -> Self {
+        MessageDuplicateError::AttachmentRequest(value)
+    }
+}
+
+impl From<twilight_http::error::Error> for MessageDuplicateError {
+    fn from(value: twilight_http::error::Error) -> Self {
+        MessageDuplicateError::MessageCreate(value)
+    }
+}
+
+impl From<twilight_http::response::DeserializeBodyError> for MessageDuplicateError {
+    fn from(value: twilight_http::response::DeserializeBodyError) -> Self {
+        MessageDuplicateError::ResponseDeserialization(value)
+    }
+}
diff --git a/src/system/types.rs b/src/system/types.rs
new file mode 100644
index 0000000..862ddd1
--- /dev/null
+++ b/src/system/types.rs
@@ -0,0 +1,33 @@
+use twilight_model::channel::Message;
+use twilight_model::id::marker::{MessageMarker, UserMarker};
+use twilight_model::id::Id;
+use twilight_model::util::Timestamp;
+
+pub type MemberId = usize;
+pub type MessageId = Id<MessageMarker>;
+pub type UserId = Id<UserMarker>;
+
+pub type Status = twilight_model::gateway::presence::Status;
+
+pub type MessageEvent = (Timestamp, Message);
+pub type ReactionEvent = (Timestamp, ());
+pub type CommandEvent = (Timestamp, ());
+
+pub enum SystemEvent {
+    // Process of operation
+    GatewayConnected(MemberId),
+    GatewayError(MemberId, String),
+    GatewayClosed(MemberId),
+    AllGatewaysConnected,
+
+    // User event handling
+    NewMessage(MessageEvent),
+    EditedMessage(MessageEvent),
+    NewReaction(ReactionEvent),
+
+    // Command handling
+    NewCommand(CommandEvent),
+
+    // Autoproxy
+    AutoproxyTimeout(Timestamp),
+}