diff options
Diffstat (limited to 'server/src')
-rw-r--r-- | server/src/main.rs | 257 |
1 files changed, 129 insertions, 128 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index db37d14..4a88dcd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs | |||
@@ -1,39 +1,82 @@ | |||
1 | use std::collections::HashMap; | 1 | use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; |
2 | use std::collections::HashSet; | ||
3 | use std::process::ExitCode; | ||
4 | use std::sync::Arc; | ||
5 | |||
6 | use aws_sdk_s3::error::SdkError; | ||
7 | use aws_sdk_s3::operation::head_object::HeadObjectOutput; | ||
8 | use axum::extract::rejection; | ||
9 | use axum::extract::FromRequest; | ||
10 | use axum::extract::Path; | ||
11 | use axum::extract::State; | ||
12 | use axum::http::header; | ||
13 | use axum::http::HeaderMap; | ||
14 | use axum::http::HeaderValue; | ||
15 | use axum::response::Response; | ||
16 | use axum::Json; | ||
17 | use axum::ServiceExt; | ||
18 | use base64::prelude::*; | ||
19 | use chrono::DateTime; | ||
20 | use chrono::Utc; | ||
21 | use common::HexByte; | ||
22 | use serde::de; | ||
23 | use serde::de::DeserializeOwned; | ||
24 | use serde::Deserialize; | ||
25 | use serde::Serialize; | ||
26 | use tokio::io::AsyncWriteExt; | ||
27 | use tower::Layer; | ||
28 | |||
29 | use axum::{ | 2 | use axum::{ |
30 | async_trait, | 3 | async_trait, |
31 | extract::{FromRequestParts, OriginalUri, Request}, | 4 | extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State}, |
32 | http::{request::Parts, StatusCode, Uri}, | 5 | http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, |
33 | response::IntoResponse, | 6 | response::{IntoResponse, Response}, |
34 | routing::{get, post}, | 7 | routing::{get, post}, |
35 | Extension, Router, | 8 | Extension, Json, Router, ServiceExt, |
36 | }; | 9 | }; |
10 | use base64::prelude::*; | ||
11 | use chrono::{DateTime, Utc}; | ||
12 | use serde::{ | ||
13 | de::{self, DeserializeOwned}, | ||
14 | Deserialize, Serialize, | ||
15 | }; | ||
16 | use std::{ | ||
17 | collections::{HashMap, HashSet}, | ||
18 | process::ExitCode, | ||
19 | sync::Arc, | ||
20 | }; | ||
21 | use tokio::io::AsyncWriteExt; | ||
22 | use tower::Layer; | ||
23 | |||
24 | #[tokio::main] | ||
25 | async fn main() -> ExitCode { | ||
26 | tracing_subscriber::fmt::init(); | ||
27 | |||
28 | let conf = match Config::load() { | ||
29 | Ok(conf) => conf, | ||
30 | Err(e) => { | ||
31 | println!("Error: {e}"); | ||
32 | return ExitCode::from(2); | ||
33 | } | ||
34 | }; | ||
35 | |||
36 | let dl_limiter = DownloadLimiter::new(conf.download_limit).await; | ||
37 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
38 | |||
39 | let resetter_dl_limiter = dl_limiter.clone(); | ||
40 | tokio::spawn(async move { | ||
41 | loop { | ||
42 | println!("Resetting download counter in one hour"); | ||
43 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
44 | println!("Resetting download counter"); | ||
45 | resetter_dl_limiter.lock().await.reset().await; | ||
46 | } | ||
47 | }); | ||
48 | |||
49 | let shared_state = Arc::new(AppState { | ||
50 | s3_client: conf.s3_client, | ||
51 | s3_bucket: conf.s3_bucket, | ||
52 | authz_conf: conf.authz_conf, | ||
53 | base_url: conf.base_url, | ||
54 | dl_limiter, | ||
55 | }); | ||
56 | let app = Router::new() | ||
57 | .route("/batch", post(batch)) | ||
58 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
59 | .with_state(shared_state); | ||
60 | |||
61 | let middleware = axum::middleware::map_request(rewrite_url); | ||
62 | let app_with_middleware = middleware.layer(app); | ||
63 | |||
64 | let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await { | ||
65 | Ok(listener) => listener, | ||
66 | Err(e) => { | ||
67 | println!("Failed to listen: {e}"); | ||
68 | return ExitCode::FAILURE; | ||
69 | } | ||
70 | }; | ||
71 | |||
72 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | ||
73 | Ok(_) => ExitCode::SUCCESS, | ||
74 | Err(e) => { | ||
75 | println!("Error serving: {e}"); | ||
76 | ExitCode::FAILURE | ||
77 | } | ||
78 | } | ||
79 | } | ||
37 | 80 | ||
38 | #[derive(Clone)] | 81 | #[derive(Clone)] |
39 | struct RepositoryName(String); | 82 | struct RepositoryName(String); |
@@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> { | |||
165 | Ok(aws_sdk_s3::Client::new(&config)) | 208 | Ok(aws_sdk_s3::Client::new(&config)) |
166 | } | 209 | } |
167 | 210 | ||
168 | #[tokio::main] | 211 | struct Config { |
169 | async fn main() -> ExitCode { | 212 | listen_addr: (String, u16), |
170 | tracing_subscriber::fmt::init(); | 213 | base_url: String, |
171 | 214 | authz_conf: AuthorizationConfig, | |
172 | let env = match Env::load() { | 215 | s3_client: aws_sdk_s3::Client, |
173 | Ok(env) => env, | 216 | s3_bucket: String, |
174 | Err(e) => { | 217 | download_limit: u64, |
175 | println!("Failed to load configuration: {e}"); | 218 | } |
176 | return ExitCode::from(2); | ||
177 | } | ||
178 | }; | ||
179 | |||
180 | let s3_client = match get_s3_client(&env) { | ||
181 | Ok(s3_client) => s3_client, | ||
182 | Err(e) => { | ||
183 | println!("Failed to create S3 client: {e}"); | ||
184 | return ExitCode::FAILURE; | ||
185 | } | ||
186 | }; | ||
187 | let key = match common::load_key(&env.key_path) { | ||
188 | Ok(key) => key, | ||
189 | Err(e) => { | ||
190 | println!("Failed to load Gitolfs3 key: {e}"); | ||
191 | return ExitCode::FAILURE; | ||
192 | } | ||
193 | }; | ||
194 | |||
195 | let trusted_forwarded_hosts: HashSet<String> = env | ||
196 | .trusted_forwarded_hosts | ||
197 | .split(',') | ||
198 | .map(|s| s.to_owned()) | ||
199 | .filter(|s| !s.is_empty()) | ||
200 | .collect(); | ||
201 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
202 | |||
203 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { | ||
204 | println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); | ||
205 | return ExitCode::from(2); | ||
206 | }; | ||
207 | let dl_limiter = DownloadLimiter::new(download_limit).await; | ||
208 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
209 | |||
210 | let resetter_dl_limiter = dl_limiter.clone(); | ||
211 | tokio::spawn(async move { | ||
212 | loop { | ||
213 | println!("Resetting download counter in one hour"); | ||
214 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
215 | println!("Resetting download counter"); | ||
216 | resetter_dl_limiter.lock().await.reset().await; | ||
217 | } | ||
218 | }); | ||
219 | 219 | ||
220 | let authz_conf = AuthorizationConfig { | 220 | impl Config { |
221 | key, | 221 | fn load() -> Result<Self, String> { |
222 | trusted_forwarded_hosts, | 222 | let env = match Env::load() { |
223 | }; | 223 | Ok(env) => env, |
224 | Err(e) => return Err(format!("failed to load configuration: {e}")), | ||
225 | }; | ||
224 | 226 | ||
225 | let shared_state = Arc::new(AppState { | 227 | let s3_client = match get_s3_client(&env) { |
226 | s3_client, | 228 | Ok(s3_client) => s3_client, |
227 | s3_bucket: env.s3_bucket, | 229 | Err(e) => return Err(format!("failed to create S3 client: {e}")), |
228 | authz_conf, | 230 | }; |
229 | base_url, | 231 | let key = match common::load_key(&env.key_path) { |
230 | dl_limiter, | 232 | Ok(key) => key, |
231 | }); | 233 | Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")), |
232 | let app = Router::new() | 234 | }; |
233 | .route("/batch", post(batch)) | ||
234 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
235 | .with_state(shared_state); | ||
236 | 235 | ||
237 | let middleware = axum::middleware::map_request(rewrite_url); | 236 | let trusted_forwarded_hosts: HashSet<String> = env |
238 | let app_with_middleware = middleware.layer(app); | 237 | .trusted_forwarded_hosts |
238 | .split(',') | ||
239 | .map(|s| s.to_owned()) | ||
240 | .filter(|s| !s.is_empty()) | ||
241 | .collect(); | ||
242 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
239 | 243 | ||
240 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 244 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { |
241 | println!( | 245 | return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string()); |
242 | "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" | 246 | }; |
243 | ); | 247 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { |
244 | return ExitCode::from(2); | 248 | return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string()); |
245 | }; | 249 | }; |
246 | let addr: (String, u16) = (env.listen_host, listen_port); | ||
247 | let listener = match tokio::net::TcpListener::bind(addr).await { | ||
248 | Ok(listener) => listener, | ||
249 | Err(e) => { | ||
250 | println!("Failed to listen: {e}"); | ||
251 | return ExitCode::FAILURE; | ||
252 | } | ||
253 | }; | ||
254 | 250 | ||
255 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | 251 | Ok(Self { |
256 | Ok(_) => ExitCode::SUCCESS, | 252 | listen_addr: (env.listen_host, listen_port), |
257 | Err(e) => { | 253 | base_url, |
258 | println!("Error serving: {e}"); | 254 | authz_conf: AuthorizationConfig { |
259 | ExitCode::FAILURE | 255 | key, |
260 | } | 256 | trusted_forwarded_hosts, |
257 | }, | ||
258 | s3_client, | ||
259 | s3_bucket: env.s3_bucket, | ||
260 | download_limit, | ||
261 | }) | ||
261 | } | 262 | } |
262 | } | 263 | } |
263 | 264 | ||
@@ -479,7 +480,7 @@ async fn handle_upload_object( | |||
479 | repo: &str, | 480 | repo: &str, |
480 | obj: &BatchRequestObject, | 481 | obj: &BatchRequestObject, |
481 | ) -> Option<BatchResponseObject> { | 482 | ) -> Option<BatchResponseObject> { |
482 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 483 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
483 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 484 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
484 | 485 | ||
485 | match state | 486 | match state |
@@ -558,7 +559,7 @@ async fn handle_download_object( | |||
558 | obj: &BatchRequestObject, | 559 | obj: &BatchRequestObject, |
559 | trusted: bool, | 560 | trusted: bool, |
560 | ) -> BatchResponseObject { | 561 | ) -> BatchResponseObject { |
561 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 562 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
562 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 563 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
563 | 564 | ||
564 | let result = match state | 565 | let result = match state |
@@ -687,8 +688,8 @@ async fn handle_download_object( | |||
687 | 688 | ||
688 | let upload_path = format!( | 689 | let upload_path = format!( |
689 | "{repo}/info/lfs/objects/{}/{}/{}", | 690 | "{repo}/info/lfs/objects/{}/{}/{}", |
690 | HexByte(obj.oid[0]), | 691 | common::HexByte(obj.oid[0]), |
691 | HexByte(obj.oid[1]), | 692 | common::HexByte(obj.oid[1]), |
692 | obj.oid, | 693 | obj.oid, |
693 | ); | 694 | ); |
694 | 695 | ||
@@ -866,8 +867,8 @@ async fn batch( | |||
866 | #[derive(Deserialize, Copy, Clone)] | 867 | #[derive(Deserialize, Copy, Clone)] |
867 | #[serde(remote = "Self")] | 868 | #[serde(remote = "Self")] |
868 | struct FileParams { | 869 | struct FileParams { |
869 | oid0: HexByte, | 870 | oid0: common::HexByte, |
870 | oid1: HexByte, | 871 | oid1: common::HexByte, |
871 | oid: common::Oid, | 872 | oid: common::Oid, |
872 | } | 873 | } |
873 | 874 | ||
@@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams { | |||
877 | D: serde::Deserializer<'de>, | 878 | D: serde::Deserializer<'de>, |
878 | { | 879 | { |
879 | let unchecked @ FileParams { | 880 | let unchecked @ FileParams { |
880 | oid0: HexByte(oid0), | 881 | oid0: common::HexByte(oid0), |
881 | oid1: HexByte(oid1), | 882 | oid1: common::HexByte(oid1), |
882 | oid, | 883 | oid, |
883 | } = FileParams::deserialize(deserializer)?; | 884 | } = FileParams::deserialize(deserializer)?; |
884 | if oid0 != oid.as_bytes()[0] { | 885 | if oid0 != oid.as_bytes()[0] { |