summary refs log tree commit diff
path: root/app/src/oauth.rs
diff options
context:
space:
mode:
Diffstat (limited to 'app/src/oauth.rs')
-rw-r--r--app/src/oauth.rs239
1 files changed, 239 insertions, 0 deletions
diff --git a/app/src/oauth.rs b/app/src/oauth.rs
new file mode 100644
index 0000000..795d5dc
--- /dev/null
+++ b/app/src/oauth.rs
@@ -0,0 +1,239 @@
+use std::{collections::HashMap, sync::Arc};
+use serde::Deserialize;
+use tokio::sync::Mutex;
+use tokio::sync::oneshot::{channel, Sender, Receiver};
+use uuid::Uuid;
+
+use crate::OAUTH_CLIENT_NAME;
+#[derive(Clone)]
+pub struct OAuthController (Arc<Mutex<OAuthInternal>>);
+
+type ServerDomain = String;
+type AccountIdentifier = String;
+type StateCode = Uuid;
+
+struct OAuthInternal {
+    servers: HashMap<ServerDomain, Server>,
+    open_callbacks: HashMap<StateCode, Callback>
+}
+
+#[derive(Clone)]
+struct Server {
+    client_name: String,
+    client_id: String,
+    client_secret: String,
+    accounts: Vec<Account>,
+}
+
+struct Callback {
+    server: ServerDomain,
+    pkce_verifier: String,
+    pkce_challenge: String,
+    code_channel: (Option<Sender<String>>, Option<Receiver<String>>),
+}
+
+#[derive(Clone)]
+struct Account {
+    username: String,
+    auth_session: AuthSession,
+}
+
+#[derive(Clone)]
+pub struct AuthSession {
+    pub api_token: String,
+    refresh_token: Option<String>,
+}
+
+pub struct AddServerResult {
+    pub auth_url: String,
+    pub auth_state: StateCode,
+}
+
+impl OAuthController {
+    pub fn new() -> Self {
+        Self(Arc::new(Mutex::new(OAuthInternal {
+            servers: HashMap::new(),
+            open_callbacks: HashMap::new(),
+        })))
+    }
+
+    pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> {
+        let registration_endpoint = format!("https://{instance_domain}/api/v1/apps");
+        let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+
+        #[derive(Deserialize)]
+        struct RegistrationResponse {
+            id: String,
+            name: String,
+            client_id: String,
+            client_secret: String,
+        }
+
+        let registration_response : RegistrationResponse = http_client.post(registration_endpoint)
+            .json(&HashMap::from([
+                ("client_name", OAUTH_CLIENT_NAME),
+                ("redirect_uris", "dev.tempest.foxfleet://oauth-response"),
+                ("scopes", "read write"),
+            ])).send().await.expect("Could not send client registration")
+            .json().await.expect("Could not parse client registration response");
+
+        let server = Server {
+            client_name: registration_response.name,
+            client_id: registration_response.client_id.clone(),
+            client_secret: registration_response.client_secret,
+            accounts: Vec::new(),
+        };
+
+        {self.0.lock().await.servers.insert(instance_domain.to_string(), server)};
+
+        let (sender, receiver) = channel::<String>();
+
+        // TODO: PKCE params for real
+        let auth_state = Uuid::new_v4();
+        let auth_callback = Callback {
+            server: instance_domain.to_string(),
+            pkce_verifier: String::new(),
+            pkce_challenge: String::new(),
+            code_channel: (Some(sender), Some(receiver)),
+        };
+
+        let client_id = registration_response.client_id;
+
+        {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)};
+
+        let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}");
+
+        return Ok(AddServerResult {
+            auth_url,
+            auth_state,
+        })
+    }
+
+    pub fn resolve_code(&self, state: StateCode, auth_code: String) {
+        let runtime = tokio::runtime::Handle::current();
+        let inner_self = self.0.clone();
+        let state = state.clone();
+
+        runtime.spawn(async move {
+            if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() {
+                let _ = sender.send(auth_code);
+            };
+        });
+    }
+
+    pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> {
+        let auth_code = self.await_auth_code(&state).await?;
+        let auth_session = self.exchange_api_token(&state, &auth_code).await?;
+        let username = self.resolve_account(&state, &auth_session).await?;
+
+        let account = Account {
+            username: username.clone(),
+            auth_session
+        };
+
+        let mut inner_state = self.0.lock().await;
+
+        if let Some(callback) = inner_state.open_callbacks.remove(&state) {
+            if let Some (server) = inner_state.servers.get_mut(&callback.server) {
+                server.accounts.push(account);
+                return Ok((callback.server.clone(), username))
+            } else {
+                Err("Unknown server URL".to_string())
+            }
+        } else {
+            Err("Unknown state".to_string())
+        }
+    }
+
+    async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> {
+        let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()};
+        if let Some(receiver) = maybe_receiver {
+            if let Ok(auth_code) = receiver.await {
+                Ok(auth_code)
+            } else {
+                Err("Error receiving code from channel".to_string())
+            }
+        } else {
+            Err("Channel already awaited".to_string())
+        }
+    }
+
+    // TODO: Send PKCE stuff
+    async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> {
+        let maybe_server = {
+            let inner_state = self.0.lock().await;
+            if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) {
+                if let Some(server) = inner_state.servers.get(&server_domain) {
+                    Some((server_domain, server.clone()))
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        };
+
+        if maybe_server.is_none() {
+            return Err("Unknown auth state".to_string());
+        }
+
+        let (server_domain, server) = maybe_server.unwrap();
+
+        #[derive(Deserialize)]
+        struct TokenResponse {
+            access_token: String,
+            created_at: u32,
+            scope: String,
+            token_type: String,
+        }
+
+        let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+        let token_endpoint = format!("https://{server_domain}/oauth/token");
+        let token_response : TokenResponse = http_client.post(token_endpoint)
+            .json(&HashMap::from([
+                ("redirect_uri", "dev.tempest.foxfleet://oauth-response"),
+                ("client_id", server.client_id.as_str()),
+                ("client_secret", server.client_secret.as_str()),
+                ("grant_type", "authorization_code"),
+                ("code", auth_code.as_str()),
+            ])).send().await.expect("Could not get API token")
+            .json().await.expect("Could not parse client registration response");
+
+        Ok(AuthSession {
+            api_token: token_response.access_token,
+            refresh_token: None,
+        })
+    }
+
+    async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> {
+        let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) };
+
+        if let Some(instance_url) = instance_url {
+            #[derive(Deserialize)]
+            struct VerifyResponse {
+                acct: String,
+            }
+
+            let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client");
+            let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials"))
+                .bearer_auth(auth_session.api_token.clone())
+                .send().await.expect("Could not look up account information")
+                .json().await.expect("Could not parse account information");
+
+            Ok(verify_response.acct)
+        } else {
+            Err("No instance URL".to_string())
+        }
+    }
+
+    pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> {
+        let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| {
+            server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone())
+        }).flatten()};
+
+        match api_token_maybe {
+            Some(api_token) => Ok(api_token),
+            None => Err("Could not look up {server_domain} / {username}".to_string()),
+        }
+    }
+}