diff options
Diffstat (limited to 'app/src/oauth.rs')
-rw-r--r-- | app/src/oauth.rs | 239 |
1 files changed, 239 insertions, 0 deletions
diff --git a/app/src/oauth.rs b/app/src/oauth.rs new file mode 100644 index 0000000..795d5dc --- /dev/null +++ b/app/src/oauth.rs @@ -0,0 +1,239 @@ +use std::{collections::HashMap, sync::Arc}; +use serde::Deserialize; +use tokio::sync::Mutex; +use tokio::sync::oneshot::{channel, Sender, Receiver}; +use uuid::Uuid; + +use crate::OAUTH_CLIENT_NAME; +#[derive(Clone)] +pub struct OAuthController (Arc<Mutex<OAuthInternal>>); + +type ServerDomain = String; +type AccountIdentifier = String; +type StateCode = Uuid; + +struct OAuthInternal { + servers: HashMap<ServerDomain, Server>, + open_callbacks: HashMap<StateCode, Callback> +} + +#[derive(Clone)] +struct Server { + client_name: String, + client_id: String, + client_secret: String, + accounts: Vec<Account>, +} + +struct Callback { + server: ServerDomain, + pkce_verifier: String, + pkce_challenge: String, + code_channel: (Option<Sender<String>>, Option<Receiver<String>>), +} + +#[derive(Clone)] +struct Account { + username: String, + auth_session: AuthSession, +} + +#[derive(Clone)] +pub struct AuthSession { + pub api_token: String, + refresh_token: Option<String>, +} + +pub struct AddServerResult { + pub auth_url: String, + pub auth_state: StateCode, +} + +impl OAuthController { + pub fn new() -> Self { + Self(Arc::new(Mutex::new(OAuthInternal { + servers: HashMap::new(), + open_callbacks: HashMap::new(), + }))) + } + + pub async fn add_server(&self, instance_domain: &str) -> Result<AddServerResult, String> { + let registration_endpoint = format!("https://{instance_domain}/api/v1/apps"); + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + + #[derive(Deserialize)] + struct RegistrationResponse { + id: String, + name: String, + client_id: String, + client_secret: String, + } + + let registration_response : RegistrationResponse = http_client.post(registration_endpoint) + .json(&HashMap::from([ + ("client_name", OAUTH_CLIENT_NAME), + ("redirect_uris", "dev.tempest.foxfleet://oauth-response"), + ("scopes", "read write"), + ])).send().await.expect("Could not send client registration") + .json().await.expect("Could not parse client registration response"); + + let server = Server { + client_name: registration_response.name, + client_id: registration_response.client_id.clone(), + client_secret: registration_response.client_secret, + accounts: Vec::new(), + }; + + {self.0.lock().await.servers.insert(instance_domain.to_string(), server)}; + + let (sender, receiver) = channel::<String>(); + + // TODO: PKCE params for real + let auth_state = Uuid::new_v4(); + let auth_callback = Callback { + server: instance_domain.to_string(), + pkce_verifier: String::new(), + pkce_challenge: String::new(), + code_channel: (Some(sender), Some(receiver)), + }; + + let client_id = registration_response.client_id; + + {self.0.lock().await.open_callbacks.insert(auth_state.clone(), auth_callback)}; + + let auth_url = format!("https://{instance_domain}/oauth/authorize?client_id={client_id}&redirect_uri=dev.tempest.foxfleet://oauth-response&response_type=code&scope=read+write&state={auth_state}"); + + return Ok(AddServerResult { + auth_url, + auth_state, + }) + } + + pub fn resolve_code(&self, state: StateCode, auth_code: String) { + let runtime = tokio::runtime::Handle::current(); + let inner_self = self.0.clone(); + let state = state.clone(); + + runtime.spawn(async move { + if let Some(sender) = inner_self.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.0.take() ).flatten() { + let _ = sender.send(auth_code); + }; + }); + } + + pub async fn finish_signin(&self, state: StateCode) -> Result<(String, String), String> { + let auth_code = self.await_auth_code(&state).await?; + let auth_session = self.exchange_api_token(&state, &auth_code).await?; + let username = self.resolve_account(&state, &auth_session).await?; + + let account = Account { + username: username.clone(), + auth_session + }; + + let mut inner_state = self.0.lock().await; + + if let Some(callback) = inner_state.open_callbacks.remove(&state) { + if let Some (server) = inner_state.servers.get_mut(&callback.server) { + server.accounts.push(account); + return Ok((callback.server.clone(), username)) + } else { + Err("Unknown server URL".to_string()) + } + } else { + Err("Unknown state".to_string()) + } + } + + async fn await_auth_code(&self, state: &StateCode) -> Result<String, String> { + let maybe_receiver = {self.0.lock().await.open_callbacks.get_mut(&state).map(|callback| callback.code_channel.1.take() ).flatten()}; + if let Some(receiver) = maybe_receiver { + if let Ok(auth_code) = receiver.await { + Ok(auth_code) + } else { + Err("Error receiving code from channel".to_string()) + } + } else { + Err("Channel already awaited".to_string()) + } + } + + // TODO: Send PKCE stuff + async fn exchange_api_token(&self, state: &StateCode, auth_code: &String) -> Result<AuthSession, String> { + let maybe_server = { + let inner_state = self.0.lock().await; + if let Some(server_domain) = inner_state.open_callbacks.get(&state).map(|callback| callback.server.clone()) { + if let Some(server) = inner_state.servers.get(&server_domain) { + Some((server_domain, server.clone())) + } else { + None + } + } else { + None + } + }; + + if maybe_server.is_none() { + return Err("Unknown auth state".to_string()); + } + + let (server_domain, server) = maybe_server.unwrap(); + + #[derive(Deserialize)] + struct TokenResponse { + access_token: String, + created_at: u32, + scope: String, + token_type: String, + } + + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + let token_endpoint = format!("https://{server_domain}/oauth/token"); + let token_response : TokenResponse = http_client.post(token_endpoint) + .json(&HashMap::from([ + ("redirect_uri", "dev.tempest.foxfleet://oauth-response"), + ("client_id", server.client_id.as_str()), + ("client_secret", server.client_secret.as_str()), + ("grant_type", "authorization_code"), + ("code", auth_code.as_str()), + ])).send().await.expect("Could not get API token") + .json().await.expect("Could not parse client registration response"); + + Ok(AuthSession { + api_token: token_response.access_token, + refresh_token: None, + }) + } + + async fn resolve_account(&self, state: &StateCode, auth_session: &AuthSession) -> Result<String, String> { + let instance_url = { self.0.lock().await.open_callbacks.get(&state).map(|callback| callback.server.clone()) }; + + if let Some(instance_url) = instance_url { + #[derive(Deserialize)] + struct VerifyResponse { + acct: String, + } + + let http_client = reqwest::Client::builder().user_agent("Foxfleet v0.0.1").build().expect("Could not construct client"); + let verify_response : VerifyResponse = http_client.get(format!("https://{instance_url}/api/v1/accounts/verify_credentials")) + .bearer_auth(auth_session.api_token.clone()) + .send().await.expect("Could not look up account information") + .json().await.expect("Could not parse account information"); + + Ok(verify_response.acct) + } else { + Err("No instance URL".to_string()) + } + } + + pub async fn get_api_token(&self, server_domain: String, username: String) -> Result<String, String> { + let api_token_maybe = { self.0.lock().await.servers.get(&server_domain).map(|server| { + server.accounts.iter().find(|account| account.username == username).map(|account| account.auth_session.api_token.clone()) + }).flatten()}; + + match api_token_maybe { + Some(api_token) => Ok(api_token), + None => Err("Could not look up {server_domain} / {username}".to_string()), + } + } +} |