v3/backend/src/main.rs
Kalle Struik e3cee4e596
All checks were successful
/ Push Docker image to local registry (push) Successful in 3m9s
Fix really stupid oversight in the backend that caused all data to get messed up
We were saving all data into a single rocksdb which caused everything to
conflict with each other. This caused edits in one note to delete
content in all other notes.
2025-04-14 22:42:48 +02:00

485 lines
15 KiB
Rust

use base64::prelude::*;
use chrono::serde::ts_seconds;
use chrono::{DateTime, Utc};
use cookie::time::{Duration, OffsetDateTime};
use dotenvy::dotenv;
use futures_util::StreamExt;
use hmac::{Hmac, Mac};
use jwt::{SignWithKey, VerifyWithKey};
use openidconnect::{
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim,
CoreIdTokenVerifier, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
CoreProviderMetadata, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse,
CoreTokenIntrospectionResponse, CoreTokenResponse,
},
reqwest, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
RedirectUrl, Scope, StandardErrorResponse,
};
use rand::{distr::Alphanumeric, rng, Rng};
use rocksdb::TransactionDB;
use sha2::Sha256;
use std::fs;
use std::{
char,
collections::HashMap,
env,
path::PathBuf,
str::FromStr,
sync::{Arc, Weak},
};
use tokio::sync::{Mutex, RwLock};
use warp::{
http::{header, Response, StatusCode},
reject::Reject,
ws::{WebSocket, Ws},
Filter, Rejection, Reply,
};
use yrs::{sync::Awareness, Doc, Transact};
use yrs_kvstore::DocOps;
use yrs_rocksdb::RocksDBStore;
use yrs_warp::{
broadcast::BroadcastGroup,
ws::{WarpSink, WarpStream},
};
// Web socket flow:
// -> Handle connection (ws_handler)
// -> Check permissions (Not doing for new)
// -> Get or load document
// ->
//
//
type OpenidClient = Client<
EmptyAdditionalClaims,
CoreAuthDisplay,
CoreGenderClaim,
CoreJweContentEncryptionAlgorithm,
CoreJsonWebKey,
CoreAuthPrompt,
StandardErrorResponse<CoreErrorResponseType>,
CoreTokenResponse,
CoreTokenIntrospectionResponse,
CoreRevocableToken,
CoreRevocationErrorResponse,
EndpointSet,
EndpointNotSet,
EndpointNotSet,
EndpointNotSet,
EndpointMaybeSet,
EndpointMaybeSet,
>;
const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token";
const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce";
const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf";
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct UserToken {
display_name: String,
email: String,
// Unused for now, might be checked later do allow some basic session management.
// Issued At Time
#[allow(dead_code)]
#[serde(with = "ts_seconds")]
iat: DateTime<Utc>,
// JWT Id
#[allow(dead_code)]
jti: String,
}
#[derive(Debug)]
struct AuthError;
impl Reject for AuthError {}
struct Connection {
bcast: BroadcastGroup,
db: Arc<TransactionDB>,
#[allow(dyn_drop)]
_db_sub: Arc<dyn Drop + Send + Sync>,
}
struct Server {
// There is something to be said for keeping the broadcast group in memory for a bit after all
// clients disconnect, but for now we don't bother.
pub open_docs: RwLock<HashMap<String, Weak<Connection>>>,
pub http_client: reqwest::Client,
pub openid_client: OpenidClient,
pub signing_key: Hmac<Sha256>,
pub data_dir: PathBuf,
}
impl Server {
pub fn new(
http_client: reqwest::Client,
openid_client: OpenidClient,
signing_key: Hmac<Sha256>,
data_dir: PathBuf,
) -> Self {
Self {
open_docs: RwLock::default(),
http_client,
openid_client,
signing_key,
data_dir,
}
}
pub async fn get_or_create_doc(&self, name: String) -> Arc<Connection> {
let open_docs = self.open_docs.read().await;
match open_docs.get(&name).and_then(Weak::upgrade) {
Some(connection) => connection.clone(),
None => {
drop(open_docs);
let mut open_docs = self.open_docs.write().await;
let doc = Doc::new();
let data_dir = self.data_dir.join(name.clone());
if let Err(e) = fs::create_dir_all(&data_dir) {
panic!("Was unable to create data directory for note {}, due to the following error. Something is very wrong!\n{}", name, e)
}
let db =
Arc::new(TransactionDB::open_default(data_dir).expect("Failed to open DB"));
// Subscribe the DB to updates
let sub = {
let db = db.clone();
let name = name.clone();
doc.observe_update_v1(move |_, e| {
let txn = RocksDBStore::from(db.transaction());
let i = txn.push_update(&name, &e.update).unwrap();
if i % 128 == 0 {
// compact updates into document
txn.flush_doc(&name).unwrap();
}
txn.commit().unwrap();
})
.unwrap()
};
{
// Load document from DB
let mut txn = doc.transact_mut();
let db_txn = RocksDBStore::from(db.transaction());
db_txn.load_doc(&name, &mut txn).unwrap();
}
let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
let group = BroadcastGroup::new(awareness, 32).await;
let connection = Arc::new(Connection {
bcast: group,
db,
_db_sub: sub,
});
open_docs.insert(name, Arc::downgrade(&connection));
connection
}
}
}
}
async fn create_openid_client(
http_client: &reqwest::Client,
issuer_url: IssuerUrl,
client_id: ClientId,
client_secret: ClientSecret,
redirect_url: RedirectUrl,
) -> OpenidClient {
let provider_metadata = CoreProviderMetadata::discover_async(issuer_url, http_client)
.await
.expect("Failed to discover OpenID Connect provider metadata");
CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
.set_redirect_uri(redirect_url)
}
#[tokio::main]
async fn main() {
// Allow loading .env to fail, since it won't be available in docker.
let _ = dotenv();
let data_dir = env::var("DATA_DIR").expect("DATA_DIR not set");
let data_dir = PathBuf::from_str(&data_dir).expect("DATA_DIR is not a valid path");
let frontend_dir = env::var("FRONTEND_DIR").expect("FRONTEND_DIR not set");
let frontend_dir = PathBuf::from_str(&frontend_dir).expect("FRONTEND_DIR is not a valid path");
let index_file = frontend_dir.join("index.html");
let signing_key = env::var("AUTH_SECRET").expect("AUTH_SECRET not set");
let signing_key =
Hmac::new_from_slice(signing_key.as_bytes()).expect("AUTH_SECRET is not a valid HMAC key");
let app_url = env::var("APP_URL").expect("APP_URL not set");
let http_client = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Failed to create HTTP client");
let openid_client = create_openid_client(
&http_client,
IssuerUrl::new(env::var("AUTH_ISSUER_URL").expect("AUTH_ISSUER_URL not set"))
.expect("Issuer URL invalid"),
ClientId::new(env::var("AUTH_CLIENT_ID").expect("AUTH_CLIENT_ID not set")),
ClientSecret::new(env::var("AUTH_CLIENT_SECRET").expect("AUTH_CLIENT_SECRET not set")),
RedirectUrl::new(format!("{}/api/auth/callback", app_url))
.expect("Redirect URL is invalid (app url?)"),
)
.await;
let server = Arc::new(Server::new(
http_client,
openid_client,
signing_key,
data_dir,
));
let ws = {
let server = server.clone();
warp::path("sync")
.and(warp::path::param())
.and(warp::ws())
.and(warp::any().map(move || server.clone()))
.and(warp::cookie(COOKIE_AUTH_TOKEN))
.and_then(ws_handler)
};
let api_routes = {
let login_route = {
let server = server.clone();
warp::path!("api" / "auth" / "login")
.map(move || server.clone())
.and(warp::query())
.and_then(handle_auth_login)
};
let callback_route = {
let server = server.clone();
warp::path!("api" / "auth" / "callback")
.map(move || server.clone())
.and(warp::cookie(COOKIE_AUTH_NONCE))
.and(warp::cookie(COOKIE_AUTH_CSRF))
.and(warp::query())
.and_then(handle_auth_callback)
};
login_route.or(callback_route)
};
let frontend_files = warp::fs::dir(frontend_dir);
let index = warp::fs::file(index_file);
let routes = api_routes.or(ws).or(frontend_files).or(index);
println!("Starting server on http://0.0.0.0:9000");
warp::serve(routes).run(([0, 0, 0, 0], 9000)).await;
}
async fn ws_handler(
name: String,
ws: Ws,
server: Arc<Server>,
token: String,
) -> Result<impl Reply, Rejection> {
let token: UserToken = token
.verify_with_key(&server.signing_key)
.map_err(|_| warp::reject())?;
println!(
"[INFO] User {} <{}> connected to document {}",
token.display_name, token.email, name
);
let doc_id = format!("{}/{}", token.email, name);
let connection = server.get_or_create_doc(doc_id).await;
Ok(ws.on_upgrade(move |socket| peer(token, name, socket, connection)))
}
async fn peer(token: UserToken, doc_name: String, ws: WebSocket, connection: Arc<Connection>) {
let (sink, stream) = ws.split();
let sink = Arc::new(Mutex::new(WarpSink::from(sink)));
let stream = WarpStream::from(stream);
let sub = connection.bcast.subscribe(sink, stream);
let _ = sub.completed().await;
println!(
"[INFO] User {} <{}> disconnected from document {}",
token.display_name, token.email, doc_name
);
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct AuthState {
redirect_url: Option<String>,
csrf_token: String,
}
#[derive(Debug, serde::Deserialize)]
struct AuthLoginParams {
redirect_url: Option<String>,
}
async fn handle_auth_login(
server: Arc<Server>,
params: AuthLoginParams,
) -> Result<impl Reply, Rejection> {
let csrf_token = CsrfToken::new_random();
let state = AuthState {
redirect_url: params.redirect_url,
csrf_token: csrf_token.clone().into_secret(),
};
let state = CsrfToken::new(
BASE64_URL_SAFE.encode(
state
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?,
),
);
// From what I understand I don't need the csrf_token
let (auth_url, _csrf_token, nonce) = server
.openid_client
.authorize_url(
AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
|| state,
Nonce::new_random,
)
.add_scope(Scope::new("openid".to_string()))
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(header::LOCATION, auth_url.as_str())
.status(StatusCode::FOUND)
.body("Login")
.unwrap())
}
#[derive(Debug, serde::Deserialize)]
struct AuthCallbackParams {
code: String,
state: String,
}
async fn handle_auth_callback(
server: Arc<Server>,
nonce: String,
csrf_token: String,
params: AuthCallbackParams,
) -> Result<impl Reply, Rejection> {
let code = AuthorizationCode::new(params.code);
let nonce = Nonce::new(nonce);
let state: AuthState = String::from_utf8(
BASE64_URL_SAFE
.decode(params.state)
.map_err(|_| warp::reject::custom(AuthError))?,
)
.map_err(|_| warp::reject::custom(AuthError))?
.verify_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
if state.csrf_token != csrf_token {
return Err(warp::reject::custom(AuthError));
}
let token_response = server
.openid_client
.exchange_code(code)
.map_err(|_| warp::reject::custom(AuthError))?
.request_async(&server.http_client)
.await
.map_err(|_| warp::reject::custom(AuthError))?;
let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier();
let id_token_claims = token_response
.extra_fields()
.id_token()
.ok_or(warp::reject::custom(AuthError))?
.claims(&id_token_verifier, &nonce)
.map_err(|_| warp::reject::custom(AuthError))?;
let token_id: String = rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let display_name = id_token_claims
.given_name()
.ok_or(warp::reject::custom(AuthError))?
.get(None)
.ok_or(warp::reject::custom(AuthError))?
.to_string();
let email = id_token_claims
.email()
.ok_or(warp::reject::custom(AuthError))?
.to_string();
println!("[INFO] Authenticated: {} <{}>", display_name, email);
let user_token = UserToken {
display_name,
email,
iat: Utc::now(),
jti: token_id,
};
let user_token = user_token
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, ""))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, ""))
.expires(OffsetDateTime::UNIX_EPOCH)
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token))
.expires(OffsetDateTime::now_utc() + Duration::days(365))
.path("/")
.build()
.to_string(),
)
.header(
header::LOCATION,
state.redirect_url.unwrap_or("/app".to_string()),
)
.status(StatusCode::FOUND)
.body("")
.unwrap())
}