From bc709f0f23be345a1e2ccd06acd36bd5dac40bde Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Fri, 12 Jul 2024 00:29:57 +0200 Subject: Restructure server --- gitolfs3-server/src/api.rs | 213 +++++++++++----------- gitolfs3-server/src/authz.rs | 80 +++++---- gitolfs3-server/src/config.rs | 122 ++++++------- gitolfs3-server/src/dlimit.rs | 2 +- gitolfs3-server/src/handler.rs | 388 ++++++++++++++++++++--------------------- gitolfs3-server/src/main.rs | 23 ++- 6 files changed, 424 insertions(+), 404 deletions(-) diff --git a/gitolfs3-server/src/api.rs b/gitolfs3-server/src/api.rs index dba7ada..d71d188 100644 --- a/gitolfs3-server/src/api.rs +++ b/gitolfs3-server/src/api.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use axum::{ async_trait, extract::{rejection, FromRequest, FromRequestParts, Request}, - http::{header, request::Parts, HeaderValue, StatusCode}, + http, response::{IntoResponse, Response}, Extension, Json, }; @@ -11,79 +11,21 @@ use chrono::{DateTime, Utc}; use gitolfs3_common::{Oid, Operation}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -pub const REPO_NOT_FOUND: GitLfsErrorResponse = - make_error_resp(StatusCode::NOT_FOUND, "Repository not found"); - -#[derive(Clone)] -pub struct RepositoryName(pub String); - -pub struct RepositoryNameRejection; - -impl IntoResponse for RepositoryNameRejection { - fn into_response(self) -> Response { - (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response() - } -} - -#[async_trait] -impl FromRequestParts for RepositoryName { - type Rejection = RepositoryNameRejection; - - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { - let Ok(Extension(repo_name)) = Extension::::from_request_parts(parts, state).await - else { - return Err(RepositoryNameRejection); - }; - Ok(repo_name) - } -} +// ----------------------- Generic facilities ---------------------- -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] -pub enum TransferAdapter { - #[serde(rename = "basic")] - Basic, - #[serde(other)] - Unknown, -} +pub type GitLfsErrorResponse<'a> = (http::StatusCode, GitLfsJson>); -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] -pub enum HashAlgo { - #[serde(rename = "sha256")] - Sha256, - #[serde(other)] - Unknown, -} - -impl Default for HashAlgo { - fn default() -> Self { - Self::Sha256 - } -} - -#[derive(Debug, Deserialize, PartialEq, Eq, Clone)] -pub struct BatchRequestObject { - pub oid: Oid, - pub size: i64, -} - -#[derive(Debug, Serialize, Deserialize, Clone)] -struct BatchRef { - name: String, +#[derive(Debug, Serialize)] +pub struct GitLfsErrorData<'a> { + pub message: &'a str, } -fn default_transfers() -> Vec { - vec![TransferAdapter::Basic] +pub const fn make_error_resp(code: http::StatusCode, message: &str) -> GitLfsErrorResponse { + (code, GitLfsJson(Json(GitLfsErrorData { message }))) } -#[derive(Debug, Deserialize, PartialEq, Eq, Clone)] -pub struct BatchRequest { - pub operation: Operation, - #[serde(default = "default_transfers")] - pub transfers: Vec, - pub objects: Vec, - #[serde(default)] - pub hash_algo: HashAlgo, -} +pub const REPO_NOT_FOUND: GitLfsErrorResponse = + make_error_resp(http::StatusCode::NOT_FOUND, "Repository not found"); #[derive(Debug, Clone)] pub struct GitLfsJson(pub Json); @@ -100,7 +42,7 @@ impl IntoResponse for GitLfsJsonRejection { match self { Self::Json(rej) => rej.into_response(), Self::MissingGitLfsJsonContentType => make_error_resp( - StatusCode::UNSUPPORTED_MEDIA_TYPE, + http::StatusCode::UNSUPPORTED_MEDIA_TYPE, &format!("Expected request with `Content-Type: {LFS_MIME}`"), ) .into_response(), @@ -125,7 +67,7 @@ pub fn is_git_lfs_json_mimetype(mimetype: &str) -> bool { } fn has_git_lfs_json_content_type(req: &Request) -> bool { - let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else { + let Some(content_type) = req.headers().get(http::header::CONTENT_TYPE) else { return false; }; let Ok(content_type) = content_type.to_str() else { @@ -158,46 +100,98 @@ impl IntoResponse for GitLfsJson { let GitLfsJson(json) = self; let mut resp = json.into_response(); resp.headers_mut().insert( - header::CONTENT_TYPE, - HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"), + http::header::CONTENT_TYPE, + http::HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"), ); resp } } -#[derive(Debug, Serialize)] -pub struct GitLfsErrorData<'a> { - pub message: &'a str, +#[derive(Clone)] +pub struct RepositoryName(pub String); + +pub struct RepositoryNameRejection; + +impl IntoResponse for RepositoryNameRejection { + fn into_response(self) -> Response { + ( + http::StatusCode::INTERNAL_SERVER_ERROR, + "Missing repository name", + ) + .into_response() + } } -pub type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson>); +#[async_trait] +impl FromRequestParts for RepositoryName { + type Rejection = RepositoryNameRejection; -pub const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse { - (code, GitLfsJson(Json(GitLfsErrorData { message }))) + async fn from_request_parts( + parts: &mut http::request::Parts, + state: &S, + ) -> Result { + let Ok(Extension(repo_name)) = Extension::::from_request_parts(parts, state).await + else { + return Err(RepositoryNameRejection); + }; + Ok(repo_name) + } } -#[derive(Debug, Serialize, Clone)] -pub struct BatchResponseObjectAction { - pub href: String, - #[serde(skip_serializing_if = "HashMap::is_empty")] - pub header: HashMap, - pub expires_at: DateTime, +// ----------------------- Git LFS Batch API ----------------------- + +#[derive(Debug, Deserialize, PartialEq, Eq, Clone)] +pub struct BatchRequest { + pub operation: Operation, + #[serde(default = "default_transfers")] + pub transfers: Vec, + pub objects: Vec, + #[serde(default)] + pub hash_algo: HashAlgo, } -#[derive(Default, Debug, Serialize, Clone)] -pub struct BatchResponseObjectActions { - #[serde(skip_serializing_if = "Option::is_none")] - pub upload: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub download: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub verify: Option, +#[derive(Debug, Deserialize, PartialEq, Eq, Clone)] +pub struct BatchRequestObject { + pub oid: Oid, + pub size: i64, } -#[derive(Debug, Clone, Serialize)] -pub struct BatchResponseObjectError { - pub code: u16, - pub message: String, +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] +pub enum TransferAdapter { + #[serde(rename = "basic")] + Basic, + #[serde(other)] + Unknown, +} + +fn default_transfers() -> Vec { + vec![TransferAdapter::Basic] +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] +pub enum HashAlgo { + #[serde(rename = "sha256")] + Sha256, + #[serde(other)] + Unknown, +} + +impl Default for HashAlgo { + fn default() -> Self { + Self::Sha256 + } +} + +#[derive(Debug, Serialize, Deserialize, Clone)] +struct BatchRef { + name: String, +} + +#[derive(Debug, Serialize, Clone)] +pub struct BatchResponse { + pub transfer: TransferAdapter, + pub objects: Vec, + pub hash_algo: HashAlgo, } #[derive(Debug, Serialize, Clone)] @@ -211,10 +205,16 @@ pub struct BatchResponseObject { pub error: Option, } +#[derive(Debug, Clone, Serialize)] +pub struct BatchResponseObjectError { + pub code: u16, + pub message: String, +} + impl BatchResponseObject { pub fn error( obj: &BatchRequestObject, - code: StatusCode, + code: http::StatusCode, message: String, ) -> BatchResponseObject { BatchResponseObject { @@ -231,10 +231,21 @@ impl BatchResponseObject { } #[derive(Debug, Serialize, Clone)] -pub struct BatchResponse { - pub transfer: TransferAdapter, - pub objects: Vec, - pub hash_algo: HashAlgo, +pub struct BatchResponseObjectAction { + pub href: String, + #[serde(skip_serializing_if = "HashMap::is_empty")] + pub header: HashMap, + pub expires_at: DateTime, +} + +#[derive(Default, Debug, Serialize, Clone)] +pub struct BatchResponseObjectActions { + #[serde(skip_serializing_if = "Option::is_none")] + pub upload: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub download: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub verify: Option, } #[test] diff --git a/gitolfs3-server/src/authz.rs b/gitolfs3-server/src/authz.rs index 0674cef..8a5f21f 100644 --- a/gitolfs3-server/src/authz.rs +++ b/gitolfs3-server/src/authz.rs @@ -1,6 +1,6 @@ use std::collections::HashSet; -use axum::http::{header, HeaderMap, StatusCode}; +use axum::http; use chrono::{DateTime, Utc}; use gitolfs3_common::{generate_tag, Claims, Digest, Oid, Operation, SpecificClaims}; @@ -11,31 +11,12 @@ use crate::{ pub struct Trusted(pub bool); -fn forwarded_from_trusted_host( - headers: &HeaderMap, - trusted: &HashSet, -) -> Result> { - if let Some(forwarded_host) = headers.get("X-Forwarded-Host") { - if let Ok(forwarded_host) = forwarded_host.to_str() { - if trusted.contains(forwarded_host) { - return Ok(true); - } - } else { - return Err(make_error_resp( - StatusCode::NOT_FOUND, - "Invalid X-Forwarded-Host header", - )); - } - } - Ok(false) -} - pub fn authorize_batch( conf: &AuthorizationConfig, repo_path: &str, public: bool, operation: Operation, - headers: &HeaderMap, + headers: &http::HeaderMap, ) -> Result> { // - No authentication required for downloading exported repos // - When authenticated: @@ -57,7 +38,7 @@ fn authorize_batch_unauthenticated( conf: &AuthorizationConfig, public: bool, operation: Operation, - headers: &HeaderMap, + headers: &http::HeaderMap, ) -> Result> { let trusted = forwarded_from_trusted_host(headers, &conf.trusted_forwarded_hosts)?; match operation { @@ -71,7 +52,7 @@ fn authorize_batch_unauthenticated( return Err(REPO_NOT_FOUND); } Err(make_error_resp( - StatusCode::FORBIDDEN, + http::StatusCode::FORBIDDEN, "Authentication required to upload", )) } @@ -94,7 +75,7 @@ pub fn authorize_get( conf: &AuthorizationConfig, repo_path: &str, oid: Oid, - headers: &HeaderMap, + headers: &http::HeaderMap, ) -> Result<(), GitLfsErrorResponse<'static>> { let claims = VerifyClaimsInput { specific_claims: SpecificClaims::Download(oid), @@ -102,27 +83,48 @@ pub fn authorize_get( }; if !verify_claims(conf, &claims, headers)? { return Err(make_error_resp( - StatusCode::UNAUTHORIZED, + http::StatusCode::UNAUTHORIZED, "Repository not found", )); } Ok(()) } -pub struct VerifyClaimsInput<'a> { - pub specific_claims: SpecificClaims, - pub repo_path: &'a str, +fn forwarded_from_trusted_host( + headers: &http::HeaderMap, + trusted: &HashSet, +) -> Result> { + if let Some(forwarded_host) = headers.get("X-Forwarded-Host") { + if let Ok(forwarded_host) = forwarded_host.to_str() { + if trusted.contains(forwarded_host) { + return Ok(true); + } + } else { + return Err(make_error_resp( + http::StatusCode::NOT_FOUND, + "Invalid X-Forwarded-Host header", + )); + } + } + Ok(false) +} + +struct VerifyClaimsInput<'a> { + specific_claims: SpecificClaims, + repo_path: &'a str, } fn verify_claims( conf: &AuthorizationConfig, claims: &VerifyClaimsInput, - headers: &HeaderMap, + headers: &http::HeaderMap, ) -> Result> { - const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = - make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header"); + const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = make_error_resp( + http::StatusCode::BAD_REQUEST, + "Invalid authorization header", + ); - let Some(authz) = headers.get(header::AUTHORIZATION) else { + let Some(authz) = headers.get(http::header::AUTHORIZATION) else { return Ok(false); }; let authz = authz.to_str().map_err(|_| INVALID_AUTHZ_HEADER)?; @@ -141,7 +143,12 @@ fn verify_claims( }, &conf.key, ) - .ok_or_else(|| make_error_resp(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?; + .ok_or_else(|| { + make_error_resp( + http::StatusCode::INTERNAL_SERVER_ERROR, + "Internal server error", + ) + })?; if tag != expected_tag { return Err(INVALID_AUTHZ_HEADER); } @@ -175,8 +182,11 @@ fn test_validate_claims() { repo_path: claims.repo_path, specific_claims: claims.specific_claims, }; - let mut headers = HeaderMap::new(); - headers.insert(header::AUTHORIZATION, header_value.try_into().unwrap()); + let mut headers = http::HeaderMap::new(); + headers.insert( + http::header::AUTHORIZATION, + header_value.try_into().unwrap(), + ); assert!(verify_claims(&conf, &verification_claims, &headers).unwrap()); } diff --git a/gitolfs3-server/src/config.rs b/gitolfs3-server/src/config.rs index 75e84dc..c6a51a5 100644 --- a/gitolfs3-server/src/config.rs +++ b/gitolfs3-server/src/config.rs @@ -2,66 +2,6 @@ use std::collections::HashSet; use gitolfs3_common::{load_key, Key}; -struct Env { - s3_access_key_id: String, - s3_secret_access_key: String, - s3_bucket: String, - s3_region: String, - s3_endpoint: String, - base_url: String, - key_path: String, - listen_host: String, - listen_port: String, - download_limit: String, - trusted_forwarded_hosts: String, -} - -fn require_env(name: &str) -> Result { - std::env::var(name) - .map_err(|_| format!("environment variable {name} should be defined and valid")) -} - -impl Env { - fn load() -> Result { - Ok(Env { - s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?, - s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?, - s3_region: require_env("GITOLFS3_S3_REGION")?, - s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?, - s3_bucket: require_env("GITOLFS3_S3_BUCKET")?, - base_url: require_env("GITOLFS3_BASE_URL")?, - key_path: require_env("GITOLFS3_KEY_PATH")?, - listen_host: require_env("GITOLFS3_LISTEN_HOST")?, - listen_port: require_env("GITOLFS3_LISTEN_PORT")?, - download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?, - trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") - .unwrap_or_default(), - }) - } -} - -fn get_s3_client(env: &Env) -> Result { - let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?; - let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?; - - let credentials = aws_sdk_s3::config::Credentials::new( - access_key_id, - secret_access_key, - None, - None, - "gitolfs3-env", - ); - let config = aws_config::SdkConfig::builder() - .behavior_version(aws_config::BehaviorVersion::latest()) - .region(aws_config::Region::new(env.s3_region.clone())) - .endpoint_url(&env.s3_endpoint) - .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new( - credentials, - )) - .build(); - Ok(aws_sdk_s3::Client::new(&config)) -} - pub struct Config { pub listen_addr: (String, u16), pub base_url: String, @@ -83,7 +23,7 @@ impl Config { Err(e) => return Err(format!("failed to load configuration: {e}")), }; - let s3_client = match get_s3_client(&env) { + let s3_client = match create_s3_client(&env) { Ok(s3_client) => s3_client, Err(e) => return Err(format!("failed to create S3 client: {e}")), }; @@ -120,3 +60,63 @@ impl Config { }) } } + +fn create_s3_client(env: &Env) -> Result { + let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?; + let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?; + + let credentials = aws_sdk_s3::config::Credentials::new( + access_key_id, + secret_access_key, + None, + None, + "gitolfs3-env", + ); + let config = aws_config::SdkConfig::builder() + .behavior_version(aws_config::BehaviorVersion::latest()) + .region(aws_config::Region::new(env.s3_region.clone())) + .endpoint_url(&env.s3_endpoint) + .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new( + credentials, + )) + .build(); + Ok(aws_sdk_s3::Client::new(&config)) +} + +struct Env { + s3_access_key_id: String, + s3_secret_access_key: String, + s3_bucket: String, + s3_region: String, + s3_endpoint: String, + base_url: String, + key_path: String, + listen_host: String, + listen_port: String, + download_limit: String, + trusted_forwarded_hosts: String, +} + +impl Env { + fn load() -> Result { + Ok(Env { + s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?, + s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?, + s3_region: require_env("GITOLFS3_S3_REGION")?, + s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?, + s3_bucket: require_env("GITOLFS3_S3_BUCKET")?, + base_url: require_env("GITOLFS3_BASE_URL")?, + key_path: require_env("GITOLFS3_KEY_PATH")?, + listen_host: require_env("GITOLFS3_LISTEN_HOST")?, + listen_port: require_env("GITOLFS3_LISTEN_PORT")?, + download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?, + trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS") + .unwrap_or_default(), + }) + } +} + +fn require_env(name: &str) -> Result { + std::env::var(name) + .map_err(|_| format!("environment variable {name} should be defined and valid")) +} diff --git a/gitolfs3-server/src/dlimit.rs b/gitolfs3-server/src/dlimit.rs index f68bec1..7a82a18 100644 --- a/gitolfs3-server/src/dlimit.rs +++ b/gitolfs3-server/src/dlimit.rs @@ -55,7 +55,7 @@ impl DownloadLimiter { Ok(true) } - pub async fn reset(&mut self) { + async fn reset(&mut self) { self.current = 0; if let Err(e) = self.write_new_count().await { println!("Failed to reset download counter: {e}"); diff --git a/gitolfs3-server/src/handler.rs b/gitolfs3-server/src/handler.rs index 6516291..b9f9bcf 100644 --- a/gitolfs3-server/src/handler.rs +++ b/gitolfs3-server/src/handler.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, sync::Arc}; use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; use axum::{ extract::{Path, State}, - http::{header, HeaderMap, StatusCode}, + http, response::{IntoResponse, Response}, Json, }; @@ -33,102 +33,6 @@ pub struct AppState { pub dl_limiter: Arc>, } -fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { - if let Some(checksum) = obj.checksum_sha256() { - if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { - if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { - return Oid::from(checksum32b) == oid; - } - } - } - true -} - -fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool { - if let Some(length) = obj.content_length() { - return length == expected; - } - true -} - -async fn handle_upload_object( - state: &AppState, - repo: &str, - obj: &BatchRequestObject, -) -> Option { - let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); - let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); - - match state - .s3_client - .head_object() - .bucket(&state.s3_bucket) - .key(full_path.clone()) - .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) - .send() - .await - { - Ok(result) => { - if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) { - return None; - } - } - Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {} - Err(e) => { - println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); - return Some(BatchResponseObject::error( - obj, - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to query object information".to_string(), - )); - } - }; - - let expires_in = std::time::Duration::from_secs(5 * 60); - let expires_at = Utc::now() + expires_in; - - let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else { - return Some(BatchResponseObject::error( - obj, - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to generate upload URL".to_string(), - )); - }; - let Ok(presigned) = state - .s3_client - .put_object() - .bucket(&state.s3_bucket) - .key(full_path) - .checksum_sha256(obj.oid.to_string()) - .content_length(obj.size) - .presigned(config) - .await - else { - return Some(BatchResponseObject::error( - obj, - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to generate upload URL".to_string(), - )); - }; - Some(BatchResponseObject { - oid: obj.oid, - size: obj.size, - authenticated: Some(true), - actions: BatchResponseObjectActions { - upload: Some(BatchResponseObjectAction { - header: presigned - .headers() - .map(|(k, v)| (k.to_owned(), v.to_owned())) - .collect(), - expires_at, - href: presigned.uri().to_string(), - }), - ..Default::default() - }, - error: None, - }) -} - async fn handle_download_object( state: &AppState, repo: &str, @@ -152,24 +56,24 @@ async fn handle_download_object( println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); return BatchResponseObject::error( obj, - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to query object information".to_string(), ); } }; - // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :) - if !validate_checksum(obj.oid, &result) { + // Scaleway actually doesn't provide SHA256 support, but maybe in the future :) + if !s3_validate_checksum(obj.oid, &result) { return BatchResponseObject::error( obj, - StatusCode::UNPROCESSABLE_ENTITY, + http::StatusCode::UNPROCESSABLE_ENTITY, "Object corrupted".to_string(), ); } - if !validate_size(obj.size, &result) { + if !s3_validate_size(obj.size, &result) { return BatchResponseObject::error( obj, - StatusCode::UNPROCESSABLE_ENTITY, + http::StatusCode::UNPROCESSABLE_ENTITY, "Incorrect size specified (or object corrupted)".to_string(), ); } @@ -181,7 +85,7 @@ async fn handle_download_object( let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else { return BatchResponseObject::error( obj, - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to generate upload URL".to_string(), ); }; @@ -195,7 +99,7 @@ async fn handle_download_object( else { return BatchResponseObject::error( obj, - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to generate upload URL".to_string(), ); }; @@ -231,7 +135,7 @@ async fn handle_download_object( Ok(false) => { return BatchResponseObject::error( obj, - StatusCode::SERVICE_UNAVAILABLE, + http::StatusCode::SERVICE_UNAVAILABLE, "Public LFS downloads temporarily unavailable".to_string(), ); } @@ -239,7 +143,7 @@ async fn handle_download_object( println!("Failed to request {content_length} bytes from download limiter: {e}"); return BatchResponseObject::error( obj, - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ); } @@ -257,7 +161,7 @@ async fn handle_download_object( ) else { return BatchResponseObject::error( obj, - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string(), ); }; @@ -292,83 +196,6 @@ async fn handle_download_object( } } -fn repo_exists(name: &str) -> bool { - let Ok(metadata) = std::fs::metadata(name) else { - return false; - }; - metadata.is_dir() -} - -fn is_repo_public(name: &str) -> Option { - if !repo_exists(name) { - return None; - } - match std::fs::metadata(format!("{name}/git-daemon-export-ok")) { - Ok(metadata) if metadata.is_file() => Some(true), - Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some(false), - _ => None, - } -} - -pub async fn batch( - State(state): State>, - headers: HeaderMap, - RepositoryName(repo): RepositoryName, - GitLfsJson(Json(payload)): GitLfsJson, -) -> Response { - let Some(public) = is_repo_public(&repo) else { - return REPO_NOT_FOUND.into_response(); - }; - let Trusted(trusted) = match authorize_batch( - &state.authz_conf, - &repo, - public, - payload.operation, - &headers, - ) { - Ok(authn) => authn, - Err(e) => return e.into_response(), - }; - - if !headers - .get_all("Accept") - .iter() - .filter_map(|v| v.to_str().ok()) - .any(is_git_lfs_json_mimetype) - { - let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types"); - return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response(); - } - - if payload.hash_algo != HashAlgo::Sha256 { - let message = "Unsupported hashing algorithm specified"; - return make_error_resp(StatusCode::CONFLICT, message).into_response(); - } - if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { - let message = "Unsupported transfer adapter specified (supported: basic)"; - return make_error_resp(StatusCode::CONFLICT, message).into_response(); - } - - let mut resp = BatchResponse { - transfer: TransferAdapter::Basic, - objects: vec![], - hash_algo: HashAlgo::Sha256, - }; - for obj in payload.objects { - match payload.operation { - Operation::Download => resp - .objects - .push(handle_download_object(&state, &repo, &obj, trusted).await), - Operation::Upload => { - if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await { - resp.objects.push(obj_resp); - } - } - }; - } - GitLfsJson(Json(resp)).into_response() -} - #[derive(Deserialize, Copy, Clone)] #[serde(remote = "Self")] pub struct FileParams { @@ -382,11 +209,11 @@ impl<'de> Deserialize<'de> for FileParams { where D: serde::Deserializer<'de>, { - let unchecked @ FileParams { + let unchecked @ Self { oid0: HexByte(oid0), oid1: HexByte(oid1), oid, - } = FileParams::deserialize(deserializer)?; + } = Self::deserialize(deserializer)?; if oid0 != oid.as_bytes()[0] { return Err(de::Error::custom( "first OID path part does not match first byte of full OID", @@ -401,9 +228,9 @@ impl<'de> Deserialize<'de> for FileParams { } } -pub async fn obj_download( +pub async fn handle_obj_download( State(state): State>, - headers: HeaderMap, + headers: http::HeaderMap, RepositoryName(repo): RepositoryName, Path(FileParams { oid0, oid1, oid }): Path, ) -> Response { @@ -425,26 +252,26 @@ pub async fn obj_download( Err(e) => { println!("Failed to GetObject (repo {repo}, OID {oid}): {e}"); return ( - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Failed to query object information", ) .into_response(); } }; - let mut headers = header::HeaderMap::new(); + let mut headers = http::header::HeaderMap::new(); if let Some(content_type) = result.content_type { let Ok(header_value) = content_type.try_into() else { return ( - StatusCode::INTERNAL_SERVER_ERROR, + http::StatusCode::INTERNAL_SERVER_ERROR, "Object has invalid content type", ) .into_response(); }; - headers.insert(header::CONTENT_TYPE, header_value); + headers.insert(http::header::CONTENT_TYPE, header_value); } if let Some(content_length) = result.content_length { - headers.insert(header::CONTENT_LENGTH, content_length.into()); + headers.insert(http::header::CONTENT_LENGTH, content_length.into()); } let async_read = result.body.into_async_read(); @@ -453,3 +280,176 @@ pub async fn obj_download( (headers, body).into_response() } + +async fn handle_upload_object( + state: &AppState, + repo: &str, + obj: &BatchRequestObject, +) -> Option { + let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); + let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); + + match state + .s3_client + .head_object() + .bucket(&state.s3_bucket) + .key(full_path.clone()) + .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) + .send() + .await + { + Ok(result) => { + if s3_validate_size(obj.size, &result) && s3_validate_checksum(obj.oid, &result) { + return None; + } + } + Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {} + Err(e) => { + println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); + return Some(BatchResponseObject::error( + obj, + http::StatusCode::INTERNAL_SERVER_ERROR, + "Failed to query object information".to_string(), + )); + } + }; + + let expires_in = std::time::Duration::from_secs(5 * 60); + let expires_at = Utc::now() + expires_in; + + let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else { + return Some(BatchResponseObject::error( + obj, + http::StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + )); + }; + let Ok(presigned) = state + .s3_client + .put_object() + .bucket(&state.s3_bucket) + .key(full_path) + .checksum_sha256(obj.oid.to_string()) + .content_length(obj.size) + .presigned(config) + .await + else { + return Some(BatchResponseObject::error( + obj, + http::StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + )); + }; + Some(BatchResponseObject { + oid: obj.oid, + size: obj.size, + authenticated: Some(true), + actions: BatchResponseObjectActions { + upload: Some(BatchResponseObjectAction { + header: presigned + .headers() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect(), + expires_at, + href: presigned.uri().to_string(), + }), + ..Default::default() + }, + error: None, + }) +} + +pub async fn handle_batch( + State(state): State>, + headers: http::HeaderMap, + RepositoryName(repo): RepositoryName, + GitLfsJson(Json(payload)): GitLfsJson, +) -> Response { + let Some(public) = is_repo_public(&repo) else { + return REPO_NOT_FOUND.into_response(); + }; + let Trusted(trusted) = match authorize_batch( + &state.authz_conf, + &repo, + public, + payload.operation, + &headers, + ) { + Ok(authn) => authn, + Err(e) => return e.into_response(), + }; + + if !headers + .get_all("Accept") + .iter() + .filter_map(|v| v.to_str().ok()) + .any(is_git_lfs_json_mimetype) + { + let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types"); + return make_error_resp(http::StatusCode::NOT_ACCEPTABLE, &message).into_response(); + } + + if payload.hash_algo != HashAlgo::Sha256 { + let message = "Unsupported hashing algorithm specified"; + return make_error_resp(http::StatusCode::CONFLICT, message).into_response(); + } + if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { + let message = "Unsupported transfer adapter specified (supported: basic)"; + return make_error_resp(http::StatusCode::CONFLICT, message).into_response(); + } + + let mut resp = BatchResponse { + transfer: TransferAdapter::Basic, + objects: vec![], + hash_algo: HashAlgo::Sha256, + }; + for obj in payload.objects { + match payload.operation { + Operation::Download => resp + .objects + .push(handle_download_object(&state, &repo, &obj, trusted).await), + Operation::Upload => { + if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await { + resp.objects.push(obj_resp); + } + } + }; + } + GitLfsJson(Json(resp)).into_response() +} + +fn s3_validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { + if let Some(checksum) = obj.checksum_sha256() { + if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { + if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { + return Oid::from(checksum32b) == oid; + } + } + } + true +} + +fn s3_validate_size(expected: i64, obj: &HeadObjectOutput) -> bool { + if let Some(length) = obj.content_length() { + return length == expected; + } + true +} + +fn repo_exists(name: &str) -> bool { + let Ok(metadata) = std::fs::metadata(name) else { + return false; + }; + metadata.is_dir() +} + +fn is_repo_public(name: &str) -> Option { + if !repo_exists(name) { + return None; + } + match std::fs::metadata(format!("{name}/git-daemon-export-ok")) { + Ok(metadata) if metadata.is_file() => Some(true), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some(false), + _ => None, + } +} diff --git a/gitolfs3-server/src/main.rs b/gitolfs3-server/src/main.rs index c9911ed..46e840a 100644 --- a/gitolfs3-server/src/main.rs +++ b/gitolfs3-server/src/main.rs @@ -10,12 +10,13 @@ use dlimit::DownloadLimiter; use axum::{ extract::OriginalUri, - http::{StatusCode, Uri}, + http::{self, Uri}, routing::{get, post}, Router, ServiceExt, }; -use handler::AppState; +use handler::{handle_batch, handle_obj_download, AppState}; use std::{process::ExitCode, sync::Arc}; +use tokio::net::TcpListener; use tower::Layer; #[tokio::main] @@ -39,14 +40,14 @@ async fn main() -> ExitCode { dl_limiter, }); let app = Router::new() - .route("/batch", post(handler::batch)) - .route("/:oid0/:oid1/:oid", get(handler::obj_download)) + .route("/batch", post(handle_batch)) + .route("/:oid0/:oid1/:oid", get(handle_obj_download)) .with_state(shared_state); let middleware = axum::middleware::map_request(rewrite_url); let app_with_middleware = middleware.layer(app); - let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await { + let listener = match TcpListener::bind(conf.listen_addr).await { Ok(listener) => listener, Err(e) => { println!("Failed to listen: {e}"); @@ -63,25 +64,23 @@ async fn main() -> ExitCode { } } -async fn rewrite_url( - mut req: axum::http::Request, -) -> Result, StatusCode> { +async fn rewrite_url(mut req: http::Request) -> Result, http::StatusCode> { let uri = req.uri(); let original_uri = OriginalUri(uri.clone()); let Some(path_and_query) = uri.path_and_query() else { // L @ no path & query - return Err(StatusCode::BAD_REQUEST); + return Err(http::StatusCode::BAD_REQUEST); }; let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { - return Err(StatusCode::NOT_FOUND); + return Err(http::StatusCode::NOT_FOUND); }; let repo = repo .trim_start_matches('/') .trim_end_matches('/') .to_string(); if !path.starts_with('/') || !repo.ends_with(".git") { - return Err(StatusCode::NOT_FOUND); + return Err(http::StatusCode::NOT_FOUND); } let mut parts = uri.clone().into_parts(); @@ -90,7 +89,7 @@ async fn rewrite_url( Some(q) => format!("{path}?{q}").try_into().ok(), }; let Ok(new_uri) = Uri::from_parts(parts) else { - return Err(StatusCode::INTERNAL_SERVER_ERROR); + return Err(http::StatusCode::INTERNAL_SERVER_ERROR); }; *req.uri_mut() = new_uri; -- cgit v1.2.3