diff --git a/libs/providers/baseten/package.json b/libs/providers/baseten/package.json new file mode 100644 index 000000000..58d41f1d8 --- /dev/null +++ b/libs/providers/baseten/package.json @@ -0,0 +1,73 @@ +{ + "name": "@langchain/baseten", + "version": "1.0.0-alpha.0", + "description": "Baseten LLM provider for LangChain and deepagents", + "main": "./dist/index.cjs", + "module": "./dist/index.js", + "types": "./dist/index.d.ts", + "type": "module", + "scripts": { + "build": "tsdown", + "clean": "rm -rf dist/ .tsdown/", + "dev": "tsc --watch", + "typecheck": "tsc --noEmit", + "prepublishOnly": "pnpm build", + "test": "vitest run", + "test:unit": "vitest run", + "test:int": "vitest run --mode int" + }, + "repository": { + "type": "git", + "url": "git+https://github.com/langchain-ai/deepagentsjs.git" + }, + "keywords": [ + "ai", + "agents", + "langgraph", + "langchain", + "typescript", + "llm", + "baseten", + "chat-model" + ], + "author": "LangChain", + "license": "MIT", + "bugs": { + "url": "https://github.com/langchain-ai/deepagentsjs/issues" + }, + "homepage": "https://github.com/langchain-ai/deepagentsjs#readme", + "dependencies": { + "@langchain/openai": "^1.4.1" + }, + "peerDependencies": { + "@langchain/core": "^1.1.38" + }, + "devDependencies": { + "deepagents": "workspace:*", + "@langchain/core": "^1.1.38", + "@tsconfig/recommended": "^1.0.13", + "@types/node": "^25.1.0", + "@vitest/coverage-v8": "^4.0.18", + "dotenv": "^17.2.3", + "tsdown": "^0.21.4", + "tsx": "^4.21.0", + "typescript": "^6.0.2", + "vitest": "^4.0.18" + }, + "exports": { + ".": { + "import": { + "types": "./dist/index.d.ts", + "default": "./dist/index.js" + }, + "require": { + "types": "./dist/index.d.cts", + "default": "./dist/index.cjs" + } + }, + "./package.json": "./package.json" + }, + "files": [ + "dist/**/*" + ] +} diff --git a/libs/providers/baseten/src/baseten.int.test.ts b/libs/providers/baseten/src/baseten.int.test.ts new file mode 100644 index 000000000..681bd135c --- /dev/null +++ b/libs/providers/baseten/src/baseten.int.test.ts @@ -0,0 +1,65 @@ +import { describe, it, expect } from "vitest"; +import { HumanMessage } from "@langchain/core/messages"; +import { ChatBaseten } from "./baseten.js"; +import { createDeepAgent } from "deepagents"; + +const BASETEN_MODEL = "deepseek-ai/DeepSeek-V3.1"; + +describe("ChatBaseten Integration Tests", () => { + it( + "should invoke ChatBaseten directly", + { timeout: 60_000 }, + async () => { + const model = new ChatBaseten({ model: BASETEN_MODEL }); + + const result = await model.invoke([ + new HumanMessage("What is 2 + 2? Answer with just the number."), + ]); + + expect(result.content).toBeTruthy(); + expect(typeof result.content).toBe("string"); + expect(result.content).toContain("4"); + }, + ); + + it( + "should stream responses from ChatBaseten", + { timeout: 60_000 }, + async () => { + const model = new ChatBaseten({ model: BASETEN_MODEL }); + + const chunks: string[] = []; + for await (const chunk of await model.stream( + "Say the word 'hello' and nothing else.", + )) { + if (typeof chunk.content === "string") { + chunks.push(chunk.content); + } + } + + const fullResponse = chunks.join(""); + expect(fullResponse.toLowerCase()).toContain("hello"); + }, + ); + + it( + "should work with createDeepAgent", + { timeout: 90_000 }, + async () => { + const model = new ChatBaseten({ model: BASETEN_MODEL }); + + const agent = createDeepAgent({ model }); + + const result = await agent.invoke({ + messages: [ + new HumanMessage( + "What is the capital of France? Answer in one word.", + ), + ], + }); + + const lastMessage = result.messages[result.messages.length - 1]; + expect(lastMessage.content).toBeTruthy(); + }, + ); +}); diff --git a/libs/providers/baseten/src/baseten.test.ts b/libs/providers/baseten/src/baseten.test.ts new file mode 100644 index 000000000..8af380c82 --- /dev/null +++ b/libs/providers/baseten/src/baseten.test.ts @@ -0,0 +1,633 @@ +import { describe, it, expect, beforeEach, afterEach, vi } from "vitest"; +import { + ChatBaseten, + normalizeToolCallChunks, + normalizeModelUrl, +} from "./index.js"; +import { DEFAULT_BASE_URL, DEFAULT_API_KEY_ENV_VAR } from "./types.js"; +import { AIMessageChunk } from "@langchain/core/messages"; +import { ChatGenerationChunk } from "@langchain/core/outputs"; +import type { ToolCallChunk } from "@langchain/core/messages/tool"; + +const TEST_API_KEY = "test-baseten-api-key"; +const TEST_MODEL = "deepseek-ai/DeepSeek-V3.1"; + +describe("ChatBaseten", () => { + let originalEnv: string | undefined; + + beforeEach(() => { + originalEnv = process.env[DEFAULT_API_KEY_ENV_VAR]; + delete process.env[DEFAULT_API_KEY_ENV_VAR]; + }); + + afterEach(() => { + if (originalEnv !== undefined) { + process.env[DEFAULT_API_KEY_ENV_VAR] = originalEnv; + } else { + delete process.env[DEFAULT_API_KEY_ENV_VAR]; + } + }); + + describe("constructor", () => { + it("sets default base URL to Baseten inference endpoint", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + expect((model as any).clientConfig.baseURL).toBe(DEFAULT_BASE_URL); + }); + + it("allows overriding base URL for self-deployed models", () => { + const customURL = "https://model-abc123.api.baseten.co/v1"; + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + baseURL: customURL, + }); + + expect((model as any).clientConfig.baseURL).toBe(customURL); + }); + + it("sets the model name correctly", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + expect(model.model).toBe(TEST_MODEL); + }); + + it("passes through additional ChatOpenAI options", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + temperature: 0.5, + maxTokens: 1024, + }); + + expect(model.temperature).toBe(0.5); + expect(model.maxTokens).toBe(1024); + }); + + it("preserves additional configuration options alongside baseURL", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + configuration: { + defaultHeaders: { "X-Custom": "value" }, + }, + }); + + expect((model as any).clientConfig.baseURL).toBe(DEFAULT_BASE_URL); + expect((model as any).clientConfig.defaultHeaders).toEqual({ + "X-Custom": "value", + }); + }); + + it("enables streamUsage by default", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + expect(model.streamUsage).toBe(true); + }); + + it("allows disabling streamUsage explicitly", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + streamUsage: false, + }); + + expect(model.streamUsage).toBe(false); + }); + + it("normalizes modelUrl and uses it as baseURL", () => { + const model = new ChatBaseten({ + model: "custom-model", + modelUrl: + "https://model-abc123.api.baseten.co/environments/production/predict", + basetenApiKey: TEST_API_KEY, + }); + + expect((model as any).clientConfig.baseURL).toBe( + "https://model-abc123.api.baseten.co/environments/production/sync/v1", + ); + }); + + it("infers model name from modelUrl when model is omitted", () => { + const model = new ChatBaseten({ + modelUrl: + "https://model-xyz789.api.baseten.co/environments/production/sync/v1", + basetenApiKey: TEST_API_KEY, + }); + + expect(model.model).toBe("model-xyz789"); + }); + + it("prefers explicit model over URL-inferred name", () => { + const model = new ChatBaseten({ + model: "my-org/my-model", + modelUrl: + "https://model-abc123.api.baseten.co/environments/production/sync/v1", + basetenApiKey: TEST_API_KEY, + }); + + expect(model.model).toBe("my-org/my-model"); + }); + + it("modelUrl overrides baseURL", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + modelUrl: + "https://model-abc123.api.baseten.co/environments/production/sync", + baseURL: "https://should-be-ignored.com/v1", + basetenApiKey: TEST_API_KEY, + }); + + expect((model as any).clientConfig.baseURL).toBe( + "https://model-abc123.api.baseten.co/environments/production/sync/v1", + ); + }); + }); + + describe("API key resolution", () => { + it("uses basetenApiKey when provided explicitly", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + expect(model.apiKey).toBe(TEST_API_KEY); + }); + + it("uses apiKey field as fallback", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + apiKey: "fallback-key", + }); + + expect(model.apiKey).toBe("fallback-key"); + }); + + it("falls back to BASETEN_API_KEY environment variable", () => { + process.env[DEFAULT_API_KEY_ENV_VAR] = "env-api-key"; + + const model = new ChatBaseten({ model: TEST_MODEL }); + + expect(model.apiKey).toBe("env-api-key"); + }); + + it("prefers basetenApiKey over apiKey and env var", () => { + process.env[DEFAULT_API_KEY_ENV_VAR] = "env-key"; + + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: "explicit-key", + apiKey: "generic-key", + }); + + expect(model.apiKey).toBe("explicit-key"); + }); + + it("throws a descriptive error when no API key is available", () => { + expect( + () => new ChatBaseten({ model: TEST_MODEL }), + ).toThrowError(/Baseten API key not found/); + }); + }); + + describe("getName", () => { + it("returns 'ChatBaseten'", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + expect(model.getName()).toBe("ChatBaseten"); + }); + }); + + describe("lc_name", () => { + it("returns 'ChatBaseten' for serialization", () => { + expect(ChatBaseten.lc_name()).toBe("ChatBaseten"); + }); + }); + + describe("getLsParams", () => { + it("returns baseten as the LangSmith provider", () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const params = model.getLsParams({} as any); + + expect(params.ls_provider).toBe("baseten"); + expect(params.ls_model_name).toBe(TEST_MODEL); + }); + }); + + // ------------------------------------------------------------------------- + // Reasoning content pass-through + // ------------------------------------------------------------------------- + + describe("reasoning content", () => { + it("preserves reasoning_content in additional_kwargs from _generate", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const msg = new AIMessageChunk({ content: "Answer" }); + msg.additional_kwargs = { reasoning_content: "Let me think..." }; + + const fakeResult = { + generations: [ + { text: "Answer", message: msg, generationInfo: {} }, + ], + llmOutput: {}, + }; + + const superGenerate = vi + .spyOn(Object.getPrototypeOf(ChatBaseten.prototype), "_generate") + .mockResolvedValueOnce(fakeResult); + + const result = await model._generate([], {} as any); + + expect( + result.generations[0].message.additional_kwargs.reasoning_content, + ).toBe("Let me think..."); + + superGenerate.mockRestore(); + }); + + it("preserves reasoning_content in additional_kwargs from stream chunks", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const msg = new AIMessageChunk({ content: "Hello" }); + msg.additional_kwargs = { reasoning_content: "step 1" }; + const chunk = new ChatGenerationChunk({ + text: "Hello", + message: msg, + }); + + const superStream = vi + .spyOn( + Object.getPrototypeOf(ChatBaseten.prototype), + "_streamResponseChunks", + ) + .mockReturnValueOnce( + (async function* () { + yield chunk; + })(), + ); + + const chunks: ChatGenerationChunk[] = []; + for await (const c of model._streamResponseChunks([], {} as any)) { + chunks.push(c); + } + + expect( + (chunks[0].message as AIMessageChunk).additional_kwargs + .reasoning_content, + ).toBe("step 1"); + + superStream.mockRestore(); + }); + }); + + // ------------------------------------------------------------------------- + // _generate override: response metadata enrichment + // ------------------------------------------------------------------------- + + describe("_generate", () => { + it("adds model_provider to response metadata", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const fakeResult = { + generations: [ + { + text: "Hello!", + message: new AIMessageChunk({ content: "Hello!" }), + generationInfo: {}, + }, + ], + llmOutput: {}, + }; + + const superGenerate = vi + .spyOn(Object.getPrototypeOf(ChatBaseten.prototype), "_generate") + .mockResolvedValueOnce(fakeResult); + + const result = await model._generate([], {} as any); + + expect(result.generations[0].message.response_metadata).toEqual( + expect.objectContaining({ model_provider: "baseten" }), + ); + + superGenerate.mockRestore(); + }); + + it("preserves existing response metadata", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const msg = new AIMessageChunk({ content: "Hi" }); + msg.response_metadata = { finish_reason: "stop" }; + + const fakeResult = { + generations: [{ text: "Hi", message: msg, generationInfo: {} }], + llmOutput: {}, + }; + + const superGenerate = vi + .spyOn(Object.getPrototypeOf(ChatBaseten.prototype), "_generate") + .mockResolvedValueOnce(fakeResult); + + const result = await model._generate([], {} as any); + const meta = result.generations[0].message.response_metadata; + + expect(meta).toEqual({ + finish_reason: "stop", + model_provider: "baseten", + }); + + superGenerate.mockRestore(); + }); + }); + + // ------------------------------------------------------------------------- + // _streamResponseChunks override + // ------------------------------------------------------------------------- + + describe("_streamResponseChunks", () => { + function makeChunk( + content: string, + overrides?: { + tool_call_chunks?: ToolCallChunk[]; + usage_metadata?: any; + response_metadata?: Record; + }, + ): ChatGenerationChunk { + const msg = new AIMessageChunk({ + content, + tool_call_chunks: overrides?.tool_call_chunks, + usage_metadata: overrides?.usage_metadata, + }); + if (overrides?.response_metadata) { + msg.response_metadata = overrides.response_metadata; + } + return new ChatGenerationChunk({ text: content, message: msg }); + } + + async function* fakeStream( + chunks: ChatGenerationChunk[], + ): AsyncGenerator { + for (const c of chunks) yield c; + } + + it("tags every chunk with model_provider: baseten", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const superStream = vi + .spyOn( + Object.getPrototypeOf(ChatBaseten.prototype), + "_streamResponseChunks", + ) + .mockReturnValueOnce(fakeStream([makeChunk("hello")])); + + const chunks: ChatGenerationChunk[] = []; + for await (const c of model._streamResponseChunks([], {} as any)) { + chunks.push(c); + } + + expect(chunks).toHaveLength(1); + expect(chunks[0].message.response_metadata).toEqual( + expect.objectContaining({ model_provider: "baseten" }), + ); + + superStream.mockRestore(); + }); + + it("strips usage_metadata from content chunks", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const contentChunk = makeChunk("hi", { + usage_metadata: { + input_tokens: 10, + output_tokens: 5, + total_tokens: 15, + }, + }); + const usageOnlyChunk = makeChunk("", { + usage_metadata: { + input_tokens: 10, + output_tokens: 20, + total_tokens: 30, + }, + }); + + const superStream = vi + .spyOn( + Object.getPrototypeOf(ChatBaseten.prototype), + "_streamResponseChunks", + ) + .mockReturnValueOnce( + fakeStream([contentChunk, usageOnlyChunk]), + ); + + const chunks: ChatGenerationChunk[] = []; + for await (const c of model._streamResponseChunks([], {} as any)) { + chunks.push(c); + } + + expect(chunks).toHaveLength(2); + expect( + (chunks[0].message as AIMessageChunk).usage_metadata, + ).toBeUndefined(); + expect( + (chunks[1].message as AIMessageChunk).usage_metadata, + ).toBeDefined(); + + superStream.mockRestore(); + }); + + it("normalizes tool_call_chunks in stream", async () => { + const model = new ChatBaseten({ + model: TEST_MODEL, + basetenApiKey: TEST_API_KEY, + }); + + const chunk = makeChunk("", { + tool_call_chunks: [ + { name: "get_weather", args: '{"loc', id: "call_1", index: 0 }, + { name: undefined, args: 'ation":', id: "call_2", index: 0 }, + ], + }); + + const superStream = vi + .spyOn( + Object.getPrototypeOf(ChatBaseten.prototype), + "_streamResponseChunks", + ) + .mockReturnValueOnce(fakeStream([chunk])); + + const chunks: ChatGenerationChunk[] = []; + for await (const c of model._streamResponseChunks([], {} as any)) { + chunks.push(c); + } + + const msg = chunks[0].message as AIMessageChunk; + expect(msg.tool_call_chunks).toHaveLength(1); + expect(msg.tool_call_chunks![0]).toMatchObject({ + name: "get_weather", + args: '{"location":', + index: 0, + }); + + superStream.mockRestore(); + }); + }); +}); + +// --------------------------------------------------------------------------- +// normalizeToolCallChunks (unit, exported) +// --------------------------------------------------------------------------- + +describe("normalizeToolCallChunks", () => { + it("returns single chunk with name unchanged", () => { + const chunks: ToolCallChunk[] = [ + { name: "search", args: '{"q":"hi"}', id: "c1", index: 0 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toEqual(chunks); + }); + + it("consolidates same-index entries", () => { + const chunks: ToolCallChunk[] = [ + { name: "search", args: '{"q":', id: "c1", index: 0 }, + { name: undefined, args: '"hi"}', id: "c2", index: 0 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + name: "search", + args: '{"q":"hi"}', + id: "c1", + index: 0, + }); + }); + + it("nulls out id on continuation chunks (no name)", () => { + const chunks: ToolCallChunk[] = [ + { name: undefined, args: '"rest"}', id: "c99", index: 0 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toHaveLength(1); + expect(result[0].id).toBeUndefined(); + expect(result[0].args).toBe('"rest"}'); + }); + + it("keeps separate entries for different indices", () => { + const chunks: ToolCallChunk[] = [ + { name: "search", args: '{"a":1}', id: "c1", index: 0 }, + { name: "fetch", args: '{"b":2}', id: "c2", index: 1 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toHaveLength(2); + expect(result[0].name).toBe("search"); + expect(result[1].name).toBe("fetch"); + }); + + it("handles empty array", () => { + expect(normalizeToolCallChunks([])).toEqual([]); + }); + + it("preserves first non-null id when merging", () => { + const chunks: ToolCallChunk[] = [ + { name: "fn", args: "{", id: undefined, index: 0 }, + { name: undefined, args: "}", id: "c5", index: 0 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toHaveLength(1); + expect(result[0].id).toBe("c5"); + expect(result[0].name).toBe("fn"); + }); + + it("merges three same-index deltas", () => { + const chunks: ToolCallChunk[] = [ + { name: "tool", args: '{"a":', id: "c1", index: 0 }, + { name: undefined, args: '"b",', id: "c2", index: 0 }, + { name: undefined, args: '"c":1}', id: "c3", index: 0 }, + ]; + const result = normalizeToolCallChunks(chunks); + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + name: "tool", + args: '{"a":"b","c":1}', + id: "c1", + index: 0, + }); + }); +}); + +// --------------------------------------------------------------------------- +// normalizeModelUrl (unit, exported) +// --------------------------------------------------------------------------- + +describe("normalizeModelUrl", () => { + it("converts /predict to /sync/v1", () => { + expect( + normalizeModelUrl( + "https://model-abc123.api.baseten.co/environments/production/predict", + ), + ).toBe( + "https://model-abc123.api.baseten.co/environments/production/sync/v1", + ); + }); + + it("appends /v1 to /sync", () => { + expect( + normalizeModelUrl( + "https://model-abc123.api.baseten.co/environments/production/sync", + ), + ).toBe( + "https://model-abc123.api.baseten.co/environments/production/sync/v1", + ); + }); + + it("leaves /sync/v1 unchanged", () => { + const url = + "https://model-abc123.api.baseten.co/environments/production/sync/v1"; + expect(normalizeModelUrl(url)).toBe(url); + }); + + it("appends /v1 to bare URL without trailing slash", () => { + expect( + normalizeModelUrl("https://model-abc123.api.baseten.co"), + ).toBe("https://model-abc123.api.baseten.co/v1"); + }); + + it("strips trailing slash before appending /v1", () => { + expect( + normalizeModelUrl("https://model-abc123.api.baseten.co/"), + ).toBe("https://model-abc123.api.baseten.co/v1"); + }); +}); diff --git a/libs/providers/baseten/src/baseten.ts b/libs/providers/baseten/src/baseten.ts new file mode 100644 index 000000000..074c9c842 --- /dev/null +++ b/libs/providers/baseten/src/baseten.ts @@ -0,0 +1,227 @@ +/** + * ChatBaseten — Baseten LLM provider for LangChain. + * + * Extends ChatOpenAI to target Baseten's OpenAI-compatible inference API. + * Includes streaming fixes for TensorRT-LLM serving quirks ported from + * the Python `langchain-baseten` package. + * + * @packageDocumentation + */ + +import { ChatOpenAI } from "@langchain/openai"; +import type { LangSmithParams } from "@langchain/core/language_models/chat_models"; +import type { BaseMessage } from "@langchain/core/messages"; +import { AIMessageChunk } from "@langchain/core/messages"; +import type { ToolCallChunk } from "@langchain/core/messages/tool"; +import type { ChatGenerationChunk, ChatResult } from "@langchain/core/outputs"; +import type { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; +import { + DEFAULT_BASE_URL, + DEFAULT_API_KEY_ENV_VAR, + normalizeModelUrl, + type BasetenChatInput, +} from "./types.js"; +function resolveApiKey(fields?: BasetenChatInput): string { + if (fields?.basetenApiKey) return fields.basetenApiKey; + if (typeof fields?.apiKey === "string") return fields.apiKey; + if (typeof fields?.openAIApiKey === "string") return fields.openAIApiKey; + + const envKey = process.env[DEFAULT_API_KEY_ENV_VAR]; + if (envKey) return envKey; + + throw new Error( + `Baseten API key not found. Provide it via the "basetenApiKey" constructor ` + + `option or set the ${DEFAULT_API_KEY_ENV_VAR} environment variable.\n\n` + + ` Get your API key at https://app.baseten.co/settings/api-keys`, + ); +} + +/** + * Fix TensorRT-LLM tool-call streaming quirks: + * - Fold same-index deltas within a single SSE event into one entry + * - Clear `id` on continuation deltas (no `name`) so `concat()` merges by index + * + * See: Python `langchain-baseten._normalize_tool_call_chunks` + */ +export function normalizeToolCallChunks( + chunks: ToolCallChunk[], +): ToolCallChunk[] { + if (chunks.length <= 1 && (!chunks[0] || chunks[0].name)) return chunks; + + const byIndex = new Map(); + + for (const tc of chunks) { + if (tc.index == null) continue; + const existing = byIndex.get(tc.index); + if (!existing) { + byIndex.set(tc.index, { ...tc }); + } else { + byIndex.set(tc.index, { + ...existing, + name: existing.name ?? tc.name, + args: (existing.args ?? "") + (tc.args ?? ""), + id: existing.id ?? tc.id, + }); + } + } + + const result: ToolCallChunk[] = []; + for (const tc of byIndex.values()) { + if (!tc.name && tc.id != null) { + result.push({ ...tc, id: undefined }); + } else { + result.push(tc); + } + } + + return result; +} + +function chunkHasContent(message: AIMessageChunk): boolean { + if (typeof message.content === "string" && message.content.length > 0) + return true; + if (Array.isArray(message.content) && message.content.length > 0) + return true; + if (message.tool_call_chunks && message.tool_call_chunks.length > 0) + return true; + return false; +} + +function inferModelNameFromUrl(url: string): string { + const match = /model-([a-zA-Z0-9]+)/.exec(url); + return match ? `model-${match[1]}` : "baseten-model"; +} + +/** + * Baseten chat model for LangChain. + * + * Wraps {@link ChatOpenAI} pointed at `https://inference.baseten.co/v1`. + * Streaming, tool calling, structured output, and token tracking are + * all inherited from ChatOpenAI. + * + * @example + * ```typescript + * import { ChatBaseten } from "@langchain/baseten"; + * + * const model = new ChatBaseten({ + * model: "deepseek-ai/DeepSeek-V3.1", + * }); + * + * const result = await model.invoke("What is the capital of France?"); + * ``` + * + * @example + * ```typescript + * import { createDeepAgent } from "deepagents"; + * + * const agent = createDeepAgent({ + * model: new ChatBaseten({ model: "deepseek-ai/DeepSeek-V3.1" }), + * }); + * ``` + */ +export class ChatBaseten extends ChatOpenAI { + static lc_name() { + return "ChatBaseten"; + } + + constructor(fields?: BasetenChatInput) { + const apiKey = resolveApiKey(fields); + + let baseURL: string; + let model: string; + + if (fields?.modelUrl) { + baseURL = normalizeModelUrl(fields.modelUrl); + model = fields.model ?? inferModelNameFromUrl(fields.modelUrl); + } else { + baseURL = fields?.baseURL ?? DEFAULT_BASE_URL; + model = fields?.model ?? ""; + } + + const { + basetenApiKey: _basetenApiKey, + baseURL: _baseURL, + modelUrl: _modelUrl, + configuration, + ...rest + } = fields ?? {}; + + super({ + ...rest, + model, + apiKey, + streamUsage: rest.streamUsage ?? true, + configuration: { + ...configuration, + baseURL, + }, + }); + } + + override getName(): string { + return "ChatBaseten"; + } + + override getLsParams( + options: this["ParsedCallOptions"], + ): LangSmithParams { + const params = super.getLsParams(options); + return { + ...params, + ls_provider: "baseten", + }; + } + + override async _generate( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun, + ): Promise { + const result = await super._generate(messages, options, runManager); + for (const generation of result.generations) { + generation.message.response_metadata = { + ...generation.message.response_metadata, + model_provider: "baseten", + }; + } + return result; + } + + // Reasoning content (`additional_kwargs.reasoning_content`) is handled by + // the parent ChatOpenAI via its completions converters — no override needed. + + override async *_streamResponseChunks( + messages: BaseMessage[], + options: this["ParsedCallOptions"], + runManager?: CallbackManagerForLLMRun, + ): AsyncGenerator { + for await (const chunk of super._streamResponseChunks( + messages, + options, + runManager, + )) { + const message = chunk.message; + + if (AIMessageChunk.isInstance(message)) { + if (message.tool_call_chunks && message.tool_call_chunks.length > 0) { + message.tool_call_chunks = normalizeToolCallChunks( + message.tool_call_chunks, + ); + } + + // Baseten sends cumulative usage on every content chunk; strip it so + // LangChain only counts the final usage-only chunk. + if (message.usage_metadata && chunkHasContent(message)) { + message.usage_metadata = undefined; + } + + message.response_metadata = { + ...message.response_metadata, + model_provider: "baseten", + }; + } + + yield chunk; + } + } +} diff --git a/libs/providers/baseten/src/index.ts b/libs/providers/baseten/src/index.ts new file mode 100644 index 000000000..e0c3fe7c6 --- /dev/null +++ b/libs/providers/baseten/src/index.ts @@ -0,0 +1,32 @@ +/** + * @langchain/baseten + * + * Baseten LLM provider for LangChain and deepagents. + * + * This package provides a `ChatBaseten` class that extends `ChatOpenAI` to connect + * to Baseten's OpenAI-compatible inference API, enabling access to open-source LLMs + * hosted on Baseten's infrastructure. + * + * @example + * ```typescript + * import { ChatBaseten } from "@langchain/baseten"; + * import { createDeepAgent } from "deepagents"; + * + * const model = new ChatBaseten({ + * model: "deepseek-ai/DeepSeek-V3.1", + * // Uses BASETEN_API_KEY env var by default + * }); + * + * const agent = createDeepAgent({ model }); + * + * const result = await agent.invoke({ + * messages: [{ role: "user", content: "Hello!" }], + * }); + * ``` + * + * @packageDocumentation + */ + +export { ChatBaseten, normalizeToolCallChunks } from "./baseten.js"; +export { normalizeModelUrl } from "./types.js"; +export type { BasetenChatInput } from "./types.js"; diff --git a/libs/providers/baseten/src/types.ts b/libs/providers/baseten/src/types.ts new file mode 100644 index 000000000..c78050853 --- /dev/null +++ b/libs/providers/baseten/src/types.ts @@ -0,0 +1,105 @@ +/** + * Type definitions for the Baseten LLM provider. + */ + +import type { ChatOpenAIFields } from "@langchain/openai"; + +/** + * Default base URL for Baseten's managed inference API. + * Supports all open-source models hosted on Baseten's Model APIs. + * + * For self-deployed models, override with the model-specific URL: + * `https://model-{model_id}.api.baseten.co/v1` + */ +export const DEFAULT_BASE_URL = "https://inference.baseten.co/v1"; + +/** + * Default environment variable name for the Baseten API key. + */ +export const DEFAULT_API_KEY_ENV_VAR = "BASETEN_API_KEY"; + +/** + * Normalize a dedicated model URL to OpenAI-compatible `/sync/v1` format. + * + * Baseten dedicated model endpoints come in several forms: + * - `.../predict` → converted to `.../sync/v1` + * - `.../sync` → appended with `/v1` + * - anything else → ensures trailing `/v1` + * + * See: Python `langchain-baseten._normalize_model_url` + */ +export function normalizeModelUrl(url: string): string { + if (url.endsWith("/predict")) { + return url.replace(/\/predict$/, "/sync/v1"); + } + if (url.endsWith("/sync")) { + return `${url}/v1`; + } + if (!url.endsWith("/v1")) { + return `${url.replace(/\/+$/, "")}/v1`; + } + return url; +} + +/** + * Input fields for constructing a `ChatBaseten` instance. + * + * Extends `ChatOpenAIFields` since Baseten exposes an OpenAI-compatible API. + * The API key defaults to the `BASETEN_API_KEY` environment variable and + * the base URL defaults to Baseten's managed inference endpoint. + * + * @example + * ```typescript + * const input: BasetenChatInput = { + * model: "deepseek-ai/DeepSeek-V3.1", + * // apiKey defaults to process.env.BASETEN_API_KEY + * }; + * ``` + * + * @example + * ```typescript + * // Self-deployed model via dedicated URL + * const input: BasetenChatInput = { + * modelUrl: "https://model-abc123.api.baseten.co/environments/production/predict", + * basetenApiKey: "my-key", + * }; + * ``` + */ +export interface BasetenChatInput extends Omit { + /** + * Baseten model slug in `org/model-name` format. + * Optional when `modelUrl` is provided (the model ID will be + * extracted from the URL). + * + * @example "deepseek-ai/DeepSeek-V3.1" + * @example "zai-org/GLM-5" + * @example "moonshotai/Kimi-K2.5" + */ + model?: string; + + /** + * Dedicated model URL for self-deployed Baseten models. + * Supports `/predict`, `/sync`, and `/sync/v1` endpoint formats; + * the URL is automatically normalized to `/sync/v1`. + * + * When provided, overrides `baseURL`. + * + * @example "https://model-abc123.api.baseten.co/environments/production/predict" + */ + modelUrl?: string; + + /** + * Baseten API key. If not provided, falls back to the `BASETEN_API_KEY` + * environment variable. + */ + basetenApiKey?: string; + + /** + * Override the base URL for Baseten's API. + * Defaults to `https://inference.baseten.co/v1`. + * + * Set this for self-deployed models: + * `https://model-{model_id}.api.baseten.co/v1` + */ + baseURL?: string; +} diff --git a/libs/providers/baseten/tsconfig.json b/libs/providers/baseten/tsconfig.json new file mode 100644 index 000000000..7c0f864fb --- /dev/null +++ b/libs/providers/baseten/tsconfig.json @@ -0,0 +1,8 @@ +{ + "extends": "../../../tsconfig.json", + "compilerOptions": { + "outDir": "dist" + }, + "include": ["src/**/*.ts", "src/*.ts"], + "exclude": ["node_modules", "dist"] +} diff --git a/libs/providers/baseten/tsdown.config.ts b/libs/providers/baseten/tsdown.config.ts new file mode 100644 index 000000000..78daf267b --- /dev/null +++ b/libs/providers/baseten/tsdown.config.ts @@ -0,0 +1,26 @@ +import { defineConfig } from "tsdown"; + +const external = [/^[^./]/]; + +export default defineConfig([ + { + entry: ["./src/index.ts"], + format: ["esm"], + dts: true, + clean: true, + sourcemap: true, + outDir: "dist", + outExtensions: () => ({ js: ".js" }), + external, + }, + { + entry: ["./src/index.ts"], + format: ["cjs"], + dts: true, + clean: true, + sourcemap: true, + outDir: "dist", + outExtensions: () => ({ js: ".cjs" }), + external, + }, +]); diff --git a/libs/providers/baseten/vitest.config.ts b/libs/providers/baseten/vitest.config.ts new file mode 100644 index 000000000..4b0e8fa9e --- /dev/null +++ b/libs/providers/baseten/vitest.config.ts @@ -0,0 +1,41 @@ +import { + configDefaults, + defineConfig, + type ViteUserConfigExport, +} from "vitest/config"; + +export default defineConfig((env) => { + const common: ViteUserConfigExport = { + test: { + environment: "node", + hideSkippedTests: true, + globals: true, + testTimeout: 60_000, + hookTimeout: 60_000, + teardownTimeout: 60_000, + exclude: ["**/*.int.test.ts", ...configDefaults.exclude], + setupFiles: ["dotenv/config"], + }, + }; + + if (env.mode === "int") { + return { + test: { + ...common.test, + globals: false, + testTimeout: 100_000, + exclude: configDefaults.exclude, + include: ["**/*.int.test.ts"], + name: "int", + sequence: { concurrent: false }, + }, + } satisfies ViteUserConfigExport; + } + + return { + test: { + ...common.test, + include: ["src/**/*.test.ts"], + }, + } satisfies ViteUserConfigExport; +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index b2674b692..0aaf2c50c 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -504,6 +504,43 @@ importers: specifier: ^4.0.18 version: 4.1.2(@opentelemetry/api@1.9.0)(@types/node@25.5.0)(@vitest/ui@4.0.18)(vite@8.0.3(@emnapi/core@1.8.1)(@emnapi/runtime@1.8.1)(@types/node@25.5.0)(esbuild@0.27.3)(jiti@2.6.1)(tsx@4.21.0)(yaml@2.8.3)) + libs/providers/baseten: + dependencies: + '@langchain/openai': + specifier: ^1.4.1 + version: 1.4.1(@langchain/core@1.1.38(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.207.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.5.1(@opentelemetry/api@1.9.0))(openai@6.33.0(ws@8.19.0)(zod@4.3.6))(ws@8.19.0))(ws@8.19.0) + devDependencies: + '@langchain/core': + specifier: ^1.1.38 + version: 1.1.38(@opentelemetry/api@1.9.0)(@opentelemetry/exporter-trace-otlp-proto@0.207.0(@opentelemetry/api@1.9.0))(@opentelemetry/sdk-trace-base@2.5.1(@opentelemetry/api@1.9.0))(openai@6.33.0(ws@8.19.0)(zod@4.3.6))(ws@8.19.0) + '@tsconfig/recommended': + specifier: ^1.0.13 + version: 1.0.13 + '@types/node': + specifier: ^25.1.0 + version: 25.5.0 + '@vitest/coverage-v8': + specifier: ^4.0.18 + version: 4.0.18(vitest@4.1.2(@opentelemetry/api@1.9.0)(@types/node@25.5.0)(vite@8.0.3(@emnapi/core@1.8.1)(@emnapi/runtime@1.8.1)(@types/node@25.5.0)(esbuild@0.27.3)(jiti@2.6.1)(tsx@4.21.0)(yaml@2.8.3))) + deepagents: + specifier: workspace:* + version: link:../../deepagents + dotenv: + specifier: ^17.2.3 + version: 17.3.1 + tsdown: + specifier: ^0.21.4 + version: 0.21.7(@emnapi/core@1.8.1)(@emnapi/runtime@1.8.1)(synckit@0.11.12)(typescript@6.0.2) + tsx: + specifier: ^4.21.0 + version: 4.21.0 + typescript: + specifier: ^6.0.2 + version: 6.0.2 + vitest: + specifier: ^4.0.18 + version: 4.1.2(@opentelemetry/api@1.9.0)(@types/node@25.5.0)(@vitest/ui@4.0.18)(vite@8.0.3(@emnapi/core@1.8.1)(@emnapi/runtime@1.8.1)(@types/node@25.5.0)(esbuild@0.27.3)(jiti@2.6.1)(tsx@4.21.0)(yaml@2.8.3)) + libs/providers/daytona: dependencies: '@daytonaio/sdk':