From 7edcc4856400107602c28a58b9eb774d577f0375 Mon Sep 17 00:00:00 2001
From: Rutger Broekhoff
Date: Mon, 21 Oct 2024 00:12:46 +0200
Subject: Fix SHA256 checksum encoding

---
 gitolfs3-server/src/handler.rs | 122 +++++++++++++++++++++++------------------
 1 file changed, 68 insertions(+), 54 deletions(-)

(limited to 'gitolfs3-server/src')

diff --git a/gitolfs3-server/src/handler.rs b/gitolfs3-server/src/handler.rs
index b9f9bcf..64d5492 100644
--- a/gitolfs3-server/src/handler.rs
+++ b/gitolfs3-server/src/handler.rs
@@ -33,6 +33,46 @@ pub struct AppState {
     pub dl_limiter: Arc<Mutex<DownloadLimiter>>,
 }
 
+enum ObjectStatus {
+    ExistsOk { content_length: Option<i64> },
+    ExistsInconsistent,
+    DoesNotExist,
+}
+
+impl AppState {
+    async fn check_object(&self, repo: &str, obj: &BatchRequestObject) -> Result<ObjectStatus, ()> {
+        let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
+        let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
+
+        let result = match self
+            .s3_client
+            .head_object()
+            .bucket(&self.s3_bucket)
+            .key(full_path)
+            .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
+            .send()
+            .await
+        {
+            Ok(result) => result,
+            Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {
+                return Ok(ObjectStatus::DoesNotExist);
+            }
+            Err(e) => {
+                println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
+                return Err(());
+            }
+        };
+
+        // Scaleway actually doesn't provide SHA256 support, but maybe in the future :)
+        if !s3_validate_checksum(obj.oid, &result) || !s3_validate_size(obj.size, &result) {
+            return Ok(ObjectStatus::ExistsInconsistent);
+        }
+        Ok(ObjectStatus::ExistsOk {
+            content_length: result.content_length(),
+        })
+    }
+}
+
 async fn handle_download_object(
     state: &AppState,
     repo: &str,
@@ -42,18 +82,16 @@ async fn handle_download_object(
     let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
     let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
 
-    let result = match state
-        .s3_client
-        .head_object()
-        .bucket(&state.s3_bucket)
-        .key(&full_path)
-        .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
-        .send()
-        .await
-    {
-        Ok(result) => result,
-        Err(e) => {
-            println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
+    let content_length = match state.check_object(repo, obj).await {
+        Ok(ObjectStatus::ExistsOk { content_length }) => content_length,
+        Ok(_) => {
+            return BatchResponseObject::error(
+                obj,
+                http::StatusCode::UNPROCESSABLE_ENTITY,
+                "Object corrupted".to_string(),
+            );
+        }
+        Err(_) => {
             return BatchResponseObject::error(
                 obj,
                 http::StatusCode::INTERNAL_SERVER_ERROR,
@@ -62,22 +100,6 @@ async fn handle_download_object(
         }
     };
 
-    // Scaleway actually doesn't provide SHA256 support, but maybe in the future :)
-    if !s3_validate_checksum(obj.oid, &result) {
-        return BatchResponseObject::error(
-            obj,
-            http::StatusCode::UNPROCESSABLE_ENTITY,
-            "Object corrupted".to_string(),
-        );
-    }
-    if !s3_validate_size(obj.size, &result) {
-        return BatchResponseObject::error(
-            obj,
-            http::StatusCode::UNPROCESSABLE_ENTITY,
-            "Incorrect size specified (or object corrupted)".to_string(),
-        );
-    }
-
     let expires_in = std::time::Duration::from_secs(5 * 60);
     let expires_at = Utc::now() + expires_in;
 
@@ -122,7 +144,7 @@ async fn handle_download_object(
         };
     }
 
-    if let Some(content_length) = result.content_length() {
+    if let Some(content_length) = content_length {
         if content_length > 0 {
             match state
                 .dl_limiter
@@ -166,13 +188,6 @@ async fn handle_download_object(
         );
     };
 
-    let upload_path = format!(
-        "{repo}/info/lfs/objects/{}/{}/{}",
-        HexByte(obj.oid[0]),
-        HexByte(obj.oid[1]),
-        obj.oid,
-    );
-
     BatchResponseObject {
         oid: obj.oid,
         size: obj.size,
@@ -188,7 +203,13 @@ async fn handle_download_object(
                     map
                 },
                 expires_at,
-                href: format!("{}/{upload_path}", state.base_url),
+                href: format!(
+                    "{}/{repo}/info/lfs/objects/{}/{}/{}",
+                    state.base_url,
+                    HexByte(obj.oid[0]),
+                    HexByte(obj.oid[1]),
+                    obj.oid
+                ),
             }),
             ..Default::default()
         },
@@ -289,23 +310,12 @@ async fn handle_upload_object(
     let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
     let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
 
-    match state
-        .s3_client
-        .head_object()
-        .bucket(&state.s3_bucket)
-        .key(full_path.clone())
-        .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
-        .send()
-        .await
-    {
-        Ok(result) => {
-            if s3_validate_size(obj.size, &result) && s3_validate_checksum(obj.oid, &result) {
-                return None;
-            }
+    match state.check_object(repo, obj).await {
+        Ok(ObjectStatus::ExistsOk { .. }) => {
+            return None;
         }
-        Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {}
-        Err(e) => {
-            println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
+        Ok(_) => {}
+        Err(_) => {
             return Some(BatchResponseObject::error(
                 obj,
                 http::StatusCode::INTERNAL_SERVER_ERROR,
@@ -329,7 +339,7 @@ async fn handle_upload_object(
         .put_object()
         .bucket(&state.s3_bucket)
         .key(full_path)
-        .checksum_sha256(obj.oid.to_string())
+        .checksum_sha256(s3_encode_checksum(obj.oid))
         .content_length(obj.size)
         .presigned(config)
         .await
@@ -418,6 +428,10 @@ pub async fn handle_batch(
     GitLfsJson(Json(resp)).into_response()
 }
 
+fn s3_encode_checksum(oid: Oid) -> String {
+    BASE64_STANDARD.encode(oid.as_bytes())
+}
+
 fn s3_validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool {
     if let Some(checksum) = obj.checksum_sha256() {
         if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
-- 
cgit v1.2.3