summary refs log tree commit diff
path: root/src/system
diff options
context:
space:
mode:
authorAshelyn Rose <git@ashen.earth>2024-10-05 02:20:59 -0600
committerAshelyn Rose <git@ashen.earth>2024-10-05 02:20:59 -0600
commit8d564bb8a17b5b0251805baf20c6741d4d20e501 (patch)
tree74cd7aad01b26548a3b1e3601cd267282c629ffd /src/system
parent0053ccbb31c3b87285bf38ee3eda3308c67ad707 (diff)
Allow individual client to disconnect and reconnect
Diffstat (limited to 'src/system')
-rw-r--r--src/system/aggregator.rs60
-rw-r--r--src/system/bot/gateway.rs2
-rw-r--r--src/system/mod.rs159
-rw-r--r--src/system/types.rs2
4 files changed, 122 insertions, 101 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs
index 6873272..00ba8e8 100644
--- a/src/system/aggregator.rs
+++ b/src/system/aggregator.rs
@@ -1,47 +1,67 @@
 use lru::LruCache;
+use std::sync::Arc;
+use tokio::sync::RwLock;
 use std::num::NonZeroUsize;
 use tokio::sync::mpsc::{channel, Receiver, Sender};
 use twilight_model::channel::Message as TwiMessage;
 
 use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent};
 
-pub struct MessageAggregator {
+pub struct AggregatorState {
     rx: Receiver<MessageEvent>,
     tx: Sender<MessageEvent>,
     message_cache: lru::LruCache<MessageId, TwiMessage>,
     system_emitter: Option<Sender<SystemEvent>>,
 }
 
+pub struct MessageAggregator {
+    state: Arc<RwLock<AggregatorState>>,
+}
+
 impl MessageAggregator {
     pub fn new() -> Self {
         let (tx, rx) = channel::<MessageEvent>(100);
 
         Self {
-            tx,
-            rx,
-            message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
-            system_emitter: None,
+            state: Arc::new(RwLock::new( AggregatorState {
+                tx,
+                rx,
+                message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()),
+                system_emitter: None,
+
+            }))
         }
     }
 
-    pub fn get_sender(&self) -> Sender<MessageEvent> {
-        self.tx.clone()
+    pub async fn get_sender(&self) -> Sender<MessageEvent> {
+        self.state.read().await.tx.clone()
     }
 
-    pub fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () {
-        self.system_emitter = Some(emitter);
+    pub async fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () {
+        self.state.write().await.system_emitter = Some(emitter);
     }
 
-    pub fn start(mut self) -> () {
+    pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> {
+        self.state.write().await.message_cache.get(&message_id).map(|m| m.clone())
+    }
+
+    pub fn start(&self) -> () {
+        let state = self.state.clone();
+
         tokio::spawn(async move {
             loop {
-                match self.rx.recv().await {
+                let system_emitter = { state.read().await.system_emitter.clone().expect("No system emitter") };
+                let self_emitter = { state.read().await.tx.clone() };
+                let next_event = { state.write().await.rx.recv().await };
+
+
+                match next_event {
                     None => (),
                     Some((timestamp, message)) => {
-                        let system_emitter = &self.system_emitter.clone().expect("No system emitter");
                         match message {
                             GatewayMessage::Partial(current_partial, member_id) => {
-                                match self.message_cache.get(&current_partial.id) {
+                                let cache_content = { state.write().await.message_cache.get(&current_partial.id).map(|m| m.clone()) };
+                                match cache_content {
                                     Some(original_message) => {
 
                                         let mut updated_message = original_message.clone();
@@ -49,11 +69,11 @@ impl MessageAggregator {
                                             updated_message.edited_timestamp = Some(edited_time);
                                         }
 
-                                        if let Some(content) = current_partial.content {
-                                            updated_message.content = content
+                                        if let Some(content) = &current_partial.content {
+                                            updated_message.content = content.clone()
                                         }
 
-                                        self.tx.send((timestamp, GatewayMessage::Complete(updated_message))).await;
+                                        self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await;
                                     },
                                     None => {
                                         system_emitter.send(
@@ -63,9 +83,9 @@ impl MessageAggregator {
                                 };
                             },
                             GatewayMessage::Complete(message) => {
-                                let previous_message = self.message_cache.get(&message.id);
+                                let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) };
 
-                                if let Some(previous_message) = previous_message.cloned() {
+                                if let Some(previous_message) = previous_message {
                                     let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp);
                                     let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp);
 
@@ -77,9 +97,9 @@ impl MessageAggregator {
                                     // If not, fall through to update stored message
                                 }
 
-                                self.message_cache.put(message.id, message.clone());
+                                { state.write().await.message_cache.put(message.id, message.clone()); };
 
-                                self.system_emitter.as_ref().expect("Aggregator has no system emitter")
+                                system_emitter
                                     .send(SystemEvent::NewMessage(timestamp, message))
                                     .await;
                             },
diff --git a/src/system/bot/gateway.rs b/src/system/bot/gateway.rs
index 4a83086..5a45083 100644
--- a/src/system/bot/gateway.rs
+++ b/src/system/bot/gateway.rs
@@ -76,7 +76,7 @@ impl Gateway {
 
                         if source.is_fatal() {
                             system_channel.send(SystemEvent::GatewayClosed(bot_conf.member_id)).await;
-                            break;
+                            return;
                         }
                     }
                     Ok(event) => match event {
diff --git a/src/system/mod.rs b/src/system/mod.rs
index bd390b7..4cb99ea 100644
--- a/src/system/mod.rs
+++ b/src/system/mod.rs
@@ -1,7 +1,7 @@
-use std::{collections::HashMap, str::FromStr, time::Duration};
+use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration};
 
 use tokio::{
-    sync::mpsc::{channel, Sender},
+    sync::{mpsc::{channel, Sender}, RwLock},
     time::sleep,
 };
 use twilight_model::id::{marker::UserMarker, Id};
@@ -23,16 +23,21 @@ pub struct Manager {
     pub bots: HashMap<MemberId, Bot>,
     pub latch_state: Option<(MemberId, Timestamp)>,
     pub system_sender: Option<Sender<SystemEvent>>,
+    pub aggregator: MessageAggregator,
+    pub reference_user_id: UserId,
 }
 
 impl Manager {
     pub fn new(system_name: String, system_config: crate::config::System) -> Self {
         Self {
+            reference_user_id: Id::from_str(&system_config.reference_user_id.as_str())
+                .expect(format!("Invalid user id for system {}", &system_name).as_str()),
             name: system_name,
             config: system_config,
             bots: HashMap::new(),
             latch_state: None,
             system_sender: None,
+            aggregator: MessageAggregator::new(),
         }
     }
 
@@ -59,104 +64,100 @@ impl Manager {
     pub async fn start_clients(&mut self) {
         println!("Starting clients for system {}", self.name);
 
-        let reference_user_id: Id<UserMarker> =
-            Id::from_str(self.config.reference_user_id.as_str())
-                .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str());
-
         let (system_sender, mut system_receiver) = channel::<SystemEvent>(100);
         self.system_sender = Some(system_sender.clone());
-        let mut aggregator = MessageAggregator::new();
-        aggregator.set_system_handler(system_sender.clone());
+        self.aggregator.set_system_handler(system_sender.clone()).await;
+        self.aggregator.start();
 
-        for (member_id, member) in self.config.members.iter().enumerate() {
-            // Create gateway listener
-            let mut bot = Bot::new(member_id, &member, reference_user_id);
+        for member_id in 0..self.config.members.len() {
+            self.start_bot(member_id).await;
+        }
 
-            bot.set_message_handler(aggregator.get_sender()).await;
-            bot.set_system_handler(system_sender.clone()).await;
+        loop {
+            match system_receiver.recv().await {
+                Some(SystemEvent::GatewayConnected(member_id)) => {
+                    let member = self.find_member_by_id(member_id).unwrap();
 
-            // Start gateway listener
-            bot.start();
-            self.bots.insert(member_id, bot);
-        }
+                    println!("Gateway client {} ({}) connected", member.name, member_id);
+                }
 
-        aggregator.start();
+                Some(SystemEvent::GatewayError(member_id, message)) => {
+                    let member = self.find_member_by_id(member_id).unwrap();
 
-        let mut num_connected = 0;
+                    println!("Gateway client {} ran into error {}", member.name, message);
+                }
 
-        loop {
-            match system_receiver.recv().await {
-                Some(event) => match event {
-                    SystemEvent::GatewayConnected(member_id) => {
-                        let member = self
-                            .find_member_by_id(member_id)
-                            .expect("Could not find member");
-
-                        num_connected += 1;
-                        println!(
-                            "Gateway client {} ({}) connected",
-                            num_connected, member.name
-                        );
-
-                        if num_connected == self.config.members.len() {
-                            let system_sender = system_sender.clone();
-                            tokio::spawn(async move {
-                                println!("All gateways connected");
-                                sleep(Duration::from_secs(5)).await;
-                                let _ = system_sender.send(SystemEvent::AllGatewaysConnected).await;
-                            });
+                Some(SystemEvent::GatewayClosed(member_id)) => {
+                    let member = self.find_member_by_id(member_id).unwrap();
+
+                    println!("Gateway client {} closed", member.name);
+
+                    self.start_bot(member_id).await;
+                }
+
+                Some(SystemEvent::NewMessage(event_time, message)) => {
+                    self.handle_message(message, event_time).await;
+                }
+
+                Some(SystemEvent::RefetchMessage(member_id, message_id, channel_id)) => {
+                    let bot = self.bots.get(&member_id).unwrap();
+                    bot.refetch_message(message_id, channel_id).await;
+                }
+
+                Some(SystemEvent::AutoproxyTimeout(time_scheduled)) => {
+                    if let Some((_member, current_last_message)) = self.latch_state.clone() {
+                        if current_last_message == time_scheduled {
+                            println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs());
+                            self.latch_state = None;
+                            self.update_status_of_system().await;
                         }
                     }
-                    SystemEvent::GatewayClosed(member_id) => {
-                        let member = self
-                            .find_member_by_id(member_id)
-                            .expect("Could not find member");
+                },
 
-                        println!("Gateway client {} closed", member.name);
+                Some(SystemEvent::UpdateClientStatus(member_id)) => {
+                    let bot = self.bots.get(&member_id).unwrap();
 
-                        num_connected -= 1;
-                    }
-                    SystemEvent::NewMessage(event_time, message) => {
-                        self.handle_message(message, event_time).await;
-                    }
-                    SystemEvent::RefetchMessage(member_id, message_id, channel_id) => {
-                        let bot = self.bots.get(&member_id).expect("No bot");
-                        bot.refetch_message(message_id, channel_id).await;
-                    }
-                    SystemEvent::GatewayError(member_id, message) => {
-                        let member = self
-                            .find_member_by_id(member_id)
-                            .expect("Could not find member");
-                        println!("Gateway client {} ran into error {}", member.name, message);
-                        return;
-                    }
-                    SystemEvent::AutoproxyTimeout(time_scheduled) => {
-                        if let Some((_member, current_last_message)) = self.latch_state.clone() {
-                            if current_last_message == time_scheduled {
-                                println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs());
-                                self.latch_state = None;
-                                self.update_status_of_system().await;
-                            }
+                    // TODO: handle other presence modes
+                    if let Some((latched_id, _)) = self.latch_state {
+                        if latched_id == member_id {
+                            bot.set_status(Status::Online).await;
+                            continue
                         }
                     }
-                    SystemEvent::AllGatewaysConnected => {
-                        println!(
-                            "Attempting to set startup status for system {}",
-                            self.name.clone()
-                        );
-                        self.update_status_of_system().await;
-                    }
-                    _ => (),
-                },
-                None => return,
+
+                    bot.set_status(Status::Invisible).await;
+                }
+
+                _ => continue,
             }
         }
     }
 
+    async fn start_bot(&mut self, member_id: MemberId) {
+        let member = self.find_member_by_id(member_id).unwrap();
+
+        // Create gateway listener
+        let mut bot = Bot::new(member_id, &member, self.reference_user_id);
+
+        bot.set_message_handler(self.aggregator.get_sender().await).await;
+        bot.set_system_handler(self.system_sender.as_ref().unwrap().clone()).await;
+
+        // Start gateway listener
+        bot.start();
+        self.bots.insert(member_id, bot);
+
+        // Schedule status update after a few seconds
+        let rx = self.system_sender.as_ref().unwrap().clone();
+        tokio::spawn(async move {
+            sleep(Duration::from_secs(10)).await;
+            let _ = rx.send(SystemEvent::UpdateClientStatus(member_id)).await;
+        });
+    }
+
     async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp) {
         // TODO: Commands
         if message.content.eq("!panic") {
-            panic!("Exiting due to user command");
+           self.bots.iter_mut().next().unwrap().1.shutdown().await;
         }
 
         // Escape sequence
diff --git a/src/system/types.rs b/src/system/types.rs
index 41cb10d..483fbbd 100644
--- a/src/system/types.rs
+++ b/src/system/types.rs
@@ -27,8 +27,8 @@ pub enum SystemEvent {
     GatewayConnected(MemberId),
     GatewayError(MemberId, String),
     GatewayClosed(MemberId),
-    AllGatewaysConnected,
     RefetchMessage(MemberId, MessageId, ChannelId),
+    UpdateClientStatus(MemberId),
 
     // User event handling
     NewMessage(Timestamp, FullMessage),