diff --git a/websockets/test_protocol.py b/websockets/test_protocol.py index 2e256c6..312f7e7 100644 --- a/websockets/test_protocol.py +++ b/websockets/test_protocol.py @@ -258,7 +258,10 @@ class CommonTests: self.run_loop_once() # The connection is established. self.assertEqual(self.protocol.local_address, ('host', 4312)) - get_extra_info.assert_called_once_with('sockname', None) + if get_extra_info.call_count == 2: + assert get_extra_info.call_args_list == [(('sslcontext',),), (('sockname', None),)] + else: + get_extra_info.assert_called_once_with('sockname', None) def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=('host', 4312)) @@ -268,7 +271,10 @@ class CommonTests: self.run_loop_once() # The connection is established. self.assertEqual(self.protocol.remote_address, ('host', 4312)) - get_extra_info.assert_called_once_with('peername', None) + if get_extra_info.call_count == 2: + assert get_extra_info.call_args_list == [(('sslcontext',),), (('peername', None),)] + else: + get_extra_info.assert_called_once_with('peername', None) def test_open(self): self.assertTrue(self.protocol.open)