diff --git a/sdk/cs/README.md b/sdk/cs/README.md index 8547434d..72ee9ac4 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -99,6 +99,18 @@ await mgr.DownloadAndRegisterEpsAsync((epName, percent) => Console.WriteLine(); ``` +#### Cancelling model and EP downloads + +Pass a `CancellationToken` to either download API. Cancellation is observed on the next progress update. + +```csharp +// mgr and model already initialized +using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + +await mgr.DownloadAndRegisterEpsAsync(ct: cts.Token); +await model.DownloadAsync(ct: cts.Token); +``` + Catalog access no longer blocks on EP downloads. Call `DownloadAndRegisterEpsAsync` explicitly when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/cs/src/Detail/CoreInterop.cs b/sdk/cs/src/Detail/CoreInterop.cs index b88f5597..a099a0d2 100644 --- a/sdk/cs/src/Detail/CoreInterop.cs +++ b/sdk/cs/src/Detail/CoreInterop.cs @@ -297,6 +297,11 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput, if (helper.Exception != null) { + if (helper.Exception is OperationCanceledException canceledException) + { + throw canceledException; + } + throw new FoundryLocalException("Exception in callback handler. See InnerException for details", helper.Exception); } diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 250c601a..03ea8e10 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -6,6 +6,8 @@ namespace Microsoft.AI.Foundry.Local; +using System.Globalization; + using Microsoft.AI.Foundry.Local.Detail; using Microsoft.Extensions.Logging; @@ -144,16 +146,22 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, }; ICoreInterop.Response? response; + var useCallbackPath = downloadProgress != null || (ct?.CanBeCanceled ?? false); - if (downloadProgress == null) - { - response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); - } - else + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { - if (float.TryParse(progressString, out var progress)) + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + + if (downloadProgress != null && + float.TryParse(progressString, + NumberStyles.Float, + CultureInfo.InvariantCulture, + out var progress)) { downloadProgress(progress); } @@ -162,6 +170,10 @@ private async Task DownloadImplAsync(Action? downloadProgress = null, response = await _coreInterop.ExecuteCommandWithCallbackAsync("download_model", request, callback, ct).ConfigureAwait(false); } + else + { + response = await _coreInterop.ExecuteCommandAsync("download_model", request, ct).ConfigureAwait(false); + } if (response.Error != null) { diff --git a/sdk/cs/src/FoundryLocalManager.cs b/sdk/cs/src/FoundryLocalManager.cs index 10b51285..e4bdeada 100644 --- a/sdk/cs/src/FoundryLocalManager.cs +++ b/sdk/cs/src/FoundryLocalManager.cs @@ -6,6 +6,7 @@ namespace Microsoft.AI.Foundry.Local; using System; +using System.Globalization; using System.Text.Json; using System.Threading.Tasks; @@ -373,20 +374,27 @@ private async Task DownloadAndRegisterEpsImplAsync(IEnumerable ICoreInterop.Response result; - if (progressCallback != null) + var useCallbackPath = progressCallback != null || (ct?.CanBeCanceled ?? false); + + if (useCallbackPath) { var callback = new ICoreInterop.CallbackFn(progressString => { + if (ct is CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + } + var sepIndex = progressString.IndexOf('|'); if (sepIndex >= 0) { var name = progressString[..sepIndex]; if (double.TryParse(progressString[(sepIndex + 1)..], - System.Globalization.NumberStyles.Float, - System.Globalization.CultureInfo.InvariantCulture, + NumberStyles.Float, + CultureInfo.InvariantCulture, out var percent)) { - progressCallback(string.IsNullOrEmpty(name) ? "" : name, percent); + progressCallback?.Invoke(string.IsNullOrEmpty(name) ? "" : name, percent); } } }); diff --git a/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs new file mode 100644 index 00000000..b8f5aac3 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/DownloadCancellationTests.cs @@ -0,0 +1,80 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using Microsoft.AI.Foundry.Local.Detail; + +using Microsoft.Extensions.Logging; + +using Moq; + +internal sealed class DownloadCancellationTests +{ + [Test] + public async Task ModelVariantDownload_WithCancellableToken_UsesCallbackPathAndPropagatesCancellation() + { + var modelInfo = new ModelInfo + { + Id = "test-model-cpu:1", + Name = "test-model-cpu", + Alias = "test-model", + Version = 1, + ProviderType = "AzureFoundry", + Uri = "azureml://registries/azureml/models/test-model-cpu/versions/1", + ModelType = "ONNX", + }; + + var modelLoadManager = new Mock(MockBehavior.Strict); + var coreInterop = new Mock(MockBehavior.Strict); + var logger = new Mock(); + using var cts = new CancellationTokenSource(); + + coreInterop.Setup(x => x.ExecuteCommandWithCallbackAsync( + It.Is(s => s == "download_model"), + It.Is(r => r != null && + r.Params != null && + r.Params.ContainsKey("Model") && + r.Params["Model"] == modelInfo.Id), + It.IsAny(), + It.IsAny())) + .Returns((string commandName, + CoreInteropRequest? request, + ICoreInterop.CallbackFn callback, + CancellationToken? cancellationToken) => + { + callback("10"); + cts.Cancel(); + callback("20"); + return Task.FromResult(new ICoreInterop.Response()); + }); + + var model = new ModelVariant(modelInfo, modelLoadManager.Object, coreInterop.Object, logger.Object); + + OperationCanceledException? caught = null; + try + { + await model.DownloadAsync(ct: cts.Token); + } + catch (OperationCanceledException ex) + { + caught = ex; + } + + await Assert.That(caught).IsNotNull(); + coreInterop.Verify(x => x.ExecuteCommandWithCallbackAsync( + "download_model", + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Once); + coreInterop.Verify(x => x.ExecuteCommandAsync( + It.IsAny(), + It.IsAny(), + It.IsAny()), + Times.Never); + } +} diff --git a/sdk/cs/test/FoundryLocal.Tests/Utils.cs b/sdk/cs/test/FoundryLocal.Tests/Utils.cs index a289011b..fe968df1 100644 --- a/sdk/cs/test/FoundryLocal.Tests/Utils.cs +++ b/sdk/cs/test/FoundryLocal.Tests/Utils.cs @@ -443,7 +443,8 @@ private static string GetRepoRoot() while (dir != null) { - if (Directory.Exists(Path.Combine(dir.FullName, ".git"))) + var gitPath = Path.Combine(dir.FullName, ".git"); + if (Directory.Exists(gitPath) || File.Exists(gitPath)) return dir.FullName; dir = dir.Parent; diff --git a/sdk/js/README.md b/sdk/js/README.md index ff1ac542..6594390e 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -71,6 +71,19 @@ await manager.downloadAndRegisterEps((epName, percent) => { process.stdout.write('\n'); ``` +#### Cancelling model and EP downloads + +Use an `AbortController` with either `downloadAndRegisterEps()` or `model.download()`. Aborting the signal rejects the in-progress download promise. + +```typescript +// manager and model already initialized +const controller = new AbortController(); +setTimeout(() => controller.abort(), 5000); + +await manager.downloadAndRegisterEps(controller.signal); +await model.download(undefined, controller.signal); +``` + Catalog access does not block on EP downloads. Call `downloadAndRegisterEps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -330,4 +343,4 @@ See `test/README.md` for details on prerequisites and setup. npm run example ``` -This runs the chat completion example in `examples/chat-completion.ts`. \ No newline at end of file +This runs the chat completion example in `examples/chat-completion.ts`. diff --git a/sdk/js/src/detail/coreInterop.ts b/sdk/js/src/detail/coreInterop.ts index 72df7e26..269a9565 100644 --- a/sdk/js/src/detail/coreInterop.ts +++ b/sdk/js/src/detail/coreInterop.ts @@ -126,9 +126,47 @@ export class CoreInterop { return this.addon.executeCommandWithBinary(command, dataStr, binBuf); } - public executeCommandStreaming(command: string, params: any, callback: (chunk: string) => void): Promise { + public async executeCommandStreaming( + command: string, + params: any, + callback: (chunk: string) => void, + signal?: AbortSignal + ): Promise { + const createAbortError = (): Error => { + const error = new Error('Operation cancelled'); + error.name = 'AbortError'; + return error; + }; + + if (signal?.aborted) { + throw createAbortError(); + } + const dataStr = params ? JSON.stringify(params) : ''; - return this.addon.executeCommandStreaming(command, dataStr, callback); + let cancelled = false; + const wrappedCallback = (chunk: string) => { + if (signal?.aborted) { + cancelled = true; + throw createAbortError(); + } + + callback(chunk); + }; + + try { + const result = await this.addon.executeCommandStreaming(command, dataStr, wrappedCallback); + if (cancelled || signal?.aborted) { + throw createAbortError(); + } + + return result; + } catch (error) { + if (cancelled || signal?.aborted) { + throw createAbortError(); + } + + throw error; + } } } diff --git a/sdk/js/src/detail/model.ts b/sdk/js/src/detail/model.ts index c1ee0d5f..aba1caf7 100644 --- a/sdk/js/src/detail/model.ts +++ b/sdk/js/src/detail/model.ts @@ -126,9 +126,10 @@ export class Model implements IModel { /** * Downloads the currently selected variant. * @param progressCallback - Optional callback to report download progress. + * @param signal - Optional AbortSignal. When aborted, the download will be cancelled at the next progress update. */ - public download(progressCallback?: (progress: number) => void): Promise { - return this.selectedVariant.download(progressCallback); + public download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise { + return this.selectedVariant.download(progressCallback, signal); } /** @@ -202,4 +203,4 @@ export class Model implements IModel { public createResponsesClient(baseUrl: string): ResponsesClient { return this.selectedVariant.createResponsesClient(baseUrl); } -} \ No newline at end of file +} diff --git a/sdk/js/src/detail/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts index 43484bac..527d5493 100644 --- a/sdk/js/src/detail/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -108,18 +108,24 @@ export class ModelVariant implements IModel { /** * Downloads the model variant. * @param progressCallback - Optional callback to report download progress (0-100). + * @param signal - Optional AbortSignal. When aborted, the download will be + * cancelled at the next progress update and the returned promise will reject. */ - public async download(progressCallback?: (progress: number) => void): Promise { + public async download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise { const request = { Params: { Model: this._modelInfo.id } }; - if (!progressCallback) { + if (!progressCallback && !signal) { this.coreInterop.executeCommand("download_model", request); } else { + // Use the streaming path when progress or cancellation is needed. + // Provide a no-op callback when only cancellation is requested so + // the native callback mechanism is engaged. + const cb = progressCallback ?? (() => {}); await this.coreInterop.executeCommandStreaming("download_model", request, (chunk: string) => { const progress = parseFloat(chunk); if (!isNaN(progress)) { - progressCallback(progress); + cb(progress); } - }); + }, signal); } } diff --git a/sdk/js/src/foundryLocalManager.ts b/sdk/js/src/foundryLocalManager.ts index f22acdc0..5199aa2f 100644 --- a/sdk/js/src/foundryLocalManager.ts +++ b/sdk/js/src/foundryLocalManager.ts @@ -5,6 +5,13 @@ import { Catalog } from './catalog.js'; import { ResponsesClient } from './openai/responsesClient.js'; import { EpInfo, EpDownloadResult } from './types.js'; +function isAbortSignal(value: unknown): value is AbortSignal { + return typeof value === 'object' + && value !== null + && 'aborted' in value + && typeof (value as AbortSignal).aborted === 'boolean'; +} + /** * The main entry point for the Foundry Local SDK. * Manages the initialization of the core system and provides access to the Catalog and ModelLoadManager. @@ -123,18 +130,38 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(): Promise; + /** + * Downloads and registers execution providers. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(signal: AbortSignal): Promise; /** * Downloads and registers execution providers. * @param names - Array of EP names to download. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[]): Promise; + /** + * Downloads and registers execution providers. + * @param names - Array of EP names to download. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; /** * Downloads and registers execution providers, reporting progress. * @param names - Array of EP names to download. @@ -142,15 +169,40 @@ export class FoundryLocalManager { * @returns A promise that resolves with an EpDownloadResult describing the outcome. */ public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void): Promise; + /** + * Downloads and registers execution providers, reporting progress. + * @param names - Array of EP names to download. + * @param progressCallback - Callback invoked with (epName, percent) as each EP downloads. Percent is 0-100. + * @param signal - Optional AbortSignal used to cancel an in-progress download. + * @returns A promise that resolves with an EpDownloadResult describing the outcome. + */ + public downloadAndRegisterEps(names: string[], progressCallback: (epName: string, percent: number) => void, signal: AbortSignal): Promise; public async downloadAndRegisterEps( - namesOrCallback?: string[] | ((epName: string, percent: number) => void), - progressCallback?: (epName: string, percent: number) => void + namesOrCallbackOrSignal?: string[] | ((epName: string, percent: number) => void) | AbortSignal, + progressCallbackOrSignal?: ((epName: string, percent: number) => void) | AbortSignal, + maybeSignal?: AbortSignal ): Promise { + let progressCallback: ((epName: string, percent: number) => void) | undefined; let names: string[] | undefined; - if (typeof namesOrCallback === 'function') { - progressCallback = namesOrCallback; + let signal: AbortSignal | undefined; + + if (Array.isArray(namesOrCallbackOrSignal)) { + names = namesOrCallbackOrSignal; + if (typeof progressCallbackOrSignal === 'function') { + progressCallback = progressCallbackOrSignal; + signal = maybeSignal; + } else if (isAbortSignal(progressCallbackOrSignal)) { + signal = progressCallbackOrSignal; + } + } else if (typeof namesOrCallbackOrSignal === 'function') { + progressCallback = namesOrCallbackOrSignal; + if (isAbortSignal(progressCallbackOrSignal)) { + signal = progressCallbackOrSignal; + } + } else if (isAbortSignal(namesOrCallbackOrSignal)) { + signal = namesOrCallbackOrSignal; } else { - names = namesOrCallback; + signal = maybeSignal; } const params: { Params?: { Names: string } } = {}; @@ -180,13 +232,15 @@ export class FoundryLocalManager { progressCallback(epName || '', percent); } } - } + }, + signal ); } else { response = await this.coreInterop.executeCommandStreaming( "download_and_register_eps", Object.keys(params).length > 0 ? params : undefined, - () => {} // no-op callback + () => {}, // no-op callback + signal ); } diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 8f9bd0c1..122b1a09 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -18,7 +18,13 @@ export interface IModel { get capabilities(): string | null; get supportsToolCalling(): boolean | null; - download(progressCallback?: (progress: number) => void): Promise; + /** + * Download the model to local cache if not already present. + * @param progressCallback - Optional callback for download progress (0-100). + * @param signal - Optional AbortSignal. When aborted, the download will be + * cancelled at the next progress update and the returned promise will reject. + */ + download(progressCallback?: (progress: number) => void, signal?: AbortSignal): Promise; get path(): string; load(): Promise; removeFromCache(): void; diff --git a/sdk/js/test/foundryLocalManager.test.ts b/sdk/js/test/foundryLocalManager.test.ts index 48adcff4..526d2bb2 100644 --- a/sdk/js/test/foundryLocalManager.test.ts +++ b/sdk/js/test/foundryLocalManager.test.ts @@ -1,6 +1,7 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager } from './testUtils.js'; +import { FoundryLocalManager } from '../src/foundryLocalManager.js'; describe('Foundry Local Manager Tests', () => { it('should initialize successfully', function() { @@ -78,4 +79,33 @@ describe('Foundry Local Manager Tests', () => { manager.coreInterop.executeCommandStreaming = originalExecuteCommandStreaming; } }); + + it('downloadAndRegisterEps should pass AbortSignal through to streaming interop', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const manager = Object.create(FoundryLocalManager.prototype) as any; + manager.coreInterop = { + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(JSON.stringify({ + Success: true, + Status: 'All providers registered', + RegisteredEps: ['CUDAExecutionProvider'], + FailedEps: [] + })); + } + }; + manager._catalog = { + invalidateCache: () => {} + }; + + await FoundryLocalManager.prototype.downloadAndRegisterEps.call( + manager, + ['CUDAExecutionProvider'], + controller.signal + ); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_and_register_eps'); + expect(calls[0][3]).to.equal(controller.signal); + }); }); diff --git a/sdk/js/test/model.test.ts b/sdk/js/test/model.test.ts index 4048d9a1..c43b42a0 100644 --- a/sdk/js/test/model.test.ts +++ b/sdk/js/test/model.test.ts @@ -1,6 +1,9 @@ import { describe, it } from 'mocha'; import { expect } from 'chai'; import { getTestManager, TEST_MODEL_ALIAS } from './testUtils.js'; +import { Model } from '../src/detail/model.js'; +import { ModelVariant } from '../src/detail/modelVariant.js'; +import { DeviceType, type ModelInfo } from '../src/types.js'; describe('Model Tests', () => { it('should verify cached models from test-data-shared', async function() { @@ -58,4 +61,40 @@ describe('Model Tests', () => { await model.unload(); expect(await model.isLoaded()).to.be.false; }); -}); \ No newline at end of file + + it('download should use streaming interop when only an AbortSignal is provided', async function() { + const calls: unknown[][] = []; + const controller = new AbortController(); + const fakeCoreInterop = { + executeCommand: () => { + throw new Error('download should not use executeCommand when a signal is provided'); + }, + executeCommandStreaming: (...args: unknown[]) => { + calls.push(args); + return Promise.resolve(''); + } + }; + const modelInfo: ModelInfo = { + id: 'test-model-cpu:1', + name: 'test-model-cpu', + version: 1, + alias: TEST_MODEL_ALIAS, + providerType: 'AzureFoundry', + uri: 'azureml://registries/azureml/models/test-model-cpu/versions/1', + modelType: 'ONNX', + cached: false, + createdAtUnix: 0, + runtime: { + deviceType: DeviceType.CPU, + executionProvider: 'CPUExecutionProvider' + } + }; + const variant = new ModelVariant(modelInfo, fakeCoreInterop as any, {} as any); + const model = new Model(variant); + + await model.download(undefined, controller.signal); + expect(calls.length).to.equal(1); + expect(calls[0][0]).to.equal('download_model'); + expect(calls[0][3]).to.equal(controller.signal); + }); +}); diff --git a/sdk/python/README.md b/sdk/python/README.md index 2a121411..55a6f8d1 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -108,6 +108,21 @@ manager.download_and_register_eps(progress_callback=on_progress) print() ``` +### Cancelling model and EP downloads + +Pass a `threading.Event` as `cancel_event` to either download API. Set the event from another thread or handler to cancel the in-progress download. + +```python +import threading + +# manager and model already initialized +cancel_event = threading.Event() +threading.Timer(5.0, cancel_event.set).start() + +manager.download_and_register_eps(cancel_event=cancel_event) +model.download(cancel_event=cancel_event) +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps()` when you need hardware-accelerated execution providers. ## Quick Start @@ -328,4 +343,4 @@ See [test/README.md](test/README.md) for detailed test setup and structure. ```bash python examples/chat_completion.py -``` \ No newline at end of file +``` diff --git a/sdk/python/src/detail/core_interop.py b/sdk/python/src/detail/core_interop.py index 1cd53e33..9b5c66f6 100644 --- a/sdk/python/src/detail/core_interop.py +++ b/sdk/python/src/detail/core_interop.py @@ -10,6 +10,7 @@ import logging import os import sys +import threading from dataclasses import dataclass from pathlib import Path @@ -67,6 +68,10 @@ class Response: error: Optional[str] = None +class CancelledException(Exception): + """Raised internally when a download or streaming operation is cancelled.""" + + class CallbackHelper: """Internal helper class to convert the callback from ctypes to a str and call the python callback.""" @staticmethod @@ -75,18 +80,27 @@ def callback(data_ptr, length, self_ptr): try: self = ctypes.cast(self_ptr, ctypes.POINTER(ctypes.py_object)).contents.value + # Check for cancellation before processing the callback data. + if self._cancel_event is not None and self._cancel_event.is_set(): + raise CancelledException("Operation cancelled") + # convert to a string and pass to the python callback data_bytes = ctypes.string_at(data_ptr, length) data_str = data_bytes.decode('utf-8') self._py_callback(data_str) return 0 # continue + except CancelledException as e: + if self is not None and self.exception is None: + self.exception = e + return 1 # cancel except Exception as e: if self is not None and self.exception is None: self.exception = e # keep the first only as they are likely all the same return 1 # cancel on error - def __init__(self, py_callback: Callable[[str], None]): + def __init__(self, py_callback: Callable[[str], None], cancel_event: Optional['threading.Event'] = None): self._py_callback = py_callback + self._cancel_event = cancel_event self.exception = None @@ -225,7 +239,8 @@ def __init__(self, config: Configuration): logger.info("Foundry.Local.Core initialized successfully: %s", response.data) def _execute_command(self, command: str, interop_request: InteropRequest = None, - callback: CoreInterop.CALLBACK_TYPE = None): + callback: CoreInterop.CALLBACK_TYPE = None, + cancel_event: Optional[threading.Event] = None): cmd_ptr, cmd_len, cmd_buf = CoreInterop._to_c_buffer(command) data_ptr, data_len, data_buf = CoreInterop._to_c_buffer(interop_request.to_json() if interop_request else None) @@ -237,7 +252,7 @@ def _execute_command(self, command: str, interop_request: InteropRequest = None, # If a callback is provided, use the execute_command_with_callback method # We need a helper to do the initial conversion from ctypes to Python and pass it through to the # provided callback function - callback_helper = CallbackHelper(callback) + callback_helper = CallbackHelper(callback, cancel_event) callback_py_obj = ctypes.py_object(callback_helper) callback_helper_ptr = ctypes.cast(ctypes.pointer(callback_py_obj), ctypes.c_void_p) callback_fn = CoreInterop.CALLBACK_TYPE(CallbackHelper.callback) @@ -245,6 +260,8 @@ def _execute_command(self, command: str, interop_request: InteropRequest = None, lib.execute_command_with_callback(ctypes.byref(req), ctypes.byref(resp), callback_fn, callback_helper_ptr) if callback_helper.exception is not None: + if isinstance(callback_helper.exception, CancelledException): + raise FoundryLocalException("Operation cancelled") raise callback_helper.exception else: lib.execute_command(ctypes.byref(req), ctypes.byref(resp)) @@ -276,23 +293,33 @@ def execute_command(self, command_name: str, command_input: Optional[InteropRequ return response def execute_command_with_callback(self, command_name: str, command_input: Optional[InteropRequest], - callback: Callable[[str], None]) -> Response: + callback: Callable[[str], None], + cancel_event: Optional[threading.Event] = None) -> Response: """Execute a command with a streaming callback. The ``callback`` receives incremental string data from the native layer (e.g. streaming chat tokens or download progress). + If ``cancel_event`` is provided and is set, the native call will be + cancelled at the next callback invocation and a ``FoundryLocalException`` + with message ``"Operation cancelled"`` will be raised. + Args: command_name: The native command name. command_input: Optional request parameters. callback: Called with each incremental string response. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. Returns: A ``Response`` with ``data`` on success or ``error`` on failure. + + Raises: + FoundryLocalException: If the operation is cancelled or fails. """ logger.debug("Executing command with callback: %s Input: %s", command_name, command_input.params if command_input else None) - response = self._execute_command(command_name, command_input, callback) + response = self._execute_command(command_name, command_input, callback, cancel_event) return response diff --git a/sdk/python/src/detail/model.py b/sdk/python/src/detail/model.py index 6d60b7a2..a71b1dba 100644 --- a/sdk/python/src/detail/model.py +++ b/sdk/python/src/detail/model.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -115,9 +116,10 @@ def is_loaded(self) -> bool: """Is the currently selected variant loaded in memory?""" return self._selected_variant.is_loaded - def download(self, progress_callback: Optional[Callable[[float], None]] = None) -> None: + def download(self, progress_callback: Optional[Callable[[float], None]] = None, + cancel_event: Optional[Event] = None) -> None: """Download the currently selected variant.""" - self._selected_variant.download(progress_callback) + self._selected_variant.download(progress_callback, cancel_event) def get_path(self) -> str: """Get the path to the currently selected variant.""" diff --git a/sdk/python/src/detail/model_variant.py b/sdk/python/src/detail/model_variant.py index 76efb05c..ff931237 100644 --- a/sdk/python/src/detail/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -5,6 +5,7 @@ from __future__ import annotations import logging +from threading import Event from typing import Callable, List, Optional from ..imodel import IModel @@ -112,20 +113,28 @@ def is_loaded(self) -> bool: loaded_model_ids = self._model_load_manager.list_loaded() return self.id in loaded_model_ids - def download(self, progress_callback: Callable[[float], None] = None): + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None): """Download this variant to the local cache. Args: progress_callback: Optional callback receiving download progress as a percentage (0.0 to 100.0). + cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ request = InteropRequest(params={"Model": self.id}) - if progress_callback is None: + if progress_callback is None and cancel_event is None: response = self._core_interop.execute_command("download_model", request) else: + # Use the callback path when either progress or cancellation is needed. + # If no progress callback was provided, use a no-op so the native + # callback mechanism is engaged (required for cancellation checks). + user_cb = progress_callback if progress_callback is not None else lambda _pct: None response = self._core_interop.execute_command_with_callback( "download_model", request, - lambda pct_str: progress_callback(float(pct_str)) + lambda pct_str: user_cb(float(pct_str)), + cancel_event, ) logger.info("Download response: %s", response) diff --git a/sdk/python/src/foundry_local_manager.py b/sdk/python/src/foundry_local_manager.py index a649f8e5..f3678267 100644 --- a/sdk/python/src/foundry_local_manager.py +++ b/sdk/python/src/foundry_local_manager.py @@ -101,6 +101,7 @@ def download_and_register_eps( self, names: Optional[list[str]] = None, progress_callback: Optional[Callable[[str, float], None]] = None, + cancel_event: Optional[threading.Event] = None, ) -> EpDownloadResult: """Download and register execution providers. @@ -109,6 +110,8 @@ def download_and_register_eps( all discoverable EPs are downloaded. progress_callback: Optional callback ``(ep_name: str, percent: float) -> None`` invoked as each EP downloads. ``percent`` is 0-100. + cancel_event: Optional ``threading.Event`` that signals cancellation + when set. The download will be cancelled at the next progress update. Returns: ``EpDownloadResult`` describing operation status and per-EP outcomes. @@ -120,19 +123,22 @@ def download_and_register_eps( if names is not None and len(names) > 0: request = InteropRequest(params={"Names": ",".join(names)}) - if progress_callback is not None: + if progress_callback is not None or cancel_event is not None: + user_cb = progress_callback + def _on_chunk(chunk: str) -> None: - sep = chunk.find("|") - if sep >= 0: - ep_name = chunk[:sep] or "" - try: - percent = float(chunk[sep + 1:]) - progress_callback(ep_name, percent) - except ValueError: - pass + if user_cb is not None: + sep = chunk.find("|") + if sep >= 0: + ep_name = chunk[:sep] or "" + try: + percent = float(chunk[sep + 1:]) + user_cb(ep_name, percent) + except ValueError: + pass response = self._core_interop.execute_command_with_callback( - "download_and_register_eps", request, _on_chunk + "download_and_register_eps", request, _on_chunk, cancel_event ) else: response = self._core_interop.execute_command("download_and_register_eps", request) diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index f723e514..fc63f374 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -5,6 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from threading import Event from typing import Callable, List, Optional from .openai.chat_client import ChatClient @@ -76,10 +77,13 @@ def supports_tool_calling(self) -> Optional[bool]: pass @abstractmethod - def download(self, progress_callback: Callable[[float], None] = None) -> None: + def download(self, progress_callback: Callable[[float], None] = None, + cancel_event: Optional[Event] = None) -> None: """ Download the model to local cache if not already present. :param progress_callback: Optional callback function for download progress as a percentage (0.0 to 100.0). + :param cancel_event: Optional ``threading.Event``. When set, the download will be + cancelled at the next progress update and ``FoundryLocalException`` is raised. """ pass diff --git a/sdk/python/test/test_foundry_local_manager.py b/sdk/python/test/test_foundry_local_manager.py index 31528891..3abb37f6 100644 --- a/sdk/python/test/test_foundry_local_manager.py +++ b/sdk/python/test/test_foundry_local_manager.py @@ -6,6 +6,10 @@ from __future__ import annotations +import threading + +from foundry_local_sdk.foundry_local_manager import FoundryLocalManager + class _Response: def __init__(self, data=None, error=None): @@ -22,6 +26,12 @@ def execute_command(self, command_name, command_input=None): self.calls.append((command_name, command_input)) return self._responses[command_name] + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return self._responses[command_name] + class TestFoundryLocalManager: """Foundry Local Manager Tests.""" @@ -81,3 +91,36 @@ def test_download_and_register_eps_returns_result(self, manager): assert result.status == "ok" assert result.registered_eps == ["CUDAExecutionProvider"] assert result.failed_eps == [] + + def test_download_and_register_eps_uses_callback_path_when_cancel_event_is_provided(self): + fake_core = _FakeCoreInterop( + { + "download_and_register_eps": _Response( + data=( + '{"Success":true,"Status":"ok",' + '"RegisteredEps":["CUDAExecutionProvider"],"FailedEps":[]}' + ), + error=None, + ) + } + ) + manager = FoundryLocalManager.__new__(FoundryLocalManager) + manager._core_interop = fake_core + manager.catalog = type( + "_FakeCatalog", + (), + {"_invalidate_cache": staticmethod(lambda: None)}, + )() + cancel_event = threading.Event() + + result = manager.download_and_register_eps( + ["CUDAExecutionProvider"], cancel_event=cancel_event + ) + + assert result.success is True + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_and_register_eps" + assert command_input.params == {"Names": "CUDAExecutionProvider"} + assert callable(callback) + assert seen_cancel_event is cancel_event diff --git a/sdk/python/test/test_model.py b/sdk/python/test/test_model.py index e2ea1509..cd2af9ef 100644 --- a/sdk/python/test/test_model.py +++ b/sdk/python/test/test_model.py @@ -6,6 +6,12 @@ from __future__ import annotations +import threading + +from types import SimpleNamespace + +from foundry_local_sdk.detail.model_variant import ModelVariant + from .conftest import TEST_MODEL_ALIAS, AUDIO_MODEL_ALIAS @@ -86,3 +92,44 @@ def test_should_expose_supports_tool_calling(self, catalog): assert model is not None stc = model.supports_tool_calling assert stc is None or isinstance(stc, bool) + + def test_download_should_use_callback_path_when_cancel_event_is_provided(self): + """Model download should route through callback interop when cancellation is enabled.""" + + class _Response: + def __init__(self, data=None, error=None): + self.data = data + self.error = error + + class _FakeCoreInterop: + def __init__(self): + self.calls = [] + + def execute_command(self, command_name, command_input=None): + raise AssertionError( + "download should not use execute_command when cancel_event is provided" + ) + + def execute_command_with_callback( + self, command_name, command_input=None, callback=None, cancel_event=None + ): + self.calls.append((command_name, command_input, callback, cancel_event)) + return _Response(data="", error=None) + + fake_core = _FakeCoreInterop() + cancel_event = threading.Event() + variant = ModelVariant.__new__(ModelVariant) + variant._model_info = SimpleNamespace(id="test-model-cpu:1", alias=TEST_MODEL_ALIAS) + variant._id = "test-model-cpu:1" + variant._alias = TEST_MODEL_ALIAS + variant._core_interop = fake_core + variant._model_load_manager = None + + variant.download(cancel_event=cancel_event) + + assert len(fake_core.calls) == 1 + command_name, command_input, callback, seen_cancel_event = fake_core.calls[0] + assert command_name == "download_model" + assert command_input.params == {"Model": "test-model-cpu:1"} + assert callable(callback) + assert seen_cancel_event is cancel_event diff --git a/sdk/rust/README.md b/sdk/rust/README.md index ce97a7dd..d017ce5e 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -107,6 +107,28 @@ manager.download_and_register_eps_with_progress(None, move |ep_name: &str, perce println!(); ``` +#### Cancelling model and EP downloads + +Use a shared `Arc` with the cancellable download APIs. Set the flag from another task or signal handler to stop the in-progress download. + +```rust +use std::sync::{ + Arc, + atomic::AtomicBool, +}; + +// manager and model already initialized +let cancel_flag = Arc::new(AtomicBool::new(false)); +// call cancel_flag.store(true, ...) from another task or signal handler to cancel + +manager + .download_and_register_eps_cancellable(None, Arc::clone(&cancel_flag)) + .await?; +model + .download_cancellable(None::, Arc::clone(&cancel_flag)) + .await?; +``` + Catalog access does not block on EP downloads. Call `download_and_register_eps` when you need hardware-accelerated execution providers. ## Quick Start diff --git a/sdk/rust/src/detail/core_interop.rs b/sdk/rust/src/detail/core_interop.rs index 43884d7f..881e5ef3 100644 --- a/sdk/rust/src/detail/core_interop.rs +++ b/sdk/rust/src/detail/core_interop.rs @@ -9,6 +9,7 @@ use std::ffi::CString; use std::path::{Path, PathBuf}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use libloading::{Library, Symbol}; @@ -126,6 +127,7 @@ unsafe fn free_native_buffer(ptr: *mut u8) { struct StreamingCallbackState<'a> { callback: &'a mut dyn FnMut(&str), buf: Vec, + cancel_flag: Option>, } impl<'a> StreamingCallbackState<'a> { @@ -133,9 +135,25 @@ impl<'a> StreamingCallbackState<'a> { Self { callback, buf: Vec::new(), + cancel_flag: None, } } + fn with_cancel(callback: &'a mut dyn FnMut(&str), cancel_flag: Arc) -> Self { + Self { + callback, + buf: Vec::new(), + cancel_flag: Some(cancel_flag), + } + } + + /// Returns `true` if cancellation has been requested. + fn is_cancelled(&self) -> bool { + self.cancel_flag + .as_ref() + .is_some_and(|f| f.load(Ordering::Relaxed)) + } + /// Append raw bytes, decode as much valid UTF-8 as possible, and forward /// complete text to the callback. Any trailing incomplete multi-byte /// sequence is kept in the buffer for the next call. Invalid byte @@ -208,15 +226,21 @@ unsafe extern "C" fn streaming_trampoline( // by the caller of `execute_command_with_callback` for the duration of // the native call. let state = &mut *(user_data as *mut StreamingCallbackState<'_>); + + // Check for cancellation before processing the chunk. + if state.is_cancelled() { + return 1; // cancel + } + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native // core's callback contract. let slice = std::slice::from_raw_parts(data, length as usize); state.push(slice); + 0 // continue })); - if result.is_err() { - 1 - } else { - 0 + match result { + Ok(ret) => ret, + Err(_) => 1, } } @@ -368,6 +392,32 @@ impl CoreInterop { where F: FnMut(&str), { + self.execute_command_streaming_impl(command, params, &mut callback, None) + } + + /// Like [`Self::execute_command_streaming`], but accepts a cancellation + /// flag. When `cancel_flag` is set to `true`, the native call will be + /// cancelled at the next callback invocation and an error is returned. + pub fn execute_command_streaming_cancellable( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str), + { + self.execute_command_streaming_impl(command, params, &mut callback, Some(cancel_flag)) + } + + fn execute_command_streaming_impl( + &self, + command: &str, + params: Option<&Value>, + callback: &mut dyn FnMut(&str), + cancel_flag: Option>, + ) -> Result { let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { reason: format!("Invalid command string: {e}"), })?; @@ -392,8 +442,10 @@ impl CoreInterop { // Wrap the closure in a StreamingCallbackState that handles partial // UTF-8 sequences split across native callbacks. - let mut cb = |chunk: &str| callback(chunk); - let mut state = StreamingCallbackState::new(&mut cb); + let mut state = match cancel_flag { + Some(flag) => StreamingCallbackState::with_cancel(callback, flag), + None => StreamingCallbackState::new(callback), + }; let user_data = &mut state as *mut StreamingCallbackState<'_> as *mut std::ffi::c_void; // SAFETY: `request` fields point into `cmd` and `data_cstr` which are @@ -410,9 +462,19 @@ impl CoreInterop { ); } + let cancelled = state.is_cancelled(); + // Flush any trailing partial UTF-8 bytes. state.flush(); + if cancelled { + // Free native response memory before returning the error. + Self::process_response(response).ok(); + return Err(FoundryLocalError::CommandExecution { + reason: "Operation cancelled".to_string(), + }); + } + Self::process_response(response) } @@ -456,6 +518,36 @@ impl CoreInterop { })? } + /// Async version of [`Self::execute_command_streaming_cancellable`]. + /// + /// Accepts a shared cancellation flag (`Arc`). When the flag + /// is set to `true`, the native call will be cancelled at the next + /// callback invocation and an error is returned. + pub async fn execute_command_streaming_cancellable_async( + self: &Arc, + command: String, + params: Option, + callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming_cancellable( + &command, + params.as_ref(), + callback, + cancel_flag, + ) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + /// Async streaming variant that bridges the FFI callback into a /// [`tokio::sync::mpsc`] channel. /// diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 08288aee..5921fbcd 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -6,7 +6,7 @@ use std::fmt; use std::path::PathBuf; -use std::sync::atomic::{AtomicUsize, Ordering::Relaxed}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering::Relaxed}; use std::sync::Arc; use super::core_interop::CoreInterop; @@ -213,6 +213,23 @@ impl Model { self.selected_variant().download(progress).await } + /// Like [`Self::download`], but accepts a shared cancellation flag + /// (`Arc`). When the flag is set to `true`, the download + /// will be cancelled at the next progress callback and an error is + /// returned. + pub async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.selected_variant() + .download_cancellable(progress, cancel_flag) + .await + } + /// Return the local file-system path of the (selected) variant. pub async fn path(&self) -> Result { self.selected_variant().path().await diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index 1f8ce7d5..a49aae21 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -5,6 +5,7 @@ use std::fmt; use std::path::PathBuf; +use std::sync::atomic::AtomicBool; use std::sync::Arc; use serde_json::json; @@ -88,12 +89,54 @@ impl ModelVariant { } pub(crate) async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, None).await + } + + /// Like [`Self::download`], but accepts a shared cancellation flag. + /// When `cancel_flag` is set to `true`, the download will be cancelled at + /// the next progress callback. + pub(crate) async fn download_cancellable( + &self, + progress: Option, + cancel_flag: Arc, + ) -> Result<()> + where + F: FnMut(f64) + Send + 'static, + { + self.download_impl(progress, Some(cancel_flag)).await + } + + async fn download_impl( + &self, + progress: Option, + cancel_flag: Option>, + ) -> Result<()> where F: FnMut(f64) + Send + 'static, { let params = json!({ "Params": { "Model": self.info.id } }); - match progress { - Some(mut cb) => { + match (progress, cancel_flag) { + (Some(mut cb), Some(flag)) => { + let wrapper = move |chunk: &str| { + for token in chunk.split_whitespace() { + if let Ok(pct) = token.parse::() { + cb(pct); + } + } + }; + self.core + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + wrapper, + flag, + ) + .await?; + } + (Some(mut cb), None) => { let wrapper = move |chunk: &str| { for token in chunk.split_whitespace() { if let Ok(pct) = token.parse::() { @@ -105,7 +148,19 @@ impl ModelVariant { .execute_command_streaming_async("download_model".into(), Some(params), wrapper) .await?; } - None => { + (None, Some(flag)) => { + // Use a no-op callback to engage the callback mechanism + // required for cancellation checks. + self.core + .execute_command_streaming_cancellable_async( + "download_model".into(), + Some(params), + |_: &str| {}, + flag, + ) + .await?; + } + (None, None) => { self.core .execute_command_async("download_model".into(), Some(params)) .await?; diff --git a/sdk/rust/src/foundry_local_manager.rs b/sdk/rust/src/foundry_local_manager.rs index 0c22ef15..a14b42b7 100644 --- a/sdk/rust/src/foundry_local_manager.rs +++ b/sdk/rust/src/foundry_local_manager.rs @@ -4,6 +4,7 @@ //! library, provides access to the model [`Catalog`], and can start / stop //! the local web service. +use std::sync::atomic::AtomicBool; use std::sync::{Arc, Mutex, OnceLock}; use serde_json::json; @@ -150,7 +151,19 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, ) -> Result { - self.download_and_register_eps_impl(names, None::) + self.download_and_register_eps_impl(names, None::, None) + .await + } + + /// Like [`Self::download_and_register_eps`], but accepts a shared + /// cancellation flag (`Arc`). When the flag is set to `true`, + /// the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_cancellable( + &self, + names: Option<&[&str]>, + cancel_flag: Arc, + ) -> Result { + self.download_and_register_eps_impl(names, None::, Some(cancel_flag)) .await } @@ -169,7 +182,23 @@ impl FoundryLocalManager { where F: FnMut(&str, f64) + Send + 'static, { - self.download_and_register_eps_impl(names, Some(progress_callback)) + self.download_and_register_eps_impl(names, Some(progress_callback), None) + .await + } + + /// Like [`Self::download_and_register_eps_with_progress`], but accepts a + /// shared cancellation flag (`Arc`). When the flag is set to + /// `true`, the download will be cancelled at the next progress callback. + pub async fn download_and_register_eps_with_progress_cancellable( + &self, + names: Option<&[&str]>, + progress_callback: F, + cancel_flag: Arc, + ) -> Result + where + F: FnMut(&str, f64) + Send + 'static, + { + self.download_and_register_eps_impl(names, Some(progress_callback), Some(cancel_flag)) .await } @@ -177,6 +206,7 @@ impl FoundryLocalManager { &self, names: Option<&[&str]>, progress_callback: Option, + cancel_flag: Option>, ) -> Result where F: FnMut(&str, f64) + Send + 'static, @@ -186,8 +216,28 @@ impl FoundryLocalManager { _ => None, }; - let raw = match progress_callback { - Some(cb) => { + let raw = match (progress_callback, cancel_flag) { + (Some(cb), Some(flag)) => { + let mut callback = cb; + let wrapper = move |chunk: &str| { + if let Some(sep) = chunk.find('|') { + let name = &chunk[..sep]; + if let Ok(percent) = chunk[sep + 1..].parse::() { + callback(if name.is_empty() { "" } else { name }, percent); + } + } + }; + + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + wrapper, + flag, + ) + .await? + } + (Some(cb), None) => { let mut callback = cb; let wrapper = move |chunk: &str| { if let Some(sep) = chunk.find('|') { @@ -206,7 +256,17 @@ impl FoundryLocalManager { ) .await? } - None => { + (None, Some(flag)) => { + self.core + .execute_command_streaming_cancellable_async( + "download_and_register_eps".into(), + params, + |_chunk: &str| {}, + flag, + ) + .await? + } + (None, None) => { self.core .execute_command_async("download_and_register_eps".into(), params) .await?