diff options
author | Ashelyn Rose <git@ashen.earth> | 2024-10-05 02:20:59 -0600 |
---|---|---|
committer | Ashelyn Rose <git@ashen.earth> | 2024-10-05 02:20:59 -0600 |
commit | 8d564bb8a17b5b0251805baf20c6741d4d20e501 (patch) | |
tree | 74cd7aad01b26548a3b1e3601cd267282c629ffd | |
parent | 0053ccbb31c3b87285bf38ee3eda3308c67ad707 (diff) |
Allow individual client to disconnect and reconnect
-rw-r--r-- | src/system/aggregator.rs | 60 | ||||
-rw-r--r-- | src/system/bot/gateway.rs | 2 | ||||
-rw-r--r-- | src/system/mod.rs | 159 | ||||
-rw-r--r-- | src/system/types.rs | 2 |
4 files changed, 122 insertions, 101 deletions
diff --git a/src/system/aggregator.rs b/src/system/aggregator.rs index 6873272..00ba8e8 100644 --- a/src/system/aggregator.rs +++ b/src/system/aggregator.rs @@ -1,47 +1,67 @@ use lru::LruCache; +use std::sync::Arc; +use tokio::sync::RwLock; use std::num::NonZeroUsize; use tokio::sync::mpsc::{channel, Receiver, Sender}; use twilight_model::channel::Message as TwiMessage; use super::{Message as GatewayMessage, MessageEvent, MessageId, SystemEvent}; -pub struct MessageAggregator { +pub struct AggregatorState { rx: Receiver<MessageEvent>, tx: Sender<MessageEvent>, message_cache: lru::LruCache<MessageId, TwiMessage>, system_emitter: Option<Sender<SystemEvent>>, } +pub struct MessageAggregator { + state: Arc<RwLock<AggregatorState>>, +} + impl MessageAggregator { pub fn new() -> Self { let (tx, rx) = channel::<MessageEvent>(100); Self { - tx, - rx, - message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()), - system_emitter: None, + state: Arc::new(RwLock::new( AggregatorState { + tx, + rx, + message_cache: LruCache::new(NonZeroUsize::new(100).unwrap()), + system_emitter: None, + + })) } } - pub fn get_sender(&self) -> Sender<MessageEvent> { - self.tx.clone() + pub async fn get_sender(&self) -> Sender<MessageEvent> { + self.state.read().await.tx.clone() } - pub fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () { - self.system_emitter = Some(emitter); + pub async fn set_system_handler(&mut self, emitter: Sender<SystemEvent>) -> () { + self.state.write().await.system_emitter = Some(emitter); } - pub fn start(mut self) -> () { + pub async fn lookup_message(&self, message_id: MessageId) -> Option<TwiMessage> { + self.state.write().await.message_cache.get(&message_id).map(|m| m.clone()) + } + + pub fn start(&self) -> () { + let state = self.state.clone(); + tokio::spawn(async move { loop { - match self.rx.recv().await { + let system_emitter = { state.read().await.system_emitter.clone().expect("No system emitter") }; + let self_emitter = { state.read().await.tx.clone() }; + let next_event = { state.write().await.rx.recv().await }; + + + match next_event { None => (), Some((timestamp, message)) => { - let system_emitter = &self.system_emitter.clone().expect("No system emitter"); match message { GatewayMessage::Partial(current_partial, member_id) => { - match self.message_cache.get(¤t_partial.id) { + let cache_content = { state.write().await.message_cache.get(¤t_partial.id).map(|m| m.clone()) }; + match cache_content { Some(original_message) => { let mut updated_message = original_message.clone(); @@ -49,11 +69,11 @@ impl MessageAggregator { updated_message.edited_timestamp = Some(edited_time); } - if let Some(content) = current_partial.content { - updated_message.content = content + if let Some(content) = ¤t_partial.content { + updated_message.content = content.clone() } - self.tx.send((timestamp, GatewayMessage::Complete(updated_message))).await; + self_emitter.send((timestamp, GatewayMessage::Complete(updated_message))).await; }, None => { system_emitter.send( @@ -63,9 +83,9 @@ impl MessageAggregator { }; }, GatewayMessage::Complete(message) => { - let previous_message = self.message_cache.get(&message.id); + let previous_message = { state.write().await.message_cache.get(&message.id).map(|m| m.clone()) }; - if let Some(previous_message) = previous_message.cloned() { + if let Some(previous_message) = previous_message { let previous_timestamp = previous_message.edited_timestamp.unwrap_or(previous_message.timestamp); let current_timestamp = message.edited_timestamp.unwrap_or(message.timestamp); @@ -77,9 +97,9 @@ impl MessageAggregator { // If not, fall through to update stored message } - self.message_cache.put(message.id, message.clone()); + { state.write().await.message_cache.put(message.id, message.clone()); }; - self.system_emitter.as_ref().expect("Aggregator has no system emitter") + system_emitter .send(SystemEvent::NewMessage(timestamp, message)) .await; }, diff --git a/src/system/bot/gateway.rs b/src/system/bot/gateway.rs index 4a83086..5a45083 100644 --- a/src/system/bot/gateway.rs +++ b/src/system/bot/gateway.rs @@ -76,7 +76,7 @@ impl Gateway { if source.is_fatal() { system_channel.send(SystemEvent::GatewayClosed(bot_conf.member_id)).await; - break; + return; } } Ok(event) => match event { diff --git a/src/system/mod.rs b/src/system/mod.rs index bd390b7..4cb99ea 100644 --- a/src/system/mod.rs +++ b/src/system/mod.rs @@ -1,7 +1,7 @@ -use std::{collections::HashMap, str::FromStr, time::Duration}; +use std::{collections::HashMap, str::FromStr, sync::Arc, time::Duration}; use tokio::{ - sync::mpsc::{channel, Sender}, + sync::{mpsc::{channel, Sender}, RwLock}, time::sleep, }; use twilight_model::id::{marker::UserMarker, Id}; @@ -23,16 +23,21 @@ pub struct Manager { pub bots: HashMap<MemberId, Bot>, pub latch_state: Option<(MemberId, Timestamp)>, pub system_sender: Option<Sender<SystemEvent>>, + pub aggregator: MessageAggregator, + pub reference_user_id: UserId, } impl Manager { pub fn new(system_name: String, system_config: crate::config::System) -> Self { Self { + reference_user_id: Id::from_str(&system_config.reference_user_id.as_str()) + .expect(format!("Invalid user id for system {}", &system_name).as_str()), name: system_name, config: system_config, bots: HashMap::new(), latch_state: None, system_sender: None, + aggregator: MessageAggregator::new(), } } @@ -59,104 +64,100 @@ impl Manager { pub async fn start_clients(&mut self) { println!("Starting clients for system {}", self.name); - let reference_user_id: Id<UserMarker> = - Id::from_str(self.config.reference_user_id.as_str()) - .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str()); - let (system_sender, mut system_receiver) = channel::<SystemEvent>(100); self.system_sender = Some(system_sender.clone()); - let mut aggregator = MessageAggregator::new(); - aggregator.set_system_handler(system_sender.clone()); + self.aggregator.set_system_handler(system_sender.clone()).await; + self.aggregator.start(); - for (member_id, member) in self.config.members.iter().enumerate() { - // Create gateway listener - let mut bot = Bot::new(member_id, &member, reference_user_id); + for member_id in 0..self.config.members.len() { + self.start_bot(member_id).await; + } - bot.set_message_handler(aggregator.get_sender()).await; - bot.set_system_handler(system_sender.clone()).await; + loop { + match system_receiver.recv().await { + Some(SystemEvent::GatewayConnected(member_id)) => { + let member = self.find_member_by_id(member_id).unwrap(); - // Start gateway listener - bot.start(); - self.bots.insert(member_id, bot); - } + println!("Gateway client {} ({}) connected", member.name, member_id); + } - aggregator.start(); + Some(SystemEvent::GatewayError(member_id, message)) => { + let member = self.find_member_by_id(member_id).unwrap(); - let mut num_connected = 0; + println!("Gateway client {} ran into error {}", member.name, message); + } - loop { - match system_receiver.recv().await { - Some(event) => match event { - SystemEvent::GatewayConnected(member_id) => { - let member = self - .find_member_by_id(member_id) - .expect("Could not find member"); - - num_connected += 1; - println!( - "Gateway client {} ({}) connected", - num_connected, member.name - ); - - if num_connected == self.config.members.len() { - let system_sender = system_sender.clone(); - tokio::spawn(async move { - println!("All gateways connected"); - sleep(Duration::from_secs(5)).await; - let _ = system_sender.send(SystemEvent::AllGatewaysConnected).await; - }); + Some(SystemEvent::GatewayClosed(member_id)) => { + let member = self.find_member_by_id(member_id).unwrap(); + + println!("Gateway client {} closed", member.name); + + self.start_bot(member_id).await; + } + + Some(SystemEvent::NewMessage(event_time, message)) => { + self.handle_message(message, event_time).await; + } + + Some(SystemEvent::RefetchMessage(member_id, message_id, channel_id)) => { + let bot = self.bots.get(&member_id).unwrap(); + bot.refetch_message(message_id, channel_id).await; + } + + Some(SystemEvent::AutoproxyTimeout(time_scheduled)) => { + if let Some((_member, current_last_message)) = self.latch_state.clone() { + if current_last_message == time_scheduled { + println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs()); + self.latch_state = None; + self.update_status_of_system().await; } } - SystemEvent::GatewayClosed(member_id) => { - let member = self - .find_member_by_id(member_id) - .expect("Could not find member"); + }, - println!("Gateway client {} closed", member.name); + Some(SystemEvent::UpdateClientStatus(member_id)) => { + let bot = self.bots.get(&member_id).unwrap(); - num_connected -= 1; - } - SystemEvent::NewMessage(event_time, message) => { - self.handle_message(message, event_time).await; - } - SystemEvent::RefetchMessage(member_id, message_id, channel_id) => { - let bot = self.bots.get(&member_id).expect("No bot"); - bot.refetch_message(message_id, channel_id).await; - } - SystemEvent::GatewayError(member_id, message) => { - let member = self - .find_member_by_id(member_id) - .expect("Could not find member"); - println!("Gateway client {} ran into error {}", member.name, message); - return; - } - SystemEvent::AutoproxyTimeout(time_scheduled) => { - if let Some((_member, current_last_message)) = self.latch_state.clone() { - if current_last_message == time_scheduled { - println!("Autoproxy timeout has expired: {} (last sent), {} (timeout scheduled)", current_last_message.as_secs(), time_scheduled.as_secs()); - self.latch_state = None; - self.update_status_of_system().await; - } + // TODO: handle other presence modes + if let Some((latched_id, _)) = self.latch_state { + if latched_id == member_id { + bot.set_status(Status::Online).await; + continue } } - SystemEvent::AllGatewaysConnected => { - println!( - "Attempting to set startup status for system {}", - self.name.clone() - ); - self.update_status_of_system().await; - } - _ => (), - }, - None => return, + + bot.set_status(Status::Invisible).await; + } + + _ => continue, } } } + async fn start_bot(&mut self, member_id: MemberId) { + let member = self.find_member_by_id(member_id).unwrap(); + + // Create gateway listener + let mut bot = Bot::new(member_id, &member, self.reference_user_id); + + bot.set_message_handler(self.aggregator.get_sender().await).await; + bot.set_system_handler(self.system_sender.as_ref().unwrap().clone()).await; + + // Start gateway listener + bot.start(); + self.bots.insert(member_id, bot); + + // Schedule status update after a few seconds + let rx = self.system_sender.as_ref().unwrap().clone(); + tokio::spawn(async move { + sleep(Duration::from_secs(10)).await; + let _ = rx.send(SystemEvent::UpdateClientStatus(member_id)).await; + }); + } + async fn handle_message(&mut self, message: TwiMessage, timestamp: Timestamp) { // TODO: Commands if message.content.eq("!panic") { - panic!("Exiting due to user command"); + self.bots.iter_mut().next().unwrap().1.shutdown().await; } // Escape sequence diff --git a/src/system/types.rs b/src/system/types.rs index 41cb10d..483fbbd 100644 --- a/src/system/types.rs +++ b/src/system/types.rs @@ -27,8 +27,8 @@ pub enum SystemEvent { GatewayConnected(MemberId), GatewayError(MemberId, String), GatewayClosed(MemberId), - AllGatewaysConnected, RefetchMessage(MemberId, MessageId, ChannelId), + UpdateClientStatus(MemberId), // User event handling NewMessage(Timestamp, FullMessage), |