1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
mod api;
mod authz;
mod config;
mod dlimit;
mod handler;
use api::RepositoryName;
use config::Config;
use dlimit::DownloadLimiter;
use axum::{
extract::OriginalUri,
http::{self, Uri},
routing::{get, post},
Router, ServiceExt,
};
use handler::{handle_batch, handle_obj_download, AppState};
use std::{process::ExitCode, sync::Arc};
use tokio::net::TcpListener;
use tower::Layer;
#[tokio::main]
async fn main() -> ExitCode {
tracing_subscriber::fmt::init();
let conf = match Config::load() {
Ok(conf) => conf,
Err(e) => {
println!("Error: {e}");
return ExitCode::from(2);
}
};
let dl_limiter = DownloadLimiter::new(conf.download_limit).await;
let shared_state = Arc::new(AppState {
s3_client: conf.s3_client,
s3_bucket: conf.s3_bucket,
authz_conf: conf.authz_conf,
base_url: conf.base_url,
dl_limiter,
});
let app = Router::new()
.route("/batch", post(handle_batch))
.route("/:oid0/:oid1/:oid", get(handle_obj_download))
.with_state(shared_state);
let middleware = axum::middleware::map_request(rewrite_url);
let app_with_middleware = middleware.layer(app);
let listener = match TcpListener::bind(conf.listen_addr).await {
Ok(listener) => listener,
Err(e) => {
println!("Failed to listen: {e}");
return ExitCode::FAILURE;
}
};
match axum::serve(listener, app_with_middleware.into_make_service()).await {
Ok(_) => ExitCode::SUCCESS,
Err(e) => {
println!("Error serving: {e}");
ExitCode::FAILURE
}
}
}
async fn rewrite_url<B>(mut req: http::Request<B>) -> Result<http::Request<B>, http::StatusCode> {
let uri = req.uri();
let original_uri = OriginalUri(uri.clone());
let Some(path_and_query) = uri.path_and_query() else {
// L @ no path & query
return Err(http::StatusCode::BAD_REQUEST);
};
let Some((repo, path)) = path_and_query.path().split_once("/info/lfs/objects") else {
return Err(http::StatusCode::NOT_FOUND);
};
let repo = repo
.trim_start_matches('/')
.trim_end_matches('/')
.to_string();
if !path.starts_with('/') || !repo.ends_with(".git") {
return Err(http::StatusCode::NOT_FOUND);
}
let mut parts = uri.clone().into_parts();
parts.path_and_query = match path_and_query.query() {
None => path.try_into().ok(),
Some(q) => format!("{path}?{q}").try_into().ok(),
};
let Ok(new_uri) = Uri::from_parts(parts) else {
return Err(http::StatusCode::INTERNAL_SERVER_ERROR);
};
*req.uri_mut() = new_uri;
req.extensions_mut().insert(original_uri);
req.extensions_mut().insert(RepositoryName(repo));
Ok(req)
}
|