diff --git a/Cargo.lock b/Cargo.lock index 2ddc6b2..d6fe171 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -64,7 +64,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -75,7 +75,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -321,7 +321,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -1091,7 +1091,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -1177,7 +1177,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", "version_check", "yansi", ] @@ -1247,7 +1247,7 @@ checksum = "5fddb4f8d99b0a2ebafc65a87a69a7b9875e4b1ae1f00db265d300ef7f28bccc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -1344,7 +1344,7 @@ dependencies = [ "proc-macro2", "quote", "rocket_http", - "syn 2.0.50", + "syn 2.0.51", "unicode-xid", "version_check", ] @@ -1405,6 +1405,7 @@ dependencies = [ "serde", "serde_json", "sqlx", + "sqly", "time", "uuid", ] @@ -1469,7 +1470,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -1808,6 +1809,15 @@ dependencies = [ "uuid", ] +[[package]] +name = "sqly" +version = "0.0.1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.51", +] + [[package]] name = "stable-pattern" version = "0.1.0" @@ -1856,9 +1866,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.50" +version = "2.0.51" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" +checksum = "6ab617d94515e94ae53b8406c628598680aa0c9587474ecbe58188f7b345d66c" dependencies = [ "proc-macro2", "quote", @@ -1894,7 +1904,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -1979,7 +1989,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -2067,7 +2077,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] @@ -2443,7 +2453,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.50", + "syn 2.0.51", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 9e1f234..e97c0e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,3 +13,9 @@ serde = { version = "1", features = ["derive"] } time = { version = "0.3.34" } uuid = { version = "1.7.0", features = ["v4", "serde"] } serde_json = "1.0.114" +sqly = { path = "sqly" } + +[workspace] +members = [ + "sqly", +] diff --git a/sqly/Cargo.toml b/sqly/Cargo.toml new file mode 100644 index 0000000..48a80e5 --- /dev/null +++ b/sqly/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "sqly" +version = "0.0.1" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.78" +quote = "1.0.35" +syn = "2.0.51" + diff --git a/sqly/rust-toolchain.toml b/sqly/rust-toolchain.toml new file mode 100644 index 0000000..5d56faf --- /dev/null +++ b/sqly/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" diff --git a/sqly/src/lib.rs b/sqly/src/lib.rs new file mode 100644 index 0000000..10519f8 --- /dev/null +++ b/sqly/src/lib.rs @@ -0,0 +1,192 @@ +use quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{bracketed, parse_macro_input, Expr, Ident, LitStr, Result, Token}; + +extern crate proc_macro; + +#[proc_macro] +pub fn query_parsed(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let macro_params = parse_macro_input!(input as QueryChecked); + let connection = macro_params.connection; + let query = macro_params.query; + let params = macro_params.params; + let return_type = macro_params.return_type; + + let query = quote! { + let rows = sqlx::query!(#query, #params) + .fetch_all(#connection) + .await?.iter() + }; + + let parse = generate_parse_value(&return_type); + + let result = quote! { + { + #query; + #parse + } + }; + + proc_macro::TokenStream::from(result) +} + +fn generate_parse_value(value: &TypeMapping) -> proc_macro2::TokenStream { + match value.clone() { + TypeMapping::Column(name) => generate_parse_column(name), + TypeMapping::Vec(mapping) => generate_parse_vec(mapping), + TypeMapping::Object { + struct_name, + properties, + } => generate_parse_object(struct_name, properties), + } +} + +fn generate_parse_column(column: Ident) -> proc_macro2::TokenStream { + quote! { + rows[0].#column + } +} + +fn generate_parse_object( + struct_name: Ident, + properties: Vec, +) -> proc_macro2::TokenStream { + let property_lines = + properties + .iter() + .map(|property: &TypeProperty| -> proc_macro2::TokenStream { + let name = &property.name; + let value = generate_parse_value(&property.mapping); + quote! { + #name: #value + } + }); + + quote! { + #struct_name { + #( #property_lines ),* + } + } +} + +fn generate_parse_vec(inner_type: Box) -> proc_macro2::TokenStream { + let value_expression = generate_parse_value(&*inner_type); + + quote! { + { + let vec = Vec::new(); + let optional_value = #value_expression; + + if optional_value.is_some() { + vec.push(#value_expression); + } + vec + } + } +} + +#[derive(Clone, Debug)] +struct QueryChecked { + connection: Expr, + query: LitStr, + params: Punctuated, + return_type: TypeMapping, +} + +#[derive(Clone, Debug)] +enum TypeMapping { + Column(Ident), + Vec(Box), + Object { + struct_name: Ident, + properties: Vec, + }, +} + +#[derive(Clone, Debug)] +struct TypeProperty { + name: Ident, + mapping: TypeMapping, +} + +impl Parse for QueryChecked { + fn parse(input: ParseStream) -> Result { + let mut connection: Option = None; + let mut query: Option = None; + let mut params: Option> = None; + let mut return_type: Option = None; + + while !input.is_empty() { + let identifier = input.parse::()?; + input.parse::()?; + + match identifier.to_string().as_str() { + "connection" => connection = Some(input.parse()?), + "query" => query = Some(input.parse()?), + "params" => { + let content; + bracketed!(content in input); + + params = Some(Punctuated::parse_terminated(&content)?) + } + "return_type" => return_type = Some(input.parse()?), + unknown => { + return Err(input.error(format!("Unknown property {}", unknown))); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + Ok(QueryChecked { + connection: connection.ok_or(input.error("Expected connection property"))?, + query: query.ok_or(input.error("Expected query property"))?, + params: params.unwrap_or(Punctuated::new()), + return_type: return_type.ok_or(input.error("Expected return type property"))?, + }) + } +} + +impl Parse for TypeMapping { + fn parse(input: ParseStream) -> Result { + let identifier = input.parse::()?; + + if !input.is_empty() && input.peek(Token![<]) && identifier.to_string() == "Vec" { + input.parse::()?; + let internal: TypeMapping = input.parse()?; + input.parse::]>()?; + + Ok(TypeMapping::Vec(Box::new(internal))) + } else if !input.is_empty() && input.peek(syn::token::Brace) { + let internal; + syn::braced!(internal in input); + + let properties: Punctuated = + Punctuated::parse_terminated(&internal)?; + + let properties = properties.iter().map(|p| p.clone()).collect(); + + Ok(TypeMapping::Object { + struct_name: identifier, + properties, + }) + } else { + Ok(TypeMapping::Column(identifier)) + } + } +} +impl Parse for TypeProperty { + fn parse(input: ParseStream) -> Result { + let identifier = input.parse::()?; + input.parse::()?; + let property_type: TypeMapping = input.parse()?; + + Ok(TypeProperty { + name: identifier, + mapping: property_type, + }) + } +} diff --git a/src/db.rs b/src/db.rs deleted file mode 100644 index 51dfbdf..0000000 --- a/src/db.rs +++ /dev/null @@ -1,180 +0,0 @@ -use serde::Serialize; -use sqlx::{Pool, Postgres}; -use uuid::Uuid; - -#[derive(Clone)] -pub struct DB { - connection_pool: Pool:: -} - -#[derive(Serialize)] -pub struct Site { - uuid: Uuid, - title: String, - base_url: String, - theme: String, - boards: Vec -} - -#[derive(Serialize, Clone)] -pub struct Board { - uuid: Uuid, - title: String, - description: String, - threads: Vec -} -impl Board { - fn clone(&self) -> Board { - todo!() - } -} - -#[derive(Serialize, Clone)] -pub struct Thread { - uuid: Uuid, - title: String, - posts: Vec -} - -#[derive(Serialize, Clone)] -pub struct Post { - uuid: Uuid, - contents: String, - author: User -} - -#[derive(Serialize, Clone)] -pub struct User { - uuid: Uuid, - email: String, - username: String, - password_hash: String, - is_admin: bool -} - -impl DB { - pub async fn init() -> Self { - let db_url = std::env::var("DATABASE_URL").expect("Please provide DATABASE_URL in environment"); - let pool = Pool::::connect(db_url.as_str()).await.expect("Could not connect to database"); - - sqlx::migrate!() - .run(&pool) - .await - .expect("Could not run database migrations"); - - DB { - connection_pool: pool - } - } - - pub async fn get_site_data(&self) -> Result { - let rows = sqlx::query!(r#" - select - site_uuid, - site_title, - site_base_url, - site_theme, - board_uuid, - board_title, - board_description, - thread_uuid, - thread_title, - post_uuid as "post_uuid?", - post_contents as "post_contents?", - user_uuid, - user_email, - user_username, - user_password_hash, - user_is_admin - from forum.site - left join forum.board on board_site = site_uuid - left join forum.thread on thread_board = board_uuid - left join forum.post on post_thread = thread_uuid - left join forum.user on post_author = user_uuid - "#).fetch_all(&self.connection_pool).await.expect("Could not connect to database"); - - let mut site : Option = None; - let mut last_board : Option = None; - let mut last_thread : Option = None; - let mut last_post : Option = None; - - for row in rows { - if site.is_none() || row.site_uuid.is_some() && row.site_uuid.unwrap() != site.as_ref().unwrap().uuid { - site = Some(Site { - uuid: row.site_uuid.unwrap(), - title: row.site_title.unwrap_or(String::new()), - base_url: row.site_base_url.unwrap_or(String::new()), - theme: row.site_theme.unwrap_or(String::new()), - boards: Vec::new() - }) - } - - if last_board.is_none() || row.board_uuid.is_some() && row.board_uuid.unwrap() != last_board.as_ref().unwrap().uuid { - if let Some(ref board) = last_board { - site.as_mut().unwrap().boards.push(board.clone()) - } - - if row.board_uuid.is_some() { - last_board = Some(Board { - uuid: row.board_uuid.unwrap(), - title: row.board_title.unwrap_or(String::new()), - description: row.board_description.unwrap_or(String::new()), - threads: Vec::new() - }) - } - } - - if last_thread.is_none() || row.thread_uuid.is_some() && row.thread_uuid.unwrap() != last_thread.as_mut().unwrap().uuid { - if let Some(ref thread) = last_thread { - last_board.as_mut().unwrap().threads.push(thread.clone()) - } - - if row.thread_uuid.is_some() { - last_thread = Some(Thread { - uuid: row.thread_uuid.unwrap(), - title: row.thread_title.unwrap_or(String::new()), - posts: Vec::new() - }) - } - } - - if last_post.is_none() || row.post_uuid.is_some() && row.post_uuid.unwrap() != last_post.as_ref().unwrap().uuid { - if let Some(ref post) = last_post { - last_thread.as_mut().unwrap().posts.push(post.clone()) - } - - if row.post_uuid.is_some() { - last_post = Some(Post { - uuid: row.post_uuid.unwrap(), - contents: row.post_contents.unwrap_or(String::new()), - author: User { - uuid: row.user_uuid.unwrap(), - email: row.user_email.unwrap_or(String::new()), - username: row.user_username.unwrap_or(String::new()), - password_hash: row.user_password_hash.unwrap_or(String::new()), - is_admin: row.user_is_admin.unwrap_or(false) - } - }) - } - } - } - - if last_post.is_some() { - last_thread.as_mut().unwrap().posts.push(last_post.unwrap()); - } - - if last_thread.is_some() { - last_board.as_mut().unwrap().threads.push(last_thread.unwrap()); - } - - if last_board.is_some() { - site.as_mut().unwrap().boards.push(last_board.unwrap()); - } - - if let Some(site) = site { - Ok(site) - } else { - Err("Could not find site in DB".to_string()) - } - } -} diff --git a/src/db/mod.rs b/src/db/mod.rs new file mode 100644 index 0000000..fefb950 --- /dev/null +++ b/src/db/mod.rs @@ -0,0 +1,74 @@ +use sqlx::{Pool, Postgres}; + +pub mod objects; +use objects::{Board, Post, Site, Thread, User}; + +#[derive(Clone)] +pub struct DB { + connection_pool: Pool, +} + +impl DB { + pub async fn init() -> Self { + let db_url = + std::env::var("DATABASE_URL").expect("Please provide DATABASE_URL in environment"); + let pool = Pool::::connect(db_url.as_str()) + .await + .expect("Could not connect to database"); + + sqlx::migrate!() + .run(&pool) + .await + .expect("Could not run database migrations"); + + DB { + connection_pool: pool, + } + } + + pub async fn get_site_data(&self) -> Result { + let site = sqly::query_parsed!( + connection = &self.connection_pool, + query = r#" + select * from forum.site + left join forum.board on board_site = site_uuid + left join forum.thread on thread_board = board_uuid + left join forum.post on post_thread = thread_uuid + left join forum.user on post_author = user_uuid + "#, + // return_type = board_uuid + return_type = Site { + uuid: site_uuid, + title: site_title, + base_url: site_base_url, + theme: site_theme, + boards: Vec + }> + }> + } + ); + + if let Some(site) = site { + Ok(site) + } else { + Err("Could not find site in DB".to_string()) + } + } +} diff --git a/src/db/objects.rs b/src/db/objects.rs new file mode 100644 index 0000000..8993972 --- /dev/null +++ b/src/db/objects.rs @@ -0,0 +1,42 @@ +use serde::Serialize; +use uuid::Uuid; + +#[derive(Serialize, Clone)] +pub struct Site { + pub uuid: Uuid, + pub title: String, + pub base_url: String, + pub theme: String, + pub boards: Vec, +} + +#[derive(Serialize, Clone)] +pub struct Board { + pub uuid: Uuid, + pub title: String, + pub description: String, + pub threads: Vec, +} + +#[derive(Serialize, Clone)] +pub struct Thread { + pub uuid: Uuid, + pub title: String, + pub posts: Vec, +} + +#[derive(Serialize, Clone)] +pub struct Post { + pub uuid: Uuid, + pub contents: String, + pub author: User, +} + +#[derive(Serialize, Clone)] +pub struct User { + pub uuid: Uuid, + pub email: String, + pub username: String, + pub password_hash: String, + pub is_admin: bool, +}