diff --git a/fakeredis/_tcp_server.py b/fakeredis/_tcp_server.py index 3136be0e..69f6140d 100644 --- a/fakeredis/_tcp_server.py +++ b/fakeredis/_tcp_server.py @@ -55,8 +55,10 @@ def load(self) -> Any: length = int(rest) bulk_string = self.reader.read(length) terminator = self.reader.read(2) - if len(bulk_string) != length or terminator != b"\r\n": - raise ValueError() + if len(bulk_string) != length: + raise ValueError(f"Invalid bulk string length. Expected {length} bytes, got: {len(bulk_string)}") + if terminator != b"\r\n": + raise ValueError(f"Invalid terminator. Expected \\r\\n, got: {terminator}") return bulk_string if prefix == b":": return int(rest) diff --git a/test/test_tcp_server/test_reader.py b/test/test_tcp_server/test_reader.py index a4c81e91..97f08669 100644 --- a/test/test_tcp_server/test_reader.py +++ b/test/test_tcp_server/test_reader.py @@ -1,76 +1,87 @@ +import socket import time +from contextlib import closing from threading import Thread +import pytest import redis from fakeredis import TcpFakeServer -from fakeredis._tcp_server import TCP_SERVER_TEST_PORT from test import testtools -@testtools.run_test_if_lupa_installed() -def test_eval_multiline_script(): - """Test that EVAL works with multi-line Lua scripts.""" - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) - server = TcpFakeServer(server_address) - t = Thread(target=server.serve_forever, daemon=True) - t.start() +@pytest.fixture(scope="function") +def redis_server(): + host = "127.0.0.1" + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + port = s.getsockname()[1] + + class TcpFakeServerWithExceptions(TcpFakeServer): + def handle_error(self, request, client_address): + super().handle_error(request, client_address) + # Send an error message back to the client here + request.sendall(b"An error occurred on the server.") + + server_address = (host, port) + server = TcpFakeServerWithExceptions(server_address) + thread = Thread(target=server.serve_forever, daemon=True) + thread.start() time.sleep(0.1) - with redis.Redis(host=server_address[0], port=server_address[1]) as r: - # Multi-line script with trailing newline - script = """ + try: + with redis.Redis(host=host, port=port) as r: + yield r + finally: + server.server_close() + server.shutdown() + thread.join() + + +@testtools.run_test_if_lupa_installed() +def test_eval_multiline_script(redis_server): + """Test that EVAL works with multi-line Lua scripts.""" + # Multi-line script with trailing newline + script = """ local key = KEYS[1] local value = ARGV[1] redis.call('SET', key, value) return redis.call('GET', key) """ - result = r.eval(script, 1, "testkey", "testvalue") - assert result == b"testvalue" - - server.server_close() - server.shutdown() - t.join() + result = redis_server.eval(script, 1, "testkey", "testvalue") + assert result == b"testvalue" @testtools.run_test_if_lupa_installed() -def test_script_load_multiline(): +def test_script_load_multiline(redis_server): """Test that SCRIPT LOAD works with multi-line Lua scripts.""" - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) - server = TcpFakeServer(server_address) - t = Thread(target=server.serve_forever, daemon=True) - t.start() - time.sleep(0.1) - - with redis.Redis(host=server_address[0], port=server_address[1]) as r: - # Multi-line script - script = """local x = 1 + # Multi-line script + script = """local x = 1 local y = 2 return x + y""" - sha = r.script_load(script) - result = r.evalsha(sha, 0) - assert result == 3 - - server.server_close() - server.shutdown() - t.join() + sha = redis_server.script_load(script) + result = redis_server.evalsha(sha, 0) + assert result == 3 @testtools.run_test_if_lupa_installed() -def test_eval_script_with_trailing_newline(): +def test_eval_script_with_trailing_newline(redis_server): """Test that scripts with trailing newlines are preserved.""" - server_address = ("127.0.0.1", TCP_SERVER_TEST_PORT) - server = TcpFakeServer(server_address) - t = Thread(target=server.serve_forever, daemon=True) - t.start() - time.sleep(0.1) + # Script with explicit trailing newline + script = "return 'hello'\n" + result = redis_server.eval(script, 0) + assert result == b"hello" + - with redis.Redis(host=server_address[0], port=server_address[1]) as r: - # Script with explicit trailing newline - script = "return 'hello'\n" - result = r.eval(script, 0) - assert result == b"hello" +@testtools.run_test_if_lupa_installed() +def test_bulk_string_length(redis_server): + """Test that malformed bulk string input is handled correctly.""" + host = redis_server.connection_pool.connection_kwargs.get("host") + port = redis_server.connection_pool.connection_kwargs.get("port") - server.server_close() - server.shutdown() - t.join() + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.connect((host, port)) + s.sendall(b"$ 1\ntest") + data = s.recv(1024).decode() + assert data != "An error occurred on the server."