diff --git a/changelog.d/403.bugfix.rst b/changelog.d/403.bugfix.rst new file mode 100644 index 00000000..3a00357a --- /dev/null +++ b/changelog.d/403.bugfix.rst @@ -0,0 +1 @@ +treq no longer vendors the ``multipart`` library, now that it no longer has import conflicts with ``python-multipart``. diff --git a/pyproject.toml b/pyproject.toml index 661be332..c8511771 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,6 @@ ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - "treq._multipart", + "multipart", ] -disallow_untyped_defs = false +ignore_missing_imports = true diff --git a/src/treq/_multipart.py b/src/treq/_multipart.py deleted file mode 100644 index b302a26f..00000000 --- a/src/treq/_multipart.py +++ /dev/null @@ -1,559 +0,0 @@ -""" -Parser for multipart/form-data -============================== - -This module provides a parser for the multipart/form-data format. It can read -from a file, a socket or a WSGI environment. The parser can be used to replace -cgi.FieldStorage to work around its limitations. - -..note:: - - Copyright (c) 2010, Marcel Hellkamp. - Inspired by the Werkzeug library: http://werkzeug.pocoo.org/ - - Permission is hereby granted, free of charge, to any person obtaining a copy - of this software and associated documentation files (the "Software"), to deal - in the Software without restriction, including without limitation the rights - to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - copies of the Software, and to permit persons to whom the Software is - furnished to do so, subject to the following conditions: - - The above copyright notice and this permission notice shall be included in - all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - THE SOFTWARE. -""" - -import re -from io import BytesIO -from typing import IO -from tempfile import TemporaryFile -from urllib.parse import parse_qs -from wsgiref.headers import Headers -from collections.abc import MutableMapping as DictMixin - - -__author__ = "Marcel Hellkamp" -__version__ = "0.2.5" -__license__ = "MIT" -__all__ = ["MultipartError", "MultipartParser", "MultipartPart", "parse_form_data"] - - -############################################################################## -################################ Helper & Misc ############################### -############################################################################## -# Some of these were copied from bottle: https://bottlepy.org - - -# --------- -# MultiDict -# --------- - - -class MultiDict(DictMixin): - """A dict that remembers old values for each key. - HTTP headers may repeat with differing values, - such as Set-Cookie. We need to remember all - values. - """ - - def __init__(self, *args, **kwargs): - self.dict = dict() - for k, v in dict(*args, **kwargs).items(): - self[k] = v - - def __len__(self): - return len(self.dict) - - def __iter__(self): - return iter(self.dict) - - def __contains__(self, key): - return key in self.dict - - def __delitem__(self, key): - del self.dict[key] - - def keys(self): - return self.dict.keys() - - def __getitem__(self, key): - return self.get(key, KeyError, -1) - - def __setitem__(self, key, value): - self.append(key, value) - - def append(self, key, value): - self.dict.setdefault(key, []).append(value) - - def replace(self, key, value): - self.dict[key] = [value] - - def getall(self, key): - return self.dict.get(key) or [] - - def get(self, key, default=None, index=-1): - if key not in self.dict and default != KeyError: - return [default][index] - - return self.dict[key][index] - - def iterallitems(self): - for key, values in self.dict.items(): - for value in values: - yield key, value - - -def to_bytes(data, enc="utf8"): - if isinstance(data, str): - data = data.encode(enc) - - return data - - -def copy_file(stream, target, maxread=-1, buffer_size=2**16): - """Read from :stream and write to :target until :maxread or EOF.""" - size, read = 0, stream.read - - while True: - to_read = buffer_size if maxread < 0 else min(buffer_size, maxread - size) - part = read(to_read) - - if not part: - return size - - target.write(part) - size += len(part) - - -# ------------- -# Header Parser -# ------------- - - -_special = re.escape('()<>@,;:"\\/[]?={} \t') -_re_special = re.compile(r"[%s]" % _special) -_quoted_string = r'"(?:\\.|[^"])*"' # Quoted string -_value = r"(?:[^%s]+|%s)" % (_special, _quoted_string) # Save or quoted string -_option = r"(?:;|^)\s*([^%s]+)\s*=\s*(%s)" % (_special, _value) -_re_option = re.compile(_option) # key=value part of an Content-Type like header - - -def header_quote(val): - if not _re_special.search(val): - return val - - return '"' + val.replace("\\", "\\\\").replace('"', '\\"') + '"' - - -def header_unquote(val, filename=False): - if val[0] == val[-1] == '"': - val = val[1:-1] - - if val[1:3] == ":\\" or val[:2] == "\\\\": - val = val.split("\\")[-1] # fix ie6 bug: full path --> filename - - return val.replace("\\\\", "\\").replace('\\"', '"') - - return val - - -def parse_options_header(header, options=None): - if ";" not in header: - return header.lower().strip(), {} - - content_type, tail = header.split(";", 1) - options = options or {} - - for match in _re_option.finditer(tail): - key = match.group(1).lower() - value = header_unquote(match.group(2), key == "filename") - options[key] = value - - return content_type, options - - -############################################################################## -################################## Multipart ################################# -############################################################################## - - -class MultipartError(ValueError): - pass - - -class MultipartParser: - def __init__( - self, - stream, - boundary, - content_length=-1, - disk_limit=2**30, - mem_limit=2**20, - memfile_limit=2**18, - buffer_size=2**16, - charset="latin1", - ): - """Parse a multipart/form-data byte stream. This object is an iterator - over the parts of the message. - - :param stream: A file-like stream. Must implement ``.read(size)``. - :param boundary: The multipart boundary as a byte string. - :param content_length: The maximum number of bytes to read. - """ - self.stream = stream - self.boundary = boundary - self.content_length = content_length - self.disk_limit = disk_limit - self.memfile_limit = memfile_limit - self.mem_limit = min(mem_limit, self.disk_limit) - self.buffer_size = min(buffer_size, self.mem_limit) - self.charset = charset - - if self.buffer_size - 6 < len(boundary): # "--boundary--\r\n" - raise MultipartError("Boundary does not fit into buffer_size.") - - self._done = [] - self._part_iter = None - - def __iter__(self): - """Iterate over the parts of the multipart message.""" - if not self._part_iter: - self._part_iter = self._iterparse() - - for part in self._done: - yield part - - for part in self._part_iter: - self._done.append(part) - yield part - - def parts(self): - """Returns a list with all parts of the multipart message.""" - return list(self) - - def get(self, name, default=None): - """Return the first part with that name or a default value (None).""" - for part in self: - if name == part.name: - return part - - return default - - def get_all(self, name): - """Return a list of parts with that name.""" - return [p for p in self if p.name == name] - - def _lineiter(self): - """Iterate over a binary file-like object line by line. Each line is - returned as a (line, line_ending) tuple. If the line does not fit - into self.buffer_size, line_ending is empty and the rest of the line - is returned with the next iteration. - """ - read = self.stream.read - maxread, maxbuf = self.content_length, self.buffer_size - buffer = b"" # buffer for the last (partial) line - - while True: - data = read(maxbuf if maxread < 0 else min(maxbuf, maxread)) - maxread -= len(data) - lines = (buffer + data).splitlines(True) - len_first_line = len(lines[0]) - - # be sure that the first line does not become too big - if len_first_line > self.buffer_size: - # at the same time don't split a '\r\n' accidentally - if len_first_line == self.buffer_size + 1 and lines[0].endswith( - b"\r\n" - ): - splitpos = self.buffer_size - 1 - else: - splitpos = self.buffer_size - lines[:1] = [lines[0][:splitpos], lines[0][splitpos:]] - - if data: - buffer = lines[-1] - lines = lines[:-1] - - for line in lines: - if line.endswith(b"\r\n"): - yield line[:-2], b"\r\n" - elif line.endswith(b"\n"): - yield line[:-1], b"\n" - elif line.endswith(b"\r"): - yield line[:-1], b"\r" - else: - yield line, b"" - - if not data: - break - - def _iterparse(self): - lines, line = self._lineiter(), "" - separator = b"--" + to_bytes(self.boundary) - terminator = b"--" + to_bytes(self.boundary) + b"--" - - # Consume first boundary. Ignore any preamble, as required by RFC - # 2046, section 5.1.1. - for line, nl in lines: - if line in (separator, terminator): - break - else: - raise MultipartError("Stream does not contain boundary") - - # Check for empty data - if line == terminator: - for _ in lines: - raise MultipartError("Data after end of stream") - return - - # For each part in stream... - mem_used, disk_used = 0, 0 # Track used resources to prevent DoS - is_tail = False # True if the last line was incomplete (cutted) - - opts = { - "buffer_size": self.buffer_size, - "memfile_limit": self.memfile_limit, - "charset": self.charset, - } - - part = MultipartPart(**opts) - - for line, nl in lines: - if line == terminator and not is_tail: - part.file.seek(0) - yield part - break - - elif line == separator and not is_tail: - if part.is_buffered(): - mem_used += part.size - else: - disk_used += part.size - part.file.seek(0) - - yield part - - part = MultipartPart(**opts) - - else: - is_tail = not nl # The next line continues this one - try: - part.feed(line, nl) - - if part.is_buffered(): - if part.size + mem_used > self.mem_limit: - raise MultipartError("Memory limit reached.") - elif part.size + disk_used > self.disk_limit: - raise MultipartError("Disk limit reached.") - except MultipartError: - part.close() - raise - else: - # If we run off the end of the loop, the current MultipartPart - # will not have been yielded, so it's our responsibility to - # close it. - part.close() - - if line != terminator: - raise MultipartError("Unexpected end of multipart stream.") - - -class MultipartPart: - file: IO[bytes] - - def __init__(self, buffer_size=2**16, memfile_limit=2**18, charset="latin1"): - self.headerlist = [] - self.headers = None - self.file = False # type:ignore - self.size = 0 - self._buf = b"" - self.disposition = None - self.name = None - self.filename = None - self.content_type = None - self.charset = charset - self.memfile_limit = memfile_limit - self.buffer_size = buffer_size - - def feed(self, line, nl=""): - if self.file: - return self.write_body(line, nl) - - return self.write_header(line, nl) - - def write_header(self, line, nl): - line = line.decode(self.charset) - - if not nl: - raise MultipartError("Unexpected end of line in header.") - - if not line.strip(): # blank line -> end of header segment - self.finish_header() - elif line[0] in " \t" and self.headerlist: - name, value = self.headerlist.pop() - self.headerlist.append((name, value + line.strip())) - else: - if ":" not in line: - raise MultipartError("Syntax error in header: No colon.") - - name, value = line.split(":", 1) - self.headerlist.append((name.strip(), value.strip())) - - def write_body(self, line, nl): - if not line and not nl: - return # This does not even flush the buffer - - self.size += len(line) + len(self._buf) - self.file.write(self._buf + line) - self._buf = nl - - if self.content_length > 0 and self.size > self.content_length: - raise MultipartError("Size of body exceeds Content-Length header.") - - if self.size > self.memfile_limit and isinstance(self.file, BytesIO): - # TODO: What about non-file uploads that exceed the memfile_limit? - self.file, old = TemporaryFile(mode="w+b"), self.file - old.seek(0) - copy_file(old, self.file, self.size, self.buffer_size) - - def finish_header(self): - self.file = BytesIO() - self.headers = Headers(self.headerlist) - content_disposition = self.headers.get("Content-Disposition", "") - content_type = self.headers.get("Content-Type", "") - - if not content_disposition: - raise MultipartError("Content-Disposition header is missing.") - - self.disposition, self.options = parse_options_header(content_disposition) - self.name = self.options.get("name") - self.filename = self.options.get("filename") - self.content_type, options = parse_options_header(content_type) - self.charset = options.get("charset") or self.charset - self.content_length = int(self.headers.get("Content-Length", "-1")) - - def is_buffered(self): - """Return true if the data is fully buffered in memory.""" - return isinstance(self.file, BytesIO) - - @property - def value(self): - """Data decoded with the specified charset""" - - return self.raw.decode(self.charset) - - @property - def raw(self): - """Data without decoding""" - pos = self.file.tell() - self.file.seek(0) - - try: - val = self.file.read() - except OSError: - raise - finally: - self.file.seek(pos) - - return val - - def save_as(self, path): - with open(path, "wb") as fp: - pos = self.file.tell() - - try: - self.file.seek(0) - size = copy_file(self.file, fp) - finally: - self.file.seek(pos) - - return size - - def close(self): - if self.file: - self.file.close() - self.file = False # type:ignore - - -############################################################################## -#################################### WSGI #################################### -############################################################################## - - -def parse_form_data(environ, charset="utf8", strict=False, **kwargs): - """Parse form data from an environ dict and return a (forms, files) tuple. - Both tuple values are dictionaries with the form-field name as a key - (unicode) and lists as values (multiple values per key are possible). - The forms-dictionary contains form-field values as unicode strings. - The files-dictionary contains :class:`MultipartPart` instances, either - because the form-field was a file-upload or the value is too big to fit - into memory limits. - - :param environ: An WSGI environment dict. - :param charset: The charset to use if unsure. (default: utf8) - :param strict: If True, raise :exc:`MultipartError` on any parsing - errors. These are silently ignored by default. - """ - - forms, files = MultiDict(), MultiDict() - - try: - if environ.get("REQUEST_METHOD", "GET").upper() not in ("POST", "PUT"): - raise MultipartError("Request method other than POST or PUT.") - content_length = int(environ.get("CONTENT_LENGTH", "-1")) - content_type = environ.get("CONTENT_TYPE", "") - - if not content_type: - raise MultipartError("Missing Content-Type header.") - - content_type, options = parse_options_header(content_type) - stream = environ.get("wsgi.input") or BytesIO() - kwargs["charset"] = charset = options.get("charset", charset) - - if content_type == "multipart/form-data": - boundary = options.get("boundary", "") - - if not boundary: - raise MultipartError("No boundary for multipart/form-data.") - - for part in MultipartParser(stream, boundary, content_length, **kwargs): - if part.filename or not part.is_buffered(): - files[part.name] = part - else: # TODO: Big form-fields are in the files dict. really? - forms[part.name] = part.value - - elif content_type in ( - "application/x-www-form-urlencoded", - "application/x-url-encoded", - ): - mem_limit = kwargs.get("mem_limit", 2**20) - if content_length > mem_limit: - raise MultipartError("Request too big. Increase MAXMEM.") - - data = stream.read(mem_limit).decode(charset) - - if stream.read(1): # These is more that does not fit mem_limit - raise MultipartError("Request too big. Increase MAXMEM.") - - data = parse_qs(data, keep_blank_values=True, encoding=charset) - - for key, values in data.items(): - for value in values: - forms[key] = value - else: - raise MultipartError("Unsupported content type.") - - except MultipartError: - if strict: - for part in files.values(): - part.close() - raise - - return forms, files diff --git a/src/treq/content.py b/src/treq/content.py index f673b1b5..b4a238f9 100644 --- a/src/treq/content.py +++ b/src/treq/content.py @@ -5,7 +5,7 @@ import json from typing import Any, Callable, Optional, cast -from ._multipart import parse_options_header +from multipart import parse_options_header from twisted.internet.defer import Deferred, succeed from twisted.internet.protocol import Protocol, connectionDone from twisted.python.failure import Failure diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index 8a7e3a06..4614342d 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -5,7 +5,7 @@ from io import BytesIO -from .._multipart import MultipartParser +from multipart import MultipartParser from twisted.trial import unittest from zope.interface.verify import verifyObject diff --git a/tox.ini b/tox.ini index d3d48642..8c57339d 100644 --- a/tox.ini +++ b/tox.ini @@ -85,4 +85,4 @@ commands = # This is a minimal Black-compatible config. # See https://black.readthedocs.io/en/stable/compatible_configs.html#flake8 max-line-length = 88 -extend-ignore = E203, W503, E266 +extend-ignore = E203, W503