diff options
Diffstat (limited to 'rs/common')
| -rw-r--r-- | rs/common/Cargo.toml | 10 | ||||
| -rw-r--r-- | rs/common/src/lib.rs | 368 |
2 files changed, 0 insertions, 378 deletions
diff --git a/rs/common/Cargo.toml b/rs/common/Cargo.toml deleted file mode 100644 index 20d9bdd..0000000 --- a/rs/common/Cargo.toml +++ /dev/null | |||
| @@ -1,10 +0,0 @@ | |||
| 1 | [package] | ||
| 2 | name = "common" | ||
| 3 | version = "0.1.0" | ||
| 4 | edition = "2021" | ||
| 5 | |||
| 6 | [dependencies] | ||
| 7 | chrono = "0.4" | ||
| 8 | hmac-sha256 = "1.1" | ||
| 9 | subtle = "2.5" | ||
| 10 | serde = { version = "1", features = ["derive"] } | ||
diff --git a/rs/common/src/lib.rs b/rs/common/src/lib.rs deleted file mode 100644 index 995352d..0000000 --- a/rs/common/src/lib.rs +++ /dev/null | |||
| @@ -1,368 +0,0 @@ | |||
| 1 | use chrono::{DateTime, Utc}; | ||
| 2 | use serde::de; | ||
| 3 | use serde::{Deserialize, Serialize}; | ||
| 4 | use std::fmt::Write; | ||
| 5 | use std::ops; | ||
| 6 | use std::{fmt, str::FromStr}; | ||
| 7 | use subtle::ConstantTimeEq; | ||
| 8 | |||
| 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] | ||
| 10 | #[repr(u8)] | ||
| 11 | pub enum Operation { | ||
| 12 | #[serde(rename = "download")] | ||
| 13 | Download = 1, | ||
| 14 | #[serde(rename = "upload")] | ||
| 15 | Upload = 2, | ||
| 16 | } | ||
| 17 | |||
| 18 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
| 19 | pub struct ParseOperationError; | ||
| 20 | |||
| 21 | impl fmt::Display for ParseOperationError { | ||
| 22 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 23 | write!(f, "operation should be 'download' or 'upload'") | ||
| 24 | } | ||
| 25 | } | ||
| 26 | |||
| 27 | impl FromStr for Operation { | ||
| 28 | type Err = ParseOperationError; | ||
| 29 | |||
| 30 | fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
| 31 | match s { | ||
| 32 | "upload" => Ok(Self::Upload), | ||
| 33 | "download" => Ok(Self::Download), | ||
| 34 | _ => Err(ParseOperationError), | ||
| 35 | } | ||
| 36 | } | ||
| 37 | } | ||
| 38 | |||
| 39 | #[repr(u8)] | ||
| 40 | enum AuthType { | ||
| 41 | BatchApi = 1, | ||
| 42 | Download = 2, | ||
| 43 | } | ||
| 44 | |||
| 45 | /// None means out of range. | ||
| 46 | fn decode_nibble(c: u8) -> Option<u8> { | ||
| 47 | if c.is_ascii_digit() { | ||
| 48 | Some(c - b'0') | ||
| 49 | } else if (b'a'..=b'f').contains(&c) { | ||
| 50 | Some(c - b'a' + 10) | ||
| 51 | } else if (b'A'..=b'F').contains(&c) { | ||
| 52 | Some(c - b'A' + 10) | ||
| 53 | } else { | ||
| 54 | None | ||
| 55 | } | ||
| 56 | } | ||
| 57 | |||
| 58 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] | ||
| 59 | pub struct HexByte(pub u8); | ||
| 60 | |||
| 61 | impl<'de> Deserialize<'de> for HexByte { | ||
| 62 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
| 63 | where | ||
| 64 | D: serde::Deserializer<'de>, | ||
| 65 | { | ||
| 66 | let str = <&str>::deserialize(deserializer)?; | ||
| 67 | let &[b1, b2] = str.as_bytes() else { | ||
| 68 | return Err(de::Error::invalid_length( | ||
| 69 | str.len(), | ||
| 70 | &"two hexadecimal characters", | ||
| 71 | )); | ||
| 72 | }; | ||
| 73 | let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { | ||
| 74 | return Err(de::Error::invalid_value( | ||
| 75 | de::Unexpected::Str(str), | ||
| 76 | &"two hexadecimal characters", | ||
| 77 | )); | ||
| 78 | }; | ||
| 79 | Ok(HexByte((b1 << 4) | b2)) | ||
| 80 | } | ||
| 81 | } | ||
| 82 | |||
| 83 | impl fmt::Display for HexByte { | ||
| 84 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 85 | let &HexByte(b) = self; | ||
| 86 | HexFmt(&[b]).fmt(f) | ||
| 87 | } | ||
| 88 | } | ||
| 89 | |||
| 90 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
| 91 | pub enum ParseHexError { | ||
| 92 | UnevenNibbles, | ||
| 93 | InvalidCharacter, | ||
| 94 | TooShort, | ||
| 95 | TooLong, | ||
| 96 | } | ||
| 97 | |||
| 98 | impl fmt::Display for ParseHexError { | ||
| 99 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
| 100 | match self { | ||
| 101 | Self::UnevenNibbles => { | ||
| 102 | write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") | ||
| 103 | } | ||
| 104 | Self::InvalidCharacter => write!(f, "non-hex character encountered"), | ||
| 105 | Self::TooShort => write!(f, "unexpected end of hex sequence"), | ||
| 106 | Self::TooLong => write!(f, "longer hex sequence than expected"), | ||
| 107 | } | ||
| 108 | } | ||
| 109 | } | ||
| 110 | |||
| 111 | #[derive(Debug)] | ||
| 112 | pub enum ReadHexError { | ||
| 113 | Io(std::io::Error), | ||
| 114 | Format(ParseHexError), | ||
| 115 | } | ||
| 116 | |||
| 117 | impl fmt::Display for ReadHexError { | ||
| 118 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 119 | match self { | ||
| 120 | Self::Io(e) => e.fmt(f), | ||
| 121 | Self::Format(e) => e.fmt(f), | ||
| 122 | } | ||
| 123 | } | ||
| 124 | } | ||
| 125 | |||
| 126 | fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { | ||
| 127 | if value.bytes().len() % 2 == 1 { | ||
| 128 | return Err(ParseHexError::UnevenNibbles); | ||
| 129 | } | ||
| 130 | if value.bytes().len() < 2 * buf.len() { | ||
| 131 | return Err(ParseHexError::TooShort); | ||
| 132 | } | ||
| 133 | if value.bytes().len() > 2 * buf.len() { | ||
| 134 | return Err(ParseHexError::TooLong); | ||
| 135 | } | ||
| 136 | for (i, c) in value.bytes().enumerate() { | ||
| 137 | if let Some(b) = decode_nibble(c) { | ||
| 138 | if i % 2 == 0 { | ||
| 139 | buf[i / 2] = b << 4; | ||
| 140 | } else { | ||
| 141 | buf[i / 2] |= b; | ||
| 142 | } | ||
| 143 | } else { | ||
| 144 | return Err(ParseHexError::InvalidCharacter); | ||
| 145 | } | ||
| 146 | } | ||
| 147 | Ok(()) | ||
| 148 | } | ||
| 149 | |||
| 150 | pub struct SafeByteArray<const N: usize> { | ||
| 151 | inner: [u8; N], | ||
| 152 | } | ||
| 153 | |||
| 154 | impl<const N: usize> SafeByteArray<N> { | ||
| 155 | pub fn new() -> Self { | ||
| 156 | Self { inner: [0; N] } | ||
| 157 | } | ||
| 158 | } | ||
| 159 | |||
| 160 | impl<const N: usize> Default for SafeByteArray<N> { | ||
| 161 | fn default() -> Self { | ||
| 162 | Self::new() | ||
| 163 | } | ||
| 164 | } | ||
| 165 | |||
| 166 | impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> { | ||
| 167 | fn as_ref(&self) -> &[u8] { | ||
| 168 | &self.inner | ||
| 169 | } | ||
| 170 | } | ||
| 171 | |||
| 172 | impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> { | ||
| 173 | fn as_mut(&mut self) -> &mut [u8] { | ||
| 174 | &mut self.inner | ||
| 175 | } | ||
| 176 | } | ||
| 177 | |||
| 178 | impl<const N: usize> Drop for SafeByteArray<N> { | ||
| 179 | fn drop(&mut self) { | ||
| 180 | self.inner.fill(0) | ||
| 181 | } | ||
| 182 | } | ||
| 183 | |||
| 184 | impl<const N: usize> FromStr for SafeByteArray<N> { | ||
| 185 | type Err = ParseHexError; | ||
| 186 | |||
| 187 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
| 188 | let mut sba = Self { inner: [0u8; N] }; | ||
| 189 | parse_hex_exact(value, &mut sba.inner)?; | ||
| 190 | Ok(sba) | ||
| 191 | } | ||
| 192 | } | ||
| 193 | |||
| 194 | pub type Oid = Digest<32>; | ||
| 195 | |||
| 196 | #[derive(Debug, Copy, Clone)] | ||
| 197 | pub enum SpecificClaims { | ||
| 198 | BatchApi(Operation), | ||
| 199 | Download(Oid), | ||
| 200 | } | ||
| 201 | |||
| 202 | #[derive(Debug, Copy, Clone)] | ||
| 203 | pub struct Claims<'a> { | ||
| 204 | pub specific_claims: SpecificClaims, | ||
| 205 | pub repo_path: &'a str, | ||
| 206 | pub expires_at: DateTime<Utc>, | ||
| 207 | } | ||
| 208 | |||
| 209 | /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. | ||
| 210 | pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> { | ||
| 211 | if claims.repo_path.len() > 100 { | ||
| 212 | return None; | ||
| 213 | } | ||
| 214 | |||
| 215 | let mut hmac = hmac_sha256::HMAC::new(key); | ||
| 216 | match claims.specific_claims { | ||
| 217 | SpecificClaims::BatchApi(operation) => { | ||
| 218 | hmac.update([AuthType::BatchApi as u8]); | ||
| 219 | hmac.update([operation as u8]); | ||
| 220 | } | ||
| 221 | SpecificClaims::Download(oid) => { | ||
| 222 | hmac.update([AuthType::Download as u8]); | ||
| 223 | hmac.update(oid.as_bytes()); | ||
| 224 | } | ||
| 225 | } | ||
| 226 | hmac.update([claims.repo_path.len() as u8]); | ||
| 227 | hmac.update(claims.repo_path.as_bytes()); | ||
| 228 | hmac.update(claims.expires_at.timestamp().to_be_bytes()); | ||
| 229 | Some(hmac.finalize().into()) | ||
| 230 | } | ||
| 231 | |||
| 232 | pub struct HexFmt<B: AsRef<[u8]>>(pub B); | ||
| 233 | |||
| 234 | impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> { | ||
| 235 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 236 | let HexFmt(buf) = self; | ||
| 237 | for b in buf.as_ref() { | ||
| 238 | let (high, low) = (b >> 4, b & 0xF); | ||
| 239 | let highc = if high < 10 { | ||
| 240 | high + b'0' | ||
| 241 | } else { | ||
| 242 | high - 10 + b'a' | ||
| 243 | }; | ||
| 244 | let lowc = if low < 10 { | ||
| 245 | low + b'0' | ||
| 246 | } else { | ||
| 247 | low - 10 + b'a' | ||
| 248 | }; | ||
| 249 | f.write_char(highc as char)?; | ||
| 250 | f.write_char(lowc as char)?; | ||
| 251 | } | ||
| 252 | Ok(()) | ||
| 253 | } | ||
| 254 | } | ||
| 255 | |||
| 256 | pub struct EscJsonFmt<'a>(pub &'a str); | ||
| 257 | |||
| 258 | impl<'a> fmt::Display for EscJsonFmt<'a> { | ||
| 259 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 260 | let EscJsonFmt(buf) = self; | ||
| 261 | for c in buf.chars() { | ||
| 262 | match c { | ||
| 263 | '"' => f.write_str("\\\"")?, // quote | ||
| 264 | '\\' => f.write_str("\\\\")?, // backslash | ||
| 265 | '\x08' => f.write_str("\\b")?, // backspace | ||
| 266 | '\x0C' => f.write_str("\\f")?, // form feed | ||
| 267 | '\n' => f.write_str("\\n")?, // line feed | ||
| 268 | '\r' => f.write_str("\\r")?, // carriage return | ||
| 269 | '\t' => f.write_str("\\t")?, // horizontal tab | ||
| 270 | _ => f.write_char(c)?, | ||
| 271 | }; | ||
| 272 | } | ||
| 273 | Ok(()) | ||
| 274 | } | ||
| 275 | } | ||
| 276 | |||
| 277 | #[derive(Debug, Copy, Clone)] | ||
| 278 | pub struct Digest<const N: usize> { | ||
| 279 | inner: [u8; N], | ||
| 280 | } | ||
| 281 | |||
| 282 | impl<const N: usize> ops::Index<usize> for Digest<N> { | ||
| 283 | type Output = u8; | ||
| 284 | |||
| 285 | fn index(&self, index: usize) -> &Self::Output { | ||
| 286 | &self.inner[index] | ||
| 287 | } | ||
| 288 | } | ||
| 289 | |||
| 290 | impl<const N: usize> Digest<N> { | ||
| 291 | pub fn as_bytes(&self) -> &[u8; N] { | ||
| 292 | &self.inner | ||
| 293 | } | ||
| 294 | } | ||
| 295 | |||
| 296 | impl<const N: usize> fmt::Display for Digest<N> { | ||
| 297 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 298 | HexFmt(&self.inner).fmt(f) | ||
| 299 | } | ||
| 300 | } | ||
| 301 | |||
| 302 | impl<const N: usize> Digest<N> { | ||
| 303 | pub fn new(data: [u8; N]) -> Self { | ||
| 304 | Self { inner: data } | ||
| 305 | } | ||
| 306 | } | ||
| 307 | |||
| 308 | impl<const N: usize> From<[u8; N]> for Digest<N> { | ||
| 309 | fn from(value: [u8; N]) -> Self { | ||
| 310 | Self::new(value) | ||
| 311 | } | ||
| 312 | } | ||
| 313 | |||
| 314 | impl<const N: usize> From<Digest<N>> for [u8; N] { | ||
| 315 | fn from(val: Digest<N>) -> Self { | ||
| 316 | val.inner | ||
| 317 | } | ||
| 318 | } | ||
| 319 | |||
| 320 | impl<const N: usize> FromStr for Digest<N> { | ||
| 321 | type Err = ParseHexError; | ||
| 322 | |||
| 323 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
| 324 | let mut buf = [0u8; N]; | ||
| 325 | parse_hex_exact(value, &mut buf)?; | ||
| 326 | Ok(buf.into()) | ||
| 327 | } | ||
| 328 | } | ||
| 329 | |||
| 330 | impl<const N: usize> ConstantTimeEq for Digest<N> { | ||
| 331 | fn ct_eq(&self, other: &Self) -> subtle::Choice { | ||
| 332 | self.inner.ct_eq(&other.inner) | ||
| 333 | } | ||
| 334 | } | ||
| 335 | |||
| 336 | impl<const N: usize> PartialEq for Digest<N> { | ||
| 337 | fn eq(&self, other: &Self) -> bool { | ||
| 338 | self.ct_eq(other).into() | ||
| 339 | } | ||
| 340 | } | ||
| 341 | |||
| 342 | impl<const N: usize> Eq for Digest<N> {} | ||
| 343 | |||
| 344 | impl<'de, const N: usize> Deserialize<'de> for Digest<N> { | ||
| 345 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
| 346 | where | ||
| 347 | D: serde::Deserializer<'de>, | ||
| 348 | { | ||
| 349 | let hex = <&str>::deserialize(deserializer)?; | ||
| 350 | Digest::from_str(hex).map_err(de::Error::custom) | ||
| 351 | } | ||
| 352 | } | ||
| 353 | |||
| 354 | impl<const N: usize> Serialize for Digest<N> { | ||
| 355 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
| 356 | where | ||
| 357 | S: serde::Serializer, | ||
| 358 | { | ||
| 359 | serializer.serialize_str(&format!("{self}")) | ||
| 360 | } | ||
| 361 | } | ||
| 362 | |||
| 363 | pub type Key = SafeByteArray<64>; | ||
| 364 | |||
| 365 | pub fn load_key(path: &str) -> Result<Key, ReadHexError> { | ||
| 366 | let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; | ||
| 367 | key_str.trim().parse().map_err(ReadHexError::Format) | ||
| 368 | } | ||