From cd1d137d7905bfd4a6174d45051e04cd4dc97626 Mon Sep 17 00:00:00 2001 From: Jason Madden Date: Thu, 31 Aug 2023 17:05:48 -0500 Subject: [PATCH] gevent.pywsgi: Much improved handling of chunk trailers. Validation is much stricter to the specification. Fixes #1989 (cherry picked from commit 2f53c851eaf926767fbac62385615efd4886221c) --- docs/changes/1989.bugfix | 26 ++ src/gevent/pywsgi.py | 236 +++++++++++++---- src/gevent/subprocess.py | 7 +- src/gevent/testing/testcase.py | 461 +++++++++++++++++++++++++++++++++ src/greentest/test__pywsgi.py | 224 +++++++++++++++- 5 files changed, 886 insertions(+), 68 deletions(-) create mode 100644 docs/changes/1989.bugfix create mode 100644 src/gevent/testing/testcase.py diff --git a/docs/changes/1989.bugfix b/docs/changes/1989.bugfix new file mode 100644 index 00000000..7ce4a93a --- /dev/null +++ b/docs/changes/1989.bugfix @@ -0,0 +1,26 @@ +Make ``gevent.pywsgi`` comply more closely with the HTTP specification +for chunked transfer encoding. In particular, we are much stricter +about trailers, and trailers that are invalid (too long or featuring +disallowed characters) forcibly close the connection to the client +*after* the results have been sent. + +Trailers otherwise continue to be ignored and are not available to the +WSGI application. + +Previously, carefully crafted invalid trailers in chunked requests on +keep-alive connections might appear as two requests to +``gevent.pywsgi``. Because this was handled exactly as a normal +keep-alive connection with two requests, the WSGI application should +handle it normally. However, if you were counting on some upstream +server to filter incoming requests based on paths or header fields, +and the upstream server simply passed trailers through without +validating them, then this embedded second request would bypass those +checks. (If the upstream server validated that the trailers meet the +HTTP specification, this could not occur, because characters that are +required in an HTTP request, like a space, are not allowed in +trailers.) CVE-2023-41419 was reserved for this. + +Our thanks to the original reporters, Keran Mu +(mkr22@mails.tsinghua.edu.cn) and Jianjun Chen +(jianjun@tsinghua.edu.cn), from Tsinghua University and Zhongguancun +Laboratory. diff --git a/src/gevent/pywsgi.py b/src/gevent/pywsgi.py index 2726f6d4..c7b2f9c5 100644 --- a/src/gevent/pywsgi.py +++ b/src/gevent/pywsgi.py @@ -8,6 +8,25 @@ WSGI work is handled by :class:`WSGIHandler` --- a new instance is created for each request. The server can be customized to use different subclasses of :class:`WSGIHandler`. +.. important:: + + This server is intended primarily for development and testing, and + secondarily for other "safe" scenarios where it will not be exposed to + potentially malicious input. The code has not been security audited, + and is not intended for direct exposure to the public Internet. For production + usage on the Internet, either choose a production-strength server such as + gunicorn, or put a reverse proxy between gevent and the Internet. + +.. versionchanged:: NEXT + + Complies more closely with the HTTP specification for chunked transfer encoding. + In particular, we are much stricter about trailers, and trailers that + are invalid (too long or featuring disallowed characters) forcibly close + the connection to the client *after* the results have been sent. + + Trailers otherwise continue to be ignored and are not available to the + WSGI application. + """ # FIXME: Can we refactor to make smallor? # pylint:disable=too-many-lines @@ -20,10 +39,7 @@ import time import traceback from datetime import datetime -try: - from urllib import unquote -except ImportError: - from urllib.parse import unquote # python 2 pylint:disable=import-error,no-name-in-module +from urllib.parse import unquote from gevent import socket import gevent @@ -51,29 +67,52 @@ __all__ = [ MAX_REQUEST_LINE = 8192 # Weekday and month names for HTTP date/time formatting; always English! -_WEEKDAYNAME = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] -_MONTHNAME = [None, # Dummy so we can use 1-based month numbers +_WEEKDAYNAME = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") +_MONTHNAME = (None, # Dummy so we can use 1-based month numbers "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec") # The contents of the "HEX" grammar rule for HTTP, upper and lowercase A-F plus digits, # in byte form for comparing to the network. _HEX = string.hexdigits.encode('ascii') +# The characters allowed in "token" rules. + +# token = 1*tchar +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" +# / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" +# / DIGIT / ALPHA +# ; any VCHAR, except delimiters +# ALPHA = %x41-5A / %x61-7A ; A-Z / a-z +_ALLOWED_TOKEN_CHARS = frozenset( + # Remember we have to be careful because bytestrings + # inexplicably iterate as integers, which are not equal to bytes. + + # explicit chars then DIGIT + (c.encode('ascii') for c in "!#$%&'*+-.^_`|~0123456789") + # Then we add ALPHA +) | {c.encode('ascii') for c in string.ascii_letters} +assert b'A' in _ALLOWED_TOKEN_CHARS + + # Errors _ERRORS = dict() _INTERNAL_ERROR_STATUS = '500 Internal Server Error' _INTERNAL_ERROR_BODY = b'Internal Server Error' -_INTERNAL_ERROR_HEADERS = [('Content-Type', 'text/plain'), - ('Connection', 'close'), - ('Content-Length', str(len(_INTERNAL_ERROR_BODY)))] +_INTERNAL_ERROR_HEADERS = ( + ('Content-Type', 'text/plain'), + ('Connection', 'close'), + ('Content-Length', str(len(_INTERNAL_ERROR_BODY))) +) _ERRORS[500] = (_INTERNAL_ERROR_STATUS, _INTERNAL_ERROR_HEADERS, _INTERNAL_ERROR_BODY) _BAD_REQUEST_STATUS = '400 Bad Request' _BAD_REQUEST_BODY = '' -_BAD_REQUEST_HEADERS = [('Content-Type', 'text/plain'), - ('Connection', 'close'), - ('Content-Length', str(len(_BAD_REQUEST_BODY)))] +_BAD_REQUEST_HEADERS = ( + ('Content-Type', 'text/plain'), + ('Connection', 'close'), + ('Content-Length', str(len(_BAD_REQUEST_BODY))) +) _ERRORS[400] = (_BAD_REQUEST_STATUS, _BAD_REQUEST_HEADERS, _BAD_REQUEST_BODY) _REQUEST_TOO_LONG_RESPONSE = b"HTTP/1.1 414 Request URI Too Long\r\nConnection: close\r\nContent-length: 0\r\n\r\n" @@ -198,23 +237,32 @@ class Input(object): # Read and return the next integer chunk length. If no # chunk length can be read, raises _InvalidClientInput. - # Here's the production for a chunk: - # (http://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html) - # chunk = chunk-size [ chunk-extension ] CRLF - # chunk-data CRLF - # chunk-size = 1*HEX - # chunk-extension= *( ";" chunk-ext-name [ "=" chunk-ext-val ] ) - # chunk-ext-name = token - # chunk-ext-val = token | quoted-string - - # To cope with malicious or broken clients that fail to send valid - # chunk lines, the strategy is to read character by character until we either reach - # a ; or newline. If at any time we read a non-HEX digit, we bail. If we hit a - # ;, indicating an chunk-extension, we'll read up to the next - # MAX_REQUEST_LINE characters - # looking for the CRLF, and if we don't find it, we bail. If we read more than 16 hex characters, - # (the number needed to represent a 64-bit chunk size), we bail (this protects us from - # a client that sends an infinite stream of `F`, for example). + # Here's the production for a chunk (actually the whole body): + # (https://www.rfc-editor.org/rfc/rfc7230#section-4.1) + + # chunked-body = *chunk + # last-chunk + # trailer-part + # CRLF + # + # chunk = chunk-size [ chunk-ext ] CRLF + # chunk-data CRLF + # chunk-size = 1*HEXDIG + # last-chunk = 1*("0") [ chunk-ext ] CRLF + # trailer-part = *( header-field CRLF ) + # chunk-data = 1*OCTET ; a sequence of chunk-size octets + + # To cope with malicious or broken clients that fail to send + # valid chunk lines, the strategy is to read character by + # character until we either reach a ; or newline. If at any + # time we read a non-HEX digit, we bail. If we hit a ;, + # indicating an chunk-extension, we'll read up to the next + # MAX_REQUEST_LINE characters ("A server ought to limit the + # total length of chunk extensions received") looking for the + # CRLF, and if we don't find it, we bail. If we read more than + # 16 hex characters, (the number needed to represent a 64-bit + # chunk size), we bail (this protects us from a client that + # sends an infinite stream of `F`, for example). buf = BytesIO() while 1: @@ -222,16 +270,20 @@ class Input(object): if not char: self._chunked_input_error = True raise _InvalidClientInput("EOF before chunk end reached") - if char == b'\r': - break - if char == b';': + + if char in ( + b'\r', # Beginning EOL + b';', # Beginning extension + ): break - if char not in _HEX: + if char not in _HEX: # Invalid data. self._chunked_input_error = True raise _InvalidClientInput("Non-hex data", char) + buf.write(char) - if buf.tell() > 16: + + if buf.tell() > 16: # Too many hex bytes self._chunked_input_error = True raise _InvalidClientInput("Chunk-size too large.") @@ -251,11 +303,72 @@ class Input(object): if char == b'\r': # We either got here from the main loop or from the # end of an extension + self.__read_chunk_size_crlf(rfile, newline_only=True) + result = int(buf.getvalue(), 16) + if result == 0: + # The only time a chunk size of zero is allowed is the final + # chunk. It is either followed by another \r\n, or some trailers + # which are then followed by \r\n. + while self.__read_chunk_trailer(rfile): + pass + return result + + # Trailers have the following production (they are a header-field followed by CRLF) + # See above for the definition of "token". + # + # header-field = field-name ":" OWS field-value OWS + # field-name = token + # field-value = *( field-content / obs-fold ) + # field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] + # field-vchar = VCHAR / obs-text + # obs-fold = CRLF 1*( SP / HTAB ) + # ; obsolete line folding + # ; see Section 3.2.4 + + + def __read_chunk_trailer(self, rfile, ): + # With rfile positioned just after a \r\n, read a trailer line. + # Return a true value if a non-empty trailer was read, and + # return false if an empty trailer was read (meaning the trailers are + # done). + # If a single line exceeds the MAX_REQUEST_LINE, raise an exception. + # If the field-name portion contains invalid characters, raise an exception. + + i = 0 + empty = True + seen_field_name = False + while i < MAX_REQUEST_LINE: + char = rfile.read(1) + if char == b'\r': + # Either read the next \n or raise an error. + self.__read_chunk_size_crlf(rfile, newline_only=True) + break + # Not a \r, so we are NOT an empty chunk. + empty = False + if char == b':' and i > 0: + # We're ending the field-name part; stop validating characters. + # Unless : was the first character... + seen_field_name = True + if not seen_field_name and char not in _ALLOWED_TOKEN_CHARS: + raise _InvalidClientInput('Invalid token character: %r' % (char,)) + i += 1 + else: + # We read too much + self._chunked_input_error = True + raise _InvalidClientInput("Too large chunk trailer") + return not empty + + def __read_chunk_size_crlf(self, rfile, newline_only=False): + # Also for safety, correctly verify that we get \r\n when expected. + if not newline_only: char = rfile.read(1) - if char != b'\n': + if char != b'\r': self._chunked_input_error = True - raise _InvalidClientInput("Line didn't end in CRLF") - return int(buf.getvalue(), 16) + raise _InvalidClientInput("Line didn't end in CRLF: %r" % (char,)) + char = rfile.read(1) + if char != b'\n': + self._chunked_input_error = True + raise _InvalidClientInput("Line didn't end in LF: %r" % (char,)) def _chunked_read(self, length=None, use_readline=False): # pylint:disable=too-many-branches @@ -291,7 +404,7 @@ class Input(object): self.position += datalen if self.chunk_length == self.position: - rfile.readline() + self.__read_chunk_size_crlf(rfile) if length is not None: length -= datalen @@ -304,9 +417,9 @@ class Input(object): # determine the next size to read self.chunk_length = self.__read_chunk_length(rfile) self.position = 0 - if self.chunk_length == 0: - # Last chunk. Terminates with a CRLF. - rfile.readline() + # If chunk_length was 0, we already read any trailers and + # validated that we have ended with \r\n\r\n. + return b''.join(response) def read(self, length=None): @@ -521,7 +634,8 @@ class WSGIHandler(object): elif len(words) == 2: self.command, self.path = words if self.command != "GET": - raise _InvalidClientRequest('Expected GET method: %r', raw_requestline) + raise _InvalidClientRequest('Expected GET method; Got command=%r; path=%r; raw=%r' % ( + self.command, self.path, raw_requestline,)) self.request_version = "HTTP/0.9" # QQQ I'm pretty sure we can drop support for HTTP/0.9 else: @@ -936,14 +1050,28 @@ class WSGIHandler(object): finally: try: self.wsgi_input._discard() - except (socket.error, IOError): - # Don't let exceptions during discarding + except _InvalidClientInput: + # This one is deliberately raised to the outer + # scope, because, with the incoming stream in some bad state, + # we can't be sure we can synchronize and properly parse the next + # request. + raise + except socket.error: + # Don't let socket exceptions during discarding # input override any exception that may have been # raised by the application, such as our own _InvalidClientInput. # In the general case, these aren't even worth logging (see the comment # just below) pass - except _InvalidClientInput: + except _InvalidClientInput as ex: + # DO log this one because: + # - Some of the data may have been read and acted on by the + # application; + # - The response may or may not have been sent; + # - It's likely that the client is bad, or malicious, and + # users might wish to take steps to block the client. + self._handle_client_error(ex) + self.close_connection = True self._send_error_response_if_possible(400) except socket.error as ex: if ex.args[0] in (errno.EPIPE, errno.ECONNRESET): @@ -994,16 +1122,22 @@ class WSGIHandler(object): def _handle_client_error(self, ex): # Called for invalid client input # Returns the appropriate error response. - if not isinstance(ex, ValueError): + if not isinstance(ex, (ValueError, _InvalidClientInput)): # XXX: Why not self._log_error to send it through the loop's # handle_error method? + # _InvalidClientRequest is a ValueError; _InvalidClientInput is an IOError. traceback.print_exc() if isinstance(ex, _InvalidClientRequest): - # These come with good error messages, and we want to let - # log_error deal with the formatting, especially to handle encoding - self.log_error(*ex.args) + # No formatting needed, that's already been handled. In fact, because the + # formatted message contains user input, it might have a % in it, and attempting + # to format that with no arguments would be an error. + # However, the error messages do not include the requesting IP + # necessarily, so we do add that. + self.log_error('(from %s) %s', self.client_address, ex.formatted_message) else: - self.log_error('Invalid request: %s', str(ex) or ex.__class__.__name__) + self.log_error('Invalid request (from %s): %s', + self.client_address, + str(ex) or ex.__class__.__name__) return ('400', _BAD_REQUEST_RESPONSE) def _headers(self): diff --git a/src/gevent/subprocess.py b/src/gevent/subprocess.py index 2ea165e5..449e5e32 100644 --- a/src/gevent/subprocess.py +++ b/src/gevent/subprocess.py @@ -280,10 +280,11 @@ def check_output(*popenargs, **kwargs): To capture standard error in the result, use ``stderr=STDOUT``:: - >>> check_output(["/bin/sh", "-c", + >>> output = check_output(["/bin/sh", "-c", ... "ls -l non_existent_file ; exit 0"], - ... stderr=STDOUT) - 'ls: non_existent_file: No such file or directory\n' + ... stderr=STDOUT).decode('ascii').strip() + >>> print(output.rsplit(':', 1)[1].strip()) + No such file or directory There is an additional optional argument, "input", allowing you to pass a string to the subprocess's stdin. If you use this argument diff --git a/src/gevent/testing/testcase.py b/src/gevent/testing/testcase.py new file mode 100644 index 00000000..ddfe5b99 --- /dev/null +++ b/src/gevent/testing/testcase.py @@ -0,0 +1,461 @@ +# Copyright (c) 2018 gevent community +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. +from __future__ import absolute_import, print_function, division + +import sys +import os.path +from contextlib import contextmanager +from unittest import TestCase as BaseTestCase +from functools import wraps + +import gevent +from gevent._util import LazyOnClass +from gevent._compat import perf_counter +from gevent._compat import get_clock_info +from gevent._hub_local import get_hub_if_exists + +from . import sysinfo +from . import params +from . import leakcheck +from . import errorhandler +from . import flaky + +from .patched_tests_setup import get_switch_expected + +class TimeAssertMixin(object): + @flaky.reraises_flaky_timeout() + def assertTimeoutAlmostEqual(self, first, second, places=None, msg=None, delta=None): + try: + self.assertAlmostEqual(first, second, places=places, msg=msg, delta=delta) + except AssertionError: + flaky.reraiseFlakyTestTimeout() + + + if sysinfo.EXPECT_POOR_TIMER_RESOLUTION: + # pylint:disable=unused-argument + def assertTimeWithinRange(self, time_taken, min_time, max_time): + return + else: + def assertTimeWithinRange(self, time_taken, min_time, max_time): + self.assertLessEqual(time_taken, max_time) + self.assertGreaterEqual(time_taken, min_time) + + @contextmanager + def runs_in_given_time(self, expected, fuzzy=None, min_time=None): + if fuzzy is None: + if sysinfo.EXPECT_POOR_TIMER_RESOLUTION or sysinfo.LIBUV: + # The noted timer jitter issues on appveyor/pypy3 + fuzzy = expected * 5.0 + else: + fuzzy = expected / 2.0 + min_time = min_time if min_time is not None else expected - fuzzy + max_time = expected + fuzzy + start = perf_counter() + yield (min_time, max_time) + elapsed = perf_counter() - start + try: + self.assertTrue( + min_time <= elapsed <= max_time, + 'Expected: %r; elapsed: %r; min: %r; max: %r; fuzzy %r; clock_info: %s' % ( + expected, elapsed, min_time, max_time, fuzzy, get_clock_info('perf_counter') + )) + except AssertionError: + flaky.reraiseFlakyTestRaceCondition() + + def runs_in_no_time( + self, + fuzzy=(0.01 if not sysinfo.EXPECT_POOR_TIMER_RESOLUTION and not sysinfo.LIBUV else 1.0)): + return self.runs_in_given_time(0.0, fuzzy) + + +class GreenletAssertMixin(object): + """Assertions related to greenlets.""" + + def assert_greenlet_ready(self, g): + self.assertTrue(g.dead, g) + self.assertTrue(g.ready(), g) + self.assertFalse(g, g) + + def assert_greenlet_not_ready(self, g): + self.assertFalse(g.dead, g) + self.assertFalse(g.ready(), g) + + def assert_greenlet_spawned(self, g): + self.assertTrue(g.started, g) + self.assertFalse(g.dead, g) + + # No difference between spawned and switched-to once + assert_greenlet_started = assert_greenlet_spawned + + def assert_greenlet_finished(self, g): + self.assertFalse(g.started, g) + self.assertTrue(g.dead, g) + + +class StringAssertMixin(object): + """ + Assertions dealing with strings. + """ + + @LazyOnClass + def HEX_NUM_RE(self): + import re + return re.compile('-?0x[0123456789abcdef]+L?', re.I) + + def normalize_addr(self, s, replace='X'): + # https://github.com/PyCQA/pylint/issues/1127 + return self.HEX_NUM_RE.sub(replace, s) # pylint:disable=no-member + + def normalize_module(self, s, module=None, replace='module'): + if module is None: + module = type(self).__module__ + + return s.replace(module, replace) + + def normalize(self, s): + return self.normalize_module(self.normalize_addr(s)) + + def assert_nstr_endswith(self, o, val): + s = str(o) + n = self.normalize(s) + self.assertTrue(n.endswith(val), (s, n)) + + def assert_nstr_startswith(self, o, val): + s = str(o) + n = self.normalize(s) + self.assertTrue(n.startswith(val), (s, n)) + + + +class TestTimeout(gevent.Timeout): + _expire_info = '' + + def __init__(self, timeout, method='Not Given'): + gevent.Timeout.__init__( + self, + timeout, + '%r: test timed out\n' % (method,), + ref=False + ) + + def _on_expiration(self, prev_greenlet, ex): + from gevent.util import format_run_info + loop = gevent.get_hub().loop + debug_info = 'N/A' + if hasattr(loop, 'debug'): + debug_info = [str(s) for s in loop.debug()] + run_info = format_run_info() + self._expire_info = 'Loop Debug:\n%s\nRun Info:\n%s' % ( + '\n'.join(debug_info), '\n'.join(run_info) + ) + gevent.Timeout._on_expiration(self, prev_greenlet, ex) + + def __str__(self): + s = gevent.Timeout.__str__(self) + s += self._expire_info + return s + +def _wrap_timeout(timeout, method): + if timeout is None: + return method + + @wraps(method) + def timeout_wrapper(self, *args, **kwargs): + with TestTimeout(timeout, method): + return method(self, *args, **kwargs) + + return timeout_wrapper + +def _get_class_attr(classDict, bases, attr, default=AttributeError): + NONE = object() + value = classDict.get(attr, NONE) + if value is not NONE: + return value + for base in bases: + value = getattr(base, attr, NONE) + if value is not NONE: + return value + if default is AttributeError: + raise AttributeError('Attribute %r not found\n%s\n%s\n' % (attr, classDict, bases)) + return default + + +class TestCaseMetaClass(type): + # wrap each test method with + # a) timeout check + # b) fatal error check + # c) restore the hub's error handler (see expect_one_error) + # d) totalrefcount check + def __new__(cls, classname, bases, classDict): + # pylint and pep8 fight over what this should be called (mcs or cls). + # pylint gets it right, but we cant scope disable pep8, so we go with + # its convention. + # pylint: disable=bad-mcs-classmethod-argument + timeout = classDict.get('__timeout__', 'NONE') + if timeout == 'NONE': + timeout = getattr(bases[0], '__timeout__', None) + if sysinfo.RUN_LEAKCHECKS and timeout is not None: + timeout *= 6 + check_totalrefcount = _get_class_attr(classDict, bases, 'check_totalrefcount', True) + + error_fatal = _get_class_attr(classDict, bases, 'error_fatal', True) + uses_handle_error = _get_class_attr(classDict, bases, 'uses_handle_error', True) + # Python 3: must copy, we mutate the classDict. Interestingly enough, + # it doesn't actually error out, but under 3.6 we wind up wrapping + # and re-wrapping the same items over and over and over. + for key, value in list(classDict.items()): + if key.startswith('test') and callable(value): + classDict.pop(key) + # XXX: When did we stop doing this? + #value = wrap_switch_count_check(value) + #value = _wrap_timeout(timeout, value) + error_fatal = getattr(value, 'error_fatal', error_fatal) + if error_fatal: + value = errorhandler.wrap_error_fatal(value) + if uses_handle_error: + value = errorhandler.wrap_restore_handle_error(value) + if check_totalrefcount and sysinfo.RUN_LEAKCHECKS: + value = leakcheck.wrap_refcount(value) + classDict[key] = value + return type.__new__(cls, classname, bases, classDict) + +def _noop(): + return + +class SubscriberCleanupMixin(object): + + def setUp(self): + super(SubscriberCleanupMixin, self).setUp() + from gevent import events + self.__old_subscribers = events.subscribers[:] + + def addSubscriber(self, sub): + from gevent import events + events.subscribers.append(sub) + + def tearDown(self): + from gevent import events + events.subscribers[:] = self.__old_subscribers + super(SubscriberCleanupMixin, self).tearDown() + + +class TestCase(TestCaseMetaClass("NewBase", + (SubscriberCleanupMixin, + TimeAssertMixin, + GreenletAssertMixin, + StringAssertMixin, + BaseTestCase,), + {})): + __timeout__ = params.LOCAL_TIMEOUT if not sysinfo.RUNNING_ON_CI else params.CI_TIMEOUT + + switch_expected = 'default' + #: Set this to true to cause errors that get reported to the hub to + #: always get propagated to the main greenlet. This can be done at the + #: class or method level. + #: .. caution:: This can hide errors and make it look like exceptions + #: are propagated even if they're not. + error_fatal = True + uses_handle_error = True + close_on_teardown = () + # This is really used by the SubscriberCleanupMixin + __old_subscribers = () # pylint:disable=unused-private-member + + def run(self, *args, **kwargs): # pylint:disable=signature-differs + if self.switch_expected == 'default': + self.switch_expected = get_switch_expected(self.fullname) + return super(TestCase, self).run(*args, **kwargs) + + def setUp(self): + super(TestCase, self).setUp() + # Especially if we're running in leakcheck mode, where + # the same test gets executed repeatedly, we need to update the + # current time. Tests don't always go through the full event loop, + # so that doesn't always happen. test__pool.py:TestPoolYYY.test_async + # tends to show timeouts that are too short if we don't. + # XXX: Should some core part of the loop call this? + hub = get_hub_if_exists() + if hub and hub.loop: + hub.loop.update_now() + self.close_on_teardown = [] + self.addCleanup(self._tearDownCloseOnTearDown) + + def tearDown(self): + if getattr(self, 'skipTearDown', False): + del self.close_on_teardown[:] + return + + cleanup = getattr(self, 'cleanup', _noop) + cleanup() + self._error = self._none + super(TestCase, self).tearDown() + + def _tearDownCloseOnTearDown(self): + while self.close_on_teardown: + x = self.close_on_teardown.pop() + close = getattr(x, 'close', x) + try: + close() + except Exception: # pylint:disable=broad-except + pass + + def _close_on_teardown(self, resource): + """ + *resource* either has a ``close`` method, or is a + callable. + """ + self.close_on_teardown.append(resource) + return resource + + @property + def testname(self): + return getattr(self, '_testMethodName', '') or getattr(self, '_TestCase__testMethodName') + + @property + def testcasename(self): + return self.__class__.__name__ + '.' + self.testname + + @property + def modulename(self): + return os.path.basename(sys.modules[self.__class__.__module__].__file__).rsplit('.', 1)[0] + + @property + def fullname(self): + return os.path.splitext(os.path.basename(self.modulename))[0] + '.' + self.testcasename + + _none = (None, None, None) + # (context, kind, value) + _error = _none + + def expect_one_error(self): + self.assertEqual(self._error, self._none) + gevent.get_hub().handle_error = self._store_error + + def _store_error(self, where, t, value, tb): + del tb + if self._error != self._none: + gevent.get_hub().parent.throw(t, value) + else: + self._error = (where, t, value) + + def peek_error(self): + return self._error + + def get_error(self): + try: + return self._error + finally: + self._error = self._none + + def assert_error(self, kind=None, value=None, error=None, where_type=None): + if error is None: + error = self.get_error() + econtext, ekind, evalue = error + if kind is not None: + self.assertIsInstance(kind, type) + self.assertIsNotNone( + ekind, + "Error must not be none %r" % (error,)) + assert issubclass(ekind, kind), error + if value is not None: + if isinstance(value, str): + self.assertEqual(str(evalue), value) + else: + self.assertIs(evalue, value) + if where_type is not None: + self.assertIsInstance(econtext, where_type) + return error + + def assertMonkeyPatchedFuncSignatures(self, mod_name, func_names=(), exclude=()): + # If inspect.getfullargspec is not available, + # We use inspect.getargspec because it's the only thing available + # in Python 2.7, but it is deprecated + # pylint:disable=deprecated-method,too-many-locals + import inspect + import warnings + from gevent.monkey import get_original + # XXX: Very similar to gevent.monkey.patch_module. Should refactor? + gevent_module = getattr(__import__('gevent.' + mod_name), mod_name) + module_name = getattr(gevent_module, '__target__', mod_name) + + funcs_given = True + if not func_names: + funcs_given = False + func_names = getattr(gevent_module, '__implements__') + + for func_name in func_names: + if func_name in exclude: + continue + gevent_func = getattr(gevent_module, func_name) + if not inspect.isfunction(gevent_func) and not funcs_given: + continue + + func = get_original(module_name, func_name) + + try: + with warnings.catch_warnings(): + try: + getfullargspec = inspect.getfullargspec + except AttributeError: + warnings.simplefilter("ignore") + getfullargspec = inspect.getargspec + gevent_sig = getfullargspec(gevent_func) + sig = getfullargspec(func) + except TypeError: + if funcs_given: + raise + # Can't do this one. If they specifically asked for it, + # it's an error, otherwise it's not. + # Python 3 can check a lot more than Python 2 can. + continue + self.assertEqual(sig.args, gevent_sig.args, func_name) + # The next two might not actually matter? + self.assertEqual(sig.varargs, gevent_sig.varargs, func_name) + self.assertEqual(sig.defaults, gevent_sig.defaults, func_name) + if hasattr(sig, 'keywords'): # the old version + msg = (func_name, sig.keywords, gevent_sig.keywords) + try: + self.assertEqual(sig.keywords, gevent_sig.keywords, msg) + except AssertionError: + # Ok, if we take `kwargs` and the original function doesn't, + # that's OK. We have to do that as a compatibility hack sometimes to + # work across multiple python versions. + self.assertIsNone(sig.keywords, msg) + self.assertEqual('kwargs', gevent_sig.keywords) + else: + # The new hotness. Unfortunately, we can't actually check these things + # until we drop Python 2 support from the shared code. The only known place + # this is a problem is python 3.11 socket.create_connection(), which we manually + # ignore. So the checks all pass as is. + self.assertEqual(sig.kwonlyargs, gevent_sig.kwonlyargs, func_name) + self.assertEqual(sig.kwonlydefaults, gevent_sig.kwonlydefaults, func_name) + # Should deal with others: https://docs.python.org/3/library/inspect.html#inspect.getfullargspec + + def assertEqualFlakyRaceCondition(self, a, b): + try: + self.assertEqual(a, b) + except AssertionError: + flaky.reraiseFlakyTestRaceCondition() + + def assertStartsWith(self, it, has_prefix): + self.assertTrue(it.startswith(has_prefix), (it, has_prefix)) + + def assertNotMonkeyPatched(self): + from gevent import monkey + self.assertFalse(monkey.is_anything_patched()) diff --git a/src/greentest/test__pywsgi.py b/src/greentest/test__pywsgi.py index 98631f8a..3bbec9c8 100644 --- a/src/greentest/test__pywsgi.py +++ b/src/greentest/test__pywsgi.py @@ -24,21 +24,12 @@ from gevent import monkey monkey.patch_all(thread=False) -try: - from urllib.parse import parse_qs -except ImportError: - # Python 2 - from cgi import parse_qs +from contextlib import contextmanager +from urllib.parse import parse_qs import os import sys -try: - # On Python 2, we want the C-optimized version if - # available; it has different corner-case behaviour than - # the Python implementation, and it used by socket.makefile - # by default. - from cStringIO import StringIO -except ImportError: - from io import BytesIO as StringIO +from io import BytesIO as StringIO + import weakref from wsgiref.validate import validator @@ -165,7 +156,13 @@ class Response(object): assert self.body == body, 'Unexpected body: %r (expected %r)\n%s' % (self.body, body, self) @classmethod - def read(cls, fd, code=200, reason='default', version='1.1', body=None, chunks=None, content_length=None): + def read(cls, fd, code=200, reason='default', version='1.1', + body=None, chunks=None, content_length=None): + """ + Read an HTTP response, optionally perform assertions, + and return the Response object. + """ + # pylint:disable=too-many-branches _status_line, headers = read_headers(fd) self = cls(_status_line, headers) if code is not None: @@ -583,6 +580,39 @@ class TestChunkedPost(TestCase): @staticmethod def application(env, start_response): + start_response('200 OK', [('Content-Type', 'text/plain')]) + if env['PATH_INFO'] == '/readline': + data = env['wsgi.input'].readline(-1) + return [data] + + def test_negative_chunked_readline(self): + data = (b'POST /readline HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' + b'Transfer-Encoding: chunked\r\n\r\n' + b'2\r\noh\r\n4\r\n hai\r\n0\r\n\r\n') + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + + def test_negative_nonchunked_readline(self): + data = (b'POST /readline HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' + b'Content-Length: 6\r\n\r\n' + b'oh hai') + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + + +class TestChunkedPost(TestCase): + + calls = 0 + + def setUp(self): + super().setUp() + self.calls = 0 + + def application(self, env, start_response): + self.calls += 1 + self.assertTrue(env.get('wsgi.input_terminated')) start_response('200 OK', [('Content-Type', 'text/plain')]) if env['PATH_INFO'] == '/a': data = env['wsgi.input'].read(6) @@ -593,6 +623,8 @@ class TestChunkedPost(TestCase): elif env['PATH_INFO'] == '/c': return [x for x in iter(lambda: env['wsgi.input'].read(1), b'')] + return [b'We should not get here', env['PATH_INFO'].encode('ascii')] + def test_014_chunked_post(self): fd = self.makefile() data = (b'POST /a HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n' @@ -660,6 +692,170 @@ class TestChunkedPost(TestCase): fd.write(data) read_http(fd, code=400) + def test_trailers_keepalive_ignored(self): + # Trailers after a chunk are ignored. + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: keep-alive\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0\r\n' # last-chunk + # Normally the final CRLF would go here, but if you put in a + # trailer, it doesn't. + b'trailer1: value1\r\n' + b'trailer2: value2\r\n' + b'\r\n' # Really terminate the chunk. + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n bye\r\n' + b'0\r\n' # last-chunk + ) + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + read_http(fd, body='oh bye') + + self.assertEqual(self.calls, 2) + + def test_trailers_too_long(self): + # Trailers after a chunk are ignored. + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: keep-alive\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0\r\n' # last-chunk + # Normally the final CRLF would go here, but if you put in a + # trailer, it doesn't. + b'trailer2: value2' # not lack of \r\n + ) + data += b't' * pywsgi.MAX_REQUEST_LINE + # No termination, because we detect the trailer as being too + # long and abort the connection. + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + with self.assertRaises(ConnectionClosed): + read_http(fd, body='oh bye') + + def test_trailers_request_smuggling_missing_last_chunk_keep_alive(self): + # When something that looks like a request line comes in the trailer + # as the first line, immediately after an invalid last chunk. + # We detect this and abort the connection, because the + # whitespace in the GET line isn't a legal part of a trailer. + # If we didn't abort the connection, then, because we specified + # keep-alive, the server would be hanging around waiting for more input. + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: keep-alive\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0' # last-chunk, but missing the \r\n + # Normally the final CRLF would go here, but if you put in a + # trailer, it doesn't. + # b'\r\n' + b'GET /path2?a=:123 HTTP/1.1\r\n' + b'Host: a.com\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + with self.assertRaises(ConnectionClosed): + read_http(fd) + + self.assertEqual(self.calls, 1) + + def test_trailers_request_smuggling_missing_last_chunk_close(self): + # Same as the above, except the trailers are actually valid + # and since we ask to close the connection we don't get stuck + # waiting for more input. + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: close\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0\r\n' # last-chunk + # Normally the final CRLF would go here, but if you put in a + # trailer, it doesn't. + # b'\r\n' + b'GETpath2a:123 HTTP/1.1\r\n' + b'Host: a.com\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + with self.assertRaises(ConnectionClosed): + read_http(fd) + + def test_trailers_request_smuggling_header_first(self): + # When something that looks like a header comes in the first line. + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: keep-alive\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0\r\n' # last-chunk, but only one CRLF + b'Header: value\r\n' + b'GET /path2?a=:123 HTTP/1.1\r\n' + b'Host: a.com\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + with self.assertRaises(ConnectionClosed): + read_http(fd, code=400) + + self.assertEqual(self.calls, 1) + + def test_trailers_request_smuggling_request_terminates_then_header(self): + data = ( + b'POST /a HTTP/1.1\r\n' + b'Host: localhost\r\n' + b'Connection: keep-alive\r\n' + b'Transfer-Encoding: chunked\r\n' + b'\r\n' + b'2\r\noh\r\n' + b'4\r\n hai\r\n' + b'0\r\n' # last-chunk + b'\r\n' + b'Header: value' + b'GET /path2?a=:123 HTTP/1.1\r\n' + b'Host: a.com\r\n' + b'Connection: close\r\n' + b'\r\n' + ) + with self.makefile() as fd: + fd.write(data) + read_http(fd, body='oh hai') + read_http(fd, code=400) + + self.assertEqual(self.calls, 1) + class TestUseWrite(TestCase): -- 2.42.0