use std::{collections::HashMap, sync::Arc}; use serde::Deserialize; use tauri::async_runtime::RuntimeHandle; use tokio::sync::Mutex; use tokio::sync::oneshot::{channel, Sender, Receiver}; use url::Url; use uuid::Uuid; use crate::persistence::PersistenceController; use crate::OAUTH_CLIENT_NAME; #[derive(Clone)] pub struct OAuthController { int: Arc>, runtime: RuntimeHandle, persistence: PersistenceController, } type ServerDomain = String; type StateCode = Uuid; struct OAuthInternal { open_callbacks: HashMap } struct Callback { server_domain: ServerDomain, pkce_verifier: String, pkce_challenge: String, code_channel: (Option>, Option>), } struct AuthSession { api_token: String, refresh_token: Option, } pub struct AddServerResult { pub auth_url: String, pub auth_state: StateCode, } pub struct SigninResult { pub server_domain: String, pub username: String, } impl OAuthController { pub fn new(runtime_handle: RuntimeHandle, persistence_controller: PersistenceController) -> Self { Self { int: Arc::new(Mutex::new(OAuthInternal { open_callbacks: HashMap::new(), })), runtime: runtime_handle, persistence: persistence_controller, } } pub fn handle_deeplink(&self, urls: &Vec) -> Option<()> { let matching_url = urls.iter().find(|url| url.domain().is_some_and(|d| d == "oauth-response")); let mut query_pairs = matching_url?.query_pairs(); let code_and_state = ( query_pairs.find(|(key, _)| key == "code").map(|(_, value)| value ), query_pairs.find(|(key, _)| key == "state").map(|(_, value)| value ) ); let has_code_and_state = match code_and_state { (Some(code), Some(state)) => Some((code.to_string(), state.to_string())), _ => None, }; let (auth_code, state) = has_code_and_state?; let state = Uuid::try_parse(&state).map_err(|_| { println!("Could not parse state field from oauth callback: {state}"); () }).ok()?; self.resolve_code(state, auth_code); Some(()) } pub async fn start_authorization(&self, instance_domain: &str) -> Result { let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); let server_client = match self.persistence.get_server(&instance_domain.to_string()).await { Some(client) => client, None => { println!("Registering a new client for {instance_domain}"); #[derive(Deserialize)] struct RegistrationResponse { id: String, name: String, client_id: String, client_secret: String, } let registration_response : RegistrationResponse = http_client.post(registration_endpoint) .json(&HashMap::from([ ("client_name", OAUTH_CLIENT_NAME), ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), ("scopes", "read write"), ])).send().await.expect("Could not send client registration") .json().await.expect("Could not parse client registration response"); self.persistence.new_server(instance_domain.to_string(), registration_response.client_id.clone(), registration_response.name.to_string(), registration_response.client_secret).await?; self.persistence.get_server(&instance_domain.to_string()).await.unwrap() }, }; let (sender, receiver) = channel::(); // TODO: PKCE params for real let auth_state = Uuid::new_v4(); let auth_callback = Callback { server_domain: instance_domain.to_string(), pkce_verifier: String::new(), pkce_challenge: String::new(), code_channel: (Some(sender), Some(receiver)), }; let client_id = server_client.client_id; {self.int.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}"); return Ok(AddServerResult { auth_url, auth_state, }) } fn resolve_code(&self, state: StateCode, auth_code: String) { let runtime = self.runtime.clone(); let inner_self = self.int.clone(); runtime.spawn(async move { if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { let _ = sender.send(auth_code); }; }); } pub async fn finish_authorization(&self, state: StateCode) -> Result { let domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())} .ok_or(format!("No callback code {state}"))?; let auth_code = self.await_auth_code(&state).await?; let auth_session = self.exchange_api_token(&state, &auth_code).await?; let username = self.resolve_account(&state, &auth_session).await?; self.persistence.new_account(username.clone(), domain.clone(), auth_session.api_token).await?; // Remove state callback record {self.int.lock().await.open_callbacks.remove(&state)}; return Ok(SigninResult { server_domain: domain, username }) } async fn await_auth_code(&self, state: &StateCode) -> Result { let maybe_receiver = {self.int.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; if let Some(receiver) = maybe_receiver { if let Ok(auth_code) = receiver.await { Ok(auth_code) } else { Err("Error receiving code from channel".to_string()) } } else { Err("Channel already awaited".to_string()) } } // TODO: Send PKCE stuff async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result { let server_domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())} .ok_or(format!("No callback code {state}"))?; let server_client = self.persistence.get_server(&server_domain).await .ok_or("Could not look up client for server")?; let client_secret = self.persistence.get_credential(server_client.client_credential_id).await?; #[derive(Deserialize)] struct TokenResponse { access_token: String, created_at: u32, scope: String, token_type: String, } let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); let token_endpoint = format!("https://{server_domain}/oauth/token"); let token_response : TokenResponse = http_client.post(token_endpoint) .json(&HashMap::from([ ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), ("client_id", server_client.client_id.as_str()), ("client_secret", client_secret.as_str()), ("grant_type", "authorization_code"), ("code", auth_code.as_str()), ])).send().await.expect("Could not get API token") .json().await.expect("Could not parse client registration response"); Ok(AuthSession { api_token: token_response.access_token, refresh_token: None, }) } async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result { let instance_url = { self.int.lock().await.open_callbacks.get(&state).map(|callback| callback.server_domain.clone()) }; if let Some(instance_url) = instance_url { #[derive(Deserialize)] struct VerifyResponse { acct: String, } let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials")) .bearer_auth(auth_session.api_token.clone()) .send().await.expect("Could not look up account information") .json().await.expect("Could not parse account information"); Ok(verify_response.acct) } else { Err("No instance URL".to_string()) } } }