use dotenvy::dotenv; use serde::{Deserialize, Serialize}; use serenity::{ model::{ guild, id::{GuildId, RoleId}, }, prelude::TypeMapKey, }; use chrono::{Datelike, SecondsFormat, Utc}; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serenity::client::Context; use serenity::model::id::UserId; use sqlx::{ sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteRow}, Error, FromRow, Pool, Row, Sqlite, }; use std::{env, str::FromStr, sync::Arc}; use tokio::sync::RwLock; pub struct Config { pub skynet_server: GuildId, pub ldap_api: String, pub home: String, pub database: String, pub auth: String, pub discord_token: String, pub mail_smtp: String, pub mail_user: String, pub mail_pass: String, pub wolves_url: String, } impl TypeMapKey for Config { type Value = Arc>; } pub struct DataBase; impl TypeMapKey for DataBase { type Value = Arc>>; } pub fn get_config() -> Config { dotenv().ok(); // reasonable defaults let mut config = Config { skynet_server: Default::default(), ldap_api: "https://api.account.skynet.ie".to_string(), auth: "".to_string(), discord_token: "".to_string(), home: ".".to_string(), database: "database.db".to_string(), mail_smtp: "".to_string(), mail_user: "".to_string(), mail_pass: "".to_string(), wolves_url: "".to_string(), }; if let Ok(x) = env::var("LDAP_API") { config.ldap_api = x.trim().to_string(); } if let Ok(x) = env::var("SKYNET_SERVER") { config.skynet_server = GuildId::from(str_to_num::(&x)); } if let Ok(x) = env::var("HOME") { config.home = x.trim().to_string(); } if let Ok(x) = env::var("DATABASE") { config.database = x.trim().to_string(); } if let Ok(x) = env::var("LDAP_DISCORD_AUTH") { config.auth = x.trim().to_string(); } if let Ok(x) = env::var("DISCORD_TOKEN") { config.discord_token = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_SMTP") { config.mail_smtp = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_USER") { config.mail_user = x.trim().to_string(); } if let Ok(x) = env::var("EMAIL_PASS") { config.mail_pass = x.trim().to_string(); } if let Ok(x) = env::var("WOLVES_URL") { config.wolves_url = x.trim().to_string(); } config } fn str_to_num(x: &str) -> T { x.trim().parse::().unwrap_or_default() } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerMembers { pub server: GuildId, pub id_wolves: i64, pub expiry: String, } impl<'r> FromRow<'r, SqliteRow> for ServerMembers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); Ok(Self { server, id_wolves: row.try_get("id_wolves")?, expiry: row.try_get("expiry")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct ServerMembersWolves { pub server: GuildId, pub id_wolves: i64, pub expiry: String, pub email: String, pub discord: Option, pub minecraft: Option, } impl<'r> FromRow<'r, SqliteRow> for ServerMembersWolves { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); let discord = match row.try_get("discord") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(UserId::from(tmp as u64)) } } _ => None, }; Ok(Self { server, id_wolves: row.try_get("id_wolves")?, expiry: row.try_get("expiry")?, email: row.try_get("email")?, discord, minecraft: row.try_get("minecraft")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Wolves { pub id_wolves: i64, pub email: String, pub discord: Option, pub minecraft: Option, } impl<'r> FromRow<'r, SqliteRow> for Wolves { fn from_row(row: &'r SqliteRow) -> Result { let discord = match row.try_get("discord") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(UserId::from(tmp as u64)) } } _ => None, }; Ok(Self { id_wolves: row.try_get("id_wolves")?, email: row.try_get("email")?, discord, minecraft: row.try_get("minecraft")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct WolvesVerify { pub email: String, pub discord: UserId, pub auth_code: String, pub date_expiry: String, } impl<'r> FromRow<'r, SqliteRow> for WolvesVerify { fn from_row(row: &'r SqliteRow) -> Result { let user_tmp: i64 = row.try_get("discord")?; let discord = UserId::from(user_tmp as u64); Ok(Self { email: row.try_get("email")?, discord, auth_code: row.try_get("auth_code")?, date_expiry: row.try_get("date_expiry")?, }) } } #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Servers { pub server: GuildId, pub wolves_api: String, pub role_past: Option, pub role_current: Option, pub member_past: i64, pub member_current: i64, } impl<'r> FromRow<'r, SqliteRow> for Servers { fn from_row(row: &'r SqliteRow) -> Result { let server_tmp: i64 = row.try_get("server")?; let server = GuildId::from(server_tmp as u64); let role_past = match row.try_get("role_past") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(RoleId::from(tmp as u64)) } } _ => None, }; let role_current = match row.try_get("role_current") { Ok(x) => { let tmp: i64 = x; if tmp == 0 { None } else { Some(RoleId::from(tmp as u64)) } } _ => None, }; Ok(Self { server, wolves_api: row.try_get("wolves_api")?, role_past, role_current, member_past: row.try_get("member_past")?, member_current: row.try_get("member_current")?, }) } } pub async fn db_init(config: &Config) -> Result, Error> { let database = format!("{}/{}", &config.home, &config.database); let pool = SqlitePoolOptions::new() .max_connections(5) .connect_with( SqliteConnectOptions::from_str(&format!("sqlite://{}", database))? .foreign_keys(true) .create_if_missing(true), ) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS wolves ( id_wolves integer PRIMARY KEY, email text not null, discord integer, minecraft text )", ) .execute(&pool) .await?; sqlx::query("CREATE INDEX IF NOT EXISTS index_discord ON wolves (discord)").execute(&pool).await?; sqlx::query( "CREATE TABLE IF NOT EXISTS wolves_verify ( discord integer PRIMARY KEY, email text not null, auth_code text not null, date_expiry text not null )", ) .execute(&pool) .await?; sqlx::query("CREATE INDEX IF NOT EXISTS index_date_expiry ON wolves_verify (date_expiry)") .execute(&pool) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS server_members ( server integer not null, id_wolves integer not null, expiry text not null, PRIMARY KEY(server,id_wolves), FOREIGN KEY (id_wolves) REFERENCES wolves (id_wolves) )", ) .execute(&pool) .await?; sqlx::query( "CREATE TABLE IF NOT EXISTS servers ( server integer PRIMARY KEY, wolves_api text not null, role_past integer, role_current integer, member_past integer DEFAULT 0, member_current integer DEFAULT 0 )", ) .execute(&pool) .await?; Ok(pool) } pub async fn get_server_config(db: &Pool, server: &GuildId) -> Option { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers WHERE server = ? "#, ) .bind(*server.as_u64() as i64) .fetch_one(db) .await .ok() } pub async fn get_server_member(db: &Pool, server: &GuildId, member: &guild::Member) -> Result { sqlx::query_as::<_, ServerMembersWolves>( r#" SELECT * FROM server_members JOIN wolves USING (id_wolves) WHERE server = ? AND discord = ? "#, ) .bind(*server.as_u64() as i64) .bind(*member.user.id.as_u64() as i64) .fetch_one(db) .await } pub async fn get_server_config_bulk(db: &Pool) -> Vec { sqlx::query_as::<_, Servers>( r#" SELECT * FROM servers "#, ) .fetch_all(db) .await .unwrap_or_default() } pub fn get_now_iso(short: bool) -> String { let now = Utc::now(); if short { format!("{}-{:02}-{:02}", now.year(), now.month(), now.day()) } else { now.to_rfc3339_opts(SecondsFormat::Millis, true) } } pub fn random_string(len: usize) -> String { thread_rng().sample_iter(&Alphanumeric).take(len).map(char::from).collect() } pub mod set_roles { use super::*; pub async fn update_server(ctx: &Context, server: &Servers, remove_roles: &[Option], members_changed: &Vec) { let db_lock = { let data_read = ctx.data.read().await; data_read.get::().expect("Expected Database in TypeMap.").clone() }; let db = db_lock.read().await; let Servers { server, role_past, role_current, .. } = server; let mut roles_set = [0, 0, 0]; let mut members = vec![]; for member in get_server_member_bulk(&db, server).await { if let Some(x) = member.discord { members.push(x); } } let mut members_all = members.len(); if let Ok(x) = server.members(ctx, None, None).await { for mut member in x { // members_changed acts as an override to only deal with teh users in it if !members_changed.is_empty() && !members_changed.contains(&member.user.id) { continue; } if members.contains(&member.user.id) { let mut roles = vec![]; if let Some(role) = &role_past { if !member.roles.contains(role) { roles_set[0] += 1; roles.push(role.to_owned()); } } if let Some(role) = &role_current { if !member.roles.contains(role) { roles_set[1] += 1; roles.push(role.to_owned()); } } if let Err(e) = member.add_roles(ctx, &roles).await { println!("{:?}", e); } } else { // old and never if let Some(role) = &role_past { if member.roles.contains(role) { members_all += 1; } } if let Some(role) = &role_current { if member.roles.contains(role) { roles_set[2] += 1; // if theya re not a current member and have the role then remove it if let Err(e) = member.remove_role(ctx, role).await { println!("{:?}", e); } } } } for role in remove_roles.iter().flatten() { if let Err(e) = member.remove_role(ctx, role).await { println!("{:?}", e); } } } } set_server_numbers(&db, server, members_all as i64, members.len() as i64).await; // small bit of logging to note changes over time println!("{:?} Changes: New: +{}, Current: +{}/-{}", server.as_u64(), roles_set[0], roles_set[1], roles_set[2]); } async fn get_server_member_bulk(db: &Pool, server: &GuildId) -> Vec { sqlx::query_as::<_, ServerMembersWolves>( r#" SELECT * FROM server_members JOIN wolves USING (id_wolves) WHERE ( server = ? AND discord IS NOT NULL AND expiry > ? ) "#, ) .bind(*server.as_u64() as i64) .bind(get_now_iso(true)) .fetch_all(db) .await .unwrap_or_default() } async fn set_server_numbers(db: &Pool, server: &GuildId, past: i64, current: i64) { match sqlx::query_as::<_, Wolves>( " UPDATE servers SET member_past = ?, member_current = ? WHERE server = ? ", ) .bind(past) .bind(current) .bind(*server.as_u64() as i64) .fetch_optional(db) .await { Ok(_) => {} Err(e) => { println!("Failure to insert into {}", server.as_u64()); println!("{:?}", e); } } } } pub mod get_data { use super::*; use crate::set_roles::update_server; use std::collections::BTreeMap; #[derive(Deserialize, Serialize, Debug)] struct WolvesResultUser { committee: String, wolves_id: String, first_name: String, last_name: String, contact_email: String, student_id: Option, note: Option, expiry: String, requested: String, approved: String, sitename: String, domain: String, } #[derive(Deserialize, Serialize, Debug)] struct WolvesResult { success: i8, result: Vec, } #[derive(Deserialize, Serialize, Debug)] struct WolvesResultLocal { pub id_wolves: String, pub email: String, pub expiry: String, } pub async fn get_wolves(ctx: &Context) { let db_lock = { let data_read = ctx.data.read().await; data_read.get::().expect("Expected Database in TypeMap.").clone() }; let db = db_lock.read().await; let config_lock = { let data_read = ctx.data.read().await; data_read.get::().expect("Expected Config in TypeMap.").clone() }; let config = config_lock.read().await; for server_config in get_server_config_bulk(&db).await { let Servers { server, wolves_api, .. } = &server_config; let existing_tmp = get_server_member(&db, server).await; let existing = existing_tmp.iter().map(|data| (data.id_wolves, data)).collect::>(); // list of users that need to be updated for this server let mut user_to_update = vec![]; for user in get_wolves_sub(&config, wolves_api).await { let id = user.wolves_id.parse::().unwrap_or_default(); match existing.get(&(id as i64)) { None => { // user does not exist already, add everything add_users_wolves(&db, &user).await; add_users_server_members(&db, server, &user).await; } Some(old) => { // always update wolves table, in case data has changed add_users_wolves(&db, &user).await; if old.expiry != user.expiry { add_users_server_members(&db, server, &user).await; if let Some(discord_id) = old.discord { user_to_update.push(discord_id); } } } } } if !user_to_update.is_empty() { update_server(ctx, &server_config, &[], &user_to_update).await; } } } pub async fn get_server_member(db: &Pool, server: &GuildId) -> Vec { sqlx::query_as::<_, ServerMembersWolves>( r#" SELECT * FROM server_members JOIN wolves USING (id_wolves) WHERE ( server = ? AND discord IS NOT NULL ) "#, ) .bind(*server.as_u64() as i64) .fetch_all(db) .await .unwrap_or_default() } async fn get_wolves_sub(config: &Config, wolves_api: &str) -> Vec { if config.wolves_url.is_empty() { return vec![]; } // get wolves data if let Ok(mut res) = surf::post(&config.wolves_url).header("X-AM-Identity", wolves_api).await { if let Ok(WolvesResult { success, result, }) = res.body_json().await { if success != 1 { return vec![]; } return result; } } vec![] } async fn add_users_wolves(db: &Pool, user: &WolvesResultUser) { // expiry match sqlx::query_as::<_, Wolves>( " INSERT INTO wolves (id_wolves, email) VALUES ($1, $2) ON CONFLICT(id_wolves) DO UPDATE SET email = $2 ", ) .bind(&user.wolves_id) .bind(&user.contact_email) .fetch_optional(db) .await { Ok(_) => {} Err(e) => { println!("Failure to insert into Wolves {:?}", user); println!("{:?}", e); } } } async fn add_users_server_members(db: &Pool, server: &GuildId, user: &WolvesResultUser) { match sqlx::query_as::<_, ServerMembers>( " INSERT OR REPLACE INTO server_members (server, id_wolves, expiry) VALUES (?1, ?2, ?3) ", ) .bind(*server.as_u64() as i64) .bind(&user.wolves_id) .bind(&user.expiry) .fetch_optional(db) .await { Ok(_) => {} Err(e) => { println!("Failure to insert into ServerMembers {} {:?}", server.as_u64(), user); println!("{:?}", e); } } } }