🗝
summary refs log tree commit diff
path: root/src/server/store.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/server/store.rs')
-rw-r--r--src/server/store.rs330
1 files changed, 330 insertions, 0 deletions
diff --git a/src/server/store.rs b/src/server/store.rs
new file mode 100644
index 0000000..98c1bcc
--- /dev/null
+++ b/src/server/store.rs
@@ -0,0 +1,330 @@
+use std::{
+    collections::HashMap,
+    path::PathBuf,
+    sync::Arc,
+    time::{Duration, SystemTime, UNIX_EPOCH},
+};
+
+use argon2::{
+    password_hash::{rand_core::OsRng, Encoding, PasswordHashString, SaltString},
+    Argon2, PasswordHash, PasswordHasher,
+};
+use nanoid::nanoid;
+use serde::{de::Unexpected, Deserialize, Serialize};
+use serde_json::{Map, Value};
+use tap::{Pipe, Tap};
+use tokio::sync::{RwLock, RwLockReadGuard};
+
+#[derive(Clone)]
+pub struct Store(Arc<RwLock<StoreInner>>);
+
+impl Store {
+    pub async fn load(path: PathBuf) -> eyre::Result<Self> {
+        let mut inner = StoreInner::new(path.clone());
+        if path.try_exists()? {
+            let mut map: Map<String, Value> =
+                { std::fs::File::open(&path)?.pipe(serde_json::from_reader)? };
+            let accounts: Vec<Account> =
+                serde_json::from_value(map.remove("accounts").unwrap_or_default())?;
+            let apps: Vec<App> = serde_json::from_value(map.remove("apps").unwrap_or_default())?;
+            for mut account in accounts {
+                account.tokens = account
+                    .tokens
+                    .into_iter()
+                    .filter(|token| token.expires > SystemTime::now())
+                    .collect();
+                for token in &account.tokens {
+                    inner.token_map.insert(
+                        token.value.clone(),
+                        (account.name.clone(), token.expires.clone()),
+                    );
+                }
+                inner.accounts.insert(account.name.clone(), account);
+            }
+            for app in apps {
+                inner.apps.insert(app.name.clone(), app);
+            }
+            inner.invites = serde_json::from_value(map.remove("invites").unwrap_or_default())?;
+            inner.invites = inner
+                .invites
+                .into_iter()
+                .filter(|invite| invite.expires > SystemTime::now())
+                .collect();
+        } else {
+            Self::save(&mut inner).await?;
+        }
+        Ok(Store(Arc::new(RwLock::new(inner))))
+    }
+
+    async fn save(inner: &mut StoreInner) -> std::io::Result<()> {
+        let mut map: Map<String, Value> = Map::new();
+        map.insert(
+            "accounts".into(),
+            serde_json::to_value(inner.accounts.values().collect::<Vec<_>>()).unwrap(),
+        );
+        map.insert(
+            "apps".into(),
+            serde_json::to_value(inner.apps.values().collect::<Vec<_>>()).unwrap(),
+        );
+        map.insert(
+            "invites".into(),
+            serde_json::to_value(&inner.invites).unwrap(),
+        );
+        let data = serde_json::to_vec_pretty(&map).unwrap();
+        let mut temp = inner.path.clone();
+        temp.set_file_name(
+            temp.file_name()
+                .unwrap_or_default()
+                .to_os_string()
+                .tap_mut(|name| name.push(".tmp")),
+        );
+        tokio::fs::write(&temp, data).await?;
+        tokio::fs::rename(temp, &inner.path).await?;
+        Ok(())
+    }
+
+    pub async fn get_account(&self, name: &str) -> Option<RwLockReadGuard<'_, Account>> {
+        let guard = self.0.read().await;
+        RwLockReadGuard::try_map(guard, |guard| guard.accounts.get(name)).ok()
+    }
+
+    pub async fn create_account(&self, name: &str, password: &str) -> std::io::Result<bool> {
+        let hash = Argon2::default()
+            .hash_password(password.as_bytes(), &SaltString::generate(&mut OsRng))
+            .unwrap()
+            .pipe(|hash| hash.serialize())
+            .pipe(OwnedPasswordHash::from);
+        let mut guard = self.0.write().await;
+        if guard.accounts.get(name).is_some() {
+            return Ok(false);
+        }
+        guard.accounts.insert(
+            name.to_string(),
+            Account {
+                name: name.to_string(),
+                password: hash,
+                tokens: Default::default(),
+                scopes: Default::default(),
+            },
+        );
+        Self::save(&mut guard).await?;
+        Ok(true)
+    }
+
+    pub async fn delete_account(&self, name: &str) -> std::io::Result<()> {
+        let mut guard = self.0.write().await;
+        if let Some(account) = guard.accounts.remove(name) {
+            for token in account.tokens {
+                guard.token_map.remove(&token.value);
+            }
+            Self::save(&mut guard).await
+        } else {
+            Ok(())
+        }
+    }
+
+    pub async fn update_account(
+        &self,
+        name: &str,
+        with: impl FnOnce(&mut Account),
+    ) -> std::io::Result<bool> {
+        let mut guard = self.0.write().await;
+        if let Some(account) = guard.accounts.get_mut(name) {
+            with(account);
+            Self::save(&mut guard).await.map(|_| true)
+        } else {
+            Ok(false)
+        }
+    }
+
+    pub async fn list_accounts(&self) -> Vec<String> {
+        self.0
+            .read()
+            .await
+            .accounts
+            .values()
+            .map(|account| account.name.clone())
+            .collect()
+    }
+
+    pub async fn create_token(&self, name: &str) -> std::io::Result<Option<String>> {
+        let mut guard = self.0.write().await;
+        let token = nanoid!(32);
+        if let Some(account) = guard.accounts.get_mut(name) {
+            let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 30);
+            account.tokens.push(ExpiringValue {
+                value: token.clone(),
+                expires: expires.clone(),
+            });
+            guard
+                .token_map
+                .insert(token.clone(), (name.to_string(), expires));
+            Self::save(&mut guard).await?;
+            Ok(Some(token))
+        } else {
+            Ok(None)
+        }
+    }
+
+    pub async fn check_token(&self, token: &str) -> Option<String> {
+        let guard = self.0.read().await;
+        let Some((name, expires)) = guard.token_map.get(token) else {
+            return None;
+        };
+        if *expires < SystemTime::now() {
+            return None;
+        }
+        Some(name.clone())
+    }
+
+    pub async fn create_invite(&self) -> std::io::Result<String> {
+        let mut guard = self.0.write().await;
+        let invite = nanoid!(32);
+        let expires = SystemTime::now() + Duration::from_secs(60 * 60 * 24 * 14);
+        guard.invites.push(ExpiringValue {
+            value: invite.clone(),
+            expires,
+        });
+        Self::save(&mut guard).await?;
+        Ok(invite)
+    }
+
+    pub async fn check_invite(&self, invite: &str) -> bool {
+        let guard = self.0.read().await;
+        let now = SystemTime::now();
+        guard
+            .invites
+            .iter()
+            .any(|check| check.expires > now && check.value == invite)
+    }
+
+    pub async fn use_invite(&self, invite: &str) -> std::io::Result<bool> {
+        let mut guard = self.0.write().await;
+        let now = SystemTime::now();
+        if let Some((index, _)) = guard
+            .invites
+            .iter()
+            .enumerate()
+            .find(|(_, check)| check.expires > now && check.value == invite)
+        {
+            guard.invites.swap_remove(index);
+            Self::save(&mut guard).await?;
+            Ok(true)
+        } else {
+            Ok(false)
+        }
+    }
+}
+
+pub struct StoreInner {
+    path: PathBuf,
+    accounts: HashMap<String, Account>,
+    apps: HashMap<String, App>,
+    invites: Vec<ExpiringValue>,
+    token_map: HashMap<String, (String, SystemTime)>,
+}
+
+impl StoreInner {
+    fn new(at: PathBuf) -> Self {
+        Self {
+            path: at,
+            accounts: Default::default(),
+            apps: Default::default(),
+            invites: Default::default(),
+            token_map: Default::default(),
+        }
+    }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct Account {
+    pub name: String,
+    pub password: OwnedPasswordHash,
+    pub tokens: Vec<ExpiringValue>,
+    pub scopes: Vec<String>,
+}
+
+#[ouroboros::self_referencing]
+pub struct OwnedPasswordHash {
+    owned: PasswordHashString,
+    #[borrows(owned)]
+    #[not_covariant]
+    parsed: PasswordHash<'this>,
+}
+
+impl OwnedPasswordHash {
+    pub fn from(inner: PasswordHashString) -> Self {
+        OwnedPasswordHash::new(inner, |inner| inner.password_hash())
+    }
+
+    pub fn parsed(&self) -> PasswordHash {
+        self.with_parsed(|x| x.clone())
+    }
+}
+
+impl Serialize for OwnedPasswordHash {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: serde::Serializer,
+    {
+        self.with_owned(|owned| owned.as_str().serialize(serializer))
+    }
+}
+
+impl<'de> Deserialize<'de> for OwnedPasswordHash {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        Ok(String::deserialize(deserializer)?
+            .pipe(|hash| {
+                PasswordHashString::parse(&hash, Encoding::B64).map_err(|_| {
+                    <D::Error as serde::de::Error>::invalid_value(
+                        Unexpected::Str(&hash),
+                        &"valid password hash",
+                    )
+                })
+            })?
+            .pipe(OwnedPasswordHash::from))
+    }
+}
+
+pub struct ExpiringValue {
+    pub value: String,
+    pub expires: SystemTime,
+}
+
+impl Serialize for ExpiringValue {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: serde::Serializer,
+    {
+        <(&str, u64)>::serialize(
+            &(
+                &self.value,
+                (self.expires.duration_since(SystemTime::UNIX_EPOCH))
+                    .unwrap()
+                    .as_secs(),
+            ),
+            serializer,
+        )
+    }
+}
+
+impl<'de> Deserialize<'de> for ExpiringValue {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        <(String, u64)>::deserialize(deserializer).map(|(token, unix)| ExpiringValue {
+            value: token,
+            expires: UNIX_EPOCH + Duration::from_secs(unix),
+        })
+    }
+}
+
+#[derive(serde::Serialize, serde::Deserialize)]
+pub struct App {
+    name: String,
+    secret: String,
+}