use dotenvy::dotenv; use serde::{Deserialize, Serialize}; use serenity::{ model::{ guild, id::{GuildId, RoleId}, }, prelude::TypeMapKey, }; use chrono::{Datelike, SecondsFormat, Utc}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use sqlx::{ sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteRow}, Error, FromRow, Pool, Row, Sqlite, }; use std::{env, str::FromStr, sync::Arc}; use tokio::sync::RwLock; pub struct Config { pub skynet_server: GuildId, pub ldap_api: String, pub home: String, pub database: String, pub csv: String, pub auth: String, pub discord_token: String, pub mail_smtp: String, pub mail_user: String, pub mail_pass: String, } impl TypeMapKey for Config { type Value = Arc>; } pub struct DataBase; impl TypeMapKey for DataBase { type Value = Arc>>; } pub fn get_config() -> Config { dotenv().ok(); // reasonable defaults let mut config = Config { skynet_server: Default::default(), ldap_api: "https://api.account.skynet.ie".to_string(), auth: "".to_string(), discord_token: "".to_string(), home: ".".to_string(), database: "database.db".to_string(), csv: "wolves.csv".to_string(), mail_smtp: "".to_string(), mail_user: "".to_string(), mail_pass: "".to_string(), }; if let Ok(x) = env::var("LDAP_API") { config.ldap_api = x.trim().to_string(); } if let Ok(x) = env::var("SKYNET_SERVER") { config.skynet_server = GuildId::from(str_to_num::(&x)); } if let Ok(x) = env::var("HOME") { config.home = x.trim().to_string(); } if let Ok(x) = env::var("DATABASE") { config.database = x.trim().to_string(); } if let Ok(x) = env::var("CSV") { config.csv = x.trim().to_string(); } if let Ok(x) = env::var("LDAP_DISCORD_AUTH") { config.auth = x.trim().to_string(); } if let Ok(x) = env::var("DISCORD_TOKEN") { config.discord_token = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_SMTP") { config.mail_smtp = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_USER") { config.mail_user = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_PASS") { config.mail_pass = x.trim().to_string(); } config } fn str_to_num(x: &str) -> T { x.trim().parse::().unwrap_or_default() } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerMembers { pub server: GuildId, pub id_wolves: String, pub expiry: String, } impl<'r> FromRow<'r, SqliteRow> for ServerMembers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); Ok(Self { server, id_wolves: row.try_get("id_wolves")?, expiry: row.try_get("expiry")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize, sqlx::FromRow)] pub struct Wolves { pub id_wolves: String, pub email: String, pub verified: bool, pub discord: Option, pub minecraft: Option, } #[derive(Debug, Clone, Deserialize, Serialize, sqlx::FromRow)] pub struct WolvesVerify { pub email: String, pub discord: String, pub auth_code: String, pub date_expiry: String, } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Servers { pub server: GuildId, pub wolves_api: String, pub role_past: Option, pub role_current: Option, pub member_past: i64, pub member_current: i64, } impl<'r> FromRow<'r, SqliteRow> for Servers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); let role_past = match row.try_get("role_past") { Ok(x) => { let tmp: i64 = x; Some(RoleId::from(tmp as u64)) } _ => None, }; let role_current = match row.try_get("role_current") { Ok(x) => { let tmp: i64 = x; Some(RoleId::from(tmp as u64)) } _ => None, }; Ok(Self { server, wolves_api: row.try_get("wolves_api")?, role_past, role_current, member_past: row.try_get("member_past")?, member_current: row.try_get("member_current")?, }) } } pub async fn db_init(config: &Config) -> Result, Error> { let database = format!("{}/{}", &config.home, &config.database); let pool = SqlitePoolOptions::new() .max_connections(5) .connect_with( SqliteConnectOptions::from_str(&format!("sqlite://{}", database))? .foreign_keys(true) .create_if_missing(true), ) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS wolves ( id_wolves text PRIMARY KEY, email text not null, discord text, minecraft text, verified integer DEFAULT FALSE )", ) .execute(&pool) .await?; sqlx::query("CREATE INDEX IF NOT EXISTS index_discord ON wolves (discord)").execute(&pool).await?; sqlx::query( "CREATE TABLE IF NOT EXISTS wolves_verify ( discord text PRIMARY KEY, email text not null, auth_code text not null, date_expiry text not null )", ) .execute(&pool) .await?; sqlx::query("CREATE INDEX IF NOT EXISTS index_date_expiry ON wolves_verify (date_expiry)") .execute(&pool) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS server_members ( server integer not null, id_wolves text not null, expiry text not null, PRIMARY KEY(server,id_wolves), FOREIGN KEY (id_wolves) REFERENCES wolves (id_wolves) )", ) .execute(&pool) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS servers ( server integer PRIMARY KEY, wolves_api text not null, role_past integer, role_current integer, member_past integer DEFAULT 0, member_current integer DEFAULT 0 )", ) .execute(&pool) .await?; Ok(pool) } pub async fn get_server_config(db: &Pool, server: &GuildId) -> Option { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers WHERE server = ? "#, ) .bind(*server.as_u64() as i64) .fetch_one(db) .await .ok() } pub async fn get_server_member(db: &Pool, server: &GuildId, member: &guild::Member) -> Option { let wolves_data = sqlx::query_as::<_, Wolves>( r#" SELECT * FROM wolves WHERE discord = ? "#, ) .bind(*server.as_u64() as i64) .bind(&member.user.name) .fetch_one(db) .await; if let Ok(user_wolves) = wolves_data { // check if the suer is on the server return sqlx::query_as::<_, ServerMembers>( r#" SELECT * FROM server_members WHERE server = ? AND id_wolves = ? "#, ) .bind(*server.as_u64() as i64) .bind(&user_wolves.id_wolves) .fetch_one(db) .await .ok(); } None } pub async fn get_server_config_bulk(db: &Pool) -> Vec { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers "#, ) .fetch_all(db) .await .unwrap_or_default() } pub fn get_now_iso(short: bool) -> String { let now = Utc::now(); if short { format!("{}-{:02}-{:02}", now.year(), now.month(), now.day()) } else { now.to_rfc3339_opts(SecondsFormat::Millis, true) } } pub fn random_string(len: usize) -> String { thread_rng().sample_iter(&Alphanumeric).take(len).map(char::from).collect() }