diff options
author | Rutger Broekhoff | 2024-01-19 01:53:04 +0100 |
---|---|---|
committer | Rutger Broekhoff | 2024-01-19 01:53:04 +0100 |
commit | fa2a763777d775a853ec41d73efb345e304c81e7 (patch) | |
tree | 0735621337ee93e41e00a0d52fbdabff93b9e8fc /rs/server | |
parent | 5156228f18f08255a1f5c7e22097b8e367881e19 (diff) | |
download | gitolfs3-fa2a763777d775a853ec41d73efb345e304c81e7.tar.gz gitolfs3-fa2a763777d775a853ec41d73efb345e304c81e7.zip |
Separate path rewrite into Tower Layer, shared state for S3 client
Diffstat (limited to 'rs/server')
-rw-r--r-- | rs/server/Cargo.toml | 5 | ||||
-rw-r--r-- | rs/server/src/main.rs | 133 |
2 files changed, 88 insertions, 50 deletions
diff --git a/rs/server/Cargo.toml b/rs/server/Cargo.toml index ac571af..188299a 100644 --- a/rs/server/Cargo.toml +++ b/rs/server/Cargo.toml | |||
@@ -4,12 +4,15 @@ version = "0.1.0" | |||
4 | edition = "2021" | 4 | edition = "2021" |
5 | 5 | ||
6 | [dependencies] | 6 | [dependencies] |
7 | aws-config = { version = "1.1.2", features = ["behavior-version-latest"] } | ||
8 | aws-sdk-s3 = "1.12.0" | ||
7 | axum = "0.7" | 9 | axum = "0.7" |
8 | aws-creds = "0.34" | 10 | aws-creds = "0.34" |
11 | base64 = "0.21" | ||
9 | chrono = { version = "0.4", features = ["serde"] } | 12 | chrono = { version = "0.4", features = ["serde"] } |
10 | common = { path = "../common" } | 13 | common = { path = "../common" } |
11 | mime = "0.3" | 14 | mime = "0.3" |
12 | rust-s3 = "0.33" | ||
13 | serde = { version = "1", features = ["derive"] } | 15 | serde = { version = "1", features = ["derive"] } |
14 | tokio = { version = "1.35", features = ["full"] } | 16 | tokio = { version = "1.35", features = ["full"] } |
17 | tower = "0.4" | ||
15 | tower-service = "0.3" | 18 | tower-service = "0.3" |
diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs index 8fe1d16..7d8b83a 100644 --- a/rs/server/src/main.rs +++ b/rs/server/src/main.rs | |||
@@ -1,6 +1,7 @@ | |||
1 | use std::collections::HashMap; | 1 | use std::collections::HashMap; |
2 | use std::sync::Arc; | ||
2 | 3 | ||
3 | use awscreds::Credentials; | 4 | use axum::extract::State; |
4 | use axum::extract::rejection; | 5 | use axum::extract::rejection; |
5 | use axum::extract::FromRequest; | 6 | use axum::extract::FromRequest; |
6 | use axum::extract::Path; | 7 | use axum::extract::Path; |
@@ -9,23 +10,23 @@ use axum::http::HeaderMap; | |||
9 | use axum::http::HeaderValue; | 10 | use axum::http::HeaderValue; |
10 | use axum::response::Response; | 11 | use axum::response::Response; |
11 | use axum::Json; | 12 | use axum::Json; |
13 | use base64::prelude::*; | ||
12 | use chrono::DateTime; | 14 | use chrono::DateTime; |
13 | use chrono::Utc; | 15 | use chrono::Utc; |
14 | use common::HexByte; | 16 | use common::HexByte; |
15 | use common::Operation; | ||
16 | use s3::Bucket; | ||
17 | use serde::de; | 17 | use serde::de; |
18 | use serde::de::DeserializeOwned; | 18 | use serde::de::DeserializeOwned; |
19 | use serde::Deserialize; | 19 | use serde::Deserialize; |
20 | use serde::Serialize; | 20 | use serde::Serialize; |
21 | use tower_service::Service; | 21 | use tower::Layer; |
22 | use axum::ServiceExt; | ||
22 | 23 | ||
23 | use axum::{ | 24 | use axum::{ |
24 | async_trait, | 25 | async_trait, |
25 | extract::{FromRequestParts, OriginalUri, Request}, | 26 | extract::{FromRequestParts, OriginalUri, Request}, |
26 | http::{request::Parts, StatusCode, Uri}, | 27 | http::{request::Parts, StatusCode, Uri}, |
27 | response::IntoResponse, | 28 | response::IntoResponse, |
28 | routing::{any, get, post, put}, | 29 | routing::{get, post, put}, |
29 | Extension, Router, | 30 | Extension, Router, |
30 | }; | 31 | }; |
31 | 32 | ||
@@ -53,46 +54,71 @@ impl<S: Send + Sync> FromRequestParts<S> for RepositoryName { | |||
53 | } | 54 | } |
54 | } | 55 | } |
55 | 56 | ||
57 | async fn rewrite_url<B>(mut req: axum::http::Request<B>) -> Result<axum::http::Request<B>, StatusCode> { | ||
58 | let uri = req.uri(); | ||
59 | let original_uri = OriginalUri(uri.clone()); | ||
60 | |||
61 | let path_and_query = uri.path_and_query().unwrap(); | ||
62 | let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { | ||
63 | return Err(StatusCode::NOT_FOUND); | ||
64 | }; | ||
65 | let repo = repo | ||
66 | .trim_start_matches('/') | ||
67 | .trim_end_matches('/') | ||
68 | .to_string(); | ||
69 | if !path.starts_with("/") || !repo.ends_with(".git") { | ||
70 | return Err(StatusCode::NOT_FOUND); | ||
71 | } | ||
72 | |||
73 | let mut parts = uri.clone().into_parts(); | ||
74 | parts.path_and_query = match path_and_query.query() { | ||
75 | None => path.try_into().ok(), | ||
76 | Some(q) => format!("{path}?{q}").try_into().ok(), | ||
77 | }; | ||
78 | let new_uri = Uri::from_parts(parts).unwrap(); | ||
79 | |||
80 | *req.uri_mut() = new_uri; | ||
81 | req.extensions_mut().insert(original_uri); | ||
82 | req.extensions_mut().insert(RepositoryName(repo)); | ||
83 | |||
84 | Ok(req) | ||
85 | } | ||
86 | |||
87 | struct AppState { | ||
88 | s3_client: aws_sdk_s3::Client, | ||
89 | s3_bucket: String, | ||
90 | } | ||
91 | |||
92 | fn get_s3_client() -> aws_sdk_s3::Client { | ||
93 | let access_key_id = std::env::var("S3_ACCESS_KEY_ID").unwrap(); | ||
94 | let secret_access_key = std::env::var("S3_SECRET_ACCESS_KEY").unwrap(); | ||
95 | |||
96 | let credentials = aws_sdk_s3::config::Credentials::new(access_key_id, secret_access_key, None, None, "gitolfs3-env"); | ||
97 | let config = aws_config::SdkConfig::builder() | ||
98 | .endpoint_url(std::env::var("S3_ENDPOINT").unwrap()) | ||
99 | .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(credentials)) | ||
100 | .build(); | ||
101 | aws_sdk_s3::Client::new(&config) | ||
102 | } | ||
103 | |||
56 | #[tokio::main] | 104 | #[tokio::main] |
57 | async fn main() { | 105 | async fn main() { |
58 | // run our app with hyper, listening globally on port 3000 | 106 | // run our app with hyper, listening globally on port 3000 |
59 | let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); | 107 | let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); |
60 | let mut app = Router::new() | 108 | |
109 | let s3_client = get_s3_client(); | ||
110 | let s3_bucket = std::env::var("S3_BUCKET").unwrap(); | ||
111 | let shared_state = Arc::new(AppState { s3_client, s3_bucket }); | ||
112 | let app = Router::new() | ||
61 | .route("/batch", post(batch)) | 113 | .route("/batch", post(batch)) |
62 | .route("/:oid0/:oid1/:oid", get(obj_download)) | 114 | .route("/:oid0/:oid1/:oid", get(obj_download)) |
63 | .route("/:oid0/:oid1/:oid", put(obj_upload)); | 115 | .route("/:oid0/:oid1/:oid", put(obj_upload)) |
64 | axum::serve( | 116 | .with_state(shared_state); |
65 | listener, | ||
66 | any(|mut req: Request| async move { | ||
67 | let uri = req.uri(); | ||
68 | let original_uri = OriginalUri(uri.clone()); | ||
69 | |||
70 | let path_and_query = uri.path_and_query().unwrap(); | ||
71 | let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { | ||
72 | return Ok(StatusCode::NOT_FOUND.into_response()); | ||
73 | }; | ||
74 | let repo = repo | ||
75 | .trim_start_matches('/') | ||
76 | .trim_end_matches('/') | ||
77 | .to_string(); | ||
78 | if !path.starts_with("/") || !repo.ends_with(".git") { | ||
79 | return Ok(StatusCode::NOT_FOUND.into_response()); | ||
80 | } | ||
81 | 117 | ||
82 | let mut parts = uri.clone().into_parts(); | 118 | let middleware = axum::middleware::map_request(rewrite_url); |
83 | parts.path_and_query = match path_and_query.query() { | 119 | let app_with_middleware = middleware.layer(app); |
84 | None => path.try_into().ok(), | ||
85 | Some(q) => format!("{path}?{q}").try_into().ok(), | ||
86 | }; | ||
87 | let new_uri = Uri::from_parts(parts).unwrap(); | ||
88 | 120 | ||
89 | *req.uri_mut() = new_uri; | 121 | axum::serve(listener, app_with_middleware.into_make_service()) |
90 | req.extensions_mut().insert(original_uri); | ||
91 | req.extensions_mut().insert(RepositoryName(repo)); | ||
92 | |||
93 | app.call(req).await | ||
94 | }), | ||
95 | ) | ||
96 | .await | 122 | .await |
97 | .unwrap(); | 123 | .unwrap(); |
98 | } | 124 | } |
@@ -254,20 +280,28 @@ struct BatchResponse { | |||
254 | hash_algo: HashAlgo, | 280 | hash_algo: HashAlgo, |
255 | } | 281 | } |
256 | 282 | ||
257 | //fn handle_download_object(repo: &str, obj: &BatchRequestObject) { | 283 | async fn handle_download_object(state: &AppState, repo: &str, obj: &BatchRequestObject) { |
258 | // let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 284 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); |
259 | // let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 285 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
260 | // | 286 | |
261 | // let bucket_anme = "asdfasdf"; | 287 | let result = state.s3_client.head_object(). |
262 | // let region = s3::Region::Custom { | 288 | bucket(&state.s3_bucket). |
263 | // region: "nl-ams".to_string(), | 289 | key(full_path). |
264 | // endpoint: "rg.nl-ams.swc.cloud".to_string() | 290 | checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled). |
265 | // }; | 291 | send().await.unwrap(); |
266 | // let credentials = Credentials::new(None, None, None, None, None).unwrap(); | 292 | if let Some(checksum) = result.checksum_sha256() { |
267 | // let bucket = Bucket::new(bucket_anme, region, credentials).unwrap(); | 293 | if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { |
268 | //} | 294 | if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) { |
295 | if Oid::from(checksum32b) != obj.oid { | ||
296 | unreachable!(); | ||
297 | } | ||
298 | } | ||
299 | } | ||
300 | } | ||
301 | } | ||
269 | 302 | ||
270 | async fn batch( | 303 | async fn batch( |
304 | State(state): State<Arc<AppState>>, | ||
271 | header: HeaderMap, | 305 | header: HeaderMap, |
272 | RepositoryName(repo): RepositoryName, | 306 | RepositoryName(repo): RepositoryName, |
273 | GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, | 307 | GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, |
@@ -301,6 +335,7 @@ async fn batch( | |||
301 | 335 | ||
302 | let resp: BatchResponse; | 336 | let resp: BatchResponse; |
303 | for obj in payload.objects { | 337 | for obj in payload.objects { |
338 | handle_download_object(&state, &repo, &obj).await; | ||
304 | // match payload.operation { | 339 | // match payload.operation { |
305 | // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, | 340 | // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, |
306 | // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), | 341 | // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), |