diff --git a/Cargo.lock b/Cargo.lock index f185686..8df1bb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -185,6 +185,20 @@ dependencies = [ "syn", ] +[[package]] +name = "actix-ws" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12d4f2fbee3ef7a22fa6cb0e416b962237a167ed0419f22d4e451da2d7f082f8" +dependencies = [ + "actix-codec", + "actix-http", + "actix-web", + "bytestring", + "futures-core", + "tokio", +] + [[package]] name = "adler2" version = "2.0.1" @@ -534,6 +548,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -541,6 +570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -549,6 +579,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -567,9 +625,13 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -1468,9 +1530,21 @@ dependencies = [ "pin-project-lite", "signal-hook-registry", "socket2 0.6.2", + "tokio-macros", "windows-sys 0.61.2", ] +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1695,8 +1769,11 @@ name = "wg-mesh" version = "0.1.0" dependencies = [ "actix-web", + "actix-ws", "base64", "console", + "futures", + "futures-util", "ipnetwork", "redis", "serde", diff --git a/Cargo.toml b/Cargo.toml index 41991fa..7f532ef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,15 +5,26 @@ edition = "2024" [dependencies] actix-web = "4.12.1" +actix-ws = "0.3.1" base64 = "0.22.1" console = "0.16.2" +futures = "0.3.31" +futures-util = "0.3.31" ipnetwork = { version = "0.21.1", features = ["serde"] } redis = { version = "=1.0.2", features = ["connection-manager", "tokio-comp"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" thiserror = "2.0.18" thiserror-ext = "0.3.0" -tokio = "1.49.0" +tokio = { version = "1.49.0", features = ["macros", "sync"] } tracing = "0.1.44" tracing-actix-web = "0.7.21" tracing-subscriber = { version = "0.3.22", features = ["env-filter"] } + +[[bin]] +name = "registry" +path = "registry/main.rs" + +[[bin]] +name = "client" +path = "client/main.rs" diff --git a/client/main.rs b/client/main.rs new file mode 100644 index 0000000..f328e4d --- /dev/null +++ b/client/main.rs @@ -0,0 +1 @@ +fn main() {} diff --git a/justfile b/justfile index 2fb610f..0e59899 100644 --- a/justfile +++ b/justfile @@ -1,2 +1,2 @@ -dev: - bacon dev +dev bin="registry": + bacon dev -- --bin {{bin}} diff --git a/src/endpoints/mod.rs b/registry/endpoints/mod.rs similarity index 73% rename from src/endpoints/mod.rs rename to registry/endpoints/mod.rs index 8a9c71b..7175a3d 100644 --- a/src/endpoints/mod.rs +++ b/registry/endpoints/mod.rs @@ -1,2 +1,3 @@ pub mod peers; pub mod register; +pub mod ws; diff --git a/registry/endpoints/peers.rs b/registry/endpoints/peers.rs new file mode 100644 index 0000000..24211c8 --- /dev/null +++ b/registry/endpoints/peers.rs @@ -0,0 +1,8 @@ +use actix_web::{HttpResponse, web}; + +use crate::{AppState, error::Result, storage::StorageImpl}; + +pub async fn get_peers(app_state: web::Data) -> Result { + let peers = app_state.storage.get_peers().await?; + Ok(HttpResponse::Ok().json(peers)) +} diff --git a/registry/endpoints/register.rs b/registry/endpoints/register.rs new file mode 100644 index 0000000..a77ad09 --- /dev/null +++ b/registry/endpoints/register.rs @@ -0,0 +1,27 @@ +use crate::{ + AppState, PeerUpdate, + error::Result, + storage::{RegisterRequest, StorageImpl}, + utils::Peer, +}; +use actix_web::{HttpResponse, web}; + +pub async fn register_peer( + app_state: web::Data, + request: web::Json, +) -> Result { + app_state.storage.register_device(&request).await?; + app_state + .peer_updates + .send(PeerUpdate { + peer: Peer { + public_key: request.public_key.as_str().to_string(), + public_ip: request.public_ip.to_string(), + port: request.port.clone(), + allowed_ips: request.allowed_ips.clone(), + }, + }) + .unwrap(); + + Ok(HttpResponse::Ok().finish()) +} diff --git a/registry/endpoints/ws/mod.rs b/registry/endpoints/ws/mod.rs new file mode 100644 index 0000000..d9cebd1 --- /dev/null +++ b/registry/endpoints/ws/mod.rs @@ -0,0 +1 @@ +pub mod peers; diff --git a/registry/endpoints/ws/peers.rs b/registry/endpoints/ws/peers.rs new file mode 100644 index 0000000..d0dc779 --- /dev/null +++ b/registry/endpoints/ws/peers.rs @@ -0,0 +1,87 @@ +use crate::{AppState, error::Error, storage::StorageImpl}; +use actix_web::{HttpRequest, HttpResponse, rt, web}; +use actix_ws::AggregatedMessage; +use futures_util::StreamExt; + +pub async fn peers( + req: HttpRequest, + stream: web::Payload, + app_state: web::Data, +) -> Result { + let (res, mut session, msg_stream) = + actix_ws::handle(&req, stream).map_err(|e| Error::ws(e))?; + + let mut msg_stream = msg_stream.aggregate_continuations(); + + let mut peer_rx = app_state.peer_updates.subscribe(); + + match app_state.storage.get_peers().await { + Ok(initial_peers) => { + let json = serde_json::to_string(&initial_peers).unwrap_or_else(|_| "[]".to_string()); + if session.text(json).await.is_err() { + return Ok(res); + } + tracing::info!( + "sent initial peer list ({} peers) to new WebSocket client", + initial_peers.len() + ); + } + Err(e) => { + tracing::warn!("failed to fetch initial peers: {:?}", e); + let _ = session.close(None).await; + return Ok(res); + } + } + + rt::spawn(async move { + loop { + tokio::select! { + msg = msg_stream.next() => { + match msg { + Some(Ok(AggregatedMessage::Ping(data))) => { + if session.pong(&data).await.is_err() { + break; + } + } + Some(Ok(AggregatedMessage::Pong(_))) => {} + Some(Ok(AggregatedMessage::Close(_))) => { + break; + } + Some(Ok(AggregatedMessage::Text(_))) => { + } + Some(Ok(AggregatedMessage::Binary(_))) => { + } + Some(Err(_)) => { + break; + } + None => { + break; + } + } + } + update = peer_rx.recv() => { + match update { + Ok(peer_update) => { + let json = serde_json::to_string(&peer_update.peer) + .unwrap_or_else(|_| "{}".to_string()); + if session.text(json).await.is_err() { + break; + } + tracing::info!("sent peer update to WebSocket client: {}", peer_update.peer.public_key); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + tracing::warn!("WebSocket client lagged, missed {} updates", n); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => { + break; + } + } + } + } + } + let _ = session.close(None).await; + tracing::info!("WebSocket client disconnected"); + }); + + Ok(res) +} diff --git a/src/error.rs b/registry/error.rs similarity index 94% rename from src/error.rs rename to registry/error.rs index 47bc7e2..6d277a7 100644 --- a/src/error.rs +++ b/registry/error.rs @@ -30,6 +30,8 @@ pub enum ErrorKind { #[source] source: serde_json::Error, }, + #[error("error handling ws")] + Ws(#[source] actix_web::Error), } impl ResponseError for Error { diff --git a/src/main.rs b/registry/main.rs similarity index 75% rename from src/main.rs rename to registry/main.rs index 5d21b73..ecd5491 100644 --- a/src/main.rs +++ b/registry/main.rs @@ -1,25 +1,36 @@ mod endpoints; mod error; +mod storage; mod utils; use actix_web::{App, HttpServer, web}; use console::style; use thiserror_ext::AsReport; +use tokio::sync::broadcast; use tracing::level_filters::LevelFilter; use tracing_subscriber::{ EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt, }; -use crate::error::Error; +use crate::storage::{Storage, get_storage_from_env}; +use crate::utils::Peer; -struct AppState { - valkey_client: redis::Client, +#[derive(Clone, Debug)] +pub struct PeerUpdate { + pub peer: Peer, +} + +pub struct AppState { + pub peer_updates: broadcast::Sender, + pub storage: Storage, } async fn run() -> crate::error::Result<()> { + let (peer_tx, _) = broadcast::channel::(100); + let app_state = web::Data::new(AppState { - valkey_client: redis::Client::open("redis://127.0.0.1:6379/") - .map_err(|e| Error::valkey_connect(e, "127.0.0.1:6379/".to_string()))?, + storage: get_storage_from_env()?, + peer_updates: peer_tx, }); HttpServer::new(move || { @@ -36,6 +47,7 @@ async fn run() -> crate::error::Result<()> { web::post().to(endpoints::register::register_peer), ) .route("/peers", web::get().to(endpoints::peers::get_peers)) + .route("/ws/peers", web::get().to(endpoints::ws::peers::peers)) }) .bind(("0.0.0.0", 8080))? .run() diff --git a/registry/storage/mod.rs b/registry/storage/mod.rs new file mode 100644 index 0000000..8c8cc33 --- /dev/null +++ b/registry/storage/mod.rs @@ -0,0 +1,38 @@ +use crate::{ + error::{Error, Result}, + utils::Peer, +}; + +mod valkey; + +pub use valkey::RegisterRequest; + +pub enum Storage { + Valkey(valkey::ValkeyStorage), +} + +pub trait StorageImpl { + async fn register_device(&self, request: &RegisterRequest) -> Result<()>; + async fn get_peers(&self) -> Result>; +} + +impl StorageImpl for Storage { + async fn register_device(&self, request: &RegisterRequest) -> Result<()> { + match self { + Self::Valkey(storage) => storage.register_device(request).await, + } + } + + async fn get_peers(&self) -> Result> { + match self { + Self::Valkey(storage) => storage.get_peers().await, + } + } +} + +pub fn get_storage_from_env() -> Result { + Ok(Storage::Valkey(valkey::ValkeyStorage { + valkey_client: redis::Client::open("redis://127.0.0.1:6379/") + .map_err(|e| Error::valkey_connect(e, "127.0.0.1:6379/".to_string()))?, + })) +} diff --git a/registry/storage/valkey.rs b/registry/storage/valkey.rs new file mode 100644 index 0000000..ba81920 --- /dev/null +++ b/registry/storage/valkey.rs @@ -0,0 +1,108 @@ +use std::collections::{HashMap, HashSet}; +use std::net::IpAddr; + +use ipnetwork::IpNetwork; +use redis::AsyncTypedCommands; +use serde::Deserialize; + +use crate::error::Result; +use crate::utils::{Peer, WireguardPublicKey}; +use crate::{error::Error, storage::StorageImpl}; + +pub struct ValkeyStorage { + pub valkey_client: redis::Client, +} + +#[derive(Deserialize, Clone)] +pub struct RegisterRequest { + pub public_ip: IpAddr, + pub public_key: WireguardPublicKey, + pub port: String, + pub allowed_ips: Vec, +} + +impl StorageImpl for ValkeyStorage { + async fn register_device(&self, request: &RegisterRequest) -> Result<()> { + let mut conn = self + .valkey_client + .get_multiplexed_async_connection() + .await + .map_err(|e| Error::valkey_get_connection(e))?; + conn.hset_multiple::<_, _, _>( + format!("peer:{}", request.public_key.as_str()), + &[ + ("public_ip", &request.public_ip.to_string()), + ( + "allowed_ips", + &serde_json::to_string(&request.allowed_ips) + .map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?, + ), + ("port", &request.port), + ], + ) + .await + .map_err(|e| { + Error::add_peer( + e, + request.public_key.as_str(), + request.public_ip.to_string(), + ) + })?; + conn.sadd("peers", &request.public_key.as_str()) + .await + .map_err(|e| { + Error::add_peer( + e, + request.public_key.as_str(), + request.public_ip.to_string(), + ) + })?; + + Ok(()) + } + + async fn get_peers(&self) -> Result> { + let mut conn = self + .valkey_client + .get_multiplexed_async_connection() + .await + .map_err(|e| Error::valkey_get_connection(e))?; + let keys: HashSet = conn + .smembers("peers") + .await + .map_err(|e| Error::get_peer(e))?; + + if keys.is_empty() { + return Ok(vec![]); + } + + let keys: Vec = keys.into_iter().collect(); + let mut pipe = redis::pipe(); + for key in keys.iter() { + pipe.hgetall(format!("peer:{key}")); + } + + let peers: Vec = pipe + .query_async::>>(&mut conn) + .await + .map_err(|e| Error::get_peer(e))? + .into_iter() + .zip(keys.iter()) + .map(|(peer, key): (HashMap, &String)| { + let allowed_ips: Vec = peer + .get("allowed_ips") + .map(|s| serde_json::from_str(s).unwrap_or_default()) + .unwrap_or_default(); + + Peer { + public_key: key.clone(), + public_ip: peer.get("public_ip").unwrap().to_string(), + port: peer.get("port").unwrap().to_string(), + allowed_ips, + } + }) + .collect(); + + Ok(peers) + } +} diff --git a/src/utils/mod.rs b/registry/utils/mod.rs similarity index 100% rename from src/utils/mod.rs rename to registry/utils/mod.rs diff --git a/src/utils/peer.rs b/registry/utils/peer.rs similarity index 71% rename from src/utils/peer.rs rename to registry/utils/peer.rs index b21bb10..059b1c9 100644 --- a/src/utils/peer.rs +++ b/registry/utils/peer.rs @@ -1,9 +1,10 @@ use ipnetwork::IpNetwork; use serde::{Deserialize, Serialize}; -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub struct Peer { pub public_key: String, pub public_ip: String, + pub port: String, pub allowed_ips: Vec, } diff --git a/src/utils/wg.rs b/registry/utils/wg.rs similarity index 100% rename from src/utils/wg.rs rename to registry/utils/wg.rs diff --git a/src/endpoints/peers.rs b/src/endpoints/peers.rs deleted file mode 100644 index 43b8a6c..0000000 --- a/src/endpoints/peers.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::collections::HashMap; - -use actix_web::{HttpResponse, web}; -use ipnetwork::IpNetwork; -use redis::AsyncTypedCommands; - -use crate::{ - AppState, - error::{Error, Result}, - utils::Peer, -}; - -pub async fn get_peers(app_state: web::Data) -> Result { - let mut conn = app_state - .valkey_client - .get_multiplexed_async_connection() - .await - .map_err(|e| Error::valkey_get_connection(e))?; - let keys = conn - .smembers("peers") - .await - .map_err(|e| Error::get_peer(e))?; - let mut pipe = redis::pipe(); - for key in keys.iter() { - pipe.hgetall(format!("peer:{key}")); - } - - let peers: Vec = pipe - .query_async::>>(&mut conn) - .await - .map_err(|e| Error::get_peer(e))? - .into_iter() - .zip(keys.iter()) - .map(|(peer, key)| { - let allowed_ips: Vec = peer - .get("allowed_ips") - .map(|s| serde_json::from_str(s).unwrap_or_default()) - .unwrap_or_default(); - - Peer { - public_key: key.clone(), - public_ip: peer.get("public_ip").unwrap().to_string(), - allowed_ips, - } - }) - .collect(); - - Ok(HttpResponse::Ok().json(peers)) -} diff --git a/src/endpoints/register.rs b/src/endpoints/register.rs deleted file mode 100644 index 87dc66f..0000000 --- a/src/endpoints/register.rs +++ /dev/null @@ -1,59 +0,0 @@ -use std::net::IpAddr; - -use crate::{ - AppState, - error::{Error, Result}, - utils::WireguardPublicKey, -}; -use actix_web::{HttpResponse, web}; -use ipnetwork::IpNetwork; -use redis::AsyncTypedCommands; -use serde::Deserialize; - -#[derive(Deserialize, Clone)] -pub struct RegisterRequest { - public_ip: IpAddr, - public_key: WireguardPublicKey, - allowed_ips: Vec, -} - -pub async fn register_peer( - app_state: web::Data, - request: web::Json, -) -> Result { - let mut conn = app_state - .valkey_client - .get_multiplexed_async_connection() - .await - .map_err(|e| Error::valkey_get_connection(e))?; - conn.hset_multiple::<_, _, _>( - format!("peer:{}", request.public_key.as_str()), - &[ - ("public_ip", &request.public_ip.to_string()), - ( - "allowed_ips", - &serde_json::to_string(&request.allowed_ips) - .map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?, - ), - ], - ) - .await - .map_err(|e| { - Error::add_peer( - e, - request.public_key.as_str(), - request.public_ip.to_string(), - ) - })?; - conn.sadd("peers", &request.public_key.as_str()) - .await - .map_err(|e| { - Error::add_peer( - e, - request.public_key.as_str(), - request.public_ip.to_string(), - ) - })?; - - Ok(HttpResponse::Ok().finish()) -}