diff options
Diffstat (limited to 'common/src')
-rw-r--r-- | common/src/lib.rs | 348 |
1 files changed, 0 insertions, 348 deletions
diff --git a/common/src/lib.rs b/common/src/lib.rs deleted file mode 100644 index 917f566..0000000 --- a/common/src/lib.rs +++ /dev/null | |||
@@ -1,348 +0,0 @@ | |||
1 | use chrono::{DateTime, Utc}; | ||
2 | use serde::{de, Deserialize, Serialize}; | ||
3 | use std::{ | ||
4 | fmt::{self, Write}, | ||
5 | ops, | ||
6 | str::FromStr, | ||
7 | }; | ||
8 | use subtle::ConstantTimeEq; | ||
9 | |||
10 | #[repr(u8)] | ||
11 | enum AuthType { | ||
12 | BatchApi = 1, | ||
13 | Download = 2, | ||
14 | } | ||
15 | |||
16 | #[derive(Debug, Copy, Clone)] | ||
17 | pub struct Claims<'a> { | ||
18 | pub specific_claims: SpecificClaims, | ||
19 | pub repo_path: &'a str, | ||
20 | pub expires_at: DateTime<Utc>, | ||
21 | } | ||
22 | |||
23 | #[derive(Debug, Copy, Clone)] | ||
24 | pub enum SpecificClaims { | ||
25 | BatchApi(Operation), | ||
26 | Download(Oid), | ||
27 | } | ||
28 | |||
29 | pub type Oid = Digest<32>; | ||
30 | |||
31 | #[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)] | ||
32 | #[repr(u8)] | ||
33 | pub enum Operation { | ||
34 | #[serde(rename = "download")] | ||
35 | Download = 1, | ||
36 | #[serde(rename = "upload")] | ||
37 | Upload = 2, | ||
38 | } | ||
39 | |||
40 | /// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes. | ||
41 | pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> { | ||
42 | if claims.repo_path.len() > 100 { | ||
43 | return None; | ||
44 | } | ||
45 | |||
46 | let mut hmac = hmac_sha256::HMAC::new(key); | ||
47 | match claims.specific_claims { | ||
48 | SpecificClaims::BatchApi(operation) => { | ||
49 | hmac.update([AuthType::BatchApi as u8]); | ||
50 | hmac.update([operation as u8]); | ||
51 | } | ||
52 | SpecificClaims::Download(oid) => { | ||
53 | hmac.update([AuthType::Download as u8]); | ||
54 | hmac.update(oid.as_bytes()); | ||
55 | } | ||
56 | } | ||
57 | hmac.update([claims.repo_path.len() as u8]); | ||
58 | hmac.update(claims.repo_path.as_bytes()); | ||
59 | hmac.update(claims.expires_at.timestamp().to_be_bytes()); | ||
60 | Some(hmac.finalize().into()) | ||
61 | } | ||
62 | |||
63 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
64 | pub struct ParseOperationError; | ||
65 | |||
66 | impl fmt::Display for ParseOperationError { | ||
67 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
68 | write!(f, "operation should be 'download' or 'upload'") | ||
69 | } | ||
70 | } | ||
71 | |||
72 | impl FromStr for Operation { | ||
73 | type Err = ParseOperationError; | ||
74 | |||
75 | fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
76 | match s { | ||
77 | "upload" => Ok(Self::Upload), | ||
78 | "download" => Ok(Self::Download), | ||
79 | _ => Err(ParseOperationError), | ||
80 | } | ||
81 | } | ||
82 | } | ||
83 | |||
84 | /// None means out of range. | ||
85 | fn decode_nibble(c: u8) -> Option<u8> { | ||
86 | if c.is_ascii_digit() { | ||
87 | Some(c - b'0') | ||
88 | } else if (b'a'..=b'f').contains(&c) { | ||
89 | Some(c - b'a' + 10) | ||
90 | } else if (b'A'..=b'F').contains(&c) { | ||
91 | Some(c - b'A' + 10) | ||
92 | } else { | ||
93 | None | ||
94 | } | ||
95 | } | ||
96 | |||
97 | #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] | ||
98 | pub struct HexByte(pub u8); | ||
99 | |||
100 | impl<'de> Deserialize<'de> for HexByte { | ||
101 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
102 | where | ||
103 | D: serde::Deserializer<'de>, | ||
104 | { | ||
105 | let str = <&str>::deserialize(deserializer)?; | ||
106 | let &[b1, b2] = str.as_bytes() else { | ||
107 | return Err(de::Error::invalid_length( | ||
108 | str.len(), | ||
109 | &"two hexadecimal characters", | ||
110 | )); | ||
111 | }; | ||
112 | let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else { | ||
113 | return Err(de::Error::invalid_value( | ||
114 | de::Unexpected::Str(str), | ||
115 | &"two hexadecimal characters", | ||
116 | )); | ||
117 | }; | ||
118 | Ok(HexByte((b1 << 4) | b2)) | ||
119 | } | ||
120 | } | ||
121 | |||
122 | impl fmt::Display for HexByte { | ||
123 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
124 | let &HexByte(b) = self; | ||
125 | HexFmt(&[b]).fmt(f) | ||
126 | } | ||
127 | } | ||
128 | |||
129 | #[derive(Debug, PartialEq, Eq, Copy, Clone)] | ||
130 | pub enum ParseHexError { | ||
131 | UnevenNibbles, | ||
132 | InvalidCharacter, | ||
133 | TooShort, | ||
134 | TooLong, | ||
135 | } | ||
136 | |||
137 | impl fmt::Display for ParseHexError { | ||
138 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||
139 | match self { | ||
140 | Self::UnevenNibbles => { | ||
141 | write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])") | ||
142 | } | ||
143 | Self::InvalidCharacter => write!(f, "non-hex character encountered"), | ||
144 | Self::TooShort => write!(f, "unexpected end of hex sequence"), | ||
145 | Self::TooLong => write!(f, "longer hex sequence than expected"), | ||
146 | } | ||
147 | } | ||
148 | } | ||
149 | |||
150 | #[derive(Debug)] | ||
151 | pub enum ReadHexError { | ||
152 | Io(std::io::Error), | ||
153 | Format(ParseHexError), | ||
154 | } | ||
155 | |||
156 | impl fmt::Display for ReadHexError { | ||
157 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
158 | match self { | ||
159 | Self::Io(e) => e.fmt(f), | ||
160 | Self::Format(e) => e.fmt(f), | ||
161 | } | ||
162 | } | ||
163 | } | ||
164 | |||
165 | fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> { | ||
166 | if value.bytes().len() % 2 == 1 { | ||
167 | return Err(ParseHexError::UnevenNibbles); | ||
168 | } | ||
169 | if value.bytes().len() < 2 * buf.len() { | ||
170 | return Err(ParseHexError::TooShort); | ||
171 | } | ||
172 | if value.bytes().len() > 2 * buf.len() { | ||
173 | return Err(ParseHexError::TooLong); | ||
174 | } | ||
175 | for (i, c) in value.bytes().enumerate() { | ||
176 | if let Some(b) = decode_nibble(c) { | ||
177 | if i % 2 == 0 { | ||
178 | buf[i / 2] = b << 4; | ||
179 | } else { | ||
180 | buf[i / 2] |= b; | ||
181 | } | ||
182 | } else { | ||
183 | return Err(ParseHexError::InvalidCharacter); | ||
184 | } | ||
185 | } | ||
186 | Ok(()) | ||
187 | } | ||
188 | |||
189 | pub type Key = SafeByteArray<64>; | ||
190 | |||
191 | pub fn load_key(path: &str) -> Result<Key, ReadHexError> { | ||
192 | let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?; | ||
193 | key_str.trim().parse().map_err(ReadHexError::Format) | ||
194 | } | ||
195 | |||
196 | pub struct SafeByteArray<const N: usize> { | ||
197 | inner: [u8; N], | ||
198 | } | ||
199 | |||
200 | impl<const N: usize> SafeByteArray<N> { | ||
201 | pub fn new() -> Self { | ||
202 | Self { inner: [0; N] } | ||
203 | } | ||
204 | } | ||
205 | |||
206 | impl<const N: usize> Default for SafeByteArray<N> { | ||
207 | fn default() -> Self { | ||
208 | Self::new() | ||
209 | } | ||
210 | } | ||
211 | |||
212 | impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> { | ||
213 | fn as_ref(&self) -> &[u8] { | ||
214 | &self.inner | ||
215 | } | ||
216 | } | ||
217 | |||
218 | impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> { | ||
219 | fn as_mut(&mut self) -> &mut [u8] { | ||
220 | &mut self.inner | ||
221 | } | ||
222 | } | ||
223 | |||
224 | impl<const N: usize> Drop for SafeByteArray<N> { | ||
225 | fn drop(&mut self) { | ||
226 | self.inner.fill(0) | ||
227 | } | ||
228 | } | ||
229 | |||
230 | impl<const N: usize> FromStr for SafeByteArray<N> { | ||
231 | type Err = ParseHexError; | ||
232 | |||
233 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
234 | let mut sba = Self { inner: [0u8; N] }; | ||
235 | parse_hex_exact(value, &mut sba.inner)?; | ||
236 | Ok(sba) | ||
237 | } | ||
238 | } | ||
239 | |||
240 | pub struct HexFmt<B: AsRef<[u8]>>(pub B); | ||
241 | |||
242 | impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> { | ||
243 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
244 | let HexFmt(buf) = self; | ||
245 | for b in buf.as_ref() { | ||
246 | let (high, low) = (b >> 4, b & 0xF); | ||
247 | let highc = if high < 10 { | ||
248 | high + b'0' | ||
249 | } else { | ||
250 | high - 10 + b'a' | ||
251 | }; | ||
252 | let lowc = if low < 10 { | ||
253 | low + b'0' | ||
254 | } else { | ||
255 | low - 10 + b'a' | ||
256 | }; | ||
257 | f.write_char(highc as char)?; | ||
258 | f.write_char(lowc as char)?; | ||
259 | } | ||
260 | Ok(()) | ||
261 | } | ||
262 | } | ||
263 | |||
264 | #[derive(Debug, Copy, Clone)] | ||
265 | pub struct Digest<const N: usize> { | ||
266 | inner: [u8; N], | ||
267 | } | ||
268 | |||
269 | impl<const N: usize> ops::Index<usize> for Digest<N> { | ||
270 | type Output = u8; | ||
271 | |||
272 | fn index(&self, index: usize) -> &Self::Output { | ||
273 | &self.inner[index] | ||
274 | } | ||
275 | } | ||
276 | |||
277 | impl<const N: usize> Digest<N> { | ||
278 | pub fn as_bytes(&self) -> &[u8; N] { | ||
279 | &self.inner | ||
280 | } | ||
281 | } | ||
282 | |||
283 | impl<const N: usize> fmt::Display for Digest<N> { | ||
284 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
285 | HexFmt(&self.inner).fmt(f) | ||
286 | } | ||
287 | } | ||
288 | |||
289 | impl<const N: usize> Digest<N> { | ||
290 | pub fn new(data: [u8; N]) -> Self { | ||
291 | Self { inner: data } | ||
292 | } | ||
293 | } | ||
294 | |||
295 | impl<const N: usize> From<[u8; N]> for Digest<N> { | ||
296 | fn from(value: [u8; N]) -> Self { | ||
297 | Self::new(value) | ||
298 | } | ||
299 | } | ||
300 | |||
301 | impl<const N: usize> From<Digest<N>> for [u8; N] { | ||
302 | fn from(val: Digest<N>) -> Self { | ||
303 | val.inner | ||
304 | } | ||
305 | } | ||
306 | |||
307 | impl<const N: usize> FromStr for Digest<N> { | ||
308 | type Err = ParseHexError; | ||
309 | |||
310 | fn from_str(value: &str) -> Result<Self, Self::Err> { | ||
311 | let mut buf = [0u8; N]; | ||
312 | parse_hex_exact(value, &mut buf)?; | ||
313 | Ok(buf.into()) | ||
314 | } | ||
315 | } | ||
316 | |||
317 | impl<const N: usize> ConstantTimeEq for Digest<N> { | ||
318 | fn ct_eq(&self, other: &Self) -> subtle::Choice { | ||
319 | self.inner.ct_eq(&other.inner) | ||
320 | } | ||
321 | } | ||
322 | |||
323 | impl<const N: usize> PartialEq for Digest<N> { | ||
324 | fn eq(&self, other: &Self) -> bool { | ||
325 | self.ct_eq(other).into() | ||
326 | } | ||
327 | } | ||
328 | |||
329 | impl<const N: usize> Eq for Digest<N> {} | ||
330 | |||
331 | impl<'de, const N: usize> Deserialize<'de> for Digest<N> { | ||
332 | fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> | ||
333 | where | ||
334 | D: serde::Deserializer<'de>, | ||
335 | { | ||
336 | let hex = <&str>::deserialize(deserializer)?; | ||
337 | Digest::from_str(hex).map_err(de::Error::custom) | ||
338 | } | ||
339 | } | ||
340 | |||
341 | impl<const N: usize> Serialize for Digest<N> { | ||
342 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | ||
343 | where | ||
344 | S: serde::Serializer, | ||
345 | { | ||
346 | serializer.serialize_str(&format!("{self}")) | ||
347 | } | ||
348 | } | ||