--- a/paramiko/__init__.py +++ b/paramiko/__init__.py @@ -19,12 +19,15 @@ # flake8: noqa import sys from paramiko._version import __version__, __version_info__ -from paramiko.transport import SecurityOptions, Transport +from paramiko.transport import ( + SecurityOptions, + Transport, +) from paramiko.client import ( - SSHClient, - MissingHostKeyPolicy, AutoAddPolicy, + MissingHostKeyPolicy, RejectPolicy, + SSHClient, WarningPolicy, ) from paramiko.auth_handler import AuthHandler @@ -43,6 +46,7 @@ from paramiko.ssh_exception import ( ConfigParseError, CouldNotCanonicalize, IncompatiblePeer, + MessageOrderError, PasswordRequiredException, ProxyCommandFailure, SSHException, --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -86,6 +86,7 @@ class Packetizer(object): self.__need_rekey = False self.__init_count = 0 self.__remainder = bytes() + self._initial_kex_done = False # used for noticing when to re-key: self.__sent_bytes = 0 @@ -130,6 +131,12 @@ class Packetizer(object): def closed(self): return self.__closed + def reset_seqno_out(self): + self.__sequence_number_out = 0 + + def reset_seqno_in(self): + self.__sequence_number_in = 0 + def set_log(self, log): """ Set the Python log object to use for logging. @@ -425,9 +432,12 @@ class Packetizer(object): out += compute_hmac( self.__mac_key_out, payload, self.__mac_engine_out )[: self.__mac_size_out] - self.__sequence_number_out = ( - self.__sequence_number_out + 1 - ) & xffffffff + next_seq = (self.__sequence_number_out + 1) & xffffffff + if next_seq == 0 and not self._initial_kex_done: + raise SSHException( + "Sequence number rolled over during initial kex!" + ) + self.__sequence_number_out = next_seq self.write_all(out) self.__sent_bytes += len(out) @@ -531,7 +541,12 @@ class Packetizer(object): msg = Message(payload[1:]) msg.seqno = self.__sequence_number_in - self.__sequence_number_in = (self.__sequence_number_in + 1) & xffffffff + next_seq = (self.__sequence_number_in + 1) & xffffffff + if next_seq == 0 and not self._initial_kex_done: + raise SSHException( + "Sequence number rolled over during initial kex!" + ) + self.__sequence_number_in = next_seq # check for rekey raw_packet_size = packet_size + self.__mac_size_in + 4 --- a/paramiko/ssh_exception.py +++ b/paramiko/ssh_exception.py @@ -235,3 +235,13 @@ class ConfigParseError(SSHException): """ pass + + +class MessageOrderError(SSHException): + """ + Out-of-order protocol messages were received, violating "strict kex" mode. + + .. versionadded:: 3.4 + """ + + pass --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -106,11 +106,12 @@ from paramiko.ecdsakey import ECDSAKey from paramiko.server import ServerInterface from paramiko.sftp_client import SFTPClient from paramiko.ssh_exception import ( - SSHException, BadAuthenticationType, ChannelException, IncompatiblePeer, + MessageOrderError, ProxyCommandFailure, + SSHException, ) from paramiko.util import retry_on_signal, ClosingContextManager, clamp_value @@ -329,6 +330,8 @@ class Transport(threading.Thread, Closin gss_deleg_creds=True, disabled_algorithms=None, server_sig_algs=True, + strict_kex=True, + packetizer_class=None, ): """ Create a new SSH session over an existing socket, or socket-like @@ -395,6 +398,13 @@ class Transport(threading.Thread, Closin Whether to send an extra message to compatible clients, in server mode, with a list of supported pubkey algorithms. Default: ``True``. + :param bool strict_kex: + Whether to advertise (and implement, if client also advertises + support for) a "strict kex" mode for safer handshaking. Default: + ``True``. + :param packetizer_class: + Which class to use for instantiating the internal packet handler. + Default: ``None`` (i.e.: use `Packetizer` as normal). .. versionchanged:: 1.15 Added the ``default_window_size`` and ``default_max_packet_size`` @@ -405,10 +415,16 @@ class Transport(threading.Thread, Closin Added the ``disabled_algorithms`` kwarg. .. versionchanged:: 2.9 Added the ``server_sig_algs`` kwarg. + .. versionchanged:: 3.4 + Added the ``strict_kex`` kwarg. + .. versionchanged:: 3.4 + Added the ``packetizer_class`` kwarg. """ self.active = False self.hostname = None self.server_extensions = {} + self.advertise_strict_kex = strict_kex + self.agreed_on_strict_kex = False if isinstance(sock, string_types): # convert "host:port" into (host, port) @@ -450,7 +466,7 @@ class Transport(threading.Thread, Closin self.sock.settimeout(self._active_check_timeout) # negotiated crypto parameters - self.packetizer = Packetizer(sock) + self.packetizer = (packetizer_class or Packetizer)(sock) self.local_version = "SSH-" + self._PROTO_ID + "-" + self._CLIENT_ID self.remote_version = "" self.local_cipher = self.remote_cipher = "" @@ -524,6 +540,20 @@ class Transport(threading.Thread, Closin self.server_accept_cv = threading.Condition(self.lock) self.subsystem_table = {} + # Handler table, now set at init time for easier per-instance + # manipulation and subclass twiddling. + self._handler_table = { + MSG_EXT_INFO: self._parse_ext_info, + MSG_NEWKEYS: self._parse_newkeys, + MSG_GLOBAL_REQUEST: self._parse_global_request, + MSG_REQUEST_SUCCESS: self._parse_request_success, + MSG_REQUEST_FAILURE: self._parse_request_failure, + MSG_CHANNEL_OPEN_SUCCESS: self._parse_channel_open_success, + MSG_CHANNEL_OPEN_FAILURE: self._parse_channel_open_failure, + MSG_CHANNEL_OPEN: self._parse_channel_open, + MSG_KEXINIT: self._negotiate_keys, + } + def _filter_algorithm(self, type_): default = getattr(self, "_preferred_{}".format(type_)) return tuple( @@ -2067,6 +2097,20 @@ class Transport(threading.Thread, Closin # be empty.) return reply + def _enforce_strict_kex(self, ptype): + """ + Conditionally raise `MessageOrderError` during strict initial kex. + + This method should only be called inside code that handles non-KEXINIT + messages; it does not interrogate ``ptype`` besides using it to log + more accurately. + """ + if self.agreed_on_strict_kex and not self.initial_kex_done: + name = MSG_NAMES.get(ptype, f"msg {ptype}") + raise MessageOrderError( + f"In strict-kex mode, but was sent {name!r}!" + ) + def run(self): # (use the exposed "run" method, because if we specify a thread target # of a private method, threading.Thread will keep a reference to it @@ -2111,16 +2155,21 @@ class Transport(threading.Thread, Closin except NeedRekeyException: continue if ptype == MSG_IGNORE: + self._enforce_strict_kex(ptype) continue elif ptype == MSG_DISCONNECT: self._parse_disconnect(m) break elif ptype == MSG_DEBUG: + self._enforce_strict_kex(ptype) self._parse_debug(m) continue if len(self._expected_packet) > 0: if ptype not in self._expected_packet: - raise SSHException( + exc_class = SSHException + if self.agreed_on_strict_kex: + exc_class = MessageOrderError + raise exc_class( "Expecting packet from {!r}, got {:d}".format( self._expected_packet, ptype ) @@ -2135,7 +2184,7 @@ class Transport(threading.Thread, Closin if error_msg: self._send_message(error_msg) else: - self._handler_table[ptype](self, m) + self._handler_table[ptype](m) elif ptype in self._channel_handler_table: chanid = m.get_int() chan = self._channels.get(chanid) @@ -2342,12 +2391,18 @@ class Transport(threading.Thread, Closin ) else: available_server_keys = self.preferred_keys - # Signal support for MSG_EXT_INFO. + # Signal support for MSG_EXT_INFO so server will send it to us. # NOTE: doing this here handily means we don't even consider this # value when agreeing on real kex algo to use (which is a common # pitfall when adding this apparently). kex_algos.append("ext-info-c") + # Similar to ext-info, but used in both server modes, so done outside + # of above if/else. + if self.advertise_strict_kex: + which = "s" if self.server_mode else "c" + kex_algos.append(f"kex-strict-{which}-v00@openssh.com") + m = Message() m.add_byte(cMSG_KEXINIT) m.add_bytes(os.urandom(16)) @@ -2388,7 +2443,8 @@ class Transport(threading.Thread, Closin def _get_latest_kex_init(self): return self._really_parse_kex_init( - Message(self._latest_kex_init), ignore_first_byte=True + Message(self._latest_kex_init), + ignore_first_byte=True, ) def _parse_kex_init(self, m): @@ -2427,10 +2483,39 @@ class Transport(threading.Thread, Closin self._log(DEBUG, "kex follows: {}".format(kex_follows)) self._log(DEBUG, "=== Key exchange agreements ===") - # Strip out ext-info "kex algo" + # Record, and strip out, ext-info and/or strict-kex non-algorithms self._remote_ext_info = None - if kex_algo_list[-1].startswith("ext-info-"): - self._remote_ext_info = kex_algo_list.pop() + self._remote_strict_kex = None + to_pop = [] + for i, algo in enumerate(kex_algo_list): + if algo.startswith("ext-info-"): + self._remote_ext_info = algo + to_pop.insert(0, i) + elif algo.startswith("kex-strict-"): + # NOTE: this is what we are expecting from the /remote/ end. + which = "c" if self.server_mode else "s" + expected = f"kex-strict-{which}-v00@openssh.com" + # Set strict mode if agreed. + self.agreed_on_strict_kex = ( + algo == expected and self.advertise_strict_kex + ) + self._log( + DEBUG, f"Strict kex mode: {self.agreed_on_strict_kex}" + ) + to_pop.insert(0, i) + for i in to_pop: + kex_algo_list.pop(i) + + # CVE mitigation: expect zeroed-out seqno anytime we are performing kex + # init phase, if strict mode was negotiated. + if ( + self.agreed_on_strict_kex + and not self.initial_kex_done + and m.seqno != 0 + ): + raise MessageOrderError( + "In strict-kex mode, but KEXINIT was not the first packet!" + ) # as a server, we pick the first item in the client's list that we # support. @@ -2631,6 +2716,13 @@ class Transport(threading.Thread, Closin ): self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) + # Reset inbound sequence number if strict mode. + if self.agreed_on_strict_kex: + self._log( + DEBUG, + "Resetting inbound seqno after NEWKEYS due to strict mode", + ) + self.packetizer.reset_seqno_in() def _activate_outbound(self): """switch on newly negotiated encryption parameters for @@ -2638,6 +2730,13 @@ class Transport(threading.Thread, Closin m = Message() m.add_byte(cMSG_NEWKEYS) self._send_message(m) + # Reset outbound sequence number if strict mode. + if self.agreed_on_strict_kex: + self._log( + DEBUG, + "Resetting outbound seqno after NEWKEYS due to strict mode", + ) + self.packetizer.reset_seqno_out() block_size = self._cipher_info[self.local_cipher]["block-size"] if self.server_mode: IV_out = self._compute_key("B", block_size) @@ -2728,7 +2827,9 @@ class Transport(threading.Thread, Closin self.auth_handler = AuthHandler(self) if not self.initial_kex_done: # this was the first key exchange - self.initial_kex_done = True + # (also signal to packetizer as it sometimes wants to know this + # status as well, eg when seqnos rollover) + self.initial_kex_done = self.packetizer._initial_kex_done = True # send an event? if self.completion_event is not None: self.completion_event.set() @@ -2982,18 +3083,6 @@ class Transport(threading.Thread, Closin finally: self.lock.release() - _handler_table = { - MSG_EXT_INFO: _parse_ext_info, - MSG_NEWKEYS: _parse_newkeys, - MSG_GLOBAL_REQUEST: _parse_global_request, - MSG_REQUEST_SUCCESS: _parse_request_success, - MSG_REQUEST_FAILURE: _parse_request_failure, - MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success, - MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure, - MSG_CHANNEL_OPEN: _parse_channel_open, - MSG_KEXINIT: _negotiate_keys, - } - _channel_handler_table = { MSG_CHANNEL_SUCCESS: Channel._request_success, MSG_CHANNEL_FAILURE: Channel._request_failed, --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -23,12 +23,14 @@ Some unit tests for the ssh2 protocol in from __future__ import with_statement from binascii import hexlify +import itertools from contextlib import contextmanager import select import socket import time import threading import random +import sys import unittest try: @@ -37,14 +39,15 @@ except ImportError: from mock import Mock from paramiko import ( + AuthenticationException, AuthHandler, ChannelException, DSSKey, + IncompatiblePeer, + MessageOrderError, Packetizer, RSAKey, SSHException, - AuthenticationException, - IncompatiblePeer, SecurityOptions, ServerInterface, Transport, @@ -57,7 +60,11 @@ from paramiko.common import ( MAX_WINDOW_SIZE, MIN_PACKET_SIZE, MIN_WINDOW_SIZE, + MSG_CHANNEL_OPEN, + MSG_DEBUG, + MSG_IGNORE, MSG_KEXINIT, + MSG_UNIMPLEMENTED, MSG_USERAUTH_SUCCESS, cMSG_CHANNEL_WINDOW_ADJUST, cMSG_UNIMPLEMENTED, @@ -67,6 +74,7 @@ from paramiko.message import Message from .util import needs_builtin, _support, requires_sha1_signing, slow from .loop import LoopSocket +from pytest import mark, raises LONG_BANNER = """\ @@ -154,6 +162,10 @@ class NullServer(ServerInterface): self._tcpip_dest = destination return OPEN_SUCCEEDED +# Faux 'packet type' we do not implement and are unlikely ever to (but which is +# technically "within spec" re RFC 4251 +MSG_FUGGEDABOUTIT = 253 + class TransportTest(unittest.TestCase): def setUp(self): @@ -1119,6 +1131,16 @@ class TransportTest(unittest.TestCase): # Real fix's behavior self._expect_unimplemented() + def test_can_override_packetizer_used(self): + class MyPacketizer(Packetizer): + pass + + # control case + assert Transport(sock=LoopSocket()).packetizer.__class__ is Packetizer + # overridden case + tweaked = Transport(sock=LoopSocket(), packetizer_class=MyPacketizer) + assert tweaked.packetizer.__class__ is MyPacketizer + class AlgorithmDisablingTests(unittest.TestCase): def test_preferred_lists_default_to_private_attribute_contents(self): @@ -1202,10 +1224,17 @@ def server( connect=None, pubkeys=None, catch_error=False, + transport_factory=None, + server_transport_factory=None, + defer=False, + skip_verify=False, ): """ SSH server contextmanager for testing. + Yields a tuple of ``(tc, ts)`` (client- and server-side `Transport` + objects), or ``(tc, ts, err)`` when ``catch_error==True``. + :param hostkey: Host key to use for the server; if None, loads ``test_rsa.key``. @@ -1222,6 +1251,17 @@ def server( :param catch_error: Whether to capture connection errors & yield from contextmanager. Necessary for connection_time exception testing. + :param transport_factory: + Like the same-named param in SSHClient: which Transport class to use. + :param server_transport_factory: + Like ``transport_factory``, but only impacts the server transport. + :param bool defer: + Whether to defer authentication during connecting. + + This is really just shorthand for ``connect={}`` which would do roughly + the same thing. Also: this implies skip_verify=True automatically! + :param bool skip_verify: + Whether NOT to do the default "make sure auth passed" check. """ if init is None: init = {} @@ -1230,12 +1270,21 @@ def server( if client_init is None: client_init = {} if connect is None: - connect = dict(username="slowdive", password="pygmalion") + # No auth at all please + if defer: + connect = dict() + # Default username based auth + else: + connect = dict(username="slowdive", password="pygmalion") socks = LoopSocket() sockc = LoopSocket() sockc.link(socks) - tc = Transport(sockc, **dict(init, **client_init)) - ts = Transport(socks, **dict(init, **server_init)) + if transport_factory is None: + transport_factory = Transport + if server_transport_factory is None: + server_transport_factory = transport_factory + tc = transport_factory(sockc, **dict(init, **client_init)) + ts = server_transport_factory(socks, **dict(init, **server_init)) if hostkey is None: hostkey = RSAKey.from_private_key_file(_support("test_rsa.key")) @@ -1354,10 +1403,14 @@ class TestSHA2SignatureKeyExchange(unitt class TestExtInfo(unittest.TestCase): - def test_ext_info_handshake(self): + def test_ext_info_handshake_exposed_in_client_kexinit(self): with server() as (tc, _): + # NOTE: this is latest KEXINIT /sent by us/ (Transport retains it) kex = tc._get_latest_kex_init() - assert kex["kex_algo_list"][-1] == "ext-info-c" + # flag in KexAlgorithms list + assert "ext-info-c" in kex["kex_algo_list"] + # data stored on Transport after hearing back from a compatible + # server (such as ourselves in server mode) assert tc.server_extensions == { "server-sig-algs": b"ssh-ed25519,ecdsa-sha2-nistp256,ecdsa-sha2-nistp384,ecdsa-sha2-nistp521,rsa-sha2-512,rsa-sha2-256,ssh-rsa,ssh-dss" # noqa } @@ -1463,3 +1516,187 @@ class TestSHA2SignaturePubkeys(unittest. ) as (tc, ts): assert tc.is_authenticated() assert tc._agreed_pubkey_algorithm == "rsa-sha2-256" + + +class BadSeqPacketizer(Packetizer): + def read_message(self): + cmd, msg = super().read_message() + # Only mess w/ seqno if kexinit. + if cmd is MSG_KEXINIT: + # NOTE: this is /only/ the copy of the seqno which gets + # transmitted up from Packetizer; it's not modifying + # Packetizer's own internal seqno. For these tests, + # modifying the latter isn't required, and is also harder + # to do w/o triggering MAC mismatches. + msg.seqno = 17 # arbitrary nonzero int + return cmd, msg + + +class TestStrictKex: + def test_kex_algos_includes_kex_strict_c(self): + with server() as (tc, _): + kex = tc._get_latest_kex_init() + assert "kex-strict-c-v00@openssh.com" in kex["kex_algo_list"] + + @mark.parametrize( + "server_active,client_active", + itertools.product([True, False], repeat=2), + ) + def test_mode_agreement(self, server_active, client_active): + with server( + server_init=dict(strict_kex=server_active), + client_init=dict(strict_kex=client_active), + ) as (tc, ts): + if server_active and client_active: + assert tc.agreed_on_strict_kex is True + assert ts.agreed_on_strict_kex is True + else: + assert tc.agreed_on_strict_kex is False + assert ts.agreed_on_strict_kex is False + + def test_mode_advertised_by_default(self): + # NOTE: no explicit strict_kex overrides... + with server() as (tc, ts): + assert all( + ( + tc.advertise_strict_kex, + tc.agreed_on_strict_kex, + ts.advertise_strict_kex, + ts.agreed_on_strict_kex, + ) + ) + + @mark.parametrize( + "ptype", + ( + # "normal" but definitely out-of-order message + MSG_CHANNEL_OPEN, + # Normally ignored, but not in this case + MSG_IGNORE, + # Normally triggers debug parsing, but not in this case + MSG_DEBUG, + # Normally ignored, but...you get the idea + MSG_UNIMPLEMENTED, + # Not real, so would normally trigger us /sending/ + # MSG_UNIMPLEMENTED, but... + MSG_FUGGEDABOUTIT, + ), + ) + def test_MessageOrderError_non_kex_messages_in_initial_kex(self, ptype): + class AttackTransport(Transport): + # Easiest apparent spot on server side which is: + # - late enough for both ends to have handshook on strict mode + # - early enough to be in the window of opportunity for Terrapin + # attack; essentially during actual kex, when the engine is + # waiting for things like MSG_KEXECDH_REPLY (for eg curve25519). + def _negotiate_keys(self, m): + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() + if self.local_kex_init is None: + # remote side wants to renegotiate + self._send_kex_init() + self._parse_kex_init(m) + # Here, we would normally kick over to kex_engine, but instead + # we want the server to send the OOO message. + m = Message() + m.add_byte(byte_chr(ptype)) + # rest of packet unnecessary... + self._send_message(m) + + with raises(MessageOrderError): + with server(server_transport_factory=AttackTransport) as (tc, _): + pass # above should run and except during connect() + + def test_SSHException_raised_on_out_of_order_messages_when_not_strict( + self, + ): + # This is kind of dumb (either situation is still fatal!) but whatever, + # may as well be strict with our new strict flag... + with raises(SSHException) as info: # would be true either way, but + with server( + client_init=dict(strict_kex=False), + ) as (tc, _): + tc._expect_packet(MSG_KEXINIT) + tc.open_session() + assert info.type is SSHException # NOT MessageOrderError! + + def test_error_not_raised_when_kexinit_not_seq_0_but_unstrict(self): + with server( + client_init=dict( + # Disable strict kex + strict_kex=False, + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which would trigger the + # new logic if we'd forgotten to wrap it in strict-kex check + packetizer_class=BadSeqPacketizer, + ), + ): + pass # kexinit happens at connect... + + def test_MessageOrderError_raised_when_kexinit_not_seq_0_and_strict(self): + with raises(MessageOrderError): + with server( + # Give our clientside a packetizer that sets all kexinit + # Message objects to have .seqno==17, which should trigger the + # new logic (given we are NOT disabling strict-mode) + client_init=dict(packetizer_class=BadSeqPacketizer), + ): + pass # kexinit happens at connect... + + def test_sequence_numbers_reset_on_newkeys_when_strict(self): + with server(defer=True) as (tc, ts): + # When in strict mode, these should all be zero or close to it + # (post-kexinit, pre-auth). + # Server->client will be 1 (EXT_INFO got sent after NEWKEYS) + assert tc.packetizer._Packetizer__sequence_number_in == 1 + assert ts.packetizer._Packetizer__sequence_number_out == 1 + # Client->server will be 0 + assert tc.packetizer._Packetizer__sequence_number_out == 0 + assert ts.packetizer._Packetizer__sequence_number_in == 0 + + def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self): + with server(defer=True, client_init=dict(strict_kex=False)) as ( + tc, + ts, + ): + # When not in strict mode, these will all be ~3-4 or so + # (post-kexinit, pre-auth). Not encoding exact values as it will + # change anytime we mess with the test harness... + assert tc.packetizer._Packetizer__sequence_number_in != 0 + assert tc.packetizer._Packetizer__sequence_number_out != 0 + assert ts.packetizer._Packetizer__sequence_number_in != 0 + assert ts.packetizer._Packetizer__sequence_number_out != 0 + + def test_sequence_number_rollover_detected(self): + class RolloverTransport(Transport): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Induce an about-to-rollover seqno, such that it rolls over + # during initial kex. + setattr( + self.packetizer, + "_Packetizer__sequence_number_in", + sys.maxsize, + ) + setattr( + self.packetizer, + "_Packetizer__sequence_number_out", + sys.maxsize, + ) + + with raises( + SSHException, + match=r"Sequence number rolled over during initial kex!", + ): + with server( + client_init=dict( + # Disable strict kex - this should happen always + strict_kex=False, + ), + # Transport which tickles its packetizer seqno's + transport_factory=RolloverTransport, + ): + pass # kexinit happens at connect...