diff options
Diffstat (limited to 'common/src')
-rw-r--r-- | common/src/lib.rs | 368 |
1 files changed, 368 insertions, 0 deletions
diff --git a/common/src/lib.rs b/common/src/lib.rs new file mode 100644 index 0000000..995352d --- /dev/null +++ b/common/src/lib.rs | |||
@@ -0,0 +1,368 @@ | |||
1 | use chrono::{DateTime, Utc}; | ||
2 | use serde::de; | ||
3 | use serde::{Deserialize, Serialize}; | ||
4 | use std::fmt::Write; | ||
5 | use std::ops; | ||
6 | use std::{fmt, str::FromStr}; | ||
7 | use subtle::ConstantTimeEq; | ||
8 | |||
9 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] | ||
10 | #[repr(u8)] | ||
11 | pub enum Operation { | ||
12 | #[serde(rename = "download")] | ||
13 | Download = 1, | ||
14 | #[serde(rename = "upload")] | ||
15 | Upload = 2, | ||
16 | } | ||
17 | |||
18 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
19 | pub struct ParseOperationError; | ||
20 | |||
21 | impl fmt::Display for ParseOperationError { | ||
22 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
23 | write!(f, "operation should be 'download' or 'upload'") | ||
24 | } | ||
25 | } | ||
26 | |||
27 | impl FromStr for Operation { | ||
28 | type Err = ParseOperationError; | ||
29 | |||
30 | fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
31 | match s { | ||
32 | "upload" => Ok(Self::Upload), | ||
33 | "download" => Ok(Self::Download), | ||
34 | _ => Err(ParseOperationError), | ||
35 | } | ||
36 | } | ||
37 | } | ||
38 | |||
39 | #[repr(u8)] | ||
40 | enum AuthType { | ||
41 | BatchApi = 1, | ||
42 | Download = 2, | ||
43 | } | ||
44 | |||
45 | /// None means out of range. | ||
46 | fn decode_nibble(c: u8) -> Option<u8> { | ||
47 | if c.is_ascii_digit() { | ||
48 | Some(c - b'0') | ||
49 | } else if (b'a'..=b'f').contains(&c) { | ||
50 | Some(c - b'a' + 10) | ||
51 | } else if (b'A'..=b'F').contains(&c) { | ||
52 | Some(c - b'A' + 10) | ||
53 | } else { | ||
54 | None | ||
55 | } | ||
56 | } | ||
57 | |||
58 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] | ||
59 | pub struct HexByte(pub u8); | ||
60 | |||
61 | impl<'de> Deserialize<'de> for HexByte { | ||
62 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
63 | where | ||
64 | D: serde::Deserializer<'de>, | ||
65 | { | ||
66 | let str = <&str>::deserialize(deserializer)?; | ||
67 | let &[b1, b2] = str.as_bytes() else { | ||
68 | return Err(de::Error::invalid_length( | ||
69 | str.len(), | ||
70 | &"two hexadecimal characters", | ||
71 | )); | ||
72 | }; | ||
73 | let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { | ||
74 | return Err(de::Error::invalid_value( | ||
75 | de::Unexpected::Str(str), | ||
76 | &"two hexadecimal characters", | ||
77 | )); | ||
78 | }; | ||
79 | Ok(HexByte((b1 << 4) | b2)) | ||
80 | } | ||
81 | } | ||
82 | |||
83 | impl fmt::Display for HexByte { | ||
84 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
85 | let &HexByte(b) = self; | ||
86 | HexFmt(&[b]).fmt(f) | ||
87 | } | ||
88 | } | ||
89 | |||
90 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
91 | pub enum ParseHexError { | ||
92 | UnevenNibbles, | ||
93 | InvalidCharacter, | ||
94 | TooShort, | ||
95 | TooLong, | ||
96 | } | ||
97 | |||
98 | impl fmt::Display for ParseHexError { | ||
99 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
100 | match self { | ||
101 | Self::UnevenNibbles => { | ||
102 | write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") | ||
103 | } | ||
104 | Self::InvalidCharacter => write!(f, "non-hex character encountered"), | ||
105 | Self::TooShort => write!(f, "unexpected end of hex sequence"), | ||
106 | Self::TooLong => write!(f, "longer hex sequence than expected"), | ||
107 | } | ||
108 | } | ||
109 | } | ||
110 | |||
111 | #[derive(Debug)] | ||
112 | pub enum ReadHexError { | ||
113 | Io(std::io::Error), | ||
114 | Format(ParseHexError), | ||
115 | } | ||
116 | |||
117 | impl fmt::Display for ReadHexError { | ||
118 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
119 | match self { | ||
120 | Self::Io(e) => e.fmt(f), | ||
121 | Self::Format(e) => e.fmt(f), | ||
122 | } | ||
123 | } | ||
124 | } | ||
125 | |||
126 | fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { | ||
127 | if value.bytes().len() % 2 == 1 { | ||
128 | return Err(ParseHexError::UnevenNibbles); | ||
129 | } | ||
130 | if value.bytes().len() < 2 * buf.len() { | ||
131 | return Err(ParseHexError::TooShort); | ||
132 | } | ||
133 | if value.bytes().len() > 2 * buf.len() { | ||
134 | return Err(ParseHexError::TooLong); | ||
135 | } | ||
136 | for (i, c) in value.bytes().enumerate() { | ||
137 | if let Some(b) = decode_nibble(c) { | ||
138 | if i % 2 == 0 { | ||
139 | buf[i / 2] = b << 4; | ||
140 | } else { | ||
141 | buf[i / 2] |= b; | ||
142 | } | ||
143 | } else { | ||
144 | return Err(ParseHexError::InvalidCharacter); | ||
145 | } | ||
146 | } | ||
147 | Ok(()) | ||
148 | } | ||
149 | |||
150 | pub struct SafeByteArray<const N: usize> { | ||
151 | inner: [u8; N], | ||
152 | } | ||
153 | |||
154 | impl<const N: usize> SafeByteArray<N> { | ||
155 | pub fn new() -> Self { | ||
156 | Self { inner: [0; N] } | ||
157 | } | ||
158 | } | ||
159 | |||
160 | impl<const N: usize> Default for SafeByteArray<N> { | ||
161 | fn default() -> Self { | ||
162 | Self::new() | ||
163 | } | ||
164 | } | ||
165 | |||
166 | impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> { | ||
167 | fn as_ref(&self) -> &[u8] { | ||
168 | &self.inner | ||
169 | } | ||
170 | } | ||
171 | |||
172 | impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> { | ||
173 | fn as_mut(&mut self) -> &mut [u8] { | ||
174 | &mut self.inner | ||
175 | } | ||
176 | } | ||
177 | |||
178 | impl<const N: usize> Drop for SafeByteArray<N> { | ||
179 | fn drop(&mut self) { | ||
180 | self.inner.fill(0) | ||
181 | } | ||
182 | } | ||
183 | |||
184 | impl<const N: usize> FromStr for SafeByteArray<N> { | ||
185 | type Err = ParseHexError; | ||
186 | |||
187 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
188 | let mut sba = Self { inner: [0u8; N] }; | ||
189 | parse_hex_exact(value, &mut sba.inner)?; | ||
190 | Ok(sba) | ||
191 | } | ||
192 | } | ||
193 | |||
194 | pub type Oid = Digest<32>; | ||
195 | |||
196 | #[derive(Debug, Copy, Clone)] | ||
197 | pub enum SpecificClaims { | ||
198 | BatchApi(Operation), | ||
199 | Download(Oid), | ||
200 | } | ||
201 | |||
202 | #[derive(Debug, Copy, Clone)] | ||
203 | pub struct Claims<'a> { | ||
204 | pub specific_claims: SpecificClaims, | ||
205 | pub repo_path: &'a str, | ||
206 | pub expires_at: DateTime<Utc>, | ||
207 | } | ||
208 | |||
209 | /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. | ||
210 | pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> { | ||
211 | if claims.repo_path.len() > 100 { | ||
212 | return None; | ||
213 | } | ||
214 | |||
215 | let mut hmac = hmac_sha256::HMAC::new(key); | ||
216 | match claims.specific_claims { | ||
217 | SpecificClaims::BatchApi(operation) => { | ||
218 | hmac.update([AuthType::BatchApi as u8]); | ||
219 | hmac.update([operation as u8]); | ||
220 | } | ||
221 | SpecificClaims::Download(oid) => { | ||
222 | hmac.update([AuthType::Download as u8]); | ||
223 | hmac.update(oid.as_bytes()); | ||
224 | } | ||
225 | } | ||
226 | hmac.update([claims.repo_path.len() as u8]); | ||
227 | hmac.update(claims.repo_path.as_bytes()); | ||
228 | hmac.update(claims.expires_at.timestamp().to_be_bytes()); | ||
229 | Some(hmac.finalize().into()) | ||
230 | } | ||
231 | |||
232 | pub struct HexFmt<B: AsRef<[u8]>>(pub B); | ||
233 | |||
234 | impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> { | ||
235 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
236 | let HexFmt(buf) = self; | ||
237 | for b in buf.as_ref() { | ||
238 | let (high, low) = (b >> 4, b & 0xF); | ||
239 | let highc = if high < 10 { | ||
240 | high + b'0' | ||
241 | } else { | ||
242 | high - 10 + b'a' | ||
243 | }; | ||
244 | let lowc = if low < 10 { | ||
245 | low + b'0' | ||
246 | } else { | ||
247 | low - 10 + b'a' | ||
248 | }; | ||
249 | f.write_char(highc as char)?; | ||
250 | f.write_char(lowc as char)?; | ||
251 | } | ||
252 | Ok(()) | ||
253 | } | ||
254 | } | ||
255 | |||
256 | pub struct EscJsonFmt<'a>(pub &'a str); | ||
257 | |||
258 | impl<'a> fmt::Display for EscJsonFmt<'a> { | ||
259 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
260 | let EscJsonFmt(buf) = self; | ||
261 | for c in buf.chars() { | ||
262 | match c { | ||
263 | '"' => f.write_str("\\\"")?, // quote | ||
264 | '\\' => f.write_str("\\\\")?, // backslash | ||
265 | '\x08' => f.write_str("\\b")?, // backspace | ||
266 | '\x0C' => f.write_str("\\f")?, // form feed | ||
267 | '\n' => f.write_str("\\n")?, // line feed | ||
268 | '\r' => f.write_str("\\r")?, // carriage return | ||
269 | '\t' => f.write_str("\\t")?, // horizontal tab | ||
270 | _ => f.write_char(c)?, | ||
271 | }; | ||
272 | } | ||
273 | Ok(()) | ||
274 | } | ||
275 | } | ||
276 | |||
277 | #[derive(Debug, Copy, Clone)] | ||
278 | pub struct Digest<const N: usize> { | ||
279 | inner: [u8; N], | ||
280 | } | ||
281 | |||
282 | impl<const N: usize> ops::Index<usize> for Digest<N> { | ||
283 | type Output = u8; | ||
284 | |||
285 | fn index(&self, index: usize) -> &Self::Output { | ||
286 | &self.inner[index] | ||
287 | } | ||
288 | } | ||
289 | |||
290 | impl<const N: usize> Digest<N> { | ||
291 | pub fn as_bytes(&self) -> &[u8; N] { | ||
292 | &self.inner | ||
293 | } | ||
294 | } | ||
295 | |||
296 | impl<const N: usize> fmt::Display for Digest<N> { | ||
297 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
298 | HexFmt(&self.inner).fmt(f) | ||
299 | } | ||
300 | } | ||
301 | |||
302 | impl<const N: usize> Digest<N> { | ||
303 | pub fn new(data: [u8; N]) -> Self { | ||
304 | Self { inner: data } | ||
305 | } | ||
306 | } | ||
307 | |||
308 | impl<const N: usize> From<[u8; N]> for Digest<N> { | ||
309 | fn from(value: [u8; N]) -> Self { | ||
310 | Self::new(value) | ||
311 | } | ||
312 | } | ||
313 | |||
314 | impl<const N: usize> From<Digest<N>> for [u8; N] { | ||
315 | fn from(val: Digest<N>) -> Self { | ||
316 | val.inner | ||
317 | } | ||
318 | } | ||
319 | |||
320 | impl<const N: usize> FromStr for Digest<N> { | ||
321 | type Err = ParseHexError; | ||
322 | |||
323 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
324 | let mut buf = [0u8; N]; | ||
325 | parse_hex_exact(value, &mut buf)?; | ||
326 | Ok(buf.into()) | ||
327 | } | ||
328 | } | ||
329 | |||
330 | impl<const N: usize> ConstantTimeEq for Digest<N> { | ||
331 | fn ct_eq(&self, other: &Self) -> subtle::Choice { | ||
332 | self.inner.ct_eq(&other.inner) | ||
333 | } | ||
334 | } | ||
335 | |||
336 | impl<const N: usize> PartialEq for Digest<N> { | ||
337 | fn eq(&self, other: &Self) -> bool { | ||
338 | self.ct_eq(other).into() | ||
339 | } | ||
340 | } | ||
341 | |||
342 | impl<const N: usize> Eq for Digest<N> {} | ||
343 | |||
344 | impl<'de, const N: usize> Deserialize<'de> for Digest<N> { | ||
345 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
346 | where | ||
347 | D: serde::Deserializer<'de>, | ||
348 | { | ||
349 | let hex = <&str>::deserialize(deserializer)?; | ||
350 | Digest::from_str(hex).map_err(de::Error::custom) | ||
351 | } | ||
352 | } | ||
353 | |||
354 | impl<const N: usize> Serialize for Digest<N> { | ||
355 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
356 | where | ||
357 | S: serde::Serializer, | ||
358 | { | ||
359 | serializer.serialize_str(&format!("{self}")) | ||
360 | } | ||
361 | } | ||
362 | |||
363 | pub type Key = SafeByteArray<64>; | ||
364 | |||
365 | pub fn load_key(path: &str) -> Result<Key, ReadHexError> { | ||
366 | let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; | ||
367 | key_str.trim().parse().map_err(ReadHexError::Format) | ||
368 | } | ||