use serenity::prelude::TypeMapKey; use std::sync::Arc; use tokio::sync::RwLock; use sqlx::{Error, FromRow, Pool, Row, Sqlite}; use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteRow}; use serenity::model::id::{ChannelId, GuildId, RoleId, UserId}; use serde::{Deserialize, Serialize}; use serenity::model::guild; use std::str::FromStr; use crate::Config; pub struct DataBase; impl TypeMapKey for DataBase { type Value = Arc>>; } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerMembers { pub server: GuildId, pub id_wolves: i64, pub expiry: String, } impl<'r> FromRow<'r, SqliteRow> for ServerMembers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); Ok(Self { server, id_wolves: row.try_get("id_wolves")?, expiry: row.try_get("expiry")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerMembersWolves { pub server: GuildId, pub id_wolves: i64, pub expiry: String, pub email: String, pub discord: Option, pub minecraft: Option, } impl<'r> FromRow<'r, SqliteRow> for ServerMembersWolves { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); let discord = match row.try_get("discord") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(UserId::from(tmp as u64)) } } _ => None, }; Ok(Self { server, id_wolves: row.try_get("id_wolves")?, expiry: row.try_get("expiry")?, email: row.try_get("email")?, discord, minecraft: row.try_get("minecraft")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Wolves { pub id_wolves: i64, pub email: String, pub discord: Option, pub minecraft: Option, } impl<'r> FromRow<'r, SqliteRow> for Wolves { fn from_row(row: &'r SqliteRow) -> Result { let discord = match row.try_get("discord") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(UserId::from(tmp as u64)) } } _ => None, }; Ok(Self { id_wolves: row.try_get("id_wolves")?, email: row.try_get("email")?, discord, minecraft: row.try_get("minecraft")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct WolvesVerify { pub email: String, pub discord: UserId, pub auth_code: String, pub date_expiry: String, } impl<'r> FromRow<'r, SqliteRow> for WolvesVerify { fn from_row(row: &'r SqliteRow) -> Result { let user_tmp: i64 = row.try_get("discord")?; let discord = UserId::from(user_tmp as u64); Ok(Self { email: row.try_get("email")?, discord, auth_code: row.try_get("auth_code")?, date_expiry: row.try_get("date_expiry")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Servers { pub server: GuildId, pub wolves_api: String, pub role_past: Option, pub role_current: RoleId, pub member_past: i64, pub member_current: i64, pub bot_channel_id: ChannelId, // TODO: these can be removed in teh future with an API update pub server_name: String, pub wolves_link: String, } impl<'r> FromRow<'r, SqliteRow> for Servers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); let role_past = match row.try_get("role_past") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(RoleId::from(tmp as u64)) } } _ => None, }; let role_current = match row.try_get("role_current") { Ok(x) => { let tmp: i64 = x; RoleId::from(tmp as u64) } _ => RoleId::from(0u64), }; let bot_channel_tmp: i64 = row.try_get("bot_channel_id")?; let bot_channel_id = ChannelId::from(bot_channel_tmp as u64); Ok(Self { server, wolves_api: row.try_get("wolves_api")?, role_past, role_current, member_past: row.try_get("member_past")?, member_current: row.try_get("member_current")?, bot_channel_id, server_name: row.try_get("server_name")?, wolves_link: row.try_get("wolves_link")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct RoleAdder { pub server: GuildId, pub role_a: RoleId, pub role_b: RoleId, pub role_c: RoleId, } impl<'r> FromRow<'r, SqliteRow> for RoleAdder { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); Ok(Self { server, role_a: get_role_from_row(row, "role_a"), role_b: get_role_from_row(row, "role_b"), role_c: get_role_from_row(row, "role_c"), }) } } fn get_role_from_row(row: &SqliteRow, col: &str) -> RoleId { match row.try_get(col) { Ok(x) => { let tmp: i64 = x; RoleId(tmp as u64) } _ => RoleId::from(0u64), } } pub async fn db_init(config: &Config) -> Result, Error> { let database = format!("{}/{}", &config.home, &config.database); let pool = SqlitePoolOptions::new() .max_connections(5) .connect_with( SqliteConnectOptions::from_str(&format!("sqlite://{}", database))? .foreign_keys(true) .create_if_missing(true), ) .await?; // migrations are amazing! sqlx::migrate!("./db/migrations").run(&pool).await?; Ok(pool) } pub async fn get_server_config(db: &Pool, server: &GuildId) -> Option { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers WHERE server = ? "#, ) .bind(*server.as_u64() as i64) .fetch_one(db) .await .ok() } pub async fn get_server_member(db: &Pool, server: &GuildId, member: &guild::Member) -> Result { sqlx::query_as::<_, ServerMembersWolves>( r#" SELECT * FROM server_members JOIN wolves USING (id_wolves) WHERE server = ? AND discord = ? "#, ) .bind(*server.as_u64() as i64) .bind(*member.user.id.as_u64() as i64) .fetch_one(db) .await } pub async fn get_server_config_bulk(db: &Pool) -> Vec { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers "#, ) .fetch_all(db) .await .unwrap_or_default() }