diff --git a/src/tls.c b/src/tls.c index ddadb6cf760..bd737df59c5 100644 --- a/src/tls.c +++ b/src/tls.c @@ -1707,7 +1707,9 @@ static const char *connTLSGetLastError(connection *conn_) { } static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, int barrier) { + tls_connection *tls_conn = (tls_connection *)conn; conn->write_handler = func; + if (!func) tls_conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ; if (barrier) conn->flags |= CONN_FLAG_WRITE_BARRIER; else @@ -1717,7 +1719,9 @@ static int connTLSSetWriteHandler(connection *conn, ConnectionCallbackFunc func, } static int connTLSSetReadHandler(connection *conn, ConnectionCallbackFunc func) { + tls_connection *tls_conn = (tls_connection *)conn; conn->read_handler = func; + if (!func) tls_conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE; updateSSLEvent((tls_connection *)conn); return C_OK; } diff --git a/src/unit/Makefile b/src/unit/Makefile index af669d67385..3bbd83328ec 100644 --- a/src/unit/Makefile +++ b/src/unit/Makefile @@ -171,6 +171,11 @@ ifeq ($(BUILD_TLS),yes) LD_LIBS += $(TLS_LIBS) endif +# Add systemd linking if needed based on PREV_FINAL_CFLAGS +ifneq (,$(findstring -DHAVE_LIBSYSTEMD,$(PREV_FINAL_CFLAGS))) + LD_LIBS += -lsystemd +endif + # Compile C++ test files, recompile if generated_wrappers or compiler flags change %.o: %.cpp generated_wrappers.cpp .flags $(CXX) -MD -MP -std=c++17 -faligned-new -Wno-write-strings -fpermissive $(GCC_FLAGS) $(OPT) $(DEBUG) $(TEST_CFLAGS) -Wall -Wno-deprecated-declarations -c \ diff --git a/src/unit/test_tls.cpp b/src/unit/test_tls.cpp new file mode 100644 index 00000000000..4c9a1f39ec2 --- /dev/null +++ b/src/unit/test_tls.cpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) Valkey Contributors + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include "generated_wrappers.hpp" +#include +#include +#include + +extern "C" { +#include "config.h" +#include "fmacros.h" +#include "connection.h" +#include "ae.h" +#include "server.h" + +/* Locally duplicate the opaque tls_connection to manipulate its flags */ +#define TLS_CONN_FLAG_WRITE_WANT_READ (1 << 1) + +typedef struct fake_tls_connection { + connection c; + int flags; +} fake_tls_connection; +} + +class TlsEventTest : public ::testing::Test { +protected: + void SetUp() override { + server.logfile = (char *)""; + server.el = aeCreateEventLoop(1024); + ASSERT_NE(server.el, nullptr); + connTypeInitialize(); + } + + void TearDown() override { + if (server.el) { + aeDeleteEventLoop(server.el); + server.el = nullptr; + } + connTypeCleanupAll(); + } +}; + +TEST_F(TlsEventTest, BusyLoopClearance) { + int fds[2]; + ASSERT_EQ(pipe(fds), 0); + + ConnectionType *ct = connectionTypeTls(); + if (!ct) { + close(fds[0]); + close(fds[1]); + GTEST_SKIP() << "TLS not supported in this build"; + return; + } + + fake_tls_connection *conn = (fake_tls_connection *)zcalloc(sizeof(fake_tls_connection)); + conn->c.type = ct; + conn->c.fd = fds[0]; + conn->c.state = CONN_STATE_CONNECTED; + + /* 1. Set Want Read */ + conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; + conn->c.type->set_write_handler(&conn->c, (ConnectionCallbackFunc)0xdeadbeef, 0); + + int mask = aeGetFileEvents(server.el, conn->c.fd); + ASSERT_NE(mask & AE_READABLE, 0); + + /* 2. High level clears */ + conn->c.type->set_write_handler(&conn->c, nullptr, 0); + + mask = aeGetFileEvents(server.el, conn->c.fd); + ASSERT_EQ(mask & AE_READABLE, 0); + + aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE); + aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE); + close(fds[0]); + close(fds[1]); + zfree(conn); +} diff --git a/src/valkey-cli.c b/src/valkey-cli.c index b746f4f344c..b62243529ad 100644 --- a/src/valkey-cli.c +++ b/src/valkey-cli.c @@ -91,6 +91,8 @@ #define CLI_RCFILE_DEFAULT ".valkeyclirc" #define CLI_AUTH_ENV "VALKEYCLI_AUTH" #define OLD_CLI_AUTH_ENV "REDISCLI_AUTH" +#define CLI_HOST_ENV "VALKEYCLI_HOST" +#define CLI_PORT_ENV "VALKEYCLI_PORT" #define CLI_CLUSTER_YES_ENV "VALKEYCLI_CLUSTER_YES" #define OLD_CLI_CLUSTER_YES_ENV "REDISCLI_CLUSTER_YES" @@ -2929,15 +2931,23 @@ static int parseOptions(int argc, char **argv) { return i; } +/* Reads environment variables and overrides the global configuration */ static void parseEnv(void) { - /* Set auth from env, but do not overwrite CLI arguments if passed */ char *auth = getenv(CLI_AUTH_ENV); if (auth == NULL) { auth = getenv(OLD_CLI_AUTH_ENV); } - if (auth != NULL && config.conn_info.auth == NULL) { + if (auth != NULL) { config.conn_info.auth = auth; } + char *host = getenv(CLI_HOST_ENV); + if (host != NULL) { + config.conn_info.hostip = sdsnew(host); + } + char *port = getenv(CLI_PORT_ENV); + if (port != NULL) { + config.conn_info.hostport = atoi(port); + } /* Check for cluster yes flag with fallback to legacy env variable */ char *cluster_yes = getenv(CLI_CLUSTER_YES_ENV); @@ -10043,6 +10053,7 @@ int main(int argc, char **argv) { int firstarg; struct timeval tv; + /* Valkey defaults */ memset(&config.sslconfig, 0, sizeof(config.sslconfig)); config.ct = VALKEY_CONN_TCP; config.conn_info.hostip = sdsnew("127.0.0.1"); @@ -10134,12 +10145,14 @@ int main(int argc, char **argv) { config.mb_delim = sdsnew("\n"); config.cmd_delim = sdsnew("\n"); + /* Override configuration based on environment variables */ + parseEnv(); + + /* Override configuration based on explicit command-line arguments */ firstarg = parseOptions(argc, argv); argc -= firstarg; argv += firstarg; - parseEnv(); - if (config.askpass) { config.conn_info.auth = askPassword("Please input password: "); }