Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ const val AGENT_MARKETPLACE_PRICING_MIN_MAX = 20.00
// [llm.proxies]
const val AGENT_LLM_PROXIES_MAX_ENTRIES = 16
val AGENT_LLM_PROXY_NAME_LENGTH = 1..32
val AGENT_LLM_PROXY_NAME_PATTERN = "^[A-Z][A-Z]*$".toRegex()
val AGENT_LLM_PROXY_NAME_PATTERN = "^[A-Z_0-9]+$".toRegex()
Comment thread
CaelumF marked this conversation as resolved.
val AGENT_LLM_PROXY_MODEL_LENGTH = 1..128

// [marketplace.identities.erc8004]
Expand Down Expand Up @@ -601,7 +601,7 @@ private fun RegistryAgent.validateLlm() {
validateStringLength("llm.proxies[$index].name", proxy.name, AGENT_LLM_PROXY_NAME_LENGTH)

if (!proxy.name.matches(AGENT_LLM_PROXY_NAME_PATTERN))
throw RegistryException("llm.proxies[$index].name (\"${proxy.name}\") must be uppercase alphabetic only")
throw RegistryException("llm.proxies[$index].name (\"${proxy.name}\") must only contain uppercase alphanumeric or underscore characters")

if (!names.add(proxy.name))
throw RegistryException("llm.proxies[$index].name (\"${proxy.name}\") is not unique")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.JsonClassDiscriminator
import org.coralprotocol.coralserver.config.AddressConsumer
import org.coralprotocol.coralserver.session.SessionAgentExecutionContext

@Serializable
Expand All @@ -20,9 +19,7 @@ sealed interface PrototypeApiUrl {
@SerialName("proxy")
data object Proxy : PrototypeApiUrl {
override fun resolve(executionContext: SessionAgentExecutionContext): String {
return executionContext.applicationRuntimeContext
.getLlmProxyUrl(executionContext, AddressConsumer.LOCAL)
.toString()
TODO("format changing soon")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ sealed class PrototypeModelProvider {
override val url: PrototypeApiUrl? = null,
) : PrototypeModelProvider() {
override fun getExecutor(executionContext: SessionAgentExecutionContext): PromptExecutor {
val resolvedUrl = resolveUrlWithProvider(url, executionContext, "openai")
return MultiLLMPromptExecutor(
OpenAILLMClient(
apiKey = key.resolve(executionContext),
settings = if (resolvedUrl == null) OpenAIClientSettings() else OpenAIClientSettings(
baseUrl = resolvedUrl
settings = if (url == null) OpenAIClientSettings() else OpenAIClientSettings(
baseUrl = url.resolve(executionContext)
)
)
)
Expand All @@ -84,12 +83,11 @@ sealed class PrototypeModelProvider {
override val url: PrototypeApiUrl? = null,
) : PrototypeModelProvider() {
override fun getExecutor(executionContext: SessionAgentExecutionContext): PromptExecutor {
val resolvedUrl = resolveUrlWithProvider(url, executionContext, "anthropic")
return MultiLLMPromptExecutor(
AnthropicLLMClient(
apiKey = key.resolve(executionContext),
settings = if (resolvedUrl == null) AnthropicClientSettings() else AnthropicClientSettings(
baseUrl = resolvedUrl
settings = if (url == null) AnthropicClientSettings() else AnthropicClientSettings(
baseUrl = url.resolve(executionContext)
)
)
)
Expand All @@ -107,12 +105,11 @@ sealed class PrototypeModelProvider {
override val url: PrototypeApiUrl? = null,
) : PrototypeModelProvider() {
override fun getExecutor(executionContext: SessionAgentExecutionContext): PromptExecutor {
val resolvedUrl = resolveUrlWithProvider(url, executionContext, "openrouter")
return MultiLLMPromptExecutor(
OpenRouterLLMClient(
apiKey = key.resolve(executionContext),
settings = if (resolvedUrl == null) OpenRouterClientSettings() else OpenRouterClientSettings(
baseUrl = resolvedUrl
settings = if (url == null) OpenRouterClientSettings() else OpenRouterClientSettings(
baseUrl = url.resolve(executionContext)
)
)
)
Expand All @@ -121,16 +118,4 @@ sealed class PrototypeModelProvider {
override val modelClass: Any
get() = OpenRouterModels
}

companion object {
fun resolveUrlWithProvider(
url: PrototypeApiUrl?,
executionContext: SessionAgentExecutionContext,
providerName: String
): String? {
if (url == null) return null
val base = url.resolve(executionContext)
return if (url is PrototypeApiUrl.Proxy) "$base/$providerName" else base
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@ enum class LlmProviderProfile(
val sdkBaseUrlEnvVar: String? = null,
val sdkPathSuffix: String = ""
) {
OPENAI("openai", "https://api.openai.com", AuthStyle.Bearer, emptyMap(), OpenAIStrategy,
sdkBaseUrlEnvVar = "OPENAI_BASE_URL", sdkPathSuffix = "/v1"),
ANTHROPIC("anthropic", "https://api.anthropic.com", AuthStyle.Custom("x-api-key"), mapOf("anthropic-version" to "2023-06-01"), AnthropicStrategy,
sdkBaseUrlEnvVar = "ANTHROPIC_BASE_URL"),
OPENROUTER("openrouter", "https://openrouter.ai", AuthStyle.Bearer, emptyMap(), OpenAIStrategy,
sdkBaseUrlEnvVar = "OPENROUTER_BASE_URL");
OPENAI(
"openai", "https://api.openai.com", AuthStyle.Bearer, emptyMap(), OpenAIStrategy,
sdkBaseUrlEnvVar = "OPENAI_BASE_URL", sdkPathSuffix = "v1"
),

ANTHROPIC(
"anthropic",
"https://api.anthropic.com",
AuthStyle.Custom("x-api-key"),
mapOf("anthropic-version" to "2023-06-01"),
AnthropicStrategy,
sdkBaseUrlEnvVar = "ANTHROPIC_BASE_URL"
),

OPENROUTER(
"openrouter", "https://openrouter.ai", AuthStyle.Bearer, emptyMap(), OpenAIStrategy,
sdkBaseUrlEnvVar = "OPENROUTER_BASE_URL"
);

companion object {
private val byId = entries.associateBy { it.providerId }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
@file:OptIn(ExperimentalSerializationApi::class)

package org.coralprotocol.coralserver.llmproxy

import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonObject
import kotlinx.serialization.json.buildJsonObject
import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.longOrNull
import kotlinx.serialization.json.put
import kotlinx.serialization.json.putJsonObject
import org.slf4j.LoggerFactory
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.Serializable
import kotlinx.serialization.SerializationException
import kotlinx.serialization.json.*
import org.coralprotocol.coralserver.logging.LoggingInterface

@Serializable
@JsonIgnoreUnknownKeys
data class LlmUsage(
@JsonNames("prompt_tokens", "input_tokens")
val inputTokens: Long? = null,

private val logger = LoggerFactory.getLogger("llm-proxy-strategy")
@JsonNames("completion_tokens", "output_tokens")
val outputTokens: Long? = null,
)

@Serializable
@JsonIgnoreUnknownKeys
private data class LlmUsageWrapper(val usage: LlmUsage? = null)

interface LlmProviderStrategy {
fun prepareStreamingRequest(requestBody: String, json: Json): String = requestBody
fun extractBufferedTokens(responseBody: String, json: Json): Pair<Long?, Long?>
fun prepareStreamingRequest(requestBody: String, json: Json, logger: LoggingInterface): String = requestBody
fun extractBufferedTokens(responseBody: String, json: Json): LlmUsage?
fun createStreamParser(json: Json): StreamTokenParser
}

Expand All @@ -29,7 +40,7 @@ interface StreamTokenParser {
}

object OpenAIStrategy : LlmProviderStrategy {
override fun prepareStreamingRequest(requestBody: String, json: Json): String {
override fun prepareStreamingRequest(requestBody: String, json: Json, logger: LoggingInterface): String {
return try {
val obj = json.decodeFromString<JsonObject>(requestBody)
if (obj.containsKey("stream_options")) return requestBody
Expand All @@ -39,23 +50,17 @@ object OpenAIStrategy : LlmProviderStrategy {
}
json.encodeToString(JsonObject.serializer(), modified)
} catch (e: Exception) {
logger.debug("Failed to inject stream_options into request body", e)
logger.error(e) { "Failed to inject stream_options into request body" }
requestBody
}
}

override fun extractBufferedTokens(responseBody: String, json: Json): Pair<Long?, Long?> {
return extractUsageField(responseBody, json)
}

override fun extractBufferedTokens(responseBody: String, json: Json) = extractLlmUsage(responseBody, json)
override fun createStreamParser(json: Json): StreamTokenParser = OpenAIStreamParser(json)
}

object AnthropicStrategy : LlmProviderStrategy {
override fun extractBufferedTokens(responseBody: String, json: Json): Pair<Long?, Long?> {
return extractUsageField(responseBody, json)
}

override fun extractBufferedTokens(responseBody: String, json: Json) = extractLlmUsage(responseBody, json)
override fun createStreamParser(json: Json): StreamTokenParser = AnthropicStreamParser(json)
}

Expand All @@ -72,11 +77,12 @@ private class OpenAIStreamParser(private val json: Json) : StreamTokenParser {
if (!line.startsWith("data: ") || line.startsWith("data: [DONE]")) return
chunkCount++
try {
val (inp, out) = extractUsageField(line.removePrefix("data: "), json)
if (inp != null) inputTokens = inp
if (out != null) outputTokens = out
} catch (e: Exception) {
logger.trace("Failed to parse OpenAI stream chunk for token usage", e)
val usageWrapper = json.decodeFromString<LlmUsageWrapper>(line.removePrefix("data: "))

inputTokens = usageWrapper.usage?.inputTokens ?: inputTokens
outputTokens = usageWrapper.usage?.outputTokens ?: outputTokens
} catch (_: SerializationException) {
// ignored, not containing usage information is not an error
}
}
}
Expand Down Expand Up @@ -104,33 +110,31 @@ private class AnthropicStreamParser(private val json: Json) : StreamTokenParser
val obj = json.decodeFromString<JsonObject>(line.removePrefix("data: "))
when (lastEventType) {
"message_start" -> {
val usage = (obj["message"] as? JsonObject)?.get("usage") as? JsonObject
val inp = usage?.get("input_tokens")?.jsonPrimitive?.longOrNull
if (inp != null) inputTokens = inp
val usage = obj["message"]?.jsonObject?.let { extractLlmUsage(it, json) }
inputTokens = usage?.inputTokens ?: inputTokens
}

"message_delta" -> {
val usage = obj["usage"] as? JsonObject
val out = usage?.get("output_tokens")?.jsonPrimitive?.longOrNull
if (out != null) outputTokens = out
val usage = extractLlmUsage(obj, json)
outputTokens = usage?.outputTokens ?: outputTokens
}
}
} catch (e: Exception) {
logger.trace("Failed to parse Anthropic stream event for token usage", e)
} catch (_: SerializationException) {
// ignored, not containing usage information is not an error
}
}
}

private fun extractUsageField(body: String, json: Json): Pair<Long?, Long?> {
return try {
val obj = json.decodeFromString<JsonObject>(body)
val usage = obj["usage"] as? JsonObject ?: return null to null
val input = usage["prompt_tokens"]?.jsonPrimitive?.longOrNull
?: usage["input_tokens"]?.jsonPrimitive?.longOrNull
val output = usage["completion_tokens"]?.jsonPrimitive?.longOrNull
?: usage["output_tokens"]?.jsonPrimitive?.longOrNull
input to output
} catch (e: Exception) {
logger.trace("Failed to extract usage field from response body", e)
null to null
fun extractLlmUsage(body: String, json: Json) =
try {
json.decodeFromString<LlmUsageWrapper>(body).usage
} catch (_: SerializationException) {
null
}
}

fun extractLlmUsage(body: JsonObject, json: Json) =
try {
json.decodeFromJsonElement<LlmUsageWrapper>(body).usage
} catch (_: SerializationException) {
null
}
Loading
Loading