diff options
| -rw-r--r-- | server/src/main.rs | 105 | 
1 files changed, 104 insertions, 1 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index 8baa0d6..f18edb3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs  | |||
| @@ -23,6 +23,7 @@ use serde::de; | |||
| 23 | use serde::de::DeserializeOwned; | 23 | use serde::de::DeserializeOwned; | 
| 24 | use serde::Deserialize; | 24 | use serde::Deserialize; | 
| 25 | use serde::Serialize; | 25 | use serde::Serialize; | 
| 26 | use tokio::io::AsyncWriteExt; | ||
| 26 | use tower::Layer; | 27 | use tower::Layer; | 
| 27 | 28 | ||
| 28 | use axum::{ | 29 | use axum::{ | 
| @@ -101,6 +102,7 @@ struct AppState { | |||
| 101 | authz_conf: AuthorizationConfig, | 102 | authz_conf: AuthorizationConfig, | 
| 102 | // Should not end with a slash. | 103 | // Should not end with a slash. | 
| 103 | base_url: String, | 104 | base_url: String, | 
| 105 | dl_limiter: Arc<tokio::sync::Mutex<DownloadLimiter>>, | ||
| 104 | } | 106 | } | 
| 105 | 107 | ||
| 106 | struct Env { | 108 | struct Env { | 
| @@ -113,6 +115,7 @@ struct Env { | |||
| 113 | key_path: String, | 115 | key_path: String, | 
| 114 | listen_host: String, | 116 | listen_host: String, | 
| 115 | listen_port: String, | 117 | listen_port: String, | 
| 118 | download_limit: String, | ||
| 116 | trusted_forwarded_hosts: String, | 119 | trusted_forwarded_hosts: String, | 
| 117 | } | 120 | } | 
| 118 | 121 | ||
| @@ -133,6 +136,7 @@ impl Env { | |||
| 133 | key_path: require_env("GITOLFS3_KEY_PATH")?, | 136 | key_path: require_env("GITOLFS3_KEY_PATH")?, | 
| 134 | listen_host: require_env("GITOLFS3_LISTEN_HOST")?, | 137 | listen_host: require_env("GITOLFS3_LISTEN_HOST")?, | 
| 135 | listen_port: require_env("GITOLFS3_LISTEN_PORT")?, | 138 | listen_port: require_env("GITOLFS3_LISTEN_PORT")?, | 
| 139 | download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?, | ||
| 136 | trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") | 140 | trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") | 
| 137 | .unwrap_or_default(), | 141 | .unwrap_or_default(), | 
| 138 | }) | 142 | }) | 
| @@ -196,6 +200,23 @@ async fn main() -> ExitCode { | |||
| 196 | .collect(); | 200 | .collect(); | 
| 197 | let base_url = env.base_url.trim_end_matches('/').to_string(); | 201 | let base_url = env.base_url.trim_end_matches('/').to_string(); | 
| 198 | 202 | ||
| 203 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { | ||
| 204 | println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); | ||
| 205 | return ExitCode::from(2); | ||
| 206 | }; | ||
| 207 | let dl_limiter = DownloadLimiter::new(download_limit).await; | ||
| 208 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
| 209 | |||
| 210 | let resetter_dl_limiter = dl_limiter.clone(); | ||
| 211 | tokio::spawn(async move { | ||
| 212 | loop { | ||
| 213 | println!("Resetting download counter in one hour"); | ||
| 214 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
| 215 | println!("Resetting download counter"); | ||
| 216 | resetter_dl_limiter.lock().await.reset().await; | ||
| 217 | } | ||
| 218 | }); | ||
| 219 | |||
| 199 | let authz_conf = AuthorizationConfig { | 220 | let authz_conf = AuthorizationConfig { | 
| 200 | key, | 221 | key, | 
| 201 | trusted_forwarded_hosts, | 222 | trusted_forwarded_hosts, | 
| @@ -206,6 +227,7 @@ async fn main() -> ExitCode { | |||
| 206 | s3_bucket: env.s3_bucket, | 227 | s3_bucket: env.s3_bucket, | 
| 207 | authz_conf, | 228 | authz_conf, | 
| 208 | base_url, | 229 | base_url, | 
| 230 | dl_limiter, | ||
| 209 | }); | 231 | }); | 
| 210 | let app = Router::new() | 232 | let app = Router::new() | 
| 211 | .route("/batch", post(batch)) | 233 | .route("/batch", post(batch)) | 
| @@ -216,7 +238,9 @@ async fn main() -> ExitCode { | |||
| 216 | let app_with_middleware = middleware.layer(app); | 238 | let app_with_middleware = middleware.layer(app); | 
| 217 | 239 | ||
| 218 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 240 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 
| 219 | println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535"); | 241 | println!( | 
| 242 | "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" | ||
| 243 | ); | ||
| 220 | return ExitCode::from(2); | 244 | return ExitCode::from(2); | 
| 221 | }; | 245 | }; | 
| 222 | let addr: (String, u16) = (env.listen_host, listen_port); | 246 | let addr: (String, u16) = (env.listen_host, listen_port); | 
| @@ -617,6 +641,35 @@ async fn handle_download_object( | |||
| 617 | }; | 641 | }; | 
| 618 | } | 642 | } | 
| 619 | 643 | ||
| 644 | if let Some(content_length) = result.content_length() { | ||
| 645 | if content_length < 0 { | ||
| 646 | match state | ||
| 647 | .dl_limiter | ||
| 648 | .lock() | ||
| 649 | .await | ||
| 650 | .request(content_length as u64) | ||
| 651 | .await | ||
| 652 | { | ||
| 653 | Ok(true) => {} | ||
| 654 | Ok(false) => { | ||
| 655 | return BatchResponseObject::error( | ||
| 656 | obj, | ||
| 657 | StatusCode::SERVICE_UNAVAILABLE, | ||
| 658 | "Public LFS downloads temporarily unavailable".to_string(), | ||
| 659 | ); | ||
| 660 | } | ||
| 661 | Err(e) => { | ||
| 662 | println!("Failed to request {content_length} bytes from download limiter: {e}"); | ||
| 663 | return BatchResponseObject::error( | ||
| 664 | obj, | ||
| 665 | StatusCode::INTERNAL_SERVER_ERROR, | ||
| 666 | "Internal server error".to_string(), | ||
| 667 | ); | ||
| 668 | } | ||
| 669 | } | ||
| 670 | } | ||
| 671 | } | ||
| 672 | |||
| 620 | let Some(tag) = common::generate_tag( | 673 | let Some(tag) = common::generate_tag( | 
| 621 | common::Claims { | 674 | common::Claims { | 
| 622 | specific_claims: common::SpecificClaims::Download(obj.oid), | 675 | specific_claims: common::SpecificClaims::Download(obj.oid), | 
| @@ -956,6 +1009,56 @@ async fn obj_download( | |||
| 956 | (headers, body).into_response() | 1009 | (headers, body).into_response() | 
| 957 | } | 1010 | } | 
| 958 | 1011 | ||
| 1012 | struct DownloadLimiter { | ||
| 1013 | current: u64, | ||
| 1014 | limit: u64, | ||
| 1015 | } | ||
| 1016 | |||
| 1017 | impl DownloadLimiter { | ||
| 1018 | async fn new(limit: u64) -> DownloadLimiter { | ||
| 1019 | let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await { | ||
| 1020 | Ok(dlimit_str) => dlimit_str, | ||
| 1021 | Err(e) => { | ||
| 1022 | println!("Failed to read download counter, assuming 0: {e}"); | ||
| 1023 | return DownloadLimiter { current: 0, limit }; | ||
| 1024 | } | ||
| 1025 | }; | ||
| 1026 | let current: u64 = match dlimit_str.parse().map_err(tokio::io::Error::other) { | ||
| 1027 | Ok(current) => current, | ||
| 1028 | Err(e) => { | ||
| 1029 | println!("Failed to read download counter, assuming 0: {e}"); | ||
| 1030 | return DownloadLimiter { current: 0, limit }; | ||
| 1031 | } | ||
| 1032 | }; | ||
| 1033 | DownloadLimiter { current, limit } | ||
| 1034 | } | ||
| 1035 | |||
| 1036 | async fn request(&mut self, n: u64) -> tokio::io::Result<bool> { | ||
| 1037 | if self.current + n > self.limit { | ||
| 1038 | return Ok(false); | ||
| 1039 | } | ||
| 1040 | self.current += n; | ||
| 1041 | self.write_new_limit().await?; | ||
| 1042 | Ok(true) | ||
| 1043 | } | ||
| 1044 | |||
| 1045 | async fn reset(&mut self) { | ||
| 1046 | self.current = 0; | ||
| 1047 | if let Err(e) = self.write_new_limit().await { | ||
| 1048 | println!("Failed to reset download counter: {e}"); | ||
| 1049 | } | ||
| 1050 | } | ||
| 1051 | |||
| 1052 | async fn write_new_limit(&self) -> tokio::io::Result<()> { | ||
| 1053 | let cwd = tokio::fs::File::open(std::env::current_dir()?).await?; | ||
| 1054 | let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?; | ||
| 1055 | file.write_all(self.limit.to_string().as_bytes()).await?; | ||
| 1056 | file.sync_all().await?; | ||
| 1057 | tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?; | ||
| 1058 | cwd.sync_all().await | ||
| 1059 | } | ||
| 1060 | } | ||
| 1061 | |||
| 959 | #[test] | 1062 | #[test] | 
| 960 | fn test_mimetype() { | 1063 | fn test_mimetype() { | 
| 961 | assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); | 1064 | assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); |