use base64::prelude::*; use chrono::serde::ts_seconds; use chrono::{DateTime, Utc}; use cookie::time::{Duration, OffsetDateTime}; use dotenvy::dotenv; use futures_util::StreamExt; use hmac::{Hmac, Mac}; use jwt::{SignWithKey, VerifyWithKey}; use openidconnect::{ core::{ CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, CoreIdTokenVerifier, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, CoreTokenResponse, }, reqwest, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, RedirectUrl, Scope, StandardErrorResponse, }; use rand::{distr::Alphanumeric, rng, Rng}; use rocksdb::TransactionDB; use sha2::Sha256; use std::fs; use std::{ char, collections::HashMap, env, path::PathBuf, str::FromStr, sync::{Arc, Weak}, }; use tokio::sync::{Mutex, RwLock}; use warp::{ http::{header, Response, StatusCode}, reject::Reject, ws::{WebSocket, Ws}, Filter, Rejection, Reply, }; use yrs::{sync::Awareness, Doc, Transact}; use yrs_kvstore::DocOps; use yrs_rocksdb::RocksDBStore; use yrs_warp::{ broadcast::BroadcastGroup, ws::{WarpSink, WarpStream}, }; // Web socket flow: // -> Handle connection (ws_handler) // -> Check permissions (Not doing for new) // -> Get or load document // -> // // type OpenidClient = Client< EmptyAdditionalClaims, CoreAuthDisplay, CoreGenderClaim, CoreJweContentEncryptionAlgorithm, CoreJsonWebKey, CoreAuthPrompt, StandardErrorResponse, CoreTokenResponse, CoreTokenIntrospectionResponse, CoreRevocableToken, CoreRevocationErrorResponse, EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointMaybeSet, EndpointMaybeSet, >; const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token"; const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce"; const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf"; #[derive(Debug, serde::Serialize, serde::Deserialize)] struct UserToken { display_name: String, email: String, // Unused for now, might be checked later do allow some basic session management. // Issued At Time #[allow(dead_code)] #[serde(with = "ts_seconds")] iat: DateTime, // JWT Id #[allow(dead_code)] jti: String, } #[derive(Debug)] struct AuthError; impl Reject for AuthError {} struct Connection { bcast: BroadcastGroup, db: Arc, #[allow(dyn_drop)] _db_sub: Arc, } struct Server { // There is something to be said for keeping the broadcast group in memory for a bit after all // clients disconnect, but for now we don't bother. pub open_docs: RwLock>>, pub http_client: reqwest::Client, pub openid_client: OpenidClient, pub signing_key: Hmac, pub data_dir: PathBuf, } impl Server { pub fn new( http_client: reqwest::Client, openid_client: OpenidClient, signing_key: Hmac, data_dir: PathBuf, ) -> Self { Self { open_docs: RwLock::default(), http_client, openid_client, signing_key, data_dir, } } pub async fn get_or_create_doc(&self, name: String) -> Arc { let open_docs = self.open_docs.read().await; match open_docs.get(&name).and_then(Weak::upgrade) { Some(connection) => connection.clone(), None => { drop(open_docs); let mut open_docs = self.open_docs.write().await; let doc = Doc::new(); let data_dir = self.data_dir.join(name.clone()); if let Err(e) = fs::create_dir_all(&data_dir) { panic!("Was unable to create data directory for note {}, due to the following error. Something is very wrong!\n{}", name, e) } let db = Arc::new(TransactionDB::open_default(data_dir).expect("Failed to open DB")); // Subscribe the DB to updates let sub = { let db = db.clone(); let name = name.clone(); doc.observe_update_v1(move |_, e| { let txn = RocksDBStore::from(db.transaction()); let i = txn.push_update(&name, &e.update).unwrap(); if i % 128 == 0 { // compact updates into document txn.flush_doc(&name).unwrap(); } txn.commit().unwrap(); }) .unwrap() }; { // Load document from DB let mut txn = doc.transact_mut(); let db_txn = RocksDBStore::from(db.transaction()); db_txn.load_doc(&name, &mut txn).unwrap(); } let awareness = Arc::new(RwLock::new(Awareness::new(doc))); let group = BroadcastGroup::new(awareness, 32).await; let connection = Arc::new(Connection { bcast: group, db, _db_sub: sub, }); open_docs.insert(name, Arc::downgrade(&connection)); connection } } } } async fn create_openid_client( http_client: &reqwest::Client, issuer_url: IssuerUrl, client_id: ClientId, client_secret: ClientSecret, redirect_url: RedirectUrl, ) -> OpenidClient { let provider_metadata = CoreProviderMetadata::discover_async(issuer_url, http_client) .await .expect("Failed to discover OpenID Connect provider metadata"); CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret)) .set_redirect_uri(redirect_url) } #[tokio::main] async fn main() { // Allow loading .env to fail, since it won't be available in docker. let _ = dotenv(); let data_dir = env::var("DATA_DIR").expect("DATA_DIR not set"); let data_dir = PathBuf::from_str(&data_dir).expect("DATA_DIR is not a valid path"); let frontend_dir = env::var("FRONTEND_DIR").expect("FRONTEND_DIR not set"); let frontend_dir = PathBuf::from_str(&frontend_dir).expect("FRONTEND_DIR is not a valid path"); let index_file = frontend_dir.join("index.html"); let signing_key = env::var("AUTH_SECRET").expect("AUTH_SECRET not set"); let signing_key = Hmac::new_from_slice(signing_key.as_bytes()).expect("AUTH_SECRET is not a valid HMAC key"); let app_url = env::var("APP_URL").expect("APP_URL not set"); let http_client = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. .redirect(reqwest::redirect::Policy::none()) .build() .expect("Failed to create HTTP client"); let openid_client = create_openid_client( &http_client, IssuerUrl::new(env::var("AUTH_ISSUER_URL").expect("AUTH_ISSUER_URL not set")) .expect("Issuer URL invalid"), ClientId::new(env::var("AUTH_CLIENT_ID").expect("AUTH_CLIENT_ID not set")), ClientSecret::new(env::var("AUTH_CLIENT_SECRET").expect("AUTH_CLIENT_SECRET not set")), RedirectUrl::new(format!("{}/api/auth/callback", app_url)) .expect("Redirect URL is invalid (app url?)"), ) .await; let server = Arc::new(Server::new( http_client, openid_client, signing_key, data_dir, )); let ws = { let server = server.clone(); warp::path("sync") .and(warp::path::param()) .and(warp::ws()) .and(warp::any().map(move || server.clone())) .and(warp::cookie(COOKIE_AUTH_TOKEN)) .and_then(ws_handler) }; let api_routes = { let login_route = { let server = server.clone(); warp::path!("api" / "auth" / "login") .map(move || server.clone()) .and(warp::query()) .and_then(handle_auth_login) }; let callback_route = { let server = server.clone(); warp::path!("api" / "auth" / "callback") .map(move || server.clone()) .and(warp::cookie(COOKIE_AUTH_NONCE)) .and(warp::cookie(COOKIE_AUTH_CSRF)) .and(warp::query()) .and_then(handle_auth_callback) }; login_route.or(callback_route) }; let frontend_files = warp::fs::dir(frontend_dir); let index = warp::fs::file(index_file); let routes = api_routes.or(ws).or(frontend_files).or(index); println!("Starting server on http://0.0.0.0:9000"); warp::serve(routes).run(([0, 0, 0, 0], 9000)).await; } async fn ws_handler( name: String, ws: Ws, server: Arc, token: String, ) -> Result { let token: UserToken = token .verify_with_key(&server.signing_key) .map_err(|_| warp::reject())?; println!( "[INFO] User {} <{}> connected to document {}", token.display_name, token.email, name ); let doc_id = format!("{}/{}", token.email, name); let connection = server.get_or_create_doc(doc_id).await; Ok(ws.on_upgrade(move |socket| peer(token, name, socket, connection))) } async fn peer(token: UserToken, doc_name: String, ws: WebSocket, connection: Arc) { let (sink, stream) = ws.split(); let sink = Arc::new(Mutex::new(WarpSink::from(sink))); let stream = WarpStream::from(stream); let sub = connection.bcast.subscribe(sink, stream); let _ = sub.completed().await; println!( "[INFO] User {} <{}> disconnected from document {}", token.display_name, token.email, doc_name ); } #[derive(Debug, serde::Serialize, serde::Deserialize)] struct AuthState { redirect_url: Option, csrf_token: String, } #[derive(Debug, serde::Deserialize)] struct AuthLoginParams { redirect_url: Option, } async fn handle_auth_login( server: Arc, params: AuthLoginParams, ) -> Result { let csrf_token = CsrfToken::new_random(); let state = AuthState { redirect_url: params.redirect_url, csrf_token: csrf_token.clone().into_secret(), }; let state = CsrfToken::new( BASE64_URL_SAFE.encode( state .sign_with_key(&server.signing_key) .map_err(|_| warp::reject::custom(AuthError))?, ), ); // From what I understand I don't need the csrf_token let (auth_url, _csrf_token, nonce) = server .openid_client .authorize_url( AuthenticationFlow::::AuthorizationCode, || state, Nonce::new_random, ) .add_scope(Scope::new("openid".to_string())) .add_scope(Scope::new("profile".to_string())) .add_scope(Scope::new("email".to_string())) .url(); Ok(Response::builder() .header( header::SET_COOKIE, cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret())) .path("/") .http_only(true) .build() .to_string(), ) .header( header::SET_COOKIE, cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret())) .path("/") .http_only(true) .build() .to_string(), ) .header(header::LOCATION, auth_url.as_str()) .status(StatusCode::FOUND) .body("Login") .unwrap()) } #[derive(Debug, serde::Deserialize)] struct AuthCallbackParams { code: String, state: String, } async fn handle_auth_callback( server: Arc, nonce: String, csrf_token: String, params: AuthCallbackParams, ) -> Result { let code = AuthorizationCode::new(params.code); let nonce = Nonce::new(nonce); let state: AuthState = String::from_utf8( BASE64_URL_SAFE .decode(params.state) .map_err(|_| warp::reject::custom(AuthError))?, ) .map_err(|_| warp::reject::custom(AuthError))? .verify_with_key(&server.signing_key) .map_err(|_| warp::reject::custom(AuthError))?; if state.csrf_token != csrf_token { return Err(warp::reject::custom(AuthError)); } let token_response = server .openid_client .exchange_code(code) .map_err(|_| warp::reject::custom(AuthError))? .request_async(&server.http_client) .await .map_err(|_| warp::reject::custom(AuthError))?; let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier(); let id_token_claims = token_response .extra_fields() .id_token() .ok_or(warp::reject::custom(AuthError))? .claims(&id_token_verifier, &nonce) .map_err(|_| warp::reject::custom(AuthError))?; let token_id: String = rng() .sample_iter(&Alphanumeric) .take(32) .map(char::from) .collect(); let display_name = id_token_claims .given_name() .ok_or(warp::reject::custom(AuthError))? .get(None) .ok_or(warp::reject::custom(AuthError))? .to_string(); let email = id_token_claims .email() .ok_or(warp::reject::custom(AuthError))? .to_string(); println!("[INFO] Authenticated: {} <{}>", display_name, email); let user_token = UserToken { display_name, email, iat: Utc::now(), jti: token_id, }; let user_token = user_token .sign_with_key(&server.signing_key) .map_err(|_| warp::reject::custom(AuthError))?; Ok(Response::builder() .header( header::SET_COOKIE, cookie::Cookie::build((COOKIE_AUTH_NONCE, "")) .path("/") .http_only(true) .build() .to_string(), ) .header( header::SET_COOKIE, cookie::Cookie::build((COOKIE_AUTH_CSRF, "")) .expires(OffsetDateTime::UNIX_EPOCH) .path("/") .http_only(true) .build() .to_string(), ) .header( header::SET_COOKIE, cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token)) .expires(OffsetDateTime::now_utc() + Duration::days(365)) .path("/") .build() .to_string(), ) .header( header::LOCATION, state.redirect_url.unwrap_or("/app".to_string()), ) .status(StatusCode::FOUND) .body("") .unwrap()) }