use chrono::{DateTime, Utc}; use serde::{de, Deserialize, Serialize}; use std::{ fmt::{self, Write}, ops, str::FromStr, }; use subtle::ConstantTimeEq; #[repr(u8)] enum AuthType { BatchApi = 1, Download = 2, } #[derive(Debug, Copy, Clone)] pub struct Claims<'a> { pub specific_claims: SpecificClaims, pub repo_path: &'a str, pub expires_at: DateTime, } #[derive(Debug, Copy, Clone)] pub enum SpecificClaims { BatchApi(Operation), Download(Oid), } pub type Oid = Digest<32>; #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] #[repr(u8)] pub enum Operation { #[serde(rename = "download")] Download = 1, #[serde(rename = "upload")] Upload = 2, } /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option> { if claims.repo_path.len() > 100 { return None; } let mut hmac = hmac_sha256::HMAC::new(key); match claims.specific_claims { SpecificClaims::BatchApi(operation) => { hmac.update([AuthType::BatchApi as u8]); hmac.update([operation as u8]); } SpecificClaims::Download(oid) => { hmac.update([AuthType::Download as u8]); hmac.update(oid.as_bytes()); } } hmac.update([claims.repo_path.len() as u8]); hmac.update(claims.repo_path.as_bytes()); hmac.update(claims.expires_at.timestamp().to_be_bytes()); Some(hmac.finalize().into()) } #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub struct ParseOperationError; impl fmt::Display for ParseOperationError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "operation should be 'download' or 'upload'") } } impl FromStr for Operation { type Err = ParseOperationError; fn from_str(s: &str) -> Result { match s { "upload" => Ok(Self::Upload), "download" => Ok(Self::Download), _ => Err(ParseOperationError), } } } /// None means out of range. fn decode_nibble(c: u8) -> Option { if c.is_ascii_digit() { Some(c - b'0') } else if (b'a'..=b'f').contains(&c) { Some(c - b'a' + 10) } else if (b'A'..=b'F').contains(&c) { Some(c - b'A' + 10) } else { None } } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct HexByte(pub u8); impl<'de> Deserialize<'de> for HexByte { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let str = <&str>::deserialize(deserializer)?; let &[b1, b2] = str.as_bytes() else { return Err(de::Error::invalid_length( str.len(), &"two hexadecimal characters", )); }; let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { return Err(de::Error::invalid_value( de::Unexpected::Str(str), &"two hexadecimal characters", )); }; Ok(HexByte((b1 << 4) | b2)) } } impl fmt::Display for HexByte { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let &HexByte(b) = self; HexFmt(&[b]).fmt(f) } } #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum ParseHexError { UnevenNibbles, InvalidCharacter, TooShort, TooLong, } impl fmt::Display for ParseHexError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::UnevenNibbles => { write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") } Self::InvalidCharacter => write!(f, "non-hex character encountered"), Self::TooShort => write!(f, "unexpected end of hex sequence"), Self::TooLong => write!(f, "longer hex sequence than expected"), } } } #[derive(Debug)] pub enum ReadHexError { Io(std::io::Error), Format(ParseHexError), } impl fmt::Display for ReadHexError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Io(e) => e.fmt(f), Self::Format(e) => e.fmt(f), } } } fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { if value.bytes().len() % 2 == 1 { return Err(ParseHexError::UnevenNibbles); } if value.bytes().len() < 2 * buf.len() { return Err(ParseHexError::TooShort); } if value.bytes().len() > 2 * buf.len() { return Err(ParseHexError::TooLong); } for (i, c) in value.bytes().enumerate() { if let Some(b) = decode_nibble(c) { if i % 2 == 0 { buf[i / 2] = b << 4; } else { buf[i / 2] |= b; } } else { return Err(ParseHexError::InvalidCharacter); } } Ok(()) } pub type Key = SafeByteArray<64>; pub fn load_key(path: &str) -> Result { let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; key_str.trim().parse().map_err(ReadHexError::Format) } pub struct SafeByteArray { inner: [u8; N], } impl SafeByteArray { pub fn new() -> Self { Self { inner: [0; N] } } } impl Default for SafeByteArray { fn default() -> Self { Self::new() } } impl AsRef<[u8]> for SafeByteArray { fn as_ref(&self) -> &[u8] { &self.inner } } impl AsMut<[u8]> for SafeByteArray { fn as_mut(&mut self) -> &mut [u8] { &mut self.inner } } impl Drop for SafeByteArray { fn drop(&mut self) { self.inner.fill(0) } } impl FromStr for SafeByteArray { type Err = ParseHexError; fn from_str(value: &str) -> Result { let mut sba = Self { inner: [0u8; N] }; parse_hex_exact(value, &mut sba.inner)?; Ok(sba) } } pub struct HexFmt>(pub B); impl> fmt::Display for HexFmt { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let HexFmt(buf) = self; for b in buf.as_ref() { let (high, low) = (b >> 4, b & 0xF); let highc = if high < 10 { high + b'0' } else { high - 10 + b'a' }; let lowc = if low < 10 { low + b'0' } else { low - 10 + b'a' }; f.write_char(highc as char)?; f.write_char(lowc as char)?; } Ok(()) } } #[derive(Debug, Copy, Clone)] pub struct Digest { inner: [u8; N], } impl ops::Index for Digest { type Output = u8; fn index(&self, index: usize) -> &Self::Output { &self.inner[index] } } impl Digest { pub fn as_bytes(&self) -> &[u8; N] { &self.inner } } impl fmt::Display for Digest { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { HexFmt(&self.inner).fmt(f) } } impl Digest { pub fn new(data: [u8; N]) -> Self { Self { inner: data } } } impl From<[u8; N]> for Digest { fn from(value: [u8; N]) -> Self { Self::new(value) } } impl From> for [u8; N] { fn from(val: Digest) -> Self { val.inner } } impl FromStr for Digest { type Err = ParseHexError; fn from_str(value: &str) -> Result { let mut buf = [0u8; N]; parse_hex_exact(value, &mut buf)?; Ok(buf.into()) } } impl ConstantTimeEq for Digest { fn ct_eq(&self, other: &Self) -> subtle::Choice { self.inner.ct_eq(&other.inner) } } impl PartialEq for Digest { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() } } impl Eq for Digest {} impl<'de, const N: usize> Deserialize<'de> for Digest { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let hex = <&str>::deserialize(deserializer)?; Digest::from_str(hex).map_err(de::Error::custom) } } impl Serialize for Digest { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str(&format!("{self}")) } }