Skip to content

Commit 71c1906

Browse files
committed
Add test
1 parent 9c02863 commit 71c1906

File tree

1 file changed

+59
-48
lines changed

1 file changed

+59
-48
lines changed
Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,87 @@
1+
import socket
12
import time
3+
from contextlib import closing
24
from threading import Thread
35

6+
import pytest
47
import redis
58

69
from fakeredis import TcpFakeServer
7-
from fakeredis._tcp_server import TCP_SERVER_TEST_PORT
810
from test import testtools
911

1012

11-
@testtools.run_test_if_lupa_installed()
12-
def test_eval_multiline_script():
13-
"""Test that EVAL works with multi-line Lua scripts."""
14-
server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT)
15-
server = TcpFakeServer(server_address)
16-
t = Thread(target=server.serve_forever, daemon=True)
17-
t.start()
13+
@pytest.fixture(scope="function")
14+
def redis_server():
15+
host = "127.0.0.1"
16+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
17+
s.bind(("", 0))
18+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
19+
port = s.getsockname()[1]
20+
21+
class TcpFakeServerWithExceptions(TcpFakeServer):
22+
def handle_error(self, request, client_address):
23+
super().handle_error(request, client_address)
24+
# Send an error message back to the client here
25+
request.sendall(b"An error occurred on the server.")
26+
27+
server_address = (host, port)
28+
server = TcpFakeServerWithExceptions(server_address)
29+
thread = Thread(target=server.serve_forever, daemon=True)
30+
thread.start()
1831
time.sleep(0.1)
1932

20-
with redis.Redis(host=server_address[0], port=server_address[1]) as r:
21-
# Multi-line script with trailing newline
22-
script = """
33+
try:
34+
with redis.Redis(host=host, port=port) as r:
35+
yield r
36+
finally:
37+
server.server_close()
38+
server.shutdown()
39+
thread.join()
40+
41+
42+
@testtools.run_test_if_lupa_installed()
43+
def test_eval_multiline_script(redis_server):
44+
"""Test that EVAL works with multi-line Lua scripts."""
45+
# Multi-line script with trailing newline
46+
script = """
2347
local key = KEYS[1]
2448
local value = ARGV[1]
2549
redis.call('SET', key, value)
2650
return redis.call('GET', key)
2751
"""
28-
result = r.eval(script, 1, "testkey", "testvalue")
29-
assert result == b"testvalue"
30-
31-
server.server_close()
32-
server.shutdown()
33-
t.join()
52+
result = redis_server.eval(script, 1, "testkey", "testvalue")
53+
assert result == b"testvalue"
3454

3555

3656
@testtools.run_test_if_lupa_installed()
37-
def test_script_load_multiline():
57+
def test_script_load_multiline(redis_server):
3858
"""Test that SCRIPT LOAD works with multi-line Lua scripts."""
39-
server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT)
40-
server = TcpFakeServer(server_address)
41-
t = Thread(target=server.serve_forever, daemon=True)
42-
t.start()
43-
time.sleep(0.1)
44-
45-
with redis.Redis(host=server_address[0], port=server_address[1]) as r:
46-
# Multi-line script
47-
script = """local x = 1
59+
# Multi-line script
60+
script = """local x = 1
4861
local y = 2
4962
return x + y"""
50-
sha = r.script_load(script)
51-
result = r.evalsha(sha, 0)
52-
assert result == 3
53-
54-
server.server_close()
55-
server.shutdown()
56-
t.join()
63+
sha = redis_server.script_load(script)
64+
result = redis_server.evalsha(sha, 0)
65+
assert result == 3
5766

5867

5968
@testtools.run_test_if_lupa_installed()
60-
def test_eval_script_with_trailing_newline():
69+
def test_eval_script_with_trailing_newline(redis_server):
6170
"""Test that scripts with trailing newlines are preserved."""
62-
server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT)
63-
server = TcpFakeServer(server_address)
64-
t = Thread(target=server.serve_forever, daemon=True)
65-
t.start()
66-
time.sleep(0.1)
71+
# Script with explicit trailing newline
72+
script = "return 'hello'\n"
73+
result = redis_server.eval(script, 0)
74+
assert result == b"hello"
75+
6776

68-
with redis.Redis(host=server_address[0], port=server_address[1]) as r:
69-
# Script with explicit trailing newline
70-
script = "return 'hello'\n"
71-
result = r.eval(script, 0)
72-
assert result == b"hello"
77+
@testtools.run_test_if_lupa_installed()
78+
def test_bulk_string_length(redis_server):
79+
"""Test that malformed bulk string input is handled correctly."""
80+
host = redis_server.connection_pool.connection_kwargs.get("host")
81+
port = redis_server.connection_pool.connection_kwargs.get("port")
7382

74-
server.server_close()
75-
server.shutdown()
76-
t.join()
83+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
84+
s.connect((host, port))
85+
s.sendall(b"$ 1\ntest")
86+
data = s.recv(1024).decode()
87+
assert data != "An error occurred on the server."

0 commit comments

Comments
 (0)