diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 9d46f265e..07a269afd 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -443,11 +443,15 @@ async def chat_completion(self, **kwargs) -> ChatCompletion: if "model" not in completion_kwargs: completion_kwargs["model"] = self.server.config.model_name - # Compute input_ids (using existing tokens if extending) + # State-aware input_ids computation if not self.track_tree and self.tokenizer is not None: input_ids = self._compute_input_ids(prompt, extending_node) completion_kwargs["input_ids"] = input_ids + if extending_node is not None: + existing_len = len(extending_node.tokens) + completion_kwargs["delta_input_ids"] = input_ids[existing_len:] + # Call the tokens and logprobs wrapper directly ( prompt_tokens, diff --git a/atroposlib/envs/server_handling/routing_utils.py b/atroposlib/envs/server_handling/routing_utils.py new file mode 100644 index 000000000..f100fca8d --- /dev/null +++ b/atroposlib/envs/server_handling/routing_utils.py @@ -0,0 +1,24 @@ +import hashlib +from typing import List, Optional + + +def get_prefix_hash(input_ids: List[int], prefix_cutoff: int = 100) -> str: + """ + Generate a stable MD5 hash for a sequence of tokens. + Used for consistent session routing to maximize KV cache hits. + """ + if not input_ids: + return "empty_prefix" + + cutoff = min(len(input_ids), prefix_cutoff) + prefix_tokens = input_ids[:cutoff] + + prefix_bytes = b",".join(str(t).encode("utf-8") for t in prefix_tokens) + return hashlib.md5(prefix_bytes).hexdigest() + + +def get_consistent_worker_index(prefix_hash: str, num_workers: int) -> int: + """Map a hash string to a worker index.""" + if num_workers <= 0: + return 0 + return int(prefix_hash, 16) % num_workers diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index b9c493f9c..229e70194 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -14,6 +14,7 @@ ManagedServer, ) from atroposlib.envs.server_handling.openai_server import OpenAIServer +from atroposlib.envs.server_handling.routing_utils import get_consistent_worker_index from atroposlib.envs.server_handling.server_baseline import ( APIServer, APIServerConfig, @@ -22,6 +23,7 @@ ) from atroposlib.envs.server_handling.server_harness import ServerHarness from atroposlib.envs.server_handling.sglang_server import SGLangServer +from atroposlib.envs.server_handling.sglang_stateful_server import StatefulSGLangServer from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer from atroposlib.envs.server_handling.vllm_server import VLLMServer @@ -72,17 +74,16 @@ def __init__( self.use_proxy = use_proxy or bool(self.proxy_url) # Tool parser — passed to ManagedServer for tool call support self.tool_parser = tool_parser - # First we check to see if it's the base server class, and if so, we need to select the appropriate server class - # You can't use type() to check if it's the base server class, because it's an abstract class, it'll appear as - # an ABCMeta, not what you're expecting. + # Select appropriate server class if not explicitly provided if inspect.isabstract(server_class): + if not isinstance(configs, list): if configs.server_type == "openai": server_class = OpenAIServer elif configs.server_type == "trl": server_class = TrlVllmServer elif configs.server_type == "sglang": - server_class = SGLangServer + server_class = StatefulSGLangServer elif configs.server_type == "vllm": server_class = VLLMServer else: @@ -93,7 +94,7 @@ def __init__( elif configs[0].server_type == "trl": server_class = TrlVllmServer elif configs[0].server_type == "sglang": - server_class = SGLangServer + server_class = StatefulSGLangServer elif configs[0].server_type == "vllm": server_class = VLLMServer else: @@ -410,6 +411,7 @@ async def managed_server( self, tokenizer=None, base_url: Optional[str] = None, + session_id: Optional[str] = None, preserve_think_blocks: bool = False, ): """ @@ -427,6 +429,8 @@ async def managed_server( extract from server or create from model name. base_url: Pin the session to a specific backend server by its base_url. In production, this comes from the atropos API's server allocation. + session_id: Session ID or prefix hash for pinning. + preserve_think_blocks: If True, preserves blocks in assistant messages, which are sometimes stripped by chat templates. Defaults to False. Usually not needed, since the chat template should be configured @@ -485,16 +489,55 @@ async def managed_server( return # -- In-process path (existing logic) -- - most_available_server = 0 - most_available_server_num_slots = -1 - for i, server in enumerate(self.servers): - if not server.server_healthy: - continue - if server.sem._value > most_available_server_num_slots: - most_available_server = i - most_available_server_num_slots = server.sem._value + # -- In-process path (existing logic + pinning fix) -- + selected_server = None + + # Resolve base_url from session_id + if session_id and not base_url and self.servers: + import hashlib + + hash_str = hashlib.md5(session_id.encode("utf-8")).hexdigest() + idx = get_consistent_worker_index(hash_str, len(self.servers)) + base_url = self.servers[idx].config.base_url + + # Attempt to pin to base_url with retries + if base_url: + for attempt in range(3): + for server in self.servers: + if server.config.base_url == base_url: + if server.server_healthy: + selected_server = server + break + break + + if selected_server: + break + + if attempt < 2: + await asyncio.sleep(0.1) + + if selected_server is None: + warnings.warn( + f"Requested pinned base_url '{base_url}' is not healthy or not found " + "after 3 attempts. Falling back to most available server." + ) - selected_server = self.servers[most_available_server] + # 2. Fallback to most available if no pin or pin failed + if selected_server is None: + most_available_server = 0 + most_available_server_num_slots = -1 + for i, server in enumerate(self.servers): + if not server.server_healthy: + continue + if server.sem._value > most_available_server_num_slots: + most_available_server = i + most_available_server_num_slots = server.sem._value + + if most_available_server_num_slots != -1: + selected_server = self.servers[most_available_server] + else: + # Edge case: No healthy servers + selected_server = self.servers[0] # Handle OpenAI servers separately - they don't support token IDs/logprobs if isinstance(selected_server, OpenAIServer): diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index 63201b3e0..e1ea284ad 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -40,29 +40,19 @@ def __init__( super().__init__(config, reasoning_config=reasoning_config) async def check_server_status_task(self, chat_completion: bool = True): + + health_url = f"{self.config.base_url.replace('/v1', '')}/health" while True: try: - if chat_completion: - await self.openai.chat.completions.create( - model=self.config.model_name, - messages=[{"role": "user", "content": "hi"}], - max_tokens=1, - ) - else: - await self.openai.completions.create( - model=self.config.model_name, - prompt="hi", - max_tokens=1, - ) - self.server_healthy = True - except ( - aiohttp.ClientError, - openai.OpenAIError, - openai.APITimeoutError, - Exception, - ): + async with aiohttp.ClientSession() as session: + async with session.get(health_url, timeout=5) as response: + if response.status == 200: + self.server_healthy = True + else: + self.server_healthy = False + except Exception: self.server_healthy = False - await asyncio.sleep(1) + await asyncio.sleep(2) # Check every 2 seconds async def _chat_completion_wrapper(self, **kwargs) -> ChatCompletion: """ diff --git a/atroposlib/envs/server_handling/sglang_stateful_server.py b/atroposlib/envs/server_handling/sglang_stateful_server.py new file mode 100644 index 000000000..896c2273c --- /dev/null +++ b/atroposlib/envs/server_handling/sglang_stateful_server.py @@ -0,0 +1,126 @@ +import asyncio +import warnings + +import aiohttp + +from atroposlib.envs.server_handling.server_baseline import APIServerConfig +from atroposlib.envs.server_handling.sglang_server import SGLangServer + + +class StatefulSGLangServer(SGLangServer): + """ + SGLangServer extension for stateful Delta-Sync protocol. + Optimizes network payload by sending only token deltas. + Includes auto-rebuild for cache-miss resilience. + """ + + def __init__(self, config: APIServerConfig, reasoning_config=None): + super().__init__(config, reasoning_config=reasoning_config) + self._session = None + + async def _get_session(self): + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=self.config.timeout) + ) + return self._session + + async def _tokens_and_logprobs_completion_wrapper( + self, **kwargs + ) -> tuple[list, list, list, list]: + """ + Interacts with SGLang /generate via raw HTTP, optimized for stateful deltas. + """ + assert ( + kwargs.get("model", None) is not None + ), "Model is required for completion!" + assert ( + kwargs.get("prompt", None) is not None + or kwargs.get("input_ids", None) is not None + ), "Prompt or input_ids is required!" + + if "input_ids" in kwargs: + prompt_tokens_full = kwargs.pop("input_ids") + kwargs.pop("prompt", None) + else: + prompt_tokens_full = self.tokenizer.encode(kwargs.pop("prompt")) + + # Clean double BOS if needed + if ( + len(prompt_tokens_full) >= 2 + and prompt_tokens_full[0] + == self.tokenizer.bos_token_id + == prompt_tokens_full[1] + ): + prompt_tokens_full = prompt_tokens_full[1:] + + if "max_tokens" in kwargs: + kwargs["max_new_tokens"] = kwargs.pop("max_tokens") + if "model" in kwargs: + kwargs.pop("model") + + # Extract new tokens (delta) if this is a continuation. + is_delta_request = False + if "delta_input_ids" in kwargs: + payload_input_ids = kwargs.pop("delta_input_ids") + is_delta_request = True + else: + payload_input_ids = prompt_tokens_full + + request_data = { + "input_ids": payload_input_ids, + "sampling_params": kwargs, + "return_logprob": True, + "return_text_in_logprobs": False, + } + + async def fetch_generate(payload): + session = await self._get_session() + async with session.post( + f"{self.config.base_url.replace('/v1', '')}/generate", + json=payload, + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + ) as response: + response.raise_for_status() + return await response.json() + + try: + results = await fetch_generate(request_data) + except Exception as e: + if is_delta_request: + warnings.warn( + f"Stateful request backfired ({e}). Attempting stateless fallback..." + ) + request_data["input_ids"] = prompt_tokens_full + results = await fetch_generate(request_data) + else: + raise e + + if not isinstance(results, list): + results = [results] + + output_tokens_list = [] + output_logprobs_list = [] + finish_reasons_list = [] + + for result in results: + meta_info = result.get("meta_info", {}) + output_token_logprobs = meta_info.get("output_token_logprobs", []) + logprobs = [item[0] for item in output_token_logprobs] + output_ids = [item[1] for item in output_token_logprobs] + finish_reason = meta_info.get("finish_reason", None) + + output_tokens_list.append(output_ids) + output_logprobs_list.append(logprobs) + finish_reasons_list.append(finish_reason) + + return ( + prompt_tokens_full, + output_tokens_list, + output_logprobs_list, + finish_reasons_list, + ) diff --git a/atroposlib/tests/test_server_pinning.py b/atroposlib/tests/test_server_pinning.py new file mode 100644 index 000000000..ffdfbddeb --- /dev/null +++ b/atroposlib/tests/test_server_pinning.py @@ -0,0 +1,136 @@ +import asyncio +import os +import sys +import unittest +import warnings +from unittest.mock import AsyncMock, MagicMock, patch + +# Add atropos to path relative to tests dir +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from atroposlib.envs.server_handling.server_baseline import APIServer +from atroposlib.envs.server_handling.server_manager import ( + APIServerConfig, + ServerManager, +) + + +class MockConfig(APIServerConfig): + base_url: str = "http://localhost:1111" + model_name: str = "test-model" + server_type: str = "sglang" + + +class MockServer(APIServer): + def __init__(self, config, reasoning_config=None): + super().__init__(config, reasoning_config=reasoning_config) + self.server_healthy = True + self.sem = asyncio.Semaphore(10) + self.eval_sem = asyncio.Semaphore(10) + + +class DummyTokenizer: + def encode(self, *args, **kwargs): + return [1, 2, 3] + + def decode(self, *args, **kwargs): + return "dummy" + + +class TestServerPinning(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self): + os.environ["ATROPOS_ALLOW_DUMMY_MANAGED_SERVER"] = "1" + self.config1 = MockConfig(base_url="http://worker-1:8000", server_type="openai") + self.config2 = MockConfig(base_url="http://worker-2:8000", server_type="openai") + self.config3 = MockConfig(base_url="http://worker-3:8000", server_type="openai") + + self.manager = ServerManager( + configs=[self.config1, self.config2, self.config3], server_class=MockServer + ) + + async def test_managed_server_pinning_respects_base_url(self): + """Verify managed_server follows base_url pin.""" + + # Make worker-2 very busy (fewest slots available) + self.manager.servers[0].sem = asyncio.Semaphore(10) + self.manager.servers[1].sem = asyncio.Semaphore(1) # worker-2 + self.manager.servers[2].sem = asyncio.Semaphore(10) + + # Pin to worker-2 + + target_url = "http://worker-2:8000" + async with self.manager.managed_server( + base_url=target_url, tokenizer=DummyTokenizer() + ) as managed: + self.assertEqual( + managed.server.config.base_url, + target_url, + "Pinning failed: ServerManager returned the wrong server.", + ) + + async def test_managed_server_pinning_fallback(self): + """Verify fallback when pin is invalid.""" + + self.manager.servers[0].sem = asyncio.Semaphore(5) + self.manager.servers[1].sem = asyncio.Semaphore(10) # worker-2 (most available) + self.manager.servers[2].sem = asyncio.Semaphore(2) + + fake_url = "http://worker-fake:8000" + + # Should fallback to worker-2 because it's most available + async with self.manager.managed_server( + base_url=fake_url, tokenizer=DummyTokenizer() + ) as managed: + self.assertEqual( + managed.server.config.base_url, + "http://worker-2:8000", + "Fallback failed: Did not select most available valid server.", + ) + + async def test_managed_server_session_id_mapping(self): + """Verify deterministic session_id hashing.""" + + # Make all servers equally available so we can trust the hash determinism + self.manager.servers[0].sem = asyncio.Semaphore(10) + self.manager.servers[1].sem = asyncio.Semaphore(10) + self.manager.servers[2].sem = asyncio.Semaphore(10) + + # 'demo_session_1_for_hashing' parses to a specific index. Let's rely on that hash being stable. + target_session = "demo_session_1_for_hashing" + + # Test routing by session id + async with self.manager.managed_server(session_id=target_session) as managed: + url_1 = managed.server.config.base_url + + # Do it again to ensure stability + async with self.manager.managed_server(session_id=target_session) as managed2: + url_2 = managed2.server.config.base_url + + self.assertEqual(url_1, url_2, "Session ID mapping should be deterministic.") + + async def test_managed_server_unhealthy_fallback(self): + """Verify fallback when pinned worker is unhealthy.""" + + # Make worker-1 targetted but unhealthy + self.manager.servers[0].server_healthy = False + + # Make worker-3 most available + self.manager.servers[1].sem = asyncio.Semaphore(2) + self.manager.servers[2].sem = asyncio.Semaphore(10) # worker-3 + + target_url = "http://worker-1:8000" + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + async with self.manager.managed_server( + base_url=target_url, tokenizer=DummyTokenizer() + ) as managed: + self.assertEqual( + managed.server.config.base_url, + "http://worker-3:8000", + "Unhealthy node bypass failed.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/benchmark_stateful_perf.py b/benchmark_stateful_perf.py new file mode 100644 index 000000000..35131446a --- /dev/null +++ b/benchmark_stateful_perf.py @@ -0,0 +1,111 @@ +import argparse +import asyncio +import hashlib +import json +import os +import statistics +import sys +import time +from typing import Dict, List + +from transformers import AutoTokenizer + +# Add atropos to path +sys.path.append(os.getcwd()) + +from atroposlib.envs.server_handling.server_manager import ( + APIServerConfig, + ServerManager, +) +from atroposlib.envs.server_handling.sglang_stateful_server import StatefulSGLangServer + + +# --------------------------------------------------------------------------- +# HARDWARE BENCHMARK SUITE +# --------------------------------------------------------------------------- +async def run_benchmark( + worker_urls: List[str], num_conversations: int = 5, turns_per_conv: int = 4 +): + print(f"Benchmarking Stateless vs Stateful SGLang ({len(worker_urls)} workers)") + + configs = [ + APIServerConfig( + base_url=url, + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + server_type="sglang", + health_check=True, + ) + for url in worker_urls + ] + + manager = ServerManager(configs=configs) + await asyncio.sleep(5) + + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + results = { + "stateless": {"ttfts": [], "total_times": []}, + "stateful": {"ttfts": [], "total_times": []}, + } + + async def benchmark_mode(mode_name: str, use_stateful: bool): + print(f"Running {mode_name}...") + + for i in range(num_conversations): + session_id = f"bench-{mode_name}-{i}" + messages = [] + + for t in range(turns_per_conv): + messages.append( + {"role": "user", "content": f"Explain topic {t} in one sentence."} + ) + start_time = time.time() + s_id = session_id if use_stateful else None + + async with manager.managed_server( + session_id=s_id, tokenizer=tokenizer + ) as managed: + res = await managed.chat_completion( + messages=messages, max_tokens=10 + ) + ttft = time.time() - start_time + + if t > 0: + results[mode_name]["ttfts"].append(ttft) + + messages.append( + {"role": "assistant", "content": res.choices[0].message.content} + ) + + await benchmark_mode("stateless", use_stateful=False) + await benchmark_mode("stateful", use_stateful=True) + + def get_stats(mode): + ttfts = results[mode]["ttfts"] + if not ttfts: + return 0.0, 0.0 + return statistics.mean(ttfts), statistics.stdev(ttfts) + + mean_sl, std_sl = get_stats("stateless") + mean_sf, std_sf = get_stats("stateful") + + print(f"\nResults (Latency T2-T{turns_per_conv}):") + print(f"Stateless: {mean_sl:.4f}s (std={std_sl:.4f})") + print(f"Stateful: {mean_sf:.4f}s (std={std_sf:.4f})") + + if mean_sl > 0: + improvement = (mean_sl - mean_sf) / mean_sl * 100 + print(f"Latency Reduction: {improvement:.2f}%") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--workers", + nargs="+", + default=["http://localhost:30001", "http://localhost:30002"], + ) + parser.add_argument("--convs", type=int, default=3) + args = parser.parse_args() + + asyncio.run(run_benchmark(args.workers, num_conversations=args.convs)) diff --git a/verify_stateful_e2e.py b/verify_stateful_e2e.py new file mode 100644 index 000000000..c08073d66 --- /dev/null +++ b/verify_stateful_e2e.py @@ -0,0 +1,89 @@ +import argparse +import asyncio +import hashlib +import os +import sys +import time +from typing import List + +from transformers import AutoTokenizer + +sys.path.append(os.getcwd()) + +from atroposlib.envs.server_handling.routing_utils import get_consistent_worker_index +from atroposlib.envs.server_handling.server_manager import ( + APIServerConfig, + ServerManager, +) + + +async def run_real_e2e_test(worker_urls: List[str]): + print(f"Hardware Verification on {len(worker_urls)} workers") + + configs = [ + APIServerConfig( + base_url=url, + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + server_type="sglang", + health_check=True, + ) + for url in worker_urls + ] + + manager = ServerManager(configs=configs) + await asyncio.sleep(8) + + for i, s in enumerate(manager.servers): + print(f"Worker {i} ({s.config.base_url}) Healthy: {s.server_healthy}") + + tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") + + session_id = "conversation-alpha" + idx = get_consistent_worker_index( + hashlib.md5(session_id.encode()).hexdigest(), len(worker_urls) + ) + expected_url = worker_urls[idx] + + print(f"Session {session_id} -> Expected {expected_url}") + + messages = [{"role": "user", "content": "What is the capital of France?"}] + + # Turn 1 + async with manager.managed_server( + session_id=session_id, tokenizer=tokenizer + ) as managed: + url_t1 = managed.server.config.base_url + print(f"Turn 1 directed to: {url_t1}") + res1 = await managed.chat_completion(messages=messages, max_tokens=20) + content1 = res1.choices[0].message.content.strip() + print(f"Response 1: {content1}") + + # Turn 2 + history = messages + [{"role": "assistant", "content": content1}] + messages_t2 = history + [{"role": "user", "content": "And its population?"}] + + async with manager.managed_server( + session_id=session_id, tokenizer=tokenizer + ) as managed: + url_t2 = managed.server.config.base_url + print(f"Turn 2 directed to: {url_t2}") + + if url_t1 != url_t2: + print(f"FAIL: Pinning failed ({url_t1} != {url_t2})") + sys.exit(1) + + res2 = await managed.chat_completion(messages=messages_t2, max_tokens=20) + print(f"Response 2: {res2.choices[0].message.content.strip()}") + + print("\nE2E VERIFICATION SUCCESSFUL") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--workers", + nargs="+", + default=["http://localhost:30001", "http://localhost:30002"], + ) + args = parser.parse_args() + asyncio.run(run_real_e2e_test(args.workers))