diff --git a/kt-kernel/python/experts.py b/kt-kernel/python/experts.py index 5d867e9f8..d56d51817 100644 --- a/kt-kernel/python/experts.py +++ b/kt-kernel/python/experts.py @@ -65,6 +65,7 @@ def __new__( cpu_save: bool = False, max_deferred_experts_per_token: Optional[int] = None, method: str = "AMXINT4", + numa_nodes: Optional[List[int]] = None, ): """ Factory method to create the appropriate backend implementation. @@ -85,6 +86,7 @@ def __new__( chunked_prefill_size: Maximum prefill chunk size cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer. Defaults to 0. + numa_nodes: Explicit list of NUMA node IDs for subpool mapping. If None, defaults to sequential. method: Backend method ("AMXINT4", "AMXINT8", "RAWINT4", "FP8", "BF16", "LLAMAFILE", "MOE_INT4", "MOE_INT8") Returns: @@ -117,6 +119,7 @@ def __new__( cpu_save=cpu_save, max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, + numa_nodes=numa_nodes, ) # Forward static methods to the base class diff --git a/kt-kernel/python/experts_base.py b/kt-kernel/python/experts_base.py index ba879ff08..e7e9a0833 100644 --- a/kt-kernel/python/experts_base.py +++ b/kt-kernel/python/experts_base.py @@ -164,6 +164,7 @@ def __init__( cpu_save: bool = False, max_deferred_experts_per_token: Optional[int] = None, method: str = "AMXINT4", + numa_nodes: Optional[List[int]] = None, ): """ Initialize base MoE Wrapper. @@ -185,6 +186,8 @@ def __init__( cpu_save: Whether to save weights to CPU memory max_deferred_experts_per_token: Number of experts per token to defer on this layer. Defaults to 0 (no defer). method: Backend method string + numa_nodes: Explicit list of NUMA node IDs for subpool mapping. + If None, defaults to [0, 1, ..., threadpool_count-1]. """ self.layer_idx = layer_idx self.num_experts = num_experts @@ -221,7 +224,15 @@ def __init__( if BaseMoEWrapper._cpu_infer_instance is None: worker_config = kt_kernel_ext.WorkerPoolConfig() - subpool_numa_map = list(range(threadpool_count)) + if numa_nodes is not None: + if len(numa_nodes) != threadpool_count: + raise ValueError( + f"numa_nodes length ({len(numa_nodes)}) must match " + f"threadpool_count ({threadpool_count})" + ) + subpool_numa_map = list(numa_nodes) + else: + subpool_numa_map = list(range(threadpool_count)) subpool_thread_count = [ cpuinfer_threads // threadpool_count + (1 if i < cpuinfer_threads % threadpool_count else 0) for i in range(threadpool_count) diff --git a/kt-kernel/python/utils/amx.py b/kt-kernel/python/utils/amx.py index cb7fd82ed..d7e7aa0bb 100644 --- a/kt-kernel/python/utils/amx.py +++ b/kt-kernel/python/utils/amx.py @@ -1,7 +1,7 @@ import os import torch import ctypes -from typing import Optional +from typing import List, Optional # Use relative imports for package structure from ..experts_base import BaseMoEWrapper @@ -47,6 +47,7 @@ def __init__( cpu_save: bool = False, max_deferred_experts_per_token: Optional[int] = None, method: str = "AMXINT4", + numa_nodes: Optional[List[int]] = None, ): """ Initialize AMX MoE Wrapper. @@ -97,6 +98,7 @@ def __init__( cpu_save=cpu_save, max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, + numa_nodes=numa_nodes, ) # AMX-specific: Check if we should load merged safetensor weights @@ -282,7 +284,11 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): moe_config.save = True moe_config.load = False base_key = f"model.layers.{self.layer_idx}" - w = self.safetensor_loader.load_experts(base_key) + try: + w = self.safetensor_loader.load_experts(base_key) + except (ValueError, KeyError): + base_key = f"model.language_model.layers.{self.layer_idx}" + w = self.safetensor_loader.load_experts(base_key) self.gate_proj = torch.cat(w["gate_weight"], dim=0).contiguous() self.up_proj = torch.cat(w["up_weight"], dim=0).contiguous() @@ -379,6 +385,7 @@ def __init__( cpu_save=cpu_save, max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, + numa_nodes=numa_nodes, ) if NativeMoEWrapper._native_loader_instance is None: @@ -416,7 +423,12 @@ def load_weights(self, physical_to_logical_map_cpu: torch.Tensor): t0 = time.time() base_key = f"model.layers.{self.layer_idx}" - weights = self.loader.load_experts(base_key) + try: + weights = self.loader.load_experts(base_key) + except (ValueError, KeyError): + # For VL/multimodal models (e.g. Qwen3.5) with 'language_model' prefix + base_key = f"model.language_model.layers.{self.layer_idx}" + weights = self.loader.load_experts(base_key) t1 = time.time() # Keep individual tensors instead of stacking - avoid expensive memory copy diff --git a/kt-kernel/python/utils/llamafile.py b/kt-kernel/python/utils/llamafile.py index 708c29d15..66ebbeaa7 100644 --- a/kt-kernel/python/utils/llamafile.py +++ b/kt-kernel/python/utils/llamafile.py @@ -1,5 +1,5 @@ import torch -from typing import Optional +from typing import List, Optional import os # Use relative imports for package structure @@ -133,6 +133,7 @@ def __init__( cpu_save=cpu_save, max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, + numa_nodes=numa_nodes, ) self.weights_to_keep = None diff --git a/kt-kernel/python/utils/moe_kernel.py b/kt-kernel/python/utils/moe_kernel.py index 1d772eab4..87326f359 100644 --- a/kt-kernel/python/utils/moe_kernel.py +++ b/kt-kernel/python/utils/moe_kernel.py @@ -1,7 +1,7 @@ import os import torch import ctypes -from typing import Optional +from typing import List, Optional # Use relative imports for package structure from ..experts_base import BaseMoEWrapper @@ -97,6 +97,7 @@ def __init__( cpu_save=cpu_save, max_deferred_experts_per_token=max_deferred_experts_per_token, method=method, + numa_nodes=numa_nodes, ) # moe-specific: Check if we should load merged safetensor weights