aboutsummaryrefslogtreecommitdiffstats
path: root/rs/server
diff options
context:
space:
mode:
authorLibravatar Rutger Broekhoff2024-01-19 01:53:04 +0100
committerLibravatar Rutger Broekhoff2024-01-19 01:53:04 +0100
commitfa2a763777d775a853ec41d73efb345e304c81e7 (patch)
tree0735621337ee93e41e00a0d52fbdabff93b9e8fc /rs/server
parent5156228f18f08255a1f5c7e22097b8e367881e19 (diff)
downloadgitolfs3-fa2a763777d775a853ec41d73efb345e304c81e7.tar.gz
gitolfs3-fa2a763777d775a853ec41d73efb345e304c81e7.zip
Separate path rewrite into Tower Layer, shared state for S3 client
Diffstat (limited to 'rs/server')
-rw-r--r--rs/server/Cargo.toml5
-rw-r--r--rs/server/src/main.rs133
2 files changed, 88 insertions, 50 deletions
diff --git a/rs/server/Cargo.toml b/rs/server/Cargo.toml
index ac571af..188299a 100644
--- a/rs/server/Cargo.toml
+++ b/rs/server/Cargo.toml
@@ -4,12 +4,15 @@ version = "0.1.0"
4edition = "2021" 4edition = "2021"
5 5
6[dependencies] 6[dependencies]
7aws-config = { version = "1.1.2", features = ["behavior-version-latest"] }
8aws-sdk-s3 = "1.12.0"
7axum = "0.7" 9axum = "0.7"
8aws-creds = "0.34" 10aws-creds = "0.34"
11base64 = "0.21"
9chrono = { version = "0.4", features = ["serde"] } 12chrono = { version = "0.4", features = ["serde"] }
10common = { path = "../common" } 13common = { path = "../common" }
11mime = "0.3" 14mime = "0.3"
12rust-s3 = "0.33"
13serde = { version = "1", features = ["derive"] } 15serde = { version = "1", features = ["derive"] }
14tokio = { version = "1.35", features = ["full"] } 16tokio = { version = "1.35", features = ["full"] }
17tower = "0.4"
15tower-service = "0.3" 18tower-service = "0.3"
diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs
index 8fe1d16..7d8b83a 100644
--- a/rs/server/src/main.rs
+++ b/rs/server/src/main.rs
@@ -1,6 +1,7 @@
1use std::collections::HashMap; 1use std::collections::HashMap;
2use std::sync::Arc;
2 3
3use awscreds::Credentials; 4use axum::extract::State;
4use axum::extract::rejection; 5use axum::extract::rejection;
5use axum::extract::FromRequest; 6use axum::extract::FromRequest;
6use axum::extract::Path; 7use axum::extract::Path;
@@ -9,23 +10,23 @@ use axum::http::HeaderMap;
9use axum::http::HeaderValue; 10use axum::http::HeaderValue;
10use axum::response::Response; 11use axum::response::Response;
11use axum::Json; 12use axum::Json;
13use base64::prelude::*;
12use chrono::DateTime; 14use chrono::DateTime;
13use chrono::Utc; 15use chrono::Utc;
14use common::HexByte; 16use common::HexByte;
15use common::Operation;
16use s3::Bucket;
17use serde::de; 17use serde::de;
18use serde::de::DeserializeOwned; 18use serde::de::DeserializeOwned;
19use serde::Deserialize; 19use serde::Deserialize;
20use serde::Serialize; 20use serde::Serialize;
21use tower_service::Service; 21use tower::Layer;
22use axum::ServiceExt;
22 23
23use axum::{ 24use axum::{
24 async_trait, 25 async_trait,
25 extract::{FromRequestParts, OriginalUri, Request}, 26 extract::{FromRequestParts, OriginalUri, Request},
26 http::{request::Parts, StatusCode, Uri}, 27 http::{request::Parts, StatusCode, Uri},
27 response::IntoResponse, 28 response::IntoResponse,
28 routing::{any, get, post, put}, 29 routing::{get, post, put},
29 Extension, Router, 30 Extension, Router,
30}; 31};
31 32
@@ -53,46 +54,71 @@ impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
53 } 54 }
54} 55}
55 56
57async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::Request<B>, StatusCode> {
58 let uri = req.uri();
59 let original_uri = OriginalUri(uri.clone());
60
61 let path_and_query = uri.path_and_query().unwrap();
62 let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else {
63 return Err(StatusCode::NOT_FOUND);
64 };
65 let repo = repo
66 .trim_start_matches('/')
67 .trim_end_matches('/')
68 .to_string();
69 if !path.starts_with("/") || !repo.ends_with(".git") {
70 return Err(StatusCode::NOT_FOUND);
71 }
72
73 let mut parts = uri.clone().into_parts();
74 parts.path_and_query = match path_and_query.query() {
75 None => path.try_into().ok(),
76 Some(q) => format!("{path}?{q}").try_into().ok(),
77 };
78 let new_uri = Uri::from_parts(parts).unwrap();
79
80 *req.uri_mut() = new_uri;
81 req.extensions_mut().insert(original_uri);
82 req.extensions_mut().insert(RepositoryName(repo));
83
84 Ok(req)
85}
86
87struct AppState {
88 s3_client: aws_sdk_s3::Client,
89 s3_bucket: String,
90}
91
92fn get_s3_client() -> aws_sdk_s3::Client {
93 let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap();
94 let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap();
95
96 let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env");
97 let config = aws_config::SdkConfig::builder()
98 .endpoint_url(std::env::var("S3_ENDPOINT").unwrap())
99 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials))
100 .build();
101 aws_sdk_s3::Client::new(&config)
102}
103
56#[tokio::main] 104#[tokio::main]
57async fn main() { 105async fn main() {
58 // run our app with hyper, listening globally on port 3000 106 // run our app with hyper, listening globally on port 3000
59 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); 107 let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
60 let mut app = Router::new() 108
109 let s3_client = get_s3_client();
110 let s3_bucket = std::env::var("S3_BUCKET").unwrap();
111 let shared_state = Arc::new(AppState { s3_client, s3_bucket });
112 let app = Router::new()
61 .route("/batch", post(batch)) 113 .route("/batch", post(batch))
62 .route("/:oid0/:oid1/:oid", get(obj_download)) 114 .route("/:oid0/:oid1/:oid", get(obj_download))
63 .route("/:oid0/:oid1/:oid", put(obj_upload)); 115 .route("/:oid0/:oid1/:oid", put(obj_upload))
64 axum::serve( 116 .with_state(shared_state);
65 listener,
66 any(|mut req: Request| async move {
67 let uri = req.uri();
68 let original_uri = OriginalUri(uri.clone());
69
70 let path_and_query = uri.path_and_query().unwrap();
71 let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else {
72 return Ok(StatusCode::NOT_FOUND.into_response());
73 };
74 let repo = repo
75 .trim_start_matches('/')
76 .trim_end_matches('/')
77 .to_string();
78 if !path.starts_with("/") || !repo.ends_with(".git") {
79 return Ok(StatusCode::NOT_FOUND.into_response());
80 }
81 117
82 let mut parts = uri.clone().into_parts(); 118 let middleware = axum::middleware::map_request(rewrite_url);
83 parts.path_and_query = match path_and_query.query() { 119 let app_with_middleware = middleware.layer(app);
84 None => path.try_into().ok(),
85 Some(q) => format!("{path}?{q}").try_into().ok(),
86 };
87 let new_uri = Uri::from_parts(parts).unwrap();
88 120
89 *req.uri_mut() = new_uri; 121 axum::serve(listener, app_with_middleware.into_make_service())
90 req.extensions_mut().insert(original_uri);
91 req.extensions_mut().insert(RepositoryName(repo));
92
93 app.call(req).await
94 }),
95 )
96 .await 122 .await
97 .unwrap(); 123 .unwrap();
98} 124}
@@ -254,20 +280,28 @@ struct BatchResponse {
254 hash_algo: HashAlgo, 280 hash_algo: HashAlgo,
255} 281}
256 282
257//fn handle_download_object(repo: &str, obj: &BatchRequestObject) { 283async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) {
258// let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 284 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
259// let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 285 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
260// 286
261// let bucket_anme = "asdfasdf"; 287 let result = state.s3_client.head_object().
262// let region = s3::Region::Custom { 288 bucket(&state.s3_bucket).
263// region: "nl-ams".to_string(), 289 key(full_path).
264// endpoint: "rg.nl-ams.swc.cloud".to_string() 290 checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled).
265// }; 291 send().await.unwrap();
266// let credentials = Credentials::new(None, None, None, None, None).unwrap(); 292 if let Some(checksum) = result.checksum_sha256() {
267// let bucket = Bucket::new(bucket_anme, region, credentials).unwrap(); 293 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
268//} 294 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
295 if Oid::from(checksum32b) != obj.oid {
296 unreachable!();
297 }
298 }
299 }
300 }
301}
269 302
270async fn batch( 303async fn batch(
304 State(state): State<Arc<AppState>>,
271 header: HeaderMap, 305 header: HeaderMap,
272 RepositoryName(repo): RepositoryName, 306 RepositoryName(repo): RepositoryName,
273 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, 307 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
@@ -301,6 +335,7 @@ async fn batch(
301 335
302 let resp: BatchResponse; 336 let resp: BatchResponse;
303 for obj in payload.objects { 337 for obj in payload.objects {
338 handle_download_object(&state, &repo, &obj).await;
304// match payload.operation { 339// match payload.operation {
305// Operation::Download => resp.objects.push(handle_download_object(repo, obj));, 340// Operation::Download => resp.objects.push(handle_download_object(repo, obj));,
306// Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), 341// Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)),