From 9646266ea822c356481dea7a343169084c1af5b1 Mon Sep 17 00:00:00 2001 From: Kalle Struik Date: Tue, 15 Apr 2025 14:03:18 +0200 Subject: [PATCH] Fix initial sync on the backend --- backend/src/auth.rs | 209 +++++++++++++++++++++++++++ backend/src/main.rs | 226 +++--------------------------- backend/yrs-warp/src/broadcast.rs | 26 +++- 3 files changed, 246 insertions(+), 215 deletions(-) create mode 100644 backend/src/auth.rs diff --git a/backend/src/auth.rs b/backend/src/auth.rs new file mode 100644 index 0000000..1da69a0 --- /dev/null +++ b/backend/src/auth.rs @@ -0,0 +1,209 @@ +use std::sync::Arc; + +use base64::prelude::*; + +use chrono::serde::ts_seconds; +use chrono::{DateTime, Utc}; +use cookie::time::{Duration, OffsetDateTime}; +use jwt::{SignWithKey, VerifyWithKey}; +use openidconnect::{ + core::{CoreIdTokenVerifier, CoreResponseType}, + AuthenticationFlow, AuthorizationCode, CsrfToken, Nonce, Scope, +}; +use rand::{distr::Alphanumeric, rng, Rng}; +use warp::{ + http::{header, Response, StatusCode}, + reject::Rejection, + reply::Reply, +}; + +use crate::{AuthError, Server}; + +pub const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token"; +pub const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce"; +pub const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf"; + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +pub struct UserToken { + pub display_name: String, + pub email: String, + + // Unused for now, might be checked later do allow some basic session management. + // Issued At Time + #[allow(dead_code)] + #[serde(with = "ts_seconds")] + pub iat: DateTime, + // JWT Id + #[allow(dead_code)] + pub jti: String, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +struct AuthState { + redirect_url: Option, + csrf_token: String, +} + +#[derive(Debug, serde::Deserialize)] +pub struct AuthLoginParams { + redirect_url: Option, +} + +pub async fn handle_auth_login( + server: Arc, + params: AuthLoginParams, +) -> Result { + let csrf_token = CsrfToken::new_random(); + let state = AuthState { + redirect_url: params.redirect_url, + csrf_token: csrf_token.clone().into_secret(), + }; + let state = CsrfToken::new( + BASE64_URL_SAFE.encode( + state + .sign_with_key(&server.signing_key) + .map_err(|_| warp::reject::custom(AuthError))?, + ), + ); + // From what I understand I don't need the csrf_token + let (auth_url, _csrf_token, nonce) = server + .openid_client + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + || state, + Nonce::new_random, + ) + .add_scope(Scope::new("openid".to_string())) + .add_scope(Scope::new("profile".to_string())) + .add_scope(Scope::new("email".to_string())) + .url(); + + Ok(Response::builder() + .header( + header::SET_COOKIE, + cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret())) + .path("/") + .http_only(true) + .build() + .to_string(), + ) + .header( + header::SET_COOKIE, + cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret())) + .path("/") + .http_only(true) + .build() + .to_string(), + ) + .header(header::LOCATION, auth_url.as_str()) + .status(StatusCode::FOUND) + .body("Login") + .unwrap()) +} + +#[derive(Debug, serde::Deserialize)] +pub struct AuthCallbackParams { + code: String, + state: String, +} + +pub async fn handle_auth_callback( + server: Arc, + nonce: String, + csrf_token: String, + params: AuthCallbackParams, +) -> Result { + let code = AuthorizationCode::new(params.code); + let nonce = Nonce::new(nonce); + let state: AuthState = String::from_utf8( + BASE64_URL_SAFE + .decode(params.state) + .map_err(|_| warp::reject::custom(AuthError))?, + ) + .map_err(|_| warp::reject::custom(AuthError))? + .verify_with_key(&server.signing_key) + .map_err(|_| warp::reject::custom(AuthError))?; + + if state.csrf_token != csrf_token { + return Err(warp::reject::custom(AuthError)); + } + + let token_response = server + .openid_client + .exchange_code(code) + .map_err(|_| warp::reject::custom(AuthError))? + .request_async(&server.http_client) + .await + .map_err(|_| warp::reject::custom(AuthError))?; + + let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier(); + let id_token_claims = token_response + .extra_fields() + .id_token() + .ok_or(warp::reject::custom(AuthError))? + .claims(&id_token_verifier, &nonce) + .map_err(|_| warp::reject::custom(AuthError))?; + + let token_id: String = rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); + + let display_name = id_token_claims + .given_name() + .ok_or(warp::reject::custom(AuthError))? + .get(None) + .ok_or(warp::reject::custom(AuthError))? + .to_string(); + let email = id_token_claims + .email() + .ok_or(warp::reject::custom(AuthError))? + .to_string(); + + println!("[INFO] Authenticated: {} <{}>", display_name, email); + + let user_token = UserToken { + display_name, + email, + iat: Utc::now(), + jti: token_id, + }; + let user_token = user_token + .sign_with_key(&server.signing_key) + .map_err(|_| warp::reject::custom(AuthError))?; + + Ok(Response::builder() + .header( + header::SET_COOKIE, + cookie::Cookie::build((COOKIE_AUTH_NONCE, "")) + .path("/") + .http_only(true) + .build() + .to_string(), + ) + .header( + header::SET_COOKIE, + cookie::Cookie::build((COOKIE_AUTH_CSRF, "")) + .expires(OffsetDateTime::UNIX_EPOCH) + .path("/") + .http_only(true) + .build() + .to_string(), + ) + .header( + header::SET_COOKIE, + cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token)) + .expires(OffsetDateTime::now_utc() + Duration::days(365)) + .path("/") + .build() + .to_string(), + ) + .header( + header::LOCATION, + state.redirect_url.unwrap_or("/app".to_string()), + ) + .status(StatusCode::FOUND) + .body("") + .unwrap()) +} diff --git a/backend/src/main.rs b/backend/src/main.rs index 479d7d1..da7e0c2 100644 --- a/backend/src/main.rs +++ b/backend/src/main.rs @@ -1,28 +1,27 @@ -use base64::prelude::*; -use chrono::serde::ts_seconds; -use chrono::{DateTime, Utc}; -use cookie::time::{Duration, OffsetDateTime}; +pub mod auth; + +use auth::{ + handle_auth_callback, handle_auth_login, UserToken, COOKIE_AUTH_CSRF, COOKIE_AUTH_NONCE, + COOKIE_AUTH_TOKEN, +}; use dotenvy::dotenv; use futures_util::StreamExt; use hmac::{Hmac, Mac}; -use jwt::{SignWithKey, VerifyWithKey}; +use jwt::VerifyWithKey; use openidconnect::{ core::{ CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim, - CoreIdTokenVerifier, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, - CoreProviderMetadata, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse, - CoreTokenIntrospectionResponse, CoreTokenResponse, + CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata, + CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse, + CoreTokenResponse, }, - reqwest, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, - EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce, - RedirectUrl, Scope, StandardErrorResponse, + reqwest, Client, ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet, + EndpointNotSet, EndpointSet, IssuerUrl, RedirectUrl, StandardErrorResponse, }; -use rand::{distr::Alphanumeric, rng, Rng}; use rocksdb::TransactionDB; use sha2::Sha256; use std::fs; use std::{ - char, collections::HashMap, env, path::PathBuf, @@ -31,12 +30,11 @@ use std::{ }; use tokio::sync::{Mutex, RwLock}; use warp::{ - http::{header, Response, StatusCode}, reject::Reject, ws::{WebSocket, Ws}, Filter, Rejection, Reply, }; -use yrs::{sync::Awareness, Doc, Transact}; +use yrs::{sync::Awareness, types::ToJson, Array, Doc, ReadTxn, Transact, WriteTxn}; use yrs_kvstore::DocOps; use yrs_rocksdb::RocksDBStore; use yrs_warp::{ @@ -72,37 +70,18 @@ type OpenidClient = Client< EndpointMaybeSet, >; -const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token"; -const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce"; -const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf"; - -#[derive(Debug, serde::Serialize, serde::Deserialize)] -struct UserToken { - display_name: String, - email: String, - - // Unused for now, might be checked later do allow some basic session management. - // Issued At Time - #[allow(dead_code)] - #[serde(with = "ts_seconds")] - iat: DateTime, - // JWT Id - #[allow(dead_code)] - jti: String, -} - #[derive(Debug)] struct AuthError; impl Reject for AuthError {} -struct Connection { +pub struct Connection { bcast: BroadcastGroup, db: Arc, #[allow(dyn_drop)] _db_sub: Arc, } -struct Server { +pub struct Server { // There is something to be said for keeping the broadcast group in memory for a bit after all // clients disconnect, but for now we don't bother. pub open_docs: RwLock>>, @@ -137,8 +116,8 @@ impl Server { None => { drop(open_docs); let mut open_docs = self.open_docs.write().await; - let doc = Doc::new(); + let data_dir = self.data_dir.join(name.clone()); if let Err(e) = fs::create_dir_all(&data_dir) { panic!("Was unable to create data directory for note {}, due to the following error. Something is very wrong!\n{}", name, e) @@ -151,7 +130,7 @@ impl Server { let sub = { let db = db.clone(); let name = name.clone(); - doc.observe_update_v1(move |_, e| { + doc.observe_update_v1(move |doc_txn, e| { let txn = RocksDBStore::from(db.transaction()); let i = txn.push_update(&name, &e.update).unwrap(); if i % 128 == 0 { @@ -162,6 +141,7 @@ impl Server { }) .unwrap() }; + { // Load document from DB let mut txn = doc.transact_mut(); @@ -313,173 +293,3 @@ async fn peer(token: UserToken, doc_name: String, ws: WebSocket, connection: Arc token.display_name, token.email, doc_name ); } - -#[derive(Debug, serde::Serialize, serde::Deserialize)] -struct AuthState { - redirect_url: Option, - csrf_token: String, -} - -#[derive(Debug, serde::Deserialize)] -struct AuthLoginParams { - redirect_url: Option, -} - -async fn handle_auth_login( - server: Arc, - params: AuthLoginParams, -) -> Result { - let csrf_token = CsrfToken::new_random(); - let state = AuthState { - redirect_url: params.redirect_url, - csrf_token: csrf_token.clone().into_secret(), - }; - let state = CsrfToken::new( - BASE64_URL_SAFE.encode( - state - .sign_with_key(&server.signing_key) - .map_err(|_| warp::reject::custom(AuthError))?, - ), - ); - // From what I understand I don't need the csrf_token - let (auth_url, _csrf_token, nonce) = server - .openid_client - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - || state, - Nonce::new_random, - ) - .add_scope(Scope::new("openid".to_string())) - .add_scope(Scope::new("profile".to_string())) - .add_scope(Scope::new("email".to_string())) - .url(); - - Ok(Response::builder() - .header( - header::SET_COOKIE, - cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret())) - .path("/") - .http_only(true) - .build() - .to_string(), - ) - .header( - header::SET_COOKIE, - cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret())) - .path("/") - .http_only(true) - .build() - .to_string(), - ) - .header(header::LOCATION, auth_url.as_str()) - .status(StatusCode::FOUND) - .body("Login") - .unwrap()) -} - -#[derive(Debug, serde::Deserialize)] -struct AuthCallbackParams { - code: String, - state: String, -} - -async fn handle_auth_callback( - server: Arc, - nonce: String, - csrf_token: String, - params: AuthCallbackParams, -) -> Result { - let code = AuthorizationCode::new(params.code); - let nonce = Nonce::new(nonce); - let state: AuthState = String::from_utf8( - BASE64_URL_SAFE - .decode(params.state) - .map_err(|_| warp::reject::custom(AuthError))?, - ) - .map_err(|_| warp::reject::custom(AuthError))? - .verify_with_key(&server.signing_key) - .map_err(|_| warp::reject::custom(AuthError))?; - - if state.csrf_token != csrf_token { - return Err(warp::reject::custom(AuthError)); - } - - let token_response = server - .openid_client - .exchange_code(code) - .map_err(|_| warp::reject::custom(AuthError))? - .request_async(&server.http_client) - .await - .map_err(|_| warp::reject::custom(AuthError))?; - - let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier(); - let id_token_claims = token_response - .extra_fields() - .id_token() - .ok_or(warp::reject::custom(AuthError))? - .claims(&id_token_verifier, &nonce) - .map_err(|_| warp::reject::custom(AuthError))?; - - let token_id: String = rng() - .sample_iter(&Alphanumeric) - .take(32) - .map(char::from) - .collect(); - - let display_name = id_token_claims - .given_name() - .ok_or(warp::reject::custom(AuthError))? - .get(None) - .ok_or(warp::reject::custom(AuthError))? - .to_string(); - let email = id_token_claims - .email() - .ok_or(warp::reject::custom(AuthError))? - .to_string(); - - println!("[INFO] Authenticated: {} <{}>", display_name, email); - - let user_token = UserToken { - display_name, - email, - iat: Utc::now(), - jti: token_id, - }; - let user_token = user_token - .sign_with_key(&server.signing_key) - .map_err(|_| warp::reject::custom(AuthError))?; - - Ok(Response::builder() - .header( - header::SET_COOKIE, - cookie::Cookie::build((COOKIE_AUTH_NONCE, "")) - .path("/") - .http_only(true) - .build() - .to_string(), - ) - .header( - header::SET_COOKIE, - cookie::Cookie::build((COOKIE_AUTH_CSRF, "")) - .expires(OffsetDateTime::UNIX_EPOCH) - .path("/") - .http_only(true) - .build() - .to_string(), - ) - .header( - header::SET_COOKIE, - cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token)) - .expires(OffsetDateTime::now_utc() + Duration::days(365)) - .path("/") - .build() - .to_string(), - ) - .header( - header::LOCATION, - state.redirect_url.unwrap_or("/app".to_string()), - ) - .status(StatusCode::FOUND) - .body("") - .unwrap()) -} diff --git a/backend/yrs-warp/src/broadcast.rs b/backend/yrs-warp/src/broadcast.rs index 4e96677..2791dda 100644 --- a/backend/yrs-warp/src/broadcast.rs +++ b/backend/yrs-warp/src/broadcast.rs @@ -169,6 +169,18 @@ impl BroadcastGroup { let stream_task = { let awareness = self.awareness().clone(); tokio::spawn(async move { + // START manual merge of https://github.com/y-crdt/yrs-warp/pull/21 + let payload = { + let mut encoder = EncoderV1::new(); + let awareness = awareness.read().await; + protocol.start(&awareness, &mut encoder)?; + encoder.to_vec() + }; + if !payload.is_empty() { + let mut s = sink.lock().await; + s.send(payload).await.map_err(|e| Error::Other(e.into()))?; + } + // END manual merge while let Some(res) = stream.next().await { let msg = Message::decode_v1(&res.map_err(|e| Error::Other(Box::new(e)))?)?; let reply = Self::handle_msg(&protocol, &awareness, msg).await?; @@ -201,34 +213,34 @@ impl BroadcastGroup { Message::Sync(msg) => match msg { SyncMessage::SyncStep1(state_vector) => { let awareness = awareness.read().await; - protocol.handle_sync_step1(&*awareness, state_vector) + protocol.handle_sync_step1(&awareness, state_vector) } SyncMessage::SyncStep2(update) => { let mut awareness = awareness.write().await; let update = Update::decode_v1(&update)?; - protocol.handle_sync_step2(&mut *awareness, update) + protocol.handle_sync_step2(&mut awareness, update) } SyncMessage::Update(update) => { let mut awareness = awareness.write().await; let update = Update::decode_v1(&update)?; - protocol.handle_sync_step2(&mut *awareness, update) + protocol.handle_sync_step2(&mut awareness, update) } }, Message::Auth(deny_reason) => { let awareness = awareness.read().await; - protocol.handle_auth(&*awareness, deny_reason) + protocol.handle_auth(&awareness, deny_reason) } Message::AwarenessQuery => { let awareness = awareness.read().await; - protocol.handle_awareness_query(&*awareness) + protocol.handle_awareness_query(&awareness) } Message::Awareness(update) => { let mut awareness = awareness.write().await; - protocol.handle_awareness_update(&mut *awareness, update) + protocol.handle_awareness_update(&mut awareness, update) } Message::Custom(tag, data) => { let mut awareness = awareness.write().await; - protocol.missing_handle(&mut *awareness, tag, data) + protocol.missing_handle(&mut awareness, tag, data) } } }