diff --git a/api/contrib/envoy/extensions/filters/network/mysql_proxy/v3/mysql_proxy.proto b/api/contrib/envoy/extensions/filters/network/mysql_proxy/v3/mysql_proxy.proto index f3f2cefdc372d..14e832163fce3 100644 --- a/api/contrib/envoy/extensions/filters/network/mysql_proxy/v3/mysql_proxy.proto +++ b/api/contrib/envoy/extensions/filters/network/mysql_proxy/v3/mysql_proxy.proto @@ -16,10 +16,34 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; // MySQL Proxy :ref:`configuration overview `. // [#extension: envoy.filters.network.mysql_proxy] +// [#next-free-field: 4] message MySQLProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.mysql_proxy.v1alpha1.MySQLProxy"; + // Downstream SSL operational modes. + enum SSLMode { + // Do not terminate SSL session initiated by a client. + // The MySQL proxy filter will pass all encrypted and unencrypted packets to the upstream server. + DISABLE = 0; + + // The MySQL proxy filter will terminate SSL session initiated by a client + // and close downstream connections that do not initiate SSL. + // The filter will mediate ``caching_sha2_password`` RSA authentication when + // the upstream MySQL server requires full authentication over the plaintext connection. + // The filter chain must use :ref:`starttls transport socket + // `. + REQUIRE = 1; + + // The MySQL proxy filter will accept downstream client's encryption settings. + // If the client wants to use clear-text, Envoy will not enforce SSL encryption. + // If the client wants to use encryption, Envoy will terminate SSL and mediate + // ``caching_sha2_password`` RSA authentication when needed. + // The filter chain must use :ref:`starttls transport socket + // `. + ALLOW = 2; + } + // The human readable prefix to use when emitting :ref:`statistics // `. string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; @@ -27,4 +51,15 @@ message MySQLProxy { // [#not-implemented-hide:] The optional path to use for writing MySQL access logs. // If the access log field is empty, access logs will not be written. string access_log = 2; + + // Controls whether to terminate SSL sessions initiated by downstream clients. + // If enabled, the filter chain must use + // :ref:`starttls transport socket `. + // Defaults to ``DISABLE``. + SSLMode downstream_ssl = 3; + + // TODO: Add upstream_ssl (SSLMode) to support encrypting the connection from + // Envoy to the upstream MySQL server, similar to PostgresProxy.upstream_ssl. + // When implemented, the corresponding cluster should use + // :ref:`starttls transport socket `. } diff --git a/bazel/deps.yaml b/bazel/deps.yaml index 714d766bdffd2..804f5ab8d215e 100644 --- a/bazel/deps.yaml +++ b/bazel/deps.yaml @@ -257,6 +257,18 @@ libprotobuf_mutator: license: "Apache-2.0" license_url: "https://github.com/google/libprotobuf-mutator/blob/v{version}/LICENSE" +libsodium: + project_name: "libsodium" + project_desc: "A modern, portable, easy-to-use crypto library for secure memory and zeroing" + project_url: "https://github.com/jedisct1/libsodium" + release_date: "2024-08-23" + use_category: + - other + extensions: + - envoy.filters.network.mysql_proxy + license: "ISC" + license_url: "https://github.com/jedisct1/libsodium/blob/{version}/LICENSE" + libsxg: project_name: "libsxg" project_desc: "Signed HTTP Exchange library" diff --git a/bazel/foreign_cc/BUILD b/bazel/foreign_cc/BUILD index 7cd494b368ad5..6f7a67fa79ae8 100644 --- a/bazel/foreign_cc/BUILD +++ b/bazel/foreign_cc/BUILD @@ -219,6 +219,27 @@ cc_library( deps = ["unicode_icu_build"], ) +configure_make( + name = "libsodium", + configure_in_place = True, + configure_options = [ + "--disable-shared", + "--enable-static", + "--disable-asm", + "--without-pthreads", + ], + env = select({ + "//bazel:clang_build": { + "AR": "$(AR)", + "RANLIB": "$(AR) -s", + }, + "//conditions:default": {}, + }), + lib_source = "@libsodium//:all", + out_static_libs = ["libsodium.a"], + tags = ["skip_on_windows"], +) + envoy_cmake( name = "libsxg", build_args = select({ diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 267761ef27c28..d00536ad5699b 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -167,6 +167,7 @@ def envoy_dependencies(skip_targets = []): _aws_lc() _aws_c_auth_testdata() + _libsodium() _liburing() _com_github_bazel_buildtools() _c_ares() @@ -323,6 +324,12 @@ def _aws_c_auth_testdata(): build_file = "@envoy//bazel/external:aws-c-auth.BUILD", ) +def _libsodium(): + external_http_archive( + name = "libsodium", + build_file_content = BUILD_ALL_CONTENT, + ) + def _liburing(): external_http_archive( name = "liburing", diff --git a/bazel/repository_locations.bzl b/bazel/repository_locations.bzl index a26fd720d9bfa..8e44951a11f4a 100644 --- a/bazel/repository_locations.bzl +++ b/bazel/repository_locations.bzl @@ -96,6 +96,12 @@ REPOSITORY_LOCATIONS_SPEC = dict( strip_prefix = "aws-c-auth-{version}", urls = ["https://github.com/awslabs/aws-c-auth/archive/refs/tags/v{version}.tar.gz"], ), + libsodium = dict( + version = "1.0.20", + sha256 = "ebb65ef6ca439333c2bb41a0c1990587288da07f6c7fd07cb3a18cc18d30ce19", + strip_prefix = "libsodium-{version}", + urls = ["https://download.libsodium.org/libsodium/releases/libsodium-{version}.tar.gz"], + ), liburing = dict( version = "2.13", sha256 = "618e34dbea408fc9e33d7c4babd746036dbdedf7fce2496b1178ced0f9b5b357", diff --git a/changelogs/current.yaml b/changelogs/current.yaml index da3995e73ff95..073b2687384c4 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -722,6 +722,15 @@ new_features: Added :ref:`set() ` to the Lua filter state API, allowing Lua scripts to create and store filter state objects dynamically using registered object factories. +- area: mysql_proxy + change: | + Added SSL termination support to the MySQL proxy filter with RSA-mediated ``caching_sha2_password`` + authentication. The filter can now terminate downstream TLS connections using the + :ref:`starttls transport socket ` + and transparently mediate MySQL 8.0+ ``caching_sha2_password`` full authentication by performing + RSA public key exchange on behalf of the client. Added a new + :ref:`downstream_ssl ` + config option with ``DISABLE``, ``REQUIRE``, and ``ALLOW`` modes. - area: http change: | Fixed an issue where filter chain execution could continue on HTTP streams that had been reset but not yet diff --git a/contrib/mysql_proxy/filters/network/source/BUILD b/contrib/mysql_proxy/filters/network/source/BUILD index 5bb9bf73dc62a..8f4a8346a3db8 100644 --- a/contrib/mysql_proxy/filters/network/source/BUILD +++ b/contrib/mysql_proxy/filters/network/source/BUILD @@ -27,8 +27,10 @@ envoy_cc_library( "//envoy/server:filter_config_interface", "//envoy/stats:stats_interface", "//envoy/stats:stats_macros", + "//source/common/crypto:utility_lib", "//source/common/network:filter_lib", "//source/extensions/filters/network:well_known_names", + "@envoy_api//contrib/envoy/extensions/filters/network/mysql_proxy/v3:pkg_cc_proto", "@envoy_api//envoy/config/core/v3:pkg_cc_proto", ], ) @@ -80,6 +82,7 @@ envoy_cc_library( hdrs = ["mysql_utils.h"], deps = [ ":codec_interface", + "//bazel/foreign_cc:libsodium", "//source/common/buffer:buffer_lib", ], ) diff --git a/contrib/mysql_proxy/filters/network/source/mysql_codec.h b/contrib/mysql_proxy/filters/network/source/mysql_codec.h index d040ab3fb6dbd..af6b3cc425455 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_codec.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_codec.h @@ -33,6 +33,10 @@ constexpr uint8_t MYSQL_RESP_MORE = 0x01; constexpr uint8_t MYSQL_RESP_AUTH_SWITCH = 0xfe; constexpr uint8_t MYSQL_RESP_ERR = 0xff; +constexpr uint8_t MYSQL_CACHINGSHA2_FAST_AUTH_SUCCESS = 0x03; +constexpr uint8_t MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED = 0x04; +constexpr uint8_t MYSQL_REQUEST_PUBLIC_KEY = 0x02; + constexpr uint8_t EOF_MARKER = 0xfe; constexpr uint8_t ERR_MARKER = 0xff; @@ -102,6 +106,7 @@ constexpr uint8_t LENENCODINT_3BYTES = 0xfd; constexpr uint8_t LENENCODINT_8BYTES = 0xfe; constexpr uint32_t DEFAULT_MAX_PACKET_SIZE = (1 << 24) - 1; // 16M-1 +constexpr uint32_t SSL_CONNECTION_REQUEST_PACKET_SIZE = 32; constexpr uint8_t MIN_PROTOCOL_VERSION = 10; constexpr char MYSQL_STR_END = '\0'; diff --git a/contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.cc b/contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.cc index a2b6e3664cb07..602575b6e9cf5 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.cc +++ b/contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.cc @@ -71,7 +71,7 @@ DecodeStatus ClientLogin::parseMessage(Buffer::Instance& buffer, uint32_t len) { return DecodeStatus::Failure; } setBaseClientCap(base_cap); - if (base_cap & CLIENT_SSL) { + if (len == SSL_CONNECTION_REQUEST_PACKET_SIZE && (base_cap & CLIENT_SSL)) { return parseResponseSsl(buffer); } if (base_cap & CLIENT_PROTOCOL_41) { diff --git a/contrib/mysql_proxy/filters/network/source/mysql_config.cc b/contrib/mysql_proxy/filters/network/source/mysql_config.cc index c5e3b32d66ec1..1292a14a81ce7 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_config.cc +++ b/contrib/mysql_proxy/filters/network/source/mysql_config.cc @@ -28,8 +28,8 @@ NetworkFilters::MySQLProxy::MySQLConfigFactory::createFilterFactoryFromProtoType const std::string stat_prefix = fmt::format("mysql.{}", proto_config.stat_prefix()); - MySQLFilterConfigSharedPtr filter_config( - std::make_shared(stat_prefix, context.scope())); + MySQLFilterConfigSharedPtr filter_config(std::make_shared( + stat_prefix, context.scope(), proto_config.downstream_ssl())); return [filter_config](Network::FilterManager& filter_manager) -> void { filter_manager.addFilter(std::make_shared(filter_config)); }; diff --git a/contrib/mysql_proxy/filters/network/source/mysql_decoder.h b/contrib/mysql_proxy/filters/network/source/mysql_decoder.h index 491cb93a49d26..60ad85c33c54d 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_decoder.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_decoder.h @@ -7,6 +7,7 @@ #include "contrib/mysql_proxy/filters/network/source/mysql_codec_greeting.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec_switch_resp.h" #include "contrib/mysql_proxy/filters/network/source/mysql_session.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_utils.h" namespace Envoy { namespace Extensions { @@ -21,14 +22,15 @@ class DecoderCallbacks { virtual ~DecoderCallbacks() = default; virtual void onProtocolError() PURE; - virtual void onNewMessage(MySQLSession::State) PURE; virtual void onServerGreeting(ServerGreeting&) PURE; - virtual void onClientLogin(ClientLogin&) PURE; + virtual void onClientLogin(ClientLogin&, MySQLSession::State) PURE; virtual void onClientLoginResponse(ClientLoginResponse&) PURE; virtual void onClientSwitchResponse(ClientSwitchResponse&) PURE; virtual void onMoreClientLoginResponse(ClientLoginResponse&) PURE; virtual void onCommand(Command&) PURE; virtual void onCommandResponse(CommandResponse&) PURE; + virtual void onAuthSwitchMoreClientData(std::unique_ptr data) PURE; + virtual bool onSSLRequest() PURE; }; /** @@ -38,16 +40,33 @@ class Decoder { public: virtual ~Decoder() = default; - virtual void onData(Buffer::Instance& data) PURE; + enum class Result { + ReadyForNext, // Decoder processed previous message and is ready for the next message. + Stopped // Received and processed message disrupts the current flow. Decoder stopped accepting + // data. This happens when decoder wants filter to perform some action, for example to + // call starttls transport socket to enable TLS. + }; + + struct PayloadMetadata { + uint8_t seq; + uint32_t len; + }; + + virtual Result onData(Buffer::Instance& data, bool is_upstream) PURE; virtual MySQLSession& getSession() PURE; const Extensions::Common::SQLUtils::SQLUtils::DecoderAttributes& getAttributes() const { return attributes_; } + const std::vector& getPayloadMetadataList() const { + return payload_metadata_list_; + } + protected: // Decoder attributes. Extensions::Common::SQLUtils::SQLUtils::DecoderAttributes attributes_; + std::vector payload_metadata_list_{}; }; using DecoderPtr = std::unique_ptr; diff --git a/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.cc b/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.cc index 36d28b7688093..7175985c33ffa 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.cc +++ b/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.cc @@ -11,8 +11,10 @@ namespace Extensions { namespace NetworkFilters { namespace MySQLProxy { -void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t len) { - ENVOY_LOG(trace, "mysql_proxy: parsing message, seq {}, len {}", seq, len); +void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t len, + bool is_upstream) { + ENVOY_LOG(trace, "mysql_proxy: parsing message, seq {}, len {}, is_upstream {}", seq, len, + is_upstream); // Run the MySQL state machine switch (session_.getState()) { case MySQLSession::State::Init: { @@ -27,14 +29,15 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t // Process Client Handshake Response ClientLogin client_login{}; client_login.decode(message, seq, len); - if (client_login.isSSLRequest()) { + + if (len == SSL_CONNECTION_REQUEST_PACKET_SIZE && client_login.isSSLRequest()) { session_.setState(MySQLSession::State::SslPt); } else if (client_login.isResponse41()) { session_.setState(MySQLSession::State::ChallengeResp41); } else { session_.setState(MySQLSession::State::ChallengeResp320); } - callbacks_.onClientLogin(client_login); + callbacks_.onClientLogin(client_login, session_.getState()); break; } case MySQLSession::State::SslPt: @@ -48,6 +51,7 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t session_.setState(MySQLSession::State::NotHandled); break; } + ENVOY_LOG(trace, "mysql_proxy: ChallengeResp resp_code is {}", resp_code); std::unique_ptr msg; MySQLSession::State state = MySQLSession::State::NotHandled; switch (resp_code) { @@ -55,7 +59,7 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t msg = std::make_unique(); state = MySQLSession::State::Req; // reset seq# when entering the REQ state - session_.setExpectedSeq(MYSQL_REQUEST_PKT_NUM); + session_.resetSeq(); break; } case MYSQL_RESP_AUTH_SWITCH: { @@ -70,6 +74,7 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t } case MYSQL_RESP_MORE: { msg = std::make_unique(); + state = MySQLSession::State::AuthSwitchMore; break; } default: @@ -92,7 +97,12 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t case MySQLSession::State::AuthSwitchMore: { uint8_t resp_code; - if (BufferHelper::peekUint8(message, resp_code) != DecodeStatus::Success) { + if (is_upstream) { + std::unique_ptr secure_data; + BufferHelper::readSecureBytes(message, len, secure_data); + callbacks_.onAuthSwitchMoreClientData(std::move(secure_data)); + break; + } else if (BufferHelper::peekUint8(message, resp_code) != DecodeStatus::Success) { session_.setState(MySQLSession::State::NotHandled); break; } @@ -102,19 +112,19 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t case MYSQL_RESP_OK: { msg = std::make_unique(); state = MySQLSession::State::Req; - session_.setExpectedSeq(MYSQL_REQUEST_PKT_NUM); + session_.resetSeq(); break; } case MYSQL_RESP_MORE: { msg = std::make_unique(); - state = MySQLSession::State::AuthSwitchResp; + state = MySQLSession::State::AuthSwitchMore; break; } case MYSQL_RESP_ERR: { msg = std::make_unique(); // stop parsing auth req/response, attempt to resync in command state state = MySQLSession::State::Resync; - session_.setExpectedSeq(MYSQL_REQUEST_PKT_NUM); + session_.resetSeq(); break; } case MYSQL_RESP_AUTH_SWITCH: { @@ -165,59 +175,84 @@ void DecoderImpl::parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t static_cast(session_.getState())); } -bool DecoderImpl::decode(Buffer::Instance& data) { +bool DecoderImpl::decode(Buffer::Instance& data, bool is_upstream) { ENVOY_LOG(trace, "mysql_proxy: decoding {} bytes", data.length()); uint32_t len = 0; uint8_t seq = 0; + bool return_without_parse = false; + bool result = true; + + auto current_state = session_.getState(); // ignore ssl message - if (session_.getState() == MySQLSession::State::SslPt) { + if (current_state == MySQLSession::State::SslPt) { data.drain(data.length()); - return true; + return result; } if (BufferHelper::peekHdr(data, len, seq) != DecodeStatus::Success) { throw EnvoyException("error parsing mysql packet header"); } ENVOY_LOG(trace, "mysql_proxy: seq {}, len {}", seq, len); + // If message is split over multiple packets, hold off until the entire message is available. // Consider the size of the header here as it's not consumed yet. if (sizeof(uint32_t) + len > data.length()) { - return false; + return_without_parse = true; + result = false; } - BufferHelper::consumeHdr(data); // Consume the header once the message is fully available. - callbacks_.onNewMessage(session_.getState()); - // Ignore duplicate and out-of-sync packets. - if (seq != session_.getExpectedSeq()) { + if (seq != session_.getExpectedSeq(is_upstream)) { // case when server response is over, and client send req if (session_.getState() == MySQLSession::State::ReqResp && seq == MYSQL_REQUEST_PKT_NUM) { - session_.setExpectedSeq(MYSQL_REQUEST_PKT_NUM); + session_.resetSeq(); session_.setState(MySQLSession::State::Req); + current_state = session_.getState(); } else { ENVOY_LOG(info, "mysql_proxy: ignoring out-of-sync packet"); callbacks_.onProtocolError(); - data.drain(len); // Ensure that the whole message was consumed - return true; + data.drain(sizeof(uint32_t) + len); // Ensure that the whole message was consumed + return_without_parse = true; } } - session_.setExpectedSeq(seq + 1); + + payload_metadata_list_.push_back( + {.seq = session_.convertToSeqOnReciever(seq, is_upstream), .len = len}); + + if (return_without_parse) { + return result; + } + + BufferHelper::consumeHdr(data); // Consume the header once the message is fully available. + session_.incSeq(); const ssize_t data_len = data.length(); - parseMessage(data, seq, len); + parseMessage(data, seq, len, is_upstream); const ssize_t consumed_len = data_len - data.length(); data.drain(len - consumed_len); // Ensure that the whole message was consumed ENVOY_LOG(trace, "mysql_proxy: {} bytes remaining in buffer", data.length()); - return true; + return result; } -void DecoderImpl::onData(Buffer::Instance& data) { +Decoder::Result DecoderImpl::onData(Buffer::Instance& data, bool is_upstream) { + payload_metadata_list_.clear(); + // TODO(venilnoronha): handle messages over 16 mb. See // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_packets.html#sect_protocol_basic_packets_sending_mt_16mb. - while (!BufferHelper::endOfBuffer(data) && decode(data)) { + while (!BufferHelper::endOfBuffer(data) && decode(data, is_upstream)) { } + + if (is_upstream && session_.getState() == MySQLSession::State::SslPt) { + if (!callbacks_.onSSLRequest()) { + session_.adjustSeqOffset(1); + session_.setState(MySQLSession::State::ChallengeReq); + return Decoder::Result::Stopped; + } + } + + return Decoder::Result::ReadyForNext; } DecoderFactoryImpl DecoderFactoryImpl::instance_; diff --git a/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.h b/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.h index 6c126edaeba1f..85c8678e64e71 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.h @@ -11,12 +11,12 @@ class DecoderImpl : public Decoder, public Logger::Loggable DecoderImpl(DecoderCallbacks& callbacks) : callbacks_(callbacks) {} // MySQLProxy::Decoder - void onData(Buffer::Instance& data) override; + Decoder::Result onData(Buffer::Instance& data, bool is_upstream) override; MySQLSession& getSession() override { return session_; } private: - bool decode(Buffer::Instance& data); - void parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t len); + bool decode(Buffer::Instance& data, bool is_upstream); + void parseMessage(Buffer::Instance& message, uint8_t seq, uint32_t len, bool is_upstream); DecoderCallbacks& callbacks_; MySQLSession session_; diff --git a/contrib/mysql_proxy/filters/network/source/mysql_filter.cc b/contrib/mysql_proxy/filters/network/source/mysql_filter.cc index 1574bd2a23e80..3ef0368c5884d 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_filter.cc +++ b/contrib/mysql_proxy/filters/network/source/mysql_filter.cc @@ -1,23 +1,30 @@ #include "contrib/mysql_proxy/filters/network/source/mysql_filter.h" +#include + #include "envoy/config/core/v3/base.pb.h" #include "source/common/buffer/buffer_impl.h" #include "source/common/common/assert.h" #include "source/common/common/logger.h" +#include "source/common/crypto/utility.h" #include "source/extensions/filters/network/well_known_names.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin_resp.h" #include "contrib/mysql_proxy/filters/network/source/mysql_decoder_impl.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_utils.h" +#include "openssl/evp.h" +#include "openssl/rsa.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace MySQLProxy { -MySQLFilterConfig::MySQLFilterConfig(const std::string& stat_prefix, Stats::Scope& scope) - : scope_(scope), stats_(generateStats(stat_prefix, scope)) {} +MySQLFilterConfig::MySQLFilterConfig(const std::string& stat_prefix, Stats::Scope& scope, + SSLMode downstream_ssl) + : scope_(scope), stats_(generateStats(stat_prefix, scope)), downstream_ssl_(downstream_ssl) {} MySQLFilter::MySQLFilter(MySQLFilterConfigSharedPtr config) : config_(std::move(config)) {} @@ -25,27 +32,145 @@ void MySQLFilter::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& ca read_callbacks_ = &callbacks; } +void MySQLFilter::initializeWriteFilterCallbacks(Network::WriteFilterCallbacks& callbacks) { + write_callbacks_ = &callbacks; +} + Network::FilterStatus MySQLFilter::onData(Buffer::Instance& data, bool) { + Network::FilterStatus status = Network::FilterStatus::Continue; + uint64_t remaining = read_buffer_.length(); + // Safety measure just to make sure that if we have a decoding error we keep going and lose stats. // This can be removed once we are more confident of this code. - if (sniffing_) { - read_buffer_.add(data); - doDecode(read_buffer_); + if (!sniffing_) { + return status; } - return Network::FilterStatus::Continue; + + read_buffer_.add(data); + status = doDecode(read_buffer_, true); + + if (status == Network::FilterStatus::StopIteration) { + data.drain(data.length()); + return status; + } + + // RSA mediation: intercept client cleartext password. + if (rsa_auth_state_ == RsaAuthState::WaitingClientPassword && cleartext_password_) { + data.drain(data.length()); + + uint8_t inject_seq = getSession().getExpectedSeq(false) - 1; + ENVOY_CONN_LOG(trace, + "mysql_proxy: intercepted client password, sending request-public-key (seq={})", + read_callbacks_->connection(), inject_seq); + + Buffer::OwnedImpl buf; + BufferHelper::addUint24(buf, 1); + BufferHelper::addUint8(buf, inject_seq); + BufferHelper::addUint8(buf, MYSQL_REQUEST_PUBLIC_KEY); + read_callbacks_->injectReadDataToFilterChain(buf, false); + + rsa_auth_state_ = RsaAuthState::WaitingServerKey; + return Network::FilterStatus::StopIteration; + } + + if (config_->terminateSsl()) { + doRewrite(data, remaining, true); + } + + return status; } Network::FilterStatus MySQLFilter::onWrite(Buffer::Instance& data, bool) { + Network::FilterStatus status = Network::FilterStatus::Continue; + // Safety measure just to make sure that if we have a decoding error we keep going and lose stats. // This can be removed once we are more confident of this code. - if (sniffing_) { + if (!sniffing_) { + return status; + } + + // RSA mediation: intercept server's public key response. + if (rsa_auth_state_ == RsaAuthState::WaitingServerKey) { write_buffer_.add(data); - doDecode(write_buffer_); + data.drain(data.length()); + + // Check if we have a complete packet. + uint32_t len = 0; + uint8_t seq = 0; + if (BufferHelper::peekHdr(write_buffer_, len, seq) != DecodeStatus::Success || + sizeof(uint32_t) + len > write_buffer_.length()) { + ENVOY_CONN_LOG(trace, + "mysql_proxy: waiting for complete public key packet ({} bytes buffered)", + read_callbacks_->connection(), write_buffer_.length()); + return Network::FilterStatus::StopIteration; + } + + ENVOY_CONN_LOG(trace, "mysql_proxy: received server public key packet (seq={}, len={})", + read_callbacks_->connection(), seq, len); + + // Full packet available. Parse it: [hdr][0x01 marker][PEM key bytes]. + BufferHelper::consumeHdr(write_buffer_); + uint8_t marker; + BufferHelper::readUint8(write_buffer_, marker); + if (marker != MYSQL_RESP_MORE) { + ENVOY_CONN_LOG(error, "mysql_proxy: unexpected marker 0x{:02x} in public key response", + read_callbacks_->connection(), marker); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + return Network::FilterStatus::StopIteration; + } + + std::string pem_key; + BufferHelper::readStringBySize(write_buffer_, len - 1, pem_key); + write_buffer_.drain(write_buffer_.length()); + + ENVOY_CONN_LOG(trace, "mysql_proxy: extracted PEM key ({} bytes), encrypting password", + read_callbacks_->connection(), pem_key.size()); + + sendEncryptedPassword(pem_key, seq); + getSession().adjustSeqOffset(-2); + rsa_auth_state_ = RsaAuthState::WaitingServerResult; + + ENVOY_CONN_LOG(trace, "mysql_proxy: RSA encrypted password sent, waiting for server result", + read_callbacks_->connection()); + + // Nothing to forward to client; data was already drained. + return Network::FilterStatus::Continue; } - return Network::FilterStatus::Continue; + + uint64_t remaining = write_buffer_.length(); + + write_buffer_.add(data); + status = doDecode(write_buffer_, false); + + if (status == Network::FilterStatus::StopIteration) { + data.drain(data.length()); + return status; + } + + if (config_->terminateSsl()) { + doRewrite(data, remaining, false); + } + + return status; } -void MySQLFilter::doDecode(Buffer::Instance& buffer) { +bool MySQLFilter::onSSLRequest() { + if (!config_->terminateSsl()) { + return true; + } + + if (!read_callbacks_->connection().startSecureTransport()) { + ENVOY_CONN_LOG(info, "mysql_proxy: cannot enable secure transport. Check configuration.", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + } else { + ENVOY_CONN_LOG(trace, "mysql_proxy: enabled SSL termination.", read_callbacks_->connection()); + } + + return false; +} + +Network::FilterStatus MySQLFilter::doDecode(Buffer::Instance& buffer, bool is_upstream) { // Clear dynamic metadata. envoy::config::core::v3::Metadata& dynamic_metadata = read_callbacks_->connection().streamInfo().dynamicMetadata(); @@ -58,7 +183,12 @@ void MySQLFilter::doDecode(Buffer::Instance& buffer) { } try { - decoder_->onData(buffer); + switch (decoder_->onData(buffer, is_upstream)) { + case Decoder::Result::ReadyForNext: + return Network::FilterStatus::Continue; + case Decoder::Result::Stopped: + return Network::FilterStatus::StopIteration; + } } catch (EnvoyException& e) { ENVOY_LOG(info, "mysql_proxy: decoding error: {}", e.what()); config_->stats_.decoder_errors_.inc(); @@ -66,38 +196,231 @@ void MySQLFilter::doDecode(Buffer::Instance& buffer) { read_buffer_.drain(read_buffer_.length()); write_buffer_.drain(write_buffer_.length()); } + + return Network::FilterStatus::Continue; } DecoderPtr MySQLFilter::createDecoder(DecoderCallbacks& callbacks) { return std::make_unique(callbacks); } +void MySQLFilter::rewritePacketHeader(Buffer::Instance& data, uint8_t seq, uint32_t len) { + BufferHelper::consumeHdr(data); + BufferHelper::addUint24(data, len); + BufferHelper::addUint8(data, seq); +} + +void MySQLFilter::stripSslCapability(Buffer::Instance& data) { + uint32_t client_cap = 0; + BufferHelper::readUint32(data, client_cap); + BufferHelper::addUint32(data, client_cap & ~CLIENT_SSL); +} + +void MySQLFilter::doRewrite(Buffer::Instance& data, uint64_t remaining, bool is_upstream) { + MySQLSession::State state = getSession().getState(); + auto& payload_metadata_list = decoder_->getPayloadMetadataList(); + const uint64_t original_data_size = data.length(); + uint64_t max_data_size = original_data_size; + + for (size_t i = 0; i < payload_metadata_list.size(); ++i) { + uint8_t seq = payload_metadata_list[i].seq; + uint32_t len = payload_metadata_list[i].len; + + if (i == 0 && remaining > 0) { + // First packet spans old internal buffer and new data. The header and first + // (remaining - 4) payload bytes are in the internal buffer, not in data. + ASSERT(remaining >= 4, "partial header should not appear in payload metadata"); + len -= remaining - 4; + } else { + rewritePacketHeader(data, seq, len); + max_data_size -= 4; + + if (is_upstream && (state == MySQLSession::State::ChallengeResp41 || + state == MySQLSession::State::ChallengeResp320)) { + stripSslCapability(data); + len -= 4; + } + } + + uint64_t copy_size = std::min(static_cast(len), max_data_size); + std::string payload; + payload.reserve(copy_size); + BufferHelper::readStringBySize(data, copy_size, payload); + BufferHelper::addBytes(data, payload.c_str(), payload.size()); + max_data_size -= copy_size; + } + + ASSERT(data.length() == original_data_size, "doRewrite must not change overall buffer size"); +} + void MySQLFilter::onProtocolError() { config_->stats_.protocol_errors_.inc(); } -void MySQLFilter::onNewMessage(MySQLSession::State state) { - if (state == MySQLSession::State::ChallengeReq) { - config_->stats_.login_attempts_.inc(); +void MySQLFilter::onServerGreeting(ServerGreeting& greeting) { + ENVOY_CONN_LOG(trace, "mysql_proxy: server greeting: version={}, auth_plugin={}, scramble_len={}", + read_callbacks_->connection(), greeting.getVersion(), greeting.getAuthPluginName(), + greeting.getAuthPluginData().size()); + if (config_->terminateSsl()) { + server_scramble_ = greeting.getAuthPluginData(); + // The MySQL greeting protocol may include a trailing null filler byte in + // auth_plugin_data, making it 21 bytes. The actual nonce used by + // caching_sha2_password is always 20 bytes. Truncate to avoid corrupting + // the XOR for passwords longer than 20 characters. + if (server_scramble_.size() > NATIVE_PSSWORD_HASH_LENGTH) { + server_scramble_.resize(NATIVE_PSSWORD_HASH_LENGTH); + } + auth_plugin_name_ = greeting.getAuthPluginName(); + ENVOY_CONN_LOG(trace, + "mysql_proxy: captured scramble ({} bytes) for SSL termination, plugin={}", + read_callbacks_->connection(), server_scramble_.size(), auth_plugin_name_); } } -void MySQLFilter::onClientLogin(ClientLogin& client_login) { +void MySQLFilter::onClientLogin(ClientLogin& client_login, MySQLSession::State state) { + ENVOY_CONN_LOG(trace, "mysql_proxy: client login: ssl_request={}, state={}, user={}", + read_callbacks_->connection(), client_login.isSSLRequest(), + static_cast(state), client_login.getUsername()); if (client_login.isSSLRequest()) { config_->stats_.upgraded_to_ssl_.inc(); } + + if (state == MySQLSession::State::ChallengeResp41 || + state == MySQLSession::State::ChallengeResp320) { + config_->stats_.login_attempts_.inc(); + + // REQUIRE mode: reject clients that did not initiate SSL. + using MySQLProto = envoy::extensions::filters::network::mysql_proxy::v3::MySQLProxy; + if (config_->downstream_ssl_ == MySQLProto::REQUIRE && getSession().getSeqOffset() == 0) { + ENVOY_CONN_LOG(info, + "mysql_proxy: downstream_ssl=REQUIRE but client did not initiate SSL, closing", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + } + } } void MySQLFilter::onClientLoginResponse(ClientLoginResponse& client_login_resp) { + ENVOY_CONN_LOG(trace, "mysql_proxy: server login response: resp_code=0x{:02x}", + read_callbacks_->connection(), client_login_resp.getRespCode()); if (client_login_resp.getRespCode() == MYSQL_RESP_AUTH_SWITCH) { config_->stats_.auth_switch_request_.inc(); } else if (client_login_resp.getRespCode() == MYSQL_RESP_ERR) { config_->stats_.login_failures_.inc(); + } else if (config_->terminateSsl() && getSession().getSeqOffset() != 0 && + client_login_resp.getRespCode() == MYSQL_RESP_MORE) { + auto* auth_more = dynamic_cast(&client_login_resp); + if (auth_more && !auth_more->getAuthMoreData().empty()) { + ENVOY_CONN_LOG(trace, "mysql_proxy: AuthMoreData[0]=0x{:02x}, plugin={}", + read_callbacks_->connection(), auth_more->getAuthMoreData()[0], + auth_plugin_name_); + if (auth_more->getAuthMoreData()[0] == MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED && + auth_plugin_name_ == "caching_sha2_password") { + rsa_auth_state_ = RsaAuthState::WaitingClientPassword; + ENVOY_CONN_LOG(trace, "mysql_proxy: full auth required, entering RSA mediation", + read_callbacks_->connection()); + } else if (auth_more->getAuthMoreData()[0] == MYSQL_CACHINGSHA2_FAST_AUTH_SUCCESS) { + ENVOY_CONN_LOG(trace, "mysql_proxy: fast auth success (cache hit), no RSA needed", + read_callbacks_->connection()); + } + } } } void MySQLFilter::onMoreClientLoginResponse(ClientLoginResponse& client_login_resp) { + ENVOY_CONN_LOG(trace, "mysql_proxy: more login response: resp_code=0x{:02x}, rsa_state={}", + read_callbacks_->connection(), client_login_resp.getRespCode(), + static_cast(rsa_auth_state_)); if (client_login_resp.getRespCode() == MYSQL_RESP_ERR) { config_->stats_.login_failures_.inc(); } + if (rsa_auth_state_ == RsaAuthState::WaitingServerResult) { + ENVOY_CONN_LOG(trace, "mysql_proxy: RSA mediation complete, result=0x{:02x}", + read_callbacks_->connection(), client_login_resp.getRespCode()); + rsa_auth_state_ = RsaAuthState::Inactive; + } +} + +void MySQLFilter::onAuthSwitchMoreClientData(std::unique_ptr data) { + ENVOY_CONN_LOG(trace, "mysql_proxy: client auth data received, len={}, rsa_state={}", + read_callbacks_->connection(), data ? data->size() : 0, + static_cast(rsa_auth_state_)); + if (rsa_auth_state_ == RsaAuthState::WaitingClientPassword && data) { + // Password arrives in SecureBytes (guarded memory, zeroed on free). + // The decoder already read it via readSecureBytes which zeroed the source buffer. + cleartext_password_ = std::move(data); + ENVOY_CONN_LOG(trace, "mysql_proxy: captured cleartext password ({} bytes) in secure memory", + read_callbacks_->connection(), cleartext_password_->size()); + } +} + +void MySQLFilter::sendEncryptedPassword(const std::string& pem_key, uint8_t last_server_seq) { + if (!cleartext_password_) { + ENVOY_CONN_LOG(error, "mysql_proxy: no cleartext password captured", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + return; + } + + // XOR password with scramble in secure memory. The cleartext_password_ includes the + // trailing null sent by the client (password\0 format for caching_sha2_password). + SecureBytes xored(cleartext_password_->size()); + for (size_t i = 0; i < cleartext_password_->size(); i++) { + xored[i] = (*cleartext_password_)[i] ^ server_scramble_[i % server_scramble_.size()]; + } + + // Import the server's public key. + auto pkey = Envoy::Common::Crypto::UtilitySingleton::get().importPublicKeyPEM(pem_key); + if (!pkey || !pkey->getEVP_PKEY()) { + ENVOY_CONN_LOG(error, "mysql_proxy: failed to import server public key", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + return; + } + + // RSA-encrypt with OAEP/SHA-1 padding (MySQL requirement). + bssl::UniquePtr ctx(EVP_PKEY_CTX_new(pkey->getEVP_PKEY(), nullptr)); + if (!ctx || EVP_PKEY_encrypt_init(ctx.get()) <= 0 || + EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING) <= 0 || + EVP_PKEY_CTX_set_rsa_oaep_md(ctx.get(), EVP_sha1()) <= 0) { + ENVOY_CONN_LOG(error, "mysql_proxy: failed to initialize RSA encryption context", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + return; + } + + size_t out_len = 0; + if (EVP_PKEY_encrypt(ctx.get(), nullptr, &out_len, xored.data(), xored.size()) <= 0) { + ENVOY_CONN_LOG(error, "mysql_proxy: failed to determine RSA ciphertext length", + read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + cleartext_password_.reset(); + return; + } + + std::vector encrypted(out_len); + if (EVP_PKEY_encrypt(ctx.get(), encrypted.data(), &out_len, xored.data(), xored.size()) <= 0) { + ENVOY_CONN_LOG(error, "mysql_proxy: RSA encryption failed", read_callbacks_->connection()); + read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + cleartext_password_.reset(); + return; + } + encrypted.resize(out_len); + + // Build the encrypted password packet: [3-byte len][seq][encrypted_data]. + uint8_t enc_seq = static_cast(last_server_seq + 1); + Buffer::OwnedImpl buf; + BufferHelper::addUint24(buf, encrypted.size()); + BufferHelper::addUint8(buf, enc_seq); + BufferHelper::addVector(buf, encrypted); + + ENVOY_CONN_LOG(trace, + "mysql_proxy: injecting RSA-encrypted password (seq={}, {} bytes ciphertext)", + read_callbacks_->connection(), enc_seq, encrypted.size()); + + // Inject toward server (bypasses our filter). + read_callbacks_->injectReadDataToFilterChain(buf, false); + + // Securely destroy the cleartext password (sodium_free zeroes before freeing). + cleartext_password_.reset(); } void MySQLFilter::onCommand(Command& command) { diff --git a/contrib/mysql_proxy/filters/network/source/mysql_filter.h b/contrib/mysql_proxy/filters/network/source/mysql_filter.h index bf41c53762269..33928f9181ee9 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_filter.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_filter.h @@ -9,6 +9,7 @@ #include "source/common/common/logger.h" +#include "contrib/envoy/extensions/filters/network/mysql_proxy/v3/mysql_proxy.pb.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin_resp.h" @@ -17,6 +18,7 @@ #include "contrib/mysql_proxy/filters/network/source/mysql_codec_switch_resp.h" #include "contrib/mysql_proxy/filters/network/source/mysql_decoder.h" #include "contrib/mysql_proxy/filters/network/source/mysql_session.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_utils.h" namespace Envoy { namespace Extensions { @@ -49,12 +51,19 @@ struct MySQLProxyStats { */ class MySQLFilterConfig { public: - MySQLFilterConfig(const std::string& stat_prefix, Stats::Scope& scope); + using SSLMode = envoy::extensions::filters::network::mysql_proxy::v3::MySQLProxy::SSLMode; + + MySQLFilterConfig(const std::string& stat_prefix, Stats::Scope& scope, SSLMode downstream_ssl); const MySQLProxyStats& stats() { return stats_; } + bool terminateSsl() const { + return downstream_ssl_ != + envoy::extensions::filters::network::mysql_proxy::v3::MySQLProxy::DISABLE; + } Stats::Scope& scope_; MySQLProxyStats stats_; + SSLMode downstream_ssl_; private: MySQLProxyStats generateStats(const std::string& prefix, Stats::Scope& scope) { @@ -64,6 +73,13 @@ class MySQLFilterConfig { using MySQLFilterConfigSharedPtr = std::shared_ptr; +enum class RsaAuthState { + Inactive, // Normal operation + WaitingClientPassword, // Server sent AuthMoreData(0x04), forwarded to client, waiting for pw + WaitingServerKey, // Sent 0x02 to server, waiting for PEM public key + WaitingServerResult, // Sent RSA-encrypted pw, waiting for OK/ERR +}; + /** * Implementation of MySQL proxy filter. */ @@ -79,29 +95,47 @@ class MySQLFilter : public Network::Filter, DecoderCallbacks, Logger::Loggable data) override; + bool onSSLRequest() override; - void doDecode(Buffer::Instance& buffer); + Network::FilterStatus doDecode(Buffer::Instance& buffer, bool is_upstream); DecoderPtr createDecoder(DecoderCallbacks& callbacks); + void doRewrite(Buffer::Instance& buffer, uint64_t remaining, bool is_upstream); MySQLSession& getSession() { return decoder_->getSession(); } + // Helpers for doRewrite. + static void rewritePacketHeader(Buffer::Instance& data, uint8_t seq, uint32_t len); + static void stripSslCapability(Buffer::Instance& data); + + RsaAuthState getRsaAuthState() const { return rsa_auth_state_; } + private: + void sendEncryptedPassword(const std::string& pem_key, uint8_t last_server_seq); + Network::ReadFilterCallbacks* read_callbacks_{}; + Network::WriteFilterCallbacks* write_callbacks_{}; MySQLFilterConfigSharedPtr config_; Buffer::OwnedImpl read_buffer_; Buffer::OwnedImpl write_buffer_; std::unique_ptr decoder_; bool sniffing_{true}; + + // RSA mediation state for caching_sha2_password full authentication. + RsaAuthState rsa_auth_state_{RsaAuthState::Inactive}; + std::unique_ptr cleartext_password_; + std::vector server_scramble_; + std::string auth_plugin_name_; }; } // namespace MySQLProxy diff --git a/contrib/mysql_proxy/filters/network/source/mysql_session.h b/contrib/mysql_proxy/filters/network/source/mysql_session.h index 691d582633b11..c659c50b4554a 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_session.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_session.h @@ -27,12 +27,23 @@ class MySQLSession : Logger::Loggable { void setState(MySQLSession::State state) { state_ = state; } MySQLSession::State getState() { return state_; } - uint8_t getExpectedSeq() { return expected_seq_; } - void setExpectedSeq(uint8_t seq) { expected_seq_ = seq; } + uint8_t getExpectedSeq(bool is_upstream) { return seq_ - (is_upstream ? 0 : seq_offset_); } + uint8_t convertToSeqOnReciever(uint8_t seq, bool is_upstream) { + return seq - (is_upstream ? 1 : -1) * seq_offset_; + } + void resetSeq() { + seq_ = MYSQL_REQUEST_PKT_NUM; + seq_offset_ = 0; + } + void incSeq() { seq_++; } + int8_t getSeqOffset() const { return seq_offset_; } + void setSeqOffset(int8_t offset) { seq_offset_ = offset; } + void adjustSeqOffset(int8_t delta) { seq_offset_ += delta; } private: MySQLSession::State state_{State::Init}; - uint8_t expected_seq_{0}; + uint8_t seq_{0}; + int8_t seq_offset_{0}; }; } // namespace MySQLProxy diff --git a/contrib/mysql_proxy/filters/network/source/mysql_utils.cc b/contrib/mysql_proxy/filters/network/source/mysql_utils.cc index c51fffd2076ce..98fe09c66d46a 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_utils.cc +++ b/contrib/mysql_proxy/filters/network/source/mysql_utils.cc @@ -216,6 +216,36 @@ DecodeStatus BufferHelper::peekHdr(Buffer::Instance& buffer, uint32_t& len, uint return DecodeStatus::Success; } +DecodeStatus BufferHelper::readSecureBytes(Buffer::Instance& buffer, size_t len, + std::unique_ptr& out) { + if (buffer.length() < len) { + return DecodeStatus::Failure; + } + + out = std::make_unique(len); + + // Copy data into secure memory, then zero the source buffer slices. + uint64_t copied = 0; + auto slices = buffer.getRawSlices(); + + for (const auto& slice : slices) { + if (copied == len) { + break; + } + + const uint64_t chunk = std::min(static_cast(slice.len_), len - copied); + auto* src = static_cast(slice.mem_); + + std::memcpy(out->data() + copied, src, chunk); + sodium_memzero(src, chunk); + + copied += chunk; + } + + buffer.drain(len); + return DecodeStatus::Success; +} + } // namespace MySQLProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/contrib/mysql_proxy/filters/network/source/mysql_utils.h b/contrib/mysql_proxy/filters/network/source/mysql_utils.h index 254ce0f8edc81..2e8d7c651f934 100644 --- a/contrib/mysql_proxy/filters/network/source/mysql_utils.h +++ b/contrib/mysql_proxy/filters/network/source/mysql_utils.h @@ -8,12 +8,51 @@ #include "source/common/common/logger.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec.h" +#include "sodium.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace MySQLProxy { +// Secure memory buffer backed by libsodium's guarded allocation. +// Memory is allocated with guard pages (mprotect) and is automatically +// zeroed on destruction, preventing password leakage in memory. +class SecureBytes { +public: + explicit SecureBytes(size_t len) : len_(len) { + data_ = static_cast(sodium_malloc(len_)); + if (data_ == nullptr) { + throw std::bad_alloc(); + } + } + + ~SecureBytes() { + if (data_ != nullptr) { + sodium_free(data_); // zeroes memory before freeing + } + } + + SecureBytes(const SecureBytes&) = delete; + SecureBytes& operator=(const SecureBytes&) = delete; + + SecureBytes(SecureBytes&& other) noexcept : data_(other.data_), len_(other.len_) { + other.data_ = nullptr; + other.len_ = 0; + } + + uint8_t* data() { return data_; } + const uint8_t* data() const { return data_; } + size_t size() const { return len_; } + + uint8_t operator[](size_t i) const { return data_[i]; } + uint8_t& operator[](size_t i) { return data_[i]; } + +private: + uint8_t* data_{nullptr}; + size_t len_{0}; +}; + /** * IO helpers for reading/writing MySQL data from/to a buffer. * MySQL uses unsigned integer values in Little Endian format only. @@ -50,6 +89,11 @@ class BufferHelper : public Logger::Loggable { static DecodeStatus peekUint8(Buffer::Instance& buffer, uint8_t& val); static void consumeHdr(Buffer::Instance& buffer); static DecodeStatus peekHdr(Buffer::Instance& buffer, uint32_t& len, uint8_t& seq); + + // Read `len` bytes from buffer into a SecureBytes object backed by guarded memory, + // then zero the original data in the buffer to prevent password leakage. + static DecodeStatus readSecureBytes(Buffer::Instance& buffer, size_t len, + std::unique_ptr& out); }; } // namespace MySQLProxy diff --git a/contrib/mysql_proxy/filters/network/test/BUILD b/contrib/mysql_proxy/filters/network/test/BUILD index a29546b61dbd9..ae52f4770d1eb 100644 --- a/contrib/mysql_proxy/filters/network/test/BUILD +++ b/contrib/mysql_proxy/filters/network/test/BUILD @@ -78,9 +78,11 @@ envoy_cc_test( srcs = [ "mysql_filter_test.cc", ], + external_deps = ["ssl"], deps = [ ":mysql_test_utils_lib", "//contrib/mysql_proxy/filters/network/source:config", + "//source/common/crypto:utility_lib", "//test/mocks/network:network_mocks", ], ) @@ -104,6 +106,31 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "mysql_ssl_integration_test", + srcs = [ + "mysql_ssl_integration_test.cc", + ], + data = [ + "mysql_ssl_disable_test_config.yaml", + "mysql_ssl_allow_test_config.yaml", + "mysql_ssl_require_test_config.yaml", + "//test/config/integration/certs", + ], + external_deps = ["ssl"], + rbe_pool = "6gig", + deps = [ + ":mysql_test_utils_lib", + "//contrib/mysql_proxy/filters/network/source:config", + "//contrib/mysql_proxy/filters/network/source:filter_lib", + "//source/common/tcp_proxy", + "//source/extensions/filters/network/tcp_proxy:config", + "//source/extensions/transport_sockets/raw_buffer:config", + "//source/extensions/transport_sockets/starttls:config", + "//test/integration:integration_lib", + ], +) + envoy_cc_test( name = "mysql_command_tests", srcs = [ diff --git a/contrib/mysql_proxy/filters/network/test/mysql_filter_test.cc b/contrib/mysql_proxy/filters/network/test/mysql_filter_test.cc index 1b816e86bd7b9..2fa75c5771450 100644 --- a/contrib/mysql_proxy/filters/network/test/mysql_filter_test.cc +++ b/contrib/mysql_proxy/filters/network/test/mysql_filter_test.cc @@ -1,15 +1,23 @@ #include "source/common/buffer/buffer_impl.h" +#include "source/common/crypto/utility.h" #include "test/mocks/network/mocks.h" #include "contrib/mysql_proxy/filters/network/source/mysql_codec.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin_resp.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_greeting.h" #include "contrib/mysql_proxy/filters/network/source/mysql_filter.h" #include "contrib/mysql_proxy/filters/network/source/mysql_utils.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include "mysql_test_utils.h" +#include "openssl/evp.h" +#include "openssl/pem.h" +#include "openssl/rsa.h" +using testing::_; using testing::NiceMock; +using testing::ReturnRef; namespace Envoy { namespace Extensions { @@ -22,10 +30,48 @@ class MySQLFilterTest : public testing::Test, public MySQLTestUtils { public: MySQLFilterTest() { ENVOY_LOG_MISC(info, "test"); } - void initialize() { - config_ = std::make_shared(stat_prefix_, scope_); + using MySQLProxyProto = envoy::extensions::filters::network::mysql_proxy::v3::MySQLProxy; + + void initialize(MySQLProxyProto::SSLMode downstream_ssl = MySQLProxyProto::DISABLE) { + config_ = std::make_shared(stat_prefix_, scope_, downstream_ssl); filter_ = std::make_unique(config_); filter_->initializeReadFilterCallbacks(filter_callbacks_); + filter_->initializeWriteFilterCallbacks(write_filter_callbacks_); + } + + // Encode a server greeting for caching_sha2_password with 20-byte scramble. + std::string encodeServerGreetingCachingSha2() { + ServerGreeting greeting; + greeting.setProtocol(MYSQL_PROTOCOL_10); + greeting.setVersion(getVersion()); + greeting.setThreadId(MYSQL_THREAD_ID); + greeting.setAuthPluginData(getAuthPluginData20()); + greeting.setServerCap(CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION); + greeting.setServerCharset(MYSQL_SERVER_LANGUAGE); + greeting.setServerStatus(MYSQL_SERVER_STATUS); + greeting.setAuthPluginName("caching_sha2_password"); + Buffer::OwnedImpl buffer; + greeting.encode(buffer); + BufferHelper::encodeHdr(buffer, GREETING_SEQ_NUM); + return buffer.toString(); + } + + // Encode an AuthMoreData packet with specific data bytes. + std::string encodeAuthMoreDataPacket(const std::vector& data, uint8_t seq) { + AuthMoreMessage auth_more; + auth_more.setAuthMoreData(data); + Buffer::OwnedImpl buffer; + auth_more.encode(buffer); + BufferHelper::encodeHdr(buffer, seq); + return buffer.toString(); + } + + // Encode a raw client-to-server packet (e.g., cleartext password). + std::string encodeRawPacket(const std::string& payload, uint8_t seq) { + Buffer::OwnedImpl buffer; + BufferHelper::addString(buffer, payload); + BufferHelper::encodeHdr(buffer, seq); + return buffer.toString(); } MySQLFilterConfigSharedPtr config_; @@ -34,6 +80,8 @@ class MySQLFilterTest : public testing::Test, public MySQLTestUtils { Stats::Scope& scope_{*store_.rootScope()}; std::string stat_prefix_{"test."}; NiceMock filter_callbacks_; + NiceMock write_filter_callbacks_; + NiceMock connection_; }; // Test New Session counter increment @@ -54,11 +102,11 @@ TEST_F(MySQLFilterTest, MySqlFallbackToTcpProxy) { EXPECT_EQ(1UL, config_->stats().sessions_.value()); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(" ")); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(1UL, config_->stats().decoder_errors_.value()); Buffer::InstancePtr more_data(new Buffer::OwnedImpl("scooby doo - part 2!")); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*more_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*more_data, false)); } /** @@ -74,7 +122,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41OkTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); @@ -85,7 +133,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41OkTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); } @@ -210,7 +258,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41ErrTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); @@ -221,14 +269,14 @@ TEST_F(MySQLFilterTest, MySqlHandshake41ErrTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_ERR); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(1UL, config_->stats().login_failures_.value()); EXPECT_EQ(MySQLSession::State::Error, filter_->getSession().getState()); } /** * Test MySQL Handshake with protocol version 41 - * Server responds with Error + * Server responds with Auth More Data * SM: greeting(p=10) -> challenge-req(v41) -> serv-resp-more */ TEST_F(MySQLFilterTest, MySqlHandshake41AuthMoreTest) { @@ -240,7 +288,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41AuthMoreTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); @@ -251,8 +299,8 @@ TEST_F(MySQLFilterTest, MySqlHandshake41AuthMoreTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_MORE); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); - EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::AuthSwitchMore, filter_->getSession().getState()); } /** @@ -268,7 +316,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320OkTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -279,7 +327,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320OkTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); } @@ -296,7 +344,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320OkTestIncomplete) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -307,7 +355,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320OkTestIncomplete) { std::string srv_resp_data = encodeMessage(0); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); } @@ -325,7 +373,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320ErrTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -336,7 +384,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320ErrTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_ERR); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(1UL, config_->stats().login_failures_.value()); EXPECT_EQ(MySQLSession::State::Error, filter_->getSession().getState()); } @@ -346,7 +394,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320ErrTest) { * State-machine moves to SSL-Pass-Through * SM: greeting(p=10) -> challenge-req(v320) -> SSL_PT */ -TEST_F(MySQLFilterTest, MySqlHandshakeSSLTest) { +TEST_F(MySQLFilterTest, MySqlHandshakeSSLPassThroughTest) { initialize(); EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); @@ -355,23 +403,86 @@ TEST_F(MySQLFilterTest, MySqlHandshakeSSLTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); - std::string clogin_data = - encodeClientLogin(CLIENT_SSL | CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + // Send SSL Connection Request packet. + // https://dev.mysql.com/doc/internals/en/ssl-handshake.html + std::string clogin_data = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); - EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); + EXPECT_EQ(0UL, config_->stats().login_attempts_.value()); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); + EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + + // After SSL handshaking, attempt to login. + // Since the SSL-Pass-Through, # of login attempts is unknown. + clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + client_login_data = Buffer::InstancePtr(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(0UL, config_->stats().login_attempts_.value()); EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, CHALLENGE_RESP_SEQ_NUM + 1); + Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + Buffer::OwnedImpl query_create_index("!@#$encr$#@!"); - BufferHelper::encodeHdr(query_create_index, 2); + BufferHelper::encodeHdr(query_create_index, 0); EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(query_create_index, false)); EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); } +/** + * Test MySQL Handshake with SSL Request + * State-machine moves to SSL-Terminate + * SM: greeting(p=10) -> challenge-req(v320) -> SSL_PT -> ChallengeReq -> Req -> ReqResp + */ +TEST_F(MySQLFilterTest, MySqlHandshakeSSLTerminateTest) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + EXPECT_EQ(1UL, config_->stats().sessions_.value()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + + std::string clogin_data = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_CALL(connection_, close(_)).Times(0); + + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, + filter_->onData(*client_login_data, false)); + EXPECT_EQ(0UL, config_->stats().login_attempts_.value()); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); + EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + + clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + client_login_data = Buffer::InstancePtr(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); + EXPECT_EQ(MySQLSession::State::ChallengeResp41, filter_->getSession().getState()); + + std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + + Buffer::OwnedImpl query_create_index("!@#$encr$#@!"); + BufferHelper::encodeHdr(query_create_index, 0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(query_create_index, false)); + EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); +} + /** * Test MySQL Handshake with protocol version 320 * Server responds with Auth Switch @@ -387,7 +498,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -398,7 +509,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -407,7 +518,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchTest) { std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_OK, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); } @@ -426,7 +537,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchAuthSwitchTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -437,7 +548,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchAuthSwitchTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -446,7 +557,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchAuthSwitchTest) { std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); } @@ -465,7 +576,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -476,7 +587,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -485,7 +596,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrTest) { std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_ERR, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::Resync, filter_->getSession().getState()); Command mysql_cmd_encode{}; @@ -516,7 +627,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchIncompleteRespcode) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -527,7 +638,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchIncompleteRespcode) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -536,7 +647,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchIncompleteRespcode) { std::string srv_resp_ok_data = encodeMessage(0, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); } @@ -555,7 +666,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrFailResync) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -566,7 +677,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrFailResync) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -575,7 +686,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrFailResync) { std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_ERR, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::Resync, filter_->getSession().getState()); Command mysql_cmd_encode{}; @@ -590,7 +701,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchErrFailResync) { } /** - * Negative Testing MySQL Handshake with protocol version 320 + * MySQL Handshake with protocol version 320 * Server responds with Auth Switch More * SM: greeting(p=10) -> challenge-req(v320) -> serv-resp-auth-switch -> * -> auth_switch_resp -> serv-resp-auth-switch-more @@ -604,7 +715,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandMore) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -615,7 +726,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandMore) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -624,8 +735,8 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandMore) { std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_MORE, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); - EXPECT_EQ(MySQLSession::State::AuthSwitchResp, filter_->getSession().getState()); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); + EXPECT_EQ(MySQLSession::State::AuthSwitchMore, filter_->getSession().getState()); } /** @@ -643,7 +754,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandUnhandled) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -654,7 +765,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandUnhandled) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); std::string auth_switch_resp = encodeAuthSwitchResp(); Buffer::InstancePtr client_switch_resp(new Buffer::OwnedImpl(auth_switch_resp)); @@ -663,7 +774,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchMoreandUnhandled) { std::string srv_resp_ok_data = encodeClientLoginResp(0x32, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); } @@ -681,24 +792,25 @@ TEST_F(MySQLFilterTest, MySqlHandshake41Ok2GreetTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string greeting_data2 = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data2(new Buffer::OwnedImpl(greeting_data2)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data2, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data2, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + EXPECT_EQ(0UL, config_->stats().login_attempts_.value()); EXPECT_EQ(1UL, config_->stats().protocol_errors_.value()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); - EXPECT_EQ(2UL, config_->stats().login_attempts_.value()); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); EXPECT_EQ(MySQLSession::State::ChallengeResp41, filter_->getSession().getState()); std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); } @@ -717,7 +829,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41Ok2CloginTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); @@ -735,7 +847,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41Ok2CloginTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); } @@ -760,7 +872,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake41OkOOOLoginTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); } @@ -786,12 +898,12 @@ TEST_F(MySQLFilterTest, MySqlHandshake41OkOOOFullLoginTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); EXPECT_EQ(2UL, config_->stats().protocol_errors_.value()); } @@ -811,12 +923,12 @@ TEST_F(MySQLFilterTest, MySqlHandshake41OkGreetingLoginOKTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); EXPECT_EQ(1UL, config_->stats_.protocol_errors_.value()); } @@ -835,7 +947,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320WrongCloginSeqTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", 2); @@ -860,7 +972,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchWrongSeqTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -876,12 +988,12 @@ TEST_F(MySQLFilterTest, MySqlHandshake320AuthSwitchWrongSeqTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_AUTH_SWITCH); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::AuthSwitchResp, filter_->getSession().getState()); std::string srv_resp_ok_data = encodeClientLoginResp(MYSQL_RESP_OK, 1); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::AuthSwitchResp, filter_->getSession().getState()); } @@ -900,7 +1012,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320WrongServerRespCode) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -911,7 +1023,7 @@ TEST_F(MySQLFilterTest, MySqlHandshake320WrongServerRespCode) { std::string srv_resp_ok_data = encodeClientLoginResp(0x53, 0); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); Buffer::OwnedImpl client_query_data; @@ -934,7 +1046,7 @@ TEST_F(MySQLFilterTest, MySqlWrongHdrPkt) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(0, "user1", CHALLENGE_SEQ_NUM); @@ -945,7 +1057,7 @@ TEST_F(MySQLFilterTest, MySqlWrongHdrPkt) { std::string srv_resp_ok_data = encodeClientLoginResp(0x53, 0); Buffer::InstancePtr server_resp_ok_data(new Buffer::OwnedImpl(srv_resp_ok_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_ok_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_ok_data, false)); EXPECT_EQ(MySQLSession::State::NotHandled, filter_->getSession().getState()); Buffer::OwnedImpl client_query_data("123"); @@ -968,7 +1080,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*greet_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); @@ -979,7 +1091,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*server_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); Command mysql_cmd_encode{}; @@ -996,7 +1108,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 1); Buffer::InstancePtr request_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*request_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*request_resp_data, false)); EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); mysql_cmd_encode.setCmd(Command::Cmd::Query); @@ -1012,7 +1124,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 1); Buffer::InstancePtr show_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*show_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*show_resp_data, false)); EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); mysql_cmd_encode.setCmd(Command::Cmd::Query); @@ -1029,7 +1141,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 1); Buffer::InstancePtr create_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*create_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*create_resp_data, false)); EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); mysql_cmd_encode.setCmd(Command::Cmd::Query); @@ -1047,7 +1159,7 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 1); Buffer::InstancePtr create_index_resp_data(new Buffer::OwnedImpl(srv_resp_data)); EXPECT_EQ(Envoy::Network::FilterStatus::Continue, - filter_->onData(*create_index_resp_data, false)); + filter_->onWrite(*create_index_resp_data, false)); EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); mysql_cmd_encode.setCmd(Command::Cmd::FieldList); @@ -1064,8 +1176,800 @@ TEST_F(MySQLFilterTest, MySqlLoginAndQueryTest) { srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 1); Buffer::InstancePtr field_list_resp_data(new Buffer::OwnedImpl(srv_resp_data)); - EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*field_list_resp_data, false)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*field_list_resp_data, false)); + EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); +} + +/** + * Test RSA mediation for caching_sha2_password full authentication flow. + * SSL termination + cache miss: client sends cleartext password over TLS, + * filter RSA-encrypts it for the plaintext upstream connection. + */ +TEST_F(MySQLFilterTest, MySqlCachingSha2FullAuthRsaMediation) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Step 1: Server greeting with caching_sha2_password plugin. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + + // Step 2: Client SSL request (seq=1). + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + + // Step 3: Client login (seq=2 from client, should be rewritten to seq=1 for server). + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeResp41, filter_->getSession().getState()); + + // Step 4: Server responds with AuthMoreData(0x04) = full auth required (raw seq=2). + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(MySQLSession::State::AuthSwitchMore, filter_->getSession().getState()); + EXPECT_EQ(RsaAuthState::WaitingClientPassword, filter_->getRsaAuthState()); + + // Step 5: Client sends cleartext password (seq=4 from client perspective). + // The filter should intercept this and inject a request-public-key packet. + std::string password = "secret"; + std::string pw_payload = password + '\0'; + std::string pw_data = encodeRawPacket(pw_payload, 4); + + Buffer::OwnedImpl captured_request_key; + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)) + .WillOnce([&](Buffer::Instance& buf, bool) { captured_request_key.add(buf); }); + + Buffer::InstancePtr pw_buf(new Buffer::OwnedImpl(pw_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*pw_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerKey, filter_->getRsaAuthState()); + + // Verify the request-public-key packet: [len=1][seq=3][0x02] + ASSERT_EQ(5u, captured_request_key.length()); + uint32_t req_len = 0; + uint8_t req_seq = 0; + BufferHelper::peekHdr(captured_request_key, req_len, req_seq); + EXPECT_EQ(1u, req_len); + EXPECT_EQ(3u, req_seq); + BufferHelper::consumeHdr(captured_request_key); + uint8_t req_code; + BufferHelper::readUint8(captured_request_key, req_code); + EXPECT_EQ(MYSQL_REQUEST_PUBLIC_KEY, req_code); + + // Step 6: Generate an RSA-2048 key pair for the test. + bssl::UniquePtr gen_ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + ASSERT_TRUE(gen_ctx); + ASSERT_GT(EVP_PKEY_keygen_init(gen_ctx.get()), 0); + ASSERT_GT(EVP_PKEY_CTX_set_rsa_keygen_bits(gen_ctx.get(), 2048), 0); + EVP_PKEY* raw_pkey = nullptr; + ASSERT_GT(EVP_PKEY_keygen(gen_ctx.get(), &raw_pkey), 0); + bssl::UniquePtr pkey(raw_pkey); + + // Extract public key PEM. + bssl::UniquePtr bio(BIO_new(BIO_s_mem())); + PEM_write_bio_PUBKEY(bio.get(), pkey.get()); + char* pem_data; + long pem_len = BIO_get_mem_data(bio.get(), &pem_data); + std::string pem_key(pem_data, pem_len); + + // Step 7: Server sends PEM key as AuthMoreData (raw seq=4). + // The filter should intercept this and inject the RSA-encrypted password. + Buffer::OwnedImpl captured_encrypted_pw; + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)) + .WillOnce([&](Buffer::Instance& buf, bool) { captured_encrypted_pw.add(buf); }); + + std::string key_packet = + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4); + Buffer::InstancePtr key_buf(new Buffer::OwnedImpl(key_packet)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*key_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerResult, filter_->getRsaAuthState()); + + // Verify the encrypted password packet header: [len=256][seq=5][encrypted_data] + ASSERT_GT(captured_encrypted_pw.length(), 4u); + uint32_t enc_len = 0; + uint8_t enc_seq = 0; + BufferHelper::peekHdr(captured_encrypted_pw, enc_len, enc_seq); + EXPECT_EQ(256u, enc_len); // RSA-2048 produces 256-byte ciphertext + EXPECT_EQ(5u, enc_seq); + + // Decrypt and verify the password XOR scramble. + BufferHelper::consumeHdr(captured_encrypted_pw); + std::string ciphertext; + BufferHelper::readStringBySize(captured_encrypted_pw, enc_len, ciphertext); + + bssl::UniquePtr dec_ctx(EVP_PKEY_CTX_new(pkey.get(), nullptr)); + ASSERT_TRUE(dec_ctx); + ASSERT_GT(EVP_PKEY_decrypt_init(dec_ctx.get()), 0); + ASSERT_GT(EVP_PKEY_CTX_set_rsa_padding(dec_ctx.get(), RSA_PKCS1_OAEP_PADDING), 0); + ASSERT_GT(EVP_PKEY_CTX_set_rsa_oaep_md(dec_ctx.get(), EVP_sha1()), 0); + + size_t plain_len = 0; + ASSERT_GT(EVP_PKEY_decrypt(dec_ctx.get(), nullptr, &plain_len, + reinterpret_cast(ciphertext.data()), + ciphertext.size()), + 0); + std::vector plaintext(plain_len); + ASSERT_GT(EVP_PKEY_decrypt(dec_ctx.get(), plaintext.data(), &plain_len, + reinterpret_cast(ciphertext.data()), + ciphertext.size()), + 0); + plaintext.resize(plain_len); + + // plaintext = (password + \0) XOR scramble (cyclic 20-byte) + std::vector scramble = getAuthPluginData20(); // 20 bytes of 0xff + ASSERT_EQ(pw_payload.size(), plaintext.size()); + for (size_t i = 0; i < plaintext.size(); i++) { + uint8_t expected = static_cast(pw_payload[i]) ^ scramble[i % scramble.size()]; + EXPECT_EQ(expected, plaintext[i]) << "mismatch at byte " << i; + } + + // Step 8: Server sends OK (raw seq=6, rewritten to seq=5 for client). + std::string srv_ok = encodeClientLoginResp(MYSQL_RESP_OK, 0, 6); + Buffer::InstancePtr ok_buf(new Buffer::OwnedImpl(srv_ok)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*ok_buf, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * Test that fast auth success (AuthMoreData 0x03) passes through without RSA mediation. + */ +TEST_F(MySQLFilterTest, MySqlCachingSha2FastAuthPassthrough) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL request. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + + // Client login. + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeResp41, filter_->getSession().getState()); + + // Server responds with AuthMoreData(0x03) = fast auth success. + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FAST_AUTH_SUCCESS}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(MySQLSession::State::AuthSwitchMore, filter_->getSession().getState()); + // RSA mediation should NOT be triggered. + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); + + // Server sends OK (raw seq=3: next seq after server's AuthMoreData at seq=2). + std::string srv_ok = encodeClientLoginResp(MYSQL_RESP_OK, 0, 3); + Buffer::InstancePtr ok_buf(new Buffer::OwnedImpl(srv_ok)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*ok_buf, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); +} + +/** + * Test RSA mediation when server returns ERR after encrypted password. + */ +TEST_F(MySQLFilterTest, MySqlCachingSha2FullAuthRsaErr) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL request. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + + // Client login. + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + + // AuthMoreData(0x04) = full auth required. + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(RsaAuthState::WaitingClientPassword, filter_->getRsaAuthState()); + + // Client password. + std::string pw_data = encodeRawPacket(std::string("secret") + '\0', 4); + + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)).Times(1); + Buffer::InstancePtr pw_buf(new Buffer::OwnedImpl(pw_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*pw_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerKey, filter_->getRsaAuthState()); + + // Generate RSA key pair and send PEM key. + bssl::UniquePtr gen_ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + ASSERT_TRUE(gen_ctx); + ASSERT_GT(EVP_PKEY_keygen_init(gen_ctx.get()), 0); + ASSERT_GT(EVP_PKEY_CTX_set_rsa_keygen_bits(gen_ctx.get(), 2048), 0); + EVP_PKEY* raw_pkey = nullptr; + ASSERT_GT(EVP_PKEY_keygen(gen_ctx.get(), &raw_pkey), 0); + bssl::UniquePtr pkey(raw_pkey); + bssl::UniquePtr bio(BIO_new(BIO_s_mem())); + PEM_write_bio_PUBKEY(bio.get(), pkey.get()); + char* pem_data; + long pem_len = BIO_get_mem_data(bio.get(), &pem_data); + std::string pem_key(pem_data, pem_len); + + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)).Times(1); + std::string key_packet = + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4); + Buffer::InstancePtr key_buf(new Buffer::OwnedImpl(key_packet)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*key_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerResult, filter_->getRsaAuthState()); + + // Server responds with ERR (raw seq=6). + std::string srv_err = encodeClientLoginResp(MYSQL_RESP_ERR, 0, 6); + Buffer::InstancePtr err_buf(new Buffer::OwnedImpl(srv_err)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*err_buf, false)); + EXPECT_EQ(1UL, config_->stats().login_failures_.value()); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * Test that RSA mediation is NOT triggered when downstream_ssl is DISABLE. + */ +TEST_F(MySQLFilterTest, MySqlCachingSha2NoTerminateSsl) { + initialize(); // downstream_ssl = DISABLE + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeReq, filter_->getSession().getState()); + + // Client login (no SSL in this path). + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(MySQLSession::State::ChallengeResp41, filter_->getSession().getState()); + + // AuthMoreData(0x04) from server. + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + // RSA mediation should NOT be triggered because downstream_ssl is DISABLE. + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +// ============================================================================= +// Extended coverage: SSL Terminated, SSL Passthrough, No-SSL with queries +// ============================================================================= + +/** + * SSL Terminated: basic login with mysql_native_password (no caching_sha2). + * Verifies seq rewriting without RSA mediation. + */ +TEST_F(MySQLFilterTest, MySqlSslTerminateNativePasswordLoginOk) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting (standard, no caching_sha2_password). + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL request. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); + + // Client login (seq=2, rewritten to seq=1). + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); + + // Server OK (seq=2, rewritten to seq=3). + std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * SSL Terminated: login followed by query execution. + * Verifies the filter works correctly after auth completes. + */ +TEST_F(MySQLFilterTest, MySqlSslTerminateLoginThenQuery) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Greeting + SSL + Login + OK (same as native password test above). + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + + std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + + // Now send a query — should be processed normally after auth. + Command mysql_cmd{}; + mysql_cmd.setCmd(Command::Cmd::Query); + mysql_cmd.setData("SELECT 1"); + Buffer::OwnedImpl query_data; + mysql_cmd.encode(query_data); + BufferHelper::encodeHdr(query_data, 0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(query_data, false)); EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); + EXPECT_EQ(1UL, config_->stats().queries_parsed_.value()); +} + +/** + * SSL Terminated: full auth RSA followed by query execution. + * End-to-end: greeting → SSL → login → AuthMore(0x04) → pw → RSA → OK → query. + */ +TEST_F(MySQLFilterTest, MySqlSslTerminateRsaThenQuery) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL + login. + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + Buffer::InstancePtr ssl_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + + Buffer::InstancePtr login_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*login_data, false)); + + // AuthMoreData(0x04). + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(RsaAuthState::WaitingClientPassword, filter_->getRsaAuthState()); + + // Client password → intercepted, request-public-key injected. + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)).Times(1); + Buffer::InstancePtr pw_buf( + new Buffer::OwnedImpl(encodeRawPacket(std::string("secret") + '\0', 4))); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*pw_buf, false)); + + // PEM key → intercepted, encrypted password injected. + bssl::UniquePtr gen_ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + EVP_PKEY_keygen_init(gen_ctx.get()); + EVP_PKEY_CTX_set_rsa_keygen_bits(gen_ctx.get(), 2048); + EVP_PKEY* raw_pkey = nullptr; + EVP_PKEY_keygen(gen_ctx.get(), &raw_pkey); + bssl::UniquePtr pkey(raw_pkey); + bssl::UniquePtr bio(BIO_new(BIO_s_mem())); + PEM_write_bio_PUBKEY(bio.get(), pkey.get()); + char* pem_data; + long pem_len = BIO_get_mem_data(bio.get(), &pem_data); + std::string pem_key(pem_data, pem_len); + + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)).Times(1); + Buffer::InstancePtr key_buf(new Buffer::OwnedImpl( + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*key_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerResult, filter_->getRsaAuthState()); + + // Server OK (raw seq=6). + Buffer::InstancePtr ok_buf(new Buffer::OwnedImpl(encodeClientLoginResp(MYSQL_RESP_OK, 0, 6))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*ok_buf, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); + + // Query after RSA auth — verify seq numbers are correct and filter works. + Command mysql_cmd{}; + mysql_cmd.setCmd(Command::Cmd::Query); + mysql_cmd.setData("SELECT 1"); + Buffer::OwnedImpl query_data; + mysql_cmd.encode(query_data); + BufferHelper::encodeHdr(query_data, 0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(query_data, false)); + EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); + EXPECT_EQ(1UL, config_->stats().queries_parsed_.value()); +} + +/** + * SSL Passthrough: server greeting with SSL → SSL pass-through → auth more data. + * Verifies that caching_sha2 auth more goes through unmodified in passthrough mode. + */ +TEST_F(MySQLFilterTest, MySqlSslPassthroughCachingSha2) { + initialize(); // downstream_ssl = DISABLE + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting (passes through). + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client SSL request → enters SslPt (passthrough) state. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*ssl_data, false)); + EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); + + // All further data is opaque (encrypted), filter just passes through. + Buffer::OwnedImpl encrypted_data("encrypted-login-packet-bytes"); + BufferHelper::encodeHdr(encrypted_data, CHALLENGE_SEQ_NUM + 1); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(encrypted_data, false)); + EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + + // RSA mediation never triggers. + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * SSL Terminated: startSecureTransport() fails → connection closed. + */ +TEST_F(MySQLFilterTest, MySqlSslTerminateStartTlsFails) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL request, but startSecureTransport fails. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(false)); + EXPECT_CALL(connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); +} + +/** + * No-SSL: caching_sha2 fast auth (0x03) without any SSL involved. + * Verifies the filter handles auth more data correctly without SSL termination. + */ +TEST_F(MySQLFilterTest, MySqlNoSslCachingSha2FastAuth) { + initialize(); // downstream_ssl = DISABLE + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client login (no SSL). + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); + + // Server responds with fast auth success (0x03). + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FAST_AUTH_SUCCESS}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(MySQLSession::State::AuthSwitchMore, filter_->getSession().getState()); + // No RSA mediation without SSL termination. + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); + + // Server OK. + std::string ok_data = encodeClientLoginResp(MYSQL_RESP_OK, 0, 3); + Buffer::InstancePtr ok_buf(new Buffer::OwnedImpl(ok_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*ok_buf, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); +} + +/** + * No-SSL: caching_sha2 full auth (0x04) without SSL termination. + * The filter should NOT intercept — full auth is handled by client/server directly. + */ +TEST_F(MySQLFilterTest, MySqlNoSslCachingSha2FullAuth) { + initialize(); // downstream_ssl = DISABLE + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client login (no SSL). + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + + // Server responds with full auth required (0x04). + std::string auth_more_data = + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM); + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl(auth_more_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + // RSA mediation NOT triggered (downstream_ssl is DISABLE). + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); + + // Client sends request-public-key (0x02) — passes through to server. + std::string req_key = encodeRawPacket(std::string(1, MYSQL_REQUEST_PUBLIC_KEY), 3); + Buffer::InstancePtr req_key_buf(new Buffer::OwnedImpl(req_key)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*req_key_buf, false)); + + // No interception — filter stays inactive. + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * No-SSL: plain login followed by query. + * Basic end-to-end without any SSL. + */ +TEST_F(MySQLFilterTest, MySqlNoSslLoginThenQuery) { + initialize(); // downstream_ssl = DISABLE + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); + + std::string srv_resp_data = encodeClientLoginResp(MYSQL_RESP_OK); + Buffer::InstancePtr server_resp_data(new Buffer::OwnedImpl(srv_resp_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*server_resp_data, false)); + EXPECT_EQ(MySQLSession::State::Req, filter_->getSession().getState()); + + // Query. + Command mysql_cmd{}; + mysql_cmd.setCmd(Command::Cmd::Query); + mysql_cmd.setData("CREATE TABLE t (id INT)"); + Buffer::OwnedImpl query_data; + mysql_cmd.encode(query_data); + BufferHelper::encodeHdr(query_data, 0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(query_data, false)); + EXPECT_EQ(MySQLSession::State::ReqResp, filter_->getSession().getState()); + EXPECT_EQ(1UL, config_->stats().queries_parsed_.value()); +} + +/** + * SSL REQUIRE mode: client that does NOT send SSL request gets rejected. + */ +TEST_F(MySQLFilterTest, MySqlSslRequireRejectsNonSslClient) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Server greeting. + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client sends login directly (no SSL request) — should be rejected. + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_CALL(connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); +} + +/** + * SSL REQUIRE mode: client that sends SSL request is accepted. + */ +TEST_F(MySQLFilterTest, MySqlSslRequireAcceptsSslClient) { + initialize(MySQLProxyProto::REQUIRE); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client sends SSL request — accepted. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + EXPECT_CALL(connection_, close(_)).Times(0); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + + // Client login after SSL — no rejection. + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); +} + +/** + * SSL ALLOW mode: client without SSL is accepted. + */ +TEST_F(MySQLFilterTest, MySqlSslAllowAcceptsNonSslClient) { + initialize(MySQLProxyProto::ALLOW); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client sends login directly (no SSL) — should be accepted in ALLOW mode. + std::string clogin_data = encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr client_login_data(new Buffer::OwnedImpl(clogin_data)); + EXPECT_CALL(connection_, close(_)).Times(0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*client_login_data, false)); + EXPECT_EQ(1UL, config_->stats().login_attempts_.value()); +} + +/** + * SSL ALLOW mode: client with SSL initiates TLS, RSA mediation works. + */ +TEST_F(MySQLFilterTest, MySqlSslAllowWithSslClientRsaMediation) { + initialize(MySQLProxyProto::ALLOW); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + // Greeting with caching_sha2_password. + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // SSL request — accepted in ALLOW mode. + EXPECT_CALL(connection_, startSecureTransport()).WillOnce(testing::Return(true)); + Buffer::InstancePtr ssl_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*ssl_data, false)); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); + + // Client login. + Buffer::InstancePtr login_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM + 1))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*login_data, false)); + + // AuthMoreData(0x04) — full auth required, should trigger RSA mediation. + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(RsaAuthState::WaitingClientPassword, filter_->getRsaAuthState()); + + // Client password — intercepted, request-public-key injected. + EXPECT_CALL(filter_callbacks_, injectReadDataToFilterChain(_, false)).Times(1); + Buffer::InstancePtr pw_buf( + new Buffer::OwnedImpl(encodeRawPacket(std::string("secret") + '\0', 4))); + EXPECT_EQ(Envoy::Network::FilterStatus::StopIteration, filter_->onData(*pw_buf, false)); + EXPECT_EQ(RsaAuthState::WaitingServerKey, filter_->getRsaAuthState()); +} + +/** + * SSL ALLOW mode: non-SSL client with caching_sha2 — NO RSA mediation. + * Client handles RSA directly with the server. + */ +TEST_F(MySQLFilterTest, MySqlSslAllowNonSslCachingSha2NoMediation) { + initialize(MySQLProxyProto::ALLOW); + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client sends login directly (no SSL). + Buffer::InstancePtr login_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM))); + EXPECT_CALL(connection_, close(_)).Times(0); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*login_data, false)); + + // AuthMoreData(0x04) — should NOT trigger RSA mediation (no SSL was terminated). + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * SSL DISABLE mode: caching_sha2 full auth passes through without mediation. + */ +TEST_F(MySQLFilterTest, MySqlSslDisableCachingSha2FullAuthNoMediation) { + initialize(); // DISABLE + + EXPECT_CALL(filter_callbacks_, connection()).WillRepeatedly(ReturnRef(connection_)); + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreetingCachingSha2(); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + Buffer::InstancePtr login_data( + new Buffer::OwnedImpl(encodeClientLogin(CLIENT_PROTOCOL_41, "user1", CHALLENGE_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*login_data, false)); + + // AuthMoreData(0x04) — no RSA mediation in DISABLE mode. + Buffer::InstancePtr auth_more(new Buffer::OwnedImpl( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*auth_more, false)); + EXPECT_EQ(RsaAuthState::Inactive, filter_->getRsaAuthState()); +} + +/** + * SSL DISABLE mode: SSL request passes through to server (passthrough behavior). + */ +TEST_F(MySQLFilterTest, MySqlSslDisableSslPassthrough) { + initialize(); // DISABLE + + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onNewConnection()); + + std::string greeting_data = encodeServerGreeting(MYSQL_PROTOCOL_10); + Buffer::InstancePtr greet_data(new Buffer::OwnedImpl(greeting_data)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onWrite(*greet_data, false)); + + // Client sends SSL request — in DISABLE mode, it passes through (SslPt state). + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + Buffer::InstancePtr ssl_data(new Buffer::OwnedImpl(ssl_req)); + EXPECT_EQ(Envoy::Network::FilterStatus::Continue, filter_->onData(*ssl_data, false)); + EXPECT_EQ(MySQLSession::State::SslPt, filter_->getSession().getState()); + EXPECT_EQ(1UL, config_->stats().upgraded_to_ssl_.value()); } } // namespace MySQLProxy diff --git a/contrib/mysql_proxy/filters/network/test/mysql_ssl_allow_test_config.yaml b/contrib/mysql_proxy/filters/network/test/mysql_ssl_allow_test_config.yaml new file mode 100644 index 0000000000000..b6f69f62f054d --- /dev/null +++ b/contrib/mysql_proxy/filters/network/test/mysql_ssl_allow_test_config.yaml @@ -0,0 +1,53 @@ +admin: + access_log: + - name: envoy.access_loggers.file + typed_config: + "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog + path: "{}" + address: + socket_address: + address: "{}" + port_value: 0 +static_resources: + clusters: + name: cluster_0 + connect_timeout: 2s + load_assignment: + cluster_name: cluster_0 + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: "{}" + port_value: 0 + listeners: + name: listener_0 + address: + socket_address: + address: "{}" + port_value: 0 + filter_chains: + - filters: + - name: mysql + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.mysql_proxy.v3.MySQLProxy + stat_prefix: mysql_stats + downstream_ssl: ALLOW + - name: tcp + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy + stat_prefix: tcp_stats + cluster: cluster_0 + transport_socket: + name: starttls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.starttls.v3.StartTlsConfig + cleartext_socket_config: + tls_socket_config: + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{}" + private_key: + filename: "{}" diff --git a/contrib/mysql_proxy/filters/network/test/mysql_ssl_disable_test_config.yaml b/contrib/mysql_proxy/filters/network/test/mysql_ssl_disable_test_config.yaml new file mode 100644 index 0000000000000..cf291957162bd --- /dev/null +++ b/contrib/mysql_proxy/filters/network/test/mysql_ssl_disable_test_config.yaml @@ -0,0 +1,40 @@ +admin: + access_log: + - name: envoy.access_loggers.file + typed_config: + "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog + path: "{}" + address: + socket_address: + address: "{}" + port_value: 0 +static_resources: + clusters: + name: cluster_0 + connect_timeout: 2s + load_assignment: + cluster_name: cluster_0 + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: "{}" + port_value: 0 + listeners: + name: listener_0 + address: + socket_address: + address: "{}" + port_value: 0 + filter_chains: + - filters: + - name: mysql + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.mysql_proxy.v3.MySQLProxy + stat_prefix: mysql_stats + - name: tcp + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy + stat_prefix: tcp_stats + cluster: cluster_0 diff --git a/contrib/mysql_proxy/filters/network/test/mysql_ssl_integration_test.cc b/contrib/mysql_proxy/filters/network/test/mysql_ssl_integration_test.cc new file mode 100644 index 0000000000000..adee0df8dc25f --- /dev/null +++ b/contrib/mysql_proxy/filters/network/test/mysql_ssl_integration_test.cc @@ -0,0 +1,904 @@ +#include "envoy/extensions/transport_sockets/raw_buffer/v3/raw_buffer.pb.h" + +#include "source/common/buffer/buffer_impl.h" +#include "source/common/network/connection_impl.h" +#include "source/extensions/transport_sockets/raw_buffer/config.h" + +#include "test/integration/fake_upstream.h" +#include "test/integration/integration.h" +#include "test/integration/ssl_utility.h" +#include "test/integration/utility.h" +#include "test/test_common/network_utility.h" + +#include "contrib/mysql_proxy/filters/network/source/mysql_codec.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_clogin_resp.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_command.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_codec_greeting.h" +#include "contrib/mysql_proxy/filters/network/source/mysql_utils.h" +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "mysql_test_utils.h" +#include "openssl/evp.h" +#include "openssl/pem.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace MySQLProxy { + +// Client connection that supports mid-stream TLS upgrade (startTLS). +// Adapted from test/extensions/transport_sockets/starttls/starttls_integration_test.cc. +class StartTlsClientConnection : public Network::ClientConnectionImpl { +public: + StartTlsClientConnection(Event::Dispatcher& dispatcher, + const Network::Address::InstanceConstSharedPtr& remote_address, + const Network::Address::InstanceConstSharedPtr& source_address, + Network::TransportSocketPtr&& transport_socket, + const Network::ConnectionSocket::OptionsSharedPtr& options) + : ClientConnectionImpl(dispatcher, remote_address, source_address, + std::move(transport_socket), options, nullptr) {} + + // Swap the raw transport socket for a TLS one and trigger the TLS handshake. + void upgradeToTls(Network::TransportSocketPtr&& tls_socket) { + transport_socket_ = std::move(tls_socket); + transport_socket_->setTransportSocketCallbacks(*this); + connecting_ = true; + ioHandle().activateFileEvents(Event::FileReadyType::Write); + } +}; + +class MySQLSSLIntegrationTest : public testing::TestWithParam, + public MySQLTestUtils, + public BaseIntegrationTest { + std::string mysqlSslConfig() { + return fmt::format( + fmt::runtime(TestEnvironment::readFileToStringForTest(TestEnvironment::runfilesPath( + "contrib/mysql_proxy/filters/network/test/mysql_ssl_require_test_config.yaml"))), + Platform::null_device_path, Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getAnyAddressString(GetParam()), + TestEnvironment::runfilesPath("test/config/integration/certs/servercert.pem"), + TestEnvironment::runfilesPath("test/config/integration/certs/serverkey.pem")); + } + +public: + MySQLSSLIntegrationTest() : BaseIntegrationTest(GetParam(), mysqlSslConfig()) { + skip_tag_extraction_rule_check_ = true; + } + + void initialize() override { + EXPECT_CALL(*mock_buffer_factory_, createBuffer_(_, _, _)) + .WillOnce(Invoke([&](std::function below_low, std::function above_high, + std::function above_overflow) -> Buffer::Instance* { + client_write_buffer_ = + new NiceMock(below_low, above_high, above_overflow); + ON_CALL(*client_write_buffer_, move(_)) + .WillByDefault(Invoke(client_write_buffer_, &MockWatermarkBuffer::baseMove)); + ON_CALL(*client_write_buffer_, drain(_)) + .WillByDefault(Invoke(client_write_buffer_, &MockWatermarkBuffer::trackDrains)); + return client_write_buffer_; + })) + .WillOnce(Invoke([&](std::function below_low, std::function above_high, + std::function above_overflow) -> Buffer::Instance* { + return new Buffer::WatermarkBuffer(below_low, above_high, above_overflow); + })); + + // Create raw buffer and TLS transport socket factories. + auto raw_config = + std::make_unique(); + auto raw_factory = + std::make_unique(); + cleartext_context_ = Network::UpstreamTransportSocketFactoryPtr{ + raw_factory->createTransportSocketFactory(*raw_config, factory_context_).value()}; + + tls_context_manager_ = std::make_unique( + server_factory_context_); + tls_context_ = Ssl::createClientSslTransportSocketFactory({}, *tls_context_manager_, *api_); + + payload_reader_ = std::make_shared(*dispatcher_); + + BaseIntegrationTest::initialize(); + + // Create client connection with raw cleartext transport socket. + Network::Address::InstanceConstSharedPtr address = + Ssl::getSslAddress(version_, lookupPort("listener_0")); + conn_ = std::make_unique( + *dispatcher_, address, Network::Address::InstanceConstSharedPtr(), + cleartext_context_->createTransportSocket(nullptr, nullptr), nullptr); + conn_->enableHalfClose(true); + conn_->addConnectionCallbacks(connect_callbacks_); + conn_->addReadFilter(payload_reader_); + } + + // Upgrade the client connection to TLS. + void upgradeClientToTls() { + conn_->upgradeToTls(tls_context_->createTransportSocket( + std::make_shared( + absl::string_view(""), std::vector(), std::vector()), + nullptr)); + connect_callbacks_.reset(); + while (!connect_callbacks_.connected() && !connect_callbacks_.closed()) { + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + } + ASSERT_TRUE(connect_callbacks_.connected()); + } + + // Write data to client connection and wait for it to be sent. + void clientWrite(const std::string& data) { + uint64_t prev_drained = client_write_buffer_->bytesDrained(); + Buffer::OwnedImpl buf(data); + conn_->write(buf, false); + while (client_write_buffer_->bytesDrained() < prev_drained + data.length()) { + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + } + } + + // Wait for specific payload from server (via Envoy) on the client side. + void clientWaitForData(const std::string& expected) { + payload_reader_->setDataToWaitFor(expected); + dispatcher_->run(Event::Dispatcher::RunType::Block); + } + + // Server greeting with caching_sha2_password and CLIENT_SSL. + std::string encodeServerGreetingCachingSha2() { + ServerGreeting greeting; + greeting.setProtocol(MYSQL_PROTOCOL_10); + greeting.setVersion(getVersion()); + greeting.setThreadId(MYSQL_THREAD_ID); + greeting.setAuthPluginData(getAuthPluginData20()); + greeting.setServerCap(CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION | CLIENT_SSL); + greeting.setServerCharset(MYSQL_SERVER_LANGUAGE); + greeting.setServerStatus(MYSQL_SERVER_STATUS); + greeting.setAuthPluginName("caching_sha2_password"); + Buffer::OwnedImpl buffer; + greeting.encode(buffer); + BufferHelper::encodeHdr(buffer, GREETING_SEQ_NUM); + return buffer.toString(); + } + + std::string encodeAuthMoreDataPacket(const std::vector& data, uint8_t seq) { + AuthMoreMessage auth_more; + auth_more.setAuthMoreData(data); + Buffer::OwnedImpl buffer; + auth_more.encode(buffer); + BufferHelper::encodeHdr(buffer, seq); + return buffer.toString(); + } + + std::string generateTestPemKey() { + bssl::UniquePtr gen_ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + EVP_PKEY_keygen_init(gen_ctx.get()); + EVP_PKEY_CTX_set_rsa_keygen_bits(gen_ctx.get(), 2048); + EVP_PKEY* raw_pkey = nullptr; + EVP_PKEY_keygen(gen_ctx.get(), &raw_pkey); + bssl::UniquePtr pkey(raw_pkey); + bssl::UniquePtr bio(BIO_new(BIO_s_mem())); + PEM_write_bio_PUBKEY(bio.get(), pkey.get()); + char* pem_data; + long pem_len = BIO_get_mem_data(bio.get(), &pem_data); + return std::string(pem_data, pem_len); + } + + std::unique_ptr tls_context_manager_; + Network::UpstreamTransportSocketFactoryPtr tls_context_; + Network::UpstreamTransportSocketFactoryPtr cleartext_context_; + MockWatermarkBuffer* client_write_buffer_{nullptr}; + ConnectionStatusCallbacks connect_callbacks_; + std::unique_ptr conn_; + std::shared_ptr payload_reader_; +}; + +INSTANTIATE_TEST_SUITE_P(IpVersions, MySQLSSLIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest())); + +/** + * caching_sha2_password fast auth (cache hit) with SSL termination. + * + * Client ←TLS→ Envoy ←plaintext→ FakeUpstream + * greeting → SSL req → TLS upgrade → login → AuthMoreData(0x03) → OK + */ +TEST_P(MySQLSSLIntegrationTest, CachingSha2FastAuth) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // 1. Server greeting (cleartext, before TLS). + std::string greeting = encodeServerGreetingCachingSha2(); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // 2. Client sends SSL request (cleartext). + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + clientWrite(ssl_req); + // SSL request consumed by filter; filter calls startSecureTransport. + + // 3. Client upgrades to TLS. + upgradeClientToTls(); + + // 4. Client sends login over TLS (seq=2, rewritten to seq=1 for upstream). + std::string login = encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1); + clientWrite(login); + + // Upstream receives the rewritten login in cleartext. + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // 5. Server responds with fast auth success (seq=2, rewritten to seq=3 for client). + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FAST_AUTH_SUCCESS}, CHALLENGE_RESP_SEQ_NUM))); + + // 6. Server sends OK (seq=3, rewritten to seq=4 for client). + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK, 0, 3))); + + // Wait for client to receive both responses. + // Use a short sleep + non-block dispatch to let data flow through. + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.upgraded_to_ssl", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.login_failures")->value(), 0); +} + +/** + * caching_sha2_password full auth (cache miss) with RSA mediation. + * + * greeting → SSL req → TLS → login → AuthMoreData(0x04) → client pw → + * [filter: request-public-key → PEM key → RSA encrypt] → OK + */ +TEST_P(MySQLSSLIntegrationTest, CachingSha2FullAuthRsaMediation) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // 1. Server greeting. + std::string greeting = encodeServerGreetingCachingSha2(); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // 2. SSL request + TLS upgrade. + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + + // 3. Client login over TLS. + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // 4. Server: full auth required (seq=2). + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + + // Wait for client to receive AuthMoreData (filter forwards it). + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + // 5. Client sends cleartext password (seq=4 from client perspective). + std::string pw_payload = std::string("testpass") + '\0'; + Buffer::OwnedImpl pw_pkt; + BufferHelper::addString(pw_pkt, pw_payload); + BufferHelper::encodeHdr(pw_pkt, 4); + clientWrite(pw_pkt.toString()); + + // 6. Filter intercepts password, sends request-public-key (0x02, seq=3) to upstream. + size_t prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 5; }, + &upstream_data)); + + // Verify request-public-key packet. + std::string req_key = upstream_data.substr(prev_len); + EXPECT_EQ(static_cast(req_key[3]), 3u); // seq + EXPECT_EQ(static_cast(req_key[4]), MYSQL_REQUEST_PUBLIC_KEY); + + // 7. Server sends PEM public key (seq=4). + std::string pem_key = generateTestPemKey(); + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4))); + + // 8. Filter RSA-encrypts password and sends to upstream (seq=5, 256 bytes). + prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 4 + 256; }, + &upstream_data)); + + std::string enc_pkt = upstream_data.substr(prev_len); + uint32_t enc_len = static_cast(enc_pkt[0]) | (static_cast(enc_pkt[1]) << 8) | + (static_cast(enc_pkt[2]) << 16); + EXPECT_EQ(enc_len, 256u); // RSA-2048 ciphertext + EXPECT_EQ(static_cast(enc_pkt[3]), 5u); // seq + + // 9. Server sends OK (raw seq=6, rewritten to seq=5 for client). + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK, 0, 6))); + + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.upgraded_to_ssl", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.login_failures")->value(), 0); +} + +/** + * Full auth — server rejects after RSA-encrypted password. + */ +TEST_P(MySQLSSLIntegrationTest, CachingSha2FullAuthRsaErr) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + std::string greeting = encodeServerGreetingCachingSha2(); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // Full auth required. + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + // Client password. + Buffer::OwnedImpl pw_pkt; + BufferHelper::addString(pw_pkt, std::string("wrongpw") + '\0'); + BufferHelper::encodeHdr(pw_pkt, 4); + clientWrite(pw_pkt.toString()); + + // Wait for request-public-key. + size_t prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 5; }, + &upstream_data)); + + // Send PEM key. + std::string pem_key = generateTestPemKey(); + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4))); + + // Wait for encrypted password. + prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 4 + 256; }, + &upstream_data)); + + // Server sends ERR (raw seq=6). + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_ERR, 0, 6))); + + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.login_failures", 1); +} + +/** + * SSL Terminated: basic login with native password (no caching_sha2), then query. + * Verifies seq rewriting works for the full lifecycle. + */ +TEST_P(MySQLSSLIntegrationTest, SslTerminateLoginThenQuery) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // Greeting (standard, not caching_sha2). + std::string greeting = encodeServerGreeting(MYSQL_PROTOCOL_10); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // SSL request + TLS upgrade. + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + + // Client login over TLS. + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // Server OK (seq=2, rewritten to seq=3 for client). + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK))); + + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + // Send a query (seq=0 after resetSeq). + Command mysql_cmd{}; + mysql_cmd.setCmd(Command::Cmd::Query); + mysql_cmd.setData("SELECT 1"); + Buffer::OwnedImpl query_buf; + mysql_cmd.encode(query_buf); + BufferHelper::encodeHdr(query_buf, 0); + clientWrite(query_buf.toString()); + + // Upstream should receive the query. + size_t prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() > prev_len; }, &upstream_data)); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.upgraded_to_ssl", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.queries_parsed", 1); +} + +/** + * SSL Terminated: RSA mediation followed by query execution. + * Full lifecycle: greeting → SSL → login → AuthMore(0x04) → RSA → OK → query. + */ +TEST_P(MySQLSSLIntegrationTest, CachingSha2FullAuthRsaThenQuery) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // Greeting + SSL + login. + std::string greeting = encodeServerGreetingCachingSha2(); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // Full auth required → client password → RSA. + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + Buffer::OwnedImpl pw_pkt; + BufferHelper::addString(pw_pkt, std::string("testpass") + '\0'); + BufferHelper::encodeHdr(pw_pkt, 4); + clientWrite(pw_pkt.toString()); + + // Wait for request-public-key. + size_t prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 5; }, + &upstream_data)); + + // Send PEM key + wait for encrypted password. + std::string pem_key = generateTestPemKey(); + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4))); + prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 4 + 256; }, + &upstream_data)); + + // Server OK (raw seq=6). + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK, 0, 6))); + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + // Now send a query — verifies seq numbers are correct after RSA mediation. + Command mysql_cmd{}; + mysql_cmd.setCmd(Command::Cmd::Query); + mysql_cmd.setData("SELECT 1"); + Buffer::OwnedImpl query_buf; + mysql_cmd.encode(query_buf); + BufferHelper::encodeHdr(query_buf, 0); + clientWrite(query_buf.toString()); + + prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() > prev_len; }, &upstream_data)); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.queries_parsed", 1); +} + +// ============================================================================= +// DISABLE mode integration tests — plain TCP proxy, no SSL termination. +// ============================================================================= + +class MySQLDisableIntegrationTest : public testing::TestWithParam, + public MySQLTestUtils, + public BaseIntegrationTest { + std::string mysqlConfig() { + return fmt::format( + fmt::runtime(TestEnvironment::readFileToStringForTest(TestEnvironment::runfilesPath( + "contrib/mysql_proxy/filters/network/test/mysql_ssl_disable_test_config.yaml"))), + Platform::null_device_path, Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getAnyAddressString(GetParam())); + } + +public: + MySQLDisableIntegrationTest() : BaseIntegrationTest(GetParam(), mysqlConfig()) { + skip_tag_extraction_rule_check_ = true; + } + + void SetUp() override { BaseIntegrationTest::initialize(); } +}; + +INSTANTIATE_TEST_SUITE_P(IpVersions, MySQLDisableIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest())); + +/** + * DISABLE mode: basic login, no SSL involved. + */ +TEST_P(MySQLDisableIntegrationTest, DisableBasicLogin) { + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // Greeting. + std::string greeting = encodeServerGreeting(MYSQL_PROTOCOL_10); + ASSERT_TRUE(fake_upstream->write(greeting)); + tcp_client->waitForData(greeting, true); + + // Client login (no SSL). + std::string login = encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM); + ASSERT_TRUE(tcp_client->write(login)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData(login.length(), &upstream_data)); + EXPECT_EQ(login, upstream_data); + + // Server OK. + std::string ok_resp = encodeClientLoginResp(MYSQL_RESP_OK); + ASSERT_TRUE(fake_upstream->write(ok_resp)); + + std::string downstream(greeting); + downstream.append(ok_resp); + tcp_client->waitForData(downstream, true); + + tcp_client->close(); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.login_failures")->value(), 0); +} + +/** + * DISABLE mode: SSL request passes through to upstream (passthrough). + */ +TEST_P(MySQLDisableIntegrationTest, DisableSslPassthrough) { + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // Greeting. + std::string greeting = encodeServerGreeting(MYSQL_PROTOCOL_10); + ASSERT_TRUE(fake_upstream->write(greeting)); + tcp_client->waitForData(greeting, true); + + // Client sends SSL request — in DISABLE mode, it passes through. + std::string ssl_req = encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM); + ASSERT_TRUE(tcp_client->write(ssl_req)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData(ssl_req.length(), &upstream_data)); + // The SSL request should be forwarded unmodified. + EXPECT_EQ(ssl_req, upstream_data); + + tcp_client->close(); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.upgraded_to_ssl")->value(), 1); +} + +// ============================================================================= +// ALLOW mode integration tests — terminate SSL if client requests, accept cleartext. +// Uses the same StartTlsClientConnection as the REQUIRE tests. +// ============================================================================= + +class MySQLAllowIntegrationTest : public testing::TestWithParam, + public MySQLTestUtils, + public BaseIntegrationTest { + std::string mysqlSslConfig() { + return fmt::format( + fmt::runtime(TestEnvironment::readFileToStringForTest(TestEnvironment::runfilesPath( + "contrib/mysql_proxy/filters/network/test/mysql_ssl_allow_test_config.yaml"))), + Platform::null_device_path, Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getLoopbackAddressString(GetParam()), + Network::Test::getAnyAddressString(GetParam()), + TestEnvironment::runfilesPath("test/config/integration/certs/servercert.pem"), + TestEnvironment::runfilesPath("test/config/integration/certs/serverkey.pem")); + } + +public: + MySQLAllowIntegrationTest() : BaseIntegrationTest(GetParam(), mysqlSslConfig()) { + skip_tag_extraction_rule_check_ = true; + } + + void initialize() override { + EXPECT_CALL(*mock_buffer_factory_, createBuffer_(_, _, _)) + .WillOnce(Invoke([&](std::function below_low, std::function above_high, + std::function above_overflow) -> Buffer::Instance* { + client_write_buffer_ = + new NiceMock(below_low, above_high, above_overflow); + ON_CALL(*client_write_buffer_, move(_)) + .WillByDefault(Invoke(client_write_buffer_, &MockWatermarkBuffer::baseMove)); + ON_CALL(*client_write_buffer_, drain(_)) + .WillByDefault(Invoke(client_write_buffer_, &MockWatermarkBuffer::trackDrains)); + return client_write_buffer_; + })) + .WillRepeatedly(Invoke([](std::function below_low, std::function above_high, + std::function above_overflow) -> Buffer::Instance* { + return new Buffer::WatermarkBuffer(below_low, above_high, above_overflow); + })); + + auto raw_config = + std::make_unique(); + auto raw_factory = + std::make_unique(); + cleartext_context_ = Network::UpstreamTransportSocketFactoryPtr{ + raw_factory->createTransportSocketFactory(*raw_config, factory_context_).value()}; + + tls_context_manager_ = std::make_unique( + server_factory_context_); + tls_context_ = Ssl::createClientSslTransportSocketFactory({}, *tls_context_manager_, *api_); + + payload_reader_ = std::make_shared(*dispatcher_); + + BaseIntegrationTest::initialize(); + + Network::Address::InstanceConstSharedPtr address = + Ssl::getSslAddress(version_, lookupPort("listener_0")); + conn_ = std::make_unique( + *dispatcher_, address, Network::Address::InstanceConstSharedPtr(), + cleartext_context_->createTransportSocket(nullptr, nullptr), nullptr); + conn_->enableHalfClose(true); + conn_->addConnectionCallbacks(connect_callbacks_); + conn_->addReadFilter(payload_reader_); + } + + void upgradeClientToTls() { + conn_->upgradeToTls(tls_context_->createTransportSocket( + std::make_shared( + absl::string_view(""), std::vector(), std::vector()), + nullptr)); + connect_callbacks_.reset(); + while (!connect_callbacks_.connected() && !connect_callbacks_.closed()) { + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + } + ASSERT_TRUE(connect_callbacks_.connected()); + } + + void clientWrite(const std::string& data) { + uint64_t prev_drained = client_write_buffer_->bytesDrained(); + Buffer::OwnedImpl buf(data); + conn_->write(buf, false); + while (client_write_buffer_->bytesDrained() < prev_drained + data.length()) { + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + } + } + + void clientWaitForData(const std::string& expected) { + payload_reader_->setDataToWaitFor(expected); + dispatcher_->run(Event::Dispatcher::RunType::Block); + } + + std::string encodeServerGreetingCachingSha2() { + ServerGreeting greeting; + greeting.setProtocol(MYSQL_PROTOCOL_10); + greeting.setVersion(getVersion()); + greeting.setThreadId(MYSQL_THREAD_ID); + greeting.setAuthPluginData(getAuthPluginData20()); + greeting.setServerCap(CLIENT_PLUGIN_AUTH | CLIENT_SECURE_CONNECTION | CLIENT_SSL); + greeting.setServerCharset(MYSQL_SERVER_LANGUAGE); + greeting.setServerStatus(MYSQL_SERVER_STATUS); + greeting.setAuthPluginName("caching_sha2_password"); + Buffer::OwnedImpl buffer; + greeting.encode(buffer); + BufferHelper::encodeHdr(buffer, GREETING_SEQ_NUM); + return buffer.toString(); + } + + std::string encodeAuthMoreDataPacket(const std::vector& data, uint8_t seq) { + AuthMoreMessage auth_more; + auth_more.setAuthMoreData(data); + Buffer::OwnedImpl buffer; + auth_more.encode(buffer); + BufferHelper::encodeHdr(buffer, seq); + return buffer.toString(); + } + + std::string generateTestPemKey() { + bssl::UniquePtr gen_ctx(EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr)); + EVP_PKEY_keygen_init(gen_ctx.get()); + EVP_PKEY_CTX_set_rsa_keygen_bits(gen_ctx.get(), 2048); + EVP_PKEY* raw_pkey = nullptr; + EVP_PKEY_keygen(gen_ctx.get(), &raw_pkey); + bssl::UniquePtr pkey(raw_pkey); + bssl::UniquePtr bio(BIO_new(BIO_s_mem())); + PEM_write_bio_PUBKEY(bio.get(), pkey.get()); + char* pem_data; + long pem_len = BIO_get_mem_data(bio.get(), &pem_data); + return std::string(pem_data, pem_len); + } + + std::unique_ptr tls_context_manager_; + Network::UpstreamTransportSocketFactoryPtr tls_context_; + Network::UpstreamTransportSocketFactoryPtr cleartext_context_; + MockWatermarkBuffer* client_write_buffer_{nullptr}; + ConnectionStatusCallbacks connect_callbacks_; + std::unique_ptr conn_; + std::shared_ptr payload_reader_; +}; + +INSTANTIATE_TEST_SUITE_P(IpVersions, MySQLAllowIntegrationTest, + testing::ValuesIn(TestEnvironment::getIpVersionsForTest())); + +/** + * ALLOW mode: SSL client connects, terminates SSL, basic login OK. + */ +TEST_P(MySQLAllowIntegrationTest, AllowSslClientLogin) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + std::string greeting = encodeServerGreeting(MYSQL_PROTOCOL_10); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // SSL request + TLS upgrade — accepted in ALLOW mode. + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK))); + + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.upgraded_to_ssl", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); +} + +/** + * ALLOW mode: non-SSL client connects in cleartext, login OK. + */ +TEST_P(MySQLAllowIntegrationTest, AllowNonSslClientLogin) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + std::string greeting = encodeServerGreeting(MYSQL_PROTOCOL_10); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // Client sends login directly (no SSL) — accepted in ALLOW mode. + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK))); + + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.upgraded_to_ssl")->value(), 0); +} + +/** + * ALLOW mode: SSL client with caching_sha2 full auth (RSA mediation). + */ +TEST_P(MySQLAllowIntegrationTest, AllowSslFullAuthRsaMediation) { + initialize(); + conn_->connect(); + + FakeRawConnectionPtr fake_upstream; + ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream)); + + // Greeting with caching_sha2_password. + std::string greeting = encodeServerGreetingCachingSha2(); + ASSERT_TRUE(fake_upstream->write(greeting)); + clientWaitForData(greeting); + + // SSL + login. + clientWrite(encodeClientLogin(CLIENT_SSL, "", CHALLENGE_SEQ_NUM)); + upgradeClientToTls(); + clientWrite(encodeClientLogin(CLIENT_PROTOCOL_41, "testuser", CHALLENGE_SEQ_NUM + 1)); + + std::string upstream_data; + ASSERT_TRUE(fake_upstream->waitForData([](const std::string& data) { return data.length() > 0; }, + &upstream_data)); + + // Full auth required. + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket({MYSQL_CACHINGSHA2_FULL_AUTH_REQUIRED}, CHALLENGE_RESP_SEQ_NUM))); + timeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + // Client password. + Buffer::OwnedImpl pw_pkt; + BufferHelper::addString(pw_pkt, std::string("testpass") + '\0'); + BufferHelper::encodeHdr(pw_pkt, 4); + clientWrite(pw_pkt.toString()); + + // Filter should inject request-public-key (0x02, seq=3). + size_t prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 5; }, + &upstream_data)); + + std::string req_key = upstream_data.substr(prev_len); + EXPECT_EQ(static_cast(req_key[3]), 3u); + EXPECT_EQ(static_cast(req_key[4]), MYSQL_REQUEST_PUBLIC_KEY); + + // Send PEM key. + std::string pem_key = generateTestPemKey(); + ASSERT_TRUE(fake_upstream->write( + encodeAuthMoreDataPacket(std::vector(pem_key.begin(), pem_key.end()), 4))); + + // Wait for encrypted password (256 bytes). + prev_len = upstream_data.length(); + ASSERT_TRUE(fake_upstream->waitForData( + [prev_len](const std::string& data) { return data.length() >= prev_len + 4 + 256; }, + &upstream_data)); + + // Server OK. + ASSERT_TRUE(fake_upstream->write(encodeClientLoginResp(MYSQL_RESP_OK, 0, 6))); + timeSystem().advanceTimeWait(std::chrono::milliseconds(500)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + conn_->close(Network::ConnectionCloseType::FlushWrite); + ASSERT_TRUE(fake_upstream->waitForDisconnect()); + + test_server_->waitForCounterGe("mysql.mysql_stats.upgraded_to_ssl", 1); + test_server_->waitForCounterGe("mysql.mysql_stats.login_attempts", 1); + EXPECT_EQ(test_server_->counter("mysql.mysql_stats.login_failures")->value(), 0); +} + +} // namespace MySQLProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/contrib/mysql_proxy/filters/network/test/mysql_ssl_require_test_config.yaml b/contrib/mysql_proxy/filters/network/test/mysql_ssl_require_test_config.yaml new file mode 100644 index 0000000000000..e5536a29fc2fd --- /dev/null +++ b/contrib/mysql_proxy/filters/network/test/mysql_ssl_require_test_config.yaml @@ -0,0 +1,53 @@ +admin: + access_log: + - name: envoy.access_loggers.file + typed_config: + "@type": type.googleapis.com/envoy.extensions.access_loggers.file.v3.FileAccessLog + path: "{}" + address: + socket_address: + address: "{}" + port_value: 0 +static_resources: + clusters: + name: cluster_0 + connect_timeout: 2s + load_assignment: + cluster_name: cluster_0 + endpoints: + - lb_endpoints: + - endpoint: + address: + socket_address: + address: "{}" + port_value: 0 + listeners: + name: listener_0 + address: + socket_address: + address: "{}" + port_value: 0 + filter_chains: + - filters: + - name: mysql + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.mysql_proxy.v3.MySQLProxy + stat_prefix: mysql_stats + downstream_ssl: REQUIRE + - name: tcp + typed_config: + "@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy + stat_prefix: tcp_stats + cluster: cluster_0 + transport_socket: + name: starttls + typed_config: + "@type": type.googleapis.com/envoy.extensions.transport_sockets.starttls.v3.StartTlsConfig + cleartext_socket_config: + tls_socket_config: + common_tls_context: + tls_certificates: + certificate_chain: + filename: "{}" + private_key: + filename: "{}"