use futures_util::StreamExt; use rocksdb::TransactionDB; use std::{ collections::HashMap, sync::{Arc, Weak}, }; use tokio::sync::{Mutex, RwLock}; use warp::{ ws::{WebSocket, Ws}, Filter, Rejection, Reply, }; use yrs::{sync::Awareness, Doc, Transact}; use yrs_kvstore::DocOps; use yrs_rocksdb::RocksDBStore; use yrs_warp::{ broadcast::BroadcastGroup, ws::{WarpSink, WarpStream}, }; const DATA_DIR: &str = "data"; const STATIC_DIR: &str = "../frontend/dist"; const INDEX_HTML: &str = "../frontend/dist/index.html"; // Web socket flow: // -> Handle connection (ws_handler) // -> Check permissions (Not doing for new) // -> Get or load document // -> // // struct Connection { bcast: BroadcastGroup, // This is purely here to keep the connection to the DB alive while there is at least one // connection. _db_subscription: Arc, } struct Server { // There is something to be said for keeping the broadcast group in memory for a bit after all // clients disconnect, but for now we don't bother. pub open_docs: RwLock>>, pub db: Arc, } impl Server { pub fn new(db: TransactionDB) -> Self { Self { open_docs: RwLock::default(), db: Arc::new(db), } } pub async fn get_or_create_doc(&self, name: String) -> Arc { let open_docs = self.open_docs.read().await; match open_docs.get(&name).and_then(Weak::upgrade) { Some(group) => group.clone(), None => { drop(open_docs); let mut open_docs = self.open_docs.write().await; let doc = Doc::new(); // Subscribe the DB to updates let sub = { let db = self.db.clone(); let name = name.clone(); doc.observe_update_v1(move |_, e| { let txn = RocksDBStore::from(db.transaction()); let i = txn.push_update(&name, &e.update).unwrap(); if i % 128 == 0 { // compact updates into document txn.flush_doc(&name).unwrap(); } txn.commit().unwrap(); }) .unwrap() }; { // Load document from DB let mut txn = doc.transact_mut(); let db_txn = RocksDBStore::from(self.db.transaction()); db_txn.load_doc(&name, &mut txn).unwrap(); } let awareness = Arc::new(RwLock::new(Awareness::new(doc))); let group = BroadcastGroup::new(awareness, 32).await; let connection = Arc::new(Connection { bcast: group, _db_subscription: sub, }); open_docs.insert(name, Arc::downgrade(&connection)); connection } } } } #[tokio::main] async fn main() { let db = TransactionDB::open_default(DATA_DIR).expect("Failed to open DB"); let server = Arc::new(Server::new(db)); let ws = warp::path("sync") .and(warp::path::param()) .and(warp::ws()) .and(warp::any().map(move || server.clone())) .and_then(ws_handler); let static_files = warp::fs::dir(STATIC_DIR); let index = warp::fs::file(INDEX_HTML); let routes = ws.or(static_files).or(index); warp::serve(routes).run(([0, 0, 0, 0], 9000)).await; } async fn ws_handler(name: String, ws: Ws, server: Arc) -> Result { // TODO: Check permissions before upgrading println!("ws_handler: {}", name); let connection = server.get_or_create_doc(name).await; Ok(ws.on_upgrade(move |socket| peer(socket, connection))) } async fn peer(ws: WebSocket, connection: Arc) { let (sink, stream) = ws.split(); let sink = Arc::new(Mutex::new(WarpSink::from(sink))); let stream = WarpStream::from(stream); let sub = connection.bcast.subscribe(sink, stream); match sub.completed().await { Ok(_) => println!("broadcasting for channel finished successfully"), Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e), } }