diff options
Diffstat (limited to 'app/src/persistence.rs')
-rw-r--r-- | app/src/persistence.rs | 197 |
1 files changed, 116 insertions, 81 deletions
diff --git a/app/src/persistence.rs b/app/src/persistence.rs index 028eded..8a0e7f2 100644 --- a/app/src/persistence.rs +++ b/app/src/persistence.rs @@ -18,8 +18,10 @@ use tokio::sync::Mutex; use tokio::sync::oneshot; use std::sync::mpsc; use std::fs; +use std::sync::mpsc::SendError; +use tokio::sync::oneshot::error::RecvError; -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Debug)] pub struct Server { pub domain: String, pub client_name: String, @@ -27,29 +29,76 @@ pub struct Server { pub client_credential_id: Uuid, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Debug)] pub struct Account { pub username: String, // Full @handle@domain identifier, used for persistence and key storage pub server_domain: String, // Web domain, used for API access pub api_credential_id: Uuid, } -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Serialize, Deserialize, Debug)] struct DiskState { servers: Vec<Server>, accounts: Vec<Account>, } +impl Default for DiskState { + fn default() -> Self { + Self { + servers: Vec::new(), + accounts: Vec::new(), + } + } +} + +type PutResponse = Result<(), PersistenceError>; +type LoadDiskResponse = Result<DiskState, PersistenceError>; +type GetCredentialResponse = Result<String, PersistenceError>; + enum CredentialThreadRequest { - PutCredential {uuid: Uuid, credential: String, callback: oneshot::Sender<()>}, - GetCredential {uuid: Uuid, callback: oneshot::Sender<String>}, - Close, + PutCredential {uuid: Uuid, credential: String, callback: oneshot::Sender<PutResponse>}, + GetCredential {uuid: Uuid, callback: oneshot::Sender<GetCredentialResponse>}, } enum DiskThreadRequest { - Write {state: DiskState, callback: oneshot::Sender<()> }, - Read {callback: oneshot::Sender<DiskState>}, - Close, + Write {state: DiskState, callback: oneshot::Sender<PutResponse> }, + Read {callback: oneshot::Sender<LoadDiskResponse>}, +} + +#[derive(Debug)] +pub enum PersistenceError { + RecvError(RecvError), + SendError, + IoError(std::io::Error), + NoDataDir, + AccountNotFound {username: String, server_domain: String}, + ServerNotFound {server_domain: String}, + AccountAlreadyExists {username: String, server_domain: String}, + ServerAlreadyRegistered { server_domain: String }, +} + +impl From<RecvError> for PersistenceError { + fn from(value: RecvError) -> Self { + PersistenceError::RecvError(value) + } +} + +impl<T> From<SendError<T>> for PersistenceError { + fn from(_value: SendError<T>) -> Self { + PersistenceError::SendError + } +} + +impl From<std::io::Error> for PersistenceError { + fn from(value: std::io::Error) -> Self { + PersistenceError::IoError(value) + } +} + +impl From<PersistenceError> for String { + fn from(value: PersistenceError) -> Self { + "PersistenceError".to_string() + } } struct PersistenceState { @@ -81,16 +130,16 @@ impl PersistenceController { .spawn(move || CredentialThread::start(credential_receiver) ) .expect("Could not spawn keyring thread"); - // Quickly wait on the disk thread to load config - let (load_send, load_recv) = oneshot::channel::<DiskState>(); - disk_sender.send(DiskThreadRequest::Read { callback: load_send }).expect("Could not load accounts"); - let loaded_state = runtime_handle.block_on(async { - load_recv.await.expect("Could not load accounts") - }); + // Load saved or default config on startup + let (load_send, load_recv) = oneshot::channel::<LoadDiskResponse>(); + + let initial_state = disk_sender.send(DiskThreadRequest::Read { callback: load_send }).ok().map(|()| { + runtime_handle.block_on(load_recv).ok().map(|response| response.ok()) + }).flatten().flatten().unwrap_or_default(); return PersistenceController(Arc::new(Mutex::new(PersistenceState { - servers: loaded_state.servers, - accounts: loaded_state.accounts, + servers: initial_state.servers, + accounts: initial_state.accounts, credential_cache: HashMap::new(), credential_channel: credential_sender, disk_channel: disk_sender, @@ -100,11 +149,11 @@ impl PersistenceController { }))) } - pub async fn new_server(&self, domain: String, client_id: String, client_name: String, client_secret: String) -> Result<Server, String> { + pub async fn new_server(&self, domain: String, client_id: String, client_name: String, client_secret: String) -> Result<Server, PersistenceError> { let has_server_with_domain = {self.0.lock().await.servers.iter().find(|s| s.domain == domain).is_some()}; if has_server_with_domain { - return Err(format!("Server already exists in state with domain {domain}")); + return Err(PersistenceError::ServerAlreadyRegistered { server_domain: domain }); } let client_credential_id = Uuid::new_v4(); @@ -115,9 +164,9 @@ impl PersistenceController { client_credential_id, }; - self.persist_credential(client_credential_id, client_secret).await; + self.persist_credential(client_credential_id, client_secret).await?; {self.0.lock().await.servers.push(server.clone())}; - self.persist_disk().await; + self.persist_disk().await?; return Ok(server) } @@ -127,11 +176,11 @@ impl PersistenceController { return server } - pub async fn new_account(&self, username: String, server_domain: String, api_token: String) -> Result<Account, String> { + pub async fn new_account(&self, username: String, server_domain: String, api_token: String) -> Result<Account, PersistenceError> { let has_account_already = {self.0.lock().await.accounts.iter().find(|a| a.username == username && a.server_domain == server_domain).is_some()}; if has_account_already { - return Err(format!("Account already exists for @{username}@{server_domain}")); + return Err(PersistenceError::AccountAlreadyExists {username, server_domain}); } let api_credential_id = Uuid::new_v4(); @@ -141,9 +190,9 @@ impl PersistenceController { api_credential_id }; - self.persist_credential(api_credential_id, api_token).await; + self.persist_credential(api_credential_id, api_token).await?; {self.0.lock().await.accounts.push(account.clone())}; - self.persist_disk().await; + self.persist_disk().await?; return Ok(account) } @@ -158,51 +207,50 @@ impl PersistenceController { return accounts } - async fn persist_credential(&self, credential_id: Uuid, value: String) { + async fn persist_credential(&self, credential_id: Uuid, value: String) -> Result<(), PersistenceError> { let credential_channel = {self.0.lock().await.credential_channel.clone()}; - let (sender, receiver) = oneshot::channel::<()>(); - credential_channel.send(CredentialThreadRequest::PutCredential { uuid: credential_id, credential: value.clone(), callback: sender }); - receiver.await; + let (sender, receiver) = oneshot::channel::<PutResponse>(); + credential_channel.send(CredentialThreadRequest::PutCredential { uuid: credential_id, credential: value.clone(), callback: sender })?; + receiver.await??; {self.0.lock().await.credential_cache.insert(credential_id, value)}; + + Ok(()) } - pub async fn get_credential(&self, credential_id: Uuid) -> Result<String, String> { + pub async fn get_credential(&self, credential_id: Uuid) -> Result<String, PersistenceError> { if let Some(cached) = {self.0.lock().await.credential_cache.get(&credential_id).cloned()} { return Ok(cached) } let credential_channel = {self.0.lock().await.credential_channel.clone()}; - let (sender, receiver) = oneshot::channel::<String>(); - credential_channel.send(CredentialThreadRequest::GetCredential { uuid: credential_id, callback: sender }); - let retrieved = receiver.await; + let (sender, receiver) = oneshot::channel::<GetCredentialResponse>(); + credential_channel.send(CredentialThreadRequest::GetCredential { uuid: credential_id, callback: sender })?; + let retrieved = receiver.await??; - if retrieved.is_err() { - return Err(format!("Could not retrieve credential: {credential_id}")) - } - - let retrieved = retrieved.unwrap(); {self.0.lock().await.credential_cache.insert(credential_id, retrieved.clone())}; return Ok(retrieved) } - async fn persist_disk(&self) { + async fn persist_disk(&self) -> Result<(), PersistenceError> { let (disk_channel, accounts, servers) = { let state = self.0.lock().await; (state.disk_channel.clone(), state.accounts.clone(), state.servers.clone()) }; - let (sender, receiver) = oneshot::channel::<()>(); + let (sender, receiver) = oneshot::channel::<PutResponse>(); disk_channel.send(DiskThreadRequest::Write { state: DiskState { accounts, servers, - }, callback: sender}); + }, callback: sender})?; - receiver.await; + receiver.await??; + + Ok(()) } } @@ -211,53 +259,44 @@ impl DiskThread { fn start(receiver: mpsc::Receiver<DiskThreadRequest>) { loop { match receiver.recv() { - Ok(DiskThreadRequest::Write { state, callback }) => Self::write(state, callback), - Ok(DiskThreadRequest::Read { callback }) => Self::read(callback), - Ok(DiskThreadRequest::Close) => return, - Err(_) => { - println!("Disk thread hung up on, exiting"); - return - }, + Ok(DiskThreadRequest::Write { state, callback }) => callback.send(Self::write(state)).expect("Disk thread hung up on"), + Ok(DiskThreadRequest::Read { callback }) => callback.send(Self::read()).expect("Disk thread hung up on"), + Err(_) => panic!("Disk thread hung up on, exiting"), } } } - fn write(state: DiskState, callback: oneshot::Sender<()>) { - Self::ensure_dir_exists(); + fn write(state: DiskState) -> PutResponse { + Self::ensure_dir_exists()?; let file_path = Self::file_path(); let data = toml::to_string(&state).unwrap(); - fs::write(file_path, data).unwrap(); - callback.send(()).unwrap(); + fs::write(file_path, data)?; + + Ok(()) } - fn read(callback: oneshot::Sender<DiskState>) { - Self::ensure_dir_exists(); + fn read() -> LoadDiskResponse { + Self::ensure_dir_exists()?; let file_path = Self::file_path(); - let data = match fs::read_to_string(file_path) { - Ok(data) => data, - Err(_) => { - let _ = callback.send(DiskState { - servers: Vec::new(), - accounts: Vec::new(), - }); - return - }, - }; + let data = fs::read_to_string(file_path)?; let result : DiskState = toml::from_str(&data).unwrap(); - let _ = callback.send(result); + + Ok(result) } fn file_path() -> PathBuf { dirs::data_dir().unwrap().join("foxfleet/data.toml") } - fn ensure_dir_exists() { - let dir = dirs::data_dir().unwrap().join("foxfleet"); - if !fs::exists(&dir).unwrap() { - fs::create_dir(&dir).expect("Could not create data directory"); + fn ensure_dir_exists() -> Result<(), PersistenceError> { + let dir = dirs::data_dir().ok_or(PersistenceError::NoDataDir)?.join("foxfleet"); + if !fs::exists(&dir)? { + fs::create_dir(&dir)?; } + + Ok(()) } } @@ -266,26 +305,22 @@ impl CredentialThread { fn start(receiver: mpsc::Receiver<CredentialThreadRequest>) { loop { match receiver.recv() { - Ok(CredentialThreadRequest::PutCredential { uuid, credential, callback }) => Self::put(uuid, credential, callback), - Ok(CredentialThreadRequest::GetCredential { uuid, callback }) => Self::get(uuid, callback), - Ok(CredentialThreadRequest::Close) => return, - Err(_) => { - println!("Credential thread hung up on, exiting"); - return - }, + Ok(CredentialThreadRequest::PutCredential { uuid, credential, callback }) => callback.send(Self::put(uuid, credential)).expect("Credential thread hung up on"), + Ok(CredentialThreadRequest::GetCredential { uuid, callback }) => callback.send(Self::get(uuid)).expect("Credential thread hung up on"), + Err(_) => panic!("Credential thread hung up on, exiting"), } } } - fn get(uuid: Uuid, callback: oneshot::Sender<String>) { + fn get(uuid: Uuid) -> GetCredentialResponse { let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap(); let credential = entry.get_password().unwrap(); - callback.send(credential).unwrap(); + Ok(credential) } - fn put(uuid: Uuid, credential: String, callback: oneshot::Sender<()>) { + fn put(uuid: Uuid, credential: String) -> PutResponse { let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap(); entry.set_password(credential.as_str()).unwrap(); - callback.send(()).unwrap(); + Ok(()) } } |