handle cors

main
Ashelyn Dawn 1 month ago
parent f65905869e
commit 562c40ecc9
No known key found for this signature in database
GPG Key ID: D1980B8C6F349BC1

@ -14,6 +14,7 @@ pub struct Config {
#[derive(Deserialize)]
pub struct Form {
pub slug: String,
pub cors_domain: Option<String>,
pub name: String,
pub recipient_email: String,
pub subject: String,

@ -3,6 +3,7 @@ use std::net::IpAddr;
use std::collections::HashMap;
use rocket::request::{self, FromRequest};
use rocket::response::{Responder, Response as RocketResponse};
use rocket::serde::json::Json;
use rocket::http::Status;
use rocket::Request;
@ -10,6 +11,25 @@ use mail_builder::MessageBuilder;
use serde::Serialize;
pub struct ClientIp(Option<IpAddr>);
pub struct Origin<'r>(Option<&'r str>);
pub enum CorsResponse {
NoCors,
AllowDomain(String)
}
impl<'r> Responder<'r,'r> for CorsResponse {
fn respond_to(self, _request: &'r Request<'_>) -> rocket::response::Result<'r> {
match self {
CorsResponse::NoCors => RocketResponse::build()
.status(Status::NoContent).ok(),
CorsResponse::AllowDomain(origin) => RocketResponse::build()
.raw_header("Access-Control-Allow-Origin", origin)
.raw_header("Access-Control-Allow-Methods", "POST")
.raw_header("Access-Control-Allow-Headers", "content-type")
.status(Status::NoContent).ok(),
}
}
}
#[derive(Serialize)]
pub struct Response {
@ -22,12 +42,39 @@ impl<'r> FromRequest<'r> for ClientIp {
type Error = Infallible;
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
request::Outcome::Success(ClientIp(request.client_ip()))
request::Outcome::Success(ClientIp(request.client_ip()))
}
}
#[rocket::async_trait]
impl<'r> FromRequest<'r> for Origin<'r> {
type Error = Infallible;
async fn from_request(request: &'r Request<'_>) -> request::Outcome<Self, Self::Error> {
request::Outcome::Success(Origin(request.headers().get_one("origin")))
}
}
#[options("/api/contact/<slug>")]
pub async fn options_handler<'r>(state: &rocket::State<crate::State>, origin: Origin<'r>, slug: &str) -> Result<CorsResponse, Status> {
let origin = origin.0.ok_or(Status::BadRequest)?;
let form_conf = &state.config.forms.iter().find(|form| {
form.slug == slug
});
let form_conf = form_conf.ok_or(Status::NotFound)?;
if let Some(allowed_domain) = &form_conf.cors_domain {
if allowed_domain.as_str() == origin {
return Ok(CorsResponse::AllowDomain(allowed_domain.clone()));
}
}
return Ok(CorsResponse::NoCors)
}
#[post("/api/contact/<slug>", data = "<fields>")]
pub async fn handler(state: &rocket::State<crate::State>, client_ip: ClientIp, slug: &str, fields: Json<HashMap<&str, &str>>) -> Result<Json<Response>, (Status, Json<Response>)> {
pub async fn post_handler(state: &rocket::State<crate::State>, client_ip: ClientIp, slug: &str, fields: Json<HashMap<&str, &str>>) -> Result<Json<Response>, (Status, Json<Response>)> {
{
let rate_limit = &state.rate_limit;

@ -56,6 +56,7 @@ fn rocket() -> _ {
rocket::custom(rocket_conf)
.manage(state)
.mount("/", routes![
endpoints::handler
endpoints::post_handler,
endpoints::options_handler,
])
}

Loading…
Cancel
Save