Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion atroposlib/envs/server_handling/managed_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,15 @@ async def chat_completion(self, **kwargs) -> ChatCompletion:
if "model" not in completion_kwargs:
completion_kwargs["model"] = self.server.config.model_name

# Compute input_ids (using existing tokens if extending)
# State-aware input_ids computation
if not self.track_tree and self.tokenizer is not None:
input_ids = self._compute_input_ids(prompt, extending_node)
completion_kwargs["input_ids"] = input_ids

if extending_node is not None:
existing_len = len(extending_node.tokens)
completion_kwargs["delta_input_ids"] = input_ids[existing_len:]

# Call the tokens and logprobs wrapper directly
(
prompt_tokens,
Expand Down
24 changes: 24 additions & 0 deletions atroposlib/envs/server_handling/routing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import hashlib
from typing import List, Optional


def get_prefix_hash(input_ids: List[int], prefix_cutoff: int = 100) -> str:
"""
Generate a stable MD5 hash for a sequence of tokens.
Used for consistent session routing to maximize KV cache hits.
"""
if not input_ids:
return "empty_prefix"

cutoff = min(len(input_ids), prefix_cutoff)
prefix_tokens = input_ids[:cutoff]

prefix_bytes = b",".join(str(t).encode("utf-8") for t in prefix_tokens)
return hashlib.md5(prefix_bytes).hexdigest()


def get_consistent_worker_index(prefix_hash: str, num_workers: int) -> int:
"""Map a hash string to a worker index."""
if num_workers <= 0:
return 0
return int(prefix_hash, 16) % num_workers
71 changes: 57 additions & 14 deletions atroposlib/envs/server_handling/server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ManagedServer,
)
from atroposlib.envs.server_handling.openai_server import OpenAIServer
from atroposlib.envs.server_handling.routing_utils import get_consistent_worker_index
from atroposlib.envs.server_handling.server_baseline import (
APIServer,
APIServerConfig,
Expand All @@ -22,6 +23,7 @@
)
from atroposlib.envs.server_handling.server_harness import ServerHarness
from atroposlib.envs.server_handling.sglang_server import SGLangServer
from atroposlib.envs.server_handling.sglang_stateful_server import StatefulSGLangServer
from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer
from atroposlib.envs.server_handling.vllm_server import VLLMServer

Expand Down Expand Up @@ -72,17 +74,16 @@ def __init__(
self.use_proxy = use_proxy or bool(self.proxy_url)
# Tool parser — passed to ManagedServer for tool call support
self.tool_parser = tool_parser
# First we check to see if it's the base server class, and if so, we need to select the appropriate server class
# You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as
# an ABCMeta, not what you're expecting.
# Select appropriate server class if not explicitly provided
if inspect.isabstract(server_class):

if not isinstance(configs, list):
if configs.server_type == "openai":
server_class = OpenAIServer
elif configs.server_type == "trl":
server_class = TrlVllmServer
elif configs.server_type == "sglang":
server_class = SGLangServer
server_class = StatefulSGLangServer
elif configs.server_type == "vllm":
server_class = VLLMServer
else:
Expand All @@ -93,7 +94,7 @@ def __init__(
elif configs[0].server_type == "trl":
server_class = TrlVllmServer
elif configs[0].server_type == "sglang":
server_class = SGLangServer
server_class = StatefulSGLangServer
elif configs[0].server_type == "vllm":
server_class = VLLMServer
else:
Expand Down Expand Up @@ -410,6 +411,7 @@ async def managed_server(
self,
tokenizer=None,
base_url: Optional[str] = None,
session_id: Optional[str] = None,
preserve_think_blocks: bool = False,
):
"""
Expand All @@ -427,6 +429,8 @@ async def managed_server(
extract from server or create from model name.
base_url: Pin the session to a specific backend server by its base_url.
In production, this comes from the atropos API's server allocation.
session_id: Session ID or prefix hash for pinning.

preserve_think_blocks: If True, preserves <think> blocks in assistant messages,
which are sometimes stripped by chat templates. Defaults to False.
Usually not needed, since the chat template should be configured
Expand Down Expand Up @@ -485,16 +489,55 @@ async def managed_server(
return

# -- In-process path (existing logic) --
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value
# -- In-process path (existing logic + pinning fix) --
selected_server = None

# Resolve base_url from session_id
if session_id and not base_url and self.servers:
import hashlib

hash_str = hashlib.md5(session_id.encode("utf-8")).hexdigest()
idx = get_consistent_worker_index(hash_str, len(self.servers))
base_url = self.servers[idx].config.base_url

# Attempt to pin to base_url with retries
if base_url:
for attempt in range(3):
for server in self.servers:
if server.config.base_url == base_url:
if server.server_healthy:
selected_server = server
break
break

if selected_server:
break

if attempt < 2:
await asyncio.sleep(0.1)

if selected_server is None:
warnings.warn(
f"Requested pinned base_url '{base_url}' is not healthy or not found "
"after 3 attempts. Falling back to most available server."
)

selected_server = self.servers[most_available_server]
# 2. Fallback to most available if no pin or pin failed
if selected_server is None:
most_available_server = 0
most_available_server_num_slots = -1
for i, server in enumerate(self.servers):
if not server.server_healthy:
continue
if server.sem._value > most_available_server_num_slots:
most_available_server = i
most_available_server_num_slots = server.sem._value

if most_available_server_num_slots != -1:
selected_server = self.servers[most_available_server]
else:
# Edge case: No healthy servers
selected_server = self.servers[0]

# Handle OpenAI servers separately - they don't support token IDs/logprobs
if isinstance(selected_server, OpenAIServer):
Expand Down
30 changes: 10 additions & 20 deletions atroposlib/envs/server_handling/sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,19 @@ def __init__(
super().__init__(config, reasoning_config=reasoning_config)

async def check_server_status_task(self, chat_completion: bool = True):

health_url = f"{self.config.base_url.replace('/v1', '')}/health"
while True:
try:
if chat_completion:
await self.openai.chat.completions.create(
model=self.config.model_name,
messages=[{"role": "user", "content": "hi"}],
max_tokens=1,
)
else:
await self.openai.completions.create(
model=self.config.model_name,
prompt="hi",
max_tokens=1,
)
self.server_healthy = True
except (
aiohttp.ClientError,
openai.OpenAIError,
openai.APITimeoutError,
Exception,
):
async with aiohttp.ClientSession() as session:
async with session.get(health_url, timeout=5) as response:
if response.status == 200:
self.server_healthy = True
else:
self.server_healthy = False
except Exception:
self.server_healthy = False
await asyncio.sleep(1)
await asyncio.sleep(2) # Check every 2 seconds

async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion:
"""
Expand Down
126 changes: 126 additions & 0 deletions atroposlib/envs/server_handling/sglang_stateful_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio
import warnings

import aiohttp

from atroposlib.envs.server_handling.server_baseline import APIServerConfig
from atroposlib.envs.server_handling.sglang_server import SGLangServer


class StatefulSGLangServer(SGLangServer):
"""
SGLangServer extension for stateful Delta-Sync protocol.
Optimizes network payload by sending only token deltas.
Includes auto-rebuild for cache-miss resilience.
"""

def __init__(self, config: APIServerConfig, reasoning_config=None):
super().__init__(config, reasoning_config=reasoning_config)
self._session = None

async def _get_session(self):
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=self.config.timeout)
)
return self._session

async def _tokens_and_logprobs_completion_wrapper(
self, **kwargs
) -> tuple[list, list, list, list]:
"""
Interacts with SGLang /generate via raw HTTP, optimized for stateful deltas.
"""
assert (
kwargs.get("model", None) is not None
), "Model is required for completion!"
assert (
kwargs.get("prompt", None) is not None
or kwargs.get("input_ids", None) is not None
), "Prompt or input_ids is required!"

if "input_ids" in kwargs:
prompt_tokens_full = kwargs.pop("input_ids")
kwargs.pop("prompt", None)
else:
prompt_tokens_full = self.tokenizer.encode(kwargs.pop("prompt"))

# Clean double BOS if needed
if (
len(prompt_tokens_full) >= 2
and prompt_tokens_full[0]
== self.tokenizer.bos_token_id
== prompt_tokens_full[1]
):
prompt_tokens_full = prompt_tokens_full[1:]

if "max_tokens" in kwargs:
kwargs["max_new_tokens"] = kwargs.pop("max_tokens")
if "model" in kwargs:
kwargs.pop("model")

# Extract new tokens (delta) if this is a continuation.
is_delta_request = False
if "delta_input_ids" in kwargs:
payload_input_ids = kwargs.pop("delta_input_ids")
is_delta_request = True
else:
payload_input_ids = prompt_tokens_full

request_data = {
"input_ids": payload_input_ids,
"sampling_params": kwargs,
"return_logprob": True,
"return_text_in_logprobs": False,
}

async def fetch_generate(payload):
session = await self._get_session()
async with session.post(
f"{self.config.base_url.replace('/v1', '')}/generate",
json=payload,
headers=(
{"Authorization": f"Bearer {self.config.api_key}"}
if self.config.api_key
else {}
),
) as response:
response.raise_for_status()
return await response.json()

try:
results = await fetch_generate(request_data)
except Exception as e:
if is_delta_request:
warnings.warn(
f"Stateful request backfired ({e}). Attempting stateless fallback..."
)
request_data["input_ids"] = prompt_tokens_full
results = await fetch_generate(request_data)
else:
raise e

if not isinstance(results, list):
results = [results]

output_tokens_list = []
output_logprobs_list = []
finish_reasons_list = []

for result in results:
meta_info = result.get("meta_info", {})
output_token_logprobs = meta_info.get("output_token_logprobs", [])
logprobs = [item[0] for item in output_token_logprobs]
output_ids = [item[1] for item in output_token_logprobs]
finish_reason = meta_info.get("finish_reason", None)

output_tokens_list.append(output_ids)
output_logprobs_list.append(logprobs)
finish_reasons_list.append(finish_reason)

return (
prompt_tokens_full,
output_tokens_list,
output_logprobs_list,
finish_reasons_list,
)
Loading