aboutsummaryrefslogtreecommitdiffstats
path: root/server/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'server/src/main.rs')
-rw-r--r--server/src/main.rs105
1 files changed, 104 insertions, 1 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
index 8baa0d6..f18edb3 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -23,6 +23,7 @@ use serde::de;
23use serde::de::DeserializeOwned; 23use serde::de::DeserializeOwned;
24use serde::Deserialize; 24use serde::Deserialize;
25use serde::Serialize; 25use serde::Serialize;
26use tokio::io::AsyncWriteExt;
26use tower::Layer; 27use tower::Layer;
27 28
28use axum::{ 29use axum::{
@@ -101,6 +102,7 @@ struct AppState {
101 authz_conf: AuthorizationConfig, 102 authz_conf: AuthorizationConfig,
102 // Should not end with a slash. 103 // Should not end with a slash.
103 base_url: String, 104 base_url: String,
105 dl_limiter: Arc<tokio::sync::Mutex<DownloadLimiter>>,
104} 106}
105 107
106struct Env { 108struct Env {
@@ -113,6 +115,7 @@ struct Env {
113 key_path: String, 115 key_path: String,
114 listen_host: String, 116 listen_host: String,
115 listen_port: String, 117 listen_port: String,
118 download_limit: String,
116 trusted_forwarded_hosts: String, 119 trusted_forwarded_hosts: String,
117} 120}
118 121
@@ -133,6 +136,7 @@ impl Env {
133 key_path: require_env("GITOLFS3_KEY_PATH")?, 136 key_path: require_env("GITOLFS3_KEY_PATH")?,
134 listen_host: require_env("GITOLFS3_LISTEN_HOST")?, 137 listen_host: require_env("GITOLFS3_LISTEN_HOST")?,
135 listen_port: require_env("GITOLFS3_LISTEN_PORT")?, 138 listen_port: require_env("GITOLFS3_LISTEN_PORT")?,
139 download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?,
136 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") 140 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS")
137 .unwrap_or_default(), 141 .unwrap_or_default(),
138 }) 142 })
@@ -196,6 +200,23 @@ async fn main() -> ExitCode {
196 .collect(); 200 .collect();
197 let base_url = env.base_url.trim_end_matches('/').to_string(); 201 let base_url = env.base_url.trim_end_matches('/').to_string();
198 202
203 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
204 println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer");
205 return ExitCode::from(2);
206 };
207 let dl_limiter = DownloadLimiter::new(download_limit).await;
208 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
209
210 let resetter_dl_limiter = dl_limiter.clone();
211 tokio::spawn(async move {
212 loop {
213 println!("Resetting download counter in one hour");
214 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
215 println!("Resetting download counter");
216 resetter_dl_limiter.lock().await.reset().await;
217 }
218 });
219
199 let authz_conf = AuthorizationConfig { 220 let authz_conf = AuthorizationConfig {
200 key, 221 key,
201 trusted_forwarded_hosts, 222 trusted_forwarded_hosts,
@@ -206,6 +227,7 @@ async fn main() -> ExitCode {
206 s3_bucket: env.s3_bucket, 227 s3_bucket: env.s3_bucket,
207 authz_conf, 228 authz_conf,
208 base_url, 229 base_url,
230 dl_limiter,
209 }); 231 });
210 let app = Router::new() 232 let app = Router::new()
211 .route("/batch", post(batch)) 233 .route("/batch", post(batch))
@@ -216,7 +238,9 @@ async fn main() -> ExitCode {
216 let app_with_middleware = middleware.layer(app); 238 let app_with_middleware = middleware.layer(app);
217 239
218 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { 240 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
219 println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535"); 241 println!(
242 "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535"
243 );
220 return ExitCode::from(2); 244 return ExitCode::from(2);
221 }; 245 };
222 let addr: (String, u16) = (env.listen_host, listen_port); 246 let addr: (String, u16) = (env.listen_host, listen_port);
@@ -617,6 +641,35 @@ async fn handle_download_object(
617 }; 641 };
618 } 642 }
619 643
644 if let Some(content_length) = result.content_length() {
645 if content_length < 0 {
646 match state
647 .dl_limiter
648 .lock()
649 .await
650 .request(content_length as u64)
651 .await
652 {
653 Ok(true) => {}
654 Ok(false) => {
655 return BatchResponseObject::error(
656 obj,
657 StatusCode::SERVICE_UNAVAILABLE,
658 "Public LFS downloads temporarily unavailable".to_string(),
659 );
660 }
661 Err(e) => {
662 println!("Failed to request {content_length} bytes from download limiter: {e}");
663 return BatchResponseObject::error(
664 obj,
665 StatusCode::INTERNAL_SERVER_ERROR,
666 "Internal server error".to_string(),
667 );
668 }
669 }
670 }
671 }
672
620 let Some(tag) = common::generate_tag( 673 let Some(tag) = common::generate_tag(
621 common::Claims { 674 common::Claims {
622 specific_claims: common::SpecificClaims::Download(obj.oid), 675 specific_claims: common::SpecificClaims::Download(obj.oid),
@@ -956,6 +1009,56 @@ async fn obj_download(
956 (headers, body).into_response() 1009 (headers, body).into_response()
957} 1010}
958 1011
1012struct DownloadLimiter {
1013 current: u64,
1014 limit: u64,
1015}
1016
1017impl DownloadLimiter {
1018 async fn new(limit: u64) -> DownloadLimiter {
1019 let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await {
1020 Ok(dlimit_str) => dlimit_str,
1021 Err(e) => {
1022 println!("Failed to read download counter, assuming 0: {e}");
1023 return DownloadLimiter { current: 0, limit };
1024 }
1025 };
1026 let current: u64 = match dlimit_str.parse().map_err(tokio::io::Error::other) {
1027 Ok(current) => current,
1028 Err(e) => {
1029 println!("Failed to read download counter, assuming 0: {e}");
1030 return DownloadLimiter { current: 0, limit };
1031 }
1032 };
1033 DownloadLimiter { current, limit }
1034 }
1035
1036 async fn request(&mut self, n: u64) -> tokio::io::Result<bool> {
1037 if self.current + n > self.limit {
1038 return Ok(false);
1039 }
1040 self.current += n;
1041 self.write_new_limit().await?;
1042 Ok(true)
1043 }
1044
1045 async fn reset(&mut self) {
1046 self.current = 0;
1047 if let Err(e) = self.write_new_limit().await {
1048 println!("Failed to reset download counter: {e}");
1049 }
1050 }
1051
1052 async fn write_new_limit(&self) -> tokio::io::Result<()> {
1053 let cwd = tokio::fs::File::open(std::env::current_dir()?).await?;
1054 let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?;
1055 file.write_all(self.limit.to_string().as_bytes()).await?;
1056 file.sync_all().await?;
1057 tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?;
1058 cwd.sync_all().await
1059 }
1060}
1061
959#[test] 1062#[test]
960fn test_mimetype() { 1063fn test_mimetype() {
961 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json")); 1064 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json"));