diff options
Diffstat (limited to 'rs/common')
| -rw-r--r-- | rs/common/Cargo.toml | 10 | ||||
| -rw-r--r-- | rs/common/src/lib.rs | 337 |
2 files changed, 347 insertions, 0 deletions
diff --git a/rs/common/Cargo.toml b/rs/common/Cargo.toml new file mode 100644 index 0000000..20d9bdd --- /dev/null +++ b/rs/common/Cargo.toml | |||
| @@ -0,0 +1,10 @@ | |||
| 1 | [package] | ||
| 2 | name = "common" | ||
| 3 | version = "0.1.0" | ||
| 4 | edition = "2021" | ||
| 5 | |||
| 6 | [dependencies] | ||
| 7 | chrono = "0.4" | ||
| 8 | hmac-sha256 = "1.1" | ||
| 9 | subtle = "2.5" | ||
| 10 | serde = { version = "1", features = ["derive"] } | ||
diff --git a/rs/common/src/lib.rs b/rs/common/src/lib.rs new file mode 100644 index 0000000..aafe7f1 --- /dev/null +++ b/rs/common/src/lib.rs | |||
| @@ -0,0 +1,337 @@ | |||
| 1 | use chrono::{DateTime, Utc}; | ||
| 2 | use serde::de; | ||
| 3 | use serde::{Deserialize, Serialize}; | ||
| 4 | use std::fmt::Write; | ||
| 5 | use std::ops; | ||
| 6 | use std::{fmt, str::FromStr}; | ||
| 7 | use subtle::ConstantTimeEq; | ||
| 8 | |||
| 9 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] | ||
| 10 | #[repr(u8)] | ||
| 11 | pub enum Operation { | ||
| 12 | #[serde(rename = "download")] | ||
| 13 | Download = 1, | ||
| 14 | #[serde(rename = "upload")] | ||
| 15 | Upload = 2, | ||
| 16 | } | ||
| 17 | |||
| 18 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
| 19 | pub struct ParseOperationError; | ||
| 20 | |||
| 21 | impl fmt::Display for ParseOperationError { | ||
| 22 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 23 | write!(f, "operation should be 'download' or 'upload'") | ||
| 24 | } | ||
| 25 | } | ||
| 26 | |||
| 27 | impl FromStr for Operation { | ||
| 28 | type Err = ParseOperationError; | ||
| 29 | |||
| 30 | fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
| 31 | match s { | ||
| 32 | "upload" => Ok(Self::Upload), | ||
| 33 | "download" => Ok(Self::Download), | ||
| 34 | _ => Err(ParseOperationError), | ||
| 35 | } | ||
| 36 | } | ||
| 37 | } | ||
| 38 | |||
| 39 | #[repr(u8)] | ||
| 40 | pub enum AuthType { | ||
| 41 | GitLfsAuthenticate = 1, | ||
| 42 | } | ||
| 43 | |||
| 44 | /// None means out of range. | ||
| 45 | fn decode_nibble(c: u8) -> Option<u8> { | ||
| 46 | if c.is_ascii_digit() { | ||
| 47 | Some(c - b'0') | ||
| 48 | } else if (b'a'..=b'f').contains(&c) { | ||
| 49 | Some(c - b'a' + 10) | ||
| 50 | } else if (b'A'..=b'F').contains(&c) { | ||
| 51 | Some(c - b'A' + 10) | ||
| 52 | } else { | ||
| 53 | None | ||
| 54 | } | ||
| 55 | } | ||
| 56 | |||
| 57 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] | ||
| 58 | pub struct HexByte(pub u8); | ||
| 59 | |||
| 60 | impl<'de> Deserialize<'de> for HexByte { | ||
| 61 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
| 62 | where | ||
| 63 | D: serde::Deserializer<'de>, | ||
| 64 | { | ||
| 65 | let str = <&str>::deserialize(deserializer)?; | ||
| 66 | let &[b1, b2] = str.as_bytes() else { | ||
| 67 | return Err(de::Error::invalid_length( | ||
| 68 | str.len(), | ||
| 69 | &"two hexadecimal characters", | ||
| 70 | )); | ||
| 71 | }; | ||
| 72 | let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { | ||
| 73 | return Err(de::Error::invalid_value( | ||
| 74 | de::Unexpected::Str(str), | ||
| 75 | &"two hexadecimal characters", | ||
| 76 | )); | ||
| 77 | }; | ||
| 78 | Ok(HexByte((b1 << 4) | b2)) | ||
| 79 | } | ||
| 80 | } | ||
| 81 | |||
| 82 | impl fmt::Display for HexByte { | ||
| 83 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 84 | let &HexByte(b) = self; | ||
| 85 | HexFmt(&[b]).fmt(f) | ||
| 86 | } | ||
| 87 | } | ||
| 88 | |||
| 89 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
| 90 | pub enum ParseHexError { | ||
| 91 | UnevenNibbles, | ||
| 92 | InvalidCharacter, | ||
| 93 | TooShort, | ||
| 94 | TooLong, | ||
| 95 | } | ||
| 96 | |||
| 97 | impl fmt::Display for ParseHexError { | ||
| 98 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
| 99 | match self { | ||
| 100 | Self::UnevenNibbles => { | ||
| 101 | write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") | ||
| 102 | } | ||
| 103 | Self::InvalidCharacter => write!(f, "non-hex character encountered"), | ||
| 104 | Self::TooShort => write!(f, "unexpected end of hex sequence"), | ||
| 105 | Self::TooLong => write!(f, "longer hex sequence than expected"), | ||
| 106 | } | ||
| 107 | } | ||
| 108 | } | ||
| 109 | |||
| 110 | #[derive(Debug)] | ||
| 111 | pub enum ReadHexError { | ||
| 112 | Io(std::io::Error), | ||
| 113 | Format(ParseHexError), | ||
| 114 | } | ||
| 115 | |||
| 116 | impl fmt::Display for ReadHexError { | ||
| 117 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 118 | match self { | ||
| 119 | Self::Io(e) => e.fmt(f), | ||
| 120 | Self::Format(e) => e.fmt(f), | ||
| 121 | } | ||
| 122 | } | ||
| 123 | } | ||
| 124 | |||
| 125 | fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { | ||
| 126 | if value.bytes().len() % 2 == 1 { | ||
| 127 | return Err(ParseHexError::UnevenNibbles); | ||
| 128 | } | ||
| 129 | if value.bytes().len() < 2 * buf.len() { | ||
| 130 | return Err(ParseHexError::TooShort); | ||
| 131 | } | ||
| 132 | if value.bytes().len() > 2 * buf.len() { | ||
| 133 | return Err(ParseHexError::TooLong); | ||
| 134 | } | ||
| 135 | for (i, c) in value.bytes().enumerate() { | ||
| 136 | if let Some(b) = decode_nibble(c) { | ||
| 137 | if i % 2 == 0 { | ||
| 138 | buf[i / 2] |= b; | ||
| 139 | } else { | ||
| 140 | buf[i / 2] = b << 4; | ||
| 141 | } | ||
| 142 | } else { | ||
| 143 | return Err(ParseHexError::InvalidCharacter); | ||
| 144 | } | ||
| 145 | } | ||
| 146 | Ok(()) | ||
| 147 | } | ||
| 148 | |||
| 149 | pub struct SafeByteArray<const N: usize> { | ||
| 150 | inner: [u8; N], | ||
| 151 | } | ||
| 152 | |||
| 153 | impl<const N: usize> SafeByteArray<N> { | ||
| 154 | pub fn new() -> Self { | ||
| 155 | Self { inner: [0; N] } | ||
| 156 | } | ||
| 157 | } | ||
| 158 | |||
| 159 | impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> { | ||
| 160 | fn as_ref(&self) -> &[u8] { | ||
| 161 | &self.inner | ||
| 162 | } | ||
| 163 | } | ||
| 164 | |||
| 165 | impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> { | ||
| 166 | fn as_mut(&mut self) -> &mut [u8] { | ||
| 167 | &mut self.inner | ||
| 168 | } | ||
| 169 | } | ||
| 170 | |||
| 171 | impl<const N: usize> Drop for SafeByteArray<N> { | ||
| 172 | fn drop(&mut self) { | ||
| 173 | self.inner.fill(0) | ||
| 174 | } | ||
| 175 | } | ||
| 176 | |||
| 177 | impl<const N: usize> FromStr for SafeByteArray<N> { | ||
| 178 | type Err = ParseHexError; | ||
| 179 | |||
| 180 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
| 181 | let mut sba = Self { inner: [0u8; N] }; | ||
| 182 | parse_hex_exact(value, &mut sba.inner)?; | ||
| 183 | Ok(sba) | ||
| 184 | } | ||
| 185 | } | ||
| 186 | |||
| 187 | pub struct Claims<'a> { | ||
| 188 | pub auth_type: AuthType, | ||
| 189 | pub repo_path: &'a str, | ||
| 190 | pub operation: Operation, | ||
| 191 | pub expires_at: DateTime<Utc>, | ||
| 192 | } | ||
| 193 | |||
| 194 | /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. | ||
| 195 | pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> { | ||
| 196 | if claims.repo_path.len() > 100 { | ||
| 197 | return None; | ||
| 198 | } | ||
| 199 | |||
| 200 | let mut hmac = hmac_sha256::HMAC::new(key); | ||
| 201 | hmac.update([claims.auth_type as u8]); | ||
| 202 | hmac.update([claims.repo_path.len() as u8]); | ||
| 203 | hmac.update(claims.repo_path.as_bytes()); | ||
| 204 | hmac.update([claims.operation as u8]); | ||
| 205 | hmac.update(claims.expires_at.timestamp().to_be_bytes()); | ||
| 206 | Some(hmac.finalize().into()) | ||
| 207 | } | ||
| 208 | |||
| 209 | pub struct HexFmt<B: AsRef<[u8]>>(pub B); | ||
| 210 | |||
| 211 | impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> { | ||
| 212 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 213 | let HexFmt(buf) = self; | ||
| 214 | for b in buf.as_ref() { | ||
| 215 | let (high, low) = (b >> 4, b & 0xF); | ||
| 216 | let highc = if high < 10 { b'0' + high } else { b'a' + high }; | ||
| 217 | let lowc = if low < 10 { b'0' + low } else { b'a' + low }; | ||
| 218 | f.write_char(highc as char)?; | ||
| 219 | f.write_char(lowc as char)?; | ||
| 220 | } | ||
| 221 | Ok(()) | ||
| 222 | } | ||
| 223 | } | ||
| 224 | |||
| 225 | pub struct EscJsonFmt<'a>(pub &'a str); | ||
| 226 | |||
| 227 | impl<'a> fmt::Display for EscJsonFmt<'a> { | ||
| 228 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 229 | let EscJsonFmt(buf) = self; | ||
| 230 | for c in buf.chars() { | ||
| 231 | match c { | ||
| 232 | '"' => f.write_str("\\\"")?, // quote | ||
| 233 | '\\' => f.write_str("\\\\")?, // backslash | ||
| 234 | '\x08' => f.write_str("\\b")?, // backspace | ||
| 235 | '\x0C' => f.write_str("\\f")?, // form feed | ||
| 236 | '\n' => f.write_str("\\n")?, // line feed | ||
| 237 | '\r' => f.write_str("\\r")?, // carriage return | ||
| 238 | '\t' => f.write_str("\\t")?, // horizontal tab | ||
| 239 | _ => f.write_char(c)?, | ||
| 240 | }; | ||
| 241 | } | ||
| 242 | Ok(()) | ||
| 243 | } | ||
| 244 | } | ||
| 245 | |||
| 246 | #[derive(Debug, Copy, Clone)] | ||
| 247 | pub struct Digest<const N: usize> { | ||
| 248 | inner: [u8; N], | ||
| 249 | } | ||
| 250 | |||
| 251 | impl<const N: usize> ops::Index<usize> for Digest<N> { | ||
| 252 | type Output = u8; | ||
| 253 | |||
| 254 | fn index(&self, index: usize) -> &Self::Output { | ||
| 255 | &self.inner[index] | ||
| 256 | } | ||
| 257 | } | ||
| 258 | |||
| 259 | impl<const N: usize> Digest<N> { | ||
| 260 | pub fn as_bytes(&self) -> &[u8; N] { | ||
| 261 | &self.inner | ||
| 262 | } | ||
| 263 | } | ||
| 264 | |||
| 265 | impl<const N: usize> fmt::Display for Digest<N> { | ||
| 266 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| 267 | HexFmt(&self.inner).fmt(f) | ||
| 268 | } | ||
| 269 | } | ||
| 270 | |||
| 271 | impl<const N: usize> Digest<N> { | ||
| 272 | pub fn new(data: [u8; N]) -> Self { | ||
| 273 | Self { inner: data } | ||
| 274 | } | ||
| 275 | } | ||
| 276 | |||
| 277 | impl<const N: usize> From<[u8; N]> for Digest<N> { | ||
| 278 | fn from(value: [u8; N]) -> Self { | ||
| 279 | Self::new(value) | ||
| 280 | } | ||
| 281 | } | ||
| 282 | |||
| 283 | impl<const N: usize> Into<[u8; N]> for Digest<N> { | ||
| 284 | fn into(self) -> [u8; N] { | ||
| 285 | self.inner | ||
| 286 | } | ||
| 287 | } | ||
| 288 | |||
| 289 | impl<const N: usize> FromStr for Digest<N> { | ||
| 290 | type Err = ParseHexError; | ||
| 291 | |||
| 292 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
| 293 | let mut buf = [0u8; N]; | ||
| 294 | parse_hex_exact(value, &mut buf)?; | ||
| 295 | Ok(buf.into()) | ||
| 296 | } | ||
| 297 | } | ||
| 298 | |||
| 299 | impl<const N: usize> ConstantTimeEq for Digest<N> { | ||
| 300 | fn ct_eq(&self, other: &Self) -> subtle::Choice { | ||
| 301 | self.inner.ct_eq(&other.inner) | ||
| 302 | } | ||
| 303 | } | ||
| 304 | |||
| 305 | impl<const N: usize> PartialEq for Digest<N> { | ||
| 306 | fn eq(&self, other: &Self) -> bool { | ||
| 307 | self.ct_eq(&other).into() | ||
| 308 | } | ||
| 309 | } | ||
| 310 | |||
| 311 | impl<const N: usize> Eq for Digest<N> {} | ||
| 312 | |||
| 313 | impl<'de, const N: usize> Deserialize<'de> for Digest<N> { | ||
| 314 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
| 315 | where | ||
| 316 | D: serde::Deserializer<'de>, | ||
| 317 | { | ||
| 318 | let hex = <&str>::deserialize(deserializer)?; | ||
| 319 | Digest::from_str(hex).map_err(de::Error::custom) | ||
| 320 | } | ||
| 321 | } | ||
| 322 | |||
| 323 | impl<const N: usize> Serialize for Digest<N> { | ||
| 324 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
| 325 | where | ||
| 326 | S: serde::Serializer, | ||
| 327 | { | ||
| 328 | serializer.serialize_str(&format!("{self}")) | ||
| 329 | } | ||
| 330 | } | ||
| 331 | |||
| 332 | pub type Key = SafeByteArray<64>; | ||
| 333 | |||
| 334 | pub fn load_key(path: &str) -> Result<Key, ReadHexError> { | ||
| 335 | let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; | ||
| 336 | key_str.trim().parse().map_err(ReadHexError::Format) | ||
| 337 | } | ||