From f5ff2803af0e03f57ab3093a9384d91abb9de083 Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Mon, 22 Jan 2024 22:52:01 +0100 Subject: Finish basic implementation of Rust LFS server --- rs/Cargo.lock | 1 + rs/common/src/lib.rs | 43 ++- rs/git-lfs-authenticate/src/main.rs | 15 +- rs/server/Cargo.toml | 1 + rs/server/src/main.rs | 548 +++++++++++++++++++++++++++++------- 5 files changed, 483 insertions(+), 125 deletions(-) diff --git a/rs/Cargo.lock b/rs/Cargo.lock index 29a2eac..5a83471 100644 --- a/rs/Cargo.lock +++ b/rs/Cargo.lock @@ -1884,6 +1884,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "tokio-util", "tower", "tower-service", ] diff --git a/rs/common/src/lib.rs b/rs/common/src/lib.rs index aafe7f1..27205bd 100644 --- a/rs/common/src/lib.rs +++ b/rs/common/src/lib.rs @@ -37,8 +37,9 @@ impl FromStr for Operation { } #[repr(u8)] -pub enum AuthType { - GitLfsAuthenticate = 1, +enum AuthType { + BatchApi = 1, + Download = 2, } /// None means out of range. @@ -156,6 +157,12 @@ impl SafeByteArray { } } +impl Default for SafeByteArray { + fn default() -> Self { + Self::new() + } +} + impl AsRef<[u8]> for SafeByteArray { fn as_ref(&self) -> &[u8] { &self.inner @@ -184,10 +191,18 @@ impl FromStr for SafeByteArray { } } +pub type Oid = Digest<32>; + +#[derive(Debug, Copy, Clone)] +pub enum SpecificClaims { + BatchApi(Operation), + Download(Oid), +} + +#[derive(Debug, Copy, Clone)] pub struct Claims<'a> { - pub auth_type: AuthType, + pub specific_claims: SpecificClaims, pub repo_path: &'a str, - pub operation: Operation, pub expires_at: DateTime, } @@ -198,10 +213,18 @@ pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option> } let mut hmac = hmac_sha256::HMAC::new(key); - hmac.update([claims.auth_type as u8]); + match claims.specific_claims { + SpecificClaims::BatchApi(operation) => { + hmac.update([AuthType::BatchApi as u8]); + hmac.update([operation as u8]); + } + SpecificClaims::Download(oid) => { + hmac.update([AuthType::Download as u8]); + hmac.update(oid.as_bytes()); + } + } hmac.update([claims.repo_path.len() as u8]); hmac.update(claims.repo_path.as_bytes()); - hmac.update([claims.operation as u8]); hmac.update(claims.expires_at.timestamp().to_be_bytes()); Some(hmac.finalize().into()) } @@ -280,9 +303,9 @@ impl From<[u8; N]> for Digest { } } -impl Into<[u8; N]> for Digest { - fn into(self) -> [u8; N] { - self.inner +impl From> for [u8; N] { + fn from(val: Digest) -> Self { + val.inner } } @@ -304,7 +327,7 @@ impl ConstantTimeEq for Digest { impl PartialEq for Digest { fn eq(&self, other: &Self) -> bool { - self.ct_eq(&other).into() + self.ct_eq(other).into() } } diff --git a/rs/git-lfs-authenticate/src/main.rs b/rs/git-lfs-authenticate/src/main.rs index db95923..36d7818 100644 --- a/rs/git-lfs-authenticate/src/main.rs +++ b/rs/git-lfs-authenticate/src/main.rs @@ -148,30 +148,30 @@ struct Config { #[derive(Debug, Eq, PartialEq, Copy, Clone)] enum LoadConfigError { - BaseUrlMissing, + BaseUrlNotProvided, BaseUrlSlashSuffixMissing, - KeyPathMissing, + KeyPathNotProvided, } impl fmt::Display for LoadConfigError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::BaseUrlMissing => write!(f, "base URL not provided"), + Self::BaseUrlNotProvided => write!(f, "base URL not provided"), Self::BaseUrlSlashSuffixMissing => write!(f, "base URL does not end with slash"), - Self::KeyPathMissing => write!(f, "key path not provided"), + Self::KeyPathNotProvided => write!(f, "key path not provided"), } } } fn load_config() -> Result { let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else { - return Err(LoadConfigError::BaseUrlMissing); + return Err(LoadConfigError::BaseUrlNotProvided); }; if !href_base.ends_with('/') { return Err(LoadConfigError::BaseUrlSlashSuffixMissing); } let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else { - return Err(LoadConfigError::KeyPathMissing); + return Err(LoadConfigError::KeyPathNotProvided); }; Ok(Config { href_base, @@ -213,10 +213,9 @@ fn main() -> ExitCode { let expires_at = Utc::now() + Duration::from_secs(5 * 60); let Some(tag) = common::generate_tag( common::Claims { - auth_type: common::AuthType::GitLfsAuthenticate, + specific_claims: common::SpecificClaims::BatchApi(operation), repo_path: &repo_name, expires_at, - operation, }, key, ) else { diff --git a/rs/server/Cargo.toml b/rs/server/Cargo.toml index 9a2a9a9..987e154 100644 --- a/rs/server/Cargo.toml +++ b/rs/server/Cargo.toml @@ -15,5 +15,6 @@ mime = "0.3" serde = { version = "1", features = ["derive"] } serde_json = "1" tokio = { version = "1.35", features = ["full"] } +tokio-util = "0.7" tower = "0.4" tower-service = "0.3" diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs index 0266f61..99805a4 100644 --- a/rs/server/src/main.rs +++ b/rs/server/src/main.rs @@ -1,7 +1,9 @@ use std::collections::HashMap; use std::collections::HashSet; +use std::process::ExitCode; use std::sync::Arc; +use aws_sdk_s3::error::SdkError; use aws_sdk_s3::operation::head_object::HeadObjectOutput; use axum::extract::rejection; use axum::extract::FromRequest; @@ -15,7 +17,6 @@ use axum::Json; use axum::ServiceExt; use base64::prelude::*; use chrono::DateTime; -use chrono::Duration; use chrono::Utc; use common::HexByte; use serde::de; @@ -29,12 +30,10 @@ use axum::{ extract::{FromRequestParts, OriginalUri, Request}, http::{request::Parts, StatusCode, Uri}, response::IntoResponse, - routing::{get, post, put}, + routing::{get, post}, Extension, Router, }; -use serde_json::json; - #[derive(Clone)] struct RepositoryName(String); @@ -65,7 +64,10 @@ async fn rewrite_url( let uri = req.uri(); let original_uri = OriginalUri(uri.clone()); - let path_and_query = uri.path_and_query().unwrap(); + let Some(path_and_query) = uri.path_and_query() else { + // L @ no path & query + return Err(StatusCode::BAD_REQUEST); + }; let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { return Err(StatusCode::NOT_FOUND); }; @@ -73,7 +75,7 @@ async fn rewrite_url( .trim_start_matches('/') .trim_end_matches('/') .to_string(); - if !path.starts_with("/") || !repo.ends_with(".git") { + if !path.starts_with('/') || !repo.ends_with(".git") { return Err(StatusCode::NOT_FOUND); } @@ -82,7 +84,9 @@ async fn rewrite_url( None => path.try_into().ok(), Some(q) => format!("{path}?{q}").try_into().ok(), }; - let new_uri = Uri::from_parts(parts).unwrap(); + let Ok(new_uri) = Uri::from_parts(parts) else { + return Err(StatusCode::INTERNAL_SERVER_ERROR); + }; *req.uri_mut() = new_uri; req.extensions_mut().insert(original_uri); @@ -95,11 +99,45 @@ struct AppState { s3_client: aws_sdk_s3::Client, s3_bucket: String, authz_conf: AuthorizationConfig, + // Should not end with a slash. + base_url: String, +} + +struct Env { + s3_access_key_id: String, + s3_secret_access_key: String, + s3_bucket: String, + s3_endpoint: String, + base_url: String, + key_path: String, + listen_host: String, + listen_port: String, + trusted_forwarded_hosts: String, +} + +fn require_env(name: &str) -> Result { + std::env::var(name).map_err(|_| format!("environment variable {name} should be defined and valid")) +} + +impl Env { + fn load() -> Result { + Ok(Env { + s3_secret_access_key: require_env("GITOLFS3_S3_ACCESS_KEY_FILE")?, + s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?, + s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?, + s3_bucket: require_env("GITOLFS3_S3_BUCKET")?, + base_url: require_env("GITOLFS3_BASE_URL")?, + key_path: require_env("GITOLFS3_KEY_PATH")?, + listen_host: require_env("GITOLFS3_LISTEN_HOST")?, + listen_port: require_env("GITOLFS3_LISTEN_PORT")?, + trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS").unwrap_or_default(), + }) + } } -fn get_s3_client() -> aws_sdk_s3::Client { - let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); - let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); +fn get_s3_client(env: &Env) -> Result { + let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?; + let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?; let credentials = aws_sdk_s3::config::Credentials::new( access_key_id, @@ -109,51 +147,85 @@ fn get_s3_client() -> aws_sdk_s3::Client { "gitolfs3-env", ); let config = aws_config::SdkConfig::builder() - .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) + .endpoint_url(&env.s3_endpoint) .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new( credentials, )) .build(); - aws_sdk_s3::Client::new(&config) + Ok(aws_sdk_s3::Client::new(&config)) } #[tokio::main] -async fn main() { - // run our app with hyper, listening globally on port 3000 - let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - - let key_path = std::env::var("GITOLFS3_KEY_PATH").unwrap(); - let key = common::load_key(&key_path).unwrap(); - let trusted_forwarded_hosts = std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS").unwrap(); - let trusted_forwarded_hosts: HashSet = trusted_forwarded_hosts +async fn main() -> ExitCode { + let env = match Env::load() { + Ok(env) => env, + Err(e) => { + println!("Failed to load configuration: {e}"); + return ExitCode::from(2); + } + }; + + let s3_client = match get_s3_client(&env) { + Ok(s3_client) => s3_client, + Err(e) => { + println!("Failed to create S3 client: {e}"); + return ExitCode::FAILURE; + }, + }; + let key = match common::load_key(&env.key_path) { + Ok(key) => key, + Err(e) => { + println!("Failed to load Gitolfs3 key: {e}"); + return ExitCode::FAILURE; + } + }; + + let trusted_forwarded_hosts: HashSet = env.trusted_forwarded_hosts .split(',') .map(|s| s.to_owned()) + .filter(|s| !s.is_empty()) .collect(); + let base_url = env.base_url.trim_end_matches('/').to_string(); let authz_conf = AuthorizationConfig { key, trusted_forwarded_hosts, }; - let s3_client = get_s3_client(); - let s3_bucket = std::env::var("S3_BUCKET").unwrap(); let shared_state = Arc::new(AppState { s3_client, - s3_bucket, + s3_bucket: env.s3_bucket, authz_conf, + base_url, }); let app = Router::new() .route("/batch", post(batch)) .route("/:oid0/:oid1/:oid", get(obj_download)) - .route("/:oid0/:oid1/:oid", put(obj_upload)) .with_state(shared_state); let middleware = axum::middleware::map_request(rewrite_url); let app_with_middleware = middleware.layer(app); - axum::serve(listener, app_with_middleware.into_make_service()) - .await - .unwrap(); + let Ok(listen_port): Result = env.listen_port.parse() else { + println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535"); + return ExitCode::from(2); + }; + let addr: (String, u16) = (env.listen_host, listen_port); + let listener = match tokio::net::TcpListener::bind(addr).await { + Ok(listener) => listener, + Err(e) => { + println!("Failed to listen: {e}"); + return ExitCode::FAILURE; + } + }; + + match axum::serve(listener, app_with_middleware.into_make_service()).await { + Ok(_) => ExitCode::SUCCESS, + Err(e) => { + println!("Error serving: {e}"); + ExitCode::FAILURE + } + } } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] @@ -174,11 +246,9 @@ impl Default for HashAlgo { } } -type Oid = common::Digest<32>; - #[derive(Debug, Deserialize, Clone)] struct BatchRequestObject { - oid: Oid, + oid: common::Oid, size: i64, } @@ -196,8 +266,6 @@ struct BatchRequest { operation: common::Operation, #[serde(default = "default_transfers")] transfers: Vec, - #[serde(rename = "ref")] - reference: Option, objects: Vec, #[serde(default)] hash_algo: HashAlgo, @@ -206,7 +274,7 @@ struct BatchRequest { #[derive(Clone)] struct GitLfsJson(Json); -const LFS_MIME: &'static str = "application/vnd.git-lfs+json"; +const LFS_MIME: &str = "application/vnd.git-lfs+json"; enum GitLfsJsonRejection { Json(rejection::JsonRejection), @@ -246,7 +314,7 @@ fn has_git_lfs_json_content_type(req: &Request) -> bool { let Ok(content_type) = content_type.to_str() else { return false; }; - return is_git_lfs_json_mimetype(content_type); + is_git_lfs_json_mimetype(content_type) } #[async_trait] @@ -287,7 +355,7 @@ struct GitLfsErrorData<'a> { type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson>); -const fn make_error_resp<'a>(code: StatusCode, message: &'a str) -> GitLfsErrorResponse { +const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse { (code, GitLfsJson(Json(GitLfsErrorData { message }))) } @@ -309,13 +377,36 @@ struct BatchResponseObjectActions { verify: Option, } +#[derive(Debug, Clone, Serialize)] +struct BatchResponseObjectError { + code: u16, + message: String, +} + #[derive(Debug, Serialize, Clone)] struct BatchResponseObject { - oid: Oid, + oid: common::Oid, size: i64, #[serde(skip_serializing_if = "Option::is_none")] authenticated: Option, actions: BatchResponseObjectActions, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +impl BatchResponseObject { + fn error(obj: &BatchRequestObject, code: StatusCode, message: String) -> BatchResponseObject { + BatchResponseObject { + oid: obj.oid, + size: obj.size, + authenticated: None, + actions: Default::default(), + error: Some(BatchResponseObjectError { + code: code.as_u16(), + message, + }), + } + } } #[derive(Debug, Serialize, Clone)] @@ -325,11 +416,11 @@ struct BatchResponse { hash_algo: HashAlgo, } -fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { +fn validate_checksum(oid: common::Oid, obj: &HeadObjectOutput) -> bool { if let Some(checksum) = obj.checksum_sha256() { if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { - return Oid::from(checksum32b) == oid; + return common::Oid::from(checksum32b) == oid; } } } @@ -343,11 +434,15 @@ fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool { true } -async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject, trusted: bool) -> BatchResponseObject { +async fn handle_upload_object( + state: &AppState, + repo: &str, + obj: &BatchRequestObject, +) -> Option { let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); - let result = state + match state .s3_client .head_object() .bucket(&state.s3_bucket) @@ -355,36 +450,189 @@ async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequest .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) .send() .await - .unwrap(); // TODO: don't unwrap() + { + Ok(result) => { + if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) { + return None; + } + } + Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {} + _ => { + return Some(BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to query object information".to_string(), + )); + } + }; + + let expires_in = std::time::Duration::from_secs(5 * 60); + let expires_at = Utc::now() + expires_in; + + let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else { + return Some(BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + )); + }; + let Ok(presigned) = state + .s3_client + .put_object() + .checksum_sha256(obj.oid.to_string()) + .content_length(obj.size) + .presigned(config) + .await + else { + return Some(BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + )); + }; + Some(BatchResponseObject { + oid: obj.oid, + size: obj.size, + authenticated: Some(true), + actions: BatchResponseObjectActions { + download: Some(BatchResponseObjectAction { + header: presigned + .headers() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect(), + expires_at, + href: presigned.uri().to_string(), + }), + ..Default::default() + }, + error: None, + }) +} + +async fn handle_download_object( + state: &AppState, + repo: &str, + obj: &BatchRequestObject, + trusted: bool, +) -> BatchResponseObject { + let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); + let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); + + let result = match state + .s3_client + .head_object() + .bucket(&state.s3_bucket) + .key(full_path) + .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) + .send() + .await + { + Ok(result) => result, + Err(_) => { + return BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to query object information".to_string(), + ) + } + }; + // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :) if !validate_checksum(obj.oid, &result) { - todo!(); + return BatchResponseObject::error( + obj, + StatusCode::UNPROCESSABLE_ENTITY, + "Object corrupted".to_string(), + ); } if !validate_size(obj.size, &result) { - todo!(); + return BatchResponseObject::error( + obj, + StatusCode::UNPROCESSABLE_ENTITY, + "Incorrect size specified (or object corrupted)".to_string(), + ); } let expires_in = std::time::Duration::from_secs(5 * 60); let expires_at = Utc::now() + expires_in; if trusted { - let config = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in).unwrap(); - let presigned = state.s3_client.get_object().presigned(config).await.unwrap(); - return BatchResponseObject{ + let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else { + return BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + ); + }; + let Ok(presigned) = state.s3_client.get_object().presigned(config).await else { + return BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to generate upload URL".to_string(), + ); + }; + return BatchResponseObject { oid: obj.oid, size: obj.size, authenticated: Some(true), actions: BatchResponseObjectActions { - download: Some(BatchResponseObjectAction{ - header: presigned.headers().map(|(k, v)| (k.to_owned(), v.to_owned())).collect(), + download: Some(BatchResponseObjectAction { + header: presigned + .headers() + .map(|(k, v)| (k.to_owned(), v.to_owned())) + .collect(), expires_at, href: presigned.uri().to_string(), }), ..Default::default() - } + }, + error: None, }; } - todo!(); + + let Some(tag) = common::generate_tag( + common::Claims { + specific_claims: common::SpecificClaims::Download(obj.oid), + repo_path: repo, + expires_at, + }, + &state.authz_conf.key, + ) else { + return BatchResponseObject::error( + obj, + StatusCode::INTERNAL_SERVER_ERROR, + "Internal server error".to_string(), + ); + }; + + let upload_path = format!( + "{repo}/info/lfs/objects/{}/{}/{}", + HexByte(obj.oid[0]), + HexByte(obj.oid[1]), + obj.oid, + ); + + BatchResponseObject { + oid: obj.oid, + size: obj.size, + authenticated: Some(true), + actions: BatchResponseObjectActions { + download: Some(BatchResponseObjectAction { + header: { + let mut map = HashMap::new(); + map.insert( + "Authorization".to_string(), + format!("Gitolfs3-Hmac-Sha256 {tag} {}", expires_at.timestamp()), + ); + map + }, + expires_at, + href: format!("{}/{upload_path}", state.base_url), + }), + ..Default::default() + }, + error: None, + } } struct AuthorizationConfig { @@ -410,17 +658,17 @@ fn forwarded_for_trusted_host( )); } } - return Ok(false); + Ok(false) } const REPO_NOT_FOUND: GitLfsErrorResponse = make_error_resp(StatusCode::NOT_FOUND, "Repository not found"); -fn authorize( +fn authorize_batch( conf: &AuthorizationConfig, - headers: &HeaderMap, repo_path: &str, public: bool, operation: common::Operation, + headers: &HeaderMap, ) -> Result> { // - No authentication required for downloading exported repos // - When authenticated: @@ -428,46 +676,12 @@ fn authorize( // - When accessing over Tailscale: // - No authentication required for downloading from any repo - const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = - make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header"); - - if let Some(authz) = headers.get(header::AUTHORIZATION) { - if let Ok(authz) = authz.to_str() { - if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") { - let Some((tag, expires_at)) = val.split_once(' ') else { - return Err(INVALID_AUTHZ_HEADER); - }; - let Ok(tag): Result, _> = tag.parse() else { - return Err(INVALID_AUTHZ_HEADER); - }; - let Ok(expires_at): Result = expires_at.parse() else { - return Err(INVALID_AUTHZ_HEADER); - }; - let Some(expires_at) = DateTime::::from_timestamp(expires_at, 0) else { - return Err(INVALID_AUTHZ_HEADER); - }; - let Some(expected_tag) = common::generate_tag( - common::Claims { - auth_type: common::AuthType::GitLfsAuthenticate, - repo_path, - expires_at, - operation, - }, - &conf.key, - ) else { - return Err(INVALID_AUTHZ_HEADER); - }; - if tag == expected_tag { - return Ok(Trusted(true)); - } else { - return Err(INVALID_AUTHZ_HEADER); - } - } else { - return Err(INVALID_AUTHZ_HEADER); - } - } else { - return Err(INVALID_AUTHZ_HEADER); - } + let claims = VerifyClaimsInput { + specific_claims: common::SpecificClaims::BatchApi(operation), + repo_path, + }; + if verify_claims(conf, &claims, headers)? { + return Ok(Trusted(true)); } let trusted = forwarded_for_trusted_host(headers, &conf.trusted_forwarded_hosts)?; @@ -495,7 +709,7 @@ fn repo_exists(name: &str) -> bool { let Ok(metadata) = std::fs::metadata(name) else { return false; }; - return metadata.is_dir(); + metadata.is_dir() } fn is_repo_public(name: &str) -> Option { @@ -517,12 +731,12 @@ async fn batch( let Some(public) = is_repo_public(&repo) else { return REPO_NOT_FOUND.into_response(); }; - let Trusted(trusted) = match authorize( + let Trusted(trusted) = match authorize_batch( &state.authz_conf, - &headers, &repo, public, payload.operation, + &headers, ) { Ok(authn) => authn, Err(e) => return e.into_response(), @@ -547,16 +761,24 @@ async fn batch( return make_error_resp(StatusCode::CONFLICT, message).into_response(); } - let resp: BatchResponse; + let mut resp = BatchResponse { + transfer: TransferAdapter::Basic, + objects: vec![], + hash_algo: HashAlgo::Sha256, + }; for obj in payload.objects { - handle_download_object(&state, &repo, &obj, trusted).await; - // match payload.operation { - // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, - // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), - // }; + match payload.operation { + common::Operation::Download => resp + .objects + .push(handle_download_object(&state, &repo, &obj, trusted).await), + common::Operation::Upload => { + if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await { + resp.objects.push(obj_resp); + } + } + }; } - - format!("hi from {repo}\n").into_response() + GitLfsJson(Json(resp)).into_response() } #[derive(Deserialize, Copy, Clone)] @@ -564,7 +786,7 @@ async fn batch( struct FileParams { oid0: HexByte, oid1: HexByte, - oid: Oid, + oid: common::Oid, } impl<'de> Deserialize<'de> for FileParams { @@ -591,6 +813,118 @@ impl<'de> Deserialize<'de> for FileParams { } } -async fn obj_download(Path(FileParams { oid0, oid1, oid }): Path) {} +pub struct VerifyClaimsInput<'a> { + pub specific_claims: common::SpecificClaims, + pub repo_path: &'a str, +} + +// Note: expires_at is ignored. +fn verify_claims( + conf: &AuthorizationConfig, + claims: &VerifyClaimsInput, + headers: &HeaderMap, +) -> Result> { + const INVALID_AUTHZ_HEADER: GitLfsErrorResponse = + make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header"); + + if let Some(authz) = headers.get(header::AUTHORIZATION) { + if let Ok(authz) = authz.to_str() { + if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") { + let (tag, expires_at) = val.split_once(' ').ok_or(INVALID_AUTHZ_HEADER)?; + let tag: common::Digest<32> = tag.parse().map_err(|_| INVALID_AUTHZ_HEADER)?; + let expires_at: i64 = expires_at.parse().map_err(|_| INVALID_AUTHZ_HEADER)?; + let expires_at = + DateTime::::from_timestamp(expires_at, 0).ok_or(INVALID_AUTHZ_HEADER)?; + let Some(expected_tag) = common::generate_tag( + common::Claims { + specific_claims: claims.specific_claims, + repo_path: claims.repo_path, + expires_at, + }, + &conf.key, + ) else { + return Err(make_error_resp( + StatusCode::INTERNAL_SERVER_ERROR, + "Internal server error", + )); + }; + if tag == expected_tag { + return Ok(true); + } + } + } + return Err(INVALID_AUTHZ_HEADER); + } + Ok(false) +} + +fn authorize_get( + conf: &AuthorizationConfig, + repo_path: &str, + oid: common::Oid, + headers: &HeaderMap, +) -> Result<(), GitLfsErrorResponse<'static>> { + let claims = VerifyClaimsInput { + specific_claims: common::SpecificClaims::Download(oid), + repo_path, + }; + if !verify_claims(conf, &claims, headers)? { + return Err(make_error_resp( + StatusCode::UNAUTHORIZED, + "Repository not found", + )); + } + Ok(()) +} -async fn obj_upload(Path(FileParams { oid0, oid1, oid }): Path) {} +async fn obj_download( + State(state): State>, + headers: HeaderMap, + RepositoryName(repo): RepositoryName, + Path(FileParams { oid0, oid1, oid }): Path, +) -> Response { + if let Err(e) = authorize_get(&state.authz_conf, &repo, oid, &headers) { + return e.into_response(); + } + + let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, oid); + let result = match state + .s3_client + .get_object() + .bucket(&state.s3_bucket) + .key(full_path) + .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) + .send() + .await + { + Ok(result) => result, + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to query object information", + ) + .into_response(); + } + }; + + let mut headers = header::HeaderMap::new(); + if let Some(content_type) = result.content_type { + let Ok(header_value) = content_type.try_into() else { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Object has invalid content type", + ) + .into_response(); + }; + headers.insert(header::CONTENT_TYPE, header_value); + } + if let Some(content_length) = result.content_length { + headers.insert(header::CONTENT_LENGTH, content_length.into()); + } + + let async_read = result.body.into_async_read(); + let stream = tokio_util::io::ReaderStream::new(async_read); + let body = axum::body::Body::from_stream(stream); + + (headers, body).into_response() +} -- cgit v1.2.3