aboutsummaryrefslogtreecommitdiffstats
path: root/server/src
diff options
context:
space:
mode:
Diffstat (limited to 'server/src')
-rw-r--r--server/src/main.rs257
1 files changed, 129 insertions, 128 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
index db37d14..4a88dcd 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,39 +1,82 @@
1use std::collections::HashMap; 1use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput};
2use std::collections::HashSet;
3use std::process::ExitCode;
4use std::sync::Arc;
5
6use aws_sdk_s3::error::SdkError;
7use aws_sdk_s3::operation::head_object::HeadObjectOutput;
8use axum::extract::rejection;
9use axum::extract::FromRequest;
10use axum::extract::Path;
11use axum::extract::State;
12use axum::http::header;
13use axum::http::HeaderMap;
14use axum::http::HeaderValue;
15use axum::response::Response;
16use axum::Json;
17use axum::ServiceExt;
18use base64::prelude::*;
19use chrono::DateTime;
20use chrono::Utc;
21use common::HexByte;
22use serde::de;
23use serde::de::DeserializeOwned;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::io::AsyncWriteExt;
27use tower::Layer;
28
29use axum::{ 2use axum::{
30 async_trait, 3 async_trait,
31 extract::{FromRequestParts, OriginalUri, Request}, 4 extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State},
32 http::{request::Parts, StatusCode, Uri}, 5 http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
33 response::IntoResponse, 6 response::{IntoResponse, Response},
34 routing::{get, post}, 7 routing::{get, post},
35 Extension, Router, 8 Extension, Json, Router, ServiceExt,
36}; 9};
10use base64::prelude::*;
11use chrono::{DateTime, Utc};
12use serde::{
13 de::{self, DeserializeOwned},
14 Deserialize, Serialize,
15};
16use std::{
17 collections::{HashMap, HashSet},
18 process::ExitCode,
19 sync::Arc,
20};
21use tokio::io::AsyncWriteExt;
22use tower::Layer;
23
24#[tokio::main]
25async fn main() -> ExitCode {
26 tracing_subscriber::fmt::init();
27
28 let conf = match Config::load() {
29 Ok(conf) => conf,
30 Err(e) => {
31 println!("Error: {e}");
32 return ExitCode::from(2);
33 }
34 };
35
36 let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
37 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
38
39 let resetter_dl_limiter = dl_limiter.clone();
40 tokio::spawn(async move {
41 loop {
42 println!("Resetting download counter in one hour");
43 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
44 println!("Resetting download counter");
45 resetter_dl_limiter.lock().await.reset().await;
46 }
47 });
48
49 let shared_state = Arc::new(AppState {
50 s3_client: conf.s3_client,
51 s3_bucket: conf.s3_bucket,
52 authz_conf: conf.authz_conf,
53 base_url: conf.base_url,
54 dl_limiter,
55 });
56 let app = Router::new()
57 .route("/batch", post(batch))
58 .route("/:oid0/:oid1/:oid", get(obj_download))
59 .with_state(shared_state);
60
61 let middleware = axum::middleware::map_request(rewrite_url);
62 let app_with_middleware = middleware.layer(app);
63
64 let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await {
65 Ok(listener) => listener,
66 Err(e) => {
67 println!("Failed to listen: {e}");
68 return ExitCode::FAILURE;
69 }
70 };
71
72 match axum::serve(listener, app_with_middleware.into_make_service()).await {
73 Ok(_) => ExitCode::SUCCESS,
74 Err(e) => {
75 println!("Error serving: {e}");
76 ExitCode::FAILURE
77 }
78 }
79}
37 80
38#[derive(Clone)] 81#[derive(Clone)]
39struct RepositoryName(String); 82struct RepositoryName(String);
@@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
165 Ok(aws_sdk_s3::Client::new(&config)) 208 Ok(aws_sdk_s3::Client::new(&config))
166} 209}
167 210
168#[tokio::main] 211struct Config {
169async fn main() -> ExitCode { 212 listen_addr: (String, u16),
170 tracing_subscriber::fmt::init(); 213 base_url: String,
171 214 authz_conf: AuthorizationConfig,
172 let env = match Env::load() { 215 s3_client: aws_sdk_s3::Client,
173 Ok(env) => env, 216 s3_bucket: String,
174 Err(e) => { 217 download_limit: u64,
175 println!("Failed to load configuration: {e}"); 218}
176 return ExitCode::from(2);
177 }
178 };
179
180 let s3_client = match get_s3_client(&env) {
181 Ok(s3_client) => s3_client,
182 Err(e) => {
183 println!("Failed to create S3 client: {e}");
184 return ExitCode::FAILURE;
185 }
186 };
187 let key = match common::load_key(&env.key_path) {
188 Ok(key) => key,
189 Err(e) => {
190 println!("Failed to load Gitolfs3 key: {e}");
191 return ExitCode::FAILURE;
192 }
193 };
194
195 let trusted_forwarded_hosts: HashSet<String> = env
196 .trusted_forwarded_hosts
197 .split(',')
198 .map(|s| s.to_owned())
199 .filter(|s| !s.is_empty())
200 .collect();
201 let base_url = env.base_url.trim_end_matches('/').to_string();
202
203 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
204 println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer");
205 return ExitCode::from(2);
206 };
207 let dl_limiter = DownloadLimiter::new(download_limit).await;
208 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
209
210 let resetter_dl_limiter = dl_limiter.clone();
211 tokio::spawn(async move {
212 loop {
213 println!("Resetting download counter in one hour");
214 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
215 println!("Resetting download counter");
216 resetter_dl_limiter.lock().await.reset().await;
217 }
218 });
219 219
220 let authz_conf = AuthorizationConfig { 220impl Config {
221 key, 221 fn load() -> Result<Self, String> {
222 trusted_forwarded_hosts, 222 let env = match Env::load() {
223 }; 223 Ok(env) => env,
224 Err(e) => return Err(format!("failed to load configuration: {e}")),
225 };
224 226
225 let shared_state = Arc::new(AppState { 227 let s3_client = match get_s3_client(&env) {
226 s3_client, 228 Ok(s3_client) => s3_client,
227 s3_bucket: env.s3_bucket, 229 Err(e) => return Err(format!("failed to create S3 client: {e}")),
228 authz_conf, 230 };
229 base_url, 231 let key = match common::load_key(&env.key_path) {
230 dl_limiter, 232 Ok(key) => key,
231 }); 233 Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
232 let app = Router::new() 234 };
233 .route("/batch", post(batch))
234 .route("/:oid0/:oid1/:oid", get(obj_download))
235 .with_state(shared_state);
236 235
237 let middleware = axum::middleware::map_request(rewrite_url); 236 let trusted_forwarded_hosts: HashSet<String> = env
238 let app_with_middleware = middleware.layer(app); 237 .trusted_forwarded_hosts
238 .split(',')
239 .map(|s| s.to_owned())
240 .filter(|s| !s.is_empty())
241 .collect();
242 let base_url = env.base_url.trim_end_matches('/').to_string();
239 243
240 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { 244 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
241 println!( 245 return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
242 "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" 246 };
243 ); 247 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
244 return ExitCode::from(2); 248 return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
245 }; 249 };
246 let addr: (String, u16) = (env.listen_host, listen_port);
247 let listener = match tokio::net::TcpListener::bind(addr).await {
248 Ok(listener) => listener,
249 Err(e) => {
250 println!("Failed to listen: {e}");
251 return ExitCode::FAILURE;
252 }
253 };
254 250
255 match axum::serve(listener, app_with_middleware.into_make_service()).await { 251 Ok(Self {
256 Ok(_) => ExitCode::SUCCESS, 252 listen_addr: (env.listen_host, listen_port),
257 Err(e) => { 253 base_url,
258 println!("Error serving: {e}"); 254 authz_conf: AuthorizationConfig {
259 ExitCode::FAILURE 255 key,
260 } 256 trusted_forwarded_hosts,
257 },
258 s3_client,
259 s3_bucket: env.s3_bucket,
260 download_limit,
261 })
261 } 262 }
262} 263}
263 264
@@ -479,7 +480,7 @@ async fn handle_upload_object(
479 repo: &str, 480 repo: &str,
480 obj: &BatchRequestObject, 481 obj: &BatchRequestObject,
481) -> Option<BatchResponseObject> { 482) -> Option<BatchResponseObject> {
482 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 483 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
483 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 484 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
484 485
485 match state 486 match state
@@ -558,7 +559,7 @@ async fn handle_download_object(
558 obj: &BatchRequestObject, 559 obj: &BatchRequestObject,
559 trusted: bool, 560 trusted: bool,
560) -> BatchResponseObject { 561) -> BatchResponseObject {
561 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 562 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
562 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 563 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
563 564
564 let result = match state 565 let result = match state
@@ -687,8 +688,8 @@ async fn handle_download_object(
687 688
688 let upload_path = format!( 689 let upload_path = format!(
689 "{repo}/info/lfs/objects/{}/{}/{}", 690 "{repo}/info/lfs/objects/{}/{}/{}",
690 HexByte(obj.oid[0]), 691 common::HexByte(obj.oid[0]),
691 HexByte(obj.oid[1]), 692 common::HexByte(obj.oid[1]),
692 obj.oid, 693 obj.oid,
693 ); 694 );
694 695
@@ -866,8 +867,8 @@ async fn batch(
866#[derive(Deserialize, Copy, Clone)] 867#[derive(Deserialize, Copy, Clone)]
867#[serde(remote = "Self")] 868#[serde(remote = "Self")]
868struct FileParams { 869struct FileParams {
869 oid0: HexByte, 870 oid0: common::HexByte,
870 oid1: HexByte, 871 oid1: common::HexByte,
871 oid: common::Oid, 872 oid: common::Oid,
872} 873}
873 874
@@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams {
877 D: serde::Deserializer<'de>, 878 D: serde::Deserializer<'de>,
878 { 879 {
879 let unchecked @ FileParams { 880 let unchecked @ FileParams {
880 oid0: HexByte(oid0), 881 oid0: common::HexByte(oid0),
881 oid1: HexByte(oid1), 882 oid1: common::HexByte(oid1),
882 oid, 883 oid,
883 } = FileParams::deserialize(deserializer)?; 884 } = FileParams::deserialize(deserializer)?;
884 if oid0 != oid.as_bytes()[0] { 885 if oid0 != oid.as_bytes()[0] {