summary refs log tree commit diff
path: root/app
diff options
context:
space:
mode:
authorAshelyn Rose <git@ashen.earth>2025-02-19 15:51:28 -0700
committerAshelyn Rose <git@ashen.earth>2025-02-19 15:51:28 -0700
commite8314458ccdf4d3c68969b206cd29f2490fb6308 (patch)
tree3eac29d969edc813ae407f0b2e114b033c5d8164 /app
parent5e8d3bc7008d29115bc520a75a9e49c00e2c270f (diff)
Refactor oauth into its own module
Diffstat (limited to 'app')
-rw-r--r--app/Cargo.toml1
-rw-r--r--app/src/lib.rs202
-rw-r--r--app/src/oauth.rs239
-rw-r--r--app/src/state.rs46
4 files changed, 312 insertions, 176 deletions
diff --git a/app/Cargo.toml b/app/Cargo.toml
index f11ae44..31bce47 100644
--- a/app/Cargo.toml
+++ b/app/Cargo.toml
@@ -20,3 +20,4 @@ tauri-plugin-deep-link = "2"
 tauri-plugin-single-instance = {version = "2", features = ["deep-link"] }
 tokio = "1.43.0"
 url = "2.5.4"
+uuid = {version="1.13.1", features= ["v4"] }
diff --git a/app/src/lib.rs b/app/src/lib.rs
index ae80c24..8ad149f 100644
--- a/app/src/lib.rs
+++ b/app/src/lib.rs
@@ -1,17 +1,18 @@
-mod state;
-use state::AppState;
+mod oauth;
+use oauth::OAuthController;
 
 use tauri_plugin_deep_link::DeepLinkExt;
 use tauri_plugin_opener::OpenerExt;
 
-use std::collections::HashMap;
 use tauri::{Manager, State, AppHandle};
 use url::Host;
-use serde::Deserialize;
-use tokio::sync::Mutex;
-use tokio::sync::mpsc::channel;
+use uuid::Uuid;
 
-const OAUTH_CLIENT_NAME: &'static str = "foxfleet_test";
+pub const OAUTH_CLIENT_NAME: &'static str = "foxfleet_test";
+
+struct AppState {
+    oauth_controller: OAuthController
+}
 
 #[cfg_attr(mobile, tauri::mobile_entry_point)]
 pub fn run() {
@@ -22,7 +23,9 @@ pub fn run() {
         .plugin(tauri_plugin_deep_link::init())
         .plugin(tauri_plugin_opener::init())
         .setup(|app| {
-            app.manage(Mutex::new(state::AppState::default()));
+            app.manage(AppState {
+                oauth_controller: OAuthController::new()
+            });
 
             #[cfg(any(target_os = "linux", all(debug_assertions, windows)))]
             app.deep_link().register_all()?;
@@ -30,31 +33,37 @@ pub fn run() {
             app.deep_link().on_open_url(move |event| {
                 if let Some(oauth_callback) = event.urls().iter().find(|url| {
                     if let Some(Host::Domain(domain)) = url.host() {
-                        if domain == "oauth-response" {
-                            return true;
-                        }
+                        domain == "oauth-response"
+                    } else {
+                        false
                     }
-                    false
                 }) {
                     let mut query = oauth_callback.query_pairs();
-                    if let Some(code) = query.find(|(key, _value)| key == "code") {
-                        let app_handle = app_handle.clone();
-                        let code = code.1.to_string();
-                        tauri::async_runtime::spawn(async move {
-                            let app_state = app_handle.state::<Mutex<AppState>>();
-                            app_state.lock().await.accounts.iter_mut().for_each(|account| {
-                                // TODO: handle if there's multiple of these that match
-                                if let state::ApiCredential::Pending(sender) = &account.api_credential {
-                                    let sender = sender.clone();
-                                    let code = code.clone();
-                                    tauri::async_runtime::spawn(async move {
-                                        let _ = sender.send(state::AuthCode(code)).await;
-                                    });
-                                }
+                    let query = (
+                        query.find(|(key, _)| key == "code").map(|(_, value)| value ),
+                        query.find(|(key, _)| key == "state").map(|(_, value)| value )
+                    );
+                    let query = match query {
+                        (Some(code), Some(state)) => Some((code, state)),
+                        _ => None,
+                    };
+
+                    if let Some((auth_code, state)) = query {
+                        if let Ok(state) = Uuid::try_parse(&state) {
+                            let state = state.clone();
+                            let auth_code = auth_code.to_string().clone();
+                            let app_handle = app_handle.clone();
+                            let oauth_controller = app_handle.state::<AppState>().oauth_controller.clone();
+
+                            tauri::async_runtime::spawn(async move {
+                                oauth_controller.resolve_code(state, auth_code);
                             });
-                        });
+                        } else {
+                            println!("Invalid UUID format: {state}");
+                            return
+                        }
                     } else {
-                        println!("No code in oauth callback");
+                        println!("Missing either state or code in oauth callback");
                         return
                     }
                 }
@@ -66,118 +75,53 @@ pub fn run() {
         .expect("Error starting")
 }
 
-#[derive(Deserialize)]
-struct RegistrationResponse {
-    id: String,
-    name: String,
-    client_id: String,
-    client_secret: String,
-}
-
-#[derive(Deserialize)]
-struct TokenResponse {
-    access_token: String,
-    created_at: u32,
-    scope: String,
-    token_type: String,
-}
 
 #[tauri::command]
-async fn start_account_auth(app_handle: AppHandle, state: State<'_, Mutex<AppState>>, instance_domain: &str) -> Result<(), ()> {
-    println!("Starting account auth");
-    let registration_endpoint = format!("https://{instance_domain}/api/v1/apps");
-    let token_endpoint = format!("https://{instance_domain}/oauth/token");
-    let client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
-    println!("Registering client");
-    let registration_response : RegistrationResponse = client.post(registration_endpoint)
-        .json(&HashMap::from([
-            ("client_name", OAUTH_CLIENT_NAME),
-            ("redirect_uris", "dev.tempest.foxfleet://oauth-response"),
-            ("scopes", "read write"),
-        ])).send().await.expect("Could not send client registration")
-        .json().await.expect("Could not parse client registration response");
-
-    // Make channel for awaiting
-    let (sender, mut receiver) = channel::<state::AuthCode>(1);
-
-    println!("Saving registration");
-    { state.lock().await.accounts.push(state::Account {
-        server_domain: instance_domain.to_string(),
-        handle_domain: None,
-        client_credential: state::ClientCredential {
-            client_name: OAUTH_CLIENT_NAME.to_string(),
-            client_id: registration_response.client_id.clone(),
-            client_secret: Some(registration_response.client_secret.clone()),
-        },
-        api_credential: state::ApiCredential::Pending(sender),
-    }) }
-
-    // Open browser to auth page
-    println!("Opening authentication page");
-    let client_id = registration_response.client_id.clone();
-    let auth_page = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write");
-    let opener = app_handle.opener();
-    if let Err(_) = opener.open_url(auth_page, None::<&str>) {
-        println!("Could not open authentication page");
-        return Err(())
-    }
-
-
-    // Wait for resolution of the credential
-    let auth_code = receiver.recv().await;
-
-    if auth_code.is_none() {
-        return Err(())
-    }
-
-    let auth_code = auth_code.unwrap();
-    println!("Exchanging auth code for API token");
-
-    // Get long-lived credential
-    let token_response : TokenResponse = client.post(token_endpoint)
-        .json(&HashMap::from([
-            ("redirect_uri", "dev.tempest.foxfleet://oauth-response"),
-            ("client_id", registration_response.client_id.as_str()),
-            ("client_secret", registration_response.client_secret.as_str()),
-            ("grant_type", "authorization_code"),
-            ("code", auth_code.0.as_str()),
-        ])).send().await.expect("Could not get API token")
-        .json().await.expect("Could not parse client registration response");
-
-    println!("Successfully exchanged for credential");
-
-    // Save credential
-    { state.lock().await.accounts.iter_mut().for_each(|account| {
-        if account.server_domain == instance_domain {
-            account.api_credential = state::ApiCredential::Some {
-                token: token_response.access_token.clone(),
-                refresh: None,
+async fn start_account_auth(app_handle: AppHandle, state: State<'_, AppState>, instance_domain: &str) -> Result<Vec<String>, ()> {
+    let add_result = state.oauth_controller.add_server(instance_domain).await;
+
+    let state_nonce = match add_result {
+        Ok(result) => {
+            let opener = app_handle.opener();
+            if let Err(_) = opener.open_url(result.auth_url, None::<&str>) {
+                println!("Could not open authentication page");
+                return Err(())
             }
-        }
-    }) };
 
-    println!("Saved credential");
+            result.auth_state
+        }
+        Err(err) => {
+            println!("Error adding server: {err}");
+            return Err(())
+        }
+    };
 
-    Ok(())
+    let signin_result = state.oauth_controller.finish_signin(state_nonce).await;
+    match signin_result {
+        Ok((server_domain, username)) => {
+            println!("Signed in successfully");
+            Ok(vec!(server_domain, username))
+        }
+        Err(err) => {
+            println!("Error completing signin: {err}");
+            Err(())
+        }
+    }
 }
 
 #[tauri::command]
-async fn get_self(state: State<'_, Mutex<AppState>>) -> Result<String, String> {
+async fn get_self(state: State<'_, AppState>, server_domain: String, username: String) -> Result<String, String> {
     let client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
 
-    let accounts = { state.lock().await.accounts.clone() };
-    let account = accounts.iter().find(|account| {
-        if let state::ApiCredential::Some {token: _, refresh: _} = account.api_credential {
-            true
-        } else {
-            false
+    let api_key = state.oauth_controller.get_api_token(server_domain, username).await;
+    match api_key {
+        Err(err) => {
+            println!("Error getting API token: {err}");
+            return Err(err)
         }
-    });
-
-    if let Some(account) = account {
-        if let state::ApiCredential::Some {token, refresh: _} = &account.api_credential {
+        Ok(api_key) => {
             if let Ok(result) = client.get("https://social.tempest.dev/api/v1/accounts/verify_credentials")
-                .bearer_auth(token)
+                .bearer_auth(api_key)
                 .send().await {
                     if let Ok(result) = result.text().await {
                         return Ok(result)
@@ -189,6 +133,4 @@ async fn get_self(state: State<'_, Mutex<AppState>>) -> Result<String, String> {
                 }
         }
     }
-
-    return Err("No logged in account".to_string());
 }
diff --git a/app/src/oauth.rs b/app/src/oauth.rs
new file mode 100644
index 0000000..795d5dc
--- /dev/null
+++ b/app/src/oauth.rs
@@ -0,0 +1,239 @@
+use std::{collections::HashMap, sync::Arc};
+use serde::Deserialize;
+use tokio::sync::Mutex;
+use tokio::sync::oneshot::{channel, Sender, Receiver};
+use uuid::Uuid;
+
+use crate::OAUTH_CLIENT_NAME;
+#[derive(Clone)]
+pub struct OAuthController (Arc<Mutex<OAuthInternal>>);
+
+type ServerDomain = String;
+type AccountIdentifier = String;
+type StateCode = Uuid;
+
+struct OAuthInternal {
+    servers: HashMap<ServerDomain, Server>,
+    open_callbacks: HashMap<StateCode, Callback>
+}
+
+#[derive(Clone)]
+struct Server {
+    client_name: String,
+    client_id: String,
+    client_secret: String,
+    accounts: Vec<Account>,
+}
+
+struct Callback {
+    server: ServerDomain,
+    pkce_verifier: String,
+    pkce_challenge: String,
+    code_channel: (Option<Sender<String>>, Option<Receiver<String>>),
+}
+
+#[derive(Clone)]
+struct Account {
+    username: String,
+    auth_session: AuthSession,
+}
+
+#[derive(Clone)]
+pub struct AuthSession {
+    pub api_token: String,
+    refresh_token: Option<String>,
+}
+
+pub struct AddServerResult {
+    pub auth_url: String,
+    pub auth_state: StateCode,
+}
+
+impl OAuthController {
+    pub fn new() -> Self {
+        Self(Arc::new(Mutex::new(OAuthInternal {
+            servers: HashMap::new(),
+            open_callbacks: HashMap::new(),
+        })))
+    }
+
+    pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> {
+        let registration_endpoint = format!("https://{instance_domain}/api/v1/apps");
+        let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+
+        #[derive(Deserialize)]
+        struct RegistrationResponse {
+            id: String,
+            name: String,
+            client_id: String,
+            client_secret: String,
+        }
+
+        let registration_response : RegistrationResponse = http_client.post(registration_endpoint)
+            .json(&HashMap::from([
+                ("client_name", OAUTH_CLIENT_NAME),
+                ("redirect_uris", "dev.tempest.foxfleet://oauth-response"),
+                ("scopes", "read write"),
+            ])).send().await.expect("Could not send client registration")
+            .json().await.expect("Could not parse client registration response");
+
+        let server = Server {
+            client_name: registration_response.name,
+            client_id: registration_response.client_id.clone(),
+            client_secret: registration_response.client_secret,
+            accounts: Vec::new(),
+        };
+
+        {self.0.lock().await.servers.insert(instance_domain.to_string(), server)};
+
+        let (sender, receiver) = channel::<String>();
+
+        // TODO: PKCE params for real
+        let auth_state = Uuid::new_v4();
+        let auth_callback = Callback {
+            server: instance_domain.to_string(),
+            pkce_verifier: String::new(),
+            pkce_challenge: String::new(),
+            code_channel: (Some(sender), Some(receiver)),
+        };
+
+        let client_id = registration_response.client_id;
+
+        {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)};
+
+        let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}");
+
+        return Ok(AddServerResult {
+            auth_url,
+            auth_state,
+        })
+    }
+
+    pub fn resolve_code(&self, state: StateCode, auth_code: String) {
+        let runtime = tokio::runtime::Handle::current();
+        let inner_self = self.0.clone();
+        let state = state.clone();
+
+        runtime.spawn(async move {
+            if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() {
+                let _ = sender.send(auth_code);
+            };
+        });
+    }
+
+    pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> {
+        let auth_code = self.await_auth_code(&state).await?;
+        let auth_session = self.exchange_api_token(&state, &auth_code).await?;
+        let username = self.resolve_account(&state, &auth_session).await?;
+
+        let account = Account {
+            username: username.clone(),
+            auth_session
+        };
+
+        let mut inner_state = self.0.lock().await;
+
+        if let Some(callback) = inner_state.open_callbacks.remove(&state) {
+            if let Some (server) = inner_state.servers.get_mut(&callback.server) {
+                server.accounts.push(account);
+                return Ok((callback.server.clone(), username))
+            } else {
+                Err("Unknown server URL".to_string())
+            }
+        } else {
+            Err("Unknown state".to_string())
+        }
+    }
+
+    async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> {
+        let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()};
+        if let Some(receiver) = maybe_receiver {
+            if let Ok(auth_code) = receiver.await {
+                Ok(auth_code)
+            } else {
+                Err("Error receiving code from channel".to_string())
+            }
+        } else {
+            Err("Channel already awaited".to_string())
+        }
+    }
+
+    // TODO: Send PKCE stuff
+    async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> {
+        let maybe_server = {
+            let inner_state = self.0.lock().await;
+            if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) {
+                if let Some(server) = inner_state.servers.get(&server_domain) {
+                    Some((server_domain, server.clone()))
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        };
+
+        if maybe_server.is_none() {
+            return Err("Unknown auth state".to_string());
+        }
+
+        let (server_domain, server) = maybe_server.unwrap();
+
+        #[derive(Deserialize)]
+        struct TokenResponse {
+            access_token: String,
+            created_at: u32,
+            scope: String,
+            token_type: String,
+        }
+
+        let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+        let token_endpoint = format!("https://{server_domain}/oauth/token");
+        let token_response : TokenResponse = http_client.post(token_endpoint)
+            .json(&HashMap::from([
+                ("redirect_uri", "dev.tempest.foxfleet://oauth-response"),
+                ("client_id", server.client_id.as_str()),
+                ("client_secret", server.client_secret.as_str()),
+                ("grant_type", "authorization_code"),
+                ("code", auth_code.as_str()),
+            ])).send().await.expect("Could not get API token")
+            .json().await.expect("Could not parse client registration response");
+
+        Ok(AuthSession {
+            api_token: token_response.access_token,
+            refresh_token: None,
+        })
+    }
+
+    async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> {
+        let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) };
+
+        if let Some(instance_url) = instance_url {
+            #[derive(Deserialize)]
+            struct VerifyResponse {
+                acct: String,
+            }
+
+            let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+            let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials"))
+                .bearer_auth(auth_session.api_token.clone())
+                .send().await.expect("Could not look up account information")
+                .json().await.expect("Could not parse account information");
+
+            Ok(verify_response.acct)
+        } else {
+            Err("No instance URL".to_string())
+        }
+    }
+
+    pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> {
+        let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| {
+            server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone())
+        }).flatten()};
+
+        match api_token_maybe {
+            Some(api_token) => Ok(api_token),
+            None => Err("Could not look up {server_domain} / {username}".to_string()),
+        }
+    }
+}
diff --git a/app/src/state.rs b/app/src/state.rs
deleted file mode 100644
index 44e74ed..0000000
--- a/app/src/state.rs
+++ /dev/null
@@ -1,46 +0,0 @@
-use tokio::sync::mpsc::Sender;
-
-#[derive(Clone)]
-pub struct AppState {
-    pub preferences: (),
-    pub accounts: Vec<Account>,
-}
-
-impl AppState {
-    pub fn default() -> Self {
-        Self {
-            preferences: (),
-            accounts: Vec::new(),
-
-        }
-    }
-}
-
-#[derive(Clone)]
-pub struct Account {
-    pub server_domain: String,
-    pub handle_domain: Option<String>,
-    pub client_credential: ClientCredential,
-    pub api_credential: ApiCredential,
-}
-
-#[derive(Clone)]
-pub struct ClientCredential {
-    pub client_name: String,
-    pub client_id: String,
-    pub client_secret: Option<String>,
-}
-
-#[derive(Clone)]
-pub struct AuthCode (pub String);
-
-#[derive(Clone)]
-pub enum ApiCredential {
-    None,
-    Pending(Sender<AuthCode>),
-    Some {
-        token: String,
-        refresh: Option<String>
-    }
-}
-