diff options
Diffstat (limited to 'gitolfs3-server/src')
-rw-r--r-- | gitolfs3-server/src/handler.rs | 122 |
1 files changed, 68 insertions, 54 deletions
diff --git a/gitolfs3-server/src/handler.rs b/gitolfs3-server/src/handler.rs index b9f9bcf..64d5492 100644 --- a/gitolfs3-server/src/handler.rs +++ b/gitolfs3-server/src/handler.rs | |||
@@ -33,6 +33,46 @@ pub struct AppState { | |||
33 | pub dl_limiter: Arc<Mutex<DownloadLimiter>>, | 33 | pub dl_limiter: Arc<Mutex<DownloadLimiter>>, |
34 | } | 34 | } |
35 | 35 | ||
36 | enum ObjectStatus { | ||
37 | ExistsOk { content_length: Option<i64> }, | ||
38 | ExistsInconsistent, | ||
39 | DoesNotExist, | ||
40 | } | ||
41 | |||
42 | impl AppState { | ||
43 | async fn check_object(&self, repo: &str, obj: &BatchRequestObject) -> Result<ObjectStatus, ()> { | ||
44 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | ||
45 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | ||
46 | |||
47 | let result = match self | ||
48 | .s3_client | ||
49 | .head_object() | ||
50 | .bucket(&self.s3_bucket) | ||
51 | .key(full_path) | ||
52 | .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) | ||
53 | .send() | ||
54 | .await | ||
55 | { | ||
56 | Ok(result) => result, | ||
57 | Err(SdkError::ServiceError(e)) if e.err().is_not_found() => { | ||
58 | return Ok(ObjectStatus::DoesNotExist); | ||
59 | } | ||
60 | Err(e) => { | ||
61 | println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); | ||
62 | return Err(()); | ||
63 | } | ||
64 | }; | ||
65 | |||
66 | // Scaleway actually doesn't provide SHA256 support, but maybe in the future :) | ||
67 | if !s3_validate_checksum(obj.oid, &result) || !s3_validate_size(obj.size, &result) { | ||
68 | return Ok(ObjectStatus::ExistsInconsistent); | ||
69 | } | ||
70 | Ok(ObjectStatus::ExistsOk { | ||
71 | content_length: result.content_length(), | ||
72 | }) | ||
73 | } | ||
74 | } | ||
75 | |||
36 | async fn handle_download_object( | 76 | async fn handle_download_object( |
37 | state: &AppState, | 77 | state: &AppState, |
38 | repo: &str, | 78 | repo: &str, |
@@ -42,18 +82,16 @@ async fn handle_download_object( | |||
42 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 82 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); |
43 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 83 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
44 | 84 | ||
45 | let result = match state | 85 | let content_length = match state.check_object(repo, obj).await { |
46 | .s3_client | 86 | Ok(ObjectStatus::ExistsOk { content_length }) => content_length, |
47 | .head_object() | 87 | Ok(_) => { |
48 | .bucket(&state.s3_bucket) | 88 | return BatchResponseObject::error( |
49 | .key(&full_path) | 89 | obj, |
50 | .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) | 90 | http::StatusCode::UNPROCESSABLE_ENTITY, |
51 | .send() | 91 | "Object corrupted".to_string(), |
52 | .await | 92 | ); |
53 | { | 93 | } |
54 | Ok(result) => result, | 94 | Err(_) => { |
55 | Err(e) => { | ||
56 | println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); | ||
57 | return BatchResponseObject::error( | 95 | return BatchResponseObject::error( |
58 | obj, | 96 | obj, |
59 | http::StatusCode::INTERNAL_SERVER_ERROR, | 97 | http::StatusCode::INTERNAL_SERVER_ERROR, |
@@ -62,22 +100,6 @@ async fn handle_download_object( | |||
62 | } | 100 | } |
63 | }; | 101 | }; |
64 | 102 | ||
65 | // Scaleway actually doesn't provide SHA256 support, but maybe in the future :) | ||
66 | if !s3_validate_checksum(obj.oid, &result) { | ||
67 | return BatchResponseObject::error( | ||
68 | obj, | ||
69 | http::StatusCode::UNPROCESSABLE_ENTITY, | ||
70 | "Object corrupted".to_string(), | ||
71 | ); | ||
72 | } | ||
73 | if !s3_validate_size(obj.size, &result) { | ||
74 | return BatchResponseObject::error( | ||
75 | obj, | ||
76 | http::StatusCode::UNPROCESSABLE_ENTITY, | ||
77 | "Incorrect size specified (or object corrupted)".to_string(), | ||
78 | ); | ||
79 | } | ||
80 | |||
81 | let expires_in = std::time::Duration::from_secs(5 * 60); | 103 | let expires_in = std::time::Duration::from_secs(5 * 60); |
82 | let expires_at = Utc::now() + expires_in; | 104 | let expires_at = Utc::now() + expires_in; |
83 | 105 | ||
@@ -122,7 +144,7 @@ async fn handle_download_object( | |||
122 | }; | 144 | }; |
123 | } | 145 | } |
124 | 146 | ||
125 | if let Some(content_length) = result.content_length() { | 147 | if let Some(content_length) = content_length { |
126 | if content_length > 0 { | 148 | if content_length > 0 { |
127 | match state | 149 | match state |
128 | .dl_limiter | 150 | .dl_limiter |
@@ -166,13 +188,6 @@ async fn handle_download_object( | |||
166 | ); | 188 | ); |
167 | }; | 189 | }; |
168 | 190 | ||
169 | let upload_path = format!( | ||
170 | "{repo}/info/lfs/objects/{}/{}/{}", | ||
171 | HexByte(obj.oid[0]), | ||
172 | HexByte(obj.oid[1]), | ||
173 | obj.oid, | ||
174 | ); | ||
175 | |||
176 | BatchResponseObject { | 191 | BatchResponseObject { |
177 | oid: obj.oid, | 192 | oid: obj.oid, |
178 | size: obj.size, | 193 | size: obj.size, |
@@ -188,7 +203,13 @@ async fn handle_download_object( | |||
188 | map | 203 | map |
189 | }, | 204 | }, |
190 | expires_at, | 205 | expires_at, |
191 | href: format!("{}/{upload_path}", state.base_url), | 206 | href: format!( |
207 | "{}/{repo}/info/lfs/objects/{}/{}/{}", | ||
208 | state.base_url, | ||
209 | HexByte(obj.oid[0]), | ||
210 | HexByte(obj.oid[1]), | ||
211 | obj.oid | ||
212 | ), | ||
192 | }), | 213 | }), |
193 | ..Default::default() | 214 | ..Default::default() |
194 | }, | 215 | }, |
@@ -289,23 +310,12 @@ async fn handle_upload_object( | |||
289 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 310 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); |
290 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 311 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
291 | 312 | ||
292 | match state | 313 | match state.check_object(repo, obj).await { |
293 | .s3_client | 314 | Ok(ObjectStatus::ExistsOk { .. }) => { |
294 | .head_object() | 315 | return None; |
295 | .bucket(&state.s3_bucket) | ||
296 | .key(full_path.clone()) | ||
297 | .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled) | ||
298 | .send() | ||
299 | .await | ||
300 | { | ||
301 | Ok(result) => { | ||
302 | if s3_validate_size(obj.size, &result) && s3_validate_checksum(obj.oid, &result) { | ||
303 | return None; | ||
304 | } | ||
305 | } | 316 | } |
306 | Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {} | 317 | Ok(_) => {} |
307 | Err(e) => { | 318 | Err(_) => { |
308 | println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid); | ||
309 | return Some(BatchResponseObject::error( | 319 | return Some(BatchResponseObject::error( |
310 | obj, | 320 | obj, |
311 | http::StatusCode::INTERNAL_SERVER_ERROR, | 321 | http::StatusCode::INTERNAL_SERVER_ERROR, |
@@ -329,7 +339,7 @@ async fn handle_upload_object( | |||
329 | .put_object() | 339 | .put_object() |
330 | .bucket(&state.s3_bucket) | 340 | .bucket(&state.s3_bucket) |
331 | .key(full_path) | 341 | .key(full_path) |
332 | .checksum_sha256(obj.oid.to_string()) | 342 | .checksum_sha256(s3_encode_checksum(obj.oid)) |
333 | .content_length(obj.size) | 343 | .content_length(obj.size) |
334 | .presigned(config) | 344 | .presigned(config) |
335 | .await | 345 | .await |
@@ -418,6 +428,10 @@ pub async fn handle_batch( | |||
418 | GitLfsJson(Json(resp)).into_response() | 428 | GitLfsJson(Json(resp)).into_response() |
419 | } | 429 | } |
420 | 430 | ||
431 | fn s3_encode_checksum(oid: Oid) -> String { | ||
432 | BASE64_STANDARD.encode(oid.as_bytes()) | ||
433 | } | ||
434 | |||
421 | fn s3_validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { | 435 | fn s3_validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool { |
422 | if let Some(checksum) = obj.checksum_sha256() { | 436 | if let Some(checksum) = obj.checksum_sha256() { |
423 | if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { | 437 | if let Ok(checksum) = BASE64_STANDARD.decode(checksum) { |