diff options
Diffstat (limited to 'server/src')
-rw-r--r-- | server/src/main.rs | 105 |
1 files changed, 104 insertions, 1 deletions
diff --git a/server/src/main.rs b/server/src/main.rs index 8baa0d6..f18edb3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs | |||
@@ -23,6 +23,7 @@ use serde::de; | |||
23 | use serde::de::DeserializeOwned; | 23 | use serde::de::DeserializeOwned; |
24 | use serde::Deserialize; | 24 | use serde::Deserialize; |
25 | use serde::Serialize; | 25 | use serde::Serialize; |
26 | use tokio::io::AsyncWriteExt; | ||
26 | use tower::Layer; | 27 | use tower::Layer; |
27 | 28 | ||
28 | use axum::{ | 29 | use axum::{ |
@@ -101,6 +102,7 @@ struct AppState { | |||
101 | authz_conf: AuthorizationConfig, | 102 | authz_conf: AuthorizationConfig, |
102 | // Should not end with a slash. | 103 | // Should not end with a slash. |
103 | base_url: String, | 104 | base_url: String, |
105 | dl_limiter: Arc<tokio::sync::Mutex<DownloadLimiter>>, | ||
104 | } | 106 | } |
105 | 107 | ||
106 | struct Env { | 108 | struct Env { |
@@ -113,6 +115,7 @@ struct Env { | |||
113 | key_path: String, | 115 | key_path: String, |
114 | listen_host: String, | 116 | listen_host: String, |
115 | listen_port: String, | 117 | listen_port: String, |
118 | download_limit: String, | ||
116 | trusted_forwarded_hosts: String, | 119 | trusted_forwarded_hosts: String, |
117 | } | 120 | } |
118 | 121 | ||
@@ -133,6 +136,7 @@ impl Env { | |||
133 | key_path: require_env("GITOLFS3_KEY_PATH")?, | 136 | key_path: require_env("GITOLFS3_KEY_PATH")?, |
134 | listen_host: require_env("GITOLFS3_LISTEN_HOST")?, | 137 | listen_host: require_env("GITOLFS3_LISTEN_HOST")?, |
135 | listen_port: require_env("GITOLFS3_LISTEN_PORT")?, | 138 | listen_port: require_env("GITOLFS3_LISTEN_PORT")?, |
139 | download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?, | ||
136 | trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") | 140 | trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") |
137 | .unwrap_or_default(), | 141 | .unwrap_or_default(), |
138 | }) | 142 | }) |
@@ -196,6 +200,23 @@ async fn main() -> ExitCode { | |||
196 | .collect(); | 200 | .collect(); |
197 | let base_url = env.base_url.trim_end_matches('/').to_string(); | 201 | let base_url = env.base_url.trim_end_matches('/').to_string(); |
198 | 202 | ||
203 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { | ||
204 | println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); | ||
205 | return ExitCode::from(2); | ||
206 | }; | ||
207 | let dl_limiter = DownloadLimiter::new(download_limit).await; | ||
208 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
209 | |||
210 | let resetter_dl_limiter = dl_limiter.clone(); | ||
211 | tokio::spawn(async move { | ||
212 | loop { | ||
213 | println!("Resetting download counter in one hour"); | ||
214 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
215 | println!("Resetting download counter"); | ||
216 | resetter_dl_limiter.lock().await.reset().await; | ||
217 | } | ||
218 | }); | ||
219 | |||
199 | let authz_conf = AuthorizationConfig { | 220 | let authz_conf = AuthorizationConfig { |
200 | key, | 221 | key, |
201 | trusted_forwarded_hosts, | 222 | trusted_forwarded_hosts, |
@@ -206,6 +227,7 @@ async fn main() -> ExitCode { | |||
206 | s3_bucket: env.s3_bucket, | 227 | s3_bucket: env.s3_bucket, |
207 | authz_conf, | 228 | authz_conf, |
208 | base_url, | 229 | base_url, |
230 | dl_limiter, | ||
209 | }); | 231 | }); |
210 | let app = Router::new() | 232 | let app = Router::new() |
211 | .route("/batch", post(batch)) | 233 | .route("/batch", post(batch)) |
@@ -216,7 +238,9 @@ async fn main() -> ExitCode { | |||
216 | let app_with_middleware = middleware.layer(app); | 238 | let app_with_middleware = middleware.layer(app); |
217 | 239 | ||
218 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 240 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { |
219 | println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535"); | 241 | println!( |
242 | "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" | ||
243 | ); | ||
220 | return ExitCode::from(2); | 244 | return ExitCode::from(2); |
221 | }; | 245 | }; |
222 | let addr: (String, u16) = (env.listen_host, listen_port); | 246 | let addr: (String, u16) = (env.listen_host, listen_port); |
@@ -617,6 +641,35 @@ async fn handle_download_object( | |||
617 | }; | 641 | }; |
618 | } | 642 | } |
619 | 643 | ||
644 | if let Some(content_length) = result.content_length() { | ||
645 | if content_length < 0 { | ||
646 | match state | ||
647 | .dl_limiter | ||
648 | .lock() | ||
649 | .await | ||
650 | .request(content_length as u64) | ||
651 | .await | ||
652 | { | ||
653 | Ok(true) => {} | ||
654 | Ok(false) => { | ||
655 | return BatchResponseObject::error( | ||
656 | obj, | ||
657 | StatusCode::SERVICE_UNAVAILABLE, | ||
658 | "Public LFS downloads temporarily unavailable".to_string(), | ||
659 | ); | ||
660 | } | ||
661 | Err(e) => { | ||
662 | println!("Failed to request {content_length} bytes from download limiter: {e}"); | ||
663 | return BatchResponseObject::error( | ||
664 | obj, | ||
665 | StatusCode::INTERNAL_SERVER_ERROR, | ||
666 | "Internal server error".to_string(), | ||
667 | ); | ||
668 | } | ||
669 | } | ||
670 | } | ||
671 | } | ||
672 | |||
620 | let Some(tag) = common::generate_tag( | 673 | let Some(tag) = common::generate_tag( |
621 | common::Claims { | 674 | common::Claims { |
622 | specific_claims: common::SpecificClaims::Download(obj.oid), | 675 | specific_claims: common::SpecificClaims::Download(obj.oid), |
@@ -956,6 +1009,56 @@ async fn obj_download( | |||
956 | (headers, body).into_response() | 1009 | (headers, body).into_response() |
957 | } | 1010 | } |
958 | 1011 | ||
1012 | struct DownloadLimiter { | ||
1013 | current: u64, | ||
1014 | limit: u64, | ||
1015 | } | ||
1016 | |||
1017 | impl DownloadLimiter { | ||
1018 | async fn new(limit: u64) -> DownloadLimiter { | ||
1019 | let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await { | ||
1020 | Ok(dlimit_str) => dlimit_str, | ||
1021 | Err(e) => { | ||
1022 | println!("Failed to read download counter, assuming 0: {e}"); | ||
1023 | return DownloadLimiter { current: 0, limit }; | ||
1024 | } | ||
1025 | }; | ||
1026 | let current: u64 = match dlimit_str.parse().map_err(tokio::io::Error::other) { | ||
1027 | Ok(current) => current, | ||
1028 | Err(e) => { | ||
1029 | println!("Failed to read download counter, assuming 0: {e}"); | ||
1030 | return DownloadLimiter { current: 0, limit }; | ||
1031 | } | ||
1032 | }; | ||
1033 | DownloadLimiter { current, limit } | ||
1034 | } | ||
1035 | |||
1036 | async fn request(&mut self, n: u64) -> tokio::io::Result<bool> { | ||
1037 | if self.current + n > self.limit { | ||
1038 | return Ok(false); | ||
1039 | } | ||
1040 | self.current += n; | ||
1041 | self.write_new_limit().await?; | ||
1042 | Ok(true) | ||
1043 | } | ||
1044 | |||
1045 | async fn reset(&mut self) { | ||
1046 | self.current = 0; | ||
1047 | if let Err(e) = self.write_new_limit().await { | ||
1048 | println!("Failed to reset download counter: {e}"); | ||
1049 | } | ||
1050 | } | ||
1051 | |||
1052 | async fn write_new_limit(&self) -> tokio::io::Result<()> { | ||
1053 | let cwd = tokio::fs::File::open(std::env::current_dir()?).await?; | ||
1054 | let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?; | ||
1055 | file.write_all(self.limit.to_string().as_bytes()).await?; | ||
1056 | file.sync_all().await?; | ||
1057 | tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?; | ||
1058 | cwd.sync_all().await | ||
1059 | } | ||
1060 | } | ||
1061 | |||
959 | #[test] | 1062 | #[test] |
960 | fn test_mimetype() { | 1063 | fn test_mimetype() { |
961 | assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); | 1064 | assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); |