Compare commits

...

3 Commits

23 changed files with 2060 additions and 76 deletions

1197
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,3 @@
[workspace] [workspace]
members = ["registry", "client"] members = ["registry", "client"]
resolver = "3" resolver = "3"
[workspace.dependencies]

View File

@@ -11,8 +11,23 @@ serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149" serde_json = "1.0.149"
thiserror = "2.0.18" thiserror = "2.0.18"
thiserror-ext = "0.3.0" thiserror-ext = "0.3.0"
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread"] } tokio = { version = "1.49.0", features = ["macros", "net", "rt-multi-thread", "signal"] }
tokio-tungstenite = { version = "0.28.0", features = ["rustls-tls-native-roots"] } tokio-tungstenite = { version = "0.28.0", features = ["rustls-tls-native-roots"] }
tracing = "0.1.44" tracing = "0.1.44"
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] } tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
registry = { path = "../registry" } registry = { path = "../registry" }
dirs = "6.0.0"
toml = "1.0.1"
url = "2.5.8"
reqwest = { version = "0.13.2", features = ["json"] }
stunclient = "0.4.2"
ipnetwork = { version = "0.21.1", features = ["serde"] }
base64 = "0.22.1"
libc = "0.2"
# WireGuard netlink control
wireguard-control = "1.1"
# Network interface management
rtnetlink = "0.14"
netlink-packet-route = "0.19"

52
client/src/config.rs Normal file
View File

@@ -0,0 +1,52 @@
use std::path::PathBuf;
use crate::error::{Error, Result};
use ipnetwork::IpNetwork;
use serde::Deserialize;
pub const DEFAULT_KEEPALIVE: u16 = 25;
pub const INTERFACE_NAME: &str = "mesh0";
#[derive(Deserialize)]
pub struct Config {
pub interface: InterfaceConfig,
pub server: ServerConfig,
}
#[derive(Deserialize)]
pub struct InterfaceConfig {
pub private_key: String,
pub public_key: String,
#[serde(default = "default_listen_port")]
pub listen_port: u16,
#[serde(default = "default_keepalive")]
pub persistent_keepalive: u16,
pub allowed_ips: Option<Vec<IpNetwork>>,
}
fn default_listen_port() -> u16 {
51820
}
fn default_keepalive() -> u16 {
DEFAULT_KEEPALIVE
}
#[derive(Deserialize)]
pub struct ServerConfig {
pub ws_url: String,
pub url: String,
}
impl Config {
pub fn load() -> Result<Self> {
let path = base_path().join("config.toml");
let bytes = std::fs::read(&path).map_err(|e| Error::read_config(e, &path))?;
Ok(toml::from_slice(&bytes)?)
}
}
pub fn base_path() -> PathBuf {
dirs::config_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("wg-mesh")
}

142
client/src/discovery.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Duration;
use stunclient::StunClient;
use crate::error::{Error, Result};
const STUN_SERVERS: &[&str] = &[
"stun.l.google.com:19302",
"stun.cloudflare.com:3478",
"stun.stunprotocol.org:3478",
];
const HTTP_IP_SERVICES: &[&str] = &["https://icanhazip.com", "https://ifconfig.me/ip"];
const STUN_TIMEOUT: Duration = Duration::from_secs(3);
const HTTP_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
pub struct PublicEndpoint {
pub ip: IpAddr,
pub port: u16,
}
impl std::fmt::Display for PublicEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.ip, self.port)
}
}
pub async fn discover_public_endpoint(local_port: u16) -> Result<PublicEndpoint> {
match stun_discover(local_port) {
Ok(endpoint) => {
tracing::debug!(
ip = %endpoint.ip,
port = endpoint.port,
"STUN discovery successful"
);
return Ok(endpoint);
}
Err(e) => {
tracing::warn!("STUN discovery failed, falling back to HTTP: {e}");
}
}
let ip = http_discover_ip().await?;
let endpoint = PublicEndpoint {
ip,
port: local_port,
};
tracing::debug!(
ip = %endpoint.ip,
port = endpoint.port,
"HTTP fallback discovery successful (port may not reflect NAT mapping)"
);
Ok(endpoint)
}
fn stun_discover(local_port: u16) -> Result<PublicEndpoint> {
let local_addr: SocketAddr = format!("0.0.0.0:{local_port}")
.parse()
.expect("valid socket addr");
let socket = UdpSocket::bind(local_addr).map_err(|e| {
Error::stun_discovery(format!(
"failed to bind UDP socket on port {local_port}: {e}"
))
})?;
socket
.set_read_timeout(Some(STUN_TIMEOUT))
.map_err(|e| Error::stun_discovery(format!("failed to set socket timeout: {e}")))?;
for server in STUN_SERVERS {
tracing::debug!(server, "Attempting STUN discovery");
match stun_query(&socket, server) {
Ok(endpoint) => return Ok(endpoint),
Err(e) => {
tracing::debug!(server, error = %e, "STUN server failed, trying next");
}
}
}
Err(Error::discovery_failed())
}
fn stun_query(socket: &UdpSocket, server: &str) -> Result<PublicEndpoint> {
let server_addr = server
.to_socket_addrs()
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?
.next()
.ok_or_else(|| Error::stun_discovery(format!("{server}: no addresses found")))?;
let client = StunClient::new(server_addr);
let public_addr = client
.query_external_address(socket)
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?;
Ok(PublicEndpoint {
ip: public_addr.ip(),
port: public_addr.port(),
})
}
async fn http_discover_ip() -> Result<IpAddr> {
let client = reqwest::Client::builder()
.timeout(HTTP_TIMEOUT)
.build()
.map_err(Error::http_discovery)?;
for service in HTTP_IP_SERVICES {
tracing::debug!(service, "Attempting HTTP IP discovery");
match http_query_ip(&client, service).await {
Ok(ip) => return Ok(ip),
Err(e) => {
tracing::debug!(service, error = %e, "HTTP service failed, trying next");
}
}
}
Err(Error::discovery_failed())
}
async fn http_query_ip(client: &reqwest::Client, url: &str) -> Result<IpAddr> {
let response = client
.get(url)
.send()
.await
.map_err(Error::http_discovery)?;
let body = response.text().await.map_err(Error::http_discovery)?;
let ip_str = body.trim();
ip_str
.parse::<IpAddr>()
.map_err(|_| Error::parse_ip(ip_str.to_string()))
}

View File

@@ -1,3 +1,5 @@
use std::path::PathBuf;
use thiserror::Error; use thiserror::Error;
use thiserror_ext::{Box, Construct}; use thiserror_ext::{Box, Construct};
use tokio_tungstenite::tungstenite; use tokio_tungstenite::tungstenite;
@@ -15,6 +17,75 @@ pub enum ErrorKind {
WsRead(#[source] tungstenite::Error), WsRead(#[source] tungstenite::Error),
#[error("error deserializing json")] #[error("error deserializing json")]
DeserializeJson(#[source] serde_json::Error), DeserializeJson(#[source] serde_json::Error),
#[error("error reading configuration at {path}")]
ReadConfig {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("error deserializing toml")]
DeserializeToml(#[from] toml::de::Error),
#[error("invalid url")]
Url(#[from] url::ParseError),
#[error("invalid url scheme: {url}")]
UrlScheme { url: String },
#[error("STUN discovery failed: {message}")]
StunDiscovery { message: String },
#[error("HTTP IP discovery failed")]
HttpDiscovery(#[source] reqwest::Error),
#[error("failed to parse IP address from HTTP response: {body}")]
ParseIp { body: String },
#[error("all discovery methods failed")]
DiscoveryFailed,
#[error("error with request")]
Reqwest(#[from] reqwest::Error),
#[error("invalid base64 key: {context}")]
InvalidKey { context: String },
#[error("error creating WireGuard interface {interface}")]
CreateInterface {
interface: String,
#[source]
source: std::io::Error,
},
#[error("error configuring WireGuard device {interface}")]
ConfigureDevice {
interface: String,
#[source]
source: std::io::Error,
},
#[error("error with netlink operation: {context}")]
Netlink { context: String },
#[error("error setting interface address")]
SetAddress {
#[source]
source: rtnetlink::Error,
},
#[error("error bringing interface up")]
SetLinkUp {
#[source]
source: rtnetlink::Error,
},
#[error("error getting interface index for {interface}")]
GetInterface { interface: String },
#[error("registration failed")]
Registration {
#[source]
source: reqwest::Error,
},
#[error("error deregistering device {public_key}")]
Deregister {
public_key: String,
#[source]
source: reqwest::Error,
},
#[error("registration failed with status {status}: {body}")]
RegistrationStatus { status: u16, body: String },
#[error("error deleting interface")]
DeleteInterface {
name: String,
#[source]
source: rtnetlink::Error,
},
} }
pub type Result<T> = core::result::Result<T, Error>; pub type Result<T> = core::result::Result<T, Error>;

View File

@@ -1,42 +1,244 @@
mod config;
mod discovery;
mod error;
mod netlink;
mod wireguard;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use console::style; use console::style;
use futures::StreamExt; use futures::{StreamExt, TryFutureExt, stream::SplitStream};
use registry::PeerMessage; use ipnetwork::IpNetwork;
use registry::{Peer, PeerMessage, RegisterResponse};
use serde::Serialize;
use thiserror_ext::AsReport; use thiserror_ext::AsReport;
use tokio_tungstenite::tungstenite::Message; use tokio::{
net::TcpStream,
signal::{self, unix::SignalKind},
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing_subscriber::{ use tracing_subscriber::{
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
}; };
use crate::error::Error; use crate::{
config::{Config, INTERFACE_NAME},
discovery::{PublicEndpoint, discover_public_endpoint},
error::{Error, Result},
};
mod error; #[derive(Serialize)]
pub struct RegisterRequest {
pub public_ip: IpAddr,
pub public_key: String,
pub port: String,
pub allowed_ips: Vec<IpNetwork>,
}
async fn run() -> crate::error::Result<()> { async fn register_self(
let url = get_api_from_env(); client: &reqwest::Client,
let (ws_stream, response) = tokio_tungstenite::connect_async(&url) endpoint: &PublicEndpoint,
config: &Config,
) -> Result<Ipv4Addr> {
let url = format!("{}/register", &config.server.url);
tracing::info!(
public_ip = %endpoint.ip,
port = endpoint.port,
"registering with registry"
);
let response = client
.post(&url)
.json(&RegisterRequest {
public_key: config.interface.public_key.clone(),
public_ip: endpoint.ip,
port: endpoint.port.to_string(),
allowed_ips: config.interface.allowed_ips.clone().unwrap_or(vec![]),
})
.send()
.await .await
.map_err(|e| Error::ws_connect(e, &url))?; .map_err(Error::registration)?;
let (_, mut read) = ws_stream.split();
tracing::info!("connected, response: {:?}", response.status());
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(Error::registration_status(status, body));
}
let register_response: RegisterResponse = response.json().await.map_err(Error::registration)?;
tracing::info!(
mesh_ip = %register_response.mesh_ip,
"registration successful, assigned mesh IP"
);
Ok(register_response.mesh_ip)
}
async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> {
let url = format!(
"{}/deregister?public_key={}",
config.server.url, config.interface.public_key
);
client
.delete(&url)
.send()
.await?
.error_for_status()
.map_err(|e| Error::deregister(e, &config.interface.public_key))?;
Ok(())
}
fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result<()> {
if peer.public_key == local_public_key {
tracing::debug!(peer_key = %peer.public_key, "skipping self");
return Ok(());
}
let peer_key = wireguard::parse_key(&peer.public_key)?;
let endpoint: SocketAddr = format!("{}:{}", peer.public_ip, peer.port)
.parse()
.map_err(|_| {
Error::netlink(format!(
"invalid endpoint: {}:{}",
peer.public_ip, peer.port
))
})?;
let mut allowed_ips: Vec<IpNetwork> =
vec![IpNetwork::new(std::net::IpAddr::V4(peer.mesh_ip), 32).expect("valid network")];
allowed_ips.extend(peer.allowed_ips.iter().cloned());
wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?;
tracing::info!(
peer_key = %peer.public_key,
mesh_ip = %peer.mesh_ip,
endpoint = %endpoint,
"configured peer"
);
Ok(())
}
fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -> Result<()> {
for peer in peers {
if let Err(e) = configure_peer(peer, local_public_key, keepalive) {
tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer");
}
}
Ok(())
}
async fn events(
read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
config: &Config,
) -> Result<()> {
while let Some(msg) = read.next().await { while let Some(msg) = read.next().await {
match msg.map_err(|e| Error::ws_read(e))? { match msg.map_err(Error::ws_read)? {
Message::Text(text) => { Message::Text(text) => {
let server_msg: PeerMessage = let server_msg: PeerMessage =
serde_json::from_str(&text).map_err(|e| Error::deserialize_json(e))?; serde_json::from_str(&text).map_err(Error::deserialize_json)?;
match server_msg { match server_msg {
PeerMessage::HydratePeers { peers } => {} PeerMessage::HydratePeers { peers } => {
PeerMessage::PeerUpdate { peer } => {} tracing::info!(count = peers.len(), "Received initial peer list");
configure_all_peers(
&peers,
&config.interface.public_key,
config.interface.persistent_keepalive,
)?;
} }
PeerMessage::PeerUpdate { peer } => {
tracing::info!(
peer_key = %peer.public_key,
mesh_ip = %peer.mesh_ip,
"Received peer update"
);
configure_peer(
&peer,
&config.interface.public_key,
config.interface.persistent_keepalive,
)?;
}
}
}
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Close(_) => {
tracing::warn!("connection closed by server");
break;
} }
_ => {} _ => {}
} }
} }
Ok(()) Ok(())
} }
async fn run() -> Result<()> {
let config = Config::load()?;
let private_key = wireguard::parse_key(&config.interface.private_key)?;
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
tracing::info!(
public_ip = %endpoint.ip,
public_port = endpoint.port,
"public endpoint"
);
let http_client = reqwest::Client::new();
let mesh_ip = register_self(&http_client, &endpoint, &config).await?;
wireguard::create_interface()?;
wireguard::configure_interface(&private_key, config.interface.listen_port)?;
let (nl_handle, _nl_task) = netlink::connect().await?;
netlink::add_address(&nl_handle, mesh_ip, 32).await?;
netlink::set_link_up(&nl_handle).await?;
tracing::info!(
interface = INTERFACE_NAME,
mesh_ip = %mesh_ip,
"interface configured and up"
);
let ws_url = &config.server.ws_url;
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| Error::ws_connect(e, ws_url))?;
tracing::info!(
status = ?response.status(),
"connected to registry WebSocket"
);
let (_, mut read) = ws_stream.split();
tokio::select! {
receiver = events(&mut read, &config) => receiver?,
_ = signal::ctrl_c() => {
},
_ = on_shutdown() => {}
};
tracing::debug!("gracefully shutting down");
netlink::delete_interface(&nl_handle, INTERFACE_NAME).await?;
deregister_self(&http_client, &config).await?;
tracing::info!("connection closed");
Ok(())
}
async fn on_shutdown() {
use tokio::signal::unix::signal;
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM");
let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT");
tokio::select! {
_ = sigterm.recv() => tracing::info!("received sigterm"),
_ = sigint.recv() => tracing::info!("received sigint")
}
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
let tracing_env_filter = EnvFilter::builder() let tracing_env_filter = EnvFilter::builder()
@@ -57,7 +259,3 @@ async fn main() {
std::process::exit(1) std::process::exit(1)
} }
} }
fn get_api_from_env() -> String {
return "ws://localhost:8080/ws/peers".to_string();
}

84
client/src/netlink.rs Normal file
View File

@@ -0,0 +1,84 @@
use std::net::Ipv4Addr;
use futures::TryStreamExt;
use rtnetlink::Handle;
use crate::config::INTERFACE_NAME;
use crate::error::{Error, Result};
async fn get_interface_index(handle: &Handle, name: &str) -> Result<u32> {
let mut links = handle.link().get().match_name(name.to_string()).execute();
if let Some(link) = links.try_next().await.map_err(Error::set_link_up)? {
Ok(link.header.index)
} else {
Err(Error::get_interface(name.to_string()))
}
}
pub async fn add_address(handle: &Handle, addr: Ipv4Addr, prefix_len: u8) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?;
tracing::debug!(
interface = INTERFACE_NAME,
address = %addr,
prefix_len,
"adding address to interface"
);
match handle
.address()
.add(index, std::net::IpAddr::V4(addr), prefix_len)
.execute()
.await
{
Ok(()) => Ok(()),
Err(rtnetlink::Error::NetlinkError(e)) if e.raw_code() == -libc::EEXIST => {
tracing::debug!(
interface = INTERFACE_NAME,
address = %addr,
"address already exists on interface"
);
Ok(())
}
Err(e) => Err(Error::set_address(e)),
}
}
pub async fn set_link_up(handle: &Handle) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?;
tracing::info!(interface = INTERFACE_NAME, "set interface up");
handle
.link()
.set(index)
.up()
.execute()
.await
.map_err(Error::set_link_up)?;
Ok(())
}
pub async fn delete_interface(handle: &Handle, name: &str) -> Result<()> {
let index = get_interface_index(&handle, name).await?;
handle
.link()
.del(index)
.execute()
.await
.map_err(|e| Error::delete_interface(e, name))?;
Ok(())
}
pub async fn connect() -> Result<(Handle, tokio::task::JoinHandle<()>)> {
let (connection, handle, _) = rtnetlink::new_connection()
.map_err(|e| Error::netlink(format!("failed to create netlink connection: {}", e)))?;
let join_handle = tokio::spawn(connection);
Ok((handle, join_handle))
}

117
client/src/wireguard.rs Normal file
View File

@@ -0,0 +1,117 @@
use std::net::SocketAddr;
use ipnetwork::IpNetwork;
use wireguard_control::{
AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, InvalidKey, Key, PeerConfigBuilder,
};
use crate::config::INTERFACE_NAME;
use crate::error::{Error, Result};
pub fn parse_key(key_b64: &str) -> Result<Key> {
Key::from_base64(key_b64).map_err(|e: InvalidKey| Error::invalid_key(e.to_string()))
}
pub fn interface_name() -> InterfaceName {
INTERFACE_NAME.parse().expect("valid interface name")
}
pub fn create_interface() -> Result<()> {
let name = interface_name();
if Device::get(&name, Backend::Kernel).is_ok() {
tracing::debug!(interface = INTERFACE_NAME, "interface already exists");
return Ok(());
}
tracing::info!(interface = INTERFACE_NAME, "creating WireGuard interface");
DeviceUpdate::new()
.apply(&name, Backend::Kernel)
.map_err(|e| Error::create_interface(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
pub fn configure_interface(private_key: &Key, listen_port: u16) -> Result<()> {
let name = interface_name();
tracing::debug!(
interface = INTERFACE_NAME,
listen_port,
"configuring WireGuard interface"
);
DeviceUpdate::new()
.set_private_key(private_key.clone())
.set_listen_port(listen_port)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
pub fn configure_peer(
public_key: &Key,
endpoint: Option<SocketAddr>,
allowed_ips: &[IpNetwork],
persistent_keepalive: Option<u16>,
) -> Result<()> {
let name = interface_name();
let allowed: Vec<AllowedIp> = allowed_ips
.iter()
.map(|ip| AllowedIp {
address: ip.ip(),
cidr: ip.prefix(),
})
.collect();
let mut peer = PeerConfigBuilder::new(public_key)
.replace_allowed_ips()
.add_allowed_ips(&allowed);
if let Some(ep) = endpoint {
peer = peer.set_endpoint(ep);
}
if let Some(keepalive) = persistent_keepalive {
peer = peer.set_persistent_keepalive_interval(keepalive);
}
tracing::debug!(
peer_key = %public_key.to_base64(),
endpoint = ?endpoint,
allowed_ips = ?allowed_ips,
"configuring peer"
);
DeviceUpdate::new()
.add_peer(peer)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn remove_peer(public_key: &Key) -> Result<()> {
let name = interface_name();
tracing::info!(
peer_key = %public_key.to_base64(),
"removing peer"
);
DeviceUpdate::new()
.remove_peer_by_key(public_key)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn get_device() -> Result<Device> {
let name = interface_name();
Device::get(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))
}

View File

@@ -1,2 +1,7 @@
dev bin="registry": dev bin="registry":
bacon dev -- --bin {{bin}} bacon dev -- --bin {{bin}}
client target="debug":
cargo build -p client
sudo setcap cap_net_admin+ep ./target/{{target}}/client
./target/{{target}}/client

6
payload.json Normal file
View File

@@ -0,0 +1,6 @@
{
"public_key": "pKkl30tba29FG86wuaC0KrpSHMr1tSOujikHFbx75BM=",
"public_ip": "75.157.238.86",
"port": "51820",
"allowed_ips": ["10.100.0.1/32"]
}

View File

@@ -21,6 +21,7 @@ tracing-actix-web = "0.7.21"
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "sync"] } tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "sync"] }
thiserror = "2.0.18" thiserror = "2.0.18"
thiserror-ext = "0.3.0" thiserror-ext = "0.3.0"
wireguard-control = "1.7.1"
[lib] [lib]
name = "registry" name = "registry"

View File

@@ -0,0 +1,27 @@
use actix_web::{HttpResponse, web};
use serde::Deserialize;
use wireguard_control::Key;
use crate::{
AppState,
error::{Error, Result},
storage::StorageImpl,
};
#[derive(Deserialize)]
pub struct DeregisterRequest {
public_key: String,
}
pub async fn deregister(
app_state: web::Data<AppState>,
query: web::Query<DeregisterRequest>,
) -> Result<HttpResponse> {
Key::from_base64(&query.public_key).map_err(|_| Error::invalid_key(&query.public_key))?;
app_state
.storage
.deregister_device(&query.public_key)
.await?;
Ok(HttpResponse::Ok().finish())
}

View File

@@ -1,3 +1,4 @@
pub mod deregister;
pub mod peers; pub mod peers;
pub mod register; pub mod register;
pub mod ws; pub mod ws;

View File

@@ -4,13 +4,14 @@ use crate::{
storage::{RegisterRequest, StorageImpl}, storage::{RegisterRequest, StorageImpl},
}; };
use actix_web::{HttpResponse, web}; use actix_web::{HttpResponse, web};
use registry::Peer; use registry::{Peer, RegisterResponse};
pub async fn register_peer( pub async fn register_peer(
app_state: web::Data<AppState>, app_state: web::Data<AppState>,
request: web::Json<RegisterRequest>, request: web::Json<RegisterRequest>,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
app_state.storage.register_device(&request).await?; let mesh_ip = app_state.storage.register_device(&request).await?;
app_state app_state
.peer_updates .peer_updates
.send(PeerUpdate { .send(PeerUpdate {
@@ -18,10 +19,11 @@ pub async fn register_peer(
public_key: request.public_key.as_str().to_string(), public_key: request.public_key.as_str().to_string(),
public_ip: request.public_ip.to_string(), public_ip: request.public_ip.to_string(),
port: request.port.clone(), port: request.port.clone(),
mesh_ip,
allowed_ips: request.allowed_ips.clone(), allowed_ips: request.allowed_ips.clone(),
}, },
}) })
.unwrap(); .ok();
Ok(HttpResponse::Ok().finish()) Ok(HttpResponse::Ok().json(RegisterResponse { mesh_ip }))
} }

View File

@@ -67,7 +67,7 @@ pub async fn peers(
update = peer_rx.recv() => { update = peer_rx.recv() => {
match update { match update {
Ok(peer_update) => { Ok(peer_update) => {
let json = serde_json::to_string(&peer_update.peer) let json = serde_json::to_string(&PeerMessage::PeerUpdate { peer: peer_update.peer.clone() })
.unwrap_or_else(|_| "{}".to_string()); .unwrap_or_else(|_| "{}".to_string());
if session.text(json).await.is_err() { if session.text(json).await.is_err() {
break; break;

View File

@@ -1,4 +1,5 @@
use actix_web::{HttpResponse, ResponseError}; use actix_web::{HttpResponse, ResponseError};
use serde::Serialize;
use thiserror::Error; use thiserror::Error;
use thiserror_ext::{Box, Construct}; use thiserror_ext::{Box, Construct};
@@ -32,12 +33,30 @@ pub enum ErrorKind {
}, },
#[error("error handling ws")] #[error("error handling ws")]
Ws(#[source] actix_web::Error), Ws(#[source] actix_web::Error),
#[error("IP pool exhausted: no available addresses in {pool}")]
IpPoolExhausted { pool: String },
#[error("error deregistering device {public_key}")]
DeregisterDevice {
public_key: String,
#[source]
source: redis::RedisError,
},
#[error("error invalid key")]
InvalidKey(String),
}
#[derive(Serialize)]
struct ErrorResponse {
error: String,
} }
impl ResponseError for Error { impl ResponseError for Error {
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> { fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
match self.inner() { match self.inner() {
ErrorKind::Ws(e) => e.error_response(), ErrorKind::Ws(e) => e.error_response(),
ErrorKind::InvalidKey(key) => HttpResponse::BadRequest().json(ErrorResponse {
error: format!("error invalid key: {key}"),
}),
_ => HttpResponse::InternalServerError().finish(), _ => HttpResponse::InternalServerError().finish(),
} }
} }

View File

@@ -1,4 +1,4 @@
mod types; mod types;
mod utils; mod utils;
pub use types::peer_message::*; pub use types::peer_message::{Peer, PeerMessage, RegisterResponse};

View File

@@ -47,6 +47,10 @@ async fn run() -> crate::error::Result<()> {
"/register", "/register",
web::post().to(endpoints::register::register_peer), web::post().to(endpoints::register::register_peer),
) )
.route(
"/deregister",
web::delete().to(endpoints::deregister::deregister),
)
.route("/peers", web::get().to(endpoints::peers::get_peers)) .route("/peers", web::get().to(endpoints::peers::get_peers))
.route("/ws/peers", web::get().to(endpoints::ws::peers::peers)) .route("/ws/peers", web::get().to(endpoints::ws::peers::peers))
}) })

View File

@@ -1,3 +1,5 @@
use std::net::Ipv4Addr;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
mod valkey; mod valkey;
@@ -10,12 +12,13 @@ pub enum Storage {
} }
pub trait StorageImpl { pub trait StorageImpl {
async fn register_device(&self, request: &RegisterRequest) -> Result<()>; async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr>;
async fn deregister_device(&self, public_key: &str) -> Result<()>;
async fn get_peers(&self) -> Result<Vec<Peer>>; async fn get_peers(&self) -> Result<Vec<Peer>>;
} }
impl StorageImpl for Storage { impl StorageImpl for Storage {
async fn register_device(&self, request: &RegisterRequest) -> Result<()> { async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr> {
match self { match self {
Self::Valkey(storage) => storage.register_device(request).await, Self::Valkey(storage) => storage.register_device(request).await,
} }
@@ -26,6 +29,12 @@ impl StorageImpl for Storage {
Self::Valkey(storage) => storage.get_peers().await, Self::Valkey(storage) => storage.get_peers().await,
} }
} }
async fn deregister_device(&self, public_key: &str) -> Result<()> {
match self {
Self::Valkey(storage) => storage.deregister_device(public_key).await,
}
}
} }
pub fn get_storage_from_env() -> Result<Storage> { pub fn get_storage_from_env() -> Result<Storage> {

View File

@@ -1,6 +1,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::net::IpAddr; use std::net::{IpAddr, Ipv4Addr};
use futures::TryFutureExt;
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use redis::AsyncTypedCommands; use redis::AsyncTypedCommands;
use registry::Peer; use registry::Peer;
@@ -10,6 +11,10 @@ use crate::error::Result;
use crate::utils::WireguardPublicKey; use crate::utils::WireguardPublicKey;
use crate::{error::Error, storage::StorageImpl}; use crate::{error::Error, storage::StorageImpl};
const MESH_NETWORK_BASE: [u8; 4] = [10, 100, 0, 0];
const MESH_POOL_START: u8 = 1;
const MESH_POOL_END: u8 = 254;
pub struct ValkeyStorage { pub struct ValkeyStorage {
pub valkey_client: redis::Client, pub valkey_client: redis::Client,
} }
@@ -23,22 +28,38 @@ pub struct RegisterRequest {
} }
impl StorageImpl for ValkeyStorage { impl StorageImpl for ValkeyStorage {
async fn register_device(&self, request: &RegisterRequest) -> Result<()> { async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr> {
let mut conn = self let mut conn = self
.valkey_client .valkey_client
.get_multiplexed_async_connection() .get_multiplexed_async_connection()
.await .await
.map_err(|e| Error::valkey_get_connection(e))?; .map_err(|e| Error::valkey_get_connection(e))?;
let peer_key = format!("peer:{}", request.public_key.as_str());
let existing_mesh_ip: Option<String> = conn
.hget(&peer_key, "mesh_ip")
.await
.map_err(|e| Error::get_peer(e))?;
let mesh_ip = if let Some(ip_str) = existing_mesh_ip {
ip_str.parse::<Ipv4Addr>().unwrap()
} else {
let allocated_ip = self.allocate_mesh_ip(&mut conn).await?;
allocated_ip
};
conn.hset_multiple::<_, _, _>( conn.hset_multiple::<_, _, _>(
format!("peer:{}", request.public_key.as_str()), &peer_key,
&[ &[
("public_ip", &request.public_ip.to_string()), ("public_ip", request.public_ip.to_string()),
( (
"allowed_ips", "allowed_ips",
&serde_json::to_string(&request.allowed_ips) serde_json::to_string(&request.allowed_ips)
.map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?, .map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?,
), ),
("port", &request.port), ("port", request.port.clone()),
("mesh_ip", mesh_ip.to_string()),
], ],
) )
.await .await
@@ -49,7 +70,8 @@ impl StorageImpl for ValkeyStorage {
request.public_ip.to_string(), request.public_ip.to_string(),
) )
})?; })?;
conn.sadd("peers", &request.public_key.as_str())
conn.sadd("peers", request.public_key.as_str())
.await .await
.map_err(|e| { .map_err(|e| {
Error::add_peer( Error::add_peer(
@@ -59,6 +81,25 @@ impl StorageImpl for ValkeyStorage {
) )
})?; })?;
Ok(mesh_ip)
}
async fn deregister_device(&self, public_key: &str) -> Result<()> {
let mut conn = self
.valkey_client
.get_multiplexed_async_connection()
.await
.map_err(|e| Error::valkey_get_connection(e))?;
let hash_key = format!("peer:{public_key}");
conn.srem("peers", public_key)
.map_err(|e| Error::deregister_device(e, public_key))
.await?;
let response = conn
.del(hash_key)
.await
.map_err(|e| Error::deregister_device(e, public_key))?;
tracing::debug!("deleted hash {keys} key(s) removed", keys = response);
Ok(()) Ok(())
} }
@@ -89,21 +130,72 @@ impl StorageImpl for ValkeyStorage {
.map_err(|e| Error::get_peer(e))? .map_err(|e| Error::get_peer(e))?
.into_iter() .into_iter()
.zip(keys.iter()) .zip(keys.iter())
.map(|(peer, key): (HashMap<String, String>, &String)| { .filter_map(|(peer, key): (HashMap<String, String>, &String)| {
let allowed_ips: Vec<IpNetwork> = peer let allowed_ips: Vec<IpNetwork> = peer
.get("allowed_ips") .get("allowed_ips")
.map(|s| serde_json::from_str(s).unwrap_or_default()) .map(|s| serde_json::from_str(s).unwrap_or_default())
.unwrap_or_default(); .unwrap_or_default();
Peer { let mesh_ip: Ipv4Addr = peer.get("mesh_ip")?.parse().ok()?;
Some(Peer {
public_key: key.clone(), public_key: key.clone(),
public_ip: peer.get("public_ip").unwrap().to_string(), public_ip: peer.get("public_ip")?.to_string(),
port: peer.get("port").unwrap().to_string(), port: peer.get("port")?.to_string(),
mesh_ip,
allowed_ips, allowed_ips,
} })
}) })
.collect(); .collect();
Ok(peers) Ok(peers)
} }
} }
impl ValkeyStorage {
async fn allocate_mesh_ip(
&self,
conn: &mut redis::aio::MultiplexedConnection,
) -> Result<Ipv4Addr> {
let keys: HashSet<String> = conn
.smembers("peers")
.await
.map_err(|e| Error::get_peer(e))?;
let mut assigned_ips: HashSet<Ipv4Addr> = HashSet::new();
if !keys.is_empty() {
let mut pipe = redis::pipe();
for key in keys.iter() {
pipe.hget(format!("peer:{key}"), "mesh_ip");
}
let ips: Vec<Option<String>> = pipe
.query_async(conn)
.await
.map_err(|e| Error::get_peer(e))?;
for ip_opt in ips {
if let Some(ip_str) = ip_opt {
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
assigned_ips.insert(ip);
}
}
}
}
for last_octet in MESH_POOL_START..=MESH_POOL_END {
let candidate = Ipv4Addr::new(
MESH_NETWORK_BASE[0],
MESH_NETWORK_BASE[1],
MESH_NETWORK_BASE[2],
last_octet,
);
if !assigned_ips.contains(&candidate) {
return Ok(candidate);
}
}
Err(Error::ip_pool_exhausted("10.100.0.0/24".to_string()))
}
}

View File

@@ -1,3 +1,5 @@
use std::net::Ipv4Addr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
@@ -7,9 +9,15 @@ pub struct Peer {
pub public_key: String, pub public_key: String,
pub public_ip: String, pub public_ip: String,
pub port: String, pub port: String,
pub mesh_ip: Ipv4Addr,
pub allowed_ips: Vec<IpNetwork>, pub allowed_ips: Vec<IpNetwork>,
} }
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct RegisterResponse {
pub mesh_ip: Ipv4Addr,
}
#[derive(Serialize, Deserialize)] #[derive(Serialize, Deserialize)]
#[serde(tag = "type")] #[serde(tag = "type")]
pub enum PeerMessage { pub enum PeerMessage {

View File

@@ -1,5 +1,5 @@
use base64::Engine; use base64::Engine;
use serde::{Deserialize, de}; use serde::{de, Deserialize};
#[derive(Clone)] #[derive(Clone)]
pub struct WireguardPublicKey(String); pub struct WireguardPublicKey(String);