diff options
author | Rutger Broekhoff | 2024-01-26 12:28:39 +0100 |
---|---|---|
committer | Rutger Broekhoff | 2024-01-26 12:28:39 +0100 |
commit | f146743061ba8170569bf18518202df9a43c09f3 (patch) | |
tree | 798c303c07f9146d420de0bb5babb3df0fed58f5 | |
parent | c3d120445877b307f5ea7e95aed4bab45d7dede0 (diff) | |
download | gitolfs3-f146743061ba8170569bf18518202df9a43c09f3.tar.gz gitolfs3-f146743061ba8170569bf18518202df9a43c09f3.zip |
Clean up part of the code
-rw-r--r-- | Cargo.lock | 7 | ||||
-rw-r--r-- | common/src/lib.rs | 11 | ||||
-rw-r--r-- | git-lfs-authenticate/Cargo.toml | 1 | ||||
-rw-r--r-- | git-lfs-authenticate/src/main.rs | 275 | ||||
-rw-r--r-- | server/src/main.rs | 257 | ||||
-rw-r--r-- | shell/src/main.rs | 158 |
6 files changed, 306 insertions, 403 deletions
@@ -42,6 +42,12 @@ dependencies = [ | |||
42 | ] | 42 | ] |
43 | 43 | ||
44 | [[package]] | 44 | [[package]] |
45 | name = "anyhow" | ||
46 | version = "1.0.79" | ||
47 | source = "registry+https://github.com/rust-lang/crates.io-index" | ||
48 | checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" | ||
49 | |||
50 | [[package]] | ||
45 | name = "async-trait" | 51 | name = "async-trait" |
46 | version = "0.1.77" | 52 | version = "0.1.77" |
47 | source = "registry+https://github.com/rust-lang/crates.io-index" | 53 | source = "registry+https://github.com/rust-lang/crates.io-index" |
@@ -869,6 +875,7 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" | |||
869 | name = "git-lfs-authenticate" | 875 | name = "git-lfs-authenticate" |
870 | version = "0.1.0" | 876 | version = "0.1.0" |
871 | dependencies = [ | 877 | dependencies = [ |
878 | "anyhow", | ||
872 | "chrono", | 879 | "chrono", |
873 | "common", | 880 | "common", |
874 | ] | 881 | ] |
diff --git a/common/src/lib.rs b/common/src/lib.rs index 995352d..0a538a5 100644 --- a/common/src/lib.rs +++ b/common/src/lib.rs | |||
@@ -1,9 +1,10 @@ | |||
1 | use chrono::{DateTime, Utc}; | 1 | use chrono::{DateTime, Utc}; |
2 | use serde::de; | 2 | use serde::{de, Deserialize, Serialize}; |
3 | use serde::{Deserialize, Serialize}; | 3 | use std::{ |
4 | use std::fmt::Write; | 4 | fmt::{self, Write}, |
5 | use std::ops; | 5 | ops, |
6 | use std::{fmt, str::FromStr}; | 6 | str::FromStr, |
7 | }; | ||
7 | use subtle::ConstantTimeEq; | 8 | use subtle::ConstantTimeEq; |
8 | 9 | ||
9 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] | 10 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] |
diff --git a/git-lfs-authenticate/Cargo.toml b/git-lfs-authenticate/Cargo.toml index 217250f..f4ab4d7 100644 --- a/git-lfs-authenticate/Cargo.toml +++ b/git-lfs-authenticate/Cargo.toml | |||
@@ -4,5 +4,6 @@ version = "0.1.0" | |||
4 | edition = "2021" | 4 | edition = "2021" |
5 | 5 | ||
6 | [dependencies] | 6 | [dependencies] |
7 | anyhow = "1.0" | ||
7 | chrono = "0.4" | 8 | chrono = "0.4" |
8 | common = { path = "../common" } | 9 | common = { path = "../common" } |
diff --git a/git-lfs-authenticate/src/main.rs b/git-lfs-authenticate/src/main.rs index 36d7818..accc37f 100644 --- a/git-lfs-authenticate/src/main.rs +++ b/git-lfs-authenticate/src/main.rs | |||
@@ -1,197 +1,13 @@ | |||
1 | use std::{fmt, process::ExitCode, time::Duration}; | 1 | use anyhow::{anyhow, bail, Result}; |
2 | |||
3 | use chrono::Utc; | 2 | use chrono::Utc; |
4 | use common::{Operation, ParseOperationError}; | 3 | use std::{process::ExitCode, time::Duration}; |
5 | |||
6 | fn help() { | ||
7 | eprintln!("Usage: git-lfs-authenticate <REPO> upload/download"); | ||
8 | } | ||
9 | |||
10 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] | ||
11 | enum RepoNameError { | ||
12 | TooLong, | ||
13 | UnresolvedPath, | ||
14 | AbsolutePath, | ||
15 | MissingGitSuffix, | ||
16 | } | ||
17 | |||
18 | impl fmt::Display for RepoNameError { | ||
19 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
20 | match self { | ||
21 | Self::TooLong => write!(f, "too long (more than 100 characters)"), | ||
22 | Self::UnresolvedPath => { | ||
23 | write!(f, "contains path one or more path elements '.' and '..'") | ||
24 | } | ||
25 | Self::AbsolutePath => { | ||
26 | write!(f, "starts with '/', which is not allowed") | ||
27 | } | ||
28 | Self::MissingGitSuffix => write!(f, "misses '.git' suffix"), | ||
29 | } | ||
30 | } | ||
31 | } | ||
32 | |||
33 | // Using `Result<(), E>` here instead of `Option<E>` because `None` typically signifies some error | ||
34 | // state with no further details provided. If we were to return an `Option` type, the user would | ||
35 | // have to first transform it into a `Result` type in order to use the `?` operator, meaning that | ||
36 | // they would have to the following operation to get the type into the right shape: | ||
37 | // `validate_repo_path(path).map_or(Ok(()), Err)`. That would not be very ergonomic. | ||
38 | fn validate_repo_path(path: &str) -> Result<(), RepoNameError> { | ||
39 | if path.len() > 100 { | ||
40 | return Err(RepoNameError::TooLong); | ||
41 | } | ||
42 | if path.contains("//") | ||
43 | || path.contains("/./") | ||
44 | || path.contains("/../") | ||
45 | || path.starts_with("./") | ||
46 | || path.starts_with("../") | ||
47 | { | ||
48 | return Err(RepoNameError::UnresolvedPath); | ||
49 | } | ||
50 | if path.starts_with('/') { | ||
51 | return Err(RepoNameError::AbsolutePath); | ||
52 | } | ||
53 | if !path.ends_with(".git") { | ||
54 | return Err(RepoNameError::MissingGitSuffix); | ||
55 | } | ||
56 | Ok(()) | ||
57 | } | ||
58 | |||
59 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] | ||
60 | enum ParseCmdlineError { | ||
61 | UnknownOperation(ParseOperationError), | ||
62 | InvalidRepoName(RepoNameError), | ||
63 | UnexpectedArgCount(ArgCountError), | ||
64 | } | ||
65 | |||
66 | impl From<RepoNameError> for ParseCmdlineError { | ||
67 | fn from(value: RepoNameError) -> Self { | ||
68 | Self::InvalidRepoName(value) | ||
69 | } | ||
70 | } | ||
71 | |||
72 | impl From<ParseOperationError> for ParseCmdlineError { | ||
73 | fn from(value: ParseOperationError) -> Self { | ||
74 | Self::UnknownOperation(value) | ||
75 | } | ||
76 | } | ||
77 | |||
78 | impl From<ArgCountError> for ParseCmdlineError { | ||
79 | fn from(value: ArgCountError) -> Self { | ||
80 | Self::UnexpectedArgCount(value) | ||
81 | } | ||
82 | } | ||
83 | |||
84 | impl fmt::Display for ParseCmdlineError { | ||
85 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
86 | match self { | ||
87 | Self::UnknownOperation(e) => write!(f, "unknown operation: {e}"), | ||
88 | Self::InvalidRepoName(e) => write!(f, "invalid repository name: {e}"), | ||
89 | Self::UnexpectedArgCount(e) => e.fmt(f), | ||
90 | } | ||
91 | } | ||
92 | } | ||
93 | |||
94 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] | ||
95 | struct ArgCountError { | ||
96 | provided: usize, | ||
97 | expected: usize, | ||
98 | } | ||
99 | |||
100 | impl fmt::Display for ArgCountError { | ||
101 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
102 | write!( | ||
103 | f, | ||
104 | "got {} argument(s), expected {}", | ||
105 | self.provided, self.expected | ||
106 | ) | ||
107 | } | ||
108 | } | ||
109 | |||
110 | fn get_cmdline_args<const N: usize>() -> Result<[String; N], ArgCountError> { | ||
111 | let args = std::env::args(); | ||
112 | if args.len() - 1 != N { | ||
113 | return Err(ArgCountError { | ||
114 | provided: args.len() - 1, | ||
115 | expected: N, | ||
116 | }); | ||
117 | } | ||
118 | |||
119 | // Does not allocate. | ||
120 | const EMPTY_STRING: String = String::new(); | ||
121 | let mut values = [EMPTY_STRING; N]; | ||
122 | |||
123 | // Skip the first element; we do not care about the program name. | ||
124 | for (i, arg) in args.skip(1).enumerate() { | ||
125 | values[i] = arg | ||
126 | } | ||
127 | Ok(values) | ||
128 | } | ||
129 | |||
130 | fn parse_cmdline() -> Result<(String, Operation), ParseCmdlineError> { | ||
131 | let [repo_path, op_str] = get_cmdline_args::<2>()?; | ||
132 | let op: Operation = op_str.parse()?; | ||
133 | validate_repo_path(&repo_path)?; | ||
134 | Ok((repo_path.to_string(), op)) | ||
135 | } | ||
136 | |||
137 | fn repo_exists(name: &str) -> bool { | ||
138 | match std::fs::metadata(name) { | ||
139 | Ok(metadata) => metadata.is_dir(), | ||
140 | _ => false, | ||
141 | } | ||
142 | } | ||
143 | |||
144 | struct Config { | ||
145 | href_base: String, | ||
146 | key_path: String, | ||
147 | } | ||
148 | |||
149 | #[derive(Debug, Eq, PartialEq, Copy, Clone)] | ||
150 | enum LoadConfigError { | ||
151 | BaseUrlNotProvided, | ||
152 | BaseUrlSlashSuffixMissing, | ||
153 | KeyPathNotProvided, | ||
154 | } | ||
155 | |||
156 | impl fmt::Display for LoadConfigError { | ||
157 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
158 | match self { | ||
159 | Self::BaseUrlNotProvided => write!(f, "base URL not provided"), | ||
160 | Self::BaseUrlSlashSuffixMissing => write!(f, "base URL does not end with slash"), | ||
161 | Self::KeyPathNotProvided => write!(f, "key path not provided"), | ||
162 | } | ||
163 | } | ||
164 | } | ||
165 | |||
166 | fn load_config() -> Result<Config, LoadConfigError> { | ||
167 | let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else { | ||
168 | return Err(LoadConfigError::BaseUrlNotProvided); | ||
169 | }; | ||
170 | if !href_base.ends_with('/') { | ||
171 | return Err(LoadConfigError::BaseUrlSlashSuffixMissing); | ||
172 | } | ||
173 | let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else { | ||
174 | return Err(LoadConfigError::KeyPathNotProvided); | ||
175 | }; | ||
176 | Ok(Config { | ||
177 | href_base, | ||
178 | key_path, | ||
179 | }) | ||
180 | } | ||
181 | 4 | ||
182 | fn main() -> ExitCode { | 5 | fn main() -> ExitCode { |
183 | let config = match load_config() { | 6 | let config = match Config::load() { |
184 | Ok(config) => config, | 7 | Ok(config) => config, |
185 | Err(e) => { | 8 | Err(e) => { |
186 | eprintln!("Failed to load config: {e}"); | 9 | eprintln!("Error: {e}"); |
187 | return ExitCode::FAILURE; | 10 | return ExitCode::from(2); |
188 | } | ||
189 | }; | ||
190 | let key = match common::load_key(&config.key_path) { | ||
191 | Ok(key) => key, | ||
192 | Err(e) => { | ||
193 | eprintln!("Failed to load key: {e}"); | ||
194 | return ExitCode::FAILURE; | ||
195 | } | 11 | } |
196 | }; | 12 | }; |
197 | 13 | ||
@@ -199,7 +15,7 @@ fn main() -> ExitCode { | |||
199 | Ok(args) => args, | 15 | Ok(args) => args, |
200 | Err(e) => { | 16 | Err(e) => { |
201 | eprintln!("Error: {e}\n"); | 17 | eprintln!("Error: {e}\n"); |
202 | help(); | 18 | eprintln!("Usage: git-lfs-authenticate <REPO> upload/download"); |
203 | // Exit code 2 signifies bad usage of CLI. | 19 | // Exit code 2 signifies bad usage of CLI. |
204 | return ExitCode::from(2); | 20 | return ExitCode::from(2); |
205 | } | 21 | } |
@@ -217,7 +33,7 @@ fn main() -> ExitCode { | |||
217 | repo_path: &repo_name, | 33 | repo_path: &repo_name, |
218 | expires_at, | 34 | expires_at, |
219 | }, | 35 | }, |
220 | key, | 36 | config.key, |
221 | ) else { | 37 | ) else { |
222 | eprintln!("Failed to generate validation tag"); | 38 | eprintln!("Failed to generate validation tag"); |
223 | return ExitCode::FAILURE; | 39 | return ExitCode::FAILURE; |
@@ -234,3 +50,80 @@ fn main() -> ExitCode { | |||
234 | 50 | ||
235 | ExitCode::SUCCESS | 51 | ExitCode::SUCCESS |
236 | } | 52 | } |
53 | |||
54 | struct Config { | ||
55 | href_base: String, | ||
56 | key: common::Key, | ||
57 | } | ||
58 | |||
59 | impl Config { | ||
60 | fn load() -> Result<Self> { | ||
61 | let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else { | ||
62 | bail!("configured base URL not provided"); | ||
63 | }; | ||
64 | if !href_base.ends_with('/') { | ||
65 | bail!("configured base URL does not end with a slash"); | ||
66 | } | ||
67 | |||
68 | let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else { | ||
69 | bail!("key path not provided"); | ||
70 | }; | ||
71 | let key = common::load_key(&key_path).map_err(|e| anyhow!("failed to load key: {e}"))?; | ||
72 | |||
73 | Ok(Self { href_base, key }) | ||
74 | } | ||
75 | } | ||
76 | |||
77 | fn parse_cmdline() -> Result<(String, common::Operation)> { | ||
78 | let [repo_path, op_str] = get_cmdline_args::<2>()?; | ||
79 | let op: common::Operation = op_str | ||
80 | .parse() | ||
81 | .map_err(|e| anyhow!("unknown operation: {e}"))?; | ||
82 | validate_repo_path(&repo_path).map_err(|e| anyhow!("invalid repository name: {e}"))?; | ||
83 | Ok((repo_path.to_string(), op)) | ||
84 | } | ||
85 | |||
86 | fn get_cmdline_args<const N: usize>() -> Result<[String; N]> { | ||
87 | let args = std::env::args(); | ||
88 | if args.len() - 1 != N { | ||
89 | bail!("got {} argument(s), expected {}", args.len() - 1, N); | ||
90 | } | ||
91 | |||
92 | // Does not allocate. | ||
93 | const EMPTY_STRING: String = String::new(); | ||
94 | let mut values = [EMPTY_STRING; N]; | ||
95 | |||
96 | // Skip the first element; we do not care about the program name. | ||
97 | for (i, arg) in args.skip(1).enumerate() { | ||
98 | values[i] = arg | ||
99 | } | ||
100 | Ok(values) | ||
101 | } | ||
102 | |||
103 | fn validate_repo_path(path: &str) -> Result<()> { | ||
104 | if path.len() > 100 { | ||
105 | bail!("too long (more than 100 characters)"); | ||
106 | } | ||
107 | if path.contains("//") | ||
108 | || path.contains("/./") | ||
109 | || path.contains("/../") | ||
110 | || path.starts_with("./") | ||
111 | || path.starts_with("../") | ||
112 | { | ||
113 | bail!("contains one or more path elements '.' and '..'"); | ||
114 | } | ||
115 | if path.starts_with('/') { | ||
116 | bail!("starts with '/', which is not allowed"); | ||
117 | } | ||
118 | if !path.ends_with(".git") { | ||
119 | bail!("missed '.git' suffix"); | ||
120 | } | ||
121 | Ok(()) | ||
122 | } | ||
123 | |||
124 | fn repo_exists(name: &str) -> bool { | ||
125 | match std::fs::metadata(name) { | ||
126 | Ok(metadata) => metadata.is_dir(), | ||
127 | _ => false, | ||
128 | } | ||
129 | } | ||
diff --git a/server/src/main.rs b/server/src/main.rs index db37d14..4a88dcd 100644 --- a/server/src/main.rs +++ b/server/src/main.rs | |||
@@ -1,39 +1,82 @@ | |||
1 | use std::collections::HashMap; | 1 | use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput}; |
2 | use std::collections::HashSet; | ||
3 | use std::process::ExitCode; | ||
4 | use std::sync::Arc; | ||
5 | |||
6 | use aws_sdk_s3::error::SdkError; | ||
7 | use aws_sdk_s3::operation::head_object::HeadObjectOutput; | ||
8 | use axum::extract::rejection; | ||
9 | use axum::extract::FromRequest; | ||
10 | use axum::extract::Path; | ||
11 | use axum::extract::State; | ||
12 | use axum::http::header; | ||
13 | use axum::http::HeaderMap; | ||
14 | use axum::http::HeaderValue; | ||
15 | use axum::response::Response; | ||
16 | use axum::Json; | ||
17 | use axum::ServiceExt; | ||
18 | use base64::prelude::*; | ||
19 | use chrono::DateTime; | ||
20 | use chrono::Utc; | ||
21 | use common::HexByte; | ||
22 | use serde::de; | ||
23 | use serde::de::DeserializeOwned; | ||
24 | use serde::Deserialize; | ||
25 | use serde::Serialize; | ||
26 | use tokio::io::AsyncWriteExt; | ||
27 | use tower::Layer; | ||
28 | |||
29 | use axum::{ | 2 | use axum::{ |
30 | async_trait, | 3 | async_trait, |
31 | extract::{FromRequestParts, OriginalUri, Request}, | 4 | extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State}, |
32 | http::{request::Parts, StatusCode, Uri}, | 5 | http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri}, |
33 | response::IntoResponse, | 6 | response::{IntoResponse, Response}, |
34 | routing::{get, post}, | 7 | routing::{get, post}, |
35 | Extension, Router, | 8 | Extension, Json, Router, ServiceExt, |
36 | }; | 9 | }; |
10 | use base64::prelude::*; | ||
11 | use chrono::{DateTime, Utc}; | ||
12 | use serde::{ | ||
13 | de::{self, DeserializeOwned}, | ||
14 | Deserialize, Serialize, | ||
15 | }; | ||
16 | use std::{ | ||
17 | collections::{HashMap, HashSet}, | ||
18 | process::ExitCode, | ||
19 | sync::Arc, | ||
20 | }; | ||
21 | use tokio::io::AsyncWriteExt; | ||
22 | use tower::Layer; | ||
23 | |||
24 | #[tokio::main] | ||
25 | async fn main() -> ExitCode { | ||
26 | tracing_subscriber::fmt::init(); | ||
27 | |||
28 | let conf = match Config::load() { | ||
29 | Ok(conf) => conf, | ||
30 | Err(e) => { | ||
31 | println!("Error: {e}"); | ||
32 | return ExitCode::from(2); | ||
33 | } | ||
34 | }; | ||
35 | |||
36 | let dl_limiter = DownloadLimiter::new(conf.download_limit).await; | ||
37 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
38 | |||
39 | let resetter_dl_limiter = dl_limiter.clone(); | ||
40 | tokio::spawn(async move { | ||
41 | loop { | ||
42 | println!("Resetting download counter in one hour"); | ||
43 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
44 | println!("Resetting download counter"); | ||
45 | resetter_dl_limiter.lock().await.reset().await; | ||
46 | } | ||
47 | }); | ||
48 | |||
49 | let shared_state = Arc::new(AppState { | ||
50 | s3_client: conf.s3_client, | ||
51 | s3_bucket: conf.s3_bucket, | ||
52 | authz_conf: conf.authz_conf, | ||
53 | base_url: conf.base_url, | ||
54 | dl_limiter, | ||
55 | }); | ||
56 | let app = Router::new() | ||
57 | .route("/batch", post(batch)) | ||
58 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
59 | .with_state(shared_state); | ||
60 | |||
61 | let middleware = axum::middleware::map_request(rewrite_url); | ||
62 | let app_with_middleware = middleware.layer(app); | ||
63 | |||
64 | let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await { | ||
65 | Ok(listener) => listener, | ||
66 | Err(e) => { | ||
67 | println!("Failed to listen: {e}"); | ||
68 | return ExitCode::FAILURE; | ||
69 | } | ||
70 | }; | ||
71 | |||
72 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | ||
73 | Ok(_) => ExitCode::SUCCESS, | ||
74 | Err(e) => { | ||
75 | println!("Error serving: {e}"); | ||
76 | ExitCode::FAILURE | ||
77 | } | ||
78 | } | ||
79 | } | ||
37 | 80 | ||
38 | #[derive(Clone)] | 81 | #[derive(Clone)] |
39 | struct RepositoryName(String); | 82 | struct RepositoryName(String); |
@@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> { | |||
165 | Ok(aws_sdk_s3::Client::new(&config)) | 208 | Ok(aws_sdk_s3::Client::new(&config)) |
166 | } | 209 | } |
167 | 210 | ||
168 | #[tokio::main] | 211 | struct Config { |
169 | async fn main() -> ExitCode { | 212 | listen_addr: (String, u16), |
170 | tracing_subscriber::fmt::init(); | 213 | base_url: String, |
171 | 214 | authz_conf: AuthorizationConfig, | |
172 | let env = match Env::load() { | 215 | s3_client: aws_sdk_s3::Client, |
173 | Ok(env) => env, | 216 | s3_bucket: String, |
174 | Err(e) => { | 217 | download_limit: u64, |
175 | println!("Failed to load configuration: {e}"); | 218 | } |
176 | return ExitCode::from(2); | ||
177 | } | ||
178 | }; | ||
179 | |||
180 | let s3_client = match get_s3_client(&env) { | ||
181 | Ok(s3_client) => s3_client, | ||
182 | Err(e) => { | ||
183 | println!("Failed to create S3 client: {e}"); | ||
184 | return ExitCode::FAILURE; | ||
185 | } | ||
186 | }; | ||
187 | let key = match common::load_key(&env.key_path) { | ||
188 | Ok(key) => key, | ||
189 | Err(e) => { | ||
190 | println!("Failed to load Gitolfs3 key: {e}"); | ||
191 | return ExitCode::FAILURE; | ||
192 | } | ||
193 | }; | ||
194 | |||
195 | let trusted_forwarded_hosts: HashSet<String> = env | ||
196 | .trusted_forwarded_hosts | ||
197 | .split(',') | ||
198 | .map(|s| s.to_owned()) | ||
199 | .filter(|s| !s.is_empty()) | ||
200 | .collect(); | ||
201 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
202 | |||
203 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { | ||
204 | println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer"); | ||
205 | return ExitCode::from(2); | ||
206 | }; | ||
207 | let dl_limiter = DownloadLimiter::new(download_limit).await; | ||
208 | let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter)); | ||
209 | |||
210 | let resetter_dl_limiter = dl_limiter.clone(); | ||
211 | tokio::spawn(async move { | ||
212 | loop { | ||
213 | println!("Resetting download counter in one hour"); | ||
214 | tokio::time::sleep(std::time::Duration::from_secs(3600)).await; | ||
215 | println!("Resetting download counter"); | ||
216 | resetter_dl_limiter.lock().await.reset().await; | ||
217 | } | ||
218 | }); | ||
219 | 219 | ||
220 | let authz_conf = AuthorizationConfig { | 220 | impl Config { |
221 | key, | 221 | fn load() -> Result<Self, String> { |
222 | trusted_forwarded_hosts, | 222 | let env = match Env::load() { |
223 | }; | 223 | Ok(env) => env, |
224 | Err(e) => return Err(format!("failed to load configuration: {e}")), | ||
225 | }; | ||
224 | 226 | ||
225 | let shared_state = Arc::new(AppState { | 227 | let s3_client = match get_s3_client(&env) { |
226 | s3_client, | 228 | Ok(s3_client) => s3_client, |
227 | s3_bucket: env.s3_bucket, | 229 | Err(e) => return Err(format!("failed to create S3 client: {e}")), |
228 | authz_conf, | 230 | }; |
229 | base_url, | 231 | let key = match common::load_key(&env.key_path) { |
230 | dl_limiter, | 232 | Ok(key) => key, |
231 | }); | 233 | Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")), |
232 | let app = Router::new() | 234 | }; |
233 | .route("/batch", post(batch)) | ||
234 | .route("/:oid0/:oid1/:oid", get(obj_download)) | ||
235 | .with_state(shared_state); | ||
236 | 235 | ||
237 | let middleware = axum::middleware::map_request(rewrite_url); | 236 | let trusted_forwarded_hosts: HashSet<String> = env |
238 | let app_with_middleware = middleware.layer(app); | 237 | .trusted_forwarded_hosts |
238 | .split(',') | ||
239 | .map(|s| s.to_owned()) | ||
240 | .filter(|s| !s.is_empty()) | ||
241 | .collect(); | ||
242 | let base_url = env.base_url.trim_end_matches('/').to_string(); | ||
239 | 243 | ||
240 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { | 244 | let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { |
241 | println!( | 245 | return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string()); |
242 | "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" | 246 | }; |
243 | ); | 247 | let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else { |
244 | return ExitCode::from(2); | 248 | return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string()); |
245 | }; | 249 | }; |
246 | let addr: (String, u16) = (env.listen_host, listen_port); | ||
247 | let listener = match tokio::net::TcpListener::bind(addr).await { | ||
248 | Ok(listener) => listener, | ||
249 | Err(e) => { | ||
250 | println!("Failed to listen: {e}"); | ||
251 | return ExitCode::FAILURE; | ||
252 | } | ||
253 | }; | ||
254 | 250 | ||
255 | match axum::serve(listener, app_with_middleware.into_make_service()).await { | 251 | Ok(Self { |
256 | Ok(_) => ExitCode::SUCCESS, | 252 | listen_addr: (env.listen_host, listen_port), |
257 | Err(e) => { | 253 | base_url, |
258 | println!("Error serving: {e}"); | 254 | authz_conf: AuthorizationConfig { |
259 | ExitCode::FAILURE | 255 | key, |
260 | } | 256 | trusted_forwarded_hosts, |
257 | }, | ||
258 | s3_client, | ||
259 | s3_bucket: env.s3_bucket, | ||
260 | download_limit, | ||
261 | }) | ||
261 | } | 262 | } |
262 | } | 263 | } |
263 | 264 | ||
@@ -479,7 +480,7 @@ async fn handle_upload_object( | |||
479 | repo: &str, | 480 | repo: &str, |
480 | obj: &BatchRequestObject, | 481 | obj: &BatchRequestObject, |
481 | ) -> Option<BatchResponseObject> { | 482 | ) -> Option<BatchResponseObject> { |
482 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 483 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
483 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 484 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
484 | 485 | ||
485 | match state | 486 | match state |
@@ -558,7 +559,7 @@ async fn handle_download_object( | |||
558 | obj: &BatchRequestObject, | 559 | obj: &BatchRequestObject, |
559 | trusted: bool, | 560 | trusted: bool, |
560 | ) -> BatchResponseObject { | 561 | ) -> BatchResponseObject { |
561 | let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); | 562 | let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1])); |
562 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); | 563 | let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); |
563 | 564 | ||
564 | let result = match state | 565 | let result = match state |
@@ -687,8 +688,8 @@ async fn handle_download_object( | |||
687 | 688 | ||
688 | let upload_path = format!( | 689 | let upload_path = format!( |
689 | "{repo}/info/lfs/objects/{}/{}/{}", | 690 | "{repo}/info/lfs/objects/{}/{}/{}", |
690 | HexByte(obj.oid[0]), | 691 | common::HexByte(obj.oid[0]), |
691 | HexByte(obj.oid[1]), | 692 | common::HexByte(obj.oid[1]), |
692 | obj.oid, | 693 | obj.oid, |
693 | ); | 694 | ); |
694 | 695 | ||
@@ -866,8 +867,8 @@ async fn batch( | |||
866 | #[derive(Deserialize, Copy, Clone)] | 867 | #[derive(Deserialize, Copy, Clone)] |
867 | #[serde(remote = "Self")] | 868 | #[serde(remote = "Self")] |
868 | struct FileParams { | 869 | struct FileParams { |
869 | oid0: HexByte, | 870 | oid0: common::HexByte, |
870 | oid1: HexByte, | 871 | oid1: common::HexByte, |
871 | oid: common::Oid, | 872 | oid: common::Oid, |
872 | } | 873 | } |
873 | 874 | ||
@@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams { | |||
877 | D: serde::Deserializer<'de>, | 878 | D: serde::Deserializer<'de>, |
878 | { | 879 | { |
879 | let unchecked @ FileParams { | 880 | let unchecked @ FileParams { |
880 | oid0: HexByte(oid0), | 881 | oid0: common::HexByte(oid0), |
881 | oid1: HexByte(oid1), | 882 | oid1: common::HexByte(oid1), |
882 | oid, | 883 | oid, |
883 | } = FileParams::deserialize(deserializer)?; | 884 | } = FileParams::deserialize(deserializer)?; |
884 | if oid0 != oid.as_bytes()[0] { | 885 | if oid0 != oid.as_bytes()[0] { |
diff --git a/shell/src/main.rs b/shell/src/main.rs index 4901e7f..4a98828 100644 --- a/shell/src/main.rs +++ b/shell/src/main.rs | |||
@@ -1,84 +1,5 @@ | |||
1 | use std::{os::unix::process::CommandExt, process::ExitCode}; | 1 | use std::{os::unix::process::CommandExt, process::ExitCode}; |
2 | 2 | ||
3 | fn parse_sq(s: &str) -> Option<(String, &str)> { | ||
4 | #[derive(PartialEq, Eq)] | ||
5 | enum SqState { | ||
6 | Quoted, | ||
7 | Unquoted { may_escape: bool }, | ||
8 | UnquotedEscaped, | ||
9 | } | ||
10 | |||
11 | let mut result = String::new(); | ||
12 | let mut state = SqState::Unquoted { may_escape: false }; | ||
13 | let mut remaining = ""; | ||
14 | for (i, c) in s.char_indices() { | ||
15 | match state { | ||
16 | SqState::Unquoted { may_escape: false } => { | ||
17 | if c != '\'' { | ||
18 | return None; | ||
19 | } | ||
20 | state = SqState::Quoted | ||
21 | } | ||
22 | SqState::Quoted => { | ||
23 | if c == '\'' { | ||
24 | state = SqState::Unquoted { may_escape: true }; | ||
25 | continue; | ||
26 | } | ||
27 | result.push(c); | ||
28 | } | ||
29 | SqState::Unquoted { may_escape: true } => { | ||
30 | if is_posix_space(c) { | ||
31 | remaining = &s[i..]; | ||
32 | break; | ||
33 | } | ||
34 | if c != '\\' { | ||
35 | return None; | ||
36 | } | ||
37 | state = SqState::UnquotedEscaped; | ||
38 | } | ||
39 | SqState::UnquotedEscaped => { | ||
40 | if c != '\\' && c != '!' { | ||
41 | return None; | ||
42 | } | ||
43 | result.push(c); | ||
44 | state = SqState::Unquoted { may_escape: false }; | ||
45 | } | ||
46 | } | ||
47 | } | ||
48 | |||
49 | if state != (SqState::Unquoted { may_escape: true }) { | ||
50 | return None; | ||
51 | } | ||
52 | Some((result, remaining)) | ||
53 | } | ||
54 | |||
55 | fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> { | ||
56 | let mut args = Vec::<String>::new(); | ||
57 | |||
58 | cmd = cmd.trim_matches(is_posix_space); | ||
59 | while !cmd.is_empty() { | ||
60 | if cmd.starts_with('\'') { | ||
61 | let (arg, remaining) = parse_sq(cmd)?; | ||
62 | args.push(arg); | ||
63 | cmd = remaining.trim_start_matches(is_posix_space); | ||
64 | } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) { | ||
65 | args.push(arg.to_owned()); | ||
66 | cmd = remaining.trim_start_matches(is_posix_space); | ||
67 | } else { | ||
68 | args.push(cmd.to_owned()); | ||
69 | cmd = ""; | ||
70 | } | ||
71 | } | ||
72 | |||
73 | Some(args) | ||
74 | } | ||
75 | |||
76 | fn is_posix_space(c: char) -> bool { | ||
77 | // Form feed: 0x0c | ||
78 | // Vertical tab: 0x0b | ||
79 | c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b' | ||
80 | } | ||
81 | |||
82 | fn main() -> ExitCode { | 3 | fn main() -> ExitCode { |
83 | let bad_usage = ExitCode::from(2); | 4 | let bad_usage = ExitCode::from(2); |
84 | 5 | ||
@@ -141,3 +62,82 @@ fn main() -> ExitCode { | |||
141 | eprintln!("Error: {e}"); | 62 | eprintln!("Error: {e}"); |
142 | ExitCode::FAILURE | 63 | ExitCode::FAILURE |
143 | } | 64 | } |
65 | |||
66 | fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> { | ||
67 | let mut args = Vec::<String>::new(); | ||
68 | |||
69 | cmd = cmd.trim_matches(is_posix_space); | ||
70 | while !cmd.is_empty() { | ||
71 | if cmd.starts_with('\'') { | ||
72 | let (arg, remaining) = parse_sq(cmd)?; | ||
73 | args.push(arg); | ||
74 | cmd = remaining.trim_start_matches(is_posix_space); | ||
75 | } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) { | ||
76 | args.push(arg.to_owned()); | ||
77 | cmd = remaining.trim_start_matches(is_posix_space); | ||
78 | } else { | ||
79 | args.push(cmd.to_owned()); | ||
80 | cmd = ""; | ||
81 | } | ||
82 | } | ||
83 | |||
84 | Some(args) | ||
85 | } | ||
86 | |||
87 | fn is_posix_space(c: char) -> bool { | ||
88 | // Form feed: 0x0c | ||
89 | // Vertical tab: 0x0b | ||
90 | c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b' | ||
91 | } | ||
92 | |||
93 | fn parse_sq(s: &str) -> Option<(String, &str)> { | ||
94 | #[derive(PartialEq, Eq)] | ||
95 | enum SqState { | ||
96 | Quoted, | ||
97 | Unquoted { may_escape: bool }, | ||
98 | UnquotedEscaped, | ||
99 | } | ||
100 | |||
101 | let mut result = String::new(); | ||
102 | let mut state = SqState::Unquoted { may_escape: false }; | ||
103 | let mut remaining = ""; | ||
104 | for (i, c) in s.char_indices() { | ||
105 | match state { | ||
106 | SqState::Unquoted { may_escape: false } => { | ||
107 | if c != '\'' { | ||
108 | return None; | ||
109 | } | ||
110 | state = SqState::Quoted | ||
111 | } | ||
112 | SqState::Quoted => { | ||
113 | if c == '\'' { | ||
114 | state = SqState::Unquoted { may_escape: true }; | ||
115 | continue; | ||
116 | } | ||
117 | result.push(c); | ||
118 | } | ||
119 | SqState::Unquoted { may_escape: true } => { | ||
120 | if is_posix_space(c) { | ||
121 | remaining = &s[i..]; | ||
122 | break; | ||
123 | } | ||
124 | if c != '\\' { | ||
125 | return None; | ||
126 | } | ||
127 | state = SqState::UnquotedEscaped; | ||
128 | } | ||
129 | SqState::UnquotedEscaped => { | ||
130 | if c != '\\' && c != '!' { | ||
131 | return None; | ||
132 | } | ||
133 | result.push(c); | ||
134 | state = SqState::Unquoted { may_escape: false }; | ||
135 | } | ||
136 | } | ||
137 | } | ||
138 | |||
139 | if state != (SqState::Unquoted { may_escape: true }) { | ||
140 | return None; | ||
141 | } | ||
142 | Some((result, remaining)) | ||
143 | } | ||