summary refs log tree commit diff
path: root/app/src
diff options
context:
space:
mode:
Diffstat (limited to 'app/src')
-rw-r--r--app/src/lib.rs38
-rw-r--r--app/src/oauth.rs29
2 files changed, 31 insertions, 36 deletions
diff --git a/app/src/lib.rs b/app/src/lib.rs
index 9823cdd..98ada19 100644
--- a/app/src/lib.rs
+++ b/app/src/lib.rs
@@ -23,46 +23,16 @@ pub fn run() {
         .plugin(tauri_plugin_deep_link::init())
         .plugin(tauri_plugin_opener::init())
         .setup(|app| {
+            let oauth_controller = OAuthController::new(tauri::async_runtime::handle());
             app.manage(AppState {
-                oauth_controller: OAuthController::new(tauri::async_runtime::handle())
+                oauth_controller: oauth_controller.clone() 
             });
 
             #[cfg(any(target_os = "linux", all(debug_assertions, windows)))]
             app.deep_link().register_all()?;
-            let app_handle = app.handle().clone();
             app.deep_link().on_open_url(move |event| {
-                if let Some(oauth_callback) = event.urls().iter().find(|url| {
-                    if let Some(Host::Domain(domain)) = url.host() {
-                        domain == "oauth-response"
-                    } else {
-                        false
-                    }
-                }) {
-                    let mut query = oauth_callback.query_pairs();
-                    let query = (
-                        query.find(|(key, _)| key == "code").map(|(_, value)| value ),
-                        query.find(|(key, _)| key == "state").map(|(_, value)| value )
-                    );
-                    let query = match query {
-                        (Some(code), Some(state)) => Some((code, state)),
-                        _ => None,
-                    };
-
-                    if let Some((auth_code, state)) = query {
-                        if let Ok(state) = Uuid::try_parse(&state) {
-                            let state = state.clone();
-                            let auth_code = auth_code.to_string().clone();
-                            let oauth_controller = &app_handle.state::<AppState>().oauth_controller;
-
-                            oauth_controller.resolve_code(state, auth_code);
-                        } else {
-                            println!("Invalid UUID format: {state}");
-                            return
-                        }
-                    } else {
-                        println!("Missing either state or code in oauth callback");
-                        return
-                    }
+                if oauth_controller.handle_deeplink(&event.urls()).is_some() {
+                    return
                 }
             });
             Ok(())
diff --git a/app/src/oauth.rs b/app/src/oauth.rs
index 9334eac..657db30 100644
--- a/app/src/oauth.rs
+++ b/app/src/oauth.rs
@@ -3,6 +3,7 @@ use serde::Deserialize;
 use tauri::async_runtime::RuntimeHandle;
 use tokio::sync::Mutex;
 use tokio::sync::oneshot::{channel, Sender, Receiver};
+use url::{Host, Url};
 use uuid::Uuid;
 
 use crate::OAUTH_CLIENT_NAME;
@@ -58,6 +59,31 @@ impl OAuthController {
         })), runtime_handle)
     }
 
+    pub fn handle_deeplink(&self, urls: &Vec<Url>) -> Option<()> {
+        let matching_url = urls.iter().find(|url| url.domain().is_some_and(|d| d == "oauth-response"));
+        let mut query_pairs = matching_url?.query_pairs();
+
+        let code_and_state = (
+            query_pairs.find(|(key, _)| key == "code").map(|(_, value)| value ),
+            query_pairs.find(|(key, _)| key == "state").map(|(_, value)| value )
+        );
+
+        let has_code_and_state = match code_and_state {
+            (Some(code), Some(state)) => Some((code.to_string(), state.to_string())),
+            _ => None,
+        };
+
+        let (auth_code, state) = has_code_and_state?;
+        let state = Uuid::try_parse(&state).map_err(|_| {
+            println!("Could not parse state field from oauth callback: {state}");
+            ()
+        }).ok()?;
+
+        self.resolve_code(state, auth_code);
+        Some(())
+    }
+
+
     pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> {
         let registration_endpoint = format!("https://{instance_domain}/api/v1/apps");
         let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
@@ -110,10 +136,9 @@ impl OAuthController {
         })
     }
 
-    pub fn resolve_code(&self, state: StateCode, auth_code: String) {
+    fn resolve_code(&self, state: StateCode, auth_code: String) {
         let runtime = self.1.clone();
         let inner_self = self.0.clone();
-        let state = state.clone();
 
         runtime.spawn(async move {
             if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() {