Files
Gym_CPU/scripts/commons/Server.py

164 lines
6.1 KiB
Python

import subprocess
import os
import time
import threading
class Server():
WATCHDOG_ENABLED = True
WATCHDOG_INTERVAL_SEC = 30.0
WATCHDOG_RSS_MB_LIMIT = 800
def __init__(self, first_server_p, first_monitor_p, n_servers, no_render=True, no_realtime=True) -> None:
try:
import psutil
self.check_running_servers(psutil, first_server_p, first_monitor_p, n_servers)
except ModuleNotFoundError:
print("Info: Cannot check if the server is already running, because the psutil module was not found")
self.first_server_p = first_server_p
self.n_servers = n_servers
self.rcss_processes = []
self._server_specs = []
self._watchdog_stop = threading.Event()
self._watchdog_lock = threading.Lock()
self._watchdog_thread = None
first_monitor_p = first_monitor_p + 100
# makes it easier to kill test servers without affecting train servers
cmd = "rcssservermj"
render_arg = "--no-render" if no_render else ""
realtime_arg = "--no-realtime" if no_realtime else ""
for i in range(n_servers):
port = first_server_p + i
mport = first_monitor_p + i
self._server_specs.append((port, mport, cmd, render_arg, realtime_arg))
proc = self._spawn_server(port, mport, cmd, render_arg, realtime_arg)
self.rcss_processes.append(proc)
if self.WATCHDOG_ENABLED:
self._watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True)
self._watchdog_thread.start()
def _spawn_server(self, port, mport, cmd, render_arg, realtime_arg):
server_cmd = f"{cmd} -c {port} -m {mport} {render_arg} {realtime_arg}".strip()
proc = subprocess.Popen(
server_cmd.split(),
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT,
start_new_session=True
)
# Avoid startup storm when launching many servers at once.
time.sleep(0.03)
rc = proc.poll()
if rc is not None:
raise RuntimeError(
f"rcssservermj exited early (code={rc}) on server port {port}, monitor port {mport}"
)
return proc
@staticmethod
def _pid_rss_mb(pid):
try:
with open(f"/proc/{pid}/status", "r", encoding="utf-8") as f:
for line in f:
if line.startswith("VmRSS:"):
parts = line.split()
if len(parts) >= 2:
# VmRSS is kB
return float(parts[1]) / 1024.0
except (FileNotFoundError, ProcessLookupError, PermissionError, OSError):
return 0.0
return 0.0
def _restart_server_at_index(self, idx, reason):
port, mport, cmd, render_arg, realtime_arg = self._server_specs[idx]
old_proc = self.rcss_processes[idx]
try:
old_proc.terminate()
old_proc.wait(timeout=1.0)
except Exception:
try:
old_proc.kill()
except Exception:
pass
new_proc = self._spawn_server(port, mport, cmd, render_arg, realtime_arg)
self.rcss_processes[idx] = new_proc
print(
f"[ServerWatchdog] Restarted server idx={idx} port={port} monitor={mport} reason={reason}"
)
def _watchdog_loop(self):
while not self._watchdog_stop.wait(self.WATCHDOG_INTERVAL_SEC):
with self._watchdog_lock:
for i, proc in enumerate(self.rcss_processes):
rc = proc.poll()
if rc is not None:
self._restart_server_at_index(i, f"exited:{rc}")
continue
rss_mb = self._pid_rss_mb(proc.pid)
if rss_mb > self.WATCHDOG_RSS_MB_LIMIT:
self._restart_server_at_index(i, f"rss_mb:{rss_mb:.1f}")
def check_running_servers(self, psutil, first_server_p, first_monitor_p, n_servers):
''' Check if any server is running on chosen ports '''
found = False
range1 = (first_server_p, first_server_p + n_servers)
range2 = (first_monitor_p, first_monitor_p + n_servers)
bad_processes = []
def safe_cmdline(proc):
try:
return proc.cmdline()
except (psutil.ZombieProcess, psutil.NoSuchProcess, psutil.AccessDenied, OSError):
return []
p_list = []
for p in psutil.process_iter():
cmdline = safe_cmdline(p)
if cmdline and "rcssservermj" in " ".join(cmdline):
p_list.append(p)
for p in p_list:
# currently ignoring remaining default port when only one of the ports is specified (uncommon scenario)
cmdline = safe_cmdline(p)
if not cmdline:
continue
ports = [int(arg) for arg in cmdline[1:] if arg.isdigit()]
if len(ports) == 0:
ports = [60000, 60100] # default server ports (changing this is unlikely)
conflicts = [str(port) for port in ports if (
(range1[0] <= port < range1[1]) or (range2[0] <= port < range2[1]))]
if len(conflicts) > 0:
if not found:
print("\nThere are already servers running on the same port(s)!")
found = True
bad_processes.append(p)
print(f"Port(s) {','.join(conflicts)} already in use by \"{' '.join(cmdline)}\" (PID:{p.pid})")
if found:
print()
while True:
inp = input("Enter 'kill' to kill these processes or ctrl+c to abort. ")
if inp == "kill":
for p in bad_processes:
p.kill()
return
def kill(self):
self._watchdog_stop.set()
if self._watchdog_thread is not None:
self._watchdog_thread.join(timeout=1.0)
for p in self.rcss_processes:
p.kill()
print(f"Killed {self.n_servers} rcssservermj processes starting at {self.first_server_p}")