Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import com.google.common.annotations.VisibleForTesting;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
Expand All @@ -28,6 +29,8 @@
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -37,6 +40,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.bookkeeper.util.collections.ConcurrentLongLongHashMap;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationState;
import org.apache.pulsar.broker.namespace.LookupOptions;
Expand Down Expand Up @@ -109,6 +113,7 @@ enum ConnectionState {
private ServerCnx pulsarServerCnx;
private AmqpBrokerService amqpBrokerService;
private AuthenticationState authenticationState;
public final static String SUPPORT_MECHANISM = "PLAIN AMQPLAIN token";

public AmqpConnection(AmqpServiceConfiguration amqpConfig,
AmqpBrokerService amqpBrokerService) {
Expand Down Expand Up @@ -201,7 +206,7 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString
}

String authMethod = String.valueOf(mechanism);
if (authMethod.equals("PLAIN")) {
if (authMethod.equals("PLAIN") || authMethod.equals("AMQPLAIN")) {
authMethod = "basic";
}

Expand All @@ -219,18 +224,14 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString
AuthData authData;
if (authMethod.equals("basic")) {
// Original format: \000USERNAME\000PASSWORD
String splitter = "\000";
String[] data = StringUtils.stripStart(new String(response, StandardCharsets.UTF_8), splitter)
.split(splitter);
if (data.length != 2) {
if (log.isDebugEnabled()) {
log.debug("Authentication data format error: mechanism={}", mechanism);
}
sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0);
// Encode data to Pulsar format: USERNAME:PASSWORD
Pair<String, String> userAndPw = getUserAndPwForBasicAuth(response, String.valueOf(mechanism));
if (userAndPw == null) {
// The userAndPw is null indicates auth data is invalid.
return;
}
// Encode data to Pulsar format: USERNAME:PASSWORD
authData = AuthData.of(String.format("%s:%s", data[0], data[1]).getBytes(StandardCharsets.UTF_8));
authData = AuthData.of(String.format("%s:%s", userAndPw.getLeft(), userAndPw.getRight())
.getBytes(StandardCharsets.UTF_8));
} else {
authData = AuthData.of(response);
}
Expand All @@ -257,6 +258,42 @@ public void receiveConnectionStartOk(FieldTable clientProperties, AMQShortString
state = ConnectionState.AWAIT_TUNE_OK;
}

private Pair<String, String> getUserAndPwForBasicAuth(byte[] response, String mechanism) {
if ("PLAIN".equals(mechanism)) {
String splitter = "\000";
String[] data = StringUtils.stripStart(new String(response, StandardCharsets.UTF_8), splitter)
.split(splitter);
if (data.length != 2) {
log.error("Authentication data format error: mechanism={}", mechanism);
sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0);
return null;
}
return Pair.of(data[0], data[1]);
} else if ("AMQPLAIN".equals(mechanism)) {
Map<String, String> dataMap = new HashMap<>();
ByteBuf byteBuf = Unpooled.wrappedBuffer(response);
while (byteBuf.isReadable()) {
byte[] keyData = new byte[byteBuf.readByte()];
byteBuf.readBytes(keyData);
byteBuf.readByte();
byte[] valueData = new byte[byteBuf.readInt()];
byteBuf.readBytes(valueData);
dataMap.put(new String(keyData), new String(valueData));
}
if (!dataMap.containsKey("LOGIN") || !dataMap.containsKey("PASSWORD")) {
log.error("Authentication data format error: mechanism={}", mechanism);
sendConnectionClose(ErrorCodes.CONNECTION_FORCED, "Authentication data format error", 0);
return null;
}
return Pair.of(dataMap.get("LOGIN"), dataMap.get("PASSWORD"));
} else {
log.error("Authentication data format error: unsupported mechanism={}", mechanism);
sendConnectionClose(ErrorCodes.CONNECTION_FORCED,
"Authentication data format error, unsupported mechanism", 0);
return null;
}
}

@Override
public void receiveConnectionSecureOk(byte[] response) {
if (log.isDebugEnabled()) {
Expand Down Expand Up @@ -464,7 +501,7 @@ public void receiveProtocolHeader(ProtocolInitiation pi) {
(short) pv.getActualMinorVersion(),
null,
// TODO temporary modification
"PLAIN token".getBytes(US_ASCII),
SUPPORT_MECHANISM.getBytes(US_ASCII),
"en_US".getBytes(US_ASCII));
writeFrame(responseBody.generateFrame(0));
state = ConnectionState.AWAIT_START_OK;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.streamnative.pulsar.handlers.amqp.proxy;

import static com.google.common.base.Preconditions.checkState;
import static io.streamnative.pulsar.handlers.amqp.AmqpConnection.SUPPORT_MECHANISM;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;

Expand Down Expand Up @@ -172,7 +173,7 @@ public void receiveProtocolHeader(ProtocolInitiation protocolInitiation) {
(short) pv.getActualMinorVersion(),
null,
// TODO temporary modification
"PLAIN token".getBytes(US_ASCII),
SUPPORT_MECHANISM.getBytes(US_ASCII),
"en_US".getBytes(US_ASCII));
writeFrame(responseBody.generateFrame(0));
} catch (QpidException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.Connection;
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.LongString;
import com.rabbitmq.client.PossibleAuthenticationFailureException;
import com.rabbitmq.client.SaslMechanism;
import com.rabbitmq.client.impl.LongStringHelper;
import io.streamnative.pulsar.handlers.amqp.AmqpTokenAuthenticationTestBase;
import java.io.IOException;
import java.util.concurrent.TimeoutException;
Expand All @@ -30,6 +33,7 @@
* PlainAuthenticationTest tests the plain authentication.
*/
public class PlainAuthenticationTest extends AmqpTokenAuthenticationTestBase {

private void testConnect(int port) throws Exception {
ConnectionFactory connectionFactory = new ConnectionFactory();
connectionFactory.setHost("localhost");
Expand All @@ -53,6 +57,30 @@ public void testConnectToProxy() throws Exception {
testConnect(getAopProxyPortList().get(0));
}

@Test
public void testAMQPLAIN() throws IOException, TimeoutException {
ConnectionFactory factory = new ConnectionFactory();
factory.setPort(getAopProxyPortList().get(0));
factory.setVirtualHost("vhost1");
factory.setSaslConfig(mechanisms -> new SaslMechanism() {
@Override
public String getName() {
return "AMQPLAIN";
}

@Override
public LongString handleChallenge(LongString challenge, String username, String password) {
byte[][] data = new byte[][]{{
5, 76, 79, 71, 73, 78, 83, 0, 0, 0, 10, 115, 117, 112, 101, 114, 85, 115,
101, 114, 50, 8, 80, 65, 83, 83, 87, 79, 82, 68, 83, 0, 0, 0, 13, 115, 117,
112, 101, 114, 112, 97, 115, 115, 119, 111, 114, 100}};
return LongStringHelper.asLongString(data[0]);
}
});
Connection connection = factory.newConnection();
connection.close();
}

private void testConnectWithInvalidToken(int port, boolean isProxy) throws IOException, TimeoutException {
ConnectionFactory connectionFactory = new ConnectionFactory();
connectionFactory.setHost("localhost");
Expand Down