From 88ca5bab81ac55ebfffd719de49cf041615c5597 Mon Sep 17 00:00:00 2001 From: John Safranek Date: Thu, 16 Apr 2026 17:17:39 -0700 Subject: [PATCH 1/2] First packet follows check needs pubkey guess When processing the KEX Init message, stash guesses for the peer's KEX and public key algorithms. When reading first_packet_follows, if set check the guesses and set the handshake info flag ignoreNextKexMsg. When processing the KexDhInit message, check that flag. Affected functions: DoKexInit, DoKexDhInit. Issue: F-1686 --- src/internal.c | 25 +++++++++++++++++-------- wolfssh/internal.h | 3 +-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/internal.c b/src/internal.c index b13ff4426..f1f4c9501 100644 --- a/src/internal.c +++ b/src/internal.c @@ -571,6 +571,7 @@ static HandshakeInfo* HandshakeInfoNew(void* heap) heap, DYNTYPE_HS); if (newHs != NULL) { WMEMSET(newHs, 0, sizeof(HandshakeInfo)); + newHs->expectMsgId = MSGID_NONE; newHs->kexId = ID_NONE; newHs->kexHashId = WC_HASH_TYPE_NONE; newHs->pubKeyId = ID_NONE; @@ -4238,6 +4239,9 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) byte algoId; byte list[24] = {ID_NONE}; byte cannedList[24] = {ID_NONE}; + byte kexIdGuess = ID_NONE; + byte pubKeyIdGuess = ID_NONE; + byte kexPacketFollows = 0; word32 listSz; word32 cannedListSz; word32 cannedAlgoNamesSz; @@ -4309,7 +4313,7 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) (const byte*)ssh->algoListKex, cannedAlgoNamesSz); } if (ret == WS_SUCCESS) { - ssh->handshake->kexIdGuess = list[0]; + kexIdGuess = list[0]; algoId = MatchIdLists(side, list, listSz, cannedList, cannedListSz); if (algoId == ID_UNKNOWN) { @@ -4354,6 +4358,7 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) } } if (ret == WS_SUCCESS) { + pubKeyIdGuess = list[0]; algoId = MatchIdLists(side, list, listSz, cannedList, cannedListSz); if (algoId == ID_UNKNOWN) { WLOG(WS_LOG_DEBUG, "Unable to negotiate Server Host Key Algo"); @@ -4511,10 +4516,15 @@ static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) /* First KEX Packet Follows */ if (ret == WS_SUCCESS) { WLOG(WS_LOG_DEBUG, "DKI: KEX Packet Follows"); - ret = GetBoolean(&ssh->handshake->kexPacketFollows, buf, len, &begin); + ret = GetBoolean(&kexPacketFollows, buf, len, &begin); if (ret == WS_SUCCESS) { WLOG(WS_LOG_DEBUG, " packet follows: %s", - ssh->handshake->kexPacketFollows ? "yes" : "no"); + kexPacketFollows ? "yes" : "no"); + if (kexPacketFollows + && (kexIdGuess != ssh->handshake->kexId + || pubKeyIdGuess != ssh->handshake->pubKeyId)) { + ssh->handshake->ignoreNextKexMsg = 1; + } } } @@ -4819,12 +4829,11 @@ static int DoKexDhInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) ret = WS_BAD_ARGUMENT; if (ret == WS_SUCCESS) { - if (ssh->handshake->kexPacketFollows - && ssh->handshake->kexIdGuess != ssh->handshake->kexId) { - + if (ssh->handshake->ignoreNextKexMsg) { /* skip this message. */ - WLOG(WS_LOG_DEBUG, "Skipping the client's KEX init function."); - ssh->handshake->kexPacketFollows = 0; + WLOG(WS_LOG_DEBUG, "Skipping client's KEXDH_INIT message due to " + "first_packet_follows guess mismatch."); + ssh->handshake->ignoreNextKexMsg = 0; *idx += len; return WS_SUCCESS; } diff --git a/wolfssh/internal.h b/wolfssh/internal.h index 9d204137f..d5c46c3d1 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -628,12 +628,10 @@ typedef struct Keys { typedef struct HandshakeInfo { byte expectMsgId; byte kexId; - byte kexIdGuess; byte kexHashId; byte pubKeyId; byte encryptId; byte macId; - byte kexPacketFollows; byte aeadMode; byte blockSz; @@ -660,6 +658,7 @@ typedef struct HandshakeInfo { word32 generatorSz; #endif + byte ignoreNextKexMsg:1; byte useDh:1; byte useEcc:1; byte useEccMlKem:1; From ed1402a0719229e3947b7a0107ffec8e595cd8b3 Mon Sep 17 00:00:00 2001 From: John Safranek Date: Fri, 17 Apr 2026 15:18:07 -0700 Subject: [PATCH 2/2] First Kex Packet Follows Test Add a regression for checking the `first_kex_packet_follows` flag versus the guesses for KEX algorithm and public key algorithm. --- src/internal.c | 33 +++++++++ tests/regress.c | 166 +++++++++++++++++++++++++++++++++++++++++++++ wolfssh/internal.h | 6 ++ 3 files changed, 205 insertions(+) diff --git a/src/internal.c b/src/internal.c index f1f4c9501..69986c6f7 100644 --- a/src/internal.c +++ b/src/internal.c @@ -858,6 +858,30 @@ int wolfSSH_TestIsMessageAllowed(WOLFSSH* ssh, byte msg, byte state) { return IsMessageAllowed(ssh, msg, state); } + +static int DoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx); +static int DoKexDhInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx); +#ifndef WOLFSSH_NO_DH_GEX_SHA256 +static int DoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len, word32* idx); +#endif + +int wolfSSH_TestDoKexInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) +{ + return DoKexInit(ssh, buf, len, idx); +} + +int wolfSSH_TestDoKexDhInit(WOLFSSH* ssh, byte* buf, word32 len, word32* idx) +{ + return DoKexDhInit(ssh, buf, len, idx); +} + +#ifndef WOLFSSH_NO_DH_GEX_SHA256 +int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, word32 len, + word32* idx) +{ + return DoKexDhGexRequest(ssh, buf, len, idx); +} +#endif #endif @@ -6272,6 +6296,15 @@ static int DoKexDhGexRequest(WOLFSSH* ssh, ret = WS_BAD_ARGUMENT; if (ret == WS_SUCCESS) { + if (ssh->handshake->ignoreNextKexMsg) { + /* skip this message. */ + WLOG(WS_LOG_DEBUG, "Skipping client's KEXDH_GEX_REQUEST message " + "due to first_packet_follows guess mismatch."); + ssh->handshake->ignoreNextKexMsg = 0; + *idx += len; + return WS_SUCCESS; + } + begin = *idx; ret = GetUint32(&ssh->handshake->dhGexMinSz, buf, len, &begin); } diff --git a/tests/regress.c b/tests/regress.c index bf37202d1..88d2e814e 100644 --- a/tests/regress.c +++ b/tests/regress.c @@ -1886,6 +1886,167 @@ static void TestKeyboardResponseNullCtx(WOLFSSH* ssh) #endif /* WOLFSSH_KEYBOARD_INTERACTIVE */ +#if !defined(WOLFSSH_NO_ECDH_SHA2_NISTP256) \ + && !defined(WOLFSSH_NO_RSA) \ + && !defined(WOLFSSH_NO_CURVE25519_SHA256) \ + && !defined(WOLFSSH_NO_RSA_SHA2_256) + +#define FPF_KEX_GOOD "ecdh-sha2-nistp256" +#define FPF_KEX_BAD "curve25519-sha256" +#define FPF_KEY_GOOD "ssh-rsa" +#define FPF_KEY_BAD "rsa-sha2-256" + +/* Build a KEXINIT payload using the server ssh's own canned cipher/MAC lists + * so negotiation succeeds whichever AES/HMAC modes are compiled in. */ +static word32 BuildKexInitPayload(WOLFSSH* ssh, const char* kexList, + const char* keyList, byte firstPacketFollows, + byte* out, word32 outSz) +{ + word32 idx = 0; + + /* cookie */ + AssertTrue(idx + COOKIE_SZ <= outSz); + WMEMSET(out + idx, 0, COOKIE_SZ); + idx += COOKIE_SZ; + + idx = AppendString(out, outSz, idx, kexList); + idx = AppendString(out, outSz, idx, keyList); + idx = AppendString(out, outSz, idx, ssh->algoListCipher); + idx = AppendString(out, outSz, idx, ssh->algoListCipher); + idx = AppendString(out, outSz, idx, ssh->algoListMac); + idx = AppendString(out, outSz, idx, ssh->algoListMac); + idx = AppendString(out, outSz, idx, "none"); + idx = AppendString(out, outSz, idx, "none"); + idx = AppendString(out, outSz, idx, ""); + idx = AppendString(out, outSz, idx, ""); + + idx = AppendByte(out, outSz, idx, firstPacketFollows); + idx = AppendUint32(out, outSz, idx, 0); /* reserved */ + + return idx; +} + +typedef struct { + const char* description; + const char* kexList; + const char* keyList; + byte firstPacketFollows; + byte expectIgnore; +} FirstPacketFollowsCase; + +static const FirstPacketFollowsCase firstPacketFollowsCases[] = { + { "follows=0, guesses irrelevant: flag stays off", + FPF_KEX_BAD "," FPF_KEX_GOOD, FPF_KEY_BAD "," FPF_KEY_GOOD, 0, 0 }, + { "follows=1, both guesses match: do not skip", + FPF_KEX_GOOD, FPF_KEY_GOOD, 1, 0 }, + { "follows=1, KEX guess wrong: skip", + FPF_KEX_BAD "," FPF_KEX_GOOD, FPF_KEY_GOOD, 1, 1 }, + { "follows=1, host-key guess wrong: skip", /* regression case */ + FPF_KEX_GOOD, FPF_KEY_BAD "," FPF_KEY_GOOD, 1, 1 }, + { "follows=1, both guesses wrong: skip", + FPF_KEX_BAD "," FPF_KEX_GOOD, FPF_KEY_BAD "," FPF_KEY_GOOD, 1, 1 }, +}; + +static void RunFirstPacketFollowsCase(const FirstPacketFollowsCase* tc) +{ + WOLFSSH_CTX* ctx; + WOLFSSH* ssh; + byte payload[512]; + word32 payloadSz; + word32 idx = 0; + + ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_SERVER, NULL); + AssertNotNull(ctx); + + ssh = wolfSSH_new(ctx); + AssertNotNull(ssh); + + AssertIntEQ(wolfSSH_SetAlgoListKex(ssh, FPF_KEX_GOOD), WS_SUCCESS); + AssertIntEQ(wolfSSH_SetAlgoListKey(ssh, FPF_KEY_GOOD), WS_SUCCESS); + + payloadSz = BuildKexInitPayload(ssh, tc->kexList, tc->keyList, + tc->firstPacketFollows, payload, sizeof(payload)); + + /* DoKexInit's tail hashes and sends a response; on a stripped-down + * WOLFSSH without a loaded host key or a primed peer proto id, that + * tail errors. We only care about the parse path up through + * first_packet_follows, where ignoreNextKexMsg is set. */ + (void)wolfSSH_TestDoKexInit(ssh, payload, payloadSz, &idx); + + AssertNotNull(ssh->handshake); + if (ssh->handshake->ignoreNextKexMsg != tc->expectIgnore) { + Fail(("ignoreNextKexMsg == %u (%s)", + tc->expectIgnore, tc->description), + ("%u", ssh->handshake->ignoreNextKexMsg)); + } + + wolfSSH_free(ssh); + wolfSSH_CTX_free(ctx); +} + +typedef int (*FirstPacketFollowsSkipFn)(WOLFSSH* ssh, byte* buf, word32 len, + word32* idx); + +/* With ignoreNextKexMsg set, the target Do* handler must consume the packet, + * clear the flag, and not advance clientState past CLIENT_KEXINIT_DONE. */ +static void RunFirstPacketFollowsSkipCase(FirstPacketFollowsSkipFn fn, + const char* label) +{ + WOLFSSH_CTX* ctx; + WOLFSSH* ssh; + byte payload[8]; + word32 idx = 0; + int ret; + + ctx = wolfSSH_CTX_new(WOLFSSH_ENDPOINT_SERVER, NULL); + AssertNotNull(ctx); + + ssh = wolfSSH_new(ctx); + AssertNotNull(ssh); + AssertNotNull(ssh->handshake); + + ssh->handshake->ignoreNextKexMsg = 1; + ssh->clientState = CLIENT_KEXINIT_DONE; + + /* Garbage payload — must never be parsed when skipped. */ + WMEMSET(payload, 0xAB, sizeof(payload)); + + ret = fn(ssh, payload, sizeof(payload), &idx); + if (ret != WS_SUCCESS) { + Fail(("%s returns WS_SUCCESS when skipping", label), ("%d", ret)); + } + AssertIntEQ(idx, sizeof(payload)); + AssertIntEQ(ssh->handshake->ignoreNextKexMsg, 0); + AssertIntEQ(ssh->clientState, CLIENT_KEXINIT_DONE); + + wolfSSH_free(ssh); + wolfSSH_CTX_free(ctx); +} + +static void TestFirstPacketFollowsSkipped(void) +{ + RunFirstPacketFollowsSkipCase(wolfSSH_TestDoKexDhInit, "DoKexDhInit"); +#ifndef WOLFSSH_NO_DH_GEX_SHA256 + RunFirstPacketFollowsSkipCase(wolfSSH_TestDoKexDhGexRequest, + "DoKexDhGexRequest"); +#endif +} + +static void TestFirstPacketFollows(void) +{ + size_t i; + size_t n = sizeof(firstPacketFollowsCases) + / sizeof(firstPacketFollowsCases[0]); + + for (i = 0; i < n; i++) { + RunFirstPacketFollowsCase(&firstPacketFollowsCases[i]); + } + TestFirstPacketFollowsSkipped(); +} + +#endif /* first_packet_follows coverage guard */ + + int main(int argc, char** argv) { WOLFSSH_CTX* ctx; @@ -1926,6 +2087,11 @@ int main(int argc, char** argv) TestAgentChannelNullAgentSendsOpenFail(); #endif TestKexInitRejectedWhenKeying(ssh); +#if !defined(WOLFSSH_NO_ECDH_SHA2_NISTP256) && !defined(WOLFSSH_NO_RSA) \ + && !defined(WOLFSSH_NO_CURVE25519_SHA256) \ + && !defined(WOLFSSH_NO_RSA_SHA2_256) + TestFirstPacketFollows(); +#endif TestDisconnectSetsDisconnectError(); TestClientBuffersIdempotent(); TestPasswordEofNoCrash(); diff --git a/wolfssh/internal.h b/wolfssh/internal.h index d5c46c3d1..07e71eb5b 100644 --- a/wolfssh/internal.h +++ b/wolfssh/internal.h @@ -1326,7 +1326,13 @@ enum WS_MessageIdLimits { WOLFSSH_API int wolfSSH_TestIsMessageAllowed(WOLFSSH* ssh, byte msg, byte state); WOLFSSH_API int wolfSSH_TestDoReceive(WOLFSSH* ssh); + WOLFSSH_API int wolfSSH_TestDoKexInit(WOLFSSH* ssh, byte* buf, + word32 len, word32* idx); + WOLFSSH_API int wolfSSH_TestDoKexDhInit(WOLFSSH* ssh, byte* buf, + word32 len, word32* idx); #ifndef WOLFSSH_NO_DH_GEX_SHA256 + WOLFSSH_API int wolfSSH_TestDoKexDhGexRequest(WOLFSSH* ssh, byte* buf, + word32 len, word32* idx); WOLFSSH_API int wolfSSH_TestValidateKexDhGexGroup(const byte* primeGroup, word32 primeGroupSz, const byte* generator, word32 generatorSz, word32 minBits, word32 maxBits, WC_RNG* rng);