aboutsummaryrefslogtreecommitdiffstats
path: root/rs/common/src
diff options
context:
space:
mode:
Diffstat (limited to 'rs/common/src')
-rw-r--r--rs/common/src/lib.rs337
1 files changed, 337 insertions, 0 deletions
diff --git a/rs/common/src/lib.rs b/rs/common/src/lib.rs
new file mode 100644
index 0000000..aafe7f1
--- /dev/null
+++ b/rs/common/src/lib.rs
@@ -0,0 +1,337 @@
1use chrono::{DateTime, Utc};
2use serde::de;
3use serde::{Deserialize, Serialize};
4use std::fmt::Write;
5use std::ops;
6use std::{fmt, str::FromStr};
7use subtle::ConstantTimeEq;
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
10#[repr(u8)]
11pub enum Operation {
12 #[serde(rename = "download")]
13 Download = 1,
14 #[serde(rename = "upload")]
15 Upload = 2,
16}
17
18#[derive(Debug, PartialEq, Eq, Copy, Clone)]
19pub struct ParseOperationError;
20
21impl fmt::Display for ParseOperationError {
22 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
23 write!(f, "operation should be 'download' or 'upload'")
24 }
25}
26
27impl FromStr for Operation {
28 type Err = ParseOperationError;
29
30 fn from_str(s: &str) -> Result<Self, Self::Err> {
31 match s {
32 "upload" => Ok(Self::Upload),
33 "download" => Ok(Self::Download),
34 _ => Err(ParseOperationError),
35 }
36 }
37}
38
39#[repr(u8)]
40pub enum AuthType {
41 GitLfsAuthenticate = 1,
42}
43
44/// None means out of range.
45fn decode_nibble(c: u8) -> Option<u8> {
46 if c.is_ascii_digit() {
47 Some(c - b'0')
48 } else if (b'a'..=b'f').contains(&c) {
49 Some(c - b'a' + 10)
50 } else if (b'A'..=b'F').contains(&c) {
51 Some(c - b'A' + 10)
52 } else {
53 None
54 }
55}
56
57#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
58pub struct HexByte(pub u8);
59
60impl<'de> Deserialize<'de> for HexByte {
61 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
62 where
63 D: serde::Deserializer<'de>,
64 {
65 let str = <&str>::deserialize(deserializer)?;
66 let &[b1, b2] = str.as_bytes() else {
67 return Err(de::Error::invalid_length(
68 str.len(),
69 &"two hexadecimal characters",
70 ));
71 };
72 let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else {
73 return Err(de::Error::invalid_value(
74 de::Unexpected::Str(str),
75 &"two hexadecimal characters",
76 ));
77 };
78 Ok(HexByte((b1 << 4) | b2))
79 }
80}
81
82impl fmt::Display for HexByte {
83 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 let &HexByte(b) = self;
85 HexFmt(&[b]).fmt(f)
86 }
87}
88
89#[derive(Debug, PartialEq, Eq, Copy, Clone)]
90pub enum ParseHexError {
91 UnevenNibbles,
92 InvalidCharacter,
93 TooShort,
94 TooLong,
95}
96
97impl fmt::Display for ParseHexError {
98 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
99 match self {
100 Self::UnevenNibbles => {
101 write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])")
102 }
103 Self::InvalidCharacter => write!(f, "non-hex character encountered"),
104 Self::TooShort => write!(f, "unexpected end of hex sequence"),
105 Self::TooLong => write!(f, "longer hex sequence than expected"),
106 }
107 }
108}
109
110#[derive(Debug)]
111pub enum ReadHexError {
112 Io(std::io::Error),
113 Format(ParseHexError),
114}
115
116impl fmt::Display for ReadHexError {
117 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
118 match self {
119 Self::Io(e) => e.fmt(f),
120 Self::Format(e) => e.fmt(f),
121 }
122 }
123}
124
125fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> {
126 if value.bytes().len() % 2 == 1 {
127 return Err(ParseHexError::UnevenNibbles);
128 }
129 if value.bytes().len() < 2 * buf.len() {
130 return Err(ParseHexError::TooShort);
131 }
132 if value.bytes().len() > 2 * buf.len() {
133 return Err(ParseHexError::TooLong);
134 }
135 for (i, c) in value.bytes().enumerate() {
136 if let Some(b) = decode_nibble(c) {
137 if i % 2 == 0 {
138 buf[i / 2] |= b;
139 } else {
140 buf[i / 2] = b << 4;
141 }
142 } else {
143 return Err(ParseHexError::InvalidCharacter);
144 }
145 }
146 Ok(())
147}
148
149pub struct SafeByteArray<const N: usize> {
150 inner: [u8; N],
151}
152
153impl<const N: usize> SafeByteArray<N> {
154 pub fn new() -> Self {
155 Self { inner: [0; N] }
156 }
157}
158
159impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> {
160 fn as_ref(&self) -> &[u8] {
161 &self.inner
162 }
163}
164
165impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> {
166 fn as_mut(&mut self) -> &mut [u8] {
167 &mut self.inner
168 }
169}
170
171impl<const N: usize> Drop for SafeByteArray<N> {
172 fn drop(&mut self) {
173 self.inner.fill(0)
174 }
175}
176
177impl<const N: usize> FromStr for SafeByteArray<N> {
178 type Err = ParseHexError;
179
180 fn from_str(value: &str) -> Result<Self, Self::Err> {
181 let mut sba = Self { inner: [0u8; N] };
182 parse_hex_exact(value, &mut sba.inner)?;
183 Ok(sba)
184 }
185}
186
187pub struct Claims<'a> {
188 pub auth_type: AuthType,
189 pub repo_path: &'a str,
190 pub operation: Operation,
191 pub expires_at: DateTime<Utc>,
192}
193
194/// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes.
195pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> {
196 if claims.repo_path.len() > 100 {
197 return None;
198 }
199
200 let mut hmac = hmac_sha256::HMAC::new(key);
201 hmac.update([claims.auth_type as u8]);
202 hmac.update([claims.repo_path.len() as u8]);
203 hmac.update(claims.repo_path.as_bytes());
204 hmac.update([claims.operation as u8]);
205 hmac.update(claims.expires_at.timestamp().to_be_bytes());
206 Some(hmac.finalize().into())
207}
208
209pub struct HexFmt<B: AsRef<[u8]>>(pub B);
210
211impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> {
212 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213 let HexFmt(buf) = self;
214 for b in buf.as_ref() {
215 let (high, low) = (b >> 4, b & 0xF);
216 let highc = if high < 10 { b'0' + high } else { b'a' + high };
217 let lowc = if low < 10 { b'0' + low } else { b'a' + low };
218 f.write_char(highc as char)?;
219 f.write_char(lowc as char)?;
220 }
221 Ok(())
222 }
223}
224
225pub struct EscJsonFmt<'a>(pub &'a str);
226
227impl<'a> fmt::Display for EscJsonFmt<'a> {
228 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
229 let EscJsonFmt(buf) = self;
230 for c in buf.chars() {
231 match c {
232 '"' => f.write_str("\\\"")?, // quote
233 '\\' => f.write_str("\\\\")?, // backslash
234 '\x08' => f.write_str("\\b")?, // backspace
235 '\x0C' => f.write_str("\\f")?, // form feed
236 '\n' => f.write_str("\\n")?, // line feed
237 '\r' => f.write_str("\\r")?, // carriage return
238 '\t' => f.write_str("\\t")?, // horizontal tab
239 _ => f.write_char(c)?,
240 };
241 }
242 Ok(())
243 }
244}
245
246#[derive(Debug, Copy, Clone)]
247pub struct Digest<const N: usize> {
248 inner: [u8; N],
249}
250
251impl<const N: usize> ops::Index<usize> for Digest<N> {
252 type Output = u8;
253
254 fn index(&self, index: usize) -> &Self::Output {
255 &self.inner[index]
256 }
257}
258
259impl<const N: usize> Digest<N> {
260 pub fn as_bytes(&self) -> &[u8; N] {
261 &self.inner
262 }
263}
264
265impl<const N: usize> fmt::Display for Digest<N> {
266 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
267 HexFmt(&self.inner).fmt(f)
268 }
269}
270
271impl<const N: usize> Digest<N> {
272 pub fn new(data: [u8; N]) -> Self {
273 Self { inner: data }
274 }
275}
276
277impl<const N: usize> From<[u8; N]> for Digest<N> {
278 fn from(value: [u8; N]) -> Self {
279 Self::new(value)
280 }
281}
282
283impl<const N: usize> Into<[u8; N]> for Digest<N> {
284 fn into(self) -> [u8; N] {
285 self.inner
286 }
287}
288
289impl<const N: usize> FromStr for Digest<N> {
290 type Err = ParseHexError;
291
292 fn from_str(value: &str) -> Result<Self, Self::Err> {
293 let mut buf = [0u8; N];
294 parse_hex_exact(value, &mut buf)?;
295 Ok(buf.into())
296 }
297}
298
299impl<const N: usize> ConstantTimeEq for Digest<N> {
300 fn ct_eq(&self, other: &Self) -> subtle::Choice {
301 self.inner.ct_eq(&other.inner)
302 }
303}
304
305impl<const N: usize> PartialEq for Digest<N> {
306 fn eq(&self, other: &Self) -> bool {
307 self.ct_eq(&other).into()
308 }
309}
310
311impl<const N: usize> Eq for Digest<N> {}
312
313impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
314 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
315 where
316 D: serde::Deserializer<'de>,
317 {
318 let hex = <&str>::deserialize(deserializer)?;
319 Digest::from_str(hex).map_err(de::Error::custom)
320 }
321}
322
323impl<const N: usize> Serialize for Digest<N> {
324 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
325 where
326 S: serde::Serializer,
327 {
328 serializer.serialize_str(&format!("{self}"))
329 }
330}
331
332pub type Key = SafeByteArray<64>;
333
334pub fn load_key(path: &str) -> Result<Key, ReadHexError> {
335 let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?;
336 key_str.trim().parse().map_err(ReadHexError::Format)
337}