From 82f1ea4b72be40f58fd0a9a37f8d8d2f7d16f9e0 Mon Sep 17 00:00:00 2001
From: Lumir Balhar <lbalhar@redhat.com>
Date: Wed, 24 Apr 2024 00:19:23 +0200
Subject: [PATCH] CVE-2023-6597

Co-authored-by: Søren Løvborg <sorenl@unity3d.com>
Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
---
 Lib/tempfile.py           |  44 +++++++++-
 Lib/test/test_tempfile.py | 166 +++++++++++++++++++++++++++++++++++---
 2 files changed, 199 insertions(+), 11 deletions(-)

diff --git a/Lib/tempfile.py b/Lib/tempfile.py
index 2cb5434..d79b70c 100644
--- a/Lib/tempfile.py
+++ b/Lib/tempfile.py
@@ -276,6 +276,23 @@ def _mkstemp_inner(dir, pre, suf, flags, output_type):
                           "No usable temporary file name found")
 
 
+def _dont_follow_symlinks(func, path, *args):
+    # Pass follow_symlinks=False, unless not supported on this platform.
+    if func in _os.supports_follow_symlinks:
+        func(path, *args, follow_symlinks=False)
+    elif _os.name == 'nt' or not _os.path.islink(path):
+        func(path, *args)
+
+
+def _resetperms(path):
+    try:
+        chflags = _os.chflags
+    except AttributeError:
+        pass
+    else:
+        _dont_follow_symlinks(chflags, path, 0)
+    _dont_follow_symlinks(_os.chmod, path, 0o700)
+
 # User visible interfaces.
 
 def gettempprefix():
@@ -794,9 +811,32 @@ class TemporaryDirectory(object):
             self, self._cleanup, self.name,
             warn_message="Implicitly cleaning up {!r}".format(self))
 
+    @classmethod
+    def _rmtree(cls, name):
+        def onerror(func, path, exc_info):
+            if issubclass(exc_info[0], PermissionError):
+                try:
+                    if path != name:
+                        _resetperms(_os.path.dirname(path))
+                    _resetperms(path)
+
+                    try:
+                        _os.unlink(path)
+                    # PermissionError is raised on FreeBSD for directories
+                    except (IsADirectoryError, PermissionError):
+                        cls._rmtree(path)
+                except FileNotFoundError:
+                    pass
+            elif issubclass(exc_info[0], FileNotFoundError):
+                pass
+            else:
+                raise
+
+        _shutil.rmtree(name, onerror=onerror)
+
     @classmethod
     def _cleanup(cls, name, warn_message):
-        _shutil.rmtree(name)
+        cls._rmtree(name)
         _warnings.warn(warn_message, ResourceWarning)
 
     def __repr__(self):
@@ -810,4 +850,4 @@ class TemporaryDirectory(object):
 
     def cleanup(self):
         if self._finalizer.detach():
-            _shutil.rmtree(self.name)
+            self._rmtree(self.name)
diff --git a/Lib/test/test_tempfile.py b/Lib/test/test_tempfile.py
index 710756b..c5560e1 100644
--- a/Lib/test/test_tempfile.py
+++ b/Lib/test/test_tempfile.py
@@ -1298,19 +1298,25 @@ class NulledModules:
 class TestTemporaryDirectory(BaseTestCase):
     """Test TemporaryDirectory()."""
 
-    def do_create(self, dir=None, pre="", suf="", recurse=1):
+    def do_create(self, dir=None, pre="", suf="", recurse=1, dirs=1, files=1):
         if dir is None:
             dir = tempfile.gettempdir()
         tmp = tempfile.TemporaryDirectory(dir=dir, prefix=pre, suffix=suf)
         self.nameCheck(tmp.name, dir, pre, suf)
-        # Create a subdirectory and some files
-        if recurse:
-            d1 = self.do_create(tmp.name, pre, suf, recurse-1)
-            d1.name = None
-        with open(os.path.join(tmp.name, "test.txt"), "wb") as f:
-            f.write(b"Hello world!")
+        self.do_create2(tmp.name, recurse, dirs, files)
         return tmp
 
+    def do_create2(self, path, recurse=1, dirs=1, files=1):
+        # Create subdirectories and some files
+        if recurse:
+            for i in range(dirs):
+                name = os.path.join(path, "dir%d" % i)
+                os.mkdir(name)
+                self.do_create2(name, recurse-1, dirs, files)
+        for i in range(files):
+            with open(os.path.join(path, "test%d.txt" % i), "wb") as f:
+                f.write(b"Hello world!")
+
     def test_mkdtemp_failure(self):
         # Check no additional exception if mkdtemp fails
         # Previously would raise AttributeError instead
@@ -1350,11 +1356,108 @@ class TestTemporaryDirectory(BaseTestCase):
                          "TemporaryDirectory %s exists after cleanup" % d1.name)
         self.assertTrue(os.path.exists(d2.name),
                         "Directory pointed to by a symlink was deleted")
-        self.assertEqual(os.listdir(d2.name), ['test.txt'],
+        self.assertEqual(os.listdir(d2.name), ['test0.txt'],
                          "Contents of the directory pointed to by a symlink "
                          "were deleted")
         d2.cleanup()
 
+    @support.skip_unless_symlink
+    def test_cleanup_with_symlink_modes(self):
+        # cleanup() should not follow symlinks when fixing mode bits (#91133)
+        with self.do_create(recurse=0) as d2:
+            file1 = os.path.join(d2, 'file1')
+            open(file1, 'wb').close()
+            dir1 = os.path.join(d2, 'dir1')
+            os.mkdir(dir1)
+            for mode in range(8):
+                mode <<= 6
+                with self.subTest(mode=format(mode, '03o')):
+                    def test(target, target_is_directory):
+                        d1 = self.do_create(recurse=0)
+                        symlink = os.path.join(d1.name, 'symlink')
+                        os.symlink(target, symlink,
+                                target_is_directory=target_is_directory)
+                        try:
+                            os.chmod(symlink, mode, follow_symlinks=False)
+                        except NotImplementedError:
+                            pass
+                        try:
+                            os.chmod(symlink, mode)
+                        except FileNotFoundError:
+                            pass
+                        os.chmod(d1.name, mode)
+                        d1.cleanup()
+                        self.assertFalse(os.path.exists(d1.name))
+
+                    with self.subTest('nonexisting file'):
+                        test('nonexisting', target_is_directory=False)
+                    with self.subTest('nonexisting dir'):
+                        test('nonexisting', target_is_directory=True)
+
+                    with self.subTest('existing file'):
+                        os.chmod(file1, mode)
+                        old_mode = os.stat(file1).st_mode
+                        test(file1, target_is_directory=False)
+                        new_mode = os.stat(file1).st_mode
+                        self.assertEqual(new_mode, old_mode,
+                                         '%03o != %03o' % (new_mode, old_mode))
+
+                    with self.subTest('existing dir'):
+                        os.chmod(dir1, mode)
+                        old_mode = os.stat(dir1).st_mode
+                        test(dir1, target_is_directory=True)
+                        new_mode = os.stat(dir1).st_mode
+                        self.assertEqual(new_mode, old_mode,
+                                         '%03o != %03o' % (new_mode, old_mode))
+
+    @unittest.skipUnless(hasattr(os, 'chflags'), 'requires os.chflags')
+    @support.skip_unless_symlink
+    def test_cleanup_with_symlink_flags(self):
+        # cleanup() should not follow symlinks when fixing flags (#91133)
+        flags = stat.UF_IMMUTABLE | stat.UF_NOUNLINK
+        self.check_flags(flags)
+
+        with self.do_create(recurse=0) as d2:
+            file1 = os.path.join(d2, 'file1')
+            open(file1, 'wb').close()
+            dir1 = os.path.join(d2, 'dir1')
+            os.mkdir(dir1)
+            def test(target, target_is_directory):
+                d1 = self.do_create(recurse=0)
+                symlink = os.path.join(d1.name, 'symlink')
+                os.symlink(target, symlink,
+                           target_is_directory=target_is_directory)
+                try:
+                    os.chflags(symlink, flags, follow_symlinks=False)
+                except NotImplementedError:
+                    pass
+                try:
+                    os.chflags(symlink, flags)
+                except FileNotFoundError:
+                    pass
+                os.chflags(d1.name, flags)
+                d1.cleanup()
+                self.assertFalse(os.path.exists(d1.name))
+
+            with self.subTest('nonexisting file'):
+                test('nonexisting', target_is_directory=False)
+            with self.subTest('nonexisting dir'):
+                test('nonexisting', target_is_directory=True)
+
+            with self.subTest('existing file'):
+                os.chflags(file1, flags)
+                old_flags = os.stat(file1).st_flags
+                test(file1, target_is_directory=False)
+                new_flags = os.stat(file1).st_flags
+                self.assertEqual(new_flags, old_flags)
+
+            with self.subTest('existing dir'):
+                os.chflags(dir1, flags)
+                old_flags = os.stat(dir1).st_flags
+                test(dir1, target_is_directory=True)
+                new_flags = os.stat(dir1).st_flags
+                self.assertEqual(new_flags, old_flags)
+
     @support.cpython_only
     def test_del_on_collection(self):
         # A TemporaryDirectory is deleted when garbage collected
@@ -1385,7 +1488,7 @@ class TestTemporaryDirectory(BaseTestCase):
 
                     tmp2 = os.path.join(tmp.name, 'test_dir')
                     os.mkdir(tmp2)
-                    with open(os.path.join(tmp2, "test.txt"), "w") as f:
+                    with open(os.path.join(tmp2, "test0.txt"), "w") as f:
                         f.write("Hello world!")
 
                     {mod}.tmp = tmp
@@ -1453,6 +1556,51 @@ class TestTemporaryDirectory(BaseTestCase):
             self.assertEqual(name, d.name)
         self.assertFalse(os.path.exists(name))
 
+    def test_modes(self):
+        for mode in range(8):
+            mode <<= 6
+            with self.subTest(mode=format(mode, '03o')):
+                d = self.do_create(recurse=3, dirs=2, files=2)
+                with d:
+                    # Change files and directories mode recursively.
+                    for root, dirs, files in os.walk(d.name, topdown=False):
+                        for name in files:
+                            os.chmod(os.path.join(root, name), mode)
+                        os.chmod(root, mode)
+                    d.cleanup()
+                self.assertFalse(os.path.exists(d.name))
+
+    def check_flags(self, flags):
+        # skip the test if these flags are not supported (ex: FreeBSD 13)
+        filename = support.TESTFN
+        try:
+            open(filename, "w").close()
+            try:
+                os.chflags(filename, flags)
+            except OSError as exc:
+                # "OSError: [Errno 45] Operation not supported"
+                self.skipTest(f"chflags() doesn't support flags "
+                              f"{flags:#b}: {exc}")
+            else:
+                os.chflags(filename, 0)
+        finally:
+            support.unlink(filename)
+
+    @unittest.skipUnless(hasattr(os, 'chflags'), 'requires os.lchflags')
+    def test_flags(self):
+        flags = stat.UF_IMMUTABLE | stat.UF_NOUNLINK
+        self.check_flags(flags)
+
+        d = self.do_create(recurse=3, dirs=2, files=2)
+        with d:
+            # Change files and directories flags recursively.
+            for root, dirs, files in os.walk(d.name, topdown=False):
+                for name in files:
+                    os.chflags(os.path.join(root, name), flags)
+                os.chflags(root, flags)
+            d.cleanup()
+        self.assertFalse(os.path.exists(d.name))
+
 
 if __name__ == "__main__":
     unittest.main()
-- 
2.44.0