diff options
Diffstat (limited to 'app/src/persistence.rs')
-rw-r--r-- | app/src/persistence.rs | 291 |
1 files changed, 291 insertions, 0 deletions
diff --git a/app/src/persistence.rs b/app/src/persistence.rs new file mode 100644 index 0000000..028eded --- /dev/null +++ b/app/src/persistence.rs @@ -0,0 +1,291 @@ +use keyring::Entry; +use serde::{Deserialize, Serialize}; +/** + * This module keeps persistent application account and credential state, + * persists it to disk (and the OS keyring) and makes it available to + * other modules in FoxFleet. + * + * All keystore and disk operations run on a secondary thread to + * prevent causing hangups (and as the OS keystore APIs are synchronous) + */ + +use tauri::async_runtime::RuntimeHandle; +use uuid::Uuid; +use std::path::PathBuf; +use std::thread::{self, JoinHandle}; +use std::{collections::HashMap, sync::Arc}; +use tokio::sync::Mutex; +use tokio::sync::oneshot; +use std::sync::mpsc; +use std::fs; + +#[derive(Clone, Serialize, Deserialize)] +pub struct Server { + pub domain: String, + pub client_name: String, + pub client_id: String, + pub client_credential_id: Uuid, +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct Account { + pub username: String, // Full @handle@domain identifier, used for persistence and key storage + pub server_domain: String, // Web domain, used for API access + pub api_credential_id: Uuid, +} + +#[derive(Clone, Serialize, Deserialize)] +struct DiskState { + servers: Vec<Server>, + accounts: Vec<Account>, +} + +enum CredentialThreadRequest { + PutCredential {uuid: Uuid, credential: String, callback: oneshot::Sender<()>}, + GetCredential {uuid: Uuid, callback: oneshot::Sender<String>}, + Close, +} + +enum DiskThreadRequest { + Write {state: DiskState, callback: oneshot::Sender<()> }, + Read {callback: oneshot::Sender<DiskState>}, + Close, +} + +struct PersistenceState { + servers: Vec<Server>, + accounts: Vec<Account>, + credential_cache: HashMap<Uuid, String>, + credential_channel: mpsc::Sender<CredentialThreadRequest>, + disk_channel: mpsc::Sender<DiskThreadRequest>, + credential_joinhandle: JoinHandle<()>, + disk_joinhandle: JoinHandle<()>, + async_runtime: RuntimeHandle, +} + +#[derive(Clone)] +pub struct PersistenceController (Arc<Mutex<PersistenceState>>); + +impl PersistenceController { + pub fn new(runtime_handle: RuntimeHandle) -> Self { + let (credential_sender, credential_receiver) = mpsc::channel::<CredentialThreadRequest>(); + let (disk_sender, disk_receiver) = mpsc::channel::<DiskThreadRequest>(); + + let disk_joinhandle = thread::Builder::new() + .name("foxfleet::persistence::disk".to_string()) + .spawn(move || DiskThread::start(disk_receiver) ) + .expect("Could not spawn disk thread"); + + let credential_joinhandle = thread::Builder::new() + .name("foxfleet::persistence::keyring".to_string()) + .spawn(move || CredentialThread::start(credential_receiver) ) + .expect("Could not spawn keyring thread"); + + // Quickly wait on the disk thread to load config + let (load_send, load_recv) = oneshot::channel::<DiskState>(); + disk_sender.send(DiskThreadRequest::Read { callback: load_send }).expect("Could not load accounts"); + let loaded_state = runtime_handle.block_on(async { + load_recv.await.expect("Could not load accounts") + }); + + return PersistenceController(Arc::new(Mutex::new(PersistenceState { + servers: loaded_state.servers, + accounts: loaded_state.accounts, + credential_cache: HashMap::new(), + credential_channel: credential_sender, + disk_channel: disk_sender, + credential_joinhandle, + disk_joinhandle, + async_runtime: runtime_handle, + }))) + } + + pub async fn new_server(&self, domain: String, client_id: String, client_name: String, client_secret: String) -> Result<Server, String> { + let has_server_with_domain = {self.0.lock().await.servers.iter().find(|s| s.domain == domain).is_some()}; + + if has_server_with_domain { + return Err(format!("Server already exists in state with domain {domain}")); + } + + let client_credential_id = Uuid::new_v4(); + let server = Server { + domain, + client_id, + client_name, + client_credential_id, + }; + + self.persist_credential(client_credential_id, client_secret).await; + {self.0.lock().await.servers.push(server.clone())}; + self.persist_disk().await; + + return Ok(server) + } + + pub async fn get_server(&self, domain: &String) -> Option<Server> { + let server = {self.0.lock().await.servers.iter().find(|s| s.domain == *domain).cloned()}; + return server + } + + pub async fn new_account(&self, username: String, server_domain: String, api_token: String) -> Result<Account, String> { + let has_account_already = {self.0.lock().await.accounts.iter().find(|a| a.username == username && a.server_domain == server_domain).is_some()}; + + if has_account_already { + return Err(format!("Account already exists for @{username}@{server_domain}")); + } + + let api_credential_id = Uuid::new_v4(); + let account = Account { + username, + server_domain, + api_credential_id + }; + + self.persist_credential(api_credential_id, api_token).await; + {self.0.lock().await.accounts.push(account.clone())}; + self.persist_disk().await; + + return Ok(account) + } + + pub async fn get_account(&self, username: &String, server_domain: &String) -> Option<Account> { + let account = {self.0.lock().await.accounts.iter().find(|a| a.username == *username && a.server_domain == *server_domain).cloned()}; + return account + } + + pub async fn get_all_accounts(&self) -> Vec<Account> { + let accounts = {self.0.lock().await.accounts.clone()}; + return accounts + } + + async fn persist_credential(&self, credential_id: Uuid, value: String) { + let credential_channel = {self.0.lock().await.credential_channel.clone()}; + + let (sender, receiver) = oneshot::channel::<()>(); + credential_channel.send(CredentialThreadRequest::PutCredential { uuid: credential_id, credential: value.clone(), callback: sender }); + receiver.await; + + {self.0.lock().await.credential_cache.insert(credential_id, value)}; + } + + pub async fn get_credential(&self, credential_id: Uuid) -> Result<String, String> { + if let Some(cached) = {self.0.lock().await.credential_cache.get(&credential_id).cloned()} { + return Ok(cached) + } + + let credential_channel = {self.0.lock().await.credential_channel.clone()}; + + let (sender, receiver) = oneshot::channel::<String>(); + credential_channel.send(CredentialThreadRequest::GetCredential { uuid: credential_id, callback: sender }); + let retrieved = receiver.await; + + if retrieved.is_err() { + return Err(format!("Could not retrieve credential: {credential_id}")) + } + + let retrieved = retrieved.unwrap(); + {self.0.lock().await.credential_cache.insert(credential_id, retrieved.clone())}; + + return Ok(retrieved) + } + + async fn persist_disk(&self) { + let (disk_channel, accounts, servers) = { + let state = self.0.lock().await; + (state.disk_channel.clone(), state.accounts.clone(), state.servers.clone()) + }; + + let (sender, receiver) = oneshot::channel::<()>(); + + disk_channel.send(DiskThreadRequest::Write { state: DiskState { + accounts, + servers, + }, callback: sender}); + + receiver.await; + } +} + +struct DiskThread; +impl DiskThread { + fn start(receiver: mpsc::Receiver<DiskThreadRequest>) { + loop { + match receiver.recv() { + Ok(DiskThreadRequest::Write { state, callback }) => Self::write(state, callback), + Ok(DiskThreadRequest::Read { callback }) => Self::read(callback), + Ok(DiskThreadRequest::Close) => return, + Err(_) => { + println!("Disk thread hung up on, exiting"); + return + }, + } + } + } + + fn write(state: DiskState, callback: oneshot::Sender<()>) { + Self::ensure_dir_exists(); + let file_path = Self::file_path(); + let data = toml::to_string(&state).unwrap(); + + fs::write(file_path, data).unwrap(); + callback.send(()).unwrap(); + } + + fn read(callback: oneshot::Sender<DiskState>) { + Self::ensure_dir_exists(); + let file_path = Self::file_path(); + let data = match fs::read_to_string(file_path) { + Ok(data) => data, + Err(_) => { + let _ = callback.send(DiskState { + servers: Vec::new(), + accounts: Vec::new(), + }); + return + }, + }; + + let result : DiskState = toml::from_str(&data).unwrap(); + let _ = callback.send(result); + } + + fn file_path() -> PathBuf { + dirs::data_dir().unwrap().join("foxfleet/data.toml") + } + + fn ensure_dir_exists() { + let dir = dirs::data_dir().unwrap().join("foxfleet"); + if !fs::exists(&dir).unwrap() { + fs::create_dir(&dir).expect("Could not create data directory"); + } + } +} + +struct CredentialThread; +impl CredentialThread { + fn start(receiver: mpsc::Receiver<CredentialThreadRequest>) { + loop { + match receiver.recv() { + Ok(CredentialThreadRequest::PutCredential { uuid, credential, callback }) => Self::put(uuid, credential, callback), + Ok(CredentialThreadRequest::GetCredential { uuid, callback }) => Self::get(uuid, callback), + Ok(CredentialThreadRequest::Close) => return, + Err(_) => { + println!("Credential thread hung up on, exiting"); + return + }, + } + } + } + + fn get(uuid: Uuid, callback: oneshot::Sender<String>) { + let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap(); + let credential = entry.get_password().unwrap(); + callback.send(credential).unwrap(); + } + + fn put(uuid: Uuid, credential: String, callback: oneshot::Sender<()>) { + let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap(); + entry.set_password(credential.as_str()).unwrap(); + callback.send(()).unwrap(); + } +} |