From 50240c828da96724d76c14dd0e71b78bc2dce657 Mon Sep 17 00:00:00 2001
From: Dan Streetman <ddstreet@ieee.org>
Date: Wed, 5 Jul 2023 12:59:47 -0400
Subject: [PATCH] openssl: add openssl_digest_size()

Add function to get digest hash size for provided digest name.

(cherry picked from commit c52a003dc86bb91b2724a00449a50b26009fdfd0)

Related: RHEL-16182
---
 src/shared/openssl-util.c | 31 +++++++++++++++++++++++
 src/shared/openssl-util.h |  3 +++
 src/test/test-openssl.c   | 52 +++++++++++++++++++++++++++++++++++++++
 3 files changed, 86 insertions(+)

diff --git a/src/shared/openssl-util.c b/src/shared/openssl-util.c
index 3d3d8090f8..ecdb418402 100644
--- a/src/shared/openssl-util.c
+++ b/src/shared/openssl-util.c
@@ -87,6 +87,37 @@ int openssl_hash(const EVP_MD *alg,
         return 0;
 }
 
+/* Returns the number of bytes generated by the specified digest algorithm. This can be used only for
+ * fixed-size algorithms, e.g. md5, sha1, sha256, etc. Do not use this for variable-sized digest algorithms,
+ * e.g. shake128. Returns 0 on success, -EOPNOTSUPP if the algorithm is not supported, or < 0 for any other
+ * error. */
+int openssl_digest_size(const char *digest_alg, size_t *ret_digest_size) {
+        assert(digest_alg);
+        assert(ret_digest_size);
+
+#if OPENSSL_VERSION_MAJOR >= 3
+        _cleanup_(EVP_MD_freep) EVP_MD *md = EVP_MD_fetch(NULL, digest_alg, NULL);
+#else
+        const EVP_MD *md = EVP_get_digestbyname(digest_alg);
+#endif
+        if (!md)
+                return log_debug_errno(SYNTHETIC_ERRNO(EOPNOTSUPP),
+                                       "Digest algorithm '%s' not supported.", digest_alg);
+
+        size_t digest_size;
+#if OPENSSL_VERSION_MAJOR >= 3
+        digest_size = EVP_MD_get_size(md);
+#else
+        digest_size = EVP_MD_size(md);
+#endif
+        if (digest_size == 0)
+                return log_openssl_errors("Failed to get Digest size");
+
+        *ret_digest_size = digest_size;
+
+        return 0;
+}
+
 int rsa_encrypt_bytes(
                 EVP_PKEY *pkey,
                 const void *decrypted_key,
diff --git a/src/shared/openssl-util.h b/src/shared/openssl-util.h
index 90158f589b..309dc16805 100644
--- a/src/shared/openssl-util.h
+++ b/src/shared/openssl-util.h
@@ -39,6 +39,7 @@ DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(SSL*, SSL_free, NULL);
 DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(BIO*, BIO_free, NULL);
 DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(EVP_MD_CTX*, EVP_MD_CTX_free, NULL);
 #if OPENSSL_VERSION_MAJOR >= 3
+DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(EVP_MD*, EVP_MD_free, NULL);
 DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(OSSL_PARAM*, OSSL_PARAM_free, NULL);
 DEFINE_TRIVIAL_CLEANUP_FUNC_FULL(OSSL_PARAM_BLD*, OSSL_PARAM_BLD_free, NULL);
 #else
@@ -57,6 +58,8 @@ int openssl_pkey_from_pem(const void *pem, size_t pem_size, EVP_PKEY **ret);
 
 int openssl_hash(const EVP_MD *alg, const void *msg, size_t msg_len, uint8_t *ret_hash, size_t *ret_hash_len);
 
+int openssl_digest_size(const char *digest_alg, size_t *ret_digest_size);
+
 int rsa_encrypt_bytes(EVP_PKEY *pkey, const void *decrypted_key, size_t decrypted_key_size, void **ret_encrypt_key, size_t *ret_encrypt_key_size);
 
 int rsa_pkey_to_suitable_key_size(EVP_PKEY *pkey, size_t *ret_suitable_key_size);
diff --git a/src/test/test-openssl.c b/src/test/test-openssl.c
index c46ecdcda8..a8a2b534a4 100644
--- a/src/test/test-openssl.c
+++ b/src/test/test-openssl.c
@@ -2,6 +2,7 @@
 
 #include "hexdecoct.h"
 #include "openssl-util.h"
+#include "string-util.h"
 #include "tests.h"
 
 TEST(openssl_pkey_from_pem) {
@@ -102,4 +103,55 @@ TEST(invalid) {
         assert_se(pkey == NULL);
 }
 
+static const struct {
+        const char *alg;
+        size_t size;
+} digest_size_table[] = {
+        /* SHA1 "family" */
+        { "sha1",     20, },
+#if OPENSSL_VERSION_MAJOR >= 3
+        { "sha-1",    20, },
+#endif
+        /* SHA2 family */
+        { "sha224",   28, },
+        { "sha256",   32, },
+        { "sha384",   48, },
+        { "sha512",   64, },
+#if OPENSSL_VERSION_MAJOR >= 3
+        { "sha-224",  28, },
+        { "sha2-224", 28, },
+        { "sha-256",  32, },
+        { "sha2-256", 32, },
+        { "sha-384",  48, },
+        { "sha2-384", 48, },
+        { "sha-512",  64, },
+        { "sha2-512", 64, },
+#endif
+        /* SHA3 family */
+        { "sha3-224", 28, },
+        { "sha3-256", 32, },
+        { "sha3-384", 48, },
+        { "sha3-512", 64, },
+        /* SM3 family */
+        { "sm3",      32, },
+        /* MD5 family */
+        { "md5",      16, },
+};
+
+TEST(digest_size) {
+        size_t size;
+
+        FOREACH_ARRAY(t, digest_size_table, ELEMENTSOF(digest_size_table)) {
+                assert(openssl_digest_size(t->alg, &size) >= 0);
+                assert_se(size == t->size);
+
+                _cleanup_free_ char *uppercase_alg = strdup(t->alg);
+                assert_se(uppercase_alg);
+                assert_se(openssl_digest_size(ascii_strupper(uppercase_alg), &size) >= 0);
+                assert_se(size == t->size);
+        }
+
+        assert_se(openssl_digest_size("invalid.alg", &size) == -EOPNOTSUPP);
+}
+
 DEFINE_TEST_MAIN(LOG_DEBUG);