use std::{collections::HashMap, sync::Arc}; use serde::Deserialize; use tokio::sync::Mutex; use tokio::sync::oneshot::{channel, Sender, Receiver}; use uuid::Uuid; use crate::OAUTH_CLIENT_NAME; #[derive(Clone)] pub struct OAuthController (Arc>); type ServerDomain = String; type AccountIdentifier = String; type StateCode = Uuid; struct OAuthInternal { servers: HashMap, open_callbacks: HashMap } #[derive(Clone)] struct Server { client_name: String, client_id: String, client_secret: String, accounts: Vec, } struct Callback { server: ServerDomain, pkce_verifier: String, pkce_challenge: String, code_channel: (Option>, Option>), } #[derive(Clone)] struct Account { username: String, auth_session: AuthSession, } #[derive(Clone)] pub struct AuthSession { pub api_token: String, refresh_token: Option, } pub struct AddServerResult { pub auth_url: String, pub auth_state: StateCode, } impl OAuthController { pub fn new() -> Self { Self(Arc::new(Mutex::new(OAuthInternal { servers: HashMap::new(), open_callbacks: HashMap::new(), }))) } pub async fn add_server(&self, instance_domain: &str) -> Result { let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); #[derive(Deserialize)] struct RegistrationResponse { id: String, name: String, client_id: String, client_secret: String, } let registration_response : RegistrationResponse = http_client.post(registration_endpoint) .json(&HashMap::from([ ("client_name", OAUTH_CLIENT_NAME), ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), ("scopes", "read write"), ])).send().await.expect("Could not send client registration") .json().await.expect("Could not parse client registration response"); let server = Server { client_name: registration_response.name, client_id: registration_response.client_id.clone(), client_secret: registration_response.client_secret, accounts: Vec::new(), }; {self.0.lock().await.servers.insert(instance_domain.to_string(), server)}; let (sender, receiver) = channel::(); // TODO: PKCE params for real let auth_state = Uuid::new_v4(); let auth_callback = Callback { server: instance_domain.to_string(), pkce_verifier: String::new(), pkce_challenge: String::new(), code_channel: (Some(sender), Some(receiver)), }; let client_id = registration_response.client_id; {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}"); return Ok(AddServerResult { auth_url, auth_state, }) } pub fn resolve_code(&self, state: StateCode, auth_code: String) { let runtime = tokio::runtime::Handle::current(); let inner_self = self.0.clone(); let state = state.clone(); runtime.spawn(async move { if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { let _ = sender.send(auth_code); }; }); } pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> { let auth_code = self.await_auth_code(&state).await?; let auth_session = self.exchange_api_token(&state, &auth_code).await?; let username = self.resolve_account(&state, &auth_session).await?; let account = Account { username: username.clone(), auth_session }; let mut inner_state = self.0.lock().await; if let Some(callback) = inner_state.open_callbacks.remove(&state) { if let Some (server) = inner_state.servers.get_mut(&callback.server) { server.accounts.push(account); return Ok((callback.server.clone(), username)) } else { Err("Unknown server URL".to_string()) } } else { Err("Unknown state".to_string()) } } async fn await_auth_code(&self, state: &StateCode) -> Result { let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; if let Some(receiver) = maybe_receiver { if let Ok(auth_code) = receiver.await { Ok(auth_code) } else { Err("Error receiving code from channel".to_string()) } } else { Err("Channel already awaited".to_string()) } } // TODO: Send PKCE stuff async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result { let maybe_server = { let inner_state = self.0.lock().await; if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) { if let Some(server) = inner_state.servers.get(&server_domain) { Some((server_domain, server.clone())) } else { None } } else { None } }; if maybe_server.is_none() { return Err("Unknown auth state".to_string()); } let (server_domain, server) = maybe_server.unwrap(); #[derive(Deserialize)] struct TokenResponse { access_token: String, created_at: u32, scope: String, token_type: String, } let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); let token_endpoint = format!("https://{server_domain}/oauth/token"); let token_response : TokenResponse = http_client.post(token_endpoint) .json(&HashMap::from([ ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), ("client_id", server.client_id.as_str()), ("client_secret", server.client_secret.as_str()), ("grant_type", "authorization_code"), ("code", auth_code.as_str()), ])).send().await.expect("Could not get API token") .json().await.expect("Could not parse client registration response"); Ok(AuthSession { api_token: token_response.access_token, refresh_token: None, }) } async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result { let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) }; if let Some(instance_url) = instance_url { #[derive(Deserialize)] struct VerifyResponse { acct: String, } let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials")) .bearer_auth(auth_session.api_token.clone()) .send().await.expect("Could not look up account information") .json().await.expect("Could not parse account information"); Ok(verify_response.acct) } else { Err("No instance URL".to_string()) } } pub async fn get_api_token(&self, server_domain: String, username: String) -> Result { let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| { server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone()) }).flatten()}; match api_token_maybe { Some(api_token) => Ok(api_token), None => Err("Could not look up {server_domain} / {username}".to_string()), } } }