refactor!: add ws endpoint, storage trait
This commit is contained in:
3
registry/endpoints/mod.rs
Normal file
3
registry/endpoints/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod peers;
|
||||
pub mod register;
|
||||
pub mod ws;
|
||||
8
registry/endpoints/peers.rs
Normal file
8
registry/endpoints/peers.rs
Normal file
@@ -0,0 +1,8 @@
|
||||
use actix_web::{HttpResponse, web};
|
||||
|
||||
use crate::{AppState, error::Result, storage::StorageImpl};
|
||||
|
||||
pub async fn get_peers(app_state: web::Data<AppState>) -> Result<HttpResponse> {
|
||||
let peers = app_state.storage.get_peers().await?;
|
||||
Ok(HttpResponse::Ok().json(peers))
|
||||
}
|
||||
27
registry/endpoints/register.rs
Normal file
27
registry/endpoints/register.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use crate::{
|
||||
AppState, PeerUpdate,
|
||||
error::Result,
|
||||
storage::{RegisterRequest, StorageImpl},
|
||||
utils::Peer,
|
||||
};
|
||||
use actix_web::{HttpResponse, web};
|
||||
|
||||
pub async fn register_peer(
|
||||
app_state: web::Data<AppState>,
|
||||
request: web::Json<RegisterRequest>,
|
||||
) -> Result<HttpResponse> {
|
||||
app_state.storage.register_device(&request).await?;
|
||||
app_state
|
||||
.peer_updates
|
||||
.send(PeerUpdate {
|
||||
peer: Peer {
|
||||
public_key: request.public_key.as_str().to_string(),
|
||||
public_ip: request.public_ip.to_string(),
|
||||
port: request.port.clone(),
|
||||
allowed_ips: request.allowed_ips.clone(),
|
||||
},
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
}
|
||||
1
registry/endpoints/ws/mod.rs
Normal file
1
registry/endpoints/ws/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod peers;
|
||||
87
registry/endpoints/ws/peers.rs
Normal file
87
registry/endpoints/ws/peers.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
use crate::{AppState, error::Error, storage::StorageImpl};
|
||||
use actix_web::{HttpRequest, HttpResponse, rt, web};
|
||||
use actix_ws::AggregatedMessage;
|
||||
use futures_util::StreamExt;
|
||||
|
||||
pub async fn peers(
|
||||
req: HttpRequest,
|
||||
stream: web::Payload,
|
||||
app_state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let (res, mut session, msg_stream) =
|
||||
actix_ws::handle(&req, stream).map_err(|e| Error::ws(e))?;
|
||||
|
||||
let mut msg_stream = msg_stream.aggregate_continuations();
|
||||
|
||||
let mut peer_rx = app_state.peer_updates.subscribe();
|
||||
|
||||
match app_state.storage.get_peers().await {
|
||||
Ok(initial_peers) => {
|
||||
let json = serde_json::to_string(&initial_peers).unwrap_or_else(|_| "[]".to_string());
|
||||
if session.text(json).await.is_err() {
|
||||
return Ok(res);
|
||||
}
|
||||
tracing::info!(
|
||||
"sent initial peer list ({} peers) to new WebSocket client",
|
||||
initial_peers.len()
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("failed to fetch initial peers: {:?}", e);
|
||||
let _ = session.close(None).await;
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
|
||||
rt::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
msg = msg_stream.next() => {
|
||||
match msg {
|
||||
Some(Ok(AggregatedMessage::Ping(data))) => {
|
||||
if session.pong(&data).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(Ok(AggregatedMessage::Pong(_))) => {}
|
||||
Some(Ok(AggregatedMessage::Close(_))) => {
|
||||
break;
|
||||
}
|
||||
Some(Ok(AggregatedMessage::Text(_))) => {
|
||||
}
|
||||
Some(Ok(AggregatedMessage::Binary(_))) => {
|
||||
}
|
||||
Some(Err(_)) => {
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
update = peer_rx.recv() => {
|
||||
match update {
|
||||
Ok(peer_update) => {
|
||||
let json = serde_json::to_string(&peer_update.peer)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
if session.text(json).await.is_err() {
|
||||
break;
|
||||
}
|
||||
tracing::info!("sent peer update to WebSocket client: {}", peer_update.peer.public_key);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => {
|
||||
tracing::warn!("WebSocket client lagged, missed {} updates", n);
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = session.close(None).await;
|
||||
tracing::info!("WebSocket client disconnected");
|
||||
});
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
45
registry/error.rs
Normal file
45
registry/error.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
use actix_web::{HttpResponse, ResponseError};
|
||||
use thiserror::Error;
|
||||
use thiserror_ext::{Box, Construct};
|
||||
|
||||
#[derive(Error, Debug, Box, Construct)]
|
||||
#[thiserror_ext(newtype(name = Error))]
|
||||
pub enum ErrorKind {
|
||||
#[error("error connecting to valkey at {address}")]
|
||||
ValkeyConnect {
|
||||
address: String,
|
||||
#[source]
|
||||
source: redis::RedisError,
|
||||
},
|
||||
#[error("error getting valkey connection")]
|
||||
ValkeyGetConnection(#[source] redis::RedisError),
|
||||
#[error("error adding peer")]
|
||||
AddPeer {
|
||||
public_key: String,
|
||||
public_ip: String,
|
||||
#[source]
|
||||
source: redis::RedisError,
|
||||
},
|
||||
#[error("error getting peers")]
|
||||
GetPeer(#[source] redis::RedisError),
|
||||
#[error("io error")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("error serializing json: {context}")]
|
||||
SerializeJson {
|
||||
context: String,
|
||||
#[source]
|
||||
source: serde_json::Error,
|
||||
},
|
||||
#[error("error handling ws")]
|
||||
Ws(#[source] actix_web::Error),
|
||||
}
|
||||
|
||||
impl ResponseError for Error {
|
||||
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
|
||||
match self.inner() {
|
||||
_ => HttpResponse::InternalServerError().finish(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
79
registry/main.rs
Normal file
79
registry/main.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod storage;
|
||||
mod utils;
|
||||
|
||||
use actix_web::{App, HttpServer, web};
|
||||
use console::style;
|
||||
use thiserror_ext::AsReport;
|
||||
use tokio::sync::broadcast;
|
||||
use tracing::level_filters::LevelFilter;
|
||||
use tracing_subscriber::{
|
||||
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
|
||||
};
|
||||
|
||||
use crate::storage::{Storage, get_storage_from_env};
|
||||
use crate::utils::Peer;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PeerUpdate {
|
||||
pub peer: Peer,
|
||||
}
|
||||
|
||||
pub struct AppState {
|
||||
pub peer_updates: broadcast::Sender<PeerUpdate>,
|
||||
pub storage: Storage,
|
||||
}
|
||||
|
||||
async fn run() -> crate::error::Result<()> {
|
||||
let (peer_tx, _) = broadcast::channel::<PeerUpdate>(100);
|
||||
|
||||
let app_state = web::Data::new(AppState {
|
||||
storage: get_storage_from_env()?,
|
||||
peer_updates: peer_tx,
|
||||
});
|
||||
|
||||
HttpServer::new(move || {
|
||||
App::new()
|
||||
.app_data(app_state.clone())
|
||||
.wrap(tracing_actix_web::TracingLogger::default())
|
||||
.route(
|
||||
"/",
|
||||
web::get()
|
||||
.to(async || concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))),
|
||||
)
|
||||
.route(
|
||||
"/register",
|
||||
web::post().to(endpoints::register::register_peer),
|
||||
)
|
||||
.route("/peers", web::get().to(endpoints::peers::get_peers))
|
||||
.route("/ws/peers", web::get().to(endpoints::ws::peers::peers))
|
||||
})
|
||||
.bind(("0.0.0.0", 8080))?
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[actix_web::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
let tracing_env_filter = EnvFilter::builder()
|
||||
.with_default_directive(LevelFilter::INFO.into())
|
||||
.from_env_lossy();
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(tracing_env_filter)
|
||||
.with(
|
||||
tracing_subscriber::fmt::layer()
|
||||
.compact()
|
||||
.with_span_events(FmtSpan::CLOSE),
|
||||
)
|
||||
.init();
|
||||
if let Err(e) = run().await {
|
||||
eprintln!("{}: {}", style("error").red(), e.as_report());
|
||||
std::process::exit(1)
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
38
registry/storage/mod.rs
Normal file
38
registry/storage/mod.rs
Normal file
@@ -0,0 +1,38 @@
|
||||
use crate::{
|
||||
error::{Error, Result},
|
||||
utils::Peer,
|
||||
};
|
||||
|
||||
mod valkey;
|
||||
|
||||
pub use valkey::RegisterRequest;
|
||||
|
||||
pub enum Storage {
|
||||
Valkey(valkey::ValkeyStorage),
|
||||
}
|
||||
|
||||
pub trait StorageImpl {
|
||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()>;
|
||||
async fn get_peers(&self) -> Result<Vec<Peer>>;
|
||||
}
|
||||
|
||||
impl StorageImpl for Storage {
|
||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
|
||||
match self {
|
||||
Self::Valkey(storage) => storage.register_device(request).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_peers(&self) -> Result<Vec<Peer>> {
|
||||
match self {
|
||||
Self::Valkey(storage) => storage.get_peers().await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_storage_from_env() -> Result<Storage> {
|
||||
Ok(Storage::Valkey(valkey::ValkeyStorage {
|
||||
valkey_client: redis::Client::open("redis://127.0.0.1:6379/")
|
||||
.map_err(|e| Error::valkey_connect(e, "127.0.0.1:6379/".to_string()))?,
|
||||
}))
|
||||
}
|
||||
108
registry/storage/valkey.rs
Normal file
108
registry/storage/valkey.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::net::IpAddr;
|
||||
|
||||
use ipnetwork::IpNetwork;
|
||||
use redis::AsyncTypedCommands;
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::utils::{Peer, WireguardPublicKey};
|
||||
use crate::{error::Error, storage::StorageImpl};
|
||||
|
||||
pub struct ValkeyStorage {
|
||||
pub valkey_client: redis::Client,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone)]
|
||||
pub struct RegisterRequest {
|
||||
pub public_ip: IpAddr,
|
||||
pub public_key: WireguardPublicKey,
|
||||
pub port: String,
|
||||
pub allowed_ips: Vec<IpNetwork>,
|
||||
}
|
||||
|
||||
impl StorageImpl for ValkeyStorage {
|
||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
|
||||
let mut conn = self
|
||||
.valkey_client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| Error::valkey_get_connection(e))?;
|
||||
conn.hset_multiple::<_, _, _>(
|
||||
format!("peer:{}", request.public_key.as_str()),
|
||||
&[
|
||||
("public_ip", &request.public_ip.to_string()),
|
||||
(
|
||||
"allowed_ips",
|
||||
&serde_json::to_string(&request.allowed_ips)
|
||||
.map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?,
|
||||
),
|
||||
("port", &request.port),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::add_peer(
|
||||
e,
|
||||
request.public_key.as_str(),
|
||||
request.public_ip.to_string(),
|
||||
)
|
||||
})?;
|
||||
conn.sadd("peers", &request.public_key.as_str())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
Error::add_peer(
|
||||
e,
|
||||
request.public_key.as_str(),
|
||||
request.public_ip.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_peers(&self) -> Result<Vec<Peer>> {
|
||||
let mut conn = self
|
||||
.valkey_client
|
||||
.get_multiplexed_async_connection()
|
||||
.await
|
||||
.map_err(|e| Error::valkey_get_connection(e))?;
|
||||
let keys: HashSet<String> = conn
|
||||
.smembers("peers")
|
||||
.await
|
||||
.map_err(|e| Error::get_peer(e))?;
|
||||
|
||||
if keys.is_empty() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let keys: Vec<String> = keys.into_iter().collect();
|
||||
let mut pipe = redis::pipe();
|
||||
for key in keys.iter() {
|
||||
pipe.hgetall(format!("peer:{key}"));
|
||||
}
|
||||
|
||||
let peers: Vec<Peer> = pipe
|
||||
.query_async::<Vec<HashMap<String, String>>>(&mut conn)
|
||||
.await
|
||||
.map_err(|e| Error::get_peer(e))?
|
||||
.into_iter()
|
||||
.zip(keys.iter())
|
||||
.map(|(peer, key): (HashMap<String, String>, &String)| {
|
||||
let allowed_ips: Vec<IpNetwork> = peer
|
||||
.get("allowed_ips")
|
||||
.map(|s| serde_json::from_str(s).unwrap_or_default())
|
||||
.unwrap_or_default();
|
||||
|
||||
Peer {
|
||||
public_key: key.clone(),
|
||||
public_ip: peer.get("public_ip").unwrap().to_string(),
|
||||
port: peer.get("port").unwrap().to_string(),
|
||||
allowed_ips,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(peers)
|
||||
}
|
||||
}
|
||||
5
registry/utils/mod.rs
Normal file
5
registry/utils/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod peer;
|
||||
mod wg;
|
||||
|
||||
pub use peer::Peer;
|
||||
pub use wg::WireguardPublicKey;
|
||||
10
registry/utils/peer.rs
Normal file
10
registry/utils/peer.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use ipnetwork::IpNetwork;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct Peer {
|
||||
pub public_key: String,
|
||||
pub public_ip: String,
|
||||
pub port: String,
|
||||
pub allowed_ips: Vec<IpNetwork>,
|
||||
}
|
||||
33
registry/utils/wg.rs
Normal file
33
registry/utils/wg.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use base64::Engine;
|
||||
use serde::{Deserialize, de};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct WireguardPublicKey(String);
|
||||
|
||||
impl<'de> Deserialize<'de> for WireguardPublicKey {
|
||||
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
|
||||
let bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(&s)
|
||||
.map_err(|_| de::Error::custom("invalid base64 in public key"))?;
|
||||
|
||||
if bytes.len() != 32 {
|
||||
return Err(de::Error::invalid_length(
|
||||
bytes.len(),
|
||||
&"exactly 32 bytes for a Wireguard public key",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(WireguardPublicKey(s))
|
||||
}
|
||||
}
|
||||
|
||||
impl WireguardPublicKey {
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user