aboutsummaryrefslogtreecommitdiffstats
path: root/server/src/main.rs
diff options
context:
space:
mode:
Diffstat (limited to 'server/src/main.rs')
-rw-r--r--server/src/main.rs1028
1 files changed, 1028 insertions, 0 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
new file mode 100644
index 0000000..8baa0d6
--- /dev/null
+++ b/server/src/main.rs
@@ -0,0 +1,1028 @@
1use std::collections::HashMap;
2use std::collections::HashSet;
3use std::process::ExitCode;
4use std::sync::Arc;
5
6use aws_sdk_s3::error::SdkError;
7use aws_sdk_s3::operation::head_object::HeadObjectOutput;
8use axum::extract::rejection;
9use axum::extract::FromRequest;
10use axum::extract::Path;
11use axum::extract::State;
12use axum::http::header;
13use axum::http::HeaderMap;
14use axum::http::HeaderValue;
15use axum::response::Response;
16use axum::Json;
17use axum::ServiceExt;
18use base64::prelude::*;
19use chrono::DateTime;
20use chrono::Utc;
21use common::HexByte;
22use serde::de;
23use serde::de::DeserializeOwned;
24use serde::Deserialize;
25use serde::Serialize;
26use tower::Layer;
27
28use axum::{
29 async_trait,
30 extract::{FromRequestParts, OriginalUri, Request},
31 http::{request::Parts, StatusCode, Uri},
32 response::IntoResponse,
33 routing::{get, post},
34 Extension, Router,
35};
36
37#[derive(Clone)]
38struct RepositoryName(String);
39
40struct RepositoryNameRejection;
41
42impl IntoResponse for RepositoryNameRejection {
43 fn into_response(self) -> Response {
44 (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response()
45 }
46}
47
48#[async_trait]
49impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
50 type Rejection = RepositoryNameRejection;
51
52 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
53 let Ok(Extension(repo_name)) = Extension::<Self>::from_request_parts(parts, state).await
54 else {
55 return Err(RepositoryNameRejection);
56 };
57 Ok(repo_name)
58 }
59}
60
61async fn rewrite_url<B>(
62 mut req: axum::http::Request<B>,
63) -> Result<axum::http::Request<B>, StatusCode> {
64 let uri = req.uri();
65 let original_uri = OriginalUri(uri.clone());
66
67 let Some(path_and_query) = uri.path_and_query() else {
68 // L @ no path & query
69 return Err(StatusCode::BAD_REQUEST);
70 };
71 let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else {
72 return Err(StatusCode::NOT_FOUND);
73 };
74 let repo = repo
75 .trim_start_matches('/')
76 .trim_end_matches('/')
77 .to_string();
78 if !path.starts_with('/') || !repo.ends_with(".git") {
79 return Err(StatusCode::NOT_FOUND);
80 }
81
82 let mut parts = uri.clone().into_parts();
83 parts.path_and_query = match path_and_query.query() {
84 None => path.try_into().ok(),
85 Some(q) => format!("{path}?{q}").try_into().ok(),
86 };
87 let Ok(new_uri) = Uri::from_parts(parts) else {
88 return Err(StatusCode::INTERNAL_SERVER_ERROR);
89 };
90
91 *req.uri_mut() = new_uri;
92 req.extensions_mut().insert(original_uri);
93 req.extensions_mut().insert(RepositoryName(repo));
94
95 Ok(req)
96}
97
98struct AppState {
99 s3_client: aws_sdk_s3::Client,
100 s3_bucket: String,
101 authz_conf: AuthorizationConfig,
102 // Should not end with a slash.
103 base_url: String,
104}
105
106struct Env {
107 s3_access_key_id: String,
108 s3_secret_access_key: String,
109 s3_bucket: String,
110 s3_region: String,
111 s3_endpoint: String,
112 base_url: String,
113 key_path: String,
114 listen_host: String,
115 listen_port: String,
116 trusted_forwarded_hosts: String,
117}
118
119fn require_env(name: &str) -> Result<String, String> {
120 std::env::var(name)
121 .map_err(|_| format!("environment variable {name} should be defined and valid"))
122}
123
124impl Env {
125 fn load() -> Result<Env, String> {
126 Ok(Env {
127 s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?,
128 s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?,
129 s3_region: require_env("GITOLFS3_S3_REGION")?,
130 s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?,
131 s3_bucket: require_env("GITOLFS3_S3_BUCKET")?,
132 base_url: require_env("GITOLFS3_BASE_URL")?,
133 key_path: require_env("GITOLFS3_KEY_PATH")?,
134 listen_host: require_env("GITOLFS3_LISTEN_HOST")?,
135 listen_port: require_env("GITOLFS3_LISTEN_PORT")?,
136 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS")
137 .unwrap_or_default(),
138 })
139 }
140}
141
142fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
143 let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?;
144 let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?;
145
146 let credentials = aws_sdk_s3::config::Credentials::new(
147 access_key_id,
148 secret_access_key,
149 None,
150 None,
151 "gitolfs3-env",
152 );
153 let config = aws_config::SdkConfig::builder()
154 .behavior_version(aws_config::BehaviorVersion::latest())
155 .region(aws_config::Region::new(env.s3_region.clone()))
156 .endpoint_url(&env.s3_endpoint)
157 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(
158 credentials,
159 ))
160 .build();
161 Ok(aws_sdk_s3::Client::new(&config))
162}
163
164#[tokio::main]
165async fn main() -> ExitCode {
166 tracing_subscriber::fmt::init();
167
168 let env = match Env::load() {
169 Ok(env) => env,
170 Err(e) => {
171 println!("Failed to load configuration: {e}");
172 return ExitCode::from(2);
173 }
174 };
175
176 let s3_client = match get_s3_client(&env) {
177 Ok(s3_client) => s3_client,
178 Err(e) => {
179 println!("Failed to create S3 client: {e}");
180 return ExitCode::FAILURE;
181 }
182 };
183 let key = match common::load_key(&env.key_path) {
184 Ok(key) => key,
185 Err(e) => {
186 println!("Failed to load Gitolfs3 key: {e}");
187 return ExitCode::FAILURE;
188 }
189 };
190
191 let trusted_forwarded_hosts: HashSet<String> = env
192 .trusted_forwarded_hosts
193 .split(',')
194 .map(|s| s.to_owned())
195 .filter(|s| !s.is_empty())
196 .collect();
197 let base_url = env.base_url.trim_end_matches('/').to_string();
198
199 let authz_conf = AuthorizationConfig {
200 key,
201 trusted_forwarded_hosts,
202 };
203
204 let shared_state = Arc::new(AppState {
205 s3_client,
206 s3_bucket: env.s3_bucket,
207 authz_conf,
208 base_url,
209 });
210 let app = Router::new()
211 .route("/batch", post(batch))
212 .route("/:oid0/:oid1/:oid", get(obj_download))
213 .with_state(shared_state);
214
215 let middleware = axum::middleware::map_request(rewrite_url);
216 let app_with_middleware = middleware.layer(app);
217
218 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
219 println!("Configured LISTEN_PORT should be an unsigned integer no higher than 65535");
220 return ExitCode::from(2);
221 };
222 let addr: (String, u16) = (env.listen_host, listen_port);
223 let listener = match tokio::net::TcpListener::bind(addr).await {
224 Ok(listener) => listener,
225 Err(e) => {
226 println!("Failed to listen: {e}");
227 return ExitCode::FAILURE;
228 }
229 };
230
231 match axum::serve(listener, app_with_middleware.into_make_service()).await {
232 Ok(_) => ExitCode::SUCCESS,
233 Err(e) => {
234 println!("Error serving: {e}");
235 ExitCode::FAILURE
236 }
237 }
238}
239
240#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
241enum TransferAdapter {
242 #[serde(rename = "basic")]
243 Basic,
244 #[serde(other)]
245 Unknown,
246}
247
248#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
249enum HashAlgo {
250 #[serde(rename = "sha256")]
251 Sha256,
252 #[serde(other)]
253 Unknown,
254}
255
256impl Default for HashAlgo {
257 fn default() -> Self {
258 Self::Sha256
259 }
260}
261
262#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
263struct BatchRequestObject {
264 oid: common::Oid,
265 size: i64,
266}
267
268#[derive(Debug, Serialize, Deserialize, Clone)]
269struct BatchRef {
270 name: String,
271}
272
273fn default_transfers() -> Vec<TransferAdapter> {
274 vec![TransferAdapter::Basic]
275}
276
277#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
278struct BatchRequest {
279 operation: common::Operation,
280 #[serde(default = "default_transfers")]
281 transfers: Vec<TransferAdapter>,
282 objects: Vec<BatchRequestObject>,
283 #[serde(default)]
284 hash_algo: HashAlgo,
285}
286
287#[derive(Debug, Clone)]
288struct GitLfsJson<T>(Json<T>);
289
290const LFS_MIME: &str = "application/vnd.git-lfs+json";
291
292enum GitLfsJsonRejection {
293 Json(rejection::JsonRejection),
294 MissingGitLfsJsonContentType,
295}
296
297impl IntoResponse for GitLfsJsonRejection {
298 fn into_response(self) -> Response {
299 match self {
300 Self::Json(rej) => rej.into_response(),
301 Self::MissingGitLfsJsonContentType => make_error_resp(
302 StatusCode::UNSUPPORTED_MEDIA_TYPE,
303 &format!("Expected request with `Content-Type: {LFS_MIME}`"),
304 )
305 .into_response(),
306 }
307 }
308}
309
310fn is_git_lfs_json_mimetype(mimetype: &str) -> bool {
311 let Ok(mime) = mimetype.parse::<mime::Mime>() else {
312 return false;
313 };
314 if mime.type_() != mime::APPLICATION
315 || mime.subtype() != "vnd.git-lfs"
316 || mime.suffix() != Some(mime::JSON)
317 {
318 return false;
319 }
320 match mime.get_param(mime::CHARSET) {
321 Some(mime::UTF_8) | None => true,
322 Some(_) => false,
323 }
324}
325
326fn has_git_lfs_json_content_type(req: &Request) -> bool {
327 let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
328 return false;
329 };
330 let Ok(content_type) = content_type.to_str() else {
331 return false;
332 };
333 is_git_lfs_json_mimetype(content_type)
334}
335
336#[async_trait]
337impl<T, S> FromRequest<S> for GitLfsJson<T>
338where
339 T: DeserializeOwned,
340 S: Send + Sync,
341{
342 type Rejection = GitLfsJsonRejection;
343
344 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
345 if !has_git_lfs_json_content_type(&req) {
346 return Err(GitLfsJsonRejection::MissingGitLfsJsonContentType);
347 }
348 Json::<T>::from_request(req, state)
349 .await
350 .map(GitLfsJson)
351 .map_err(GitLfsJsonRejection::Json)
352 }
353}
354
355impl<T: Serialize> IntoResponse for GitLfsJson<T> {
356 fn into_response(self) -> Response {
357 let GitLfsJson(json) = self;
358 let mut resp = json.into_response();
359 resp.headers_mut().insert(
360 header::CONTENT_TYPE,
361 HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"),
362 );
363 resp
364 }
365}
366
367#[derive(Debug, Serialize)]
368struct GitLfsErrorData<'a> {
369 message: &'a str,
370}
371
372type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>);
373
374const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse {
375 (code, GitLfsJson(Json(GitLfsErrorData { message })))
376}
377
378#[derive(Debug, Serialize, Clone)]
379struct BatchResponseObjectAction {
380 href: String,
381 #[serde(skip_serializing_if = "HashMap::is_empty")]
382 header: HashMap<String, String>,
383 expires_at: DateTime<Utc>,
384}
385
386#[derive(Default, Debug, Serialize, Clone)]
387struct BatchResponseObjectActions {
388 #[serde(skip_serializing_if = "Option::is_none")]
389 upload: Option<BatchResponseObjectAction>,
390 #[serde(skip_serializing_if = "Option::is_none")]
391 download: Option<BatchResponseObjectAction>,
392 #[serde(skip_serializing_if = "Option::is_none")]
393 verify: Option<BatchResponseObjectAction>,
394}
395
396#[derive(Debug, Clone, Serialize)]
397struct BatchResponseObjectError {
398 code: u16,
399 message: String,
400}
401
402#[derive(Debug, Serialize, Clone)]
403struct BatchResponseObject {
404 oid: common::Oid,
405 size: i64,
406 #[serde(skip_serializing_if = "Option::is_none")]
407 authenticated: Option<bool>,
408 actions: BatchResponseObjectActions,
409 #[serde(skip_serializing_if = "Option::is_none")]
410 error: Option<BatchResponseObjectError>,
411}
412
413impl BatchResponseObject {
414 fn error(obj: &BatchRequestObject, code: StatusCode, message: String) -> BatchResponseObject {
415 BatchResponseObject {
416 oid: obj.oid,
417 size: obj.size,
418 authenticated: None,
419 actions: Default::default(),
420 error: Some(BatchResponseObjectError {
421 code: code.as_u16(),
422 message,
423 }),
424 }
425 }
426}
427
428#[derive(Debug, Serialize, Clone)]
429struct BatchResponse {
430 transfer: TransferAdapter,
431 objects: Vec<BatchResponseObject>,
432 hash_algo: HashAlgo,
433}
434
435fn validate_checksum(oid: common::Oid, obj: &HeadObjectOutput) -> bool {
436 if let Some(checksum) = obj.checksum_sha256() {
437 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
438 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
439 return common::Oid::from(checksum32b) == oid;
440 }
441 }
442 }
443 true
444}
445
446fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool {
447 if let Some(length) = obj.content_length() {
448 return length == expected;
449 }
450 true
451}
452
453async fn handle_upload_object(
454 state: &AppState,
455 repo: &str,
456 obj: &BatchRequestObject,
457) -> Option<BatchResponseObject> {
458 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
459 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
460
461 match state
462 .s3_client
463 .head_object()
464 .bucket(&state.s3_bucket)
465 .key(full_path.clone())
466 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
467 .send()
468 .await
469 {
470 Ok(result) => {
471 if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) {
472 return None;
473 }
474 }
475 Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {}
476 Err(e) => {
477 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
478 return Some(BatchResponseObject::error(
479 obj,
480 StatusCode::INTERNAL_SERVER_ERROR,
481 "Failed to query object information".to_string(),
482 ));
483 }
484 };
485
486 let expires_in = std::time::Duration::from_secs(5 * 60);
487 let expires_at = Utc::now() + expires_in;
488
489 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
490 return Some(BatchResponseObject::error(
491 obj,
492 StatusCode::INTERNAL_SERVER_ERROR,
493 "Failed to generate upload URL".to_string(),
494 ));
495 };
496 let Ok(presigned) = state
497 .s3_client
498 .put_object()
499 .bucket(&state.s3_bucket)
500 .key(full_path)
501 .checksum_sha256(obj.oid.to_string())
502 .content_length(obj.size)
503 .presigned(config)
504 .await
505 else {
506 return Some(BatchResponseObject::error(
507 obj,
508 StatusCode::INTERNAL_SERVER_ERROR,
509 "Failed to generate upload URL".to_string(),
510 ));
511 };
512 Some(BatchResponseObject {
513 oid: obj.oid,
514 size: obj.size,
515 authenticated: Some(true),
516 actions: BatchResponseObjectActions {
517 upload: Some(BatchResponseObjectAction {
518 header: presigned
519 .headers()
520 .map(|(k, v)| (k.to_owned(), v.to_owned()))
521 .collect(),
522 expires_at,
523 href: presigned.uri().to_string(),
524 }),
525 ..Default::default()
526 },
527 error: None,
528 })
529}
530
531async fn handle_download_object(
532 state: &AppState,
533 repo: &str,
534 obj: &BatchRequestObject,
535 trusted: bool,
536) -> BatchResponseObject {
537 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
538 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
539
540 let result = match state
541 .s3_client
542 .head_object()
543 .bucket(&state.s3_bucket)
544 .key(&full_path)
545 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
546 .send()
547 .await
548 {
549 Ok(result) => result,
550 Err(e) => {
551 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
552 return BatchResponseObject::error(
553 obj,
554 StatusCode::INTERNAL_SERVER_ERROR,
555 "Failed to query object information".to_string(),
556 );
557 }
558 };
559
560 // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :)
561 if !validate_checksum(obj.oid, &result) {
562 return BatchResponseObject::error(
563 obj,
564 StatusCode::UNPROCESSABLE_ENTITY,
565 "Object corrupted".to_string(),
566 );
567 }
568 if !validate_size(obj.size, &result) {
569 return BatchResponseObject::error(
570 obj,
571 StatusCode::UNPROCESSABLE_ENTITY,
572 "Incorrect size specified (or object corrupted)".to_string(),
573 );
574 }
575
576 let expires_in = std::time::Duration::from_secs(5 * 60);
577 let expires_at = Utc::now() + expires_in;
578
579 if trusted {
580 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
581 return BatchResponseObject::error(
582 obj,
583 StatusCode::INTERNAL_SERVER_ERROR,
584 "Failed to generate upload URL".to_string(),
585 );
586 };
587 let Ok(presigned) = state
588 .s3_client
589 .get_object()
590 .bucket(&state.s3_bucket)
591 .key(full_path)
592 .presigned(config)
593 .await
594 else {
595 return BatchResponseObject::error(
596 obj,
597 StatusCode::INTERNAL_SERVER_ERROR,
598 "Failed to generate upload URL".to_string(),
599 );
600 };
601 return BatchResponseObject {
602 oid: obj.oid,
603 size: obj.size,
604 authenticated: Some(true),
605 actions: BatchResponseObjectActions {
606 download: Some(BatchResponseObjectAction {
607 header: presigned
608 .headers()
609 .map(|(k, v)| (k.to_owned(), v.to_owned()))
610 .collect(),
611 expires_at,
612 href: presigned.uri().to_string(),
613 }),
614 ..Default::default()
615 },
616 error: None,
617 };
618 }
619
620 let Some(tag) = common::generate_tag(
621 common::Claims {
622 specific_claims: common::SpecificClaims::Download(obj.oid),
623 repo_path: repo,
624 expires_at,
625 },
626 &state.authz_conf.key,
627 ) else {
628 return BatchResponseObject::error(
629 obj,
630 StatusCode::INTERNAL_SERVER_ERROR,
631 "Internal server error".to_string(),
632 );
633 };
634
635 let upload_path = format!(
636 "{repo}/info/lfs/objects/{}/{}/{}",
637 HexByte(obj.oid[0]),
638 HexByte(obj.oid[1]),
639 obj.oid,
640 );
641
642 BatchResponseObject {
643 oid: obj.oid,
644 size: obj.size,
645 authenticated: Some(true),
646 actions: BatchResponseObjectActions {
647 download: Some(BatchResponseObjectAction {
648 header: {
649 let mut map = HashMap::new();
650 map.insert(
651 "Authorization".to_string(),
652 format!("Gitolfs3-Hmac-Sha256 {tag} {}", expires_at.timestamp()),
653 );
654 map
655 },
656 expires_at,
657 href: format!("{}/{upload_path}", state.base_url),
658 }),
659 ..Default::default()
660 },
661 error: None,
662 }
663}
664
665struct AuthorizationConfig {
666 trusted_forwarded_hosts: HashSet<String>,
667 key: common::Key,
668}
669
670struct Trusted(bool);
671
672fn forwarded_for_trusted_host(
673 headers: &HeaderMap,
674 trusted: &HashSet<String>,
675) -> Result<bool, GitLfsErrorResponse<'static>> {
676 if let Some(forwarded_for) = headers.get("X-Forwarded-For") {
677 if let Ok(forwarded_for) = forwarded_for.to_str() {
678 if trusted.contains(forwarded_for) {
679 return Ok(true);
680 }
681 } else {
682 return Err(make_error_resp(
683 StatusCode::NOT_FOUND,
684 "Invalid X-Forwarded-For header",
685 ));
686 }
687 }
688 Ok(false)
689}
690const REPO_NOT_FOUND: GitLfsErrorResponse =
691 make_error_resp(StatusCode::NOT_FOUND, "Repository not found");
692
693fn authorize_batch(
694 conf: &AuthorizationConfig,
695 repo_path: &str,
696 public: bool,
697 operation: common::Operation,
698 headers: &HeaderMap,
699) -> Result<Trusted, GitLfsErrorResponse<'static>> {
700 // - No authentication required for downloading exported repos
701 // - When authenticated:
702 // - Download / upload over presigned URLs
703 // - When accessing over Tailscale:
704 // - No authentication required for downloading from any repo
705
706 let claims = VerifyClaimsInput {
707 specific_claims: common::SpecificClaims::BatchApi(operation),
708 repo_path,
709 };
710 if verify_claims(conf, &claims, headers)? {
711 return Ok(Trusted(true));
712 }
713
714 let trusted = forwarded_for_trusted_host(headers, &conf.trusted_forwarded_hosts)?;
715 if operation != common::Operation::Download {
716 if trusted {
717 return Err(make_error_resp(
718 StatusCode::FORBIDDEN,
719 "Authentication required to upload",
720 ));
721 }
722 return Err(REPO_NOT_FOUND);
723 }
724 if trusted {
725 return Ok(Trusted(true));
726 }
727
728 if public {
729 Ok(Trusted(false))
730 } else {
731 Err(REPO_NOT_FOUND)
732 }
733}
734
735fn repo_exists(name: &str) -> bool {
736 let Ok(metadata) = std::fs::metadata(name) else {
737 return false;
738 };
739 metadata.is_dir()
740}
741
742fn is_repo_public(name: &str) -> Option<bool> {
743 if !repo_exists(name) {
744 return None;
745 }
746 std::fs::metadata(format!("{name}/git-daemon-export-ok"))
747 .ok()?
748 .is_file()
749 .into()
750}
751
752async fn batch(
753 State(state): State<Arc<AppState>>,
754 headers: HeaderMap,
755 RepositoryName(repo): RepositoryName,
756 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
757) -> Response {
758 let Some(public) = is_repo_public(&repo) else {
759 return REPO_NOT_FOUND.into_response();
760 };
761 let Trusted(trusted) = match authorize_batch(
762 &state.authz_conf,
763 &repo,
764 public,
765 payload.operation,
766 &headers,
767 ) {
768 Ok(authn) => authn,
769 Err(e) => return e.into_response(),
770 };
771
772 if !headers
773 .get_all("Accept")
774 .iter()
775 .filter_map(|v| v.to_str().ok())
776 .any(is_git_lfs_json_mimetype)
777 {
778 let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types");
779 return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response();
780 }
781
782 if payload.hash_algo != HashAlgo::Sha256 {
783 let message = "Unsupported hashing algorithm specified";
784 return make_error_resp(StatusCode::CONFLICT, message).into_response();
785 }
786 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) {
787 let message = "Unsupported transfer adapter specified (supported: basic)";
788 return make_error_resp(StatusCode::CONFLICT, message).into_response();
789 }
790
791 let mut resp = BatchResponse {
792 transfer: TransferAdapter::Basic,
793 objects: vec![],
794 hash_algo: HashAlgo::Sha256,
795 };
796 for obj in payload.objects {
797 match payload.operation {
798 common::Operation::Download => resp
799 .objects
800 .push(handle_download_object(&state, &repo, &obj, trusted).await),
801 common::Operation::Upload => {
802 if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await {
803 resp.objects.push(obj_resp);
804 }
805 }
806 };
807 }
808 GitLfsJson(Json(resp)).into_response()
809}
810
811#[derive(Deserialize, Copy, Clone)]
812#[serde(remote = "Self")]
813struct FileParams {
814 oid0: HexByte,
815 oid1: HexByte,
816 oid: common::Oid,
817}
818
819impl<'de> Deserialize<'de> for FileParams {
820 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
821 where
822 D: serde::Deserializer<'de>,
823 {
824 let unchecked @ FileParams {
825 oid0: HexByte(oid0),
826 oid1: HexByte(oid1),
827 oid,
828 } = FileParams::deserialize(deserializer)?;
829 if oid0 != oid.as_bytes()[0] {
830 return Err(de::Error::custom(
831 "first OID path part does not match first byte of full OID",
832 ));
833 }
834 if oid1 != oid.as_bytes()[1] {
835 return Err(de::Error::custom(
836 "second OID path part does not match first byte of full OID",
837 ));
838 }
839 Ok(unchecked)
840 }
841}
842
843pub struct VerifyClaimsInput<'a> {
844 pub specific_claims: common::SpecificClaims,
845 pub repo_path: &'a str,
846}
847
848fn verify_claims(
849 conf: &AuthorizationConfig,
850 claims: &VerifyClaimsInput,
851 headers: &HeaderMap,
852) -> Result<bool, GitLfsErrorResponse<'static>> {
853 const INVALID_AUTHZ_HEADER: GitLfsErrorResponse =
854 make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header");
855
856 if let Some(authz) = headers.get(header::AUTHORIZATION) {
857 if let Ok(authz) = authz.to_str() {
858 if let Some(val) = authz.strip_prefix("Gitolfs3-Hmac-Sha256 ") {
859 let (tag, expires_at) = val.split_once(' ').ok_or(INVALID_AUTHZ_HEADER)?;
860 let tag: common::Digest<32> = tag.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
861 let expires_at: i64 = expires_at.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
862 let expires_at =
863 DateTime::<Utc>::from_timestamp(expires_at, 0).ok_or(INVALID_AUTHZ_HEADER)?;
864 let Some(expected_tag) = common::generate_tag(
865 common::Claims {
866 specific_claims: claims.specific_claims,
867 repo_path: claims.repo_path,
868 expires_at,
869 },
870 &conf.key,
871 ) else {
872 return Err(make_error_resp(
873 StatusCode::INTERNAL_SERVER_ERROR,
874 "Internal server error",
875 ));
876 };
877 if tag == expected_tag {
878 return Ok(true);
879 }
880 }
881 }
882 return Err(INVALID_AUTHZ_HEADER);
883 }
884 Ok(false)
885}
886
887fn authorize_get(
888 conf: &AuthorizationConfig,
889 repo_path: &str,
890 oid: common::Oid,
891 headers: &HeaderMap,
892) -> Result<(), GitLfsErrorResponse<'static>> {
893 let claims = VerifyClaimsInput {
894 specific_claims: common::SpecificClaims::Download(oid),
895 repo_path,
896 };
897 if !verify_claims(conf, &claims, headers)? {
898 return Err(make_error_resp(
899 StatusCode::UNAUTHORIZED,
900 "Repository not found",
901 ));
902 }
903 Ok(())
904}
905
906async fn obj_download(
907 State(state): State<Arc<AppState>>,
908 headers: HeaderMap,
909 RepositoryName(repo): RepositoryName,
910 Path(FileParams { oid0, oid1, oid }): Path<FileParams>,
911) -> Response {
912 if let Err(e) = authorize_get(&state.authz_conf, &repo, oid, &headers) {
913 return e.into_response();
914 }
915
916 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, oid);
917 let result = match state
918 .s3_client
919 .get_object()
920 .bucket(&state.s3_bucket)
921 .key(full_path)
922 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
923 .send()
924 .await
925 {
926 Ok(result) => result,
927 Err(e) => {
928 println!("Failed to GetObject (repo {repo}, OID {oid}): {e}");
929 return (
930 StatusCode::INTERNAL_SERVER_ERROR,
931 "Failed to query object information",
932 )
933 .into_response();
934 }
935 };
936
937 let mut headers = header::HeaderMap::new();
938 if let Some(content_type) = result.content_type {
939 let Ok(header_value) = content_type.try_into() else {
940 return (
941 StatusCode::INTERNAL_SERVER_ERROR,
942 "Object has invalid content type",
943 )
944 .into_response();
945 };
946 headers.insert(header::CONTENT_TYPE, header_value);
947 }
948 if let Some(content_length) = result.content_length {
949 headers.insert(header::CONTENT_LENGTH, content_length.into());
950 }
951
952 let async_read = result.body.into_async_read();
953 let stream = tokio_util::io::ReaderStream::new(async_read);
954 let body = axum::body::Body::from_stream(stream);
955
956 (headers, body).into_response()
957}
958
959#[test]
960fn test_mimetype() {
961 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json"));
962 assert!(!is_git_lfs_json_mimetype("application/vnd.git-lfs"));
963 assert!(!is_git_lfs_json_mimetype("application/json"));
964 assert!(is_git_lfs_json_mimetype(
965 "application/vnd.git-lfs+json; charset=utf-8"
966 ));
967 assert!(is_git_lfs_json_mimetype(
968 "application/vnd.git-lfs+json; charset=UTF-8"
969 ));
970 assert!(!is_git_lfs_json_mimetype(
971 "application/vnd.git-lfs+json; charset=ISO-8859-1"
972 ));
973}
974
975#[test]
976fn test_deserialize() {
977 let json = r#"{"operation":"upload","objects":[{"oid":"8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c","size":170904}],
978 "transfers":["lfs-standalone-file","basic","ssh"],"ref":{"name":"refs/heads/main"},"hash_algo":"sha256"}"#;
979 let expected = BatchRequest {
980 operation: common::Operation::Upload,
981 objects: vec![BatchRequestObject {
982 oid: "8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c"
983 .parse()
984 .unwrap(),
985 size: 170904,
986 }],
987 transfers: vec![
988 TransferAdapter::Unknown,
989 TransferAdapter::Basic,
990 TransferAdapter::Unknown,
991 ],
992 hash_algo: HashAlgo::Sha256,
993 };
994 assert_eq!(
995 serde_json::from_str::<BatchRequest>(json).unwrap(),
996 expected
997 );
998}
999
1000#[test]
1001fn test_validate_claims() {
1002 let key = "00232f7a019bd34e3921ee6c5f04caf48a4489d1be5d1999038950a7054e0bfea369ce2becc0f13fd3c69f8af2384a25b7ac2d52eb52c33722f3c00c50d4c9c2";
1003 let key: common::Key = key.parse().unwrap();
1004
1005 let claims = common::Claims {
1006 expires_at: Utc::now() + std::time::Duration::from_secs(5 * 60),
1007 repo_path: "lfs-test.git",
1008 specific_claims: common::SpecificClaims::BatchApi(common::Operation::Download),
1009 };
1010 let tag = common::generate_tag(claims, &key).unwrap();
1011 let header_value = format!(
1012 "Gitolfs3-Hmac-Sha256 {tag} {}",
1013 claims.expires_at.timestamp()
1014 );
1015
1016 let conf = AuthorizationConfig {
1017 key,
1018 trusted_forwarded_hosts: HashSet::new(),
1019 };
1020 let verification_claims = VerifyClaimsInput {
1021 repo_path: claims.repo_path,
1022 specific_claims: claims.specific_claims,
1023 };
1024 let mut headers = HeaderMap::new();
1025 headers.insert(header::AUTHORIZATION, header_value.try_into().unwrap());
1026
1027 assert!(verify_claims(&conf, &verification_claims, &headers).unwrap());
1028}