From a7f9c8de31231b9fd9c67c57db659f7b01f1a3b0 Mon Sep 17 00:00:00 2001 From: Rutger Broekhoff Date: Mon, 29 Apr 2024 19:18:56 +0200 Subject: Rename crates (and therefore commands) --- gitolfs3-common/src/lib.rs | 348 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 348 insertions(+) create mode 100644 gitolfs3-common/src/lib.rs (limited to 'gitolfs3-common/src') diff --git a/gitolfs3-common/src/lib.rs b/gitolfs3-common/src/lib.rs new file mode 100644 index 0000000..917f566 --- /dev/null +++ b/gitolfs3-common/src/lib.rs @@ -0,0 +1,348 @@ +use chrono::{DateTime, Utc}; +use serde::{de, Deserialize, Serialize}; +use std::{ + fmt::{self, Write}, + ops, + str::FromStr, +}; +use subtle::ConstantTimeEq; + +#[repr(u8)] +enum AuthType { + BatchApi = 1, + Download = 2, +} + +#[derive(Debug, Copy, Clone)] +pub struct Claims<'a> { + pub specific_claims: SpecificClaims, + pub repo_path: &'a str, + pub expires_at: DateTime, +} + +#[derive(Debug, Copy, Clone)] +pub enum SpecificClaims { + BatchApi(Operation), + Download(Oid), +} + +pub type Oid = Digest<32>; + +#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] +#[repr(u8)] +pub enum Operation { + #[serde(rename = "download")] + Download = 1, + #[serde(rename = "upload")] + Upload = 2, +} + +/// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. +pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option> { + if claims.repo_path.len() > 100 { + return None; + } + + let mut hmac = hmac_sha256::HMAC::new(key); + match claims.specific_claims { + SpecificClaims::BatchApi(operation) => { + hmac.update([AuthType::BatchApi as u8]); + hmac.update([operation as u8]); + } + SpecificClaims::Download(oid) => { + hmac.update([AuthType::Download as u8]); + hmac.update(oid.as_bytes()); + } + } + hmac.update([claims.repo_path.len() as u8]); + hmac.update(claims.repo_path.as_bytes()); + hmac.update(claims.expires_at.timestamp().to_be_bytes()); + Some(hmac.finalize().into()) +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub struct ParseOperationError; + +impl fmt::Display for ParseOperationError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "operation should be 'download' or 'upload'") + } +} + +impl FromStr for Operation { + type Err = ParseOperationError; + + fn from_str(s: &str) -> Result { + match s { + "upload" => Ok(Self::Upload), + "download" => Ok(Self::Download), + _ => Err(ParseOperationError), + } + } +} + +/// None means out of range. +fn decode_nibble(c: u8) -> Option { + if c.is_ascii_digit() { + Some(c - b'0') + } else if (b'a'..=b'f').contains(&c) { + Some(c - b'a' + 10) + } else if (b'A'..=b'F').contains(&c) { + Some(c - b'A' + 10) + } else { + None + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct HexByte(pub u8); + +impl<'de> Deserialize<'de> for HexByte { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let str = <&str>::deserialize(deserializer)?; + let &[b1, b2] = str.as_bytes() else { + return Err(de::Error::invalid_length( + str.len(), + &"two hexadecimal characters", + )); + }; + let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { + return Err(de::Error::invalid_value( + de::Unexpected::Str(str), + &"two hexadecimal characters", + )); + }; + Ok(HexByte((b1 << 4) | b2)) + } +} + +impl fmt::Display for HexByte { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let &HexByte(b) = self; + HexFmt(&[b]).fmt(f) + } +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum ParseHexError { + UnevenNibbles, + InvalidCharacter, + TooShort, + TooLong, +} + +impl fmt::Display for ParseHexError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::UnevenNibbles => { + write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") + } + Self::InvalidCharacter => write!(f, "non-hex character encountered"), + Self::TooShort => write!(f, "unexpected end of hex sequence"), + Self::TooLong => write!(f, "longer hex sequence than expected"), + } + } +} + +#[derive(Debug)] +pub enum ReadHexError { + Io(std::io::Error), + Format(ParseHexError), +} + +impl fmt::Display for ReadHexError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Io(e) => e.fmt(f), + Self::Format(e) => e.fmt(f), + } + } +} + +fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { + if value.bytes().len() % 2 == 1 { + return Err(ParseHexError::UnevenNibbles); + } + if value.bytes().len() < 2 * buf.len() { + return Err(ParseHexError::TooShort); + } + if value.bytes().len() > 2 * buf.len() { + return Err(ParseHexError::TooLong); + } + for (i, c) in value.bytes().enumerate() { + if let Some(b) = decode_nibble(c) { + if i % 2 == 0 { + buf[i / 2] = b << 4; + } else { + buf[i / 2] |= b; + } + } else { + return Err(ParseHexError::InvalidCharacter); + } + } + Ok(()) +} + +pub type Key = SafeByteArray<64>; + +pub fn load_key(path: &str) -> Result { + let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; + key_str.trim().parse().map_err(ReadHexError::Format) +} + +pub struct SafeByteArray { + inner: [u8; N], +} + +impl SafeByteArray { + pub fn new() -> Self { + Self { inner: [0; N] } + } +} + +impl Default for SafeByteArray { + fn default() -> Self { + Self::new() + } +} + +impl AsRef<[u8]> for SafeByteArray { + fn as_ref(&self) -> &[u8] { + &self.inner + } +} + +impl AsMut<[u8]> for SafeByteArray { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.inner + } +} + +impl Drop for SafeByteArray { + fn drop(&mut self) { + self.inner.fill(0) + } +} + +impl FromStr for SafeByteArray { + type Err = ParseHexError; + + fn from_str(value: &str) -> Result { + let mut sba = Self { inner: [0u8; N] }; + parse_hex_exact(value, &mut sba.inner)?; + Ok(sba) + } +} + +pub struct HexFmt>(pub B); + +impl> fmt::Display for HexFmt { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let HexFmt(buf) = self; + for b in buf.as_ref() { + let (high, low) = (b >> 4, b & 0xF); + let highc = if high < 10 { + high + b'0' + } else { + high - 10 + b'a' + }; + let lowc = if low < 10 { + low + b'0' + } else { + low - 10 + b'a' + }; + f.write_char(highc as char)?; + f.write_char(lowc as char)?; + } + Ok(()) + } +} + +#[derive(Debug, Copy, Clone)] +pub struct Digest { + inner: [u8; N], +} + +impl ops::Index for Digest { + type Output = u8; + + fn index(&self, index: usize) -> &Self::Output { + &self.inner[index] + } +} + +impl Digest { + pub fn as_bytes(&self) -> &[u8; N] { + &self.inner + } +} + +impl fmt::Display for Digest { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + HexFmt(&self.inner).fmt(f) + } +} + +impl Digest { + pub fn new(data: [u8; N]) -> Self { + Self { inner: data } + } +} + +impl From<[u8; N]> for Digest { + fn from(value: [u8; N]) -> Self { + Self::new(value) + } +} + +impl From> for [u8; N] { + fn from(val: Digest) -> Self { + val.inner + } +} + +impl FromStr for Digest { + type Err = ParseHexError; + + fn from_str(value: &str) -> Result { + let mut buf = [0u8; N]; + parse_hex_exact(value, &mut buf)?; + Ok(buf.into()) + } +} + +impl ConstantTimeEq for Digest { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.inner.ct_eq(&other.inner) + } +} + +impl PartialEq for Digest { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Digest {} + +impl<'de, const N: usize> Deserialize<'de> for Digest { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let hex = <&str>::deserialize(deserializer)?; + Digest::from_str(hex).map_err(de::Error::custom) + } +} + +impl Serialize for Digest { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&format!("{self}")) + } +} -- cgit v1.2.3