From 08c48bbe717a0be0f8c41a4652473d8483e0fbde Mon Sep 17 00:00:00 2001
From: Yu Watanabe <watanabe.yu+github@gmail.com>
Date: Tue, 23 May 2023 17:49:16 +0900
Subject: [PATCH] core/unit: update bidirectional dependency simultaneously

Previously, if unit_add_dependency_hashmap() failed, then a
one-directional unit dependency reference might be created, and
triggeres use-after-free. See issue #27742 for more details.

This makes unit dependency always bidirectional, and cleanly revert
partial update on failure.

Fixes #27742.

(cherry picked from commit 831108245eb757f41fe0ebbccf1b42c9dd0ce297)

Resolves: #2213521
---
 src/core/unit.c | 164 ++++++++++++++++++++++++++++++------------------
 1 file changed, 103 insertions(+), 61 deletions(-)

diff --git a/src/core/unit.c b/src/core/unit.c
index 6acc63e091..438213a47a 100644
--- a/src/core/unit.c
+++ b/src/core/unit.c
@@ -990,46 +990,6 @@ static int unit_per_dependency_type_hashmap_update(
         return 1;
 }
 
-static int unit_add_dependency_hashmap(
-                Hashmap **dependencies,
-                UnitDependency d,
-                Unit *other,
-                UnitDependencyMask origin_mask,
-                UnitDependencyMask destination_mask) {
-
-        Hashmap *per_type;
-        int r;
-
-        assert(dependencies);
-        assert(other);
-        assert(origin_mask < _UNIT_DEPENDENCY_MASK_FULL);
-        assert(destination_mask < _UNIT_DEPENDENCY_MASK_FULL);
-        assert(origin_mask > 0 || destination_mask > 0);
-
-        /* Ensure the top-level dependency hashmap exists that maps UnitDependency → Hashmap(Unit* →
-         * UnitDependencyInfo) */
-        r = hashmap_ensure_allocated(dependencies, NULL);
-        if (r < 0)
-                return r;
-
-        /* Acquire the inner hashmap, that maps Unit* → UnitDependencyInfo, for the specified dependency
-         * type, and if it's missing allocate it and insert it. */
-        per_type = hashmap_get(*dependencies, UNIT_DEPENDENCY_TO_PTR(d));
-        if (!per_type) {
-                per_type = hashmap_new(NULL);
-                if (!per_type)
-                        return -ENOMEM;
-
-                r = hashmap_put(*dependencies, UNIT_DEPENDENCY_TO_PTR(d), per_type);
-                if (r < 0) {
-                        hashmap_free(per_type);
-                        return r;
-                }
-        }
-
-        return unit_per_dependency_type_hashmap_update(per_type, other, origin_mask, destination_mask);
-}
-
 static void unit_merge_dependencies(Unit *u, Unit *other) {
         Hashmap *deps;
         void *dt; /* Actually of type UnitDependency, except that we don't bother casting it here,
@@ -3002,11 +2962,38 @@ bool unit_job_is_applicable(Unit *u, JobType j) {
         }
 }
 
-int unit_add_dependency(
+static Hashmap *unit_get_dependency_hashmap_per_type(Unit *u, UnitDependency d) {
+        Hashmap *deps;
+
+        assert(u);
+        assert(d >= 0 && d < _UNIT_DEPENDENCY_MAX);
+
+        deps = hashmap_get(u->dependencies, UNIT_DEPENDENCY_TO_PTR(d));
+        if (!deps) {
+                _cleanup_hashmap_free_ Hashmap *h = NULL;
+
+                h = hashmap_new(NULL);
+                if (!h)
+                        return NULL;
+
+                if (hashmap_ensure_put(&u->dependencies, NULL, UNIT_DEPENDENCY_TO_PTR(d), h) < 0)
+                        return NULL;
+
+                deps = TAKE_PTR(h);
+        }
+
+        return deps;
+}
+
+typedef enum NotifyDependencyFlags {
+        NOTIFY_DEPENDENCY_UPDATE_FROM = 1 << 0,
+        NOTIFY_DEPENDENCY_UPDATE_TO   = 1 << 1,
+} NotifyDependencyFlags;
+
+static int unit_add_dependency_impl(
                 Unit *u,
                 UnitDependency d,
                 Unit *other,
-                bool add_reference,
                 UnitDependencyMask mask) {
 
         static const UnitDependency inverse_table[_UNIT_DEPENDENCY_MAX] = {
@@ -3042,12 +3029,78 @@ int unit_add_dependency(
                 [UNIT_IN_SLICE]               = UNIT_SLICE_OF,
                 [UNIT_SLICE_OF]               = UNIT_IN_SLICE,
         };
+
+        Hashmap *u_deps, *other_deps;
+        UnitDependencyInfo u_info, u_info_old, other_info, other_info_old;
+        NotifyDependencyFlags flags = 0;
+        int r;
+
+        assert(u);
+        assert(other);
+        assert(d >= 0 && d < _UNIT_DEPENDENCY_MAX);
+        assert(inverse_table[d] >= 0 && inverse_table[d] < _UNIT_DEPENDENCY_MAX);
+        assert(mask > 0 && mask < _UNIT_DEPENDENCY_MASK_FULL);
+
+        /* Ensure the following two hashmaps for each unit exist:
+         * - the top-level dependency hashmap that maps UnitDependency → Hashmap(Unit* → UnitDependencyInfo),
+         * - the inner hashmap, that maps Unit* → UnitDependencyInfo, for the specified dependency type. */
+        u_deps = unit_get_dependency_hashmap_per_type(u, d);
+        if (!u_deps)
+                return -ENOMEM;
+
+        other_deps = unit_get_dependency_hashmap_per_type(other, inverse_table[d]);
+        if (!other_deps)
+                return -ENOMEM;
+
+        /* Save the original dependency info. */
+        u_info.data = u_info_old.data = hashmap_get(u_deps, other);
+        other_info.data = other_info_old.data = hashmap_get(other_deps, u);
+
+        /* Update dependency info. */
+        u_info.origin_mask |= mask;
+        other_info.destination_mask |= mask;
+
+        /* Save updated dependency info. */
+        if (u_info.data != u_info_old.data) {
+                r = hashmap_replace(u_deps, other, u_info.data);
+                if (r < 0)
+                        return r;
+
+                flags = NOTIFY_DEPENDENCY_UPDATE_FROM;
+        }
+
+        if (other_info.data != other_info_old.data) {
+                r = hashmap_replace(other_deps, u, other_info.data);
+                if (r < 0) {
+                        if (u_info.data != u_info_old.data) {
+                                /* Restore the old dependency. */
+                                if (u_info_old.data)
+                                        (void) hashmap_update(u_deps, other, u_info_old.data);
+                                else
+                                        hashmap_remove(u_deps, other);
+                        }
+                        return r;
+                }
+
+                flags |= NOTIFY_DEPENDENCY_UPDATE_TO;
+        }
+
+        return flags;
+}
+
+int unit_add_dependency(
+                Unit *u,
+                UnitDependency d,
+                Unit *other,
+                bool add_reference,
+                UnitDependencyMask mask) {
+
         UnitDependencyAtom a;
         int r;
 
         /* Helper to know whether sending a notification is necessary or not: if the dependency is already
          * there, no need to notify! */
-        bool notify, notify_other = false;
+        NotifyDependencyFlags notify_flags;
 
         assert(u);
         assert(d >= 0 && d < _UNIT_DEPENDENCY_MAX);
@@ -3103,35 +3156,24 @@ int unit_add_dependency(
                 return log_unit_error_errno(u, SYNTHETIC_ERRNO(EINVAL),
                                             "Requested dependency SliceOf=%s refused (%s is not a cgroup unit).", other->id, other->id);
 
-        r = unit_add_dependency_hashmap(&u->dependencies, d, other, mask, 0);
-        if (r < 0)
-                return r;
-        notify = r > 0;
-
-        assert(inverse_table[d] >= 0 && inverse_table[d] < _UNIT_DEPENDENCY_MAX);
-        r = unit_add_dependency_hashmap(&other->dependencies, inverse_table[d], u, 0, mask);
+        r = unit_add_dependency_impl(u, d, other, mask);
         if (r < 0)
                 return r;
-        notify_other = r > 0;
+        notify_flags = r;
 
         if (add_reference) {
-                r = unit_add_dependency_hashmap(&u->dependencies, UNIT_REFERENCES, other, mask, 0);
-                if (r < 0)
-                        return r;
-                notify = notify || r > 0;
-
-                r = unit_add_dependency_hashmap(&other->dependencies, UNIT_REFERENCED_BY, u, 0, mask);
+                r = unit_add_dependency_impl(u, UNIT_REFERENCES, other, mask);
                 if (r < 0)
                         return r;
-                notify_other = notify_other || r > 0;
+                notify_flags |= r;
         }
 
-        if (notify)
+        if (FLAGS_SET(notify_flags, NOTIFY_DEPENDENCY_UPDATE_FROM))
                 unit_add_to_dbus_queue(u);
-        if (notify_other)
+        if (FLAGS_SET(notify_flags, NOTIFY_DEPENDENCY_UPDATE_TO))
                 unit_add_to_dbus_queue(other);
 
-        return notify || notify_other;
+        return notify_flags != 0;
 }
 
 int unit_add_two_dependencies(Unit *u, UnitDependency d, UnitDependency e, Unit *other, bool add_reference, UnitDependencyMask mask) {