diff --git a/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/AmqpConnection.java b/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/AmqpConnection.java index 37074156..39ad661d 100644 --- a/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/AmqpConnection.java +++ b/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/AmqpConnection.java @@ -18,6 +18,7 @@ import com.google.common.annotations.VisibleForTesting; import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; @@ -28,6 +29,8 @@ import java.lang.reflect.Method; import java.lang.reflect.Proxy; import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -37,6 +40,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.bookkeeper.util.collections.ConcurrentLongLongHashMap; import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang3.tuple.Pair; import org.apache.pulsar.broker.authentication.AuthenticationProvider; import org.apache.pulsar.broker.authentication.AuthenticationState; import org.apache.pulsar.broker.namespace.LookupOptions; @@ -109,6 +113,7 @@ enum ConnectionState { private ServerCnx pulsarServerCnx; private AmqpBrokerService amqpBrokerService; private AuthenticationState authenticationState; + public final static String SUPPORT_MECHANISM = "PLAIN AMQPLAIN token"; public AmqpConnection(AmqpServiceConfiguration amqpConfig, AmqpBrokerService amqpBrokerService) { @@ -201,7 +206,7 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString } String authMethod = String.valueOf(mechanism); - if (authMethod.equals("PLAIN")) { + if (authMethod.equals("PLAIN") || authMethod.equals("AMQPLAIN")) { authMethod = "basic"; } @@ -219,18 +224,14 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString AuthData authData; if (authMethod.equals("basic")) { // Original format: \000USERNAME\000PASSWORD - String splitter = "\000"; - String[] data = StringUtils.stripStart(new String(response, StandardCharsets.UTF_8), splitter) - .split(splitter); - if (data.length != 2) { - if (log.isDebugEnabled()) { - log.debug("Authentication data format error: mechanism={}", mechanism); - } - sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0); + // Encode data to Pulsar format: USERNAME:PASSWORD + Pair userAndPw = getUserAndPwForBasicAuth(response, String.valueOf(mechanism)); + if (userAndPw == null) { + // The userAndPw is null indicates auth data is invalid. return; } - // Encode data to Pulsar format: USERNAME:PASSWORD - authData = AuthData.of(String.format("%s:%s", data[0], data[1]).getBytes(StandardCharsets.UTF_8)); + authData = AuthData.of(String.format("%s:%s", userAndPw.getLeft(), userAndPw.getRight()) + .getBytes(StandardCharsets.UTF_8)); } else { authData = AuthData.of(response); } @@ -257,6 +258,42 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString state = ConnectionState.AWAIT_TUNE_OK; } + private Pair getUserAndPwForBasicAuth(byte[] response, String mechanism) { + if ("PLAIN".equals(mechanism)) { + String splitter = "\000"; + String[] data = StringUtils.stripStart(new String(response, StandardCharsets.UTF_8), splitter) + .split(splitter); + if (data.length != 2) { + log.error("Authentication data format error: mechanism={}", mechanism); + sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0); + return null; + } + return Pair.of(data[0], data[1]); + } else if ("AMQPLAIN".equals(mechanism)) { + Map dataMap = new HashMap<>(); + ByteBuf byteBuf = Unpooled.wrappedBuffer(response); + while (byteBuf.isReadable()) { + byte[] keyData = new byte[byteBuf.readByte()]; + byteBuf.readBytes(keyData); + byteBuf.readByte(); + byte[] valueData = new byte[byteBuf.readInt()]; + byteBuf.readBytes(valueData); + dataMap.put(new String(keyData), new String(valueData)); + } + if (!dataMap.containsKey("LOGIN") || !dataMap.containsKey("PASSWORD")) { + log.error("Authentication data format error: mechanism={}", mechanism); + sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0); + return null; + } + return Pair.of(dataMap.get("LOGIN"), dataMap.get("PASSWORD")); + } else { + log.error("Authentication data format error: unsupported mechanism={}", mechanism); + sendConnectionClose(ErrorCodes.CONNECTION_FORCED, + "Authentication data format error, unsupported mechanism", 0); + return null; + } + } + @Override public void receiveConnectionSecureOk(byte[] response) { if (log.isDebugEnabled()) { @@ -464,7 +501,7 @@ public void receiveProtocolHeader(ProtocolInitiation pi) { (short) pv.getActualMinorVersion(), null, // TODO temporary modification - "PLAIN token".getBytes(US_ASCII), + SUPPORT_MECHANISM.getBytes(US_ASCII), "en_US".getBytes(US_ASCII)); writeFrame(responseBody.generateFrame(0)); state = ConnectionState.AWAIT_START_OK; diff --git a/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/proxy/ProxyConnection.java b/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/proxy/ProxyConnection.java index 22b6e695..6536de19 100644 --- a/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/proxy/ProxyConnection.java +++ b/amqp-impl/src/main/java/io/streamnative/pulsar/handlers/amqp/proxy/ProxyConnection.java @@ -14,6 +14,7 @@ package io.streamnative.pulsar.handlers.amqp.proxy; import static com.google.common.base.Preconditions.checkState; +import static io.streamnative.pulsar.handlers.amqp.AmqpConnection.SUPPORT_MECHANISM; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.UTF_8; @@ -172,7 +173,7 @@ public void receiveProtocolHeader(ProtocolInitiation protocolInitiation) { (short) pv.getActualMinorVersion(), null, // TODO temporary modification - "PLAIN token".getBytes(US_ASCII), + SUPPORT_MECHANISM.getBytes(US_ASCII), "en_US".getBytes(US_ASCII)); writeFrame(responseBody.generateFrame(0)); } catch (QpidException e) { diff --git a/tests/src/test/java/io/streamnative/pulsar/handlers/amqp/rabbitmq/authentication/PlainAuthenticationTest.java b/tests/src/test/java/io/streamnative/pulsar/handlers/amqp/rabbitmq/authentication/PlainAuthenticationTest.java index 5daa5d91..de263bfc 100644 --- a/tests/src/test/java/io/streamnative/pulsar/handlers/amqp/rabbitmq/authentication/PlainAuthenticationTest.java +++ b/tests/src/test/java/io/streamnative/pulsar/handlers/amqp/rabbitmq/authentication/PlainAuthenticationTest.java @@ -19,7 +19,10 @@ import com.rabbitmq.client.Channel; import com.rabbitmq.client.Connection; import com.rabbitmq.client.ConnectionFactory; +import com.rabbitmq.client.LongString; import com.rabbitmq.client.PossibleAuthenticationFailureException; +import com.rabbitmq.client.SaslMechanism; +import com.rabbitmq.client.impl.LongStringHelper; import io.streamnative.pulsar.handlers.amqp.AmqpTokenAuthenticationTestBase; import java.io.IOException; import java.util.concurrent.TimeoutException; @@ -30,6 +33,7 @@ * PlainAuthenticationTest tests the plain authentication. */ public class PlainAuthenticationTest extends AmqpTokenAuthenticationTestBase { + private void testConnect(int port) throws Exception { ConnectionFactory connectionFactory = new ConnectionFactory(); connectionFactory.setHost("localhost"); @@ -53,6 +57,30 @@ public void testConnectToProxy() throws Exception { testConnect(getAopProxyPortList().get(0)); } + @Test + public void testAMQPLAIN() throws IOException, TimeoutException { + ConnectionFactory factory = new ConnectionFactory(); + factory.setPort(getAopProxyPortList().get(0)); + factory.setVirtualHost("vhost1"); + factory.setSaslConfig(mechanisms -> new SaslMechanism() { + @Override + public String getName() { + return "AMQPLAIN"; + } + + @Override + public LongString handleChallenge(LongString challenge, String username, String password) { + byte[][] data = new byte[][]{{ + 5, 76, 79, 71, 73, 78, 83, 0, 0, 0, 10, 115, 117, 112, 101, 114, 85, 115, + 101, 114, 50, 8, 80, 65, 83, 83, 87, 79, 82, 68, 83, 0, 0, 0, 13, 115, 117, + 112, 101, 114, 112, 97, 115, 115, 119, 111, 114, 100}}; + return LongStringHelper.asLongString(data[0]); + } + }); + Connection connection = factory.newConnection(); + connection.close(); + } + private void testConnectWithInvalidToken(int port, boolean isProxy) throws IOException, TimeoutException { ConnectionFactory connectionFactory = new ConnectionFactory(); connectionFactory.setHost("localhost");