feat!: add stun discovery
This commit is contained in:
759
Cargo.lock
generated
759
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,3 @@
|
|||||||
[workspace]
|
[workspace]
|
||||||
members = ["registry", "client"]
|
members = ["registry", "client"]
|
||||||
resolver = "3"
|
resolver = "3"
|
||||||
|
|
||||||
[workspace.dependencies]
|
|
||||||
|
|||||||
@@ -19,3 +19,6 @@ registry = { path = "../registry" }
|
|||||||
dirs = "6.0.0"
|
dirs = "6.0.0"
|
||||||
toml = "1.0.1"
|
toml = "1.0.1"
|
||||||
url = "2.5.8"
|
url = "2.5.8"
|
||||||
|
reqwest = { version = "0.13.2", features = ["json"] }
|
||||||
|
stunclient = "0.4.2"
|
||||||
|
ipnetwork = { version = "0.21.1", features = ["serde"] }
|
||||||
|
|||||||
22
client/src/app_state.rs
Normal file
22
client/src/app_state.rs
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
use std::{ops::Deref, sync::Arc};
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub reqwest_client: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Data<T>(Arc<T>);
|
||||||
|
|
||||||
|
impl<T> Deref for Data<T> {
|
||||||
|
type Target = T;
|
||||||
|
fn deref(&self) -> &Self::Target {
|
||||||
|
&self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Data<T> {
|
||||||
|
pub fn new(inner: T) -> Self {
|
||||||
|
Self(Arc::new(inner))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use crate::error::{Error, Result};
|
use crate::error::{Error, Result};
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@@ -11,11 +12,14 @@ pub struct Config {
|
|||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct InterfaceConfig {
|
pub struct InterfaceConfig {
|
||||||
pub private_key: String,
|
pub private_key: String,
|
||||||
|
pub public_key: String,
|
||||||
pub listen_port: u16,
|
pub listen_port: u16,
|
||||||
pub address: String,
|
pub address: String,
|
||||||
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
}
|
}
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
|
pub ws_url: String,
|
||||||
pub url: String,
|
pub url: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
142
client/src/discovery.rs
Normal file
142
client/src/discovery.rs
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
use std::net::{IpAddr, SocketAddr, ToSocketAddrs, UdpSocket};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use stunclient::StunClient;
|
||||||
|
|
||||||
|
use crate::error::{Error, Result};
|
||||||
|
|
||||||
|
const STUN_SERVERS: &[&str] = &[
|
||||||
|
"stun.l.google.com:19302",
|
||||||
|
"stun.cloudflare.com:3478",
|
||||||
|
"stun.stunprotocol.org:3478",
|
||||||
|
];
|
||||||
|
|
||||||
|
const HTTP_IP_SERVICES: &[&str] = &["https://icanhazip.com", "https://ifconfig.me/ip"];
|
||||||
|
|
||||||
|
const STUN_TIMEOUT: Duration = Duration::from_secs(3);
|
||||||
|
const HTTP_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PublicEndpoint {
|
||||||
|
pub ip: IpAddr,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for PublicEndpoint {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{}:{}", self.ip, self.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn discover_public_endpoint(local_port: u16) -> Result<PublicEndpoint> {
|
||||||
|
match stun_discover(local_port) {
|
||||||
|
Ok(endpoint) => {
|
||||||
|
tracing::debug!(
|
||||||
|
ip = %endpoint.ip,
|
||||||
|
port = endpoint.port,
|
||||||
|
"STUN discovery successful"
|
||||||
|
);
|
||||||
|
return Ok(endpoint);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("STUN discovery failed, falling back to HTTP: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ip = http_discover_ip().await?;
|
||||||
|
let endpoint = PublicEndpoint {
|
||||||
|
ip,
|
||||||
|
port: local_port,
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
ip = %endpoint.ip,
|
||||||
|
port = endpoint.port,
|
||||||
|
"HTTP fallback discovery successful (port may not reflect NAT mapping)"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stun_discover(local_port: u16) -> Result<PublicEndpoint> {
|
||||||
|
let local_addr: SocketAddr = format!("0.0.0.0:{local_port}")
|
||||||
|
.parse()
|
||||||
|
.expect("valid socket addr");
|
||||||
|
|
||||||
|
let socket = UdpSocket::bind(local_addr).map_err(|e| {
|
||||||
|
Error::stun_discovery(format!(
|
||||||
|
"failed to bind UDP socket on port {local_port}: {e}"
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
socket
|
||||||
|
.set_read_timeout(Some(STUN_TIMEOUT))
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("failed to set socket timeout: {e}")))?;
|
||||||
|
|
||||||
|
for server in STUN_SERVERS {
|
||||||
|
tracing::debug!(server, "Attempting STUN discovery");
|
||||||
|
|
||||||
|
match stun_query(&socket, server) {
|
||||||
|
Ok(endpoint) => return Ok(endpoint),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(server, error = %e, "STUN server failed, trying next");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::discovery_failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stun_query(socket: &UdpSocket, server: &str) -> Result<PublicEndpoint> {
|
||||||
|
let server_addr = server
|
||||||
|
.to_socket_addrs()
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| Error::stun_discovery(format!("{server}: no addresses found")))?;
|
||||||
|
|
||||||
|
let client = StunClient::new(server_addr);
|
||||||
|
|
||||||
|
let public_addr = client
|
||||||
|
.query_external_address(socket)
|
||||||
|
.map_err(|e| Error::stun_discovery(format!("{server}: {e}")))?;
|
||||||
|
|
||||||
|
Ok(PublicEndpoint {
|
||||||
|
ip: public_addr.ip(),
|
||||||
|
port: public_addr.port(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn http_discover_ip() -> Result<IpAddr> {
|
||||||
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(HTTP_TIMEOUT)
|
||||||
|
.build()
|
||||||
|
.map_err(Error::http_discovery)?;
|
||||||
|
|
||||||
|
for service in HTTP_IP_SERVICES {
|
||||||
|
tracing::debug!(service, "Attempting HTTP IP discovery");
|
||||||
|
|
||||||
|
match http_query_ip(&client, service).await {
|
||||||
|
Ok(ip) => return Ok(ip),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::debug!(service, error = %e, "HTTP service failed, trying next");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::discovery_failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn http_query_ip(client: &reqwest::Client, url: &str) -> Result<IpAddr> {
|
||||||
|
let response = client
|
||||||
|
.get(url)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(Error::http_discovery)?;
|
||||||
|
|
||||||
|
let body = response.text().await.map_err(Error::http_discovery)?;
|
||||||
|
let ip_str = body.trim();
|
||||||
|
|
||||||
|
ip_str
|
||||||
|
.parse::<IpAddr>()
|
||||||
|
.map_err(|_| Error::parse_ip(ip_str.to_string()))
|
||||||
|
}
|
||||||
@@ -35,6 +35,16 @@ pub enum ErrorKind {
|
|||||||
#[source]
|
#[source]
|
||||||
source: std::io::Error,
|
source: std::io::Error,
|
||||||
},
|
},
|
||||||
|
#[error("STUN discovery failed: {message}")]
|
||||||
|
StunDiscovery { message: String },
|
||||||
|
#[error("HTTP IP discovery failed")]
|
||||||
|
HttpDiscovery(#[source] reqwest::Error),
|
||||||
|
#[error("failed to parse IP address from HTTP response: {body}")]
|
||||||
|
ParseIp { body: String },
|
||||||
|
#[error("all discovery methods failed")]
|
||||||
|
DiscoveryFailed,
|
||||||
|
#[error("error with request")]
|
||||||
|
Reqwest(#[from] reqwest::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = core::result::Result<T, Error>;
|
pub type Result<T> = core::result::Result<T, Error>;
|
||||||
|
|||||||
@@ -1,10 +1,16 @@
|
|||||||
|
mod app_state;
|
||||||
mod config;
|
mod config;
|
||||||
|
mod discovery;
|
||||||
mod error;
|
mod error;
|
||||||
mod wireguard;
|
mod wireguard;
|
||||||
|
|
||||||
|
use std::{fs::OpenOptions, io::Write, net::IpAddr, os::unix::fs::OpenOptionsExt};
|
||||||
|
|
||||||
use console::style;
|
use console::style;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use ipnetwork::IpNetwork;
|
||||||
use registry::{Peer, PeerMessage};
|
use registry::{Peer, PeerMessage};
|
||||||
|
use serde::Serialize;
|
||||||
use thiserror_ext::AsReport;
|
use thiserror_ext::AsReport;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
use tracing::level_filters::LevelFilter;
|
use tracing::level_filters::LevelFilter;
|
||||||
@@ -14,12 +20,22 @@ use tracing_subscriber::{
|
|||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
app_state::{AppState, Data},
|
||||||
config::{Config, InterfaceConfig, wg_config_path},
|
config::{Config, InterfaceConfig, wg_config_path},
|
||||||
|
discovery::{PublicEndpoint, discover_public_endpoint},
|
||||||
error::{Error, Result},
|
error::{Error, Result},
|
||||||
};
|
};
|
||||||
|
|
||||||
fn parse_url(input: &str) -> Result<String> {
|
#[derive(Serialize)]
|
||||||
let url = Url::parse(&input)?.join("/ws/peers")?;
|
pub struct RegisterRequest {
|
||||||
|
pub public_ip: IpAddr,
|
||||||
|
pub public_key: String,
|
||||||
|
pub port: String,
|
||||||
|
pub allowed_ips: Vec<IpNetwork>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_ws_url(input: &Url) -> Result<String> {
|
||||||
|
let url = input.join("/ws/peers")?;
|
||||||
if url.scheme() != "ws" && url.scheme() != "wss" {
|
if url.scheme() != "ws" && url.scheme() != "wss" {
|
||||||
return Err(Error::url_scheme(url.to_string()));
|
return Err(Error::url_scheme(url.to_string()));
|
||||||
}
|
}
|
||||||
@@ -29,19 +45,60 @@ fn parse_url(input: &str) -> Result<String> {
|
|||||||
fn write_wg_config(interface: &InterfaceConfig, peers: &[Peer]) -> Result<()> {
|
fn write_wg_config(interface: &InterfaceConfig, peers: &[Peer]) -> Result<()> {
|
||||||
let path = wg_config_path();
|
let path = wg_config_path();
|
||||||
let config = wireguard::generate_config(interface, peers);
|
let config = wireguard::generate_config(interface, peers);
|
||||||
std::fs::write(&path, config).map_err(|e| Error::write_config(e, &path))?;
|
let mut file = OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.truncate(true)
|
||||||
|
.mode(0o600)
|
||||||
|
.open(&path)
|
||||||
|
.map_err(|e| Error::write_config(e, &path))?;
|
||||||
|
file.write_all(config.as_bytes())
|
||||||
|
.map_err(|e| Error::write_config(e, &path))?;
|
||||||
tracing::info!("wrote {} with {} peers", path.display(), peers.len());
|
tracing::info!("wrote {} with {} peers", path.display(), peers.len());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn register_self(
|
||||||
|
app_state: Data<AppState>,
|
||||||
|
endpoint: &PublicEndpoint,
|
||||||
|
config: &InterfaceConfig,
|
||||||
|
url: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
app_state
|
||||||
|
.reqwest_client
|
||||||
|
.post(url)
|
||||||
|
.json(&RegisterRequest {
|
||||||
|
public_key: config.public_key.clone(),
|
||||||
|
public_ip: endpoint.ip,
|
||||||
|
port: endpoint.port.to_string(),
|
||||||
|
allowed_ips: config.allowed_ips.clone(),
|
||||||
|
})
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn run() -> crate::error::Result<()> {
|
async fn run() -> crate::error::Result<()> {
|
||||||
let config = Config::load()?;
|
let config = Config::load()?;
|
||||||
let url = parse_url(&config.server.url)?;
|
let ws_url = &config.server.ws_url;
|
||||||
let (ws_stream, response) = tokio_tungstenite::connect_async(&url)
|
let (ws_stream, response) = tokio_tungstenite::connect_async(ws_url)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| Error::ws_connect(e, &url))?;
|
.map_err(|e| Error::ws_connect(e, ws_url))?;
|
||||||
let (_, mut read) = ws_stream.split();
|
let (_, mut read) = ws_stream.split();
|
||||||
|
let app_state = Data::new(AppState {
|
||||||
|
reqwest_client: reqwest::Client::new(),
|
||||||
|
});
|
||||||
tracing::info!("connected, response: {:?}", response.status());
|
tracing::info!("connected, response: {:?}", response.status());
|
||||||
|
let endpoint = discover_public_endpoint(config.interface.listen_port).await?;
|
||||||
|
register_self(
|
||||||
|
app_state.clone(),
|
||||||
|
&endpoint,
|
||||||
|
&config.interface,
|
||||||
|
&format!("{}/register", &config.server.url),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let mut peers: Vec<Peer> = Vec::new();
|
let mut peers: Vec<Peer> = Vec::new();
|
||||||
|
|
||||||
|
|||||||
6
payload.json
Normal file
6
payload.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"public_key": "pKkl30tba29FG86wuaC0KrpSHMr1tSOujikHFbx75BM=",
|
||||||
|
"public_ip": "75.157.238.86",
|
||||||
|
"port": "51820",
|
||||||
|
"allowed_ips": ["10.100.0.1/32"]
|
||||||
|
}
|
||||||
@@ -67,7 +67,7 @@ pub async fn peers(
|
|||||||
update = peer_rx.recv() => {
|
update = peer_rx.recv() => {
|
||||||
match update {
|
match update {
|
||||||
Ok(peer_update) => {
|
Ok(peer_update) => {
|
||||||
let json = serde_json::to_string(&peer_update.peer)
|
let json = serde_json::to_string(&PeerMessage::PeerUpdate { peer: peer_update.peer.clone() })
|
||||||
.unwrap_or_else(|_| "{}".to_string());
|
.unwrap_or_else(|_| "{}".to_string());
|
||||||
if session.text(json).await.is_err() {
|
if session.text(json).await.is_err() {
|
||||||
break;
|
break;
|
||||||
|
|||||||
Reference in New Issue
Block a user