diff options
author | Ashelyn Dawn <git@ashen.earth> | 2024-07-02 21:28:21 -0600 |
---|---|---|
committer | Ashelyn Rose <git@ashen.earth> | 2024-07-02 22:11:54 -0600 |
commit | a54f77766246f4ce418447cc4d37295c15065b39 (patch) | |
tree | aa9b73592cf36ff56cb25e7dc91f7d1099d1106b /src | |
parent | 5dc71ca03b5402c0a284e25492e63c696f7bdec6 (diff) |
multithreading and message filtering from multiple clients
Diffstat (limited to 'src')
-rw-r--r-- | src/config.rs | 12 | ||||
-rw-r--r-- | src/main.rs | 58 | ||||
-rw-r--r-- | src/system.rs | 121 |
3 files changed, 148 insertions, 43 deletions
diff --git a/src/config.rs b/src/config.rs index 37a967f..26d3666 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,7 +10,7 @@ pub enum AutoProxyScope { Channel } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum PresenceMode { Online, Busy, @@ -18,7 +18,7 @@ pub enum PresenceMode { Invisible, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum AutoproxyConfig { Member(String), Latch { @@ -28,20 +28,20 @@ pub enum AutoproxyConfig { } } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum AutoproxyLatchScope { Global, Server } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct PluralkitConfig { #[serde(with = "serde_regex")] pub message_pattern: Regex, pub api_token: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct System { pub reference_user_id: String, pub members: Vec<Member>, @@ -55,7 +55,7 @@ fn default_forward_pings() -> bool { false } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct Member { pub name: String, #[serde(with = "serde_regex")] diff --git a/src/main.rs b/src/main.rs index 02fba86..a63c4bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,44 +1,28 @@ mod config; -use config::Config; -use twilight_gateway::{Intents, Shard, ShardId}; +mod system; -#[tokio::main(flavor = "current_thread")] -async fn main() { - println!("Hello, world!"); +use std::thread::{self, JoinHandle}; +use tokio::runtime; - let config_str = include_str!("../config.toml"); - let config = Config::load(config_str.to_string()); - - let token = config.systems.get("ashe-test").unwrap().members - .iter().find(|member| member.name == "test").unwrap() - .discord_token.clone(); - - println!("Token: {}", token); - - let intents = Intents::GUILD_MEMBERS | Intents::GUILD_PRESENCES | Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; - - let mut shard = Shard::new(ShardId::ONE, token, intents); +use system::System; - loop { - let event = match shard.next_event().await { - Ok(event) => event, - Err(source) => { - println!("error receiving event"); - - if source.is_fatal() { - break; - } - - continue; - } - }; - - match event { - twilight_gateway::Event::MessageCreate(message) => println!("Message: {:?}", message), - twilight_gateway::Event::MessageUpdate(_) => println!("Message updated"), - twilight_gateway::Event::Ready(_) => println!("Bot ready!"), - _ => (), - } +fn main() { + let config_str = include_str!("../config.toml"); + let config = config::Config::load(config_str.to_string()); + + let handles : Vec<_> = config.systems.into_iter().map(|(system_name, system_config)| -> JoinHandle<()> { + println!("Starting thread for system {}", system_name); + thread::spawn(move || { + let runtime = runtime::Builder::new_current_thread().enable_all().build().expect("Could not construct Tokio runtime"); + runtime.block_on(async { + let mut system = System::new(system_config); + system.start_clients().await; + }) + }) + }).collect(); + + for thread_handle in handles.into_iter() { + thread_handle.join().expect("Child thread has panicked"); } } diff --git a/src/system.rs b/src/system.rs new file mode 100644 index 0000000..e707d7f --- /dev/null +++ b/src/system.rs @@ -0,0 +1,121 @@ +use std::{num::NonZeroUsize, str::FromStr}; + +use std::sync::Arc; +use tokio::task::JoinSet; +use tokio::sync::Mutex; +use twilight_http::Client; +use twilight_model::id::{Id, marker::{MessageMarker, UserMarker}}; +use twilight_gateway::{Intents, Shard, ShardId}; + +#[derive(Clone)] +pub struct System { + pub config: crate::config::System, + pub message_dedup_cache: Arc<Mutex<lru::LruCache<Id<MessageMarker>, ()>>>, +} + +impl System { + pub fn new(system_config: crate::config::System) -> Self { + System { + config: system_config, + message_dedup_cache: Arc::new(Mutex::new(lru::LruCache::new(NonZeroUsize::new(100).unwrap()))) + } + } + + pub async fn start_clients(&mut self) { + println!("Starting clients for system"); + + let reference_user_id : Id<UserMarker> = Id::from_str(self.config.reference_user_id.as_str()) + .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str()); + + let mut set = JoinSet::new(); + + for member in self.config.members.iter() { + let token = member.discord_token.clone(); + let self_clone = self.clone(); + set.spawn(async move { + println!("Starting client with token: {}", token); + + let intents = Intents::GUILD_MEMBERS | Intents::GUILD_PRESENCES | Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; + + let mut shard = Shard::new(ShardId::ONE, token.clone(), intents); + let mut client = Client::new(token.clone()); + + loop { + let event = match shard.next_event().await { + Ok(event) => event, + Err(source) => { + println!("error receiving event"); + + if source.is_fatal() { + break; + } + + continue; + } + }; + + match event { + twilight_gateway::Event::Ready(client) => { + println!("Bot started for {}#{}", client.user.name, client.user.discriminator); + }, + + twilight_gateway::Event::MessageCreate(message) => { + if message.author.id != reference_user_id { + continue + } + + if self_clone.is_new_message(message.id).await { + self_clone.handle_message_create(message, &mut client).await; + } + }, + + twilight_gateway::Event::MessageUpdate(message) => { + if message.author.is_none() || message.author.as_ref().unwrap().id != reference_user_id { + continue + } + + if self_clone.is_new_message(message.id).await { + self_clone.handle_message_update(message, &mut client).await; + } + }, + + _ => (), + } + } + }); + } + + while let Some(join_result) = set.join_next().await { + if let Err(join_error) = join_result { + println!("Task encountered error: {}", join_error); + } else { + println!("Task joined cleanly"); + } + } + } + + async fn is_new_message(&self, message_id: Id<MessageMarker>) -> bool { + let mut message_cache = self.message_dedup_cache.lock().await; + if let None = message_cache.get(&message_id) { + message_cache.put(message_id, ()); + true + } else { + false + } + } + + async fn handle_message_create(&self, message: Box<twilight_model::gateway::payload::incoming::MessageCreate>, client: &mut Client) { + println!("Message created: {}", message.content); + + if let Err(err) = client.create_message(message.channel_id) + .reply(message.id) + .content(&format!("Recognized message from authorized user {}", message.author.name)).expect("Error: reply too long") + .await { + println!("Error: {}", err); + } + } + + async fn handle_message_update(&self, _message: Box<twilight_model::gateway::payload::incoming::MessageUpdate>, _client: &mut Client) { + // TODO: handle message edits and stuff + } +} |