diff options
author | Ashelyn Rose <git@ashen.earth> | 2025-02-19 15:51:28 -0700 |
---|---|---|
committer | Ashelyn Rose <git@ashen.earth> | 2025-02-19 15:51:28 -0700 |
commit | e8314458ccdf4d3c68969b206cd29f2490fb6308 (patch) | |
tree | 3eac29d969edc813ae407f0b2e114b033c5d8164 /app | |
parent | 5e8d3bc7008d29115bc520a75a9e49c00e2c270f (diff) |
Refactor oauth into its own module
Diffstat (limited to 'app')
-rw-r--r-- | app/Cargo.toml | 1 | ||||
-rw-r--r-- | app/src/lib.rs | 202 | ||||
-rw-r--r-- | app/src/oauth.rs | 239 | ||||
-rw-r--r-- | app/src/state.rs | 46 |
4 files changed, 312 insertions, 176 deletions
diff --git a/app/Cargo.toml b/app/Cargo.toml index f11ae44..31bce47 100644 --- a/app/Cargo.toml +++ b/app/Cargo.toml @@ -20,3 +20,4 @@ tauri-plugin-deep-link = "2" tauri-plugin-single-instance = {version = "2", features = ["deep-link"] } tokio = "1.43.0" url = "2.5.4" +uuid = {version="1.13.1", features= ["v4"] } diff --git a/app/src/lib.rs b/app/src/lib.rs index ae80c24..8ad149f 100644 --- a/app/src/lib.rs +++ b/app/src/lib.rs @@ -1,17 +1,18 @@ -mod state; -use state::AppState; +mod oauth; +use oauth::OAuthController; use tauri_plugin_deep_link::DeepLinkExt; use tauri_plugin_opener::OpenerExt; -use std::collections::HashMap; use tauri::{Manager, State, AppHandle}; use url::Host; -use serde::Deserialize; -use tokio::sync::Mutex; -use tokio::sync::mpsc::channel; +use uuid::Uuid; -const OAUTH_CLIENT_NAME: &'static str = "foxfleet_test"; +pub const OAUTH_CLIENT_NAME: &'static str = "foxfleet_test"; + +struct AppState { + oauth_controller: OAuthController +} #[cfg_attr(mobile, tauri::mobile_entry_point)] pub fn run() { @@ -22,7 +23,9 @@ pub fn run() { .plugin(tauri_plugin_deep_link::init()) .plugin(tauri_plugin_opener::init()) .setup(|app| { - app.manage(Mutex::new(state::AppState::default())); + app.manage(AppState { + oauth_controller: OAuthController::new() + }); #[cfg(any(target_os = "linux", all(debug_assertions, windows)))] app.deep_link().register_all()?; @@ -30,31 +33,37 @@ pub fn run() { app.deep_link().on_open_url(move |event| { if let Some(oauth_callback) = event.urls().iter().find(|url| { if let Some(Host::Domain(domain)) = url.host() { - if domain == "oauth-response" { - return true; - } + domain == "oauth-response" + } else { + false } - false }) { let mut query = oauth_callback.query_pairs(); - if let Some(code) = query.find(|(key, _value)| key == "code") { - let app_handle = app_handle.clone(); - let code = code.1.to_string(); - tauri::async_runtime::spawn(async move { - let app_state = app_handle.state::<Mutex<AppState>>(); - app_state.lock().await.accounts.iter_mut().for_each(|account| { - // TODO: handle if there's multiple of these that match - if let state::ApiCredential::Pending(sender) = &account.api_credential { - let sender = sender.clone(); - let code = code.clone(); - tauri::async_runtime::spawn(async move { - let _ = sender.send(state::AuthCode(code)).await; - }); - } + let query = ( + query.find(|(key, _)| key == "code").map(|(_, value)| value ), + query.find(|(key, _)| key == "state").map(|(_, value)| value ) + ); + let query = match query { + (Some(code), Some(state)) => Some((code, state)), + _ => None, + }; + + if let Some((auth_code, state)) = query { + if let Ok(state) = Uuid::try_parse(&state) { + let state = state.clone(); + let auth_code = auth_code.to_string().clone(); + let app_handle = app_handle.clone(); + let oauth_controller = app_handle.state::<AppState>().oauth_controller.clone(); + + tauri::async_runtime::spawn(async move { + oauth_controller.resolve_code(state, auth_code); }); - }); + } else { + println!("Invalid UUID format: {state}"); + return + } } else { - println!("No code in oauth callback"); + println!("Missing either state or code in oauth callback"); return } } @@ -66,118 +75,53 @@ pub fn run() { .expect("Error starting") } -#[derive(Deserialize)] -struct RegistrationResponse { - id: String, - name: String, - client_id: String, - client_secret: String, -} - -#[derive(Deserialize)] -struct TokenResponse { - access_token: String, - created_at: u32, - scope: String, - token_type: String, -} #[tauri::command] -async fn start_account_auth(app_handle: AppHandle, state: State<'_, Mutex<AppState>>, instance_domain: &str) -> Result<(), ()> { - println!("Starting account auth"); - let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); - let token_endpoint = format!("https://{instance_domain}/oauth/token"); - let client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); - println!("Registering client"); - let registration_response : RegistrationResponse = client.post(registration_endpoint) - .json(&HashMap::from([ - ("client_name", OAUTH_CLIENT_NAME), - ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), - ("scopes", "read write"), - ])).send().await.expect("Could not send client registration") - .json().await.expect("Could not parse client registration response"); - - // Make channel for awaiting - let (sender, mut receiver) = channel::<state::AuthCode>(1); - - println!("Saving registration"); - { state.lock().await.accounts.push(state::Account { - server_domain: instance_domain.to_string(), - handle_domain: None, - client_credential: state::ClientCredential { - client_name: OAUTH_CLIENT_NAME.to_string(), - client_id: registration_response.client_id.clone(), - client_secret: Some(registration_response.client_secret.clone()), - }, - api_credential: state::ApiCredential::Pending(sender), - }) } - - // Open browser to auth page - println!("Opening authentication page"); - let client_id = registration_response.client_id.clone(); - let auth_page = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write"); - let opener = app_handle.opener(); - if let Err(_) = opener.open_url(auth_page, None::<&str>) { - println!("Could not open authentication page"); - return Err(()) - } - - - // Wait for resolution of the credential - let auth_code = receiver.recv().await; - - if auth_code.is_none() { - return Err(()) - } - - let auth_code = auth_code.unwrap(); - println!("Exchanging auth code for API token"); - - // Get long-lived credential - let token_response : TokenResponse = client.post(token_endpoint) - .json(&HashMap::from([ - ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), - ("client_id", registration_response.client_id.as_str()), - ("client_secret", registration_response.client_secret.as_str()), - ("grant_type", "authorization_code"), - ("code", auth_code.0.as_str()), - ])).send().await.expect("Could not get API token") - .json().await.expect("Could not parse client registration response"); - - println!("Successfully exchanged for credential"); - - // Save credential - { state.lock().await.accounts.iter_mut().for_each(|account| { - if account.server_domain == instance_domain { - account.api_credential = state::ApiCredential::Some { - token: token_response.access_token.clone(), - refresh: None, +async fn start_account_auth(app_handle: AppHandle, state: State<'_, AppState>, instance_domain: &str) -> Result<Vec<String>, ()> { + let add_result = state.oauth_controller.add_server(instance_domain).await; + + let state_nonce = match add_result { + Ok(result) => { + let opener = app_handle.opener(); + if let Err(_) = opener.open_url(result.auth_url, None::<&str>) { + println!("Could not open authentication page"); + return Err(()) } - } - }) }; - println!("Saved credential"); + result.auth_state + } + Err(err) => { + println!("Error adding server: {err}"); + return Err(()) + } + }; - Ok(()) + let signin_result = state.oauth_controller.finish_signin(state_nonce).await; + match signin_result { + Ok((server_domain, username)) => { + println!("Signed in successfully"); + Ok(vec!(server_domain, username)) + } + Err(err) => { + println!("Error completing signin: {err}"); + Err(()) + } + } } #[tauri::command] -async fn get_self(state: State<'_, Mutex<AppState>>) -> Result<String, String> { +async fn get_self(state: State<'_, AppState>, server_domain: String, username: String) -> Result<String, String> { let client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); - let accounts = { state.lock().await.accounts.clone() }; - let account = accounts.iter().find(|account| { - if let state::ApiCredential::Some {token: _, refresh: _} = account.api_credential { - true - } else { - false + let api_key = state.oauth_controller.get_api_token(server_domain, username).await; + match api_key { + Err(err) => { + println!("Error getting API token: {err}"); + return Err(err) } - }); - - if let Some(account) = account { - if let state::ApiCredential::Some {token, refresh: _} = &account.api_credential { + Ok(api_key) => { if let Ok(result) = client.get("https://social.tempest.dev/api/v1/accounts/verify_credentials") - .bearer_auth(token) + .bearer_auth(api_key) .send().await { if let Ok(result) = result.text().await { return Ok(result) @@ -189,6 +133,4 @@ async fn get_self(state: State<'_, Mutex<AppState>>) -> Result<String, String> { } } } - - return Err("No logged in account".to_string()); } diff --git a/app/src/oauth.rs b/app/src/oauth.rs new file mode 100644 index 0000000..795d5dc --- /dev/null +++ b/app/src/oauth.rs @@ -0,0 +1,239 @@ +use std::{collections::HashMap, sync::Arc}; +use serde::Deserialize; +use tokio::sync::Mutex; +use tokio::sync::oneshot::{channel, Sender, Receiver}; +use uuid::Uuid; + +use crate::OAUTH_CLIENT_NAME; +#[derive(Clone)] +pub struct OAuthController (Arc<Mutex<OAuthInternal>>); + +type ServerDomain = String; +type AccountIdentifier = String; +type StateCode = Uuid; + +struct OAuthInternal { + servers: HashMap<ServerDomain, Server>, + open_callbacks: HashMap<StateCode, Callback> +} + +#[derive(Clone)] +struct Server { + client_name: String, + client_id: String, + client_secret: String, + accounts: Vec<Account>, +} + +struct Callback { + server: ServerDomain, + pkce_verifier: String, + pkce_challenge: String, + code_channel: (Option<Sender<String>>, Option<Receiver<String>>), +} + +#[derive(Clone)] +struct Account { + username: String, + auth_session: AuthSession, +} + +#[derive(Clone)] +pub struct AuthSession { + pub api_token: String, + refresh_token: Option<String>, +} + +pub struct AddServerResult { + pub auth_url: String, + pub auth_state: StateCode, +} + +impl OAuthController { + pub fn new() -> Self { + Self(Arc::new(Mutex::new(OAuthInternal { + servers: HashMap::new(), + open_callbacks: HashMap::new(), + }))) + } + + pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> { + let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + + #[derive(Deserialize)] + struct RegistrationResponse { + id: String, + name: String, + client_id: String, + client_secret: String, + } + + let registration_response : RegistrationResponse = http_client.post(registration_endpoint) + .json(&HashMap::from([ + ("client_name", OAUTH_CLIENT_NAME), + ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), + ("scopes", "read write"), + ])).send().await.expect("Could not send client registration") + .json().await.expect("Could not parse client registration response"); + + let server = Server { + client_name: registration_response.name, + client_id: registration_response.client_id.clone(), + client_secret: registration_response.client_secret, + accounts: Vec::new(), + }; + + {self.0.lock().await.servers.insert(instance_domain.to_string(), server)}; + + let (sender, receiver) = channel::<String>(); + + // TODO: PKCE params for real + let auth_state = Uuid::new_v4(); + let auth_callback = Callback { + server: instance_domain.to_string(), + pkce_verifier: String::new(), + pkce_challenge: String::new(), + code_channel: (Some(sender), Some(receiver)), + }; + + let client_id = registration_response.client_id; + + {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; + + let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}"); + + return Ok(AddServerResult { + auth_url, + auth_state, + }) + } + + pub fn resolve_code(&self, state: StateCode, auth_code: String) { + let runtime = tokio::runtime::Handle::current(); + let inner_self = self.0.clone(); + let state = state.clone(); + + runtime.spawn(async move { + if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { + let _ = sender.send(auth_code); + }; + }); + } + + pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> { + let auth_code = self.await_auth_code(&state).await?; + let auth_session = self.exchange_api_token(&state, &auth_code).await?; + let username = self.resolve_account(&state, &auth_session).await?; + + let account = Account { + username: username.clone(), + auth_session + }; + + let mut inner_state = self.0.lock().await; + + if let Some(callback) = inner_state.open_callbacks.remove(&state) { + if let Some (server) = inner_state.servers.get_mut(&callback.server) { + server.accounts.push(account); + return Ok((callback.server.clone(), username)) + } else { + Err("Unknown server URL".to_string()) + } + } else { + Err("Unknown state".to_string()) + } + } + + async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> { + let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; + if let Some(receiver) = maybe_receiver { + if let Ok(auth_code) = receiver.await { + Ok(auth_code) + } else { + Err("Error receiving code from channel".to_string()) + } + } else { + Err("Channel already awaited".to_string()) + } + } + + // TODO: Send PKCE stuff + async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> { + let maybe_server = { + let inner_state = self.0.lock().await; + if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) { + if let Some(server) = inner_state.servers.get(&server_domain) { + Some((server_domain, server.clone())) + } else { + None + } + } else { + None + } + }; + + if maybe_server.is_none() { + return Err("Unknown auth state".to_string()); + } + + let (server_domain, server) = maybe_server.unwrap(); + + #[derive(Deserialize)] + struct TokenResponse { + access_token: String, + created_at: u32, + scope: String, + token_type: String, + } + + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + let token_endpoint = format!("https://{server_domain}/oauth/token"); + let token_response : TokenResponse = http_client.post(token_endpoint) + .json(&HashMap::from([ + ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), + ("client_id", server.client_id.as_str()), + ("client_secret", server.client_secret.as_str()), + ("grant_type", "authorization_code"), + ("code", auth_code.as_str()), + ])).send().await.expect("Could not get API token") + .json().await.expect("Could not parse client registration response"); + + Ok(AuthSession { + api_token: token_response.access_token, + refresh_token: None, + }) + } + + async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> { + let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) }; + + if let Some(instance_url) = instance_url { + #[derive(Deserialize)] + struct VerifyResponse { + acct: String, + } + + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials")) + .bearer_auth(auth_session.api_token.clone()) + .send().await.expect("Could not look up account information") + .json().await.expect("Could not parse account information"); + + Ok(verify_response.acct) + } else { + Err("No instance URL".to_string()) + } + } + + pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> { + let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| { + server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone()) + }).flatten()}; + + match api_token_maybe { + Some(api_token) => Ok(api_token), + None => Err("Could not look up {server_domain} / {username}".to_string()), + } + } +} diff --git a/app/src/state.rs b/app/src/state.rs deleted file mode 100644 index 44e74ed..0000000 --- a/app/src/state.rs +++ /dev/null @@ -1,46 +0,0 @@ -use tokio::sync::mpsc::Sender; - -#[derive(Clone)] -pub struct AppState { - pub preferences: (), - pub accounts: Vec<Account>, -} - -impl AppState { - pub fn default() -> Self { - Self { - preferences: (), - accounts: Vec::new(), - - } - } -} - -#[derive(Clone)] -pub struct Account { - pub server_domain: String, - pub handle_domain: Option<String>, - pub client_credential: ClientCredential, - pub api_credential: ApiCredential, -} - -#[derive(Clone)] -pub struct ClientCredential { - pub client_name: String, - pub client_id: String, - pub client_secret: Option<String>, -} - -#[derive(Clone)] -pub struct AuthCode (pub String); - -#[derive(Clone)] -pub enum ApiCredential { - None, - Pending(Sender<AuthCode>), - Some { - token: String, - refresh: Option<String> - } -} - |