diff --git a/mediapipe/tasks/web/genai/llm_inference/BUILD b/mediapipe/tasks/web/genai/llm_inference/BUILD index 2ad83adb35..ceade59277 100644 --- a/mediapipe/tasks/web/genai/llm_inference/BUILD +++ b/mediapipe/tasks/web/genai/llm_inference/BUILD @@ -21,6 +21,7 @@ mediapipe_ts_library( deps = [ ":llm_inference_types", ":model_loading_utils", + ":efficient_model_loader", "//mediapipe/framework:calculator_jspb_proto", "//mediapipe/tasks/cc/core/proto:base_options_jspb_proto", "//mediapipe/tasks/cc/genai/inference/calculators:detokenizer_calculator_jspb_proto", @@ -58,3 +59,11 @@ mediapipe_ts_library( ], visibility = ["//visibility:public"], ) + +mediapipe_ts_library( + name = "efficient_model_loader", + srcs = [ + "efficient_model_loader.ts", + ], + deps = [], +) diff --git a/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader.ts b/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader.ts new file mode 100644 index 0000000000..625ade68d0 --- /dev/null +++ b/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader.ts @@ -0,0 +1,68 @@ +/** + * Copyright 2025 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Optimized model loading utilities for LLM inference. + */ + +/** + * Creates a streaming model loader with proper resource management. + */ +export async function createModelStream( + modelAssetPath: string, + signal?: AbortSignal +): Promise> { + const response = await fetch(modelAssetPath, { signal }); + + if (!response.ok) { + throw new Error( + `Failed to fetch model: ${modelAssetPath} (${response.status})` + ); + } + + if (!response.body) { + throw new Error( + `Failed to fetch model: ${modelAssetPath} (no body)` + ); + } + + return response.body.getReader(); +} + +/** + * Model loader with cancellation support. + */ +export class ModelLoader { + private abortController?: AbortController; + + async loadModel( + modelAssetPath: string + ): Promise> { + this.cancel(); + this.abortController = new AbortController(); + + return createModelStream(modelAssetPath, this.abortController.signal); + } + + cancel(): void { + this.abortController?.abort(); + this.abortController = undefined; + } + + isLoading(): boolean { + return !!this.abortController; + } +} \ No newline at end of file diff --git a/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader_test.ts b/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader_test.ts new file mode 100644 index 0000000000..bdc0ffbea0 --- /dev/null +++ b/mediapipe/tasks/web/genai/llm_inference/efficient_model_loader_test.ts @@ -0,0 +1,142 @@ +/** + * Copyright 2025 The MediaPipe Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import 'jasmine'; + +import { + createModelStream, + ModelLoader, +} from './efficient_model_loader'; + +describe('EfficientModelLoader', () => { + let mockFetch: jasmine.Spy; + let originalFetch: typeof fetch; + + beforeEach(() => { + originalFetch = globalThis.fetch; + mockFetch = jasmine.createSpy('fetch'); + globalThis.fetch = mockFetch; + }); + + afterEach(() => { + globalThis.fetch = originalFetch; + }); + + describe('createModelStream', () => { + it('should create a stream from a successful fetch', async () => { + const mockData = new Uint8Array([1, 2, 3, 4, 5]); + const mockResponse = { + ok: true, + status: 200, + body: new ReadableStream({ + start(controller) { + controller.enqueue(mockData); + controller.close(); + }, + }), + }; + mockFetch.and.returnValue(Promise.resolve(mockResponse)); + + const stream = await createModelStream('http://example.com/model.bin'); + const { value } = await stream.read(); + + expect(value).toEqual(mockData); + }); + + it('should throw error for failed fetch', async () => { + const mockResponse = { + ok: false, + status: 404, + }; + mockFetch.and.returnValue(Promise.resolve(mockResponse)); + + await expectAsync( + createModelStream('http://example.com/nonexistent.bin') + ).toBeRejectedWithError(/Failed to fetch model.*404/); + }); + }); + + + + describe('ModelLoader', () => { + let loader: ModelLoader; + + beforeEach(() => { + loader = new ModelLoader(); + }); + + afterEach(() => { + loader.cancel(); + }); + + it('should load a model successfully', async () => { + const mockData = new Uint8Array([1, 2, 3]); + const mockResponse = { + ok: true, + status: 200, + body: new ReadableStream({ + start(controller) { + controller.enqueue(mockData); + controller.close(); + }, + }), + }; + mockFetch.and.returnValue(Promise.resolve(mockResponse)); + + const stream = await loader.loadModel('http://example.com/model.bin'); + const { value } = await stream.read(); + + expect(value).toEqual(mockData); + }); + + it('should track loading state', async () => { + mockFetch.and.returnValue(new Promise(() => {})); // Never resolves + + expect(loader.isLoading()).toBeFalse(); + + const loadPromise = loader.loadModel('http://example.com/model.bin'); + expect(loader.isLoading()).toBeTrue(); + + loader.cancel(); + await expectAsync(loadPromise).toBeRejected(); + expect(loader.isLoading()).toBeFalse(); + }); + + it('should cancel previous loading when starting new load', async () => { + mockFetch.and.returnValue(new Promise(() => {})); // Never resolves + + const firstLoad = loader.loadModel('http://example.com/model1.bin'); + expect(loader.isLoading()).toBeTrue(); + + loader.cancel(); + expect(loader.isLoading()).toBeFalse(); + + await expectAsync(firstLoad).toBeRejected(); + }); + + + + it('should handle loading errors gracefully', async () => { + mockFetch.and.returnValue(Promise.reject(new Error('Network failure'))); + + await expectAsync( + loader.loadModel('http://example.com/model.bin') + ).toBeRejected(); + + expect(loader.isLoading()).toBeFalse(); + }); + }); +}); \ No newline at end of file diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts index d104e5741b..b518a65037 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference.ts @@ -59,6 +59,7 @@ import { tee, uint8ArrayToStream, } from './model_loading_utils'; +import { ModelLoader } from './efficient_model_loader'; export type { Audio, @@ -181,6 +182,7 @@ export class LlmInference extends TaskRunner { private streamingReader?: StreamingReader; private useLlmEngine = false; private isConvertedModel = false; + private modelLoader = new ModelLoader(); // The WebGPU device used for LLM inference. private wgpuDevice?: GPUDevice; @@ -394,7 +396,7 @@ export class LlmInference extends TaskRunner { override async setOptions(options: LlmInferenceOptions): Promise { // TODO: b/324482487 - Support customizing config for Web task of LLM // Inference. - if (this.isProcessing) { + if (this.isProcessing || this.modelLoader.isLoading()) { throw new Error('Cannot set options while loading or processing.'); } @@ -414,20 +416,15 @@ export class LlmInference extends TaskRunner { let modelStream: ReadableStreamDefaultReader | undefined; if (options.baseOptions?.modelAssetPath) { - const request = await fetch( - options.baseOptions.modelAssetPath.toString(), - ); - if (!request.ok) { - throw new Error( - `Failed to fetch model: ${options.baseOptions.modelAssetPath} (${request.status})`, + try { + modelStream = await this.modelLoader.loadModel( + options.baseOptions.modelAssetPath.toString() ); - } - if (!request.body) { + } catch (error) { throw new Error( - `Failed to fetch model: ${options.baseOptions.modelAssetPath} (no body)`, + `Failed to load model from path: ${options.baseOptions.modelAssetPath}. ${error}` ); } - modelStream = request.body.getReader(); } else if (options.baseOptions?.modelAssetBuffer instanceof Uint8Array) { modelStream = uint8ArrayToStream( options.baseOptions.modelAssetBuffer, @@ -1385,6 +1382,8 @@ export class LlmInference extends TaskRunner { } override close() { + this.modelLoader.cancel(); + if (this.useLlmEngine) { ( this.graphRunner as unknown as LlmGraphRunner diff --git a/mediapipe/tasks/web/genai/llm_inference/llm_inference_test.ts b/mediapipe/tasks/web/genai/llm_inference/llm_inference_test.ts index f2ac5ad46f..01fec86758 100644 --- a/mediapipe/tasks/web/genai/llm_inference/llm_inference_test.ts +++ b/mediapipe/tasks/web/genai/llm_inference/llm_inference_test.ts @@ -77,6 +77,27 @@ describe('LlmInference', () => { expect(llmInference).toBeDefined(); }); + it('loads a model from modelAssetPath', async () => { + const pathOptions = { + baseOptions: { modelAssetPath: modelUrl }, + numResponses: 1, + }; + + llmInference = await LlmInference.createFromOptions(genaiFileset, pathOptions); + expect(llmInference).toBeDefined(); + }); + + it('handles modelAssetPath loading errors', async () => { + const invalidPathOptions = { + baseOptions: { modelAssetPath: 'http://invalid-url/nonexistent.bin' }, + numResponses: 1, + }; + + await expectAsync( + LlmInference.createFromOptions(genaiFileset, invalidPathOptions) + ).toBeRejectedWithError(/Failed to load model from path/); + }); + it('loads a model, deletes it, and then loads it again', async () => { llmInference = await load(); @@ -295,6 +316,8 @@ describe('LlmInference', () => { }).toThrowError(/currently loading or processing/); expect(typeof (await responsePromise)).toBe('string'); }); + + }); describe('running', () => { @@ -387,6 +410,8 @@ describe('LlmInference', () => { await expectAsync(responsePromise).toBeResolved(); expect(typeof (await responsePromise)).toBe('string'); }); + + }); }); }