v3/backend/yrs-warp/src/signaling.rs

330 lines
10 KiB
Rust

use futures_util::stream::SplitSink;
use futures_util::{SinkExt, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use std::time::Duration;
use tokio::select;
use tokio::sync::{Mutex, RwLock};
use tokio::time::interval;
use warp::ws::{Message, WebSocket};
use warp::Error;
const PING_TIMEOUT: Duration = Duration::from_secs(30);
/// Signaling service is used by y-webrtc protocol in order to exchange WebRTC offerings between
/// clients subscribing to particular rooms.
///
/// # Example
///
/// ```rust
/// use warp::{Filter, Rejection, Reply};
/// use warp::ws::{Ws, WebSocket};
/// use yrs_warp::signaling::{SignalingService, signaling_conn};
///
/// fn main() {
/// let signaling = SignalingService::new();
/// let ws = warp::path("signaling")
/// .and(warp::ws())
/// .and(warp::any().map(move || signaling.clone()))
/// .and_then(ws_handler);
///
/// //warp::serve(routes).run(([0, 0, 0, 0], 8000)).await;
/// }
///
/// async fn ws_handler(ws: Ws, svc: SignalingService) -> Result<impl Reply, Rejection> {
/// Ok(ws.on_upgrade(move |socket| peer(socket, svc)))
/// }
///
/// async fn peer(ws: WebSocket, svc: SignalingService) {
/// match signaling_conn(ws, svc).await {
/// Ok(_) => println!("signaling connection stopped"),
/// Err(e) => eprintln!("signaling connection failed: {}", e),
/// }
/// }
/// ```
#[derive(Debug, Clone)]
pub struct SignalingService(Topics);
impl SignalingService {
pub fn new() -> Self {
SignalingService(Arc::new(RwLock::new(Default::default())))
}
pub async fn publish(&self, topic: &str, msg: Message) -> Result<(), Error> {
let mut failed = Vec::new();
{
let topics = self.0.read().await;
if let Some(subs) = topics.get(topic) {
let client_count = subs.len();
tracing::info!("publishing message to {client_count} clients: {msg:?}");
for sub in subs {
if let Err(e) = sub.try_send(msg.clone()).await {
tracing::info!("failed to send {msg:?}: {e}");
failed.push(sub.clone());
}
}
}
}
if !failed.is_empty() {
let mut topics = self.0.write().await;
if let Some(subs) = topics.get_mut(topic) {
for f in failed {
subs.remove(&f);
}
}
}
Ok(())
}
pub async fn close_topic(&self, topic: &str) -> Result<(), Error> {
let mut topics = self.0.write().await;
if let Some(subs) = topics.remove(topic) {
for sub in subs {
if let Err(e) = sub.close().await {
tracing::warn!("failed to close connection on topic '{topic}': {e}");
}
}
}
Ok(())
}
pub async fn close(self) -> Result<(), Error> {
let mut topics = self.0.write_owned().await;
let mut all_conns = HashSet::new();
for (_, subs) in topics.drain() {
for sub in subs {
all_conns.insert(sub);
}
}
for conn in all_conns {
if let Err(e) = conn.close().await {
tracing::warn!("failed to close connection: {e}");
}
}
Ok(())
}
}
impl Default for SignalingService {
fn default() -> Self {
Self::new()
}
}
type Topics = Arc<RwLock<HashMap<Arc<str>, HashSet<WsSink>>>>;
#[derive(Debug, Clone)]
struct WsSink(Arc<Mutex<SplitSink<WebSocket, Message>>>);
impl WsSink {
fn new(sink: SplitSink<WebSocket, Message>) -> Self {
WsSink(Arc::new(Mutex::new(sink)))
}
async fn try_send(&self, msg: Message) -> Result<(), Error> {
let mut sink = self.0.lock().await;
if let Err(e) = sink.send(msg).await {
sink.close().await?;
Err(e)
} else {
Ok(())
}
}
async fn close(&self) -> Result<(), Error> {
let mut sink = self.0.lock().await;
sink.close().await
}
}
impl Hash for WsSink {
fn hash<H: Hasher>(&self, state: &mut H) {
let ptr = Arc::as_ptr(&self.0) as usize;
ptr.hash(state);
}
}
impl PartialEq<Self> for WsSink {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.0, &other.0)
}
}
impl Eq for WsSink {}
/// Handle incoming signaling connection - it's a websocket connection used by y-webrtc protocol
/// to exchange offering metadata between y-webrtc peers. It also manages topic/room access.
pub async fn signaling_conn(ws: WebSocket, service: SignalingService) -> Result<(), Error> {
let mut topics: Topics = service.0;
let (sink, mut stream) = ws.split();
let ws = WsSink::new(sink);
let mut ping_interval = interval(PING_TIMEOUT);
let mut state = ConnState::default();
loop {
select! {
_ = ping_interval.tick() => {
if !state.pong_received {
ws.close().await?;
drop(ping_interval);
return Ok(());
} else {
state.pong_received = false;
if let Err(e) = ws.try_send(Message::ping(Vec::default())).await {
ws.close().await?;
return Err(e);
}
}
},
res = stream.next() => {
match res {
None => {
ws.close().await?;
return Ok(());
},
Some(Err(e)) => {
ws.close().await?;
return Err(e);
},
Some(Ok(msg)) => {
process_msg(msg, &ws, &mut state, &mut topics).await?;
}
}
}
}
}
}
const PING_MSG: &'static str = r#"{"type":"ping"}"#;
const PONG_MSG: &'static str = r#"{"type":"pong"}"#;
async fn process_msg(
msg: Message,
ws: &WsSink,
state: &mut ConnState,
topics: &mut Topics,
) -> Result<(), Error> {
if msg.is_text() {
let json = msg.to_str().unwrap();
let msg = serde_json::from_str(json).unwrap();
match msg {
Signal::Subscribe {
topics: topic_names,
} => {
if !topic_names.is_empty() {
let mut topics = topics.write().await;
for topic in topic_names {
tracing::trace!("subscribing new client to '{topic}'");
if let Some((key, _)) = topics.get_key_value(topic) {
state.subscribed_topics.insert(key.clone());
let subs = topics.get_mut(topic).unwrap();
subs.insert(ws.clone());
} else {
let topic: Arc<str> = topic.into();
state.subscribed_topics.insert(topic.clone());
let mut subs = HashSet::new();
subs.insert(ws.clone());
topics.insert(topic, subs);
};
}
}
}
Signal::Unsubscribe {
topics: topic_names,
} => {
if !topic_names.is_empty() {
let mut topics = topics.write().await;
for topic in topic_names {
if let Some(subs) = topics.get_mut(topic) {
tracing::trace!("unsubscribing client from '{topic}'");
subs.remove(ws);
}
}
}
}
Signal::Publish { topic } => {
let mut failed = Vec::new();
{
let topics = topics.read().await;
if let Some(receivers) = topics.get(topic) {
let client_count = receivers.len();
tracing::trace!(
"publishing on {client_count} clients at '{topic}': {json}"
);
for receiver in receivers.iter() {
if let Err(e) = receiver.try_send(Message::text(json)).await {
tracing::info!(
"failed to publish message {json} on '{topic}': {e}"
);
failed.push(receiver.clone());
}
}
}
}
if !failed.is_empty() {
let mut topics = topics.write().await;
if let Some(receivers) = topics.get_mut(topic) {
for f in failed {
receivers.remove(&f);
}
}
}
}
Signal::Ping => {
ws.try_send(Message::text(PONG_MSG)).await?;
}
Signal::Pong => {
ws.try_send(Message::text(PING_MSG)).await?;
}
}
} else if msg.is_close() {
let mut topics = topics.write().await;
for topic in state.subscribed_topics.drain() {
if let Some(subs) = topics.get_mut(&topic) {
subs.remove(ws);
if subs.is_empty() {
topics.remove(&topic);
}
}
}
state.closed = true;
} else if msg.is_ping() {
ws.try_send(Message::ping(Vec::default())).await?;
}
Ok(())
}
#[derive(Debug)]
struct ConnState {
closed: bool,
pong_received: bool,
subscribed_topics: HashSet<Arc<str>>,
}
impl Default for ConnState {
fn default() -> Self {
ConnState {
closed: false,
pong_received: true,
subscribed_topics: HashSet::new(),
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type")]
pub(crate) enum Signal<'a> {
#[serde(rename = "publish")]
Publish { topic: &'a str },
#[serde(rename = "subscribe")]
Subscribe { topics: Vec<&'a str> },
#[serde(rename = "unsubscribe")]
Unsubscribe { topics: Vec<&'a str> },
#[serde(rename = "ping")]
Ping,
#[serde(rename = "pong")]
Pong,
}