feat: implement client

This commit is contained in:
2026-02-12 13:29:15 -08:00
parent 0d0a548a46
commit 5aa6b98742
23 changed files with 645 additions and 164 deletions

View File

@@ -0,0 +1,3 @@
pub mod peers;
pub mod register;
pub mod ws;

View File

@@ -0,0 +1,8 @@
use actix_web::{HttpResponse, web};
use crate::{AppState, error::Result, storage::StorageImpl};
pub async fn get_peers(app_state: web::Data<AppState>) -> Result<HttpResponse> {
let peers = app_state.storage.get_peers().await?;
Ok(HttpResponse::Ok().json(peers))
}

View File

@@ -0,0 +1,27 @@
use crate::{
AppState, PeerUpdate,
error::Result,
storage::{RegisterRequest, StorageImpl},
};
use actix_web::{HttpResponse, web};
use registry::Peer;
pub async fn register_peer(
app_state: web::Data<AppState>,
request: web::Json<RegisterRequest>,
) -> Result<HttpResponse> {
app_state.storage.register_device(&request).await?;
app_state
.peer_updates
.send(PeerUpdate {
peer: Peer {
public_key: request.public_key.as_str().to_string(),
public_ip: request.public_ip.to_string(),
port: request.port.clone(),
allowed_ips: request.allowed_ips.clone(),
},
})
.unwrap();
Ok(HttpResponse::Ok().finish())
}

View File

@@ -0,0 +1 @@
pub mod peers;

View File

@@ -0,0 +1,92 @@
use crate::{AppState, error::Error, storage::StorageImpl};
use actix_web::{HttpRequest, HttpResponse, rt, web};
use actix_ws::AggregatedMessage;
use futures_util::StreamExt;
use registry::PeerMessage;
pub async fn peers(
req: HttpRequest,
stream: web::Payload,
app_state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let (res, mut session, msg_stream) =
actix_ws::handle(&req, stream).map_err(|e| Error::ws(e))?;
let mut msg_stream = msg_stream.aggregate_continuations();
let mut peer_rx = app_state.peer_updates.subscribe();
match app_state.storage.get_peers().await {
Ok(initial_peers) => {
let msg = PeerMessage::HydratePeers {
peers: initial_peers.clone(),
};
let json = serde_json::to_string(&msg)
.unwrap_or_else(|_| r#"{"type":"HydratePeers","peers":[]}"#.to_string());
if session.text(json).await.is_err() {
return Ok(res);
}
tracing::info!(
"sent initial peer list ({} peers) to new WebSocket client",
initial_peers.len()
);
}
Err(e) => {
tracing::warn!("failed to fetch initial peers: {:?}", e);
session.close(None).await.ok();
return Ok(res);
}
}
rt::spawn(async move {
loop {
tokio::select! {
msg = msg_stream.next() => {
match msg {
Some(Ok(AggregatedMessage::Ping(data))) => {
if session.pong(&data).await.is_err() {
break;
}
}
Some(Ok(AggregatedMessage::Pong(_))) => {}
Some(Ok(AggregatedMessage::Close(_))) => {
break;
}
Some(Ok(AggregatedMessage::Text(_))) => {
}
Some(Ok(AggregatedMessage::Binary(_))) => {
}
Some(Err(_)) => {
break;
}
None => {
break;
}
}
}
update = peer_rx.recv() => {
match update {
Ok(peer_update) => {
let json = serde_json::to_string(&peer_update.peer)
.unwrap_or_else(|_| "{}".to_string());
if session.text(json).await.is_err() {
break;
}
tracing::info!("sent peer update to WebSocket client: {}", peer_update.peer.public_key);
}
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("WebSocket client lagged, missed {} updates", n);
}
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
break;
}
}
}
}
}
session.close(None).await.ok();
tracing::info!("WebSocket client disconnected");
});
Ok(res)
}

46
registry/src/error.rs Normal file
View File

@@ -0,0 +1,46 @@
use actix_web::{HttpResponse, ResponseError};
use thiserror::Error;
use thiserror_ext::{Box, Construct};
#[derive(Error, Debug, Box, Construct)]
#[thiserror_ext(newtype(name = Error))]
pub enum ErrorKind {
#[error("error connecting to valkey at {address}")]
ValkeyConnect {
address: String,
#[source]
source: redis::RedisError,
},
#[error("error getting valkey connection")]
ValkeyGetConnection(#[source] redis::RedisError),
#[error("error adding peer")]
AddPeer {
public_key: String,
public_ip: String,
#[source]
source: redis::RedisError,
},
#[error("error getting peers")]
GetPeer(#[source] redis::RedisError),
#[error("io error")]
Io(#[from] std::io::Error),
#[error("error serializing json: {context}")]
SerializeJson {
context: String,
#[source]
source: serde_json::Error,
},
#[error("error handling ws")]
Ws(#[source] actix_web::Error),
}
impl ResponseError for Error {
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
match self.inner() {
ErrorKind::Ws(e) => e.error_response(),
_ => HttpResponse::InternalServerError().finish(),
}
}
}
pub type Result<T> = core::result::Result<T, Error>;

4
registry/src/lib.rs Normal file
View File

@@ -0,0 +1,4 @@
mod types;
mod utils;
pub use types::peer_message::*;

80
registry/src/main.rs Normal file
View File

@@ -0,0 +1,80 @@
mod endpoints;
mod error;
mod storage;
mod types;
mod utils;
use actix_web::{App, HttpServer, web};
use console::style;
use registry::Peer;
use thiserror_ext::AsReport;
use tokio::sync::broadcast;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
};
use crate::storage::{Storage, get_storage_from_env};
#[derive(Clone, Debug)]
pub struct PeerUpdate {
pub peer: Peer,
}
pub struct AppState {
pub peer_updates: broadcast::Sender<PeerUpdate>,
pub storage: Storage,
}
async fn run() -> crate::error::Result<()> {
let (peer_tx, _) = broadcast::channel::<PeerUpdate>(100);
let app_state = web::Data::new(AppState {
storage: get_storage_from_env()?,
peer_updates: peer_tx,
});
HttpServer::new(move || {
App::new()
.app_data(app_state.clone())
.wrap(tracing_actix_web::TracingLogger::default())
.route(
"/",
web::get()
.to(async || concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))),
)
.route(
"/register",
web::post().to(endpoints::register::register_peer),
)
.route("/peers", web::get().to(endpoints::peers::get_peers))
.route("/ws/peers", web::get().to(endpoints::ws::peers::peers))
})
.bind(("0.0.0.0", 8080))?
.run()
.await?;
Ok(())
}
#[actix_web::main]
async fn main() -> std::io::Result<()> {
let tracing_env_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();
tracing_subscriber::registry()
.with(tracing_env_filter)
.with(
tracing_subscriber::fmt::layer()
.compact()
.with_span_events(FmtSpan::CLOSE),
)
.init();
if let Err(e) = run().await {
eprintln!("{}: {}", style("error").red(), e.as_report());
std::process::exit(1)
}
Ok(())
}

View File

@@ -0,0 +1,36 @@
use crate::error::{Error, Result};
mod valkey;
use registry::Peer;
pub use valkey::RegisterRequest;
pub enum Storage {
Valkey(valkey::ValkeyStorage),
}
pub trait StorageImpl {
async fn register_device(&self, request: &RegisterRequest) -> Result<()>;
async fn get_peers(&self) -> Result<Vec<Peer>>;
}
impl StorageImpl for Storage {
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
match self {
Self::Valkey(storage) => storage.register_device(request).await,
}
}
async fn get_peers(&self) -> Result<Vec<Peer>> {
match self {
Self::Valkey(storage) => storage.get_peers().await,
}
}
}
pub fn get_storage_from_env() -> Result<Storage> {
Ok(Storage::Valkey(valkey::ValkeyStorage {
valkey_client: redis::Client::open("redis://127.0.0.1:6379/")
.map_err(|e| Error::valkey_connect(e, "127.0.0.1:6379/".to_string()))?,
}))
}

View File

@@ -0,0 +1,109 @@
use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use ipnetwork::IpNetwork;
use redis::AsyncTypedCommands;
use registry::Peer;
use serde::Deserialize;
use crate::error::Result;
use crate::utils::WireguardPublicKey;
use crate::{error::Error, storage::StorageImpl};
pub struct ValkeyStorage {
pub valkey_client: redis::Client,
}
#[derive(Deserialize, Clone)]
pub struct RegisterRequest {
pub public_ip: IpAddr,
pub public_key: WireguardPublicKey,
pub port: String,
pub allowed_ips: Vec<IpNetwork>,
}
impl StorageImpl for ValkeyStorage {
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
let mut conn = self
.valkey_client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::valkey_get_connection(e))?;
conn.hset_multiple::<_, _, _>(
format!("peer:{}", request.public_key.as_str()),
&[
("public_ip", &request.public_ip.to_string()),
(
"allowed_ips",
&serde_json::to_string(&request.allowed_ips)
.map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?,
),
("port", &request.port),
],
)
.await
.map_err(|e| {
Error::add_peer(
e,
request.public_key.as_str(),
request.public_ip.to_string(),
)
})?;
conn.sadd("peers", &request.public_key.as_str())
.await
.map_err(|e| {
Error::add_peer(
e,
request.public_key.as_str(),
request.public_ip.to_string(),
)
})?;
Ok(())
}
async fn get_peers(&self) -> Result<Vec<Peer>> {
let mut conn = self
.valkey_client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::valkey_get_connection(e))?;
let keys: HashSet<String> = conn
.smembers("peers")
.await
.map_err(|e| Error::get_peer(e))?;
if keys.is_empty() {
return Ok(vec![]);
}
let keys: Vec<String> = keys.into_iter().collect();
let mut pipe = redis::pipe();
for key in keys.iter() {
pipe.hgetall(format!("peer:{key}"));
}
let peers: Vec<Peer> = pipe
.query_async::<Vec<HashMap<String, String>>>(&mut conn)
.await
.map_err(|e| Error::get_peer(e))?
.into_iter()
.zip(keys.iter())
.map(|(peer, key): (HashMap<String, String>, &String)| {
let allowed_ips: Vec<IpNetwork> = peer
.get("allowed_ips")
.map(|s| serde_json::from_str(s).unwrap_or_default())
.unwrap_or_default();
Peer {
public_key: key.clone(),
public_ip: peer.get("public_ip").unwrap().to_string(),
port: peer.get("port").unwrap().to_string(),
allowed_ips,
}
})
.collect();
Ok(peers)
}
}

View File

@@ -0,0 +1 @@
pub mod peer_message;

View File

@@ -0,0 +1,18 @@
use serde::{Deserialize, Serialize};
use ipnetwork::IpNetwork;
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Peer {
pub public_key: String,
pub public_ip: String,
pub port: String,
pub allowed_ips: Vec<IpNetwork>,
}
#[derive(Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum PeerMessage {
HydratePeers { peers: Vec<Peer> },
PeerUpdate { peer: Peer },
}

View File

@@ -0,0 +1,3 @@
mod wg;
pub use wg::WireguardPublicKey;

33
registry/src/utils/wg.rs Normal file
View File

@@ -0,0 +1,33 @@
use base64::Engine;
use serde::{Deserialize, de};
#[derive(Clone)]
pub struct WireguardPublicKey(String);
impl<'de> Deserialize<'de> for WireguardPublicKey {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
let bytes = base64::engine::general_purpose::STANDARD
.decode(&s)
.map_err(|_| de::Error::custom("invalid base64 in public key"))?;
if bytes.len() != 32 {
return Err(de::Error::invalid_length(
bytes.len(),
&"exactly 32 bytes for a Wireguard public key",
));
}
Ok(WireguardPublicKey(s))
}
}
impl WireguardPublicKey {
pub fn as_str(&self) -> &str {
&self.0
}
}