From 94c0991a62246018bc9909907c2889519158079d Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 4 Jan 2024 11:30:57 +0100 Subject: [PATCH] Add ipv6 support to should_bypass_proxies Add support to should_bypass_proxies to support IPv6 ipaddresses and CIDRs in no_proxy. Includes adding IPv6 support to various other helper functions. --- requests/utils.py | 83 ++++++++++++++++++++++++++++++++++++++------- tests/test_utils.py | 67 ++++++++++++++++++++++++++++++++---- 2 files changed, 131 insertions(+), 19 deletions(-) diff --git a/requests/utils.py b/requests/utils.py index db67938..f3f780c 100644 --- a/requests/utils.py +++ b/requests/utils.py @@ -623,18 +623,46 @@ def requote_uri(uri): return quote(uri, safe=safe_without_percent) +def _get_mask_bits(mask, totalbits=32): + """Converts a mask from /xx format to a int + to be used as a mask for IP's in int format + + Example: if mask is 24 function returns 0xFFFFFF00 + if mask is 24 and totalbits=128 function + returns 0xFFFFFF00000000000000000000000000 + + :rtype: int + """ + bits = ((1 << mask) - 1) << (totalbits - mask) + return bits + + def address_in_network(ip, net): """This function allows you to check if an IP belongs to a network subnet Example: returns True if ip = 192.168.1.1 and net = 192.168.1.0/24 returns False if ip = 192.168.1.1 and net = 192.168.100.0/24 + returns True if ip = 1:2:3:4::1 and net = 1:2:3:4::/64 :rtype: bool """ - ipaddr = struct.unpack('=L', socket.inet_aton(ip))[0] netaddr, bits = net.split('/') - netmask = struct.unpack('=L', socket.inet_aton(dotted_netmask(int(bits))))[0] - network = struct.unpack('=L', socket.inet_aton(netaddr))[0] & netmask + if is_ipv4_address(ip) and is_ipv4_address(netaddr): + ipaddr = struct.unpack(">L", socket.inet_aton(ip))[0] + netmask = _get_mask_bits(int(bits)) + network = struct.unpack(">L", socket.inet_aton(netaddr))[0] + elif is_ipv6_address(ip) and is_ipv6_address(netaddr): + ipaddr_msb, ipaddr_lsb = struct.unpack( + ">QQ", socket.inet_pton(socket.AF_INET6, ip) + ) + ipaddr = (ipaddr_msb << 64) ^ ipaddr_lsb + netmask = _get_mask_bits(int(bits), 128) + network_msb, network_lsb = struct.unpack( + ">QQ", socket.inet_pton(socket.AF_INET6, netaddr) + ) + network = (network_msb << 64) ^ network_lsb + else: + return False return (ipaddr & netmask) == (network & netmask) @@ -654,12 +682,39 @@ def is_ipv4_address(string_ip): :rtype: bool """ try: - socket.inet_aton(string_ip) + socket.inet_pton(socket.AF_INET, string_ip) + except socket.error: + return False + return True + + +def is_ipv6_address(string_ip): + """ + :rtype: bool + """ + try: + socket.inet_pton(socket.AF_INET6, string_ip) except socket.error: return False return True +def compare_ips(a, b): + """ + Compare 2 IP's, uses socket.inet_pton to normalize IPv6 IPs + + :rtype: bool + """ + if a == b: + return True + try: + return socket.inet_pton(socket.AF_INET6, a) == socket.inet_pton( + socket.AF_INET6, b + ) + except OSError: + return False + + def is_valid_cidr(string_network): """ Very simple check of the cidr format in no_proxy variable. @@ -667,17 +722,19 @@ def is_valid_cidr(string_network): :rtype: bool """ if string_network.count('/') == 1: + address, mask = string_network.split("/") try: - mask = int(string_network.split('/')[1]) + mask = int(mask) except ValueError: return False - if mask < 1 or mask > 32: - return False - - try: - socket.inet_aton(string_network.split('/')[0]) - except socket.error: + if is_ipv4_address(address): + if mask < 1 or mask > 32: + return False + elif is_ipv6_address(address): + if mask < 1 or mask > 128: + return False + else: return False else: return False @@ -734,12 +791,12 @@ def should_bypass_proxies(url, no_proxy): host for host in no_proxy.replace(' ', '').split(',') if host ) - if is_ipv4_address(parsed.hostname): + if is_ipv4_address(parsed.hostname) or is_ipv6_address(parsed.hostname): for proxy_ip in no_proxy: if is_valid_cidr(proxy_ip): if address_in_network(parsed.hostname, proxy_ip): return True - elif parsed.hostname == proxy_ip: + elif compare_ips(parsed.hostname, proxy_ip): # If no_proxy ip was defined in plain IP notation instead of cidr notation & # matches the IP of the index return True diff --git a/tests/test_utils.py b/tests/test_utils.py index 463516b..4ce139a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,7 +21,7 @@ from requests.utils import ( requote_uri, select_proxy, should_bypass_proxies, super_len, to_key_val_list, to_native_string, unquote_header_value, unquote_unreserved, - urldefragauth, add_dict_to_cookiejar, set_environ) + urldefragauth, add_dict_to_cookiejar, set_environ, _get_mask_bits, compare_ips) from requests._internal_utils import unicode_is_ascii from .compat import StringIO, cStringIO @@ -215,9 +215,15 @@ class TestIsIPv4Address: class TestIsValidCIDR: - - def test_valid(self): - assert is_valid_cidr('192.168.1.0/24') + @pytest.mark.parametrize( + "value", + ( + "192.168.1.0/24", + "1:2:3:4::/64", + ), + ) + def test_valid(self, value): + assert is_valid_cidr(value) @pytest.mark.parametrize( 'value', ( @@ -226,6 +232,11 @@ class TestIsValidCIDR: '192.168.1.0/128', '192.168.1.0/-1', '192.168.1.999/24', + "1:2:3:4::1", + "1:2:3:4::/a", + "1:2:3:4::0/321", + "1:2:3:4::/-1", + "1:2:3:4::12211/64", )) def test_invalid(self, value): assert not is_valid_cidr(value) @@ -239,6 +250,12 @@ class TestAddressInNetwork: def test_invalid(self): assert not address_in_network('172.16.0.1', '192.168.1.0/24') + def test_valid_v6(self): + assert address_in_network("1:2:3:4::1111", "1:2:3:4::/64") + + def test_invalid_v6(self): + assert not address_in_network("1:2:3:4:1111", "1:2:3:4::/124") + class TestGuessFilename: @@ -624,13 +641,18 @@ def test_urldefragauth(url, expected): ('http://172.16.1.12:5000/', False), ('http://google.com:5000/v1.0/', False), ('file:///some/path/on/disk', True), + ("http://[1:2:3:4:5:6:7:8]:5000/", True), + ("http://[1:2:3:4::1]/", True), + ("http://[1:2:3:9::1]/", True), + ("http://[1:2:3:9:0:0:0:1]/", True), + ("http://[1:2:3:9::2]/", False), )) def test_should_bypass_proxies(url, expected, monkeypatch): """Tests for function should_bypass_proxies to check if proxy can be bypassed or not """ - monkeypatch.setenv('no_proxy', '192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000') - monkeypatch.setenv('NO_PROXY', '192.168.0.0/24,127.0.0.1,localhost.localdomain,172.16.1.1, google.com:6000') + monkeypatch.setenv('no_proxy', '192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000') + monkeypatch.setenv('NO_PROXY', '192.168.0.0/24,127.0.0.1,localhost.localdomain,1:2:3:4::/64,1:2:3:9::1,172.16.1.1, google.com:6000') assert should_bypass_proxies(url, no_proxy=None) == expected @@ -781,3 +803,36 @@ def test_set_environ_raises_exception(): raise Exception('Expected exception') assert 'Expected exception' in str(exception.value) + + +@pytest.mark.parametrize( + "mask, totalbits, maskbits", + ( + (24, None, 0xFFFFFF00), + (31, None, 0xFFFFFFFE), + (0, None, 0x0), + (4, 4, 0xF), + (24, 128, 0xFFFFFF00000000000000000000000000), + ), +) +def test__get_mask_bits(mask, totalbits, maskbits): + args = {"mask": mask} + if totalbits: + args["totalbits"] = totalbits + assert _get_mask_bits(**args) == maskbits + + +@pytest.mark.parametrize( + "a, b, expected", + ( + ("1.2.3.4", "1.2.3.4", True), + ("1.2.3.4", "2.2.3.4", False), + ("1::4", "1.2.3.4", False), + ("1::4", "1::4", True), + ("1::4", "1:0:0:0:0:0:0:4", True), + ("1::4", "1:0:0:0:0:0::4", True), + ("1::4", "1:0:0:0:0:0:1:4", False), + ), +) +def test_compare_ips(a, b, expected): + assert compare_ips(a, b) == expected -- 2.43.0