From f146743061ba8170569bf18518202df9a43c09f3 Mon Sep 17 00:00:00 2001
From: Rutger Broekhoff
Date: Fri, 26 Jan 2024 12:28:39 +0100
Subject: Clean up part of the code

---
 Cargo.lock                       |   7 +
 common/src/lib.rs                |  11 +-
 git-lfs-authenticate/Cargo.toml  |   1 +
 git-lfs-authenticate/src/main.rs | 275 ++++++++++++---------------------------
 server/src/main.rs               | 257 ++++++++++++++++++------------------
 shell/src/main.rs                | 158 +++++++++++-----------
 6 files changed, 306 insertions(+), 403 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index f3beb9e..cc204f1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -41,6 +41,12 @@ dependencies = [
  "libc",
 ]
 
+[[package]]
+name = "anyhow"
+version = "1.0.79"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
+
 [[package]]
 name = "async-trait"
 version = "0.1.77"
@@ -869,6 +875,7 @@ checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253"
 name = "git-lfs-authenticate"
 version = "0.1.0"
 dependencies = [
+ "anyhow",
  "chrono",
  "common",
 ]
diff --git a/common/src/lib.rs b/common/src/lib.rs
index 995352d..0a538a5 100644
--- a/common/src/lib.rs
+++ b/common/src/lib.rs
@@ -1,9 +1,10 @@
 use chrono::{DateTime, Utc};
-use serde::de;
-use serde::{Deserialize, Serialize};
-use std::fmt::Write;
-use std::ops;
-use std::{fmt, str::FromStr};
+use serde::{de, Deserialize, Serialize};
+use std::{
+    fmt::{self, Write},
+    ops,
+    str::FromStr,
+};
 use subtle::ConstantTimeEq;
 
 #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
diff --git a/git-lfs-authenticate/Cargo.toml b/git-lfs-authenticate/Cargo.toml
index 217250f..f4ab4d7 100644
--- a/git-lfs-authenticate/Cargo.toml
+++ b/git-lfs-authenticate/Cargo.toml
@@ -4,5 +4,6 @@ version = "0.1.0"
 edition = "2021"
 
 [dependencies]
+anyhow = "1.0"
 chrono = "0.4"
 common = { path = "../common" }
diff --git a/git-lfs-authenticate/src/main.rs b/git-lfs-authenticate/src/main.rs
index 36d7818..accc37f 100644
--- a/git-lfs-authenticate/src/main.rs
+++ b/git-lfs-authenticate/src/main.rs
@@ -1,197 +1,13 @@
-use std::{fmt, process::ExitCode, time::Duration};
-
+use anyhow::{anyhow, bail, Result};
 use chrono::Utc;
-use common::{Operation, ParseOperationError};
-
-fn help() {
-    eprintln!("Usage: git-lfs-authenticate <REPO> upload/download");
-}
-
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-enum RepoNameError {
-    TooLong,
-    UnresolvedPath,
-    AbsolutePath,
-    MissingGitSuffix,
-}
-
-impl fmt::Display for RepoNameError {
-    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        match self {
-            Self::TooLong => write!(f, "too long (more than 100 characters)"),
-            Self::UnresolvedPath => {
-                write!(f, "contains path one or more path elements '.' and '..'")
-            }
-            Self::AbsolutePath => {
-                write!(f, "starts with '/', which is not allowed")
-            }
-            Self::MissingGitSuffix => write!(f, "misses '.git' suffix"),
-        }
-    }
-}
-
-// Using `Result<(), E>` here instead of `Option<E>` because `None` typically signifies some error
-// state with no further details provided. If we were to return an `Option` type, the user would
-// have to first transform it into a `Result` type in order to use the `?` operator, meaning that
-// they would have to the following operation to get the type into the right shape:
-// `validate_repo_path(path).map_or(Ok(()), Err)`. That would not be very ergonomic.
-fn validate_repo_path(path: &str) -> Result<(), RepoNameError> {
-    if path.len() > 100 {
-        return Err(RepoNameError::TooLong);
-    }
-    if path.contains("//")
-        || path.contains("/./")
-        || path.contains("/../")
-        || path.starts_with("./")
-        || path.starts_with("../")
-    {
-        return Err(RepoNameError::UnresolvedPath);
-    }
-    if path.starts_with('/') {
-        return Err(RepoNameError::AbsolutePath);
-    }
-    if !path.ends_with(".git") {
-        return Err(RepoNameError::MissingGitSuffix);
-    }
-    Ok(())
-}
-
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-enum ParseCmdlineError {
-    UnknownOperation(ParseOperationError),
-    InvalidRepoName(RepoNameError),
-    UnexpectedArgCount(ArgCountError),
-}
-
-impl From<RepoNameError> for ParseCmdlineError {
-    fn from(value: RepoNameError) -> Self {
-        Self::InvalidRepoName(value)
-    }
-}
-
-impl From<ParseOperationError> for ParseCmdlineError {
-    fn from(value: ParseOperationError) -> Self {
-        Self::UnknownOperation(value)
-    }
-}
-
-impl From<ArgCountError> for ParseCmdlineError {
-    fn from(value: ArgCountError) -> Self {
-        Self::UnexpectedArgCount(value)
-    }
-}
-
-impl fmt::Display for ParseCmdlineError {
-    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        match self {
-            Self::UnknownOperation(e) => write!(f, "unknown operation: {e}"),
-            Self::InvalidRepoName(e) => write!(f, "invalid repository name: {e}"),
-            Self::UnexpectedArgCount(e) => e.fmt(f),
-        }
-    }
-}
-
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-struct ArgCountError {
-    provided: usize,
-    expected: usize,
-}
-
-impl fmt::Display for ArgCountError {
-    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
-        write!(
-            f,
-            "got {} argument(s), expected {}",
-            self.provided, self.expected
-        )
-    }
-}
-
-fn get_cmdline_args<const N: usize>() -> Result<[String; N], ArgCountError> {
-    let args = std::env::args();
-    if args.len() - 1 != N {
-        return Err(ArgCountError {
-            provided: args.len() - 1,
-            expected: N,
-        });
-    }
-
-    // Does not allocate.
-    const EMPTY_STRING: String = String::new();
-    let mut values = [EMPTY_STRING; N];
-
-    // Skip the first element; we do not care about the program name.
-    for (i, arg) in args.skip(1).enumerate() {
-        values[i] = arg
-    }
-    Ok(values)
-}
-
-fn parse_cmdline() -> Result<(String, Operation), ParseCmdlineError> {
-    let [repo_path, op_str] = get_cmdline_args::<2>()?;
-    let op: Operation = op_str.parse()?;
-    validate_repo_path(&repo_path)?;
-    Ok((repo_path.to_string(), op))
-}
-
-fn repo_exists(name: &str) -> bool {
-    match std::fs::metadata(name) {
-        Ok(metadata) => metadata.is_dir(),
-        _ => false,
-    }
-}
-
-struct Config {
-    href_base: String,
-    key_path: String,
-}
-
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-enum LoadConfigError {
-    BaseUrlNotProvided,
-    BaseUrlSlashSuffixMissing,
-    KeyPathNotProvided,
-}
-
-impl fmt::Display for LoadConfigError {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        match self {
-            Self::BaseUrlNotProvided => write!(f, "base URL not provided"),
-            Self::BaseUrlSlashSuffixMissing => write!(f, "base URL does not end with slash"),
-            Self::KeyPathNotProvided => write!(f, "key path not provided"),
-        }
-    }
-}
-
-fn load_config() -> Result<Config, LoadConfigError> {
-    let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else {
-        return Err(LoadConfigError::BaseUrlNotProvided);
-    };
-    if !href_base.ends_with('/') {
-        return Err(LoadConfigError::BaseUrlSlashSuffixMissing);
-    }
-    let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else {
-        return Err(LoadConfigError::KeyPathNotProvided);
-    };
-    Ok(Config {
-        href_base,
-        key_path,
-    })
-}
+use std::{process::ExitCode, time::Duration};
 
 fn main() -> ExitCode {
-    let config = match load_config() {
+    let config = match Config::load() {
         Ok(config) => config,
         Err(e) => {
-            eprintln!("Failed to load config: {e}");
-            return ExitCode::FAILURE;
-        }
-    };
-    let key = match common::load_key(&config.key_path) {
-        Ok(key) => key,
-        Err(e) => {
-            eprintln!("Failed to load key: {e}");
-            return ExitCode::FAILURE;
+            eprintln!("Error: {e}");
+            return ExitCode::from(2);
         }
     };
 
@@ -199,7 +15,7 @@ fn main() -> ExitCode {
         Ok(args) => args,
         Err(e) => {
             eprintln!("Error: {e}\n");
-            help();
+            eprintln!("Usage: git-lfs-authenticate <REPO> upload/download");
             // Exit code 2 signifies bad usage of CLI.
             return ExitCode::from(2);
         }
@@ -217,7 +33,7 @@ fn main() -> ExitCode {
             repo_path: &repo_name,
             expires_at,
         },
-        key,
+        config.key,
     ) else {
         eprintln!("Failed to generate validation tag");
         return ExitCode::FAILURE;
@@ -234,3 +50,80 @@ fn main() -> ExitCode {
 
     ExitCode::SUCCESS
 }
+
+struct Config {
+    href_base: String,
+    key: common::Key,
+}
+
+impl Config {
+    fn load() -> Result<Self> {
+        let Ok(href_base) = std::env::var("GITOLFS3_HREF_BASE") else {
+            bail!("configured base URL not provided");
+        };
+        if !href_base.ends_with('/') {
+            bail!("configured base URL does not end with a slash");
+        }
+
+        let Ok(key_path) = std::env::var("GITOLFS3_KEY_PATH") else {
+            bail!("key path not provided");
+        };
+        let key = common::load_key(&key_path).map_err(|e| anyhow!("failed to load key: {e}"))?;
+
+        Ok(Self { href_base, key })
+    }
+}
+
+fn parse_cmdline() -> Result<(String, common::Operation)> {
+    let [repo_path, op_str] = get_cmdline_args::<2>()?;
+    let op: common::Operation = op_str
+        .parse()
+        .map_err(|e| anyhow!("unknown operation: {e}"))?;
+    validate_repo_path(&repo_path).map_err(|e| anyhow!("invalid repository name: {e}"))?;
+    Ok((repo_path.to_string(), op))
+}
+
+fn get_cmdline_args<const N: usize>() -> Result<[String; N]> {
+    let args = std::env::args();
+    if args.len() - 1 != N {
+        bail!("got {} argument(s), expected {}", args.len() - 1, N);
+    }
+
+    // Does not allocate.
+    const EMPTY_STRING: String = String::new();
+    let mut values = [EMPTY_STRING; N];
+
+    // Skip the first element; we do not care about the program name.
+    for (i, arg) in args.skip(1).enumerate() {
+        values[i] = arg
+    }
+    Ok(values)
+}
+
+fn validate_repo_path(path: &str) -> Result<()> {
+    if path.len() > 100 {
+        bail!("too long (more than 100 characters)");
+    }
+    if path.contains("//")
+        || path.contains("/./")
+        || path.contains("/../")
+        || path.starts_with("./")
+        || path.starts_with("../")
+    {
+        bail!("contains one or more path elements '.' and '..'");
+    }
+    if path.starts_with('/') {
+        bail!("starts with '/', which is not allowed");
+    }
+    if !path.ends_with(".git") {
+        bail!("missed '.git' suffix");
+    }
+    Ok(())
+}
+
+fn repo_exists(name: &str) -> bool {
+    match std::fs::metadata(name) {
+        Ok(metadata) => metadata.is_dir(),
+        _ => false,
+    }
+}
diff --git a/server/src/main.rs b/server/src/main.rs
index db37d14..4a88dcd 100644
--- a/server/src/main.rs
+++ b/server/src/main.rs
@@ -1,39 +1,82 @@
-use std::collections::HashMap;
-use std::collections::HashSet;
-use std::process::ExitCode;
-use std::sync::Arc;
-
-use aws_sdk_s3::error::SdkError;
-use aws_sdk_s3::operation::head_object::HeadObjectOutput;
-use axum::extract::rejection;
-use axum::extract::FromRequest;
-use axum::extract::Path;
-use axum::extract::State;
-use axum::http::header;
-use axum::http::HeaderMap;
-use axum::http::HeaderValue;
-use axum::response::Response;
-use axum::Json;
-use axum::ServiceExt;
-use base64::prelude::*;
-use chrono::DateTime;
-use chrono::Utc;
-use common::HexByte;
-use serde::de;
-use serde::de::DeserializeOwned;
-use serde::Deserialize;
-use serde::Serialize;
-use tokio::io::AsyncWriteExt;
-use tower::Layer;
-
+use aws_sdk_s3::{error::SdkError, operation::head_object::HeadObjectOutput};
 use axum::{
     async_trait,
-    extract::{FromRequestParts, OriginalUri, Request},
-    http::{request::Parts, StatusCode, Uri},
-    response::IntoResponse,
+    extract::{rejection, FromRequest, FromRequestParts, OriginalUri, Path, Request, State},
+    http::{header, request::Parts, HeaderMap, HeaderValue, StatusCode, Uri},
+    response::{IntoResponse, Response},
     routing::{get, post},
-    Extension, Router,
+    Extension, Json, Router, ServiceExt,
 };
+use base64::prelude::*;
+use chrono::{DateTime, Utc};
+use serde::{
+    de::{self, DeserializeOwned},
+    Deserialize, Serialize,
+};
+use std::{
+    collections::{HashMap, HashSet},
+    process::ExitCode,
+    sync::Arc,
+};
+use tokio::io::AsyncWriteExt;
+use tower::Layer;
+
+#[tokio::main]
+async fn main() -> ExitCode {
+    tracing_subscriber::fmt::init();
+
+    let conf = match Config::load() {
+        Ok(conf) => conf,
+        Err(e) => {
+            println!("Error: {e}");
+            return ExitCode::from(2);
+        }
+    };
+
+    let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
+    let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
+
+    let resetter_dl_limiter = dl_limiter.clone();
+    tokio::spawn(async move {
+        loop {
+            println!("Resetting download counter in one hour");
+            tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
+            println!("Resetting download counter");
+            resetter_dl_limiter.lock().await.reset().await;
+        }
+    });
+
+    let shared_state = Arc::new(AppState {
+        s3_client: conf.s3_client,
+        s3_bucket: conf.s3_bucket,
+        authz_conf: conf.authz_conf,
+        base_url: conf.base_url,
+        dl_limiter,
+    });
+    let app = Router::new()
+        .route("/batch", post(batch))
+        .route("/:oid0/:oid1/:oid", get(obj_download))
+        .with_state(shared_state);
+
+    let middleware = axum::middleware::map_request(rewrite_url);
+    let app_with_middleware = middleware.layer(app);
+
+    let listener = match tokio::net::TcpListener::bind(conf.listen_addr).await {
+        Ok(listener) => listener,
+        Err(e) => {
+            println!("Failed to listen: {e}");
+            return ExitCode::FAILURE;
+        }
+    };
+
+    match axum::serve(listener, app_with_middleware.into_make_service()).await {
+        Ok(_) => ExitCode::SUCCESS,
+        Err(e) => {
+            println!("Error serving: {e}");
+            ExitCode::FAILURE
+        }
+    }
+}
 
 #[derive(Clone)]
 struct RepositoryName(String);
@@ -165,99 +208,57 @@ fn get_s3_client(env: &Env) -> Result<aws_sdk_s3::Client, std::io::Error> {
     Ok(aws_sdk_s3::Client::new(&config))
 }
 
-#[tokio::main]
-async fn main() -> ExitCode {
-    tracing_subscriber::fmt::init();
-
-    let env = match Env::load() {
-        Ok(env) => env,
-        Err(e) => {
-            println!("Failed to load configuration: {e}");
-            return ExitCode::from(2);
-        }
-    };
-
-    let s3_client = match get_s3_client(&env) {
-        Ok(s3_client) => s3_client,
-        Err(e) => {
-            println!("Failed to create S3 client: {e}");
-            return ExitCode::FAILURE;
-        }
-    };
-    let key = match common::load_key(&env.key_path) {
-        Ok(key) => key,
-        Err(e) => {
-            println!("Failed to load Gitolfs3 key: {e}");
-            return ExitCode::FAILURE;
-        }
-    };
-
-    let trusted_forwarded_hosts: HashSet<String> = env
-        .trusted_forwarded_hosts
-        .split(',')
-        .map(|s| s.to_owned())
-        .filter(|s| !s.is_empty())
-        .collect();
-    let base_url = env.base_url.trim_end_matches('/').to_string();
-
-    let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
-        println!("Configured GITOLFS3_DOWNLOAD_LIMIT should be a 64-bit unsigned integer");
-        return ExitCode::from(2);
-    };
-    let dl_limiter = DownloadLimiter::new(download_limit).await;
-    let dl_limiter = Arc::new(tokio::sync::Mutex::new(dl_limiter));
-
-    let resetter_dl_limiter = dl_limiter.clone();
-    tokio::spawn(async move {
-        loop {
-            println!("Resetting download counter in one hour");
-            tokio::time::sleep(std::time::Duration::from_secs(3600)).await;
-            println!("Resetting download counter");
-            resetter_dl_limiter.lock().await.reset().await;
-        }
-    });
+struct Config {
+    listen_addr: (String, u16),
+    base_url: String,
+    authz_conf: AuthorizationConfig,
+    s3_client: aws_sdk_s3::Client,
+    s3_bucket: String,
+    download_limit: u64,
+}
 
-    let authz_conf = AuthorizationConfig {
-        key,
-        trusted_forwarded_hosts,
-    };
+impl Config {
+    fn load() -> Result<Self, String> {
+        let env = match Env::load() {
+            Ok(env) => env,
+            Err(e) => return Err(format!("failed to load configuration: {e}")),
+        };
 
-    let shared_state = Arc::new(AppState {
-        s3_client,
-        s3_bucket: env.s3_bucket,
-        authz_conf,
-        base_url,
-        dl_limiter,
-    });
-    let app = Router::new()
-        .route("/batch", post(batch))
-        .route("/:oid0/:oid1/:oid", get(obj_download))
-        .with_state(shared_state);
+        let s3_client = match get_s3_client(&env) {
+            Ok(s3_client) => s3_client,
+            Err(e) => return Err(format!("failed to create S3 client: {e}")),
+        };
+        let key = match common::load_key(&env.key_path) {
+            Ok(key) => key,
+            Err(e) => return Err(format!("failed to load Gitolfs3 key: {e}")),
+        };
 
-    let middleware = axum::middleware::map_request(rewrite_url);
-    let app_with_middleware = middleware.layer(app);
+        let trusted_forwarded_hosts: HashSet<String> = env
+            .trusted_forwarded_hosts
+            .split(',')
+            .map(|s| s.to_owned())
+            .filter(|s| !s.is_empty())
+            .collect();
+        let base_url = env.base_url.trim_end_matches('/').to_string();
 
-    let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
-        println!(
-            "Configured GITOLFS3_LISTEN_PORT should be an unsigned integer no higher than 65535"
-        );
-        return ExitCode::from(2);
-    };
-    let addr: (String, u16) = (env.listen_host, listen_port);
-    let listener = match tokio::net::TcpListener::bind(addr).await {
-        Ok(listener) => listener,
-        Err(e) => {
-            println!("Failed to listen: {e}");
-            return ExitCode::FAILURE;
-        }
-    };
+        let Ok(listen_port): Result<u16, _> = env.listen_port.parse() else {
+            return Err("configured GITOLFS3_LISTEN_PORT is invalid".to_string());
+        };
+        let Ok(download_limit): Result<u64, _> = env.download_limit.parse() else {
+            return Err("configured GITOLFS3_DOWNLOAD_LIMIT is invalid".to_string());
+        };
 
-    match axum::serve(listener, app_with_middleware.into_make_service()).await {
-        Ok(_) => ExitCode::SUCCESS,
-        Err(e) => {
-            println!("Error serving: {e}");
-            ExitCode::FAILURE
-        }
+        Ok(Self {
+            listen_addr: (env.listen_host, listen_port),
+            base_url,
+            authz_conf: AuthorizationConfig {
+                key,
+                trusted_forwarded_hosts,
+            },
+            s3_client,
+            s3_bucket: env.s3_bucket,
+            download_limit,
+        })
     }
 }
 
@@ -479,7 +480,7 @@ async fn handle_upload_object(
     repo: &str,
     obj: &BatchRequestObject,
 ) -> Option<BatchResponseObject> {
-    let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
+    let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
     let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
 
     match state
@@ -558,7 +559,7 @@ async fn handle_download_object(
     obj: &BatchRequestObject,
     trusted: bool,
 ) -> BatchResponseObject {
-    let (oid0, oid1) = (HexByte(obj.oid[0]), HexByte(obj.oid[1]));
+    let (oid0, oid1) = (common::HexByte(obj.oid[0]), common::HexByte(obj.oid[1]));
     let full_path = format!("{repo}/lfs/objects/{}/{}/{}", oid0, oid1, obj.oid);
 
     let result = match state
@@ -687,8 +688,8 @@ async fn handle_download_object(
 
     let upload_path = format!(
         "{repo}/info/lfs/objects/{}/{}/{}",
-        HexByte(obj.oid[0]),
-        HexByte(obj.oid[1]),
+        common::HexByte(obj.oid[0]),
+        common::HexByte(obj.oid[1]),
         obj.oid,
     );
 
@@ -866,8 +867,8 @@ async fn batch(
 #[derive(Deserialize, Copy, Clone)]
 #[serde(remote = "Self")]
 struct FileParams {
-    oid0: HexByte,
-    oid1: HexByte,
+    oid0: common::HexByte,
+    oid1: common::HexByte,
     oid: common::Oid,
 }
 
@@ -877,8 +878,8 @@ impl<'de> Deserialize<'de> for FileParams {
         D: serde::Deserializer<'de>,
     {
         let unchecked @ FileParams {
-            oid0: HexByte(oid0),
-            oid1: HexByte(oid1),
+            oid0: common::HexByte(oid0),
+            oid1: common::HexByte(oid1),
             oid,
         } = FileParams::deserialize(deserializer)?;
         if oid0 != oid.as_bytes()[0] {
diff --git a/shell/src/main.rs b/shell/src/main.rs
index 4901e7f..4a98828 100644
--- a/shell/src/main.rs
+++ b/shell/src/main.rs
@@ -1,84 +1,5 @@
 use std::{os::unix::process::CommandExt, process::ExitCode};
 
-fn parse_sq(s: &str) -> Option<(String, &str)> {
-    #[derive(PartialEq, Eq)]
-    enum SqState {
-        Quoted,
-        Unquoted { may_escape: bool },
-        UnquotedEscaped,
-    }
-
-    let mut result = String::new();
-    let mut state = SqState::Unquoted { may_escape: false };
-    let mut remaining = "";
-    for (i, c) in s.char_indices() {
-        match state {
-            SqState::Unquoted { may_escape: false } => {
-                if c != '\'' {
-                    return None;
-                }
-                state = SqState::Quoted
-            }
-            SqState::Quoted => {
-                if c == '\'' {
-                    state = SqState::Unquoted { may_escape: true };
-                    continue;
-                }
-                result.push(c);
-            }
-            SqState::Unquoted { may_escape: true } => {
-                if is_posix_space(c) {
-                    remaining = &s[i..];
-                    break;
-                }
-                if c != '\\' {
-                    return None;
-                }
-                state = SqState::UnquotedEscaped;
-            }
-            SqState::UnquotedEscaped => {
-                if c != '\\' && c != '!' {
-                    return None;
-                }
-                result.push(c);
-                state = SqState::Unquoted { may_escape: false };
-            }
-        }
-    }
-
-    if state != (SqState::Unquoted { may_escape: true }) {
-        return None;
-    }
-    Some((result, remaining))
-}
-
-fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> {
-    let mut args = Vec::<String>::new();
-
-    cmd = cmd.trim_matches(is_posix_space);
-    while !cmd.is_empty() {
-        if cmd.starts_with('\'') {
-            let (arg, remaining) = parse_sq(cmd)?;
-            args.push(arg);
-            cmd = remaining.trim_start_matches(is_posix_space);
-        } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) {
-            args.push(arg.to_owned());
-            cmd = remaining.trim_start_matches(is_posix_space);
-        } else {
-            args.push(cmd.to_owned());
-            cmd = "";
-        }
-    }
-
-    Some(args)
-}
-
-fn is_posix_space(c: char) -> bool {
-    // Form feed: 0x0c
-    // Vertical tab: 0x0b
-    c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b'
-}
-
 fn main() -> ExitCode {
     let bad_usage = ExitCode::from(2);
 
@@ -141,3 +62,82 @@ fn main() -> ExitCode {
     eprintln!("Error: {e}");
     ExitCode::FAILURE
 }
+
+fn parse_cmd(mut cmd: &str) -> Option<Vec<String>> {
+    let mut args = Vec::<String>::new();
+
+    cmd = cmd.trim_matches(is_posix_space);
+    while !cmd.is_empty() {
+        if cmd.starts_with('\'') {
+            let (arg, remaining) = parse_sq(cmd)?;
+            args.push(arg);
+            cmd = remaining.trim_start_matches(is_posix_space);
+        } else if let Some((arg, remaining)) = cmd.split_once(is_posix_space) {
+            args.push(arg.to_owned());
+            cmd = remaining.trim_start_matches(is_posix_space);
+        } else {
+            args.push(cmd.to_owned());
+            cmd = "";
+        }
+    }
+
+    Some(args)
+}
+
+fn is_posix_space(c: char) -> bool {
+    // Form feed: 0x0c
+    // Vertical tab: 0x0b
+    c == ' ' || c == '\x0c' || c == '\n' || c == '\r' || c == '\t' || c == '\x0b'
+}
+
+fn parse_sq(s: &str) -> Option<(String, &str)> {
+    #[derive(PartialEq, Eq)]
+    enum SqState {
+        Quoted,
+        Unquoted { may_escape: bool },
+        UnquotedEscaped,
+    }
+
+    let mut result = String::new();
+    let mut state = SqState::Unquoted { may_escape: false };
+    let mut remaining = "";
+    for (i, c) in s.char_indices() {
+        match state {
+            SqState::Unquoted { may_escape: false } => {
+                if c != '\'' {
+                    return None;
+                }
+                state = SqState::Quoted
+            }
+            SqState::Quoted => {
+                if c == '\'' {
+                    state = SqState::Unquoted { may_escape: true };
+                    continue;
+                }
+                result.push(c);
+            }
+            SqState::Unquoted { may_escape: true } => {
+                if is_posix_space(c) {
+                    remaining = &s[i..];
+                    break;
+                }
+                if c != '\\' {
+                    return None;
+                }
+                state = SqState::UnquotedEscaped;
+            }
+            SqState::UnquotedEscaped => {
+                if c != '\\' && c != '!' {
+                    return None;
+                }
+                result.push(c);
+                state = SqState::Unquoted { may_escape: false };
+            }
+        }
+    }
+
+    if state != (SqState::Unquoted { may_escape: true }) {
+        return None;
+    }
+    Some((result, remaining))
+}
-- 
cgit v1.2.3