From 849e5ef60ee32fd5743f5a2da9a90fad2b869a49 Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Wed, 24 Jan 2024 20:02:11 +0100 Subject: Add very simple download limit --- server/src/main.rs | 105 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 1 deletion(-) (limited to 'server/src/main.rs') diff --git a/server/src/main.rs b/server/src/main.rs index 8baa0d6..f18edb3 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -23,6 +23,7 @@ use serde::de; use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; +use tokio::io::AsyncWriteExt; use tower::Layer; use axum::{ @@ -101,6 +102,7 @@ struct AppState { authz_conf: AuthorizationConfig, // Should not end with a slash. base_url: String, + dl_limiter: Arc>, } struct Env { @@ -113,6 +115,7 @@ struct Env { key_path: String, listen_host: String, listen_port: String, + download_limit: String, trusted_forwarded_hosts: String, } @@ -133,6 +136,7 @@ impl Env { key_path: require_env("GITOLFS3_KEY_PATH")?, listen_host: require_env("GITOLFS3_LISTEN_HOST")?, listen_port: require_env("GITOLFS3_LISTEN_PORT")?, + download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?, trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") .unwrap_or_default(), }) @@ -196,6 +200,23 @@ async fn main() -> ExitCode { .collect(); let base_url = env.base_url.trim_end_matches('/').to_string(); + let Ok(download_limit): Result = env.download_limit.parse() else { + println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); + return ExitCode::from(2); + }; + let dl_limiter = DownloadLimiter::new(download_limit).await; + let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); + + let resetter_dl_limiter = dl_limiter.clone(); + tokio::spawn(async move { + loop { + println!("Resetting download counter in one hour"); + tokio::time::sleep(std::time::Duration::from_secs(3600)).await; + println!("Resetting download counter"); + resetter_dl_limiter.lock().await.reset().await; + } + }); + let authz_conf = AuthorizationConfig { key, trusted_forwarded_hosts, @@ -206,6 +227,7 @@ async fn main() -> ExitCode { s3_bucket: env.s3_bucket, authz_conf, base_url, + dl_limiter, }); let app = Router::new() .route("/batch", post(batch)) @@ -216,7 +238,9 @@ async fn main() -> ExitCode { let app_with_middleware = middleware.layer(app); let Ok(listen_port): Result = env.listen_port.parse() else { - println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535"); + println!( + "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" + ); return ExitCode::from(2); }; let addr: (String, u16) = (env.listen_host, listen_port); @@ -617,6 +641,35 @@ async fn handle_download_object( }; } + if let Some(content_length) = result.content_length() { + if content_length < 0 { + match state + .dl_limiter + .lock() + .await + .request(content_length as u64) + .await + { + Ok(true) => {} + Ok(false) => { + return BatchResponseObject::error( + obj, + StatusCode::SERVICE_UNAVAILABLE, + "Public LFS downloads temporarily unavailable".to_string(), + ); + } + Err(e) => { + println!("Failed to request {content_length} bytes from download limiter: {e}"); + return BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Internal server error".to_string(), + ); + } + } + } + } + let Some(tag) = common::generate_tag( common::Claims { specific_claims: common::SpecificClaims::Download(obj.oid), @@ -956,6 +1009,56 @@ async fn obj_download( (headers, body).into_response() } +struct DownloadLimiter { + current: u64, + limit: u64, +} + +impl DownloadLimiter { + async fn new(limit: u64) -> DownloadLimiter { + let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await { + Ok(dlimit_str) => dlimit_str, + Err(e) => { + println!("Failed to read download counter, assuming 0: {e}"); + return DownloadLimiter { current: 0, limit }; + } + }; + let current: u64 = match dlimit_str.parse().map_err(tokio::io::Error::other) { + Ok(current) => current, + Err(e) => { + println!("Failed to read download counter, assuming 0: {e}"); + return DownloadLimiter { current: 0, limit }; + } + }; + DownloadLimiter { current, limit } + } + + async fn request(&mut self, n: u64) -> tokio::io::Result { + if self.current + n > self.limit { + return Ok(false); + } + self.current += n; + self.write_new_limit().await?; + Ok(true) + } + + async fn reset(&mut self) { + self.current = 0; + if let Err(e) = self.write_new_limit().await { + println!("Failed to reset download counter: {e}"); + } + } + + async fn write_new_limit(&self) -> tokio::io::Result<()> { + let cwd = tokio::fs::File::open(std::env::current_dir()?).await?; + let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?; + file.write_all(self.limit.to_string().as_bytes()).await?; + file.sync_all().await?; + tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?; + cwd.sync_all().await + } +} + #[test] fn test_mimetype() { assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); -- cgit v1.2.3