feat!: add stun discovery
This commit is contained in:
@@ -19,3 +19,6 @@ registry = { path = "../registry" }
|
||||
dirs = "6.0.0"
|
||||
toml = "1.0.1"
|
||||
url = "2.5.8"
|
||||
reqwest = { version = "0.13.2", features = ["json"] }
|
||||
stunclient = "0.4.2"
|
||||
ipnetwork = { version = "0.21.1", features = ["serde"] }
|
||||
|
||||
22
client/src/app_state.rs
Normal file
22
client/src/app_state.rs
Normal file
@@ -0,0 +1,22 @@
|
||||
use std::{ops::Deref, sync::Arc};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub reqwest_client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Data<T>(Arc<T>);
|
||||
|
||||
impl<T> Deref for Data<T> {
|
||||
type Target = T;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Data<T> {
|
||||
pub fn new(inner: T) -> Self {
|
||||
Self(Arc::new(inner))
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
use ipnetwork::IpNetwork;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
@@ -11,11 +12,14 @@ pub struct Config {
|
||||
#[derive(Deserialize)]
|
||||
pub struct InterfaceConfig {
|
||||
pub private_key: String,
|
||||
pub public_key: String,
|
||||
pub listen_port: u16,
|
||||
pub address: String,
|
||||
pub allowed_ips: Vec<IpNetwork>,
|
||||
}
|
||||
#[derive(Deserialize)]
|
||||
pub struct ServerConfig {
|
||||
pub ws_url: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
|
||||
142
client/src/discovery.rs
Normal file
142
client/src/discovery.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
|
||||
use std::time::Duration;
|
||||
|
||||
use stunclient::StunClient;
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
const STUN_SERVERS: &[&str] = &[
|
||||
"stun.l.google.com:19302",
|
||||
"stun.cloudflare.com:3478",
|
||||
"stun.stunprotocol.org:3478",
|
||||
];
|
||||
|
||||
const HTTP_IP_SERVICES: &[&str] = &["https://icanhazip.com", "https://ifconfig.me/ip"];
|
||||
|
||||
const STUN_TIMEOUT: Duration = Duration::from_secs(3);
|
||||
const HTTP_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PublicEndpoint {
|
||||
pub ip: IpAddr,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PublicEndpoint {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}:{}", self.ip, self.port)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn discover_public_endpoint(local_port: u16) -> Result<PublicEndpoint> {
|
||||
match stun_discover(local_port) {
|
||||
Ok(endpoint) => {
|
||||
tracing::debug!(
|
||||
ip = %endpoint.ip,
|
||||
port = endpoint.port,
|
||||
"STUN discovery successful"
|
||||
);
|
||||
return Ok(endpoint);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("STUN discovery failed, falling back to HTTP: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let ip = http_discover_ip().await?;
|
||||
let endpoint = PublicEndpoint {
|
||||
ip,
|
||||
port: local_port,
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
ip = %endpoint.ip,
|
||||
port = endpoint.port,
|
||||
"HTTP fallback discovery successful (port may not reflect NAT mapping)"
|
||||
);
|
||||
|
||||
Ok(endpoint)
|
||||
}
|
||||
|
||||
fn stun_discover(local_port: u16) -> Result<PublicEndpoint> {
|
||||
let local_addr: SocketAddr = format!("0.0.0.0:{local_port}")
|
||||
.parse()
|
||||
.expect("valid socket addr");
|
||||
|
||||
let socket = UdpSocket::bind(local_addr).map_err(|e| {
|
||||
Error::stun_discovery(format!(
|
||||
"failed to bind UDP socket on port {local_port}: {e}"
|
||||
))
|
||||
})?;
|
||||
|
||||
socket
|
||||
.set_read_timeout(Some(STUN_TIMEOUT))
|
||||
.map_err(|e| Error::stun_discovery(format!("failed to set socket timeout: {e}")))?;
|
||||
|
||||
for server in STUN_SERVERS {
|
||||
tracing::debug!(server, "Attempting STUN discovery");
|
||||
|
||||
match stun_query(&socket, server) {
|
||||
Ok(endpoint) => return Ok(endpoint),
|
||||
Err(e) => {
|
||||
tracing::debug!(server, error = %e, "STUN server failed, trying next");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::discovery_failed())
|
||||
}
|
||||
|
||||
fn stun_query(socket: &UdpSocket, server: &str) -> Result<PublicEndpoint> {
|
||||
let server_addr = server
|
||||
.to_socket_addrs()
|
||||
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?
|
||||
.next()
|
||||
.ok_or_else(|| Error::stun_discovery(format!("{server}: no addresses found")))?;
|
||||
|
||||
let client = StunClient::new(server_addr);
|
||||
|
||||
let public_addr = client
|
||||
.query_external_address(socket)
|
||||
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?;
|
||||
|
||||
Ok(PublicEndpoint {
|
||||
ip: public_addr.ip(),
|
||||
port: public_addr.port(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn http_discover_ip() -> Result<IpAddr> {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(HTTP_TIMEOUT)
|
||||
.build()
|
||||
.map_err(Error::http_discovery)?;
|
||||
|
||||
for service in HTTP_IP_SERVICES {
|
||||
tracing::debug!(service, "Attempting HTTP IP discovery");
|
||||
|
||||
match http_query_ip(&client, service).await {
|
||||
Ok(ip) => return Ok(ip),
|
||||
Err(e) => {
|
||||
tracing::debug!(service, error = %e, "HTTP service failed, trying next");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(Error::discovery_failed())
|
||||
}
|
||||
|
||||
async fn http_query_ip(client: &reqwest::Client, url: &str) -> Result<IpAddr> {
|
||||
let response = client
|
||||
.get(url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(Error::http_discovery)?;
|
||||
|
||||
let body = response.text().await.map_err(Error::http_discovery)?;
|
||||
let ip_str = body.trim();
|
||||
|
||||
ip_str
|
||||
.parse::<IpAddr>()
|
||||
.map_err(|_| Error::parse_ip(ip_str.to_string()))
|
||||
}
|
||||
@@ -35,6 +35,16 @@ pub enum ErrorKind {
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
#[error("STUN discovery failed: {message}")]
|
||||
StunDiscovery { message: String },
|
||||
#[error("HTTP IP discovery failed")]
|
||||
HttpDiscovery(#[source] reqwest::Error),
|
||||
#[error("failed to parse IP address from HTTP response: {body}")]
|
||||
ParseIp { body: String },
|
||||
#[error("all discovery methods failed")]
|
||||
DiscoveryFailed,
|
||||
#[error("error with request")]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
pub type Result<T> = core::result::Result<T, Error>;
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
mod app_state;
|
||||
mod config;
|
||||
mod discovery;
|
||||
mod error;
|
||||
mod wireguard;
|
||||
|
||||
use std::{fs::OpenOptions, io::Write, net::IpAddr, os::unix::fs::OpenOptionsExt};
|
||||
|
||||
use console::style;
|
||||
use futures::StreamExt;
|
||||
use ipnetwork::IpNetwork;
|
||||
use registry::{Peer, PeerMessage};
|
||||
use serde::Serialize;
|
||||
use thiserror_ext::AsReport;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::level_filters::LevelFilter;
|
||||
@@ -14,12 +20,22 @@ use tracing_subscriber::{
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
app_state::{AppState, Data},
|
||||
config::{Config, InterfaceConfig, wg_config_path},
|
||||
discovery::{PublicEndpoint, discover_public_endpoint},
|
||||
error::{Error, Result},
|
||||
};
|
||||
|
||||
fn parse_url(input: &str) -> Result<String> {
|
||||
let url = Url::parse(&input)?.join("/ws/peers")?;
|
||||
#[derive(Serialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub public_ip: IpAddr,
|
||||
pub public_key: String,
|
||||
pub port: String,
|
||||
pub allowed_ips: Vec<IpNetwork>,
|
||||
}
|
||||
|
||||
fn parse_ws_url(input: &Url) -> Result<String> {
|
||||
let url = input.join("/ws/peers")?;
|
||||
if url.scheme() != "ws" && url.scheme() != "wss" {
|
||||
return Err(Error::url_scheme(url.to_string()));
|
||||
}
|
||||
@@ -29,19 +45,60 @@ fn parse_url(input: &str) -> Result<String> {
|
||||
fn write_wg_config(interface: &InterfaceConfig, peers: &[Peer]) -> Result<()> {
|
||||
let path = wg_config_path();
|
||||
let config = wireguard::generate_config(interface, peers);
|
||||
std::fs::write(&path, config).map_err(|e| Error::write_config(e, &path))?;
|
||||
let mut file = OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)
|
||||
.map_err(|e| Error::write_config(e, &path))?;
|
||||
file.write_all(config.as_bytes())
|
||||
.map_err(|e| Error::write_config(e, &path))?;
|
||||
tracing::info!("wrote {} with {} peers", path.display(), peers.len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn register_self(
|
||||
app_state: Data<AppState>,
|
||||
endpoint: &PublicEndpoint,
|
||||
config: &InterfaceConfig,
|
||||
url: &str,
|
||||
) -> Result<()> {
|
||||
app_state
|
||||
.reqwest_client
|
||||
.post(url)
|
||||
.json(&RegisterRequest {
|
||||
public_key: config.public_key.clone(),
|
||||
public_ip: endpoint.ip,
|
||||
port: endpoint.port.to_string(),
|
||||
allowed_ips: config.allowed_ips.clone(),
|
||||
})
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run() -> crate::error::Result<()> {
|
||||
let config = Config::load()?;
|
||||
let url = parse_url(&config.server.url)?;
|
||||
let (ws_stream, response) = tokio_tungstenite::connect_async(&url)
|
||||
let ws_url = &config.server.ws_url;
|
||||
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
|
||||
.await
|
||||
.map_err(|e| Error::ws_connect(e, &url))?;
|
||||
.map_err(|e| Error::ws_connect(e, ws_url))?;
|
||||
let (_, mut read) = ws_stream.split();
|
||||
let app_state = Data::new(AppState {
|
||||
reqwest_client: reqwest::Client::new(),
|
||||
});
|
||||
tracing::info!("connected, response: {:?}", response.status());
|
||||
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
|
||||
register_self(
|
||||
app_state.clone(),
|
||||
&endpoint,
|
||||
&config.interface,
|
||||
&format!("{}/register", &config.server.url),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut peers: Vec<Peer> = Vec::new();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user