Compare commits
3 Commits
2d537c3fbd
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
03f38b9ee3
|
|||
|
a022c18ff9
|
|||
|
3cac99c24c
|
1197
Cargo.lock
generated
1197
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,3 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["registry", "client"]
|
members = ["registry", "client"]
|
||||||
resolver = "3"
|
resolver = "3"
|
||||||
|
|
||||||
[workspace.dependencies]
|
|
||||||
|
|||||||
@@ -11,8 +11,23 @@ serde = { version = "1.0.228", features = ["derive"] }
|
|||||||
serde_json = "1.0.149"
|
serde_json = "1.0.149"
|
||||||
thiserror = "2.0.18"
|
thiserror = "2.0.18"
|
||||||
thiserror-ext = "0.3.0"
|
thiserror-ext = "0.3.0"
|
||||||
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread"] }
|
tokio = { version = "1.49.0", features = ["macros", "net", "rt-multi-thread", "signal"] }
|
||||||
tokio-tungstenite = { version = "0.28.0", features = ["rustls-tls-native-roots"] }
|
tokio-tungstenite = { version = "0.28.0", features = ["rustls-tls-native-roots"] }
|
||||||
tracing = "0.1.44"
|
tracing = "0.1.44"
|
||||||
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3.22", features = ["env-filter"] }
|
||||||
registry = { path = "../registry" }
|
registry = { path = "../registry" }
|
||||||
|
dirs = "6.0.0"
|
||||||
|
toml = "1.0.1"
|
||||||
|
url = "2.5.8"
|
||||||
|
reqwest = { version = "0.13.2", features = ["json"] }
|
||||||
|
stunclient = "0.4.2"
|
||||||
|
ipnetwork = { version = "0.21.1", features = ["serde"] }
|
||||||
|
base64 = "0.22.1"
|
||||||
|
libc = "0.2"
|
||||||
|
|
||||||
|
# WireGuard netlink control
|
||||||
|
wireguard-control = "1.1"
|
||||||
|
|
||||||
|
# Network interface management
|
||||||
|
rtnetlink = "0.14"
|
||||||
|
netlink-packet-route = "0.19"
|
||||||
|
|||||||
52
client/src/config.rs
Normal file
52
client/src/config.rs
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub const DEFAULT_KEEPALIVE: u16 = 25;
|
||||||
|
pub const INTERFACE_NAME: &str = "mesh0";
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub interface: InterfaceConfig,
|
||||||
|
pub server: ServerConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct InterfaceConfig {
|
||||||
|
pub private_key: String,
|
||||||
|
pub public_key: String,
|
||||||
|
#[serde(default = "default_listen_port")]
|
||||||
|
pub listen_port: u16,
|
||||||
|
#[serde(default = "default_keepalive")]
|
||||||
|
pub persistent_keepalive: u16,
|
||||||
|
pub allowed_ips: Option<Vec<IpNetwork>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_listen_port() -> u16 {
|
||||||
|
51820
|
||||||
|
}
|
||||||
|
fn default_keepalive() -> u16 {
|
||||||
|
DEFAULT_KEEPALIVE
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct ServerConfig {
|
||||||
|
pub ws_url: String,
|
||||||
|
pub url: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn load() -> Result<Self> {
|
||||||
|
let path = base_path().join("config.toml");
|
||||||
|
let bytes = std::fs::read(&path).map_err(|e| Error::read_config(e, &path))?;
|
||||||
|
Ok(toml::from_slice(&bytes)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn base_path() -> PathBuf {
|
||||||
|
dirs::config_dir()
|
||||||
|
.unwrap_or_else(|| PathBuf::from("."))
|
||||||
|
.join("wg-mesh")
|
||||||
|
}
|
||||||
142
client/src/discovery.rs
Normal file
142
client/src/discovery.rs
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use stunclient::StunClient;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
const STUN_SERVERS: &[&str] = &[
|
||||||
|
"stun.l.google.com:19302",
|
||||||
|
"stun.cloudflare.com:3478",
|
||||||
|
"stun.stunprotocol.org:3478",
|
||||||
|
];
|
||||||
|
|
||||||
|
const HTTP_IP_SERVICES: &[&str] = &["https://icanhazip.com", "https://ifconfig.me/ip"];
|
||||||
|
|
||||||
|
const STUN_TIMEOUT: Duration = Duration::from_secs(3);
|
||||||
|
const HTTP_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PublicEndpoint {
|
||||||
|
pub ip: IpAddr,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PublicEndpoint {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}:{}", self.ip, self.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn discover_public_endpoint(local_port: u16) -> Result<PublicEndpoint> {
|
||||||
|
match stun_discover(local_port) {
|
||||||
|
Ok(endpoint) => {
|
||||||
|
tracing::debug!(
|
||||||
|
ip = %endpoint.ip,
|
||||||
|
port = endpoint.port,
|
||||||
|
"STUN discovery successful"
|
||||||
|
);
|
||||||
|
return Ok(endpoint);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("STUN discovery failed, falling back to HTTP: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip = http_discover_ip().await?;
|
||||||
|
let endpoint = PublicEndpoint {
|
||||||
|
ip,
|
||||||
|
port: local_port,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
ip = %endpoint.ip,
|
||||||
|
port = endpoint.port,
|
||||||
|
"HTTP fallback discovery successful (port may not reflect NAT mapping)"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stun_discover(local_port: u16) -> Result<PublicEndpoint> {
|
||||||
|
let local_addr: SocketAddr = format!("0.0.0.0:{local_port}")
|
||||||
|
.parse()
|
||||||
|
.expect("valid socket addr");
|
||||||
|
|
||||||
|
let socket = UdpSocket::bind(local_addr).map_err(|e| {
|
||||||
|
Error::stun_discovery(format!(
|
||||||
|
"failed to bind UDP socket on port {local_port}: {e}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
socket
|
||||||
|
.set_read_timeout(Some(STUN_TIMEOUT))
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("failed to set socket timeout: {e}")))?;
|
||||||
|
|
||||||
|
for server in STUN_SERVERS {
|
||||||
|
tracing::debug!(server, "Attempting STUN discovery");
|
||||||
|
|
||||||
|
match stun_query(&socket, server) {
|
||||||
|
Ok(endpoint) => return Ok(endpoint),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(server, error = %e, "STUN server failed, trying next");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::discovery_failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stun_query(socket: &UdpSocket, server: &str) -> Result<PublicEndpoint> {
|
||||||
|
let server_addr = server
|
||||||
|
.to_socket_addrs()
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| Error::stun_discovery(format!("{server}: no addresses found")))?;
|
||||||
|
|
||||||
|
let client = StunClient::new(server_addr);
|
||||||
|
|
||||||
|
let public_addr = client
|
||||||
|
.query_external_address(socket)
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?;
|
||||||
|
|
||||||
|
Ok(PublicEndpoint {
|
||||||
|
ip: public_addr.ip(),
|
||||||
|
port: public_addr.port(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn http_discover_ip() -> Result<IpAddr> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(HTTP_TIMEOUT)
|
||||||
|
.build()
|
||||||
|
.map_err(Error::http_discovery)?;
|
||||||
|
|
||||||
|
for service in HTTP_IP_SERVICES {
|
||||||
|
tracing::debug!(service, "Attempting HTTP IP discovery");
|
||||||
|
|
||||||
|
match http_query_ip(&client, service).await {
|
||||||
|
Ok(ip) => return Ok(ip),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(service, error = %e, "HTTP service failed, trying next");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::discovery_failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn http_query_ip(client: &reqwest::Client, url: &str) -> Result<IpAddr> {
|
||||||
|
let response = client
|
||||||
|
.get(url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(Error::http_discovery)?;
|
||||||
|
|
||||||
|
let body = response.text().await.map_err(Error::http_discovery)?;
|
||||||
|
let ip_str = body.trim();
|
||||||
|
|
||||||
|
ip_str
|
||||||
|
.parse::<IpAddr>()
|
||||||
|
.map_err(|_| Error::parse_ip(ip_str.to_string()))
|
||||||
|
}
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use thiserror_ext::{Box, Construct};
|
use thiserror_ext::{Box, Construct};
|
||||||
use tokio_tungstenite::tungstenite;
|
use tokio_tungstenite::tungstenite;
|
||||||
@@ -15,6 +17,75 @@ pub enum ErrorKind {
|
|||||||
WsRead(#[source] tungstenite::Error),
|
WsRead(#[source] tungstenite::Error),
|
||||||
#[error("error deserializing json")]
|
#[error("error deserializing json")]
|
||||||
DeserializeJson(#[source] serde_json::Error),
|
DeserializeJson(#[source] serde_json::Error),
|
||||||
|
#[error("error reading configuration at {path}")]
|
||||||
|
ReadConfig {
|
||||||
|
path: PathBuf,
|
||||||
|
#[source]
|
||||||
|
source: std::io::Error,
|
||||||
|
},
|
||||||
|
#[error("error deserializing toml")]
|
||||||
|
DeserializeToml(#[from] toml::de::Error),
|
||||||
|
#[error("invalid url")]
|
||||||
|
Url(#[from] url::ParseError),
|
||||||
|
#[error("invalid url scheme: {url}")]
|
||||||
|
UrlScheme { url: String },
|
||||||
|
#[error("STUN discovery failed: {message}")]
|
||||||
|
StunDiscovery { message: String },
|
||||||
|
#[error("HTTP IP discovery failed")]
|
||||||
|
HttpDiscovery(#[source] reqwest::Error),
|
||||||
|
#[error("failed to parse IP address from HTTP response: {body}")]
|
||||||
|
ParseIp { body: String },
|
||||||
|
#[error("all discovery methods failed")]
|
||||||
|
DiscoveryFailed,
|
||||||
|
#[error("error with request")]
|
||||||
|
Reqwest(#[from] reqwest::Error),
|
||||||
|
#[error("invalid base64 key: {context}")]
|
||||||
|
InvalidKey { context: String },
|
||||||
|
#[error("error creating WireGuard interface {interface}")]
|
||||||
|
CreateInterface {
|
||||||
|
interface: String,
|
||||||
|
#[source]
|
||||||
|
source: std::io::Error,
|
||||||
|
},
|
||||||
|
#[error("error configuring WireGuard device {interface}")]
|
||||||
|
ConfigureDevice {
|
||||||
|
interface: String,
|
||||||
|
#[source]
|
||||||
|
source: std::io::Error,
|
||||||
|
},
|
||||||
|
#[error("error with netlink operation: {context}")]
|
||||||
|
Netlink { context: String },
|
||||||
|
#[error("error setting interface address")]
|
||||||
|
SetAddress {
|
||||||
|
#[source]
|
||||||
|
source: rtnetlink::Error,
|
||||||
|
},
|
||||||
|
#[error("error bringing interface up")]
|
||||||
|
SetLinkUp {
|
||||||
|
#[source]
|
||||||
|
source: rtnetlink::Error,
|
||||||
|
},
|
||||||
|
#[error("error getting interface index for {interface}")]
|
||||||
|
GetInterface { interface: String },
|
||||||
|
#[error("registration failed")]
|
||||||
|
Registration {
|
||||||
|
#[source]
|
||||||
|
source: reqwest::Error,
|
||||||
|
},
|
||||||
|
#[error("error deregistering device {public_key}")]
|
||||||
|
Deregister {
|
||||||
|
public_key: String,
|
||||||
|
#[source]
|
||||||
|
source: reqwest::Error,
|
||||||
|
},
|
||||||
|
#[error("registration failed with status {status}: {body}")]
|
||||||
|
RegistrationStatus { status: u16, body: String },
|
||||||
|
#[error("error deleting interface")]
|
||||||
|
DeleteInterface {
|
||||||
|
name: String,
|
||||||
|
#[source]
|
||||||
|
source: rtnetlink::Error,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = core::result::Result<T, Error>;
|
pub type Result<T> = core::result::Result<T, Error>;
|
||||||
|
|||||||
@@ -1,42 +1,244 @@
|
|||||||
|
mod config;
|
||||||
|
mod discovery;
|
||||||
|
mod error;
|
||||||
|
mod netlink;
|
||||||
|
mod wireguard;
|
||||||
|
|
||||||
|
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
|
||||||
|
|
||||||
use console::style;
|
use console::style;
|
||||||
use futures::StreamExt;
|
use futures::{StreamExt, TryFutureExt, stream::SplitStream};
|
||||||
use registry::PeerMessage;
|
use ipnetwork::IpNetwork;
|
||||||
|
use registry::{Peer, PeerMessage, RegisterResponse};
|
||||||
|
use serde::Serialize;
|
||||||
use thiserror_ext::AsReport;
|
use thiserror_ext::AsReport;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio::{
|
||||||
|
net::TcpStream,
|
||||||
|
signal::{self, unix::SignalKind},
|
||||||
|
};
|
||||||
|
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
|
||||||
use tracing::level_filters::LevelFilter;
|
use tracing::level_filters::LevelFilter;
|
||||||
use tracing_subscriber::{
|
use tracing_subscriber::{
|
||||||
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
|
EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::error::Error;
|
use crate::{
|
||||||
|
config::{Config, INTERFACE_NAME},
|
||||||
|
discovery::{PublicEndpoint, discover_public_endpoint},
|
||||||
|
error::{Error, Result},
|
||||||
|
};
|
||||||
|
|
||||||
mod error;
|
#[derive(Serialize)]
|
||||||
|
pub struct RegisterRequest {
|
||||||
|
pub public_ip: IpAddr,
|
||||||
|
pub public_key: String,
|
||||||
|
pub port: String,
|
||||||
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
|
}
|
||||||
|
|
||||||
async fn run() -> crate::error::Result<()> {
|
async fn register_self(
|
||||||
let url = get_api_from_env();
|
client: &reqwest::Client,
|
||||||
let (ws_stream, response) = tokio_tungstenite::connect_async(&url)
|
endpoint: &PublicEndpoint,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<Ipv4Addr> {
|
||||||
|
let url = format!("{}/register", &config.server.url);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
public_ip = %endpoint.ip,
|
||||||
|
port = endpoint.port,
|
||||||
|
"registering with registry"
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.post(&url)
|
||||||
|
.json(&RegisterRequest {
|
||||||
|
public_key: config.interface.public_key.clone(),
|
||||||
|
public_ip: endpoint.ip,
|
||||||
|
port: endpoint.port.to_string(),
|
||||||
|
allowed_ips: config.interface.allowed_ips.clone().unwrap_or(vec![]),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::ws_connect(e, &url))?;
|
.map_err(Error::registration)?;
|
||||||
let (_, mut read) = ws_stream.split();
|
|
||||||
tracing::info!("connected, response: {:?}", response.status());
|
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status().as_u16();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
return Err(Error::registration_status(status, body));
|
||||||
|
}
|
||||||
|
|
||||||
|
let register_response: RegisterResponse = response.json().await.map_err(Error::registration)?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
mesh_ip = %register_response.mesh_ip,
|
||||||
|
"registration successful, assigned mesh IP"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(register_response.mesh_ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deregister_self(client: &reqwest::Client, config: &Config) -> Result<()> {
|
||||||
|
let url = format!(
|
||||||
|
"{}/deregister?public_key={}",
|
||||||
|
config.server.url, config.interface.public_key
|
||||||
|
);
|
||||||
|
client
|
||||||
|
.delete(&url)
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()
|
||||||
|
.map_err(|e| Error::deregister(e, &config.interface.public_key))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_peer(peer: &Peer, local_public_key: &str, keepalive: u16) -> Result<()> {
|
||||||
|
if peer.public_key == local_public_key {
|
||||||
|
tracing::debug!(peer_key = %peer.public_key, "skipping self");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let peer_key = wireguard::parse_key(&peer.public_key)?;
|
||||||
|
let endpoint: SocketAddr = format!("{}:{}", peer.public_ip, peer.port)
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| {
|
||||||
|
Error::netlink(format!(
|
||||||
|
"invalid endpoint: {}:{}",
|
||||||
|
peer.public_ip, peer.port
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut allowed_ips: Vec<IpNetwork> =
|
||||||
|
vec![IpNetwork::new(std::net::IpAddr::V4(peer.mesh_ip), 32).expect("valid network")];
|
||||||
|
allowed_ips.extend(peer.allowed_ips.iter().cloned());
|
||||||
|
|
||||||
|
wireguard::configure_peer(&peer_key, Some(endpoint), &allowed_ips, Some(keepalive))?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
peer_key = %peer.public_key,
|
||||||
|
mesh_ip = %peer.mesh_ip,
|
||||||
|
endpoint = %endpoint,
|
||||||
|
"configured peer"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn configure_all_peers(peers: &[Peer], local_public_key: &str, keepalive: u16) -> Result<()> {
|
||||||
|
for peer in peers {
|
||||||
|
if let Err(e) = configure_peer(peer, local_public_key, keepalive) {
|
||||||
|
tracing::warn!(peer_key = %peer.public_key, error = %e, "failed to configure peer");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn events(
|
||||||
|
read: &mut SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
|
||||||
|
config: &Config,
|
||||||
|
) -> Result<()> {
|
||||||
while let Some(msg) = read.next().await {
|
while let Some(msg) = read.next().await {
|
||||||
match msg.map_err(|e| Error::ws_read(e))? {
|
match msg.map_err(Error::ws_read)? {
|
||||||
Message::Text(text) => {
|
Message::Text(text) => {
|
||||||
let server_msg: PeerMessage =
|
let server_msg: PeerMessage =
|
||||||
serde_json::from_str(&text).map_err(|e| Error::deserialize_json(e))?;
|
serde_json::from_str(&text).map_err(Error::deserialize_json)?;
|
||||||
|
|
||||||
match server_msg {
|
match server_msg {
|
||||||
PeerMessage::HydratePeers { peers } => {}
|
PeerMessage::HydratePeers { peers } => {
|
||||||
PeerMessage::PeerUpdate { peer } => {}
|
tracing::info!(count = peers.len(), "Received initial peer list");
|
||||||
|
configure_all_peers(
|
||||||
|
&peers,
|
||||||
|
&config.interface.public_key,
|
||||||
|
config.interface.persistent_keepalive,
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
|
PeerMessage::PeerUpdate { peer } => {
|
||||||
|
tracing::info!(
|
||||||
|
peer_key = %peer.public_key,
|
||||||
|
mesh_ip = %peer.mesh_ip,
|
||||||
|
"Received peer update"
|
||||||
|
);
|
||||||
|
configure_peer(
|
||||||
|
&peer,
|
||||||
|
&config.interface.public_key,
|
||||||
|
config.interface.persistent_keepalive,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Message::Ping(_) => {}
|
||||||
|
Message::Pong(_) => {}
|
||||||
|
Message::Close(_) => {
|
||||||
|
tracing::warn!("connection closed by server");
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn run() -> Result<()> {
|
||||||
|
let config = Config::load()?;
|
||||||
|
let private_key = wireguard::parse_key(&config.interface.private_key)?;
|
||||||
|
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
|
||||||
|
tracing::info!(
|
||||||
|
public_ip = %endpoint.ip,
|
||||||
|
public_port = endpoint.port,
|
||||||
|
"public endpoint"
|
||||||
|
);
|
||||||
|
let http_client = reqwest::Client::new();
|
||||||
|
let mesh_ip = register_self(&http_client, &endpoint, &config).await?;
|
||||||
|
wireguard::create_interface()?;
|
||||||
|
wireguard::configure_interface(&private_key, config.interface.listen_port)?;
|
||||||
|
|
||||||
|
let (nl_handle, _nl_task) = netlink::connect().await?;
|
||||||
|
netlink::add_address(&nl_handle, mesh_ip, 32).await?;
|
||||||
|
netlink::set_link_up(&nl_handle).await?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
interface = INTERFACE_NAME,
|
||||||
|
mesh_ip = %mesh_ip,
|
||||||
|
"interface configured and up"
|
||||||
|
);
|
||||||
|
|
||||||
|
let ws_url = &config.server.ws_url;
|
||||||
|
|
||||||
|
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::ws_connect(e, ws_url))?;
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
status = ?response.status(),
|
||||||
|
"connected to registry WebSocket"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (_, mut read) = ws_stream.split();
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
receiver = events(&mut read, &config) => receiver?,
|
||||||
|
_ = signal::ctrl_c() => {
|
||||||
|
},
|
||||||
|
_ = on_shutdown() => {}
|
||||||
|
};
|
||||||
|
tracing::debug!("gracefully shutting down");
|
||||||
|
netlink::delete_interface(&nl_handle, INTERFACE_NAME).await?;
|
||||||
|
deregister_self(&http_client, &config).await?;
|
||||||
|
tracing::info!("connection closed");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn on_shutdown() {
|
||||||
|
use tokio::signal::unix::signal;
|
||||||
|
let mut sigterm = signal(SignalKind::terminate()).expect("failed to register SIGTERM");
|
||||||
|
let mut sigint = signal(SignalKind::interrupt()).expect("failed to register SIGINT");
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
_ = sigterm.recv() => tracing::info!("received sigterm"),
|
||||||
|
_ = sigint.recv() => tracing::info!("received sigint")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let tracing_env_filter = EnvFilter::builder()
|
let tracing_env_filter = EnvFilter::builder()
|
||||||
@@ -57,7 +259,3 @@ async fn main() {
|
|||||||
std::process::exit(1)
|
std::process::exit(1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_api_from_env() -> String {
|
|
||||||
return "ws://localhost:8080/ws/peers".to_string();
|
|
||||||
}
|
|
||||||
|
|||||||
84
client/src/netlink.rs
Normal file
84
client/src/netlink.rs
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
|
use futures::TryStreamExt;
|
||||||
|
use rtnetlink::Handle;
|
||||||
|
|
||||||
|
use crate::config::INTERFACE_NAME;
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
async fn get_interface_index(handle: &Handle, name: &str) -> Result<u32> {
|
||||||
|
let mut links = handle.link().get().match_name(name.to_string()).execute();
|
||||||
|
|
||||||
|
if let Some(link) = links.try_next().await.map_err(Error::set_link_up)? {
|
||||||
|
Ok(link.header.index)
|
||||||
|
} else {
|
||||||
|
Err(Error::get_interface(name.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_address(handle: &Handle, addr: Ipv4Addr, prefix_len: u8) -> Result<()> {
|
||||||
|
let index = get_interface_index(handle, INTERFACE_NAME).await?;
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
interface = INTERFACE_NAME,
|
||||||
|
address = %addr,
|
||||||
|
prefix_len,
|
||||||
|
"adding address to interface"
|
||||||
|
);
|
||||||
|
|
||||||
|
match handle
|
||||||
|
.address()
|
||||||
|
.add(index, std::net::IpAddr::V4(addr), prefix_len)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(()) => Ok(()),
|
||||||
|
Err(rtnetlink::Error::NetlinkError(e)) if e.raw_code() == -libc::EEXIST => {
|
||||||
|
tracing::debug!(
|
||||||
|
interface = INTERFACE_NAME,
|
||||||
|
address = %addr,
|
||||||
|
"address already exists on interface"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(e) => Err(Error::set_address(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_link_up(handle: &Handle) -> Result<()> {
|
||||||
|
let index = get_interface_index(handle, INTERFACE_NAME).await?;
|
||||||
|
|
||||||
|
tracing::info!(interface = INTERFACE_NAME, "set interface up");
|
||||||
|
|
||||||
|
handle
|
||||||
|
.link()
|
||||||
|
.set(index)
|
||||||
|
.up()
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.map_err(Error::set_link_up)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_interface(handle: &Handle, name: &str) -> Result<()> {
|
||||||
|
let index = get_interface_index(&handle, name).await?;
|
||||||
|
|
||||||
|
handle
|
||||||
|
.link()
|
||||||
|
.del(index)
|
||||||
|
.execute()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::delete_interface(e, name))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect() -> Result<(Handle, tokio::task::JoinHandle<()>)> {
|
||||||
|
let (connection, handle, _) = rtnetlink::new_connection()
|
||||||
|
.map_err(|e| Error::netlink(format!("failed to create netlink connection: {}", e)))?;
|
||||||
|
|
||||||
|
let join_handle = tokio::spawn(connection);
|
||||||
|
|
||||||
|
Ok((handle, join_handle))
|
||||||
|
}
|
||||||
117
client/src/wireguard.rs
Normal file
117
client/src/wireguard.rs
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
|
use wireguard_control::{
|
||||||
|
AllowedIp, Backend, Device, DeviceUpdate, InterfaceName, InvalidKey, Key, PeerConfigBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::config::INTERFACE_NAME;
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
pub fn parse_key(key_b64: &str) -> Result<Key> {
|
||||||
|
Key::from_base64(key_b64).map_err(|e: InvalidKey| Error::invalid_key(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn interface_name() -> InterfaceName {
|
||||||
|
INTERFACE_NAME.parse().expect("valid interface name")
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn create_interface() -> Result<()> {
|
||||||
|
let name = interface_name();
|
||||||
|
|
||||||
|
if Device::get(&name, Backend::Kernel).is_ok() {
|
||||||
|
tracing::debug!(interface = INTERFACE_NAME, "interface already exists");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!(interface = INTERFACE_NAME, "creating WireGuard interface");
|
||||||
|
|
||||||
|
DeviceUpdate::new()
|
||||||
|
.apply(&name, Backend::Kernel)
|
||||||
|
.map_err(|e| Error::create_interface(e, INTERFACE_NAME.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configure_interface(private_key: &Key, listen_port: u16) -> Result<()> {
|
||||||
|
let name = interface_name();
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
interface = INTERFACE_NAME,
|
||||||
|
listen_port,
|
||||||
|
"configuring WireGuard interface"
|
||||||
|
);
|
||||||
|
|
||||||
|
DeviceUpdate::new()
|
||||||
|
.set_private_key(private_key.clone())
|
||||||
|
.set_listen_port(listen_port)
|
||||||
|
.apply(&name, Backend::Kernel)
|
||||||
|
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn configure_peer(
|
||||||
|
public_key: &Key,
|
||||||
|
endpoint: Option<SocketAddr>,
|
||||||
|
allowed_ips: &[IpNetwork],
|
||||||
|
persistent_keepalive: Option<u16>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let name = interface_name();
|
||||||
|
|
||||||
|
let allowed: Vec<AllowedIp> = allowed_ips
|
||||||
|
.iter()
|
||||||
|
.map(|ip| AllowedIp {
|
||||||
|
address: ip.ip(),
|
||||||
|
cidr: ip.prefix(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut peer = PeerConfigBuilder::new(public_key)
|
||||||
|
.replace_allowed_ips()
|
||||||
|
.add_allowed_ips(&allowed);
|
||||||
|
if let Some(ep) = endpoint {
|
||||||
|
peer = peer.set_endpoint(ep);
|
||||||
|
}
|
||||||
|
if let Some(keepalive) = persistent_keepalive {
|
||||||
|
peer = peer.set_persistent_keepalive_interval(keepalive);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
peer_key = %public_key.to_base64(),
|
||||||
|
endpoint = ?endpoint,
|
||||||
|
allowed_ips = ?allowed_ips,
|
||||||
|
"configuring peer"
|
||||||
|
);
|
||||||
|
|
||||||
|
DeviceUpdate::new()
|
||||||
|
.add_peer(peer)
|
||||||
|
.apply(&name, Backend::Kernel)
|
||||||
|
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn remove_peer(public_key: &Key) -> Result<()> {
|
||||||
|
let name = interface_name();
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
peer_key = %public_key.to_base64(),
|
||||||
|
"removing peer"
|
||||||
|
);
|
||||||
|
|
||||||
|
DeviceUpdate::new()
|
||||||
|
.remove_peer_by_key(public_key)
|
||||||
|
.apply(&name, Backend::Kernel)
|
||||||
|
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn get_device() -> Result<Device> {
|
||||||
|
let name = interface_name();
|
||||||
|
Device::get(&name, Backend::Kernel)
|
||||||
|
.map_err(|e| Error::configure_device(e, INTERFACE_NAME.to_string()))
|
||||||
|
}
|
||||||
5
justfile
5
justfile
@@ -1,2 +1,7 @@
|
|||||||
dev bin="registry":
|
dev bin="registry":
|
||||||
bacon dev -- --bin {{bin}}
|
bacon dev -- --bin {{bin}}
|
||||||
|
|
||||||
|
client target="debug":
|
||||||
|
cargo build -p client
|
||||||
|
sudo setcap cap_net_admin+ep ./target/{{target}}/client
|
||||||
|
./target/{{target}}/client
|
||||||
|
|||||||
6
payload.json
Normal file
6
payload.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"public_key": "pKkl30tba29FG86wuaC0KrpSHMr1tSOujikHFbx75BM=",
|
||||||
|
"public_ip": "75.157.238.86",
|
||||||
|
"port": "51820",
|
||||||
|
"allowed_ips": ["10.100.0.1/32"]
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ tracing-actix-web = "0.7.21"
|
|||||||
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "sync"] }
|
tokio = { version = "1.49.0", features = ["macros", "rt-multi-thread", "sync"] }
|
||||||
thiserror = "2.0.18"
|
thiserror = "2.0.18"
|
||||||
thiserror-ext = "0.3.0"
|
thiserror-ext = "0.3.0"
|
||||||
|
wireguard-control = "1.7.1"
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "registry"
|
name = "registry"
|
||||||
|
|||||||
27
registry/src/endpoints/deregister.rs
Normal file
27
registry/src/endpoints/deregister.rs
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
use actix_web::{HttpResponse, web};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use wireguard_control::Key;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
AppState,
|
||||||
|
error::{Error, Result},
|
||||||
|
storage::StorageImpl,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct DeregisterRequest {
|
||||||
|
public_key: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn deregister(
|
||||||
|
app_state: web::Data<AppState>,
|
||||||
|
query: web::Query<DeregisterRequest>,
|
||||||
|
) -> Result<HttpResponse> {
|
||||||
|
Key::from_base64(&query.public_key).map_err(|_| Error::invalid_key(&query.public_key))?;
|
||||||
|
|
||||||
|
app_state
|
||||||
|
.storage
|
||||||
|
.deregister_device(&query.public_key)
|
||||||
|
.await?;
|
||||||
|
Ok(HttpResponse::Ok().finish())
|
||||||
|
}
|
||||||
@@ -1,3 +1,4 @@
|
|||||||
|
pub mod deregister;
|
||||||
pub mod peers;
|
pub mod peers;
|
||||||
pub mod register;
|
pub mod register;
|
||||||
pub mod ws;
|
pub mod ws;
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ use crate::{
|
|||||||
storage::{RegisterRequest, StorageImpl},
|
storage::{RegisterRequest, StorageImpl},
|
||||||
};
|
};
|
||||||
use actix_web::{HttpResponse, web};
|
use actix_web::{HttpResponse, web};
|
||||||
use registry::Peer;
|
use registry::{Peer, RegisterResponse};
|
||||||
|
|
||||||
pub async fn register_peer(
|
pub async fn register_peer(
|
||||||
app_state: web::Data<AppState>,
|
app_state: web::Data<AppState>,
|
||||||
request: web::Json<RegisterRequest>,
|
request: web::Json<RegisterRequest>,
|
||||||
) -> Result<HttpResponse> {
|
) -> Result<HttpResponse> {
|
||||||
app_state.storage.register_device(&request).await?;
|
let mesh_ip = app_state.storage.register_device(&request).await?;
|
||||||
|
|
||||||
app_state
|
app_state
|
||||||
.peer_updates
|
.peer_updates
|
||||||
.send(PeerUpdate {
|
.send(PeerUpdate {
|
||||||
@@ -18,10 +19,11 @@ pub async fn register_peer(
|
|||||||
public_key: request.public_key.as_str().to_string(),
|
public_key: request.public_key.as_str().to_string(),
|
||||||
public_ip: request.public_ip.to_string(),
|
public_ip: request.public_ip.to_string(),
|
||||||
port: request.port.clone(),
|
port: request.port.clone(),
|
||||||
|
mesh_ip,
|
||||||
allowed_ips: request.allowed_ips.clone(),
|
allowed_ips: request.allowed_ips.clone(),
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
.unwrap();
|
.ok();
|
||||||
|
|
||||||
Ok(HttpResponse::Ok().finish())
|
Ok(HttpResponse::Ok().json(RegisterResponse { mesh_ip }))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ pub async fn peers(
|
|||||||
update = peer_rx.recv() => {
|
update = peer_rx.recv() => {
|
||||||
match update {
|
match update {
|
||||||
Ok(peer_update) => {
|
Ok(peer_update) => {
|
||||||
let json = serde_json::to_string(&peer_update.peer)
|
let json = serde_json::to_string(&PeerMessage::PeerUpdate { peer: peer_update.peer.clone() })
|
||||||
.unwrap_or_else(|_| "{}".to_string());
|
.unwrap_or_else(|_| "{}".to_string());
|
||||||
if session.text(json).await.is_err() {
|
if session.text(json).await.is_err() {
|
||||||
break;
|
break;
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
use actix_web::{HttpResponse, ResponseError};
|
use actix_web::{HttpResponse, ResponseError};
|
||||||
|
use serde::Serialize;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use thiserror_ext::{Box, Construct};
|
use thiserror_ext::{Box, Construct};
|
||||||
|
|
||||||
@@ -32,12 +33,30 @@ pub enum ErrorKind {
|
|||||||
},
|
},
|
||||||
#[error("error handling ws")]
|
#[error("error handling ws")]
|
||||||
Ws(#[source] actix_web::Error),
|
Ws(#[source] actix_web::Error),
|
||||||
|
#[error("IP pool exhausted: no available addresses in {pool}")]
|
||||||
|
IpPoolExhausted { pool: String },
|
||||||
|
#[error("error deregistering device {public_key}")]
|
||||||
|
DeregisterDevice {
|
||||||
|
public_key: String,
|
||||||
|
#[source]
|
||||||
|
source: redis::RedisError,
|
||||||
|
},
|
||||||
|
#[error("error invalid key")]
|
||||||
|
InvalidKey(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct ErrorResponse {
|
||||||
|
error: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseError for Error {
|
impl ResponseError for Error {
|
||||||
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
|
fn error_response(&self) -> actix_web::HttpResponse<actix_web::body::BoxBody> {
|
||||||
match self.inner() {
|
match self.inner() {
|
||||||
ErrorKind::Ws(e) => e.error_response(),
|
ErrorKind::Ws(e) => e.error_response(),
|
||||||
|
ErrorKind::InvalidKey(key) => HttpResponse::BadRequest().json(ErrorResponse {
|
||||||
|
error: format!("error invalid key: {key}"),
|
||||||
|
}),
|
||||||
_ => HttpResponse::InternalServerError().finish(),
|
_ => HttpResponse::InternalServerError().finish(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
mod types;
|
mod types;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use types::peer_message::*;
|
pub use types::peer_message::{Peer, PeerMessage, RegisterResponse};
|
||||||
|
|||||||
@@ -47,6 +47,10 @@ async fn run() -> crate::error::Result<()> {
|
|||||||
"/register",
|
"/register",
|
||||||
web::post().to(endpoints::register::register_peer),
|
web::post().to(endpoints::register::register_peer),
|
||||||
)
|
)
|
||||||
|
.route(
|
||||||
|
"/deregister",
|
||||||
|
web::delete().to(endpoints::deregister::deregister),
|
||||||
|
)
|
||||||
.route("/peers", web::get().to(endpoints::peers::get_peers))
|
.route("/peers", web::get().to(endpoints::peers::get_peers))
|
||||||
.route("/ws/peers", web::get().to(endpoints::ws::peers::peers))
|
.route("/ws/peers", web::get().to(endpoints::ws::peers::peers))
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
mod valkey;
|
mod valkey;
|
||||||
@@ -10,12 +12,13 @@ pub enum Storage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub trait StorageImpl {
|
pub trait StorageImpl {
|
||||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()>;
|
async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr>;
|
||||||
|
async fn deregister_device(&self, public_key: &str) -> Result<()>;
|
||||||
async fn get_peers(&self) -> Result<Vec<Peer>>;
|
async fn get_peers(&self) -> Result<Vec<Peer>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StorageImpl for Storage {
|
impl StorageImpl for Storage {
|
||||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
|
async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr> {
|
||||||
match self {
|
match self {
|
||||||
Self::Valkey(storage) => storage.register_device(request).await,
|
Self::Valkey(storage) => storage.register_device(request).await,
|
||||||
}
|
}
|
||||||
@@ -26,6 +29,12 @@ impl StorageImpl for Storage {
|
|||||||
Self::Valkey(storage) => storage.get_peers().await,
|
Self::Valkey(storage) => storage.get_peers().await,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn deregister_device(&self, public_key: &str) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Self::Valkey(storage) => storage.deregister_device(public_key).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_storage_from_env() -> Result<Storage> {
|
pub fn get_storage_from_env() -> Result<Storage> {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::net::IpAddr;
|
use std::net::{IpAddr, Ipv4Addr};
|
||||||
|
|
||||||
|
use futures::TryFutureExt;
|
||||||
use ipnetwork::IpNetwork;
|
use ipnetwork::IpNetwork;
|
||||||
use redis::AsyncTypedCommands;
|
use redis::AsyncTypedCommands;
|
||||||
use registry::Peer;
|
use registry::Peer;
|
||||||
@@ -10,6 +11,10 @@ use crate::error::Result;
|
|||||||
use crate::utils::WireguardPublicKey;
|
use crate::utils::WireguardPublicKey;
|
||||||
use crate::{error::Error, storage::StorageImpl};
|
use crate::{error::Error, storage::StorageImpl};
|
||||||
|
|
||||||
|
const MESH_NETWORK_BASE: [u8; 4] = [10, 100, 0, 0];
|
||||||
|
const MESH_POOL_START: u8 = 1;
|
||||||
|
const MESH_POOL_END: u8 = 254;
|
||||||
|
|
||||||
pub struct ValkeyStorage {
|
pub struct ValkeyStorage {
|
||||||
pub valkey_client: redis::Client,
|
pub valkey_client: redis::Client,
|
||||||
}
|
}
|
||||||
@@ -23,22 +28,38 @@ pub struct RegisterRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl StorageImpl for ValkeyStorage {
|
impl StorageImpl for ValkeyStorage {
|
||||||
async fn register_device(&self, request: &RegisterRequest) -> Result<()> {
|
async fn register_device(&self, request: &RegisterRequest) -> Result<Ipv4Addr> {
|
||||||
let mut conn = self
|
let mut conn = self
|
||||||
.valkey_client
|
.valkey_client
|
||||||
.get_multiplexed_async_connection()
|
.get_multiplexed_async_connection()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::valkey_get_connection(e))?;
|
.map_err(|e| Error::valkey_get_connection(e))?;
|
||||||
|
|
||||||
|
let peer_key = format!("peer:{}", request.public_key.as_str());
|
||||||
|
|
||||||
|
let existing_mesh_ip: Option<String> = conn
|
||||||
|
.hget(&peer_key, "mesh_ip")
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::get_peer(e))?;
|
||||||
|
|
||||||
|
let mesh_ip = if let Some(ip_str) = existing_mesh_ip {
|
||||||
|
ip_str.parse::<Ipv4Addr>().unwrap()
|
||||||
|
} else {
|
||||||
|
let allocated_ip = self.allocate_mesh_ip(&mut conn).await?;
|
||||||
|
allocated_ip
|
||||||
|
};
|
||||||
|
|
||||||
conn.hset_multiple::<_, _, _>(
|
conn.hset_multiple::<_, _, _>(
|
||||||
format!("peer:{}", request.public_key.as_str()),
|
&peer_key,
|
||||||
&[
|
&[
|
||||||
("public_ip", &request.public_ip.to_string()),
|
("public_ip", request.public_ip.to_string()),
|
||||||
(
|
(
|
||||||
"allowed_ips",
|
"allowed_ips",
|
||||||
&serde_json::to_string(&request.allowed_ips)
|
serde_json::to_string(&request.allowed_ips)
|
||||||
.map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?,
|
.map_err(|e| Error::serialize_json(e, "serializing allowed_ips"))?,
|
||||||
),
|
),
|
||||||
("port", &request.port),
|
("port", request.port.clone()),
|
||||||
|
("mesh_ip", mesh_ip.to_string()),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
@@ -49,7 +70,8 @@ impl StorageImpl for ValkeyStorage {
|
|||||||
request.public_ip.to_string(),
|
request.public_ip.to_string(),
|
||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
conn.sadd("peers", &request.public_key.as_str())
|
|
||||||
|
conn.sadd("peers", request.public_key.as_str())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
Error::add_peer(
|
Error::add_peer(
|
||||||
@@ -59,6 +81,25 @@ impl StorageImpl for ValkeyStorage {
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
Ok(mesh_ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn deregister_device(&self, public_key: &str) -> Result<()> {
|
||||||
|
let mut conn = self
|
||||||
|
.valkey_client
|
||||||
|
.get_multiplexed_async_connection()
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::valkey_get_connection(e))?;
|
||||||
|
let hash_key = format!("peer:{public_key}");
|
||||||
|
conn.srem("peers", public_key)
|
||||||
|
.map_err(|e| Error::deregister_device(e, public_key))
|
||||||
|
.await?;
|
||||||
|
let response = conn
|
||||||
|
.del(hash_key)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::deregister_device(e, public_key))?;
|
||||||
|
tracing::debug!("deleted hash {keys} key(s) removed", keys = response);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,21 +130,72 @@ impl StorageImpl for ValkeyStorage {
|
|||||||
.map_err(|e| Error::get_peer(e))?
|
.map_err(|e| Error::get_peer(e))?
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.zip(keys.iter())
|
.zip(keys.iter())
|
||||||
.map(|(peer, key): (HashMap<String, String>, &String)| {
|
.filter_map(|(peer, key): (HashMap<String, String>, &String)| {
|
||||||
let allowed_ips: Vec<IpNetwork> = peer
|
let allowed_ips: Vec<IpNetwork> = peer
|
||||||
.get("allowed_ips")
|
.get("allowed_ips")
|
||||||
.map(|s| serde_json::from_str(s).unwrap_or_default())
|
.map(|s| serde_json::from_str(s).unwrap_or_default())
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
Peer {
|
let mesh_ip: Ipv4Addr = peer.get("mesh_ip")?.parse().ok()?;
|
||||||
|
Some(Peer {
|
||||||
public_key: key.clone(),
|
public_key: key.clone(),
|
||||||
public_ip: peer.get("public_ip").unwrap().to_string(),
|
public_ip: peer.get("public_ip")?.to_string(),
|
||||||
port: peer.get("port").unwrap().to_string(),
|
port: peer.get("port")?.to_string(),
|
||||||
|
mesh_ip,
|
||||||
allowed_ips,
|
allowed_ips,
|
||||||
}
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(peers)
|
Ok(peers)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ValkeyStorage {
|
||||||
|
async fn allocate_mesh_ip(
|
||||||
|
&self,
|
||||||
|
conn: &mut redis::aio::MultiplexedConnection,
|
||||||
|
) -> Result<Ipv4Addr> {
|
||||||
|
let keys: HashSet<String> = conn
|
||||||
|
.smembers("peers")
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::get_peer(e))?;
|
||||||
|
|
||||||
|
let mut assigned_ips: HashSet<Ipv4Addr> = HashSet::new();
|
||||||
|
|
||||||
|
if !keys.is_empty() {
|
||||||
|
let mut pipe = redis::pipe();
|
||||||
|
for key in keys.iter() {
|
||||||
|
pipe.hget(format!("peer:{key}"), "mesh_ip");
|
||||||
|
}
|
||||||
|
|
||||||
|
let ips: Vec<Option<String>> = pipe
|
||||||
|
.query_async(conn)
|
||||||
|
.await
|
||||||
|
.map_err(|e| Error::get_peer(e))?;
|
||||||
|
|
||||||
|
for ip_opt in ips {
|
||||||
|
if let Some(ip_str) = ip_opt {
|
||||||
|
if let Ok(ip) = ip_str.parse::<Ipv4Addr>() {
|
||||||
|
assigned_ips.insert(ip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for last_octet in MESH_POOL_START..=MESH_POOL_END {
|
||||||
|
let candidate = Ipv4Addr::new(
|
||||||
|
MESH_NETWORK_BASE[0],
|
||||||
|
MESH_NETWORK_BASE[1],
|
||||||
|
MESH_NETWORK_BASE[2],
|
||||||
|
last_octet,
|
||||||
|
);
|
||||||
|
|
||||||
|
if !assigned_ips.contains(&candidate) {
|
||||||
|
return Ok(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::ip_pool_exhausted("10.100.0.0/24".to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
use std::net::Ipv4Addr;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use ipnetwork::IpNetwork;
|
use ipnetwork::IpNetwork;
|
||||||
@@ -7,9 +9,15 @@ pub struct Peer {
|
|||||||
pub public_key: String,
|
pub public_key: String,
|
||||||
pub public_ip: String,
|
pub public_ip: String,
|
||||||
pub port: String,
|
pub port: String,
|
||||||
|
pub mesh_ip: Ipv4Addr,
|
||||||
pub allowed_ips: Vec<IpNetwork>,
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||||
|
pub struct RegisterResponse {
|
||||||
|
pub mesh_ip: Ipv4Addr,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
#[serde(tag = "type")]
|
#[serde(tag = "type")]
|
||||||
pub enum PeerMessage {
|
pub enum PeerMessage {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use serde::{Deserialize, de};
|
use serde::{de, Deserialize};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct WireguardPublicKey(String);
|
pub struct WireguardPublicKey(String);
|
||||||
|
|||||||
Reference in New Issue
Block a user