feat: move start_watch into trait definition, to do that use async_traits
Some checks failed
/ build (push) Has been cancelled

This commit is contained in:
Tamipes 2026-06-04 19:31:20 +02:00
parent 8a7c3f5203
commit 750e8dcbf0
5 changed files with 111 additions and 101 deletions

12
Cargo.lock generated
View file

@ -76,6 +76,17 @@ dependencies = [
"syn",
]
[[package]]
name = "async-trait"
version = "0.1.89"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@ -954,6 +965,7 @@ name = "mc-ingress"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"clap",
"either",
"evalexpr",

View file

@ -41,3 +41,4 @@ nix = { version= "0.30.1", features = [ "zerocopy"] }
tokio-splice2 = "0.3.2"
strip-ansi-escapes = "0.2.1"
evalexpr = { version = "13.1.0", features = ["regex"] }
async-trait = "0.1.89"

View file

@ -166,7 +166,6 @@ impl MinecraftAPI<Server> for McApi {
a
}
};
tracing::info!(inter_addr = inter_addr);
return Ok(Server {
dep: deployment,
srv: service,
@ -177,77 +176,8 @@ impl MinecraftAPI<Server> for McApi {
});
}
async fn start_watch(
self,
server: impl MinecraftServerHandle,
frequency: Duration,
) -> Result<(), OpaqueError> {
let addr = server.get_addr().ok_or("could not get addr of server")?;
let port = server.get_port().ok_or("could not get port of server")?;
let full_addr = format!("{addr}:{port}");
if let Some(handle) = self.map.lock().await.get(&full_addr) {
if !handle.is_finished() {
return Ok(());
}
}
let span = tracing::span!(parent: None,tracing::Level::INFO, "server_watcher", addr, port);
let full_addr_clone = full_addr.clone();
let api = self.clone();
let handle = tokio::spawn(
async move {
tracing::info!("starting watcher");
loop {
tokio::time::sleep(frequency).await;
let server = match api.query_server(&addr, &port).await {
Ok(x) => x,
Err(e) => {
tracing::error!(
err = format!("{}", e.context),
"could not query server"
);
return;
}
};
let status_json = match server.query_description().await {
Ok(x) => x,
Err(e) => {
tracing::error!(
err = format!("{}", e.context),
"could not query description"
);
return;
}
};
if status_json.get_players_online() == 0 {
// With this I don't need to specify that StatusTrait
// should be send as well.
// Otherwise I would need to have it be defined as:
// trait StatusTrait: Send { ... }
drop(status_json);
let mut guard = api.map.lock().await;
guard.remove(&full_addr_clone);
drop(guard);
if let Err(err) = server.stop().await {
tracing::error!(
trace = %err.print_span_trace(),
err = err.context,
msg = "failed to stop server"
);
}
return;
}
}
}
.instrument(span),
);
let mut guard = self.map.lock().await;
guard.insert(full_addr.clone(), handle);
drop(guard);
Ok(())
fn get_map(&self) -> Arc<tokio::sync::Mutex<HashMap<String, tokio::task::JoinHandle<()>>>> {
self.map.clone()
}
}
@ -290,10 +220,10 @@ impl fmt::Debug for Server {
.name
.unwrap_or("#error#".to_string()),
)
.field("server_addr", &self.server_addr)
.finish()
}
}
#[async_trait::async_trait]
impl MinecraftServerHandle for Server {
async fn start(&self) -> Result<(), OpaqueError> {
self.set_scale(1).await.map_err(|e| {
@ -357,8 +287,8 @@ impl MinecraftServerHandle for Server {
self.inter_addr.as_str()
}
fn get_addr(&self) -> Option<String> {
Some(self.server_addr.clone())
fn get_addr(&self) -> String {
self.server_addr.clone()
}
async fn query_description(&self) -> Result<Box<dyn StatusTrait>, OpaqueError> {
@ -367,9 +297,7 @@ impl MinecraftServerHandle for Server {
ServerDeploymentStatus::Connectable(mut tcp_stream) => {
let handshake = crate::packets::serverbound::handshake::Handshake::create(
crate::types::VarInt::from(746).ok_or("could not create VarInt WTF?")?,
crate::types::VarString::from(
self.get_addr().ok_or("failed to get addr of server")?,
),
crate::types::VarString::from(self.get_addr()),
crate::types::UShort::from(1234),
crate::types::VarInt::from(1).ok_or("could not create VarInt WTF?")?,
)
@ -403,8 +331,8 @@ impl MinecraftServerHandle for Server {
}
}
fn get_port(&self) -> Option<String> {
Some(self.server_port.clone())
fn get_port(&self) -> String {
self.server_port.clone()
}
fn get_motd(&self) -> Option<String> {
@ -434,15 +362,13 @@ impl MinecraftServerHandle for Server {
impl Server {
async fn set_scale(&self, num: i32) -> Result<(), kube::Error> {
let name = self
.srv
.dep
.metadata
.clone()
.name
.unwrap_or("#error#".to_string());
let res = self.cache.set_dep_scale(&name, num).await;
if res.is_ok() {
tracing::info!("scaled replicas of {} to {num}", self.server_addr);
}
let _res = self.cache.set_dep_scale(&name, num).await?;
tracing::info!("scaled replicas of {} to {num}", self.server_addr);
Ok(())
}
}

View file

@ -100,9 +100,12 @@ async fn main() {
async fn process_connection<T: MinecraftServerHandle>(
mut client_stream: TcpStream,
addr: SocketAddr,
api: impl MinecraftAPI<T>,
api: impl MinecraftAPI<T> + Send + Sync + 'static + Clone,
config: Config,
) -> Result<(), OpaqueError> {
) -> Result<(), OpaqueError>
where
T: Send + Sync + 'static,
{
// this is wrapper so that async doesnt mess up the span, and
// to make sure this doesn't propagate to later `handle_*`
#[tracing::instrument(level = "info", skip(client_stream, config))]
@ -185,7 +188,10 @@ async fn handle_status<T: MinecraftServerHandle>(
client_stream: &mut TcpStream,
handshake: &Handshake,
api: impl MinecraftAPI<T>,
) -> Result<(), OpaqueError> {
) -> Result<(), OpaqueError>
where
T: Send + Sync + 'static,
{
let client_packet = Packet::parse(client_stream).await?;
if client_packet.id.get_int() != 0 {
return Err(OpaqueError::create(&format!(
@ -276,8 +282,11 @@ async fn handle_login<T: MinecraftServerHandle>(
client_stream: &mut TcpStream,
handshake: &Handshake,
login_start: LoginStart,
api: impl MinecraftAPI<T>,
) -> Result<(), OpaqueError> {
api: impl MinecraftAPI<T> + Send + Sync + 'static + Clone,
) -> Result<(), OpaqueError>
where
T: Send + Sync + 'static,
{
let server = api
.query_server(
&handshake.get_server_address(),
@ -290,7 +299,7 @@ async fn handle_login<T: MinecraftServerHandle>(
tracing::debug!(msg = "server status", status = ?status);
match status {
ServerDeploymentStatus::Connectable(mut server_stream) => {
api.start_watch(server.clone(), Duration::from_secs(600))
api.start_watch(server.clone(), Duration::from_secs(60))
.await?;
// referenced from:
@ -341,7 +350,7 @@ async fn handle_login<T: MinecraftServerHandle>(
}
ServerDeploymentStatus::Offline => {
server.start().await?;
api.start_watch(server.clone(), Duration::from_secs(600))
api.start_watch(server.clone(), Duration::from_secs(60))
.await?;
mc_server::send_disconnect(client_stream, format!("[\"\",{{\"text\":\"Okayy, §2starting§r the server!\n\n\"}},{{\"text\":\"{BYE_MESSAGE}\"}}]").as_str()).await?;
}

View file

@ -1,4 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use tokio::{io::AsyncWriteExt, net::TcpStream};
use tracing::Instrument;
use crate::{
packets::{
@ -60,13 +63,14 @@ pub async fn send_disconnect(
Ok(())
}
#[async_trait::async_trait]
pub trait MinecraftServerHandle: Clone {
async fn start(&self) -> Result<(), OpaqueError>;
async fn stop(&self) -> Result<(), OpaqueError>;
async fn query_status(&self) -> Result<ServerDeploymentStatus, OpaqueError>;
fn get_internal_address(&self) -> &str;
fn get_addr(&self) -> Option<String>;
fn get_port(&self) -> Option<String>;
fn get_addr(&self) -> String;
fn get_port(&self) -> String;
fn get_motd(&self) -> Option<String>;
async fn query_server_connectable(&self) -> Result<TcpStream, OpaqueError> {
@ -111,24 +115,82 @@ pub trait MinecraftServerHandle: Clone {
tracing::trace!("data exchanged while proxying status: {:?}", data_amount);
Ok(())
}
// TODO: move the implementation to here, but
// the async things are *strange* in rust
fn query_description(
&self,
) -> impl std::future::Future<Output = Result<Box<dyn StatusTrait>, OpaqueError>> + Send;
async fn query_description(&self) -> Result<Box<dyn StatusTrait>, OpaqueError>;
}
pub trait MinecraftAPI<T> {
async fn query_server(&self, addr: &str, port: &str) -> Result<T, OpaqueError>;
fn get_map(&self) -> Arc<tokio::sync::Mutex<HashMap<String, tokio::task::JoinHandle<()>>>>;
// TODO: move the implementation to here, but
/// This should be callable even if there is already a watcher,
/// and it should handle the collision itself while returning OK().
async fn start_watch(
self,
server: impl MinecraftServerHandle,
server: impl MinecraftServerHandle + Send + Sync + 'static,
frequency: std::time::Duration,
) -> Result<(), OpaqueError>;
) -> Result<(), OpaqueError>
where
Self: Send + Sync + 'static + Clone,
{
let inter_addr = server.get_internal_address().to_string();
if let Some(handle) = self.get_map().lock().await.get(&inter_addr) {
if !handle.is_finished() {
return Ok(());
}
}
let span = tracing::span!(parent: None,tracing::Level::INFO, "server_watcher", inter_addr, join_addr = server.get_addr(), join_port = server.get_port());
let full_addr_clone = inter_addr.clone();
let api = self.clone();
let handle = tokio::spawn(
async move {
tracing::info!("starting watcher");
loop {
tokio::time::sleep(frequency).await;
let status_json = match server.query_description().await {
Ok(x) => x,
Err(e) => {
tracing::error!(
err = format!("{}", e.context),
"could not query description"
);
return;
}
};
if status_json.get_players_online() == 0 {
// With this I don't need to specify that StatusTrait
// should be send as well.
// Otherwise I would need to have it be defined as:
// trait StatusTrait: Send { ... }
drop(status_json);
let map = api.get_map();
let mut guard = map.lock().await;
guard.remove(&full_addr_clone);
drop(guard);
drop(map);
if let Err(err) = server.stop().await {
tracing::error!(
trace = %err.print_span_trace(),
err = err.context,
msg = "failed to stop server"
);
}
return;
}
}
}
.instrument(span),
);
let map = self.get_map();
let mut guard = map.lock().await;
guard.insert(inter_addr.clone(), handle);
drop(guard);
Ok(())
}
}
pub enum ServerDeploymentStatus {