feat!: auto add routes to exposed subnets

This commit is contained in:
2026-02-18 19:37:08 -08:00
parent 76d60080f7
commit 17a86a5054
4 changed files with 118 additions and 43 deletions

View File

@@ -27,8 +27,6 @@ pub enum ErrorKind {
DeserializeToml(#[from] toml::de::Error), DeserializeToml(#[from] toml::de::Error),
#[error("invalid url")] #[error("invalid url")]
Url(#[from] url::ParseError), Url(#[from] url::ParseError),
#[error("invalid url scheme: {url}")]
UrlScheme { url: String },
#[error("STUN discovery failed: {message}")] #[error("STUN discovery failed: {message}")]
StunDiscovery { message: String }, StunDiscovery { message: String },
#[error("HTTP IP discovery failed")] #[error("HTTP IP discovery failed")]
@@ -86,6 +84,12 @@ pub enum ErrorKind {
#[source] #[source]
source: rtnetlink::Error, source: rtnetlink::Error,
}, },
#[error("error adding route to {destination}")]
AddRoute {
destination: String,
#[source]
source: rtnetlink::Error,
},
} }
pub type Result<T> = core::result::Result<T, Error>; pub type Result<T> = core::result::Result<T, Error>;

View File

@@ -2,6 +2,7 @@ mod config;
mod discovery; mod discovery;
mod error; mod error;
mod netlink; mod netlink;
mod state;
mod wireguard; mod wireguard;
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::{IpAddr, Ipv4Addr, SocketAddr};
@@ -10,6 +11,7 @@ use console::style;
use futures::{StreamExt, TryFutureExt, stream::SplitStream}; use futures::{StreamExt, TryFutureExt, stream::SplitStream};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use registry::{Peer, PeerMessage, RegisterResponse}; use registry::{Peer, PeerMessage, RegisterResponse};
use rtnetlink::Handle;
use serde::Serialize; use serde::Serialize;
use thiserror_ext::AsReport; use thiserror_ext::AsReport;
use tokio::{ use tokio::{
@@ -26,6 +28,7 @@ use crate::{
config::{Config, INTERFACE_NAME}, config::{Config, INTERFACE_NAME},
discovery::{PublicEndpoint, discover_public_endpoint}, discovery::{PublicEndpoint, discover_public_endpoint},
error::{Error, Result}, error::{Error, Result},
state::{AppState, Data},
}; };
#[derive(Serialize)] #[derive(Serialize)]
@@ -43,12 +46,6 @@ async fn register_self(
) -> Result<Ipv4Addr> { ) -> Result<Ipv4Addr> {
let url = format!("{}/register", &config.server.url); let url = format!("{}/register", &config.server.url);
tracing::info!(
public_ip = %endpoint.ip,
port = endpoint.port,
"registering with registry"
);
let response = client let response = client
.post(&url) .post(&url)
.json(&RegisterRequest { .json(&RegisterRequest {
@@ -77,6 +74,7 @@ async fn register_self(
Ok(register_response.mesh_ip) Ok(register_response.mesh_ip)
} }
#[allow(dead_code)]
async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> { async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> {
let url = format!( let url = format!(
"{}/deregister?public_key={}", "{}/deregister?public_key={}",
@@ -92,9 +90,13 @@ async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()
Ok(()) Ok(())
} }
fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result<()> { async fn configure_peer(
handle: &Handle,
peer: &Peer,
local_public_key: &str,
keepalive: u16,
) -> Result<()> {
if peer.public_key == local_public_key { if peer.public_key == local_public_key {
tracing::debug!(peer_key = %peer.public_key, "skipping self");
return Ok(()); return Ok(());
} }
@@ -113,6 +115,7 @@ fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result
allowed_ips.extend(peer.allowed_ips.iter().cloned()); allowed_ips.extend(peer.allowed_ips.iter().cloned());
wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?; wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?;
netlink::add_routes(handle, &allowed_ips).await?;
tracing::info!( tracing::info!(
peer_key = %peer.public_key, peer_key = %peer.public_key,
@@ -124,9 +127,14 @@ fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result
Ok(()) Ok(())
} }
fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -> Result<()> { async fn configure_all_peers(
handle: &Handle,
peers: &[Peer],
local_public_key: &str,
keepalive: u16,
) -> Result<()> {
for peer in peers { for peer in peers {
if let Err(e) = configure_peer(peer, local_public_key, keepalive) { if let Err(e) = configure_peer(handle, peer, local_public_key, keepalive).await {
tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer"); tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer");
} }
} }
@@ -134,44 +142,51 @@ fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -
} }
async fn events( async fn events(
app_state: Data<AppState>,
read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>, read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
config: &Config, config: &Config,
) -> Result<()> { ) -> Result<()> {
while let Some(msg) = read.next().await { while let Some(msg) = read.next().await {
match msg.map_err(Error::ws_read)? { match msg {
Message::Text(text) => { Ok(Message::Text(text)) => {
let server_msg: PeerMessage = let server_msg: PeerMessage =
serde_json::from_str(&text).map_err(Error::deserialize_json)?; serde_json::from_str(&text).map_err(Error::deserialize_json)?;
match server_msg { match server_msg {
PeerMessage::HydratePeers { peers } => { PeerMessage::HydratePeers { peers } => {
tracing::info!(count = peers.len(), "Received initial peer list"); tracing::info!(count = peers.len(), "received initial peer list");
configure_all_peers( configure_all_peers(
&app_state.nl_handle,
&peers, &peers,
&config.interface.public_key, &config.interface.public_key,
config.interface.persistent_keepalive, config.interface.persistent_keepalive,
)?; )
.await?;
} }
PeerMessage::PeerUpdate { peer } => { PeerMessage::PeerUpdate { peer } => {
tracing::info!( tracing::info!(
peer_key = %peer.public_key, peer_key = %peer.public_key,
mesh_ip = %peer.mesh_ip, mesh_ip = %peer.mesh_ip,
"Received peer update" "received peer update"
); );
configure_peer( configure_peer(
&app_state.nl_handle,
&peer, &peer,
&config.interface.public_key, &config.interface.public_key,
config.interface.persistent_keepalive, config.interface.persistent_keepalive,
)?; )
.await?;
} }
} }
} }
Message::Ping(_) => {} Ok(Message::Close(_)) => {
Message::Pong(_) => {}
Message::Close(_) => {
tracing::warn!("connection closed by server"); tracing::warn!("connection closed by server");
break; break;
} }
Err(tokio_tungstenite::tungstenite::Error::Protocol(e)) => {
tracing::error!(error = %e.as_report(), "protocol error");
break;
}
_ => {} _ => {}
} }
} }
@@ -187,43 +202,32 @@ async fn run() -> Result<()> {
public_port = endpoint.port, public_port = endpoint.port,
"public endpoint" "public endpoint"
); );
let http_client = reqwest::Client::new(); let app_state = Data::new(AppState {
let mesh_ip = register_self(&http_client, &endpoint, &config).await?; reqwest_client: reqwest::Client::new(),
nl_handle: netlink::connect().await?.0,
});
let mesh_ip = register_self(&app_state.reqwest_client, &endpoint, &config).await?;
wireguard::create_interface()?; wireguard::create_interface()?;
wireguard::configure_interface(&private_key, config.interface.listen_port)?; wireguard::configure_interface(&private_key, config.interface.listen_port)?;
let (nl_handle, _nl_task) = netlink::connect().await?; let nl_handle = &app_state.nl_handle;
netlink::add_address(&nl_handle, mesh_ip, 32).await?; netlink::add_address(nl_handle, mesh_ip, 32).await?;
netlink::set_link_up(&nl_handle).await?; netlink::set_link_up(nl_handle).await?;
tracing::info!(
interface = INTERFACE_NAME,
mesh_ip = %mesh_ip,
"interface configured and up"
);
let ws_url = &config.server.ws_url; let ws_url = &config.server.ws_url;
let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url)
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
.await .await
.map_err(|e| Error::ws_connect(e, ws_url))?; .map_err(|e| Error::ws_connect(e, ws_url))?;
tracing::info!(
status = ?response.status(),
"connected to registry WebSocket"
);
let (_, mut read) = ws_stream.split(); let (_, mut read) = ws_stream.split();
tokio::select! { tokio::select! {
receiver = events(&mut read, &config) => receiver?, receiver = events(app_state.clone(), &mut read, &config) => receiver?,
_ = signal::ctrl_c() => { _ = signal::ctrl_c() => {
}, },
_ = on_shutdown() => {} _ = on_shutdown() => {}
}; };
tracing::debug!("gracefully shutting down"); netlink::delete_interface(nl_handle, INTERFACE_NAME).await?;
netlink::delete_interface(&nl_handle, INTERFACE_NAME).await?;
deregister_self(&http_client, &config).await?;
tracing::info!("connection closed"); tracing::info!("connection closed");
Ok(()) Ok(())
} }

View File

@@ -1,10 +1,12 @@
use std::net::Ipv4Addr; use std::net::Ipv4Addr;
use futures::TryStreamExt; use futures::TryStreamExt;
use ipnetwork::IpNetwork;
use rtnetlink::Handle; use rtnetlink::Handle;
use crate::config::INTERFACE_NAME; use crate::config::INTERFACE_NAME;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::state::{AppState, Data};
async fn get_interface_index(handle: &Handle, name: &str) -> Result<u32> { async fn get_interface_index(handle: &Handle, name: &str) -> Result<u32> {
let mut links = handle.link().get().match_name(name.to_string()).execute(); let mut links = handle.link().get().match_name(name.to_string()).execute();
@@ -45,6 +47,48 @@ pub async fn add_address(handle: &Handle, addr: Ipv4Addr, prefix_len: u8) -> Res
} }
} }
pub async fn add_routes(handle: &Handle, routes: &[IpNetwork]) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?;
for route in routes {
let result = match route {
IpNetwork::V4(network) => {
handle
.route()
.add()
.v4()
.destination_prefix(network.ip(), network.prefix())
.output_interface(index)
.execute()
.await
}
IpNetwork::V6(network) => {
handle
.route()
.add()
.v6()
.destination_prefix(network.ip(), network.prefix())
.output_interface(index)
.execute()
.await
}
};
match result {
Ok(()) => {}
Err(rtnetlink::Error::NetlinkError(e)) if e.raw_code() == -libc::EEXIST => {
tracing::debug!(
destination = %route,
"route already exists"
);
}
Err(e) => return Err(Error::add_route(e, route.to_string())),
}
}
Ok(())
}
pub async fn set_link_up(handle: &Handle) -> Result<()> { pub async fn set_link_up(handle: &Handle) -> Result<()> {
let index = get_interface_index(handle, INTERFACE_NAME).await?; let index = get_interface_index(handle, INTERFACE_NAME).await?;

23
client/src/state.rs Normal file
View File

@@ -0,0 +1,23 @@
use std::{ops::Deref, sync::Arc};
#[derive(Clone)]
pub struct AppState {
pub reqwest_client: reqwest::Client,
pub nl_handle: rtnetlink::Handle,
}
#[derive(Clone)]
pub struct Data<T>(Arc<T>);
impl<T> Deref for Data<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> Data<T> {
pub fn new(inner: T) -> Self {
Self(Arc::new(inner))
}
}