use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize, de}; use std::{ fmt::{self, Write}, ops, str::FromStr, }; use subtle::ConstantTimeEq; #[repr(u8)] enum AuthType { BatchApi = 1, Download = 2, } #[derive(Debug, Copy, Clone)] pub struct Claims<'a> { pub specific_claims: SpecificClaims, pub repo_path: &'a str, pub expires_at: DateTime<Utc>, } #[derive(Debug, Copy, Clone)] pub enum SpecificClaims { BatchApi(Operation), Download(Oid), } pub type Oid = Digest<32>; #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] #[repr(u8)] pub enum Operation { #[serde(rename = "download")] Download = 1, #[serde(rename = "upload")] Upload = 2, } /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> { if claims.repo_path.len() > 100 { return None; } let mut hmac = hmac_sha256::HMAC::new(key); match claims.specific_claims { SpecificClaims::BatchApi(operation) => { hmac.update([AuthType::BatchApi as u8]); hmac.update([operation as u8]); } SpecificClaims::Download(oid) => { hmac.update([AuthType::Download as u8]); hmac.update(oid.as_bytes()); } } hmac.update([claims.repo_path.len() as u8]); hmac.update(claims.repo_path.as_bytes()); hmac.update(claims.expires_at.timestamp().to_be_bytes()); Some(hmac.finalize().into()) } #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub struct ParseOperationError; impl fmt::Display for ParseOperationError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "operation should be 'download' or 'upload'") } } impl FromStr for Operation { type Err = ParseOperationError; fn from_str(s: &str) -> Result<Self, Self::Err> { match s { "upload" => Ok(Self::Upload), "download" => Ok(Self::Download), _ => Err(ParseOperationError), } } } /// None means out of range. fn decode_nibble(c: u8) -> Option<u8> { if c.is_ascii_digit() { Some(c - b'0') } else if (b'a'..=b'f').contains(&c) { Some(c - b'a' + 10) } else if (b'A'..=b'F').contains(&c) { Some(c - b'A' + 10) } else { None } } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct HexByte(pub u8); impl<'de> Deserialize<'de> for HexByte { fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>, { let str = <&str>::deserialize(deserializer)?; let &[b1, b2] = str.as_bytes() else { return Err(de::Error::invalid_length( str.len(), &"two hexadecimal characters", )); }; let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { return Err(de::Error::invalid_value( de::Unexpected::Str(str), &"two hexadecimal characters", )); }; Ok(HexByte((b1 << 4) | b2)) } } impl fmt::Display for HexByte { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let &HexByte(b) = self; HexFmt(&[b]).fmt(f) } } #[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum ParseHexError { UnevenNibbles, InvalidCharacter, TooShort, TooLong, } impl fmt::Display for ParseHexError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::UnevenNibbles => { write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") } Self::InvalidCharacter => write!(f, "non-hex character encountered"), Self::TooShort => write!(f, "unexpected end of hex sequence"), Self::TooLong => write!(f, "longer hex sequence than expected"), } } } #[derive(Debug)] pub enum ReadHexError { Io(std::io::Error), Format(ParseHexError), } impl fmt::Display for ReadHexError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::Io(e) => e.fmt(f), Self::Format(e) => e.fmt(f), } } } fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { if value.len() % 2 == 1 { return Err(ParseHexError::UnevenNibbles); } if value.len() < 2 * buf.len() { return Err(ParseHexError::TooShort); } if value.len() > 2 * buf.len() { return Err(ParseHexError::TooLong); } for (i, c) in value.bytes().enumerate() { if let Some(b) = decode_nibble(c) { if i % 2 == 0 { buf[i / 2] = b << 4; } else { buf[i / 2] |= b; } } else { return Err(ParseHexError::InvalidCharacter); } } Ok(()) } pub type Key = SafeByteArray<64>; pub fn load_key(path: &str) -> Result<Key, ReadHexError> { let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; key_str.trim().parse().map_err(ReadHexError::Format) } pub struct SafeByteArray<const N: usize> { inner: [u8; N], } impl<const N: usize> SafeByteArray<N> { pub fn new() -> Self { Self { inner: [0; N] } } } impl<const N: usize> Default for SafeByteArray<N> { fn default() -> Self { Self::new() } } impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> { fn as_ref(&self) -> &[u8] { &self.inner } } impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> { fn as_mut(&mut self) -> &mut [u8] { &mut self.inner } } impl<const N: usize> Drop for SafeByteArray<N> { fn drop(&mut self) { self.inner.fill(0) } } impl<const N: usize> FromStr for SafeByteArray<N> { type Err = ParseHexError; fn from_str(value: &str) -> Result<Self, Self::Err> { let mut sba = Self { inner: [0u8; N] }; parse_hex_exact(value, &mut sba.inner)?; Ok(sba) } } pub struct HexFmt<B: AsRef<[u8]>>(pub B); impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let HexFmt(buf) = self; for b in buf.as_ref() { let (high, low) = (b >> 4, b & 0xF); let highc = if high < 10 { high + b'0' } else { high - 10 + b'a' }; let lowc = if low < 10 { low + b'0' } else { low - 10 + b'a' }; f.write_char(highc as char)?; f.write_char(lowc as char)?; } Ok(()) } } #[derive(Debug, Copy, Clone)] pub struct Digest<const N: usize> { inner: [u8; N], } impl<const N: usize> ops::Index<usize> for Digest<N> { type Output = u8; fn index(&self, index: usize) -> &Self::Output { &self.inner[index] } } impl<const N: usize> Digest<N> { pub fn as_bytes(&self) -> &[u8; N] { &self.inner } } impl<const N: usize> fmt::Display for Digest<N> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { HexFmt(&self.inner).fmt(f) } } impl<const N: usize> Digest<N> { pub fn new(data: [u8; N]) -> Self { Self { inner: data } } } impl<const N: usize> From<[u8; N]> for Digest<N> { fn from(value: [u8; N]) -> Self { Self::new(value) } } impl<const N: usize> From<Digest<N>> for [u8; N] { fn from(val: Digest<N>) -> Self { val.inner } } impl<const N: usize> FromStr for Digest<N> { type Err = ParseHexError; fn from_str(value: &str) -> Result<Self, Self::Err> { let mut buf = [0u8; N]; parse_hex_exact(value, &mut buf)?; Ok(buf.into()) } } impl<const N: usize> ConstantTimeEq for Digest<N> { fn ct_eq(&self, other: &Self) -> subtle::Choice { self.inner.ct_eq(&other.inner) } } impl<const N: usize> PartialEq for Digest<N> { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() } } impl<const N: usize> Eq for Digest<N> {} impl<'de, const N: usize> Deserialize<'de> for Digest<N> { fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: serde::Deserializer<'de>, { let hex = <&str>::deserialize(deserializer)?; Digest::from_str(hex).map_err(de::Error::custom) } } impl<const N: usize> Serialize for Digest<N> { fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: serde::Serializer, { serializer.serialize_str(&format!("{self}")) } }