Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sdk/cs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ await mgr.DownloadAndRegisterEpsAsync((epName, percent) =>
Console.WriteLine();
```

#### Cancelling model and EP downloads

Pass a `CancellationToken` to either download API. Cancellation is observed on the next progress update.

```csharp
// mgr and model already initialized
using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5));

await mgr.DownloadAndRegisterEpsAsync(ct: cts.Token);
await model.DownloadAsync(ct: cts.Token);
```

Catalog access no longer blocks on EP downloads. Call `DownloadAndRegisterEpsAsync` explicitly when you need hardware-accelerated execution providers.

## Quick Start
Expand Down
5 changes: 5 additions & 0 deletions sdk/cs/src/Detail/CoreInterop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,11 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput,

if (helper.Exception != null)
{
if (helper.Exception is OperationCanceledException canceledException)
{
throw canceledException;
}

throw new FoundryLocalException("Exception in callback handler. See InnerException for details",
helper.Exception);
}
Expand Down
24 changes: 18 additions & 6 deletions sdk/cs/src/Detail/ModelVariant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

namespace Microsoft.AI.Foundry.Local;

using System.Globalization;

using Microsoft.AI.Foundry.Local.Detail;
using Microsoft.Extensions.Logging;

Expand Down Expand Up @@ -144,16 +146,22 @@ private async Task DownloadImplAsync(Action<float>? downloadProgress = null,
};

ICoreInterop.Response? response;
var useCallbackPath = downloadProgress != null || (ct?.CanBeCanceled ?? false);

if (downloadProgress == null)
{
response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false);
}
else
if (useCallbackPath)
{
var callback = new ICoreInterop.CallbackFn(progressString =>
{
if (float.TryParse(progressString, out var progress))
if (ct is CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
}

if (downloadProgress != null &&
float.TryParse(progressString,
NumberStyles.Float,
CultureInfo.InvariantCulture,
out var progress))
{
downloadProgress(progress);
}
Expand All @@ -162,6 +170,10 @@ private async Task DownloadImplAsync(Action<float>? downloadProgress = null,
response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request,
callback, ct).ConfigureAwait(false);
}
else
{
response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false);
}

if (response.Error != null)
{
Expand Down
16 changes: 12 additions & 4 deletions sdk/cs/src/FoundryLocalManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
namespace Microsoft.AI.Foundry.Local;

using System;
using System.Globalization;
using System.Text.Json;
using System.Threading.Tasks;

Expand Down Expand Up @@ -373,20 +374,27 @@ private async Task<EpDownloadResult> DownloadAndRegisterEpsImplAsync(IEnumerable

ICoreInterop.Response result;

if (progressCallback != null)
var useCallbackPath = progressCallback != null || (ct?.CanBeCanceled ?? false);

if (useCallbackPath)
{
var callback = new ICoreInterop.CallbackFn(progressString =>
{
if (ct is CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
}

var sepIndex = progressString.IndexOf('|');
if (sepIndex >= 0)
{
var name = progressString[..sepIndex];
if (double.TryParse(progressString[(sepIndex + 1)..],
System.Globalization.NumberStyles.Float,
System.Globalization.CultureInfo.InvariantCulture,
NumberStyles.Float,
CultureInfo.InvariantCulture,
out var percent))
{
progressCallback(string.IsNullOrEmpty(name) ? "" : name, percent);
progressCallback?.Invoke(string.IsNullOrEmpty(name) ? "" : name, percent);
}
}
});
Expand Down
80 changes: 80 additions & 0 deletions sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// --------------------------------------------------------------------------------------------------------------------
// <copyright company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

namespace Microsoft.AI.Foundry.Local.Tests;

using Microsoft.AI.Foundry.Local.Detail;

using Microsoft.Extensions.Logging;

using Moq;

internal sealed class DownloadCancellationTests
{
[Test]
public async Task ModelVariantDownload_WithCancellableToken_UsesCallbackPathAndPropagatesCancellation()
{
var modelInfo = new ModelInfo
{
Id = "test-model-cpu:1",
Name = "test-model-cpu",
Alias = "test-model",
Version = 1,
ProviderType = "AzureFoundry",
Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1",
ModelType = "ONNX",
};

var modelLoadManager = new Mock<IModelLoadManager>(MockBehavior.Strict);
var coreInterop = new Mock<ICoreInterop>(MockBehavior.Strict);
var logger = new Mock<ILogger>();
using var cts = new CancellationTokenSource();

coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync(
It.Is<string>(s => s == "download_model"),
It.Is<CoreInteropRequest?>(r => r != null &&
r.Params != null &&
r.Params.ContainsKey("Model") &&
r.Params["Model"] == modelInfo.Id),
It.IsAny<ICoreInterop.CallbackFn>(),
It.IsAny<CancellationToken?>()))
.Returns((string commandName,
CoreInteropRequest? request,
ICoreInterop.CallbackFn callback,
CancellationToken? cancellationToken) =>
{
callback("10");
cts.Cancel();
callback("20");
return Task.FromResult(new ICoreInterop.Response());
});

var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object);

OperationCanceledException? caught = null;
try
{
await model.DownloadAsync(ct: cts.Token);
}
catch (OperationCanceledException ex)
{
caught = ex;
}

await Assert.That(caught).IsNotNull();
coreInterop.Verify(x => x.ExecuteCommandWithCallbackAsync(
"download_model",
It.IsAny<CoreInteropRequest?>(),
It.IsAny<ICoreInterop.CallbackFn>(),
It.IsAny<CancellationToken?>()),
Times.Once);
coreInterop.Verify(x => x.ExecuteCommandAsync(
It.IsAny<string>(),
It.IsAny<CoreInteropRequest?>(),
It.IsAny<CancellationToken?>()),
Times.Never);
}
}
3 changes: 2 additions & 1 deletion sdk/cs/test/FoundryLocal.Tests/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ private static string GetRepoRoot()

while (dir != null)
{
if (Directory.Exists(Path.Combine(dir.FullName, ".git")))
var gitPath = Path.Combine(dir.FullName, ".git");
if (Directory.Exists(gitPath) || File.Exists(gitPath))
return dir.FullName;

dir = dir.Parent;
Expand Down
15 changes: 14 additions & 1 deletion sdk/js/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@ await manager.downloadAndRegisterEps((epName, percent) => {
process.stdout.write('\n');
```

#### Cancelling model and EP downloads

Use an `AbortController` with either `downloadAndRegisterEps()` or `model.download()`. Aborting the signal rejects the in-progress download promise.

```typescript
// manager and model already initialized
const controller = new AbortController();
setTimeout(() => controller.abort(), 5000);

await manager.downloadAndRegisterEps(controller.signal);
await model.download(undefined, controller.signal);
```

Catalog access does not block on EP downloads. Call `downloadAndRegisterEps()` when you need hardware-accelerated execution providers.

## Quick Start
Expand Down Expand Up @@ -330,4 +343,4 @@ See `test/README.md` for details on prerequisites and setup.
npm run example
```

This runs the chat completion example in `examples/chat-completion.ts`.
This runs the chat completion example in `examples/chat-completion.ts`.
42 changes: 40 additions & 2 deletions sdk/js/src/detail/coreInterop.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,47 @@ export class CoreInterop {
return this.addon.executeCommandWithBinary(command, dataStr, binBuf);
}

public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise<string> {
public async executeCommandStreaming(
command: string,
params: any,
callback: (chunk: string) => void,
signal?: AbortSignal
): Promise<string> {
const createAbortError = (): Error => {
const error = new Error('Operation cancelled');
error.name = 'AbortError';
return error;
};

if (signal?.aborted) {
throw createAbortError();
}

const dataStr = params ? JSON.stringify(params) : '';
return this.addon.executeCommandStreaming(command, dataStr, callback);
let cancelled = false;
const wrappedCallback = (chunk: string) => {
if (signal?.aborted) {
cancelled = true;
throw createAbortError();
}

callback(chunk);
};

try {
const result = await this.addon.executeCommandStreaming(command, dataStr, wrappedCallback);
if (cancelled || signal?.aborted) {
throw createAbortError();
}

return result;
} catch (error) {
if (cancelled || signal?.aborted) {
throw createAbortError();
}

throw error;
}
}

}
7 changes: 4 additions & 3 deletions sdk/js/src/detail/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,10 @@ export class Model implements IModel {
/**
* Downloads the currently selected variant.
* @param progressCallback - Optional callback to report download progress.
* @param signal - Optional AbortSignal. When aborted, the download will be cancelled at the next progress update.
*/
public download(progressCallback?: (progress: number) => void): Promise<void> {
return this.selectedVariant.download(progressCallback);
public download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise<void> {
return this.selectedVariant.download(progressCallback, signal);
}

/**
Expand Down Expand Up @@ -202,4 +203,4 @@ export class Model implements IModel {
public createResponsesClient(baseUrl: string): ResponsesClient {
return this.selectedVariant.createResponsesClient(baseUrl);
}
}
}
14 changes: 10 additions & 4 deletions sdk/js/src/detail/modelVariant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,24 @@ export class ModelVariant implements IModel {
/**
* Downloads the model variant.
* @param progressCallback - Optional callback to report download progress (0-100).
* @param signal - Optional AbortSignal. When aborted, the download will be
* cancelled at the next progress update and the returned promise will reject.
*/
public async download(progressCallback?: (progress: number) => void): Promise<void> {
public async download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise<void> {
const request = { Params: { Model: this._modelInfo.id } };
if (!progressCallback) {
if (!progressCallback && !signal) {
this.coreInterop.executeCommand("download_model", request);
} else {
// Use the streaming path when progress or cancellation is needed.
// Provide a no-op callback when only cancellation is requested so
// the native callback mechanism is engaged.
const cb = progressCallback ?? (() => {});
await this.coreInterop.executeCommandStreaming("download_model", request, (chunk: string) => {
const progress = parseFloat(chunk);
if (!isNaN(progress)) {
progressCallback(progress);
cb(progress);
}
});
}, signal);
}
}

Expand Down
Loading
Loading