diff --git a/routellm/controller.py b/routellm/controller.py index 8a02a05..a4b04d8 100644 --- a/routellm/controller.py +++ b/routellm/controller.py @@ -24,7 +24,7 @@ }, "causal_llm": {"checkpoint_path": "routellm/causal_llm_gpt4_augmented"}, "bert": {"checkpoint_path": "routellm/bert_gpt4_augmented"}, - "mf": {"checkpoint_path": "routellm/mf_gpt4_augmented"}, + "mf": {"checkpoint_path": "madison-evans/routellm_all-MiniLM-L6-v2"}, } @@ -48,7 +48,9 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, progress_bar: bool = False, + hf_token: Optional[str] = None, # Add hf_token as a parameter ): + self.hf_token = hf_token # Store the hf_token self.model_pair = ModelPair(strong=strong_model, weak=weak_model) self.routers = {} self.api_base = api_base @@ -67,7 +69,8 @@ def __init__( for router in routers: if router_pbar is not None: router_pbar.set_description(f"Loading {router}") - self.routers[router] = ROUTER_CLS[router](**config.get(router, {})) + self.routers[router] = ROUTER_CLS[router](hf_token=self.hf_token, **config.get(router, {})) + # Some Python magic to match the OpenAI Python SDK self.chat = SimpleNamespace( @@ -101,6 +104,14 @@ def _parse_model_name(self, model: str): f"Invalid model {model}. Model name must be of the format 'router-[router name]-[threshold]." ) return router, threshold + + def get_routed_model(self, messages: list, router: str, threshold: float) -> str: + """ + Get the routed model for a given message using the specified router and threshold. + """ + self._validate_router_threshold(router, threshold) + routed_model = self._get_routed_model_for_completion(messages, router, threshold) + return routed_model def _get_routed_model_for_completion( self, messages: list, router: str, threshold: float diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index 09fbb25..ad006a9 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -1,7 +1,14 @@ import torch from huggingface_hub import PyTorchModelHubMixin - +from transformers import AutoTokenizer, AutoModel from routellm.routers.similarity_weighted.utils import OPENAI_CLIENT +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) MODEL_IDS = { "RWKV-4-Raven-14B": 0, @@ -70,7 +77,6 @@ "zephyr-7b-beta": 63, } - class MFModel(torch.nn.Module, PyTorchModelHubMixin): def __init__( self, @@ -79,51 +85,85 @@ def __init__( text_dim, num_classes, use_proj, + use_openai_embeddings=False, # Default: Hugging Face embeddings + embedding_model_name="BAAI/bge-base-en", # Match notebook + hf_token=None, # Hugging Face API token ): super().__init__() - self._name = "TextMF" self.use_proj = use_proj - self.P = torch.nn.Embedding(num_models, dim) + self.use_openai_embeddings = use_openai_embeddings + self.hf_token = hf_token + self.embedding_model_name = embedding_model_name - self.embedding_model = "text-embedding-3-small" + # Model embedding matrix + self.P = torch.nn.Embedding(num_models, dim) if self.use_proj: - self.text_proj = torch.nn.Sequential( - torch.nn.Linear(text_dim, dim, bias=False) - ) + self.text_proj = torch.nn.Linear(text_dim, dim, bias=False) else: - assert ( - text_dim == dim - ), f"text_dim {text_dim} must be equal to dim {dim} if not using projection" + assert text_dim == dim, f"text_dim {text_dim} must be equal to dim {dim} if not using projection" - self.classifier = torch.nn.Sequential( - torch.nn.Linear(dim, num_classes, bias=False) - ) + self.classifier = torch.nn.Linear(dim, num_classes, bias=False) + + if not self.use_openai_embeddings: + logger.info(f"Loading Hugging Face tokenizer and model: {self.embedding_model_name}") + + # Load tokenizer & model exactly as in the notebook + self.tokenizer = AutoTokenizer.from_pretrained( + self.embedding_model_name, + token=hf_token + ) + self.embedding_model = AutoModel.from_pretrained( + self.embedding_model_name, + token=hf_token + ) + self.embedding_model.eval() # Set to inference mode + self.embedding_model.to(self.get_device()) def get_device(self): return self.P.weight.device + def get_prompt_embedding(self, prompt): + """Generate sentence embedding using mean pooling (matches notebook).""" + logger.info(f"Generating embedding for prompt: {prompt[:30]}...") + + inputs = self.tokenizer( + prompt, + padding=True, + truncation=True, + return_tensors="pt" + ).to(self.get_device()) + + with torch.no_grad(): + outputs = self.embedding_model(**inputs) + last_hidden_state = outputs.last_hidden_state + + # Mean pooling over token embeddings + prompt_embed = last_hidden_state.mean(dim=1).squeeze() + + return prompt_embed + def forward(self, model_id, prompt): model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) - model_embed = self.P(model_id) model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) + prompt_embed = self.get_prompt_embedding(prompt) - prompt_embed = ( - OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) - .data[0] - .embedding - ) - prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) - prompt_embed = self.text_proj(prompt_embed) + if self.use_proj: + prompt_embed = self.text_proj(prompt_embed) return self.classifier(model_embed * prompt_embed).squeeze() @torch.no_grad() def pred_win_rate(self, model_a, model_b, prompt): logits = self.forward([model_a, model_b], prompt) - winrate = torch.sigmoid(logits[0] - logits[1]).item() + raw_diff = logits[0] - logits[1] + winrate = torch.sigmoid(raw_diff).item() + logger.info( + f"For prompt: '{prompt[:30]}...', logits: {[float(x) for x in logits]}, " + f"raw difference: {raw_diff:.4f}, winrate (sigmoid): {winrate:.4f}" + ) return winrate def load(self, path): - self.load_state_dict(torch.load(path)) + self.load_state_dict(torch.load(path)) \ No newline at end of file diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index 0096c0a..a635bfd 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -1,7 +1,7 @@ import abc import functools import random - +from transformers import AutoTokenizer, AutoModel import numpy as np import torch from datasets import concatenate_datasets, load_dataset @@ -21,6 +21,13 @@ compute_tiers, preprocess_battles, ) +import logging + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) def no_parallel(cls): @@ -211,18 +218,47 @@ class MatrixFactorizationRouter(Router): def __init__( self, checkpoint_path, - # This is the model pair for scoring at inference time, - # and can be different from the model pair used for routing. strong_model="gpt-4-1106-preview", weak_model="mixtral-8x7b-instruct-v0.1", hidden_size=128, - num_models=64, - text_dim=1536, + num_models=None, + text_dim=None, num_classes=1, use_proj=True, + use_openai_embeddings=True, + embedding_model_name=None, + hf_token=None, ): + """ + A simplified constructor that flattens the logic for: + 1) Setting num_models from MODEL_IDS, + 2) Determining embedding_model_name defaults, + 3) Setting text_dim for OpenAI vs. HF embeddings, + 4) Initializing the MFModel, + 5) Setting strong/weak model IDs. + """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Default num_models to the length of MODEL_IDS if not provided + num_models = num_models or len(MODEL_IDS) + + # Decide which embedding model_name to use if none provided + if not embedding_model_name: + if use_openai_embeddings: + # e.g. "text-embedding-ada-002" or your default + embedding_model_name = "text-embedding-3-small" + else: + embedding_model_name = "BAAI/bge-base-en" + + # Decide text_dim if not provided + if text_dim is None: + if use_openai_embeddings: + # e.g., 1536 for text-embedding-ada-002 + text_dim = 1536 + else: + text_dim = self._infer_hf_text_dim(embedding_model_name) + + # Initialize the MFModel self.model = MFModel.from_pretrained( checkpoint_path, dim=hidden_size, @@ -230,15 +266,42 @@ def __init__( text_dim=text_dim, num_classes=num_classes, use_proj=use_proj, - ) - self.model = self.model.eval().to(device) + use_openai_embeddings=use_openai_embeddings, + embedding_model_name=embedding_model_name, + hf_token=hf_token, + ).eval().to(device) + + # Store strong/weak model IDs self.strong_model_id = MODEL_IDS[strong_model] self.weak_model_id = MODEL_IDS[weak_model] + @staticmethod + def _infer_hf_text_dim(embedding_model_name: str) -> int: + """ + Helper to load a huggingface model and extract its hidden size. + Immediately frees model from memory. + """ + tokenizer = AutoTokenizer.from_pretrained(embedding_model_name) + hf_model = AutoModel.from_pretrained(embedding_model_name) + dim = hf_model.config.hidden_size + + del tokenizer + del hf_model + + return dim + def calculate_strong_win_rate(self, prompt): + """ + Scores the prompt using the MF model to see how + often the 'strong' model is predicted to win + over the 'weak' model. + """ winrate = self.model.pred_win_rate( - self.strong_model_id, self.weak_model_id, prompt + self.strong_model_id, + self.weak_model_id, + prompt ) + logger.info(f"\n\nwinrate: {winrate}\n\n") return winrate