Fix initial sync on the backend
All checks were successful
/ Push Docker image to local registry (push) Successful in 2m59s

This commit is contained in:
kalle 2025-04-15 14:03:18 +02:00
parent e3cee4e596
commit 9646266ea8
3 changed files with 246 additions and 215 deletions

209
backend/src/auth.rs Normal file
View file

@ -0,0 +1,209 @@
use std::sync::Arc;
use base64::prelude::*;
use chrono::serde::ts_seconds;
use chrono::{DateTime, Utc};
use cookie::time::{Duration, OffsetDateTime};
use jwt::{SignWithKey, VerifyWithKey};
use openidconnect::{
core::{CoreIdTokenVerifier, CoreResponseType},
AuthenticationFlow, AuthorizationCode, CsrfToken, Nonce, Scope,
};
use rand::{distr::Alphanumeric, rng, Rng};
use warp::{
http::{header, Response, StatusCode},
reject::Rejection,
reply::Reply,
};
use crate::{AuthError, Server};
pub const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token";
pub const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce";
pub const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf";
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct UserToken {
pub display_name: String,
pub email: String,
// Unused for now, might be checked later do allow some basic session management.
// Issued At Time
#[allow(dead_code)]
#[serde(with = "ts_seconds")]
pub iat: DateTime<Utc>,
// JWT Id
#[allow(dead_code)]
pub jti: String,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct AuthState {
redirect_url: Option<String>,
csrf_token: String,
}
#[derive(Debug, serde::Deserialize)]
pub struct AuthLoginParams {
redirect_url: Option<String>,
}
pub async fn handle_auth_login(
server: Arc<Server>,
params: AuthLoginParams,
) -> Result<impl Reply, Rejection> {
let csrf_token = CsrfToken::new_random();
let state = AuthState {
redirect_url: params.redirect_url,
csrf_token: csrf_token.clone().into_secret(),
};
let state = CsrfToken::new(
BASE64_URL_SAFE.encode(
state
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?,
),
);
// From what I understand I don't need the csrf_token
let (auth_url, _csrf_token, nonce) = server
.openid_client
.authorize_url(
AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
|| state,
Nonce::new_random,
)
.add_scope(Scope::new("openid".to_string()))
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(header::LOCATION, auth_url.as_str())
.status(StatusCode::FOUND)
.body("Login")
.unwrap())
}
#[derive(Debug, serde::Deserialize)]
pub struct AuthCallbackParams {
code: String,
state: String,
}
pub async fn handle_auth_callback(
server: Arc<Server>,
nonce: String,
csrf_token: String,
params: AuthCallbackParams,
) -> Result<impl Reply, Rejection> {
let code = AuthorizationCode::new(params.code);
let nonce = Nonce::new(nonce);
let state: AuthState = String::from_utf8(
BASE64_URL_SAFE
.decode(params.state)
.map_err(|_| warp::reject::custom(AuthError))?,
)
.map_err(|_| warp::reject::custom(AuthError))?
.verify_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
if state.csrf_token != csrf_token {
return Err(warp::reject::custom(AuthError));
}
let token_response = server
.openid_client
.exchange_code(code)
.map_err(|_| warp::reject::custom(AuthError))?
.request_async(&server.http_client)
.await
.map_err(|_| warp::reject::custom(AuthError))?;
let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier();
let id_token_claims = token_response
.extra_fields()
.id_token()
.ok_or(warp::reject::custom(AuthError))?
.claims(&id_token_verifier, &nonce)
.map_err(|_| warp::reject::custom(AuthError))?;
let token_id: String = rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let display_name = id_token_claims
.given_name()
.ok_or(warp::reject::custom(AuthError))?
.get(None)
.ok_or(warp::reject::custom(AuthError))?
.to_string();
let email = id_token_claims
.email()
.ok_or(warp::reject::custom(AuthError))?
.to_string();
println!("[INFO] Authenticated: {} <{}>", display_name, email);
let user_token = UserToken {
display_name,
email,
iat: Utc::now(),
jti: token_id,
};
let user_token = user_token
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, ""))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, ""))
.expires(OffsetDateTime::UNIX_EPOCH)
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token))
.expires(OffsetDateTime::now_utc() + Duration::days(365))
.path("/")
.build()
.to_string(),
)
.header(
header::LOCATION,
state.redirect_url.unwrap_or("/app".to_string()),
)
.status(StatusCode::FOUND)
.body("")
.unwrap())
}

View file

@ -1,28 +1,27 @@
use base64::prelude::*;
use chrono::serde::ts_seconds;
use chrono::{DateTime, Utc};
use cookie::time::{Duration, OffsetDateTime};
pub mod auth;
use auth::{
handle_auth_callback, handle_auth_login, UserToken, COOKIE_AUTH_CSRF, COOKIE_AUTH_NONCE,
COOKIE_AUTH_TOKEN,
};
use dotenvy::dotenv;
use futures_util::StreamExt;
use hmac::{Hmac, Mac};
use jwt::{SignWithKey, VerifyWithKey};
use jwt::VerifyWithKey;
use openidconnect::{
core::{
CoreAuthDisplay, CoreAuthPrompt, CoreClient, CoreErrorResponseType, CoreGenderClaim,
CoreIdTokenVerifier, CoreJsonWebKey, CoreJweContentEncryptionAlgorithm,
CoreProviderMetadata, CoreResponseType, CoreRevocableToken, CoreRevocationErrorResponse,
CoreTokenIntrospectionResponse, CoreTokenResponse,
CoreJsonWebKey, CoreJweContentEncryptionAlgorithm, CoreProviderMetadata,
CoreRevocableToken, CoreRevocationErrorResponse, CoreTokenIntrospectionResponse,
CoreTokenResponse,
},
reqwest, AuthenticationFlow, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EmptyAdditionalClaims, EndpointMaybeSet, EndpointNotSet, EndpointSet, IssuerUrl, Nonce,
RedirectUrl, Scope, StandardErrorResponse,
reqwest, Client, ClientId, ClientSecret, EmptyAdditionalClaims, EndpointMaybeSet,
EndpointNotSet, EndpointSet, IssuerUrl, RedirectUrl, StandardErrorResponse,
};
use rand::{distr::Alphanumeric, rng, Rng};
use rocksdb::TransactionDB;
use sha2::Sha256;
use std::fs;
use std::{
char,
collections::HashMap,
env,
path::PathBuf,
@ -31,12 +30,11 @@ use std::{
};
use tokio::sync::{Mutex, RwLock};
use warp::{
http::{header, Response, StatusCode},
reject::Reject,
ws::{WebSocket, Ws},
Filter, Rejection, Reply,
};
use yrs::{sync::Awareness, Doc, Transact};
use yrs::{sync::Awareness, types::ToJson, Array, Doc, ReadTxn, Transact, WriteTxn};
use yrs_kvstore::DocOps;
use yrs_rocksdb::RocksDBStore;
use yrs_warp::{
@ -72,37 +70,18 @@ type OpenidClient = Client<
EndpointMaybeSet,
>;
const COOKIE_AUTH_TOKEN: &str = "knotes_auth_token";
const COOKIE_AUTH_NONCE: &str = "knotes_auth_nonce";
const COOKIE_AUTH_CSRF: &str = "knotes_auth_csrf";
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct UserToken {
display_name: String,
email: String,
// Unused for now, might be checked later do allow some basic session management.
// Issued At Time
#[allow(dead_code)]
#[serde(with = "ts_seconds")]
iat: DateTime<Utc>,
// JWT Id
#[allow(dead_code)]
jti: String,
}
#[derive(Debug)]
struct AuthError;
impl Reject for AuthError {}
struct Connection {
pub struct Connection {
bcast: BroadcastGroup,
db: Arc<TransactionDB>,
#[allow(dyn_drop)]
_db_sub: Arc<dyn Drop + Send + Sync>,
}
struct Server {
pub struct Server {
// There is something to be said for keeping the broadcast group in memory for a bit after all
// clients disconnect, but for now we don't bother.
pub open_docs: RwLock<HashMap<String, Weak<Connection>>>,
@ -137,8 +116,8 @@ impl Server {
None => {
drop(open_docs);
let mut open_docs = self.open_docs.write().await;
let doc = Doc::new();
let data_dir = self.data_dir.join(name.clone());
if let Err(e) = fs::create_dir_all(&data_dir) {
panic!("Was unable to create data directory for note {}, due to the following error. Something is very wrong!\n{}", name, e)
@ -151,7 +130,7 @@ impl Server {
let sub = {
let db = db.clone();
let name = name.clone();
doc.observe_update_v1(move |_, e| {
doc.observe_update_v1(move |doc_txn, e| {
let txn = RocksDBStore::from(db.transaction());
let i = txn.push_update(&name, &e.update).unwrap();
if i % 128 == 0 {
@ -162,6 +141,7 @@ impl Server {
})
.unwrap()
};
{
// Load document from DB
let mut txn = doc.transact_mut();
@ -313,173 +293,3 @@ async fn peer(token: UserToken, doc_name: String, ws: WebSocket, connection: Arc
token.display_name, token.email, doc_name
);
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct AuthState {
redirect_url: Option<String>,
csrf_token: String,
}
#[derive(Debug, serde::Deserialize)]
struct AuthLoginParams {
redirect_url: Option<String>,
}
async fn handle_auth_login(
server: Arc<Server>,
params: AuthLoginParams,
) -> Result<impl Reply, Rejection> {
let csrf_token = CsrfToken::new_random();
let state = AuthState {
redirect_url: params.redirect_url,
csrf_token: csrf_token.clone().into_secret(),
};
let state = CsrfToken::new(
BASE64_URL_SAFE.encode(
state
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?,
),
);
// From what I understand I don't need the csrf_token
let (auth_url, _csrf_token, nonce) = server
.openid_client
.authorize_url(
AuthenticationFlow::<CoreResponseType>::AuthorizationCode,
|| state,
Nonce::new_random,
)
.add_scope(Scope::new("openid".to_string()))
.add_scope(Scope::new("profile".to_string()))
.add_scope(Scope::new("email".to_string()))
.url();
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, nonce.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, csrf_token.secret()))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(header::LOCATION, auth_url.as_str())
.status(StatusCode::FOUND)
.body("Login")
.unwrap())
}
#[derive(Debug, serde::Deserialize)]
struct AuthCallbackParams {
code: String,
state: String,
}
async fn handle_auth_callback(
server: Arc<Server>,
nonce: String,
csrf_token: String,
params: AuthCallbackParams,
) -> Result<impl Reply, Rejection> {
let code = AuthorizationCode::new(params.code);
let nonce = Nonce::new(nonce);
let state: AuthState = String::from_utf8(
BASE64_URL_SAFE
.decode(params.state)
.map_err(|_| warp::reject::custom(AuthError))?,
)
.map_err(|_| warp::reject::custom(AuthError))?
.verify_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
if state.csrf_token != csrf_token {
return Err(warp::reject::custom(AuthError));
}
let token_response = server
.openid_client
.exchange_code(code)
.map_err(|_| warp::reject::custom(AuthError))?
.request_async(&server.http_client)
.await
.map_err(|_| warp::reject::custom(AuthError))?;
let id_token_verifier: CoreIdTokenVerifier = server.openid_client.id_token_verifier();
let id_token_claims = token_response
.extra_fields()
.id_token()
.ok_or(warp::reject::custom(AuthError))?
.claims(&id_token_verifier, &nonce)
.map_err(|_| warp::reject::custom(AuthError))?;
let token_id: String = rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
let display_name = id_token_claims
.given_name()
.ok_or(warp::reject::custom(AuthError))?
.get(None)
.ok_or(warp::reject::custom(AuthError))?
.to_string();
let email = id_token_claims
.email()
.ok_or(warp::reject::custom(AuthError))?
.to_string();
println!("[INFO] Authenticated: {} <{}>", display_name, email);
let user_token = UserToken {
display_name,
email,
iat: Utc::now(),
jti: token_id,
};
let user_token = user_token
.sign_with_key(&server.signing_key)
.map_err(|_| warp::reject::custom(AuthError))?;
Ok(Response::builder()
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_NONCE, ""))
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_CSRF, ""))
.expires(OffsetDateTime::UNIX_EPOCH)
.path("/")
.http_only(true)
.build()
.to_string(),
)
.header(
header::SET_COOKIE,
cookie::Cookie::build((COOKIE_AUTH_TOKEN, user_token))
.expires(OffsetDateTime::now_utc() + Duration::days(365))
.path("/")
.build()
.to_string(),
)
.header(
header::LOCATION,
state.redirect_url.unwrap_or("/app".to_string()),
)
.status(StatusCode::FOUND)
.body("")
.unwrap())
}

View file

@ -169,6 +169,18 @@ impl BroadcastGroup {
let stream_task = {
let awareness = self.awareness().clone();
tokio::spawn(async move {
// START manual merge of https://github.com/y-crdt/yrs-warp/pull/21
let payload = {
let mut encoder = EncoderV1::new();
let awareness = awareness.read().await;
protocol.start(&awareness, &mut encoder)?;
encoder.to_vec()
};
if !payload.is_empty() {
let mut s = sink.lock().await;
s.send(payload).await.map_err(|e| Error::Other(e.into()))?;
}
// END manual merge
while let Some(res) = stream.next().await {
let msg = Message::decode_v1(&res.map_err(|e| Error::Other(Box::new(e)))?)?;
let reply = Self::handle_msg(&protocol, &awareness, msg).await?;
@ -201,34 +213,34 @@ impl BroadcastGroup {
Message::Sync(msg) => match msg {
SyncMessage::SyncStep1(state_vector) => {
let awareness = awareness.read().await;
protocol.handle_sync_step1(&*awareness, state_vector)
protocol.handle_sync_step1(&awareness, state_vector)
}
SyncMessage::SyncStep2(update) => {
let mut awareness = awareness.write().await;
let update = Update::decode_v1(&update)?;
protocol.handle_sync_step2(&mut *awareness, update)
protocol.handle_sync_step2(&mut awareness, update)
}
SyncMessage::Update(update) => {
let mut awareness = awareness.write().await;
let update = Update::decode_v1(&update)?;
protocol.handle_sync_step2(&mut *awareness, update)
protocol.handle_sync_step2(&mut awareness, update)
}
},
Message::Auth(deny_reason) => {
let awareness = awareness.read().await;
protocol.handle_auth(&*awareness, deny_reason)
protocol.handle_auth(&awareness, deny_reason)
}
Message::AwarenessQuery => {
let awareness = awareness.read().await;
protocol.handle_awareness_query(&*awareness)
protocol.handle_awareness_query(&awareness)
}
Message::Awareness(update) => {
let mut awareness = awareness.write().await;
protocol.handle_awareness_update(&mut *awareness, update)
protocol.handle_awareness_update(&mut awareness, update)
}
Message::Custom(tag, data) => {
let mut awareness = awareness.write().await;
protocol.missing_handle(&mut *awareness, tag, data)
protocol.missing_handle(&mut awareness, tag, data)
}
}
}