|
2 | 2 | import re |
3 | 3 | import ssl |
4 | 4 | from dataclasses import dataclass |
5 | | -from typing import Dict, List, Optional, Tuple |
| 5 | +from typing import Dict, List, Optional, Tuple, Union |
6 | 6 | from urllib.parse import unquote, urljoin, urlparse |
7 | 7 |
|
8 | 8 | import structlog |
@@ -61,6 +61,30 @@ def reconstruct(self) -> bytes: |
61 | 61 | return result |
62 | 62 |
|
63 | 63 |
|
| 64 | +@dataclass |
| 65 | +class HttpResponse: |
| 66 | + """Data class to store HTTP response details""" |
| 67 | + |
| 68 | + version: str |
| 69 | + status_code: int |
| 70 | + reason: str |
| 71 | + headers: List[str] |
| 72 | + body: Optional[bytes] = None |
| 73 | + |
| 74 | + def reconstruct(self) -> bytes: |
| 75 | + """Reconstruct HTTP response from stored details""" |
| 76 | + headers = "\r\n".join(self.headers) |
| 77 | + status_line = f"{self.version} {self.status_code} {self.reason}\r\n" |
| 78 | + header_block = f"{status_line}{headers}\r\n\r\n" |
| 79 | + |
| 80 | + # Convert header block to bytes and combine with body |
| 81 | + result = header_block.encode("utf-8") |
| 82 | + if self.body: |
| 83 | + result += self.body |
| 84 | + |
| 85 | + return result |
| 86 | + |
| 87 | + |
64 | 88 | def extract_path(full_path: str) -> str: |
65 | 89 | """Extract clean path from full URL or path string""" |
66 | 90 | logger.debug(f"Extracting path from {full_path}") |
@@ -145,7 +169,7 @@ async def _body_through_pipeline( |
145 | 169 | ) -> Tuple[bytes, PipelineContext]: |
146 | 170 | logger.debug(f"Processing body through pipeline: {len(body)} bytes") |
147 | 171 | strategy = self._select_pipeline(method, path) |
148 | | - if strategy is None: |
| 172 | + if len(body) == 0 or strategy is None: |
149 | 173 | # if we didn't select any strategy that would change the request |
150 | 174 | # let's just pass through the body as-is |
151 | 175 | return body, None |
@@ -243,35 +267,87 @@ def _check_buffer_size(self, new_data: bytes) -> bool: |
243 | 267 | """Check if adding new data would exceed buffer size limit""" |
244 | 268 | return len(self.buffer) + len(new_data) <= MAX_BUFFER_SIZE |
245 | 269 |
|
246 | | - async def _forward_data_through_pipeline(self, data: bytes) -> bytes: |
| 270 | + async def _forward_data_through_pipeline( |
| 271 | + self, data: bytes |
| 272 | + ) -> Union[HttpRequest, HttpResponse]: |
247 | 273 | http_request = http_request_from_bytes(data) |
248 | 274 | if not http_request: |
249 | 275 | # we couldn't parse this into an HTTP request, so we just pass through |
250 | 276 | return data |
251 | 277 |
|
252 | | - http_request.body, context = await self._body_through_pipeline( |
| 278 | + body, context = await self._body_through_pipeline( |
253 | 279 | http_request.method, |
254 | 280 | http_request.path, |
255 | 281 | http_request.headers, |
256 | 282 | http_request.body, |
257 | 283 | ) |
258 | 284 | self.context_tracking = context |
259 | 285 |
|
260 | | - for header in http_request.headers: |
261 | | - if header.lower().startswith("content-length:"): |
262 | | - http_request.headers.remove(header) |
263 | | - break |
264 | | - http_request.headers.append(f"Content-Length: {len(http_request.body)}") |
| 286 | + if context and context.shortcut_response: |
| 287 | + # Send shortcut response |
| 288 | + data_prefix = b'data:' |
| 289 | + http_response = HttpResponse( |
| 290 | + http_request.version, |
| 291 | + 200, |
| 292 | + "OK", |
| 293 | + [ |
| 294 | + "server: uvicorn", |
| 295 | + "cache-control: no-cache", |
| 296 | + "connection: keep-alive", |
| 297 | + "Content-Type: application/json", |
| 298 | + "Transfer-Encoding: chunked", |
| 299 | + ], |
| 300 | + data_prefix + body |
| 301 | + ) |
| 302 | + return http_response |
265 | 303 |
|
266 | | - pipeline_data = http_request.reconstruct() |
| 304 | + else: |
| 305 | + # Forward request to target |
| 306 | + http_request.body = body |
| 307 | + |
| 308 | + for header in http_request.headers: |
| 309 | + if header.lower().startswith("content-length:"): |
| 310 | + http_request.headers.remove(header) |
| 311 | + break |
| 312 | + http_request.headers.append(f"Content-Length: {len(http_request.body)}") |
267 | 313 |
|
268 | | - return pipeline_data |
| 314 | + return http_request |
269 | 315 |
|
270 | 316 | async def _forward_data_to_target(self, data: bytes) -> None: |
271 | | - """Forward data to target if connection is established""" |
272 | | - if self.target_transport and not self.target_transport.is_closing(): |
273 | | - data = await self._forward_data_through_pipeline(data) |
274 | | - self.target_transport.write(data) |
| 317 | + """ |
| 318 | + Forward data to target if connection is established. In case of shortcut |
| 319 | + response, send a response to the client |
| 320 | + """ |
| 321 | + pipeline_output = await self._forward_data_through_pipeline(data) |
| 322 | + |
| 323 | + if isinstance(pipeline_output, HttpResponse): |
| 324 | + # We need to send shortcut response |
| 325 | + if self.transport and not self.transport.is_closing(): |
| 326 | + # First, close target_transport since we don't need to send any |
| 327 | + # request to the target |
| 328 | + self.target_transport.close() |
| 329 | + |
| 330 | + # Send the shortcut response data in a chunk |
| 331 | + chunk = pipeline_output.reconstruct() |
| 332 | + chunk_size = hex(len(chunk))[2:] + "\r\n" |
| 333 | + self.transport.write(chunk_size.encode()) |
| 334 | + self.transport.write(chunk) |
| 335 | + self.transport.write(b"\r\n") |
| 336 | + |
| 337 | + # Send data done chunk |
| 338 | + chunk = b"data: [DONE]\n\n" |
| 339 | + # Add chunk size for DONE message |
| 340 | + chunk_size = hex(len(chunk))[2:] + "\r\n" |
| 341 | + self.transport.write(chunk_size.encode()) |
| 342 | + self.transport.write(chunk) |
| 343 | + self.transport.write(b"\r\n") |
| 344 | + # Now send the final chunk with 0 |
| 345 | + self.transport.write(b"0\r\n\r\n") |
| 346 | + else: |
| 347 | + if self.target_transport and not self.target_transport.is_closing(): |
| 348 | + if isinstance(pipeline_output, HttpRequest): |
| 349 | + pipeline_output = pipeline_output.reconstruct() |
| 350 | + self.target_transport.write(pipeline_output) |
275 | 351 |
|
276 | 352 | def data_received(self, data: bytes) -> None: |
277 | 353 | """Handle received data from client""" |
|
0 commit comments