diff --git a/src/Platform/Microsoft.Testing.Platform/Helpers/EnvironmentVariableConstants.cs b/src/Platform/Microsoft.Testing.Platform/Helpers/EnvironmentVariableConstants.cs index e496cbc3af..dd835b9699 100644 --- a/src/Platform/Microsoft.Testing.Platform/Helpers/EnvironmentVariableConstants.cs +++ b/src/Platform/Microsoft.Testing.Platform/Helpers/EnvironmentVariableConstants.cs @@ -9,7 +9,6 @@ internal static class EnvironmentVariableConstants public const string DOTNET_WATCH = nameof(DOTNET_WATCH); public const string TESTINGPLATFORM_HOTRELOAD_ENABLED = nameof(TESTINGPLATFORM_HOTRELOAD_ENABLED); public const string TESTINGPLATFORM_DEFAULT_HANG_TIMEOUT = nameof(TESTINGPLATFORM_DEFAULT_HANG_TIMEOUT); - public const string TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS = nameof(TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS); public const string TESTINGPLATFORM_TESTHOSTCONTROLLER_SKIPEXTENSION = nameof(TESTINGPLATFORM_TESTHOSTCONTROLLER_SKIPEXTENSION); public const string TESTINGPLATFORM_TESTHOSTCONTROLLER_PIPENAME = nameof(TESTINGPLATFORM_TESTHOSTCONTROLLER_PIPENAME); diff --git a/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostBuilder.cs b/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostBuilder.cs index beb17eeeea..896c580d34 100644 --- a/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostBuilder.cs +++ b/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostBuilder.cs @@ -809,12 +809,11 @@ private static async Task BuildTestFrameworkAsync(TestFrameworkB IDataConsumer[] dataConsumerServices = [.. dataConsumersBuilder]; - AsynchronousMessageBus concreteMessageBusService = new( + var concreteMessageBusService = new AsynchronousMessageBus( dataConsumerServices, serviceProvider.GetTestApplicationCancellationTokenSource(), serviceProvider.GetTask(), - serviceProvider.GetLoggerFactory(), - serviceProvider.GetEnvironment()); + serviceProvider.GetLoggerFactory()); await concreteMessageBusService.InitAsync().ConfigureAwait(false); testFrameworkBuilderData.MessageBusProxy.SetBuiltMessageBus(concreteMessageBusService); diff --git a/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostControllersTestHost.cs b/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostControllersTestHost.cs index 632f30a8a1..fd6927182a 100644 --- a/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostControllersTestHost.cs +++ b/src/Platform/Microsoft.Testing.Platform/Hosts/TestHostControllersTestHost.cs @@ -170,12 +170,11 @@ protected override async Task InternalRunAsync(CancellationToken cancellati } } - AsynchronousMessageBus concreteMessageBusService = new( + var concreteMessageBusService = new AsynchronousMessageBus( [.. dataConsumersBuilder], ServiceProvider.GetTestApplicationCancellationTokenSource(), ServiceProvider.GetTask(), - ServiceProvider.GetLoggerFactory(), - ServiceProvider.GetEnvironment()); + ServiceProvider.GetLoggerFactory()); await concreteMessageBusService.InitAsync().ConfigureAwait(false); ((MessageBusProxy)ServiceProvider.GetMessageBus()).SetBuiltMessageBus(concreteMessageBusService); diff --git a/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.net.cs b/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.net.cs index d071679abb..2246299a56 100644 --- a/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.net.cs +++ b/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.net.cs @@ -15,24 +15,9 @@ internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor { private readonly ITask _task; private readonly CancellationToken _cancellationToken; - private readonly Channel<(IDataProducer DataProducer, IData Data)> _channel = Channel.CreateUnbounded<(IDataProducer DataProducer, IData Data)>(new UnboundedChannelOptions - { - // We process only 1 data at a time - SingleReader = true, - - // We don't know how many threads will call the publish on the message bus - SingleWriter = false, - // We want to unlink the publish that's the message bus - AllowSynchronousContinuations = false, - }); - - // This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition. - // This is the "logical" consume workflow state. - private readonly TaskCompletionSource _consumerState = new(); - private readonly Task _consumeTask; - private long _totalPayloadReceived; - private long _totalPayloadProcessed; + private Channel<(IDataProducer DataProducer, IData Data)> _channel = CreateChannel(); + private Task _consumeTask; public AsyncConsumerDataProcessor(IDataConsumer consumer, ITask task, CancellationToken cancellationToken) { @@ -45,10 +30,7 @@ public AsyncConsumerDataProcessor(IDataConsumer consumer, ITask task, Cancellati public IDataConsumer DataConsumer { get; } public async Task PublishAsync(IDataProducer dataProducer, IData data) - { - Interlocked.Increment(ref _totalPayloadReceived); - await _channel.Writer.WriteAsync((dataProducer, data), _cancellationToken).ConfigureAwait(false); - } + => await _channel.Writer.WriteAsync((dataProducer, data), _cancellationToken).ConfigureAwait(false); private async Task ConsumeAsync() { @@ -58,112 +40,59 @@ private async Task ConsumeAsync() { (IDataProducer dataProducer, IData data) = await _channel.Reader.ReadAsync(_cancellationToken).ConfigureAwait(false); - try - { - // We don't enqueue the data if the consumer is the producer of the data. - // We could optimize this if and make a get with type/all but producers, but it - // could be over-engineering. - if (dataProducer.Uid == DataConsumer.Uid) - { - continue; - } - - try - { - await DataConsumer.ConsumeAsync(dataProducer, data, _cancellationToken).ConfigureAwait(false); - } - - // We let the catch below to handle the graceful cancellation of the process - catch (Exception ex) when (ex is not OperationCanceledException) - { - // If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because - // it's possible we have a race condition where the payload counting in DrainDataAsync returns false and the current task is not yet in a - // "faulted state". - _consumerState.SetException(ex); - - // We let current task to move to fault state, checked inside CompleteAddingAsync. - throw; - } - } - finally + // We don't enqueue the data if the consumer is the producer of the data. + // We could optimize this if and make a get with type/all but producers, but it + // could be over-engineering. + if (dataProducer.Uid == DataConsumer.Uid) { - Interlocked.Increment(ref _totalPayloadProcessed); + continue; } + + await DataConsumer.ConsumeAsync(dataProducer, data, _cancellationToken).ConfigureAwait(false); } } catch (OperationCanceledException oc) when (oc.CancellationToken == _cancellationToken) { // Ignore we're shutting down } - catch (Exception ex) - { - // For all other exception we signal the state if not already faulted - if (!_consumerState.Task.IsFaulted) - { - _consumerState.SetException(ex); - } - - // let the exception bubble up - throw; - } - - // We're exiting gracefully, signal the correct state. - _consumerState.SetResult(); } public async Task CompleteAddingAsync() { // Signal that no more items will be added to the collection // It's possible that we call this method multiple times - _channel.Writer.TryComplete(); + _channel.Writer.Complete(); // Wait for the consumer to complete await _consumeTask.ConfigureAwait(false); } - public async Task DrainDataAsync() + public async Task DrainDataAsync() { - // We go volatile because we race with Interlocked.Increment in PublishAsync - long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed); - long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived); - const int minDelayTimeMs = 25; - int currentDelayTimeMs = minDelayTimeMs; - while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed) - { - // When we cancel we throw inside ConsumeAsync and we won't drain anymore any data - if (_cancellationToken.IsCancellationRequested) - { - break; - } - - await _task.Delay(currentDelayTimeMs).ConfigureAwait(false); - currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200); - - if (_consumerState.Task.IsFaulted) - { - // Rethrow the exception - await _consumerState.Task.ConfigureAwait(false); - } - - // Wait for the consumer to complete the current enqueued items - totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed); - totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived); - } - - // It' possible that we fail and we have consumed the item - if (_consumerState.Task.IsFaulted) - { - // Rethrow the exception - await _consumerState.Task.ConfigureAwait(false); - } + _channel.Writer.Complete(); + await _consumeTask.ConfigureAwait(false); - return _totalPayloadReceived; + _channel = CreateChannel(); + _consumeTask = _task.Run(ConsumeAsync, _cancellationToken); } // At this point we simply signal the channel as complete and we don't wait for the consumer to complete. // We expect that the CompleteAddingAsync() is already done correctly and so we prefer block the loop and in // case get exception inside the PublishAsync() public void Dispose() - => _channel.Writer.TryComplete(); + => _channel.Writer.Complete(); + + private static Channel<(IDataProducer DataProducer, IData Data)> CreateChannel() + => Channel.CreateUnbounded<(IDataProducer DataProducer, IData Data)>(new UnboundedChannelOptions + { + // We process only 1 data at a time + SingleReader = true, + + // We don't know how many threads will call the publish on the message bus + SingleWriter = false, + + // We want to unlink the publish that's the message bus + AllowSynchronousContinuations = false, + }); } #endif diff --git a/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.netstandard.cs b/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.netstandard.cs index b43d6ff504..17796d5920 100644 --- a/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.netstandard.cs +++ b/src/Platform/Microsoft.Testing.Platform/Messages/AsyncConsumerDataProcessor.netstandard.cs @@ -12,14 +12,9 @@ internal sealed class AsyncConsumerDataProcessor : IAsyncConsumerDataProcessor { private readonly ITask _task; private readonly CancellationToken _cancellationToken; - private readonly SingleConsumerUnboundedChannel<(IDataProducer DataProducer, IData Data)> _channel = new(); - // This is needed to avoid possible race condition between drain and _totalPayloadProcessed race condition. - // This is the "logical" consume workflow state. - private readonly TaskCompletionSource _consumerState = new(); - private readonly Task _consumeTask; - private long _totalPayloadReceived; - private long _totalPayloadProcessed; + private SingleConsumerUnboundedChannel<(IDataProducer DataProducer, IData Data)> _channel = new(); + private Task _consumeTask; public AsyncConsumerDataProcessor(IDataConsumer dataConsumer, ITask task, CancellationToken cancellationToken) { @@ -34,7 +29,6 @@ public AsyncConsumerDataProcessor(IDataConsumer dataConsumer, ITask task, Cancel public Task PublishAsync(IDataProducer dataProducer, IData data) { _cancellationToken.ThrowIfCancellationRequested(); - Interlocked.Increment(ref _totalPayloadReceived); _channel.Write((dataProducer, data)); return Task.CompletedTask; } @@ -47,37 +41,15 @@ private async Task ConsumeAsync() { while (_channel.TryRead(out (IDataProducer DataProducer, IData Data) item)) { - try + // We don't enqueue the data if the consumer is the producer of the data. + // We could optimize this if and make a get with type/all but producers, but it + // could be over-engineering. + if (item.DataProducer.Uid == DataConsumer.Uid) { - // We don't enqueue the data if the consumer is the producer of the data. - // We could optimize this if and make a get with type/all but producers, but it - // could be over-engineering. - if (item.DataProducer.Uid == DataConsumer.Uid) - { - continue; - } - - try - { - await DataConsumer.ConsumeAsync(item.DataProducer, item.Data, _cancellationToken).ConfigureAwait(false); - } - - // We let the catch below to handle the graceful cancellation of the process - catch (Exception ex) when (ex is not OperationCanceledException) - { - // If we're draining before to increment the _totalPayloadProcessed we need to signal that we should throw because - // it's possible we have a race condition where the payload check at line 106 return false and the current task is not yet in a - // "faulted state". - _consumerState.SetException(ex); - - // We let current task to move to fault state, checked inside CompleteAddingAsync. - throw; - } - } - finally - { - Interlocked.Increment(ref _totalPayloadProcessed); + continue; } + + await DataConsumer.ConsumeAsync(item.DataProducer, item.Data, _cancellationToken).ConfigureAwait(false); } } } @@ -85,20 +57,6 @@ private async Task ConsumeAsync() { // Ignore we're shutting down } - catch (Exception ex) - { - // For all other exception we signal the state if not already faulted - if (!_consumerState.Task.IsFaulted) - { - _consumerState.SetException(ex); - } - - // let the exception bubble up - throw; - } - - // We're exiting gracefully, signal the correct state. - _consumerState.SetResult(new object()); } public async Task CompleteAddingAsync() @@ -111,43 +69,13 @@ public async Task CompleteAddingAsync() await _consumeTask.ConfigureAwait(false); } - public async Task DrainDataAsync() + public async Task DrainDataAsync() { - // We go volatile because we race with Interlocked.Increment in PublishAsync - long totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed); - long totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived); - const int minDelayTimeMs = 25; - int currentDelayTimeMs = minDelayTimeMs; - while (Interlocked.CompareExchange(ref _totalPayloadReceived, totalPayloadReceived, totalPayloadProcessed) != totalPayloadProcessed) - { - // When we cancel we throw inside ConsumeAsync and we won't drain anymore any data - if (_cancellationToken.IsCancellationRequested) - { - break; - } - - await _task.Delay(currentDelayTimeMs).ConfigureAwait(false); - currentDelayTimeMs = Math.Min(currentDelayTimeMs + minDelayTimeMs, 200); - - if (_consumerState.Task.IsFaulted) - { - // Rethrow the exception - await _consumerState.Task.ConfigureAwait(false); - } - - // Wait for the consumer to complete the current enqueued items - totalPayloadProcessed = Volatile.Read(ref _totalPayloadProcessed); - totalPayloadReceived = Volatile.Read(ref _totalPayloadReceived); - } - - // It' possible that we fail and we have consumed the item - if (_consumerState.Task.IsFaulted) - { - // Rethrow the exception - await _consumerState.Task.ConfigureAwait(false); - } + _channel.Complete(); + await _consumeTask.ConfigureAwait(false); - return _totalPayloadReceived; + _channel = new(); + _consumeTask = _task.Run(ConsumeAsync, _cancellationToken); } public void Dispose() diff --git a/src/Platform/Microsoft.Testing.Platform/Messages/AsynchronousMessageBus.cs b/src/Platform/Microsoft.Testing.Platform/Messages/AsynchronousMessageBus.cs index 601876c206..2a2537f3c2 100644 --- a/src/Platform/Microsoft.Testing.Platform/Messages/AsynchronousMessageBus.cs +++ b/src/Platform/Microsoft.Testing.Platform/Messages/AsynchronousMessageBus.cs @@ -11,11 +11,7 @@ namespace Microsoft.Testing.Platform.Messages; internal sealed class AsynchronousMessageBus : BaseMessageBus, IMessageBus, IDisposable { - // This is an arbitrary number of attempts to drain the message bus. - // The number of attempts is configurable via the environment variable TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS. - private const int DefaultDrainAttempt = 5; private readonly ITask _task; - private readonly IEnvironment _environment; private readonly ILogger _logger; private readonly bool _isTraceLoggingEnabled; private readonly Dictionary _consumerProcessor = []; @@ -28,13 +24,11 @@ public AsynchronousMessageBus( IDataConsumer[] dataConsumers, ITestApplicationCancellationTokenSource testApplicationCancellationTokenSource, ITask task, - ILoggerFactory loggerFactory, - IEnvironment environment) + ILoggerFactory loggerFactory) { _dataConsumers = dataConsumers; _testApplicationCancellationTokenSource = testApplicationCancellationTokenSource; _task = task; - _environment = environment; _logger = loggerFactory.CreateLogger(); _isTraceLoggingEnabled = _logger.IsEnabled(LogLevel.Trace); } @@ -127,51 +121,11 @@ private async Task LogDataAsync(IDataProducer dataProducer, IData data) public override async Task DrainDataAsync() { - Dictionary consumerToDrain = []; - bool anotherRound = true; - string? customAttempts = _environment.GetEnvironmentVariable(EnvironmentVariableConstants.TESTINGPLATFORM_MESSAGEBUS_DRAINDATA_ATTEMPTS); - if (!int.TryParse(customAttempts, out int totalNumberOfDrainAttempt)) - { - totalNumberOfDrainAttempt = DefaultDrainAttempt; - } - - var stopwatch = Stopwatch.StartNew(); - CancellationToken cancellationToken = _testApplicationCancellationTokenSource.CancellationToken; - while (anotherRound) + foreach (List dataProcessors in _dataTypeConsumers.Values) { - if (cancellationToken.IsCancellationRequested) - { - return; - } - - if (totalNumberOfDrainAttempt == 0) - { - StringBuilder builder = new(); - builder.Append(CultureInfo.InvariantCulture, $"Publisher/Consumer loop detected during the drain after {stopwatch.Elapsed}.\n{builder}"); - - foreach ((IAsyncConsumerDataProcessor key, long value) in consumerToDrain) - { - builder.AppendLine(CultureInfo.InvariantCulture, $"Consumer '{key.DataConsumer}' payload received {value}."); - } - - throw new InvalidOperationException(builder.ToString()); - } - - totalNumberOfDrainAttempt--; - anotherRound = false; - foreach (List dataProcessors in _dataTypeConsumers.Values) + foreach (IAsyncConsumerDataProcessor asyncMultiProducerMultiConsumerDataProcessor in dataProcessors) { - foreach (IAsyncConsumerDataProcessor asyncMultiProducerMultiConsumerDataProcessor in dataProcessors) - { - consumerToDrain.TryAdd(asyncMultiProducerMultiConsumerDataProcessor, 0); - - long totalPayloadReceived = await asyncMultiProducerMultiConsumerDataProcessor.DrainDataAsync().ConfigureAwait(false); - if (consumerToDrain[asyncMultiProducerMultiConsumerDataProcessor] != totalPayloadReceived) - { - consumerToDrain[asyncMultiProducerMultiConsumerDataProcessor] = totalPayloadReceived; - anotherRound = true; - } - } + await asyncMultiProducerMultiConsumerDataProcessor.DrainDataAsync().ConfigureAwait(false); } } } diff --git a/src/Platform/Microsoft.Testing.Platform/Messages/IAsyncConsumerDataProcessor.cs b/src/Platform/Microsoft.Testing.Platform/Messages/IAsyncConsumerDataProcessor.cs index 21317c40f3..f8eb7386cd 100644 --- a/src/Platform/Microsoft.Testing.Platform/Messages/IAsyncConsumerDataProcessor.cs +++ b/src/Platform/Microsoft.Testing.Platform/Messages/IAsyncConsumerDataProcessor.cs @@ -12,7 +12,7 @@ internal interface IAsyncConsumerDataProcessor : IDisposable Task CompleteAddingAsync(); - Task DrainDataAsync(); + Task DrainDataAsync(); Task PublishAsync(IDataProducer dataProducer, IData data); } diff --git a/test/UnitTests/Microsoft.Testing.Platform.UnitTests/Messages/AsynchronousMessageBusTests.cs b/test/UnitTests/Microsoft.Testing.Platform.UnitTests/Messages/AsynchronousMessageBusTests.cs index 1f85e0681b..f30942a439 100644 --- a/test/UnitTests/Microsoft.Testing.Platform.UnitTests/Messages/AsynchronousMessageBusTests.cs +++ b/test/UnitTests/Microsoft.Testing.Platform.UnitTests/Messages/AsynchronousMessageBusTests.cs @@ -19,12 +19,11 @@ public async Task UnexpectedTypePublished_ShouldFail() { using MessageBusProxy proxy = new(); InvalidTypePublished consumer = new(proxy); - AsynchronousMessageBus asynchronousMessageBus = new( + var asynchronousMessageBus = new AsynchronousMessageBus( [consumer], new CTRLPlusCCancellationTokenSource(), new SystemTask(), - new NopLoggerFactory(), - new SystemEnvironment()); + new NopLoggerFactory()); await asynchronousMessageBus.InitAsync(); proxy.SetBuiltMessageBus(asynchronousMessageBus); @@ -40,12 +39,11 @@ public async Task DrainDataAsync_Loop_ShouldFail() using MessageBusProxy proxy = new(); LoopConsumerA consumerA = new(proxy); ConsumerB consumerB = new(proxy); - AsynchronousMessageBus asynchronousMessageBus = new( + var asynchronousMessageBus = new AsynchronousMessageBus( [consumerA, consumerB], new CTRLPlusCCancellationTokenSource(), new SystemTask(), - new NopLoggerFactory(), - new SystemEnvironment()); + new NopLoggerFactory()); await asynchronousMessageBus.InitAsync(); proxy.SetBuiltMessageBus(asynchronousMessageBus); @@ -65,12 +63,11 @@ public async Task MessageBus_WhenConsumerProducesAndConsumesTheSameType_ShouldNo using MessageBusProxy proxy = new(); Consumer consumerA = new(proxy, "consumerA"); Consumer consumerB = new(proxy, "consumerB"); - AsynchronousMessageBus asynchronousMessageBus = new( + var asynchronousMessageBus = new AsynchronousMessageBus( [consumerA, consumerB], new CTRLPlusCCancellationTokenSource(), new SystemTask(), - new NopLoggerFactory(), - new SystemEnvironment()); + new NopLoggerFactory()); await asynchronousMessageBus.InitAsync(); proxy.SetBuiltMessageBus(asynchronousMessageBus); @@ -104,12 +101,11 @@ public async Task Consumers_ConsumeData_ShouldNotMissAnyPayload() dummyConsumers.Add(dummyConsumer); } - using AsynchronousMessageBus asynchronousMessageBus = new( + using var asynchronousMessageBus = new AsynchronousMessageBus( dummyConsumers.ToArray(), new CTRLPlusCCancellationTokenSource(), new SystemTask(), - new NopLoggerFactory(), - new SystemEnvironment()); + new NopLoggerFactory()); await asynchronousMessageBus.InitAsync(); proxy.SetBuiltMessageBus(asynchronousMessageBus);