From 0956476fda305f48644e53305e12ae46cb67a32b Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Fri, 19 Jan 2024 15:20:19 +0100 Subject: Implement crude batch API authorization --- rs/server/src/main.rs | 279 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 236 insertions(+), 43 deletions(-) (limited to 'rs/server/src') diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs index 7d8b83a..aebeb88 100644 --- a/rs/server/src/main.rs +++ b/rs/server/src/main.rs @@ -1,17 +1,21 @@ use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; -use axum::extract::State; +use aws_sdk_s3::operation::head_object::HeadObjectOutput; use axum::extract::rejection; use axum::extract::FromRequest; use axum::extract::Path; +use axum::extract::State; use axum::http::header; use axum::http::HeaderMap; use axum::http::HeaderValue; use axum::response::Response; use axum::Json; +use axum::ServiceExt; use base64::prelude::*; use chrono::DateTime; +use chrono::Duration; use chrono::Utc; use common::HexByte; use serde::de; @@ -19,7 +23,6 @@ use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; use tower::Layer; -use axum::ServiceExt; use axum::{ async_trait, @@ -30,6 +33,8 @@ use axum::{ Extension, Router, }; +use serde_json::json; + #[derive(Clone)] struct RepositoryName(String); @@ -54,7 +59,9 @@ impl FromRequestParts for RepositoryName { } } -async fn rewrite_url(mut req: axum::http::Request) -> Result, StatusCode> { +async fn rewrite_url( + mut req: axum::http::Request, +) -> Result, StatusCode> { let uri = req.uri(); let original_uri = OriginalUri(uri.clone()); @@ -67,7 +74,7 @@ async fn rewrite_url(mut req: axum::http::Request) -> Result(mut req: axum::http::Request) -> Result aws_sdk_s3::Client { let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); - let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env"); + let credentials = aws_sdk_s3::config::Credentials::new( + access_key_id, + secret_access_key, + None, + None, + "gitolfs3-env", + ); let config = aws_config::SdkConfig::builder() .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) - .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials)) + .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new( + credentials, + )) .build(); aws_sdk_s3::Client::new(&config) } @@ -106,9 +122,26 @@ async fn main() { // run our app with hyper, listening globally on port 3000 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + let key_path = std::env::var("GITOLFS3_KEY_PATH").unwrap(); + let key = common::load_key(&key_path).unwrap(); + let trusted_forwarded_hosts = std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS").unwrap(); + let trusted_forwarded_hosts: HashSet = trusted_forwarded_hosts + .split(',') + .map(|s| s.to_owned()) + .collect(); + + let authz_conf = AuthorizationConfig { + key, + trusted_forwarded_hosts, + }; + let s3_client = get_s3_client(); let s3_bucket = std::env::var("S3_BUCKET").unwrap(); - let shared_state = Arc::new(AppState { s3_client, s3_bucket }); + let shared_state = Arc::new(AppState { + s3_client, + s3_bucket, + authz_conf, + }); let app = Router::new() .route("/batch", post(batch)) .route("/:oid0/:oid1/:oid", get(obj_download)) @@ -119,8 +152,8 @@ async fn main() { let app_with_middleware = middleware.layer(app); axum::serve(listener, app_with_middleware.into_make_service()) - .await - .unwrap(); + .await + .unwrap(); } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] @@ -182,11 +215,11 @@ enum GitLfsJsonRejection { impl IntoResponse for GitLfsJsonRejection { fn into_response(self) -> Response { - ( + make_error_resp( StatusCode::UNSUPPORTED_MEDIA_TYPE, - format!("Expected request with `Content-Type: {LFS_MIME}`"), + &format!("Expected request with `Content-Type: {LFS_MIME}`"), ) - .into_response() + .into_response() } } @@ -241,12 +274,23 @@ impl IntoResponse for GitLfsJson { let mut resp = json.into_response(); resp.headers_mut().insert( header::CONTENT_TYPE, - HeaderValue::from_static("application/vnd.git-lfs+json"), + HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"), ); resp } } +#[derive(Serialize)] +struct GitLfsErrorData<'a> { + message: &'a str, +} + +type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson>); + +const fn make_error_resp<'a>(code: StatusCode, message: &'a str) -> GitLfsErrorResponse { + (code, GitLfsJson(Json(GitLfsErrorData { message }))) +} + #[derive(Debug, Serialize, Clone)] struct BatchResponseObjectAction { href: String, @@ -280,66 +324,215 @@ struct BatchResponse { hash_algo: HashAlgo, } +fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { + if let Some(checksum) = obj.checksum_sha256() { + if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { + if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { + return Oid::from(checksum32b) == oid; + } + } + } + true +} + +fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool { + if let Some(length) = obj.content_length() { + return length == expected; + } + true +} + async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); - let result = state.s3_client.head_object(). - bucket(&state.s3_bucket). - key(full_path). - checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled). - send().await.unwrap(); - if let Some(checksum) = result.checksum_sha256() { - if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { - if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { - if Oid::from(checksum32b) != obj.oid { - unreachable!(); + let result = state + .s3_client + .head_object() + .bucket(&state.s3_bucket) + .key(full_path) + .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) + .send() + .await + .unwrap(); + // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :) + if !validate_checksum(obj.oid, &result) { + unreachable!(); + } + if !validate_size(obj.size, &result) { + unreachable!(); + } + + let expires_at = Utc::now() + Duration::seconds(5 * 60); +} + +struct AuthorizationConfig { + trusted_forwarded_hosts: HashSet, + key: common::Key, +} + +struct Trusted(bool); + +fn forwarded_for_trusted_host( + headers: &HeaderMap, + trusted: &HashSet, +) -> Result> { + if let Some(forwarded_for) = headers.get("X-Forwarded-For") { + if let Ok(forwarded_for) = forwarded_for.to_str() { + if trusted.contains(forwarded_for) { + return Ok(true); + } + } else { + return Err(make_error_resp( + StatusCode::NOT_FOUND, + "Invalid X-Forwarded-For header", + )); + } + } + return Ok(false); +} +const REPO_NOT_FOUND: GitLfsErrorResponse = + make_error_resp(StatusCode::NOT_FOUND, "Repository not found"); + +fn authorize( + conf: &AuthorizationConfig, + headers: &HeaderMap, + repo_path: &str, + public: bool, + operation: common::Operation, +) -> Result> { + // - No authentication required for downloading exported repos + // - When authenticated: + // - Download / upload over presigned URLs + // - When accessing over Tailscale: + // - No authentication required for downloading from any repo + + const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = + make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header"); + + if let Some(authz) = headers.get(header::AUTHORIZATION) { + if let Ok(authz) = authz.to_str() { + if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") { + let Some((tag, expires_at)) = val.split_once(' ') else { + return Err(INVALID_AUTHZ_HEADER); + }; + let Ok(tag): Result, _> = tag.parse() else { + return Err(INVALID_AUTHZ_HEADER); + }; + let Ok(expires_at): Result = expires_at.parse() else { + return Err(INVALID_AUTHZ_HEADER); + }; + let Some(expires_at) = DateTime::::from_timestamp(expires_at, 0) else { + return Err(INVALID_AUTHZ_HEADER); + }; + let Some(expected_tag) = common::generate_tag( + common::Claims { + auth_type: common::AuthType::GitLfsAuthenticate, + repo_path, + expires_at, + operation, + }, + &conf.key, + ) else { + return Err(INVALID_AUTHZ_HEADER); + }; + if tag == expected_tag { + return Ok(Trusted(true)); + } else { + return Err(INVALID_AUTHZ_HEADER); } + } else { + return Err(INVALID_AUTHZ_HEADER); } + } else { + return Err(INVALID_AUTHZ_HEADER); } } + + let trusted = forwarded_for_trusted_host(headers, &conf.trusted_forwarded_hosts)?; + if operation != common::Operation::Download { + if trusted { + return Err(make_error_resp( + StatusCode::FORBIDDEN, + "Authentication required to upload", + )); + } + return Err(REPO_NOT_FOUND); + } + if trusted { + return Ok(Trusted(true)); + } + + if public { + Ok(Trusted(false)) + } else { + Err(REPO_NOT_FOUND) + } +} + +fn repo_exists(name: &str) -> bool { + let Ok(metadata) = std::fs::metadata(name) else { + return false; + }; + return metadata.is_dir(); +} + +fn is_repo_public(name: &str) -> Option { + if !repo_exists(name) { + return None; + } + std::fs::metadata(format!("{name}/git-daemon-export-ok")) + .ok()? + .is_file() + .into() } async fn batch( State(state): State>, - header: HeaderMap, + headers: HeaderMap, RepositoryName(repo): RepositoryName, GitLfsJson(Json(payload)): GitLfsJson, ) -> Response { - if !header + let Some(public) = is_repo_public(&repo) else { + return REPO_NOT_FOUND.into_response(); + }; + let authn = match authorize( + &state.authz_conf, + &headers, + &repo, + public, + payload.operation, + ) { + Ok(authn) => authn, + Err(e) => return e.into_response(), + }; + + if !headers .get_all("Accept") .iter() .filter_map(|v| v.to_str().ok()) .any(is_git_lfs_json_mimetype) { - return ( - StatusCode::NOT_ACCEPTABLE, - format!("Expected `{LFS_MIME}` (with UTF-8 charset) in list of acceptable response media types"), - ).into_response(); + let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types"); + return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response(); } if payload.hash_algo != HashAlgo::Sha256 { - return ( - StatusCode::CONFLICT, - "Unsupported hashing algorithm speicifed", - ) - .into_response(); + let message = "Unsupported hashing algorithm specified"; + return make_error_resp(StatusCode::CONFLICT, message).into_response(); } if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { - return ( - StatusCode::CONFLICT, - "Unsupported transfer adapter specified (supported: basic)", - ) - .into_response(); + let message = "Unsupported transfer adapter specified (supported: basic)"; + return make_error_resp(StatusCode::CONFLICT, message).into_response(); } let resp: BatchResponse; for obj in payload.objects { handle_download_object(&state, &repo, &obj).await; -// match payload.operation { -// Operation::Download => resp.objects.push(handle_download_object(repo, obj));, -// Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), -// }; + // match payload.operation { + // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, + // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), + // }; } format!("hi from {repo}\n").into_response() -- cgit v1.2.3