diff options
Diffstat (limited to 'app/src/oauth.rs')
-rw-r--r-- | app/src/oauth.rs | 177 |
1 files changed, 74 insertions, 103 deletions
diff --git a/app/src/oauth.rs b/app/src/oauth.rs index 657db30..1b0661c 100644 --- a/app/src/oauth.rs +++ b/app/src/oauth.rs @@ -3,46 +3,35 @@ use serde::Deserialize; use tauri::async_runtime::RuntimeHandle; use tokio::sync::Mutex; use tokio::sync::oneshot::{channel, Sender, Receiver}; -use url::{Host, Url}; +use url::Url; use uuid::Uuid; +use crate::persistence::PersistenceController; use crate::OAUTH_CLIENT_NAME; + #[derive(Clone)] -pub struct OAuthController (Arc<Mutex<OAuthInternal>>, RuntimeHandle); +pub struct OAuthController { + int: Arc<Mutex<OAuthInternal>>, + runtime: RuntimeHandle, + persistence: PersistenceController, +} type ServerDomain = String; -type AccountIdentifier = String; type StateCode = Uuid; struct OAuthInternal { - servers: HashMap<ServerDomain, Server>, open_callbacks: HashMap<StateCode, Callback> } -#[derive(Clone)] -struct Server { - client_name: String, - client_id: String, - client_secret: String, - accounts: Vec<Account>, -} - struct Callback { - server: ServerDomain, + server_domain: ServerDomain, pkce_verifier: String, pkce_challenge: String, code_channel: (Option<Sender<String>>, Option<Receiver<String>>), } -#[derive(Clone)] -struct Account { - username: String, - auth_session: AuthSession, -} - -#[derive(Clone)] -pub struct AuthSession { - pub api_token: String, +struct AuthSession { + api_token: String, refresh_token: Option<String>, } @@ -51,12 +40,20 @@ pub struct AddServerResult { pub auth_state: StateCode, } +pub struct SigninResult { + pub server_domain: String, + pub username: String, +} + impl OAuthController { - pub fn new(runtime_handle: RuntimeHandle) -> Self { - Self(Arc::new(Mutex::new(OAuthInternal { - servers: HashMap::new(), - open_callbacks: HashMap::new(), - })), runtime_handle) + pub fn new(runtime_handle: RuntimeHandle, persistence_controller: PersistenceController) -> Self { + Self { + int: Arc::new(Mutex::new(OAuthInternal { + open_callbacks: HashMap::new(), + })), + runtime: runtime_handle, + persistence: persistence_controller, + } } pub fn handle_deeplink(&self, urls: &Vec<Url>) -> Option<()> { @@ -84,49 +81,51 @@ impl OAuthController { } - pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> { + pub async fn start_authorization(&self, instance_domain: &str) -> Result<AddServerResult, String> { let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); - #[derive(Deserialize)] - struct RegistrationResponse { - id: String, - name: String, - client_id: String, - client_secret: String, - } + let server_client = match self.persistence.get_server(&instance_domain.to_string()).await { + Some(client) => client, + None => { + println!("Registering a new client for {instance_domain}"); + + #[derive(Deserialize)] + struct RegistrationResponse { + id: String, + name: String, + client_id: String, + client_secret: String, + } - let registration_response : RegistrationResponse = http_client.post(registration_endpoint) - .json(&HashMap::from([ - ("client_name", OAUTH_CLIENT_NAME), - ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), - ("scopes", "read write"), - ])).send().await.expect("Could not send client registration") - .json().await.expect("Could not parse client registration response"); + let registration_response : RegistrationResponse = http_client.post(registration_endpoint) + .json(&HashMap::from([ + ("client_name", OAUTH_CLIENT_NAME), + ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), + ("scopes", "read write"), + ])).send().await.expect("Could not send client registration") + .json().await.expect("Could not parse client registration response"); - let server = Server { - client_name: registration_response.name, - client_id: registration_response.client_id.clone(), - client_secret: registration_response.client_secret, - accounts: Vec::new(), - }; + self.persistence.new_server(instance_domain.to_string(), registration_response.client_id.clone(), registration_response.name.to_string(), registration_response.client_secret).await; - {self.0.lock().await.servers.insert(instance_domain.to_string(), server)}; + self.persistence.get_server(&instance_domain.to_string()).await.unwrap() + }, + }; let (sender, receiver) = channel::<String>(); // TODO: PKCE params for real let auth_state = Uuid::new_v4(); let auth_callback = Callback { - server: instance_domain.to_string(), + server_domain: instance_domain.to_string(), pkce_verifier: String::new(), pkce_challenge: String::new(), code_channel: (Some(sender), Some(receiver)), }; - let client_id = registration_response.client_id; + let client_id = server_client.client_id; - {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; + {self.int.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}"); @@ -137,8 +136,8 @@ impl OAuthController { } fn resolve_code(&self, state: StateCode, auth_code: String) { - let runtime = self.1.clone(); - let inner_self = self.0.clone(); + let runtime = self.runtime.clone(); + let inner_self = self.int.clone(); runtime.spawn(async move { if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { @@ -147,32 +146,28 @@ impl OAuthController { }); } - pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> { + pub async fn finish_authorization(&self, state: StateCode) -> Result<SigninResult, String> { + let domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())} + .ok_or(format!("No callback code {state}"))?; + let auth_code = self.await_auth_code(&state).await?; let auth_session = self.exchange_api_token(&state, &auth_code).await?; let username = self.resolve_account(&state, &auth_session).await?; - let account = Account { - username: username.clone(), - auth_session - }; - let mut inner_state = self.0.lock().await; + self.persistence.new_account(username.clone(), domain.clone(), auth_session.api_token).await; - if let Some(callback) = inner_state.open_callbacks.remove(&state) { - if let Some (server) = inner_state.servers.get_mut(&callback.server) { - server.accounts.push(account); - return Ok((callback.server.clone(), username)) - } else { - Err("Unknown server URL".to_string()) - } - } else { - Err("Unknown state".to_string()) - } + // Remove state callback record + {self.int.lock().await.open_callbacks.remove(&state)}; + + return Ok(SigninResult { + server_domain: domain, + username + }) } async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> { - let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; + let maybe_receiver = {self.int.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; if let Some(receiver) = maybe_receiver { if let Ok(auth_code) = receiver.await { Ok(auth_code) @@ -186,24 +181,11 @@ impl OAuthController { // TODO: Send PKCE stuff async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> { - let maybe_server = { - let inner_state = self.0.lock().await; - if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) { - if let Some(server) = inner_state.servers.get(&server_domain) { - Some((server_domain, server.clone())) - } else { - None - } - } else { - None - } - }; - - if maybe_server.is_none() { - return Err("Unknown auth state".to_string()); - } - - let (server_domain, server) = maybe_server.unwrap(); + let server_domain = {self.int.lock().await.open_callbacks.get(&state).map(|cb| cb.server_domain.clone())} + .ok_or(format!("No callback code {state}"))?; + let server_client = self.persistence.get_server(&server_domain).await + .ok_or("Could not look up client for server")?; + let client_secret = self.persistence.get_credential(server_client.client_credential_id).await?; #[derive(Deserialize)] struct TokenResponse { @@ -218,8 +200,8 @@ impl OAuthController { let token_response : TokenResponse = http_client.post(token_endpoint) .json(&HashMap::from([ ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), - ("client_id", server.client_id.as_str()), - ("client_secret", server.client_secret.as_str()), + ("client_id", server_client.client_id.as_str()), + ("client_secret", client_secret.as_str()), ("grant_type", "authorization_code"), ("code", auth_code.as_str()), ])).send().await.expect("Could not get API token") @@ -232,7 +214,7 @@ impl OAuthController { } async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> { - let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) }; + let instance_url = { self.int.lock().await.open_callbacks.get(&state).map(|callback| callback.server_domain.clone()) }; if let Some(instance_url) = instance_url { #[derive(Deserialize)] @@ -251,15 +233,4 @@ impl OAuthController { Err("No instance URL".to_string()) } } - - pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> { - let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| { - server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone()) - }).flatten()}; - - match api_token_maybe { - Some(api_token) => Ok(api_token), - None => Err("Could not look up {server_domain} / {username}".to_string()), - } - } } |