aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLibravatar Rutger Broekhoff2024-01-26 12:28:39 +0100
committerLibravatar Rutger Broekhoff2024-01-26 12:28:39 +0100
commitf146743061ba8170569bf18518202df9a43c09f3 (patch)
tree798c303c07f9146d420de0bb5babb3df0fed58f5
parentc3d120445877b307f5ea7e95aed4bab45d7dede0 (diff)
downloadgitolfs3-f146743061ba8170569bf18518202df9a43c09f3.tar.gz
gitolfs3-f146743061ba8170569bf18518202df9a43c09f3.zip
Clean up part of the code
-rw-r--r--Cargo.lock7
-rw-r--r--common/src/lib.rs11
-rw-r--r--git-lfs-authenticate/Cargo.toml1
-rw-r--r--git-lfs-authenticate/src/main.rs275
-rw-r--r--server/src/main.rs257
-rw-r--r--shell/src/main.rs158
6 files changed, 306 insertions, 403 deletions
diff --git a/Cargo.lock b/Cargo.lock
index f3beb9e..cc204f1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -42,6 +42,12 @@ dependencies = [
42] 42]
43 43
44[[package]] 44[[package]]
45name = "anyhow"
46version = "1.0.79"
47source = "registry+https://github.com/rust-lang/crates.io-index"
48checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
49
50[[package]]
45name = "async-trait" 51name = "async-trait"
46version = "0.1.77" 52version = "0.1.77"
47source = "registry+https://github.com/rust-lang/crates.io-index" 53source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -869,6 +875,7 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
869name = "git-lfs-authenticate" 875name = "git-lfs-authenticate"
870version = "0.1.0" 876version = "0.1.0"
871dependencies = [ 877dependencies = [
878 "anyhow",
872 "chrono", 879 "chrono",
873 "common", 880 "common",
874] 881]
diff --git a/common/src/lib.rs b/common/src/lib.rs
index 995352d..0a538a5 100644
--- a/common/src/lib.rs
+++ b/common/src/lib.rs
@@ -1,9 +1,10 @@
1use chrono::{DateTime, Utc}; 1use chrono::{DateTime, Utc};
2use serde::de; 2use serde::{de, Deserialize, Serialize};
3use serde::{Deserialize, Serialize}; 3use std::{
4use std::fmt::Write; 4 fmt::{self, Write},
5use std::ops; 5 ops,
6use std::{fmt, str::FromStr}; 6 str::FromStr,
7};
7use subtle::ConstantTimeEq; 8use subtle::ConstantTimeEq;
8 9
9#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] 10#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
diff --git a/git-lfs-authenticate/Cargo.toml b/git-lfs-authenticate/Cargo.toml
index 217250f..f4ab4d7 100644
--- a/git-lfs-authenticate/Cargo.toml
+++ b/git-lfs-authenticate/Cargo.toml
@@ -4,5 +4,6 @@ version = "0.1.0"
4edition = "2021" 4edition = "2021"
5 5
6[dependencies] 6[dependencies]
7anyhow = "1.0"
7chrono = "0.4" 8chrono = "0.4"
8common = { path = "../common" } 9common = { path = "../common" }
diff --git a/git-lfs-authenticate/src/main.rs b/git-lfs-authenticate/src/main.rs
index 36d7818..accc37f 100644
--- a/git-lfs-authenticate/src/main.rs
+++ b/git-lfs-authenticate/src/main.rs
@@ -1,197 +1,13 @@
1use std::{fmt, process::ExitCode, time::Duration}; 1use anyhow::{anyhow, bail, Result};
2
3use chrono::Utc; 2use chrono::Utc;
4use common::{Operation, ParseOperationError}; 3use std::{process::ExitCode, time::Duration};
5
6fn help() {
7 eprintln!("Usage: git-lfs-authenticate <REPO> upload/download");
8}
9
10#[derive(Debug, Eq, PartialEq, Copy, Clone)]
11enum RepoNameError {
12 TooLong,
13 UnresolvedPath,
14 AbsolutePath,
15 MissingGitSuffix,
16}
17
18impl fmt::Display for RepoNameError {
19 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
20 match self {
21 Self::TooLong => write!(f, "too long (more than 100 characters)"),
22 Self::UnresolvedPath => {
23 write!(f, "contains path one or more path elements '.' and '..'")
24 }
25 Self::AbsolutePath => {
26 write!(f, "starts with '/', which is not allowed")
27 }
28 Self::MissingGitSuffix => write!(f, "misses '.git' suffix"),
29 }
30 }
31}
32
33// Using `Result<(), E>` here instead of `Option<E>` because `None` typically signifies some error
34// state with no further details provided. If we were to return an `Option` type, the user would
35// have to first transform it into a `Result` type in order to use the `?` operator, meaning that
36// they would have to the following operation to get the type into the right shape:
37// `validate_repo_path(path).map_or(Ok(()), Err)`. That would not be very ergonomic.
38fn validate_repo_path(path: &str) -> Result<(), RepoNameError> {
39 if path.len() > 100 {
40 return Err(RepoNameError::TooLong);
41 }
42 if path.contains("//")
43 || path.contains("/./")
44 || path.contains("/../")
45 || path.starts_with("./")
46 || path.starts_with("../")
47 {
48 return Err(RepoNameError::UnresolvedPath);
49 }
50 if path.starts_with('/') {
51 return Err(RepoNameError::AbsolutePath);
52 }
53 if !path.ends_with(".git") {
54 return Err(RepoNameError::MissingGitSuffix);
55 }
56 Ok(())
57}
58
59#[derive(Debug, Eq, PartialEq, Copy, Clone)]
60enum ParseCmdlineError {
61 UnknownOperation(ParseOperationError),
62 InvalidRepoName(RepoNameError),
63 UnexpectedArgCount(ArgCountError),
64}
65
66impl From<RepoNameError> for ParseCmdlineError {
67 fn from(value: RepoNameError) -> Self {
68 Self::InvalidRepoName(value)
69 }
70}
71
72impl From<ParseOperationError> for ParseCmdlineError {
73 fn from(value: ParseOperationError) -> Self {
74 Self::UnknownOperation(value)
75 }
76}
77
78impl From<ArgCountError> for ParseCmdlineError {
79 fn from(value: ArgCountError) -> Self {
80 Self::UnexpectedArgCount(value)
81 }
82}
83
84impl fmt::Display for ParseCmdlineError {
85 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
86 match self {
87 Self::UnknownOperation(e) => write!(f, "unknown operation: {e}"),
88 Self::InvalidRepoName(e) => write!(f, "invalid repository name: {e}"),
89 Self::UnexpectedArgCount(e) => e.fmt(f),
90 }
91 }
92}
93
94#[derive(Debug, Eq, PartialEq, Copy, Clone)]
95struct ArgCountError {
96 provided: usize,
97 expected: usize,
98}
99
100impl fmt::Display for ArgCountError {
101 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
102 write!(
103 f,
104 "got {} argument(s), expected {}",
105 self.provided, self.expected
106 )
107 }
108}
109
110fn get_cmdline_args<const N: usize>() -> Result<[String; N], ArgCountError> {
111 let args = std::env::args();
112 if args.len() - 1 != N {
113 return Err(ArgCountError {
114 provided: args.len() - 1,
115 expected: N,
116 });
117 }
118
119 // Does not allocate.
120 const EMPTY_STRING: String = String::new();
121 let mut values = [EMPTY_STRING; N];
122
123 // Skip the first element; we do not care about the program name.
124 for (i, arg) in args.skip(1).enumerate() {
125 values[i] = arg
126 }
127 Ok(values)
128}
129
130fn parse_cmdline() -> Result<(String, Operation), ParseCmdlineError> {
131 let [repo_path, op_str] = get_cmdline_args::<2>()?;
132 let op: Operation = op_str.parse()?;
133 validate_repo_path(&repo_path)?;
134 Ok((repo_path.to_string(), op))
135}
136
137fn repo_exists(name: &str) -> bool {
138 match std::fs::metadata(name) {
139 Ok(metadata) => metadata.is_dir(),
140 _ => false,
141 }
142}
143
144struct Config {
145 href_base: String,
146 key_path: String,
147}
148
149#[derive(Debug, Eq, PartialEq, Copy, Clone)]
150enum LoadConfigError {
151 BaseUrlNotProvided,
152 BaseUrlSlashSuffixMissing,
153 KeyPathNotProvided,
154}
155
156impl fmt::Display for LoadConfigError {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self {
159 Self::BaseUrlNotProvided => write!(f, "base URL not provided"),
160 Self::BaseUrlSlashSuffixMissing => write!(f, "base URL does not end with slash"),
161 Self::KeyPathNotProvided => write!(f, "key path not provided"),
162 }
163 }
164}
165
166fn load_config() -> Result<Config, LoadConfigError> {
167 let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else {
168 return Err(LoadConfigError::BaseUrlNotProvided);
169 };
170 if !href_base.ends_with('/') {
171 return Err(LoadConfigError::BaseUrlSlashSuffixMissing);
172 }
173 let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else {
174 return Err(LoadConfigError::KeyPathNotProvided);
175 };
176 Ok(Config {
177 href_base,
178 key_path,
179 })
180}
181 4
182fn main() -> ExitCode { 5fn main() -> ExitCode {
183 let config = match load_config() { 6 let config = match Config::load() {
184 Ok(config) => config, 7 Ok(config) => config,
185 Err(e) => { 8 Err(e) => {
186 eprintln!("Failed to load config: {e}"); 9 eprintln!("Error: {e}");
187 return ExitCode::FAILURE; 10 return ExitCode::from(2);
188 }
189 };
190 let key = match common::load_key(&config.key_path) {
191 Ok(key) => key,
192 Err(e) => {
193 eprintln!("Failed to load key: {e}");
194 return ExitCode::FAILURE;
195 } 11 }
196 }; 12 };
197 13
@@ -199,7 +15,7 @@ fn main() -> ExitCode {
199 Ok(args) => args, 15 Ok(args) => args,
200 Err(e) => { 16 Err(e) => {
201 eprintln!("Error: {e}\n"); 17 eprintln!("Error: {e}\n");
202 help(); 18 eprintln!("Usage: git-lfs-authenticate <REPO> upload/download");
203 // Exit code 2 signifies bad usage of CLI. 19 // Exit code 2 signifies bad usage of CLI.
204 return ExitCode::from(2); 20 return ExitCode::from(2);
205 } 21 }
@@ -217,7 +33,7 @@ fn main() -> ExitCode {
217 repo_path: &repo_name, 33 repo_path: &repo_name,
218 expires_at, 34 expires_at,
219 }, 35 },
220 key, 36 config.key,
221 ) else { 37 ) else {
222 eprintln!("Failed to generate validation tag"); 38 eprintln!("Failed to generate validation tag");
223 return ExitCode::FAILURE; 39 return ExitCode::FAILURE;
@@ -234,3 +50,80 @@ fn main() -> ExitCode {
234 50
235 ExitCode::SUCCESS 51 ExitCode::SUCCESS
236} 52}
53
54struct Config {
55 href_base: String,
56 key: common::Key,
57}
58
59impl Config {
60 fn load() -> Result<Self> {
61 let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else {
62 bail!("configured base URL not provided");
63 };
64 if !href_base.ends_with('/') {
65 bail!("configured base URL does not end with a slash");
66 }
67
68 let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else {
69 bail!("key path not provided");
70 };
71 let key = common::load_key(&key_path).map_err(|e| anyhow!("failed to load key: {e}"))?;
72
73 Ok(Self { href_base, key })
74 }
75}
76
77fn parse_cmdline() -> Result<(String, common::Operation)> {
78 let [repo_path, op_str] = get_cmdline_args::<2>()?;
79 let op: common::Operation = op_str
80 .parse()
81 .map_err(|e| anyhow!("unknown operation: {e}"))?;
82 validate_repo_path(&repo_path).map_err(|e| anyhow!("invalid repository name: {e}"))?;
83 Ok((repo_path.to_string(), op))
84}
85
86fn get_cmdline_args<const N: usize>() -> Result<[String; N]> {
87 let args = std::env::args();
88 if args.len() - 1 != N {
89 bail!("got {} argument(s), expected {}", args.len() - 1, N);
90 }
91
92 // Does not allocate.
93 const EMPTY_STRING: String = String::new();
94 let mut values = [EMPTY_STRING; N];
95
96 // Skip the first element; we do not care about the program name.
97 for (i, arg) in args.skip(1).enumerate() {
98 values[i] = arg
99 }
100 Ok(values)
101}
102
103fn validate_repo_path(path: &str) -> Result<()> {
104 if path.len() > 100 {
105 bail!("too long (more than 100 characters)");
106 }
107 if path.contains("//")
108 || path.contains("/./")
109 || path.contains("/../")
110 || path.starts_with("./")
111 || path.starts_with("../")
112 {
113 bail!("contains one or more path elements '.' and '..'");
114 }
115 if path.starts_with('/') {
116 bail!("starts with '/', which is not allowed");
117 }
118 if !path.ends_with(".git") {
119 bail!("missed '.git' suffix");
120 }
121 Ok(())
122}
123
124fn repo_exists(name: &str) -> bool {
125 match std::fs::metadata(name) {
126 Ok(metadata) => metadata.is_dir(),
127 _ => false,
128 }
129}
diff --git a/server/src/main.rs b/server/src/main.rs
index db37d14..4a88dcd 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,39 +1,82 @@
1use std::collections::HashMap; 1use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput};
2use std::collections::HashSet;
3use std::process::ExitCode;
4use std::sync::Arc;
5
6use aws_sdk_s3::error::SdkError;
7use aws_sdk_s3::operation::head_object::HeadObjectOutput;
8use axum::extract::rejection;
9use axum::extract::FromRequest;
10use axum::extract::Path;
11use axum::extract::State;
12use axum::http::header;
13use axum::http::HeaderMap;
14use axum::http::HeaderValue;
15use axum::response::Response;
16use axum::Json;
17use axum::ServiceExt;
18use base64::prelude::*;
19use chrono::DateTime;
20use chrono::Utc;
21use common::HexByte;
22use serde::de;
23use serde::de::DeserializeOwned;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::io::AsyncWriteExt;
27use tower::Layer;
28
29use axum::{ 2use axum::{
30 async_trait, 3 async_trait,
31 extract::{FromRequestParts, OriginalUri, Request}, 4 extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State},
32 http::{request::Parts, StatusCode, Uri}, 5 http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
33 response::IntoResponse, 6 response::{IntoResponse, Response},
34 routing::{get, post}, 7 routing::{get, post},
35 Extension, Router, 8 Extension, Json, Router, ServiceExt,
36}; 9};
10use base64::prelude::*;
11use chrono::{DateTime, Utc};
12use serde::{
13 de::{self, DeserializeOwned},
14 Deserialize, Serialize,
15};
16use std::{
17 collections::{HashMap, HashSet},
18 process::ExitCode,
19 sync::Arc,
20};
21use tokio::io::AsyncWriteExt;
22use tower::Layer;
23
24#[tokio::main]
25async fn main() -> ExitCode {
26 tracing_subscriber::fmt::init();
27
28 let conf = match Config::load() {
29 Ok(conf) => conf,
30 Err(e) => {
31 println!("Error: {e}");
32 return ExitCode::from(2);
33 }
34 };
35
36 let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
37 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
38
39 let resetter_dl_limiter = dl_limiter.clone();
40 tokio::spawn(async move {
41 loop {
42 println!("Resetting download counter in one hour");
43 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
44 println!("Resetting download counter");
45 resetter_dl_limiter.lock().await.reset().await;
46 }
47 });
48
49 let shared_state = Arc::new(AppState {
50 s3_client: conf.s3_client,
51 s3_bucket: conf.s3_bucket,
52 authz_conf: conf.authz_conf,
53 base_url: conf.base_url,
54 dl_limiter,
55 });
56 let app = Router::new()
57 .route("/batch", post(batch))
58 .route("/:oid0/:oid1/:oid", get(obj_download))
59 .with_state(shared_state);
60
61 let middleware = axum::middleware::map_request(rewrite_url);
62 let app_with_middleware = middleware.layer(app);
63
64 let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await {
65 Ok(listener) => listener,
66 Err(e) => {
67 println!("Failed to listen: {e}");
68 return ExitCode::FAILURE;
69 }
70 };
71
72 match axum::serve(listener, app_with_middleware.into_make_service()).await {
73 Ok(_) => ExitCode::SUCCESS,
74 Err(e) => {
75 println!("Error serving: {e}");
76 ExitCode::FAILURE
77 }
78 }
79}
37 80
38#[derive(Clone)] 81#[derive(Clone)]
39struct RepositoryName(String); 82struct RepositoryName(String);
@@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
165 Ok(aws_sdk_s3::Client::new(&config)) 208 Ok(aws_sdk_s3::Client::new(&config))
166} 209}
167 210
168#[tokio::main] 211struct Config {
169async fn main() -> ExitCode { 212 listen_addr: (String, u16),
170 tracing_subscriber::fmt::init(); 213 base_url: String,
171 214 authz_conf: AuthorizationConfig,
172 let env = match Env::load() { 215 s3_client: aws_sdk_s3::Client,
173 Ok(env) => env, 216 s3_bucket: String,
174 Err(e) => { 217 download_limit: u64,
175 println!("Failed to load configuration: {e}"); 218}
176 return ExitCode::from(2);
177 }
178 };
179
180 let s3_client = match get_s3_client(&env) {
181 Ok(s3_client) => s3_client,
182 Err(e) => {
183 println!("Failed to create S3 client: {e}");
184 return ExitCode::FAILURE;
185 }
186 };
187 let key = match common::load_key(&env.key_path) {
188 Ok(key) => key,
189 Err(e) => {
190 println!("Failed to load Gitolfs3 key: {e}");
191 return ExitCode::FAILURE;
192 }
193 };
194
195 let trusted_forwarded_hosts: HashSet<String> = env
196 .trusted_forwarded_hosts
197 .split(',')
198 .map(|s| s.to_owned())
199 .filter(|s| !s.is_empty())
200 .collect();
201 let base_url = env.base_url.trim_end_matches('/').to_string();
202
203 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
204 println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer");
205 return ExitCode::from(2);
206 };
207 let dl_limiter = DownloadLimiter::new(download_limit).await;
208 let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
209
210 let resetter_dl_limiter = dl_limiter.clone();
211 tokio::spawn(async move {
212 loop {
213 println!("Resetting download counter in one hour");
214 tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
215 println!("Resetting download counter");
216 resetter_dl_limiter.lock().await.reset().await;
217 }
218 });
219 219
220 let authz_conf = AuthorizationConfig { 220impl Config {
221 key, 221 fn load() -> Result<Self, String> {
222 trusted_forwarded_hosts, 222 let env = match Env::load() {
223 }; 223 Ok(env) => env,
224 Err(e) => return Err(format!("failed to load configuration: {e}")),
225 };
224 226
225 let shared_state = Arc::new(AppState { 227 let s3_client = match get_s3_client(&env) {
226 s3_client, 228 Ok(s3_client) => s3_client,
227 s3_bucket: env.s3_bucket, 229 Err(e) => return Err(format!("failed to create S3 client: {e}")),
228 authz_conf, 230 };
229 base_url, 231 let key = match common::load_key(&env.key_path) {
230 dl_limiter, 232 Ok(key) => key,
231 }); 233 Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
232 let app = Router::new() 234 };
233 .route("/batch", post(batch))
234 .route("/:oid0/:oid1/:oid", get(obj_download))
235 .with_state(shared_state);
236 235
237 let middleware = axum::middleware::map_request(rewrite_url); 236 let trusted_forwarded_hosts: HashSet<String> = env
238 let app_with_middleware = middleware.layer(app); 237 .trusted_forwarded_hosts
238 .split(',')
239 .map(|s| s.to_owned())
240 .filter(|s| !s.is_empty())
241 .collect();
242 let base_url = env.base_url.trim_end_matches('/').to_string();
239 243
240 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else { 244 let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
241 println!( 245 return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
242 "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535" 246 };
243 ); 247 let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
244 return ExitCode::from(2); 248 return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
245 }; 249 };
246 let addr: (String, u16) = (env.listen_host, listen_port);
247 let listener = match tokio::net::TcpListener::bind(addr).await {
248 Ok(listener) => listener,
249 Err(e) => {
250 println!("Failed to listen: {e}");
251 return ExitCode::FAILURE;
252 }
253 };
254 250
255 match axum::serve(listener, app_with_middleware.into_make_service()).await { 251 Ok(Self {
256 Ok(_) => ExitCode::SUCCESS, 252 listen_addr: (env.listen_host, listen_port),
257 Err(e) => { 253 base_url,
258 println!("Error serving: {e}"); 254 authz_conf: AuthorizationConfig {
259 ExitCode::FAILURE 255 key,
260 } 256 trusted_forwarded_hosts,
257 },
258 s3_client,
259 s3_bucket: env.s3_bucket,
260 download_limit,
261 })
261 } 262 }
262} 263}
263 264
@@ -479,7 +480,7 @@ async fn handle_upload_object(
479 repo: &str, 480 repo: &str,
480 obj: &BatchRequestObject, 481 obj: &BatchRequestObject,
481) -> Option<BatchResponseObject> { 482) -> Option<BatchResponseObject> {
482 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 483 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
483 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 484 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
484 485
485 match state 486 match state
@@ -558,7 +559,7 @@ async fn handle_download_object(
558 obj: &BatchRequestObject, 559 obj: &BatchRequestObject,
559 trusted: bool, 560 trusted: bool,
560) -> BatchResponseObject { 561) -> BatchResponseObject {
561 let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1])); 562 let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
562 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid); 563 let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
563 564
564 let result = match state 565 let result = match state
@@ -687,8 +688,8 @@ async fn handle_download_object(
687 688
688 let upload_path = format!( 689 let upload_path = format!(
689 "{repo}/info/lfs/objects/{}/{}/{}", 690 "{repo}/info/lfs/objects/{}/{}/{}",
690 HexByte(obj.oid[0]), 691 common::HexByte(obj.oid[0]),
691 HexByte(obj.oid[1]), 692 common::HexByte(obj.oid[1]),
692 obj.oid, 693 obj.oid,
693 ); 694 );
694 695
@@ -866,8 +867,8 @@ async fn batch(
866#[derive(Deserialize, Copy, Clone)] 867#[derive(Deserialize, Copy, Clone)]
867#[serde(remote = "Self")] 868#[serde(remote = "Self")]
868struct FileParams { 869struct FileParams {
869 oid0: HexByte, 870 oid0: common::HexByte,
870 oid1: HexByte, 871 oid1: common::HexByte,
871 oid: common::Oid, 872 oid: common::Oid,
872} 873}
873 874
@@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams {
877 D: serde::Deserializer<'de>, 878 D: serde::Deserializer<'de>,
878 { 879 {
879 let unchecked @ FileParams { 880 let unchecked @ FileParams {
880 oid0: HexByte(oid0), 881 oid0: common::HexByte(oid0),
881 oid1: HexByte(oid1), 882 oid1: common::HexByte(oid1),
882 oid, 883 oid,
883 } = FileParams::deserialize(deserializer)?; 884 } = FileParams::deserialize(deserializer)?;
884 if oid0 != oid.as_bytes()[0] { 885 if oid0 != oid.as_bytes()[0] {
diff --git a/shell/src/main.rs b/shell/src/main.rs
index 4901e7f..4a98828 100644
--- a/shell/src/main.rs
+++ b/shell/src/main.rs
@@ -1,84 +1,5 @@
1use std::{os::unix::process::CommandExt, process::ExitCode}; 1use std::{os::unix::process::CommandExt, process::ExitCode};
2 2
3fn parse_sq(s: &str) -> Option<(String, &str)> {
4 #[derive(PartialEq, Eq)]
5 enum SqState {
6 Quoted,
7 Unquoted { may_escape: bool },
8 UnquotedEscaped,
9 }
10
11 let mut result = String::new();
12 let mut state = SqState::Unquoted { may_escape: false };
13 let mut remaining = "";
14 for (i, c) in s.char_indices() {
15 match state {
16 SqState::Unquoted { may_escape: false } => {
17 if c != '\'' {
18 return None;
19 }
20 state = SqState::Quoted
21 }
22 SqState::Quoted => {
23 if c == '\'' {
24 state = SqState::Unquoted { may_escape: true };
25 continue;
26 }
27 result.push(c);
28 }
29 SqState::Unquoted { may_escape: true } => {
30 if is_posix_space(c) {
31 remaining = &s[i..];
32 break;
33 }
34 if c != '\\' {
35 return None;
36 }
37 state = SqState::UnquotedEscaped;
38 }
39 SqState::UnquotedEscaped => {
40 if c != '\\' && c != '!' {
41 return None;
42 }
43 result.push(c);
44 state = SqState::Unquoted { may_escape: false };
45 }
46 }
47 }
48
49 if state != (SqState::Unquoted { may_escape: true }) {
50 return None;
51 }
52 Some((result, remaining))
53}
54
55fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> {
56 let mut args = Vec::<String>::new();
57
58 cmd = cmd.trim_matches(is_posix_space);
59 while !cmd.is_empty() {
60 if cmd.starts_with('\'') {
61 let (arg, remaining) = parse_sq(cmd)?;
62 args.push(arg);
63 cmd = remaining.trim_start_matches(is_posix_space);
64 } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) {
65 args.push(arg.to_owned());
66 cmd = remaining.trim_start_matches(is_posix_space);
67 } else {
68 args.push(cmd.to_owned());
69 cmd = "";
70 }
71 }
72
73 Some(args)
74}
75
76fn is_posix_space(c: char) -> bool {
77 // Form feed: 0x0c
78 // Vertical tab: 0x0b
79 c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b'
80}
81
82fn main() -> ExitCode { 3fn main() -> ExitCode {
83 let bad_usage = ExitCode::from(2); 4 let bad_usage = ExitCode::from(2);
84 5
@@ -141,3 +62,82 @@ fn main() -> ExitCode {
141 eprintln!("Error: {e}"); 62 eprintln!("Error: {e}");
142 ExitCode::FAILURE 63 ExitCode::FAILURE
143} 64}
65
66fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> {
67 let mut args = Vec::<String>::new();
68
69 cmd = cmd.trim_matches(is_posix_space);
70 while !cmd.is_empty() {
71 if cmd.starts_with('\'') {
72 let (arg, remaining) = parse_sq(cmd)?;
73 args.push(arg);
74 cmd = remaining.trim_start_matches(is_posix_space);
75 } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) {
76 args.push(arg.to_owned());
77 cmd = remaining.trim_start_matches(is_posix_space);
78 } else {
79 args.push(cmd.to_owned());
80 cmd = "";
81 }
82 }
83
84 Some(args)
85}
86
87fn is_posix_space(c: char) -> bool {
88 // Form feed: 0x0c
89 // Vertical tab: 0x0b
90 c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b'
91}
92
93fn parse_sq(s: &str) -> Option<(String, &str)> {
94 #[derive(PartialEq, Eq)]
95 enum SqState {
96 Quoted,
97 Unquoted { may_escape: bool },
98 UnquotedEscaped,
99 }
100
101 let mut result = String::new();
102 let mut state = SqState::Unquoted { may_escape: false };
103 let mut remaining = "";
104 for (i, c) in s.char_indices() {
105 match state {
106 SqState::Unquoted { may_escape: false } => {
107 if c != '\'' {
108 return None;
109 }
110 state = SqState::Quoted
111 }
112 SqState::Quoted => {
113 if c == '\'' {
114 state = SqState::Unquoted { may_escape: true };
115 continue;
116 }
117 result.push(c);
118 }
119 SqState::Unquoted { may_escape: true } => {
120 if is_posix_space(c) {
121 remaining = &s[i..];
122 break;
123 }
124 if c != '\\' {
125 return None;
126 }
127 state = SqState::UnquotedEscaped;
128 }
129 SqState::UnquotedEscaped => {
130 if c != '\\' && c != '!' {
131 return None;
132 }
133 result.push(c);
134 state = SqState::Unquoted { may_escape: false };
135 }
136 }
137 }
138
139 if state != (SqState::Unquoted { may_escape: true }) {
140 return None;
141 }
142 Some((result, remaining))
143}