aboutsummaryrefslogtreecommitdiffstats
path: root/gitolfs3-common/src
diff options
context:
space:
mode:
Diffstat (limited to 'gitolfs3-common/src')
-rw-r--r--gitolfs3-common/src/lib.rs348
1 files changed, 348 insertions, 0 deletions
diff --git a/gitolfs3-common/src/lib.rs b/gitolfs3-common/src/lib.rs
new file mode 100644
index 0000000..917f566
--- /dev/null
+++ b/gitolfs3-common/src/lib.rs
@@ -0,0 +1,348 @@
1use chrono::{DateTime, Utc};
2use serde::{de, Deserialize, Serialize};
3use std::{
4 fmt::{self, Write},
5 ops,
6 str::FromStr,
7};
8use subtle::ConstantTimeEq;
9
10#[repr(u8)]
11enum AuthType {
12 BatchApi = 1,
13 Download = 2,
14}
15
16#[derive(Debug, Copy, Clone)]
17pub struct Claims<'a> {
18 pub specific_claims: SpecificClaims,
19 pub repo_path: &'a str,
20 pub expires_at: DateTime<Utc>,
21}
22
23#[derive(Debug, Copy, Clone)]
24pub enum SpecificClaims {
25 BatchApi(Operation),
26 Download(Oid),
27}
28
29pub type Oid = Digest<32>;
30
31#[derive(Debug, Eq, PartialEq, Copy, Clone, Serialize, Deserialize)]
32#[repr(u8)]
33pub enum Operation {
34 #[serde(rename = "download")]
35 Download = 1,
36 #[serde(rename = "upload")]
37 Upload = 2,
38}
39
40/// Returns None if the claims are invalid. Repo path length may be no more than 100 bytes.
41pub fn generate_tag(claims: Claims, key: impl AsRef<[u8]>) -> Option<Digest<32>> {
42 if claims.repo_path.len() > 100 {
43 return None;
44 }
45
46 let mut hmac = hmac_sha256::HMAC::new(key);
47 match claims.specific_claims {
48 SpecificClaims::BatchApi(operation) => {
49 hmac.update([AuthType::BatchApi as u8]);
50 hmac.update([operation as u8]);
51 }
52 SpecificClaims::Download(oid) => {
53 hmac.update([AuthType::Download as u8]);
54 hmac.update(oid.as_bytes());
55 }
56 }
57 hmac.update([claims.repo_path.len() as u8]);
58 hmac.update(claims.repo_path.as_bytes());
59 hmac.update(claims.expires_at.timestamp().to_be_bytes());
60 Some(hmac.finalize().into())
61}
62
63#[derive(Debug, PartialEq, Eq, Copy, Clone)]
64pub struct ParseOperationError;
65
66impl fmt::Display for ParseOperationError {
67 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68 write!(f, "operation should be 'download' or 'upload'")
69 }
70}
71
72impl FromStr for Operation {
73 type Err = ParseOperationError;
74
75 fn from_str(s: &str) -> Result<Self, Self::Err> {
76 match s {
77 "upload" => Ok(Self::Upload),
78 "download" => Ok(Self::Download),
79 _ => Err(ParseOperationError),
80 }
81 }
82}
83
84/// None means out of range.
85fn decode_nibble(c: u8) -> Option<u8> {
86 if c.is_ascii_digit() {
87 Some(c - b'0')
88 } else if (b'a'..=b'f').contains(&c) {
89 Some(c - b'a' + 10)
90 } else if (b'A'..=b'F').contains(&c) {
91 Some(c - b'A' + 10)
92 } else {
93 None
94 }
95}
96
97#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
98pub struct HexByte(pub u8);
99
100impl<'de> Deserialize<'de> for HexByte {
101 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
102 where
103 D: serde::Deserializer<'de>,
104 {
105 let str = <&str>::deserialize(deserializer)?;
106 let &[b1, b2] = str.as_bytes() else {
107 return Err(de::Error::invalid_length(
108 str.len(),
109 &"two hexadecimal characters",
110 ));
111 };
112 let (Some(b1), Some(b2)) = (decode_nibble(b1), decode_nibble(b2)) else {
113 return Err(de::Error::invalid_value(
114 de::Unexpected::Str(str),
115 &"two hexadecimal characters",
116 ));
117 };
118 Ok(HexByte((b1 << 4) | b2))
119 }
120}
121
122impl fmt::Display for HexByte {
123 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
124 let &HexByte(b) = self;
125 HexFmt(&[b]).fmt(f)
126 }
127}
128
129#[derive(Debug, PartialEq, Eq, Copy, Clone)]
130pub enum ParseHexError {
131 UnevenNibbles,
132 InvalidCharacter,
133 TooShort,
134 TooLong,
135}
136
137impl fmt::Display for ParseHexError {
138 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
139 match self {
140 Self::UnevenNibbles => {
141 write!(f, "uneven amount of nibbles (chars in range [a-zA-Z0-9])")
142 }
143 Self::InvalidCharacter => write!(f, "non-hex character encountered"),
144 Self::TooShort => write!(f, "unexpected end of hex sequence"),
145 Self::TooLong => write!(f, "longer hex sequence than expected"),
146 }
147 }
148}
149
150#[derive(Debug)]
151pub enum ReadHexError {
152 Io(std::io::Error),
153 Format(ParseHexError),
154}
155
156impl fmt::Display for ReadHexError {
157 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
158 match self {
159 Self::Io(e) => e.fmt(f),
160 Self::Format(e) => e.fmt(f),
161 }
162 }
163}
164
165fn parse_hex_exact(value: &str, buf: &mut [u8]) -> Result<(), ParseHexError> {
166 if value.bytes().len() % 2 == 1 {
167 return Err(ParseHexError::UnevenNibbles);
168 }
169 if value.bytes().len() < 2 * buf.len() {
170 return Err(ParseHexError::TooShort);
171 }
172 if value.bytes().len() > 2 * buf.len() {
173 return Err(ParseHexError::TooLong);
174 }
175 for (i, c) in value.bytes().enumerate() {
176 if let Some(b) = decode_nibble(c) {
177 if i % 2 == 0 {
178 buf[i / 2] = b << 4;
179 } else {
180 buf[i / 2] |= b;
181 }
182 } else {
183 return Err(ParseHexError::InvalidCharacter);
184 }
185 }
186 Ok(())
187}
188
189pub type Key = SafeByteArray<64>;
190
191pub fn load_key(path: &str) -> Result<Key, ReadHexError> {
192 let key_str = std::fs::read_to_string(path).map_err(ReadHexError::Io)?;
193 key_str.trim().parse().map_err(ReadHexError::Format)
194}
195
196pub struct SafeByteArray<const N: usize> {
197 inner: [u8; N],
198}
199
200impl<const N: usize> SafeByteArray<N> {
201 pub fn new() -> Self {
202 Self { inner: [0; N] }
203 }
204}
205
206impl<const N: usize> Default for SafeByteArray<N> {
207 fn default() -> Self {
208 Self::new()
209 }
210}
211
212impl<const N: usize> AsRef<[u8]> for SafeByteArray<N> {
213 fn as_ref(&self) -> &[u8] {
214 &self.inner
215 }
216}
217
218impl<const N: usize> AsMut<[u8]> for SafeByteArray<N> {
219 fn as_mut(&mut self) -> &mut [u8] {
220 &mut self.inner
221 }
222}
223
224impl<const N: usize> Drop for SafeByteArray<N> {
225 fn drop(&mut self) {
226 self.inner.fill(0)
227 }
228}
229
230impl<const N: usize> FromStr for SafeByteArray<N> {
231 type Err = ParseHexError;
232
233 fn from_str(value: &str) -> Result<Self, Self::Err> {
234 let mut sba = Self { inner: [0u8; N] };
235 parse_hex_exact(value, &mut sba.inner)?;
236 Ok(sba)
237 }
238}
239
240pub struct HexFmt<B: AsRef<[u8]>>(pub B);
241
242impl<B: AsRef<[u8]>> fmt::Display for HexFmt<B> {
243 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
244 let HexFmt(buf) = self;
245 for b in buf.as_ref() {
246 let (high, low) = (b >> 4, b & 0xF);
247 let highc = if high < 10 {
248 high + b'0'
249 } else {
250 high - 10 + b'a'
251 };
252 let lowc = if low < 10 {
253 low + b'0'
254 } else {
255 low - 10 + b'a'
256 };
257 f.write_char(highc as char)?;
258 f.write_char(lowc as char)?;
259 }
260 Ok(())
261 }
262}
263
264#[derive(Debug, Copy, Clone)]
265pub struct Digest<const N: usize> {
266 inner: [u8; N],
267}
268
269impl<const N: usize> ops::Index<usize> for Digest<N> {
270 type Output = u8;
271
272 fn index(&self, index: usize) -> &Self::Output {
273 &self.inner[index]
274 }
275}
276
277impl<const N: usize> Digest<N> {
278 pub fn as_bytes(&self) -> &[u8; N] {
279 &self.inner
280 }
281}
282
283impl<const N: usize> fmt::Display for Digest<N> {
284 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
285 HexFmt(&self.inner).fmt(f)
286 }
287}
288
289impl<const N: usize> Digest<N> {
290 pub fn new(data: [u8; N]) -> Self {
291 Self { inner: data }
292 }
293}
294
295impl<const N: usize> From<[u8; N]> for Digest<N> {
296 fn from(value: [u8; N]) -> Self {
297 Self::new(value)
298 }
299}
300
301impl<const N: usize> From<Digest<N>> for [u8; N] {
302 fn from(val: Digest<N>) -> Self {
303 val.inner
304 }
305}
306
307impl<const N: usize> FromStr for Digest<N> {
308 type Err = ParseHexError;
309
310 fn from_str(value: &str) -> Result<Self, Self::Err> {
311 let mut buf = [0u8; N];
312 parse_hex_exact(value, &mut buf)?;
313 Ok(buf.into())
314 }
315}
316
317impl<const N: usize> ConstantTimeEq for Digest<N> {
318 fn ct_eq(&self, other: &Self) -> subtle::Choice {
319 self.inner.ct_eq(&other.inner)
320 }
321}
322
323impl<const N: usize> PartialEq for Digest<N> {
324 fn eq(&self, other: &Self) -> bool {
325 self.ct_eq(other).into()
326 }
327}
328
329impl<const N: usize> Eq for Digest<N> {}
330
331impl<'de, const N: usize> Deserialize<'de> for Digest<N> {
332 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
333 where
334 D: serde::Deserializer<'de>,
335 {
336 let hex = <&str>::deserialize(deserializer)?;
337 Digest::from_str(hex).map_err(de::Error::custom)
338 }
339}
340
341impl<const N: usize> Serialize for Digest<N> {
342 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
343 where
344 S: serde::Serializer,
345 {
346 serializer.serialize_str(&format!("{self}"))
347 }
348}