summary refs log tree commit diff
path: root/app/src/oauth.rs
diff options
context:
space:
mode:
authorAshelyn Rose <git@ashen.earth>2025-02-22 20:23:43 -0700
committerAshelyn Rose <git@ashen.earth>2025-02-22 20:23:43 -0700
commitddbef5d475951dfd9157221b611e7d1ac06da86b (patch)
treef143f1dfc0b61c97aeb513c9bfa540260b3e753d /app/src/oauth.rs
parent5c6a049c4c962be7bf889897b16a1778bbe63819 (diff)
Incredibly messy persistence refactor
Stores account information in the app data folder, and account credentials in the OS keychain
Diffstat (limited to 'app/src/oauth.rs')
-rw-r--r--app/src/oauth.rs177
1 files changed, 74 insertions, 103 deletions
diff --git a/app/src/oauth.rs b/app/src/oauth.rs
index 657db30..1b0661c 100644
--- a/app/src/oauth.rs
+++ b/app/src/oauth.rs
@@ -3,46 +3,35 @@ use serde::Deserialize;
 use tauri::async_runtime::RuntimeHandle;
 use tokio::sync::Mutex;
 use tokio::sync::oneshot::{channel, Sender, Receiver};
-use url::{Host, Url};
+use url::Url;
 use uuid::Uuid;
 
+use crate::persistence::PersistenceController;
 use crate::OAUTH_CLIENT_NAME;
+
 #[derive(Clone)]
-pub struct OAuthController (Arc<Mutex<OAuthInternal>>, RuntimeHandle);
+pub struct OAuthController {
+    int: Arc<Mutex<OAuthInternal>>,
+    runtime: RuntimeHandle,
+    persistence: PersistenceController,
+}
 
 type ServerDomain = String;
-type AccountIdentifier = String;
 type StateCode = Uuid;
 
 struct OAuthInternal {
-    servers: HashMap<ServerDomain, Server>,
     open_callbacks: HashMap<StateCode, Callback>
 }
 
-#[derive(Clone)]
-struct Server {
-    client_name: String,
-    client_id: String,
-    client_secret: String,
-    accounts: Vec<Account>,
-}
-
 struct Callback {
-    server: ServerDomain,
+    server_domain: ServerDomain,
     pkce_verifier: String,
     pkce_challenge: String,
     code_channel: (Option<Sender<String>>, Option<Receiver<String>>),
 }
 
-#[derive(Clone)]
-struct Account {
-    username: String,
-    auth_session: AuthSession,
-}
-
-#[derive(Clone)]
-pub struct AuthSession {
-    pub api_token: String,
+struct AuthSession {
+    api_token: String,
     refresh_token: Option<String>,
 }
 
@@ -51,12 +40,20 @@ pub struct AddServerResult {
     pub auth_state: StateCode,
 }
 
+pub struct SigninResult {
+    pub server_domain: String,
+    pub username: String,
+}
+
 impl OAuthController {
-    pub fn new(runtime_handle: RuntimeHandle) -> Self {
-        Self(Arc::new(Mutex::new(OAuthInternal {
-            servers: HashMap::new(),
-            open_callbacks: HashMap::new(),
-        })), runtime_handle)
+    pub fn new(runtime_handle: RuntimeHandle, persistence_controller: PersistenceController) -> Self {
+        Self {
+            int: Arc::new(Mutex::new(OAuthInternal {
+                open_callbacks: HashMap::new(),
+            })),
+            runtime: runtime_handle,
+            persistence: persistence_controller,
+        }
     }
 
     pub fn handle_deeplink(&self, urls: &Vec<Url>) -> Option<()> {
@@ -84,49 +81,51 @@ impl OAuthController {
     }
 
 
-    pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> {
+    pub async fn start_authorization(&self, instance_domain: &str) -> Result<AddServerResult, String> {
         let registration_endpoint = format!("https://{instance_domain}/api/v1/apps");
         let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
 
-        #[derive(Deserialize)]
-        struct RegistrationResponse {
-            id: String,
-            name: String,
-            client_id: String,
-            client_secret: String,
-        }
+        let server_client = match self.persistence.get_server(&instance_domain.to_string()).await {
+            Some(client) => client,
+            None => {
+                println!("Registering a new client for {instance_domain}");
+
+                #[derive(Deserialize)]
+                struct RegistrationResponse {
+                    id: String,
+                    name: String,
+                    client_id: String,
+                    client_secret: String,
+                }
 
-        let registration_response : RegistrationResponse = http_client.post(registration_endpoint)
-            .json(&HashMap::from([
-                ("client_name", OAUTH_CLIENT_NAME),
-                ("redirect_uris", "dev.tempest.foxfleet://oauth-response"),
-                ("scopes", "read write"),
-            ])).send().await.expect("Could not send client registration")
-            .json().await.expect("Could not parse client registration response");
+                let registration_response : RegistrationResponse = http_client.post(registration_endpoint)
+                    .json(&HashMap::from([
+                        ("client_name", OAUTH_CLIENT_NAME),
+                        ("redirect_uris", "dev.tempest.foxfleet://oauth-response"),
+                        ("scopes", "read write"),
+                    ])).send().await.expect("Could not send client registration")
+                    .json().await.expect("Could not parse client registration response");
 
-        let server = Server {
-            client_name: registration_response.name,
-            client_id: registration_response.client_id.clone(),
-            client_secret: registration_response.client_secret,
-            accounts: Vec::new(),
-        };
+                self.persistence.new_server(instance_domain.to_string(), registration_response.client_id.clone(), registration_response.name.to_string(), registration_response.client_secret).await;
 
-        {self.0.lock().await.servers.insert(instance_domain.to_string(), server)};
+                self.persistence.get_server(&instance_domain.to_string()).await.unwrap()
+            },
+        };
 
         let (sender, receiver) = channel::<String>();
 
         // TODO: PKCE params for real
         let auth_state = Uuid::new_v4();
         let auth_callback = Callback {
-            server: instance_domain.to_string(),
+            server_domain: instance_domain.to_string(),
             pkce_verifier: String::new(),
             pkce_challenge: String::new(),
             code_channel: (Some(sender), Some(receiver)),
         };
 
-        let client_id = registration_response.client_id;
+        let client_id = server_client.client_id;
 
-        {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)};
+        {self.int.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)};
 
         let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}");
 
@@ -137,8 +136,8 @@ impl OAuthController {
     }
 
     fn resolve_code(&self, state: StateCode, auth_code: String) {
-        let runtime = self.1.clone();
-        let inner_self = self.0.clone();
+        let runtime = self.runtime.clone();
+        let inner_self = self.int.clone();
 
         runtime.spawn(async move {
             if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() {
@@ -147,32 +146,28 @@ impl OAuthController {
         });
     }
 
-    pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> {
+    pub async fn finish_authorization(&self, state: StateCode) -> Result<SigninResult, String> {
+        let domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())}
+            .ok_or(format!("No callback code {state}"))?;
+
         let auth_code = self.await_auth_code(&state).await?;
         let auth_session = self.exchange_api_token(&state, &auth_code).await?;
         let username = self.resolve_account(&state, &auth_session).await?;
 
-        let account = Account {
-            username: username.clone(),
-            auth_session
-        };
 
-        let mut inner_state = self.0.lock().await;
+        self.persistence.new_account(username.clone(), domain.clone(), auth_session.api_token).await;
 
-        if let Some(callback) = inner_state.open_callbacks.remove(&state) {
-            if let Some (server) = inner_state.servers.get_mut(&callback.server) {
-                server.accounts.push(account);
-                return Ok((callback.server.clone(), username))
-            } else {
-                Err("Unknown server URL".to_string())
-            }
-        } else {
-            Err("Unknown state".to_string())
-        }
+        // Remove state callback record
+        {self.int.lock().await.open_callbacks.remove(&state)};
+
+        return Ok(SigninResult {
+            server_domain: domain,
+            username
+        })
     }
 
     async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> {
-        let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()};
+        let maybe_receiver = {self.int.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()};
         if let Some(receiver) = maybe_receiver {
             if let Ok(auth_code) = receiver.await {
                 Ok(auth_code)
@@ -186,24 +181,11 @@ impl OAuthController {
 
     // TODO: Send PKCE stuff
     async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> {
-        let maybe_server = {
-            let inner_state = self.0.lock().await;
-            if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) {
-                if let Some(server) = inner_state.servers.get(&server_domain) {
-                    Some((server_domain, server.clone()))
-                } else {
-                    None
-                }
-            } else {
-                None
-            }
-        };
-
-        if maybe_server.is_none() {
-            return Err("Unknown auth state".to_string());
-        }
-
-        let (server_domain, server) = maybe_server.unwrap();
+        let server_domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())}
+            .ok_or(format!("No callback code {state}"))?;
+        let server_client = self.persistence.get_server(&server_domain).await
+            .ok_or("Could not look up client for server")?;
+        let client_secret = self.persistence.get_credential(server_client.client_credential_id).await?;
 
         #[derive(Deserialize)]
         struct TokenResponse {
@@ -218,8 +200,8 @@ impl OAuthController {
         let token_response : TokenResponse = http_client.post(token_endpoint)
             .json(&HashMap::from([
                 ("redirect_uri", "dev.tempest.foxfleet://oauth-response"),
-                ("client_id", server.client_id.as_str()),
-                ("client_secret", server.client_secret.as_str()),
+                ("client_id", server_client.client_id.as_str()),
+                ("client_secret", client_secret.as_str()),
                 ("grant_type", "authorization_code"),
                 ("code", auth_code.as_str()),
             ])).send().await.expect("Could not get API token")
@@ -232,7 +214,7 @@ impl OAuthController {
     }
 
     async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> {
-        let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) };
+        let instance_url = { self.int.lock().await.open_callbacks.get(&state).map(|callback| callback.server_domain.clone()) };
 
         if let Some(instance_url) = instance_url {
             #[derive(Deserialize)]
@@ -251,15 +233,4 @@ impl OAuthController {
             Err("No instance URL".to_string())
         }
     }
-
-    pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> {
-        let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| {
-            server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone())
-        }).flatten()};
-
-        match api_token_maybe {
-            Some(api_token) => Ok(api_token),
-            None => Err("Could not look up {server_domain} / {username}".to_string()),
-        }
-    }
 }