diff options
Diffstat (limited to 'rs/server/src')
-rw-r--r-- | rs/server/src/main.rs | 347 |
1 files changed, 347 insertions, 0 deletions
diff --git a/rs/server/src/main.rs b/rs/server/src/main.rs new file mode 100644 index 0000000..8fe1d16 --- /dev/null +++ b/rs/server/src/main.rs | |||
@@ -0,0 +1,347 @@ | |||
1 | use std::collections::HashMap; | ||
2 | |||
3 | use awscreds::Credentials; | ||
4 | use axum::extract::rejection; | ||
5 | use axum::extract::FromRequest; | ||
6 | use axum::extract::Path; | ||
7 | use axum::http::header; | ||
8 | use axum::http::HeaderMap; | ||
9 | use axum::http::HeaderValue; | ||
10 | use axum::response::Response; | ||
11 | use axum::Json; | ||
12 | use chrono::DateTime; | ||
13 | use chrono::Utc; | ||
14 | use common::HexByte; | ||
15 | use common::Operation; | ||
16 | use s3::Bucket; | ||
17 | use serde::de; | ||
18 | use serde::de::DeserializeOwned; | ||
19 | use serde::Deserialize; | ||
20 | use serde::Serialize; | ||
21 | use tower_service::Service; | ||
22 | |||
23 | use axum::{ | ||
24 | async_trait, | ||
25 | extract::{FromRequestParts, OriginalUri, Request}, | ||
26 | http::{request::Parts, StatusCode, Uri}, | ||
27 | response::IntoResponse, | ||
28 | routing::{any, get, post, put}, | ||
29 | Extension, Router, | ||
30 | }; | ||
31 | |||
32 | #[derive(Clone)] | ||
33 | struct RepositoryName(String); | ||
34 | |||
35 | struct RepositoryNameRejection; | ||
36 | |||
37 | impl IntoResponse for RepositoryNameRejection { | ||
38 | fn into_response(self) -> Response { | ||
39 | (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response() | ||
40 | } | ||
41 | } | ||
42 | |||
43 | #[async_trait] | ||
44 | impl<S: Send + Sync> FromRequestParts<S> for RepositoryName { | ||
45 | type Rejection = RepositoryNameRejection; | ||
46 | |||
47 | async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { | ||
48 | let Ok(Extension(repo_name)) = Extension::<Self>::from_request_parts(parts, state).await | ||
49 | else { | ||
50 | return Err(RepositoryNameRejection); | ||
51 | }; | ||
52 | Ok(repo_name) | ||
53 | } | ||
54 | } | ||
55 | |||
56 | #[tokio::main] | ||
57 | async fn main() { | ||
58 | // run our app with hyper, listening globally on port 3000 | ||
59 | let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); | ||
60 | let mut app = Router::new() | ||
61 | .route("/batch", post(batch)) | ||
62 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
63 | .route("/:oid0/:oid1/:oid", put(obj_upload)); | ||
64 | axum::serve( | ||
65 | listener, | ||
66 | any(|mut req: Request| async move { | ||
67 | let uri = req.uri(); | ||
68 | let original_uri = OriginalUri(uri.clone()); | ||
69 | |||
70 | let path_and_query = uri.path_and_query().unwrap(); | ||
71 | let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else { | ||
72 | return Ok(StatusCode::NOT_FOUND.into_response()); | ||
73 | }; | ||
74 | let repo = repo | ||
75 | .trim_start_matches('/') | ||
76 | .trim_end_matches('/') | ||
77 | .to_string(); | ||
78 | if !path.starts_with("/") || !repo.ends_with(".git") { | ||
79 | return Ok(StatusCode::NOT_FOUND.into_response()); | ||
80 | } | ||
81 | |||
82 | let mut parts = uri.clone().into_parts(); | ||
83 | parts.path_and_query = match path_and_query.query() { | ||
84 | None => path.try_into().ok(), | ||
85 | Some(q) => format!("{path}?{q}").try_into().ok(), | ||
86 | }; | ||
87 | let new_uri = Uri::from_parts(parts).unwrap(); | ||
88 | |||
89 | *req.uri_mut() = new_uri; | ||
90 | req.extensions_mut().insert(original_uri); | ||
91 | req.extensions_mut().insert(RepositoryName(repo)); | ||
92 | |||
93 | app.call(req).await | ||
94 | }), | ||
95 | ) | ||
96 | .await | ||
97 | .unwrap(); | ||
98 | } | ||
99 | |||
100 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] | ||
101 | enum TransferAdapter { | ||
102 | #[serde(rename = "basic")] | ||
103 | Basic, | ||
104 | } | ||
105 | |||
106 | #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)] | ||
107 | enum HashAlgo { | ||
108 | #[serde(rename = "sha256")] | ||
109 | Sha256, | ||
110 | } | ||
111 | |||
112 | impl Default for HashAlgo { | ||
113 | fn default() -> Self { | ||
114 | Self::Sha256 | ||
115 | } | ||
116 | } | ||
117 | |||
118 | type Oid = common::Digest<32>; | ||
119 | |||
120 | #[derive(Debug, Deserialize, Clone)] | ||
121 | struct BatchRequestObject { | ||
122 | oid: Oid, | ||
123 | size: i64, | ||
124 | } | ||
125 | |||
126 | #[derive(Debug, Serialize, Deserialize, Clone)] | ||
127 | struct BatchRef { | ||
128 | name: String, | ||
129 | } | ||
130 | |||
131 | fn default_transfers() -> Vec<TransferAdapter> { | ||
132 | vec![TransferAdapter::Basic] | ||
133 | } | ||
134 | |||
135 | #[derive(Deserialize)] | ||
136 | struct BatchRequest { | ||
137 | operation: common::Operation, | ||
138 | #[serde(default = "default_transfers")] | ||
139 | transfers: Vec<TransferAdapter>, | ||
140 | #[serde(rename = "ref")] | ||
141 | reference: Option<BatchRef>, | ||
142 | objects: Vec<BatchRequestObject>, | ||
143 | #[serde(default)] | ||
144 | hash_algo: HashAlgo, | ||
145 | } | ||
146 | |||
147 | #[derive(Clone)] | ||
148 | struct GitLfsJson<T>(Json<T>); | ||
149 | |||
150 | const LFS_MIME: &'static str = "application/vnd.git-lfs+json"; | ||
151 | |||
152 | enum GitLfsJsonRejection { | ||
153 | Json(rejection::JsonRejection), | ||
154 | MissingGitLfsJsonContentType, | ||
155 | } | ||
156 | |||
157 | impl IntoResponse for GitLfsJsonRejection { | ||
158 | fn into_response(self) -> Response { | ||
159 | ( | ||
160 | StatusCode::UNSUPPORTED_MEDIA_TYPE, | ||
161 | format!("Expected request with `Content-Type: {LFS_MIME}`"), | ||
162 | ) | ||
163 | .into_response() | ||
164 | } | ||
165 | } | ||
166 | |||
167 | fn is_git_lfs_json_mimetype(mimetype: &str) -> bool { | ||
168 | let Ok(mime) = mimetype.parse::<mime::Mime>() else { | ||
169 | return false; | ||
170 | }; | ||
171 | if mime.type_() != mime::APPLICATION | ||
172 | || mime.subtype() != "vnd.git-lfs" | ||
173 | || mime.suffix() != Some(mime::JSON) | ||
174 | { | ||
175 | return false; | ||
176 | } | ||
177 | match mime.get_param(mime::CHARSET) { | ||
178 | Some(mime::UTF_8) | None => true, | ||
179 | Some(_) => false, | ||
180 | } | ||
181 | } | ||
182 | |||
183 | fn has_git_lfs_json_content_type(req: &Request) -> bool { | ||
184 | let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else { | ||
185 | return false; | ||
186 | }; | ||
187 | let Ok(content_type) = content_type.to_str() else { | ||
188 | return false; | ||
189 | }; | ||
190 | return is_git_lfs_json_mimetype(content_type); | ||
191 | } | ||
192 | |||
193 | #[async_trait] | ||
194 | impl<T, S> FromRequest<S> for GitLfsJson<T> | ||
195 | where | ||
196 | T: DeserializeOwned, | ||
197 | S: Send + Sync, | ||
198 | { | ||
199 | type Rejection = GitLfsJsonRejection; | ||
200 | |||
201 | async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> { | ||
202 | if !has_git_lfs_json_content_type(&req) { | ||
203 | return Err(GitLfsJsonRejection::MissingGitLfsJsonContentType); | ||
204 | } | ||
205 | Json::<T>::from_request(req, state) | ||
206 | .await | ||
207 | .map(GitLfsJson) | ||
208 | .map_err(GitLfsJsonRejection::Json) | ||
209 | } | ||
210 | } | ||
211 | |||
212 | impl<T: Serialize> IntoResponse for GitLfsJson<T> { | ||
213 | fn into_response(self) -> Response { | ||
214 | let GitLfsJson(json) = self; | ||
215 | let mut resp = json.into_response(); | ||
216 | resp.headers_mut().insert( | ||
217 | header::CONTENT_TYPE, | ||
218 | HeaderValue::from_static("application/vnd.git-lfs+json"), | ||
219 | ); | ||
220 | resp | ||
221 | } | ||
222 | } | ||
223 | |||
224 | #[derive(Debug, Serialize, Clone)] | ||
225 | struct BatchResponseObjectAction { | ||
226 | href: String, | ||
227 | #[serde(skip_serializing_if = "HashMap::is_empty")] | ||
228 | header: HashMap<String, String>, | ||
229 | expires_at: DateTime<Utc>, | ||
230 | } | ||
231 | |||
232 | #[derive(Debug, Serialize, Clone)] | ||
233 | struct BatchResponseObjectActions { | ||
234 | #[serde(skip_serializing_if = "Option::is_none")] | ||
235 | upload: Option<BatchResponseObjectAction>, | ||
236 | #[serde(skip_serializing_if = "Option::is_none")] | ||
237 | download: Option<BatchResponseObjectAction>, | ||
238 | #[serde(skip_serializing_if = "Option::is_none")] | ||
239 | verify: Option<BatchResponseObjectAction>, | ||
240 | } | ||
241 | |||
242 | #[derive(Debug, Serialize, Clone)] | ||
243 | struct BatchResponseObject { | ||
244 | oid: Oid, | ||
245 | size: i64, | ||
246 | #[serde(skip_serializing_if = "Option::is_none")] | ||
247 | authenticated: Option<bool>, | ||
248 | } | ||
249 | |||
250 | #[derive(Debug, Serialize, Clone)] | ||
251 | struct BatchResponse { | ||
252 | transfer: TransferAdapter, | ||
253 | objects: Vec<BatchResponseObject>, | ||
254 | hash_algo: HashAlgo, | ||
255 | } | ||
256 | |||
257 | //fn handle_download_object(repo: &str, obj: &BatchRequestObject) { | ||
258 | // let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | ||
259 | // let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | ||
260 | // | ||
261 | // let bucket_anme = "asdfasdf"; | ||
262 | // let region = s3::Region::Custom { | ||
263 | // region: "nl-ams".to_string(), | ||
264 | // endpoint: "rg.nl-ams.swc.cloud".to_string() | ||
265 | // }; | ||
266 | // let credentials = Credentials::new(None, None, None, None, None).unwrap(); | ||
267 | // let bucket = Bucket::new(bucket_anme, region, credentials).unwrap(); | ||
268 | //} | ||
269 | |||
270 | async fn batch( | ||
271 | header: HeaderMap, | ||
272 | RepositoryName(repo): RepositoryName, | ||
273 | GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>, | ||
274 | ) -> Response { | ||
275 | if !header | ||
276 | .get_all("Accept") | ||
277 | .iter() | ||
278 | .filter_map(|v| v.to_str().ok()) | ||
279 | .any(is_git_lfs_json_mimetype) | ||
280 | { | ||
281 | return ( | ||
282 | StatusCode::NOT_ACCEPTABLE, | ||
283 | format!("Expected `{LFS_MIME}` (with UTF-8 charset) in list of acceptable response media types"), | ||
284 | ).into_response(); | ||
285 | } | ||
286 | |||
287 | if payload.hash_algo != HashAlgo::Sha256 { | ||
288 | return ( | ||
289 | StatusCode::CONFLICT, | ||
290 | "Unsupported hashing algorithm speicifed", | ||
291 | ) | ||
292 | .into_response(); | ||
293 | } | ||
294 | if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) { | ||
295 | return ( | ||
296 | StatusCode::CONFLICT, | ||
297 | "Unsupported transfer adapter specified (supported: basic)", | ||
298 | ) | ||
299 | .into_response(); | ||
300 | } | ||
301 | |||
302 | let resp: BatchResponse; | ||
303 | for obj in payload.objects { | ||
304 | // match payload.operation { | ||
305 | // Operation::Download => resp.objects.push(handle_download_object(repo, obj));, | ||
306 | // Operation::Upload => resp.objects.push(handle_upload_object(repo, obj)), | ||
307 | // }; | ||
308 | } | ||
309 | |||
310 | format!("hi from {repo}\n").into_response() | ||
311 | } | ||
312 | |||
313 | #[derive(Deserialize, Copy, Clone)] | ||
314 | #[serde(remote = "Self")] | ||
315 | struct FileParams { | ||
316 | oid0: HexByte, | ||
317 | oid1: HexByte, | ||
318 | oid: Oid, | ||
319 | } | ||
320 | |||
321 | impl<'de> Deserialize<'de> for FileParams { | ||
322 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
323 | where | ||
324 | D: serde::Deserializer<'de>, | ||
325 | { | ||
326 | let unchecked @ FileParams { | ||
327 | oid0: HexByte(oid0), | ||
328 | oid1: HexByte(oid1), | ||
329 | oid, | ||
330 | } = FileParams::deserialize(deserializer)?; | ||
331 | if oid0 != oid.as_bytes()[0] { | ||
332 | return Err(de::Error::custom( | ||
333 | "first OID path part does not match first byte of full OID", | ||
334 | )); | ||
335 | } | ||
336 | if oid1 != oid.as_bytes()[1] { | ||
337 | return Err(de::Error::custom( | ||
338 | "second OID path part does not match first byte of full OID", | ||
339 | )); | ||
340 | } | ||
341 | Ok(unchecked) | ||
342 | } | ||
343 | } | ||
344 | |||
345 | async fn obj_download(Path(FileParams { oid0, oid1, oid }): Path<FileParams>) {} | ||
346 | |||
347 | async fn obj_upload(Path(FileParams { oid0, oid1, oid }): Path<FileParams>) {} | ||