From f146743061ba8170569bf18518202df9a43c09f3 Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Fri, 26 Jan 2024 12:28:39 +0100 Subject: Clean up part of the code --- server/src/main.rs | 257 +++++++++++++++++++++++++++-------------------------- 1 file changed, 129 insertions(+), 128 deletions(-) (limited to 'server/src') diff --git a/server/src/main.rs b/server/src/main.rs index db37d14..4a88dcd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,39 +1,82 @@ -use std::collections::HashMap; -use std::collections::HashSet; -use std::process::ExitCode; -use std::sync::Arc; - -use aws_sdk_s3::error::SdkError; -use aws_sdk_s3::operation::head_object::HeadObjectOutput; -use axum::extract::rejection; -use axum::extract::FromRequest; -use axum::extract::Path; -use axum::extract::State; -use axum::http::header; -use axum::http::HeaderMap; -use axum::http::HeaderValue; -use axum::response::Response; -use axum::Json; -use axum::ServiceExt; -use base64::prelude::*; -use chrono::DateTime; -use chrono::Utc; -use common::HexByte; -use serde::de; -use serde::de::DeserializeOwned; -use serde::Deserialize; -use serde::Serialize; -use tokio::io::AsyncWriteExt; -use tower::Layer; - +use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; use axum::{ async_trait, - extract::{FromRequestParts, OriginalUri, Request}, - http::{request::Parts, StatusCode, Uri}, - response::IntoResponse, + extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State}, + http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, + response::{IntoResponse, Response}, routing::{get, post}, - Extension, Router, + Extension, Json, Router, ServiceExt, }; +use base64::prelude::*; +use chrono::{DateTime, Utc}; +use serde::{ + de::{self, DeserializeOwned}, + Deserialize, Serialize, +}; +use std::{ + collections::{HashMap, HashSet}, + process::ExitCode, + sync::Arc, +}; +use tokio::io::AsyncWriteExt; +use tower::Layer; + +#[tokio::main] +async fn main() -> ExitCode { + tracing_subscriber::fmt::init(); + + let conf = match Config::load() { + Ok(conf) => conf, + Err(e) => { + println!("Error: {e}"); + return ExitCode::from(2); + } + }; + + let dl_limiter = DownloadLimiter::new(conf.download_limit).await; + let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); + + let resetter_dl_limiter = dl_limiter.clone(); + tokio::spawn(async move { + loop { + println!("Resetting download counter in one hour"); + tokio::time::sleep(std::time::Duration::from_secs(3600)).await; + println!("Resetting download counter"); + resetter_dl_limiter.lock().await.reset().await; + } + }); + + let shared_state = Arc::new(AppState { + s3_client: conf.s3_client, + s3_bucket: conf.s3_bucket, + authz_conf: conf.authz_conf, + base_url: conf.base_url, + dl_limiter, + }); + let app = Router::new() + .route("/batch", post(batch)) + .route("/:oid0/:oid1/:oid", get(obj_download)) + .with_state(shared_state); + + let middleware = axum::middleware::map_request(rewrite_url); + let app_with_middleware = middleware.layer(app); + + let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await { + Ok(listener) => listener, + Err(e) => { + println!("Failed to listen: {e}"); + return ExitCode::FAILURE; + } + }; + + match axum::serve(listener, app_with_middleware.into_make_service()).await { + Ok(_) => ExitCode::SUCCESS, + Err(e) => { + println!("Error serving: {e}"); + ExitCode::FAILURE + } + } +} #[derive(Clone)] struct RepositoryName(String); @@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result { Ok(aws_sdk_s3::Client::new(&config)) } -#[tokio::main] -async fn main() -> ExitCode { - tracing_subscriber::fmt::init(); - - let env = match Env::load() { - Ok(env) => env, - Err(e) => { - println!("Failed to load configuration: {e}"); - return ExitCode::from(2); - } - }; - - let s3_client = match get_s3_client(&env) { - Ok(s3_client) => s3_client, - Err(e) => { - println!("Failed to create S3 client: {e}"); - return ExitCode::FAILURE; - } - }; - let key = match common::load_key(&env.key_path) { - Ok(key) => key, - Err(e) => { - println!("Failed to load Gitolfs3 key: {e}"); - return ExitCode::FAILURE; - } - }; - - let trusted_forwarded_hosts: HashSet = env - .trusted_forwarded_hosts - .split(',') - .map(|s| s.to_owned()) - .filter(|s| !s.is_empty()) - .collect(); - let base_url = env.base_url.trim_end_matches('/').to_string(); - - let Ok(download_limit): Result = env.download_limit.parse() else { - println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); - return ExitCode::from(2); - }; - let dl_limiter = DownloadLimiter::new(download_limit).await; - let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); - - let resetter_dl_limiter = dl_limiter.clone(); - tokio::spawn(async move { - loop { - println!("Resetting download counter in one hour"); - tokio::time::sleep(std::time::Duration::from_secs(3600)).await; - println!("Resetting download counter"); - resetter_dl_limiter.lock().await.reset().await; - } - }); +struct Config { + listen_addr: (String, u16), + base_url: String, + authz_conf: AuthorizationConfig, + s3_client: aws_sdk_s3::Client, + s3_bucket: String, + download_limit: u64, +} - let authz_conf = AuthorizationConfig { - key, - trusted_forwarded_hosts, - }; +impl Config { + fn load() -> Result { + let env = match Env::load() { + Ok(env) => env, + Err(e) => return Err(format!("failed to load configuration: {e}")), + }; - let shared_state = Arc::new(AppState { - s3_client, - s3_bucket: env.s3_bucket, - authz_conf, - base_url, - dl_limiter, - }); - let app = Router::new() - .route("/batch", post(batch)) - .route("/:oid0/:oid1/:oid", get(obj_download)) - .with_state(shared_state); + let s3_client = match get_s3_client(&env) { + Ok(s3_client) => s3_client, + Err(e) => return Err(format!("failed to create S3 client: {e}")), + }; + let key = match common::load_key(&env.key_path) { + Ok(key) => key, + Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")), + }; - let middleware = axum::middleware::map_request(rewrite_url); - let app_with_middleware = middleware.layer(app); + let trusted_forwarded_hosts: HashSet = env + .trusted_forwarded_hosts + .split(',') + .map(|s| s.to_owned()) + .filter(|s| !s.is_empty()) + .collect(); + let base_url = env.base_url.trim_end_matches('/').to_string(); - let Ok(listen_port): Result = env.listen_port.parse() else { - println!( - "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" - ); - return ExitCode::from(2); - }; - let addr: (String, u16) = (env.listen_host, listen_port); - let listener = match tokio::net::TcpListener::bind(addr).await { - Ok(listener) => listener, - Err(e) => { - println!("Failed to listen: {e}"); - return ExitCode::FAILURE; - } - }; + let Ok(listen_port): Result = env.listen_port.parse() else { + return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string()); + }; + let Ok(download_limit): Result = env.download_limit.parse() else { + return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string()); + }; - match axum::serve(listener, app_with_middleware.into_make_service()).await { - Ok(_) => ExitCode::SUCCESS, - Err(e) => { - println!("Error serving: {e}"); - ExitCode::FAILURE - } + Ok(Self { + listen_addr: (env.listen_host, listen_port), + base_url, + authz_conf: AuthorizationConfig { + key, + trusted_forwarded_hosts, + }, + s3_client, + s3_bucket: env.s3_bucket, + download_limit, + }) } } @@ -479,7 +480,7 @@ async fn handle_upload_object( repo: &str, obj: &BatchRequestObject, ) -> Option { - let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); + let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); match state @@ -558,7 +559,7 @@ async fn handle_download_object( obj: &BatchRequestObject, trusted: bool, ) -> BatchResponseObject { - let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); + let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); let result = match state @@ -687,8 +688,8 @@ async fn handle_download_object( let upload_path = format!( "{repo}/info/lfs/objects/{}/{}/{}", - HexByte(obj.oid[0]), - HexByte(obj.oid[1]), + common::HexByte(obj.oid[0]), + common::HexByte(obj.oid[1]), obj.oid, ); @@ -866,8 +867,8 @@ async fn batch( #[derive(Deserialize, Copy, Clone)] #[serde(remote = "Self")] struct FileParams { - oid0: HexByte, - oid1: HexByte, + oid0: common::HexByte, + oid1: common::HexByte, oid: common::Oid, } @@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams { D: serde::Deserializer<'de>, { let unchecked @ FileParams { - oid0: HexByte(oid0), - oid1: HexByte(oid1), + oid0: common::HexByte(oid0), + oid1: common::HexByte(oid1), oid, } = FileParams::deserialize(deserializer)?; if oid0 != oid.as_bytes()[0] { -- cgit v1.2.3