aboutsummaryrefslogtreecommitdiffstats
path: root/server/src
diff options
context:
space:
mode:
authorLibravatar Rutger Broekhoff2024-04-29 19:18:56 +0200
committerLibravatar Rutger Broekhoff2024-04-29 19:18:56 +0200
commita7f9c8de31231b9fd9c67c57db659f7b01f1a3b0 (patch)
tree040c518d243769c1920b8201a07a626f4e5934cf /server/src
parent80c4d49ad7f590f5af1304939fdaea7baf13142e (diff)
downloadgitolfs3-a7f9c8de31231b9fd9c67c57db659f7b01f1a3b0.tar.gz
gitolfs3-a7f9c8de31231b9fd9c67c57db659f7b01f1a3b0.zip
Rename crates (and therefore commands)
Diffstat (limited to 'server/src')
-rw-r--r--server/src/main.rs1151
1 files changed, 0 insertions, 1151 deletions
diff --git a/server/src/main.rs b/server/src/main.rs
deleted file mode 100644
index 7371c0d..0000000
--- a/server/src/main.rs
+++ /dev/null
@@ -1,1151 +0,0 @@
1use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput};
2use axum::{
3 async_trait,
4 extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State},
5 http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
6 response::{IntoResponse, Response},
7 routing::{get, post},
8 Extension, Json, Router, ServiceExt,
9};
10use base64::prelude::*;
11use chrono::{DateTime, Utc};
12use serde::{
13 de::{self, DeserializeOwned},
14 Deserialize, Serialize,
15};
16use std::{
17 collections::{HashMap, HashSet},
18 process::ExitCode,
19 sync::Arc,
20};
21use tokio::io::AsyncWriteExt;
22use tower::Layer;
23
24#[tokio::main]
25async fn main() -> ExitCode {
26 tracing_subscriber::fmt::init();
27
28 let conf = match Config::load() {
29 Ok(conf) => conf,
30 Err(e) => {
31 println!("Error: {e}");
32 return ExitCode::from(2);
33 }
34 };
35
36 let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
37 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
38
39 let resetter_dl_limiter = dl_limiter.clone();
40 tokio::spawn(async move {
41 loop {
42 println!("Resetting download counter in one hour");
43 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
44 println!("Resetting download counter");
45 resetter_dl_limiter.lock().await.reset().await;
46 }
47 });
48
49 let shared_state = Arc::new(AppState {
50 s3_client: conf.s3_client,
51 s3_bucket: conf.s3_bucket,
52 authz_conf: conf.authz_conf,
53 base_url: conf.base_url,
54 dl_limiter,
55 });
56 let app = Router::new()
57 .route("/batch", post(batch))
58 .route("/:oid0/:oid1/:oid", get(obj_download))
59 .with_state(shared_state);
60
61 let middleware = axum::middleware::map_request(rewrite_url);
62 let app_with_middleware = middleware.layer(app);
63
64 let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await {
65 Ok(listener) => listener,
66 Err(e) => {
67 println!("Failed to listen: {e}");
68 return ExitCode::FAILURE;
69 }
70 };
71
72 match axum::serve(listener, app_with_middleware.into_make_service()).await {
73 Ok(_) => ExitCode::SUCCESS,
74 Err(e) => {
75 println!("Error serving: {e}");
76 ExitCode::FAILURE
77 }
78 }
79}
80
81#[derive(Clone)]
82struct RepositoryName(String);
83
84struct RepositoryNameRejection;
85
86impl IntoResponse for RepositoryNameRejection {
87 fn into_response(self) -> Response {
88 (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response()
89 }
90}
91
92#[async_trait]
93impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
94 type Rejection = RepositoryNameRejection;
95
96 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
97 let Ok(Extension(repo_name)) = Extension::<Self>::from_request_parts(parts, state).await
98 else {
99 return Err(RepositoryNameRejection);
100 };
101 Ok(repo_name)
102 }
103}
104
105async fn rewrite_url<B>(
106 mut req: axum::http::Request<B>,
107) -> Result<axum::http::Request<B>, StatusCode> {
108 let uri = req.uri();
109 let original_uri = OriginalUri(uri.clone());
110
111 let Some(path_and_query) = uri.path_and_query() else {
112 // L @ no path & query
113 return Err(StatusCode::BAD_REQUEST);
114 };
115 let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else {
116 return Err(StatusCode::NOT_FOUND);
117 };
118 let repo = repo
119 .trim_start_matches('/')
120 .trim_end_matches('/')
121 .to_string();
122 if !path.starts_with('/') || !repo.ends_with(".git") {
123 return Err(StatusCode::NOT_FOUND);
124 }
125
126 let mut parts = uri.clone().into_parts();
127 parts.path_and_query = match path_and_query.query() {
128 None => path.try_into().ok(),
129 Some(q) => format!("{path}?{q}").try_into().ok(),
130 };
131 let Ok(new_uri) = Uri::from_parts(parts) else {
132 return Err(StatusCode::INTERNAL_SERVER_ERROR);
133 };
134
135 *req.uri_mut() = new_uri;
136 req.extensions_mut().insert(original_uri);
137 req.extensions_mut().insert(RepositoryName(repo));
138
139 Ok(req)
140}
141
142struct AppState {
143 s3_client: aws_sdk_s3::Client,
144 s3_bucket: String,
145 authz_conf: AuthorizationConfig,
146 // Should not end with a slash.
147 base_url: String,
148 dl_limiter: Arc<tokio::sync::Mutex<DownloadLimiter>>,
149}
150
151struct Env {
152 s3_access_key_id: String,
153 s3_secret_access_key: String,
154 s3_bucket: String,
155 s3_region: String,
156 s3_endpoint: String,
157 base_url: String,
158 key_path: String,
159 listen_host: String,
160 listen_port: String,
161 download_limit: String,
162 trusted_forwarded_hosts: String,
163}
164
165fn require_env(name: &str) -> Result<String, String> {
166 std::env::var(name)
167 .map_err(|_| format!("environment variable {name} should be defined and valid"))
168}
169
170impl Env {
171 fn load() -> Result<Env, String> {
172 Ok(Env {
173 s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?,
174 s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?,
175 s3_region: require_env("GITOLFS3_S3_REGION")?,
176 s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?,
177 s3_bucket: require_env("GITOLFS3_S3_BUCKET")?,
178 base_url: require_env("GITOLFS3_BASE_URL")?,
179 key_path: require_env("GITOLFS3_KEY_PATH")?,
180 listen_host: require_env("GITOLFS3_LISTEN_HOST")?,
181 listen_port: require_env("GITOLFS3_LISTEN_PORT")?,
182 download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?,
183 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS")
184 .unwrap_or_default(),
185 })
186 }
187}
188
189fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
190 let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?;
191 let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?;
192
193 let credentials = aws_sdk_s3::config::Credentials::new(
194 access_key_id,
195 secret_access_key,
196 None,
197 None,
198 "gitolfs3-env",
199 );
200 let config = aws_config::SdkConfig::builder()
201 .behavior_version(aws_config::BehaviorVersion::latest())
202 .region(aws_config::Region::new(env.s3_region.clone()))
203 .endpoint_url(&env.s3_endpoint)
204 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(
205 credentials,
206 ))
207 .build();
208 Ok(aws_sdk_s3::Client::new(&config))
209}
210
211struct Config {
212 listen_addr: (String, u16),
213 base_url: String,
214 authz_conf: AuthorizationConfig,
215 s3_client: aws_sdk_s3::Client,
216 s3_bucket: String,
217 download_limit: u64,
218}
219
220impl Config {
221 fn load() -> Result<Self, String> {
222 let env = match Env::load() {
223 Ok(env) => env,
224 Err(e) => return Err(format!("failed to load configuration: {e}")),
225 };
226
227 let s3_client = match get_s3_client(&env) {
228 Ok(s3_client) => s3_client,
229 Err(e) => return Err(format!("failed to create S3 client: {e}")),
230 };
231 let key = match common::load_key(&env.key_path) {
232 Ok(key) => key,
233 Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
234 };
235
236 let trusted_forwarded_hosts: HashSet<String> = env
237 .trusted_forwarded_hosts
238 .split(',')
239 .map(|s| s.to_owned())
240 .filter(|s| !s.is_empty())
241 .collect();
242 let base_url = env.base_url.trim_end_matches('/').to_string();
243
244 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
245 return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
246 };
247 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
248 return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
249 };
250
251 Ok(Self {
252 listen_addr: (env.listen_host, listen_port),
253 base_url,
254 authz_conf: AuthorizationConfig {
255 key,
256 trusted_forwarded_hosts,
257 },
258 s3_client,
259 s3_bucket: env.s3_bucket,
260 download_limit,
261 })
262 }
263}
264
265#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
266enum TransferAdapter {
267 #[serde(rename = "basic")]
268 Basic,
269 #[serde(other)]
270 Unknown,
271}
272
273#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
274enum HashAlgo {
275 #[serde(rename = "sha256")]
276 Sha256,
277 #[serde(other)]
278 Unknown,
279}
280
281impl Default for HashAlgo {
282 fn default() -> Self {
283 Self::Sha256
284 }
285}
286
287#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
288struct BatchRequestObject {
289 oid: common::Oid,
290 size: i64,
291}
292
293#[derive(Debug, Serialize, Deserialize, Clone)]
294struct BatchRef {
295 name: String,
296}
297
298fn default_transfers() -> Vec<TransferAdapter> {
299 vec![TransferAdapter::Basic]
300}
301
302#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
303struct BatchRequest {
304 operation: common::Operation,
305 #[serde(default = "default_transfers")]
306 transfers: Vec<TransferAdapter>,
307 objects: Vec<BatchRequestObject>,
308 #[serde(default)]
309 hash_algo: HashAlgo,
310}
311
312#[derive(Debug, Clone)]
313struct GitLfsJson<T>(Json<T>);
314
315const LFS_MIME: &str = "application/vnd.git-lfs+json";
316
317enum GitLfsJsonRejection {
318 Json(rejection::JsonRejection),
319 MissingGitLfsJsonContentType,
320}
321
322impl IntoResponse for GitLfsJsonRejection {
323 fn into_response(self) -> Response {
324 match self {
325 Self::Json(rej) => rej.into_response(),
326 Self::MissingGitLfsJsonContentType => make_error_resp(
327 StatusCode::UNSUPPORTED_MEDIA_TYPE,
328 &format!("Expected request with `Content-Type: {LFS_MIME}`"),
329 )
330 .into_response(),
331 }
332 }
333}
334
335fn is_git_lfs_json_mimetype(mimetype: &str) -> bool {
336 let Ok(mime) = mimetype.parse::<mime::Mime>() else {
337 return false;
338 };
339 if mime.type_() != mime::APPLICATION
340 || mime.subtype() != "vnd.git-lfs"
341 || mime.suffix() != Some(mime::JSON)
342 {
343 return false;
344 }
345 match mime.get_param(mime::CHARSET) {
346 Some(mime::UTF_8) | None => true,
347 Some(_) => false,
348 }
349}
350
351fn has_git_lfs_json_content_type(req: &Request) -> bool {
352 let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
353 return false;
354 };
355 let Ok(content_type) = content_type.to_str() else {
356 return false;
357 };
358 is_git_lfs_json_mimetype(content_type)
359}
360
361#[async_trait]
362impl<T, S> FromRequest<S> for GitLfsJson<T>
363where
364 T: DeserializeOwned,
365 S: Send + Sync,
366{
367 type Rejection = GitLfsJsonRejection;
368
369 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
370 if !has_git_lfs_json_content_type(&req) {
371 return Err(GitLfsJsonRejection::MissingGitLfsJsonContentType);
372 }
373 Json::<T>::from_request(req, state)
374 .await
375 .map(GitLfsJson)
376 .map_err(GitLfsJsonRejection::Json)
377 }
378}
379
380impl<T: Serialize> IntoResponse for GitLfsJson<T> {
381 fn into_response(self) -> Response {
382 let GitLfsJson(json) = self;
383 let mut resp = json.into_response();
384 resp.headers_mut().insert(
385 header::CONTENT_TYPE,
386 HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"),
387 );
388 resp
389 }
390}
391
392#[derive(Debug, Serialize)]
393struct GitLfsErrorData<'a> {
394 message: &'a str,
395}
396
397type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>);
398
399const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse {
400 (code, GitLfsJson(Json(GitLfsErrorData { message })))
401}
402
403#[derive(Debug, Serialize, Clone)]
404struct BatchResponseObjectAction {
405 href: String,
406 #[serde(skip_serializing_if = "HashMap::is_empty")]
407 header: HashMap<String, String>,
408 expires_at: DateTime<Utc>,
409}
410
411#[derive(Default, Debug, Serialize, Clone)]
412struct BatchResponseObjectActions {
413 #[serde(skip_serializing_if = "Option::is_none")]
414 upload: Option<BatchResponseObjectAction>,
415 #[serde(skip_serializing_if = "Option::is_none")]
416 download: Option<BatchResponseObjectAction>,
417 #[serde(skip_serializing_if = "Option::is_none")]
418 verify: Option<BatchResponseObjectAction>,
419}
420
421#[derive(Debug, Clone, Serialize)]
422struct BatchResponseObjectError {
423 code: u16,
424 message: String,
425}
426
427#[derive(Debug, Serialize, Clone)]
428struct BatchResponseObject {
429 oid: common::Oid,
430 size: i64,
431 #[serde(skip_serializing_if = "Option::is_none")]
432 authenticated: Option<bool>,
433 actions: BatchResponseObjectActions,
434 #[serde(skip_serializing_if = "Option::is_none")]
435 error: Option<BatchResponseObjectError>,
436}
437
438impl BatchResponseObject {
439 fn error(obj: &BatchRequestObject, code: StatusCode, message: String) -> BatchResponseObject {
440 BatchResponseObject {
441 oid: obj.oid,
442 size: obj.size,
443 authenticated: None,
444 actions: Default::default(),
445 error: Some(BatchResponseObjectError {
446 code: code.as_u16(),
447 message,
448 }),
449 }
450 }
451}
452
453#[derive(Debug, Serialize, Clone)]
454struct BatchResponse {
455 transfer: TransferAdapter,
456 objects: Vec<BatchResponseObject>,
457 hash_algo: HashAlgo,
458}
459
460fn validate_checksum(oid: common::Oid, obj: &HeadObjectOutput) -> bool {
461 if let Some(checksum) = obj.checksum_sha256() {
462 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
463 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
464 return common::Oid::from(checksum32b) == oid;
465 }
466 }
467 }
468 true
469}
470
471fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool {
472 if let Some(length) = obj.content_length() {
473 return length == expected;
474 }
475 true
476}
477
478async fn handle_upload_object(
479 state: &AppState,
480 repo: &str,
481 obj: &BatchRequestObject,
482) -> Option<BatchResponseObject> {
483 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
484 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
485
486 match state
487 .s3_client
488 .head_object()
489 .bucket(&state.s3_bucket)
490 .key(full_path.clone())
491 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
492 .send()
493 .await
494 {
495 Ok(result) => {
496 if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) {
497 return None;
498 }
499 }
500 Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {}
501 Err(e) => {
502 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
503 return Some(BatchResponseObject::error(
504 obj,
505 StatusCode::INTERNAL_SERVER_ERROR,
506 "Failed to query object information".to_string(),
507 ));
508 }
509 };
510
511 let expires_in = std::time::Duration::from_secs(5 * 60);
512 let expires_at = Utc::now() + expires_in;
513
514 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
515 return Some(BatchResponseObject::error(
516 obj,
517 StatusCode::INTERNAL_SERVER_ERROR,
518 "Failed to generate upload URL".to_string(),
519 ));
520 };
521 let Ok(presigned) = state
522 .s3_client
523 .put_object()
524 .bucket(&state.s3_bucket)
525 .key(full_path)
526 .checksum_sha256(obj.oid.to_string())
527 .content_length(obj.size)
528 .presigned(config)
529 .await
530 else {
531 return Some(BatchResponseObject::error(
532 obj,
533 StatusCode::INTERNAL_SERVER_ERROR,
534 "Failed to generate upload URL".to_string(),
535 ));
536 };
537 Some(BatchResponseObject {
538 oid: obj.oid,
539 size: obj.size,
540 authenticated: Some(true),
541 actions: BatchResponseObjectActions {
542 upload: Some(BatchResponseObjectAction {
543 header: presigned
544 .headers()
545 .map(|(k, v)| (k.to_owned(), v.to_owned()))
546 .collect(),
547 expires_at,
548 href: presigned.uri().to_string(),
549 }),
550 ..Default::default()
551 },
552 error: None,
553 })
554}
555
556async fn handle_download_object(
557 state: &AppState,
558 repo: &str,
559 obj: &BatchRequestObject,
560 trusted: bool,
561) -> BatchResponseObject {
562 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
563 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
564
565 let result = match state
566 .s3_client
567 .head_object()
568 .bucket(&state.s3_bucket)
569 .key(&full_path)
570 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
571 .send()
572 .await
573 {
574 Ok(result) => result,
575 Err(e) => {
576 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
577 return BatchResponseObject::error(
578 obj,
579 StatusCode::INTERNAL_SERVER_ERROR,
580 "Failed to query object information".to_string(),
581 );
582 }
583 };
584
585 // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :)
586 if !validate_checksum(obj.oid, &result) {
587 return BatchResponseObject::error(
588 obj,
589 StatusCode::UNPROCESSABLE_ENTITY,
590 "Object corrupted".to_string(),
591 );
592 }
593 if !validate_size(obj.size, &result) {
594 return BatchResponseObject::error(
595 obj,
596 StatusCode::UNPROCESSABLE_ENTITY,
597 "Incorrect size specified (or object corrupted)".to_string(),
598 );
599 }
600
601 let expires_in = std::time::Duration::from_secs(5 * 60);
602 let expires_at = Utc::now() + expires_in;
603
604 if trusted {
605 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
606 return BatchResponseObject::error(
607 obj,
608 StatusCode::INTERNAL_SERVER_ERROR,
609 "Failed to generate upload URL".to_string(),
610 );
611 };
612 let Ok(presigned) = state
613 .s3_client
614 .get_object()
615 .bucket(&state.s3_bucket)
616 .key(full_path)
617 .presigned(config)
618 .await
619 else {
620 return BatchResponseObject::error(
621 obj,
622 StatusCode::INTERNAL_SERVER_ERROR,
623 "Failed to generate upload URL".to_string(),
624 );
625 };
626 return BatchResponseObject {
627 oid: obj.oid,
628 size: obj.size,
629 authenticated: Some(true),
630 actions: BatchResponseObjectActions {
631 download: Some(BatchResponseObjectAction {
632 header: presigned
633 .headers()
634 .map(|(k, v)| (k.to_owned(), v.to_owned()))
635 .collect(),
636 expires_at,
637 href: presigned.uri().to_string(),
638 }),
639 ..Default::default()
640 },
641 error: None,
642 };
643 }
644
645 if let Some(content_length) = result.content_length() {
646 if content_length > 0 {
647 match state
648 .dl_limiter
649 .lock()
650 .await
651 .request(content_length as u64)
652 .await
653 {
654 Ok(true) => {}
655 Ok(false) => {
656 return BatchResponseObject::error(
657 obj,
658 StatusCode::SERVICE_UNAVAILABLE,
659 "Public LFS downloads temporarily unavailable".to_string(),
660 );
661 }
662 Err(e) => {
663 println!("Failed to request {content_length} bytes from download limiter: {e}");
664 return BatchResponseObject::error(
665 obj,
666 StatusCode::INTERNAL_SERVER_ERROR,
667 "Internal server error".to_string(),
668 );
669 }
670 }
671 }
672 }
673
674 let Some(tag) = common::generate_tag(
675 common::Claims {
676 specific_claims: common::SpecificClaims::Download(obj.oid),
677 repo_path: repo,
678 expires_at,
679 },
680 &state.authz_conf.key,
681 ) else {
682 return BatchResponseObject::error(
683 obj,
684 StatusCode::INTERNAL_SERVER_ERROR,
685 "Internal server error".to_string(),
686 );
687 };
688
689 let upload_path = format!(
690 "{repo}/info/lfs/objects/{}/{}/{}",
691 common::HexByte(obj.oid[0]),
692 common::HexByte(obj.oid[1]),
693 obj.oid,
694 );
695
696 BatchResponseObject {
697 oid: obj.oid,
698 size: obj.size,
699 authenticated: Some(true),
700 actions: BatchResponseObjectActions {
701 download: Some(BatchResponseObjectAction {
702 header: {
703 let mut map = HashMap::new();
704 map.insert(
705 "Authorization".to_string(),
706 format!("Gitolfs3-Hmac-Sha256 {tag} {}", expires_at.timestamp()),
707 );
708 map
709 },
710 expires_at,
711 href: format!("{}/{upload_path}", state.base_url),
712 }),
713 ..Default::default()
714 },
715 error: None,
716 }
717}
718
719struct AuthorizationConfig {
720 trusted_forwarded_hosts: HashSet<String>,
721 key: common::Key,
722}
723
724struct Trusted(bool);
725
726fn forwarded_from_trusted_host(
727 headers: &HeaderMap,
728 trusted: &HashSet<String>,
729) -> Result<bool, GitLfsErrorResponse<'static>> {
730 if let Some(forwarded_host) = headers.get("X-Forwarded-Host") {
731 if let Ok(forwarded_host) = forwarded_host.to_str() {
732 if trusted.contains(forwarded_host) {
733 return Ok(true);
734 }
735 } else {
736 return Err(make_error_resp(
737 StatusCode::NOT_FOUND,
738 "Invalid X-Forwarded-Host header",
739 ));
740 }
741 }
742 Ok(false)
743}
744
745const REPO_NOT_FOUND: GitLfsErrorResponse =
746 make_error_resp(StatusCode::NOT_FOUND, "Repository not found");
747
748fn authorize_batch(
749 conf: &AuthorizationConfig,
750 repo_path: &str,
751 public: bool,
752 operation: common::Operation,
753 headers: &HeaderMap,
754) -> Result<Trusted, GitLfsErrorResponse<'static>> {
755 // - No authentication required for downloading exported repos
756 // - When authenticated:
757 // - Download / upload over presigned URLs
758 // - When accessing over Tailscale:
759 // - No authentication required for downloading from any repo
760
761 let claims = VerifyClaimsInput {
762 specific_claims: common::SpecificClaims::BatchApi(operation),
763 repo_path,
764 };
765 if !verify_claims(conf, &claims, headers)? {
766 return authorize_batch_unauthenticated(conf, public, operation, headers);
767 }
768 Ok(Trusted(true))
769}
770
771fn authorize_batch_unauthenticated(
772 conf: &AuthorizationConfig,
773 public: bool,
774 operation: common::Operation,
775 headers: &HeaderMap,
776) -> Result<Trusted, GitLfsErrorResponse<'static>> {
777 let trusted = forwarded_from_trusted_host(headers, &conf.trusted_forwarded_hosts)?;
778 match operation {
779 common::Operation::Upload => {
780 // Trusted users can clone all repositories (by virtue of accessing the server via a
781 // trusted network). However, they can not push without proper authentication. Untrusted
782 // users who are also not authenticated should not need to know which repositories exists.
783 // Therefore, we tell untrusted && unauthenticated users that the repo doesn't exist, but
784 // tell trusted users that they need to authenticate.
785 if !trusted {
786 return Err(REPO_NOT_FOUND);
787 }
788 Err(make_error_resp(
789 StatusCode::FORBIDDEN,
790 "Authentication required to upload",
791 ))
792 }
793 common::Operation::Download => {
794 // Again, trusted users can see all repos. For untrusted users, we first need to check
795 // whether the repo is public before we authorize. If the user is untrusted and the
796 // repo isn't public, we just act like it doesn't even exist.
797 if !trusted {
798 if !public {
799 return Err(REPO_NOT_FOUND);
800 }
801 return Ok(Trusted(false));
802 }
803 Ok(Trusted(true))
804 }
805 }
806}
807
808fn repo_exists(name: &str) -> bool {
809 let Ok(metadata) = std::fs::metadata(name) else {
810 return false;
811 };
812 metadata.is_dir()
813}
814
815fn is_repo_public(name: &str) -> Option<bool> {
816 if !repo_exists(name) {
817 return None;
818 }
819 match std::fs::metadata(format!("{name}/git-daemon-export-ok")) {
820 Ok(metadata) if metadata.is_file() => Some(true),
821 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some(false),
822 _ => None,
823 }
824}
825
826async fn batch(
827 State(state): State<Arc<AppState>>,
828 headers: HeaderMap,
829 RepositoryName(repo): RepositoryName,
830 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
831) -> Response {
832 let Some(public) = is_repo_public(&repo) else {
833 return REPO_NOT_FOUND.into_response();
834 };
835 let Trusted(trusted) = match authorize_batch(
836 &state.authz_conf,
837 &repo,
838 public,
839 payload.operation,
840 &headers,
841 ) {
842 Ok(authn) => authn,
843 Err(e) => return e.into_response(),
844 };
845
846 if !headers
847 .get_all("Accept")
848 .iter()
849 .filter_map(|v| v.to_str().ok())
850 .any(is_git_lfs_json_mimetype)
851 {
852 let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types");
853 return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response();
854 }
855
856 if payload.hash_algo != HashAlgo::Sha256 {
857 let message = "Unsupported hashing algorithm specified";
858 return make_error_resp(StatusCode::CONFLICT, message).into_response();
859 }
860 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) {
861 let message = "Unsupported transfer adapter specified (supported: basic)";
862 return make_error_resp(StatusCode::CONFLICT, message).into_response();
863 }
864
865 let mut resp = BatchResponse {
866 transfer: TransferAdapter::Basic,
867 objects: vec![],
868 hash_algo: HashAlgo::Sha256,
869 };
870 for obj in payload.objects {
871 match payload.operation {
872 common::Operation::Download => resp
873 .objects
874 .push(handle_download_object(&state, &repo, &obj, trusted).await),
875 common::Operation::Upload => {
876 if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await {
877 resp.objects.push(obj_resp);
878 }
879 }
880 };
881 }
882 GitLfsJson(Json(resp)).into_response()
883}
884
885#[derive(Deserialize, Copy, Clone)]
886#[serde(remote = "Self")]
887struct FileParams {
888 oid0: common::HexByte,
889 oid1: common::HexByte,
890 oid: common::Oid,
891}
892
893impl<'de> Deserialize<'de> for FileParams {
894 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
895 where
896 D: serde::Deserializer<'de>,
897 {
898 let unchecked @ FileParams {
899 oid0: common::HexByte(oid0),
900 oid1: common::HexByte(oid1),
901 oid,
902 } = FileParams::deserialize(deserializer)?;
903 if oid0 != oid.as_bytes()[0] {
904 return Err(de::Error::custom(
905 "first OID path part does not match first byte of full OID",
906 ));
907 }
908 if oid1 != oid.as_bytes()[1] {
909 return Err(de::Error::custom(
910 "second OID path part does not match first byte of full OID",
911 ));
912 }
913 Ok(unchecked)
914 }
915}
916
917pub struct VerifyClaimsInput<'a> {
918 pub specific_claims: common::SpecificClaims,
919 pub repo_path: &'a str,
920}
921
922fn verify_claims(
923 conf: &AuthorizationConfig,
924 claims: &VerifyClaimsInput,
925 headers: &HeaderMap,
926) -> Result<bool, GitLfsErrorResponse<'static>> {
927 const INVALID_AUTHZ_HEADER: GitLfsErrorResponse =
928 make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header");
929
930 let Some(authz) = headers.get(header::AUTHORIZATION) else {
931 return Ok(false);
932 };
933 let authz = authz.to_str().map_err(|_| INVALID_AUTHZ_HEADER)?;
934 let val = authz
935 .strip_prefix("Gitolfs3-Hmac-Sha256 ")
936 .ok_or(INVALID_AUTHZ_HEADER)?;
937 let (tag, expires_at) = val.split_once(' ').ok_or(INVALID_AUTHZ_HEADER)?;
938 let tag: common::Digest<32> = tag.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
939 let expires_at: i64 = expires_at.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
940 let expires_at = DateTime::<Utc>::from_timestamp(expires_at, 0).ok_or(INVALID_AUTHZ_HEADER)?;
941 let expected_tag = common::generate_tag(
942 common::Claims {
943 specific_claims: claims.specific_claims,
944 repo_path: claims.repo_path,
945 expires_at,
946 },
947 &conf.key,
948 )
949 .ok_or_else(|| make_error_resp(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
950 if tag != expected_tag {
951 return Err(INVALID_AUTHZ_HEADER);
952 }
953
954 Ok(true)
955}
956
957fn authorize_get(
958 conf: &AuthorizationConfig,
959 repo_path: &str,
960 oid: common::Oid,
961 headers: &HeaderMap,
962) -> Result<(), GitLfsErrorResponse<'static>> {
963 let claims = VerifyClaimsInput {
964 specific_claims: common::SpecificClaims::Download(oid),
965 repo_path,
966 };
967 if !verify_claims(conf, &claims, headers)? {
968 return Err(make_error_resp(
969 StatusCode::UNAUTHORIZED,
970 "Repository not found",
971 ));
972 }
973 Ok(())
974}
975
976async fn obj_download(
977 State(state): State<Arc<AppState>>,
978 headers: HeaderMap,
979 RepositoryName(repo): RepositoryName,
980 Path(FileParams { oid0, oid1, oid }): Path<FileParams>,
981) -> Response {
982 if let Err(e) = authorize_get(&state.authz_conf, &repo, oid, &headers) {
983 return e.into_response();
984 }
985
986 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, oid);
987 let result = match state
988 .s3_client
989 .get_object()
990 .bucket(&state.s3_bucket)
991 .key(full_path)
992 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
993 .send()
994 .await
995 {
996 Ok(result) => result,
997 Err(e) => {
998 println!("Failed to GetObject (repo {repo}, OID {oid}): {e}");
999 return (
1000 StatusCode::INTERNAL_SERVER_ERROR,
1001 "Failed to query object information",
1002 )
1003 .into_response();
1004 }
1005 };
1006
1007 let mut headers = header::HeaderMap::new();
1008 if let Some(content_type) = result.content_type {
1009 let Ok(header_value) = content_type.try_into() else {
1010 return (
1011 StatusCode::INTERNAL_SERVER_ERROR,
1012 "Object has invalid content type",
1013 )
1014 .into_response();
1015 };
1016 headers.insert(header::CONTENT_TYPE, header_value);
1017 }
1018 if let Some(content_length) = result.content_length {
1019 headers.insert(header::CONTENT_LENGTH, content_length.into());
1020 }
1021
1022 let async_read = result.body.into_async_read();
1023 let stream = tokio_util::io::ReaderStream::new(async_read);
1024 let body = axum::body::Body::from_stream(stream);
1025
1026 (headers, body).into_response()
1027}
1028
1029struct DownloadLimiter {
1030 current: u64,
1031 limit: u64,
1032}
1033
1034impl DownloadLimiter {
1035 async fn new(limit: u64) -> DownloadLimiter {
1036 let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await {
1037 Ok(dlimit_str) => dlimit_str,
1038 Err(e) => {
1039 println!("Failed to read download counter, assuming 0: {e}");
1040 return DownloadLimiter { current: 0, limit };
1041 }
1042 };
1043 let current: u64 = match dlimit_str
1044 .parse()
1045 .map_err(|e| tokio::io::Error::new(tokio::io::ErrorKind::InvalidData, e))
1046 {
1047 Ok(current) => current,
1048 Err(e) => {
1049 println!("Failed to read download counter, assuming 0: {e}");
1050 return DownloadLimiter { current: 0, limit };
1051 }
1052 };
1053 DownloadLimiter { current, limit }
1054 }
1055
1056 async fn request(&mut self, n: u64) -> tokio::io::Result<bool> {
1057 if self.current + n > self.limit {
1058 return Ok(false);
1059 }
1060 self.current += n;
1061 self.write_new_count().await?;
1062 Ok(true)
1063 }
1064
1065 async fn reset(&mut self) {
1066 self.current = 0;
1067 if let Err(e) = self.write_new_count().await {
1068 println!("Failed to reset download counter: {e}");
1069 }
1070 }
1071
1072 async fn write_new_count(&self) -> tokio::io::Result<()> {
1073 let cwd = tokio::fs::File::open(std::env::current_dir()?).await?;
1074 let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?;
1075 file.write_all(self.current.to_string().as_bytes()).await?;
1076 file.sync_all().await?;
1077 tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?;
1078 cwd.sync_all().await
1079 }
1080}
1081
1082#[test]
1083fn test_mimetype() {
1084 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json"));
1085 assert!(!is_git_lfs_json_mimetype("application/vnd.git-lfs"));
1086 assert!(!is_git_lfs_json_mimetype("application/json"));
1087 assert!(is_git_lfs_json_mimetype(
1088 "application/vnd.git-lfs+json; charset=utf-8"
1089 ));
1090 assert!(is_git_lfs_json_mimetype(
1091 "application/vnd.git-lfs+json; charset=UTF-8"
1092 ));
1093 assert!(!is_git_lfs_json_mimetype(
1094 "application/vnd.git-lfs+json; charset=ISO-8859-1"
1095 ));
1096}
1097
1098#[test]
1099fn test_deserialize() {
1100 let json = r#"{"operation":"upload","objects":[{"oid":"8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c","size":170904}],
1101 "transfers":["lfs-standalone-file","basic","ssh"],"ref":{"name":"refs/heads/main"},"hash_algo":"sha256"}"#;
1102 let expected = BatchRequest {
1103 operation: common::Operation::Upload,
1104 objects: vec![BatchRequestObject {
1105 oid: "8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c"
1106 .parse()
1107 .unwrap(),
1108 size: 170904,
1109 }],
1110 transfers: vec![
1111 TransferAdapter::Unknown,
1112 TransferAdapter::Basic,
1113 TransferAdapter::Unknown,
1114 ],
1115 hash_algo: HashAlgo::Sha256,
1116 };
1117 assert_eq!(
1118 serde_json::from_str::<BatchRequest>(json).unwrap(),
1119 expected
1120 );
1121}
1122
1123#[test]
1124fn test_validate_claims() {
1125 let key = "00232f7a019bd34e3921ee6c5f04caf48a4489d1be5d1999038950a7054e0bfea369ce2becc0f13fd3c69f8af2384a25b7ac2d52eb52c33722f3c00c50d4c9c2";
1126 let key: common::Key = key.parse().unwrap();
1127
1128 let claims = common::Claims {
1129 expires_at: Utc::now() + std::time::Duration::from_secs(5 * 60),
1130 repo_path: "lfs-test.git",
1131 specific_claims: common::SpecificClaims::BatchApi(common::Operation::Download),
1132 };
1133 let tag = common::generate_tag(claims, &key).unwrap();
1134 let header_value = format!(
1135 "Gitolfs3-Hmac-Sha256 {tag} {}",
1136 claims.expires_at.timestamp()
1137 );
1138
1139 let conf = AuthorizationConfig {
1140 key,
1141 trusted_forwarded_hosts: HashSet::new(),
1142 };
1143 let verification_claims = VerifyClaimsInput {
1144 repo_path: claims.repo_path,
1145 specific_claims: claims.specific_claims,
1146 };
1147 let mut headers = HeaderMap::new();
1148 headers.insert(header::AUTHORIZATION, header_value.try_into().unwrap());
1149
1150 assert!(verify_claims(&conf, &verification_claims, &headers).unwrap());
1151}