diff options
-rw-r--r-- | app/src/lib.rs | 38 | ||||
-rw-r--r-- | app/src/oauth.rs | 29 |
2 files changed, 31 insertions, 36 deletions
diff --git a/app/src/lib.rs b/app/src/lib.rs index 9823cdd..98ada19 100644 --- a/app/src/lib.rs +++ b/app/src/lib.rs @@ -23,46 +23,16 @@ pub fn run() { .plugin(tauri_plugin_deep_link::init()) .plugin(tauri_plugin_opener::init()) .setup(|app| { + let oauth_controller = OAuthController::new(tauri::async_runtime::handle()); app.manage(AppState { - oauth_controller: OAuthController::new(tauri::async_runtime::handle()) + oauth_controller: oauth_controller.clone() }); #[cfg(any(target_os = "linux", all(debug_assertions, windows)))] app.deep_link().register_all()?; - let app_handle = app.handle().clone(); app.deep_link().on_open_url(move |event| { - if let Some(oauth_callback) = event.urls().iter().find(|url| { - if let Some(Host::Domain(domain)) = url.host() { - domain == "oauth-response" - } else { - false - } - }) { - let mut query = oauth_callback.query_pairs(); - let query = ( - query.find(|(key, _)| key == "code").map(|(_, value)| value ), - query.find(|(key, _)| key == "state").map(|(_, value)| value ) - ); - let query = match query { - (Some(code), Some(state)) => Some((code, state)), - _ => None, - }; - - if let Some((auth_code, state)) = query { - if let Ok(state) = Uuid::try_parse(&state) { - let state = state.clone(); - let auth_code = auth_code.to_string().clone(); - let oauth_controller = &app_handle.state::<AppState>().oauth_controller; - - oauth_controller.resolve_code(state, auth_code); - } else { - println!("Invalid UUID format: {state}"); - return - } - } else { - println!("Missing either state or code in oauth callback"); - return - } + if oauth_controller.handle_deeplink(&event.urls()).is_some() { + return } }); Ok(()) diff --git a/app/src/oauth.rs b/app/src/oauth.rs index 9334eac..657db30 100644 --- a/app/src/oauth.rs +++ b/app/src/oauth.rs @@ -3,6 +3,7 @@ use serde::Deserialize; use tauri::async_runtime::RuntimeHandle; use tokio::sync::Mutex; use tokio::sync::oneshot::{channel, Sender, Receiver}; +use url::{Host, Url}; use uuid::Uuid; use crate::OAUTH_CLIENT_NAME; @@ -58,6 +59,31 @@ impl OAuthController { })), runtime_handle) } + pub fn handle_deeplink(&self, urls: &Vec<Url>) -> Option<()> { + let matching_url = urls.iter().find(|url| url.domain().is_some_and(|d| d == "oauth-response")); + let mut query_pairs = matching_url?.query_pairs(); + + let code_and_state = ( + query_pairs.find(|(key, _)| key == "code").map(|(_, value)| value ), + query_pairs.find(|(key, _)| key == "state").map(|(_, value)| value ) + ); + + let has_code_and_state = match code_and_state { + (Some(code), Some(state)) => Some((code.to_string(), state.to_string())), + _ => None, + }; + + let (auth_code, state) = has_code_and_state?; + let state = Uuid::try_parse(&state).map_err(|_| { + println!("Could not parse state field from oauth callback: {state}"); + () + }).ok()?; + + self.resolve_code(state, auth_code); + Some(()) + } + + pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> { let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); @@ -110,10 +136,9 @@ impl OAuthController { }) } - pub fn resolve_code(&self, state: StateCode, auth_code: String) { + fn resolve_code(&self, state: StateCode, auth_code: String) { let runtime = self.1.clone(); let inner_self = self.0.clone(); - let state = state.clone(); runtime.spawn(async move { if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { |