diff options
-rw-r--r-- | Cargo.lock | 53 | ||||
-rw-r--r-- | Cargo.toml | 2 | ||||
-rw-r--r-- | src/config.rs | 12 | ||||
-rw-r--r-- | src/main.rs | 58 | ||||
-rw-r--r-- | src/system.rs | 121 |
5 files changed, 203 insertions, 43 deletions
diff --git a/Cargo.lock b/Cargo.lock index 6daeffd..f642e3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -18,6 +18,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] name = "aho-corasick" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -42,6 +54,12 @@ dependencies = [ ] [[package]] +name = "allocator-api2" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c6cb57a04249c6480766f7f7cef5467412af1490f8d1e243141daddada3264f" + +[[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -327,6 +345,10 @@ name = "hashbrown" version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +dependencies = [ + "ahash", + "allocator-api2", +] [[package]] name = "http" @@ -459,6 +481,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] +name = "lru" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3262e75e648fce39813cb56ac41f3c3e3f65217ebf3844d818d1f9398cfb0dc" +dependencies = [ + "hashbrown", +] + +[[package]] name = "memchr" version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -733,6 +764,7 @@ dependencies = [ name = "seance-rs" version = "0.1.0" dependencies = [ + "lru", "regex", "serde", "serde_regex", @@ -740,6 +772,7 @@ dependencies = [ "toml", "twilight-gateway", "twilight-http", + "twilight-model", ] [[package]] @@ -1523,3 +1556,23 @@ checksum = "1931d78a9c73861da0134f453bb1f790ce49b2e30eba8410b4b79bac72b46a2d" dependencies = [ "memchr", ] + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index 3ff6966..43f49c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +lru = "0.12.3" regex = "1.10.2" serde = { version = "1.0.196", features = [ "derive" ] } serde_regex = "1.1.0" @@ -13,3 +14,4 @@ tokio = { version = "1.38.0", features = [ "rt", "macros" ] } toml = "0.8.8" twilight-gateway = "0.15.4" twilight-http = "0.15.4" +twilight-model = "0.15.4" diff --git a/src/config.rs b/src/config.rs index 37a967f..26d3666 100644 --- a/src/config.rs +++ b/src/config.rs @@ -10,7 +10,7 @@ pub enum AutoProxyScope { Channel } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum PresenceMode { Online, Busy, @@ -18,7 +18,7 @@ pub enum PresenceMode { Invisible, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum AutoproxyConfig { Member(String), Latch { @@ -28,20 +28,20 @@ pub enum AutoproxyConfig { } } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub enum AutoproxyLatchScope { Global, Server } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct PluralkitConfig { #[serde(with = "serde_regex")] pub message_pattern: Regex, pub api_token: String, } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct System { pub reference_user_id: String, pub members: Vec<Member>, @@ -55,7 +55,7 @@ fn default_forward_pings() -> bool { false } -#[derive(Deserialize)] +#[derive(Deserialize, Clone)] pub struct Member { pub name: String, #[serde(with = "serde_regex")] diff --git a/src/main.rs b/src/main.rs index 02fba86..a63c4bd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,44 +1,28 @@ mod config; -use config::Config; -use twilight_gateway::{Intents, Shard, ShardId}; +mod system; -#[tokio::main(flavor = "current_thread")] -async fn main() { - println!("Hello, world!"); +use std::thread::{self, JoinHandle}; +use tokio::runtime; - let config_str = include_str!("../config.toml"); - let config = Config::load(config_str.to_string()); - - let token = config.systems.get("ashe-test").unwrap().members - .iter().find(|member| member.name == "test").unwrap() - .discord_token.clone(); - - println!("Token: {}", token); - - let intents = Intents::GUILD_MEMBERS | Intents::GUILD_PRESENCES | Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; - - let mut shard = Shard::new(ShardId::ONE, token, intents); +use system::System; - loop { - let event = match shard.next_event().await { - Ok(event) => event, - Err(source) => { - println!("error receiving event"); - - if source.is_fatal() { - break; - } - - continue; - } - }; - - match event { - twilight_gateway::Event::MessageCreate(message) => println!("Message: {:?}", message), - twilight_gateway::Event::MessageUpdate(_) => println!("Message updated"), - twilight_gateway::Event::Ready(_) => println!("Bot ready!"), - _ => (), - } +fn main() { + let config_str = include_str!("../config.toml"); + let config = config::Config::load(config_str.to_string()); + + let handles : Vec<_> = config.systems.into_iter().map(|(system_name, system_config)| -> JoinHandle<()> { + println!("Starting thread for system {}", system_name); + thread::spawn(move || { + let runtime = runtime::Builder::new_current_thread().enable_all().build().expect("Could not construct Tokio runtime"); + runtime.block_on(async { + let mut system = System::new(system_config); + system.start_clients().await; + }) + }) + }).collect(); + + for thread_handle in handles.into_iter() { + thread_handle.join().expect("Child thread has panicked"); } } diff --git a/src/system.rs b/src/system.rs new file mode 100644 index 0000000..e707d7f --- /dev/null +++ b/src/system.rs @@ -0,0 +1,121 @@ +use std::{num::NonZeroUsize, str::FromStr}; + +use std::sync::Arc; +use tokio::task::JoinSet; +use tokio::sync::Mutex; +use twilight_http::Client; +use twilight_model::id::{Id, marker::{MessageMarker, UserMarker}}; +use twilight_gateway::{Intents, Shard, ShardId}; + +#[derive(Clone)] +pub struct System { + pub config: crate::config::System, + pub message_dedup_cache: Arc<Mutex<lru::LruCache<Id<MessageMarker>, ()>>>, +} + +impl System { + pub fn new(system_config: crate::config::System) -> Self { + System { + config: system_config, + message_dedup_cache: Arc::new(Mutex::new(lru::LruCache::new(NonZeroUsize::new(100).unwrap()))) + } + } + + pub async fn start_clients(&mut self) { + println!("Starting clients for system"); + + let reference_user_id : Id<UserMarker> = Id::from_str(self.config.reference_user_id.as_str()) + .expect(format!("Invalid user ID: {}", self.config.reference_user_id).as_str()); + + let mut set = JoinSet::new(); + + for member in self.config.members.iter() { + let token = member.discord_token.clone(); + let self_clone = self.clone(); + set.spawn(async move { + println!("Starting client with token: {}", token); + + let intents = Intents::GUILD_MEMBERS | Intents::GUILD_PRESENCES | Intents::GUILD_MESSAGES | Intents::MESSAGE_CONTENT; + + let mut shard = Shard::new(ShardId::ONE, token.clone(), intents); + let mut client = Client::new(token.clone()); + + loop { + let event = match shard.next_event().await { + Ok(event) => event, + Err(source) => { + println!("error receiving event"); + + if source.is_fatal() { + break; + } + + continue; + } + }; + + match event { + twilight_gateway::Event::Ready(client) => { + println!("Bot started for {}#{}", client.user.name, client.user.discriminator); + }, + + twilight_gateway::Event::MessageCreate(message) => { + if message.author.id != reference_user_id { + continue + } + + if self_clone.is_new_message(message.id).await { + self_clone.handle_message_create(message, &mut client).await; + } + }, + + twilight_gateway::Event::MessageUpdate(message) => { + if message.author.is_none() || message.author.as_ref().unwrap().id != reference_user_id { + continue + } + + if self_clone.is_new_message(message.id).await { + self_clone.handle_message_update(message, &mut client).await; + } + }, + + _ => (), + } + } + }); + } + + while let Some(join_result) = set.join_next().await { + if let Err(join_error) = join_result { + println!("Task encountered error: {}", join_error); + } else { + println!("Task joined cleanly"); + } + } + } + + async fn is_new_message(&self, message_id: Id<MessageMarker>) -> bool { + let mut message_cache = self.message_dedup_cache.lock().await; + if let None = message_cache.get(&message_id) { + message_cache.put(message_id, ()); + true + } else { + false + } + } + + async fn handle_message_create(&self, message: Box<twilight_model::gateway::payload::incoming::MessageCreate>, client: &mut Client) { + println!("Message created: {}", message.content); + + if let Err(err) = client.create_message(message.channel_id) + .reply(message.id) + .content(&format!("Recognized message from authorized user {}", message.author.name)).expect("Error: reply too long") + .await { + println!("Error: {}", err); + } + } + + async fn handle_message_update(&self, _message: Box<twilight_model::gateway::payload::incoming::MessageUpdate>, _client: &mut Client) { + // TODO: handle message edits and stuff + } +} |