v3/backend/src/main.rs

137 lines
4.3 KiB
Rust

use futures_util::StreamExt;
use rocksdb::TransactionDB;
use std::{
collections::HashMap,
sync::{Arc, Weak},
};
use tokio::sync::{Mutex, RwLock};
use warp::{
ws::{WebSocket, Ws},
Filter, Rejection, Reply,
};
use yrs::{sync::Awareness, Doc, Transact};
use yrs_kvstore::DocOps;
use yrs_rocksdb::RocksDBStore;
use yrs_warp::{
broadcast::BroadcastGroup,
ws::{WarpSink, WarpStream},
};
const DATA_DIR: &str = "data";
const STATIC_DIR: &str = "../frontend/dist";
const INDEX_HTML: &str = "../frontend/dist/index.html";
// Web socket flow:
// -> Handle connection (ws_handler)
// -> Check permissions (Not doing for new)
// -> Get or load document
// ->
//
//
struct Connection {
bcast: BroadcastGroup,
// This is purely here to keep the connection to the DB alive while there is at least one
// connection.
_db_subscription: Arc<dyn Drop + Send + Sync>,
}
struct Server {
// There is something to be said for keeping the broadcast group in memory for a bit after all
// clients disconnect, but for now we don't bother.
pub open_docs: RwLock<HashMap<String, Weak<Connection>>>,
pub db: Arc<TransactionDB>,
}
impl Server {
pub fn new(db: TransactionDB) -> Self {
Self {
open_docs: RwLock::default(),
db: Arc::new(db),
}
}
pub async fn get_or_create_doc(&self, name: String) -> Arc<Connection> {
let open_docs = self.open_docs.read().await;
match open_docs.get(&name).and_then(Weak::upgrade) {
Some(group) => group.clone(),
None => {
drop(open_docs);
let mut open_docs = self.open_docs.write().await;
let doc = Doc::new();
// Subscribe the DB to updates
let sub = {
let db = self.db.clone();
let name = name.clone();
doc.observe_update_v1(move |_, e| {
let txn = RocksDBStore::from(db.transaction());
let i = txn.push_update(&name, &e.update).unwrap();
if i % 128 == 0 {
// compact updates into document
txn.flush_doc(&name).unwrap();
}
txn.commit().unwrap();
})
.unwrap()
};
{
// Load document from DB
let mut txn = doc.transact_mut();
let db_txn = RocksDBStore::from(self.db.transaction());
db_txn.load_doc(&name, &mut txn).unwrap();
}
let awareness = Arc::new(RwLock::new(Awareness::new(doc)));
let group = BroadcastGroup::new(awareness, 32).await;
let connection = Arc::new(Connection {
bcast: group,
_db_subscription: sub,
});
open_docs.insert(name, Arc::downgrade(&connection));
connection
}
}
}
}
#[tokio::main]
async fn main() {
let db = TransactionDB::open_default(DATA_DIR).expect("Failed to open DB");
let server = Arc::new(Server::new(db));
let ws = warp::path("sync")
.and(warp::path::param())
.and(warp::ws())
.and(warp::any().map(move || server.clone()))
.and_then(ws_handler);
let static_files = warp::fs::dir(STATIC_DIR);
let index = warp::fs::file(INDEX_HTML);
let routes = ws.or(static_files).or(index);
warp::serve(routes).run(([0, 0, 0, 0], 9000)).await;
}
async fn ws_handler(name: String, ws: Ws, server: Arc<Server>) -> Result<impl Reply, Rejection> {
// TODO: Check permissions before upgrading
println!("ws_handler: {}", name);
let connection = server.get_or_create_doc(name).await;
Ok(ws.on_upgrade(move |socket| peer(socket, connection)))
}
async fn peer(ws: WebSocket, connection: Arc<Connection>) {
let (sink, stream) = ws.split();
let sink = Arc::new(Mutex::new(WarpSink::from(sink)));
let stream = WarpStream::from(stream);
let sub = connection.bcast.subscribe(sink, stream);
match sub.completed().await {
Ok(_) => println!("broadcasting for channel finished successfully"),
Err(e) => eprintln!("broadcasting for channel finished abruptly: {}", e),
}
}