|
| 1 | +import socket |
1 | 2 | import time |
| 3 | +from contextlib import closing |
2 | 4 | from threading import Thread |
3 | 5 |
|
| 6 | +import pytest |
4 | 7 | import redis |
5 | 8 |
|
6 | 9 | from fakeredis import TcpFakeServer |
7 | | -from fakeredis._tcp_server import TCP_SERVER_TEST_PORT |
8 | 10 | from test import testtools |
9 | 11 |
|
10 | 12 |
|
11 | | -@testtools.run_test_if_lupa_installed() |
12 | | -def test_eval_multiline_script(): |
13 | | - """Test that EVAL works with multi-line Lua scripts.""" |
14 | | - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) |
15 | | - server = TcpFakeServer(server_address) |
16 | | - t = Thread(target=server.serve_forever, daemon=True) |
17 | | - t.start() |
| 13 | +@pytest.fixture(scope="function") |
| 14 | +def redis_server(): |
| 15 | + host = "127.0.0.1" |
| 16 | + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: |
| 17 | + s.bind(("", 0)) |
| 18 | + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| 19 | + port = s.getsockname()[1] |
| 20 | + |
| 21 | + class TcpFakeServerWithExceptions(TcpFakeServer): |
| 22 | + def handle_error(self, request, client_address): |
| 23 | + super().handle_error(request, client_address) |
| 24 | + # Send an error message back to the client here |
| 25 | + request.sendall(b"An error occurred on the server.") |
| 26 | + |
| 27 | + server_address = (host, port) |
| 28 | + server = TcpFakeServerWithExceptions(server_address) |
| 29 | + thread = Thread(target=server.serve_forever, daemon=True) |
| 30 | + thread.start() |
18 | 31 | time.sleep(0.1) |
19 | 32 |
|
20 | | - with redis.Redis(host=server_address[0], port=server_address[1]) as r: |
21 | | - # Multi-line script with trailing newline |
22 | | - script = """ |
| 33 | + try: |
| 34 | + with redis.Redis(host=host, port=port) as r: |
| 35 | + yield r |
| 36 | + finally: |
| 37 | + server.server_close() |
| 38 | + server.shutdown() |
| 39 | + thread.join() |
| 40 | + |
| 41 | + |
| 42 | +@testtools.run_test_if_lupa_installed() |
| 43 | +def test_eval_multiline_script(redis_server): |
| 44 | + """Test that EVAL works with multi-line Lua scripts.""" |
| 45 | + # Multi-line script with trailing newline |
| 46 | + script = """ |
23 | 47 | local key = KEYS[1] |
24 | 48 | local value = ARGV[1] |
25 | 49 | redis.call('SET', key, value) |
26 | 50 | return redis.call('GET', key) |
27 | 51 | """ |
28 | | - result = r.eval(script, 1, "testkey", "testvalue") |
29 | | - assert result == b"testvalue" |
30 | | - |
31 | | - server.server_close() |
32 | | - server.shutdown() |
33 | | - t.join() |
| 52 | + result = redis_server.eval(script, 1, "testkey", "testvalue") |
| 53 | + assert result == b"testvalue" |
34 | 54 |
|
35 | 55 |
|
36 | 56 | @testtools.run_test_if_lupa_installed() |
37 | | -def test_script_load_multiline(): |
| 57 | +def test_script_load_multiline(redis_server): |
38 | 58 | """Test that SCRIPT LOAD works with multi-line Lua scripts.""" |
39 | | - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) |
40 | | - server = TcpFakeServer(server_address) |
41 | | - t = Thread(target=server.serve_forever, daemon=True) |
42 | | - t.start() |
43 | | - time.sleep(0.1) |
44 | | - |
45 | | - with redis.Redis(host=server_address[0], port=server_address[1]) as r: |
46 | | - # Multi-line script |
47 | | - script = """local x = 1 |
| 59 | + # Multi-line script |
| 60 | + script = """local x = 1 |
48 | 61 | local y = 2 |
49 | 62 | return x + y""" |
50 | | - sha = r.script_load(script) |
51 | | - result = r.evalsha(sha, 0) |
52 | | - assert result == 3 |
53 | | - |
54 | | - server.server_close() |
55 | | - server.shutdown() |
56 | | - t.join() |
| 63 | + sha = redis_server.script_load(script) |
| 64 | + result = redis_server.evalsha(sha, 0) |
| 65 | + assert result == 3 |
57 | 66 |
|
58 | 67 |
|
59 | 68 | @testtools.run_test_if_lupa_installed() |
60 | | -def test_eval_script_with_trailing_newline(): |
| 69 | +def test_eval_script_with_trailing_newline(redis_server): |
61 | 70 | """Test that scripts with trailing newlines are preserved.""" |
62 | | - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) |
63 | | - server = TcpFakeServer(server_address) |
64 | | - t = Thread(target=server.serve_forever, daemon=True) |
65 | | - t.start() |
66 | | - time.sleep(0.1) |
| 71 | + # Script with explicit trailing newline |
| 72 | + script = "return 'hello'\n" |
| 73 | + result = redis_server.eval(script, 0) |
| 74 | + assert result == b"hello" |
| 75 | + |
67 | 76 |
|
68 | | - with redis.Redis(host=server_address[0], port=server_address[1]) as r: |
69 | | - # Script with explicit trailing newline |
70 | | - script = "return 'hello'\n" |
71 | | - result = r.eval(script, 0) |
72 | | - assert result == b"hello" |
| 77 | +@testtools.run_test_if_lupa_installed() |
| 78 | +def test_bulk_string_length(redis_server): |
| 79 | + """Test that malformed bulk string input is handled correctly.""" |
| 80 | + host = redis_server.connection_pool.connection_kwargs.get("host") |
| 81 | + port = redis_server.connection_pool.connection_kwargs.get("port") |
73 | 82 |
|
74 | | - server.server_close() |
75 | | - server.shutdown() |
76 | | - t.join() |
| 83 | + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
| 84 | + s.connect((host, port)) |
| 85 | + s.sendall(b"$ 1\ntest") |
| 86 | + data = s.recv(1024).decode() |
| 87 | + assert data != "An error occurred on the server." |
0 commit comments