From 7c95191434415d1c9b7fe9b130df13cce630b6b5 Mon Sep 17 00:00:00 2001
From: Matt Caswell <matt@openssl.org>
Date: Fri, 21 Jun 2024 10:09:41 +0100
Subject: [PATCH 09/10] Add explicit testing of ALN and NPN in sslapitest

We already had some tests elsewhere - but this extends that testing with
additional tests.

Follow on from CVE-2024-5535

Reviewed-by: Neil Horman <nhorman@openssl.org>
Reviewed-by: Tomas Mraz <tomas@openssl.org>
(Merged from https://github.com/openssl/openssl/pull/24717)
---
 test/sslapitest.c | 229 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 229 insertions(+)

diff --git a/test/sslapitest.c b/test/sslapitest.c
index 15cb9060cb..7a55a2b721 100644
--- a/test/sslapitest.c
+++ b/test/sslapitest.c
@@ -11877,6 +11877,231 @@ static int test_select_next_proto(int idx)
     return ret;
 }
 
+static const unsigned char fooprot[] = {3, 'f', 'o', 'o' };
+static const unsigned char barprot[] = {3, 'b', 'a', 'r' };
+
+#if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_NEXTPROTONEG)
+static int npn_advert_cb(SSL *ssl, const unsigned char **out,
+                         unsigned int *outlen, void *arg)
+{
+    int *idx = (int *)arg;
+
+    switch (*idx) {
+    default:
+    case 0:
+        *out = fooprot;
+        *outlen = sizeof(fooprot);
+        return SSL_TLSEXT_ERR_OK;
+
+    case 1:
+        *outlen = 0;
+        return SSL_TLSEXT_ERR_OK;
+
+    case 2:
+        return SSL_TLSEXT_ERR_NOACK;
+    }
+}
+
+static int npn_select_cb(SSL *s, unsigned char **out, unsigned char *outlen,
+                         const unsigned char *in, unsigned int inlen, void *arg)
+{
+    int *idx = (int *)arg;
+
+    switch (*idx) {
+    case 0:
+    case 1:
+        *out = (unsigned char *)(fooprot + 1);
+        *outlen = *fooprot;
+        return SSL_TLSEXT_ERR_OK;
+
+    case 3:
+        *out = (unsigned char *)(barprot + 1);
+        *outlen = *barprot;
+        return SSL_TLSEXT_ERR_OK;
+
+    case 4:
+        *outlen = 0;
+        return SSL_TLSEXT_ERR_OK;
+
+    default:
+    case 2:
+        return SSL_TLSEXT_ERR_ALERT_FATAL;
+    }
+}
+
+/*
+ * Test the NPN callbacks
+ * Test 0: advert = foo, select = foo
+ * Test 1: advert = <empty>, select = foo
+ * Test 2: no advert
+ * Test 3: advert = foo, select = bar
+ * Test 4: advert = foo, select = <empty> (should fail)
+ */
+static int test_npn(int idx)
+{
+    SSL_CTX *sctx = NULL, *cctx = NULL;
+    SSL *serverssl = NULL, *clientssl = NULL;
+    int testresult = 0;
+
+    if (!TEST_true(create_ssl_ctx_pair(libctx, TLS_server_method(),
+                                       TLS_client_method(), 0, TLS1_2_VERSION,
+                                       &sctx, &cctx, cert, privkey)))
+        goto end;
+
+    SSL_CTX_set_next_protos_advertised_cb(sctx, npn_advert_cb, &idx);
+    SSL_CTX_set_next_proto_select_cb(cctx, npn_select_cb, &idx);
+
+    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL,
+                                      NULL)))
+        goto end;
+
+    if (idx == 4) {
+        /* We don't allow empty selection of NPN, so this should fail */
+        if (!TEST_false(create_ssl_connection(serverssl, clientssl,
+                                              SSL_ERROR_NONE)))
+            goto end;
+    } else {
+        const unsigned char *prot;
+        unsigned int protlen;
+
+        if (!TEST_true(create_ssl_connection(serverssl, clientssl,
+                                             SSL_ERROR_NONE)))
+            goto end;
+
+        SSL_get0_next_proto_negotiated(serverssl, &prot, &protlen);
+        switch (idx) {
+        case 0:
+        case 1:
+            if (!TEST_mem_eq(prot, protlen, fooprot + 1, *fooprot))
+                goto end;
+            break;
+        case 2:
+            if (!TEST_uint_eq(protlen, 0))
+                goto end;
+            break;
+        case 3:
+            if (!TEST_mem_eq(prot, protlen, barprot + 1, *barprot))
+                goto end;
+            break;
+        default:
+            TEST_error("Should not get here");
+            goto end;
+        }
+    }
+
+    testresult = 1;
+ end:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+
+    return testresult;
+}
+#endif /* !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_NEXTPROTONEG) */
+
+static int alpn_select_cb2(SSL *ssl, const unsigned char **out,
+                           unsigned char *outlen, const unsigned char *in,
+                           unsigned int inlen, void *arg)
+{
+    int *idx = (int *)arg;
+
+    switch (*idx) {
+    case 0:
+        *out = (unsigned char *)(fooprot + 1);
+        *outlen = *fooprot;
+        return SSL_TLSEXT_ERR_OK;
+
+    case 2:
+        *out = (unsigned char *)(barprot + 1);
+        *outlen = *barprot;
+        return SSL_TLSEXT_ERR_OK;
+
+    case 3:
+        *outlen = 0;
+        return SSL_TLSEXT_ERR_OK;
+
+    default:
+    case 1:
+        return SSL_TLSEXT_ERR_ALERT_FATAL;
+    }
+    return 0;
+}
+
+/*
+ * Test the ALPN callbacks
+ * Test 0: client = foo, select = foo
+ * Test 1: client = <empty>, select = none
+ * Test 2: client = foo, select = bar (should fail)
+ * Test 3: client = foo, select = <empty> (should fail)
+ */
+static int test_alpn(int idx)
+{
+    SSL_CTX *sctx = NULL, *cctx = NULL;
+    SSL *serverssl = NULL, *clientssl = NULL;
+    int testresult = 0;
+    const unsigned char *prots = fooprot;
+    unsigned int protslen = sizeof(fooprot);
+
+    if (!TEST_true(create_ssl_ctx_pair(libctx, TLS_server_method(),
+                                       TLS_client_method(), 0, 0,
+                                       &sctx, &cctx, cert, privkey)))
+        goto end;
+
+    SSL_CTX_set_alpn_select_cb(sctx, alpn_select_cb2, &idx);
+
+    if (!TEST_true(create_ssl_objects(sctx, cctx, &serverssl, &clientssl, NULL,
+                                      NULL)))
+        goto end;
+
+    if (idx == 1) {
+        prots = NULL;
+        protslen = 0;
+    }
+
+    /* SSL_set_alpn_protos returns 0 for success! */
+    if (!TEST_false(SSL_set_alpn_protos(clientssl, prots, protslen)))
+        goto end;
+
+    if (idx == 2 || idx == 3) {
+        /* We don't allow empty selection of NPN, so this should fail */
+        if (!TEST_false(create_ssl_connection(serverssl, clientssl,
+                                              SSL_ERROR_NONE)))
+            goto end;
+    } else {
+        const unsigned char *prot;
+        unsigned int protlen;
+
+        if (!TEST_true(create_ssl_connection(serverssl, clientssl,
+                                             SSL_ERROR_NONE)))
+            goto end;
+
+        SSL_get0_alpn_selected(clientssl, &prot, &protlen);
+        switch (idx) {
+        case 0:
+            if (!TEST_mem_eq(prot, protlen, fooprot + 1, *fooprot))
+                goto end;
+            break;
+        case 1:
+            if (!TEST_uint_eq(protlen, 0))
+                goto end;
+            break;
+        default:
+            TEST_error("Should not get here");
+            goto end;
+        }
+    }
+
+    testresult = 1;
+ end:
+    SSL_free(serverssl);
+    SSL_free(clientssl);
+    SSL_CTX_free(sctx);
+    SSL_CTX_free(cctx);
+
+    return testresult;
+}
+
 OPT_TEST_DECLARE_USAGE("certfile privkeyfile srpvfile tmpfile provider config dhfile\n")
 
 int setup_tests(void)
@@ -12190,6 +12415,10 @@ int setup_tests(void)
     ADD_TEST(test_data_retry);
     ADD_ALL_TESTS(test_multi_resume, 5);
     ADD_ALL_TESTS(test_select_next_proto, OSSL_NELEM(next_proto_tests));
+#if !defined(OPENSSL_NO_TLS1_2) && !defined(OPENSSL_NO_NEXTPROTONEG)
+    ADD_ALL_TESTS(test_npn, 5);
+#endif
+    ADD_ALL_TESTS(test_alpn, 4);
     return 1;
 
  err:
-- 
2.46.0