aboutsummaryrefslogtreecommitdiffstats
path: root/gitolfs3-server/src
diff options
context:
space:
mode:
Diffstat (limited to 'gitolfs3-server/src')
-rw-r--r--gitolfs3-server/src/api.rs279
-rw-r--r--gitolfs3-server/src/authz.rs182
-rw-r--r--gitolfs3-server/src/config.rs122
-rw-r--r--gitolfs3-server/src/dlimit.rs73
-rw-r--r--gitolfs3-server/src/handler.rs455
-rw-r--r--gitolfs3-server/src/main.rs1087
6 files changed, 1128 insertions, 1070 deletions
diff --git a/gitolfs3-server/src/api.rs b/gitolfs3-server/src/api.rs
new file mode 100644
index 0000000..dba7ada
--- /dev/null
+++ b/gitolfs3-server/src/api.rs
@@ -0,0 +1,279 @@
1use std::collections::HashMap;
2
3use axum::{
4 async_trait,
5 extract::{rejection, FromRequest, FromRequestParts, Request},
6 http::{header, request::Parts, HeaderValue, StatusCode},
7 response::{IntoResponse, Response},
8 Extension, Json,
9};
10use chrono::{DateTime, Utc};
11use gitolfs3_common::{Oid, Operation};
12use serde::{de::DeserializeOwned, Deserialize, Serialize};
13
14pub const REPO_NOT_FOUND: GitLfsErrorResponse =
15 make_error_resp(StatusCode::NOT_FOUND, "Repository not found");
16
17#[derive(Clone)]
18pub struct RepositoryName(pub String);
19
20pub struct RepositoryNameRejection;
21
22impl IntoResponse for RepositoryNameRejection {
23 fn into_response(self) -> Response {
24 (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response()
25 }
26}
27
28#[async_trait]
29impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
30 type Rejection = RepositoryNameRejection;
31
32 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
33 let Ok(Extension(repo_name)) = Extension::<Self>::from_request_parts(parts, state).await
34 else {
35 return Err(RepositoryNameRejection);
36 };
37 Ok(repo_name)
38 }
39}
40
41#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
42pub enum TransferAdapter {
43 #[serde(rename = "basic")]
44 Basic,
45 #[serde(other)]
46 Unknown,
47}
48
49#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
50pub enum HashAlgo {
51 #[serde(rename = "sha256")]
52 Sha256,
53 #[serde(other)]
54 Unknown,
55}
56
57impl Default for HashAlgo {
58 fn default() -> Self {
59 Self::Sha256
60 }
61}
62
63#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
64pub struct BatchRequestObject {
65 pub oid: Oid,
66 pub size: i64,
67}
68
69#[derive(Debug, Serialize, Deserialize, Clone)]
70struct BatchRef {
71 name: String,
72}
73
74fn default_transfers() -> Vec<TransferAdapter> {
75 vec![TransferAdapter::Basic]
76}
77
78#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
79pub struct BatchRequest {
80 pub operation: Operation,
81 #[serde(default = "default_transfers")]
82 pub transfers: Vec<TransferAdapter>,
83 pub objects: Vec<BatchRequestObject>,
84 #[serde(default)]
85 pub hash_algo: HashAlgo,
86}
87
88#[derive(Debug, Clone)]
89pub struct GitLfsJson<T>(pub Json<T>);
90
91pub const LFS_MIME: &str = "application/vnd.git-lfs+json";
92
93pub enum GitLfsJsonRejection {
94 Json(rejection::JsonRejection),
95 MissingGitLfsJsonContentType,
96}
97
98impl IntoResponse for GitLfsJsonRejection {
99 fn into_response(self) -> Response {
100 match self {
101 Self::Json(rej) => rej.into_response(),
102 Self::MissingGitLfsJsonContentType => make_error_resp(
103 StatusCode::UNSUPPORTED_MEDIA_TYPE,
104 &format!("Expected request with `Content-Type: {LFS_MIME}`"),
105 )
106 .into_response(),
107 }
108 }
109}
110
111pub fn is_git_lfs_json_mimetype(mimetype: &str) -> bool {
112 let Ok(mime) = mimetype.parse::<mime::Mime>() else {
113 return false;
114 };
115 if mime.type_() != mime::APPLICATION
116 || mime.subtype() != "vnd.git-lfs"
117 || mime.suffix() != Some(mime::JSON)
118 {
119 return false;
120 }
121 match mime.get_param(mime::CHARSET) {
122 Some(mime::UTF_8) | None => true,
123 Some(_) => false,
124 }
125}
126
127fn has_git_lfs_json_content_type(req: &Request) -> bool {
128 let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
129 return false;
130 };
131 let Ok(content_type) = content_type.to_str() else {
132 return false;
133 };
134 is_git_lfs_json_mimetype(content_type)
135}
136
137#[async_trait]
138impl<T, S> FromRequest<S> for GitLfsJson<T>
139where
140 T: DeserializeOwned,
141 S: Send + Sync,
142{
143 type Rejection = GitLfsJsonRejection;
144
145 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
146 if !has_git_lfs_json_content_type(&req) {
147 return Err(GitLfsJsonRejection::MissingGitLfsJsonContentType);
148 }
149 Json::<T>::from_request(req, state)
150 .await
151 .map(GitLfsJson)
152 .map_err(GitLfsJsonRejection::Json)
153 }
154}
155
156impl<T: Serialize> IntoResponse for GitLfsJson<T> {
157 fn into_response(self) -> Response {
158 let GitLfsJson(json) = self;
159 let mut resp = json.into_response();
160 resp.headers_mut().insert(
161 header::CONTENT_TYPE,
162 HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"),
163 );
164 resp
165 }
166}
167
168#[derive(Debug, Serialize)]
169pub struct GitLfsErrorData<'a> {
170 pub message: &'a str,
171}
172
173pub type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>);
174
175pub const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse {
176 (code, GitLfsJson(Json(GitLfsErrorData { message })))
177}
178
179#[derive(Debug, Serialize, Clone)]
180pub struct BatchResponseObjectAction {
181 pub href: String,
182 #[serde(skip_serializing_if = "HashMap::is_empty")]
183 pub header: HashMap<String, String>,
184 pub expires_at: DateTime<Utc>,
185}
186
187#[derive(Default, Debug, Serialize, Clone)]
188pub struct BatchResponseObjectActions {
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub upload: Option<BatchResponseObjectAction>,
191 #[serde(skip_serializing_if = "Option::is_none")]
192 pub download: Option<BatchResponseObjectAction>,
193 #[serde(skip_serializing_if = "Option::is_none")]
194 pub verify: Option<BatchResponseObjectAction>,
195}
196
197#[derive(Debug, Clone, Serialize)]
198pub struct BatchResponseObjectError {
199 pub code: u16,
200 pub message: String,
201}
202
203#[derive(Debug, Serialize, Clone)]
204pub struct BatchResponseObject {
205 pub oid: Oid,
206 pub size: i64,
207 #[serde(skip_serializing_if = "Option::is_none")]
208 pub authenticated: Option<bool>,
209 pub actions: BatchResponseObjectActions,
210 #[serde(skip_serializing_if = "Option::is_none")]
211 pub error: Option<BatchResponseObjectError>,
212}
213
214impl BatchResponseObject {
215 pub fn error(
216 obj: &BatchRequestObject,
217 code: StatusCode,
218 message: String,
219 ) -> BatchResponseObject {
220 BatchResponseObject {
221 oid: obj.oid,
222 size: obj.size,
223 authenticated: None,
224 actions: Default::default(),
225 error: Some(BatchResponseObjectError {
226 code: code.as_u16(),
227 message,
228 }),
229 }
230 }
231}
232
233#[derive(Debug, Serialize, Clone)]
234pub struct BatchResponse {
235 pub transfer: TransferAdapter,
236 pub objects: Vec<BatchResponseObject>,
237 pub hash_algo: HashAlgo,
238}
239
240#[test]
241fn test_mimetype() {
242 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json"));
243 assert!(!is_git_lfs_json_mimetype("application/vnd.git-lfs"));
244 assert!(!is_git_lfs_json_mimetype("application/json"));
245 assert!(is_git_lfs_json_mimetype(
246 "application/vnd.git-lfs+json; charset=utf-8"
247 ));
248 assert!(is_git_lfs_json_mimetype(
249 "application/vnd.git-lfs+json; charset=UTF-8"
250 ));
251 assert!(!is_git_lfs_json_mimetype(
252 "application/vnd.git-lfs+json; charset=ISO-8859-1"
253 ));
254}
255
256#[test]
257fn test_deserialize() {
258 let json = r#"{"operation":"upload","objects":[{"oid":"8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c","size":170904}],
259 "transfers":["lfs-standalone-file","basic","ssh"],"ref":{"name":"refs/heads/main"},"hash_algo":"sha256"}"#;
260 let expected = BatchRequest {
261 operation: Operation::Upload,
262 objects: vec![BatchRequestObject {
263 oid: "8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c"
264 .parse()
265 .unwrap(),
266 size: 170904,
267 }],
268 transfers: vec![
269 TransferAdapter::Unknown,
270 TransferAdapter::Basic,
271 TransferAdapter::Unknown,
272 ],
273 hash_algo: HashAlgo::Sha256,
274 };
275 assert_eq!(
276 serde_json::from_str::<BatchRequest>(json).unwrap(),
277 expected
278 );
279}
diff --git a/gitolfs3-server/src/authz.rs b/gitolfs3-server/src/authz.rs
new file mode 100644
index 0000000..0674cef
--- /dev/null
+++ b/gitolfs3-server/src/authz.rs
@@ -0,0 +1,182 @@
1use std::collections::HashSet;
2
3use axum::http::{header, HeaderMap, StatusCode};
4use chrono::{DateTime, Utc};
5use gitolfs3_common::{generate_tag, Claims, Digest, Oid, Operation, SpecificClaims};
6
7use crate::{
8 api::{make_error_resp, GitLfsErrorResponse, REPO_NOT_FOUND},
9 config::AuthorizationConfig,
10};
11
12pub struct Trusted(pub bool);
13
14fn forwarded_from_trusted_host(
15 headers: &HeaderMap,
16 trusted: &HashSet<String>,
17) -> Result<bool, GitLfsErrorResponse<'static>> {
18 if let Some(forwarded_host) = headers.get("X-Forwarded-Host") {
19 if let Ok(forwarded_host) = forwarded_host.to_str() {
20 if trusted.contains(forwarded_host) {
21 return Ok(true);
22 }
23 } else {
24 return Err(make_error_resp(
25 StatusCode::NOT_FOUND,
26 "Invalid X-Forwarded-Host header",
27 ));
28 }
29 }
30 Ok(false)
31}
32
33pub fn authorize_batch(
34 conf: &AuthorizationConfig,
35 repo_path: &str,
36 public: bool,
37 operation: Operation,
38 headers: &HeaderMap,
39) -> Result<Trusted, GitLfsErrorResponse<'static>> {
40 // - No authentication required for downloading exported repos
41 // - When authenticated:
42 // - Download / upload over presigned URLs
43 // - When accessing over Tailscale:
44 // - No authentication required for downloading from any repo
45
46 let claims = VerifyClaimsInput {
47 specific_claims: SpecificClaims::BatchApi(operation),
48 repo_path,
49 };
50 if !verify_claims(conf, &claims, headers)? {
51 return authorize_batch_unauthenticated(conf, public, operation, headers);
52 }
53 Ok(Trusted(true))
54}
55
56fn authorize_batch_unauthenticated(
57 conf: &AuthorizationConfig,
58 public: bool,
59 operation: Operation,
60 headers: &HeaderMap,
61) -> Result<Trusted, GitLfsErrorResponse<'static>> {
62 let trusted = forwarded_from_trusted_host(headers, &conf.trusted_forwarded_hosts)?;
63 match operation {
64 Operation::Upload => {
65 // Trusted users can clone all repositories (by virtue of accessing the server via a
66 // trusted network). However, they can not push without proper authentication. Untrusted
67 // users who are also not authenticated should not need to know which repositories exists.
68 // Therefore, we tell untrusted && unauthenticated users that the repo doesn't exist, but
69 // tell trusted users that they need to authenticate.
70 if !trusted {
71 return Err(REPO_NOT_FOUND);
72 }
73 Err(make_error_resp(
74 StatusCode::FORBIDDEN,
75 "Authentication required to upload",
76 ))
77 }
78 Operation::Download => {
79 // Again, trusted users can see all repos. For untrusted users, we first need to check
80 // whether the repo is public before we authorize. If the user is untrusted and the
81 // repo isn't public, we just act like it doesn't even exist.
82 if !trusted {
83 if !public {
84 return Err(REPO_NOT_FOUND);
85 }
86 return Ok(Trusted(false));
87 }
88 Ok(Trusted(true))
89 }
90 }
91}
92
93pub fn authorize_get(
94 conf: &AuthorizationConfig,
95 repo_path: &str,
96 oid: Oid,
97 headers: &HeaderMap,
98) -> Result<(), GitLfsErrorResponse<'static>> {
99 let claims = VerifyClaimsInput {
100 specific_claims: SpecificClaims::Download(oid),
101 repo_path,
102 };
103 if !verify_claims(conf, &claims, headers)? {
104 return Err(make_error_resp(
105 StatusCode::UNAUTHORIZED,
106 "Repository not found",
107 ));
108 }
109 Ok(())
110}
111
112pub struct VerifyClaimsInput<'a> {
113 pub specific_claims: SpecificClaims,
114 pub repo_path: &'a str,
115}
116
117fn verify_claims(
118 conf: &AuthorizationConfig,
119 claims: &VerifyClaimsInput,
120 headers: &HeaderMap,
121) -> Result<bool, GitLfsErrorResponse<'static>> {
122 const INVALID_AUTHZ_HEADER: GitLfsErrorResponse =
123 make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header");
124
125 let Some(authz) = headers.get(header::AUTHORIZATION) else {
126 return Ok(false);
127 };
128 let authz = authz.to_str().map_err(|_| INVALID_AUTHZ_HEADER)?;
129 let val = authz
130 .strip_prefix("Gitolfs3-Hmac-Sha256 ")
131 .ok_or(INVALID_AUTHZ_HEADER)?;
132 let (tag, expires_at) = val.split_once(' ').ok_or(INVALID_AUTHZ_HEADER)?;
133 let tag: Digest<32> = tag.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
134 let expires_at: i64 = expires_at.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
135 let expires_at = DateTime::<Utc>::from_timestamp(expires_at, 0).ok_or(INVALID_AUTHZ_HEADER)?;
136 let expected_tag = generate_tag(
137 Claims {
138 specific_claims: claims.specific_claims,
139 repo_path: claims.repo_path,
140 expires_at,
141 },
142 &conf.key,
143 )
144 .ok_or_else(|| make_error_resp(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
145 if tag != expected_tag {
146 return Err(INVALID_AUTHZ_HEADER);
147 }
148
149 Ok(true)
150}
151
152#[test]
153fn test_validate_claims() {
154 use gitolfs3_common::Key;
155
156 let key = "00232f7a019bd34e3921ee6c5f04caf48a4489d1be5d1999038950a7054e0bfea369ce2becc0f13fd3c69f8af2384a25b7ac2d52eb52c33722f3c00c50d4c9c2";
157 let key: Key = key.parse().unwrap();
158
159 let claims = Claims {
160 expires_at: Utc::now() + std::time::Duration::from_secs(5 * 60),
161 repo_path: "lfs-test.git",
162 specific_claims: SpecificClaims::BatchApi(Operation::Download),
163 };
164 let tag = generate_tag(claims, &key).unwrap();
165 let header_value = format!(
166 "Gitolfs3-Hmac-Sha256 {tag} {}",
167 claims.expires_at.timestamp()
168 );
169
170 let conf = AuthorizationConfig {
171 key,
172 trusted_forwarded_hosts: HashSet::new(),
173 };
174 let verification_claims = VerifyClaimsInput {
175 repo_path: claims.repo_path,
176 specific_claims: claims.specific_claims,
177 };
178 let mut headers = HeaderMap::new();
179 headers.insert(header::AUTHORIZATION, header_value.try_into().unwrap());
180
181 assert!(verify_claims(&conf, &verification_claims, &headers).unwrap());
182}
diff --git a/gitolfs3-server/src/config.rs b/gitolfs3-server/src/config.rs
new file mode 100644
index 0000000..75e84dc
--- /dev/null
+++ b/gitolfs3-server/src/config.rs
@@ -0,0 +1,122 @@
1use std::collections::HashSet;
2
3use gitolfs3_common::{load_key, Key};
4
5struct Env {
6 s3_access_key_id: String,
7 s3_secret_access_key: String,
8 s3_bucket: String,
9 s3_region: String,
10 s3_endpoint: String,
11 base_url: String,
12 key_path: String,
13 listen_host: String,
14 listen_port: String,
15 download_limit: String,
16 trusted_forwarded_hosts: String,
17}
18
19fn require_env(name: &str) -> Result<String, String> {
20 std::env::var(name)
21 .map_err(|_| format!("environment variable {name} should be defined and valid"))
22}
23
24impl Env {
25 fn load() -> Result<Env, String> {
26 Ok(Env {
27 s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?,
28 s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?,
29 s3_region: require_env("GITOLFS3_S3_REGION")?,
30 s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?,
31 s3_bucket: require_env("GITOLFS3_S3_BUCKET")?,
32 base_url: require_env("GITOLFS3_BASE_URL")?,
33 key_path: require_env("GITOLFS3_KEY_PATH")?,
34 listen_host: require_env("GITOLFS3_LISTEN_HOST")?,
35 listen_port: require_env("GITOLFS3_LISTEN_PORT")?,
36 download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?,
37 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS")
38 .unwrap_or_default(),
39 })
40 }
41}
42
43fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
44 let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?;
45 let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?;
46
47 let credentials = aws_sdk_s3::config::Credentials::new(
48 access_key_id,
49 secret_access_key,
50 None,
51 None,
52 "gitolfs3-env",
53 );
54 let config = aws_config::SdkConfig::builder()
55 .behavior_version(aws_config::BehaviorVersion::latest())
56 .region(aws_config::Region::new(env.s3_region.clone()))
57 .endpoint_url(&env.s3_endpoint)
58 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(
59 credentials,
60 ))
61 .build();
62 Ok(aws_sdk_s3::Client::new(&config))
63}
64
65pub struct Config {
66 pub listen_addr: (String, u16),
67 pub base_url: String,
68 pub authz_conf: AuthorizationConfig,
69 pub s3_client: aws_sdk_s3::Client,
70 pub s3_bucket: String,
71 pub download_limit: u64,
72}
73
74pub struct AuthorizationConfig {
75 pub trusted_forwarded_hosts: HashSet<String>,
76 pub key: Key,
77}
78
79impl Config {
80 pub fn load() -> Result<Self, String> {
81 let env = match Env::load() {
82 Ok(env) => env,
83 Err(e) => return Err(format!("failed to load configuration: {e}")),
84 };
85
86 let s3_client = match get_s3_client(&env) {
87 Ok(s3_client) => s3_client,
88 Err(e) => return Err(format!("failed to create S3 client: {e}")),
89 };
90 let key = match load_key(&env.key_path) {
91 Ok(key) => key,
92 Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
93 };
94
95 let trusted_forwarded_hosts: HashSet<String> = env
96 .trusted_forwarded_hosts
97 .split(',')
98 .map(|s| s.to_owned())
99 .filter(|s| !s.is_empty())
100 .collect();
101 let base_url = env.base_url.trim_end_matches('/').to_string();
102
103 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
104 return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
105 };
106 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
107 return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
108 };
109
110 Ok(Self {
111 listen_addr: (env.listen_host, listen_port),
112 base_url,
113 authz_conf: AuthorizationConfig {
114 key,
115 trusted_forwarded_hosts,
116 },
117 s3_client,
118 s3_bucket: env.s3_bucket,
119 download_limit,
120 })
121 }
122}
diff --git a/gitolfs3-server/src/dlimit.rs b/gitolfs3-server/src/dlimit.rs
new file mode 100644
index 0000000..f68bec1
--- /dev/null
+++ b/gitolfs3-server/src/dlimit.rs
@@ -0,0 +1,73 @@
1use std::sync::Arc;
2use std::time::Duration;
3use tokio::io::AsyncWriteExt;
4use tokio::sync::Mutex;
5
6// I know that this is pretty bad, but it's good enough (??) for now.
7pub struct DownloadLimiter {
8 current: u64,
9 limit: u64,
10}
11
12impl DownloadLimiter {
13 pub async fn new(limit: u64) -> Arc<Mutex<DownloadLimiter>> {
14 let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await {
15 Ok(dlimit_str) => dlimit_str,
16 Err(e) => {
17 println!("Failed to read download counter, assuming 0: {e}");
18 return DownloadLimiter { current: 0, limit }.auto_resetting();
19 }
20 };
21 let current: u64 = match dlimit_str
22 .parse()
23 .map_err(|e| tokio::io::Error::new(tokio::io::ErrorKind::InvalidData, e))
24 {
25 Ok(current) => current,
26 Err(e) => {
27 println!("Failed to read download counter, assuming 0: {e}");
28 return DownloadLimiter { current: 0, limit }.auto_resetting();
29 }
30 };
31
32 DownloadLimiter { current, limit }.auto_resetting()
33 }
34
35 fn auto_resetting(self) -> Arc<Mutex<Self>> {
36 let limiter_ref = Arc::new(Mutex::new(self));
37 let limiter_ref_cloned = limiter_ref.clone();
38 tokio::spawn(async move {
39 loop {
40 println!("Resetting download counter in one hour");
41 tokio::time::sleep(Duration::from_secs(3600)).await;
42 println!("Resetting download counter");
43 limiter_ref_cloned.lock().await.reset().await;
44 }
45 });
46 limiter_ref
47 }
48
49 pub async fn request(&mut self, n: u64) -> tokio::io::Result<bool> {
50 if self.current + n > self.limit {
51 return Ok(false);
52 }
53 self.current += n;
54 self.write_new_count().await?;
55 Ok(true)
56 }
57
58 pub async fn reset(&mut self) {
59 self.current = 0;
60 if let Err(e) = self.write_new_count().await {
61 println!("Failed to reset download counter: {e}");
62 }
63 }
64
65 async fn write_new_count(&self) -> tokio::io::Result<()> {
66 let cwd = tokio::fs::File::open(std::env::current_dir()?).await?;
67 let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?;
68 file.write_all(self.current.to_string().as_bytes()).await?;
69 file.sync_all().await?;
70 tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?;
71 cwd.sync_all().await
72 }
73}
diff --git a/gitolfs3-server/src/handler.rs b/gitolfs3-server/src/handler.rs
new file mode 100644
index 0000000..6516291
--- /dev/null
+++ b/gitolfs3-server/src/handler.rs
@@ -0,0 +1,455 @@
1use std::{collections::HashMap, sync::Arc};
2
3use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput};
4use axum::{
5 extract::{Path, State},
6 http::{header, HeaderMap, StatusCode},
7 response::{IntoResponse, Response},
8 Json,
9};
10use base64::{prelude::BASE64_STANDARD, Engine};
11use chrono::Utc;
12use gitolfs3_common::{generate_tag, Claims, HexByte, Oid, Operation, SpecificClaims};
13use serde::{de, Deserialize};
14use tokio::sync::Mutex;
15
16use crate::{
17 api::{
18 is_git_lfs_json_mimetype, make_error_resp, BatchRequest, BatchRequestObject, BatchResponse,
19 BatchResponseObject, BatchResponseObjectAction, BatchResponseObjectActions, GitLfsJson,
20 HashAlgo, RepositoryName, TransferAdapter, LFS_MIME, REPO_NOT_FOUND,
21 },
22 authz::{authorize_batch, authorize_get, Trusted},
23 config::AuthorizationConfig,
24 dlimit::DownloadLimiter,
25};
26
27pub struct AppState {
28 pub s3_client: aws_sdk_s3::Client,
29 pub s3_bucket: String,
30 pub authz_conf: AuthorizationConfig,
31 // Should not end with a slash.
32 pub base_url: String,
33 pub dl_limiter: Arc<Mutex<DownloadLimiter>>,
34}
35
36fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool {
37 if let Some(checksum) = obj.checksum_sha256() {
38 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
39 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
40 return Oid::from(checksum32b) == oid;
41 }
42 }
43 }
44 true
45}
46
47fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool {
48 if let Some(length) = obj.content_length() {
49 return length == expected;
50 }
51 true
52}
53
54async fn handle_upload_object(
55 state: &AppState,
56 repo: &str,
57 obj: &BatchRequestObject,
58) -> Option<BatchResponseObject> {
59 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
60 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
61
62 match state
63 .s3_client
64 .head_object()
65 .bucket(&state.s3_bucket)
66 .key(full_path.clone())
67 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
68 .send()
69 .await
70 {
71 Ok(result) => {
72 if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) {
73 return None;
74 }
75 }
76 Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {}
77 Err(e) => {
78 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
79 return Some(BatchResponseObject::error(
80 obj,
81 StatusCode::INTERNAL_SERVER_ERROR,
82 "Failed to query object information".to_string(),
83 ));
84 }
85 };
86
87 let expires_in = std::time::Duration::from_secs(5 * 60);
88 let expires_at = Utc::now() + expires_in;
89
90 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
91 return Some(BatchResponseObject::error(
92 obj,
93 StatusCode::INTERNAL_SERVER_ERROR,
94 "Failed to generate upload URL".to_string(),
95 ));
96 };
97 let Ok(presigned) = state
98 .s3_client
99 .put_object()
100 .bucket(&state.s3_bucket)
101 .key(full_path)
102 .checksum_sha256(obj.oid.to_string())
103 .content_length(obj.size)
104 .presigned(config)
105 .await
106 else {
107 return Some(BatchResponseObject::error(
108 obj,
109 StatusCode::INTERNAL_SERVER_ERROR,
110 "Failed to generate upload URL".to_string(),
111 ));
112 };
113 Some(BatchResponseObject {
114 oid: obj.oid,
115 size: obj.size,
116 authenticated: Some(true),
117 actions: BatchResponseObjectActions {
118 upload: Some(BatchResponseObjectAction {
119 header: presigned
120 .headers()
121 .map(|(k, v)| (k.to_owned(), v.to_owned()))
122 .collect(),
123 expires_at,
124 href: presigned.uri().to_string(),
125 }),
126 ..Default::default()
127 },
128 error: None,
129 })
130}
131
132async fn handle_download_object(
133 state: &AppState,
134 repo: &str,
135 obj: &BatchRequestObject,
136 trusted: bool,
137) -> BatchResponseObject {
138 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
139 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
140
141 let result = match state
142 .s3_client
143 .head_object()
144 .bucket(&state.s3_bucket)
145 .key(&full_path)
146 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
147 .send()
148 .await
149 {
150 Ok(result) => result,
151 Err(e) => {
152 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
153 return BatchResponseObject::error(
154 obj,
155 StatusCode::INTERNAL_SERVER_ERROR,
156 "Failed to query object information".to_string(),
157 );
158 }
159 };
160
161 // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :)
162 if !validate_checksum(obj.oid, &result) {
163 return BatchResponseObject::error(
164 obj,
165 StatusCode::UNPROCESSABLE_ENTITY,
166 "Object corrupted".to_string(),
167 );
168 }
169 if !validate_size(obj.size, &result) {
170 return BatchResponseObject::error(
171 obj,
172 StatusCode::UNPROCESSABLE_ENTITY,
173 "Incorrect size specified (or object corrupted)".to_string(),
174 );
175 }
176
177 let expires_in = std::time::Duration::from_secs(5 * 60);
178 let expires_at = Utc::now() + expires_in;
179
180 if trusted {
181 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
182 return BatchResponseObject::error(
183 obj,
184 StatusCode::INTERNAL_SERVER_ERROR,
185 "Failed to generate upload URL".to_string(),
186 );
187 };
188 let Ok(presigned) = state
189 .s3_client
190 .get_object()
191 .bucket(&state.s3_bucket)
192 .key(full_path)
193 .presigned(config)
194 .await
195 else {
196 return BatchResponseObject::error(
197 obj,
198 StatusCode::INTERNAL_SERVER_ERROR,
199 "Failed to generate upload URL".to_string(),
200 );
201 };
202 return BatchResponseObject {
203 oid: obj.oid,
204 size: obj.size,
205 authenticated: Some(true),
206 actions: BatchResponseObjectActions {
207 download: Some(BatchResponseObjectAction {
208 header: presigned
209 .headers()
210 .map(|(k, v)| (k.to_owned(), v.to_owned()))
211 .collect(),
212 expires_at,
213 href: presigned.uri().to_string(),
214 }),
215 ..Default::default()
216 },
217 error: None,
218 };
219 }
220
221 if let Some(content_length) = result.content_length() {
222 if content_length > 0 {
223 match state
224 .dl_limiter
225 .lock()
226 .await
227 .request(content_length as u64)
228 .await
229 {
230 Ok(true) => {}
231 Ok(false) => {
232 return BatchResponseObject::error(
233 obj,
234 StatusCode::SERVICE_UNAVAILABLE,
235 "Public LFS downloads temporarily unavailable".to_string(),
236 );
237 }
238 Err(e) => {
239 println!("Failed to request {content_length} bytes from download limiter: {e}");
240 return BatchResponseObject::error(
241 obj,
242 StatusCode::INTERNAL_SERVER_ERROR,
243 "Internal server error".to_string(),
244 );
245 }
246 }
247 }
248 }
249
250 let Some(tag) = generate_tag(
251 Claims {
252 specific_claims: SpecificClaims::Download(obj.oid),
253 repo_path: repo,
254 expires_at,
255 },
256 &state.authz_conf.key,
257 ) else {
258 return BatchResponseObject::error(
259 obj,
260 StatusCode::INTERNAL_SERVER_ERROR,
261 "Internal server error".to_string(),
262 );
263 };
264
265 let upload_path = format!(
266 "{repo}/info/lfs/objects/{}/{}/{}",
267 HexByte(obj.oid[0]),
268 HexByte(obj.oid[1]),
269 obj.oid,
270 );
271
272 BatchResponseObject {
273 oid: obj.oid,
274 size: obj.size,
275 authenticated: Some(true),
276 actions: BatchResponseObjectActions {
277 download: Some(BatchResponseObjectAction {
278 header: {
279 let mut map = HashMap::new();
280 map.insert(
281 "Authorization".to_string(),
282 format!("Gitolfs3-Hmac-Sha256 {tag} {}", expires_at.timestamp()),
283 );
284 map
285 },
286 expires_at,
287 href: format!("{}/{upload_path}", state.base_url),
288 }),
289 ..Default::default()
290 },
291 error: None,
292 }
293}
294
295fn repo_exists(name: &str) -> bool {
296 let Ok(metadata) = std::fs::metadata(name) else {
297 return false;
298 };
299 metadata.is_dir()
300}
301
302fn is_repo_public(name: &str) -> Option<bool> {
303 if !repo_exists(name) {
304 return None;
305 }
306 match std::fs::metadata(format!("{name}/git-daemon-export-ok")) {
307 Ok(metadata) if metadata.is_file() => Some(true),
308 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some(false),
309 _ => None,
310 }
311}
312
313pub async fn batch(
314 State(state): State<Arc<AppState>>,
315 headers: HeaderMap,
316 RepositoryName(repo): RepositoryName,
317 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
318) -> Response {
319 let Some(public) = is_repo_public(&repo) else {
320 return REPO_NOT_FOUND.into_response();
321 };
322 let Trusted(trusted) = match authorize_batch(
323 &state.authz_conf,
324 &repo,
325 public,
326 payload.operation,
327 &headers,
328 ) {
329 Ok(authn) => authn,
330 Err(e) => return e.into_response(),
331 };
332
333 if !headers
334 .get_all("Accept")
335 .iter()
336 .filter_map(|v| v.to_str().ok())
337 .any(is_git_lfs_json_mimetype)
338 {
339 let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types");
340 return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response();
341 }
342
343 if payload.hash_algo != HashAlgo::Sha256 {
344 let message = "Unsupported hashing algorithm specified";
345 return make_error_resp(StatusCode::CONFLICT, message).into_response();
346 }
347 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) {
348 let message = "Unsupported transfer adapter specified (supported: basic)";
349 return make_error_resp(StatusCode::CONFLICT, message).into_response();
350 }
351
352 let mut resp = BatchResponse {
353 transfer: TransferAdapter::Basic,
354 objects: vec![],
355 hash_algo: HashAlgo::Sha256,
356 };
357 for obj in payload.objects {
358 match payload.operation {
359 Operation::Download => resp
360 .objects
361 .push(handle_download_object(&state, &repo, &obj, trusted).await),
362 Operation::Upload => {
363 if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await {
364 resp.objects.push(obj_resp);
365 }
366 }
367 };
368 }
369 GitLfsJson(Json(resp)).into_response()
370}
371
372#[derive(Deserialize, Copy, Clone)]
373#[serde(remote = "Self")]
374pub struct FileParams {
375 oid0: HexByte,
376 oid1: HexByte,
377 oid: Oid,
378}
379
380impl<'de> Deserialize<'de> for FileParams {
381 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
382 where
383 D: serde::Deserializer<'de>,
384 {
385 let unchecked @ FileParams {
386 oid0: HexByte(oid0),
387 oid1: HexByte(oid1),
388 oid,
389 } = FileParams::deserialize(deserializer)?;
390 if oid0 != oid.as_bytes()[0] {
391 return Err(de::Error::custom(
392 "first OID path part does not match first byte of full OID",
393 ));
394 }
395 if oid1 != oid.as_bytes()[1] {
396 return Err(de::Error::custom(
397 "second OID path part does not match first byte of full OID",
398 ));
399 }
400 Ok(unchecked)
401 }
402}
403
404pub async fn obj_download(
405 State(state): State<Arc<AppState>>,
406 headers: HeaderMap,
407 RepositoryName(repo): RepositoryName,
408 Path(FileParams { oid0, oid1, oid }): Path<FileParams>,
409) -> Response {
410 if let Err(e) = authorize_get(&state.authz_conf, &repo, oid, &headers) {
411 return e.into_response();
412 }
413
414 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, oid);
415 let result = match state
416 .s3_client
417 .get_object()
418 .bucket(&state.s3_bucket)
419 .key(full_path)
420 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
421 .send()
422 .await
423 {
424 Ok(result) => result,
425 Err(e) => {
426 println!("Failed to GetObject (repo {repo}, OID {oid}): {e}");
427 return (
428 StatusCode::INTERNAL_SERVER_ERROR,
429 "Failed to query object information",
430 )
431 .into_response();
432 }
433 };
434
435 let mut headers = header::HeaderMap::new();
436 if let Some(content_type) = result.content_type {
437 let Ok(header_value) = content_type.try_into() else {
438 return (
439 StatusCode::INTERNAL_SERVER_ERROR,
440 "Object has invalid content type",
441 )
442 .into_response();
443 };
444 headers.insert(header::CONTENT_TYPE, header_value);
445 }
446 if let Some(content_length) = result.content_length {
447 headers.insert(header::CONTENT_LENGTH, content_length.into());
448 }
449
450 let async_read = result.body.into_async_read();
451 let stream = tokio_util::io::ReaderStream::new(async_read);
452 let body = axum::body::Body::from_stream(stream);
453
454 (headers, body).into_response()
455}
diff --git a/gitolfs3-server/src/main.rs b/gitolfs3-server/src/main.rs
index b05a0c8..c9911ed 100644
--- a/gitolfs3-server/src/main.rs
+++ b/gitolfs3-server/src/main.rs
@@ -1,27 +1,21 @@
1use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; 1mod api;
2mod authz;
3mod config;
4mod dlimit;
5mod handler;
6
7use api::RepositoryName;
8use config::Config;
9use dlimit::DownloadLimiter;
10
2use axum::{ 11use axum::{
3 async_trait, 12 extract::OriginalUri,
4 extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State}, 13 http::{StatusCode, Uri},
5 http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
6 response::{IntoResponse, Response},
7 routing::{get, post}, 14 routing::{get, post},
8 Extension, Json, Router, ServiceExt, 15 Router, ServiceExt,
9};
10use base64::prelude::*;
11use chrono::{DateTime, Utc};
12use gitolfs3_common::{
13 generate_tag, load_key, Claims, Digest, HexByte, Key, Oid, Operation, SpecificClaims,
14};
15use serde::{
16 de::{self, DeserializeOwned},
17 Deserialize, Serialize,
18}; 16};
19use std::{ 17use handler::AppState;
20 collections::{HashMap, HashSet}, 18use std::{process::ExitCode, sync::Arc};
21 process::ExitCode,
22 sync::Arc,
23};
24use tokio::io::AsyncWriteExt;
25use tower::Layer; 19use tower::Layer;
26 20
27#[tokio::main] 21#[tokio::main]
@@ -37,18 +31,6 @@ async fn main() -> ExitCode {
37 }; 31 };
38 32
39 let dl_limiter = DownloadLimiter::new(conf.download_limit).await; 33 let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
40 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
41
42 let resetter_dl_limiter = dl_limiter.clone();
43 tokio::spawn(async move {
44 loop {
45 println!("Resetting download counter in one hour");
46 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
47 println!("Resetting download counter");
48 resetter_dl_limiter.lock().await.reset().await;
49 }
50 });
51
52 let shared_state = Arc::new(AppState { 34 let shared_state = Arc::new(AppState {
53 s3_client: conf.s3_client, 35 s3_client: conf.s3_client,
54 s3_bucket: conf.s3_bucket, 36 s3_bucket: conf.s3_bucket,
@@ -57,8 +39,8 @@ async fn main() -> ExitCode {
57 dl_limiter, 39 dl_limiter,
58 }); 40 });
59 let app = Router::new() 41 let app = Router::new()
60 .route("/batch", post(batch)) 42 .route("/batch", post(handler::batch))
61 .route("/:oid0/:oid1/:oid", get(obj_download)) 43 .route("/:oid0/:oid1/:oid", get(handler::obj_download))
62 .with_state(shared_state); 44 .with_state(shared_state);
63 45
64 let middleware = axum::middleware::map_request(rewrite_url); 46 let middleware = axum::middleware::map_request(rewrite_url);
@@ -81,30 +63,6 @@ async fn main() -> ExitCode {
81 } 63 }
82} 64}
83 65
84#[derive(Clone)]
85struct RepositoryName(String);
86
87struct RepositoryNameRejection;
88
89impl IntoResponse for RepositoryNameRejection {
90 fn into_response(self) -> Response {
91 (StatusCode::INTERNAL_SERVER_ERROR, "Missing repository name").into_response()
92 }
93}
94
95#[async_trait]
96impl<S: Send + Sync> FromRequestParts<S> for RepositoryName {
97 type Rejection = RepositoryNameRejection;
98
99 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
100 let Ok(Extension(repo_name)) = Extension::<Self>::from_request_parts(parts, state).await
101 else {
102 return Err(RepositoryNameRejection);
103 };
104 Ok(repo_name)
105 }
106}
107
108async fn rewrite_url<B>( 66async fn rewrite_url<B>(
109 mut req: axum::http::Request<B>, 67 mut req: axum::http::Request<B>,
110) -> Result<axum::http::Request<B>, StatusCode> { 68) -> Result<axum::http::Request<B>, StatusCode> {
@@ -141,1014 +99,3 @@ async fn rewrite_url<B>(
141 99
142 Ok(req) 100 Ok(req)
143} 101}
144
145struct AppState {
146 s3_client: aws_sdk_s3::Client,
147 s3_bucket: String,
148 authz_conf: AuthorizationConfig,
149 // Should not end with a slash.
150 base_url: String,
151 dl_limiter: Arc<tokio::sync::Mutex<DownloadLimiter>>,
152}
153
154struct Env {
155 s3_access_key_id: String,
156 s3_secret_access_key: String,
157 s3_bucket: String,
158 s3_region: String,
159 s3_endpoint: String,
160 base_url: String,
161 key_path: String,
162 listen_host: String,
163 listen_port: String,
164 download_limit: String,
165 trusted_forwarded_hosts: String,
166}
167
168fn require_env(name: &str) -> Result<String, String> {
169 std::env::var(name)
170 .map_err(|_| format!("environment variable {name} should be defined and valid"))
171}
172
173impl Env {
174 fn load() -> Result<Env, String> {
175 Ok(Env {
176 s3_secret_access_key: require_env("GITOLFS3_S3_SECRET_ACCESS_KEY_FILE")?,
177 s3_access_key_id: require_env("GITOLFS3_S3_ACCESS_KEY_ID_FILE")?,
178 s3_region: require_env("GITOLFS3_S3_REGION")?,
179 s3_endpoint: require_env("GITOLFS3_S3_ENDPOINT")?,
180 s3_bucket: require_env("GITOLFS3_S3_BUCKET")?,
181 base_url: require_env("GITOLFS3_BASE_URL")?,
182 key_path: require_env("GITOLFS3_KEY_PATH")?,
183 listen_host: require_env("GITOLFS3_LISTEN_HOST")?,
184 listen_port: require_env("GITOLFS3_LISTEN_PORT")?,
185 download_limit: require_env("GITOLFS3_DOWNLOAD_LIMIT")?,
186 trusted_forwarded_hosts: std::env::var("GITOLFS3_TRUSTED_FORWARDED_HOSTS")
187 .unwrap_or_default(),
188 })
189 }
190}
191
192fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
193 let access_key_id = std::fs::read_to_string(&env.s3_access_key_id)?;
194 let secret_access_key = std::fs::read_to_string(&env.s3_secret_access_key)?;
195
196 let credentials = aws_sdk_s3::config::Credentials::new(
197 access_key_id,
198 secret_access_key,
199 None,
200 None,
201 "gitolfs3-env",
202 );
203 let config = aws_config::SdkConfig::builder()
204 .behavior_version(aws_config::BehaviorVersion::latest())
205 .region(aws_config::Region::new(env.s3_region.clone()))
206 .endpoint_url(&env.s3_endpoint)
207 .credentials_provider(aws_sdk_s3::config::SharedCredentialsProvider::new(
208 credentials,
209 ))
210 .build();
211 Ok(aws_sdk_s3::Client::new(&config))
212}
213
214struct Config {
215 listen_addr: (String, u16),
216 base_url: String,
217 authz_conf: AuthorizationConfig,
218 s3_client: aws_sdk_s3::Client,
219 s3_bucket: String,
220 download_limit: u64,
221}
222
223impl Config {
224 fn load() -> Result<Self, String> {
225 let env = match Env::load() {
226 Ok(env) => env,
227 Err(e) => return Err(format!("failed to load configuration: {e}")),
228 };
229
230 let s3_client = match get_s3_client(&env) {
231 Ok(s3_client) => s3_client,
232 Err(e) => return Err(format!("failed to create S3 client: {e}")),
233 };
234 let key = match load_key(&env.key_path) {
235 Ok(key) => key,
236 Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
237 };
238
239 let trusted_forwarded_hosts: HashSet<String> = env
240 .trusted_forwarded_hosts
241 .split(',')
242 .map(|s| s.to_owned())
243 .filter(|s| !s.is_empty())
244 .collect();
245 let base_url = env.base_url.trim_end_matches('/').to_string();
246
247 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
248 return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
249 };
250 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
251 return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
252 };
253
254 Ok(Self {
255 listen_addr: (env.listen_host, listen_port),
256 base_url,
257 authz_conf: AuthorizationConfig {
258 key,
259 trusted_forwarded_hosts,
260 },
261 s3_client,
262 s3_bucket: env.s3_bucket,
263 download_limit,
264 })
265 }
266}
267
268#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
269enum TransferAdapter {
270 #[serde(rename = "basic")]
271 Basic,
272 #[serde(other)]
273 Unknown,
274}
275
276#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Copy)]
277enum HashAlgo {
278 #[serde(rename = "sha256")]
279 Sha256,
280 #[serde(other)]
281 Unknown,
282}
283
284impl Default for HashAlgo {
285 fn default() -> Self {
286 Self::Sha256
287 }
288}
289
290#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
291struct BatchRequestObject {
292 oid: Oid,
293 size: i64,
294}
295
296#[derive(Debug, Serialize, Deserialize, Clone)]
297struct BatchRef {
298 name: String,
299}
300
301fn default_transfers() -> Vec<TransferAdapter> {
302 vec![TransferAdapter::Basic]
303}
304
305#[derive(Debug, Deserialize, PartialEq, Eq, Clone)]
306struct BatchRequest {
307 operation: Operation,
308 #[serde(default = "default_transfers")]
309 transfers: Vec<TransferAdapter>,
310 objects: Vec<BatchRequestObject>,
311 #[serde(default)]
312 hash_algo: HashAlgo,
313}
314
315#[derive(Debug, Clone)]
316struct GitLfsJson<T>(Json<T>);
317
318const LFS_MIME: &str = "application/vnd.git-lfs+json";
319
320enum GitLfsJsonRejection {
321 Json(rejection::JsonRejection),
322 MissingGitLfsJsonContentType,
323}
324
325impl IntoResponse for GitLfsJsonRejection {
326 fn into_response(self) -> Response {
327 match self {
328 Self::Json(rej) => rej.into_response(),
329 Self::MissingGitLfsJsonContentType => make_error_resp(
330 StatusCode::UNSUPPORTED_MEDIA_TYPE,
331 &format!("Expected request with `Content-Type: {LFS_MIME}`"),
332 )
333 .into_response(),
334 }
335 }
336}
337
338fn is_git_lfs_json_mimetype(mimetype: &str) -> bool {
339 let Ok(mime) = mimetype.parse::<mime::Mime>() else {
340 return false;
341 };
342 if mime.type_() != mime::APPLICATION
343 || mime.subtype() != "vnd.git-lfs"
344 || mime.suffix() != Some(mime::JSON)
345 {
346 return false;
347 }
348 match mime.get_param(mime::CHARSET) {
349 Some(mime::UTF_8) | None => true,
350 Some(_) => false,
351 }
352}
353
354fn has_git_lfs_json_content_type(req: &Request) -> bool {
355 let Some(content_type) = req.headers().get(header::CONTENT_TYPE) else {
356 return false;
357 };
358 let Ok(content_type) = content_type.to_str() else {
359 return false;
360 };
361 is_git_lfs_json_mimetype(content_type)
362}
363
364#[async_trait]
365impl<T, S> FromRequest<S> for GitLfsJson<T>
366where
367 T: DeserializeOwned,
368 S: Send + Sync,
369{
370 type Rejection = GitLfsJsonRejection;
371
372 async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
373 if !has_git_lfs_json_content_type(&req) {
374 return Err(GitLfsJsonRejection::MissingGitLfsJsonContentType);
375 }
376 Json::<T>::from_request(req, state)
377 .await
378 .map(GitLfsJson)
379 .map_err(GitLfsJsonRejection::Json)
380 }
381}
382
383impl<T: Serialize> IntoResponse for GitLfsJson<T> {
384 fn into_response(self) -> Response {
385 let GitLfsJson(json) = self;
386 let mut resp = json.into_response();
387 resp.headers_mut().insert(
388 header::CONTENT_TYPE,
389 HeaderValue::from_static("application/vnd.git-lfs+json; charset=utf-8"),
390 );
391 resp
392 }
393}
394
395#[derive(Debug, Serialize)]
396struct GitLfsErrorData<'a> {
397 message: &'a str,
398}
399
400type GitLfsErrorResponse<'a> = (StatusCode, GitLfsJson<GitLfsErrorData<'a>>);
401
402const fn make_error_resp(code: StatusCode, message: &str) -> GitLfsErrorResponse {
403 (code, GitLfsJson(Json(GitLfsErrorData { message })))
404}
405
406#[derive(Debug, Serialize, Clone)]
407struct BatchResponseObjectAction {
408 href: String,
409 #[serde(skip_serializing_if = "HashMap::is_empty")]
410 header: HashMap<String, String>,
411 expires_at: DateTime<Utc>,
412}
413
414#[derive(Default, Debug, Serialize, Clone)]
415struct BatchResponseObjectActions {
416 #[serde(skip_serializing_if = "Option::is_none")]
417 upload: Option<BatchResponseObjectAction>,
418 #[serde(skip_serializing_if = "Option::is_none")]
419 download: Option<BatchResponseObjectAction>,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 verify: Option<BatchResponseObjectAction>,
422}
423
424#[derive(Debug, Clone, Serialize)]
425struct BatchResponseObjectError {
426 code: u16,
427 message: String,
428}
429
430#[derive(Debug, Serialize, Clone)]
431struct BatchResponseObject {
432 oid: Oid,
433 size: i64,
434 #[serde(skip_serializing_if = "Option::is_none")]
435 authenticated: Option<bool>,
436 actions: BatchResponseObjectActions,
437 #[serde(skip_serializing_if = "Option::is_none")]
438 error: Option<BatchResponseObjectError>,
439}
440
441impl BatchResponseObject {
442 fn error(obj: &BatchRequestObject, code: StatusCode, message: String) -> BatchResponseObject {
443 BatchResponseObject {
444 oid: obj.oid,
445 size: obj.size,
446 authenticated: None,
447 actions: Default::default(),
448 error: Some(BatchResponseObjectError {
449 code: code.as_u16(),
450 message,
451 }),
452 }
453 }
454}
455
456#[derive(Debug, Serialize, Clone)]
457struct BatchResponse {
458 transfer: TransferAdapter,
459 objects: Vec<BatchResponseObject>,
460 hash_algo: HashAlgo,
461}
462
463fn validate_checksum(oid: Oid, obj: &HeadObjectOutput) -> bool {
464 if let Some(checksum) = obj.checksum_sha256() {
465 if let Ok(checksum) = BASE64_STANDARD.decode(checksum) {
466 if let Ok(checksum32b) = TryInto::<[u8; 32]>::try_into(checksum) {
467 return Oid::from(checksum32b) == oid;
468 }
469 }
470 }
471 true
472}
473
474fn validate_size(expected: i64, obj: &HeadObjectOutput) -> bool {
475 if let Some(length) = obj.content_length() {
476 return length == expected;
477 }
478 true
479}
480
481async fn handle_upload_object(
482 state: &AppState,
483 repo: &str,
484 obj: &BatchRequestObject,
485) -> Option<BatchResponseObject> {
486 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
487 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
488
489 match state
490 .s3_client
491 .head_object()
492 .bucket(&state.s3_bucket)
493 .key(full_path.clone())
494 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
495 .send()
496 .await
497 {
498 Ok(result) => {
499 if validate_size(obj.size, &result) && validate_checksum(obj.oid, &result) {
500 return None;
501 }
502 }
503 Err(SdkError::ServiceError(e)) if e.err().is_not_found() => {}
504 Err(e) => {
505 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
506 return Some(BatchResponseObject::error(
507 obj,
508 StatusCode::INTERNAL_SERVER_ERROR,
509 "Failed to query object information".to_string(),
510 ));
511 }
512 };
513
514 let expires_in = std::time::Duration::from_secs(5 * 60);
515 let expires_at = Utc::now() + expires_in;
516
517 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
518 return Some(BatchResponseObject::error(
519 obj,
520 StatusCode::INTERNAL_SERVER_ERROR,
521 "Failed to generate upload URL".to_string(),
522 ));
523 };
524 let Ok(presigned) = state
525 .s3_client
526 .put_object()
527 .bucket(&state.s3_bucket)
528 .key(full_path)
529 .checksum_sha256(obj.oid.to_string())
530 .content_length(obj.size)
531 .presigned(config)
532 .await
533 else {
534 return Some(BatchResponseObject::error(
535 obj,
536 StatusCode::INTERNAL_SERVER_ERROR,
537 "Failed to generate upload URL".to_string(),
538 ));
539 };
540 Some(BatchResponseObject {
541 oid: obj.oid,
542 size: obj.size,
543 authenticated: Some(true),
544 actions: BatchResponseObjectActions {
545 upload: Some(BatchResponseObjectAction {
546 header: presigned
547 .headers()
548 .map(|(k, v)| (k.to_owned(), v.to_owned()))
549 .collect(),
550 expires_at,
551 href: presigned.uri().to_string(),
552 }),
553 ..Default::default()
554 },
555 error: None,
556 })
557}
558
559async fn handle_download_object(
560 state: &AppState,
561 repo: &str,
562 obj: &BatchRequestObject,
563 trusted: bool,
564) -> BatchResponseObject {
565 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
566 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
567
568 let result = match state
569 .s3_client
570 .head_object()
571 .bucket(&state.s3_bucket)
572 .key(&full_path)
573 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
574 .send()
575 .await
576 {
577 Ok(result) => result,
578 Err(e) => {
579 println!("Failed to HeadObject (repo {repo}, OID {}): {e}", obj.oid);
580 return BatchResponseObject::error(
581 obj,
582 StatusCode::INTERNAL_SERVER_ERROR,
583 "Failed to query object information".to_string(),
584 );
585 }
586 };
587
588 // Scaleway actually doesn't provide SHA256 suport, but maybe in the future :)
589 if !validate_checksum(obj.oid, &result) {
590 return BatchResponseObject::error(
591 obj,
592 StatusCode::UNPROCESSABLE_ENTITY,
593 "Object corrupted".to_string(),
594 );
595 }
596 if !validate_size(obj.size, &result) {
597 return BatchResponseObject::error(
598 obj,
599 StatusCode::UNPROCESSABLE_ENTITY,
600 "Incorrect size specified (or object corrupted)".to_string(),
601 );
602 }
603
604 let expires_in = std::time::Duration::from_secs(5 * 60);
605 let expires_at = Utc::now() + expires_in;
606
607 if trusted {
608 let Ok(config) = aws_sdk_s3::presigning::PresigningConfig::expires_in(expires_in) else {
609 return BatchResponseObject::error(
610 obj,
611 StatusCode::INTERNAL_SERVER_ERROR,
612 "Failed to generate upload URL".to_string(),
613 );
614 };
615 let Ok(presigned) = state
616 .s3_client
617 .get_object()
618 .bucket(&state.s3_bucket)
619 .key(full_path)
620 .presigned(config)
621 .await
622 else {
623 return BatchResponseObject::error(
624 obj,
625 StatusCode::INTERNAL_SERVER_ERROR,
626 "Failed to generate upload URL".to_string(),
627 );
628 };
629 return BatchResponseObject {
630 oid: obj.oid,
631 size: obj.size,
632 authenticated: Some(true),
633 actions: BatchResponseObjectActions {
634 download: Some(BatchResponseObjectAction {
635 header: presigned
636 .headers()
637 .map(|(k, v)| (k.to_owned(), v.to_owned()))
638 .collect(),
639 expires_at,
640 href: presigned.uri().to_string(),
641 }),
642 ..Default::default()
643 },
644 error: None,
645 };
646 }
647
648 if let Some(content_length) = result.content_length() {
649 if content_length > 0 {
650 match state
651 .dl_limiter
652 .lock()
653 .await
654 .request(content_length as u64)
655 .await
656 {
657 Ok(true) => {}
658 Ok(false) => {
659 return BatchResponseObject::error(
660 obj,
661 StatusCode::SERVICE_UNAVAILABLE,
662 "Public LFS downloads temporarily unavailable".to_string(),
663 );
664 }
665 Err(e) => {
666 println!("Failed to request {content_length} bytes from download limiter: {e}");
667 return BatchResponseObject::error(
668 obj,
669 StatusCode::INTERNAL_SERVER_ERROR,
670 "Internal server error".to_string(),
671 );
672 }
673 }
674 }
675 }
676
677 let Some(tag) = generate_tag(
678 Claims {
679 specific_claims: SpecificClaims::Download(obj.oid),
680 repo_path: repo,
681 expires_at,
682 },
683 &state.authz_conf.key,
684 ) else {
685 return BatchResponseObject::error(
686 obj,
687 StatusCode::INTERNAL_SERVER_ERROR,
688 "Internal server error".to_string(),
689 );
690 };
691
692 let upload_path = format!(
693 "{repo}/info/lfs/objects/{}/{}/{}",
694 HexByte(obj.oid[0]),
695 HexByte(obj.oid[1]),
696 obj.oid,
697 );
698
699 BatchResponseObject {
700 oid: obj.oid,
701 size: obj.size,
702 authenticated: Some(true),
703 actions: BatchResponseObjectActions {
704 download: Some(BatchResponseObjectAction {
705 header: {
706 let mut map = HashMap::new();
707 map.insert(
708 "Authorization".to_string(),
709 format!("Gitolfs3-Hmac-Sha256 {tag} {}", expires_at.timestamp()),
710 );
711 map
712 },
713 expires_at,
714 href: format!("{}/{upload_path}", state.base_url),
715 }),
716 ..Default::default()
717 },
718 error: None,
719 }
720}
721
722struct AuthorizationConfig {
723 trusted_forwarded_hosts: HashSet<String>,
724 key: Key,
725}
726
727struct Trusted(bool);
728
729fn forwarded_from_trusted_host(
730 headers: &HeaderMap,
731 trusted: &HashSet<String>,
732) -> Result<bool, GitLfsErrorResponse<'static>> {
733 if let Some(forwarded_host) = headers.get("X-Forwarded-Host") {
734 if let Ok(forwarded_host) = forwarded_host.to_str() {
735 if trusted.contains(forwarded_host) {
736 return Ok(true);
737 }
738 } else {
739 return Err(make_error_resp(
740 StatusCode::NOT_FOUND,
741 "Invalid X-Forwarded-Host header",
742 ));
743 }
744 }
745 Ok(false)
746}
747
748const REPO_NOT_FOUND: GitLfsErrorResponse =
749 make_error_resp(StatusCode::NOT_FOUND, "Repository not found");
750
751fn authorize_batch(
752 conf: &AuthorizationConfig,
753 repo_path: &str,
754 public: bool,
755 operation: Operation,
756 headers: &HeaderMap,
757) -> Result<Trusted, GitLfsErrorResponse<'static>> {
758 // - No authentication required for downloading exported repos
759 // - When authenticated:
760 // - Download / upload over presigned URLs
761 // - When accessing over Tailscale:
762 // - No authentication required for downloading from any repo
763
764 let claims = VerifyClaimsInput {
765 specific_claims: SpecificClaims::BatchApi(operation),
766 repo_path,
767 };
768 if !verify_claims(conf, &claims, headers)? {
769 return authorize_batch_unauthenticated(conf, public, operation, headers);
770 }
771 Ok(Trusted(true))
772}
773
774fn authorize_batch_unauthenticated(
775 conf: &AuthorizationConfig,
776 public: bool,
777 operation: Operation,
778 headers: &HeaderMap,
779) -> Result<Trusted, GitLfsErrorResponse<'static>> {
780 let trusted = forwarded_from_trusted_host(headers, &conf.trusted_forwarded_hosts)?;
781 match operation {
782 Operation::Upload => {
783 // Trusted users can clone all repositories (by virtue of accessing the server via a
784 // trusted network). However, they can not push without proper authentication. Untrusted
785 // users who are also not authenticated should not need to know which repositories exists.
786 // Therefore, we tell untrusted && unauthenticated users that the repo doesn't exist, but
787 // tell trusted users that they need to authenticate.
788 if !trusted {
789 return Err(REPO_NOT_FOUND);
790 }
791 Err(make_error_resp(
792 StatusCode::FORBIDDEN,
793 "Authentication required to upload",
794 ))
795 }
796 Operation::Download => {
797 // Again, trusted users can see all repos. For untrusted users, we first need to check
798 // whether the repo is public before we authorize. If the user is untrusted and the
799 // repo isn't public, we just act like it doesn't even exist.
800 if !trusted {
801 if !public {
802 return Err(REPO_NOT_FOUND);
803 }
804 return Ok(Trusted(false));
805 }
806 Ok(Trusted(true))
807 }
808 }
809}
810
811fn repo_exists(name: &str) -> bool {
812 let Ok(metadata) = std::fs::metadata(name) else {
813 return false;
814 };
815 metadata.is_dir()
816}
817
818fn is_repo_public(name: &str) -> Option<bool> {
819 if !repo_exists(name) {
820 return None;
821 }
822 match std::fs::metadata(format!("{name}/git-daemon-export-ok")) {
823 Ok(metadata) if metadata.is_file() => Some(true),
824 Err(e) if e.kind() == std::io::ErrorKind::NotFound => Some(false),
825 _ => None,
826 }
827}
828
829async fn batch(
830 State(state): State<Arc<AppState>>,
831 headers: HeaderMap,
832 RepositoryName(repo): RepositoryName,
833 GitLfsJson(Json(payload)): GitLfsJson<BatchRequest>,
834) -> Response {
835 let Some(public) = is_repo_public(&repo) else {
836 return REPO_NOT_FOUND.into_response();
837 };
838 let Trusted(trusted) = match authorize_batch(
839 &state.authz_conf,
840 &repo,
841 public,
842 payload.operation,
843 &headers,
844 ) {
845 Ok(authn) => authn,
846 Err(e) => return e.into_response(),
847 };
848
849 if !headers
850 .get_all("Accept")
851 .iter()
852 .filter_map(|v| v.to_str().ok())
853 .any(is_git_lfs_json_mimetype)
854 {
855 let message = format!("Expected `{LFS_MIME}` in list of acceptable response media types");
856 return make_error_resp(StatusCode::NOT_ACCEPTABLE, &message).into_response();
857 }
858
859 if payload.hash_algo != HashAlgo::Sha256 {
860 let message = "Unsupported hashing algorithm specified";
861 return make_error_resp(StatusCode::CONFLICT, message).into_response();
862 }
863 if !payload.transfers.is_empty() && !payload.transfers.contains(&TransferAdapter::Basic) {
864 let message = "Unsupported transfer adapter specified (supported: basic)";
865 return make_error_resp(StatusCode::CONFLICT, message).into_response();
866 }
867
868 let mut resp = BatchResponse {
869 transfer: TransferAdapter::Basic,
870 objects: vec![],
871 hash_algo: HashAlgo::Sha256,
872 };
873 for obj in payload.objects {
874 match payload.operation {
875 Operation::Download => resp
876 .objects
877 .push(handle_download_object(&state, &repo, &obj, trusted).await),
878 Operation::Upload => {
879 if let Some(obj_resp) = handle_upload_object(&state, &repo, &obj).await {
880 resp.objects.push(obj_resp);
881 }
882 }
883 };
884 }
885 GitLfsJson(Json(resp)).into_response()
886}
887
888#[derive(Deserialize, Copy, Clone)]
889#[serde(remote = "Self")]
890struct FileParams {
891 oid0: HexByte,
892 oid1: HexByte,
893 oid: Oid,
894}
895
896impl<'de> Deserialize<'de> for FileParams {
897 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
898 where
899 D: serde::Deserializer<'de>,
900 {
901 let unchecked @ FileParams {
902 oid0: HexByte(oid0),
903 oid1: HexByte(oid1),
904 oid,
905 } = FileParams::deserialize(deserializer)?;
906 if oid0 != oid.as_bytes()[0] {
907 return Err(de::Error::custom(
908 "first OID path part does not match first byte of full OID",
909 ));
910 }
911 if oid1 != oid.as_bytes()[1] {
912 return Err(de::Error::custom(
913 "second OID path part does not match first byte of full OID",
914 ));
915 }
916 Ok(unchecked)
917 }
918}
919
920pub struct VerifyClaimsInput<'a> {
921 pub specific_claims: SpecificClaims,
922 pub repo_path: &'a str,
923}
924
925fn verify_claims(
926 conf: &AuthorizationConfig,
927 claims: &VerifyClaimsInput,
928 headers: &HeaderMap,
929) -> Result<bool, GitLfsErrorResponse<'static>> {
930 const INVALID_AUTHZ_HEADER: GitLfsErrorResponse =
931 make_error_resp(StatusCode::BAD_REQUEST, "Invalid authorization header");
932
933 let Some(authz) = headers.get(header::AUTHORIZATION) else {
934 return Ok(false);
935 };
936 let authz = authz.to_str().map_err(|_| INVALID_AUTHZ_HEADER)?;
937 let val = authz
938 .strip_prefix("Gitolfs3-Hmac-Sha256 ")
939 .ok_or(INVALID_AUTHZ_HEADER)?;
940 let (tag, expires_at) = val.split_once(' ').ok_or(INVALID_AUTHZ_HEADER)?;
941 let tag: Digest<32> = tag.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
942 let expires_at: i64 = expires_at.parse().map_err(|_| INVALID_AUTHZ_HEADER)?;
943 let expires_at = DateTime::<Utc>::from_timestamp(expires_at, 0).ok_or(INVALID_AUTHZ_HEADER)?;
944 let expected_tag = generate_tag(
945 Claims {
946 specific_claims: claims.specific_claims,
947 repo_path: claims.repo_path,
948 expires_at,
949 },
950 &conf.key,
951 )
952 .ok_or_else(|| make_error_resp(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"))?;
953 if tag != expected_tag {
954 return Err(INVALID_AUTHZ_HEADER);
955 }
956
957 Ok(true)
958}
959
960fn authorize_get(
961 conf: &AuthorizationConfig,
962 repo_path: &str,
963 oid: Oid,
964 headers: &HeaderMap,
965) -> Result<(), GitLfsErrorResponse<'static>> {
966 let claims = VerifyClaimsInput {
967 specific_claims: SpecificClaims::Download(oid),
968 repo_path,
969 };
970 if !verify_claims(conf, &claims, headers)? {
971 return Err(make_error_resp(
972 StatusCode::UNAUTHORIZED,
973 "Repository not found",
974 ));
975 }
976 Ok(())
977}
978
979async fn obj_download(
980 State(state): State<Arc<AppState>>,
981 headers: HeaderMap,
982 RepositoryName(repo): RepositoryName,
983 Path(FileParams { oid0, oid1, oid }): Path<FileParams>,
984) -> Response {
985 if let Err(e) = authorize_get(&state.authz_conf, &repo, oid, &headers) {
986 return e.into_response();
987 }
988
989 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, oid);
990 let result = match state
991 .s3_client
992 .get_object()
993 .bucket(&state.s3_bucket)
994 .key(full_path)
995 .checksum_mode(aws_sdk_s3::types::ChecksumMode::Enabled)
996 .send()
997 .await
998 {
999 Ok(result) => result,
1000 Err(e) => {
1001 println!("Failed to GetObject (repo {repo}, OID {oid}): {e}");
1002 return (
1003 StatusCode::INTERNAL_SERVER_ERROR,
1004 "Failed to query object information",
1005 )
1006 .into_response();
1007 }
1008 };
1009
1010 let mut headers = header::HeaderMap::new();
1011 if let Some(content_type) = result.content_type {
1012 let Ok(header_value) = content_type.try_into() else {
1013 return (
1014 StatusCode::INTERNAL_SERVER_ERROR,
1015 "Object has invalid content type",
1016 )
1017 .into_response();
1018 };
1019 headers.insert(header::CONTENT_TYPE, header_value);
1020 }
1021 if let Some(content_length) = result.content_length {
1022 headers.insert(header::CONTENT_LENGTH, content_length.into());
1023 }
1024
1025 let async_read = result.body.into_async_read();
1026 let stream = tokio_util::io::ReaderStream::new(async_read);
1027 let body = axum::body::Body::from_stream(stream);
1028
1029 (headers, body).into_response()
1030}
1031
1032struct DownloadLimiter {
1033 current: u64,
1034 limit: u64,
1035}
1036
1037impl DownloadLimiter {
1038 async fn new(limit: u64) -> DownloadLimiter {
1039 let dlimit_str = match tokio::fs::read_to_string(".gitolfs3-dlimit").await {
1040 Ok(dlimit_str) => dlimit_str,
1041 Err(e) => {
1042 println!("Failed to read download counter, assuming 0: {e}");
1043 return DownloadLimiter { current: 0, limit };
1044 }
1045 };
1046 let current: u64 = match dlimit_str
1047 .parse()
1048 .map_err(|e| tokio::io::Error::new(tokio::io::ErrorKind::InvalidData, e))
1049 {
1050 Ok(current) => current,
1051 Err(e) => {
1052 println!("Failed to read download counter, assuming 0: {e}");
1053 return DownloadLimiter { current: 0, limit };
1054 }
1055 };
1056 DownloadLimiter { current, limit }
1057 }
1058
1059 async fn request(&mut self, n: u64) -> tokio::io::Result<bool> {
1060 if self.current + n > self.limit {
1061 return Ok(false);
1062 }
1063 self.current += n;
1064 self.write_new_count().await?;
1065 Ok(true)
1066 }
1067
1068 async fn reset(&mut self) {
1069 self.current = 0;
1070 if let Err(e) = self.write_new_count().await {
1071 println!("Failed to reset download counter: {e}");
1072 }
1073 }
1074
1075 async fn write_new_count(&self) -> tokio::io::Result<()> {
1076 let cwd = tokio::fs::File::open(std::env::current_dir()?).await?;
1077 let mut file = tokio::fs::File::create(".gitolfs3-dlimit.tmp").await?;
1078 file.write_all(self.current.to_string().as_bytes()).await?;
1079 file.sync_all().await?;
1080 tokio::fs::rename(".gitolfs3-dlimit.tmp", ".gitolfs3-dlimit").await?;
1081 cwd.sync_all().await
1082 }
1083}
1084
1085#[test]
1086fn test_mimetype() {
1087 assert!(is_git_lfs_json_mimetype("application/vnd.git-lfs+json"));
1088 assert!(!is_git_lfs_json_mimetype("application/vnd.git-lfs"));
1089 assert!(!is_git_lfs_json_mimetype("application/json"));
1090 assert!(is_git_lfs_json_mimetype(
1091 "application/vnd.git-lfs+json; charset=utf-8"
1092 ));
1093 assert!(is_git_lfs_json_mimetype(
1094 "application/vnd.git-lfs+json; charset=UTF-8"
1095 ));
1096 assert!(!is_git_lfs_json_mimetype(
1097 "application/vnd.git-lfs+json; charset=ISO-8859-1"
1098 ));
1099}
1100
1101#[test]
1102fn test_deserialize() {
1103 let json = r#"{"operation":"upload","objects":[{"oid":"8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c","size":170904}],
1104 "transfers":["lfs-standalone-file","basic","ssh"],"ref":{"name":"refs/heads/main"},"hash_algo":"sha256"}"#;
1105 let expected = BatchRequest {
1106 operation: Operation::Upload,
1107 objects: vec![BatchRequestObject {
1108 oid: "8f4123f9a7181f488c5e111d82cefd992e461ae5df01fd2254399e6e670b2d3c"
1109 .parse()
1110 .unwrap(),
1111 size: 170904,
1112 }],
1113 transfers: vec![
1114 TransferAdapter::Unknown,
1115 TransferAdapter::Basic,
1116 TransferAdapter::Unknown,
1117 ],
1118 hash_algo: HashAlgo::Sha256,
1119 };
1120 assert_eq!(
1121 serde_json::from_str::<BatchRequest>(json).unwrap(),
1122 expected
1123 );
1124}
1125
1126#[test]
1127fn test_validate_claims() {
1128 let key = "00232f7a019bd34e3921ee6c5f04caf48a4489d1be5d1999038950a7054e0bfea369ce2becc0f13fd3c69f8af2384a25b7ac2d52eb52c33722f3c00c50d4c9c2";
1129 let key: Key = key.parse().unwrap();
1130
1131 let claims = Claims {
1132 expires_at: Utc::now() + std::time::Duration::from_secs(5 * 60),
1133 repo_path: "lfs-test.git",
1134 specific_claims: SpecificClaims::BatchApi(Operation::Download),
1135 };
1136 let tag = generate_tag(claims, &key).unwrap();
1137 let header_value = format!(
1138 "Gitolfs3-Hmac-Sha256 {tag} {}",
1139 claims.expires_at.timestamp()
1140 );
1141
1142 let conf = AuthorizationConfig {
1143 key,
1144 trusted_forwarded_hosts: HashSet::new(),
1145 };
1146 let verification_claims = VerifyClaimsInput {
1147 repo_path: claims.repo_path,
1148 specific_claims: claims.specific_claims,
1149 };
1150 let mut headers = HeaderMap::new();
1151 headers.insert(header::AUTHORIZATION, header_value.try_into().unwrap());
1152
1153 assert!(verify_claims(&conf, &verification_claims, &headers).unwrap());
1154}