diff options
Diffstat (limited to 'server/src')
| -rw-r--r-- | server/src/main.rs | 257 |
1 files changed, 129 insertions, 128 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index db37d14..4a88dcd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs | |||
| @@ -1,39 +1,82 @@ | |||
| 1 | use std::collections::HashMap; | 1 | use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; |
| 2 | use std::collections::HashSet; | ||
| 3 | use std::process::ExitCode; | ||
| 4 | use std::sync::Arc; | ||
| 5 | |||
| 6 | use aws_sdk_s3::error::SdkError; | ||
| 7 | use aws_sdk_s3::operation::head_object::HeadObjectOutput; | ||
| 8 | use axum::extract::rejection; | ||
| 9 | use axum::extract::FromRequest; | ||
| 10 | use axum::extract::Path; | ||
| 11 | use axum::extract::State; | ||
| 12 | use axum::http::header; | ||
| 13 | use axum::http::HeaderMap; | ||
| 14 | use axum::http::HeaderValue; | ||
| 15 | use axum::response::Response; | ||
| 16 | use axum::Json; | ||
| 17 | use axum::ServiceExt; | ||
| 18 | use base64::prelude::*; | ||
| 19 | use chrono::DateTime; | ||
| 20 | use chrono::Utc; | ||
| 21 | use common::HexByte; | ||
| 22 | use serde::de; | ||
| 23 | use serde::de::DeserializeOwned; | ||
| 24 | use serde::Deserialize; | ||
| 25 | use serde::Serialize; | ||
| 26 | use tokio::io::AsyncWriteExt; | ||
| 27 | use tower::Layer; | ||
| 28 | |||
| 29 | use axum::{ | 2 | use axum::{ |
| 30 | async_trait, | 3 | async_trait, |
| 31 | extract::{FromRequestParts, OriginalUri, Request}, | 4 | extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State}, |
| 32 | http::{request::Parts, StatusCode, Uri}, | 5 | http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, |
| 33 | response::IntoResponse, | 6 | response::{IntoResponse, Response}, |
| 34 | routing::{get, post}, | 7 | routing::{get, post}, |
| 35 | Extension, Router, | 8 | Extension, Json, Router, ServiceExt, |
| 36 | }; | 9 | }; |
| 10 | use base64::prelude::*; | ||
| 11 | use chrono::{DateTime, Utc}; | ||
| 12 | use serde::{ | ||
| 13 | de::{self, DeserializeOwned}, | ||
| 14 | Deserialize, Serialize, | ||
| 15 | }; | ||
| 16 | use std::{ | ||
| 17 | collections::{HashMap, HashSet}, | ||
| 18 | process::ExitCode, | ||
| 19 | sync::Arc, | ||
| 20 | }; | ||
| 21 | use tokio::io::AsyncWriteExt; | ||
| 22 | use tower::Layer; | ||
| 23 | |||
| 24 | #[tokio::main] | ||
| 25 | async fn main() -> ExitCode { | ||
| 26 | tracing_subscriber::fmt::init(); | ||
| 27 | |||
| 28 | let conf = match Config::load() { | ||
| 29 | Ok(conf) => conf, | ||
| 30 | Err(e) => { | ||
| 31 | println!("Error: {e}"); | ||
| 32 | return ExitCode::from(2); | ||
| 33 | } | ||
| 34 | }; | ||
| 35 | |||
| 36 | let dl_limiter = DownloadLimiter::new(conf.download_limit).await; | ||
| 37 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
| 38 | |||
| 39 | let resetter_dl_limiter = dl_limiter.clone(); | ||
| 40 | tokio::spawn(async move { | ||
| 41 | loop { | ||
| 42 | println!("Resetting download counter in one hour"); | ||
| 43 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
| 44 | println!("Resetting download counter"); | ||
| 45 | resetter_dl_limiter.lock().await.reset().await; | ||
| 46 | } | ||
| 47 | }); | ||
| 48 | |||
| 49 | let shared_state = Arc::new(AppState { | ||
| 50 | s3_client: conf.s3_client, | ||
| 51 | s3_bucket: conf.s3_bucket, | ||
| 52 | authz_conf: conf.authz_conf, | ||
| 53 | base_url: conf.base_url, | ||
| 54 | dl_limiter, | ||
| 55 | }); | ||
| 56 | let app = Router::new() | ||
| 57 | .route("/batch", post(batch)) | ||
| 58 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
| 59 | .with_state(shared_state); | ||
| 60 | |||
| 61 | let middleware = axum::middleware::map_request(rewrite_url); | ||
| 62 | let app_with_middleware = middleware.layer(app); | ||
| 63 | |||
| 64 | let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await { | ||
| 65 | Ok(listener) => listener, | ||
| 66 | Err(e) => { | ||
| 67 | println!("Failed to listen: {e}"); | ||
| 68 | return ExitCode::FAILURE; | ||
| 69 | } | ||
| 70 | }; | ||
| 71 | |||
| 72 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | ||
| 73 | Ok(_) => ExitCode::SUCCESS, | ||
| 74 | Err(e) => { | ||
| 75 | println!("Error serving: {e}"); | ||
| 76 | ExitCode::FAILURE | ||
| 77 | } | ||
| 78 | } | ||
| 79 | } | ||
| 37 | 80 | ||
| 38 | #[derive(Clone)] | 81 | #[derive(Clone)] |
| 39 | struct RepositoryName(String); | 82 | struct RepositoryName(String); |
| @@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> { | |||
| 165 | Ok(aws_sdk_s3::Client::new(&config)) | 208 | Ok(aws_sdk_s3::Client::new(&config)) |
| 166 | } | 209 | } |
| 167 | 210 | ||
| 168 | #[tokio::main] | 211 | struct Config { |
| 169 | async fn main() -> ExitCode { | 212 | listen_addr: (String, u16), |
| 170 | tracing_subscriber::fmt::init(); | 213 | base_url: String, |
| 171 | 214 | authz_conf: AuthorizationConfig, | |
| 172 | let env = match Env::load() { | 215 | s3_client: aws_sdk_s3::Client, |
| 173 | Ok(env) => env, | 216 | s3_bucket: String, |
| 174 | Err(e) => { | 217 | download_limit: u64, |
| 175 | println!("Failed to load configuration: {e}"); | 218 | } |
| 176 | return ExitCode::from(2); | ||
| 177 | } | ||
| 178 | }; | ||
| 179 | |||
| 180 | let s3_client = match get_s3_client(&env) { | ||
| 181 | Ok(s3_client) => s3_client, | ||
| 182 | Err(e) => { | ||
| 183 | println!("Failed to create S3 client: {e}"); | ||
| 184 | return ExitCode::FAILURE; | ||
| 185 | } | ||
| 186 | }; | ||
| 187 | let key = match common::load_key(&env.key_path) { | ||
| 188 | Ok(key) => key, | ||
| 189 | Err(e) => { | ||
| 190 | println!("Failed to load Gitolfs3 key: {e}"); | ||
| 191 | return ExitCode::FAILURE; | ||
| 192 | } | ||
| 193 | }; | ||
| 194 | |||
| 195 | let trusted_forwarded_hosts: HashSet<String> = env | ||
| 196 | .trusted_forwarded_hosts | ||
| 197 | .split(',') | ||
| 198 | .map(|s| s.to_owned()) | ||
| 199 | .filter(|s| !s.is_empty()) | ||
| 200 | .collect(); | ||
| 201 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
| 202 | |||
| 203 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { | ||
| 204 | println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); | ||
| 205 | return ExitCode::from(2); | ||
| 206 | }; | ||
| 207 | let dl_limiter = DownloadLimiter::new(download_limit).await; | ||
| 208 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
| 209 | |||
| 210 | let resetter_dl_limiter = dl_limiter.clone(); | ||
| 211 | tokio::spawn(async move { | ||
| 212 | loop { | ||
| 213 | println!("Resetting download counter in one hour"); | ||
| 214 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
| 215 | println!("Resetting download counter"); | ||
| 216 | resetter_dl_limiter.lock().await.reset().await; | ||
| 217 | } | ||
| 218 | }); | ||
| 219 | 219 | ||
| 220 | let authz_conf = AuthorizationConfig { | 220 | impl Config { |
| 221 | key, | 221 | fn load() -> Result<Self, String> { |
| 222 | trusted_forwarded_hosts, | 222 | let env = match Env::load() { |
| 223 | }; | 223 | Ok(env) => env, |
| 224 | Err(e) => return Err(format!("failed to load configuration: {e}")), | ||
| 225 | }; | ||
| 224 | 226 | ||
| 225 | let shared_state = Arc::new(AppState { | 227 | let s3_client = match get_s3_client(&env) { |
| 226 | s3_client, | 228 | Ok(s3_client) => s3_client, |
| 227 | s3_bucket: env.s3_bucket, | 229 | Err(e) => return Err(format!("failed to create S3 client: {e}")), |
| 228 | authz_conf, | 230 | }; |
| 229 | base_url, | 231 | let key = match common::load_key(&env.key_path) { |
| 230 | dl_limiter, | 232 | Ok(key) => key, |
| 231 | }); | 233 | Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")), |
| 232 | let app = Router::new() | 234 | }; |
| 233 | .route("/batch", post(batch)) | ||
| 234 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
| 235 | .with_state(shared_state); | ||
| 236 | 235 | ||
| 237 | let middleware = axum::middleware::map_request(rewrite_url); | 236 | let trusted_forwarded_hosts: HashSet<String> = env |
| 238 | let app_with_middleware = middleware.layer(app); | 237 | .trusted_forwarded_hosts |
| 238 | .split(',') | ||
| 239 | .map(|s| s.to_owned()) | ||
| 240 | .filter(|s| !s.is_empty()) | ||
| 241 | .collect(); | ||
| 242 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
| 239 | 243 | ||
| 240 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 244 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { |
| 241 | println!( | 245 | return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string()); |
| 242 | "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" | 246 | }; |
| 243 | ); | 247 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { |
| 244 | return ExitCode::from(2); | 248 | return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string()); |
| 245 | }; | 249 | }; |
| 246 | let addr: (String, u16) = (env.listen_host, listen_port); | ||
| 247 | let listener = match tokio::net::TcpListener::bind(addr).await { | ||
| 248 | Ok(listener) => listener, | ||
| 249 | Err(e) => { | ||
| 250 | println!("Failed to listen: {e}"); | ||
| 251 | return ExitCode::FAILURE; | ||
| 252 | } | ||
| 253 | }; | ||
| 254 | 250 | ||
| 255 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | 251 | Ok(Self { |
| 256 | Ok(_) => ExitCode::SUCCESS, | 252 | listen_addr: (env.listen_host, listen_port), |
| 257 | Err(e) => { | 253 | base_url, |
| 258 | println!("Error serving: {e}"); | 254 | authz_conf: AuthorizationConfig { |
| 259 | ExitCode::FAILURE | 255 | key, |
| 260 | } | 256 | trusted_forwarded_hosts, |
| 257 | }, | ||
| 258 | s3_client, | ||
| 259 | s3_bucket: env.s3_bucket, | ||
| 260 | download_limit, | ||
| 261 | }) | ||
| 261 | } | 262 | } |
| 262 | } | 263 | } |
| 263 | 264 | ||
| @@ -479,7 +480,7 @@ async fn handle_upload_object( | |||
| 479 | repo: &str, | 480 | repo: &str, |
| 480 | obj: &BatchRequestObject, | 481 | obj: &BatchRequestObject, |
| 481 | ) -> Option<BatchResponseObject> { | 482 | ) -> Option<BatchResponseObject> { |
| 482 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 483 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
| 483 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 484 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
| 484 | 485 | ||
| 485 | match state | 486 | match state |
| @@ -558,7 +559,7 @@ async fn handle_download_object( | |||
| 558 | obj: &BatchRequestObject, | 559 | obj: &BatchRequestObject, |
| 559 | trusted: bool, | 560 | trusted: bool, |
| 560 | ) -> BatchResponseObject { | 561 | ) -> BatchResponseObject { |
| 561 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 562 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
| 562 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 563 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
| 563 | 564 | ||
| 564 | let result = match state | 565 | let result = match state |
| @@ -687,8 +688,8 @@ async fn handle_download_object( | |||
| 687 | 688 | ||
| 688 | let upload_path = format!( | 689 | let upload_path = format!( |
| 689 | "{repo}/info/lfs/objects/{}/{}/{}", | 690 | "{repo}/info/lfs/objects/{}/{}/{}", |
| 690 | HexByte(obj.oid[0]), | 691 | common::HexByte(obj.oid[0]), |
| 691 | HexByte(obj.oid[1]), | 692 | common::HexByte(obj.oid[1]), |
| 692 | obj.oid, | 693 | obj.oid, |
| 693 | ); | 694 | ); |
| 694 | 695 | ||
| @@ -866,8 +867,8 @@ async fn batch( | |||
| 866 | #[derive(Deserialize, Copy, Clone)] | 867 | #[derive(Deserialize, Copy, Clone)] |
| 867 | #[serde(remote = "Self")] | 868 | #[serde(remote = "Self")] |
| 868 | struct FileParams { | 869 | struct FileParams { |
| 869 | oid0: HexByte, | 870 | oid0: common::HexByte, |
| 870 | oid1: HexByte, | 871 | oid1: common::HexByte, |
| 871 | oid: common::Oid, | 872 | oid: common::Oid, |
| 872 | } | 873 | } |
| 873 | 874 | ||
| @@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams { | |||
| 877 | D: serde::Deserializer<'de>, | 878 | D: serde::Deserializer<'de>, |
| 878 | { | 879 | { |
| 879 | let unchecked @ FileParams { | 880 | let unchecked @ FileParams { |
| 880 | oid0: HexByte(oid0), | 881 | oid0: common::HexByte(oid0), |
| 881 | oid1: HexByte(oid1), | 882 | oid1: common::HexByte(oid1), |
| 882 | oid, | 883 | oid, |
| 883 | } = FileParams::deserialize(deserializer)?; | 884 | } = FileParams::deserialize(deserializer)?; |
| 884 | if oid0 != oid.as_bytes()[0] { | 885 | if oid0 != oid.as_bytes()[0] { |