aboutsummaryrefslogtreecommitdiffstats
path: root/rs/server/src
diff options
context:
space:
mode:
Diffstat (limited to 'rs/server/src')
-rw-r--r--rs/server/src/main.rs279
1 files changed, 236 insertions, 43 deletions
diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs
index 7d8b83a..aebeb88 100644
--- a/rs/server/src/main.rs
+++ b/rs/server/src/main.rs
@@ -1,17 +1,21 @@
1use std::collections::HashMap; 1use std::collections::HashMap;
2use std::collections::HashSet;
2use std::sync::Arc; 3use std::sync::Arc;
3 4
4use axum::extract::State; 5use aws_sdk_s3::operation::head_object::HeadObjectOutput;
5use axum::extract::rejection; 6use axum::extract::rejection;
6use axum::extract::FromRequest; 7use axum::extract::FromRequest;
7use axum::extract::Path; 8use axum::extract::Path;
9use axum::extract::State;
8use axum::http::header; 10use axum::http::header;
9use axum::http::HeaderMap; 11use axum::http::HeaderMap;
10use axum::http::HeaderValue; 12use axum::http::HeaderValue;
11use axum::response::Response; 13use axum::response::Response;
12use axum::Json; 14use axum::Json;
15use axum::ServiceExt;
13use base64::prelude::*; 16use base64::prelude::*;
14use chrono::DateTime; 17use chrono::DateTime;
18use chrono::Duration;
15use chrono::Utc; 19use chrono::Utc;
16use common::HexByte; 20use common::HexByte;
17use serde::de; 21use serde::de;
@@ -19,7 +23,6 @@ use serde::de::DeserializeOwned;
19use serde::Deserialize; 23use serde::Deserialize;
20use serde::Serialize; 24use serde::Serialize;
21use tower::Layer; 25use tower::Layer;
22use axum::ServiceExt;
23 26
24use axum::{ 27use axum::{
25 async_trait, 28 async_trait,
@@ -30,6 +33,8 @@ use axum::{
30 Extension, Router, 33 Extension, Router,
31}; 34};
32 35
36use serde_json::json;
37
33#[derive(Clone)] 38#[derive(Clone)]
34struct RepositoryName(String); 39struct RepositoryName(String);
35 40
@@ -54,7 +59,9 @@ impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
54 } 59 }
55} 60}
56 61
57async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::Request<B>, StatusCode> { 62async fn rewrite_url<B>(
63 mut req: axum::http::Request<B>,
64) -> Result<axum::http::Request<B>, StatusCode> {
58 let uri = req.uri(); 65 let uri = req.uri();
59 let original_uri = OriginalUri(uri.clone()); 66 let original_uri = OriginalUri(uri.clone());
60 67
@@ -67,7 +74,7 @@ async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::R
67 .trim_end_matches('/') 74 .trim_end_matches('/')
68 .to_string(); 75 .to_string();
69 if !path.starts_with("/") || !repo.ends_with(".git") { 76 if !path.starts_with("/") || !repo.ends_with(".git") {
70 return Err(StatusCode::NOT_FOUND); 77 return Err(StatusCode::NOT_FOUND);
71 } 78 }
72 79
73 let mut parts = uri.clone().into_parts(); 80 let mut parts = uri.clone().into_parts();
@@ -87,16 +94,25 @@ async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::R
87struct AppState { 94struct AppState {
88 s3_client: aws_sdk_s3::Client, 95 s3_client: aws_sdk_s3::Client,
89 s3_bucket: String, 96 s3_bucket: String,
97 authz_conf: AuthorizationConfig,
90} 98}
91 99
92fn get_s3_client() -> aws_sdk_s3::Client { 100fn get_s3_client() -> aws_sdk_s3::Client {
93 let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); 101 let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap();
94 let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); 102 let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap();
95 103
96 let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env"); 104 let credentials = aws_sdk_s3::config::Credentials::new(
105 access_key_id,
106 secret_access_key,
107 None,
108 None,
109 "gitolfs3-env",
110 );
97 let config = aws_config::SdkConfig::builder() 111 let config = aws_config::SdkConfig::builder()
98 .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) 112 .endpoint_url(std::env::var("S3_ENDPOINT").unwrap())
99 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials)) 113 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(
114 credentials,
115 ))
100 .build(); 116 .build();
101 aws_sdk_s3::Client::new(&config) 117 aws_sdk_s3::Client::new(&config)
102} 118}
@@ -106,9 +122,26 @@ async fn main() {
106 // run our app with hyper, listening globally on port 3000 122 // run our app with hyper, listening globally on port 3000
107 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); 123 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
108 124
125 let key_path = std::env::var("GITOLFS3_KEY_PATH").unwrap();
126 let key = common::load_key(&key_path).unwrap();
127 let trusted_forwarded_hosts = std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS").unwrap();
128 let trusted_forwarded_hosts: HashSet<String> = trusted_forwarded_hosts
129 .split(',')
130 .map(|s| s.to_owned())
131 .collect();
132
133 let authz_conf = AuthorizationConfig {
134 key,
135 trusted_forwarded_hosts,
136 };
137
109 let s3_client = get_s3_client(); 138 let s3_client = get_s3_client();
110 let s3_bucket = std::env::var("S3_BUCKET").unwrap(); 139 let s3_bucket = std::env::var("S3_BUCKET").unwrap();
111 let shared_state = Arc::new(AppState { s3_client, s3_bucket }); 140 let shared_state = Arc::new(AppState {
141 s3_client,
142 s3_bucket,
143 authz_conf,
144 });
112 let app = Router::new() 145 let app = Router::new()
113 .route("/batch", post(batch)) 146 .route("/batch", post(batch))
114 .route("/:oid0/:oid1/:oid", get(obj_download)) 147 .route("/:oid0/:oid1/:oid", get(obj_download))
@@ -119,8 +152,8 @@ async fn main() {
119 let app_with_middleware = middleware.layer(app); 152 let app_with_middleware = middleware.layer(app);
120 153
121 axum::serve(listener, app_with_middleware.into_make_service()) 154 axum::serve(listener, app_with_middleware.into_make_service())
122 .await 155 .await
123 .unwrap(); 156 .unwrap();
124} 157}
125 158
126#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] 159#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
@@ -182,11 +215,11 @@ enum GitLfsJsonRejection {
182 215
183impl IntoResponse for GitLfsJsonRejection { 216impl IntoResponse for GitLfsJsonRejection {
184 fn into_response(self) -> Response { 217 fn into_response(self) -> Response {
185 ( 218 make_error_resp(
186 StatusCode::UNSUPPORTED_MEDIA_TYPE, 219 StatusCode::UNSUPPORTED_MEDIA_TYPE,
187 format!("Expected request with `Content-Type: {LFS_MIME}`"), 220 &format!("Expected request with `Content-Type: {LFS_MIME}`"),
188 ) 221 )
189 .into_response() 222 .into_response()
190 } 223 }
191} 224}
192 225
@@ -241,12 +274,23 @@ impl<T: Serialize> IntoResponse for GitLfsJson<T> {
241 let mut resp = json.into_response(); 274 let mut resp = json.into_response();
242 resp.headers_mut().insert( 275 resp.headers_mut().insert(
243 header::CONTENT_TYPE, 276 header::CONTENT_TYPE,
244 HeaderValue::from_static("application/vnd.git-lfs+json"), 277 HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"),
245 ); 278 );
246 resp 279 resp
247 } 280 }
248} 281}
249 282
283#[derive(Serialize)]
284struct GitLfsErrorData<'a> {
285 message: &'a str,
286}
287
288type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>);
289
290const fn make_error_resp<'a>(code: StatusCode, message: &'a str) -> GitLfsErrorResponse {
291 (code, GitLfsJson(Json(GitLfsErrorData { message })))
292}
293
250#[derive(Debug, Serialize, Clone)] 294#[derive(Debug, Serialize, Clone)]
251struct BatchResponseObjectAction { 295struct BatchResponseObjectAction {
252 href: String, 296 href: String,
@@ -280,66 +324,215 @@ struct BatchResponse {
280 hash_algo: HashAlgo, 324 hash_algo: HashAlgo,
281} 325}
282 326
327fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool {
328 if let Some(checksum) = obj.checksum_sha256() {
329 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
330 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
331 return Oid::from(checksum32b) == oid;
332 }
333 }
334 }
335 true
336}
337
338fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool {
339 if let Some(length) = obj.content_length() {
340 return length == expected;
341 }
342 true
343}
344
283async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { 345async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) {
284 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 346 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
285 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 347 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
286 348
287 let result = state.s3_client.head_object(). 349 let result = state
288 bucket(&state.s3_bucket). 350 .s3_client
289 key(full_path). 351 .head_object()
290 checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled). 352 .bucket(&state.s3_bucket)
291 send().await.unwrap(); 353 .key(full_path)
292 if let Some(checksum) = result.checksum_sha256() { 354 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
293 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { 355 .send()
294 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { 356 .await
295 if Oid::from(checksum32b) != obj.oid { 357 .unwrap();
296 unreachable!(); 358 // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :)
359 if !validate_checksum(obj.oid, &result) {
360 unreachable!();
361 }
362 if !validate_size(obj.size, &result) {
363 unreachable!();
364 }
365
366 let expires_at = Utc::now() + Duration::seconds(5 * 60);
367}
368
369struct AuthorizationConfig {
370 trusted_forwarded_hosts: HashSet<String>,
371 key: common::Key,
372}
373
374struct Trusted(bool);
375
376fn forwarded_for_trusted_host(
377 headers: &HeaderMap,
378 trusted: &HashSet<String>,
379) -> Result<bool, GitLfsErrorResponse<'static>> {
380 if let Some(forwarded_for) = headers.get("X-Forwarded-For") {
381 if let Ok(forwarded_for) = forwarded_for.to_str() {
382 if trusted.contains(forwarded_for) {
383 return Ok(true);
384 }
385 } else {
386 return Err(make_error_resp(
387 StatusCode::NOT_FOUND,
388 "Invalid X-Forwarded-For header",
389 ));
390 }
391 }
392 return Ok(false);
393}
394const REPO_NOT_FOUND: GitLfsErrorResponse =
395 make_error_resp(StatusCode::NOT_FOUND, "Repository not found");
396
397fn authorize(
398 conf: &AuthorizationConfig,
399 headers: &HeaderMap,
400 repo_path: &str,
401 public: bool,
402 operation: common::Operation,
403) -> Result<Trusted, GitLfsErrorResponse<'static>> {
404 // - No authentication required for downloading exported repos
405 // - When authenticated:
406 // - Download / upload over presigned URLs
407 // - When accessing over Tailscale:
408 // - No authentication required for downloading from any repo
409
410 const INVALID_AUTHZ_HEADER: GitLfsErrorResponse =
411 make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header");
412
413 if let Some(authz) = headers.get(header::AUTHORIZATION) {
414 if let Ok(authz) = authz.to_str() {
415 if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") {
416 let Some((tag, expires_at)) = val.split_once(' ') else {
417 return Err(INVALID_AUTHZ_HEADER);
418 };
419 let Ok(tag): Result<common::Digest<32>, _> = tag.parse() else {
420 return Err(INVALID_AUTHZ_HEADER);
421 };
422 let Ok(expires_at): Result<i64, _> = expires_at.parse() else {
423 return Err(INVALID_AUTHZ_HEADER);
424 };
425 let Some(expires_at) = DateTime::<Utc>::from_timestamp(expires_at, 0) else {
426 return Err(INVALID_AUTHZ_HEADER);
427 };
428 let Some(expected_tag) = common::generate_tag(
429 common::Claims {
430 auth_type: common::AuthType::GitLfsAuthenticate,
431 repo_path,
432 expires_at,
433 operation,
434 },
435 &conf.key,
436 ) else {
437 return Err(INVALID_AUTHZ_HEADER);
438 };
439 if tag == expected_tag {
440 return Ok(Trusted(true));
441 } else {
442 return Err(INVALID_AUTHZ_HEADER);
297 } 443 }
444 } else {
445 return Err(INVALID_AUTHZ_HEADER);
298 } 446 }
447 } else {
448 return Err(INVALID_AUTHZ_HEADER);
299 } 449 }
300 } 450 }
451
452 let trusted = forwarded_for_trusted_host(headers, &conf.trusted_forwarded_hosts)?;
453 if operation != common::Operation::Download {
454 if trusted {
455 return Err(make_error_resp(
456 StatusCode::FORBIDDEN,
457 "Authentication required to upload",
458 ));
459 }
460 return Err(REPO_NOT_FOUND);
461 }
462 if trusted {
463 return Ok(Trusted(true));
464 }
465
466 if public {
467 Ok(Trusted(false))
468 } else {
469 Err(REPO_NOT_FOUND)
470 }
471}
472
473fn repo_exists(name: &str) -> bool {
474 let Ok(metadata) = std::fs::metadata(name) else {
475 return false;
476 };
477 return metadata.is_dir();
478}
479
480fn is_repo_public(name: &str) -> Option<bool> {
481 if !repo_exists(name) {
482 return None;
483 }
484 std::fs::metadata(format!("{name}/git-daemon-export-ok"))
485 .ok()?
486 .is_file()
487 .into()
301} 488}
302 489
303async fn batch( 490async fn batch(
304 State(state): State<Arc<AppState>>, 491 State(state): State<Arc<AppState>>,
305 header: HeaderMap, 492 headers: HeaderMap,
306 RepositoryName(repo): RepositoryName, 493 RepositoryName(repo): RepositoryName,
307 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, 494 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
308) -> Response { 495) -> Response {
309 if !header 496 let Some(public) = is_repo_public(&repo) else {
497 return REPO_NOT_FOUND.into_response();
498 };
499 let authn = match authorize(
500 &state.authz_conf,
501 &headers,
502 &repo,
503 public,
504 payload.operation,
505 ) {
506 Ok(authn) => authn,
507 Err(e) => return e.into_response(),
508 };
509
510 if !headers
310 .get_all("Accept") 511 .get_all("Accept")
311 .iter() 512 .iter()
312 .filter_map(|v| v.to_str().ok()) 513 .filter_map(|v| v.to_str().ok())
313 .any(is_git_lfs_json_mimetype) 514 .any(is_git_lfs_json_mimetype)
314 { 515 {
315 return ( 516 let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types");
316 StatusCode::NOT_ACCEPTABLE, 517 return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response();
317 format!("Expected `{LFS_MIME}` (with UTF-8 charset) in list of acceptable response media types"),
318 ).into_response();
319 } 518 }
320 519
321 if payload.hash_algo != HashAlgo::Sha256 { 520 if payload.hash_algo != HashAlgo::Sha256 {
322 return ( 521 let message = "Unsupported hashing algorithm specified";
323 StatusCode::CONFLICT, 522 return make_error_resp(StatusCode::CONFLICT, message).into_response();
324 "Unsupported hashing algorithm speicifed",
325 )
326 .into_response();
327 } 523 }
328 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { 524 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) {
329 return ( 525 let message = "Unsupported transfer adapter specified (supported: basic)";
330 StatusCode::CONFLICT, 526 return make_error_resp(StatusCode::CONFLICT, message).into_response();
331 "Unsupported transfer adapter specified (supported: basic)",
332 )
333 .into_response();
334 } 527 }
335 528
336 let resp: BatchResponse; 529 let resp: BatchResponse;
337 for obj in payload.objects { 530 for obj in payload.objects {
338 handle_download_object(&state, &repo, &obj).await; 531 handle_download_object(&state, &repo, &obj).await;
339// match payload.operation { 532 // match payload.operation {
340// Operation::Download => resp.objects.push(handle_download_object(repo, obj));, 533 // Operation::Download => resp.objects.push(handle_download_object(repo, obj));,
341// Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), 534 // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)),
342// }; 535 // };
343 } 536 }
344 537
345 format!("hi from {repo}\n").into_response() 538 format!("hi from {repo}\n").into_response()