diff --git a/client/src/error.rs b/client/src/error.rs index b955e46..50634dc 100644 --- a/client/src/error.rs +++ b/client/src/error.rs @@ -27,8 +27,6 @@ pub enum ErrorKind { DeserializeToml(#[from] toml::de::Error), #[error("invalid url")] Url(#[from] url::ParseError), - #[error("invalid url scheme: {url}")] - UrlScheme { url: String }, #[error("STUN discovery failed: {message}")] StunDiscovery { message: String }, #[error("HTTP IP discovery failed")] @@ -86,6 +84,12 @@ pub enum ErrorKind { #[source] source: rtnetlink::Error, }, + #[error("error adding route to {destination}")] + AddRoute { + destination: String, + #[source] + source: rtnetlink::Error, + }, } pub type Result = core::result::Result; diff --git a/client/src/main.rs b/client/src/main.rs index 4dc09c0..a7f81e6 100644 --- a/client/src/main.rs +++ b/client/src/main.rs @@ -2,6 +2,7 @@ mod config; mod discovery; mod error; mod netlink; +mod state; mod wireguard; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; @@ -10,6 +11,7 @@ use console::style; use futures::{StreamExt, TryFutureExt, stream::SplitStream}; use ipnetwork::IpNetwork; use registry::{Peer, PeerMessage, RegisterResponse}; +use rtnetlink::Handle; use serde::Serialize; use thiserror_ext::AsReport; use tokio::{ @@ -26,6 +28,7 @@ use crate::{ config::{Config, INTERFACE_NAME}, discovery::{PublicEndpoint, discover_public_endpoint}, error::{Error, Result}, + state::{AppState, Data}, }; #[derive(Serialize)] @@ -43,12 +46,6 @@ async fn register_self( ) -> Result { let url = format!("{}/register", &config.server.url); - tracing::info!( - public_ip = %endpoint.ip, - port = endpoint.port, - "registering with registry" - ); - let response = client .post(&url) .json(&RegisterRequest { @@ -77,6 +74,7 @@ async fn register_self( Ok(register_response.mesh_ip) } +#[allow(dead_code)] async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> { let url = format!( "{}/deregister?public_key={}", @@ -92,9 +90,13 @@ async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<() Ok(()) } -fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result<()> { +async fn configure_peer( + handle: &Handle, + peer: &Peer, + local_public_key: &str, + keepalive: u16, +) -> Result<()> { if peer.public_key == local_public_key { - tracing::debug!(peer_key = %peer.public_key, "skipping self"); return Ok(()); } @@ -113,6 +115,7 @@ fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result allowed_ips.extend(peer.allowed_ips.iter().cloned()); wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?; + netlink::add_routes(handle, &allowed_ips).await?; tracing::info!( peer_key = %peer.public_key, @@ -124,9 +127,14 @@ fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result Ok(()) } -fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -> Result<()> { +async fn configure_all_peers( + handle: &Handle, + peers: &[Peer], + local_public_key: &str, + keepalive: u16, +) -> Result<()> { for peer in peers { - if let Err(e) = configure_peer(peer, local_public_key, keepalive) { + if let Err(e) = configure_peer(handle, peer, local_public_key, keepalive).await { tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer"); } } @@ -134,44 +142,51 @@ fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) - } async fn events( + app_state: Data, read: &mut SplitStream>>, config: &Config, ) -> Result<()> { while let Some(msg) = read.next().await { - match msg.map_err(Error::ws_read)? { - Message::Text(text) => { + match msg { + Ok(Message::Text(text)) => { let server_msg: PeerMessage = serde_json::from_str(&text).map_err(Error::deserialize_json)?; match server_msg { PeerMessage::HydratePeers { peers } => { - tracing::info!(count = peers.len(), "Received initial peer list"); + tracing::info!(count = peers.len(), "received initial peer list"); configure_all_peers( + &app_state.nl_handle, &peers, &config.interface.public_key, config.interface.persistent_keepalive, - )?; + ) + .await?; } PeerMessage::PeerUpdate { peer } => { tracing::info!( peer_key = %peer.public_key, mesh_ip = %peer.mesh_ip, - "Received peer update" + "received peer update" ); configure_peer( + &app_state.nl_handle, &peer, &config.interface.public_key, config.interface.persistent_keepalive, - )?; + ) + .await?; } } } - Message::Ping(_) => {} - Message::Pong(_) => {} - Message::Close(_) => { + Ok(Message::Close(_)) => { tracing::warn!("connection closed by server"); break; } + Err(tokio_tungstenite::tungstenite::Error::Protocol(e)) => { + tracing::error!(error = %e.as_report(), "protocol error"); + break; + } _ => {} } } @@ -187,43 +202,32 @@ async fn run() -> Result<()> { public_port = endpoint.port, "public endpoint" ); - let http_client = reqwest::Client::new(); - let mesh_ip = register_self(&http_client, &endpoint, &config).await?; + let app_state = Data::new(AppState { + reqwest_client: reqwest::Client::new(), + nl_handle: netlink::connect().await?.0, + }); + let mesh_ip = register_self(&app_state.reqwest_client, &endpoint, &config).await?; wireguard::create_interface()?; wireguard::configure_interface(&private_key, config.interface.listen_port)?; - let (nl_handle, _nl_task) = netlink::connect().await?; - netlink::add_address(&nl_handle, mesh_ip, 32).await?; - netlink::set_link_up(&nl_handle).await?; - - tracing::info!( - interface = INTERFACE_NAME, - mesh_ip = %mesh_ip, - "interface configured and up" - ); + let nl_handle = &app_state.nl_handle; + netlink::add_address(nl_handle, mesh_ip, 32).await?; + netlink::set_link_up(nl_handle).await?; let ws_url = &config.server.ws_url; - - let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url) + let (ws_stream, _) = tokio_tungstenite::connect_async(ws_url) .await .map_err(|e| Error::ws_connect(e, ws_url))?; - tracing::info!( - status = ?response.status(), - "connected to registry WebSocket" - ); - let (_, mut read) = ws_stream.split(); tokio::select! { - receiver = events(&mut read, &config) => receiver?, + receiver = events(app_state.clone(), &mut read, &config) => receiver?, _ = signal::ctrl_c() => { }, _ = on_shutdown() => {} }; - tracing::debug!("gracefully shutting down"); - netlink::delete_interface(&nl_handle, INTERFACE_NAME).await?; - deregister_self(&http_client, &config).await?; + netlink::delete_interface(nl_handle, INTERFACE_NAME).await?; tracing::info!("connection closed"); Ok(()) } diff --git a/client/src/netlink.rs b/client/src/netlink.rs index 6915ff4..da9d4ae 100644 --- a/client/src/netlink.rs +++ b/client/src/netlink.rs @@ -1,10 +1,12 @@ use std::net::Ipv4Addr; use futures::TryStreamExt; +use ipnetwork::IpNetwork; use rtnetlink::Handle; use crate::config::INTERFACE_NAME; use crate::error::{Error, Result}; +use crate::state::{AppState, Data}; async fn get_interface_index(handle: &Handle, name: &str) -> Result { let mut links = handle.link().get().match_name(name.to_string()).execute(); @@ -45,6 +47,48 @@ pub async fn add_address(handle: &Handle, addr: Ipv4Addr, prefix_len: u8) -> Res } } +pub async fn add_routes(handle: &Handle, routes: &[IpNetwork]) -> Result<()> { + let index = get_interface_index(handle, INTERFACE_NAME).await?; + + for route in routes { + let result = match route { + IpNetwork::V4(network) => { + handle + .route() + .add() + .v4() + .destination_prefix(network.ip(), network.prefix()) + .output_interface(index) + .execute() + .await + } + IpNetwork::V6(network) => { + handle + .route() + .add() + .v6() + .destination_prefix(network.ip(), network.prefix()) + .output_interface(index) + .execute() + .await + } + }; + + match result { + Ok(()) => {} + Err(rtnetlink::Error::NetlinkError(e)) if e.raw_code() == -libc::EEXIST => { + tracing::debug!( + destination = %route, + "route already exists" + ); + } + Err(e) => return Err(Error::add_route(e, route.to_string())), + } + } + + Ok(()) +} + pub async fn set_link_up(handle: &Handle) -> Result<()> { let index = get_interface_index(handle, INTERFACE_NAME).await?; diff --git a/client/src/state.rs b/client/src/state.rs new file mode 100644 index 0000000..8f2b291 --- /dev/null +++ b/client/src/state.rs @@ -0,0 +1,23 @@ +use std::{ops::Deref, sync::Arc}; + +#[derive(Clone)] +pub struct AppState { + pub reqwest_client: reqwest::Client, + pub nl_handle: rtnetlink::Handle, +} + +#[derive(Clone)] +pub struct Data(Arc); + +impl Deref for Data { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Data { + pub fn new(inner: T) -> Self { + Self(Arc::new(inner)) + } +}