diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java index 1fcc317f9d..997caea594 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/metrics/ZeroCopyMetrics.java @@ -17,6 +17,9 @@ */ package org.apache.ratis.grpc.metrics; +import java.util.function.Supplier; + +import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller.Metrics; import org.apache.ratis.metrics.LongCounter; import org.apache.ratis.metrics.MetricRegistryInfo; import org.apache.ratis.metrics.RatisMetricRegistry; @@ -24,8 +27,6 @@ import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting; import org.apache.ratis.thirdparty.com.google.protobuf.AbstractMessage; -import java.util.function.Supplier; - public class ZeroCopyMetrics extends RatisMetrics { private static final String RATIS_GRPC_METRICS_APP_NAME = "ratis_grpc"; private static final String RATIS_GRPC_METRICS_COMP_NAME = "zero_copy"; @@ -35,6 +36,20 @@ public class ZeroCopyMetrics extends RatisMetrics { private final LongCounter nonZeroCopyMessages = getRegistry().counter("num_non_zero_copy_messages"); private final LongCounter releasedMessages = getRegistry().counter("num_released_messages"); + // Per-message-type zero-copy counters. + private final LongCounter zeroCopyAppendEntries = getRegistry().counter("num_zero_copy_append_entries"); + private final LongCounter zeroCopyInstallSnapshot = getRegistry().counter("num_zero_copy_install_snapshot"); + private final LongCounter zeroCopyClientRequest = getRegistry().counter("num_zero_copy_client_request"); + + // Aggregated savings and parse time (nanos) for zero-copy path. + private final LongCounter bytesSavedByZeroCopy = getRegistry().counter("bytes_saved_by_zero_copy"); + private final LongCounter zeroCopyParseTimeNanos = getRegistry().counter("zero_copy_parse_time_nanos"); + + // Reason counters for zero-copy fallback. + private final LongCounter fallbackNotKnownLength = getRegistry().counter("zero_copy_fallback_not_known_length"); + private final LongCounter fallbackNotDetachable = getRegistry().counter("zero_copy_fallback_not_detachable"); + private final LongCounter fallbackNotByteBuffer = getRegistry().counter("zero_copy_fallback_not_byte_buffer"); + public ZeroCopyMetrics() { super(createRegistry()); } @@ -54,6 +69,21 @@ public void onZeroCopyMessage(AbstractMessage ignored) { zeroCopyMessages.inc(); } + public void onZeroCopyAppendEntries(AbstractMessage ignored) { + onZeroCopyMessage(ignored); + zeroCopyAppendEntries.inc(); + } + + public void onZeroCopyInstallSnapshot(AbstractMessage ignored) { + onZeroCopyMessage(ignored); + zeroCopyInstallSnapshot.inc(); + } + + public void onZeroCopyClientRequest(AbstractMessage ignored) { + onZeroCopyMessage(ignored); + zeroCopyClientRequest.inc(); + } + public void onNonZeroCopyMessage(AbstractMessage ignored) { nonZeroCopyMessages.inc(); } @@ -62,6 +92,34 @@ public void onReleasedMessage(AbstractMessage ignored) { releasedMessages.inc(); } + public ZeroCopyMessageMarshallerMetrics newMarshallerMetrics() { + return new ZeroCopyMessageMarshallerMetrics(); + } + + // Adapter used by ZeroCopyMessageMarshaller to report parse stats and fallback reasons. + public class ZeroCopyMessageMarshallerMetrics implements Metrics { + @Override + public void onZeroCopyParse(long bytesSaved, long parseTimeNanos) { + bytesSavedByZeroCopy.inc(bytesSaved); + zeroCopyParseTimeNanos.inc(parseTimeNanos); + } + + @Override + public void onFallbackNotKnownLength() { + fallbackNotKnownLength.inc(); + } + + @Override + public void onFallbackNotDetachable() { + fallbackNotDetachable.inc(); + } + + @Override + public void onFallbackNotByteBuffer() { + fallbackNotByteBuffer.inc(); + } + } + @VisibleForTesting public long zeroCopyMessages() { return zeroCopyMessages.getCount(); @@ -76,4 +134,4 @@ public long nonZeroCopyMessages() { public long releasedMessages() { return releasedMessages.getCount(); } -} \ No newline at end of file +} diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java index b7548780cd..04a6fcef2c 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcClientProtocolService.java @@ -161,7 +161,8 @@ void closeAllExisting(RaftGroupId groupId) { this.executor = executor; this.zeroCopyEnabled = zeroCopyEnabled; this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(RaftClientRequestProto.getDefaultInstance(), - zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage); + zeroCopyMetrics::onZeroCopyClientRequest, zeroCopyMetrics::onNonZeroCopyMessage, + zeroCopyMetrics::onReleasedMessage, zeroCopyMetrics.newMarshallerMetrics()); zeroCopyMetrics.addUnreleased("client_protocol", zeroCopyRequestMarshaller::getUnclosedCount); } diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java index 7e17cb3cf4..729b907394 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/server/GrpcServerProtocolService.java @@ -47,6 +47,7 @@ import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller; import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getAppendEntriesMethod; +import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getInstallSnapshotMethod; class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase { public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class); @@ -59,6 +60,7 @@ private enum BatchLogKey implements BatchLogger.Key { static class PendingServerRequest { private final AtomicReference> requestRef; private final CompletableFuture future = new CompletableFuture<>(); + private volatile String requestString; PendingServerRequest(ReferenceCountedObject requestRef) { requestRef.retain(); @@ -71,6 +73,14 @@ REQUEST getRequest() { .orElse(null); } + void setRequestString(String requestString) { + this.requestString = requestString; + } + + String getRequestString() { + return requestString; + } + CompletableFuture getFuture() { return future; } @@ -104,8 +114,7 @@ String getName() { private String getPreviousRequestString() { return Optional.ofNullable(previousOnNext.get()) - .map(PendingServerRequest::getRequest) - .map(this::requestToString) + .map(PendingServerRequest::getRequestString) .orElse(null); } @@ -178,12 +187,17 @@ public void onNext(REQUEST request) { } final PendingServerRequest current = new PendingServerRequest<>(requestRef); - final long callId = getCallId(current.getRequest()); - final boolean isHeartbeat = isHeartbeat(current.getRequest()); - final Optional> previous = Optional.ofNullable(previousOnNext.getAndSet(current)); - final CompletableFuture previousFuture = previous.map(PendingServerRequest::getFuture) - .orElse(CompletableFuture.completedFuture(null)); + current.getFuture().whenComplete((r, e) -> current.release()); + final REQUEST currentRequest = current.getRequest(); + final long callId = getCallId(currentRequest); + final boolean isHeartbeat = isHeartbeat(currentRequest); + Optional> previous = Optional.empty(); + CompletableFuture previousFuture = CompletableFuture.completedFuture(null); try { + current.setRequestString(requestToString(currentRequest)); + previous = Optional.ofNullable(previousOnNext.getAndSet(current)); + previousFuture = previous.map(PendingServerRequest::getFuture) + .orElse(CompletableFuture.completedFuture(null)); final CompletableFuture f = process(requestRef).exceptionally(e -> { // Handle cases, such as RaftServer is paused handleError(e, callId, isHeartbeat); @@ -243,6 +257,7 @@ private void releaseLast() { private final RaftServer server; private final boolean zeroCopyEnabled; private final ZeroCopyMessageMarshaller zeroCopyRequestMarshaller; + private final ZeroCopyMessageMarshaller zeroCopyInstallSnapshotMarshaller; GrpcServerProtocolService(Supplier idSupplier, RaftServer server, boolean zeroCopyEnabled, ZeroCopyMetrics zeroCopyMetrics) { @@ -250,8 +265,15 @@ private void releaseLast() { this.server = server; this.zeroCopyEnabled = zeroCopyEnabled; this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(AppendEntriesRequestProto.getDefaultInstance(), - zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage); + zeroCopyMetrics::onZeroCopyAppendEntries, zeroCopyMetrics::onNonZeroCopyMessage, + zeroCopyMetrics::onReleasedMessage, zeroCopyMetrics.newMarshallerMetrics()); + this.zeroCopyInstallSnapshotMarshaller = new ZeroCopyMessageMarshaller<>( + InstallSnapshotRequestProto.getDefaultInstance(), + zeroCopyMetrics::onZeroCopyInstallSnapshot, zeroCopyMetrics::onNonZeroCopyMessage, + zeroCopyMetrics::onReleasedMessage, zeroCopyMetrics.newMarshallerMetrics()); zeroCopyMetrics.addUnreleased("server_protocol", zeroCopyRequestMarshaller::getUnclosedCount); + zeroCopyMetrics.addUnreleased("server_protocol_install_snapshot", + zeroCopyInstallSnapshotMarshaller::getUnclosedCount); } RaftPeerId getId() { @@ -268,9 +290,16 @@ ServerServiceDefinition bindServiceWithZeroCopy() { // Add appendEntries with zero copy marshaller. addMethodWithCustomMarshaller(orig, builder, getAppendEntriesMethod(), zeroCopyRequestMarshaller); + // Add installSnapshot with zero copy marshaller for zero-copy counters/metrics. + addMethodWithCustomMarshaller(orig, builder, getInstallSnapshotMethod(), zeroCopyInstallSnapshotMarshaller); // Add remaining methods as is. + final String appendEntriesMethod = getAppendEntriesMethod().getFullMethodName(); + final String installSnapshotMethod = getInstallSnapshotMethod().getFullMethodName(); orig.getMethods().stream().filter( - x -> !x.getMethodDescriptor().getFullMethodName().equals(getAppendEntriesMethod().getFullMethodName()) + x -> { + final String methodName = x.getMethodDescriptor().getFullMethodName(); + return !methodName.equals(appendEntriesMethod) && !methodName.equals(installSnapshotMethod); + } ).forEach( builder::addMethod ); @@ -365,6 +394,11 @@ CompletableFuture process(InstallSnapshotRequestProto return CompletableFuture.completedFuture(server.installSnapshot(request)); } + @Override + void release(InstallSnapshotRequestProto request) { + zeroCopyInstallSnapshotMarshaller.release(request); + } + @Override long getCallId(InstallSnapshotRequestProto request) { return request.getServerRequest().getCallId(); diff --git a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java index eddf2495e4..38e93b99b3 100644 --- a/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java +++ b/ratis-grpc/src/main/java/org/apache/ratis/grpc/util/ZeroCopyMessageMarshaller.java @@ -55,6 +55,23 @@ public class ZeroCopyMessageMarshaller implements PrototypeMarshaller { static final Logger LOG = LoggerFactory.getLogger(ZeroCopyMessageMarshaller.class); + public interface Metrics { + default void onZeroCopyParse(long bytesSaved, long parseTimeNanos) { + } + + default void onFallbackNotKnownLength() { + } + + default void onFallbackNotDetachable() { + } + + default void onFallbackNotByteBuffer() { + } + } + + private static final Metrics NOOP_METRICS = new Metrics() { + }; + private final String name; private final Map unclosedStreams = Collections.synchronizedMap(new IdentityHashMap<>()); private final Parser parser; @@ -63,13 +80,19 @@ public class ZeroCopyMessageMarshaller implements Prototy private final Consumer zeroCopyCount; private final Consumer nonZeroCopyCount; private final Consumer releasedCount; + private final Metrics metrics; public ZeroCopyMessageMarshaller(T defaultInstance) { - this(defaultInstance, m -> {}, m -> {}, m -> {}); + this(defaultInstance, m -> {}, m -> {}, m -> {}, NOOP_METRICS); } public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount, Consumer releasedCount) { + this(defaultInstance, zeroCopyCount, nonZeroCopyCount, releasedCount, NOOP_METRICS); + } + + public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, Consumer nonZeroCopyCount, + Consumer releasedCount, Metrics metrics) { this.name = JavaUtils.getClassSimpleName(defaultInstance.getClass()) + "-Marshaller"; @SuppressWarnings("unchecked") final Parser p = (Parser) defaultInstance.getParserForType(); @@ -79,6 +102,7 @@ public ZeroCopyMessageMarshaller(T defaultInstance, Consumer zeroCopyCount, C this.zeroCopyCount = zeroCopyCount; this.nonZeroCopyCount = nonZeroCopyCount; this.releasedCount = releasedCount; + this.metrics = metrics == null ? NOOP_METRICS : metrics; } @Override @@ -158,28 +182,36 @@ private List getByteStrings(InputStream detached, int exactSize) thr */ private T parseZeroCopy(InputStream stream) throws IOException { if (!(stream instanceof KnownLength)) { + metrics.onFallbackNotKnownLength(); LOG.debug("stream is not KnownLength: {}", stream.getClass()); return null; } if (!(stream instanceof Detachable)) { + metrics.onFallbackNotDetachable(); LOG.debug("stream is not Detachable: {}", stream.getClass()); return null; } if (!(stream instanceof HasByteBuffer)) { + metrics.onFallbackNotByteBuffer(); LOG.debug("stream is not HasByteBuffer: {}", stream.getClass()); return null; } if (!((HasByteBuffer) stream).byteBufferSupported()) { + metrics.onFallbackNotByteBuffer(); LOG.debug("stream is HasByteBuffer but not byteBufferSupported: {}", stream.getClass()); return null; } final int exactSize = stream.available(); InputStream detached = ((Detachable) stream).detach(); + // Measure only the zero-copy parse path (detach + parse). + final long startNanos = System.nanoTime(); try { final List byteStrings = getByteStrings(detached, exactSize); final T message = parseFrom(byteStrings, exactSize); + metrics.onZeroCopyParse(exactSize, System.nanoTime() - startNanos); + final InputStream previous = unclosedStreams.put(message, detached); Preconditions.assertNull(previous, "previous"); diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestZeroCopyMetrics.java b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestZeroCopyMetrics.java new file mode 100644 index 0000000000..01c8dc9c1c --- /dev/null +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/util/TestZeroCopyMetrics.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ratis.grpc.util; + +import org.apache.ratis.grpc.metrics.ZeroCopyMetrics; +import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto; +import org.apache.ratis.proto.RaftProtos.InstallSnapshotRequestProto; +import org.apache.ratis.proto.RaftProtos.RaftClientRequestProto; +import org.apache.ratis.test.proto.BinaryRequest; +import org.apache.ratis.thirdparty.com.google.protobuf.ByteString; +import org.apache.ratis.thirdparty.io.grpc.Detachable; +import org.apache.ratis.thirdparty.io.grpc.HasByteBuffer; +import org.apache.ratis.thirdparty.io.grpc.KnownLength; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class TestZeroCopyMetrics { + @Test + public void testZeroCopyMetricsTrackMessageTypesAndMarshallerStats() { + final ZeroCopyMetrics metrics = newZeroCopyMetrics(); + try { + metrics.onZeroCopyAppendEntries(AppendEntriesRequestProto.getDefaultInstance()); + metrics.onZeroCopyInstallSnapshot(InstallSnapshotRequestProto.getDefaultInstance()); + metrics.onZeroCopyClientRequest(RaftClientRequestProto.getDefaultInstance()); + metrics.onNonZeroCopyMessage(AppendEntriesRequestProto.getDefaultInstance()); + metrics.onReleasedMessage(AppendEntriesRequestProto.getDefaultInstance()); + + final ZeroCopyMetrics.ZeroCopyMessageMarshallerMetrics marshallerMetrics = metrics.newMarshallerMetrics(); + marshallerMetrics.onZeroCopyParse(123L, 456L); + marshallerMetrics.onFallbackNotKnownLength(); + marshallerMetrics.onFallbackNotDetachable(); + marshallerMetrics.onFallbackNotByteBuffer(); + + assertEquals(3L, metrics.zeroCopyMessages()); + assertEquals(1L, metrics.nonZeroCopyMessages()); + assertEquals(1L, metrics.releasedMessages()); + assertCounter(metrics, "num_zero_copy_append_entries", 1L); + assertCounter(metrics, "num_zero_copy_install_snapshot", 1L); + assertCounter(metrics, "num_zero_copy_client_request", 1L); + assertCounter(metrics, "bytes_saved_by_zero_copy", 123L); + assertCounter(metrics, "zero_copy_parse_time_nanos", 456L); + assertCounter(metrics, "zero_copy_fallback_not_known_length", 1L); + assertCounter(metrics, "zero_copy_fallback_not_detachable", 1L); + assertCounter(metrics, "zero_copy_fallback_not_byte_buffer", 1L); + } finally { + metrics.unregister(); + } + } + + @Test + public void testMarshallerReportsZeroCopyParseMetrics() { + final BinaryRequest request = BinaryRequest.newBuilder() + .setData(ByteString.copyFromUtf8("zero-copy")) + .build(); + final byte[] bytes = request.toByteArray(); + final RecordingMetrics metrics = new RecordingMetrics(); + final AtomicInteger zeroCopyCount = new AtomicInteger(); + final AtomicInteger nonZeroCopyCount = new AtomicInteger(); + final AtomicInteger releasedCount = new AtomicInteger(); + final ZeroCopyMessageMarshaller marshaller = new ZeroCopyMessageMarshaller<>( + BinaryRequest.getDefaultInstance(), + ignored -> zeroCopyCount.incrementAndGet(), + ignored -> nonZeroCopyCount.incrementAndGet(), + ignored -> releasedCount.incrementAndGet(), + metrics); + + final BinaryRequest parsed = marshaller.parse(new DetachableByteBufferInputStream(bytes)); + assertEquals(request, parsed); + assertEquals(1, zeroCopyCount.get()); + assertEquals(0, nonZeroCopyCount.get()); + assertEquals(bytes.length, metrics.bytesSavedByZeroCopy); + assertTrue(metrics.zeroCopyParseTimeNanos > 0); + assertEquals(1, marshaller.getUnclosedCount()); + + marshaller.release(parsed); + assertEquals(1, releasedCount.get()); + assertEquals(0, marshaller.getUnclosedCount()); + } + + @Test + public void testMarshallerReportsFallbackNotKnownLength() { + final BinaryRequest request = BinaryRequest.newBuilder() + .setData(ByteString.copyFromUtf8("known-length-fallback")) + .build(); + final RecordingMetrics metrics = new RecordingMetrics(); + final AtomicInteger nonZeroCopyCount = new AtomicInteger(); + final ZeroCopyMessageMarshaller marshaller = new ZeroCopyMessageMarshaller<>( + BinaryRequest.getDefaultInstance(), + ignored -> fail("Should not use zero-copy path"), + ignored -> nonZeroCopyCount.incrementAndGet(), + ignored -> { }, + metrics); + + final BinaryRequest parsed = marshaller.parse(new ByteArrayInputStream(request.toByteArray())); + assertEquals(request, parsed); + assertEquals(1, nonZeroCopyCount.get()); + assertEquals(1, metrics.fallbackNotKnownLength); + assertEquals(0, metrics.fallbackNotDetachable); + assertEquals(0, metrics.fallbackNotByteBuffer); + } + + @Test + public void testMarshallerReportsFallbackNotDetachable() { + final BinaryRequest request = BinaryRequest.newBuilder() + .setData(ByteString.copyFromUtf8("not-detachable")) + .build(); + final RecordingMetrics metrics = new RecordingMetrics(); + final AtomicInteger nonZeroCopyCount = new AtomicInteger(); + final ZeroCopyMessageMarshaller marshaller = new ZeroCopyMessageMarshaller<>( + BinaryRequest.getDefaultInstance(), + ignored -> fail("Should not use zero-copy path"), + ignored -> nonZeroCopyCount.incrementAndGet(), + ignored -> { }, + metrics); + + final BinaryRequest parsed = marshaller.parse(new KnownLengthByteArrayInputStream(request.toByteArray())); + assertEquals(request, parsed); + assertEquals(1, nonZeroCopyCount.get()); + assertEquals(0, metrics.fallbackNotKnownLength); + assertEquals(1, metrics.fallbackNotDetachable); + assertEquals(0, metrics.fallbackNotByteBuffer); + } + + @Test + public void testMarshallerReportsFallbackNotByteBuffer() { + final BinaryRequest request = BinaryRequest.newBuilder() + .setData(ByteString.copyFromUtf8("not-byte-buffer")) + .build(); + final RecordingMetrics metrics = new RecordingMetrics(); + final AtomicInteger nonZeroCopyCount = new AtomicInteger(); + final ZeroCopyMessageMarshaller marshaller = new ZeroCopyMessageMarshaller<>( + BinaryRequest.getDefaultInstance(), + ignored -> fail("Should not use zero-copy path"), + ignored -> nonZeroCopyCount.incrementAndGet(), + ignored -> { }, + metrics); + + final BinaryRequest parsed = marshaller.parse(new KnownLengthDetachableByteArrayInputStream(request.toByteArray())); + assertEquals(request, parsed); + assertEquals(1, nonZeroCopyCount.get()); + assertEquals(0, metrics.fallbackNotKnownLength); + assertEquals(0, metrics.fallbackNotDetachable); + assertEquals(1, metrics.fallbackNotByteBuffer); + } + + private static void assertCounter(ZeroCopyMetrics metrics, String name, long expected) { + assertEquals(expected, metrics.getRegistry().counter(name).getCount(), name); + } + + private static ZeroCopyMetrics newZeroCopyMetrics() { + final ZeroCopyMetrics metrics = new ZeroCopyMetrics(); + metrics.unregister(); + return new ZeroCopyMetrics(); + } + + private static class RecordingMetrics implements ZeroCopyMessageMarshaller.Metrics { + private long bytesSavedByZeroCopy; + private long zeroCopyParseTimeNanos; + private int fallbackNotKnownLength; + private int fallbackNotDetachable; + private int fallbackNotByteBuffer; + + @Override + public void onZeroCopyParse(long bytesSaved, long parseTimeNanos) { + this.bytesSavedByZeroCopy += bytesSaved; + this.zeroCopyParseTimeNanos += parseTimeNanos; + } + + @Override + public void onFallbackNotKnownLength() { + fallbackNotKnownLength++; + } + + @Override + public void onFallbackNotDetachable() { + fallbackNotDetachable++; + } + + @Override + public void onFallbackNotByteBuffer() { + fallbackNotByteBuffer++; + } + } + + private static class KnownLengthByteArrayInputStream extends ByteArrayInputStream implements KnownLength { + KnownLengthByteArrayInputStream(byte[] buf) { + super(buf); + } + } + + private static class KnownLengthDetachableByteArrayInputStream extends KnownLengthByteArrayInputStream + implements Detachable { + KnownLengthDetachableByteArrayInputStream(byte[] buf) { + super(buf); + } + + @Override + public InputStream detach() { + return this; + } + } + + private static class DetachableByteBufferInputStream extends InputStream + implements KnownLength, Detachable, HasByteBuffer { + private final byte[] bytes; + private int position; + private int mark; + + DetachableByteBufferInputStream(byte[] bytes) { + this.bytes = bytes; + } + + @Override + public InputStream detach() { + return this; + } + + @Override + public boolean byteBufferSupported() { + return true; + } + + @Override + public ByteBuffer getByteBuffer() { + return ByteBuffer.wrap(bytes, position, available()).slice(); + } + + @Override + public int read() { + return position < bytes.length ? bytes[position++] & 0xff : -1; + } + + @Override + public int read(byte[] b, int off, int len) { + if (position >= bytes.length) { + return -1; + } + final int n = Math.min(len, available()); + System.arraycopy(bytes, position, b, off, n); + position += n; + return n; + } + + @Override + public long skip(long n) { + final int skipped = Math.min((int) n, available()); + position += skipped; + return skipped; + } + + @Override + public int available() { + return bytes.length - position; + } + + @Override + public synchronized void mark(int readlimit) { + this.mark = position; + } + + @Override + public synchronized void reset() { + this.position = mark; + } + + @Override + public boolean markSupported() { + return true; + } + } +}