summary refs log tree commit diff
path: root/app/src/persistence.rs
diff options
context:
space:
mode:
Diffstat (limited to 'app/src/persistence.rs')
-rw-r--r--app/src/persistence.rs197
1 files changed, 116 insertions, 81 deletions
diff --git a/app/src/persistence.rs b/app/src/persistence.rs
index 028eded..8a0e7f2 100644
--- a/app/src/persistence.rs
+++ b/app/src/persistence.rs
@@ -18,8 +18,10 @@ use tokio::sync::Mutex;
 use tokio::sync::oneshot;
 use std::sync::mpsc;
 use std::fs;
+use std::sync::mpsc::SendError;
+use tokio::sync::oneshot::error::RecvError;
 
-#[derive(Clone, Serialize, Deserialize)]
+#[derive(Clone, Serialize, Deserialize, Debug)]
 pub struct Server {
     pub domain: String,
     pub client_name: String,
@@ -27,29 +29,76 @@ pub struct Server {
     pub client_credential_id: Uuid,
 }
 
-#[derive(Clone, Serialize, Deserialize)]
+#[derive(Clone, Serialize, Deserialize, Debug)]
 pub struct Account {
     pub username: String, // Full @handle@domain identifier, used for persistence and key storage
     pub server_domain: String, // Web domain, used for API access
     pub api_credential_id: Uuid,
 }
 
-#[derive(Clone, Serialize, Deserialize)]
+#[derive(Clone, Serialize, Deserialize, Debug)]
 struct DiskState {
     servers: Vec<Server>,
     accounts: Vec<Account>,
 }
 
+impl Default for DiskState {
+    fn default() -> Self {
+        Self {
+            servers: Vec::new(),
+            accounts: Vec::new(),
+        }
+    }
+}
+
+type PutResponse = Result<(), PersistenceError>;
+type LoadDiskResponse = Result<DiskState, PersistenceError>;
+type GetCredentialResponse = Result<String, PersistenceError>;
+
 enum CredentialThreadRequest {
-    PutCredential {uuid: Uuid, credential: String, callback: oneshot::Sender<()>},
-    GetCredential {uuid: Uuid, callback: oneshot::Sender<String>},
-    Close,
+    PutCredential {uuid: Uuid, credential: String, callback: oneshot::Sender<PutResponse>},
+    GetCredential {uuid: Uuid, callback: oneshot::Sender<GetCredentialResponse>},
 }
 
 enum DiskThreadRequest {
-    Write {state: DiskState, callback: oneshot::Sender<()> },
-    Read {callback: oneshot::Sender<DiskState>},
-    Close,
+    Write {state: DiskState, callback: oneshot::Sender<PutResponse> },
+    Read {callback: oneshot::Sender<LoadDiskResponse>},
+}
+
+#[derive(Debug)]
+pub enum PersistenceError {
+    RecvError(RecvError),
+    SendError,
+    IoError(std::io::Error),
+    NoDataDir,
+    AccountNotFound {username: String, server_domain: String},
+    ServerNotFound {server_domain: String},
+    AccountAlreadyExists {username: String, server_domain: String},
+    ServerAlreadyRegistered { server_domain: String },
+}
+
+impl From<RecvError> for PersistenceError {
+    fn from(value: RecvError) -> Self {
+        PersistenceError::RecvError(value)
+    }
+}
+
+impl<T> From<SendError<T>> for PersistenceError {
+    fn from(_value: SendError<T>) -> Self {
+        PersistenceError::SendError
+    }
+}
+
+impl From<std::io::Error> for PersistenceError {
+    fn from(value: std::io::Error) -> Self {
+        PersistenceError::IoError(value)
+    }
+}
+
+impl From<PersistenceError> for String {
+    fn from(value: PersistenceError) -> Self {
+        "PersistenceError".to_string()
+    }
 }
 
 struct PersistenceState {
@@ -81,16 +130,16 @@ impl PersistenceController {
             .spawn(move || CredentialThread::start(credential_receiver) )
             .expect("Could not spawn keyring thread");
 
-        // Quickly wait on the disk thread to load config
-        let (load_send, load_recv) = oneshot::channel::<DiskState>();
-        disk_sender.send(DiskThreadRequest::Read { callback: load_send }).expect("Could not load accounts");
-        let loaded_state = runtime_handle.block_on(async {
-            load_recv.await.expect("Could not load accounts")
-        });
+        // Load saved or default config on startup
+        let (load_send, load_recv) = oneshot::channel::<LoadDiskResponse>();
+
+        let initial_state = disk_sender.send(DiskThreadRequest::Read { callback: load_send }).ok().map(|()| {
+            runtime_handle.block_on(load_recv).ok().map(|response| response.ok())
+        }).flatten().flatten().unwrap_or_default();
 
         return PersistenceController(Arc::new(Mutex::new(PersistenceState {
-            servers: loaded_state.servers,
-            accounts: loaded_state.accounts,
+            servers: initial_state.servers,
+            accounts: initial_state.accounts,
             credential_cache: HashMap::new(),
             credential_channel: credential_sender,
             disk_channel: disk_sender,
@@ -100,11 +149,11 @@ impl PersistenceController {
         })))
     }
 
-    pub async fn new_server(&self, domain: String, client_id: String, client_name: String, client_secret: String) -> Result<Server, String> {
+    pub async fn new_server(&self, domain: String, client_id: String, client_name: String, client_secret: String) -> Result<Server, PersistenceError> {
         let has_server_with_domain = {self.0.lock().await.servers.iter().find(|s| s.domain == domain).is_some()};
 
         if has_server_with_domain {
-            return Err(format!("Server already exists in state with domain {domain}"));
+            return Err(PersistenceError::ServerAlreadyRegistered { server_domain: domain });
         }
 
         let client_credential_id = Uuid::new_v4();
@@ -115,9 +164,9 @@ impl PersistenceController {
             client_credential_id,
         };
 
-        self.persist_credential(client_credential_id, client_secret).await;
+        self.persist_credential(client_credential_id, client_secret).await?;
         {self.0.lock().await.servers.push(server.clone())};
-        self.persist_disk().await;
+        self.persist_disk().await?;
 
         return Ok(server)
     }
@@ -127,11 +176,11 @@ impl PersistenceController {
         return server
     }
 
-    pub async fn new_account(&self, username: String, server_domain: String, api_token: String) -> Result<Account, String> {
+    pub async fn new_account(&self, username: String, server_domain: String, api_token: String) -> Result<Account, PersistenceError> {
         let has_account_already = {self.0.lock().await.accounts.iter().find(|a| a.username == username && a.server_domain == server_domain).is_some()};
 
         if has_account_already {
-            return Err(format!("Account already exists for @{username}@{server_domain}"));
+            return Err(PersistenceError::AccountAlreadyExists {username, server_domain});
         }
 
         let api_credential_id = Uuid::new_v4();
@@ -141,9 +190,9 @@ impl PersistenceController {
             api_credential_id
         };
 
-        self.persist_credential(api_credential_id, api_token).await;
+        self.persist_credential(api_credential_id, api_token).await?;
         {self.0.lock().await.accounts.push(account.clone())};
-        self.persist_disk().await;
+        self.persist_disk().await?;
 
         return Ok(account)
     }
@@ -158,51 +207,50 @@ impl PersistenceController {
         return accounts
     }
 
-    async fn persist_credential(&self, credential_id: Uuid, value: String) {
+    async fn persist_credential(&self, credential_id: Uuid, value: String) -> Result<(), PersistenceError> {
         let credential_channel = {self.0.lock().await.credential_channel.clone()};
 
-        let (sender, receiver) = oneshot::channel::<()>();
-        credential_channel.send(CredentialThreadRequest::PutCredential { uuid: credential_id, credential: value.clone(), callback: sender });
-        receiver.await;
+        let (sender, receiver) = oneshot::channel::<PutResponse>();
+        credential_channel.send(CredentialThreadRequest::PutCredential { uuid: credential_id, credential: value.clone(), callback: sender })?;
+        receiver.await??;
 
         {self.0.lock().await.credential_cache.insert(credential_id, value)};
+
+        Ok(())
     }
 
-    pub async fn get_credential(&self, credential_id: Uuid) -> Result<String, String> {
+    pub async fn get_credential(&self, credential_id: Uuid) -> Result<String, PersistenceError> {
         if let Some(cached) = {self.0.lock().await.credential_cache.get(&credential_id).cloned()} {
             return Ok(cached)
         }
 
         let credential_channel = {self.0.lock().await.credential_channel.clone()};
 
-        let (sender, receiver) = oneshot::channel::<String>();
-        credential_channel.send(CredentialThreadRequest::GetCredential { uuid: credential_id, callback: sender });
-        let retrieved = receiver.await;
+        let (sender, receiver) = oneshot::channel::<GetCredentialResponse>();
+        credential_channel.send(CredentialThreadRequest::GetCredential { uuid: credential_id, callback: sender })?;
+        let retrieved = receiver.await??;
 
-        if retrieved.is_err() {
-            return Err(format!("Could not retrieve credential: {credential_id}"))
-        }
-
-        let retrieved = retrieved.unwrap();
         {self.0.lock().await.credential_cache.insert(credential_id, retrieved.clone())};
 
         return Ok(retrieved)
     }
 
-    async fn persist_disk(&self) {
+    async fn persist_disk(&self) -> Result<(), PersistenceError> {
         let (disk_channel, accounts, servers) = {
             let state = self.0.lock().await;
             (state.disk_channel.clone(), state.accounts.clone(), state.servers.clone())
         };
 
-        let (sender, receiver) = oneshot::channel::<()>();
+        let (sender, receiver) = oneshot::channel::<PutResponse>();
 
         disk_channel.send(DiskThreadRequest::Write { state: DiskState {
             accounts,
             servers,
-        }, callback: sender});
+        }, callback: sender})?;
 
-        receiver.await;
+        receiver.await??;
+
+        Ok(())
     }
 }
 
@@ -211,53 +259,44 @@ impl DiskThread {
     fn start(receiver: mpsc::Receiver<DiskThreadRequest>) {
         loop {
             match receiver.recv() {
-                Ok(DiskThreadRequest::Write { state, callback }) => Self::write(state, callback),
-                Ok(DiskThreadRequest::Read { callback }) => Self::read(callback),
-                Ok(DiskThreadRequest::Close) => return,
-                Err(_) => {
-                    println!("Disk thread hung up on, exiting");
-                    return
-                },
+                Ok(DiskThreadRequest::Write { state, callback }) => callback.send(Self::write(state)).expect("Disk thread hung up on"),
+                Ok(DiskThreadRequest::Read { callback }) => callback.send(Self::read()).expect("Disk thread hung up on"),
+                Err(_) => panic!("Disk thread hung up on, exiting"),
             }
         }
     }
 
-    fn write(state: DiskState, callback: oneshot::Sender<()>) {
-        Self::ensure_dir_exists();
+    fn write(state: DiskState) -> PutResponse {
+        Self::ensure_dir_exists()?;
         let file_path = Self::file_path();
         let data = toml::to_string(&state).unwrap();
 
-        fs::write(file_path, data).unwrap();
-        callback.send(()).unwrap();
+        fs::write(file_path, data)?;
+
+        Ok(())
     }
 
-    fn read(callback: oneshot::Sender<DiskState>) {
-        Self::ensure_dir_exists();
+    fn read() -> LoadDiskResponse {
+        Self::ensure_dir_exists()?;
         let file_path = Self::file_path();
-        let data = match fs::read_to_string(file_path) {
-            Ok(data) => data,
-            Err(_) => {
-                let _ = callback.send(DiskState {
-                    servers: Vec::new(),
-                    accounts: Vec::new(),
-                });
-                return
-            },
-        };
+        let data = fs::read_to_string(file_path)?;
 
         let result : DiskState = toml::from_str(&data).unwrap();
-        let _ = callback.send(result);
+
+        Ok(result)
     }
 
     fn file_path() -> PathBuf {
         dirs::data_dir().unwrap().join("foxfleet/data.toml")
     }
 
-    fn ensure_dir_exists() {
-        let dir = dirs::data_dir().unwrap().join("foxfleet");
-        if !fs::exists(&dir).unwrap() {
-            fs::create_dir(&dir).expect("Could not create data directory");
+    fn ensure_dir_exists() -> Result<(), PersistenceError> {
+        let dir = dirs::data_dir().ok_or(PersistenceError::NoDataDir)?.join("foxfleet");
+        if !fs::exists(&dir)? {
+            fs::create_dir(&dir)?;
         }
+
+        Ok(())
     }
 }
 
@@ -266,26 +305,22 @@ impl CredentialThread {
     fn start(receiver: mpsc::Receiver<CredentialThreadRequest>) {
         loop {
             match receiver.recv() {
-                Ok(CredentialThreadRequest::PutCredential { uuid, credential, callback }) => Self::put(uuid, credential, callback),
-                Ok(CredentialThreadRequest::GetCredential { uuid, callback }) => Self::get(uuid, callback),
-                Ok(CredentialThreadRequest::Close) => return,
-                Err(_) => {
-                    println!("Credential thread hung up on, exiting");
-                    return
-                },
+                Ok(CredentialThreadRequest::PutCredential { uuid, credential, callback }) => callback.send(Self::put(uuid, credential)).expect("Credential thread hung up on"),
+                Ok(CredentialThreadRequest::GetCredential { uuid, callback }) => callback.send(Self::get(uuid)).expect("Credential thread hung up on"),
+                Err(_) => panic!("Credential thread hung up on, exiting"),
             }
         }
     }
 
-    fn get(uuid: Uuid, callback: oneshot::Sender<String>) {
+    fn get(uuid: Uuid) -> GetCredentialResponse {
         let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap();
         let credential = entry.get_password().unwrap();
-        callback.send(credential).unwrap();
+        Ok(credential)
     }
 
-    fn put(uuid: Uuid, credential: String, callback: oneshot::Sender<()>) {
+    fn put(uuid: Uuid, credential: String) -> PutResponse {
         let entry = Entry::new("dev.tempest.foxfleet", &uuid.to_string()).unwrap();
         entry.set_password(credential.as_str()).unwrap();
-        callback.send(()).unwrap();
+        Ok(())
     }
 }