refactor!: use rtnetlink for client interfaces, add deregistration

This commit is contained in:
2026-02-17 22:59:50 -08:00
parent a022c18ff9
commit 03f38b9ee3
20 changed files with 975 additions and 167 deletions

View File

@@ -11,7 +11,7 @@ serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.149"
thiserror = "2.0.18"
thiserror-ext = "0.3.0"
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread"] }
tokio = { version = "1.49.0", features = ["macros", "net", "rt-multi-thread", "signal"] }
tokio-tungstenite = { version = "0.28.0", features = ["rustls-tls-native-roots"] }
tracing = "0.1.44"
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
@@ -22,3 +22,12 @@ url = "2.5.8"
reqwest = { version = "0.13.2", features = ["json"] }
stunclient = "0.4.2"
ipnetwork = { version = "0.21.1", features = ["serde"] }
base64 = "0.22.1"
libc = "0.2"
# WireGuard netlink control
wireguard-control = "1.1"
# Network interface management
rtnetlink = "0.14"
netlink-packet-route = "0.19"

View File

@@ -1,22 +0,0 @@
use std::{ops::Deref, sync::Arc};
#[derive(Clone)]
pub struct AppState {
pub reqwest_client: reqwest::Client,
}
#[derive(Clone)]
pub struct Data<T>(Arc<T>);
impl<T> Deref for Data<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> Data<T> {
pub fn new(inner: T) -> Self {
Self(Arc::new(inner))
}
}

View File

@@ -4,19 +4,33 @@ use crate::error::{Error, Result};
use ipnetwork::IpNetwork;
use serde::Deserialize;
pub const DEFAULT_KEEPALIVE: u16 = 25;
pub const INTERFACE_NAME: &str = "mesh0";
#[derive(Deserialize)]
pub struct Config {
pub interface: InterfaceConfig,
pub server: ServerConfig,
}
#[derive(Deserialize)]
pub struct InterfaceConfig {
pub private_key: String,
pub public_key: String,
#[serde(default = "default_listen_port")]
pub listen_port: u16,
pub address: String,
pub allowed_ips: Vec<IpNetwork>,
#[serde(default = "default_keepalive")]
pub persistent_keepalive: u16,
pub allowed_ips: Option<Vec<IpNetwork>>,
}
fn default_listen_port() -> u16 {
51820
}
fn default_keepalive() -> u16 {
DEFAULT_KEEPALIVE
}
#[derive(Deserialize)]
pub struct ServerConfig {
pub ws_url: String,
@@ -36,7 +50,3 @@ pub fn base_path() -> PathBuf {
.unwrap_or_else(|| PathBuf::from("."))
.join("wg-mesh")
}
pub fn wg_config_path() -> PathBuf {
PathBuf::from("/etc/wireguard/mesh0.conf")
}

View File

@@ -29,12 +29,6 @@ pub enum ErrorKind {
Url(#[from] url::ParseError),
#[error("invalid url scheme: {url}")]
UrlScheme { url: String },
#[error("error writing wireguard config to {path}")]
WriteConfig {
path: PathBuf,
#[source]
source: std::io::Error,
},
#[error("STUN discovery failed: {message}")]
StunDiscovery { message: String },
#[error("HTTP IP discovery failed")]
@@ -45,6 +39,53 @@ pub enum ErrorKind {
DiscoveryFailed,
#[error("error with request")]
Reqwest(#[from] reqwest::Error),
#[error("invalid base64 key: {context}")]
InvalidKey { context: String },
#[error("error creating WireGuard interface {interface}")]
CreateInterface {
interface: String,
#[source]
source: std::io::Error,
},
#[error("error configuring WireGuard device {interface}")]
ConfigureDevice {
interface: String,
#[source]
source: std::io::Error,
},
#[error("error with netlink operation: {context}")]
Netlink { context: String },
#[error("error setting interface address")]
SetAddress {
#[source]
source: rtnetlink::Error,
},
#[error("error bringing interface up")]
SetLinkUp {
#[source]
source: rtnetlink::Error,
},
#[error("error getting interface index for {interface}")]
GetInterface { interface: String },
#[error("registration failed")]
Registration {
#[source]
source: reqwest::Error,
},
#[error("error deregistering device {public_key}")]
Deregister {
public_key: String,
#[source]
source: reqwest::Error,
},
#[error("registration failed with status {status}: {body}")]
RegistrationStatus { status: u16, body: String },
#[error("error deleting interface")]
DeleteInterface {
name: String,
#[source]
source: rtnetlink::Error,
},
}
pub type Result<T> = core::result::Result<T, Error>;

View File

@@ -1,27 +1,29 @@
mod app_state;
mod config;
mod discovery;
mod error;
mod netlink;
mod wireguard;
use std::{fs::OpenOptions, io::Write, net::IpAddr, os::unix::fs::OpenOptionsExt};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use console::style;
use futures::StreamExt;
use futures::{StreamExt, TryFutureExt, stream::SplitStream};
use ipnetwork::IpNetwork;
use registry::{Peer, PeerMessage};
use registry::{Peer, PeerMessage, RegisterResponse};
use serde::Serialize;
use thiserror_ext::AsReport;
use tokio_tungstenite::tungstenite::Message;
use tokio::{
net::TcpStream,
signal::{self, unix::SignalKind},
};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
use tracing::level_filters::LevelFilter;
use tracing_subscriber::{
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
};
use url::Url;
use crate::{
app_state::{AppState, Data},
config::{Config, InterfaceConfig, wg_config_path},
config::{Config, INTERFACE_NAME},
discovery::{PublicEndpoint, discover_public_endpoint},
error::{Error, Result},
};
@@ -34,105 +36,209 @@ pub struct RegisterRequest {
pub allowed_ips: Vec<IpNetwork>,
}
fn parse_ws_url(input: &Url) -> Result<String> {
let url = input.join("/ws/peers")?;
if url.scheme() != "ws" && url.scheme() != "wss" {
return Err(Error::url_scheme(url.to_string()));
}
Ok(url.to_string())
}
fn write_wg_config(interface: &InterfaceConfig, peers: &[Peer]) -> Result<()> {
let path = wg_config_path();
let config = wireguard::generate_config(interface, peers);
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)
.map_err(|e| Error::write_config(e, &path))?;
file.write_all(config.as_bytes())
.map_err(|e| Error::write_config(e, &path))?;
tracing::info!("wrote {} with {} peers", path.display(), peers.len());
Ok(())
}
async fn register_self(
app_state: Data<AppState>,
client: &reqwest::Client,
endpoint: &PublicEndpoint,
config: &InterfaceConfig,
url: &str,
) -> Result<()> {
app_state
.reqwest_client
.post(url)
config: &Config,
) -> Result<Ipv4Addr> {
let url = format!("{}/register", &config.server.url);
tracing::info!(
public_ip = %endpoint.ip,
port = endpoint.port,
"registering with registry"
);
let response = client
.post(&url)
.json(&RegisterRequest {
public_key: config.public_key.clone(),
public_key: config.interface.public_key.clone(),
public_ip: endpoint.ip,
port: endpoint.port.to_string(),
allowed_ips: config.allowed_ips.clone(),
allowed_ips: config.interface.allowed_ips.clone().unwrap_or(vec![]),
})
.send()
.await
.map_err(Error::registration)?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
return Err(Error::registration_status(status, body));
}
let register_response: RegisterResponse = response.json().await.map_err(Error::registration)?;
tracing::info!(
mesh_ip = %register_response.mesh_ip,
"registration successful, assigned mesh IP"
);
Ok(register_response.mesh_ip)
}
async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> {
let url = format!(
"{}/deregister?public_key={}",
config.server.url, config.interface.public_key
);
client
.delete(&url)
.send()
.await?
.error_for_status()?;
.error_for_status()
.map_err(|e| Error::deregister(e, &config.interface.public_key))?;
Ok(())
}
async fn run() -> crate::error::Result<()> {
let config = Config::load()?;
let ws_url = &config.server.ws_url;
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| Error::ws_connect(e, ws_url))?;
let (_, mut read) = ws_stream.split();
let app_state = Data::new(AppState {
reqwest_client: reqwest::Client::new(),
});
tracing::info!("connected, response: {:?}", response.status());
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
register_self(
app_state.clone(),
&endpoint,
&config.interface,
&format!("{}/register", &config.server.url),
)
.await?;
fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result<()> {
if peer.public_key == local_public_key {
tracing::debug!(peer_key = %peer.public_key, "skipping self");
return Ok(());
}
let mut peers: Vec<Peer> = Vec::new();
let peer_key = wireguard::parse_key(&peer.public_key)?;
let endpoint: SocketAddr = format!("{}:{}", peer.public_ip, peer.port)
.parse()
.map_err(|_| {
Error::netlink(format!(
"invalid endpoint: {}:{}",
peer.public_ip, peer.port
))
})?;
let mut allowed_ips: Vec<IpNetwork> =
vec![IpNetwork::new(std::net::IpAddr::V4(peer.mesh_ip), 32).expect("valid network")];
allowed_ips.extend(peer.allowed_ips.iter().cloned());
wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?;
tracing::info!(
peer_key = %peer.public_key,
mesh_ip = %peer.mesh_ip,
endpoint = %endpoint,
"configured peer"
);
Ok(())
}
fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -> Result<()> {
for peer in peers {
if let Err(e) = configure_peer(peer, local_public_key, keepalive) {
tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer");
}
}
Ok(())
}
async fn events(
read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
config: &Config,
) -> Result<()> {
while let Some(msg) = read.next().await {
match msg.map_err(|e| Error::ws_read(e))? {
match msg.map_err(Error::ws_read)? {
Message::Text(text) => {
let server_msg: PeerMessage =
serde_json::from_str(&text).map_err(|e| Error::deserialize_json(e))?;
serde_json::from_str(&text).map_err(Error::deserialize_json)?;
match server_msg {
PeerMessage::HydratePeers { peers: new_peers } => {
tracing::info!("received {} peers", new_peers.len());
peers = new_peers;
write_wg_config(&config.interface, &peers)?;
PeerMessage::HydratePeers { peers } => {
tracing::info!(count = peers.len(), "Received initial peer list");
configure_all_peers(
&peers,
&config.interface.public_key,
config.interface.persistent_keepalive,
)?;
}
PeerMessage::PeerUpdate { peer } => {
tracing::info!("peer update: {}", peer.public_key);
if let Some(existing) =
peers.iter_mut().find(|p| p.public_key == peer.public_key)
{
*existing = peer;
} else {
peers.push(peer);
}
write_wg_config(&config.interface, &peers)?;
tracing::info!(
peer_key = %peer.public_key,
mesh_ip = %peer.mesh_ip,
"Received peer update"
);
configure_peer(
&peer,
&config.interface.public_key,
config.interface.persistent_keepalive,
)?;
}
}
}
Message::Ping(_) => {}
Message::Pong(_) => {}
Message::Close(_) => {
tracing::warn!("connection closed by server");
break;
}
_ => {}
}
}
Ok(())
}
async fn run() -> Result<()> {
let config = Config::load()?;
let private_key = wireguard::parse_key(&config.interface.private_key)?;
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
tracing::info!(
public_ip = %endpoint.ip,
public_port = endpoint.port,
"public endpoint"
);
let http_client = reqwest::Client::new();
let mesh_ip = register_self(&http_client, &endpoint, &config).await?;
wireguard::create_interface()?;
wireguard::configure_interface(&private_key, config.interface.listen_port)?;
let (nl_handle, _nl_task) = netlink::connect().await?;
netlink::add_address(&nl_handle, mesh_ip, 32).await?;
netlink::set_link_up(&nl_handle).await?;
tracing::info!(
interface = INTERFACE_NAME,
mesh_ip = %mesh_ip,
"interface configured and up"
);
let ws_url = &config.server.ws_url;
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| Error::ws_connect(e, ws_url))?;
tracing::info!(
status = ?response.status(),
"connected to registry WebSocket"
);
let (_, mut read) = ws_stream.split();
tokio::select! {
receiver = events(&mut read, &config) => receiver?,
_ = signal::ctrl_c() => {
},
_ = on_shutdown() => {}
};
tracing::debug!("gracefully shutting down");
netlink::delete_interface(&nl_handle, INTERFACE_NAME).await?;
deregister_self(&http_client, &config).await?;
tracing::info!("connection closed");
Ok(())
}
async fn on_shutdown() {
use tokio::signal::unix::signal;
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM");
let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT");
tokio::select! {
_ = sigterm.recv() => tracing::info!("received sigterm"),
_ = sigint.recv() => tracing::info!("received sigint")
}
}
#[tokio::main]
async fn main() {
let tracing_env_filter = EnvFilter::builder()

84
client/src/netlink.rs Normal file
View File

@@ -0,0 +1,84 @@
use std::net::Ipv4Addr;
use futures::TryStreamExt;
use rtnetlink::Handle;
use crate::config::INTERFACE_NAME;
use crate::error::{Error, Result};
async fn get_interface_index(handle: &Handle, name: &str) -> Result<u32> {
let mut links = handle.link().get().match_name(name.to_string()).execute();
if let Some(link) = links.try_next().await.map_err(Error::set_link_up)? {
Ok(link.header.index)
} else {
Err(Error::get_interface(name.to_string()))
}
}
pub async fn add_address(handle: &Handle, addr: Ipv4Addr, prefix_len: u8) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?;
tracing::debug!(
interface = INTERFACE_NAME,
address = %addr,
prefix_len,
"adding address to interface"
);
match handle
.address()
.add(index, std::net::IpAddr::V4(addr), prefix_len)
.execute()
.await
{
Ok(()) => Ok(()),
Err(rtnetlink::Error::NetlinkError(e)) if e.raw_code() == -libc::EEXIST => {
tracing::debug!(
interface = INTERFACE_NAME,
address = %addr,
"address already exists on interface"
);
Ok(())
}
Err(e) => Err(Error::set_address(e)),
}
}
pub async fn set_link_up(handle: &Handle) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?;
tracing::info!(interface = INTERFACE_NAME, "set interface up");
handle
.link()
.set(index)
.up()
.execute()
.await
.map_err(Error::set_link_up)?;
Ok(())
}
pub async fn delete_interface(handle: &Handle, name: &str) -> Result<()> {
let index = get_interface_index(&handle, name).await?;
handle
.link()
.del(index)
.execute()
.await
.map_err(|e| Error::delete_interface(e, name))?;
Ok(())
}
pub async fn connect() -> Result<(Handle, tokio::task::JoinHandle<()>)> {
let (connection, handle, _) = rtnetlink::new_connection()
.map_err(|e| Error::netlink(format!("failed to create netlink connection: {}", e)))?;
let join_handle = tokio::spawn(connection);
Ok((handle, join_handle))
}

View File

@@ -1,24 +1,117 @@
use registry::Peer;
use std::net::SocketAddr;
use crate::config::InterfaceConfig;
use ipnetwork::IpNetwork;
use wireguard_control::{
AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, InvalidKey, Key, PeerConfigBuilder,
};
pub fn generate_config(interface: &InterfaceConfig, peers: &[Peer]) -> String {
let mut config = format!(
"[Interface]\nPrivateKey = {}\nListenPort = {}\nAddress = {}\n",
interface.private_key, interface.listen_port, interface.address,
);
for peer in peers {
config.push_str(&format!(
"\n[Peer]\nPublicKey = {}\nEndpoint = {}:{}\nAllowedIPs = {}\n",
peer.public_key,
peer.public_ip,
peer.port,
peer.allowed_ips
.iter()
.map(|ip| ip.to_string())
.collect::<Vec<_>>()
.join(", "),
));
}
config
use crate::config::INTERFACE_NAME;
use crate::error::{Error, Result};
pub fn parse_key(key_b64: &str) -> Result<Key> {
Key::from_base64(key_b64).map_err(|e: InvalidKey| Error::invalid_key(e.to_string()))
}
pub fn interface_name() -> InterfaceName {
INTERFACE_NAME.parse().expect("valid interface name")
}
pub fn create_interface() -> Result<()> {
let name = interface_name();
if Device::get(&name, Backend::Kernel).is_ok() {
tracing::debug!(interface = INTERFACE_NAME, "interface already exists");
return Ok(());
}
tracing::info!(interface = INTERFACE_NAME, "creating WireGuard interface");
DeviceUpdate::new()
.apply(&name, Backend::Kernel)
.map_err(|e| Error::create_interface(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
pub fn configure_interface(private_key: &Key, listen_port: u16) -> Result<()> {
let name = interface_name();
tracing::debug!(
interface = INTERFACE_NAME,
listen_port,
"configuring WireGuard interface"
);
DeviceUpdate::new()
.set_private_key(private_key.clone())
.set_listen_port(listen_port)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
pub fn configure_peer(
public_key: &Key,
endpoint: Option<SocketAddr>,
allowed_ips: &[IpNetwork],
persistent_keepalive: Option<u16>,
) -> Result<()> {
let name = interface_name();
let allowed: Vec<AllowedIp> = allowed_ips
.iter()
.map(|ip| AllowedIp {
address: ip.ip(),
cidr: ip.prefix(),
})
.collect();
let mut peer = PeerConfigBuilder::new(public_key)
.replace_allowed_ips()
.add_allowed_ips(&allowed);
if let Some(ep) = endpoint {
peer = peer.set_endpoint(ep);
}
if let Some(keepalive) = persistent_keepalive {
peer = peer.set_persistent_keepalive_interval(keepalive);
}
tracing::debug!(
peer_key = %public_key.to_base64(),
endpoint = ?endpoint,
allowed_ips = ?allowed_ips,
"configuring peer"
);
DeviceUpdate::new()
.add_peer(peer)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn remove_peer(public_key: &Key) -> Result<()> {
let name = interface_name();
tracing::info!(
peer_key = %public_key.to_base64(),
"removing peer"
);
DeviceUpdate::new()
.remove_peer_by_key(public_key)
.apply(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
Ok(())
}
#[allow(dead_code)]
pub fn get_device() -> Result<Device> {
let name = interface_name();
Device::get(&name, Backend::Kernel)
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))
}