Skip to content

Commit 338d9a0

Browse files
committed
Fixed formatting.
1 parent 3896b94 commit 338d9a0

2 files changed

Lines changed: 51 additions & 32 deletions

File tree

python/rcs/rpc/server.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
21
# import wrapper
32
from gymnasium import Wrapper
43
import rpyc
54
from rpyc.utils.server import ThreadedServer
6-
rpyc.core.protocol.DEFAULT_CONFIG['allow_pickle'] = True
5+
6+
rpyc.core.protocol.DEFAULT_CONFIG["allow_pickle"] = True
7+
78

89
@rpyc.service
910
class RcsServer(Wrapper, rpyc.Service):
10-
def __init__(self, env, host='localhost', port=50051):
11+
def __init__(self, env, host="localhost", port=50051):
1112
super().__init__(env)
1213
self.host = host
1314
self.port = port
@@ -25,25 +26,24 @@ def reset(self, **kwargs):
2526
@rpyc.exposed
2627
def get_obs(self):
2728
"""Get the current observation using the Wrapper base class if available."""
28-
if hasattr(super(), 'get_obs'):
29+
if hasattr(super(), "get_obs"):
2930
return super().get_obs()
30-
elif hasattr(self.env, 'get_obs'):
31+
if hasattr(self.env, "get_obs"):
3132
return self.env.get_obs()
32-
else:
33-
raise NotImplementedError("The environment does not have a get_obs method.")
33+
error = "The environment does not have a get_obs method."
34+
raise NotImplementedError(error)
3435

3536
@rpyc.exposed
3637
def unwrapped(self):
3738
"""Return the unwrapped environment using the Wrapper base class."""
3839
return super().unwrapped
39-
40+
4041
@rpyc.exposed
4142
def action_space(self):
4243
"""Return the action space using the Wrapper base class."""
4344
return super().action_space
4445

4546
def start(self):
46-
import time
4747
print(f"Starting RcsServer RPC (looped OneShotServer) on {self.host}:{self.port}")
4848
t = ThreadedServer(self, port=self.port)
49-
t.start()
49+
t.start()

python/tests/test_rpc.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
# noqa: type
2+
3+
from contextlib import suppress
14
import multiprocessing
5+
from multiprocessing.context import SpawnContext, ForkServerContext
26
import time
37
import socket
48
import sys
59
import traceback
610
import os
711
import pytest
8-
from typing import Optional # Add this import at the top
12+
from typing import Optional, Type, Union # Add Type and Union here
913
from rcs.envs.creators import SimEnvCreator
1014
from rcs.envs.utils import (
11-
default_mujoco_cameraset_cfg,
1215
default_sim_gripper_cfg,
1316
default_sim_robot_cfg,
1417
)
@@ -18,17 +21,19 @@
1821

1922
HOST = "127.0.0.1"
2023

24+
2125
def get_free_port() -> int:
2226
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
2327
s.bind((HOST, 0))
2428
return s.getsockname()[1]
2529

30+
2631
def wait_for_port(
2732
host: str,
2833
port: int,
2934
timeout: float,
3035
server_proc: Optional[multiprocessing.Process] = None,
31-
err_q: Optional[multiprocessing.Queue] = None
36+
err_q: Optional[multiprocessing.Queue] = None,
3237
) -> None:
3338
start = time.time()
3439
last_exc = None
@@ -44,21 +49,17 @@ def wait_for_port(
4449
if server_proc is not None and not server_proc.is_alive():
4550
server_err = None
4651
if err_q is not None:
47-
try:
52+
with suppress(Exception):
4853
server_err = err_q.get_nowait()
49-
except Exception:
50-
pass
5154
msg = f"Server process exited early (exitcode={server_proc.exitcode})."
5255
if server_err:
5356
msg += f"\nServer traceback:\n{server_err}"
5457
raise RuntimeError(msg)
5558
time.sleep(0.2)
5659
server_err = None
5760
if err_q is not None:
58-
try:
61+
with suppress(Exception):
5962
server_err = err_q.get_nowait()
60-
except Exception:
61-
pass
6263
msg = f"Timed out waiting for {host}:{port} to open."
6364
if last_exc:
6465
msg += f" Last socket error: {last_exc}"
@@ -68,6 +69,7 @@ def wait_for_port(
6869
msg += f"\nServer traceback:\n{server_err}"
6970
raise TimeoutError(msg)
7071

72+
7173
def run_server(host: str, port: int, err_q: multiprocessing.Queue) -> None:
7274
try:
7375
env = SimEnvCreator()(
@@ -76,7 +78,7 @@ def run_server(host: str, port: int, err_q: multiprocessing.Queue) -> None:
7678
robot_cfg=default_sim_robot_cfg(),
7779
gripper_cfg=default_sim_gripper_cfg(),
7880
# Disabled to avoid rendering problem in python subprocess.
79-
#cameras=default_mujoco_cameraset_cfg(),
81+
# cameras=default_mujoco_cameraset_cfg(),
8082
max_relative_movement=0.1,
8183
relative_to=RelativeTo.LAST_STEP,
8284
)
@@ -90,20 +92,22 @@ def run_server(host: str, port: int, err_q: multiprocessing.Queue) -> None:
9092
time.sleep(1)
9193
except Exception:
9294
tb = "".join(traceback.format_exception(*sys.exc_info()))
93-
try:
95+
with suppress(Exception):
9496
err_q.put(tb)
95-
except Exception:
96-
pass
9797
sys.exit(1)
9898

99-
def _mp_context() -> multiprocessing.context.BaseContext:
99+
100+
def _mp_context() -> Union[SpawnContext, ForkServerContext]:
100101
# Prefer spawn to avoid fork-related issues with GL/MuJoCo/threaded libs
101102
methods = multiprocessing.get_all_start_methods()
102103
if "spawn" in methods:
103104
return multiprocessing.get_context("spawn")
104105
if "forkserver" in methods:
105106
return multiprocessing.get_context("forkserver")
106-
return multiprocessing.get_context(methods[0])
107+
108+
msg = "No suitable multiprocessing context found."
109+
raise RuntimeError(msg)
110+
107111

108112
def _external_server_from_env() -> tuple[str, int] | None:
109113
# Set RCS_TEST_HOST and RCS_TEST_PORT to reuse an already running server.
@@ -119,6 +123,7 @@ def _external_server_from_env() -> tuple[str, int] | None:
119123
return HOST, 50055
120124
return None
121125

126+
122127
def test_run_server_starts_and_stops():
123128
# Skip if reusing an external server
124129
ext = _external_server_from_env()
@@ -130,17 +135,25 @@ def test_run_server_starts_and_stops():
130135
server_proc = ctx.Process(target=run_server, args=(HOST, port, err_q))
131136
server_proc.start()
132137
try:
133-
wait_for_port(HOST, port, timeout=120.0, server_proc=server_proc, err_q=err_q)
138+
wait_for_port(HOST, port, timeout=120.0, server_proc=server_proc, err_q=err_q) # type: ignore
134139
assert server_proc.is_alive(), "Server process did not start as expected."
135140
finally:
136141
if server_proc.is_alive():
137142
server_proc.terminate()
138143
server_proc.join(timeout=5)
139144
assert not server_proc.is_alive(), "Server process did not terminate as expected."
140145

146+
141147
class TestRcsClientServer:
148+
client: RcsClient
149+
host: str = HOST
150+
port: int = 0
151+
server_proc = None
152+
err_q: Optional[multiprocessing.Queue] = None
153+
154+
142155
@classmethod
143-
def setup_class(cls):
156+
def setup_class(cls: Type["TestRcsClientServer"]):
144157
ext = _external_server_from_env()
145158
if ext:
146159
cls.host, cls.port = ext
@@ -156,11 +169,11 @@ def setup_class(cls):
156169
cls.server_proc = ctx.Process(target=run_server, args=(cls.host, cls.port, cls.err_q))
157170
cls.server_proc.start()
158171
# Wait until the server is actually listening or fail early if it crashed
159-
wait_for_port(cls.host, cls.port, timeout=180.0, server_proc=cls.server_proc, err_q=cls.err_q)
172+
wait_for_port(cls.host, cls.port, timeout=180.0, server_proc=cls.server_proc, err_q=cls.err_q) # type: ignore
160173
cls.client = RcsClient(host=cls.host, port=cls.port)
161174

162175
@classmethod
163-
def teardown_class(cls):
176+
def teardown_class(cls: Type["TestRcsClientServer"]):
164177
try:
165178
if getattr(cls, "client", None):
166179
cls.client.close()
@@ -188,8 +201,14 @@ def test_unwrapped(self):
188201
_ = self.client.unwrapped
189202

190203
def test_close(self):
191-
self.client.close()
204+
if self.client is not None:
205+
self.client.close()
192206
# Reconnect for further tests
193-
wait_for_port(self.__class__.host, self.__class__.port, timeout=15.0,
194-
server_proc=self.__class__.server_proc, err_q=self.__class__.err_q)
207+
wait_for_port(
208+
self.__class__.host,
209+
self.__class__.port,
210+
timeout=15.0,
211+
server_proc=self.__class__.server_proc, # type: ignore
212+
err_q=self.__class__.err_q,
213+
)
195214
self.__class__.client = RcsClient(host=self.__class__.host, port=self.__class__.port)

0 commit comments

Comments
 (0)