From fa2a763777d775a853ec41d73efb345e304c81e7 Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Fri, 19 Jan 2024 01:53:04 +0100 Subject: Separate path rewrite into Tower Layer, shared state for S3 client --- rs/server/src/main.rs | 133 +++++++++++++++++++++++++++++++------------------- 1 file changed, 84 insertions(+), 49 deletions(-) (limited to 'rs/server/src/main.rs') diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs index 8fe1d16..7d8b83a 100644 --- a/rs/server/src/main.rs +++ b/rs/server/src/main.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; +use std::sync::Arc; -use awscreds::Credentials; +use axum::extract::State; use axum::extract::rejection; use axum::extract::FromRequest; use axum::extract::Path; @@ -9,23 +10,23 @@ use axum::http::HeaderMap; use axum::http::HeaderValue; use axum::response::Response; use axum::Json; +use base64::prelude::*; use chrono::DateTime; use chrono::Utc; use common::HexByte; -use common::Operation; -use s3::Bucket; use serde::de; use serde::de::DeserializeOwned; use serde::Deserialize; use serde::Serialize; -use tower_service::Service; +use tower::Layer; +use axum::ServiceExt; use axum::{ async_trait, extract::{FromRequestParts, OriginalUri, Request}, http::{request::Parts, StatusCode, Uri}, response::IntoResponse, - routing::{any, get, post, put}, + routing::{get, post, put}, Extension, Router, }; @@ -53,46 +54,71 @@ impl FromRequestParts for RepositoryName { } } +async fn rewrite_url(mut req: axum::http::Request) -> Result, StatusCode> { + let uri = req.uri(); + let original_uri = OriginalUri(uri.clone()); + + let path_and_query = uri.path_and_query().unwrap(); + let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { + return Err(StatusCode::NOT_FOUND); + }; + let repo = repo + .trim_start_matches('/') + .trim_end_matches('/') + .to_string(); + if !path.starts_with("/") || !repo.ends_with(".git") { + return Err(StatusCode::NOT_FOUND); + } + + let mut parts = uri.clone().into_parts(); + parts.path_and_query = match path_and_query.query() { + None => path.try_into().ok(), + Some(q) => format!("{path}?{q}").try_into().ok(), + }; + let new_uri = Uri::from_parts(parts).unwrap(); + + *req.uri_mut() = new_uri; + req.extensions_mut().insert(original_uri); + req.extensions_mut().insert(RepositoryName(repo)); + + Ok(req) +} + +struct AppState { + s3_client: aws_sdk_s3::Client, + s3_bucket: String, +} + +fn get_s3_client() -> aws_sdk_s3::Client { + let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); + let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); + + let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env"); + let config = aws_config::SdkConfig::builder() + .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) + .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials)) + .build(); + aws_sdk_s3::Client::new(&config) +} + #[tokio::main] async fn main() { // run our app with hyper, listening globally on port 3000 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); - let mut app = Router::new() + + let s3_client = get_s3_client(); + let s3_bucket = std::env::var("S3_BUCKET").unwrap(); + let shared_state = Arc::new(AppState { s3_client, s3_bucket }); + let app = Router::new() .route("/batch", post(batch)) .route("/:oid0/:oid1/:oid", get(obj_download)) - .route("/:oid0/:oid1/:oid", put(obj_upload)); - axum::serve( - listener, - any(|mut req: Request| async move { - let uri = req.uri(); - let original_uri = OriginalUri(uri.clone()); - - let path_and_query = uri.path_and_query().unwrap(); - let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { - return Ok(StatusCode::NOT_FOUND.into_response()); - }; - let repo = repo - .trim_start_matches('/') - .trim_end_matches('/') - .to_string(); - if !path.starts_with("/") || !repo.ends_with(".git") { - return Ok(StatusCode::NOT_FOUND.into_response()); - } + .route("/:oid0/:oid1/:oid", put(obj_upload)) + .with_state(shared_state); - let mut parts = uri.clone().into_parts(); - parts.path_and_query = match path_and_query.query() { - None => path.try_into().ok(), - Some(q) => format!("{path}?{q}").try_into().ok(), - }; - let new_uri = Uri::from_parts(parts).unwrap(); + let middleware = axum::middleware::map_request(rewrite_url); + let app_with_middleware = middleware.layer(app); - *req.uri_mut() = new_uri; - req.extensions_mut().insert(original_uri); - req.extensions_mut().insert(RepositoryName(repo)); - - app.call(req).await - }), - ) + axum::serve(listener, app_with_middleware.into_make_service()) .await .unwrap(); } @@ -254,20 +280,28 @@ struct BatchResponse { hash_algo: HashAlgo, } -//fn handle_download_object(repo: &str, obj: &BatchRequestObject) { -// let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); -// let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); -// -// let bucket_anme = "asdfasdf"; -// let region = s3::Region::Custom { -// region: "nl-ams".to_string(), -// endpoint: "rg.nl-ams.swc.cloud".to_string() -// }; -// let credentials = Credentials::new(None, None, None, None, None).unwrap(); -// let bucket = Bucket::new(bucket_anme, region, credentials).unwrap(); -//} +async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { + let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); + let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); + + let result = state.s3_client.head_object(). + bucket(&state.s3_bucket). + key(full_path). + checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled). + send().await.unwrap(); + if let Some(checksum) = result.checksum_sha256() { + if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { + if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { + if Oid::from(checksum32b) != obj.oid { + unreachable!(); + } + } + } + } +} async fn batch( + State(state): State>, header: HeaderMap, RepositoryName(repo): RepositoryName, GitLfsJson(Json(payload)): GitLfsJson, @@ -301,6 +335,7 @@ async fn batch( let resp: BatchResponse; for obj in payload.objects { + handle_download_object(&state, &repo, &obj).await; // match payload.operation { // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), -- cgit v1.2.3