diff options
Diffstat (limited to 'rs/server')
-rw-r--r-- | rs/server/Cargo.toml | 1 | ||||
-rw-r--r-- | rs/server/src/main.rs | 279 |
2 files changed, 237 insertions, 43 deletions
diff --git a/rs/server/Cargo.toml b/rs/server/Cargo.toml index 188299a..9a2a9a9 100644 --- a/rs/server/Cargo.toml +++ b/rs/server/Cargo.toml | |||
@@ -13,6 +13,7 @@ chrono = { version = "0.4", features = ["serde"] } | |||
13 | common = { path = "../common" } | 13 | common = { path = "../common" } |
14 | mime = "0.3" | 14 | mime = "0.3" |
15 | serde = { version = "1", features = ["derive"] } | 15 | serde = { version = "1", features = ["derive"] } |
16 | serde_json = "1" | ||
16 | tokio = { version = "1.35", features = ["full"] } | 17 | tokio = { version = "1.35", features = ["full"] } |
17 | tower = "0.4" | 18 | tower = "0.4" |
18 | tower-service = "0.3" | 19 | tower-service = "0.3" |
diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs index 7d8b83a..aebeb88 100644 --- a/rs/server/src/main.rs +++ b/rs/server/src/main.rs | |||
@@ -1,17 +1,21 @@ | |||
1 | use std::collections::HashMap; | 1 | use std::collections::HashMap; |
2 | use std::collections::HashSet; | ||
2 | use std::sync::Arc; | 3 | use std::sync::Arc; |
3 | 4 | ||
4 | use axum::extract::State; | 5 | use aws_sdk_s3::operation::head_object::HeadObjectOutput; |
5 | use axum::extract::rejection; | 6 | use axum::extract::rejection; |
6 | use axum::extract::FromRequest; | 7 | use axum::extract::FromRequest; |
7 | use axum::extract::Path; | 8 | use axum::extract::Path; |
9 | use axum::extract::State; | ||
8 | use axum::http::header; | 10 | use axum::http::header; |
9 | use axum::http::HeaderMap; | 11 | use axum::http::HeaderMap; |
10 | use axum::http::HeaderValue; | 12 | use axum::http::HeaderValue; |
11 | use axum::response::Response; | 13 | use axum::response::Response; |
12 | use axum::Json; | 14 | use axum::Json; |
15 | use axum::ServiceExt; | ||
13 | use base64::prelude::*; | 16 | use base64::prelude::*; |
14 | use chrono::DateTime; | 17 | use chrono::DateTime; |
18 | use chrono::Duration; | ||
15 | use chrono::Utc; | 19 | use chrono::Utc; |
16 | use common::HexByte; | 20 | use common::HexByte; |
17 | use serde::de; | 21 | use serde::de; |
@@ -19,7 +23,6 @@ use serde::de::DeserializeOwned; | |||
19 | use serde::Deserialize; | 23 | use serde::Deserialize; |
20 | use serde::Serialize; | 24 | use serde::Serialize; |
21 | use tower::Layer; | 25 | use tower::Layer; |
22 | use axum::ServiceExt; | ||
23 | 26 | ||
24 | use axum::{ | 27 | use axum::{ |
25 | async_trait, | 28 | async_trait, |
@@ -30,6 +33,8 @@ use axum::{ | |||
30 | Extension, Router, | 33 | Extension, Router, |
31 | }; | 34 | }; |
32 | 35 | ||
36 | use serde_json::json; | ||
37 | |||
33 | #[derive(Clone)] | 38 | #[derive(Clone)] |
34 | struct RepositoryName(String); | 39 | struct RepositoryName(String); |
35 | 40 | ||
@@ -54,7 +59,9 @@ impl<S: Send + Sync> FromRequestParts<S> for RepositoryName { | |||
54 | } | 59 | } |
55 | } | 60 | } |
56 | 61 | ||
57 | async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::Request<B>, StatusCode> { | 62 | async fn rewrite_url<B>( |
63 | mut req: axum::http::Request<B>, | ||
64 | ) -> Result<axum::http::Request<B>, StatusCode> { | ||
58 | let uri = req.uri(); | 65 | let uri = req.uri(); |
59 | let original_uri = OriginalUri(uri.clone()); | 66 | let original_uri = OriginalUri(uri.clone()); |
60 | 67 | ||
@@ -67,7 +74,7 @@ async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::R | |||
67 | .trim_end_matches('/') | 74 | .trim_end_matches('/') |
68 | .to_string(); | 75 | .to_string(); |
69 | if !path.starts_with("/") || !repo.ends_with(".git") { | 76 | if !path.starts_with("/") || !repo.ends_with(".git") { |
70 | return Err(StatusCode::NOT_FOUND); | 77 | return Err(StatusCode::NOT_FOUND); |
71 | } | 78 | } |
72 | 79 | ||
73 | let mut parts = uri.clone().into_parts(); | 80 | let mut parts = uri.clone().into_parts(); |
@@ -87,16 +94,25 @@ async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::R | |||
87 | struct AppState { | 94 | struct AppState { |
88 | s3_client: aws_sdk_s3::Client, | 95 | s3_client: aws_sdk_s3::Client, |
89 | s3_bucket: String, | 96 | s3_bucket: String, |
97 | authz_conf: AuthorizationConfig, | ||
90 | } | 98 | } |
91 | 99 | ||
92 | fn get_s3_client() -> aws_sdk_s3::Client { | 100 | fn get_s3_client() -> aws_sdk_s3::Client { |
93 | let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); | 101 | let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); |
94 | let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); | 102 | let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); |
95 | 103 | ||
96 | let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env"); | 104 | let credentials = aws_sdk_s3::config::Credentials::new( |
105 | access_key_id, | ||
106 | secret_access_key, | ||
107 | None, | ||
108 | None, | ||
109 | "gitolfs3-env", | ||
110 | ); | ||
97 | let config = aws_config::SdkConfig::builder() | 111 | let config = aws_config::SdkConfig::builder() |
98 | .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) | 112 | .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) |
99 | .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials)) | 113 | .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new( |
114 | credentials, | ||
115 | )) | ||
100 | .build(); | 116 | .build(); |
101 | aws_sdk_s3::Client::new(&config) | 117 | aws_sdk_s3::Client::new(&config) |
102 | } | 118 | } |
@@ -106,9 +122,26 @@ async fn main() { | |||
106 | // run our app with hyper, listening globally on port 3000 | 122 | // run our app with hyper, listening globally on port 3000 |
107 | let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); | 123 | let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); |
108 | 124 | ||
125 | let key_path = std::env::var("GITOLFS3_KEY_PATH").unwrap(); | ||
126 | let key = common::load_key(&key_path).unwrap(); | ||
127 | let trusted_forwarded_hosts = std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS").unwrap(); | ||
128 | let trusted_forwarded_hosts: HashSet<String> = trusted_forwarded_hosts | ||
129 | .split(',') | ||
130 | .map(|s| s.to_owned()) | ||
131 | .collect(); | ||
132 | |||
133 | let authz_conf = AuthorizationConfig { | ||
134 | key, | ||
135 | trusted_forwarded_hosts, | ||
136 | }; | ||
137 | |||
109 | let s3_client = get_s3_client(); | 138 | let s3_client = get_s3_client(); |
110 | let s3_bucket = std::env::var("S3_BUCKET").unwrap(); | 139 | let s3_bucket = std::env::var("S3_BUCKET").unwrap(); |
111 | let shared_state = Arc::new(AppState { s3_client, s3_bucket }); | 140 | let shared_state = Arc::new(AppState { |
141 | s3_client, | ||
142 | s3_bucket, | ||
143 | authz_conf, | ||
144 | }); | ||
112 | let app = Router::new() | 145 | let app = Router::new() |
113 | .route("/batch", post(batch)) | 146 | .route("/batch", post(batch)) |
114 | .route("/:oid0/:oid1/:oid", get(obj_download)) | 147 | .route("/:oid0/:oid1/:oid", get(obj_download)) |
@@ -119,8 +152,8 @@ async fn main() { | |||
119 | let app_with_middleware = middleware.layer(app); | 152 | let app_with_middleware = middleware.layer(app); |
120 | 153 | ||
121 | axum::serve(listener, app_with_middleware.into_make_service()) | 154 | axum::serve(listener, app_with_middleware.into_make_service()) |
122 | .await | 155 | .await |
123 | .unwrap(); | 156 | .unwrap(); |
124 | } | 157 | } |
125 | 158 | ||
126 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] | 159 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] |
@@ -182,11 +215,11 @@ enum GitLfsJsonRejection { | |||
182 | 215 | ||
183 | impl IntoResponse for GitLfsJsonRejection { | 216 | impl IntoResponse for GitLfsJsonRejection { |
184 | fn into_response(self) -> Response { | 217 | fn into_response(self) -> Response { |
185 | ( | 218 | make_error_resp( |
186 | StatusCode::UNSUPPORTED_MEDIA_TYPE, | 219 | StatusCode::UNSUPPORTED_MEDIA_TYPE, |
187 | format!("Expected request with `Content-Type: {LFS_MIME}`"), | 220 | &format!("Expected request with `Content-Type: {LFS_MIME}`"), |
188 | ) | 221 | ) |
189 | .into_response() | 222 | .into_response() |
190 | } | 223 | } |
191 | } | 224 | } |
192 | 225 | ||
@@ -241,12 +274,23 @@ impl<T: Serialize> IntoResponse for GitLfsJson<T> { | |||
241 | let mut resp = json.into_response(); | 274 | let mut resp = json.into_response(); |
242 | resp.headers_mut().insert( | 275 | resp.headers_mut().insert( |
243 | header::CONTENT_TYPE, | 276 | header::CONTENT_TYPE, |
244 | HeaderValue::from_static("application/vnd.git-lfs+json"), | 277 | HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"), |
245 | ); | 278 | ); |
246 | resp | 279 | resp |
247 | } | 280 | } |
248 | } | 281 | } |
249 | 282 | ||
283 | #[derive(Serialize)] | ||
284 | struct GitLfsErrorData<'a> { | ||
285 | message: &'a str, | ||
286 | } | ||
287 | |||
288 | type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>); | ||
289 | |||
290 | const fn make_error_resp<'a>(code: StatusCode, message: &'a str) -> GitLfsErrorResponse { | ||
291 | (code, GitLfsJson(Json(GitLfsErrorData { message }))) | ||
292 | } | ||
293 | |||
250 | #[derive(Debug, Serialize, Clone)] | 294 | #[derive(Debug, Serialize, Clone)] |
251 | struct BatchResponseObjectAction { | 295 | struct BatchResponseObjectAction { |
252 | href: String, | 296 | href: String, |
@@ -280,66 +324,215 @@ struct BatchResponse { | |||
280 | hash_algo: HashAlgo, | 324 | hash_algo: HashAlgo, |
281 | } | 325 | } |
282 | 326 | ||
327 | fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { | ||
328 | if let Some(checksum) = obj.checksum_sha256() { | ||
329 | if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { | ||
330 | if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { | ||
331 | return Oid::from(checksum32b) == oid; | ||
332 | } | ||
333 | } | ||
334 | } | ||
335 | true | ||
336 | } | ||
337 | |||
338 | fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool { | ||
339 | if let Some(length) = obj.content_length() { | ||
340 | return length == expected; | ||
341 | } | ||
342 | true | ||
343 | } | ||
344 | |||
283 | async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { | 345 | async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { |
284 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 346 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); |
285 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 347 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
286 | 348 | ||
287 | let result = state.s3_client.head_object(). | 349 | let result = state |
288 | bucket(&state.s3_bucket). | 350 | .s3_client |
289 | key(full_path). | 351 | .head_object() |
290 | checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled). | 352 | .bucket(&state.s3_bucket) |
291 | send().await.unwrap(); | 353 | .key(full_path) |
292 | if let Some(checksum) = result.checksum_sha256() { | 354 | .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) |
293 | if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { | 355 | .send() |
294 | if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { | 356 | .await |
295 | if Oid::from(checksum32b) != obj.oid { | 357 | .unwrap(); |
296 | unreachable!(); | 358 | // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :) |
359 | if !validate_checksum(obj.oid, &result) { | ||
360 | unreachable!(); | ||
361 | } | ||
362 | if !validate_size(obj.size, &result) { | ||
363 | unreachable!(); | ||
364 | } | ||
365 | |||
366 | let expires_at = Utc::now() + Duration::seconds(5 * 60); | ||
367 | } | ||
368 | |||
369 | struct AuthorizationConfig { | ||
370 | trusted_forwarded_hosts: HashSet<String>, | ||
371 | key: common::Key, | ||
372 | } | ||
373 | |||
374 | struct Trusted(bool); | ||
375 | |||
376 | fn forwarded_for_trusted_host( | ||
377 | headers: &HeaderMap, | ||
378 | trusted: &HashSet<String>, | ||
379 | ) -> Result<bool, GitLfsErrorResponse<'static>> { | ||
380 | if let Some(forwarded_for) = headers.get("X-Forwarded-For") { | ||
381 | if let Ok(forwarded_for) = forwarded_for.to_str() { | ||
382 | if trusted.contains(forwarded_for) { | ||
383 | return Ok(true); | ||
384 | } | ||
385 | } else { | ||
386 | return Err(make_error_resp( | ||
387 | StatusCode::NOT_FOUND, | ||
388 | "Invalid X-Forwarded-For header", | ||
389 | )); | ||
390 | } | ||
391 | } | ||
392 | return Ok(false); | ||
393 | } | ||
394 | const REPO_NOT_FOUND: GitLfsErrorResponse = | ||
395 | make_error_resp(StatusCode::NOT_FOUND, "Repository not found"); | ||
396 | |||
397 | fn authorize( | ||
398 | conf: &AuthorizationConfig, | ||
399 | headers: &HeaderMap, | ||
400 | repo_path: &str, | ||
401 | public: bool, | ||
402 | operation: common::Operation, | ||
403 | ) -> Result<Trusted, GitLfsErrorResponse<'static>> { | ||
404 | // - No authentication required for downloading exported repos | ||
405 | // - When authenticated: | ||
406 | // - Download / upload over presigned URLs | ||
407 | // - When accessing over Tailscale: | ||
408 | // - No authentication required for downloading from any repo | ||
409 | |||
410 | const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = | ||
411 | make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header"); | ||
412 | |||
413 | if let Some(authz) = headers.get(header::AUTHORIZATION) { | ||
414 | if let Ok(authz) = authz.to_str() { | ||
415 | if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") { | ||
416 | let Some((tag, expires_at)) = val.split_once(' ') else { | ||
417 | return Err(INVALID_AUTHZ_HEADER); | ||
418 | }; | ||
419 | let Ok(tag): Result<common::Digest<32>, _> = tag.parse() else { | ||
420 | return Err(INVALID_AUTHZ_HEADER); | ||
421 | }; | ||
422 | let Ok(expires_at): Result<i64, _> = expires_at.parse() else { | ||
423 | return Err(INVALID_AUTHZ_HEADER); | ||
424 | }; | ||
425 | let Some(expires_at) = DateTime::<Utc>::from_timestamp(expires_at, 0) else { | ||
426 | return Err(INVALID_AUTHZ_HEADER); | ||
427 | }; | ||
428 | let Some(expected_tag) = common::generate_tag( | ||
429 | common::Claims { | ||
430 | auth_type: common::AuthType::GitLfsAuthenticate, | ||
431 | repo_path, | ||
432 | expires_at, | ||
433 | operation, | ||
434 | }, | ||
435 | &conf.key, | ||
436 | ) else { | ||
437 | return Err(INVALID_AUTHZ_HEADER); | ||
438 | }; | ||
439 | if tag == expected_tag { | ||
440 | return Ok(Trusted(true)); | ||
441 | } else { | ||
442 | return Err(INVALID_AUTHZ_HEADER); | ||
297 | } | 443 | } |
444 | } else { | ||
445 | return Err(INVALID_AUTHZ_HEADER); | ||
298 | } | 446 | } |
447 | } else { | ||
448 | return Err(INVALID_AUTHZ_HEADER); | ||
299 | } | 449 | } |
300 | } | 450 | } |
451 | |||
452 | let trusted = forwarded_for_trusted_host(headers, &conf.trusted_forwarded_hosts)?; | ||
453 | if operation != common::Operation::Download { | ||
454 | if trusted { | ||
455 | return Err(make_error_resp( | ||
456 | StatusCode::FORBIDDEN, | ||
457 | "Authentication required to upload", | ||
458 | )); | ||
459 | } | ||
460 | return Err(REPO_NOT_FOUND); | ||
461 | } | ||
462 | if trusted { | ||
463 | return Ok(Trusted(true)); | ||
464 | } | ||
465 | |||
466 | if public { | ||
467 | Ok(Trusted(false)) | ||
468 | } else { | ||
469 | Err(REPO_NOT_FOUND) | ||
470 | } | ||
471 | } | ||
472 | |||
473 | fn repo_exists(name: &str) -> bool { | ||
474 | let Ok(metadata) = std::fs::metadata(name) else { | ||
475 | return false; | ||
476 | }; | ||
477 | return metadata.is_dir(); | ||
478 | } | ||
479 | |||
480 | fn is_repo_public(name: &str) -> Option<bool> { | ||
481 | if !repo_exists(name) { | ||
482 | return None; | ||
483 | } | ||
484 | std::fs::metadata(format!("{name}/git-daemon-export-ok")) | ||
485 | .ok()? | ||
486 | .is_file() | ||
487 | .into() | ||
301 | } | 488 | } |
302 | 489 | ||
303 | async fn batch( | 490 | async fn batch( |
304 | State(state): State<Arc<AppState>>, | 491 | State(state): State<Arc<AppState>>, |
305 | header: HeaderMap, | 492 | headers: HeaderMap, |
306 | RepositoryName(repo): RepositoryName, | 493 | RepositoryName(repo): RepositoryName, |
307 | GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, | 494 | GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, |
308 | ) -> Response { | 495 | ) -> Response { |
309 | if !header | 496 | let Some(public) = is_repo_public(&repo) else { |
497 | return REPO_NOT_FOUND.into_response(); | ||
498 | }; | ||
499 | let authn = match authorize( | ||
500 | &state.authz_conf, | ||
501 | &headers, | ||
502 | &repo, | ||
503 | public, | ||
504 | payload.operation, | ||
505 | ) { | ||
506 | Ok(authn) => authn, | ||
507 | Err(e) => return e.into_response(), | ||
508 | }; | ||
509 | |||
510 | if !headers | ||
310 | .get_all("Accept") | 511 | .get_all("Accept") |
311 | .iter() | 512 | .iter() |
312 | .filter_map(|v| v.to_str().ok()) | 513 | .filter_map(|v| v.to_str().ok()) |
313 | .any(is_git_lfs_json_mimetype) | 514 | .any(is_git_lfs_json_mimetype) |
314 | { | 515 | { |
315 | return ( | 516 | let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types"); |
316 | StatusCode::NOT_ACCEPTABLE, | 517 | return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response(); |
317 | format!("Expected `{LFS_MIME}` (with UTF-8 charset) in list of acceptable response media types"), | ||
318 | ).into_response(); | ||
319 | } | 518 | } |
320 | 519 | ||
321 | if payload.hash_algo != HashAlgo::Sha256 { | 520 | if payload.hash_algo != HashAlgo::Sha256 { |
322 | return ( | 521 | let message = "Unsupported hashing algorithm specified"; |
323 | StatusCode::CONFLICT, | 522 | return make_error_resp(StatusCode::CONFLICT, message).into_response(); |
324 | "Unsupported hashing algorithm speicifed", | ||
325 | ) | ||
326 | .into_response(); | ||
327 | } | 523 | } |
328 | if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { | 524 | if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { |
329 | return ( | 525 | let message = "Unsupported transfer adapter specified (supported: basic)"; |
330 | StatusCode::CONFLICT, | 526 | return make_error_resp(StatusCode::CONFLICT, message).into_response(); |
331 | "Unsupported transfer adapter specified (supported: basic)", | ||
332 | ) | ||
333 | .into_response(); | ||
334 | } | 527 | } |
335 | 528 | ||
336 | let resp: BatchResponse; | 529 | let resp: BatchResponse; |
337 | for obj in payload.objects { | 530 | for obj in payload.objects { |
338 | handle_download_object(&state, &repo, &obj).await; | 531 | handle_download_object(&state, &repo, &obj).await; |
339 | // match payload.operation { | 532 | // match payload.operation { |
340 | // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, | 533 | // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, |
341 | // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), | 534 | // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), |
342 | // }; | 535 | // }; |
343 | } | 536 | } |
344 | 537 | ||
345 | format!("hi from {repo}\n").into_response() | 538 | format!("hi from {repo}\n").into_response() |