feat!: add stun discovery

This commit is contained in:
2026-02-14 16:17:22 -08:00
parent 3cac99c24c
commit a022c18ff9
10 changed files with 985 additions and 34 deletions

759
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,3 @@
[workspace]
members = ["registry", "client"]
resolver = "3"
[workspace.dependencies]

View File

@@ -19,3 +19,6 @@ registry = { path = "../registry" }
dirs = "6.0.0"
toml = "1.0.1"
url = "2.5.8"
reqwest = { version = "0.13.2", features = ["json"] }
stunclient = "0.4.2"
ipnetwork = { version = "0.21.1", features = ["serde"] }

22
client/src/app_state.rs Normal file
View File

@@ -0,0 +1,22 @@
use std::{ops::Deref, sync::Arc};
#[derive(Clone)]
pub struct AppState {
pub reqwest_client: reqwest::Client,
}
#[derive(Clone)]
pub struct Data<T>(Arc<T>);
impl<T> Deref for Data<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> Data<T> {
pub fn new(inner: T) -> Self {
Self(Arc::new(inner))
}
}

View File

@@ -1,6 +1,7 @@
use std::path::PathBuf;
use crate::error::{Error, Result};
use ipnetwork::IpNetwork;
use serde::Deserialize;
#[derive(Deserialize)]
@@ -11,11 +12,14 @@ pub struct Config {
#[derive(Deserialize)]
pub struct InterfaceConfig {
pub private_key: String,
pub public_key: String,
pub listen_port: u16,
pub address: String,
pub allowed_ips: Vec<IpNetwork>,
}
#[derive(Deserialize)]
pub struct ServerConfig {
pub ws_url: String,
pub url: String,
}

142
client/src/discovery.rs Normal file
View File

@@ -0,0 +1,142 @@
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
use std::time::Duration;
use stunclient::StunClient;
use crate::error::{Error, Result};
const STUN_SERVERS: &[&str] = &[
"stun.l.google.com:19302",
"stun.cloudflare.com:3478",
"stun.stunprotocol.org:3478",
];
const HTTP_IP_SERVICES: &[&str] = &["https://icanhazip.com", "https://ifconfig.me/ip"];
const STUN_TIMEOUT: Duration = Duration::from_secs(3);
const HTTP_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
pub struct PublicEndpoint {
pub ip: IpAddr,
pub port: u16,
}
impl std::fmt::Display for PublicEndpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.ip, self.port)
}
}
pub async fn discover_public_endpoint(local_port: u16) -> Result<PublicEndpoint> {
match stun_discover(local_port) {
Ok(endpoint) => {
tracing::debug!(
ip = %endpoint.ip,
port = endpoint.port,
"STUN discovery successful"
);
return Ok(endpoint);
}
Err(e) => {
tracing::warn!("STUN discovery failed, falling back to HTTP: {e}");
}
}
let ip = http_discover_ip().await?;
let endpoint = PublicEndpoint {
ip,
port: local_port,
};
tracing::debug!(
ip = %endpoint.ip,
port = endpoint.port,
"HTTP fallback discovery successful (port may not reflect NAT mapping)"
);
Ok(endpoint)
}
fn stun_discover(local_port: u16) -> Result<PublicEndpoint> {
let local_addr: SocketAddr = format!("0.0.0.0:{local_port}")
.parse()
.expect("valid socket addr");
let socket = UdpSocket::bind(local_addr).map_err(|e| {
Error::stun_discovery(format!(
"failed to bind UDP socket on port {local_port}: {e}"
))
})?;
socket
.set_read_timeout(Some(STUN_TIMEOUT))
.map_err(|e| Error::stun_discovery(format!("failed to set socket timeout: {e}")))?;
for server in STUN_SERVERS {
tracing::debug!(server, "Attempting STUN discovery");
match stun_query(&socket, server) {
Ok(endpoint) => return Ok(endpoint),
Err(e) => {
tracing::debug!(server, error = %e, "STUN server failed, trying next");
}
}
}
Err(Error::discovery_failed())
}
fn stun_query(socket: &UdpSocket, server: &str) -> Result<PublicEndpoint> {
let server_addr = server
.to_socket_addrs()
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?
.next()
.ok_or_else(|| Error::stun_discovery(format!("{server}: no addresses found")))?;
let client = StunClient::new(server_addr);
let public_addr = client
.query_external_address(socket)
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?;
Ok(PublicEndpoint {
ip: public_addr.ip(),
port: public_addr.port(),
})
}
async fn http_discover_ip() -> Result<IpAddr> {
let client = reqwest::Client::builder()
.timeout(HTTP_TIMEOUT)
.build()
.map_err(Error::http_discovery)?;
for service in HTTP_IP_SERVICES {
tracing::debug!(service, "Attempting HTTP IP discovery");
match http_query_ip(&client, service).await {
Ok(ip) => return Ok(ip),
Err(e) => {
tracing::debug!(service, error = %e, "HTTP service failed, trying next");
}
}
}
Err(Error::discovery_failed())
}
async fn http_query_ip(client: &reqwest::Client, url: &str) -> Result<IpAddr> {
let response = client
.get(url)
.send()
.await
.map_err(Error::http_discovery)?;
let body = response.text().await.map_err(Error::http_discovery)?;
let ip_str = body.trim();
ip_str
.parse::<IpAddr>()
.map_err(|_| Error::parse_ip(ip_str.to_string()))
}

View File

@@ -35,6 +35,16 @@ pub enum ErrorKind {
#[source]
source: std::io::Error,
},
#[error("STUN discovery failed: {message}")]
StunDiscovery { message: String },
#[error("HTTP IP discovery failed")]
HttpDiscovery(#[source] reqwest::Error),
#[error("failed to parse IP address from HTTP response: {body}")]
ParseIp { body: String },
#[error("all discovery methods failed")]
DiscoveryFailed,
#[error("error with request")]
Reqwest(#[from] reqwest::Error),
}
pub type Result<T> = core::result::Result<T, Error>;

View File

@@ -1,10 +1,16 @@
mod app_state;
mod config;
mod discovery;
mod error;
mod wireguard;
use std::{fs::OpenOptions, io::Write, net::IpAddr, os::unix::fs::OpenOptionsExt};
use console::style;
use futures::StreamExt;
use ipnetwork::IpNetwork;
use registry::{Peer, PeerMessage};
use serde::Serialize;
use thiserror_ext::AsReport;
use tokio_tungstenite::tungstenite::Message;
use tracing::level_filters::LevelFilter;
@@ -14,12 +20,22 @@ use tracing_subscriber::{
use url::Url;
use crate::{
app_state::{AppState, Data},
config::{Config, InterfaceConfig, wg_config_path},
discovery::{PublicEndpoint, discover_public_endpoint},
error::{Error, Result},
};
fn parse_url(input: &str) -> Result<String> {
let url = Url::parse(&input)?.join("/ws/peers")?;
#[derive(Serialize)]
pub struct RegisterRequest {
pub public_ip: IpAddr,
pub public_key: String,
pub port: String,
pub allowed_ips: Vec<IpNetwork>,
}
fn parse_ws_url(input: &Url) -> Result<String> {
let url = input.join("/ws/peers")?;
if url.scheme() != "ws" && url.scheme() != "wss" {
return Err(Error::url_scheme(url.to_string()));
}
@@ -29,19 +45,60 @@ fn parse_url(input: &str) -> Result<String> {
fn write_wg_config(interface: &InterfaceConfig, peers: &[Peer]) -> Result<()> {
let path = wg_config_path();
let config = wireguard::generate_config(interface, peers);
std::fs::write(&path, config).map_err(|e| Error::write_config(e, &path))?;
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)
.map_err(|e| Error::write_config(e, &path))?;
file.write_all(config.as_bytes())
.map_err(|e| Error::write_config(e, &path))?;
tracing::info!("wrote {} with {} peers", path.display(), peers.len());
Ok(())
}
async fn register_self(
app_state: Data<AppState>,
endpoint: &PublicEndpoint,
config: &InterfaceConfig,
url: &str,
) -> Result<()> {
app_state
.reqwest_client
.post(url)
.json(&RegisterRequest {
public_key: config.public_key.clone(),
public_ip: endpoint.ip,
port: endpoint.port.to_string(),
allowed_ips: config.allowed_ips.clone(),
})
.send()
.await?
.error_for_status()?;
Ok(())
}
async fn run() -> crate::error::Result<()> {
let config = Config::load()?;
let url = parse_url(&config.server.url)?;
let (ws_stream, response) = tokio_tungstenite::connect_async(&url)
let ws_url = &config.server.ws_url;
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
.await
.map_err(|e| Error::ws_connect(e, &url))?;
.map_err(|e| Error::ws_connect(e, ws_url))?;
let (_, mut read) = ws_stream.split();
let app_state = Data::new(AppState {
reqwest_client: reqwest::Client::new(),
});
tracing::info!("connected, response: {:?}", response.status());
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
register_self(
app_state.clone(),
&endpoint,
&config.interface,
&format!("{}/register", &config.server.url),
)
.await?;
let mut peers: Vec<Peer> = Vec::new();

6
payload.json Normal file
View File

@@ -0,0 +1,6 @@
{
"public_key": "pKkl30tba29FG86wuaC0KrpSHMr1tSOujikHFbx75BM=",
"public_ip": "75.157.238.86",
"port": "51820",
"allowed_ips": ["10.100.0.1/32"]
}

View File

@@ -67,7 +67,7 @@ pub async fn peers(
update = peer_rx.recv() => {
match update {
Ok(peer_update) => {
let json = serde_json::to_string(&peer_update.peer)
let json = serde_json::to_string(&PeerMessage::PeerUpdate { peer: peer_update.peer.clone() })
.unwrap_or_else(|_| "{}".to_string());
if session.text(json).await.is_err() {
break;