commit 038020e662c757b9ab7eb8b8a0eb0c2225430fc7 Author: ashelyn vi Date: Fri Mar 15 18:17:05 2024 -0600 not really tested: split proc macro and structs diff --git a/joinrs/Cargo.toml b/joinrs/Cargo.toml new file mode 100644 index 0000000..54666c9 --- /dev/null +++ b/joinrs/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "joinrs" +version = "0.0.1" +edition = "2021" + +[dependencies] +proc-macro2 = "1.0.78" +quote = "1.0.35" +syn = "2.0.51" +anyhow = "1.0.81" +thiserror = "1.0.58" +joinrs_proc = { path = "../joinrs_proc" } diff --git a/joinrs/src/lib.rs b/joinrs/src/lib.rs new file mode 100644 index 0000000..0bc7021 --- /dev/null +++ b/joinrs/src/lib.rs @@ -0,0 +1,16 @@ +pub use joinrs_proc::query_parsed; +pub use anyhow; + +mod err { + use thiserror::Error; + + #[derive(Error, Debug)] + pub enum QueryError { + #[error("Query returned no rows")] + RowNotFound, + #[error("Expected column {0} to have a value, but it was null")] + NullColumn(String), + #[error("Query returned too many rows, expected 1 but got {0}")] + TooManyRows(u32) + } +} diff --git a/joinrs_proc/Cargo.toml b/joinrs_proc/Cargo.toml new file mode 100644 index 0000000..df010e8 --- /dev/null +++ b/joinrs_proc/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "joinrs_proc" +version = "0.0.1" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1.0.78" +quote = "1.0.35" +syn = "2.0.51" +anyhow = "1.0.81" +thiserror = "1.0.58" diff --git a/joinrs_proc/src/lib.rs b/joinrs_proc/src/lib.rs new file mode 100644 index 0000000..1e9006a --- /dev/null +++ b/joinrs_proc/src/lib.rs @@ -0,0 +1,214 @@ +use quote::quote; +use syn::parse::{Parse, ParseStream}; +use syn::punctuated::Punctuated; +use syn::{bracketed, parse_macro_input, Expr, Ident, LitStr, Token}; + +extern crate proc_macro; + +extern crate self as joinrs_proc; + +#[proc_macro] +pub fn query_parsed(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let macro_params = parse_macro_input!(input as QueryChecked); + let connection = macro_params.connection; + let query = macro_params.query; + let params = macro_params.params; + let column_mapping = macro_params.column_mapping; + let return_type = get_return_type(&column_mapping); + + let query = quote! { + sqlx::query!(#query, #params) + .fetch_all(#connection) + .await.expect("Query failed").into_iter().peekable() + }; + + let parse = generate_parse_expression(&column_mapping); + + let result = quote! { + { + let mut rows = #query; + + let mut parse_closure = || -> joinrs::anyhow::Result<#return_type> { + + if rows.len() < 1 { + // anyhow::bail!(sqly::err::QueryError::RowNotFound) + panic!("Not enough rows"); + } + + let result = #parse; + Ok(result) + }; + + parse_closure() + } + }; + + proc_macro::TokenStream::from(result) +} + +fn generate_parse_expression(value: &TypeMapping) -> proc_macro2::TokenStream { + match value.clone() { + TypeMapping::Column(name) => generate_parse_column(name), + TypeMapping::Vec(mapping) => generate_parse_vec(mapping), + TypeMapping::Object { + struct_name, + properties, + } => generate_parse_object(struct_name, properties), + } +} + +fn generate_parse_column(column: Ident) -> proc_macro2::TokenStream { + quote! { + { + let row = rows.peek().expect("Too few rows"); + row.#column.clone() + } + } +} + +fn generate_parse_object( + struct_name: Ident, + properties: Vec, +) -> proc_macro2::TokenStream { + let property_lines = + properties + .iter() + .map(|property: &TypeProperty| -> proc_macro2::TokenStream { + let name = &property.name; + let value = generate_parse_expression(&property.mapping); + quote! { + #name: #value + } + }); + + quote! { + #struct_name { + #( #property_lines ),* + } + } +} + +fn generate_parse_vec(inner_type: Box) -> proc_macro2::TokenStream { + let value_expression = generate_parse_expression(&*inner_type); + + quote! { + { + let mut vec = Vec::new(); + vec.push(#value_expression); + vec + } + } +} + +fn get_return_type(mapping: &TypeMapping) -> Ident { + match mapping { + TypeMapping::Column(_) => panic!("Cannot generate type mapping for a single column: Unknown type"), + TypeMapping::Vec(mapping) => get_return_type(mapping), + TypeMapping::Object { struct_name, properties: _ } => struct_name.clone(), + } +} + +#[derive(Clone, Debug)] +struct QueryChecked { + connection: Expr, + query: LitStr, + params: Punctuated, + column_mapping: TypeMapping, +} + +#[derive(Clone, Debug)] +enum TypeMapping { + Column(Ident), + Vec(Box), + Object { + struct_name: Ident, + properties: Vec, + }, +} + +#[derive(Clone, Debug)] +struct TypeProperty { + name: Ident, + mapping: TypeMapping, +} + +impl Parse for QueryChecked { + fn parse(input: ParseStream) -> syn::Result { + let mut connection: Option = None; + let mut query: Option = None; + let mut params: Option> = None; + let mut return_type: Option = None; + + while !input.is_empty() { + let identifier = input.parse::()?; + input.parse::()?; + + match identifier.to_string().as_str() { + "connection" => connection = Some(input.parse()?), + "query" => query = Some(input.parse()?), + "params" => { + let content; + bracketed!(content in input); + + params = Some(Punctuated::parse_terminated(&content)?) + } + "return_type" => return_type = Some(input.parse()?), + unknown => { + return Err(input.error(format!("Unknown property {}", unknown))); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + Ok(QueryChecked { + connection: connection.ok_or(input.error("Expected connection property"))?, + query: query.ok_or(input.error("Expected query property"))?, + params: params.unwrap_or(Punctuated::new()), + column_mapping: return_type.ok_or(input.error("Expected return type property"))?, + }) + } +} + +impl Parse for TypeMapping { + fn parse(input: ParseStream) -> syn::Result { + let identifier = input.parse::()?; + + if !input.is_empty() && input.peek(Token![<]) && identifier.to_string() == "Vec" { + input.parse::()?; + let internal: TypeMapping = input.parse()?; + input.parse::]>()?; + + Ok(TypeMapping::Vec(Box::new(internal))) + } else if !input.is_empty() && input.peek(syn::token::Brace) { + let internal; + syn::braced!(internal in input); + + let properties: Punctuated = + Punctuated::parse_terminated(&internal)?; + + let properties = properties.iter().map(|p| p.clone()).collect(); + + Ok(TypeMapping::Object { + struct_name: identifier, + properties, + }) + } else { + Ok(TypeMapping::Column(identifier)) + } + } +} +impl Parse for TypeProperty { + fn parse(input: ParseStream) -> syn::Result { + let identifier = input.parse::()?; + input.parse::()?; + let property_type: TypeMapping = input.parse()?; + + Ok(TypeProperty { + name: identifier, + mapping: property_type, + }) + } +}