aboutsummaryrefslogtreecommitdiffstats
path: root/common/src
diff options
context:
space:
mode:
Diffstat (limited to 'common/src')
-rw-r--r--common/src/lib.rs368
1 files changed, 368 insertions, 0 deletions
diff --git a/common/src/lib.rs b/common/src/lib.rs
new file mode 100644
index 0000000..995352d
--- /dev/null
+++ b/common/src/lib.rs
@@ -0,0 +1,368 @@
1use chrono::{DateTime, Utc};
2use serde::de;
3use serde::{Deserialize, Serialize};
4use std::fmt::Write;
5use std::ops;
6use std::{fmt, str::FromStr};
7use subtle::ConstantTimeEq;
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
10#[repr(u8)]
11pub enum Operation {
12 #[serde(rename = "download")]
13 Download = 1,
14 #[serde(rename = "upload")]
15 Upload = 2,
16}
17
18#[derive(Debug, PartialEq, Eq, Copy, Clone)]
19pub struct ParseOperationError;
20
21impl fmt::Display for ParseOperationError {
22 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23 write!(f, "operation should be 'download' or 'upload'")
24 }
25}
26
27impl FromStr for Operation {
28 type Err = ParseOperationError;
29
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 match s {
32 "upload" => Ok(Self::Upload),
33 "download" => Ok(Self::Download),
34 _ => Err(ParseOperationError),
35 }
36 }
37}
38
39#[repr(u8)]
40enum AuthType {
41 BatchApi = 1,
42 Download = 2,
43}
44
45/// None means out of range.
46fn decode_nibble(c: u8) -> Option<u8> {
47 if c.is_ascii_digit() {
48 Some(c - b'0')
49 } else if (b'a'..=b'f').contains(&c) {
50 Some(c - b'a' + 10)
51 } else if (b'A'..=b'F').contains(&c) {
52 Some(c - b'A' + 10)
53 } else {
54 None
55 }
56}
57
58#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
59pub struct HexByte(pub u8);
60
61impl<'de> Deserialize<'de> for HexByte {
62 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63 where
64 D: serde::Deserializer<'de>,
65 {
66 let str = <&str>::deserialize(deserializer)?;
67 let &[b1, b2] = str.as_bytes() else {
68 return Err(de::Error::invalid_length(
69 str.len(),
70 &"two hexadecimal characters",
71 ));
72 };
73 let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else {
74 return Err(de::Error::invalid_value(
75 de::Unexpected::Str(str),
76 &"two hexadecimal characters",
77 ));
78 };
79 Ok(HexByte((b1 << 4) | b2))
80 }
81}
82
83impl fmt::Display for HexByte {
84 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85 let &HexByte(b) = self;
86 HexFmt(&[b]).fmt(f)
87 }
88}
89
90#[derive(Debug, PartialEq, Eq, Copy, Clone)]
91pub enum ParseHexError {
92 UnevenNibbles,
93 InvalidCharacter,
94 TooShort,
95 TooLong,
96}
97
98impl fmt::Display for ParseHexError {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 match self {
101 Self::UnevenNibbles => {
102 write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])")
103 }
104 Self::InvalidCharacter => write!(f, "non-hex character encountered"),
105 Self::TooShort => write!(f, "unexpected end of hex sequence"),
106 Self::TooLong => write!(f, "longer hex sequence than expected"),
107 }
108 }
109}
110
111#[derive(Debug)]
112pub enum ReadHexError {
113 Io(std::io::Error),
114 Format(ParseHexError),
115}
116
117impl fmt::Display for ReadHexError {
118 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
119 match self {
120 Self::Io(e) => e.fmt(f),
121 Self::Format(e) => e.fmt(f),
122 }
123 }
124}
125
126fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> {
127 if value.bytes().len() % 2 == 1 {
128 return Err(ParseHexError::UnevenNibbles);
129 }
130 if value.bytes().len() < 2 * buf.len() {
131 return Err(ParseHexError::TooShort);
132 }
133 if value.bytes().len() > 2 * buf.len() {
134 return Err(ParseHexError::TooLong);
135 }
136 for (i, c) in value.bytes().enumerate() {
137 if let Some(b) = decode_nibble(c) {
138 if i % 2 == 0 {
139 buf[i / 2] = b << 4;
140 } else {
141 buf[i / 2] |= b;
142 }
143 } else {
144 return Err(ParseHexError::InvalidCharacter);
145 }
146 }
147 Ok(())
148}
149
150pub struct SafeByteArray<const N: usize> {
151 inner: [u8; N],
152}
153
154impl<const N: usize> SafeByteArray<N> {
155 pub fn new() -> Self {
156 Self { inner: [0; N] }
157 }
158}
159
160impl<const N: usize> Default for SafeByteArray<N> {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> {
167 fn as_ref(&self) -> &[u8] {
168 &self.inner
169 }
170}
171
172impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> {
173 fn as_mut(&mut self) -> &mut [u8] {
174 &mut self.inner
175 }
176}
177
178impl<const N: usize> Drop for SafeByteArray<N> {
179 fn drop(&mut self) {
180 self.inner.fill(0)
181 }
182}
183
184impl<const N: usize> FromStr for SafeByteArray<N> {
185 type Err = ParseHexError;
186
187 fn from_str(value: &str) -> Result<Self, Self::Err> {
188 let mut sba = Self { inner: [0u8; N] };
189 parse_hex_exact(value, &mut sba.inner)?;
190 Ok(sba)
191 }
192}
193
194pub type Oid = Digest<32>;
195
196#[derive(Debug, Copy, Clone)]
197pub enum SpecificClaims {
198 BatchApi(Operation),
199 Download(Oid),
200}
201
202#[derive(Debug, Copy, Clone)]
203pub struct Claims<'a> {
204 pub specific_claims: SpecificClaims,
205 pub repo_path: &'a str,
206 pub expires_at: DateTime<Utc>,
207}
208
209/// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes.
210pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> {
211 if claims.repo_path.len() > 100 {
212 return None;
213 }
214
215 let mut hmac = hmac_sha256::HMAC::new(key);
216 match claims.specific_claims {
217 SpecificClaims::BatchApi(operation) => {
218 hmac.update([AuthType::BatchApi as u8]);
219 hmac.update([operation as u8]);
220 }
221 SpecificClaims::Download(oid) => {
222 hmac.update([AuthType::Download as u8]);
223 hmac.update(oid.as_bytes());
224 }
225 }
226 hmac.update([claims.repo_path.len() as u8]);
227 hmac.update(claims.repo_path.as_bytes());
228 hmac.update(claims.expires_at.timestamp().to_be_bytes());
229 Some(hmac.finalize().into())
230}
231
232pub struct HexFmt<B: AsRef<[u8]>>(pub B);
233
234impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> {
235 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
236 let HexFmt(buf) = self;
237 for b in buf.as_ref() {
238 let (high, low) = (b >> 4, b & 0xF);
239 let highc = if high < 10 {
240 high + b'0'
241 } else {
242 high - 10 + b'a'
243 };
244 let lowc = if low < 10 {
245 low + b'0'
246 } else {
247 low - 10 + b'a'
248 };
249 f.write_char(highc as char)?;
250 f.write_char(lowc as char)?;
251 }
252 Ok(())
253 }
254}
255
256pub struct EscJsonFmt<'a>(pub &'a str);
257
258impl<'a> fmt::Display for EscJsonFmt<'a> {
259 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260 let EscJsonFmt(buf) = self;
261 for c in buf.chars() {
262 match c {
263 '"' => f.write_str("\\\"")?, // quote
264 '\\' => f.write_str("\\\\")?, // backslash
265 '\x08' => f.write_str("\\b")?, // backspace
266 '\x0C' => f.write_str("\\f")?, // form feed
267 '\n' => f.write_str("\\n")?, // line feed
268 '\r' => f.write_str("\\r")?, // carriage return
269 '\t' => f.write_str("\\t")?, // horizontal tab
270 _ => f.write_char(c)?,
271 };
272 }
273 Ok(())
274 }
275}
276
277#[derive(Debug, Copy, Clone)]
278pub struct Digest<const N: usize> {
279 inner: [u8; N],
280}
281
282impl<const N: usize> ops::Index<usize> for Digest<N> {
283 type Output = u8;
284
285 fn index(&self, index: usize) -> &Self::Output {
286 &self.inner[index]
287 }
288}
289
290impl<const N: usize> Digest<N> {
291 pub fn as_bytes(&self) -> &[u8; N] {
292 &self.inner
293 }
294}
295
296impl<const N: usize> fmt::Display for Digest<N> {
297 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
298 HexFmt(&self.inner).fmt(f)
299 }
300}
301
302impl<const N: usize> Digest<N> {
303 pub fn new(data: [u8; N]) -> Self {
304 Self { inner: data }
305 }
306}
307
308impl<const N: usize> From<[u8; N]> for Digest<N> {
309 fn from(value: [u8; N]) -> Self {
310 Self::new(value)
311 }
312}
313
314impl<const N: usize> From<Digest<N>> for [u8; N] {
315 fn from(val: Digest<N>) -> Self {
316 val.inner
317 }
318}
319
320impl<const N: usize> FromStr for Digest<N> {
321 type Err = ParseHexError;
322
323 fn from_str(value: &str) -> Result<Self, Self::Err> {
324 let mut buf = [0u8; N];
325 parse_hex_exact(value, &mut buf)?;
326 Ok(buf.into())
327 }
328}
329
330impl<const N: usize> ConstantTimeEq for Digest<N> {
331 fn ct_eq(&self, other: &Self) -> subtle::Choice {
332 self.inner.ct_eq(&other.inner)
333 }
334}
335
336impl<const N: usize> PartialEq for Digest<N> {
337 fn eq(&self, other: &Self) -> bool {
338 self.ct_eq(other).into()
339 }
340}
341
342impl<const N: usize> Eq for Digest<N> {}
343
344impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
345 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
346 where
347 D: serde::Deserializer<'de>,
348 {
349 let hex = <&str>::deserialize(deserializer)?;
350 Digest::from_str(hex).map_err(de::Error::custom)
351 }
352}
353
354impl<const N: usize> Serialize for Digest<N> {
355 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
356 where
357 S: serde::Serializer,
358 {
359 serializer.serialize_str(&format!("{self}"))
360 }
361}
362
363pub type Key = SafeByteArray<64>;
364
365pub fn load_key(path: &str) -> Result<Key, ReadHexError> {
366 let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?;
367 key_str.trim().parse().map_err(ReadHexError::Format)
368}