diff --git a/convert-dense.py b/convert-dense.py index f2c05f17..154d5617 100755 --- a/convert-dense.py +++ b/convert-dense.py @@ -31,23 +31,25 @@ from sentencepiece import SentencePieceProcessor import os -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf if TYPE_CHECKING: from typing import TypeAlias -if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): +if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): faulthandler.register(signal.SIGUSR1) -NDArray: TypeAlias = 'np.ndarray[Any, Any]' +NDArray: TypeAlias = "np.ndarray[Any, Any]" DEFAULT_CONCURRENCY = 8 # # data types # + @dataclass(frozen=True) class DataType: name: str @@ -57,14 +59,23 @@ class DataType: def elements_to_bytes(self, n_elements: int) -> int: return n_elements * self.dtype.itemsize + @dataclass(frozen=True) class UnquantizedDataType(DataType): pass -DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) -DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) -DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) -DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) + +DT_F16 = UnquantizedDataType( + "F16", dtype=np.dtype(np.float16), valid_conversions=["F32", "Q8_0"] +) +DT_F32 = UnquantizedDataType( + "F32", dtype=np.dtype(np.float32), valid_conversions=["F16", "Q8_0"] +) +DT_I32 = UnquantizedDataType("I32", dtype=np.dtype(np.int16), valid_conversions=[]) +DT_BF16 = UnquantizedDataType( + "BF16", dtype=np.dtype(np.uint16), valid_conversions=["F32", "F16", "Q8_0"] +) + @dataclass(frozen=True) class QuantizedDataType(DataType): @@ -73,54 +84,69 @@ class QuantizedDataType(DataType): ggml_type: gguf.GGMLQuantizationType def quantize(self, arr: NDArray) -> NDArray: - raise NotImplementedError(f'Quantization for {self.name} not implemented') + raise NotImplementedError(f"Quantization for {self.name} not implemented") def elements_to_bytes(self, n_elements: int) -> int: - assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + assert ( + n_elements % self.block_size == 0 + ), f"Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}" return self.quantized_dtype.itemsize * (n_elements // self.block_size) + @dataclass(frozen=True) class Q8_0QuantizedDataType(QuantizedDataType): # Mini Q8_0 quantization in Python! def quantize(self, arr: NDArray) -> NDArray: - assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' - assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + assert ( + arr.size % self.block_size == 0 and arr.size != 0 + ), f"Bad array size {arr.size}" + assert arr.dtype == np.float32, f"Bad array type {arr.dtype}" n_blocks = arr.size // self.block_size blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: - d = abs(blocks).max(axis = 1) / np.float32(127) - with np.errstate(divide = 'ignore'): + d = abs(blocks).max(axis=1) / np.float32(127) + with np.errstate(divide="ignore"): qs = (blocks / d[:, None]).round() qs[d == 0] = 0 yield from zip(d, qs) - return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) -DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', - dtype = np.dtype(np.float32), valid_conversions = [], - ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, - quantized_dtype = np.dtype([('d', ' DataType: @@ -130,9 +156,10 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: # 1D tensors are always F32. return dt if len(tensor.shape) > 1 else DT_F32 + GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { - GGMLFileType.AllF32 : DT_F32, - GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.AllF32: DT_F32, + GGMLFileType.MostlyF16: DT_F16, GGMLFileType.MostlyQ8_0: DT_Q8_0, } @@ -140,18 +167,19 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: # hparams loading # + @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int f_norm_eps: float - arch: gguf.MODEL_ARCH = gguf.MODEL_ARCH.LLAMA + arch: gguf.MODEL_ARCH = gguf.MODEL_ARCH.LLAMA rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None @@ -167,36 +195,56 @@ class Params: @staticmethod def guessed(model: LazyModel) -> Params: # try transformer naming first - n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + n_vocab, n_embd = ( + model["model.embed_tokens.weight"].shape + if "model.embed_tokens.weight" in model + else model["tok_embeddings.weight"].shape + ) # try transformer naming first if "model.layers.0.self_attn.q_proj.weight" in model: - n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) - elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming - n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.q_proj.weight" not in model + ) + elif ( + "model.layers.0.self_attn.W_pack.weight" in model + ): # next: try baichuan naming + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.W_pack.weight" not in model + ) else: - n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"layers.{i}.attention.wq.weight" not in model + ) if n_layer < 1: - raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_layer'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) - n_head = n_embd // 128 # guessed - n_mult = 256 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed # TODO: verify this n_ff = int(2 * (4 * n_embd) / 3) n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_layer = n_layer, - n_ctx = -1, - n_ff = n_ff, - n_head = n_head, - n_head_kv = n_head, - f_norm_eps = 1e-5, + n_vocab=n_vocab, + n_embd=n_embd, + n_layer=n_layer, + n_ctx=-1, + n_ff=n_ff, + n_head=n_head, + n_head_kv=n_head, + f_norm_eps=1e-5, ) @staticmethod @@ -213,33 +261,35 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_orig_ctx = rope_scaling['original_max_position_embeddings'] - rope_finetuned = rope_scaling['finetuned'] + n_orig_ctx = rope_scaling["original_max_position_embeddings"] + rope_finetuned = rope_scaling["finetuned"] else: - raise NotImplementedError(f'Unknown rope scaling type: {typ}') + raise NotImplementedError(f"Unknown rope scaling type: {typ}") if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] elif "max_position_embeddings" in config: n_ctx = config["max_position_embeddings"] else: - raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) params = Params( - n_vocab = config["vocab_size"], - n_embd = config["hidden_size"], - n_layer = config["num_hidden_layers"], - n_ctx = n_ctx, - n_ff = config["intermediate_size"], - n_head = (n_head := config["num_attention_heads"]), - n_head_kv = config.get("num_key_value_heads", n_head), - f_norm_eps = config["rms_norm_eps"], - f_rope_freq_base = config.get("rope_theta"), - rope_scaling_type = rope_scaling_type, - f_rope_scale = f_rope_scale, - n_orig_ctx = n_orig_ctx, - rope_finetuned = rope_finetuned, + n_vocab=config["vocab_size"], + n_embd=config["hidden_size"], + n_layer=config["num_hidden_layers"], + n_ctx=n_ctx, + n_ff=config["intermediate_size"], + n_head=(n_head := config["num_attention_heads"]), + n_head_kv=config.get("num_key_value_heads", n_head), + f_norm_eps=config["rms_norm_eps"], + f_rope_freq_base=config.get("rope_theta"), + rope_scaling_type=rope_scaling_type, + f_rope_scale=f_rope_scale, + n_orig_ctx=n_orig_ctx, + rope_finetuned=rope_finetuned, ) if config.get("model_type", None) == "bamboo": @@ -247,7 +297,6 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: return params - # LLaMA v2 70B params.json # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1} @staticmethod @@ -266,30 +315,30 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: n_ctx = 2048 return Params( - n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), - n_embd = config["dim"], - n_layer = config["n_layers"], - n_ctx = n_ctx, - n_ff = model["layers.0.feed_forward.w1.weight"].shape[0], - n_head = (n_head := config["n_heads"]), - n_head_kv = config.get("n_kv_heads", n_head), - f_norm_eps = config["norm_eps"], - f_rope_freq_base = config.get("rope_theta"), + n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), + n_embd=config["dim"], + n_layer=config["n_layers"], + n_ctx=n_ctx, + n_ff=model["layers.0.feed_forward.w1.weight"].shape[0], + n_head=(n_head := config["n_heads"]), + n_head_kv=config.get("n_kv_heads", n_head), + f_norm_eps=config["norm_eps"], + f_rope_freq_base=config.get("rope_theta"), ) @staticmethod def load(model_plus: ModelPlus) -> Params: - hf_config_path = model_plus.paths[0].parent / "config.json" + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" if hf_config_path.exists(): params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) elif orig_config_path.exists(): params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) - elif model_plus.format != 'none': + elif model_plus.format != "none": params = Params.guessed(model_plus.model) else: - raise ValueError('Cannot guess params when model format is none') + raise ValueError("Cannot guess params when model format is none") params.path_model = model_plus.paths[0].parent @@ -300,43 +349,50 @@ def load(model_plus: ModelPlus) -> Params: # vocab # + class BpeVocab: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) + self.bpe_tokenizer = json.loads( + open(str(fname_tokenizer), encoding="utf-8").read() + ) added_tokens: dict[str, int] if fname_added_tokens is not None: # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: # Fall back to trying to find the added tokens in tokenizer.json - tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' + tokenizer_json_file = fname_tokenizer.parent / "tokenizer.json" if not tokenizer_json_file.is_file(): added_tokens = {} else: tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) added_tokens = dict( - (item['content'], item['id']) - for item in tokenizer_json.get('added_tokens', []) + (item["content"], item["id"]) + for item in tokenizer_json.get("added_tokens", []) # Added tokens here can be duplicates of the main vocabulary. - if item['content'] not in self.bpe_tokenizer ) + if item["content"] not in self.bpe_tokenizer + ) vocab_size: int = len(self.bpe_tokenizer) - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") + raise Exception( + f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}" + ) items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_list = [text for (text, idx) in items] + self.added_tokens_list = [text for (text, idx) in items] self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.bpe_tokenizer from transformers.models.gpt2 import tokenization_gpt2 + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()} for i, _ in enumerate(tokenizer): @@ -366,18 +422,22 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> No vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + new_tokens = { + id: piece for piece, id in added_tokens.items() if id >= vocab_size + } expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) + actual_new_ids = sorted(new_tokens.keys()) if expected_new_ids != actual_new_ids: - raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + raise ValueError( + f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" + ) # Token pieces that were added to the base vocabulary. - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: @@ -416,20 +476,24 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def __repr__(self) -> str: return f"" -Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab' + +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab" # # data loading # TODO: reuse (probably move to gguf.py?) # + def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: - #print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) + # print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) class Tensor(metaclass=ABCMeta): @@ -440,7 +504,9 @@ def astype(self, data_type: DataType) -> Tensor: ... @abstractmethod def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... @abstractmethod - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... + def permute_part( + self, n_part: int, n_head: int, n_head_kv: int + ) -> UnquantizedTensor: ... @abstractmethod def part(self, n_part: int) -> UnquantizedTensor: ... @abstractmethod @@ -448,7 +514,9 @@ def to_ggml(self) -> GGMLCompatibleTensor: ... def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: - assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + assert ( + bf16_arr.dtype == np.uint16 + ), f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" fp32_arr = bf16_arr.astype(np.uint32) << 16 return fp32_arr.view(np.float32) @@ -468,9 +536,13 @@ def astype(self, data_type: DataType) -> Tensor: def to_ggml(self) -> UnquantizedTensor: return self - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + def permute_part( + self, n_part: int, n_head: int, n_head_kv: int + ) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + return UnquantizedTensor( + permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv) + ) def part(self, n_part: int) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 @@ -480,7 +552,9 @@ def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) -def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: +def load_unquantized( + lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False +) -> NDArray: tensor = lazy_tensor.load() assert isinstance(tensor, UnquantizedTensor) @@ -491,7 +565,9 @@ def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, conver if convert: tensor.ndarray = tensor.ndarray.astype(expected_dtype) else: - raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + raise ValueError( + f"expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}" + ) return tensor.ndarray @@ -509,8 +585,9 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() # Should be okay if it maps to the same numpy type? - assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ - (self.data_type, ret.data_type, self.description) + assert ret.data_type == self.data_type or ( + self.data_type.dtype == ret.data_type.dtype + ), (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> LazyTensor: @@ -518,21 +595,29 @@ def astype(self, data_type: DataType) -> LazyTensor: def load() -> Tensor: return self.load().astype(data_type) - return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + return LazyTensor( + load, self.shape, data_type, f"convert({data_type}) {self.description}" + ) def validate_conversion_to(self, data_type: DataType) -> None: - if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: - raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + if ( + data_type != self.data_type + and data_type.name not in self.data_type.valid_conversions + ): + raise ValueError( + f"Cannot validate conversion from {self.data_type} to {data_type}." + ) -LazyModel: TypeAlias = 'dict[str, LazyTensor]' +LazyModel: TypeAlias = "dict[str, LazyTensor]" @dataclass class ModelPlus: model: LazyModel paths: list[Path] # Where this was read from. - format: Literal['ggml', 'torch', 'safetensors', 'none'] + format: Literal["ggml", "torch", "safetensors", "none"] vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. @@ -550,9 +635,11 @@ def convert(name: str) -> LazyTensor: if len(lazy_tensors[0].shape) == 1: # the tensor is just duplicated in every file return lazy_tensors[0] - if name.startswith('tok_embeddings.') or \ - name.endswith('.attention.wo.weight') or \ - name.endswith('.feed_forward.w2.weight'): + if ( + name.startswith("tok_embeddings.") + or name.endswith(".attention.wo.weight") + or name.endswith(".feed_forward.w2.weight") + ): # split by columns axis = 1 else: @@ -565,8 +652,16 @@ def load() -> UnquantizedTensor: ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] concatenated: NDArray = np.concatenate(ndarrays, axis=axis) return UnquantizedTensor(concatenated) - description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' - return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + + description = ( + "concatenated[[" + + "] | [".join(lt.description for lt in lazy_tensors) + + "]]" + ) + return LazyTensor( + load, concatenated_shape, lazy_tensors[0].data_type, description + ) + return {name: convert(name) for name in names} @@ -596,21 +691,38 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute(n_head, n_head_kv) - return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) -def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + return LazyTensor( + load, + lazy_tensor.shape, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) + + +def permute_part_lazy( + lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int +) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + return LazyTensor( + load, + s, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) + def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + return LazyTensor(load, s, lazy_tensor.data_type, "part " + lazy_tensor.description) # Functionality that simulates `torch.load` but where individual tensors are @@ -640,11 +752,11 @@ def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile self.zip_file = zip_file def persistent_load(self, pid: Any) -> Any: - assert pid[0] == 'storage' + assert pid[0] == "storage" assert isinstance(pid[1], LazyStorageKind) data_type = pid[1].data_type filename_stem = pid[2] - filename = f'{self.data_base_path}/{filename_stem}' + filename = f"{self.data_base_path}/{filename_stem}" info = self.zip_file.getinfo(filename) def load(offset: int, elm_count: int) -> NDArray: @@ -655,18 +767,31 @@ def load(offset: int, elm_count: int) -> NDArray: data = fp.read(size) assert len(data) == size return np.frombuffer(data, dtype) - description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + + description = f"storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}" return LazyStorage(load=load, kind=pid[1], description=description) @staticmethod - def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, - requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + def lazy_rebuild_tensor_v2( + storage: Any, + storage_offset: Any, + size: Any, + stride: Any, + requires_grad: Any, + backward_hooks: Any, + metadata: Any = None, + ) -> LazyTensor: assert isinstance(storage, LazyStorage) def load() -> UnquantizedTensor: elm_count = stride[0] * size[0] - return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) - description = f'pickled storage_offset={storage_offset} in {storage.description}' + return UnquantizedTensor( + storage.load(storage_offset, elm_count).reshape(size) + ) + + description = ( + f"pickled storage_offset={storage_offset} in {storage.description}" + ) return LazyTensor(load, list(size), storage.kind.data_type, description) @staticmethod @@ -676,56 +801,68 @@ def rebuild_from_type_v2(func, new_type, args, state): CLASSES: dict[tuple[str, str], Any] = { # getattr used here as a workaround for mypy not being smart enough to detrmine # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, + ("torch._tensor", "_rebuild_from_type_v2"): getattr( + rebuild_from_type_v2, "__func__" + ), + ("torch._utils", "_rebuild_tensor_v2"): getattr( + lazy_rebuild_tensor_v2, "__func__" + ), + ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16), + ("torch", "HalfStorage"): LazyStorageKind(DT_F16), + ("torch", "FloatStorage"): LazyStorageKind(DT_F32), + ("torch", "IntStorage"): LazyStorageKind(DT_I32), + ("torch", "Tensor"): LazyTensor, } def find_class(self, module: str, name: str) -> Any: - if not module.startswith('torch'): + if not module.startswith("torch"): return super().find_class(module, name) return self.CLASSES[(module, name)] def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: zf = zipfile.ZipFile(outer_fp) - pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + pickle_paths = [name for name in zf.namelist() if name.endswith(".pkl")] assert len(pickle_paths) == 1, pickle_paths - pickle_fp = zf.open(pickle_paths[0], 'r') - unpickler = LazyUnpickler(pickle_fp, - data_base_path=pickle_paths[0][:-4], - zip_file=zf) + pickle_fp = zf.open(pickle_paths[0], "r") + unpickler = LazyUnpickler( + pickle_fp, data_base_path=pickle_paths[0][:-4], zip_file=zf + ) model = unpickler.load() as_dict = dict(model.items()) - return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + return ModelPlus(model=as_dict, paths=[path], format="torch", vocab=None) def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: - header_size, = struct.unpack(' LazyTensor: - data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + data_type = SAFETENSORS_DATA_TYPES[info["dtype"]] numpy_dtype = data_type.dtype - shape: list[int] = info['shape'] - begin, end = info['data_offsets'] + shape: list[int] = info["shape"] + begin, end = info["data_offsets"] assert 0 <= begin <= end <= len(byte_buf) assert end - begin == math.prod(shape) * numpy_dtype.itemsize buf = byte_buf[begin:end] def load() -> UnquantizedTensor: - return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) - description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return UnquantizedTensor( + np.frombuffer(buf, dtype=numpy_dtype).reshape(shape) + ) + + description = ( + f"safetensors begin={begin} end={end} type={data_type} path={path}" + ) return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} - return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + model = { + name: convert(info) for (name, info) in header.items() if name != "__metadata__" + } + return ModelPlus(model=model, paths=[path], format="safetensors", vocab=None) def must_read(fp: IO[bytes], length: int) -> bytes: @@ -737,27 +874,34 @@ def must_read(fp: IO[bytes], length: int) -> bytes: @functools.lru_cache(maxsize=None) def lazy_load_file(path: Path) -> ModelPlus: - fp = open(path, 'rb') + fp = open(path, "rb") first8 = fp.read(8) fp.seek(0) - if first8[:2] == b'PK': + if first8[:2] == b"PK": # A zip file, i.e. PyTorch format return lazy_load_torch_file(fp, path) - elif struct.unpack(' Iterable[Out]: - '''Parallel map, but with backpressure. If the caller doesn't call `next` + +def bounded_parallel_map( + func: Callable[[In], Out], + iterable: Iterable[In], + concurrency: int, + max_workers: int | None = None, + use_processpool_executor: bool = False, +) -> Iterable[Out]: + """Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one - output value buffered per thread.''' + output value buffered per thread.""" if concurrency < 2: yield from map(func, iterable) # Not reached. @@ -767,7 +911,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc executor_class = ProcessPoolExecutor else: executor_class = ThreadPoolExecutor - with executor_class(max_workers = max_workers) as executor: + with executor_class(max_workers=max_workers) as executor: futures: list[concurrent.futures.Future[Out]] = [] done = False for _ in range(concurrency): @@ -787,11 +931,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc break yield result + def check_vocab_size(params: Params, vocab: Vocab) -> None: if params.n_vocab != vocab.vocab_size: assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) if params.n_vocab == vocab.vocab_size_base: - print("Ignoring added_tokens.json since model matches vocab size without it.") + print( + "Ignoring added_tokens.json since model matches vocab size without it." + ) vocab.added_tokens_list = [] vocab.vocab_size = vocab.vocab_size_base return @@ -799,14 +946,24 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: if vocab.fname_added_tokens is not None: msg += f" combined with {vocab.fname_added_tokens}" msg += f" has {vocab.vocab_size})." - if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None: + if ( + vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 + and vocab.fname_added_tokens is None + ): msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." raise Exception(msg) class OutputFile: - def __init__(self, fname_out: Path, arch: gguf.MODEL_ARCH, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: - self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[arch], endianess=endianess) + def __init__( + self, + fname_out: Path, + arch: gguf.MODEL_ARCH, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: + self.gguf = gguf.GGUFWriter( + fname_out, gguf.MODEL_ARCH_NAMES[arch], endianess=endianess + ) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -815,17 +972,17 @@ def add_meta_arch(self, params: Params) -> None: if params.n_ctx == 4096: name = "LLaMA v2" elif params.path_model is not None: - name = str(params.path_model).split('/')[-1] + name = str(params.path_model).split("/")[-1] - self.gguf.add_name (name) - self.gguf.add_context_length (params.n_ctx) - self.gguf.add_embedding_length (params.n_embd) - self.gguf.add_block_count (params.n_layer) - self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_name(name) + self.gguf.add_context_length(params.n_ctx) + self.gguf.add_embedding_length(params.n_embd) + self.gguf.add_block_count(params.n_layer) + self.gguf.add_feed_forward_length(params.n_ff) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) - self.gguf.add_head_count (params.n_head) - self.gguf.add_head_count_kv (params.n_head_kv) - self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) + self.gguf.add_head_count(params.n_head) + self.gguf.add_head_count_kv(params.n_head_kv) + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -859,7 +1016,7 @@ def add_meta_vocab(self, vocab: Vocab) -> None: elif isinstance(vocab, BpeVocab): self.gguf.add_tokenizer_model("gpt2") else: - raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab') + raise ValueError("Unknown vocab type: Not BpeVocab or SentencePieceVocab") self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) self.gguf.add_token_types(toktypes) @@ -869,10 +1026,14 @@ def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) - raw_dtype = getattr(tensor.data_type, 'ggml_type', None) - data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + raw_dtype = getattr(tensor.data_type, "ggml_type", None) + data_type = ( + getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype + ) data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) + self.gguf.add_tensor_info( + name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype + ) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -885,7 +1046,13 @@ def close(self) -> None: self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: + def write_vocab_only( + fname_out: Path, + params: Params, + vocab: Vocab, + svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out, params.arch, endianess=endianess) @@ -913,7 +1080,16 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: return dt.quantize(arr) @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: + def write_all( + fname_out: Path, + ftype: GGMLFileType, + params: Params, + model: LazyModel, + vocab: Vocab, + svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out, params.arch, endianess=endianess) @@ -931,43 +1107,68 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) + ndarrays_inner = bounded_parallel_map( + OutputFile.do_item, model.items(), concurrency=concurrency + ) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True) + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, + ndarrays_inner, + concurrency=concurrency, + max_workers=concurrency, + use_processpool_executor=True, + ) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + for i, ((name, lazy_tensor), ndarray) in enumerate( + zip(model.items(), ndarrays) + ): elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}") + print( + f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) of.gguf.write_tensor_data(ndarray) of.close() + def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: - wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type + wq_type = model[ + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight" + ].data_type if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): return GGMLFileType.AllF32 - if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): + if output_type_str == "f16" or ( + output_type_str is None and wq_type in (DT_F16, DT_BF16) + ): return GGMLFileType.MostlyF16 if output_type_str == "q8_0": return GGMLFileType.MostlyQ8_0 - name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + name_to_type = { + name: lazy_tensor.data_type for (name, lazy_tensor) in model.items() + } raise Exception(f"Unexpected combination of types: {name_to_type}") + def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: - return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) - for (name, tensor) in model.items()} + return { + name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items() + } + def convert_model_names(model: LazyModel, params: Params) -> LazyModel: tmap = gguf.TensorNameMap(params.arch, params.n_layer) - should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(params.arch, [])) + should_skip: set[gguf.MODEL_TENSOR] = set( + gguf.MODEL_TENSOR_SKIP.get(params.arch, []) + ) tmp = model @@ -975,21 +1176,43 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: for i in itertools.count(): if f"model.layers.{i}.self_attn.q_proj.weight" in model: print(f"Permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) - #tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.q_proj.weight"], + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.k_proj.weight"], + params.n_head, + params.n_head_kv, + ) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] elif f"model.layers.{i}.self_attn.W_pack.weight" in model: print(f"Unpacking and permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) - tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 0, + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 1, + params.n_head, + params.n_head_kv, + ) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], 2 + ) del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] else: break out: LazyModel = {} for name, lazy_tensor in model.items(): - tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + tensor_type, name_new = tmap.get_type_and_name( + name, try_suffixes=(".weight", ".bias") + ) or (None, None) if name_new is None: raise Exception(f"Unexpected tensor name: {name}") @@ -997,23 +1220,26 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: print(f"skipping tensor {name_new}") continue - print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + print( + f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}" + ) out[name_new] = lazy_tensor return out + def nth_multifile_path(path: Path, n: int) -> Path | None: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the nth path in the model. - ''' + """ # Support the following patterns: patterns: list[tuple[str, str]] = [ # - x.00.pth, x.01.pth, etc. - (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + (r"\.[0-9]{2}\.pth$", f".{n:02}.pth"), # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. - (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + (r"-[0-9]{5}-of-(.*)$", rf"-{n:05}-of-\1"), # x.bin, x.bin.1, etc. - (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + (r"(\.[0-9]+)?$", r"\1" if n == 0 else rf"\1.{n}"), ] for regex, replacement in patterns: if re.search(regex, path.name): @@ -1024,9 +1250,9 @@ def nth_multifile_path(path: Path, n: int) -> Path | None: def find_multifile_paths(path: Path) -> list[Path]: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the whole list of paths in the model. - ''' + """ ret: list[Path] = [] for i in itertools.count(): nth_path = nth_multifile_path(path, i) @@ -1042,7 +1268,7 @@ def find_multifile_paths(path: Path) -> list[Path]: def load_some_model(path: Path) -> ModelPlus: - '''Load a model of any supported format.''' + """Load a model of any supported format.""" # Be extra-friendly and accept either a file or a directory: if path.is_dir(): # Check if it's a set of safetensors files first @@ -1050,12 +1276,19 @@ def load_some_model(path: Path) -> ModelPlus: files = [file for glob in globs for file in path.glob(glob)] if not files: # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + globs = [ + "consolidated.00.pth", + "pytorch_model-00001-of-*.bin", + "*.pt", + "pytorch_model.bin", + ] files = [file for glob in globs for file in path.glob(glob)] if not files: raise Exception(f"Can't find model in directory {path}") if len(files) > 1: - raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}") + raise Exception( + f"Found multiple models in {path}, not sure which to pick: {files}" + ) path = files[0] paths = find_multifile_paths(path) @@ -1074,7 +1307,7 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: # be in the parent of that. if path.is_dir(): vocab_file = "tokenizer.model" - if vocabtype == 'bpe': + if vocabtype == "bpe": vocab_file = "vocab.json" path2 = path / vocab_file # Use `.parent` instead of /.. to handle the symlink case better. @@ -1086,7 +1319,8 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: else: raise FileNotFoundError( f"Could not find {vocab_file} in {path} or its parent; " - "if it's in another directory, pass the directory as --vocab-dir") + "if it's in another directory, pass the directory as --vocab-dir" + ) print(f"Loading vocab file '{path}', type '{vocabtype}'") @@ -1094,22 +1328,25 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: if vocabtype == "bpe": return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) elif vocabtype == "spm": - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) + return SentencePieceVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) else: raise ValueError(f"Unsupported vocabulary type {vocabtype}") def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: namestr = { - GGMLFileType.AllF32: "f32", + GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", - GGMLFileType.MostlyQ8_0:"q8_0", + GGMLFileType.MostlyQ8_0: "q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. " - "Please explicitly specify a path using --outfile.\n") + "Please explicitly specify a path using --outfile.\n" + ) sys.exit(1) return ret @@ -1119,7 +1356,9 @@ def do_dump_model(model_plus: ModelPlus) -> None: print(f"model_plus.format = {model_plus.format!r}") print(f"model_plus.vocab = {model_plus.vocab!r}") for name, lazy_tensor in model_plus.model.items(): - print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") + print( + f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}" + ) def main(args_in: list[str] | None = None) -> None: @@ -1127,21 +1366,63 @@ def main(args_in: list[str] | None = None) -> None: if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") - parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") - parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") - parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)") - parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") - parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") - parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) - parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") + parser = argparse.ArgumentParser( + description="Convert a LLaMa model to a GGML compatible file" + ) + parser.add_argument( + "--dump", + action="store_true", + help="don't convert, just show what's in the model", + ) + parser.add_argument( + "--dump-single", + action="store_true", + help="don't convert, just show what's in a single model file", + ) + parser.add_argument( + "--vocab-only", action="store_true", help="extract only the vocab" + ) + parser.add_argument( + "--outtype", + choices=output_choices, + help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", + ) + parser.add_argument( + "--vocab-dir", + type=Path, + help="directory containing tokenizer.model, if separate from model file", + ) + parser.add_argument( + "--outfile", type=Path, help="path to write to; default: based on input" + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)", + ) + parser.add_argument( + "--vocabtype", + choices=["spm", "bpe"], + help="vocab format (default: spm)", + default="spm", + ) + parser.add_argument( + "--ctx", type=int, help="model training context (default: based on input)" + ) + parser.add_argument( + "--concurrency", + type=int, + help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", + default=DEFAULT_CONCURRENCY, + ) + parser.add_argument( + "--bigendian", + action="store_true", + help="model is executed on big endian machine", + ) args = parser.parse_args(args_in) - + if args.dump_single: model_plus = lazy_load_file(args.model) do_dump_model(model_plus) @@ -1150,7 +1431,9 @@ def main(args_in: list[str] | None = None) -> None: if not args.vocab_only: model_plus = load_some_model(args.model) else: - model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + model_plus = ModelPlus( + model={}, paths=[args.model / "dummy"], format="none", vocab=None + ) if args.dump: do_dump_model(model_plus) @@ -1162,10 +1445,12 @@ def main(args_in: list[str] | None = None) -> None: params = Params.load(model_plus) if params.n_ctx == -1: if args.ctx is None: - raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" - "Please specify one with --ctx:\n" - " - LLaMA v1: --ctx 2048\n" - " - LLaMA v2: --ctx 4096\n") + raise Exception( + "The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n" + ) params.n_ctx = args.ctx if args.outtype: @@ -1183,9 +1468,11 @@ def main(args_in: list[str] | None = None) -> None: raise ValueError("need --outfile if using --vocab-only") # FIXME: Try to respect vocab_dir somehow? vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', - n_vocab = vocab.vocab_size) + special_vocab = gguf.SpecialVocab( + model_plus.paths[0].parent, + load_merges=args.vocabtype == "bpe", + n_vocab=vocab.vocab_size, + ) outfile = args.outfile OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) print(f"Wrote {outfile}") @@ -1197,22 +1484,33 @@ def main(args_in: list[str] | None = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) # FIXME: Try to respect vocab_dir somehow? - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', - n_vocab = vocab.vocab_size) - - model = model_plus.model - model = convert_model_names(model, params) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) + special_vocab = gguf.SpecialVocab( + model_plus.paths[0].parent, + load_merges=args.vocabtype == "bpe", + n_vocab=vocab.vocab_size, + ) + + model = model_plus.model + model = convert_model_names(model, params) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) outfile = args.outfile or default_outfile(model_plus.paths, ftype) params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) + OutputFile.write_all( + outfile, + ftype, + params, + model, + vocab, + special_vocab, + concurrency=args.concurrency, + endianess=endianess, + ) print(f"Wrote {outfile}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/convert-hf-to-powerinfer-gguf.py b/convert-hf-to-powerinfer-gguf.py index 0aa4632e..9a4280bf 100644 --- a/convert-hf-to-powerinfer-gguf.py +++ b/convert-hf-to-powerinfer-gguf.py @@ -90,7 +90,10 @@ def __init__( self.hparams = Model.load_hparams(self.dir_model) self.model_arch = self._get_model_architecture() self.gguf_writer = gguf.GGUFWriter( - fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file = False + fname_out, + gguf.MODEL_ARCH_NAMES[self.model_arch], + endianess=self.endianess, + use_temp_file=False, ) def set_vocab(self): @@ -517,6 +520,7 @@ def write_tensors(self): self.gguf_writer.add_tensor(new_name, data) + class OptModel(Model): def set_gguf_parameters(self, params: PredictorParams): self.gguf_writer.add_name("opt") @@ -527,20 +531,20 @@ def set_gguf_parameters(self, params: PredictorParams): self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_file_type(self.ftype) - + if params.sparse_threshold is not None: self.gguf_writer.add_sparse_threshold(params.sparse_threshold) def write_tensors(self): for name, data_torch in self.get_tensors(): old_dtype = data_torch.dtype - + # convert any unsupported data types to float32 if data_torch.dtype not in (torch.float16, torch.float32): data_torch = data_torch.to(torch.float32) - + data = data_torch.squeeze().numpy() - + # map tensor names new_name = self._translate_tensor_key(name) if new_name is None: @@ -552,8 +556,8 @@ def write_tensors(self): if "ffn_down" in new_name: new_name = new_name.replace("ffn_down", "ffn_down_t") data = data.T - - n_dims = len(data.shape) + + n_dims = len(data.shape) data_dtype = data.dtype # if f32 desired, convert any float16 to float32 @@ -570,11 +574,12 @@ def write_tensors(self): and n_dims == 2 ): data = data.astype(np.float16) - + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") - + self.gguf_writer.add_tensor(new_name, data) + @dataclass class PredictorParams: sparse_threshold: float | None = None @@ -583,12 +588,12 @@ class PredictorParams: def loadPredictorJson(config_path: Path) -> PredictorParams: config = json.load(open(config_path)) return PredictorParams( - sparse_threshold = config.get("sparse_threshold"), + sparse_threshold=config.get("sparse_threshold"), ) @staticmethod def load(model_instance: Model) -> PredictorParams: - config_path = model_instance.dir_mlp_pred / "config.json" + config_path = model_instance.dir_mlp_pred / "config.json" if config_path.exists(): params = PredictorParams.loadPredictorJson(config_path) @@ -597,6 +602,7 @@ def load(model_instance: Model) -> PredictorParams: return params + ###### CONVERSION LOGIC ###### diff --git a/convert.py b/convert.py index 9103f661..4350de83 100755 --- a/convert.py +++ b/convert.py @@ -29,23 +29,25 @@ from sentencepiece import SentencePieceProcessor import os -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf if TYPE_CHECKING: from typing import TypeAlias -if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): +if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): faulthandler.register(signal.SIGUSR1) -NDArray: TypeAlias = 'np.ndarray[Any, Any]' +NDArray: TypeAlias = "np.ndarray[Any, Any]" DEFAULT_CONCURRENCY = 8 # # data types # + @dataclass(frozen=True) class DataType: name: str @@ -55,14 +57,23 @@ class DataType: def elements_to_bytes(self, n_elements: int) -> int: return n_elements * self.dtype.itemsize + @dataclass(frozen=True) class UnquantizedDataType(DataType): pass -DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) -DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) -DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) -DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) + +DT_F16 = UnquantizedDataType( + "F16", dtype=np.dtype(np.float16), valid_conversions=["F32", "Q8_0"] +) +DT_F32 = UnquantizedDataType( + "F32", dtype=np.dtype(np.float32), valid_conversions=["F16", "Q8_0"] +) +DT_I32 = UnquantizedDataType("I32", dtype=np.dtype(np.int16), valid_conversions=[]) +DT_BF16 = UnquantizedDataType( + "BF16", dtype=np.dtype(np.uint16), valid_conversions=["F32", "F16", "Q8_0"] +) + @dataclass(frozen=True) class QuantizedDataType(DataType): @@ -71,54 +82,69 @@ class QuantizedDataType(DataType): ggml_type: gguf.GGMLQuantizationType def quantize(self, arr: NDArray) -> NDArray: - raise NotImplementedError(f'Quantization for {self.name} not implemented') + raise NotImplementedError(f"Quantization for {self.name} not implemented") def elements_to_bytes(self, n_elements: int) -> int: - assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + assert ( + n_elements % self.block_size == 0 + ), f"Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}" return self.quantized_dtype.itemsize * (n_elements // self.block_size) + @dataclass(frozen=True) class Q8_0QuantizedDataType(QuantizedDataType): # Mini Q8_0 quantization in Python! def quantize(self, arr: NDArray) -> NDArray: - assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' - assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + assert ( + arr.size % self.block_size == 0 and arr.size != 0 + ), f"Bad array size {arr.size}" + assert arr.dtype == np.float32, f"Bad array type {arr.dtype}" n_blocks = arr.size // self.block_size blocks = arr.reshape((n_blocks, self.block_size)) + # Much faster implementation of block quantization contributed by @Cebtenzzre def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: - d = abs(blocks).max(axis = 1) / np.float32(127) - with np.errstate(divide = 'ignore'): + d = abs(blocks).max(axis=1) / np.float32(127) + with np.errstate(divide="ignore"): qs = (blocks / d[:, None]).round() qs[d == 0] = 0 yield from zip(d, qs) - return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) -DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', - dtype = np.dtype(np.float32), valid_conversions = [], - ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, - quantized_dtype = np.dtype([('d', ' DataType: @@ -128,9 +154,10 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: # 1D tensors are always F32. return dt if len(tensor.shape) > 1 else DT_F32 + GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { - GGMLFileType.AllF32 : DT_F32, - GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.AllF32: DT_F32, + GGMLFileType.MostlyF16: DT_F16, GGMLFileType.MostlyQ8_0: DT_Q8_0, } @@ -138,6 +165,7 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: # hparams loading # + @dataclass class PredictorParams: sparse_threshold: float | None = None @@ -146,12 +174,12 @@ class PredictorParams: def loadPredictorJson(model: LazyModel, config_path: Path) -> PredictorParams: config = json.load(open(config_path)) return PredictorParams( - sparse_threshold = config.get("sparse_threshold"), + sparse_threshold=config.get("sparse_threshold"), ) @staticmethod def load(model_plus: ModelPlus) -> PredictorParams: - config_path = model_plus.paths[0].parent / "config.json" + config_path = model_plus.paths[0].parent / "config.json" if config_path.exists(): params = PredictorParams.loadPredictorJson(model_plus.model, config_path) @@ -160,18 +188,19 @@ def load(model_plus: ModelPlus) -> PredictorParams: return params + @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int f_norm_eps: float - arch: gguf.MODEL_ARCH = gguf.MODEL_ARCH.LLAMA + arch: gguf.MODEL_ARCH = gguf.MODEL_ARCH.LLAMA rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None f_rope_scale: float | None = None @@ -184,41 +213,63 @@ class Params: path_model: Path | None = None # MLP predictor parameters - predictor_params: PredictorParams = dataclasses.field(default_factory=PredictorParams) + predictor_params: PredictorParams = dataclasses.field( + default_factory=PredictorParams + ) @staticmethod def guessed(model: LazyModel) -> Params: # try transformer naming first - n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + n_vocab, n_embd = ( + model["model.embed_tokens.weight"].shape + if "model.embed_tokens.weight" in model + else model["tok_embeddings.weight"].shape + ) # try transformer naming first if "model.layers.0.self_attn.q_proj.weight" in model: - n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) - elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming - n_layer=next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.q_proj.weight" not in model + ) + elif ( + "model.layers.0.self_attn.W_pack.weight" in model + ): # next: try baichuan naming + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.W_pack.weight" not in model + ) else: - n_layer=next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"layers.{i}.attention.wq.weight" not in model + ) if n_layer < 1: - raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_layer'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) - n_head = n_embd // 128 # guessed - n_mult = 256 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed # TODO: verify this n_ff = int(2 * (4 * n_embd) / 3) n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_layer = n_layer, - n_ctx = -1, - n_ff = n_ff, - n_head = n_head, - n_head_kv = n_head, - f_norm_eps = 1e-5, + n_vocab=n_vocab, + n_embd=n_embd, + n_layer=n_layer, + n_ctx=-1, + n_ff=n_ff, + n_head=n_head, + n_head_kv=n_head, + f_norm_eps=1e-5, ) @staticmethod @@ -235,33 +286,35 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_orig_ctx = rope_scaling['original_max_position_embeddings'] - rope_finetuned = rope_scaling['finetuned'] + n_orig_ctx = rope_scaling["original_max_position_embeddings"] + rope_finetuned = rope_scaling["finetuned"] else: - raise NotImplementedError(f'Unknown rope scaling type: {typ}') + raise NotImplementedError(f"Unknown rope scaling type: {typ}") if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] elif "max_position_embeddings" in config: n_ctx = config["max_position_embeddings"] else: - raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" - "Suggestion: provide 'config.json' of the model in the same directory containing model files.") + raise Exception( + "failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files." + ) params = Params( - n_vocab = config["vocab_size"], - n_embd = config["hidden_size"], - n_layer = config["num_hidden_layers"], - n_ctx = n_ctx, - n_ff = config["intermediate_size"], - n_head = (n_head := config["num_attention_heads"]), - n_head_kv = config.get("num_key_value_heads", n_head), - f_norm_eps = config["rms_norm_eps"], - f_rope_freq_base = config.get("rope_theta"), - rope_scaling_type = rope_scaling_type, - f_rope_scale = f_rope_scale, - n_orig_ctx = n_orig_ctx, - rope_finetuned = rope_finetuned, + n_vocab=config["vocab_size"], + n_embd=config["hidden_size"], + n_layer=config["num_hidden_layers"], + n_ctx=n_ctx, + n_ff=config["intermediate_size"], + n_head=(n_head := config["num_attention_heads"]), + n_head_kv=config.get("num_key_value_heads", n_head), + f_norm_eps=config["rms_norm_eps"], + f_rope_freq_base=config.get("rope_theta"), + rope_scaling_type=rope_scaling_type, + f_rope_scale=f_rope_scale, + n_orig_ctx=n_orig_ctx, + rope_finetuned=rope_finetuned, ) if config.get("model_type", None) == "bamboo": @@ -287,30 +340,30 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: n_ctx = 2048 return Params( - n_vocab = config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), - n_embd = config["dim"], - n_layer = config["n_layers"], - n_ctx = n_ctx, - n_ff = model["layers.0.feed_forward.w1.weight"].shape[0], - n_head = (n_head := config["n_heads"]), - n_head_kv = config.get("n_kv_heads", n_head), - f_norm_eps = config["norm_eps"], - f_rope_freq_base = config.get("rope_theta"), + n_vocab=config.get("vocab_size", model["tok_embeddings.weight"].shape[0]), + n_embd=config["dim"], + n_layer=config["n_layers"], + n_ctx=n_ctx, + n_ff=model["layers.0.feed_forward.w1.weight"].shape[0], + n_head=(n_head := config["n_heads"]), + n_head_kv=config.get("n_kv_heads", n_head), + f_norm_eps=config["norm_eps"], + f_rope_freq_base=config.get("rope_theta"), ) @staticmethod def load(model_plus: ModelPlus) -> Params: - hf_config_path = model_plus.paths[0].parent / "config.json" + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" if hf_config_path.exists(): params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) elif orig_config_path.exists(): params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) - elif model_plus.format != 'none': + elif model_plus.format != "none": params = Params.guessed(model_plus.model) else: - raise ValueError('Cannot guess params when model format is none') + raise ValueError("Cannot guess params when model format is none") params.path_model = model_plus.paths[0].parent @@ -321,43 +374,50 @@ def load(model_plus: ModelPlus) -> Params: # vocab # + class BpeVocab: def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: - self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) + self.bpe_tokenizer = json.loads( + open(str(fname_tokenizer), encoding="utf-8").read() + ) added_tokens: dict[str, int] if fname_added_tokens is not None: # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: # Fall back to trying to find the added tokens in tokenizer.json - tokenizer_json_file = fname_tokenizer.parent / 'tokenizer.json' + tokenizer_json_file = fname_tokenizer.parent / "tokenizer.json" if not tokenizer_json_file.is_file(): added_tokens = {} else: tokenizer_json = json.load(open(tokenizer_json_file, encoding="utf-8")) added_tokens = dict( - (item['content'], item['id']) - for item in tokenizer_json.get('added_tokens', []) + (item["content"], item["id"]) + for item in tokenizer_json.get("added_tokens", []) # Added tokens here can be duplicates of the main vocabulary. - if item['content'] not in self.bpe_tokenizer ) + if item["content"] not in self.bpe_tokenizer + ) vocab_size: int = len(self.bpe_tokenizer) - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise Exception(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}") + raise Exception( + f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range {vocab_size} - {expected_end_id}; got {actual_ids}" + ) items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_list = [text for (text, idx) in items] + self.added_tokens_list = [text for (text, idx) in items] self.vocab_size_base: int = vocab_size - self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer - self.fname_added_tokens = fname_added_tokens + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.bpe_tokenizer from transformers.models.gpt2 import tokenization_gpt2 + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.items()} for i, _ in enumerate(tokenizer): @@ -387,18 +447,22 @@ def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> No vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + new_tokens = { + id: piece for piece, id in added_tokens.items() if id >= vocab_size + } expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) + actual_new_ids = sorted(new_tokens.keys()) if expected_new_ids != actual_new_ids: - raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + raise ValueError( + f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" + ) # Token pieces that were added to the base vocabulary. - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: @@ -437,20 +501,24 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: def __repr__(self) -> str: return f"" -Vocab: TypeAlias = 'BpeVocab | SentencePieceVocab' + +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab" # # data loading # TODO: reuse (probably move to gguf.py?) # + def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: - #print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) + # print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) class Tensor(metaclass=ABCMeta): @@ -461,7 +529,9 @@ def astype(self, data_type: DataType) -> Tensor: ... @abstractmethod def permute(self, n_head: int, n_head_kv: int) -> Tensor: ... @abstractmethod - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: ... + def permute_part( + self, n_part: int, n_head: int, n_head_kv: int + ) -> UnquantizedTensor: ... @abstractmethod def part(self, n_part: int) -> UnquantizedTensor: ... @abstractmethod @@ -469,7 +539,9 @@ def to_ggml(self) -> GGMLCompatibleTensor: ... def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: - assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + assert ( + bf16_arr.dtype == np.uint16 + ), f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" fp32_arr = bf16_arr.astype(np.uint32) << 16 return fp32_arr.view(np.float32) @@ -489,9 +561,13 @@ def astype(self, data_type: DataType) -> Tensor: def to_ggml(self) -> UnquantizedTensor: return self - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + def permute_part( + self, n_part: int, n_head: int, n_head_kv: int + ) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + return UnquantizedTensor( + permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv) + ) def part(self, n_part: int) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 @@ -501,7 +577,9 @@ def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) -def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: +def load_unquantized( + lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False +) -> NDArray: tensor = lazy_tensor.load() assert isinstance(tensor, UnquantizedTensor) @@ -512,7 +590,9 @@ def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, conver if convert: tensor.ndarray = tensor.ndarray.astype(expected_dtype) else: - raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + raise ValueError( + f"expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}" + ) return tensor.ndarray @@ -530,8 +610,9 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() # Should be okay if it maps to the same numpy type? - assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ - (self.data_type, ret.data_type, self.description) + assert ret.data_type == self.data_type or ( + self.data_type.dtype == ret.data_type.dtype + ), (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> LazyTensor: @@ -539,29 +620,40 @@ def astype(self, data_type: DataType) -> LazyTensor: def load() -> Tensor: return self.load().astype(data_type) - return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') - + + return LazyTensor( + load, self.shape, data_type, f"convert({data_type}) {self.description}" + ) + def transposed(self) -> LazyTensor: def load() -> Tensor: loaded = self.load() - assert isinstance(loaded, UnquantizedTensor), f'Cannot transpose {loaded}' + assert isinstance(loaded, UnquantizedTensor), f"Cannot transpose {loaded}" loaded.ndarray = loaded.ndarray.T return loaded - return LazyTensor(load, self.shape[::-1], self.data_type, f'transpose {self.description}') + + return LazyTensor( + load, self.shape[::-1], self.data_type, f"transpose {self.description}" + ) def validate_conversion_to(self, data_type: DataType) -> None: - if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: - raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + if ( + data_type != self.data_type + and data_type.name not in self.data_type.valid_conversions + ): + raise ValueError( + f"Cannot validate conversion from {self.data_type} to {data_type}." + ) -LazyModel: TypeAlias = 'dict[str, LazyTensor]' +LazyModel: TypeAlias = "dict[str, LazyTensor]" @dataclass class ModelPlus: model: LazyModel paths: list[Path] # Where this was read from. - format: Literal['ggml', 'torch', 'safetensors', 'none'] + format: Literal["ggml", "torch", "safetensors", "none"] vocab: Vocab | None # For GGML models (which have vocab built in), the vocab. @@ -579,9 +671,11 @@ def convert(name: str) -> LazyTensor: if len(lazy_tensors[0].shape) == 1: # the tensor is just duplicated in every file return lazy_tensors[0] - if name.startswith('tok_embeddings.') or \ - name.endswith('.attention.wo.weight') or \ - name.endswith('.feed_forward.w2.weight'): + if ( + name.startswith("tok_embeddings.") + or name.endswith(".attention.wo.weight") + or name.endswith(".feed_forward.w2.weight") + ): # split by columns axis = 1 else: @@ -594,8 +688,16 @@ def load() -> UnquantizedTensor: ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] concatenated: NDArray = np.concatenate(ndarrays, axis=axis) return UnquantizedTensor(concatenated) - description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' - return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + + description = ( + "concatenated[[" + + "] | [".join(lt.description for lt in lazy_tensors) + + "]]" + ) + return LazyTensor( + load, concatenated_shape, lazy_tensors[0].data_type, description + ) + return {name: convert(name) for name in names} @@ -610,8 +712,9 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: except StopIteration: vocab = None - if any("model.embed_tokens.weight" in mp.model for mp in models_plus) or \ - any("model.layers.0.fc1.weight" in mp.model for mp in models_plus): + if any("model.embed_tokens.weight" in mp.model for mp in models_plus) or any( + "model.layers.0.fc1.weight" in mp.model for mp in models_plus + ): # Transformers models put different tensors in different files, but # don't split indivdual tensors between files. model: LazyModel = {} @@ -626,21 +729,38 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute(n_head, n_head_kv) - return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) -def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + return LazyTensor( + load, + lazy_tensor.shape, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) + + +def permute_part_lazy( + lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int +) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + return LazyTensor( + load, + s, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) + def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + return LazyTensor(load, s, lazy_tensor.data_type, "part " + lazy_tensor.description) # Functionality that simulates `torch.load` but where individual tensors are @@ -670,11 +790,11 @@ def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile self.zip_file = zip_file def persistent_load(self, pid: Any) -> Any: - assert pid[0] == 'storage' + assert pid[0] == "storage" assert isinstance(pid[1], LazyStorageKind) data_type = pid[1].data_type filename_stem = pid[2] - filename = f'{self.data_base_path}/{filename_stem}' + filename = f"{self.data_base_path}/{filename_stem}" info = self.zip_file.getinfo(filename) def load(offset: int, elm_count: int) -> NDArray: @@ -685,18 +805,31 @@ def load(offset: int, elm_count: int) -> NDArray: data = fp.read(size) assert len(data) == size return np.frombuffer(data, dtype) - description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + + description = f"storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}" return LazyStorage(load=load, kind=pid[1], description=description) @staticmethod - def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, - requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + def lazy_rebuild_tensor_v2( + storage: Any, + storage_offset: Any, + size: Any, + stride: Any, + requires_grad: Any, + backward_hooks: Any, + metadata: Any = None, + ) -> LazyTensor: assert isinstance(storage, LazyStorage) def load() -> UnquantizedTensor: elm_count = stride[0] * size[0] - return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) - description = f'pickled storage_offset={storage_offset} in {storage.description}' + return UnquantizedTensor( + storage.load(storage_offset, elm_count).reshape(size) + ) + + description = ( + f"pickled storage_offset={storage_offset} in {storage.description}" + ) return LazyTensor(load, list(size), storage.kind.data_type, description) @staticmethod @@ -706,56 +839,68 @@ def rebuild_from_type_v2(func, new_type, args, state): CLASSES: dict[tuple[str, str], Any] = { # getattr used here as a workaround for mypy not being smart enough to detrmine # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, + ("torch._tensor", "_rebuild_from_type_v2"): getattr( + rebuild_from_type_v2, "__func__" + ), + ("torch._utils", "_rebuild_tensor_v2"): getattr( + lazy_rebuild_tensor_v2, "__func__" + ), + ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16), + ("torch", "HalfStorage"): LazyStorageKind(DT_F16), + ("torch", "FloatStorage"): LazyStorageKind(DT_F32), + ("torch", "IntStorage"): LazyStorageKind(DT_I32), + ("torch", "Tensor"): LazyTensor, } def find_class(self, module: str, name: str) -> Any: - if not module.startswith('torch'): + if not module.startswith("torch"): return super().find_class(module, name) return self.CLASSES[(module, name)] def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: zf = zipfile.ZipFile(outer_fp) - pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + pickle_paths = [name for name in zf.namelist() if name.endswith(".pkl")] assert len(pickle_paths) == 1, pickle_paths - pickle_fp = zf.open(pickle_paths[0], 'r') - unpickler = LazyUnpickler(pickle_fp, - data_base_path=pickle_paths[0][:-4], - zip_file=zf) + pickle_fp = zf.open(pickle_paths[0], "r") + unpickler = LazyUnpickler( + pickle_fp, data_base_path=pickle_paths[0][:-4], zip_file=zf + ) model = unpickler.load() as_dict = dict(model.items()) - return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + return ModelPlus(model=as_dict, paths=[path], format="torch", vocab=None) def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: - header_size, = struct.unpack(' LazyTensor: - data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + data_type = SAFETENSORS_DATA_TYPES[info["dtype"]] numpy_dtype = data_type.dtype - shape: list[int] = info['shape'] - begin, end = info['data_offsets'] + shape: list[int] = info["shape"] + begin, end = info["data_offsets"] assert 0 <= begin <= end <= len(byte_buf) assert end - begin == math.prod(shape) * numpy_dtype.itemsize buf = byte_buf[begin:end] def load() -> UnquantizedTensor: - return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) - description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return UnquantizedTensor( + np.frombuffer(buf, dtype=numpy_dtype).reshape(shape) + ) + + description = ( + f"safetensors begin={begin} end={end} type={data_type} path={path}" + ) return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} - return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + model = { + name: convert(info) for (name, info) in header.items() if name != "__metadata__" + } + return ModelPlus(model=model, paths=[path], format="safetensors", vocab=None) def must_read(fp: IO[bytes], length: int) -> bytes: @@ -767,27 +912,34 @@ def must_read(fp: IO[bytes], length: int) -> bytes: @functools.lru_cache(maxsize=None) def lazy_load_file(path: Path) -> ModelPlus: - fp = open(path, 'rb') + fp = open(path, "rb") first8 = fp.read(8) fp.seek(0) - if first8[:2] == b'PK': + if first8[:2] == b"PK": # A zip file, i.e. PyTorch format return lazy_load_torch_file(fp, path) - elif struct.unpack(' Iterable[Out]: - '''Parallel map, but with backpressure. If the caller doesn't call `next` + +def bounded_parallel_map( + func: Callable[[In], Out], + iterable: Iterable[In], + concurrency: int, + max_workers: int | None = None, + use_processpool_executor: bool = False, +) -> Iterable[Out]: + """Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one - output value buffered per thread.''' + output value buffered per thread.""" if concurrency < 2: yield from map(func, iterable) # Not reached. @@ -797,7 +949,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc executor_class = ProcessPoolExecutor else: executor_class = ThreadPoolExecutor - with executor_class(max_workers = max_workers) as executor: + with executor_class(max_workers=max_workers) as executor: futures: list[concurrent.futures.Future[Out]] = [] done = False for _ in range(concurrency): @@ -817,11 +969,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc break yield result + def check_vocab_size(params: Params, vocab: Vocab) -> None: if params.n_vocab != vocab.vocab_size: assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) if params.n_vocab == vocab.vocab_size_base: - print("Ignoring added_tokens.json since model matches vocab size without it.") + print( + "Ignoring added_tokens.json since model matches vocab size without it." + ) vocab.added_tokens_list = [] vocab.vocab_size = vocab.vocab_size_base return @@ -829,14 +984,24 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: if vocab.fname_added_tokens is not None: msg += f" combined with {vocab.fname_added_tokens}" msg += f" has {vocab.vocab_size})." - if vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 and vocab.fname_added_tokens is None: + if ( + vocab.vocab_size < params.n_vocab < vocab.vocab_size + 20 + and vocab.fname_added_tokens is None + ): msg += f" Most likely you are missing added_tokens.json (should be in {vocab.fname_tokenizer.parent})." raise Exception(msg) class OutputFile: - def __init__(self, fname_out: Path, arch: gguf.MODEL_ARCH, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: - self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[arch], endianess=endianess) + def __init__( + self, + fname_out: Path, + arch: gguf.MODEL_ARCH, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: + self.gguf = gguf.GGUFWriter( + fname_out, gguf.MODEL_ARCH_NAMES[arch], endianess=endianess + ) def add_meta_arch(self, params: Params) -> None: name = "LLaMA" @@ -845,17 +1010,17 @@ def add_meta_arch(self, params: Params) -> None: if params.n_ctx == 4096: name = "LLaMA v2" elif params.path_model is not None: - name = str(params.path_model).split('/')[-1] + name = str(params.path_model).split("/")[-1] - self.gguf.add_name (name) - self.gguf.add_context_length (params.n_ctx) - self.gguf.add_embedding_length (params.n_embd) - self.gguf.add_block_count (params.n_layer) - self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_name(name) + self.gguf.add_context_length(params.n_ctx) + self.gguf.add_embedding_length(params.n_embd) + self.gguf.add_block_count(params.n_layer) + self.gguf.add_feed_forward_length(params.n_ff) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) - self.gguf.add_head_count (params.n_head) - self.gguf.add_head_count_kv (params.n_head_kv) - self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) + self.gguf.add_head_count(params.n_head) + self.gguf.add_head_count_kv(params.n_head_kv) + self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -892,7 +1057,7 @@ def add_meta_vocab(self, vocab: Vocab) -> None: elif isinstance(vocab, BpeVocab): self.gguf.add_tokenizer_model("gpt2") else: - raise ValueError('Unknown vocab type: Not BpeVocab or SentencePieceVocab') + raise ValueError("Unknown vocab type: Not BpeVocab or SentencePieceVocab") self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) self.gguf.add_token_types(toktypes) @@ -902,10 +1067,14 @@ def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) - raw_dtype = getattr(tensor.data_type, 'ggml_type', None) - data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + raw_dtype = getattr(tensor.data_type, "ggml_type", None) + data_type = ( + getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype + ) data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype = raw_dtype) + self.gguf.add_tensor_info( + name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype + ) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -918,7 +1087,13 @@ def close(self) -> None: self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None: + def write_vocab_only( + fname_out: Path, + params: Params, + vocab: Vocab, + svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out, params.arch, endianess=endianess) @@ -946,7 +1121,16 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: return dt.quantize(arr) @staticmethod - def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE) -> None: + def write_all( + fname_out: Path, + ftype: GGMLFileType, + params: Params, + model: LazyModel, + vocab: Vocab, + svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: check_vocab_size(params, vocab) of = OutputFile(fname_out, params.arch, endianess=endianess) @@ -964,43 +1148,68 @@ def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyM of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) + ndarrays_inner = bounded_parallel_map( + OutputFile.do_item, model.items(), concurrency=concurrency + ) if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True) + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, + ndarrays_inner, + concurrency=concurrency, + max_workers=concurrency, + use_processpool_executor=True, + ) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + for i, ((name, lazy_tensor), ndarray) in enumerate( + zip(model.items(), ndarrays) + ): elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) - print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}") + print( + f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) of.gguf.write_tensor_data(ndarray) of.close() + def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: - wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type + wq_type = model[ + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight" + ].data_type if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): return GGMLFileType.AllF32 - if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): + if output_type_str == "f16" or ( + output_type_str is None and wq_type in (DT_F16, DT_BF16) + ): return GGMLFileType.MostlyF16 if output_type_str == "q8_0": return GGMLFileType.MostlyQ8_0 - name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + name_to_type = { + name: lazy_tensor.data_type for (name, lazy_tensor) in model.items() + } raise Exception(f"Unexpected combination of types: {name_to_type}") + def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: - return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) - for (name, tensor) in model.items()} + return { + name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items() + } + def convert_model_names(model: LazyModel, params: Params) -> LazyModel: tmap = gguf.TensorNameMap(params.arch, params.n_layer) - should_skip: set[gguf.MODEL_TENSOR] = set(gguf.MODEL_TENSOR_SKIP.get(params.arch, [])) + should_skip: set[gguf.MODEL_TENSOR] = set( + gguf.MODEL_TENSOR_SKIP.get(params.arch, []) + ) tmp = model @@ -1008,21 +1217,43 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: for i in itertools.count(): if f"model.layers.{i}.self_attn.q_proj.weight" in model: print(f"Permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) - #tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.q_proj.weight"], + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.k_proj.weight"], + params.n_head, + params.n_head_kv, + ) + # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] elif f"model.layers.{i}.self_attn.W_pack.weight" in model: print(f"Unpacking and permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) - tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 0, + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 1, + params.n_head, + params.n_head_kv, + ) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], 2 + ) del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] else: break out: LazyModel = {} for name, lazy_tensor in model.items(): - tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + tensor_type, name_new = tmap.get_type_and_name( + name, try_suffixes=(".weight", ".bias") + ) or (None, None) if name_new is None: raise Exception(f"Unexpected tensor name: {name}") @@ -1030,37 +1261,41 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel: print(f"skipping tensor {name_new}") continue - print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + print( + f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}" + ) out[name_new] = lazy_tensor return out + def postprocess_transpose(model: LazyModel) -> LazyModel: """Transpose ffn_down matrices for Axpy ops.""" out: LazyModel = {} - + for name, lazy_tensor in model.items(): if name.endswith(".ffn_down.weight"): out[name.replace("ffn_down", "ffn_down_t")] = lazy_tensor.transposed() else: out[name] = lazy_tensor - + return out + def nth_multifile_path(path: Path, n: int) -> Path | None: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the nth path in the model. - ''' + """ # Support the following patterns: patterns: list[tuple[str, str]] = [ # - x.00.pth, x.01.pth, etc. - (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + (r"\.[0-9]{2}\.pth$", f".{n:02}.pth"), # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. - (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + (r"-[0-9]{5}-of-(.*)$", rf"-{n:05}-of-\1"), # x.bin, x.bin.1, etc. - (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}'), + (r"(\.[0-9]+)?$", r"\1" if n == 0 else rf"\1.{n}"), # x_0.pt, x_1.pt, etc. - (r'(_[0-9]+)?\.pt$', fr'_{n}.pt'), + (r"(_[0-9]+)?\.pt$", rf"_{n}.pt"), ] for regex, replacement in patterns: if re.search(regex, path.name): @@ -1071,9 +1306,9 @@ def nth_multifile_path(path: Path, n: int) -> Path | None: def find_multifile_paths(path: Path) -> list[Path]: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the whole list of paths in the model. - ''' + """ ret: list[Path] = [] for i in itertools.count(): nth_path = nth_multifile_path(path, i) @@ -1089,7 +1324,7 @@ def find_multifile_paths(path: Path) -> list[Path]: def load_some_model(path: Path) -> ModelPlus: - '''Load a model of any supported format.''' + """Load a model of any supported format.""" # Be extra-friendly and accept either a file or a directory: if path.is_dir(): # Check if it's a set of safetensors files first @@ -1097,12 +1332,19 @@ def load_some_model(path: Path) -> ModelPlus: files = [file for glob in globs for file in path.glob(glob)] if not files: # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + globs = [ + "consolidated.00.pth", + "pytorch_model-00001-of-*.bin", + "*.pt", + "pytorch_model.bin", + ] files = [file for glob in globs for file in path.glob(glob)] if not files: raise Exception(f"Can't find model in directory {path}") if len(files) > 1: - raise Exception(f"Found multiple models in {path}, not sure which to pick: {files}") + raise Exception( + f"Found multiple models in {path}, not sure which to pick: {files}" + ) path = files[0] paths = find_multifile_paths(path) @@ -1114,21 +1356,27 @@ def load_some_model(path: Path) -> ModelPlus: model_plus = merge_multifile_models(models_plus) return model_plus + def load_predictor_model(path: Path) -> ModelPlus: - '''Load MLP models for sparse FFN inference from directory.''' + """Load MLP models for sparse FFN inference from directory.""" assert path.is_dir(), f"MLP model path {path} is not a directory" - + first_model_path = path / "model_0.pt" - assert first_model_path.resolve(), f"MLP model path {path} does not contain model_0.pt" + assert ( + first_model_path.resolve() + ), f"MLP model path {path} does not contain model_0.pt" model_paths = find_multifile_paths(first_model_path) models_plus: list[ModelPlus] = [] for model_path in model_paths: # find number in model_path - model_layer = int(re.search(r'model_(\d+).pt', str(model_path)).group(1)) + model_layer = int(re.search(r"model_(\d+).pt", str(model_path)).group(1)) print(f"Loading MLP model file {model_path}") mlp_model = lazy_load_file(model_path) - mlp_model.model = {f"model.layers.{model_layer}.{name}": tensor for name, tensor in mlp_model.model.items()} + mlp_model.model = { + f"model.layers.{model_layer}.{name}": tensor + for name, tensor in mlp_model.model.items() + } models_plus.append(mlp_model) return merge_multifile_models(models_plus) @@ -1140,7 +1388,7 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: # be in the parent of that. if path.is_dir(): vocab_file = "tokenizer.model" - if vocabtype == 'bpe': + if vocabtype == "bpe": vocab_file = "vocab.json" path2 = path / vocab_file # Use `.parent` instead of /.. to handle the symlink case better. @@ -1152,7 +1400,8 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: else: raise FileNotFoundError( f"Could not find {vocab_file} in {path} or its parent; " - "if it's in another directory, pass the directory as --vocab-dir") + "if it's in another directory, pass the directory as --vocab-dir" + ) print(f"Loading vocab file '{path}', type '{vocabtype}'") @@ -1160,22 +1409,25 @@ def load_vocab(path: Path, vocabtype: str | None) -> Vocab: if vocabtype == "bpe": return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) elif vocabtype == "spm": - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) + return SentencePieceVocab( + path, added_tokens_path if added_tokens_path.exists() else None + ) else: raise ValueError(f"Unsupported vocabulary type {vocabtype}") def default_outfile(model_paths: list[Path], file_type: GGMLFileType) -> Path: namestr = { - GGMLFileType.AllF32: "f32", + GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", - GGMLFileType.MostlyQ8_0:"q8_0", + GGMLFileType.MostlyQ8_0: "q8_0", }[file_type] ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. " - "Please explicitly specify a path using --outfile.\n") + "Please explicitly specify a path using --outfile.\n" + ) sys.exit(1) return ret @@ -1185,7 +1437,9 @@ def do_dump_model(model_plus: ModelPlus) -> None: print(f"model_plus.format = {model_plus.format!r}") print(f"model_plus.vocab = {model_plus.vocab!r}") for name, lazy_tensor in model_plus.model.items(): - print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") + print( + f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}" + ) def main(args_in: list[str] | None = None) -> None: @@ -1193,19 +1447,64 @@ def main(args_in: list[str] | None = None) -> None: if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") - parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") - parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") - parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", default="f16") - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") - parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) - parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine") - parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)") - parser.add_argument("sparse_predictor", type=Path, help="predictors for sparse FFN inference") + parser = argparse.ArgumentParser( + description="Convert a LLaMa model to a GGML compatible file" + ) + parser.add_argument( + "--dump", + action="store_true", + help="don't convert, just show what's in the model", + ) + parser.add_argument( + "--dump-single", + action="store_true", + help="don't convert, just show what's in a single model file", + ) + parser.add_argument( + "--vocab-only", action="store_true", help="extract only the vocab" + ) + parser.add_argument( + "--outtype", + choices=output_choices, + help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", + default="f16", + ) + parser.add_argument( + "--vocab-dir", + type=Path, + help="directory containing tokenizer.model, if separate from model file", + ) + parser.add_argument( + "--outfile", type=Path, help="path to write to; default: based on input" + ) + parser.add_argument( + "--ctx", type=int, help="model training context (default: based on input)" + ) + parser.add_argument( + "--concurrency", + type=int, + help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", + default=DEFAULT_CONCURRENCY, + ) + parser.add_argument( + "--bigendian", + action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--vocabtype", + choices=["spm", "bpe"], + help="vocab format (default: spm)", + default="spm", + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)", + ) + parser.add_argument( + "sparse_predictor", type=Path, help="predictors for sparse FFN inference" + ) args = parser.parse_args(args_in) @@ -1214,12 +1513,19 @@ def main(args_in: list[str] | None = None) -> None: hf_config = json.load(f) if model_type := hf_config.get("model_type") not in ("llama", "bamboo"): # invoke another script to convert other models - print(f"Model architecture {model_type} is not supported by this `convert.py`. Trying with `convert-hf-to-powerinfer-gguf.py`...") - script_path = Path(__file__).resolve().parent / "convert-hf-to-powerinfer-gguf.py" + print( + f"Model architecture {model_type} is not supported by this `convert.py`. Trying with `convert-hf-to-powerinfer-gguf.py`..." + ) + script_path = ( + Path(__file__).resolve().parent / "convert-hf-to-powerinfer-gguf.py" + ) subprocess.run(["python3", str(script_path.absolute())] + sys.argv[1:]) return except FileNotFoundError: - print("Could not find config.json under the original model directory. ", file=sys.stderr) + print( + "Could not find config.json under the original model directory. ", + file=sys.stderr, + ) sys.exit(1) if args.dump_single: @@ -1234,7 +1540,9 @@ def main(args_in: list[str] | None = None) -> None: params.predictor_params = PredictorParams.load(mlp_predictor_plus) model_plus = merge_multifile_models([model_plus, mlp_predictor_plus]) else: - model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) + model_plus = ModelPlus( + model={}, paths=[args.model / "dummy"], format="none", vocab=None + ) params = Params.load(model_plus) if args.dump: @@ -1246,10 +1554,12 @@ def main(args_in: list[str] | None = None) -> None: if params.n_ctx == -1: if args.ctx is None: - raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" - "Please specify one with --ctx:\n" - " - LLaMA v1: --ctx 2048\n" - " - LLaMA v2: --ctx 4096\n") + raise Exception( + "The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n" + ) params.n_ctx = args.ctx if args.outtype: @@ -1267,9 +1577,11 @@ def main(args_in: list[str] | None = None) -> None: raise ValueError("need --outfile if using --vocab-only") # FIXME: Try to respect vocab_dir somehow? vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', - n_vocab = vocab.vocab_size) + special_vocab = gguf.SpecialVocab( + model_plus.paths[0].parent, + load_merges=args.vocabtype == "bpe", + n_vocab=vocab.vocab_size, + ) outfile = args.outfile OutputFile.write_vocab_only(outfile, params, vocab, special_vocab) print(f"Wrote {outfile}") @@ -1281,21 +1593,32 @@ def main(args_in: list[str] | None = None) -> None: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) # FIXME: Try to respect vocab_dir somehow? - special_vocab = gguf.SpecialVocab(model_plus.paths[0].parent, - load_merges = args.vocabtype == 'bpe', - n_vocab = vocab.vocab_size) - - model = model_plus.model - model = convert_model_names(model, params) - model = postprocess_transpose(model) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) + special_vocab = gguf.SpecialVocab( + model_plus.paths[0].parent, + load_merges=args.vocabtype == "bpe", + n_vocab=vocab.vocab_size, + ) + + model = model_plus.model + model = convert_model_names(model, params) + model = postprocess_transpose(model) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) outfile = args.outfile or default_outfile(model_plus.paths, ftype) params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) + OutputFile.write_all( + outfile, + ftype, + params, + model, + vocab, + special_vocab, + concurrency=args.concurrency, + endianess=endianess, + ) print(f"Wrote {outfile}") # post-process: write another unique file header to distinguish from the origianl GGUF file @@ -1304,5 +1627,5 @@ def main(args_in: list[str] | None = None) -> None: fout.write(struct.pack(" 0 else []) - - self.lbfgs_x = Tensor('f', [self.nx]) - self.lbfgs_xp = Tensor('f', [self.nx]) - self.lbfgs_g = Tensor('f', [self.nx]) - self.lbfgs_gp = Tensor('f', [self.nx]) - self.lbfgs_d = Tensor('f', [self.nx]) - self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else []) - self.lbfgs_lmal = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lmys = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m]) - self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m]) + raise ValueError( + "Invalid version of optimization context in checkpoint file" + ) + + self.past = struct.unpack(" 0 else []) + + self.lbfgs_x = Tensor("f", [self.nx]) + self.lbfgs_xp = Tensor("f", [self.nx]) + self.lbfgs_g = Tensor("f", [self.nx]) + self.lbfgs_gp = Tensor("f", [self.nx]) + self.lbfgs_d = Tensor("f", [self.nx]) + self.lbfgs_pf = Tensor("f", [self.past] if self.past > 0 else []) + self.lbfgs_lmal = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lmys = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lms = Tensor("f", [self.nx, self.lbfgs_m]) + self.lbfgs_lmy = Tensor("f", [self.nx, self.lbfgs_m]) # forgot to save type in version 1: # guess self.type from number of remaining bytes - size_type_0 = 12 + sum([t.max_storage_size() for t in - [self.adam_m, self.adam_v] - +([self.adam_pf] if (self.past > 0) else [])]) - size_type_1 = 24 + sum([t.max_storage_size() for t in - [self.lbfgs_x, self.lbfgs_xp, self.lbfgs_g, - self.lbfgs_gp, self.lbfgs_d, self.lbfgs_pf, - self.lbfgs_lmal, self.lbfgs_lmys, - self.lbfgs_lms, self.lbfgs_lmy] - +([self.lbfgs_pf] if (self.past > 0) else [])]) + size_type_0 = 12 + sum( + [ + t.max_storage_size() + for t in [self.adam_m, self.adam_v] + + ([self.adam_pf] if (self.past > 0) else []) + ] + ) + size_type_1 = 24 + sum( + [ + t.max_storage_size() + for t in [ + self.lbfgs_x, + self.lbfgs_xp, + self.lbfgs_g, + self.lbfgs_gp, + self.lbfgs_d, + self.lbfgs_pf, + self.lbfgs_lmal, + self.lbfgs_lmys, + self.lbfgs_lms, + self.lbfgs_lmy, + ] + + ([self.lbfgs_pf] if (self.past > 0) else []) + ] + ) # due to alignment padding the size might not by exact # but the difference in size for both types is significant, # so we can just use whichever is closest @@ -179,11 +216,16 @@ def load(self, data, offset): if self.type == 0: offset = self.adam_m.load(data, offset) offset = self.adam_v.load(data, offset) - offset = self.adam_pf.load(data,offset) + offset = self.adam_pf.load(data, offset) - self.adam_fx_best = struct.unpack(' 0: - self.adam_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES) + self.adam_pf.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES + ) elif self.type == 1: gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS) - gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m) - gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best) - gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step) + gguf_writer.add_uint32( + LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m + ) + gguf_writer.add_float32( + LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best + ) + gguf_writer.add_float32( + LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step + ) gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, self.lbfgs_j) gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, self.lbfgs_k) - gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end) - gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement) - - self.lbfgs_x.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS) - self.lbfgs_xp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS) - self.lbfgs_g.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS) - self.lbfgs_gp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS) - self.lbfgs_d.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION) + gguf_writer.add_int32( + LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end + ) + gguf_writer.add_uint32( + LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement + ) + + self.lbfgs_x.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS + ) + self.lbfgs_xp.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS + ) + self.lbfgs_g.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS + ) + self.lbfgs_gp.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS + ) + self.lbfgs_d.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION + ) if self.past > 0: - self.lbfgs_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES) - self.lbfgs_lmal.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA) - self.lbfgs_lmys.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS) - self.lbfgs_lms.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S) - self.lbfgs_lmy.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y) + self.lbfgs_pf.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES + ) + self.lbfgs_lmal.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA + ) + self.lbfgs_lmys.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS + ) + self.lbfgs_lms.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S + ) + self.lbfgs_lmy.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y + ) else: - raise ValueError('Unknown optimizer type') + raise ValueError("Unknown optimizer type") + class LoraParams: def __init__(self): pass def load(self, data, offset): - self.n_rank_attention_norm = struct.unpack(' 0: rule += ' "," space' - rule += fr' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' + rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}' rule += ' "}" space' return self._add_rule(rule_name, rule) - elif schema_type == 'array' and 'items' in schema: + elif schema_type == "array" and "items" in schema: # TODO `prefixItems` keyword - item_rule_name = self.visit(schema['items'], f'{name}{"-" if name else ""}item') - rule = f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + item_rule_name = self.visit( + schema["items"], f'{name}{"-" if name else ""}item' + ) + rule = ( + f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space' + ) return self._add_rule(rule_name, rule) else: - assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" return self._add_rule( - 'root' if rule_name == 'root' else schema_type, - PRIMITIVE_RULES[schema_type] + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], ) def format_grammar(self): - return '\n'.join((f'{name} ::= {rule}' for name, rule in self._rules.items())) + return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items())) -def main(args_in = None): +def main(args_in=None): parser = argparse.ArgumentParser( - description=''' + description=""" Generates a grammar (suitable for use in ./main) that produces JSON conforming to a given JSON schema. Only a subset of JSON schema features are supported; more may be added in the future. - ''', + """, ) parser.add_argument( - '--prop-order', + "--prop-order", default=[], - type=lambda s: s.split(','), - help=''' + type=lambda s: s.split(","), + help=""" comma-separated property names defining the order of precedence for object properties; properties not specified here are given lower precedence than those that are, and are sorted alphabetically - ''' + """, ) - parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') + parser.add_argument("schema", help='file containing JSON schema ("-" for stdin)') args = parser.parse_args(args_in) - schema = json.load(sys.stdin if args.schema == '-' else open(args.schema)) + schema = json.load(sys.stdin if args.schema == "-" else open(args.schema)) prop_order = {name: idx for idx, name in enumerate(args.prop_order)} converter = SchemaConverter(prop_order) - converter.visit(schema, '') + converter.visit(schema, "") print(converter.format_grammar()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/examples/llava/convert-image-encoder-to-gguf.py b/examples/llava/convert-image-encoder-to-gguf.py index 2f5eef19..9dd4abdf 100644 --- a/examples/llava/convert-image-encoder-to-gguf.py +++ b/examples/llava/convert-image-encoder-to-gguf.py @@ -15,7 +15,9 @@ def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: +def should_skip_tensor( + name: str, has_text: bool, has_vision: bool, has_llava: bool +) -> bool: if name in ( "logit_scale", "text_model.embeddings.position_ids", @@ -23,7 +25,11 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b ): return True - if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]: + if has_llava and name in [ + "visual_projection.weight", + "vision_model.post_layernorm.weight", + "vision_model.post_layernorm.bias", + ]: return True if name.startswith("v") and not has_vision: @@ -42,7 +48,21 @@ def get_tensor_name(name: str) -> str: if "mm_projector" in name: return name.replace("model.mm_projector", "mm") - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + return ( + name.replace("text_model", "t") + .replace("vision_model", "v") + .replace("encoder.layers", "blk") + .replace("embeddings.", "") + .replace("_proj", "") + .replace("self_attn.", "attn_") + .replace("layer_norm", "ln") + .replace("layernorm", "ln") + .replace("mlp.fc1", "ffn_down") + .replace("mlp.fc2", "ffn_up") + .replace("embedding", "embd") + .replace("final", "post") + .replace("layrnorm", "ln") + ) def bytes_to_unicode(): @@ -72,26 +92,61 @@ def bytes_to_unicode(): ap = argparse.ArgumentParser(prog="convert_hf_to_gguf.py") -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") -ap.add_argument("--image-mean", nargs=3, type=float, required=False, help="Override image mean values") -ap.add_argument("--image-std", nargs=3, type=float, required=False, help="Override image std values") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +ap.add_argument( + "-m", + "--model-dir", + help="Path to model directory cloned from HF Hub", + required=True, +) +ap.add_argument( + "--use-f32", action="store_true", default=False, help="Use f32 instead of f16" +) +ap.add_argument( + "--text-only", + action="store_true", + required=False, + help="Save a text-only model. It can't be used to encode images", +) +ap.add_argument( + "--vision-only", + action="store_true", + required=False, + help="Save a vision-only model. It can't be used to encode texts", +) +ap.add_argument( + "--llava-projector", + help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.", +) +ap.add_argument( + "--image-mean", + nargs=3, + type=float, + required=False, + help="Override image mean values", +) +ap.add_argument( + "--image-std", nargs=3, type=float, required=False, help="Override image std values" +) +ap.add_argument( + "-o", + "--output-dir", + help="Directory to save GGUF files. Default is the original model directory", + default=None, +) args = ap.parse_args() if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") + print( + "--text-only and --image-only arguments cannot be specified at the same time." + ) exit(1) if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + print( + "WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet." + ) # output in the same directory as the model if output_dir is None dir_model = args.model_dir @@ -148,7 +203,11 @@ def bytes_to_unicode(): fout.add_bool("clip.has_vision_encoder", has_vision_encoder) fout.add_bool("clip.has_llava_projector", has_llava_projector) fout.add_file_type(ftype) -model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model) +model_name = ( + config["_name_or_path"] + if "_name_or_path" in config + else os.path.basename(dir_model) +) fout.add_name(model_name) if args.text_only: fout.add_description("text-only CLIP model") @@ -164,7 +223,10 @@ def bytes_to_unicode(): fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) - fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32( + "clip.text.projection_dim", + t_hparams.get("projection_dim", config["projection_dim"]), + ) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) @@ -176,14 +238,33 @@ def bytes_to_unicode(): fout.add_uint32("clip.vision.patch_size", v_hparams["patch_size"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) - fout.add_uint32("clip.vision.projection_dim", v_hparams.get("projection_dim", config["projection_dim"])) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + fout.add_uint32( + "clip.vision.projection_dim", + v_hparams.get("projection_dim", config["projection_dim"]), + ) + fout.add_uint32( + k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"] + ) + fout.add_float32( + k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"] + ) + block_count = ( + v_hparams["num_hidden_layers"] - 1 + if has_llava_projector + else v_hparams["num_hidden_layers"] + ) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - image_mean = processor.image_processor.image_mean if args.image_mean is None else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None else args.image_std + image_mean = ( + processor.image_processor.image_mean + if args.image_mean is None + else args.image_mean + ) + image_std = ( + processor.image_processor.image_std + if args.image_std is None + else args.image_std + ) fout.add_array("clip.vision.image_mean", image_mean) fout.add_array("clip.vision.image_std", image_std) @@ -207,7 +288,9 @@ def bytes_to_unicode(): state_dict = model.state_dict() for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): + if should_skip_tensor( + name, has_text_encoder, has_vision_encoder, has_llava_projector + ): # we don't need this print(f"skipping parameter: {name}") continue diff --git a/examples/llava/llava-surgery.py b/examples/llava/llava-surgery.py index 515f6b58..580c56a8 100644 --- a/examples/llava/llava-surgery.py +++ b/examples/llava/llava-surgery.py @@ -26,7 +26,10 @@ # BakLLaVA models contain CLIP tensors in it clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")] if len(clip_tensors) > 0: - clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors} + clip = { + name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() + for name in clip_tensors + } torch.save(clip, f"{args.model}/llava.clip") # remove these tensors diff --git a/examples/make-ggml.py b/examples/make-ggml.py index c73485eb..97d6bb92 100755 --- a/examples/make-ggml.py +++ b/examples/make-ggml.py @@ -37,24 +37,26 @@ - F32: absolutely huge, lossless - not recommended """ import subprocess + subprocess.run(f"pip install huggingface-hub==0.16.4", shell=True, check=True) import argparse import os from huggingface_hub import snapshot_download + def main(model, model_type, outname, outdir, quants, keep_fp16): if not os.path.isdir(model): print(f"Model not found at {model}. Downloading...") try: if outname is None: - outname = model.split('/')[-1] - model = snapshot_download(repo_id=model, cache_dir='../models/hf_cache') + outname = model.split("/")[-1] + model = snapshot_download(repo_id=model, cache_dir="../models/hf_cache") except Exception as e: raise Exception(f"Could not download the model: {e}") if outdir is None: - outdir = f'../models/{outname}' + outdir = f"../models/{outname}" if not os.path.isfile(f"{model}/config.json"): raise Exception(f"Could not find config.json in {model}") @@ -69,9 +71,17 @@ def main(model, model_type, outname, outdir, quants, keep_fp16): print(f"Making unquantised GGUF at {fp16}") if not os.path.isfile(fp16): if model_type != "llama": - subprocess.run(f"python3 ../convert-{model_type}-hf-to-gguf.py {model} 1 --outfile {fp16}", shell=True, check=True) + subprocess.run( + f"python3 ../convert-{model_type}-hf-to-gguf.py {model} 1 --outfile {fp16}", + shell=True, + check=True, + ) else: - subprocess.run(f"python3 ../convert.py {model} --outtype f16 --outfile {fp16}", shell=True, check=True) + subprocess.run( + f"python3 ../convert.py {model} --outtype f16 --outfile {fp16}", + shell=True, + check=True, + ) else: print(f"Unquantised GGML already exists at: {fp16}") @@ -84,15 +94,36 @@ def main(model, model_type, outname, outdir, quants, keep_fp16): if not keep_fp16: os.remove(fp16) + if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Convert/Quantize HF models to GGUF. If you have the HF model downloaded already, pass the path to the model dir. Otherwise, pass the Hugging Face model repo name. You need to be in the /examples folder for it to work.') - parser.add_argument('model', help='Downloaded model dir or Hugging Face model repo name') - parser.add_argument('--model_type', required=True, choices=['llama', 'starcoder', 'falcon', 'baichuan', 'gptneox'], help='Type of the model to be converted. Choose from llama, starcoder, falcon, baichuan, or gptneox.') - parser.add_argument('--outname', default=None, help='Output model(s) name') - parser.add_argument('--outdir', default=None, help='Output directory') - parser.add_argument('--quants', nargs='*', default=["Q4_K_M", "Q5_K_S"], help='Quant types') - parser.add_argument('--keep_fp16', action='store_true', help='Keep fp16 model', default=False) + parser = argparse.ArgumentParser( + description="Convert/Quantize HF models to GGUF. If you have the HF model downloaded already, pass the path to the model dir. Otherwise, pass the Hugging Face model repo name. You need to be in the /examples folder for it to work." + ) + parser.add_argument( + "model", help="Downloaded model dir or Hugging Face model repo name" + ) + parser.add_argument( + "--model_type", + required=True, + choices=["llama", "starcoder", "falcon", "baichuan", "gptneox"], + help="Type of the model to be converted. Choose from llama, starcoder, falcon, baichuan, or gptneox.", + ) + parser.add_argument("--outname", default=None, help="Output model(s) name") + parser.add_argument("--outdir", default=None, help="Output directory") + parser.add_argument( + "--quants", nargs="*", default=["Q4_K_M", "Q5_K_S"], help="Quant types" + ) + parser.add_argument( + "--keep_fp16", action="store_true", help="Keep fp16 model", default=False + ) args = parser.parse_args() - main(args.model, args.model_type, args.outname, args.outdir, args.quants, args.keep_fp16) + main( + args.model, + args.model_type, + args.outname, + args.outdir, + args.quants, + args.keep_fp16, + ) diff --git a/examples/server/api_like_OAI.py b/examples/server/api_like_OAI.py index 313e1a96..dc34395e 100755 --- a/examples/server/api_like_OAI.py +++ b/examples/server/api_like_OAI.py @@ -10,19 +10,64 @@ app = Flask(__name__) slot_id = -1 -parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.") -parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n') -parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ") -parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ") -parser.add_argument("--system-name", type=str, help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", default="\\nASSISTANT's RULE: ") -parser.add_argument("--stop", type=str, help="the end of response in chat completions(default: '')", default="") -parser.add_argument("--llama-api", type=str, help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", default='http://127.0.0.1:8080') -parser.add_argument("--api-key", type=str, help="Set the api key to allow only few user(default: NULL)", default="") -parser.add_argument("--host", type=str, help="Set the ip address to listen.(default: 127.0.0.1)", default='127.0.0.1') -parser.add_argument("--port", type=int, help="Set the port to listen.(default: 8081)", default=8081) +parser = argparse.ArgumentParser( + description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp." +) +parser.add_argument( + "--chat-prompt", + type=str, + help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", + default="A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n", +) +parser.add_argument( + "--user-name", + type=str, + help="USER name in chat completions(default: '\\nUSER: ')", + default="\\nUSER: ", +) +parser.add_argument( + "--ai-name", + type=str, + help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", + default="\\nASSISTANT: ", +) +parser.add_argument( + "--system-name", + type=str, + help="SYSTEM name in chat completions(default: '\\nASSISTANT's RULE: ')", + default="\\nASSISTANT's RULE: ", +) +parser.add_argument( + "--stop", + type=str, + help="the end of response in chat completions(default: '')", + default="", +) +parser.add_argument( + "--llama-api", + type=str, + help="Set the address of server.cpp in llama.cpp(default: http://127.0.0.1:8080)", + default="http://127.0.0.1:8080", +) +parser.add_argument( + "--api-key", + type=str, + help="Set the api key to allow only few user(default: NULL)", + default="", +) +parser.add_argument( + "--host", + type=str, + help="Set the ip address to listen.(default: 127.0.0.1)", + default="127.0.0.1", +) +parser.add_argument( + "--port", type=int, help="Set the port to listen.(default: 8081)", default=8081 +) args = parser.parse_args() + def is_present(json, key): try: buf = json[key] @@ -32,7 +77,8 @@ def is_present(json, key): return False return True -#convert chat to prompt + +# convert chat to prompt def convert_chat(messages): prompt = "" + args.chat_prompt.replace("\\n", "\n") @@ -41,47 +87,64 @@ def convert_chat(messages): ai_n = args.ai_name.replace("\\n", "\n") stop = args.stop.replace("\\n", "\n") - for line in messages: - if (line["role"] == "system"): + if line["role"] == "system": prompt += f"{system_n}{line['content']}" - if (line["role"] == "user"): + if line["role"] == "user": prompt += f"{user_n}{line['content']}" - if (line["role"] == "assistant"): + if line["role"] == "assistant": prompt += f"{ai_n}{line['content']}{stop}" prompt += ai_n.rstrip() return prompt + def make_postData(body, chat=False, stream=False): postData = {} - if (chat): + if chat: postData["prompt"] = convert_chat(body["messages"]) else: postData["prompt"] = body["prompt"] - if(is_present(body, "temperature")): postData["temperature"] = body["temperature"] - if(is_present(body, "top_k")): postData["top_k"] = body["top_k"] - if(is_present(body, "top_p")): postData["top_p"] = body["top_p"] - if(is_present(body, "max_tokens")): postData["n_predict"] = body["max_tokens"] - if(is_present(body, "presence_penalty")): postData["presence_penalty"] = body["presence_penalty"] - if(is_present(body, "frequency_penalty")): postData["frequency_penalty"] = body["frequency_penalty"] - if(is_present(body, "repeat_penalty")): postData["repeat_penalty"] = body["repeat_penalty"] - if(is_present(body, "mirostat")): postData["mirostat"] = body["mirostat"] - if(is_present(body, "mirostat_tau")): postData["mirostat_tau"] = body["mirostat_tau"] - if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"] - if(is_present(body, "seed")): postData["seed"] = body["seed"] - if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()] - if (args.stop != ""): + if is_present(body, "temperature"): + postData["temperature"] = body["temperature"] + if is_present(body, "top_k"): + postData["top_k"] = body["top_k"] + if is_present(body, "top_p"): + postData["top_p"] = body["top_p"] + if is_present(body, "max_tokens"): + postData["n_predict"] = body["max_tokens"] + if is_present(body, "presence_penalty"): + postData["presence_penalty"] = body["presence_penalty"] + if is_present(body, "frequency_penalty"): + postData["frequency_penalty"] = body["frequency_penalty"] + if is_present(body, "repeat_penalty"): + postData["repeat_penalty"] = body["repeat_penalty"] + if is_present(body, "mirostat"): + postData["mirostat"] = body["mirostat"] + if is_present(body, "mirostat_tau"): + postData["mirostat_tau"] = body["mirostat_tau"] + if is_present(body, "mirostat_eta"): + postData["mirostat_eta"] = body["mirostat_eta"] + if is_present(body, "seed"): + postData["seed"] = body["seed"] + if is_present(body, "logit_bias"): + postData["logit_bias"] = [ + [int(token), body["logit_bias"][token]] + for token in body["logit_bias"].keys() + ] + if args.stop != "": postData["stop"] = [args.stop] else: postData["stop"] = [] - if(is_present(body, "stop")): postData["stop"] += body["stop"] + if is_present(body, "stop"): + postData["stop"] += body["stop"] postData["n_keep"] = -1 postData["stream"] = stream postData["cache_prompt"] = True postData["slot_id"] = slot_id return postData + def make_resData(data, chat=False, promptToken=[]): resData = { "id": "chatcmpl" if (chat) else "cmpl", @@ -92,132 +155,187 @@ def make_resData(data, chat=False, promptToken=[]): "usage": { "prompt_tokens": data["tokens_evaluated"], "completion_tokens": data["tokens_predicted"], - "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"] - } + "total_tokens": data["tokens_evaluated"] + data["tokens_predicted"], + }, } - if (len(promptToken) != 0): + if len(promptToken) != 0: resData["promptToken"] = promptToken - if (chat): - #only one choice is supported - resData["choices"] = [{ - "index": 0, - "message": { - "role": "assistant", - "content": data["content"], - }, - "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" - }] + if chat: + # only one choice is supported + resData["choices"] = [ + { + "index": 0, + "message": { + "role": "assistant", + "content": data["content"], + }, + "finish_reason": ( + "stop" + if (data["stopped_eos"] or data["stopped_word"]) + else "length" + ), + } + ] else: - #only one choice is supported - resData["choices"] = [{ - "text": data["content"], - "index": 0, - "logprobs": None, - "finish_reason": "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" - }] + # only one choice is supported + resData["choices"] = [ + { + "text": data["content"], + "index": 0, + "logprobs": None, + "finish_reason": ( + "stop" + if (data["stopped_eos"] or data["stopped_word"]) + else "length" + ), + } + ] return resData -def make_resData_stream(data, chat=False, time_now = 0, start=False): + +def make_resData_stream(data, chat=False, time_now=0, start=False): resData = { "id": "chatcmpl" if (chat) else "cmpl", "object": "chat.completion.chunk" if (chat) else "text_completion.chunk", "created": time_now, "model": "LLaMA_CPP", - "choices": [ - { - "finish_reason": None, - "index": 0 - } - ] + "choices": [{"finish_reason": None, "index": 0}], } slot_id = data["slot_id"] - if (chat): - if (start): - resData["choices"][0]["delta"] = { - "role": "assistant" - } + if chat: + if start: + resData["choices"][0]["delta"] = {"role": "assistant"} else: - resData["choices"][0]["delta"] = { - "content": data["content"] - } - if (data["stop"]): - resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" + resData["choices"][0]["delta"] = {"content": data["content"]} + if data["stop"]: + resData["choices"][0]["finish_reason"] = ( + "stop" + if (data["stopped_eos"] or data["stopped_word"]) + else "length" + ) else: resData["choices"][0]["text"] = data["content"] - if (data["stop"]): - resData["choices"][0]["finish_reason"] = "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" + if data["stop"]: + resData["choices"][0]["finish_reason"] = ( + "stop" if (data["stopped_eos"] or data["stopped_word"]) else "length" + ) return resData -@app.route('/chat/completions', methods=['POST']) -@app.route('/v1/chat/completions', methods=['POST']) +@app.route("/chat/completions", methods=["POST"]) +@app.route("/v1/chat/completions", methods=["POST"]) def chat_completions(): - if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): + if ( + args.api_key != "" + and request.headers["Authorization"].split()[1] != args.api_key + ): return Response(status=403) body = request.get_json() stream = False tokenize = False - if(is_present(body, "stream")): stream = body["stream"] - if(is_present(body, "tokenize")): tokenize = body["tokenize"] + if is_present(body, "stream"): + stream = body["stream"] + if is_present(body, "tokenize"): + tokenize = body["tokenize"] postData = make_postData(body, chat=True, stream=stream) promptToken = [] - if (tokenize): - tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() + if tokenize: + tokenData = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/tokenize"), + data=json.dumps({"content": postData["prompt"]}), + ).json() promptToken = tokenData["tokens"] - if (not stream): - data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) + if not stream: + data = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/completion"), + data=json.dumps(postData), + ) print(data.json()) resData = make_resData(data.json(), chat=True, promptToken=promptToken) return jsonify(resData) else: + def generate(): - data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) + data = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/completion"), + data=json.dumps(postData), + stream=True, + ) time_now = int(time.time()) resData = make_resData_stream({}, chat=True, time_now=time_now, start=True) - yield 'data: {}\n'.format(json.dumps(resData)) + yield "data: {}\n".format(json.dumps(resData)) for line in data.iter_lines(): if line: - decoded_line = line.decode('utf-8') - resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + decoded_line = line.decode("utf-8") + resData = make_resData_stream( + json.loads(decoded_line[6:]), chat=True, time_now=time_now + ) + yield "data: {}\n".format(json.dumps(resData)) + + return Response(generate(), mimetype="text/event-stream") -@app.route('/completions', methods=['POST']) -@app.route('/v1/completions', methods=['POST']) +@app.route("/completions", methods=["POST"]) +@app.route("/v1/completions", methods=["POST"]) def completion(): - if (args.api_key != "" and request.headers["Authorization"].split()[1] != args.api_key): + if ( + args.api_key != "" + and request.headers["Authorization"].split()[1] != args.api_key + ): return Response(status=403) body = request.get_json() stream = False tokenize = False - if(is_present(body, "stream")): stream = body["stream"] - if(is_present(body, "tokenize")): tokenize = body["tokenize"] + if is_present(body, "stream"): + stream = body["stream"] + if is_present(body, "tokenize"): + tokenize = body["tokenize"] postData = make_postData(body, chat=False, stream=stream) promptToken = [] - if (tokenize): - tokenData = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/tokenize"), data=json.dumps({"content": postData["prompt"]})).json() + if tokenize: + tokenData = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/tokenize"), + data=json.dumps({"content": postData["prompt"]}), + ).json() promptToken = tokenData["tokens"] - if (not stream): - data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData)) + if not stream: + data = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/completion"), + data=json.dumps(postData), + ) print(data.json()) resData = make_resData(data.json(), chat=False, promptToken=promptToken) return jsonify(resData) else: + def generate(): - data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True) + data = requests.request( + "POST", + urllib.parse.urljoin(args.llama_api, "/completion"), + data=json.dumps(postData), + stream=True, + ) time_now = int(time.time()) for line in data.iter_lines(): if line: - decoded_line = line.decode('utf-8') - resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now) - yield 'data: {}\n'.format(json.dumps(resData)) - return Response(generate(), mimetype='text/event-stream') + decoded_line = line.decode("utf-8") + resData = make_resData_stream( + json.loads(decoded_line[6:]), chat=False, time_now=time_now + ) + yield "data: {}\n".format(json.dumps(resData)) + + return Response(generate(), mimetype="text/event-stream") + -if __name__ == '__main__': +if __name__ == "__main__": app.run(args.host, port=args.port) diff --git a/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py index ed93673b..5fd40058 100644 --- a/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py +++ b/examples/train-text-from-scratch/convert-train-checkpoint-to-gguf.py @@ -8,61 +8,62 @@ import numpy as np from pathlib import Path -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / '..' / '..' / 'gguf-py')) +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / ".." / ".." / "gguf-py")) import gguf # gguf constants LLM_KV_OPTIMIZER_TYPE = "optimizer.type" -LLM_KV_OPTIMIZER_TYPE_ADAM = "adam" +LLM_KV_OPTIMIZER_TYPE_ADAM = "adam" LLM_KV_OPTIMIZER_TYPE_LBFGS = "lbfgs" -LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version" -LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count" -LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count" -LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count" -LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized" -LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss" -LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss" -LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count" +LLM_KV_OPTIMIZER_FILE_VERSION = "optimizer.file_version" +LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergence_past_count" +LLM_KV_OPTIMIZER_PARAMETER_COUNT = "optimizer.parameter_count" +LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count" +LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized" +LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss" +LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss" +LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count" LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT = "optimizer.lbfgs.approx_hessian_count" -LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss" -LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step" -LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j" -LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k" -LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end" +LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS = "optimizer.lbfgs.best_loss" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP = "optimizer.lbfgs.line_search_step" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J = "optimizer.lbfgs.line_search_j" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K = "optimizer.lbfgs.line_search_k" +LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END = "optimizer.lbfgs.line_search_end" LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT = "optimizer.lbfgs.no_improvement_count" -LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments" -LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments" +LLM_TENSOR_OPTIMIZER_ADAM_FIRST_MOMENTS = "optimizer.adam.first_moments" +LLM_TENSOR_OPTIMIZER_ADAM_SECOND_MOMENTS = "optimizer.adam.second_moments" LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES = "optimizer.adam.past_loss_values" -LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters" +LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS = "optimizer.lbfgs.current_parameters" LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS = "optimizer.lbfgs.previous_parameters" -LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients" -LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients" -LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction" -LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values" -LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha" -LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys" -LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s" -LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y" - -LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model" +LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS = "optimizer.lbfgs.current_gradients" +LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS = "optimizer.lbfgs.previous_gradients" +LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION = "optimizer.lbfgs.search_direction" +LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES = "optimizer.lbfgs.past_loss_values" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA = "optimizer.lbfgs.memory_alpha" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS = "optimizer.lbfgs.memory_ys" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S = "optimizer.lbfgs.memory_s" +LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y = "optimizer.lbfgs.memory_y" + +LLM_KV_TRAINING_TYPE_TRAIN_MODEL = "train_model" LLM_KV_TRAINING_TYPE_FINETUNE_LORA = "finetune_lora" -LLM_KV_TRAINING_TYPE = "training.type" -LLM_KV_TRAINING_FILE_VERSION = "training.file_version" -LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count" -LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count" -LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count" +LLM_KV_TRAINING_TYPE = "training.type" +LLM_KV_TRAINING_FILE_VERSION = "training.file_version" +LLM_KV_TRAINING_ITERATION_COUNT = "training.iteration_count" +LLM_KV_TRAINING_SAMPLE_COUNT = "training.sample_count" +LLM_KV_TRAINING_TOKEN_COUNT = "training.token_count" + class Tensor: - def __init__(self, dtype='f', ne=None): + def __init__(self, dtype="f", ne=None): if ne is None: ne = [] self.dtype = dtype self.ne = ne self.nbytes = 0 - if self.dtype == 'f': + if self.dtype == "f": if len(self.ne) == 0: self.nbytes = 0 else: @@ -71,38 +72,45 @@ def __init__(self, dtype='f', ne=None): raise ValueError(f"Unhandled data type '{self.dtype}'") def load(self, data, offset): - nd = struct.unpack(' 0 else []) - - self.lbfgs_x = Tensor('f', [self.nx]) - self.lbfgs_xp = Tensor('f', [self.nx]) - self.lbfgs_g = Tensor('f', [self.nx]) - self.lbfgs_gp = Tensor('f', [self.nx]) - self.lbfgs_d = Tensor('f', [self.nx]) - self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else []) - self.lbfgs_lmal = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lmys = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m]) - self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m]) + self.adam_m = Tensor("f", [self.nx]) + self.adam_v = Tensor("f", [self.nx]) + self.adam_pf = Tensor("f", [self.past] if self.past > 0 else []) + + self.lbfgs_x = Tensor("f", [self.nx]) + self.lbfgs_xp = Tensor("f", [self.nx]) + self.lbfgs_g = Tensor("f", [self.nx]) + self.lbfgs_gp = Tensor("f", [self.nx]) + self.lbfgs_d = Tensor("f", [self.nx]) + self.lbfgs_pf = Tensor("f", [self.past] if self.past > 0 else []) + self.lbfgs_lmal = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lmys = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lms = Tensor("f", [self.nx, self.lbfgs_m]) + self.lbfgs_lmy = Tensor("f", [self.nx, self.lbfgs_m]) if self.type == 0: # these tensors are stored, but we don't need their data - x = Tensor('f', [self.nx]) - g = Tensor('f', [self.nx]) - g2 = Tensor('f', [self.nx]) - mh = Tensor('f', [self.nx]) - vh = Tensor('f', [self.nx]) + x = Tensor("f", [self.nx]) + g = Tensor("f", [self.nx]) + g2 = Tensor("f", [self.nx]) + mh = Tensor("f", [self.nx]) + vh = Tensor("f", [self.nx]) offset = x.load(data, offset) offset = g.load(data, offset) @@ -194,9 +240,18 @@ def load(self, data, offset): offset = vh.load(data, offset) offset = self.adam_pf.load(data, offset) - self.adam_fx_best = struct.unpack(' 0 else []) - - self.lbfgs_x = Tensor('f', [self.nx]) - self.lbfgs_xp = Tensor('f', [self.nx]) - self.lbfgs_g = Tensor('f', [self.nx]) - self.lbfgs_gp = Tensor('f', [self.nx]) - self.lbfgs_d = Tensor('f', [self.nx]) - self.lbfgs_pf = Tensor('f', [self.past] if self.past > 0 else []) - self.lbfgs_lmal = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lmys = Tensor('f', [self.lbfgs_m]) - self.lbfgs_lms = Tensor('f', [self.nx, self.lbfgs_m]) - self.lbfgs_lmy = Tensor('f', [self.nx, self.lbfgs_m]) + self.past = struct.unpack(" 0 else []) + + self.lbfgs_x = Tensor("f", [self.nx]) + self.lbfgs_xp = Tensor("f", [self.nx]) + self.lbfgs_g = Tensor("f", [self.nx]) + self.lbfgs_gp = Tensor("f", [self.nx]) + self.lbfgs_d = Tensor("f", [self.nx]) + self.lbfgs_pf = Tensor("f", [self.past] if self.past > 0 else []) + self.lbfgs_lmal = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lmys = Tensor("f", [self.lbfgs_m]) + self.lbfgs_lms = Tensor("f", [self.nx, self.lbfgs_m]) + self.lbfgs_lmy = Tensor("f", [self.nx, self.lbfgs_m]) # forgot to save type in version 1: # guess self.type from number of remaining bytes - size_type_0 = 12 + sum([t.max_storage_size() for t in - [self.adam_m, self.adam_v] - +([self.adam_pf] if (self.past > 0) else [])]) - size_type_1 = 24 + sum([t.max_storage_size() for t in - [self.lbfgs_x, self.lbfgs_xp, self.lbfgs_g, - self.lbfgs_gp, self.lbfgs_d, self.lbfgs_pf, - self.lbfgs_lmal, self.lbfgs_lmys, - self.lbfgs_lms, self.lbfgs_lmy] - +([self.lbfgs_pf] if (self.past > 0) else [])]) + size_type_0 = 12 + sum( + [ + t.max_storage_size() + for t in [self.adam_m, self.adam_v] + + ([self.adam_pf] if (self.past > 0) else []) + ] + ) + size_type_1 = 24 + sum( + [ + t.max_storage_size() + for t in [ + self.lbfgs_x, + self.lbfgs_xp, + self.lbfgs_g, + self.lbfgs_gp, + self.lbfgs_d, + self.lbfgs_pf, + self.lbfgs_lmal, + self.lbfgs_lmys, + self.lbfgs_lms, + self.lbfgs_lmy, + ] + + ([self.lbfgs_pf] if (self.past > 0) else []) + ] + ) # due to alignment padding the size might not by exact # but the difference in size for both types is significant, # so we can just use whichever is closest @@ -266,11 +357,20 @@ def load(self, data, offset): if self.type == 0: offset = self.adam_m.load(data, offset) offset = self.adam_v.load(data, offset) - offset = self.adam_pf.load(data,offset) + offset = self.adam_pf.load(data, offset) - self.adam_fx_best = struct.unpack(' 0: - self.adam_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES) + self.adam_pf.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_ADAM_PAST_LOSS_VALUES + ) elif self.type == 1: gguf_writer.add_string(LLM_KV_OPTIMIZER_TYPE, LLM_KV_OPTIMIZER_TYPE_LBFGS) - gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m) - gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best) - gguf_writer.add_float32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step) + gguf_writer.add_uint32( + LLM_KV_OPTIMIZER_LBFGS_APPROX_HESSIAN_COUNT, self.lbfgs_m + ) + gguf_writer.add_float32( + LLM_KV_OPTIMIZER_LBFGS_BEST_LOSS, self.lbfgs_fx_best + ) + gguf_writer.add_float32( + LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_STEP, self.lbfgs_step + ) gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_J, self.lbfgs_j) gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_K, self.lbfgs_k) - gguf_writer.add_int32(LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end) - gguf_writer.add_uint32(LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement) - - self.lbfgs_x.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS) - self.lbfgs_xp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS) - self.lbfgs_g.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS) - self.lbfgs_gp.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS) - self.lbfgs_d.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION) + gguf_writer.add_int32( + LLM_KV_OPTIMIZER_LBFGS_LINE_SEARCH_END, self.lbfgs_end + ) + gguf_writer.add_uint32( + LLM_KV_OPTIMIZER_LBFGS_NO_IMPROVEMENT_COUNT, self.lbfgs_n_no_improvement + ) + + self.lbfgs_x.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_PARAMETERS + ) + self.lbfgs_xp.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_PARAMETERS + ) + self.lbfgs_g.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_CURRENT_GRADIENTS + ) + self.lbfgs_gp.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PREVIOUS_GRADIENTS + ) + self.lbfgs_d.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_SEARCH_DIRECTION + ) if self.past > 0: - self.lbfgs_pf.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES) - self.lbfgs_lmal.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA) - self.lbfgs_lmys.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS) - self.lbfgs_lms.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S) - self.lbfgs_lmy.save_gguf(gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y) + self.lbfgs_pf.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_PAST_LOSS_VALUES + ) + self.lbfgs_lmal.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_ALPHA + ) + self.lbfgs_lmys.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_YS + ) + self.lbfgs_lms.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_S + ) + self.lbfgs_lmy.save_gguf( + gguf_writer, name=LLM_TENSOR_OPTIMIZER_LBFGS_MEMORY_Y + ) else: - raise ValueError('Unknown optimizer type') + raise ValueError("Unknown optimizer type") + class ModelParams: def __init__(self): pass def load(self, data, offset): - self.n_vocab = struct.unpack(' None: gguf_writer.close() -if __name__ == '__main__': +if __name__ == "__main__": writer_example() diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9459b477..dbe1d5be 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -8,8 +8,8 @@ # constants # -GGUF_MAGIC = 0x46554747 # "GGUF" -GGUF_VERSION = 3 +GGUF_MAGIC = 0x46554747 # "GGUF" +GGUF_VERSION = 3 GGUF_DEFAULT_ALIGNMENT = 32 # @@ -19,58 +19,58 @@ class Keys: class General: - ARCHITECTURE = "general.architecture" + ARCHITECTURE = "general.architecture" QUANTIZATION_VERSION = "general.quantization_version" - ALIGNMENT = "general.alignment" - NAME = "general.name" - AUTHOR = "general.author" - URL = "general.url" - DESCRIPTION = "general.description" - LICENSE = "general.license" - SOURCE_URL = "general.source.url" - SOURCE_HF_REPO = "general.source.huggingface.repository" - FILE_TYPE = "general.file_type" + ALIGNMENT = "general.alignment" + NAME = "general.name" + AUTHOR = "general.author" + URL = "general.url" + DESCRIPTION = "general.description" + LICENSE = "general.license" + SOURCE_URL = "general.source.url" + SOURCE_HF_REPO = "general.source.huggingface.repository" + FILE_TYPE = "general.file_type" class LLM: - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - BLOCK_COUNT = "{arch}.block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + BLOCK_COUNT = "{arch}.block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" class Attention: - HEAD_COUNT = "{arch}.attention.head_count" - HEAD_COUNT_KV = "{arch}.attention.head_count_kv" - MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" - CLAMP_KQV = "{arch}.attention.clamp_kqv" - LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" + HEAD_COUNT = "{arch}.attention.head_count" + HEAD_COUNT_KV = "{arch}.attention.head_count_kv" + MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" + CLAMP_KQV = "{arch}.attention.clamp_kqv" + LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - FREQ_BASE = "{arch}.rope.freq_base" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" class Tokenizer: - MODEL = "tokenizer.ggml.model" - LIST = "tokenizer.ggml.tokens" + MODEL = "tokenizer.ggml.model" + LIST = "tokenizer.ggml.tokens" TOKEN_TYPE = "tokenizer.ggml.token_type" - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + class PowerInfer: SPARSE_THRESHOLD = "powerinfer.sparse_threshold" @@ -84,94 +84,93 @@ class Split: class MODEL_ARCH(IntEnum): - LLAMA = auto() - FALCON = auto() - BAICHUAN = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - OPT = auto() - MPT = auto() + LLAMA = auto() + FALCON = auto() + BAICHUAN = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + OPT = auto() + MPT = auto() STARCODER = auto() PERSIMMON = auto() - REFACT = auto() - BERT = auto() - BLOOM = auto() - STABLELM = auto() - BAMBOO = auto() + REFACT = auto() + BERT = auto() + BLOOM = auto() + STABLELM = auto() + BAMBOO = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() + TOKEN_EMBD = auto() TOKEN_EMBD_NORM = auto() - TOKEN_TYPES = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_NORM = auto() - ATTN_Q_NORM = auto() - ATTN_K_NORM = auto() - FFN_DOWN_T = auto() - FC_1 = auto() - FC_2 = auto() - + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_NORM = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + FFN_DOWN_T = auto() + FC_1 = auto() + FC_2 = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.OPT: "opt", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.PERSIMMON: "persimmon", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.BAMBOO: "bamboo", + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.OPT: "opt", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.PERSIMMON: "persimmon", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.BAMBOO: "bamboo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", - MODEL_TENSOR.FFN_DOWN_T: "blk.{bid}.ffn_down_t", - MODEL_TENSOR.FC_1: "blk.{bid}.fc1", - MODEL_TENSOR.FC_2: "blk.{bid}.fc2", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_DOWN_T: "blk.{bid}.ffn_down_t", + MODEL_TENSOR.FC_1: "blk.{bid}.fc1", + MODEL_TENSOR.FC_2: "blk.{bid}.fc2", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -399,23 +398,23 @@ class MODEL_TENSOR(IntEnum): class TokenType(IntEnum): - NORMAL = 1 - UNKNOWN = 2 - CONTROL = 3 + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 USER_DEFINED = 4 - UNUSED = 5 - BYTE = 6 + UNUSED = 5 + BYTE = 6 class RopeScalingType(Enum): - NONE = 'none' - LINEAR = 'linear' - YARN = 'yarn' + NONE = "none" + LINEAR = "linear" + YARN = "yarn" class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 + F32 = 0 + F16 = 1 Q4_0 = 2 Q4_1 = 3 Q5_0 = 6 @@ -428,9 +427,9 @@ class GGMLQuantizationType(IntEnum): Q5_K = 13 Q6_K = 14 Q8_K = 15 - I8 = 16, + I8 = (16,) I16 = 17 - I32 = 18, + I32 = (18,) class GGUFEndian(IntEnum): @@ -439,18 +438,18 @@ class GGUFEndian(IntEnum): class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 - UINT64 = 10 - INT64 = 11 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 FLOAT64 = 12 @staticmethod @@ -475,8 +474,8 @@ def get_type(val: Any) -> GGUFValueType: QK_K = 256 # Items here are (block size, type size) GGML_QUANT_SIZES = { - GGMLQuantizationType.F32: (1, 4), - GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), GGMLQuantizationType.Q4_0: (32, 2 + 16), GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), @@ -495,52 +494,52 @@ def get_type(val: Any) -> GGUFValueType: # Aliases for backward compatibility. # general -KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE +KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION -KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT -KEY_GENERAL_NAME = Keys.General.NAME -KEY_GENERAL_AUTHOR = Keys.General.AUTHOR -KEY_GENERAL_URL = Keys.General.URL -KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION -KEY_GENERAL_LICENSE = Keys.General.LICENSE -KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL -KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO -KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE +KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT +KEY_GENERAL_NAME = Keys.General.NAME +KEY_GENERAL_AUTHOR = Keys.General.AUTHOR +KEY_GENERAL_URL = Keys.General.URL +KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION +KEY_GENERAL_LICENSE = Keys.General.LICENSE +KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL +KEY_GENERAL_SOURCE_HF_REPO = Keys.General.SOURCE_HF_REPO +KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE # LLM -KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH -KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH -KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT -KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH +KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH +KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH +KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT +KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL -KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT +KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT # attention -KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT -KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV -KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS -KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV -KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS +KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT +KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV +KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS +KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV +KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS # RoPE -KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT -KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE -KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE -KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN -KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED # tokenization -KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL -KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST +KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL +KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE -KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES -KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES -KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID -KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID -KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID -KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID -KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON -KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV +KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES +KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES +KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID +KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID +KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID +KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID +KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON +KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index 8682765e..30d175f3 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -63,59 +63,79 @@ class ReaderTensor(NamedTuple): class GGUFReader: # I - same as host, S - swapped - byte_order: Literal['I' | 'S'] = 'I' + byte_order: Literal["I" | "S"] = "I" alignment: int = GGUF_DEFAULT_ALIGNMENT # Note: Internal helper, API may change. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { - GGUFValueType.UINT8: np.uint8, - GGUFValueType.INT8: np.int8, - GGUFValueType.UINT16: np.uint16, - GGUFValueType.INT16: np.int16, - GGUFValueType.UINT32: np.uint32, - GGUFValueType.INT32: np.int32, + GGUFValueType.UINT8: np.uint8, + GGUFValueType.INT8: np.int8, + GGUFValueType.UINT16: np.uint16, + GGUFValueType.INT16: np.int16, + GGUFValueType.UINT32: np.uint32, + GGUFValueType.INT32: np.int32, GGUFValueType.FLOAT32: np.float32, - GGUFValueType.UINT64: np.uint64, - GGUFValueType.INT64: np.int64, + GGUFValueType.UINT64: np.uint64, + GGUFValueType.INT64: np.int64, GGUFValueType.FLOAT64: np.float64, - GGUFValueType.BOOL: np.bool_, + GGUFValueType.BOOL: np.bool_, } - def __init__(self, path: os.PathLike[str] | str, mode: Literal['r' | 'r+' | 'c'] = 'r'): - self.data = np.memmap(path, mode = mode) + def __init__( + self, path: os.PathLike[str] | str, mode: Literal["r" | "r+" | "c"] = "r" + ): + self.data = np.memmap(path, mode=mode) offs = 0 - if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: - raise ValueError('GGUF magic invalid') + if self._get(offs, np.uint32, override_order="<")[0] != GGUF_MAGIC: + raise ValueError("GGUF magic invalid") offs += 4 temp_version = self._get(offs, np.uint32) if temp_version[0] & 65535 == 0: # If we get 0 here that means it's (probably) a GGUF file created for # the opposite byte order of the machine this script is running on. - self.byte_order = 'S' + self.byte_order = "S" temp_version = temp_version.newbyteorder(self.byte_order) version = temp_version[0] if version not in READER_SUPPORTED_VERSIONS: - raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') + raise ValueError( + f"Sorry, file appears to be version {version} which we cannot handle" + ) self.fields: OrderedDict[str, ReaderField] = OrderedDict() self.tensors: list[ReaderTensor] = [] - offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + offs += self._push_field( + ReaderField( + offs, "GGUF.version", [temp_version], [0], [GGUFValueType.UINT32] + ) + ) temp_counts = self._get(offs, np.uint64, 2) - offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) - offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) + offs += self._push_field( + ReaderField( + offs, + "GGUF.tensor_count", + [temp_counts[:1]], + [0], + [GGUFValueType.UINT64], + ) + ) + offs += self._push_field( + ReaderField( + offs, "GGUF.kv_count", [temp_counts[1:]], [0], [GGUFValueType.UINT64] + ) + ) tensor_count, kv_count = temp_counts offs = self._build_fields(offs, kv_count) offs, tensors_fields = self._build_tensors_fields(offs, tensor_count) - new_align = self.fields.get('general.alignment') + new_align = self.fields.get("general.alignment") if new_align is not None: if new_align.types != [GGUFValueType.UINT64]: - raise ValueError('Bad type for general.alignment field') + raise ValueError("Bad type for general.alignment field") self.alignment = new_align.parts[-1][0] padding = offs % self.alignment if padding != 0: offs += self.alignment - padding self._build_tensors(offs, tensors_fields) - _DT = TypeVar('_DT', bound = npt.DTypeLike) + _DT = TypeVar("_DT", bound=npt.DTypeLike) # Fetch a key/value metadata field by key. def get_field(self, key: str) -> Union[ReaderField, None]: @@ -126,29 +146,39 @@ def get_tensor(self, idx: int) -> ReaderTensor: return self.tensors[idx] def _get( - self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I' | 'S' | '<'] = None, + self, + offset: int, + dtype: npt.DTypeLike, + count: int = 1, + override_order: None | Literal["I" | "S" | "<"] = None, ) -> npt.NDArray[Any]: count = int(count) - itemsize = int(np.empty([], dtype = dtype).itemsize) + itemsize = int(np.empty([], dtype=dtype).itemsize) end_offs = offset + itemsize * count return ( self.data[offset:end_offs] - .view(dtype = dtype)[:count] + .view(dtype=dtype)[:count] .newbyteorder(override_order or self.byte_order) ) def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: - raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') + raise KeyError( + f"Duplicate {field.name} already in list at offset {field.offset}" + ) self.fields[field.name] = field return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) - def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + def _get_str( + self, offset: int + ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: slen = self._get(offset, np.uint64) return slen, self._get(offset + 8, np.uint8, slen[0]) def _get_field_parts( - self, orig_offs: int, raw_type: int, + self, + orig_offs: int, + raw_type: int, ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: offs = orig_offs types: list[GGUFValueType] = [] @@ -173,7 +203,9 @@ def _get_field_parts( aparts: list[npt.NDArray[Any]] = [raw_itype, alen] data_idxs: list[int] = [] for idx in range(alen[0]): - curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) + curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts( + offs, raw_itype[0] + ) if idx == 0: types += curr_types idxs_offs = len(aparts) @@ -182,7 +214,7 @@ def _get_field_parts( offs += curr_size return offs - orig_offs, aparts, data_idxs, types # We can't deal with this one. - raise ValueError('Unknown/unhandled field type {gtype}') + raise ValueError(f"Unknown/unhandled field type {gtype}") def _get_tensor(self, orig_offs: int) -> ReaderField: offs = orig_offs @@ -198,7 +230,7 @@ def _get_tensor(self, orig_offs: int) -> ReaderField: offs += int(offset_tensor.nbytes) return ReaderField( orig_offs, - str(bytes(name_data), encoding = 'utf-8'), + str(bytes(name_data), encoding="utf-8"), [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], [1, 3, 4, 5], ) @@ -212,19 +244,26 @@ def _build_fields(self, offs: int, count: int) -> int: offs += int(raw_kv_type.nbytes) parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] idxs_offs = len(parts) - field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) + field_size, field_parts, field_idxs, field_types = self._get_field_parts( + offs, raw_kv_type[0] + ) parts += field_parts - self._push_field(ReaderField( - orig_offs, - str(bytes(kv_kdata), encoding = 'utf-8'), - parts, - [idx + idxs_offs for idx in field_idxs], - field_types, - ), skip_sum = True) + self._push_field( + ReaderField( + orig_offs, + str(bytes(kv_kdata), encoding="utf-8"), + parts, + [idx + idxs_offs for idx in field_idxs], + field_types, + ), + skip_sum=True, + ) offs += field_size return offs - def _build_tensors_fields(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + def _build_tensors_fields( + self, offs: int, count: int + ) -> tuple[int, list[ReaderField]]: tensor_fields = [] for _ in range(count): field = self._get_tensor(offs) @@ -251,14 +290,16 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: else: item_count = n_bytes item_type = np.uint8 - tensors.append(ReaderTensor( - name = str(bytes(name_data), encoding = 'utf-8'), - tensor_type = ggml_type, - shape = dims, - n_elements = n_elems, - n_bytes = n_bytes, - data_offset = data_offs, - data = self._get(data_offs, item_type, item_count), - field = field, - )) + tensors.append( + ReaderTensor( + name=str(bytes(name_data), encoding="utf-8"), + tensor_type=ggml_type, + shape=dims, + n_elements=n_elems, + n_bytes=n_bytes, + data_offset=data_offs, + data=self._get(data_offs, item_type, item_count), + field=field, + ) + ) self.tensors = tensors diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 0483d7ba..3f7ac8ee 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -24,8 +24,8 @@ class WriterState(Enum): - EMPTY = auto() - HEADER = auto() + EMPTY = auto() + HEADER = auto() KV_DATA = auto() TI_DATA = auto() @@ -35,21 +35,24 @@ class GGUFWriter: temp_file: tempfile.SpooledTemporaryFile[bytes] | None tensors: list[np.ndarray[Any, Any]] _simple_value_packing = { - GGUFValueType.UINT8: "B", - GGUFValueType.INT8: "b", - GGUFValueType.UINT16: "H", - GGUFValueType.INT16: "h", - GGUFValueType.UINT32: "I", - GGUFValueType.INT32: "i", + GGUFValueType.UINT8: "B", + GGUFValueType.INT8: "b", + GGUFValueType.UINT16: "H", + GGUFValueType.INT16: "h", + GGUFValueType.UINT32: "I", + GGUFValueType.INT32: "i", GGUFValueType.FLOAT32: "f", - GGUFValueType.UINT64: "Q", - GGUFValueType.INT64: "q", + GGUFValueType.UINT64: "Q", + GGUFValueType.INT64: "q", GGUFValueType.FLOAT64: "d", - GGUFValueType.BOOL: "?", + GGUFValueType.BOOL: "?", } def __init__( - self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool = True, + self, + path: os.PathLike[str] | str, + arch: str, + use_temp_file: bool = True, endianess: GGUFEndian = GGUFEndian.LITTLE, ): self.fout = open(path, "wb") @@ -64,18 +67,20 @@ def __init__( self.use_temp_file = use_temp_file self.temp_file = None self.tensors = [] - print("gguf: This GGUF file is for {0} Endian only".format( - "Big" if self.endianess == GGUFEndian.BIG else "Little", - )) + print( + "gguf: This GGUF file is for {0} Endian only".format( + "Big" if self.endianess == GGUFEndian.BIG else "Little", + ) + ) self.state = WriterState.EMPTY self.add_architecture() def write_header_to_file(self) -> None: if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + raise ValueError(f"Expected output file to be empty, got {self.state}") - self._write_packed(" None: def write_kv_data_to_file(self) -> None: if self.state is not WriterState.HEADER: - raise ValueError(f'Expected output file to contain the header, got {self.state}') + raise ValueError( + f"Expected output file to contain the header, got {self.state}" + ) self.fout.write(self.kv_data) self.flush() @@ -92,7 +99,9 @@ def write_kv_data_to_file(self) -> None: def write_ti_data_to_file(self) -> None: if self.state is not WriterState.KV_DATA: - raise ValueError(f'Expected output file to contain KV data, got {self.state}') + raise ValueError( + f"Expected output file to contain KV data, got {self.state}" + ) self.fout.write(self.ti_data) self.flush() @@ -158,7 +167,9 @@ def add_array(self, key: str, val: Sequence[Any]) -> None: self.add_key(key) self.add_val(val, GGUFValueType.ARRAY) - def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True) -> None: + def add_val( + self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True + ) -> None: if vtype is None: vtype = GGUFValueType.get_type(val) @@ -168,7 +179,9 @@ def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool pack_fmt = self._simple_value_packing.get(vtype) if pack_fmt is not None: - self.kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + self.kv_data += self._pack( + pack_fmt, val, skip_pack_prefix=vtype == GGUFValueType.BOOL + ) elif vtype == GGUFValueType.STRING: encoded_val = val.encode("utf8") if isinstance(val, str) else val self.kv_data += self._pack("Q", len(encoded_val)) @@ -189,11 +202,15 @@ def ggml_pad(x: int, n: int) -> int: return ((x + n - 1) // n) * n def add_tensor_info( - self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], - tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, + self, + name: str, + tensor_shape: Sequence[int], + tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], + tensor_nbytes: int, + raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + raise ValueError(f"Expected output file to be empty, got {self.state}") if raw_dtype is None and tensor_dtype not in (np.float32, np.float16): raise ValueError("Only F32 and F16 tensors are supported for now") @@ -206,7 +223,11 @@ def add_tensor_info( for i in range(n_dims): self.ti_data += self._pack("Q", tensor_shape[n_dims - 1 - i]) if raw_dtype is None: - dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16 + dtype = ( + GGMLQuantizationType.F32 + if tensor_dtype == np.float32 + else GGMLQuantizationType.F16 + ) else: dtype = raw_dtype self.ti_data += self._pack("I", dtype) @@ -215,18 +236,23 @@ def add_tensor_info( self.ti_data_count += 1 def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + self, + name: str, + tensor: np.ndarray[Any, Any], + raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) if self.use_temp_file and self.temp_file is None: - fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024) + fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp.seek(0) self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype) + self.add_tensor_info( + name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype + ) if self.temp_file is None: self.tensors.append(tensor) @@ -236,13 +262,18 @@ def add_tensor( self.write_padding(self.temp_file, tensor.nbytes) def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None: - pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n + pad = ( + GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) + - n + ) if pad != 0: fp.write(bytes([0] * pad)) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: if self.state is not WriterState.TI_DATA: - raise ValueError(f'Expected output file to contain tensor info, got {self.state}') + raise ValueError( + f"Expected output file to contain tensor info, got {self.state}" + ) if self.endianess == GGUFEndian.BIG: tensor.byteswap(inplace=True) @@ -304,9 +335,10 @@ def add_file_type(self, ftype: int) -> None: def add_name(self, name: str) -> None: self.add_string(Keys.General.NAME, name) - def add_quantization_version(self, quantization_version: GGMLQuantizationType) -> None: - self.add_uint32( - Keys.General.QUANTIZATION_VERSION, quantization_version) + def add_quantization_version( + self, quantization_version: GGMLQuantizationType + ) -> None: + self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version) def add_custom_alignment(self, alignment: int) -> None: self.data_alignment = alignment @@ -366,10 +398,14 @@ def add_rope_scaling_finetuned(self, value: bool) -> None: def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) - def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + def add_token_list( + self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray] + ) -> None: self.add_array(Keys.Tokenizer.LIST, tokens) - def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + def add_token_merges( + self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray] + ) -> None: self.add_array(Keys.Tokenizer.MERGES, merges) def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: @@ -403,10 +439,12 @@ def add_sparse_threshold(self, value: float) -> None: self.add_float32(Keys.PowerInfer.SPARSE_THRESHOLD, value) def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: - pack_prefix = '' + pack_prefix = "" if not skip_pack_prefix: - pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' - return struct.pack(f'{pack_prefix}{fmt}', value) + pack_prefix = "<" if self.endianess == GGUFEndian.LITTLE else ">" + return struct.pack(f"{pack_prefix}{fmt}", value) - def _write_packed(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> None: + def _write_packed( + self, fmt: str, value: Any, skip_pack_prefix: bool = False + ) -> None: self.fout.write(self._pack(fmt, value, skip_pack_prefix)) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 641b81f0..24327950 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -9,210 +9,176 @@ class TensorNameMap: mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( - "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact - "decoder.embed_tokens", # opt - "transformer.word_embeddings", # falcon - "word_embeddings", # bloom - "model.embed_tokens", # llama-hf - "tok_embeddings", # llama-pth - "embeddings.word_embeddings", # bert + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 gpt-j mpt refact + "decoder.embed_tokens", # opt + "transformer.word_embeddings", # falcon + "word_embeddings", # bloom + "model.embed_tokens", # llama-hf + "tok_embeddings", # llama-pth + "embeddings.word_embeddings", # bert "language_model.embedding.word_embeddings", # persimmon ), - # Token type embeddings - MODEL_TENSOR.TOKEN_TYPES: ( - "embeddings.token_type_embeddings", # bert - ), - + MODEL_TENSOR.TOKEN_TYPES: ("embeddings.token_type_embeddings",), # bert # Normalization of token embeddings - MODEL_TENSOR.TOKEN_EMBD_NORM: ( - "word_embeddings_layernorm", # bloom - ), - + MODEL_TENSOR.TOKEN_EMBD_NORM: ("word_embeddings_layernorm",), # bloom # Position embeddings MODEL_TENSOR.POS_EMBD: ( - "transformer.wpe", # gpt2 + "transformer.wpe", # gpt2 "embeddings.position_embeddings", # bert - "decoder.embed_positions", # opt + "decoder.embed_positions", # opt ), - # Output MODEL_TENSOR.OUTPUT: ( - "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan - "output", # llama-pth bloom + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf baichuan + "output", # llama-pth bloom "word_embeddings_for_head", # persimmon ), - # Output norm MODEL_TENSOR.OUTPUT_NORM: ( - "gpt_neox.final_layer_norm", # gptneox - "transformer.ln_f", # gpt2 gpt-j falcon - "decoder.final_layer_norm", # opt - "model.norm", # llama-hf baichuan - "norm", # llama-pth - "embeddings.LayerNorm", # bert - "transformer.norm_f", # mpt - "ln_f", # refact bloom + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 gpt-j falcon + "decoder.final_layer_norm", # opt + "model.norm", # llama-hf baichuan + "norm", # llama-pth + "embeddings.LayerNorm", # bert + "transformer.norm_f", # mpt + "ln_f", # refact bloom "language_model.encoder.final_layernorm", # persimmon ), - # Rope frequencies - MODEL_TENSOR.ROPE_FREQS: ( - "rope.freqs", # llama-pth - ), + MODEL_TENSOR.ROPE_FREQS: ("rope.freqs",), # llama-pth } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Attention norm MODEL_TENSOR.ATTN_NORM: ( - "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact - "decoder.layers.{bid}.self_attn_layer_norm", # opt - "transformer.blocks.{bid}.norm_1", # mpt - "transformer.h.{bid}.input_layernorm", # falcon7b - "h.{bid}.input_layernorm", # bloom - "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf - "layers.{bid}.attention_norm", # llama-pth - "encoder.layer.{bid}.attention.output.LayerNorm", # bert + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact + "decoder.layers.{bid}.self_attn_layer_norm", # opt + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "h.{bid}.input_layernorm", # bloom + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf + "layers.{bid}.attention_norm", # llama-pth + "encoder.layer.{bid}.attention.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.input_layernorm", # persimmon - "model.layers.{bid}.ln1", # yi + "model.layers.{bid}.ln1", # yi ), - # Attention norm 2 - MODEL_TENSOR.ATTN_NORM_2: ( - "transformer.h.{bid}.ln_attn", # falcon40b - ), - + MODEL_TENSOR.ATTN_NORM_2: ("transformer.h.{bid}.ln_attn",), # falcon40b # Attention query-key-value MODEL_TENSOR.ATTN_QKV: ( - "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox - "transformer.h.{bid}.attn.c_attn", # gpt2 - "transformer.blocks.{bid}.attn.Wqkv", # mpt - "transformer.h.{bid}.self_attention.query_key_value", # falcon - "h.{bid}.self_attention.query_key_value", # bloom + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.h.{bid}.self_attention.query_key_value", # falcon + "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon ), - # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf - "layers.{bid}.attention.wq", # llama-pth + "model.layers.{bid}.self_attn.q_proj", # llama-hf + "layers.{bid}.attention.wq", # llama-pth "encoder.layer.{bid}.attention.self.query", # bert - "transformer.h.{bid}.attn.q_proj", # gpt-j - "decoder.layers.{bid}.self_attn.q_proj", # opt + "transformer.h.{bid}.attn.q_proj", # gpt-j + "decoder.layers.{bid}.self_attn.q_proj", # opt ), - # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf - "layers.{bid}.attention.wk", # llama-pth + "model.layers.{bid}.self_attn.k_proj", # llama-hf + "layers.{bid}.attention.wk", # llama-pth "encoder.layer.{bid}.attention.self.key", # bert - "transformer.h.{bid}.attn.k_proj", # gpt-j - "decoder.layers.{bid}.self_attn.k_proj", # opt + "transformer.h.{bid}.attn.k_proj", # gpt-j + "decoder.layers.{bid}.self_attn.k_proj", # opt ), - # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf - "layers.{bid}.attention.wv", # llama-pth + "model.layers.{bid}.self_attn.v_proj", # llama-hf + "layers.{bid}.attention.wv", # llama-pth "encoder.layer.{bid}.attention.self.value", # bert - "transformer.h.{bid}.attn.v_proj", # gpt-j - "decoder.layers.{bid}.self_attn.v_proj", # opt + "transformer.h.{bid}.attn.v_proj", # gpt-j + "decoder.layers.{bid}.self_attn.v_proj", # opt ), - # Attention output MODEL_TENSOR.ATTN_OUT: ( - "gpt_neox.layers.{bid}.attention.dense", # gptneox - "transformer.h.{bid}.attn.c_proj", # gpt2 refact - "decoder.layers.{bid}.self_attn.out_proj", # opt - "transformer.blocks.{bid}.attn.out_proj", # mpt - "transformer.h.{bid}.self_attention.dense", # falcon - "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf - "layers.{bid}.attention.wo", # llama-pth - "encoder.layer.{bid}.attention.output.dense", # bert - "transformer.h.{bid}.attn.out_proj", # gpt-j + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact + "decoder.layers.{bid}.self_attn.out_proj", # opt + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom + "model.layers.{bid}.self_attn.o_proj", # llama-hf + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon ), - # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( - "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth ), - # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( - "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 refact - "decoder.layers.{bid}.final_layer_norm", # opt - "h.{bid}.post_attention_layernorm", # bloom - "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf - "layers.{bid}.ffn_norm", # llama-pth - "encoder.layer.{bid}.output.LayerNorm", # bert + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 refact + "decoder.layers.{bid}.final_layer_norm", # opt + "h.{bid}.post_attention_layernorm", # bloom + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf + "layers.{bid}.ffn_norm", # llama-pth + "encoder.layer.{bid}.output.LayerNorm", # bert "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon - "model.layers.{bid}.ln2", # yi + "model.layers.{bid}.ln2", # yi ), - # Feed-forward up MODEL_TENSOR.FFN_UP: ( - "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox - "transformer.h.{bid}.mlp.c_fc", # gpt2 - "decoder.layers.{bid}.fc1", # opt - "transformer.blocks.{bid}.ffn.up_proj", # mpt - "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon - "h.{bid}.mlp.dense_h_to_4h", # bloom - "model.layers.{bid}.mlp.up_proj", # llama-hf refact - "layers.{bid}.feed_forward.w3", # llama-pth - "encoder.layer.{bid}.intermediate.dense", # bert - "transformer.h.{bid}.mlp.fc_in", # gpt-j + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 + "decoder.layers.{bid}.fc1", # opt + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "h.{bid}.mlp.dense_h_to_4h", # bloom + "model.layers.{bid}.mlp.up_proj", # llama-hf refact + "layers.{bid}.feed_forward.w3", # llama-pth + "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.h.{bid}.mlp.fc_in", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon ), - # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( "model.layers.{bid}.mlp.gate_proj", # llama-hf refact - "layers.{bid}.feed_forward.w1", # llama-pth + "layers.{bid}.feed_forward.w1", # llama-pth ), - # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( - "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox - "transformer.h.{bid}.mlp.c_proj", # gpt2 refact - "decoder.layers.{bid}.fc2", # opt - "transformer.blocks.{bid}.ffn.down_proj", # mpt - "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon - "h.{bid}.mlp.dense_4h_to_h", # bloom - "model.layers.{bid}.mlp.down_proj", # llama-hf - "layers.{bid}.feed_forward.w2", # llama-pth - "encoder.layer.{bid}.output.dense", # bert - "transformer.h.{bid}.mlp.fc_out", # gpt-j + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact + "decoder.layers.{bid}.fc2", # opt + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "h.{bid}.mlp.dense_4h_to_h", # bloom + "model.layers.{bid}.mlp.down_proj", # llama-hf + "layers.{bid}.feed_forward.w2", # llama-pth + "encoder.layer.{bid}.output.dense", # bert + "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon ), - MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", ), - MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", ), - MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), - - MODEL_TENSOR.FC_1: ( - "model.layers.{bid}.fc1", - ), - - MODEL_TENSOR.FC_2: ( - "model.layers.{bid}.fc2", - ), + MODEL_TENSOR.FC_1: ("model.layers.{bid}.fc1",), + MODEL_TENSOR.FC_2: ("model.layers.{bid}.fc2",), } mapping: dict[str, tuple[MODEL_TENSOR, str]] @@ -230,31 +196,35 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): for tensor, keys in self.block_mappings_cfg.items(): if tensor not in MODEL_TENSORS[arch]: continue - tensor_name = TENSOR_NAMES[tensor].format(bid = bid) + tensor_name = TENSOR_NAMES[tensor].format(bid=bid) self.mapping[tensor_name] = (tensor, tensor_name) for key in keys: - key = key.format(bid = bid) + key = key.format(bid=bid) self.mapping[key] = (tensor, tensor_name) - def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: + def get_type_and_name( + self, key: str, try_suffixes: Sequence[str] = () + ) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key) if result is not None: return result for suffix in try_suffixes: if key.endswith(suffix): - result = self.mapping.get(key[:-len(suffix)]) + result = self.mapping.get(key[: -len(suffix)]) if result is not None: return result[0], result[1] + suffix return None def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: - result = self.get_type_and_name(key, try_suffixes = try_suffixes) + result = self.get_type_and_name(key, try_suffixes=try_suffixes) if result is None: return None return result[1] - def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: - result = self.get_type_and_name(key, try_suffixes = try_suffixes) + def get_type( + self, key: str, try_suffixes: Sequence[str] = () + ) -> MODEL_TENSOR | None: + result = self.get_type_and_name(key, try_suffixes=try_suffixes) if result is None: return None return result[0] diff --git a/gguf-py/gguf/vocab.py b/gguf-py/gguf/vocab.py index 71192a92..4b5c6fae 100644 --- a/gguf-py/gguf/vocab.py +++ b/gguf-py/gguf/vocab.py @@ -15,7 +15,9 @@ class SpecialVocab: special_token_ids: dict[str, int] def __init__( - self, path: str | os.PathLike[str], load_merges: bool = False, + self, + path: str | os.PathLike[str], + load_merges: bool = False, special_token_types: tuple[str, ...] | None = None, n_vocab: int | None = None, ): @@ -27,45 +29,51 @@ def __init__( if special_token_types is not None: self.special_token_types = special_token_types else: - self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad') + self.special_token_types = ("bos", "eos", "unk", "sep", "pad") self._load(Path(path)) def __repr__(self) -> str: - return ''.format( - len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", + return "".format( + len(self.merges), + self.special_token_ids or "unset", + self.add_special_token or "unset", ) def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: if self.merges: if not quiet: - print(f'gguf: Adding {len(self.merges)} merge(s).') + print(f"gguf: Adding {len(self.merges)} merge(s).") gw.add_token_merges(self.merges) elif self.load_merges: print( - 'gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.', - file = sys.stderr, + "gguf: WARNING: Adding merges requested but no merges found, output may be non-functional.", + file=sys.stderr, ) for typ, tokid in self.special_token_ids.items(): - id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) + id_handler: Callable[[int], None] | None = getattr( + gw, f"add_{typ}_token_id", None + ) if id_handler is None: print( - f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping', - file = sys.stderr, + f"gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping", + file=sys.stderr, ) continue if not quiet: - print(f'gguf: Setting special token type {typ} to {tokid}') + print(f"gguf: Setting special token type {typ} to {tokid}") id_handler(tokid) for typ, value in self.add_special_token.items(): - add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) + add_handler: Callable[[bool], None] | None = getattr( + gw, f"add_add_{typ}_token", None + ) if add_handler is None: print( - f'gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping', - file = sys.stderr, + f"gguf: WARNING: No handler for add_{typ}_token with value {value} - skipping", + file=sys.stderr, ) continue if not quiet: - print(f'gguf: Setting add_{typ}_token to {value}') + print(f"gguf: Setting add_{typ}_token to {value}") add_handler(value) def _load(self, path: Path) -> None: @@ -75,12 +83,12 @@ def _load(self, path: Path) -> None: self._try_load_merges_txt(path) def _try_load_merges_txt(self, path: Path) -> bool: - merges_file = path / 'merges.txt' + merges_file = path / "merges.txt" if not merges_file.is_file(): return False - with open(merges_file, 'r') as fp: - first_line = next(fp, '').strip() - if not first_line.startswith('#'): + with open(merges_file, "r") as fp: + first_line = next(fp, "").strip() + if not first_line.startswith("#"): fp.seek(0) line_num = 0 else: @@ -94,11 +102,11 @@ def _try_load_merges_txt(self, path: Path) -> bool: parts = line.split(None, 3) if len(parts) != 2: print( - f'gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring', - file = sys.stderr, + f"gguf: WARNING: {merges_file.name}: Line {line_num}: Entry malformed, ignoring", + file=sys.stderr, ) continue - merges.append(f'{parts[0]} {parts[1]}') + merges.append(f"{parts[0]} {parts[1]}") self.merges = merges return True @@ -111,35 +119,35 @@ def _set_special_token(self, typ: str, tid: Any) -> None: self.special_token_ids[typ] = tid return print( - f'gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping', - file = sys.stderr, + f"gguf: WARNING: Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping", + file=sys.stderr, ) def _try_load_from_tokenizer_json(self, path: Path) -> bool: - tokenizer_file = path / 'tokenizer.json' + tokenizer_file = path / "tokenizer.json" if not tokenizer_file.is_file(): return False - with open(tokenizer_file, encoding = 'utf-8') as f: + with open(tokenizer_file, encoding="utf-8") as f: tokenizer = json.load(f) if self.load_merges: - merges = tokenizer.get('model', {}).get('merges') + merges = tokenizer.get("model", {}).get("merges") if isinstance(merges, list) and merges and isinstance(merges[0], str): self.merges = merges - tokenizer_config_file = path / 'tokenizer_config.json' - added_tokens = tokenizer.get('added_tokens') + tokenizer_config_file = path / "tokenizer_config.json" + added_tokens = tokenizer.get("added_tokens") if added_tokens is None or not tokenizer_config_file.is_file(): return True - with open(tokenizer_config_file, encoding = 'utf-8') as f: + with open(tokenizer_config_file, encoding="utf-8") as f: tokenizer_config = json.load(f) for typ in self.special_token_types: - add_entry = tokenizer_config.get(f'add_{typ}_token') + add_entry = tokenizer_config.get(f"add_{typ}_token") if isinstance(add_entry, bool): self.add_special_token[typ] = add_entry - entry = tokenizer_config.get(f'{typ}_token') + entry = tokenizer_config.get(f"{typ}_token") if isinstance(entry, str): tc_content = entry elif isinstance(entry, dict): - entry_content = entry.get('content') + entry_content = entry.get("content") if not isinstance(entry_content, str): continue tc_content = entry_content @@ -147,18 +155,22 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: continue # We only need the first match here. maybe_token_id = next( - (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), + ( + atok.get("id") + for atok in added_tokens + if atok.get("content") == tc_content + ), None, ) self._set_special_token(typ, maybe_token_id) return True def _try_load_from_config_json(self, path: Path) -> bool: - config_file = path / 'config.json' + config_file = path / "config.json" if not config_file.is_file(): return False - with open(config_file, encoding = 'utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = json.load(f) for typ in self.special_token_types: - self._set_special_token(typ, config.get(f'{typ}_token_id')) + self._set_special_token(typ, config.get(f"{typ}_token_id")) return True diff --git a/gguf-py/scripts/__init__.py b/gguf-py/scripts/__init__.py index 77132db7..220bb9b7 100644 --- a/gguf-py/scripts/__init__.py +++ b/gguf-py/scripts/__init__.py @@ -6,7 +6,7 @@ os.environ["NO_LOCAL_GGUF"] = "TRUE" gguf_convert_endian_entrypoint = import_module("scripts.gguf-convert-endian").main -gguf_dump_entrypoint = import_module("scripts.gguf-dump").main -gguf_set_metadata_entrypoint = import_module("scripts.gguf-set-metadata").main +gguf_dump_entrypoint = import_module("scripts.gguf-dump").main +gguf_set_metadata_entrypoint = import_module("scripts.gguf-set-metadata").main del import_module, os diff --git a/gguf-py/scripts/gguf-convert-endian.py b/gguf-py/scripts/gguf-convert-endian.py index 10a16ad0..fe04d510 100755 --- a/gguf-py/scripts/gguf-convert-endian.py +++ b/gguf-py/scripts/gguf-convert-endian.py @@ -9,7 +9,10 @@ import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent)) import gguf @@ -29,7 +32,9 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None else: file_endian = host_endian order = host_endian if args.order == "native" else args.order - print(f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian") + print( + f"* Host is {host_endian.upper()} endian, GGUF file seems to be {file_endian.upper()} endian" + ) if file_endian == order: print(f"* File is already {order.upper()} endian. Nothing to do.") sys.exit(0) @@ -40,23 +45,33 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.Q8_0, ): - raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") + raise ValueError( + f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}" + ) print(f"* Preparing to convert from {file_endian.upper()} to {order.upper()}") if args.dry_run: return print("\n*** Warning *** Warning *** Warning **") print("* This conversion process may damage the file. Ensure you have a backup.") if order != host_endian: - print("* Requested endian differs from host, you will not be able to load the model on this machine.") - print("* The file will be modified immediately, so if conversion fails or is interrupted") - print("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:") + print( + "* Requested endian differs from host, you will not be able to load the model on this machine." + ) + print( + "* The file will be modified immediately, so if conversion fails or is interrupted" + ) + print( + "* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:" + ) response = input("YES, I am sure> ") if response != "YES": print("You didn't enter YES. Okay then, see ya!") sys.exit(0) print(f"\n* Converting fields ({len(reader.fields)})") for idx, field in enumerate(reader.fields.values()): - print(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}") + print( + f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}" + ) for part in field.parts: part.byteswap(inplace=True) print(f"\n* Converting tensors ({len(reader.tensors)})") @@ -79,7 +94,7 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None for block_num in range(n_blocks): block_offs = block_num * block_size # I know I said f16, but it doesn't matter here - any simple 16 bit type works. - delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta = tensor.data[block_offs : block_offs + 2].view(dtype=np.uint16) delta.byteswap(inplace=True) if block_num % 100000 == 0: print(f"[{(n_blocks - block_num) // 1000}K]", end="") @@ -91,20 +106,24 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None def main() -> None: parser = argparse.ArgumentParser(description="Convert GGUF file byte order") parser.add_argument( - "model", type=str, + "model", + type=str, help="GGUF format model filename", ) parser.add_argument( - "order", type=str, choices=['big', 'little', 'native'], + "order", + type=str, + choices=["big", "little", "native"], help="Requested byte order", ) parser.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Don't actually change anything", ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) - print(f'* Loading: {args.model}') - reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+') + print(f"* Loading: {args.model}") + reader = gguf.GGUFReader(args.model, "r" if args.dry_run else "r+") convert_byteorder(reader, args) diff --git a/gguf-py/scripts/gguf-dump.py b/gguf-py/scripts/gguf-dump.py index 5141873d..49f5380e 100755 --- a/gguf-py/scripts/gguf-dump.py +++ b/gguf-py/scripts/gguf-dump.py @@ -10,16 +10,19 @@ import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent)) from gguf import GGUFReader, GGUFValueType # noqa: E402 def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: - host_endian = 'LITTLE' if np.uint32(1) == np.uint32(1).newbyteorder("<") else 'BIG' - if reader.byte_order == 'S': - file_endian = 'BIG' if host_endian == 'LITTLE' else 'LITTLE' + host_endian = "LITTLE" if np.uint32(1) == np.uint32(1).newbyteorder("<") else "BIG" + if reader.byte_order == "S": + file_endian = "BIG" if host_endian == "LITTLE" else "LITTLE" else: file_endian = host_endian return (host_endian, file_endian) @@ -29,34 +32,49 @@ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: # please see the comments in the modify_gguf.py example. def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: host_endian, file_endian = get_file_host_endian(reader) - print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') - print(f'\n* Dumping {len(reader.fields)} key/value pair(s)') + print( + f"* File is {file_endian} endian, script is running on a {host_endian} endian host." + ) + print(f"\n* Dumping {len(reader.fields)} key/value pair(s)") for n, field in enumerate(reader.fields.values(), 1): if not field.types: - pretty_type = 'N/A' + pretty_type = "N/A" elif field.types[0] == GGUFValueType.ARRAY: nest_count = len(field.types) - 1 - pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + pretty_type = ( + "[" * nest_count + str(field.types[-1].name) + "]" * nest_count + ) else: pretty_type = str(field.types[-1].name) - print(f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}', end = '') + print(f" {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}", end="") if len(field.types) == 1: curr_type = field.types[0] if curr_type == GGUFValueType.STRING: - print(' = {0}'.format(repr(str(bytes(field.parts[-1]), encoding='utf8')[:60])), end = '') + print( + " = {0}".format( + repr(str(bytes(field.parts[-1]), encoding="utf8")[:60]) + ), + end="", + ) elif field.types[0] in reader.gguf_scalar_to_np: - print(' = {0}'.format(field.parts[-1][0]), end = '') + print(" = {0}".format(field.parts[-1][0]), end="") print() if args.no_tensors: return - print(f'\n* Dumping {len(reader.tensors)} tensor(s)') + print(f"\n* Dumping {len(reader.tensors)} tensor(s)") for n, tensor in enumerate(reader.tensors, 1): - prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape))) - print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') + prettydims = ", ".join( + "{0:5}".format(d) + for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)) + ) + print( + f" {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}" + ) def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: import json + host_endian, file_endian = get_file_host_endian(reader) metadata: dict[str, Any] = {} tensors: dict[str, Any] = {} @@ -69,7 +87,7 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: for idx, field in enumerate(reader.fields.values()): curr: dict[str, Any] = { "index": idx, - "type": field.types[0].name if field.types else 'UNKNOWN', + "type": field.types[0].name if field.types else "UNKNOWN", "offset": field.offset, } metadata[field.name] = curr @@ -79,9 +97,13 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: continue itype = field.types[-1] if itype == GGUFValueType.STRING: - curr["value"] = [str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data] + curr["value"] = [ + str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data + ] else: - curr["value"] = [pv for idx in field.data for pv in field.parts[idx].tolist()] + curr["value"] = [ + pv for idx in field.data for pv in field.parts[idx].tolist() + ] elif field.types[0] == GGUFValueType.STRING: curr["value"] = str(bytes(field.parts[-1]), encoding="utf-8") else: @@ -98,19 +120,25 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: def main() -> None: parser = argparse.ArgumentParser(description="Dump GGUF file metadata") - parser.add_argument("model", type=str, help="GGUF format model filename") - parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") - parser.add_argument("--json", action="store_true", help="Produce JSON output") - parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument( + "--no-tensors", action="store_true", help="Don't dump tensor metadata" + ) + parser.add_argument("--json", action="store_true", help="Produce JSON output") + parser.add_argument( + "--json-array", + action="store_true", + help="Include full array values in JSON output (long)", + ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) if not args.json: - print(f'* Loading: {args.model}') - reader = GGUFReader(args.model, 'r') + print(f"* Loading: {args.model}") + reader = GGUFReader(args.model, "r") if args.json: dump_metadata_json(reader, args) else: dump_metadata(reader, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/gguf-py/scripts/gguf-set-metadata.py b/gguf-py/scripts/gguf-set-metadata.py index 3ebdfa89..07f3e1c1 100755 --- a/gguf-py/scripts/gguf-set-metadata.py +++ b/gguf-py/scripts/gguf-set-metadata.py @@ -5,15 +5,18 @@ from pathlib import Path # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent)) from gguf import GGUFReader # noqa: E402 def minimal_example(filename: str) -> None: - reader = GGUFReader(filename, 'r+') - field = reader.fields['tokenizer.ggml.bos_token_id'] + reader = GGUFReader(filename, "r+") + field = reader.fields["tokenizer.ggml.bos_token_id"] if field is None: return part_index = field.data[0] @@ -41,7 +44,7 @@ def minimal_example(filename: str) -> None: def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: field = reader.get_field(args.key) if field is None: - print(f'! Field {repr(args.key)} not found', file = sys.stderr) + print(f"! Field {repr(args.key)} not found", file=sys.stderr) sys.exit(1) # Note that field.types is a list of types. This is because the GGUF # format supports arrays. For example, an array of UINT32 would @@ -49,42 +52,52 @@ def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None if handler is None: print( - f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}', - file = sys.stderr, + f"! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}", + file=sys.stderr, ) sys.exit(1) current_value = field.parts[field.data[0]][0] new_value = handler(args.value) - print(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}') + print( + f"* Preparing to change field {repr(args.key)} from {current_value} to {new_value}" + ) if current_value == new_value: - print(f'- Key {repr(args.key)} already set to requested value {current_value}') + print(f"- Key {repr(args.key)} already set to requested value {current_value}") sys.exit(0) if args.dry_run: sys.exit(0) if not args.force: - print('*** Warning *** Warning *** Warning **') - print('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.') - print('* Enter exactly YES if you are positive you want to proceed:') - response = input('YES, I am sure> ') - if response != 'YES': + print("*** Warning *** Warning *** Warning **") + print( + "* Changing fields in a GGUF file can make it unusable. Proceed at your own risk." + ) + print("* Enter exactly YES if you are positive you want to proceed:") + response = input("YES, I am sure> ") + if response != "YES": print("You didn't enter YES. Okay then, see ya!") sys.exit(0) field.parts[field.data[0]][0] = new_value - print('* Field changed. Successful completion.') + print("* Field changed. Successful completion.") def main() -> None: - parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata") - parser.add_argument("model", type=str, help="GGUF format model filename") - parser.add_argument("key", type=str, help="Metadata key to set") - parser.add_argument("value", type=str, help="Metadata value to set") - parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything") - parser.add_argument("--force", action="store_true", help="Change the field without confirmation") + parser = argparse.ArgumentParser( + description="Set a simple value in GGUF file metadata" + ) + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("key", type=str, help="Metadata key to set") + parser.add_argument("value", type=str, help="Metadata value to set") + parser.add_argument( + "--dry-run", action="store_true", help="Don't actually change anything" + ) + parser.add_argument( + "--force", action="store_true", help="Change the field without confirmation" + ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) - print(f'* Loading: {args.model}') - reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+') + print(f"* Loading: {args.model}") + reader = GGUFReader(args.model, "r" if args.dry_run else "r+") set_metadata(reader, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/powerinfer-py/powerinfer/__main__.py b/powerinfer-py/powerinfer/__main__.py index 53fe879f..a4881856 100644 --- a/powerinfer-py/powerinfer/__main__.py +++ b/powerinfer-py/powerinfer/__main__.py @@ -1,4 +1,3 @@ - import argparse from .solver import solve_gpu_split @@ -6,17 +5,55 @@ if __name__ == "__main__": - + # Set up command line arguments - parser = argparse.ArgumentParser(description='Optimize neuron activation based on VRAM capacity and other parameters.') - parser.add_argument('--activation', type=str, required=True, help='Path to the directory containing activation data.') - parser.add_argument('--neuron', type=int, default=8192*4, help='Total number of neurons in the network.') - parser.add_argument('--capacity', type=int, default=int(8192*4*32*0.1), help='Total VRAM capacity for the model.') - parser.add_argument('--layer', type=int, default=59, help='Total number of layers in the neural network.') - parser.add_argument('--vram-capacity', type=int, help='Total VRAM capacity (Bytes) available for splitting') - parser.add_argument('--batch', type=int, default=256, help='Batch size for processing.') - parser.add_argument('--threshold', type=int, default=0, help='Threshold for splitting a layer across multiple GPUs.') - parser.add_argument('--output', type=str, required=True, help='File path for the output pickle file.') + parser = argparse.ArgumentParser( + description="Optimize neuron activation based on VRAM capacity and other parameters." + ) + parser.add_argument( + "--activation", + type=str, + required=True, + help="Path to the directory containing activation data.", + ) + parser.add_argument( + "--neuron", + type=int, + default=8192 * 4, + help="Total number of neurons in the network.", + ) + parser.add_argument( + "--capacity", + type=int, + default=int(8192 * 4 * 32 * 0.1), + help="Total VRAM capacity for the model.", + ) + parser.add_argument( + "--layer", + type=int, + default=59, + help="Total number of layers in the neural network.", + ) + parser.add_argument( + "--vram-capacity", + type=int, + help="Total VRAM capacity (Bytes) available for splitting", + ) + parser.add_argument( + "--batch", type=int, default=256, help="Batch size for processing." + ) + parser.add_argument( + "--threshold", + type=int, + default=0, + help="Threshold for splitting a layer across multiple GPUs.", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="File path for the output pickle file.", + ) args = parser.parse_args() @@ -37,7 +74,7 @@ activations_path=args.activation, output_path=args.output, solved_list=solved, - vram_capacity=args.vram_capacity + vram_capacity=args.vram_capacity, ) print(f"Exported to {args.output}") diff --git a/powerinfer-py/powerinfer/export_split.py b/powerinfer-py/powerinfer/export_split.py index 7f230d8c..fd72730f 100644 --- a/powerinfer-py/powerinfer/export_split.py +++ b/powerinfer-py/powerinfer/export_split.py @@ -6,13 +6,15 @@ import torch from pathlib import Path import os -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) + +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf import struct import numpy as np import re + def load_activation_weights(models_base: Path): # TODO: might need a specification file to indicate which models to load. # But for now, let's assume it is a plain directory of activation_{0, ... , n_layers - 1}.pt @@ -21,6 +23,7 @@ def load_activation_weights(models_base: Path): activation_files.sort() return [torch.load(models_base / f) for f in activation_files] + def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) -> None: _, indices = torch.topk(activation, k=int(select_count)) gpu_idx = torch.zeros_like(activation) @@ -50,8 +53,13 @@ def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) -> raw_dtype=GGMLQuantizationType.I32, ) -def export_split(activations_path: str, output_path: str, solved_list: list[int], vram_capacity: int): - predictors = load_activation_weights(Path(activations_path)) # predictor => activation acount + +def export_split( + activations_path: str, output_path: str, solved_list: list[int], vram_capacity: int +): + predictors = load_activation_weights( + Path(activations_path) + ) # predictor => activation acount gguf_out = GGUFWriter(output_path, "generic.gpu_index") for i, (activation, selected_count) in enumerate(zip(predictors, solved_list)): append_gpu_idx(gguf_out, i, activation, selected_count) @@ -73,4 +81,3 @@ def export_split(activations_path: str, output_path: str, solved_list: list[int] fout.write(struct.pack(" specify additional CLI ars to be passed to the binary (override all preset files). " - "Unknown args will be ignored.") +epilog = ( + " -- specify additional CLI ars to be passed to the binary (override all preset files). " + "Unknown args will be ignored." +) parser = argparse.ArgumentParser( - description=description, usage=usage, epilog=epilog, formatter_class=argparse.RawTextHelpFormatter) + description=description, + usage=usage, + epilog=epilog, + formatter_class=argparse.RawTextHelpFormatter, +) parser.add_argument("-bin", "--binary", help="The binary to run.") -parser.add_argument("yaml_files", nargs="*", - help="Arbitrary number of YAML files from which to read preset values. " - "If two files specify the same values the later one will be used.") +parser.add_argument( + "yaml_files", + nargs="*", + help="Arbitrary number of YAML files from which to read preset values. " + "If two files specify the same values the later one will be used.", +) known_args, unknown_args = parser.parse_known_args() diff --git a/scripts/verify-checksum-models.py b/scripts/verify-checksum-models.py index dff4b473..1f26b4d0 100755 --- a/scripts/verify-checksum-models.py +++ b/scripts/verify-checksum-models.py @@ -9,7 +9,7 @@ def sha256sum(file): b = bytearray(block_size) file_hash = hashlib.sha256() mv = memoryview(b) - with open(file, 'rb', buffering=0) as f: + with open(file, "rb", buffering=0) as f: while True: n = f.readinto(mv) if not n: @@ -65,15 +65,22 @@ def sha256sum(file): file_missing = "X" # Add the results to the array - results.append({ - "filename": filename, - "valid checksum": valid_checksum, - "file missing": file_missing - }) + results.append( + { + "filename": filename, + "valid checksum": valid_checksum, + "file missing": file_missing, + } + ) # Print column headers for results table -print("\n" + "filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20)) +print( + "\n" + + "filename".ljust(40) + + "valid checksum".center(20) + + "file missing".center(20) +) print("-" * 80) # Output the results as a table diff --git a/smallthinker/convert_hf_to_gguf.py b/smallthinker/convert_hf_to_gguf.py index a57b207c..4878795a 100755 --- a/smallthinker/convert_hf_to_gguf.py +++ b/smallthinker/convert_hf_to_gguf.py @@ -14,7 +14,18 @@ from enum import IntEnum from pathlib import Path from hashlib import sha256 -from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Iterable, + Iterator, + Literal, + Sequence, + TypeVar, + cast, +) from itertools import chain from transformers import AutoConfig @@ -25,8 +36,8 @@ if TYPE_CHECKING: from torch import Tensor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf logger = logging.getLogger("hf-to-gguf") @@ -34,6 +45,7 @@ ###### MODEL DEFINITIONS ###### + class SentencePieceTokenTypes(IntEnum): NORMAL = 1 UNKNOWN = 2 @@ -85,22 +97,41 @@ class ModelBase: block_count: int tensor_map: gguf.TensorNameMap - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False, - use_temp_file: bool = False, eager: bool = False, - metadata_override: Path | None = None, model_name: str | None = None, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, - small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, - transpose_down: str = "none"): - if type(self) is ModelBase or \ - type(self) is TextModel or \ - type(self) is MmprojModel: - raise TypeError(f"{type(self).__name__!r} should not be directly instantiated") + def __init__( + self, + dir_model: Path, + ftype: gguf.LlamaFileType, + fname_out: Path, + *, + is_big_endian: bool = False, + use_temp_file: bool = False, + eager: bool = False, + metadata_override: Path | None = None, + model_name: str | None = None, + split_max_tensors: int = 0, + split_max_size: int = 0, + dry_run: bool = False, + small_first_shard: bool = False, + hparams: dict[str, Any] | None = None, + remote_hf_model_id: str | None = None, + transpose_down: str = "none", + ): + if ( + type(self) is ModelBase + or type(self) is TextModel + or type(self) is MmprojModel + ): + raise TypeError( + f"{type(self).__name__!r} should not be directly instantiated" + ) self.dir_model = dir_model self.ftype = ftype self.fname_out = fname_out self.is_big_endian = is_big_endian - self.endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.endianess = ( + gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + ) self.use_temp_file = use_temp_file self.lazy = not eager or (remote_hf_model_id is not None) self.remote_hf_model_id = remote_hf_model_id @@ -108,19 +139,36 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, self.is_safetensors = True def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: - logger.info(f"Using remote model with HuggingFace id: {remote_hf_model_id}") - remote_tensors = gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id) + logger.info( + f"Using remote model with HuggingFace id: {remote_hf_model_id}" + ) + remote_tensors = ( + gguf.utility.SafetensorRemote.get_list_tensors_hf_model( + remote_hf_model_id + ) + ) self.tensor_names = set(name for name in remote_tensors.keys()) - for name, remote_tensor in gguf.utility.SafetensorRemote.get_list_tensors_hf_model(remote_hf_model_id).items(): + for ( + name, + remote_tensor, + ) in gguf.utility.SafetensorRemote.get_list_tensors_hf_model( + remote_hf_model_id + ).items(): yield (name, LazyTorchTensor.from_remote_tensor(remote_tensor)) self.get_tensors = get_remote_tensors else: - self.part_names = ModelBase.get_model_part_names(self.dir_model, "model", ".safetensors") + self.part_names = ModelBase.get_model_part_names( + self.dir_model, "model", ".safetensors" + ) self.is_safetensors = len(self.part_names) > 0 if not self.is_safetensors: - self.part_names = ModelBase.get_model_part_names(self.dir_model, "pytorch_model", ".bin") - self.hparams = ModelBase.load_hparams(self.dir_model) if hparams is None else hparams + self.part_names = ModelBase.get_model_part_names( + self.dir_model, "pytorch_model", ".bin" + ) + self.hparams = ( + ModelBase.load_hparams(self.dir_model) if hparams is None else hparams + ) self.tensor_names = None self.metadata_override = metadata_override self.model_name = model_name @@ -135,15 +183,27 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]: # NOTE: can't use field "torch_dtype" in config.json, because some finetunes lie. _, first_tensor = next(self.get_tensors()) if first_tensor.dtype == torch.float16: - logger.info(f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})") + logger.info( + f"choosing --outtype f16 from first tensor type ({first_tensor.dtype})" + ) self.ftype = gguf.LlamaFileType.MOSTLY_F16 else: - logger.info(f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})") + logger.info( + f"choosing --outtype bf16 from first tensor type ({first_tensor.dtype})" + ) self.ftype = gguf.LlamaFileType.MOSTLY_BF16 # Configure GGUF Writer - self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, - split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) + self.gguf_writer = gguf.GGUFWriter( + path=None, + arch=gguf.MODEL_ARCH_NAMES[self.model_arch], + endianess=self.endianess, + use_temp_file=self.use_temp_file, + split_max_tensors=split_max_tensors, + split_max_size=split_max_size, + dry_run=dry_run, + small_first_shard=small_first_shard, + ) @classmethod def add_prefix_to_filename(cls, path: Path, prefix: str) -> Path: @@ -184,9 +244,20 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: ctx: ContextManager[Any] if self.is_safetensors: from safetensors import safe_open - ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) + + ctx = cast( + ContextManager[Any], + safe_open(self.dir_model / part_name, framework="pt", device="cpu"), + ) else: - ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True)) + ctx = contextlib.nullcontext( + torch.load( + str(self.dir_model / part_name), + map_location="cpu", + mmap=True, + weights_only=True, + ) + ) with ctx as model_part: tensor_names_from_parts.update(model_part.keys()) @@ -208,25 +279,41 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: if len(tensor_names_from_parts.symmetric_difference(self.tensor_names)) > 0: missing = sorted(self.tensor_names.difference(tensor_names_from_parts)) extra = sorted(tensor_names_from_parts.difference(self.tensor_names)) - missing_files = sorted(set(weight_map[n] for n in missing if n in weight_map)) + missing_files = sorted( + set(weight_map[n] for n in missing if n in weight_map) + ) if len(extra) == 0 and len(missing_files) > 0: - raise ValueError(f"Missing or incomplete model files: {missing_files}\n" - f"Missing tensors: {missing}") + raise ValueError( + f"Missing or incomplete model files: {missing_files}\n" + f"Missing tensors: {missing}" + ) else: - raise ValueError("Mismatch between weight map and model parts for tensor names:\n" - f"Missing tensors: {missing}\n" - f"Extra tensors: {extra}") - - def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str: + raise ValueError( + "Mismatch between weight map and model parts for tensor names:\n" + f"Missing tensors: {missing}\n" + f"Extra tensors: {extra}" + ) + + def format_tensor_name( + self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight" + ) -> str: if key not in gguf.MODEL_TENSORS[self.model_arch]: - raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}") + raise ValueError( + f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}" + ) name: str = gguf.TENSOR_NAMES[key] if "{bid}" in name: assert bid is not None name = name.format(bid=bid) return name + suffix - def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | None, suffix: str = ".weight") -> bool: + def match_model_tensor_name( + self, + name: str, + key: gguf.MODEL_TENSOR, + bid: int | None, + suffix: str = ".weight", + ) -> bool: if key not in gguf.MODEL_TENSORS[self.model_arch]: return False key_name: str = gguf.TENSOR_NAMES[key] @@ -239,21 +326,29 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | return False return name == (key_name + suffix) - def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: + def map_tensor_name( + self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias") + ) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) if new_name is None: raise ValueError(f"Can not map tensor {name!r}") return new_name def set_gguf_parameters(self): - raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") + raise NotImplementedError( + "set_gguf_parameters() must be implemented in subclasses" + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused return [(self.map_tensor_name(name), data_torch)] - def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + def tensor_force_quant( + self, name: str, new_name: str, bid: int | None, n_dims: int + ) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused return False @@ -263,11 +358,17 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: return () def prepare_tensors(self): - max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") + max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len( + ".weight," + ) - for name, data_torch in chain(self.generate_extra_tensors(), self.get_tensors()): + for name, data_torch in chain( + self.generate_extra_tensors(), self.get_tensors() + ): # we don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + if name.endswith( + (".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq") + ): continue # convert any unsupported data types to float32 @@ -281,7 +382,7 @@ def prepare_tensors(self): bid = int(part) break - for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): + for new_name, data_torch in self.modify_tensors(data_torch, name, bid): old_dtype = data_torch.dtype # TODO: why do we squeeze here? @@ -293,7 +394,9 @@ def prepare_tensors(self): data = data_torch.numpy() n_dims = len(data.shape) - data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) + data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant( + name, new_name, bid, n_dims + ) # Most of the codebase that takes in 1D tensors or norms only handles F32 tensors if n_dims <= 1 or new_name.endswith("_norm.weight"): @@ -356,7 +459,6 @@ def prepare_tensors(self): else: raise ValueError(f"Unknown file type: {self.ftype.name}") - try: data = gguf.quants.quantize(data, data_qtype) except gguf.QuantError as e: @@ -364,13 +466,19 @@ def prepare_tensors(self): data_qtype = gguf.GGMLQuantizationType.F16 data = gguf.quants.quantize(data, data_qtype) - shape = gguf.quant_shape_from_byte_shape(data.shape, data_qtype) if data.dtype == np.uint8 else data.shape + shape = ( + gguf.quant_shape_from_byte_shape(data.shape, data_qtype) + if data.dtype == np.uint8 + else data.shape + ) # reverse shape to make it similar to the internal ggml dimension order shape_str = f"{{{', '.join(str(n) for n in reversed(shape))}}}" # n_dims is implicit in the shape - logger.info(f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + logger.info( + f"{f'%-{max_name_len}s' % f'{new_name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}" + ) self.gguf_writer.add_tensor(new_name, data, raw_dtype=data_qtype) @@ -379,9 +487,13 @@ def set_type(self): def prepare_metadata(self, vocab_only: bool): - total_params, shared_params, expert_params, expert_count = self.gguf_writer.get_total_parameter_count() + total_params, shared_params, expert_params, expert_count = ( + self.gguf_writer.get_total_parameter_count() + ) - self.metadata = gguf.Metadata.load(self.metadata_override, self.dir_model_card, self.model_name, total_params) + self.metadata = gguf.Metadata.load( + self.metadata_override, self.dir_model_card, self.model_name, total_params + ) # If we are using HF model id, set the metadata name to the model id if self.remote_hf_model_id: @@ -393,7 +505,9 @@ def prepare_metadata(self, vocab_only: bool): # Generate parameter weight class (useful for leader boards) if not yet determined if self.metadata.size_label is None and total_params > 0: - self.metadata.size_label = gguf.size_label(total_params, shared_params, expert_params, expert_count) + self.metadata.size_label = gguf.size_label( + total_params, shared_params, expert_params, expert_count + ) self.set_type() @@ -433,7 +547,9 @@ def load_hparams(dir_model: Path): try: # for security reason, we don't allow loading remote code by default # if a model need remote code, we will fallback to config.json - config = AutoConfig.from_pretrained(dir_model, trust_remote_code=False).to_dict() + config = AutoConfig.from_pretrained( + dir_model, trust_remote_code=False + ).to_dict() except Exception as e: logger.warning(f"Failed to load model config from {dir_model}: {e}") logger.warning("Trying to load config.json instead") @@ -452,10 +568,15 @@ def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: assert names def func(modelcls: AnyModel) -> AnyModel: - model_type = ModelType.MMPROJ if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ else ModelType.TEXT + model_type = ( + ModelType.MMPROJ + if modelcls.model_arch == gguf.MODEL_ARCH.MMPROJ + else ModelType.TEXT + ) for name in names: cls._model_classes[model_type][name] = modelcls return modelcls + return func @classmethod @@ -466,11 +587,13 @@ def print_registered_models(cls): logger.error(f" - {name}") @classmethod - def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type[ModelBase]: + def from_model_architecture( + cls, arch: str, model_type=ModelType.TEXT + ) -> type[ModelBase]: try: return cls._model_classes[model_type][arch] except KeyError: - raise NotImplementedError(f'Architecture {arch!r} not supported!') from None + raise NotImplementedError(f"Architecture {arch!r} not supported!") from None class TextModel(ModelBase): @@ -485,7 +608,9 @@ def __init__(self, *args, **kwargs): # move the text_config to the root level self.hparams = {**self.hparams, **self.hparams["text_config"]} - self.block_count = self.find_hparam(["n_layers", "num_hidden_layers", "n_layer", "num_layers"]) + self.block_count = self.find_hparam( + ["n_layers", "num_hidden_layers", "n_layer", "num_layers"] + ) self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) @classmethod @@ -509,9 +634,25 @@ def prepare_metadata(self, vocab_only: bool): if self.fname_out.is_dir(): # Generate default filename based on model specification and available metadata if not vocab_only: - fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, self.metadata.size_label, output_type, model_type="LoRA" if total_params < 0 else None) + fname_default: str = gguf.naming_convention( + self.metadata.name, + self.metadata.basename, + self.metadata.finetune, + self.metadata.version, + self.metadata.size_label, + output_type, + model_type="LoRA" if total_params < 0 else None, + ) else: - fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=None, model_type="vocab") + fname_default: str = gguf.naming_convention( + self.metadata.name, + self.metadata.basename, + self.metadata.finetune, + self.metadata.version, + size_label=None, + output_type=None, + model_type="vocab", + ) # Use the default filename self.fname_out = self.fname_out / f"{fname_default}.gguf" @@ -521,7 +662,9 @@ def prepare_metadata(self, vocab_only: bool): # file template strings as it doesn't actually exist as a file # Process templated file name with the output ftype, useful with the "auto" ftype - self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type) + self.fname_out = self.fname_out.parent / gguf.fill_templated_filename( + self.fname_out.name, output_type + ) logger.info("Set model tokenizer") self.set_vocab() @@ -529,19 +672,34 @@ def prepare_metadata(self, vocab_only: bool): def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) - if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None: + if ( + n_ctx := self.find_hparam( + ["max_position_embeddings", "n_ctx", "n_positions", "max_length"], + optional=True, + ) + ) is not None: self.gguf_writer.add_context_length(n_ctx) logger.info(f"gguf: context length = {n_ctx}") - if (n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True)) is not None: + if ( + n_embd := self.find_hparam(["hidden_size", "n_embd", "dim"], optional=True) + ) is not None: self.gguf_writer.add_embedding_length(n_embd) logger.info(f"gguf: embedding length = {n_embd}") - if (n_ff := self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"], optional=True)) is not None: + if ( + n_ff := self.find_hparam( + ["intermediate_size", "n_inner", "hidden_dim"], optional=True + ) + ) is not None: self.gguf_writer.add_feed_forward_length(n_ff) logger.info(f"gguf: feed forward length = {n_ff}") - if (n_head := self.find_hparam(["num_attention_heads", "n_head", "n_heads"], optional=True)) is not None: + if ( + n_head := self.find_hparam( + ["num_attention_heads", "n_head", "n_heads"], optional=True + ) + ) is not None: self.gguf_writer.add_head_count(n_head) logger.info(f"gguf: head count = {n_head}") @@ -555,7 +713,11 @@ def set_gguf_parameters(self): if (f_rms_eps := self.hparams.get("rms_norm_eps")) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") - if (f_norm_eps := self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True)) is not None: + if ( + f_norm_eps := self.find_hparam( + ["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon"], optional=True + ) + ) is not None: self.gguf_writer.add_layer_norm_eps(f_norm_eps) logger.info(f"gguf: layer norm epsilon = {f_norm_eps}") if (n_experts := self.hparams.get("num_local_experts")) is not None: @@ -574,7 +736,7 @@ def set_gguf_parameters(self): def write_vocab(self): if len(self.gguf_writer.tensors) != 1: - raise ValueError('Splitting the vocabulary is not supported') + raise ValueError("Splitting the vocabulary is not supported") self.prepare_metadata(vocab_only=True) self.gguf_writer.write_header_to_file(path=self.fname_out) @@ -593,14 +755,22 @@ def does_token_look_special(self, token: str | bytes) -> bool: # (e.g. command-r, command-r-plus, deepseek-coder, gemma{,-2}) seems_special = token_text in ( "", # deepseek-coder - "", "<2mass>", "[@BOS@]", # gemma{,-2} + "", + "<2mass>", + "[@BOS@]", # gemma{,-2} ) - seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) - seems_special = seems_special or (token_text.startswith("<|") and token_text.endswith("|>")) # deepseek-coder + seems_special = seems_special or ( + token_text.startswith("<|") and token_text.endswith("|>") + ) + seems_special = seems_special or ( + token_text.startswith("<|") and token_text.endswith("|>") + ) # deepseek-coder # TODO: should these be marked as UNUSED instead? (maybe not) - seems_special = seems_special or (token_text.startswith("")) # gemma{,-2} + seems_special = seems_special or ( + token_text.startswith("") + ) # gemma{,-2} return seems_special @@ -610,13 +780,16 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() + } added_vocab = tokenizer.get_added_vocab() added_tokens_decoder = tokenizer.added_tokens_decoder @@ -632,16 +805,24 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: # To avoid unexpected issues - we make sure to normalize non-normalized tokens if not added_tokens_decoder[i].normalized: previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + token = tokenizer.decode( + tokenizer.encode(token, add_special_tokens=False) + ) if previous_token != token: - logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + logger.info( + f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer" + ) - if added_tokens_decoder[i].special or self.does_token_look_special(token): + if added_tokens_decoder[i].special or self.does_token_look_special( + token + ): toktypes.append(gguf.TokenType.CONTROL) else: # NOTE: this was added for Gemma. # Encoding and decoding the tokens above isn't sufficient for this case. - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + token = token.replace( + b"\xe2\x96\x81".decode("utf-8"), " " + ) # pre-normalize user-defined spaces toktypes.append(gguf.TokenType.USER_DEFINED) else: toktypes.append(gguf.TokenType.NORMAL) @@ -659,7 +840,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: # we will use this unique identifier to write a "tokenizer.ggml.pre" entry in the GGUF file which we can # use in llama.cpp to implement the same pre-tokenizer - chktxt = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````""""......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' + chktxt = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶\u200d🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````\"\"\"\"......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL" chktok = tokenizer.encode(chktxt) chkhsh = sha256(str(chktok).encode()).hexdigest() @@ -822,18 +1003,32 @@ def get_vocab_base_pre(self, tokenizer) -> str: if res is None: logger.warning("\n") - logger.warning("**************************************************************************************") + logger.warning( + "**************************************************************************************" + ) logger.warning("** WARNING: The BPE pre-tokenizer was not recognized!") logger.warning("** There are 2 possible reasons for this:") - logger.warning("** - the model has not been added to convert_hf_to_gguf_update.py yet") - logger.warning("** - the pre-tokenization config has changed upstream") - logger.warning("** Check your model files and convert_hf_to_gguf_update.py and update them accordingly.") - logger.warning("** ref: https://github.com/ggml-org/llama.cpp/pull/6920") + logger.warning( + "** - the model has not been added to convert_hf_to_gguf_update.py yet" + ) + logger.warning( + "** - the pre-tokenization config has changed upstream" + ) + logger.warning( + "** Check your model files and convert_hf_to_gguf_update.py and update them accordingly." + ) + logger.warning( + "** ref: https://github.com/ggml-org/llama.cpp/pull/6920" + ) logger.warning("**") logger.warning(f"** chkhsh: {chkhsh}") - logger.warning("**************************************************************************************") + logger.warning( + "**************************************************************************************" + ) logger.warning("\n") - raise NotImplementedError("BPE pre-tokenizer was not recognized - update get_vocab_base_pre()") + raise NotImplementedError( + "BPE pre-tokenizer was not recognized - update get_vocab_base_pre()" + ) logger.debug(f"tokenizer.ggml.pre: {repr(res)}") logger.debug(f"chkhsh: {chkhsh}") @@ -861,6 +1056,7 @@ def _set_vocab_qwen(self): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams["vocab_size"] assert max(tokenizer.get_vocab().values()) < vocab_size @@ -876,11 +1072,13 @@ def _set_vocab_qwen(self): continue merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank) assert len(merged) == 2 - merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged))) + merges.append(" ".join(map(QwenModel.token_bytes_to_string, merged))) # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined added_vocab = tokenizer.special_tokens - reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items() + } for i in range(vocab_size): if i not in reverse_vocab: @@ -902,10 +1100,16 @@ def _set_vocab_qwen(self): special_vocab.merges = merges # only add special tokens when they were not already loaded from config.json if len(special_vocab.special_token_ids) == 0: - special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"]) - special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token( + "bos", tokenizer.special_tokens["<|endoftext|>"] + ) + special_vocab._set_special_token( + "eos", tokenizer.special_tokens["<|endoftext|>"] + ) # this one is usually not in config.json anyway - special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"]) + special_vocab._set_special_token( + "unk", tokenizer.special_tokens["<|endoftext|>"] + ) special_vocab.add_to_gguf(self.gguf_writer) def _set_vocab_sentencepiece(self, add_to_gguf=True): @@ -923,7 +1127,7 @@ def _set_vocab_sentencepiece(self, add_to_gguf=True): def _create_vocab_sentencepiece(self): from sentencepiece import SentencePieceProcessor - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") @@ -931,7 +1135,7 @@ def _create_vocab_sentencepiece(self): tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -956,38 +1160,48 @@ def _create_vocab_sentencepiece(self): scores[token_id] = score toktypes[token_id] = toktype - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) for key in added_tokens_json: token_id = added_tokens_json[key] if token_id >= vocab_size: - logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.warning( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue tokens[token_id] = key.encode("utf-8") scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) - added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + added_tokens_decoder = tokenizer_config_json.get( + "added_tokens_decoder", {} + ) for token_id, token_data in added_tokens_decoder.items(): token_id = int(token_id) token: str = token_data["content"] if token_id >= vocab_size: - logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.warning( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: if tokens[token_id] != token.encode("utf-8"): - logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}') + logger.warning( + f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token!r}' + ) if token_data.get("special") or self.does_token_look_special(token): toktypes[token_id] = SentencePieceTokenTypes.CONTROL else: - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + token = token.replace( + b"\xe2\x96\x81".decode("utf-8"), " " + ) # pre-normalize user-defined spaces toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED scores[token_id] = -1000.0 @@ -995,7 +1209,9 @@ def _create_vocab_sentencepiece(self): if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + logger.debug( + f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]" + ) for i in range(1, pad_count + 1): tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) scores.append(-1000.0) @@ -1029,15 +1245,19 @@ def _set_vocab_rwkv_world(self): assert (self.dir_model / "rwkv_vocab_v20230424.txt").is_file() vocab_size = self.hparams.get("vocab_size", 65536) - tokens: list[bytes] = [''.encode("utf-8")] + tokens: list[bytes] = ["".encode("utf-8")] toktypes: list[int] = [gguf.TokenType.CONTROL] - with open(self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8") as f: + with open( + self.dir_model / "rwkv_vocab_v20230424.txt", "r", encoding="utf-8" + ) as f: lines = f.readlines() for line in lines: - parts = line.split(' ') + parts = line.split(" ") assert len(parts) >= 3 - token, token_len = ast.literal_eval(' '.join(parts[1:-1])), int(parts[-1]) + token, token_len = ast.literal_eval(" ".join(parts[1:-1])), int( + parts[-1] + ) token = token.encode("utf-8") if isinstance(token, str) else token assert isinstance(token, bytes) assert len(token) == token_len @@ -1063,9 +1283,13 @@ def _set_vocab_rwkv_world(self): special_vocab.add_to_gguf(self.gguf_writer) - def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int): + def _set_vocab_builtin( + self, model_name: Literal["gpt-neox", "llama-spm"], vocab_size: int + ): tokenizer_path = Path(sys.path[0]) / "models" / f"ggml-vocab-{model_name}.gguf" - logger.warning(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'") + logger.warning( + f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'" + ) vocab_reader = gguf.GGUFReader(tokenizer_path, "r") default_pre = "mpt" if model_name == "gpt-neox" else "default" @@ -1075,25 +1299,35 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]).decode("utf-8")) field = vocab_reader.get_field(gguf.Keys.Tokenizer.PRE) - self.gguf_writer.add_tokenizer_pre(bytes(field.parts[-1]).decode("utf-8") if field else default_pre) + self.gguf_writer.add_tokenizer_pre( + bytes(field.parts[-1]).decode("utf-8") if field else default_pre + ) field = vocab_reader.get_field(gguf.Keys.Tokenizer.LIST) assert field # token list - self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size]) + self.gguf_writer.add_token_list( + [bytes(field.parts[i]) for i in field.data][:vocab_size] + ) if model_name == "llama-spm": field = vocab_reader.get_field(gguf.Keys.Tokenizer.SCORES) assert field # token scores - self.gguf_writer.add_token_scores([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + self.gguf_writer.add_token_scores( + [field.parts[i].tolist()[0] for i in field.data][:vocab_size] + ) field = vocab_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE) assert field # token types - self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size]) + self.gguf_writer.add_token_types( + [field.parts[i].tolist()[0] for i in field.data][:vocab_size] + ) if model_name != "llama-spm": field = vocab_reader.get_field(gguf.Keys.Tokenizer.MERGES) assert field # token merges - self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data]) + self.gguf_writer.add_token_merges( + [bytes(field.parts[i]) for i in field.data] + ) if (field := vocab_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)) is not None: self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0]) @@ -1122,7 +1356,9 @@ def _try_set_pooling_type(self) -> None: # get pooling type if pooling_path is not None: - with open(self.dir_model / pooling_path / "config.json", encoding="utf-8") as f: + with open( + self.dir_model / pooling_path / "config.json", encoding="utf-8" + ) as f: pooling = json.load(f) if pooling["pooling_mode_mean_tokens"]: pooling_type = gguf.PoolingType.MEAN @@ -1131,7 +1367,9 @@ def _try_set_pooling_type(self) -> None: elif pooling["pooling_mode_lasttoken"]: pooling_type = gguf.PoolingType.LAST else: - raise NotImplementedError("Only MEAN, CLS, and LAST pooling types supported") + raise NotImplementedError( + "Only MEAN, CLS, and LAST pooling types supported" + ) self.gguf_writer.add_pooling_type(pooling_type) @@ -1143,7 +1381,7 @@ class MmprojModel(ModelBase): n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth"] - has_vision_encoder: bool = True # by default + has_vision_encoder: bool = True # by default has_audio_encoder: bool = False # for models having multiple encoders, we need to separate their hparams @@ -1154,7 +1392,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if self.model_arch != gguf.MODEL_ARCH.MMPROJ: - raise TypeError("MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ") + raise TypeError( + "MmprojModel must be subclassed with model_arch = gguf.MODEL_ARCH.MMPROJ" + ) # get n_embd of the text model if "text_config" not in self.hparams: @@ -1167,6 +1407,7 @@ def __init__(self, *args, **kwargs): # move vision config to the top level, while preserving the original hparams in global_config import copy + self.global_config = copy.deepcopy(self.hparams) self.hparams_vision = self.get_vision_config() self.hparams_audio = self.get_audio_config() @@ -1179,11 +1420,17 @@ def __init__(self, *args, **kwargs): # TODO @ngxson : this is a hack to support both vision and audio encoders have_multiple_encoders = self.has_audio_encoder and self.has_vision_encoder - self.block_count = 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) - self.tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.MMPROJ, self.block_count) + self.block_count = ( + 128 if have_multiple_encoders else self.find_hparam(self.n_block_keys, True) + ) + self.tensor_map = gguf.get_tensor_name_map( + gguf.MODEL_ARCH.MMPROJ, self.block_count + ) # load preprocessor config - with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f: + with open( + self.dir_model / "preprocessor_config.json", "r", encoding="utf-8" + ) as f: self.preprocessor_config = json.load(f) def get_vision_config(self) -> dict[str, Any] | None: @@ -1205,13 +1452,21 @@ def set_gguf_parameters(self): # vision config self.gguf_writer.add_vision_image_size(self.find_vparam(["image_size"])) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_embedding_length( + self.find_vparam(["hidden_size"]) + ) + self.gguf_writer.add_vision_feed_forward_length( + self.find_vparam(["intermediate_size"]) + ) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) - self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads"])) + self.gguf_writer.add_vision_head_count( + self.find_vparam(["num_attention_heads"]) + ) # preprocessor config - self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"]) + self.gguf_writer.add_vision_image_mean( + self.preprocessor_config["image_mean"] + ) self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_std"]) if self.has_audio_encoder: @@ -1219,10 +1474,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_audio_projection_dim(self.n_embd_text) # audio config - self.gguf_writer.add_audio_embedding_length(self.find_aparam(["hidden_size"])) - self.gguf_writer.add_audio_feed_forward_length(self.find_aparam(["intermediate_size"])) + self.gguf_writer.add_audio_embedding_length( + self.find_aparam(["hidden_size"]) + ) + self.gguf_writer.add_audio_feed_forward_length( + self.find_aparam(["intermediate_size"]) + ) self.gguf_writer.add_audio_block_count(self.find_aparam(self.n_block_keys)) - self.gguf_writer.add_audio_head_count(self.find_aparam(["num_attention_heads"])) + self.gguf_writer.add_audio_head_count( + self.find_aparam(["num_attention_heads"]) + ) if not self.has_vision_encoder and not self.has_audio_encoder: raise ValueError("MmprojModel must have either vision or audio encoder") @@ -1238,7 +1499,9 @@ def find_aparam(self, keys: Iterable[str], optional: bool = False) -> Any: assert self.hparams_audio is not None return self._find_param(self.hparams_audio, keys, optional) - def _find_param(self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False) -> Any: + def _find_param( + self, obj: dict[str, Any], keys: Iterable[str], optional: bool = False + ) -> Any: key = next((k for k in keys if k in obj), None) if key is not None: return obj[key] @@ -1259,13 +1522,20 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_dimension_count( - int(self.hparams["rotary_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])), + int( + self.hparams["rotary_pct"] + * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + ), ) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) - self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True)) + self.gguf_writer.add_parallel_residual( + self.hparams.get("use_parallel_residual", True) + ) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) @@ -1320,13 +1590,15 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) - name = re.sub(r'transformer\.', '', name) + name = re.sub(r"transformer\.", "", name) tensors: list[tuple[str, Tensor]] = [] @@ -1389,15 +1661,21 @@ def set_gguf_parameters(self): if self.hparams["attn_config"]["clip_qkv"] is not None: self.gguf_writer.add_clamp_kqv(self.hparams["attn_config"]["clip_qkv"]) if self.hparams["attn_config"]["alibi"]: - self.gguf_writer.add_max_alibi_bias(self.hparams["attn_config"]["alibi_bias_max"]) + self.gguf_writer.add_max_alibi_bias( + self.hparams["attn_config"]["alibi_bias_max"] + ) else: self.gguf_writer.add_max_alibi_bias(0.0) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if "scales" in name: - new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias", ".scales")) + new_name = self.map_tensor_name( + name, try_suffixes=(".weight", ".bias", ".scales") + ) new_name = new_name.replace("scales", "act.scales") else: new_name = self.map_tensor_name(name, try_suffixes=(".weight", ".bias")) @@ -1467,18 +1745,25 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: head_count = self.hparams["num_attention_heads"] head_count_kv = self.hparams.get("num_key_value_heads", head_count) @@ -1487,37 +1772,57 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if bid is not None and name == f"model.layers.{bid}.self_attn.W_pack.weight": logger.info(f"Unpacking and permuting layer {bid}") tensors = [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), - self._reverse_hf_permute_part(data_torch, 0, head_count, head_count)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), - self._reverse_hf_permute_part(data_torch, 1, head_count, head_count_kv)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), - self._reverse_hf_part(data_torch, 2)), + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), + self._reverse_hf_permute_part( + data_torch, 0, head_count, head_count + ), + ), + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), + self._reverse_hf_permute_part( + data_torch, 1, head_count, head_count_kv + ), + ), + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), + self._reverse_hf_part(data_torch, 2), + ), ] else: tensors = [(self.map_tensor_name(name), data_torch)] return tensors - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) def _reverse_hf_permute_part( - self, weights: Tensor, n_part: int, n_head: int, n_head_kv: int | None = None, + self, + weights: Tensor, + n_part: int, + n_head: int, + n_head_kv: int | None = None, ) -> Tensor: r = weights.shape[0] // 3 - return self._reverse_hf_permute(weights[r * n_part:r * n_part + r, ...], n_head, n_head_kv) + return self._reverse_hf_permute( + weights[r * n_part : r * n_part + r, ...], n_head, n_head_kv + ) def _reverse_hf_part(self, weights: Tensor, n_part: int) -> Tensor: r = weights.shape[0] // 3 - return weights[r * n_part:r * n_part + r, ...] + return weights[r * n_part : r * n_part + r, ...] @ModelBase.register("XverseForCausalLM") @@ -1533,6 +1838,7 @@ def set_vocab(self): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model) vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) # Since we are checking the maximum index, we need to ensure it's strictly less than vocab_size, @@ -1541,16 +1847,18 @@ def set_vocab(self): if max_vocab_index >= vocab_size: raise ValueError("Vocabulary size exceeds expected maximum size.") - reverse_vocab: dict[int, str] = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + reverse_vocab: dict[int, str] = { + id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() + } added_vocab = tokenizer.get_added_vocab() for token_id in range(vocab_size): - token_text = reverse_vocab[token_id].encode('utf-8') + token_text = reverse_vocab[token_id].encode("utf-8") # replace "\x00" to string with length > 0 if token_text == b"\x00": toktype = gguf.TokenType.BYTE # special - token_text = f"<{token_text}>".encode('utf-8') - elif re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + token_text = f"<{token_text}>".encode("utf-8") + elif re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): toktype = gguf.TokenType.BYTE # special elif reverse_vocab[token_id] in added_vocab: if tokenizer.added_tokens_decoder[token_id].special: @@ -1591,18 +1899,25 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused head_count = self.hparams["num_attention_heads"] @@ -1616,12 +1931,16 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) @@ -1654,7 +1973,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused # QKV tensor transform @@ -1669,10 +1990,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if "query_key_value" in name: n_head = self.find_hparam(["num_attention_heads", "n_head"]) - n_head_kv = self.find_hparam(["num_kv_heads", "n_head_kv"], optional=True) or 1 + n_head_kv = ( + self.find_hparam(["num_kv_heads", "n_head_kv"], optional=True) or 1 + ) head_dim = self.hparams["hidden_size"] // n_head - qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + qkv = data_torch.view( + n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head + ) q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head) k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) @@ -1706,8 +2031,11 @@ def set_vocab(self): super().set_vocab() # TODO: how to determine special FIM tokens automatically? - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'eot']) + special_vocab = gguf.SpecialVocab( + self.dir_model, + load_merges=False, + special_token_types=["prefix", "suffix", "middle", "eot"], + ) special_vocab._set_special_token("prefix", 1) special_vocab._set_special_token("suffix", 3) special_vocab._set_special_token("middle", 2) @@ -1734,7 +2062,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: hidden_dim = self.hparams["n_embd"] inner_dim = 4 * hidden_dim hidden_dim = int(2 * inner_dim / 3) @@ -1748,13 +2078,35 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if bid is not None: if name == f"transformer.h.{bid}.attn.kv.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), data_torch[:n_head_kv * head_dim])) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), data_torch[n_head_kv * head_dim:])) + tensors.append( + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), + data_torch[: n_head_kv * head_dim], + ) + ) + tensors.append( + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), + data_torch[n_head_kv * head_dim :], + ) + ) elif name == f"transformer.h.{bid}.attn.q.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch)) + tensors.append( + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), data_torch) + ) elif name == f"transformer.h.{bid}.mlp.gate_up_proj.weight": - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim])) - tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:])) + tensors.append( + ( + self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), + data_torch[:ff_dim], + ) + ) + tensors.append( + ( + self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), + data_torch[ff_dim:], + ) + ) if len(tensors) == 0: tensors.append((self.map_tensor_name(name), data_torch)) @@ -1762,7 +2114,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return tensors -@ModelBase.register("StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM") +@ModelBase.register( + "StableLmForCausalLM", "StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM" +) class StableLMModel(TextModel): model_arch = gguf.MODEL_ARCH.STABLELM @@ -1782,17 +2136,30 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) - self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + self.gguf_writer.add_rope_dimension_count( + int( + rotary_factor + * (hparams["hidden_size"] // hparams["num_attention_heads"]) + ) + ) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) - self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) - self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_eps", "norm_eps"])) + self.gguf_writer.add_parallel_residual( + hparams["use_parallel_residual"] + if "use_parallel_residual" in hparams + else True + ) + self.gguf_writer.add_layer_norm_eps( + self.find_hparam(["layer_norm_eps", "norm_eps"]) + ) self.gguf_writer.add_file_type(self.ftype) _q_norms: list[dict[str, Tensor]] | None = None _k_norms: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams["num_key_value_heads"] @@ -1805,7 +2172,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._q_norms[bid][name] = data_torch if len(self._q_norms[bid]) >= n_head: - return self._stack_qk_norm(bid, n_head, self._q_norms[bid], "q_layernorm") + return self._stack_qk_norm( + bid, n_head, self._q_norms[bid], "q_layernorm" + ) else: return [] @@ -1818,13 +2187,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self._k_norms[bid][name] = data_torch if len(self._k_norms[bid]) >= n_kv_head: - return self._stack_qk_norm(bid, n_kv_head, self._k_norms[bid], "k_layernorm") + return self._stack_qk_norm( + bid, n_kv_head, self._k_norms[bid], "k_layernorm" + ) else: return [] return [(self.map_tensor_name(name), data_torch)] - def _stack_qk_norm(self, bid: int, n_head: int, norms: dict[str, Tensor], layer_name: str = "q_layernorm"): + def _stack_qk_norm( + self, + bid: int, + n_head: int, + norms: dict[str, Tensor], + layer_name: str = "q_layernorm", + ): datas: list[Tensor] = [] # extract the norms in order for xid in range(n_head): @@ -1844,9 +2221,13 @@ def prepare_tensors(self): if self._q_norms is not None or self._k_norms is not None: # flatten two `list[dict[str, Tensor]]` into a single `list[str]` norms = ( - [k for d in self._q_norms for k in d.keys()] if self._q_norms is not None else [] + [k for d in self._q_norms for k in d.keys()] + if self._q_norms is not None + else [] ) + ( - [k for d in self._k_norms for k in d.keys()] if self._k_norms is not None else [] + [k for d in self._k_norms for k in d.keys()] + if self._k_norms is not None + else [] ) if len(norms) > 0: raise ValueError(f"Unprocessed norms: {norms}") @@ -1859,7 +2240,8 @@ def prepare_tensors(self): "MixtralForCausalLM", "VLlama3ForCausalLM", "LlavaForConditionalGeneration", - "LlamaModel") + "LlamaModel", +) class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA undo_permute = True @@ -1868,7 +2250,9 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # fix for SmolVLM2, missing `num_attention_heads` in config.json if self.hf_arch == "VLlama3ForCausalLM": - self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32) + self.hparams["num_attention_heads"] = self.hparams.get( + "num_attention_heads", 32 + ) def set_vocab(self): try: @@ -1883,21 +2267,24 @@ def set_vocab(self): # Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256) if self.hparams.get("vocab_size", 32000) == 32016: special_vocab = gguf.SpecialVocab( - self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'eot'] + self.dir_model, + load_merges=False, + special_token_types=["prefix", "suffix", "middle", "eot"], ) special_vocab._set_special_token("prefix", 32007) special_vocab._set_special_token("suffix", 32008) special_vocab._set_special_token("middle", 32009) - special_vocab._set_special_token("eot", 32010) + special_vocab._set_special_token("eot", 32010) special_vocab.add_to_gguf(self.gguf_writer) - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) if "add_prefix_space" in tokenizer_config_json: - self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + self.gguf_writer.add_add_space_prefix( + tokenizer_config_json["add_prefix_space"] + ) # Apply to granite small models only if self.hparams.get("vocab_size", 32000) == 49152: @@ -1913,7 +2300,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) @@ -1921,28 +2311,36 @@ def set_gguf_parameters(self): def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - is_vision_tensor = "vision_tower" in name \ - or "vision_model" in name \ - or "model.connector" in name \ + is_vision_tensor = ( + "vision_tower" in name + or "vision_model" in name + or "model.connector" in name or "multi_modal_projector" in name + ) if is_vision_tensor: - return [] # skip vision tensors + return [] # skip vision tensors elif self.hf_arch == "LlamaModel": name = "model." + name elif name.startswith("model.text_model"): - name = name.replace("text_model.", "") # for SmolVLM + name = name.replace("text_model.", "") # for SmolVLM elif name.startswith("language_model."): - name = name.replace("language_model.", "") # for the rest + name = name.replace("language_model.", "") # for the rest if self.undo_permute: if name.endswith(("q_proj.weight", "q_proj.bias")): @@ -1988,16 +2386,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): - if rope_scaling.get("rope_type", '').lower() == "llama3": + if rope_scaling.get("rope_type", "").lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) if (dim := self.hparams.get("head_dim")) is None: - dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + dim = ( + self.hparams["hidden_size"] + // self.hparams["num_attention_heads"] + ) + freqs = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) factor = rope_scaling.get("factor", 8.0) low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = self.hparams.get( + "original_max_position_embeddings", 8192 + ) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -2011,10 +2416,15 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: elif wavelen > low_freq_wavelen: rope_factors.append(factor) else: - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), + torch.tensor(rope_factors, dtype=torch.float32), + ) def prepare_tensors(self): super().prepare_tensors() @@ -2025,9 +2435,10 @@ def prepare_tensors(self): if len(experts) > 0: raise ValueError(f"Unprocessed experts: {experts}") + @ModelBase.register( - "LlavaForConditionalGeneration", # pixtral - "Mistral3ForConditionalGeneration", # mistral small 3.1 + "LlavaForConditionalGeneration", # pixtral + "Mistral3ForConditionalGeneration", # mistral small 3.1 ) class LlavaVisionModel(MmprojModel): img_break_tok_id = -1 @@ -2043,9 +2454,9 @@ def __init__(self, *args, **kwargs): raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") def get_token_id(self, token: str) -> int: - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" with open(tokenizer_config_file, "r", encoding="utf-8") as f: - added_tokens_decoder = json.load(f)['added_tokens_decoder'] + added_tokens_decoder = json.load(f)["added_tokens_decoder"] for id_, token_data in added_tokens_decoder.items(): if token_data["content"] == token: return int(id_) @@ -2056,7 +2467,9 @@ def set_gguf_parameters(self): hparams = self.hparams if hparams["model_type"] == "pixtral": self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.PIXTRAL) - self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + self.gguf_writer.add_vision_attention_layernorm_eps( + hparams["layer_norm_eps"] + ) # hidden_act if hparams["hidden_act"] == "silu": @@ -2068,14 +2481,20 @@ def set_gguf_parameters(self): # spatial_merge_size if "spatial_merge_size" in self.global_config: - self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"]) + self.gguf_writer.add_vision_spatial_merge_size( + self.global_config["spatial_merge_size"] + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_head = self.hparams["num_attention_heads"] n_kv_head = n_head - if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."): + if name.startswith("multi_modal_projector.") or name.startswith( + "vision_tower." + ): # process vision tensors if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) @@ -2090,10 +2509,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK] return [(self.map_tensor_name(name), img_break_embd)] - return [] # skip other tensors + return [] # skip other tensors -@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration") +@ModelBase.register( + "Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration" +) class SmolVLMModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2101,14 +2522,22 @@ def __init__(self, *args, **kwargs): # fix for SmolVLM2, missing some keys in config.json # default values are taken from transformers code self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152) - self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16) - self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072) + self.hparams["num_attention_heads"] = self.hparams.get( + "num_attention_heads", 16 + ) + self.hparams["intermediate_size"] = self.hparams.get( + "intermediate_size", 3072 + ) def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.IDEFICS3) - self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) - self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2)) + self.gguf_writer.add_vision_attention_layernorm_eps( + self.hparams.get("layer_norm_eps", 1e-5) + ) + self.gguf_writer.add_vision_projector_scale_factor( + self.global_config.get("scale_factor", 2) + ) self.gguf_writer.add_vision_use_gelu(True) def tensor_force_quant(self, name, new_name, bid, n_dims): @@ -2117,14 +2546,20 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F32 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused - is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name + is_vision_tensor = ( + "vision_tower" in name + or "vision_model" in name + or "model.connector" in name + ) if is_vision_tensor: return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + return [] # skip other tensors @ModelBase.register("Llama4ForConditionalGeneration") @@ -2143,8 +2578,12 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_interleave_moe_layer_step(self.hparams["interleave_moe_layer_step"]) - self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size_moe"]) + self.gguf_writer.add_interleave_moe_layer_step( + self.hparams["interleave_moe_layer_step"] + ) + self.gguf_writer.add_expert_feed_forward_length( + self.hparams["intermediate_size_moe"] + ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): if name.startswith("language_model."): @@ -2155,10 +2594,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None): name_up = name.replace("gate_up_proj", "up_proj.weight") name_gate = name.replace("gate_up_proj", "gate_proj.weight") dim_half = data_torch.shape[-1] // 2 - gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split(dim_half, dim=-2) + gate_proj_weight, up_proj_weight = data_torch.transpose(-1, -2).split( + dim_half, dim=-2 + ) return [ (self.map_tensor_name(name_gate), gate_proj_weight), - (self.map_tensor_name(name_up), up_proj_weight) + (self.map_tensor_name(name_up), up_proj_weight), ] if name.endswith("down_proj"): @@ -2176,19 +2617,28 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LLAMA4) self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"]) - self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"])) + self.gguf_writer.add_vision_projector_scale_factor( + int(1.0 / self.hparams["pixel_shuffle_ratio"]) + ) assert self.hparams["hidden_act"] == "gelu" self.gguf_writer.add_vision_use_gelu(True) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: + del bid # unused if "multi_modal_projector" in name or "vision_model" in name: # process vision tensors if "positional_embedding_vlm" in name and ".weight" not in name: name += ".weight" if "multi_modal_projector.linear_1" in name: # despite the name with number postfix, this is a single fully connected layer - return [(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + '.weight', data_torch)] + return [ + ( + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_MMPROJ_FC] + ".weight", + data_torch, + ) + ] return [(self.map_tensor_name(name), data_torch)] return [] @@ -2224,8 +2674,8 @@ def _find_multiple(n: int, k: int) -> int: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B - _block_configs: list[dict[str,Any]] = self.hparams["block_configs"] + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + _block_configs: list[dict[str, Any]] = self.hparams["block_configs"] assert self.block_count == len(_block_configs) self._num_kv_heads = list() self._num_heads = list() @@ -2252,20 +2702,31 @@ def __init__(self, *args, **kwargs): self._num_kv_heads.append(0) self._num_heads.append(0) else: - self._num_kv_heads.append(self.hparams["num_attention_heads"] // _block_configs[il]["attention"]["n_heads_in_group"]) + self._num_kv_heads.append( + self.hparams["num_attention_heads"] + // _block_configs[il]["attention"]["n_heads_in_group"] + ) self._num_heads.append(self.hparams["num_attention_heads"]) - if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer + if _block_configs[il]["ffn"]["ffn_mult"] is None: # dummy layer _ffn_multipliers.append(0.0) else: _ffn_multipliers.append(_block_configs[il]["ffn"]["ffn_mult"]) assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_heads) assert self.block_count == len(_ffn_multipliers) - assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) - assert isinstance(self._num_heads, list) and isinstance(self._num_heads[0], int) - assert isinstance(_ffn_multipliers, list) and isinstance(_ffn_multipliers[0], float) + assert isinstance(self._num_kv_heads, list) and isinstance( + self._num_kv_heads[0], int + ) + assert isinstance(self._num_heads, list) and isinstance( + self._num_heads[0], int + ) + assert isinstance(_ffn_multipliers, list) and isinstance( + _ffn_multipliers[0], float + ) self._ffn_dims: list[int] = [ - DeciModel._ffn_mult_to_intermediate_size(multiplier, self.hparams["hidden_size"]) + DeciModel._ffn_mult_to_intermediate_size( + multiplier, self.hparams["hidden_size"] + ) for multiplier in _ffn_multipliers ] @@ -2286,7 +2747,7 @@ def set_vocab(self): self._set_vocab_llama_hf() def set_gguf_parameters(self): - if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B + if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B assert self.block_count == len(self._num_kv_heads) assert self.block_count == len(self._num_heads) assert self.block_count == len(self._ffn_dims) @@ -2299,13 +2760,19 @@ def set_gguf_parameters(self): self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"]) self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) - self.gguf_writer.add_key_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) - self.gguf_writer.add_value_length(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_key_length( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_value_length( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_file_type(self.ftype) - else: # DeciLM-7B + else: # DeciLM-7B super().set_gguf_parameters() - if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B - self._num_kv_heads: list[int] = self.hparams["num_key_value_heads_per_layer"] + if "num_key_value_heads_per_layer" in self.hparams: # DeciLM-7B + self._num_kv_heads: list[int] = self.hparams[ + "num_key_value_heads_per_layer" + ] assert self.block_count == len(self._num_kv_heads) self.gguf_writer.add_head_count_kv(self._num_kv_heads) hparams = self.hparams @@ -2316,7 +2783,10 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) @@ -2324,11 +2794,17 @@ def set_gguf_parameters(self): def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] if bid is not None: if "num_key_value_heads_per_layer" in self.hparams: @@ -2349,16 +2825,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): - if rope_scaling.get("rope_type", '').lower() == "llama3": + if rope_scaling.get("rope_type", "").lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) if (dim := self.hparams.get("head_dim")) is None: - dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + dim = ( + self.hparams["hidden_size"] + // self.hparams["num_attention_heads"] + ) + freqs = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) factor = rope_scaling.get("factor", 8.0) low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = self.hparams.get( + "original_max_position_embeddings", 8192 + ) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -2372,10 +2855,15 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: elif wavelen > low_freq_wavelen: rope_factors.append(factor) else: - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), + torch.tensor(rope_factors, dtype=torch.float32), + ) def prepare_tensors(self): super().prepare_tensors() @@ -2404,18 +2892,23 @@ def weight_quant(self, weight: Tensor) -> Tensor: result = (weight * iscale).round().clamp(-1, 1) / iscale return result.type(dtype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) - if any(self.match_model_tensor_name(new_name, key, bid) for key in [ - gguf.MODEL_TENSOR.ATTN_Q, - gguf.MODEL_TENSOR.ATTN_K, - gguf.MODEL_TENSOR.ATTN_V, - gguf.MODEL_TENSOR.ATTN_OUT, - gguf.MODEL_TENSOR.FFN_UP, - gguf.MODEL_TENSOR.FFN_DOWN, - gguf.MODEL_TENSOR.FFN_GATE, - ]): + if any( + self.match_model_tensor_name(new_name, key, bid) + for key in [ + gguf.MODEL_TENSOR.ATTN_Q, + gguf.MODEL_TENSOR.ATTN_K, + gguf.MODEL_TENSOR.ATTN_V, + gguf.MODEL_TENSOR.ATTN_OUT, + gguf.MODEL_TENSOR.FFN_UP, + gguf.MODEL_TENSOR.FFN_DOWN, + gguf.MODEL_TENSOR.FFN_GATE, + ] + ): # transform weight into 1/0/-1 (in fp32) data_torch = self.weight_quant(data_torch) @@ -2437,7 +2930,9 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find(".moe.") != -1: n_experts = self.hparams["num_local_experts"] @@ -2457,7 +2952,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter datas: list[Tensor] = [] for xid in range(n_experts): - ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight" + ename = ( + f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight" + ) datas.append(self._experts[bid][ename]) del self._experts[bid][ename] @@ -2503,7 +3000,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_file_type(self.ftype) logger.info(f"gguf: file type = {self.ftype}") - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_expert = self.hparams["ffn_config"]["moe_num_experts"] @@ -2515,9 +3014,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # But llama.cpp moe graph works differently # AND the dimensions in ggml are typically in the reverse order of the pytorch dimensions # so (n_expert, n_ff, n_embd) in pytorch is {n_embd, n_ff, n_expert} in ggml_tensor - exp_tensor_names = {"ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} - "ffn.experts.mlp.w2": (0, 2, 1), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert} - "ffn.experts.mlp.v1": None} # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + exp_tensor_names = { + "ffn.experts.mlp.w1": None, # LLM_TENSOR_FFN_GATE_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} + "ffn.experts.mlp.w2": ( + 0, + 2, + 1, + ), # LLM_TENSOR_FFN_DOWN_EXPS ggml_tensor->ne{n_ff, n_embd, n_expert} + "ffn.experts.mlp.v1": None, + } # LLM_TENSOR_FFN_UP_EXPS ggml_tensor->ne{n_embd, n_ff, n_expert} experts = False for exp_tensor_name in exp_tensor_names.keys(): @@ -2534,11 +3039,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # Every other model has the weight names ending in .weight, # let's assume that is the convention which is not the case for dbrx: # https://huggingface.co/databricks/dbrx-instruct/blob/main/model.safetensors.index.json#L15 - new_name = self.map_tensor_name(name if not experts else name + ".weight", try_suffixes=(".weight",)) + new_name = self.map_tensor_name( + name if not experts else name + ".weight", try_suffixes=(".weight",) + ) return [(new_name, data_torch)] - def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: + def tensor_force_quant( + self, name: str, new_name: str, bid: int | None, n_dims: int + ) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid # unused return n_dims > 1 @@ -2553,7 +3062,9 @@ def set_gguf_parameters(self): embedding_scale = float(self.hparams["scale_emb"]) self.gguf_writer.add_embedding_scale(embedding_scale) logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}") - residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5 + residual_scale = ( + self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5 + ) self.gguf_writer.add_residual_scale(residual_scale) logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}") logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"] @@ -2562,29 +3073,46 @@ def set_gguf_parameters(self): rope_scaling = self.hparams.get("rope_scaling") or {} if rope_scaling.get("rope_type", rope_scaling.get("type")) == "longrope": self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE) - logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}") + logger.info( + f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}" + ) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - rope_scaling = self.find_hparam(['rope_scaling'], True) + rope_scaling = self.find_hparam(["rope_scaling"], True) if rope_scaling is not None: - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) + long_factors = rope_scaling.get("long_factor", None) + short_factors = rope_scaling.get("short_factor", None) if long_factors is None or short_factors is None: - raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + raise KeyError( + "Missing the required key rope_scaling.long_factor or rope_scaling_short_factor" + ) - if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + if ( + len(long_factors) != len(short_factors) + or len(long_factors) != rope_dims / 2 + ): + raise ValueError( + f"The length of rope long and short factors must be {rope_dims / 2}" + ) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), + torch.tensor(long_factors, dtype=torch.float32), + ) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), + torch.tensor(short_factors, dtype=torch.float32), + ) def set_vocab(self): self._set_vocab_sentencepiece() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_head = self.hparams["num_attention_heads"] @@ -2618,35 +3146,54 @@ def set_gguf_parameters(self): if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_key_length( + hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"] + ) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: - rope_scaling = self.find_hparam(['rope_scaling'], True) + rope_scaling = self.find_hparam(["rope_scaling"], True) if rope_scaling is not None: rope_dims = self.hparams["qk_rope_head_dim"] - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) + long_factors = rope_scaling.get("long_factor", None) + short_factors = rope_scaling.get("short_factor", None) if long_factors is None or short_factors is None: - raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + raise KeyError( + "Missing the required key rope_scaling.long_factor or rope_scaling_short_factor" + ) - if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}') + if ( + len(long_factors) != len(short_factors) + or len(long_factors) != rope_dims / 2 + ): + raise ValueError( + f"The length of rope long and short factors must be {rope_dims / 2}" + ) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), + torch.tensor(long_factors, dtype=torch.float32), + ) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), + torch.tensor(short_factors, dtype=torch.float32), + ) def set_vocab(self): self._set_vocab_sentencepiece() - def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor: + def _reverse_hf_permute( + self, weights: Tensor, n_head: int, n_kv_head: int | None = None + ) -> Tensor: if n_kv_head is not None and n_head != n_kv_head: n_head //= n_kv_head return ( - weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) .swapaxes(1, 2) .reshape(weights.shape) ) @@ -2659,11 +3206,14 @@ class QwenModel(TextModel): @staticmethod def token_bytes_to_string(b): from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() - return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) @staticmethod - def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + def bpe( + mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None + ) -> list[bytes]: parts = [bytes([b]) for b in token] while True: min_idx = None @@ -2676,7 +3226,11 @@ def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = if min_rank is None or (max_rank is not None and min_rank >= max_rank): break assert min_idx is not None - parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + parts = ( + parts[:min_idx] + + [parts[min_idx] + parts[min_idx + 1]] + + parts[min_idx + 2 :] + ) return parts def set_vocab(self): @@ -2688,13 +3242,17 @@ def set_gguf_parameters(self): self.gguf_writer.add_embedding_length(self.hparams["hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) - self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) -@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration") +@ModelBase.register( + "Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration" +) class Qwen2Model(TextModel): model_arch = gguf.MODEL_ARCH.QWEN2 @@ -2704,13 +3262,19 @@ def set_vocab(self): except FileNotFoundError: self._set_vocab_gpt2() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: if self.hf_arch == "Qwen2Model": name = f"model.{name}" # map to Qwen2ForCausalLM tensors if "language_model." in name: - name = name.replace("language_model.", "") # for InternVL - if name.startswith("mlp") or name.startswith("multi_modal_projector") \ - or name.startswith("vision_model") or name.startswith("audio_tower"): + name = name.replace("language_model.", "") # for InternVL + if ( + name.startswith("mlp") + or name.startswith("multi_modal_projector") + or name.startswith("vision_model") + or name.startswith("audio_tower") + ): # skip vision and audio tensors return [] yield from super().modify_tensors(data_torch, name, bid) @@ -2719,10 +3283,15 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self._try_set_pooling_type() rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) # def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # if self.hf_arch == "Qwen2Model": @@ -2757,56 +3326,85 @@ def set_vocab(self): except FileNotFoundError: self._set_vocab_gpt2() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("thinker."): name = name.replace("thinker.", "") - if name.startswith("visual") or name.startswith("audio") or \ - name.startswith("talker") or name.startswith("token2wav"): + if ( + name.startswith("visual") + or name.startswith("audio") + or name.startswith("talker") + or name.startswith("token2wav") + ): # skip multimodal tensors return [] return [(self.map_tensor_name(name), data_torch)] -@ModelBase.register("Qwen2VLModel", "Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration") +@ModelBase.register( + "Qwen2VLModel", + "Qwen2VLForConditionalGeneration", + "Qwen2_5_VLForConditionalGeneration", +) class Qwen2VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.hparams_vision is not None self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) # rename config.json values - self.hparams_vision["num_attention_heads"] = self.hparams_vision.get("num_heads") + self.hparams_vision["num_attention_heads"] = self.hparams_vision.get( + "num_heads" + ) self.hparams_vision["num_hidden_layers"] = self.hparams_vision.get("depth") - if "embed_dim" in self.hparams_vision: # qwen2vl - self.hparams_vision["intermediate_size"] = self.hparams_vision.get("hidden_size") + if "embed_dim" in self.hparams_vision: # qwen2vl + self.hparams_vision["intermediate_size"] = self.hparams_vision.get( + "hidden_size" + ) self.hparams_vision["hidden_size"] = self.hparams_vision.get("embed_dim") def set_gguf_parameters(self): super().set_gguf_parameters() assert self.hparams_vision is not None hparams = self.hparams_vision - model_type = self.global_config['model_type'] - if model_type == 'qwen2_vl': + model_type = self.global_config["model_type"] + if model_type == "qwen2_vl": self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2VL) - elif model_type == 'qwen2_5_vl' or model_type == 'qwen2_5_omni': - if model_type == 'qwen2_5_omni': - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25O) + elif model_type == "qwen2_5_vl" or model_type == "qwen2_5_omni": + if model_type == "qwen2_5_omni": + self.gguf_writer.add_clip_projector_type( + gguf.VisionProjectorType.QWEN25O + ) else: - self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN25VL) + self.gguf_writer.add_clip_projector_type( + gguf.VisionProjectorType.QWEN25VL + ) self.gguf_writer.add_vision_use_silu(True) # find n_wa_pattern (window attention pattern) fullatt_block_indexes = hparams.get("fullatt_block_indexes") - assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for qwen2_5_vl" + assert ( + fullatt_block_indexes is not None + ), "fullatt_block_indexes is required for qwen2_5_vl" n_wa_pattern = fullatt_block_indexes[0] + 1 # validate n_wa_pattern for i in range(1, len(fullatt_block_indexes)): - if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: - raise ValueError(f"Invalid fullatt_block_indexes: {fullatt_block_indexes}") + if ( + fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] + != n_wa_pattern + ): + raise ValueError( + f"Invalid fullatt_block_indexes: {fullatt_block_indexes}" + ) self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) else: - raise ValueError(f"Unknown QwenVL model type: {self.global_config['model_type']}") + raise ValueError( + f"Unknown QwenVL model type: {self.global_config['model_type']}" + ) # default values below are taken from HF tranformers code - self.gguf_writer.add_vision_attention_layernorm_eps(self.global_config.get("rms_norm_eps", 1e-6)) + self.gguf_writer.add_vision_attention_layernorm_eps( + self.global_config.get("rms_norm_eps", 1e-6) + ) def tensor_force_quant(self, name, new_name, bid, n_dims): del bid, name, n_dims # unused @@ -2816,38 +3414,50 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F32 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("visual."): # process visual tensors # split QKV tensors if needed if ".qkv." in name: - if data_torch.ndim == 2: # weight + if data_torch.ndim == 2: # weight c3, _ = data_torch.shape - else: # bias + else: # bias c3 = data_torch.shape[0] assert c3 % 3 == 0 c = c3 // 3 wq = data_torch[:c] - wk = data_torch[c: c * 2] - wv = data_torch[c * 2:] + wk = data_torch[c : c * 2] + wv = data_torch[c * 2 :] return [ (self.map_tensor_name(name.replace("qkv", "q")), wq), (self.map_tensor_name(name.replace("qkv", "k")), wk), (self.map_tensor_name(name.replace("qkv", "v")), wv), ] - elif 'patch_embed.proj.weight' in name: + elif "patch_embed.proj.weight" in name: # split Conv3D into Conv2Ds c1, c2, kt, kh, kw = data_torch.shape del c1, c2, kh, kw # unused - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + assert ( + kt == 2 + ), "Current implmentation only support temporal_patch_size of 2" return [ - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]), - (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]), + ( + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + + ".weight", + data_torch[:, :, 0, ...], + ), + ( + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + + ".weight.1", + data_torch[:, :, 1, ...], + ), ] else: return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + return [] # skip other tensors @ModelBase.register("Qwen2_5OmniModel") @@ -2860,13 +3470,17 @@ def __init__(self, *args, **kwargs): assert self.hparams_audio is not None self.hparams_audio["hidden_size"] = self.hparams_audio["d_model"] self.hparams_audio["intermediate_size"] = self.hparams_audio["encoder_ffn_dim"] - self.hparams_audio["num_attention_heads"] = self.hparams_audio["encoder_attention_heads"] + self.hparams_audio["num_attention_heads"] = self.hparams_audio[ + "encoder_attention_heads" + ] def set_gguf_parameters(self): super().set_gguf_parameters() assert self.hparams_audio is not None self.gguf_writer.add_audio_num_mel_bins(self.hparams_audio["num_mel_bins"]) - self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams_audio.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_audio_attention_layernorm_eps( + self.hparams_audio.get("layer_norm_eps", 1e-5) + ) def get_vision_config(self) -> dict[str, Any] | None: return self.global_config["thinker_config"].get("vision_config") @@ -2881,9 +3495,15 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: length = 1500 channels = self.hparams_audio["hidden_size"] log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) - inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float()) - scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] - pos_embd = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1).to(dtype=torch.float32) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + pos_embd = torch.cat( + [torch.sin(scaled_time), torch.cos(scaled_time)], dim=1 + ).to(dtype=torch.float32) yield ("audio_tower.embed_positions.weight", pos_embd) def tensor_force_quant(self, name, new_name, bid, n_dims): @@ -2892,7 +3512,9 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F16 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: if name.startswith("thinker."): name = name.replace("thinker.", "") @@ -2937,46 +3559,68 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F32 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("vision_model") or name.startswith("mlp"): # process visual tensors # correct name if name.startswith("vision_model"): name = "vision_tower." + name - if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + if (".ls" in name or "position_embedding" in name) and not name.endswith( + ".weight" + ): name += ".weight" # split QKV tensors if needed if ".qkv." in name: - if data_torch.ndim == 2: # weight + if data_torch.ndim == 2: # weight c3, _ = data_torch.shape - else: # bias + else: # bias c3 = data_torch.shape[0] assert c3 % 3 == 0 c = c3 // 3 wq = data_torch[:c] - wk = data_torch[c: c * 2] - wv = data_torch[c * 2:] + wk = data_torch[c : c * 2] + wv = data_torch[c * 2 :] return [ - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), - (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), + ( + self.map_tensor_name( + name.replace("attn.qkv", "self_attn.q_proj") + ), + wq, + ), + ( + self.map_tensor_name( + name.replace("attn.qkv", "self_attn.k_proj") + ), + wk, + ), + ( + self.map_tensor_name( + name.replace("attn.qkv", "self_attn.v_proj") + ), + wv, + ), ] return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + return [] # skip other tensors @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused - if \ - name.endswith("codebook.cluster_size") or \ - name.endswith("codebook.embed_avg") or \ - name.endswith("codebook.inited"): + if ( + name.endswith("codebook.cluster_size") + or name.endswith("codebook.embed_avg") + or name.endswith("codebook.inited") + ): logger.debug(f"Skipping {name!r}") return [] @@ -2989,17 +3633,19 @@ def set_vocab(self): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size (self.hparams["vocab_size"]) - self.gguf_writer.add_features_length (self.hparams["n_embd_features"]) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_features_length(self.hparams["n_embd_features"]) self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"]) - self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"]) - self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"]) + self.gguf_writer.add_group_norm_eps(self.hparams["group_norm_epsilon"]) + self.gguf_writer.add_group_norm_groups(self.hparams["group_norm_groups"]) self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"]) - self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"]) + self.gguf_writer.add_posnet_block_count(self.hparams["posnet"]["n_layer"]) - self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"]) - self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"]) + self.gguf_writer.add_convnext_embedding_length( + self.hparams["convnext"]["n_embd"] + ) + self.gguf_writer.add_convnext_block_count(self.hparams["convnext"]["n_layer"]) self.gguf_writer.add_causal_attention(False) @@ -3012,23 +3658,40 @@ def set_gguf_parameters(self): super().set_gguf_parameters() if (n_experts := self.hparams.get("num_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) - if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: + if ( + moe_intermediate_size := self.hparams.get("moe_intermediate_size") + ) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") - if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None: - self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size) - logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}") + if ( + shared_expert_intermediate_size := self.hparams.get( + "shared_expert_intermediate_size" + ) + ) is not None: + self.gguf_writer.add_expert_shared_feed_forward_length( + shared_expert_intermediate_size + ) + logger.info( + f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}" + ) # YaRN is not enabled by default # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: n_experts = self.hparams["num_experts"] @@ -3097,7 +3760,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused tensors: list[tuple[str, Tensor]] = [] @@ -3106,7 +3771,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith((".attn.bias", ".attn.masked_bias")): return tensors - if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight")): + if name.endswith( + (".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_proj.weight") + ): data_torch = data_torch.transpose(1, 0) new_name = self.map_tensor_name(name) @@ -3127,14 +3794,18 @@ def set_gguf_parameters(self): n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) - self.gguf_writer.add_context_length(self.find_hparam(["n_positions", "max_position_embeddings"])) + self.gguf_writer.add_context_length( + self.find_hparam(["n_positions", "max_position_embeddings"]) + ) self.gguf_writer.add_embedding_length(n_embd) self.gguf_writer.add_feed_forward_length(4 * n_embd) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head) - self.gguf_writer.add_layer_norm_eps(self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"])) + self.gguf_writer.add_layer_norm_eps( + self.find_hparam(["layer_norm_epsilon", "layer_norm_eps"]) + ) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) self.gguf_writer.add_file_type(self.ftype) self.gguf_writer.add_add_bos_token(False) @@ -3146,25 +3817,25 @@ class Phi3MiniModel(TextModel): def set_vocab(self): # Phi-4 model uses GPT2Tokenizer - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) - tokenizer_class = tokenizer_config_json['tokenizer_class'] - if tokenizer_class == 'GPT2Tokenizer': + tokenizer_class = tokenizer_config_json["tokenizer_class"] + if tokenizer_class == "GPT2Tokenizer": return self._set_vocab_gpt2() from sentencepiece import SentencePieceProcessor - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" if not tokenizer_path.is_file(): - raise ValueError(f'Error: Missing {tokenizer_path}') + raise ValueError(f"Error: Missing {tokenizer_path}") tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -3190,7 +3861,7 @@ def set_vocab(self): scores[token_id] = score toktypes[token_id] = toktype - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) @@ -3198,31 +3869,37 @@ def set_vocab(self): for key in added_tokens_json: token_id = added_tokens_json[key] if token_id >= vocab_size: - logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.debug( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue tokens[token_id] = key.encode("utf-8") scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) - added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + added_tokens_decoder = tokenizer_config_json.get( + "added_tokens_decoder", {} + ) for token_id, foken_data in added_tokens_decoder.items(): token_id = int(token_id) token = foken_data["content"].encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: if tokens[token_id] != token: - logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + logger.warning( + f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}' + ) tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED if foken_data.get("special"): toktypes[token_id] = SentencePieceTokenTypes.CONTROL - tokenizer_file = self.dir_model / 'tokenizer.json' + tokenizer_file = self.dir_model / "tokenizer.json" if tokenizer_file.is_file(): with open(tokenizer_file, "r", encoding="utf-8") as f: tokenizer_json = json.load(f) @@ -3232,7 +3909,9 @@ def set_vocab(self): token = foken_data["content"].encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: if tokens[token_id] != token: - logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + logger.warning( + f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}' + ) tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -3263,7 +3942,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_context_length(max_pos_embds) self.gguf_writer.add_rope_scaling_orig_ctx_len(orig_max_pos_embds) self.gguf_writer.add_embedding_length(n_embd) - self.gguf_writer.add_feed_forward_length(self.find_hparam(["intermediate_size"])) + self.gguf_writer.add_feed_forward_length( + self.find_hparam(["intermediate_size"]) + ) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) @@ -3286,36 +3967,57 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_dims = int(rot_pct * n_embd) // n_head # write rope scaling for long context (128k) model - rope_scaling = self.find_hparam(['rope_scaling'], True) + rope_scaling = self.find_hparam(["rope_scaling"], True) if rope_scaling is None: return scale = max_pos_embds / orig_max_pos_embds - rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() + rope_scaling_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", "") + ).lower() if len(rope_scaling_type) == 0: - raise KeyError('Missing the required key rope_scaling.type') + raise KeyError("Missing the required key rope_scaling.type") - if rope_scaling_type == 'su' or rope_scaling_type == 'longrope': - attn_factor = math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) if scale > 1.0 else 1.0 - elif rope_scaling_type == 'yarn': + if rope_scaling_type == "su" or rope_scaling_type == "longrope": + attn_factor = ( + math.sqrt(1 + math.log(scale) / math.log(orig_max_pos_embds)) + if scale > 1.0 + else 1.0 + ) + elif rope_scaling_type == "yarn": attn_factor = 0.1 * math.log(scale) + 1.0 if scale > 1.0 else 1.0 else: - raise NotImplementedError(f'The rope scaling type {rope_scaling_type} is not supported yet') + raise NotImplementedError( + f"The rope scaling type {rope_scaling_type} is not supported yet" + ) self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) + long_factors = rope_scaling.get("long_factor", None) + short_factors = rope_scaling.get("short_factor", None) if long_factors is None or short_factors is None: - raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') + raise KeyError( + "Missing the required key rope_scaling.long_factor or rope_scaling_short_factor" + ) - if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2: - raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}. long_factors = {len(long_factors)}, short_factors = {len(short_factors)}.') + if ( + len(long_factors) != len(short_factors) + or len(long_factors) != rope_dims / 2 + ): + raise ValueError( + f"The length of rope long and short factors must be {rope_dims / 2}. long_factors = {len(long_factors)}, short_factors = {len(short_factors)}." + ) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), + torch.tensor(long_factors, dtype=torch.float32), + ) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), + torch.tensor(short_factors, dtype=torch.float32), + ) @ModelBase.register("PhiMoEForCausalLM") @@ -3329,7 +4031,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("block_sparse_moe.experts") != -1: n_experts = self.hparams["num_local_experts"] @@ -3354,7 +4058,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) - merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + merged_name = ( + f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + ) new_name = self.map_tensor_name(merged_name) @@ -3391,7 +4097,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(5) # hparams["num_key_value_heads"]) is wrong + self.gguf_writer.add_head_count_kv( + 5 + ) # hparams["num_key_value_heads"]) is wrong self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) self.gguf_writer.add_file_type(self.ftype) @@ -3409,7 +4117,9 @@ def shuffle_attn_output_weight(self, data_torch): data_torch = torch.reshape(data_torch, (5120, 5120)) return data_torch - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused new_name = self.map_tensor_name(name) @@ -3444,7 +4154,9 @@ def set_gguf_parameters(self): _has_tok_embd = False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) @@ -3453,10 +4165,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name = self.map_tensor_name(name) # assuming token_embd.weight is seen before output.weight - if not self._has_tok_embd and new_name == self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT): + if not self._has_tok_embd and new_name == self.format_tensor_name( + gguf.MODEL_TENSOR.OUTPUT + ): # even though the tensor file(s) does not contain the word embeddings they are still in the weight map if self.tensor_names and "transformer.wte.weight" in self.tensor_names: - logger.debug(f"{tok_embd_name} not found before {output_name}, assuming they are tied") + logger.debug( + f"{tok_embd_name} not found before {output_name}, assuming they are tied" + ) self.tensor_names.remove("transformer.wte.weight") elif new_name == tok_embd_name: self._has_tok_embd = True @@ -3476,24 +4192,26 @@ def set_vocab(self): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" tokens: list[bytes] = [] scores: list[float] = [] toktypes: list[int] = [] if not tokenizer_path.is_file(): - logger.error(f'Error: Missing {tokenizer_path}') + logger.error(f"Error: Missing {tokenizer_path}") sys.exit(1) - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model = ( + model.ModelProto() + ) # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) for token_id in range(vocab_size): piece = tokenizer.IdToPiece(token_id) @@ -3515,14 +4233,14 @@ def set_vocab(self): elif tokenizer.IsByte(token_id): toktype = SentencePieceTokenTypes.BYTE # take care of ununsed raw token - if piece.startswith('[UNUSED'): + if piece.startswith("[UNUSED"): toktype = SentencePieceTokenTypes.UNUSED tokens.append(text) scores.append(score) toktypes.append(toktype) - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) @@ -3532,14 +4250,16 @@ def set_vocab(self): scores.append(-1000.0) toktypes.append(SentencePieceTokenTypes.USER_DEFINED) - chat_eos_token = '<|im_end|>' + chat_eos_token = "<|im_end|>" chat_eos_token_id = None - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) - added_tokens_decoder = tokenizer_config_json.get("added_tokens_decoder", {}) + added_tokens_decoder = tokenizer_config_json.get( + "added_tokens_decoder", {} + ) for token_id, foken_data in added_tokens_decoder.items(): token_id = int(token_id) token = foken_data["content"] @@ -3548,14 +4268,16 @@ def set_vocab(self): token = token.encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: if tokens[token_id] != token: - logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + logger.warning( + f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}' + ) tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED if foken_data.get("special"): toktypes[token_id] = SentencePieceTokenTypes.CONTROL - tokenizer_file = self.dir_model / 'tokenizer.json' + tokenizer_file = self.dir_model / "tokenizer.json" if tokenizer_file.is_file(): with open(tokenizer_file, "r", encoding="utf-8") as f: tokenizer_json = json.load(f) @@ -3568,7 +4290,9 @@ def set_vocab(self): token = token.encode("utf-8") if toktypes[token_id] != SentencePieceTokenTypes.UNUSED: if tokens[token_id] != token: - logger.warning(f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}') + logger.warning( + f'replacing token {token_id}: {tokens[token_id].decode("utf-8")!r} -> {token.decode("utf-8")!r}' + ) tokens[token_id] = token scores[token_id] = -1000.0 toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED @@ -3589,8 +4313,10 @@ def set_vocab(self): # TODO: this is a hack, should be fixed # https://github.com/ggml-org/llama.cpp/pull/6745#issuecomment-2067687048 special_vocab.special_token_ids["eos"] = chat_eos_token_id - logger.warning(f"Replace eos:{old_eos} with a special token:{chat_eos_token_id}" - " in chat mode so that the conversation can end normally.") + logger.warning( + f"Replace eos:{old_eos} with a special token:{chat_eos_token_id}" + " in chat mode so that the conversation can end normally." + ) special_vocab.add_to_gguf(self.gguf_writer) @@ -3605,11 +4331,16 @@ def set_gguf_parameters(self): self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"]) self.gguf_writer.add_file_type(self.ftype) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: num_heads = self.hparams["num_attention_heads"] num_kv_heads = self.hparams["num_key_value_heads"] n_embd = self.hparams["hidden_size"] @@ -3617,7 +4348,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv - name = name.replace("language_model.", "") # InternVL + name = name.replace("language_model.", "") # InternVL if name.startswith("mlp") or name.startswith("vision_model"): # skip visual tensors return [] @@ -3626,11 +4357,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter qkv = data_torch qkv = qkv.reshape((num_groups, q_per_kv + 2, head_dim, n_embd)) - q, k, v = qkv[:, : q_per_kv], qkv[:, -2], qkv[:, -1] + q, k, v = qkv[:, :q_per_kv], qkv[:, -2], qkv[:, -1] # The model weights of q and k equire additional reshape. q = LlamaModel.permute(q.reshape((-1, q.shape[-1])), num_heads, num_heads) - k = LlamaModel.permute(k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads) + k = LlamaModel.permute( + k.reshape((-1, k.shape[-1])), num_heads, num_kv_heads + ) v = v.reshape((-1, v.shape[-1])) return [ @@ -3657,21 +4390,28 @@ def set_vocab(self): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) if "add_prefix_space" in tokenizer_config_json: - self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"]) + self.gguf_writer.add_add_space_prefix( + tokenizer_config_json["add_prefix_space"] + ) if "added_tokens_decoder" in tokenizer_config_json: - for token_id, token_data in tokenizer_config_json["added_tokens_decoder"].items(): + for token_id, token_data in tokenizer_config_json[ + "added_tokens_decoder" + ].items(): if token_data.get("special"): token_id = int(token_id) token = token_data["content"] special_vocab._set_special_token(token, token_id) # update eos token - if token == '<|im_end|>' and "eos" in special_vocab.special_token_ids: + if ( + token == "<|im_end|>" + and "eos" in special_vocab.special_token_ids + ): special_vocab.special_token_ids["eos"] = token_id special_vocab.add_to_gguf(self.gguf_writer) @@ -3686,14 +4426,19 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") - name = name.replace("language_model.", "") # InternVL + name = name.replace("language_model.", "") # InternVL if name.startswith("mlp") or name.startswith("vision_model"): # skip visual tensors return [] @@ -3704,7 +4449,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] -@ModelBase.register("BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification") +@ModelBase.register( + "BertModel", "BertForMaskedLM", "CamembertModel", "BertForSequenceClassification" +) class BertModel(TextModel): model_arch = gguf.MODEL_ARCH.BERT @@ -3724,7 +4471,9 @@ def set_gguf_parameters(self): self._try_set_pooling_type() if self.cls_out_labels: - self.gguf_writer.add_classifier_output_labels([v for k, v in sorted(self.cls_out_labels.items())]) + self.gguf_writer.add_classifier_output_labels( + [v for k, v in sorted(self.cls_out_labels.items())] + ) def set_vocab(self): tokens, toktypes, tokpre = self.get_vocab_base() @@ -3742,6 +4491,7 @@ def phantom(tok): if tok.startswith("##"): return tok[2:] return "\u2581" + tok + tokens = list(map(phantom, tokens)) # add vocab to gguf @@ -3754,7 +4504,9 @@ def phantom(tok): special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens)) special_vocab.add_to_gguf(self.gguf_writer) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("bert."): @@ -3767,8 +4519,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name[:-5] + ".bias" # we are only using BERT for embeddings so we don't need the pooling layer - if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): - return [] # we don't need these + if name in ( + "embeddings.position_ids", + "pooler.dense.weight", + "pooler.dense.bias", + ): + return [] # we don't need these if name.startswith("cls.predictions"): return [] @@ -3802,19 +4558,20 @@ def _xlmroberta_set_vocab(self) -> None: from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'sentencepiece.bpe.model' + tokenizer_path = self.dir_model / "sentencepiece.bpe.model" tokenizer_json = {} tokenizer_config_json = {} if not tokenizer_path.is_file(): - tokenizer_path = self.dir_model / 'tokenizer.json' - tokenizer_config_path = self.dir_model / 'tokenizer_config.json' + tokenizer_path = self.dir_model / "tokenizer.json" + tokenizer_config_path = self.dir_model / "tokenizer_config.json" if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") from base64 import b64decode from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) with open(tokenizer_path, "r", encoding="utf-8") as fp: @@ -3826,17 +4583,25 @@ def _xlmroberta_set_vocab(self) -> None: add_prefix = tokenizer.add_prefix_space remove_whitespaces = tokenizer.clean_up_tokenization_spaces - precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"]) + precompiled_charsmap = b64decode( + tokenizer_json["normalizer"]["precompiled_charsmap"] + ) vocab_size = max(self.hparams.get("vocab_size", 0), tokenizer.vocab_size) else: - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model = ( + model.ModelProto() + ) # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces - precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap + remove_whitespaces = ( + sentencepiece_model.normalizer_spec.remove_extra_whitespaces + ) + precompiled_charsmap = ( + sentencepiece_model.normalizer_spec.precompiled_charsmap + ) tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) @@ -3869,7 +4634,9 @@ def _xlmroberta_set_vocab(self) -> None: else: added_vocab = tokenizer.get_added_vocab() unk_token = tokenizer_config_json.get("unk_token") - unk_token_id = added_vocab.get(unk_token, tokenizer_json["model"].get("unk_id", 3)) + unk_token_id = added_vocab.get( + unk_token, tokenizer_json["model"].get("unk_id", 3) + ) for token_id in range(tokenizer.vocab_size): piece = tokenizer._convert_id_to_token(token_id) @@ -3894,7 +4661,7 @@ def _xlmroberta_set_vocab(self) -> None: if isinstance(tokenizer, SentencePieceProcessor): # realign tokens (see HF tokenizer code) - tokens = [b'', b'', b'', b''] + tokens[3:-1] + tokens = [b"", b"", b"", b""] + tokens[3:-1] scores = [0.0, 0.0, 0.0, 0.0] + scores[3:-1] toktypes = [ SentencePieceTokenTypes.CONTROL, @@ -3905,7 +4672,7 @@ def _xlmroberta_set_vocab(self) -> None: if self.model_arch == gguf.MODEL_ARCH.NOMIC_BERT_MOE: # Add mask token missing from sentencepiece.bpe.model - tokens[250001] = b'' + tokens[250001] = b"" scores[250001] = 0.0 toktypes[250001] = SentencePieceTokenTypes.CONTROL @@ -3924,7 +4691,9 @@ def _xlmroberta_set_vocab(self) -> None: special_vocab.add_to_gguf(self.gguf_writer) -@ModelBase.register("DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification") +@ModelBase.register( + "DistilBertModel", "DistilBertForMaskedLM", "DistilBertForSequenceClassification" +) class DistilBertModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -3933,7 +4702,9 @@ def set_gguf_parameters(self): logger.info("gguf: layer norm epsilon = 1e-12") super().set_gguf_parameters() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: if name.startswith("distilbert."): name = name[11:] @@ -3968,12 +4739,16 @@ def set_vocab(self): # we need this to validate the size of the token_type embeddings # though currently we are passing all zeros to the token_type embeddings # "Sequence A" or "Sequence B" - self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) + self.gguf_writer.add_token_type_count( + self.hparams.get("type_vocab_size", 1) + ) else: return super().set_vocab() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # if name starts with "roberta.", remove the prefix # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main if name.startswith("roberta."): @@ -3982,7 +4757,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: - data_torch = data_torch[self._position_offset:,:] + data_torch = data_torch[self._position_offset :, :] return super().modify_tensors(data_torch, name, bid) @@ -3991,13 +4766,19 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class NomicBertModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any): + def __init__( + self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any + ): hparams = kwargs.pop("hparams", None) if hparams is None: hparams = ModelBase.load_hparams(dir_model) self.is_moe = bool(hparams.get("moe_every_n_layers")) - self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT + self.model_arch = ( + gguf.MODEL_ARCH.NOMIC_BERT_MOE + if self.is_moe + else gguf.MODEL_ARCH.NOMIC_BERT + ) super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs) @@ -4005,22 +4786,32 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, if self._tokenizer_is_xlmroberta: self._xlmroberta_tokenizer_init() - npos, mtp = self.hparams["n_positions"], self.hparams.get("max_trained_positions", 2048) + npos, mtp = self.hparams["n_positions"], self.hparams.get( + "max_trained_positions", 2048 + ) if npos == 8192 and mtp == 2048: - self.hparams["n_positions"] = 2048 # nomic-embed-text v1 and v1.5 are trained for 2048 tokens. + self.hparams["n_positions"] = ( + 2048 # nomic-embed-text v1 and v1.5 are trained for 2048 tokens. + ) elif npos == 2048 and mtp == 2048: - self.hparams["n_positions"] = 512 # nomic-embed-text-v2-moe is trained for 512 tokens. + self.hparams["n_positions"] = ( + 512 # nomic-embed-text-v2-moe is trained for 512 tokens. + ) else: - raise ValueError(f"unrecognized parameters: n_positions={npos}, max_trained_positions={mtp}") + raise ValueError( + f"unrecognized parameters: n_positions={npos}, max_trained_positions={mtp}" + ) - assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu" + assert ( + self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu" + ) # this doesn't do anything in the HF version assert self.hparams["causal"] is False # no bias tensors unless MoE assert self.hparams["qkv_proj_bias"] == self.is_moe - assert self.hparams["mlp_fc1_bias"] == self.is_moe - assert self.hparams["mlp_fc2_bias"] == self.is_moe + assert self.hparams["mlp_fc1_bias"] == self.is_moe + assert self.hparams["mlp_fc2_bias"] == self.is_moe # norm at end of layer assert self.hparams["prenorm"] is False @@ -4034,17 +4825,27 @@ def set_vocab(self) -> None: return self._xlmroberta_set_vocab() return super().set_vocab() - def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]: + def modify_tensors( + self, data_torch: torch.Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, torch.Tensor]]: # If the tensor is an experts bias tensor, skip it by returning an empty list. if "mlp.experts.bias" in name: return [] # Explicitly return an empty list. if "mlp.experts.mlp.w1" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view( + self.hparams["num_experts"], + self.hparams["n_inner"], + self.hparams["n_embd"], + ) name += ".weight" if "mlp.experts.mlp.w2" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view( + self.hparams["num_experts"], + self.hparams["n_inner"], + self.hparams["n_embd"], + ) data_torch = data_torch.transpose(1, 2) name += ".weight" @@ -4080,7 +4881,9 @@ def __init__(self, *args, **kwargs): def set_vocab(self): self._xlmroberta_set_vocab() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # if name starts with "roberta.", remove the prefix # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main if name.startswith("roberta."): @@ -4089,7 +4892,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # position embeddings start at pad_token_id + 1, so just chop down the weight tensor if name == "embeddings.position_embeddings.weight": if self._position_offset is not None: - data_torch = data_torch[self._position_offset:,:] + data_torch = data_torch[self._position_offset :, :] return super().modify_tensors(data_torch, name, bid) @@ -4102,13 +4905,16 @@ def set_vocab(self): self._set_vocab_sentencepiece() # TODO: these special tokens should be exported only for the CodeGemma family - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False, - special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot']) + special_vocab = gguf.SpecialVocab( + self.dir_model, + load_merges=False, + special_token_types=["prefix", "suffix", "middle", "fsep", "eot"], + ) special_vocab._set_special_token("prefix", 67) special_vocab._set_special_token("suffix", 69) special_vocab._set_special_token("middle", 68) - special_vocab._set_special_token("fsep", 70) - special_vocab._set_special_token("eot", 107) + special_vocab._set_special_token("fsep", 70) + special_vocab._set_special_token("eot", 107) special_vocab.chat_template = None # do not add it twice special_vocab.add_to_gguf(self.gguf_writer) @@ -4123,19 +4929,27 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv( + self.hparams["num_key_value_heads"] + if "num_key_value_heads" in hparams + else hparams["num_attention_heads"] + ) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused # lm_head is not used in llama.cpp, while autoawq will include this tensor in model # To prevent errors, skip loading lm_head.weight. if name == "lm_head.weight": - logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + logger.debug( + f"Skipping get tensor {name!r} in safetensors so that convert can end normally." + ) return [] # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 @@ -4163,7 +4977,11 @@ def set_gguf_parameters(self): self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) - self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv( + self.hparams["num_key_value_heads"] + if "num_key_value_heads" in hparams + else hparams["num_attention_heads"] + ) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) @@ -4176,13 +4994,17 @@ def set_gguf_parameters(self): ) self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused # lm_head is not used in llama.cpp, while autoawq will include this tensor in model # To prevent errors, skip loading lm_head.weight. if name == "lm_head.weight": - logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.") + logger.debug( + f"Skipping get tensor {name!r} in safetensors so that convert can end normally." + ) return [] # ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89 @@ -4206,7 +5028,9 @@ def set_gguf_parameters(self): block_count = hparams["num_hidden_layers"] # some default values are not specified in the hparams - self.gguf_writer.add_context_length(hparams.get("max_position_embeddings", 131072)) + self.gguf_writer.add_context_length( + hparams.get("max_position_embeddings", 131072) + ) self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) @@ -4215,7 +5039,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(hparams.get("head_dim", 256)) self.gguf_writer.add_value_length(hparams.get("head_dim", 256)) self.gguf_writer.add_file_type(self.ftype) - self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1_000_000.0)) # for global layers + self.gguf_writer.add_rope_freq_base( + hparams.get("rope_theta", 1_000_000.0) + ) # for global layers # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3 assert hparams.get("attn_logit_softcapping") is None assert hparams.get("final_logit_softcapping") is None @@ -4227,21 +5053,27 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("language_model."): name = name.replace("language_model.", "") - elif name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ - or name.startswith("multimodal_projector.") or name.startswith("vision_model."): - return [] # skip vision tensors + elif ( + name.startswith("multi_modal_projector.") + or name.startswith("vision_tower.") + or name.startswith("multimodal_projector.") + or name.startswith("vision_model.") + ): + return [] # skip vision tensors # remove OOV (out-of-vocabulary) rows in token_embd if "embed_tokens.weight" in name: vocab = self._create_vocab_sentencepiece() tokens = vocab[0] - data_torch = data_torch[:len(tokens)] + data_torch = data_torch[: len(tokens)] # ref code in Gemma3RMSNorm # output = output * (1.0 + self.weight.float()) @@ -4258,11 +5090,13 @@ def set_gguf_parameters(self): hparams = self.hparams self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3) # default values below are taken from HF tranformers code - self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) + self.gguf_writer.add_vision_attention_layernorm_eps( + hparams.get("layer_norm_eps", 1e-6) + ) self.gguf_writer.add_vision_use_gelu(True) # calculate proj_scale_factor (used by tinygemma3 test model) image_seq_length = self.preprocessor_config.get("image_seq_length", 256) - n_per_side = int(image_seq_length ** 0.5) + n_per_side = int(image_seq_length**0.5) image_size = self.hparams["image_size"] patch_size = self.hparams["patch_size"] proj_scale_factor = (image_size // patch_size) // n_per_side @@ -4280,14 +5114,20 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F32 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if "vision_model.head." in name: - return [] # skip redundant tensors for tinygemma3 - - if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \ - or name.startswith("multimodal_projector.") or name.startswith("vision_model."): + return [] # skip redundant tensors for tinygemma3 + + if ( + name.startswith("multi_modal_projector.") + or name.startswith("vision_tower.") + or name.startswith("multimodal_projector.") + or name.startswith("vision_model.") + ): # process vision tensors name = name.replace("_weight", ".weight") @@ -4300,7 +5140,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] - return [] # skip other tensors + return [] # skip other tensors @ModelBase.register("Starcoder2ForCausalLM") @@ -4321,7 +5161,11 @@ def set_gguf_parameters(self): hidden_size = self.hparams["hidden_size"] layer_norm_eps = self.hparams["layer_norm_epsilon"] rescale_every_n_layers = self.hparams["rescale_every"] - intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else int((hidden_size * 3.5) // 32 * 32) + intermediate_size = ( + self.hparams["intermediate_size"] + if self.hparams["intermediate_size"] is not None + else int((hidden_size * 3.5) // 32 * 32) + ) time_mix_extra_dim = 64 if hidden_size == 4096 else 32 time_decay_extra_dim = 128 if hidden_size == 4096 else 64 @@ -4342,13 +5186,19 @@ def set_gguf_parameters(self): lerp_weights: dict[int, dict[str, Tensor]] = {} - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) if not (new_name.endswith(".weight") or new_name.endswith(".bias")): new_name += ".weight" - if new_name.endswith("time_mix_w1.weight") or new_name.endswith("time_mix_decay_w1.weight") or new_name.endswith("time_mix_decay_w2.weight"): + if ( + new_name.endswith("time_mix_w1.weight") + or new_name.endswith("time_mix_decay_w1.weight") + or new_name.endswith("time_mix_decay_w2.weight") + ): data_torch = data_torch.transpose(0, 1) if new_name.endswith("time_mix_w2.weight"): @@ -4360,21 +5210,40 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter try: rescale_every_n_layers = self.hparams["rescale_every"] if rescale_every_n_layers > 0: - if new_name.endswith("time_mix_output.weight") or new_name.endswith("channel_mix_value.weight"): - data_torch = data_torch.div_(2 ** int(bid // rescale_every_n_layers)) + if new_name.endswith("time_mix_output.weight") or new_name.endswith( + "channel_mix_value.weight" + ): + data_torch = data_torch.div_( + 2 ** int(bid // rescale_every_n_layers) + ) except KeyError: pass # concat time_mix_lerp weights to reduce some cpu overhead # also reduces the number of tensors in the model - if bid is not None and "time_mix_lerp" in new_name and "time_mix_lerp_x" not in new_name: + if ( + bid is not None + and "time_mix_lerp" in new_name + and "time_mix_lerp_x" not in new_name + ): try: self.lerp_weights[bid][new_name] = data_torch except KeyError: self.lerp_weights[bid] = {new_name: data_torch} - if all(f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() for i in ["w", "k", "v", "r", "g"]): + if all( + f"blk.{bid}.time_mix_lerp_{i}.weight" in self.lerp_weights[bid].keys() + for i in ["w", "k", "v", "r", "g"] + ): new_name = f"blk.{bid}.time_mix_lerp_fused.weight" - data = torch.stack([self.lerp_weights[bid][f"blk.{bid}.time_mix_lerp_{i}.weight"].unsqueeze(0) for i in ["w", "k", "v", "r", "g"]], dim=0).unsqueeze(1) + data = torch.stack( + [ + self.lerp_weights[bid][ + f"blk.{bid}.time_mix_lerp_{i}.weight" + ].unsqueeze(0) + for i in ["w", "k", "v", "r", "g"] + ], + dim=0, + ).unsqueeze(1) yield (new_name, data) return @@ -4399,8 +5268,12 @@ def set_gguf_parameters(self): head_size = hidden_size // num_attention_heads rms_norm_eps = self.hparams["rms_norm_eps"] intermediate_size = self.hparams["intermediate_size"] - time_mix_extra_dim = self.hparams.get("lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32) - time_decay_extra_dim = self.hparams.get("lora_rank_decay", 128 if hidden_size >= 4096 else 64) + time_mix_extra_dim = self.hparams.get( + "lora_rank_tokenshift", 64 if hidden_size >= 4096 else 32 + ) + time_decay_extra_dim = self.hparams.get( + "lora_rank_decay", 128 if hidden_size >= 4096 else 64 + ) # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) @@ -4421,13 +5294,17 @@ def set_gguf_parameters(self): # required by llama.cpp, unused self.gguf_writer.add_head_count(0) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: for new_name, data in super().modify_tensors(data_torch, name, bid): if "time_mix_w1" in new_name or "time_mix_w2" in new_name: data = data.view(5, -1, data.shape[-1]) # rwkv6qwen2 has a different order of rkvwg instead of the original wkvrg # permute them here to avoid code changes - data = torch.stack([data[3], data[1], data[2], data[0], data[4]], dim=0).view(-1, data.shape[-1]) + data = torch.stack( + [data[3], data[1], data[2], data[0], data[4]], dim=0 + ).view(-1, data.shape[-1]) if "w2" in new_name: data = data.view(5, -1, data.shape[-1]) yield (new_name, data) @@ -4443,7 +5320,7 @@ def set_vocab(self): self._set_vocab_rwkv_world() def calc_lora_rank(self, hidden_size, exponent, multiplier): - return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 + return max(1, round(hidden_size**exponent * multiplier / 32)) * 32 def set_gguf_parameters(self): block_count = self.hparams["num_hidden_layers"] @@ -4454,19 +5331,55 @@ def set_gguf_parameters(self): head_size = self.hparams["head_dim"] layer_norm_eps = self.hparams["norm_eps"] hidden_size = self.hparams["hidden_size"] - intermediate_size = self.hparams["intermediate_size"] if self.hparams["intermediate_size"] is not None else (hidden_size * 4) + intermediate_size = ( + self.hparams["intermediate_size"] + if self.hparams["intermediate_size"] is not None + else (hidden_size * 4) + ) # ICLR: In-Context-Learning-Rate try: - lora_rank_decay = self.hparams["lora_rank_decay"] if self.hparams["lora_rank_decay"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) - lora_rank_iclr = self.hparams["lora_rank_iclr"] if self.hparams["lora_rank_iclr"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) - lora_rank_value_residual_mix = self.hparams["lora_rank_value_residual_mix"] if self.hparams["lora_rank_value_residual_mix"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3) - lora_rank_gate = self.hparams["lora_rank_gate"] if self.hparams["lora_rank_gate"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6) + lora_rank_decay = ( + self.hparams["lora_rank_decay"] + if self.hparams["lora_rank_decay"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.8) + ) + lora_rank_iclr = ( + self.hparams["lora_rank_iclr"] + if self.hparams["lora_rank_iclr"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.8) + ) + lora_rank_value_residual_mix = ( + self.hparams["lora_rank_value_residual_mix"] + if self.hparams["lora_rank_value_residual_mix"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.3) + ) + lora_rank_gate = ( + self.hparams["lora_rank_gate"] + if self.hparams["lora_rank_gate"] is not None + else self.calc_lora_rank(hidden_size, 0.8, 0.6) + ) except KeyError: - lora_rank_decay = self.hparams["decay_low_rank_dim"] if self.hparams["decay_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) - lora_rank_iclr = self.hparams["a_low_rank_dim"] if self.hparams["a_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.8) - lora_rank_value_residual_mix = self.hparams["v_low_rank_dim"] if self.hparams["v_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.5, 1.3) - lora_rank_gate = self.hparams["gate_low_rank_dim"] if self.hparams["gate_low_rank_dim"] is not None else self.calc_lora_rank(hidden_size, 0.8, 0.6) + lora_rank_decay = ( + self.hparams["decay_low_rank_dim"] + if self.hparams["decay_low_rank_dim"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.8) + ) + lora_rank_iclr = ( + self.hparams["a_low_rank_dim"] + if self.hparams["a_low_rank_dim"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.8) + ) + lora_rank_value_residual_mix = ( + self.hparams["v_low_rank_dim"] + if self.hparams["v_low_rank_dim"] is not None + else self.calc_lora_rank(hidden_size, 0.5, 1.3) + ) + lora_rank_gate = ( + self.hparams["gate_low_rank_dim"] + if self.hparams["gate_low_rank_dim"] is not None + else self.calc_lora_rank(hidden_size, 0.8, 0.6) + ) # RWKV isn't context limited self.gguf_writer.add_context_length(1048576) @@ -4487,7 +5400,9 @@ def set_gguf_parameters(self): lerp_weights: dict[int, dict[str, Tensor]] = {} lora_needs_transpose: bool = True - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # unify tensor names here to make life easier name = name.replace("blocks", "layers").replace("ffn", "feed_forward") name = name.replace("self_attn", "attention").replace("attn", "attention") @@ -4502,13 +5417,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.replace("feed_forward_norm", "ln2") name = name.replace("g_norm", "ln_x") - if "attention.v" in name and "value" not in self.map_tensor_name(name) and bid == 0: + if ( + "attention.v" in name + and "value" not in self.map_tensor_name(name) + and bid == 0 + ): # some models have dummy v0/v1/v2 on first layer while others don't # ignore them all since they are not used return wkv_has_gate = self.hparams.get("wkv_has_gate", True) - lerp_list = ["r", "w", "k", "v", "a", "g"] if wkv_has_gate else ["r", "w", "k", "v", "a"] + lerp_list = ( + ["r", "w", "k", "v", "a", "g"] + if wkv_has_gate + else ["r", "w", "k", "v", "a"] + ) if bid is not None and "attention.x_" in name: if "attention.x_x" in name: @@ -4521,9 +5444,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter self.lerp_weights[bid][name] = data_torch except KeyError: self.lerp_weights[bid] = {name: data_torch} - if all(f"model.layers.{bid}.attention.x_{i}" in self.lerp_weights[bid].keys() for i in lerp_list): + if all( + f"model.layers.{bid}.attention.x_{i}" + in self.lerp_weights[bid].keys() + for i in lerp_list + ): new_name = f"blk.{bid}.time_mix_lerp_fused.weight" - data = torch.stack([self.lerp_weights[bid][f"model.layers.{bid}.attention.x_{i}"] for i in lerp_list], dim=0) + data = torch.stack( + [ + self.lerp_weights[bid][ + f"model.layers.{bid}.attention.x_{i}" + ] + for i in lerp_list + ], + dim=0, + ) yield (new_name, data) return else: @@ -4534,16 +5469,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter new_name += ".weight" if self.lora_needs_transpose and any( - new_name.endswith(t) for t in [ - "time_mix_w1.weight", "time_mix_w2.weight", - "time_mix_a1.weight", "time_mix_a2.weight", - "time_mix_v1.weight", "time_mix_v2.weight", - "time_mix_g1.weight", "time_mix_g2.weight", + new_name.endswith(t) + for t in [ + "time_mix_w1.weight", + "time_mix_w2.weight", + "time_mix_a1.weight", + "time_mix_a2.weight", + "time_mix_v1.weight", + "time_mix_v2.weight", + "time_mix_g1.weight", + "time_mix_g2.weight", ] ): data_torch = data_torch.transpose(0, 1) - if 'r_k' in new_name: + if "r_k" in new_name: data_torch = data_torch.flatten() if bid == 0 and "time_mix_a" in new_name: @@ -4619,15 +5559,23 @@ def set_vocab(self): self._set_vocab_builtin("gpt-neox", vocab_size) def set_gguf_parameters(self): - d_model = self.find_hparam(["hidden_size", "d_model"]) - d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 - d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model - d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 + d_model = self.find_hparam(["hidden_size", "d_model"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = ( + self.find_hparam(["intermediate_size", "d_inner"], optional=True) + or 2 * d_model + ) + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16 # ceiling division # ref: https://stackoverflow.com/a/17511341/22827863 # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 - dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16) - rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -( + d_model // -16 + ) + rms_norm_eps = ( + self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) + or 1e-5 + ) use_dt_b_c_norm = False # For falconmamba we do apply RMS norm on B / DT and C layers if self.find_hparam(["model_type"], optional=True) in ("falcon_mamba",): @@ -4635,22 +5583,32 @@ def set_gguf_parameters(self): # Fail early for models which don't have a block expansion factor of 2 assert d_inner == 2 * d_model - self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_context_length( + 2**20 + ) # arbitrary value; for those who use the default self.gguf_writer.add_embedding_length(d_model) - self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading - self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_feed_forward_length( + 0 + ) # unused, but seemingly required when loading + self.gguf_writer.add_head_count( + 0 + ) # unused, but seemingly required when loading self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_ssm_conv_kernel(d_conv) self.gguf_writer.add_ssm_inner_size(d_inner) self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) - self.gguf_writer.add_ssm_dt_b_c_rms(use_dt_b_c_norm) # For classic Mamba we don't apply rms norm on B / DT layers + self.gguf_writer.add_ssm_dt_b_c_rms( + use_dt_b_c_norm + ) # For classic Mamba we don't apply rms norm on B / DT layers self.gguf_writer.add_file_type(self.ftype) _tok_embd = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) tok_embd_name = self.format_tensor_name(gguf.MODEL_TENSOR.TOKEN_EMBD) @@ -4667,7 +5625,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # assuming token_embd.weight is seen before output.weight if self._tok_embd is not None and new_name == output_name: if torch.equal(self._tok_embd, data_torch): - logger.debug(f"{output_name} is equivalent to {tok_embd_name}, omitting") + logger.debug( + f"{output_name} is equivalent to {tok_embd_name}, omitting" + ) return [] elif new_name == tok_embd_name: self._tok_embd = data_torch @@ -4685,7 +5645,9 @@ def __init__(self, *args, **kwargs): # max_position_embeddings = 8192 in config.json but model was actually # trained on 128k context length # aya-23 models don't have model_max_length specified - self.hparams["max_position_embeddings"] = self.find_hparam(["model_max_length", "max_position_embeddings"]) + self.hparams["max_position_embeddings"] = self.find_hparam( + ["model_max_length", "max_position_embeddings"] + ) def set_gguf_parameters(self): super().set_gguf_parameters() @@ -4707,7 +5669,9 @@ def set_gguf_parameters(self): rotary_pct = self.hparams["rotary_pct"] hidden_size = self.hparams["hidden_size"] num_attention_heads = self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) + self.gguf_writer.add_rope_dimension_count( + int(rotary_pct * (hidden_size // num_attention_heads)) + ) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) @@ -4725,7 +5689,9 @@ def set_gguf_parameters(self): # Same as super class, but permuting q_proj, k_proj # Copied from: LlamaModel - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused n_head = self.hparams["num_attention_heads"] @@ -4757,7 +5723,9 @@ def set_gguf_parameters(self): _experts: list[dict[str, Tensor]] | None = None # Copied from: Qwen2MoeModel - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: n_experts = self.hparams["num_experts"] @@ -4809,17 +5777,19 @@ class JinaBertV2Model(BertModel): model_arch = gguf.MODEL_ARCH.JINA_BERT_V2 def set_vocab(self): - tokenizer_class = 'BertTokenizer' + tokenizer_class = "BertTokenizer" with open(self.dir_model / "tokenizer_config.json", "r", encoding="utf-8") as f: - tokenizer_class = json.load(f)['tokenizer_class'] + tokenizer_class = json.load(f)["tokenizer_class"] - if tokenizer_class == 'BertTokenizer': + if tokenizer_class == "BertTokenizer": super().set_vocab() - elif tokenizer_class == 'RobertaTokenizer': + elif tokenizer_class == "RobertaTokenizer": self._set_vocab_gpt2() self.gguf_writer.add_token_type_count(2) else: - raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel') + raise NotImplementedError( + f"Tokenizer {tokenizer_class} is not supported for JinaBertModel" + ) @ModelBase.register("OpenELMForCausalLM") @@ -4847,8 +5817,12 @@ def __init__(self, *args, **kwargs): OpenELMModel._make_divisible(multiplier * self._n_embd, ffn_dim_divisor) for multiplier in ffn_multipliers ] - assert isinstance(self._num_kv_heads, list) and isinstance(self._num_kv_heads[0], int) - assert isinstance(self._num_query_heads, list) and isinstance(self._num_query_heads[0], int) + assert isinstance(self._num_kv_heads, list) and isinstance( + self._num_kv_heads[0], int + ) + assert isinstance(self._num_query_heads, list) and isinstance( + self._num_query_heads[0], int + ) # Uses the tokenizer from meta-llama/Llama-2-7b-hf def set_vocab(self): @@ -4885,13 +5859,21 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: return super().find_hparam(keys, optional) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # split ff if bid is not None and name == f"transformer.layers.{bid}.ffn.proj_1.weight": ff_dim = self._ffn_dims[bid] - yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), data_torch[:ff_dim]) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), data_torch[ff_dim:]) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), + data_torch[:ff_dim], + ) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), + data_torch[ff_dim:], + ) return yield (self.map_tensor_name(name), data_torch) @@ -4907,17 +5889,17 @@ def set_vocab(self): # tokenizer.model and used them as BOS and EOS instead of adding new tokens. from sentencepiece import SentencePieceProcessor - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" if not tokenizer_path.is_file(): - logger.error(f'Error: Missing {tokenizer_path}') + logger.error(f"Error: Missing {tokenizer_path}") sys.exit(1) # Read the whole vocabulary from the tokenizer.model file tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -4945,7 +5927,7 @@ def set_vocab(self): # Use the added_tokens_decoder field from tokeniser_config.json as the source # of information about added/redefined tokens and modify them accordingly. - tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + tokenizer_config_file = self.dir_model / "tokenizer_config.json" if tokenizer_config_file.is_file(): with open(tokenizer_config_file, "r", encoding="utf-8") as f: tokenizer_config_json = json.load(f) @@ -4955,7 +5937,9 @@ def set_vocab(self): for token_id, token_json in added_tokens_decoder.items(): token_id = int(token_id) if token_id >= vocab_size: - logger.debug(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.debug( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue token_content = token_json["content"] @@ -4971,7 +5955,9 @@ def set_vocab(self): token_type = SentencePieceTokenTypes.CONTROL token_score = 0.0 - logger.info(f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})") + logger.info( + f"Setting added token {token_id} to '{token_content}' (type: {token_type}, score: {token_score:.2f})" + ) tokens[token_id] = token_content.encode("utf-8") toktypes[token_id] = token_type scores[token_id] = token_score @@ -4989,11 +5975,15 @@ def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count( + hparams["hidden_size"] // hparams["num_attention_heads"] + ) _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -5068,7 +6058,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_feed_forward_length( + hparams["moe_intermediate_size"] + ) self.gguf_writer.add_expert_weights_scale(1.0) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) @@ -5079,11 +6071,17 @@ def set_gguf_parameters(self): def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") @@ -5160,12 +6158,18 @@ def set_gguf_parameters(self): self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) # note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - self.gguf_writer.add_key_length(hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_key_length( + hparams["kv_lora_rank"] + hparams["qk_rope_head_dim"] + ) self.gguf_writer.add_value_length(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_key_length_mla( + hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"] + ) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_feed_forward_length( + hparams["moe_intermediate_size"] + ) self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) @@ -5176,20 +6180,31 @@ def set_gguf_parameters(self): elif hparams["scoring_func"] == "softmax": self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) else: - raise ValueError(f"Unsupported scoring_func value: {hparams['scoring_func']}") + raise ValueError( + f"Unsupported scoring_func value: {hparams['scoring_func']}" + ) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) - self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * rope_scaling["mscale_all_dim"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) + self.gguf_writer.add_rope_scaling_yarn_log_mul( + 0.1 * rope_scaling["mscale_all_dim"] + ) _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -5244,13 +6259,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim) - kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1]) + kv_b = data_torch.view( + n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1] + ) k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1) k_b = k_b.transpose(1, 2) return [ (self.map_tensor_name(name_kb), k_b), - (self.map_tensor_name(name_vb), v_b) + (self.map_tensor_name(name_vb), v_b), ] return [(self.map_tensor_name(name), data_torch)] @@ -5277,11 +6294,15 @@ def set_gguf_parameters(self): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"]) - self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) + self.gguf_writer.add_key_length( + hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"] + ) self.gguf_writer.add_value_length(hparams["v_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: return [(self.map_tensor_name(name), data_torch)] def prepare_tensors(self): @@ -5306,34 +6327,38 @@ def set_vocab(self): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" # many older models use spiece.model tokenizer model filename if not tokenizer_path.is_file(): - tokenizer_path = self.dir_model / 'spiece.model' + tokenizer_path = self.dir_model / "spiece.model" if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model = ( + model.ModelProto() + ) # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) # some models like Pile-T5 family use BPE tokenizer instead of Unigram if sentencepiece_model.trainer_spec.model_type == 2: # BPE # assure the tokenizer model file name is correct - assert tokenizer_path.name == 'tokenizer.model' + assert tokenizer_path.name == "tokenizer.model" return self._set_vocab_sentencepiece() else: assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + remove_whitespaces = ( + sentencepiece_model.normalizer_spec.remove_extra_whitespaces + ) precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -5358,14 +6383,16 @@ def set_vocab(self): scores[token_id] = score toktypes[token_id] = toktype - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) for key in added_tokens_json: token_id = added_tokens_json[key] if token_id >= vocab_size: - logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.warning( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue tokens[token_id] = key.encode("utf-8") @@ -5374,7 +6401,9 @@ def set_vocab(self): if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + logger.debug( + f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]" + ) for i in range(1, pad_count + 1): tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) scores.append(-1000.0) @@ -5395,7 +6424,9 @@ def set_vocab(self): def set_gguf_parameters(self): if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: - logger.warning("Couldn't find context length in config.json, assuming default value of 512") + logger.warning( + "Couldn't find context length in config.json, assuming default value of 512" + ) n_ctx = 512 self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) @@ -5405,24 +6436,36 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) - self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_relative_attn_buckets_count( + self.hparams["relative_attention_num_buckets"] + ) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) - self.gguf_writer.add_decoder_start_token_id(self.hparams["decoder_start_token_id"]) + self.gguf_writer.add_decoder_start_token_id( + self.hparams["decoder_start_token_id"] + ) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder # and decoder and ignore the remaining ones. - if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if name in [ + "decoder.embed_tokens.weight", + "encoder.embed_tokens.weight", + "shared.weight", + ]: if not self.shared_token_embeddings_found: name = "shared.weight" self.shared_token_embeddings_found = True else: - logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + logger.debug( + f"Skipping shared tensor {name!r} in safetensors so that convert can end normally." + ) return [] return [(self.map_tensor_name(name), data_torch)] @@ -5443,34 +6486,38 @@ def set_vocab(self): from sentencepiece import SentencePieceProcessor from sentencepiece import sentencepiece_model_pb2 as model - tokenizer_path = self.dir_model / 'tokenizer.model' + tokenizer_path = self.dir_model / "tokenizer.model" # many older models use spiece.model tokenizer model filename if not tokenizer_path.is_file(): - tokenizer_path = self.dir_model / 'spiece.model' + tokenizer_path = self.dir_model / "spiece.model" if not tokenizer_path.is_file(): raise FileNotFoundError(f"File not found: {tokenizer_path}") - sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue] + sentencepiece_model = ( + model.ModelProto() + ) # pyright: ignore[reportAttributeAccessIssue] sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read()) # some models like Pile-T5 family use BPE tokenizer instead of Unigram if sentencepiece_model.trainer_spec.model_type == 2: # BPE # assure the tokenizer model file name is correct - assert tokenizer_path.name == 'tokenizer.model' + assert tokenizer_path.name == "tokenizer.model" return self._set_vocab_sentencepiece() else: assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix - remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces + remove_whitespaces = ( + sentencepiece_model.normalizer_spec.remove_extra_whitespaces + ) precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap tokenizer = SentencePieceProcessor() tokenizer.LoadFromFile(str(tokenizer_path)) - vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size()) + vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size()) tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)] scores: list[float] = [-10000.0] * vocab_size @@ -5495,14 +6542,16 @@ def set_vocab(self): scores[token_id] = score toktypes[token_id] = toktype - added_tokens_file = self.dir_model / 'added_tokens.json' + added_tokens_file = self.dir_model / "added_tokens.json" if added_tokens_file.is_file(): with open(added_tokens_file, "r", encoding="utf-8") as f: added_tokens_json = json.load(f) for key in added_tokens_json: token_id = added_tokens_json[key] if token_id >= vocab_size: - logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}') + logger.warning( + f"ignore token {token_id}: id is out of range, max={vocab_size - 1}" + ) continue tokens[token_id] = key.encode("utf-8") @@ -5511,7 +6560,9 @@ def set_vocab(self): if vocab_size > len(tokens): pad_count = vocab_size - len(tokens) - logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]") + logger.debug( + f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]" + ) for i in range(1, pad_count + 1): tokens.append(bytes(f"[PAD{i}]", encoding="utf-8")) scores.append(-1000.0) @@ -5532,7 +6583,9 @@ def set_vocab(self): def set_gguf_parameters(self): if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None: - logger.warning("Couldn't find context length in config.json, assuming default value of 512") + logger.warning( + "Couldn't find context length in config.json, assuming default value of 512" + ) n_ctx = 512 self.gguf_writer.add_context_length(n_ctx) self.gguf_writer.add_embedding_length(self.hparams["d_model"]) @@ -5542,23 +6595,33 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) - self.gguf_writer.add_relative_attn_buckets_count(self.hparams["relative_attention_num_buckets"]) + self.gguf_writer.add_relative_attn_buckets_count( + self.hparams["relative_attention_num_buckets"] + ) self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused # T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight", # "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored # in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder # and decoder and ignore the remaining ones. - if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]: + if name in [ + "decoder.embed_tokens.weight", + "encoder.embed_tokens.weight", + "shared.weight", + ]: if not self.shared_token_embeddings_found: name = "shared.weight" self.shared_token_embeddings_found = True else: - logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.") + logger.debug( + f"Skipping shared tensor {name!r} in safetensors so that convert can end normally." + ) return [] return [(self.map_tensor_name(name), data_torch)] @@ -5578,19 +6641,21 @@ def __init__(self, *args, **kwargs): # Embeddings scale self.embeddings_scale = 1.0 - if 'mup_embeddings_scale' in self.hparams: - self.embeddings_scale = self.hparams['mup_embeddings_scale'] - elif 'embeddings_scale' in self.hparams: - self.embeddings_scale = self.hparams['embeddings_scale'] + if "mup_embeddings_scale" in self.hparams: + self.embeddings_scale = self.hparams["mup_embeddings_scale"] + elif "embeddings_scale" in self.hparams: + self.embeddings_scale = self.hparams["embeddings_scale"] else: assert False self.width_scale = 1.0 - if 'mup_output_alpha' in self.hparams: - assert 'mup_width_scale' in self.hparams - self.width_scale = self.hparams['mup_output_alpha'] * self.hparams['mup_width_scale'] - elif 'width_scale' in self.hparams: - self.width_scale = self.hparams['width_scale'] + if "mup_output_alpha" in self.hparams: + assert "mup_width_scale" in self.hparams + self.width_scale = ( + self.hparams["mup_output_alpha"] * self.hparams["mup_width_scale"] + ) + elif "width_scale" in self.hparams: + self.width_scale = self.hparams["width_scale"] else: assert False @@ -5608,7 +6673,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"]) self.gguf_writer.add_file_type(self.ftype) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused tensors: list[tuple[str, Tensor]] = [] @@ -5628,7 +6695,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return tensors - if name.endswith((".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight")): + if name.endswith( + (".c_attn.weight", ".c_proj.weight", ".c_fc.weight", ".c_fc2.weight") + ): data_torch = data_torch.transpose(1, 0) new_name = self.map_tensor_name(name) @@ -5653,7 +6722,10 @@ class Glm4Model(TextModel): def set_vocab(self): from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + tokenizer = AutoTokenizer.from_pretrained( + self.dir_model, trust_remote_code=True + ) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) tokens, toktypes, tokpre = self.get_vocab_base() self.gguf_writer.add_tokenizer_model("gpt2") @@ -5661,21 +6733,34 @@ def set_vocab(self): self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token( + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token( + "unk", tokenizer.get_added_vocab()["<|endoftext|>"] + ) + special_vocab._set_special_token( + "bos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): super().set_gguf_parameters() rope_dim = self.hparams["head_dim"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") @@ -5690,11 +6775,23 @@ def set_vocab_chatglm3(self): scores: list[float] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab())) assert max(tokenizer.get_vocab().values()) < vocab_size - role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] - special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + role_special_tokens = [ + "<|system|>", + "<|user|>", + "<|assistant|>", + "<|observation|>", + ] + special_tokens = [ + "[MASK]", + "[gMASK]", + "[sMASK]", + "sop", + "eop", + ] + role_special_tokens for token_id in range(vocab_size): piece = tokenizer._convert_id_to_token(token_id) if token_id == 0: @@ -5752,11 +6849,14 @@ def set_vocab_chatglm3(self): @staticmethod def token_bytes_to_string(b): from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode + byte_encoder = bytes_to_unicode() - return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')]) + return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")]) @staticmethod - def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]: + def bpe( + mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None + ) -> list[bytes]: parts = [bytes([b]) for b in token] while True: min_idx = None @@ -5769,7 +6869,11 @@ def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = if min_rank is None or (max_rank is not None and min_rank >= max_rank): break assert min_idx is not None - parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:] + parts = ( + parts[:min_idx] + + [parts[min_idx] + parts[min_idx + 1]] + + parts[min_idx + 2 :] + ) return parts def set_vocab(self): @@ -5783,8 +6887,9 @@ def set_vocab(self): toktypes: list[int] = [] from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True) - vocab_size = hparams.get("padded_vocab_size",hparams["vocab_size"]) + vocab_size = hparams.get("padded_vocab_size", hparams["vocab_size"]) assert max(tokenizer.get_vocab().values()) < vocab_size tokens, toktypes, tokpre = self.get_vocab_base() @@ -5794,39 +6899,61 @@ def set_vocab(self): self.gguf_writer.add_token_types(toktypes) special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) # only add special tokens when they were not already loaded from config.json - special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token( + "eos", tokenizer.get_added_vocab()["<|endoftext|>"] + ) special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # this one is usually not in config.json anyway - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token( + "unk", tokenizer.get_added_vocab()["<|endoftext|>"] + ) special_vocab.add_to_gguf(self.gguf_writer) def set_gguf_parameters(self): n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed")) n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads")) - n_head_kv = self.hparams.get("multi_query_group_num", self.hparams.get("num_key_value_heads", n_head)) + n_head_kv = self.hparams.get( + "multi_query_group_num", self.hparams.get("num_key_value_heads", n_head) + ) self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed)) self.gguf_writer.add_embedding_length(n_embed) - self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed))) - self.gguf_writer.add_block_count(self.hparams.get("num_layers", self.hparams["num_hidden_layers"])) + self.gguf_writer.add_feed_forward_length( + self.hparams.get( + "ffn_hidden_size", self.hparams.get("intermediate_size", 4 * n_embed) + ) + ) + self.gguf_writer.add_block_count( + self.hparams.get("num_layers", self.hparams["num_hidden_layers"]) + ) self.gguf_writer.add_head_count(n_head) self.gguf_writer.add_head_count_kv(n_head_kv) - self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon",1e-5)) + self.gguf_writer.add_layer_norm_rms_eps( + self.hparams.get("layernorm_epsilon", 1e-5) + ) self.gguf_writer.add_file_type(self.ftype) if "attention_dim" in self.hparams: rope_dim = self.hparams["attention_dim"] else: - rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + rope_dim = ( + self.hparams["hidden_size"] // self.hparams["num_attention_heads"] + ) + self.gguf_writer.add_rope_dimension_count( + int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + ) self.gguf_writer.add_add_bos_token(False) rope_freq = 10000 if "rope_ratio" in self.hparams: rope_freq = rope_freq * self.hparams["rope_ratio"] self.gguf_writer.add_rope_freq_base(rope_freq) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused - if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith("model.vision."): + if name.endswith(".rotary_pos_emb.inv_freq") or name.startswith( + "model.vision." + ): return [] name = name.removeprefix("transformer.") @@ -5847,11 +6974,15 @@ def set_gguf_parameters(self): hparams = self.hparams self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - f_norm_eps = self.find_hparam(["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"]) + f_norm_eps = self.find_hparam( + ["layer_norm_eps", "layer_norm_epsilon", "norm_epsilon", "norm_eps"] + ) self.gguf_writer.add_layer_norm_eps(f_norm_eps) # * Partial RoPE - rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + rot_pct = self.find_hparam( + ["partial_rotary_factor", "rope_pct", "rope_percent"] + ) n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) @@ -5863,7 +6994,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side # model.layers.{l}.input_layernorm.weight # model.layers.{l}.post_attention_layernorm.weight @@ -5881,14 +7014,18 @@ class ExaoneModel(TextModel): def set_gguf_parameters(self): hparams = self.hparams - assert (hparams["activation_function"] == "silu") + assert hparams["activation_function"] == "silu" max_position_embeddings = hparams["max_position_embeddings"] embed_dim = hparams["hidden_size"] num_heads = hparams["num_attention_heads"] num_kv_heads = hparams.get("num_key_value_heads", num_heads) layer_norm_eps = hparams["layer_norm_epsilon"] - intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim + intermediate_size = ( + hparams["intermediate_size"] + if "intermediate_size" in hparams + else 4 * embed_dim + ) num_layers = hparams["num_layers"] # ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0 # attention_dropout_rate = hparams["attention_dropout"] @@ -5905,26 +7042,43 @@ def set_gguf_parameters(self): if (rope_theta := self.hparams.get("rope_theta")) is not None: self.gguf_writer.add_rope_freq_base(rope_theta) - rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = self.find_hparam( + ["partial_rotary_factor", "rope_pct"], optional=True + ) rotary_factor = rotary_factor if rotary_factor is not None else 1.0 - self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) + self.gguf_writer.add_rope_dimension_count( + int( + rotary_factor + * (hparams["hidden_size"] // hparams["num_attention_heads"]) + ) + ) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: if rope_scaling := self.find_hparam(["rope_scaling"], optional=True): - if rope_scaling.get("rope_type", '').lower() == "llama3": + if rope_scaling.get("rope_type", "").lower() == "llama3": base = self.hparams.get("rope_theta", 10000.0) if (dim := self.hparams.get("head_dim")) is None: - dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + dim = ( + self.hparams["hidden_size"] + // self.hparams["num_attention_heads"] + ) + freqs = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) factor = rope_scaling.get("factor", 8.0) low_freq_factor = rope_scaling.get("low_freq_factor", 1.0) high_freq_factor = rope_scaling.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = self.hparams.get( + "original_max_position_embeddings", 8192 + ) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -5938,15 +7092,21 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: elif wavelen > low_freq_wavelen: rope_factors.append(factor) else: - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) rope_factors.append(1 / ((1 - smooth) / factor + smooth)) - yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32)) + yield ( + self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), + torch.tensor(rope_factors, dtype=torch.float32), + ) @ModelBase.register("GraniteForCausalLM") class GraniteModel(LlamaModel): """Conversion for IBM's GraniteForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE def set_gguf_parameters(self): @@ -5981,6 +7141,7 @@ def set_gguf_parameters(self): @ModelBase.register("GraniteMoeForCausalLM", "GraniteMoeSharedForCausalLM") class GraniteMoeModel(GraniteModel): """Conversion for IBM's GraniteMoeForCausalLM""" + model_arch = gguf.MODEL_ARCH.GRANITE_MOE def set_gguf_parameters(self): @@ -5989,10 +7150,17 @@ def set_gguf_parameters(self): """ super().set_gguf_parameters() if shared_feed_forward_length := self.hparams.get("shared_intermediate_size"): - self.gguf_writer.add_expert_shared_feed_forward_length(shared_feed_forward_length) - logger.info("gguf: (granitemoeshared) shared_feed_forward_length = %s", shared_feed_forward_length) + self.gguf_writer.add_expert_shared_feed_forward_length( + shared_feed_forward_length + ) + logger.info( + "gguf: (granitemoeshared) shared_feed_forward_length = %s", + shared_feed_forward_length, + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: """In modeling_granitemoe, the JetMoe implementation of parallel experts is used. This essentially merges w1 and w3 into a single tensor with 2x the hidden size that is then split during forward. To keep compatibility @@ -6001,7 +7169,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("block_sparse_moe.input_linear.weight"): ffn_dim = self.hparams["intermediate_size"] - assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size" + assert ( + data_torch.shape[-2] == 2 * ffn_dim + ), "Merged FFN tensor size must be 2 * intermediate_size" gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate), @@ -6010,7 +7180,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith("shared_mlp.input_linear.weight"): ffn_dim = self.hparams["shared_intermediate_size"] - assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size" + assert ( + data_torch.shape[-2] == 2 * ffn_dim + ), "Merged FFN tensor size must be 2 * shared_intermediate_size" gate, up = data_torch.split(ffn_dim, dim=-2) return [ (self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate), @@ -6035,15 +7207,22 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_dimension_count(rope_dim) rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) else: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_feed_forward_length( + hparams["moe_intermediate_size"] + ) self.gguf_writer.add_expert_weights_scale(1.0) self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) @@ -6055,11 +7234,17 @@ def set_gguf_parameters(self): def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape( + n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:] + ) + .swapaxes(1, 2) + .reshape(weights.shape) + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") n_embd = self.hparams["hidden_size"] @@ -6069,14 +7254,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT) if name.endswith("attention.dense.weight"): - return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)] + return [ + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch) + ] elif name.endswith("query_key_value.weight"): - q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2) + q, k, v = data_torch.split( + [n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2 + ) return [ - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)), - (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v) + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), + BailingMoeModel.permute(q, n_head, n_head), + ), + ( + self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), + BailingMoeModel.permute(k, n_head, n_kv_head), + ), + (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v), ] elif name.find("mlp.experts") != -1: n_experts = self.hparams["num_experts"] @@ -6139,7 +7334,9 @@ def set_gguf_parameters(self): def set_vocab(self): self._set_vocab_gpt2() - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # ignore image tokenizer for now # TODO: remove this once image support is implemented for Chameleon if name.startswith("model.vqmodel"): @@ -6154,9 +7351,13 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(("k_proj.weight", "k_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head) if name.endswith(("q_norm.weight", "q_norm.bias")): - data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_head, hidden_dim) + data_torch = ChameleonModel._reverse_hf_permute( + data_torch, n_head, hidden_dim + ) if name.endswith(("k_norm.weight", "k_norm.bias")): - data_torch = ChameleonModel._reverse_hf_permute(data_torch, n_kv_head, hidden_dim) + data_torch = ChameleonModel._reverse_hf_permute( + data_torch, n_kv_head, hidden_dim + ) return [(self.map_tensor_name(name), data_torch)] @@ -6171,16 +7372,18 @@ def _reverse_hf_permute(data_torch, n_heads, hidden_dim): @ModelBase.register("UltravoxModel") class UltravoxModel(TextModel): - model_arch = gguf.MODEL_ARCH.LLAMA # dummy + model_arch = gguf.MODEL_ARCH.LLAMA # dummy def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - raise NotImplementedError("Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument") + raise NotImplementedError( + "Ultravox does not have text decoder. Instead, it uses Llama or other models for text. If you want to get the audio encoder, please use --mmproj argument" + ) @ModelBase.register("Qwen2AudioForConditionalGeneration") class WhisperEncoderModel(MmprojModel): - has_vision_encoder = False # no vision encoder + has_vision_encoder = False # no vision encoder has_audio_encoder = True def __init__(self, *args, **kwargs): @@ -6193,7 +7396,9 @@ def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.QWEN2A) self.gguf_writer.add_audio_num_mel_bins(self.hparams["num_mel_bins"]) - self.gguf_writer.add_audio_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5)) + self.gguf_writer.add_audio_attention_layernorm_eps( + self.hparams.get("layer_norm_eps", 1e-5) + ) def tensor_force_quant(self, name, new_name, bid, n_dims): del bid, new_name, n_dims # unused @@ -6201,7 +7406,9 @@ def tensor_force_quant(self, name, new_name, bid, n_dims): return gguf.GGMLQuantizationType.F16 return False - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: del bid # unused if name.startswith("language_model."): @@ -6221,41 +7428,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter @ModelBase.register("UltravoxModel") class UltravoxWhisperEncoderModel(WhisperEncoderModel): - has_vision_encoder = False # no vision encoder + has_vision_encoder = False # no vision encoder has_audio_encoder = True def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"]) + @ModelBase.register("SmallThinkerForCausalLM") class SmallThinkerMoeModel(TextModel): model_arch = gguf.MODEL_ARCH.SMALLTHINKERMOE def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + if ( + n_experts := self.hparams.get( + "num_experts", self.hparams.get("moe_num_primary_experts") + ) + ) is not None: self.gguf_writer.add_expert_count(n_experts) - if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + if ( + n_experts_used := self.hparams.get( + "num_experts_per_tok", + self.hparams.get("moe_num_active_primary_experts"), + ) + ) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) - if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: + if ( + moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size") + ) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) self.gguf_writer.add_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") - if (router_apply_softmax := self.hparams.get('moe_primary_router_apply_softmax')): + if router_apply_softmax := self.hparams.get("moe_primary_router_apply_softmax"): if router_apply_softmax: - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + self.gguf_writer.add_expert_gating_func( + gguf.ExpertGatingFuncType.SOFTMAX + ) else: - self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) - logger.info(f'gguf: router apply softmax = {router_apply_softmax}') - + self.gguf_writer.add_expert_gating_func( + gguf.ExpertGatingFuncType.SIGMOID + ) + logger.info(f"gguf: router apply softmax = {router_apply_softmax}") + # YaRN is not enabled by default # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts rope_scaling = self.hparams.get("rope_scaling") or {} - if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + if ( + rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" + and "factor" in rope_scaling + ): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) - self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len( + rope_scaling["original_max_position_embeddings"] + ) sliding_window_layout = self.hparams.get("sliding_window_layout") if sliding_window_layout: @@ -6263,12 +7491,14 @@ def set_gguf_parameters(self): if i != 0: sliding_window = self.hparams.get("sliding_window_size") self.gguf_writer.add_sliding_window(sliding_window) - logger.info(f'gguf: sliding window = True') + logger.info(f"gguf: sliding window = True") break _experts: list[dict[str, Tensor]] | None = None - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: # process the experts separately print(f"Processing tensor: {name} in block {bid}") @@ -6286,7 +7516,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter print(f"Transposed tensor: {name}, shape: {data_torch.shape}") if name.find("experts") != -1: - n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + n_experts = self.hparams.get( + "num_experts", self.hparams.get("moe_num_primary_experts") + ) assert bid is not None if self._experts is None: @@ -6310,7 +7542,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter data_torch = torch.stack(datas, dim=0) # merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" - merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + merged_name = ( + f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight" + ) new_name = self.map_tensor_name(merged_name) @@ -6335,13 +7569,18 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: toktypes: list[int] = [] from transformers import AutoTokenizer - tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True) + + tokenizer = AutoTokenizer.from_pretrained( + self.dir_model, trust_remote_code=True + ) vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) - reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()} + reverse_vocab = { + id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items() + } added_vocab = tokenizer.get_added_vocab() added_tokens_decoder = tokenizer.added_tokens_decoder @@ -6357,16 +7596,24 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: # To avoid unexpected issues - we make sure to normalize non-normalized tokens if not added_tokens_decoder[i].normalized: previous_token = token - token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False)) + token = tokenizer.decode( + tokenizer.encode(token, add_special_tokens=False) + ) if previous_token != token: - logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer") + logger.info( + f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer" + ) - if added_tokens_decoder[i].special or self.does_token_look_special(token): + if added_tokens_decoder[i].special or self.does_token_look_special( + token + ): toktypes.append(gguf.TokenType.CONTROL) else: # NOTE: this was added for Gemma. # Encoding and decoding the tokens above isn't sufficient for this case. - token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces + token = token.replace( + b"\xe2\x96\x81".decode("utf-8"), " " + ) # pre-normalize user-defined spaces toktypes.append(gguf.TokenType.USER_DEFINED) else: toktypes.append(gguf.TokenType.NORMAL) @@ -6374,6 +7621,7 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: return tokens, toktypes, tokpre + # -- PowerInfer end @@ -6419,18 +7667,24 @@ def numpy(self) -> gguf.LazyNumpyTensor: return gguf.LazyNumpyTensor( meta=gguf.LazyNumpyTensor.meta_with_dtype_and_shape(dtype, self.shape), args=(self,), - func=(lambda s: s.numpy()) + func=(lambda s: s.numpy()), ) @classmethod - def meta_with_dtype_and_shape(cls, dtype: torch.dtype, shape: tuple[int, ...]) -> Tensor: + def meta_with_dtype_and_shape( + cls, dtype: torch.dtype, shape: tuple[int, ...] + ) -> Tensor: return torch.empty(size=shape, dtype=dtype, device="meta") @classmethod def from_safetensors_slice(cls, st_slice: Any) -> Tensor: dtype = cls._dtype_str_map[st_slice.get_dtype()] shape: tuple[int, ...] = tuple(st_slice.get_shape()) - lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(st_slice,), func=lambda s: s[:]) + lazy = cls( + meta=cls.meta_with_dtype_and_shape(dtype, shape), + args=(st_slice,), + func=lambda s: s[:], + ) return cast(torch.Tensor, lazy) @classmethod @@ -6438,7 +7692,11 @@ def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): dtype = cls._dtype_str_map[remote_tensor.dtype] shape = remote_tensor.shape meta = cls.meta_with_dtype_and_shape(dtype, shape) - lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) + lazy = cls( + meta=meta, + args=(remote_tensor,), + func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape), + ) return cast(torch.Tensor, lazy) @classmethod @@ -6456,81 +7714,106 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a huggingface model to a GGML compatible file") + description="Convert a huggingface model to a GGML compatible file" + ) parser.add_argument( - "--vocab-only", action="store_true", + "--vocab-only", + action="store_true", help="extract only the vocab", ) parser.add_argument( - "--outfile", type=Path, + "--outfile", + type=Path, help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", + "--outtype", + type=str, + choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], + default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( - "--bigendian", action="store_true", + "--bigendian", + action="store_true", help="model is executed on big endian machine", ) parser.add_argument( - "model", type=str, + "model", + type=str, help="directory containing model file or huggingface repository ID (if --remote)", nargs="?", ) parser.add_argument( - "--use-temp-file", action="store_true", + "--use-temp-file", + action="store_true", help="use the tempfile library while processing (helpful when running out of memory, process killed)", ) parser.add_argument( - "--no-lazy", action="store_true", + "--no-lazy", + action="store_true", help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", ) parser.add_argument( - "--model-name", type=str, default=None, + "--model-name", + type=str, + default=None, help="name of the model", ) parser.add_argument( - "--verbose", action="store_true", + "--verbose", + action="store_true", help="increase output verbosity", ) parser.add_argument( - "--split-max-tensors", type=int, default=0, + "--split-max-tensors", + type=int, + default=0, help="max tensors in each split", ) parser.add_argument( - "--split-max-size", type=str, default="0", + "--split-max-size", + type=str, + default="0", help="max size per split N(M|G)", ) parser.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="only print out a split plan and exit, without writing any new files", ) parser.add_argument( - "--no-tensor-first-split", action="store_true", - help="do not add tensors to the first split (disabled by default)" + "--no-tensor-first-split", + action="store_true", + help="do not add tensors to the first split (disabled by default)", ) parser.add_argument( - "--metadata", type=Path, - help="Specify the path for an authorship metadata override file" + "--metadata", + type=Path, + help="Specify the path for an authorship metadata override file", ) parser.add_argument( - "--print-supported-models", action="store_true", - help="Print the supported models" + "--print-supported-models", + action="store_true", + help="Print the supported models", ) parser.add_argument( - "--remote", action="store_true", + "--remote", + action="store_true", help="(Experimental) Read safetensors file remotely without downloading to disk. Config and tokenizer files will still be downloaded. To use this feature, you need to specify Hugging Face model repo name instead of a local directory. For example: 'HuggingFaceTB/SmolLM2-1.7B-Instruct'. Note: To access gated repo, set HF_TOKEN environment variable to your Hugging Face token.", ) parser.add_argument( - "--mmproj", action="store_true", + "--mmproj", + action="store_true", help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.", ) - # -- PowerInfer parser.add_argument( - "--transpose-down", type=str, choices=["none", "dense", "moe", "all"], default="none", + "--transpose-down", + type=str, + choices=["none", "dense", "moe", "all"], + default="none", help="transpose down projection in dense layers, choices: none, dense, moe, all. ", ) # -- PowerInfer end @@ -6551,7 +7834,9 @@ def split_str_to_n_bytes(split_str: str) -> int: elif split_str.isnumeric(): n = int(split_str) else: - raise ValueError(f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G") + raise ValueError( + f"Invalid split size: {split_str}, must be a number, optionally followed by K, M, or G" + ) if n < 0: raise ValueError(f"Invalid split size: {split_str}, must be positive") @@ -6568,7 +7853,10 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] - elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: + elif ( + model_type == ModelType.MMPROJ + and vision_config.get("architectures") is not None + ): arch = vision_config["architectures"][0] return arch @@ -6589,9 +7877,11 @@ def main() -> None: if args.remote: hf_repo_id = args.model from huggingface_hub import snapshot_download + local_dir = snapshot_download( repo_id=hf_repo_id, - allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"]) + allow_patterns=["LICENSE", "*.json", "*.md", "*.txt", "tokenizer.model"], + ) dir_model = Path(local_dir) logger.info(f"Downloaded config and tokenizer to {local_dir}") else: @@ -6599,7 +7889,7 @@ def main() -> None: dir_model = Path(args.model) if not dir_model.is_dir(): - logger.error(f'Error: {dir_model} is not a directory') + logger.error(f"Error: {dir_model} is not a directory") sys.exit(1) ftype_map: dict[str, gguf.LlamaFileType] = { @@ -6638,36 +7928,48 @@ def main() -> None: model_architecture = get_model_architecture(hparams, model_type) logger.info(f"Model architecture: {model_architecture}") try: - model_class = ModelBase.from_model_architecture(model_architecture, model_type=model_type) + model_class = ModelBase.from_model_architecture( + model_architecture, model_type=model_type + ) except NotImplementedError: logger.error(f"Model {model_architecture} is not supported") sys.exit(1) - model_instance = model_class(dir_model, output_type, fname_out, - is_big_endian=args.bigendian, use_temp_file=args.use_temp_file, - eager=args.no_lazy, - metadata_override=args.metadata, model_name=args.model_name, - split_max_tensors=args.split_max_tensors, - split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, - small_first_shard=args.no_tensor_first_split, - remote_hf_model_id=hf_repo_id, - - # -- PowerInfer - transpose_down=args.transpose_down - # -- PowerInfer end - - ) + model_instance = model_class( + dir_model, + output_type, + fname_out, + is_big_endian=args.bigendian, + use_temp_file=args.use_temp_file, + eager=args.no_lazy, + metadata_override=args.metadata, + model_name=args.model_name, + split_max_tensors=args.split_max_tensors, + split_max_size=split_str_to_n_bytes(args.split_max_size), + dry_run=args.dry_run, + small_first_shard=args.no_tensor_first_split, + remote_hf_model_id=hf_repo_id, + # -- PowerInfer + transpose_down=args.transpose_down, + # -- PowerInfer end + ) if args.vocab_only: logger.info("Exporting model vocab...") model_instance.write_vocab() - logger.info(f"Model vocab successfully exported to {model_instance.fname_out}") + logger.info( + f"Model vocab successfully exported to {model_instance.fname_out}" + ) else: logger.info("Exporting model...") model_instance.write() - out_path = f"{model_instance.fname_out.parent}{os.sep}" if is_split else model_instance.fname_out + out_path = ( + f"{model_instance.fname_out.parent}{os.sep}" + if is_split + else model_instance.fname_out + ) logger.info(f"Model successfully exported to {out_path}") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/smallthinker/convert_hf_to_gguf_update.py b/smallthinker/convert_hf_to_gguf_update.py index 2f733f09..79b46ce8 100755 --- a/smallthinker/convert_hf_to_gguf_update.py +++ b/smallthinker/convert_hf_to_gguf_update.py @@ -23,7 +23,9 @@ convert_py_pth = pathlib.Path("convert_hf_to_gguf.py") convert_py = convert_py_pth.read_text(encoding="utf-8") hf_token_pth = pathlib.Path.home() / ".cache" / "huggingface" / "token" -hf_token = hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None +hf_token = ( + hf_token_pth.read_text(encoding="utf-8").strip() if hf_token_pth.exists() else None +) class TOKENIZER_TYPE(IntEnum): @@ -55,9 +57,12 @@ class TOKENIZER_TYPE(IntEnum): """ # TODO: generate tokenizer tests for llama.cpp -parser = argparse.ArgumentParser(description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter) +parser = argparse.ArgumentParser( + description=DOC_STRING, formatter_class=argparse.RawTextHelpFormatter +) parser.add_argument( - "--full", action="store_true", + "--full", + action="store_true", help="download full list of models - make sure you have access to all of them", ) parser.add_argument( @@ -69,74 +74,296 @@ class TOKENIZER_TYPE(IntEnum): hf_token = args.hf_token if args.hf_token is not None else hf_token if hf_token is None: - logger.error("HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token") + logger.error( + "HF token is required. Please provide it as an argument or set it in ~/.cache/huggingface/token" + ) sys.exit(1) # TODO: this string has to exercise as much pre-tokenizer functionality as possible # will be updated with time - contributions welcome -CHK_TXT = '\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български \'\'\'\'\'\'```````\"\"\"\"......!!!!!!?????? I\'ve been \'told he\'s there, \'RE you sure? \'M not sure I\'ll make it, \'D you like some tea? We\'Ve a\'lL' +CHK_TXT = "\n \n\n \n\n\n \t \t\t \t\n \n \n \n \n🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ 🦙🦙 3 33 333 3333 33333 333333 3333333 33333333 3.3 3..3 3...3 កាន់តែពិសេសអាច😁 ?我想在apple工作1314151天~ ------======= нещо на Български ''''''```````\"\"\"\"......!!!!!!?????? I've been 'told he's there, 'RE you sure? 'M not sure I'll make it, 'D you like some tea? We'Ve a'lL" # TODO: add models here, base models preferred models = [ - {"name": "llama-spm", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", }, - {"name": "llama-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", }, - {"name": "phi-3", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", }, - {"name": "deepseek-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", }, - {"name": "deepseek-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", }, - {"name": "falcon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/falcon-7b", }, - {"name": "bert-bge", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", }, - {"name": "falcon3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", }, - {"name": "bert-bge-large", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", }, - {"name": "mpt", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mosaicml/mpt-7b", }, - {"name": "starcoder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigcode/starcoder2-3b", }, - {"name": "gpt-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/openai-community/gpt2", }, - {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, - {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, - {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, - {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, - {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, - {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, - {"name": "jina-v1-en", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", }, - {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! - {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, - {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, - {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, - {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, - {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, - {"name": "viking", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Viking-7B", }, # Also used for Viking 13B and 33B - {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, - {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, - {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, - {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, - {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, - {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, - {"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", }, - {'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", }, - {'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", }, - {"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", }, - {"name": "phi-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/microsoft/phi-2", }, - {"name": "chameleon", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/facebook/chameleon-7b", }, - {"name": "roberta-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base"}, - {"name": "gigachat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct"}, - {"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"}, - {"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"}, - {"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"}, - {"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", }, - {"name": "superbpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", }, - {"name": "trillion", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", }, - {"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", }, - {"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", }, - {"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", }, - {"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", }, + { + "name": "llama-spm", + "tokt": TOKENIZER_TYPE.SPM, + "repo": "https://huggingface.co/meta-llama/Llama-2-7b-hf", + }, + { + "name": "llama-bpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/meta-llama/Meta-Llama-3-8B", + }, + { + "name": "phi-3", + "tokt": TOKENIZER_TYPE.SPM, + "repo": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct", + }, + { + "name": "deepseek-llm", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/deepseek-ai/deepseek-llm-7b-base", + }, + { + "name": "deepseek-coder", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-base", + }, + { + "name": "falcon", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/tiiuae/falcon-7b", + }, + { + "name": "bert-bge", + "tokt": TOKENIZER_TYPE.WPM, + "repo": "https://huggingface.co/BAAI/bge-small-en-v1.5", + }, + { + "name": "falcon3", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/tiiuae/Falcon3-7B-Base", + }, + { + "name": "bert-bge-large", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/BAAI/bge-large-zh-v1.5", + }, + { + "name": "mpt", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/mosaicml/mpt-7b", + }, + { + "name": "starcoder", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/bigcode/starcoder2-3b", + }, + { + "name": "gpt-2", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/openai-community/gpt2", + }, + { + "name": "stablelm2", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", + }, + { + "name": "refact", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", + }, + { + "name": "command-r", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", + }, + { + "name": "qwen2", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", + }, + { + "name": "olmo", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", + }, + { + "name": "dbrx", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/databricks/dbrx-base", + }, + { + "name": "jina-v1-en", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/jinaai/jina-reranker-v1-tiny-en", + }, + { + "name": "jina-v2-en", + "tokt": TOKENIZER_TYPE.WPM, + "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", + }, # WPM! + { + "name": "jina-v2-es", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", + }, + { + "name": "jina-v2-de", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", + }, + { + "name": "smaug-bpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", + }, + { + "name": "poro-chat", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", + }, + { + "name": "jina-v2-code", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", + }, + { + "name": "viking", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/LumiOpen/Viking-7B", + }, # Also used for Viking 13B and 33B + { + "name": "gemma", + "tokt": TOKENIZER_TYPE.SPM, + "repo": "https://huggingface.co/google/gemma-2b", + }, + { + "name": "gemma-2", + "tokt": TOKENIZER_TYPE.SPM, + "repo": "https://huggingface.co/google/gemma-2-9b", + }, + { + "name": "jais", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/core42/jais-13b", + }, + { + "name": "t5", + "tokt": TOKENIZER_TYPE.UGM, + "repo": "https://huggingface.co/google-t5/t5-small", + }, + { + "name": "codeshell", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", + }, + { + "name": "tekken", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", + }, + { + "name": "smollm", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", + }, + { + "name": "bloom", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/bigscience/bloom", + }, + { + "name": "gpt3-finnish", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", + }, + { + "name": "exaone", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", + }, + { + "name": "phi-2", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/microsoft/phi-2", + }, + { + "name": "chameleon", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/facebook/chameleon-7b", + }, + { + "name": "roberta-bpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/sentence-transformers/stsb-roberta-base", + }, + { + "name": "gigachat", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct", + }, + { + "name": "megrez", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct", + }, + { + "name": "deepseek-v3", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3", + }, + { + "name": "deepseek-r1-qwen", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + }, + { + "name": "gpt-4o", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/Xenova/gpt-4o", + }, + { + "name": "superbpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/UW/OLMo2-8B-SuperBPE-t180k", + }, + { + "name": "trillion", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/trillionlabs/Trillion-7B-preview", + }, + { + "name": "bailingmoe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/inclusionAI/Ling-lite", + }, + { + "name": "llama4", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", + }, + { + "name": "pixtral", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/mistral-community/pixtral-12b", + }, + { + "name": "seed-coder", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", + }, ] # some models are known to be broken upstream, so we will skip them as exceptions pre_computed_hashes = [ # chatglm-bpe has 2 hashes, why? - {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b"}, - {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, - {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, - {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, + { + "name": "chatglm-bpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", + "chkhsh": "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b", + }, + { + "name": "chatglm-bpe", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", + "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516", + }, + { + "name": "glm4", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", + "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2", + }, + { + "name": "minerva-7b", + "tokt": TOKENIZER_TYPE.BPE, + "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", + "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35", + }, ] @@ -145,7 +372,7 @@ def download_file_with_auth(url, token, save_path): response = sess.get(url, headers=headers) response.raise_for_status() os.makedirs(os.path.dirname(save_path), exist_ok=True) - with open(save_path, 'wb') as downloaded_file: + with open(save_path, "wb") as downloaded_file: downloaded_file.write(response.content) logger.info(f"File {save_path} downloaded successfully") @@ -247,11 +474,15 @@ def get_existing_models(convert_py): try: logger.info(f"Loading tokenizer from {f'models/tokenizers/{name}'}...") if name == "t5": - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + f"models/tokenizers/{name}", use_fast=False + ) else: tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") except OSError as e: - logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}") + logger.error( + f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}" + ) continue # Skip to the next model if the tokenizer can't be loaded chktok = tokenizer.encode(CHK_TXT) @@ -264,20 +495,25 @@ def get_existing_models(convert_py): logger.info(f"chkhsh: {chkhsh}") # print the "pre_tokenizer" content from the tokenizer.json - with open(f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8") as f: + with open( + f"models/tokenizers/{name}/tokenizer.json", "r", encoding="utf-8" + ) as f: cfg = json.load(f) normalizer = cfg["normalizer"] logger.info("normalizer: " + json.dumps(normalizer, indent=4)) pre_tokenizer = cfg["pre_tokenizer"] logger.info("pre_tokenizer: " + json.dumps(pre_tokenizer, indent=4)) if "ignore_merges" in cfg["model"]: - logger.info("ignore_merges: " + json.dumps(cfg["model"]["ignore_merges"], indent=4)) + logger.info( + "ignore_merges: " + + json.dumps(cfg["model"]["ignore_merges"], indent=4) + ) logger.info("") - src_ifs += f" if chkhsh == \"{chkhsh}\":\n" + src_ifs += f' if chkhsh == "{chkhsh}":\n' src_ifs += f" # ref: {model['repo']}\n" - src_ifs += f" res = \"{name}\"\n" + src_ifs += f' res = "{name}"\n' src_func = f""" def get_vocab_base_pre(self, tokenizer) -> str: @@ -378,7 +614,7 @@ def get_vocab_base_pre(self, tokenizer) -> str: "3333333", "33333333", "333333333", - "Cửa Việt", # llama-bpe fails on this + "Cửa Việt", # llama-bpe fails on this " discards", CHK_TXT, ] @@ -408,7 +644,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: # create the tokenizer try: if name == "t5": - tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False) + tokenizer = AutoTokenizer.from_pretrained( + f"models/tokenizers/{name}", use_fast=False + ) else: tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}") except OSError as e: @@ -440,6 +678,8 @@ def get_vocab_base_pre(self, tokenizer) -> str: for model in models: name = model["name"] - print(f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only") # noqa: NP100 + print( + f"python3 convert_hf_to_gguf.py models/tokenizers/{name}/ --outfile models/ggml-vocab-{name}.gguf --vocab-only" + ) # noqa: NP100 logger.info("\n") diff --git a/smallthinker/convert_llama_ggml_to_gguf.py b/smallthinker/convert_llama_ggml_to_gguf.py index 29b14e98..db14c188 100755 --- a/smallthinker/convert_llama_ggml_to_gguf.py +++ b/smallthinker/convert_llama_ggml_to_gguf.py @@ -11,8 +11,8 @@ import numpy as np -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf logger = logging.getLogger("ggml-to-gguf") @@ -25,23 +25,23 @@ class GGMLFormat(IntEnum): class GGMLFType(IntEnum): - ALL_F32 = 0 - MOSTLY_F16 = 1 - MOSTLY_Q4_0 = 2 - MOSTLY_Q4_1 = 3 + ALL_F32 = 0 + MOSTLY_F16 = 1 + MOSTLY_Q4_0 = 2 + MOSTLY_Q4_1 = 3 MOSTLY_Q4_1_SOME_F16 = 4 - MOSTLY_Q8_0 = 7 - MOSTLY_Q5_0 = 8 - MOSTLY_Q5_1 = 9 - MOSTLY_Q2_K = 10 - MOSTLY_Q3_K_S = 11 - MOSTLY_Q3_K_M = 12 - MOSTLY_Q3_K_L = 13 - MOSTLY_Q4_K_S = 14 - MOSTLY_Q4_K_M = 15 - MOSTLY_Q5_K_S = 16 - MOSTLY_Q5_K_M = 17 - MOSTLY_Q6_K = 18 + MOSTLY_Q8_0 = 7 + MOSTLY_Q5_0 = 8 + MOSTLY_Q5_1 = 9 + MOSTLY_Q2_K = 10 + MOSTLY_Q3_K_S = 11 + MOSTLY_Q3_K_M = 12 + MOSTLY_Q3_K_L = 13 + MOSTLY_Q4_K_S = 14 + MOSTLY_Q4_K_M = 15 + MOSTLY_Q5_K_S = 16 + MOSTLY_Q5_K_M = 17 + MOSTLY_Q6_K = 18 class Hyperparameters: @@ -51,8 +51,8 @@ def __init__(self): self.ftype = GGMLFType.ALL_F32 def set_n_ff(self, model): - ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') - assert ff_tensor_idx is not None, 'Missing layer 0 FF tensor' + ff_tensor_idx = model.tensor_map.get(b"layers.0.feed_forward.w1.weight") + assert ff_tensor_idx is not None, "Missing layer 0 FF tensor" ff_tensor = model.tensors[ff_tensor_idx] self.n_ff = ff_tensor.dims[1] @@ -65,32 +65,32 @@ def load(self, data, offset): self.n_layer, self.n_rot, ftype, - ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) + ) = struct.unpack("<7I", data[offset : offset + (4 * 7)]) try: self.ftype = GGMLFType(ftype) except ValueError: - raise ValueError(f'Invalid ftype {ftype}') + raise ValueError(f"Invalid ftype {ftype}") return 4 * 7 def __str__(self): - return f'' + return f"" class Vocab: - def __init__(self, load_scores = True): + def __init__(self, load_scores=True): self.items = [] self.load_scores = load_scores def load(self, data, offset, n_vocab): orig_offset = offset for _ in range(n_vocab): - itemlen = struct.unpack('= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' - assert name_len < 4096, 'Absurd tensor name length' + (n_dims, name_len, dtype) = struct.unpack("<3I", data[offset : offset + 12]) + assert n_dims >= 0 and n_dims <= 4, f"Invalid tensor dimensions {n_dims}" + assert name_len < 4096, "Absurd tensor name length" quant = gguf.GGML_QUANT_SIZES.get(dtype) - assert quant is not None, 'Unknown tensor type' + assert quant is not None, "Unknown tensor type" (blksize, tysize) = quant offset += 12 - self.dtype= gguf.GGMLQuantizationType(dtype) - self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) + self.dtype = gguf.GGMLQuantizationType(dtype) + self.dims = struct.unpack(f"<{n_dims}I", data[offset : offset + (4 * n_dims)]) offset += 4 * n_dims - self.name = bytes(data[offset:offset + name_len]) + self.name = bytes(data[offset : offset + name_len]) offset += name_len pad = ((offset + 31) & ~31) - offset if self.use_padding else 0 offset += pad @@ -143,52 +143,66 @@ def __init__(self): self.tensors = [] def validate_header(self, data, offset): - magic = bytes(data[offset:offset + 4]) - if magic == b'GGUF': - raise ValueError('File is already in GGUF format.') - if magic == b'lmgg': + magic = bytes(data[offset : offset + 4]) + if magic == b"GGUF": + raise ValueError("File is already in GGUF format.") + if magic == b"lmgg": self.file_format = GGMLFormat.GGML self.format_version = 1 return 4 - version = struct.unpack(' 3: - raise ValueError(f'Cannot handle unexpected GGJT file version {version}') + raise ValueError( + f"Cannot handle unexpected GGJT file version {version}" + ) self.file_format = GGMLFormat.GGJT self.format_version = version return 8 - raise ValueError(f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file.") + raise ValueError( + f"Unexpected file magic {magic!r}! This doesn't look like a GGML format file." + ) def validate_conversion(self, ftype): - err = '' - if (self.file_format < GGMLFormat.GGJT or self.format_version < 2): + err = "" + if self.file_format < GGMLFormat.GGJT or self.format_version < 2: if ftype not in (GGMLFType.ALL_F32, GGMLFType.MOSTLY_F16): - err = 'Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2.' - elif (self.file_format == GGMLFormat.GGJT and self.format_version == 2): - if ftype in (GGMLFType.MOSTLY_Q4_0, GGMLFType.MOSTLY_Q4_1, - GGMLFType.MOSTLY_Q4_1_SOME_F16, GGMLFType.MOSTLY_Q8_0): - err = 'Q4 and Q8 quantizations changed in GGJTv3.' + err = "Quantizations changed in GGJTv2. Can only convert unquantized GGML files older than GGJTv2." + elif self.file_format == GGMLFormat.GGJT and self.format_version == 2: + if ftype in ( + GGMLFType.MOSTLY_Q4_0, + GGMLFType.MOSTLY_Q4_1, + GGMLFType.MOSTLY_Q4_1_SOME_F16, + GGMLFType.MOSTLY_Q8_0, + ): + err = "Q4 and Q8 quantizations changed in GGJTv3." if len(err) > 0: - raise ValueError(f'{err} Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.') + raise ValueError( + f"{err} Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion." + ) def load(self, data, offset): offset += self.validate_header(data, offset) hp = Hyperparameters() offset += hp.load(data, offset) - logger.info(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}') + logger.info( + f"* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}" + ) self.validate_conversion(hp.ftype) - vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML) + vocab = Vocab(load_scores=self.file_format > GGMLFormat.GGML) offset += vocab.load(data, offset, hp.n_vocab) tensors: list[Tensor] = [] tensor_map = {} while offset < len(data): - tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF) + tensor = Tensor(use_padding=self.file_format > GGMLFormat.GGMF) offset += tensor.load(data, offset) tensor_map[tensor.name] = len(tensors) tensors.append(tensor) @@ -201,7 +215,15 @@ def load(self, data, offset): class GGMLToGGUF: - def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override = None, special_vocab = None): + def __init__( + self, + ggml_model, + data, + cfg, + params_override=None, + vocab_override=None, + special_vocab=None, + ): hp = ggml_model.hyperparameters self.model = ggml_model self.data = data @@ -220,17 +242,22 @@ def __init__(self, ggml_model, data, cfg, params_override = None, vocab_override for x in range(1, 256): if float(hp.n_head) / float(x) == gqa: n_kv_head = x - assert n_kv_head is not None, "Couldn't determine n_kv_head from GQA param" - logger.info(f'- Guessed n_kv_head = {n_kv_head} based on GQA {cfg.gqa}') + assert ( + n_kv_head is not None + ), "Couldn't determine n_kv_head from GQA param" + logger.info(f"- Guessed n_kv_head = {n_kv_head} based on GQA {cfg.gqa}") self.n_kv_head = n_kv_head - self.name_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAMA, ggml_model.hyperparameters.n_layer) + self.name_map = gguf.get_tensor_name_map( + gguf.MODEL_ARCH.LLAMA, ggml_model.hyperparameters.n_layer + ) def save(self): - logger.info('* Preparing to save GGUF file') + logger.info("* Preparing to save GGUF file") gguf_writer = gguf.GGUFWriter( self.cfg.output, gguf.MODEL_ARCH_NAMES[gguf.MODEL_ARCH.LLAMA], - use_temp_file = False) + use_temp_file=False, + ) self.add_params(gguf_writer) self.add_vocab(gguf_writer) if self.special_vocab is not None: @@ -250,30 +277,30 @@ def add_params(self, gguf_writer): if cfg.desc is not None: desc = cfg.desc else: - desc = f'converted from legacy {self.model.file_format.name}v{self.model.format_version} {hp.ftype.name} format' + desc = f"converted from legacy {self.model.file_format.name}v{self.model.format_version} {hp.ftype.name} format" try: # Filenames aren't necessarily valid UTF8. name = cfg.name if cfg.name is not None else cfg.input.name except UnicodeDecodeError: name = None - logger.info('* Adding model parameters and KV items') + logger.info("* Adding model parameters and KV items") if name is not None: gguf_writer.add_name(name) gguf_writer.add_description(desc) gguf_writer.add_file_type(int(hp.ftype)) if self.params_override is not None: po = self.params_override - assert po.n_embd == hp.n_embd, 'Model hyperparams mismatch' - assert po.n_layer == hp.n_layer, 'Model hyperparams mismatch' - assert po.n_head == hp.n_head, 'Model hyperparams mismatch' - gguf_writer.add_context_length (po.n_ctx) - gguf_writer.add_embedding_length (po.n_embd) - gguf_writer.add_block_count (po.n_layer) - gguf_writer.add_feed_forward_length (po.n_ff) + assert po.n_embd == hp.n_embd, "Model hyperparams mismatch" + assert po.n_layer == hp.n_layer, "Model hyperparams mismatch" + assert po.n_head == hp.n_head, "Model hyperparams mismatch" + gguf_writer.add_context_length(po.n_ctx) + gguf_writer.add_embedding_length(po.n_embd) + gguf_writer.add_block_count(po.n_layer) + gguf_writer.add_feed_forward_length(po.n_ff) gguf_writer.add_rope_dimension_count(po.n_embd // po.n_head) - gguf_writer.add_head_count (po.n_head) - gguf_writer.add_head_count_kv (po.n_head_kv) - gguf_writer.add_layer_norm_rms_eps (po.f_norm_eps) + gguf_writer.add_head_count(po.n_head) + gguf_writer.add_head_count_kv(po.n_head_kv) + gguf_writer.add_layer_norm_rms_eps(po.f_norm_eps) return gguf_writer.add_context_length(cfg.context_length) gguf_writer.add_embedding_length(hp.n_embd) @@ -286,47 +313,50 @@ def add_params(self, gguf_writer): def add_vocab(self, gguf_writer): hp = self.model.hyperparameters - gguf_writer.add_tokenizer_model('llama') - gguf_writer.add_tokenizer_pre('default') + gguf_writer.add_tokenizer_model("llama") + gguf_writer.add_tokenizer_pre("default") tokens = [] scores = [] toktypes = [] if self.vocab_override is not None: vo = self.vocab_override - logger.info('* Adding vocab item(s)') - for (_, (vbytes, score, ttype)) in enumerate(vo.all_tokens()): + logger.info("* Adding vocab item(s)") + for _, (vbytes, score, ttype) in enumerate(vo.all_tokens()): tokens.append(vbytes) scores.append(score) toktypes.append(ttype) - assert len(tokens) == hp.n_vocab, \ - f'Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}' + assert ( + len(tokens) == hp.n_vocab + ), f"Override vocab has a different number of items than hyperparameters - override = {len(tokens)} but n_vocab={hp.n_vocab}" gguf_writer.add_token_list(tokens) gguf_writer.add_token_scores(scores) if len(toktypes) > 0: gguf_writer.add_token_types(toktypes) return - logger.info(f'* Adding {hp.n_vocab} vocab item(s)') - assert len(self.model.vocab.items) >= 3, 'Cannot handle unexpectedly short model vocab' - for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items): - tt = 1 # Normal + logger.info(f"* Adding {hp.n_vocab} vocab item(s)") + assert ( + len(self.model.vocab.items) >= 3 + ), "Cannot handle unexpectedly short model vocab" + for tokid, (vbytes, vscore) in enumerate(self.model.vocab.items): + tt = 1 # Normal # Special handling for UNK, BOS, EOS tokens. if tokid <= 2: if tokid == 0: - vbytes = b'' + vbytes = b"" tt = 2 elif tokid == 1: - vbytes = b'' + vbytes = b"" tt = 3 else: - vbytes = b'' + vbytes = b"" tt = 3 elif len(vbytes) == 0: - tt = 3 # Control + tt = 3 # Control elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1: - vbytes = bytes(f'<0x{vbytes[0]:02X}>', encoding = 'UTF-8') - tt = 6 # Byte + vbytes = bytes(f"<0x{vbytes[0]:02X}>", encoding="UTF-8") + tt = 6 # Byte else: - vbytes = vbytes.replace(b' ', b'\xe2\x96\x81') + vbytes = vbytes.replace(b" ", b"\xe2\x96\x81") toktypes.append(tt) tokens.append(vbytes) scores.append(vscore) @@ -340,11 +370,11 @@ def add_vocab(self, gguf_writer): def add_tensors(self, gguf_writer): tensor_map = self.name_map data = self.data - logger.info(f'* Adding {len(self.model.tensors)} tensor(s)') + logger.info(f"* Adding {len(self.model.tensors)} tensor(s)") for tensor in self.model.tensors: - name = str(tensor.name, 'UTF-8') - mapped_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) - assert mapped_name is not None, f'Bad name {name}' + name = str(tensor.name, "UTF-8") + mapped_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + assert mapped_name is not None, f"Bad name {name}" tempdims = list(tensor.dims[:]) if len(tempdims) > 1: temp = tempdims[1] @@ -352,99 +382,144 @@ def add_tensors(self, gguf_writer): tempdims[0] = temp gguf_writer.add_tensor( mapped_name, - data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], - raw_shape = tempdims, - raw_dtype = tensor.dtype) + data[tensor.start_offset : tensor.start_offset + tensor.len_bytes], + raw_shape=tempdims, + raw_dtype=tensor.dtype, + ) def handle_metadata(cfg, hp): import examples.convert_legacy_llama as convert - assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' - hf_config_path = cfg.model_metadata_dir / "config.json" + assert cfg.model_metadata_dir.is_dir(), "Metadata dir is not a directory" + hf_config_path = cfg.model_metadata_dir / "config.json" orig_config_path = cfg.model_metadata_dir / "params.json" # We pass a fake model here. "original" mode will check the shapes of some # tensors if information is missing in the .json file: other than that, the # model data isn't used so this should be safe (at least for now). fakemodel = { - 'tok_embeddings.weight': convert.LazyTensor.__new__(convert.LazyTensor), - 'layers.0.feed_forward.w1.weight': convert.LazyTensor.__new__(convert.LazyTensor), + "tok_embeddings.weight": convert.LazyTensor.__new__(convert.LazyTensor), + "layers.0.feed_forward.w1.weight": convert.LazyTensor.__new__( + convert.LazyTensor + ), } - fakemodel['tok_embeddings.weight'].shape = [hp.n_vocab] - fakemodel['layers.0.feed_forward.w1.weight'].shape = [hp.n_ff] + fakemodel["tok_embeddings.weight"].shape = [hp.n_vocab] + fakemodel["layers.0.feed_forward.w1.weight"].shape = [hp.n_ff] if hf_config_path.exists(): params = convert.Params.loadHFTransformerJson(fakemodel, hf_config_path) elif orig_config_path.exists(): params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) else: - raise ValueError('Unable to load metadata') - vocab_path = Path(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir) + raise ValueError("Unable to load metadata") + vocab_path = Path( + cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir + ) vocab_factory = convert.VocabFactory(vocab_path) - vocab, special_vocab = vocab_factory.load_vocab(cfg.vocabtype.split(","), cfg.model_metadata_dir) + vocab, special_vocab = vocab_factory.load_vocab( + cfg.vocabtype.split(","), cfg.model_metadata_dir + ) convert.check_vocab_size(params, vocab) return params, vocab, special_vocab def handle_args(): - parser = argparse.ArgumentParser(description = 'Convert GGML models to GGUF') - parser.add_argument('--input', '-i', type = Path, required = True, - help = 'Input GGMLv3 filename') - parser.add_argument('--output', '-o', type = Path, required = True, - help ='Output GGUF filename') - parser.add_argument('--name', - help = 'Set model name') - parser.add_argument('--desc', - help = 'Set model description') - parser.add_argument('--gqa', type = int, default = 1, - help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') - parser.add_argument('--eps', default = '5.0e-06', - help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') - parser.add_argument('--context-length', '-c', type=int, default = 2048, - help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') - parser.add_argument('--model-metadata-dir', '-m', type = Path, - help ='Load HuggingFace/.pth vocab and metadata from the specified directory') - parser.add_argument("--vocab-dir", type=Path, - help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") - parser.add_argument("--vocabtype", default="spm,hfft", - help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser = argparse.ArgumentParser(description="Convert GGML models to GGUF") + parser.add_argument( + "--input", "-i", type=Path, required=True, help="Input GGMLv3 filename" + ) + parser.add_argument( + "--output", "-o", type=Path, required=True, help="Output GGUF filename" + ) + parser.add_argument("--name", help="Set model name") + parser.add_argument("--desc", help="Set model description") + parser.add_argument( + "--gqa", + type=int, + default=1, + help="grouped-query attention factor (use 8 for LLaMA2 70B)", + ) + parser.add_argument( + "--eps", + default="5.0e-06", + help="RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2", + ) + parser.add_argument( + "--context-length", + "-c", + type=int, + default=2048, + help="Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096", + ) + parser.add_argument( + "--model-metadata-dir", + "-m", + type=Path, + help="Load HuggingFace/.pth vocab and metadata from the specified directory", + ) + parser.add_argument( + "--vocab-dir", + type=Path, + help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir", + ) + parser.add_argument( + "--vocabtype", + default="spm,hfft", + help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm,hfft)", + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) return parser.parse_args() def main(): cfg = handle_args() logging.basicConfig(level=logging.DEBUG if cfg.verbose else logging.INFO) - logger.info(f'* Using config: {cfg}') - logger.warning('=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===') - if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == '5.0e-06'): - logger.info('- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".') - data = np.memmap(cfg.input, mode = 'r') + logger.info(f"* Using config: {cfg}") + logger.warning( + "=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===" + ) + if cfg.model_metadata_dir is None and (cfg.gqa == 1 or cfg.eps == "5.0e-06"): + logger.info( + '- Note: If converting LLaMA2, specifying "--eps 1e-5" is required. 70B models also need "--gqa 8".' + ) + data = np.memmap(cfg.input, mode="r") model = GGMLModel() - logger.info('* Scanning GGML input file') + logger.info("* Scanning GGML input file") offset = model.load(data, 0) # noqa - logger.info(f'* GGML model hyperparameters: {model.hyperparameters}') + logger.info(f"* GGML model hyperparameters: {model.hyperparameters}") vocab_override = None params_override = None special_vocab = None if cfg.model_metadata_dir is not None: - (params_override, vocab_override, special_vocab) = handle_metadata(cfg, model.hyperparameters) - logger.info('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') - logger.info(f'* Overriding params: {params_override}') - logger.info(f'* Overriding vocab: {vocab_override}') - logger.info(f'* Special vocab: {special_vocab}') + (params_override, vocab_override, special_vocab) = handle_metadata( + cfg, model.hyperparameters + ) + logger.info( + "!! Note: When overriding params the --gqa, --eps and --context-length options are ignored." + ) + logger.info(f"* Overriding params: {params_override}") + logger.info(f"* Overriding vocab: {vocab_override}") + logger.info(f"* Special vocab: {special_vocab}") else: - logger.warning('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') + logger.warning( + "\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n" + ) if model.file_format == GGMLFormat.GGML: - logger.info('! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!') + logger.info( + "! This is a very old GGML file that does not contain vocab scores. Strongly recommend using model metadata!" + ) converter = GGMLToGGUF( - model, data, cfg, - params_override = params_override, - vocab_override = vocab_override, - special_vocab = special_vocab + model, + data, + cfg, + params_override=params_override, + vocab_override=vocab_override, + special_vocab=special_vocab, ) converter.save() - logger.info(f'* Successful completion. Output saved to: {cfg.output}') + logger.info(f"* Successful completion. Output saved to: {cfg.output}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/convert_lora_to_gguf.py b/smallthinker/convert_lora_to_gguf.py index 00a6733c..d453db37 100755 --- a/smallthinker/convert_lora_to_gguf.py +++ b/smallthinker/convert_lora_to_gguf.py @@ -11,7 +11,16 @@ import json from math import prod from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Iterator, + Sequence, + SupportsIndex, + cast, +) from transformers import AutoConfig import torch @@ -19,8 +28,8 @@ if TYPE_CHECKING: from torch import Tensor -if 'NO_LOCAL_GGUF' not in os.environ: - sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) +if "NO_LOCAL_GGUF" not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / "gguf-py")) import gguf # reuse model definitions from convert_hf_to_gguf.py @@ -59,7 +68,9 @@ def __getitem__( indices: ( SupportsIndex | slice - | tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature + | tuple[ + SupportsIndex | slice | Tensor, ... + ] # TODO: add ellipsis in the type signature ), ) -> LoraTorchTensor: shape = self.shape @@ -92,7 +103,10 @@ def __getitem__( ) if len(indices) < len(shape): - indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) + indices = ( + *indices, + *(slice(None, None) for _ in range(len(indices), len(shape))), + ) # TODO: make sure this is correct indices_A = ( @@ -140,7 +154,9 @@ def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor: n_elems = prod(orig_shape) n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) assert n_elems % n_new_elems == 0 - new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) + new_shape = ( + *(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape), + ) if new_shape[-1] != orig_shape[-1]: raise NotImplementedError # can't reshape the row size trivially @@ -166,7 +182,9 @@ def permute(self, *dims: int) -> LoraTorchTensor: assert all(dim == 1 for dim in self._lora_A.shape[:-2]) return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: - return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) + return LoraTorchTensor( + self._lora_B.permute(*dims), self._lora_A.permute(*dims) + ) else: # TODO: compose the above two raise NotImplementedError @@ -181,7 +199,9 @@ def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor: return self.transpose(axis0, axis1) def to(self, *args, **kwargs): - return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) + return LoraTorchTensor( + self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs) + ) @classmethod def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): @@ -234,41 +254,53 @@ def get_base_tensor_name(lora_tensor_name: str) -> str: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file") + description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file" + ) parser.add_argument( - "--outfile", type=Path, + "--outfile", + type=Path, help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", + "--outtype", + type=str, + choices=["f32", "f16", "bf16", "q8_0", "auto"], + default="f16", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( - "--bigendian", action="store_true", + "--bigendian", + action="store_true", help="model is executed on big endian machine", ) parser.add_argument( - "--no-lazy", action="store_true", + "--no-lazy", + action="store_true", help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", ) parser.add_argument( - "--verbose", action="store_true", + "--verbose", + action="store_true", help="increase output verbosity", ) parser.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="only print out what will be done, without writing any new files", ) parser.add_argument( - "--base", type=Path, + "--base", + type=Path, help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config", ) parser.add_argument( - "--base-model-id", type=str, + "--base-model-id", + type=str, help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')", ) parser.add_argument( - "lora_path", type=Path, + "lora_path", + type=Path, help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)", ) @@ -281,7 +313,7 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: return config.to_dict() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) @@ -332,11 +364,17 @@ def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]: hparams = load_hparams_from_hf(model_id) except OSError as e: logger.error(f"Failed to load base model config: {e}") - logger.error("Please try downloading the base model and add its path to --base") + logger.error( + "Please try downloading the base model and add its path to --base" + ) sys.exit(1) else: - logger.error("'base_model_name_or_path' is not found in adapter_config.json") - logger.error("Base model config is required. Please download the base model and add its path to --base") + logger.error( + "'base_model_name_or_path' is not found in adapter_config.json" + ) + logger.error( + "Base model config is required. Please download the base model and add its path to --base" + ) sys.exit(1) else: logger.info(f"Loading base model: {dir_base_model.name}") @@ -354,7 +392,9 @@ class LoraModel(model_class): lora_alpha: float - def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs): + def __init__( + self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs + ): super().__init__(*args, **kwargs) @@ -369,7 +409,9 @@ def set_type(self): self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") def set_gguf_parameters(self): - self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) + self.gguf_writer.add_float32( + gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha + ) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: # Never add extra tensors (e.g. rope_freqs) for LoRA adapters @@ -392,10 +434,16 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: if "_layernorm" in name or ".norm" in name: yield (base_name, tensor) continue - logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") + logger.error( + f"Unexpected name '{name}': Not a lora_A or lora_B tensor" + ) if ".embed_tokens.weight" in name or ".lm_head.weight" in name: - logger.error("Embeddings is present in the adapter. This can be due to new tokens added during fine tuning") - logger.error("Please refer to https://github.com/ggml-org/llama.cpp/pull/9948") + logger.error( + "Embeddings is present in the adapter. This can be due to new tokens added during fine tuning" + ) + logger.error( + "Please refer to https://github.com/ggml-org/llama.cpp/pull/9948" + ) sys.exit(1) if base_name in tensor_map: @@ -412,16 +460,23 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]: for name, tensor in tensor_map.items(): assert tensor.A is not None assert tensor.B is not None - yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) + yield ( + name, + cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)), + ) - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + def modify_tensors( + self, data_torch: Tensor, name: str, bid: int | None + ) -> Iterable[tuple[str, Tensor]]: dest = list(super().modify_tensors(data_torch, name, bid)) # some archs may have the same tensor for lm_head and output (tie word embeddings) # in this case, adapters targeting lm_head will fail when using llama-export-lora # therefore, we ignore them for now # see: https://github.com/ggml-org/llama.cpp/issues/9065 if name == "lm_head.weight" and len(dest) == 0: - raise ValueError("lm_head is present in adapter, but is ignored in base model") + raise ValueError( + "lm_head is present in adapter, but is ignored in base model" + ) for dest_name, dest_data in dest: # mergekit-extract-lora add these layernorm to the adapter if "_norm" in dest_name: diff --git a/smallthinker/docs/docker.md b/smallthinker/docs/docker.md index f8f0573c..7b6cfc8e 100644 --- a/smallthinker/docs/docker.md +++ b/smallthinker/docs/docker.md @@ -30,6 +30,8 @@ The GPU enabled images are not currently tested by CI beyond being built. They a ## Usage +For Chrome AI Hub integration in this repository, prefer host port `18001` for containerized server examples. + The easiest way to download the models, convert them to ggml and optimize them is with the --all-in-one command which includes the full docker image. Replace `/path/to/models` below with the actual path where you downloaded the models. @@ -53,7 +55,7 @@ docker run -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:light -m /model or with a server image: ```bash -docker run -v /path/to/models:/models -p 8000:8000 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 +docker run -v /path/to/models:/models -p 18001:18001 ghcr.io/ggml-org/llama.cpp:server -m /models/7B/ggml-model-q4_0.gguf --port 18001 --host 0.0.0.0 -n 512 ``` ## Docker With CUDA @@ -88,7 +90,7 @@ After building locally, Usage is similar to the non-CUDA examples, but you'll ne ```bash docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 -docker run --gpus all -v /path/to/models:/models local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models -p 18001:18001 local/llama.cpp:server-cuda -m /models/7B/ggml-model-q4_0.gguf --port 18001 --host 0.0.0.0 -n 512 --n-gpu-layers 1 ``` ## Docker With MUSA @@ -122,5 +124,5 @@ After building locally, Usage is similar to the non-MUSA examples, but you'll ne ```bash docker run -v /path/to/models:/models local/llama.cpp:full-musa --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 docker run -v /path/to/models:/models local/llama.cpp:light-musa -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 -docker run -v /path/to/models:/models local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 8000 --host 0.0.0.0 -n 512 --n-gpu-layers 1 +docker run -v /path/to/models:/models -p 18001:18001 local/llama.cpp:server-musa -m /models/7B/ggml-model-q4_0.gguf --port 18001 --host 0.0.0.0 -n 512 --n-gpu-layers 1 ``` diff --git a/smallthinker/examples/convert_legacy_llama.py b/smallthinker/examples/convert_legacy_llama.py index c4ec5c52..441ad1a2 100755 --- a/smallthinker/examples/convert_legacy_llama.py +++ b/smallthinker/examples/convert_legacy_llama.py @@ -28,9 +28,9 @@ import numpy as np -if 'NO_LOCAL_GGUF' not in os.environ: +if "NO_LOCAL_GGUF" not in os.environ: # use .parent.parent since we are in "examples" directory - sys.path.insert(1, str(Path(__file__).parent.parent / 'gguf-py')) + sys.path.insert(1, str(Path(__file__).parent.parent / "gguf-py")) import gguf from gguf import BaseVocab, Vocab, NoVocab, BpeVocab, SentencePieceVocab, LlamaHfVocab @@ -40,17 +40,17 @@ logger = logging.getLogger("convert") -if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): +if hasattr(faulthandler, "register") and hasattr(signal, "SIGUSR1"): faulthandler.register(signal.SIGUSR1) -NDArray: TypeAlias = 'np.ndarray[Any, Any]' +NDArray: TypeAlias = "np.ndarray[Any, Any]" ARCH = gguf.MODEL_ARCH.LLAMA DEFAULT_CONCURRENCY = 8 -ADDED_TOKENS_FILE = 'added_tokens.json' -FAST_TOKENIZER_FILE = 'tokenizer.json' +ADDED_TOKENS_FILE = "added_tokens.json" +FAST_TOKENIZER_FILE = "tokenizer.json" # # data types @@ -72,10 +72,16 @@ class UnquantizedDataType(DataType): pass -DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0']) -DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0']) -DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = []) -DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0']) +DT_F16 = UnquantizedDataType( + "F16", dtype=np.dtype(np.float16), valid_conversions=["F32", "Q8_0"] +) +DT_F32 = UnquantizedDataType( + "F32", dtype=np.dtype(np.float32), valid_conversions=["F16", "Q8_0"] +) +DT_I32 = UnquantizedDataType("I32", dtype=np.dtype(np.int16), valid_conversions=[]) +DT_BF16 = UnquantizedDataType( + "BF16", dtype=np.dtype(np.uint16), valid_conversions=["F32", "F16", "Q8_0"] +) @dataclass(frozen=True) @@ -85,10 +91,12 @@ class QuantizedDataType(DataType): ggml_type: gguf.GGMLQuantizationType def quantize(self, arr: NDArray) -> NDArray: - raise NotImplementedError(f'Quantization for {self.name} not implemented') + raise NotImplementedError(f"Quantization for {self.name} not implemented") def elements_to_bytes(self, n_elements: int) -> int: - assert n_elements % self.block_size == 0, f'Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}' + assert ( + n_elements % self.block_size == 0 + ), f"Invalid number of elements {n_elements} for {self.name} with block size {self.block_size}" return self.quantized_dtype.itemsize * (n_elements // self.block_size) @@ -96,38 +104,47 @@ def elements_to_bytes(self, n_elements: int) -> int: class Q8_0QuantizedDataType(QuantizedDataType): # Mini Q8_0 quantization in Python! def quantize(self, arr: NDArray) -> NDArray: - assert arr.size % self.block_size == 0 and arr.size != 0, f'Bad array size {arr.size}' - assert arr.dtype == np.float32, f'Bad array type {arr.dtype}' + assert ( + arr.size % self.block_size == 0 and arr.size != 0 + ), f"Bad array size {arr.size}" + assert arr.dtype == np.float32, f"Bad array type {arr.dtype}" n_blocks = arr.size // self.block_size blocks = arr.reshape((n_blocks, self.block_size)) # Much faster implementation of block quantization contributed by @Cebtenzzre def quantize_blocks_q8_0(blocks: NDArray) -> Iterable[tuple[Any, Any]]: - d = abs(blocks).max(axis = 1) / np.float32(127) - with np.errstate(divide = 'ignore'): + d = abs(blocks).max(axis=1) / np.float32(127) + with np.errstate(divide="ignore"): qs = (blocks / d[:, None]).round() qs[d == 0] = 0 yield from zip(d, qs) - return np.fromiter(quantize_blocks_q8_0(blocks), count = n_blocks, dtype = self.quantized_dtype) + + return np.fromiter( + quantize_blocks_q8_0(blocks), count=n_blocks, dtype=self.quantized_dtype + ) -DT_Q8_0 = Q8_0QuantizedDataType('Q8_0', - dtype = np.dtype(np.float32), valid_conversions = [], - ggml_type = gguf.GGMLQuantizationType.Q8_0, block_size = 32, - quantized_dtype = np.dtype([('d', ' Iterable[tuple[Any, Any]]: class GGMLFileType(enum.IntEnum): - AllF32 = 0 - MostlyF16 = 1 # except 1d tensors + AllF32 = 0 + MostlyF16 = 1 # except 1d tensors MostlyQ8_0 = 7 # except 1d tensors def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: @@ -150,8 +167,8 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = { - GGMLFileType.AllF32 : DT_F32, - GGMLFileType.MostlyF16 : DT_F16, + GGMLFileType.AllF32: DT_F32, + GGMLFileType.MostlyF16: DT_F16, GGMLFileType.MostlyQ8_0: DT_Q8_0, } @@ -162,16 +179,16 @@ def type_for_tensor(self, name: str, tensor: LazyTensor) -> DataType: @dataclass class Params: - n_vocab: int - n_embd: int - n_layer: int - n_ctx: int - n_ff: int - n_head: int - n_head_kv: int - n_experts: int | None = None + n_vocab: int + n_embd: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + n_experts: int | None = None n_experts_used: int | None = None - f_norm_eps: float | None = None + f_norm_eps: float | None = None rope_scaling_type: gguf.RopeScalingType | None = None f_rope_freq_base: float | None = None @@ -187,15 +204,33 @@ class Params: @staticmethod def guessed(model: LazyModel) -> Params: # try transformer naming first - n_vocab, n_embd = model["model.embed_tokens.weight"].shape if "model.embed_tokens.weight" in model else model["tok_embeddings.weight"].shape + n_vocab, n_embd = ( + model["model.embed_tokens.weight"].shape + if "model.embed_tokens.weight" in model + else model["tok_embeddings.weight"].shape + ) # try transformer naming first if "model.layers.0.self_attn.q_proj.weight" in model: - n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.q_proj.weight" not in model) - elif "model.layers.0.self_attn.W_pack.weight" in model: # next: try baichuan naming - n_layer = next(i for i in itertools.count() if f"model.layers.{i}.self_attn.W_pack.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.q_proj.weight" not in model + ) + elif ( + "model.layers.0.self_attn.W_pack.weight" in model + ): # next: try baichuan naming + n_layer = next( + i + for i in itertools.count() + if f"model.layers.{i}.self_attn.W_pack.weight" not in model + ) else: - n_layer = next(i for i in itertools.count() if f"layers.{i}.attention.wq.weight" not in model) + n_layer = next( + i + for i in itertools.count() + if f"layers.{i}.attention.wq.weight" not in model + ) if n_layer < 1: msg = """\ @@ -203,22 +238,22 @@ def guessed(model: LazyModel) -> Params: Suggestion: provide 'config.json' of the model in the same directory containing model files.""" raise KeyError(textwrap.dedent(msg)) - n_head = n_embd // 128 # guessed - n_mult = 256 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed # TODO: verify this n_ff = int(2 * (4 * n_embd) / 3) n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_layer = n_layer, - n_ctx = -1, - n_ff = n_ff, - n_head = n_head, - n_head_kv = n_head, - f_norm_eps = 1e-5, + n_vocab=n_vocab, + n_embd=n_embd, + n_layer=n_layer, + n_ctx=-1, + n_ff=n_ff, + n_head=n_head, + n_head_kv=n_head, + f_norm_eps=1e-5, ) @staticmethod @@ -236,10 +271,10 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: rope_scaling_type = gguf.RopeScalingType.LINEAR elif typ == "yarn": rope_scaling_type = gguf.RopeScalingType.YARN - n_ctx_orig = rope_scaling['original_max_position_embeddings'] - rope_finetuned = rope_scaling['finetuned'] + n_ctx_orig = rope_scaling["original_max_position_embeddings"] + rope_finetuned = rope_scaling["finetuned"] else: - raise NotImplementedError(f'Unknown rope scaling type: {typ}') + raise NotImplementedError(f"Unknown rope scaling type: {typ}") if "max_sequence_length" in config: n_ctx = config["max_sequence_length"] @@ -251,7 +286,7 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: Suggestion: provide 'config.json' of the model in the same directory containing model files.""" raise KeyError(textwrap.dedent(msg)) - n_experts = None + n_experts = None n_experts_used = None if "num_local_experts" in config: @@ -259,21 +294,21 @@ def loadHFTransformerJson(model: LazyModel, config_path: Path) -> Params: n_experts_used = config["num_experts_per_tok"] return Params( - n_vocab = config["vocab_size"], - n_embd = config["hidden_size"], - n_layer = config["num_hidden_layers"], - n_ctx = n_ctx, - n_ff = config["intermediate_size"], - n_head = (n_head := config["num_attention_heads"]), - n_head_kv = config.get("num_key_value_heads", n_head), - n_experts = n_experts, - n_experts_used = n_experts_used, - f_norm_eps = config["rms_norm_eps"], - f_rope_freq_base = config.get("rope_theta"), - rope_scaling_type = rope_scaling_type, - f_rope_scale = f_rope_scale, - n_ctx_orig = n_ctx_orig, - rope_finetuned = rope_finetuned, + n_vocab=config["vocab_size"], + n_embd=config["hidden_size"], + n_layer=config["num_hidden_layers"], + n_ctx=n_ctx, + n_ff=config["intermediate_size"], + n_head=(n_head := config["num_attention_heads"]), + n_head_kv=config.get("num_key_value_heads", n_head), + n_experts=n_experts, + n_experts_used=n_experts_used, + f_norm_eps=config["rms_norm_eps"], + f_rope_freq_base=config.get("rope_theta"), + rope_scaling_type=rope_scaling_type, + f_rope_scale=f_rope_scale, + n_ctx_orig=n_ctx_orig, + rope_finetuned=rope_finetuned, ) # LLaMA v2 70B params.json @@ -283,7 +318,7 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: with open(config_path) as f: config = json.load(f) - n_experts = None + n_experts = None n_experts_used = None f_rope_freq_base = None n_ff = None @@ -307,39 +342,39 @@ def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: if config.get("moe"): n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0] - n_experts = config["moe"]["num_experts"] + n_experts = config["moe"]["num_experts"] n_experts_used = config["moe"]["num_experts_per_tok"] f_rope_freq_base = 1e6 assert n_ff is not None return Params( - n_vocab = model["tok_embeddings.weight"].shape[0], - n_embd = config["dim"], - n_layer = config["n_layers"], - n_ctx = n_ctx, - n_ff = n_ff, - n_head = (n_head := config["n_heads"]), - n_head_kv = config.get("n_kv_heads", n_head), - n_experts = n_experts, - n_experts_used = n_experts_used, - f_norm_eps = config["norm_eps"], - f_rope_freq_base = config.get("rope_theta", f_rope_freq_base), + n_vocab=model["tok_embeddings.weight"].shape[0], + n_embd=config["dim"], + n_layer=config["n_layers"], + n_ctx=n_ctx, + n_ff=n_ff, + n_head=(n_head := config["n_heads"]), + n_head_kv=config.get("n_kv_heads", n_head), + n_experts=n_experts, + n_experts_used=n_experts_used, + f_norm_eps=config["norm_eps"], + f_rope_freq_base=config.get("rope_theta", f_rope_freq_base), ) @staticmethod def load(model_plus: ModelPlus) -> Params: - hf_config_path = model_plus.paths[0].parent / "config.json" + hf_config_path = model_plus.paths[0].parent / "config.json" orig_config_path = model_plus.paths[0].parent / "params.json" if hf_config_path.exists(): params = Params.loadHFTransformerJson(model_plus.model, hf_config_path) elif orig_config_path.exists(): params = Params.loadOriginalParamsJson(model_plus.model, orig_config_path) - elif model_plus.format != 'none': + elif model_plus.format != "none": params = Params.guessed(model_plus.model) else: - raise ValueError('Cannot guess params when model format is none') + raise ValueError("Cannot guess params when model format is none") params.path_model = model_plus.paths[0].parent @@ -355,9 +390,11 @@ def load(model_plus: ModelPlus) -> Params: def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: if n_head_kv is not None and n_head != n_head_kv: n_head = n_head_kv - return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) - .swapaxes(1, 2) - .reshape(weights.shape)) + return ( + weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape) + ) class Tensor(ABC): @@ -377,7 +414,9 @@ def to_ggml(self) -> GGMLCompatibleTensor: ... def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray: - assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" + assert ( + bf16_arr.dtype == np.uint16 + ), f"Input array should be of dtype uint16, but got {bf16_arr.dtype}" fp32_arr = bf16_arr.astype(np.uint32) << 16 return fp32_arr.view(np.float32) @@ -397,9 +436,13 @@ def astype(self, data_type: DataType) -> UnquantizedTensor: def to_ggml(self) -> Self: return self - def permute_part(self, n_part: int, n_head: int, n_head_kv: int) -> UnquantizedTensor: + def permute_part( + self, n_part: int, n_head: int, n_head_kv: int + ) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv)) + return UnquantizedTensor( + permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head, n_head_kv) + ) def part(self, n_part: int) -> UnquantizedTensor: r = self.ndarray.shape[0] // 3 @@ -409,7 +452,9 @@ def permute(self, n_head: int, n_head_kv: int) -> UnquantizedTensor: return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) -def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False) -> NDArray: +def load_unquantized( + lazy_tensor: LazyTensor, expected_dtype: Any = None, convert: bool = False +) -> NDArray: tensor = lazy_tensor.load() assert isinstance(tensor, UnquantizedTensor) @@ -420,7 +465,9 @@ def load_unquantized(lazy_tensor: LazyTensor, expected_dtype: Any = None, conver if convert: tensor.ndarray = tensor.ndarray.astype(expected_dtype) else: - raise ValueError(f'expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}') + raise ValueError( + f"expected this tensor to have dtype {expected_dtype}, got {tensor.ndarray.dtype}" + ) return tensor.ndarray @@ -438,8 +485,9 @@ class LazyTensor: def load(self) -> Tensor: ret = self._load() # Should be okay if it maps to the same numpy type? - assert ret.data_type == self.data_type or (self.data_type.dtype == ret.data_type.dtype), \ - (self.data_type, ret.data_type, self.description) + assert ret.data_type == self.data_type or ( + self.data_type.dtype == ret.data_type.dtype + ), (self.data_type, ret.data_type, self.description) return ret def astype(self, data_type: DataType) -> LazyTensor: @@ -447,16 +495,25 @@ def astype(self, data_type: DataType) -> LazyTensor: def load() -> Tensor: return self.load().astype(data_type) - return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') + + return LazyTensor( + load, self.shape, data_type, f"convert({data_type}) {self.description}" + ) def validate_conversion_to(self, data_type: DataType) -> None: - if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: - raise ValueError(f'Cannot validate conversion from {self.data_type} to {data_type}.') + if ( + data_type != self.data_type + and data_type.name not in self.data_type.valid_conversions + ): + raise ValueError( + f"Cannot validate conversion from {self.data_type} to {data_type}." + ) -LazyModel: TypeAlias = 'dict[str, LazyTensor]' +LazyModel: TypeAlias = "dict[str, LazyTensor]" + +ModelFormat: TypeAlias = Literal["ggml", "torch", "safetensors", "none"] -ModelFormat: TypeAlias = Literal['ggml', 'torch', 'safetensors', 'none'] @dataclass class ModelPlus: @@ -480,9 +537,11 @@ def convert(name: str) -> LazyTensor: if len(lazy_tensors[0].shape) == 1: # the tensor is just duplicated in every file return lazy_tensors[0] - if name.startswith('tok_embeddings.') or \ - name.endswith('.attention.wo.weight') or \ - name.endswith('.feed_forward.w2.weight'): + if ( + name.startswith("tok_embeddings.") + or name.endswith(".attention.wo.weight") + or name.endswith(".feed_forward.w2.weight") + ): # split by columns axis = 1 else: @@ -495,8 +554,16 @@ def load() -> UnquantizedTensor: ndarrays = [load_unquantized(tensor) for tensor in lazy_tensors] concatenated = np.concatenate(ndarrays, axis=axis) return UnquantizedTensor(concatenated) - description = 'concatenated[[' + '] | ['.join(lt.description for lt in lazy_tensors) + ']]' - return LazyTensor(load, concatenated_shape, lazy_tensors[0].data_type, description) + + description = ( + "concatenated[[" + + "] | [".join(lt.description for lt in lazy_tensors) + + "]]" + ) + return LazyTensor( + load, concatenated_shape, lazy_tensors[0].data_type, description + ) + return {name: convert(name) for name in names} @@ -526,32 +593,53 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus: def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute(n_head, n_head_kv) - return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + return LazyTensor( + load, + lazy_tensor.shape, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) -def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int) -> LazyTensor: + +def permute_part_lazy( + lazy_tensor: LazyTensor, n_part: int, n_head: int, n_head_kv: int +) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().permute_part(n_part, n_head, n_head_kv) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) + return LazyTensor( + load, + s, + lazy_tensor.data_type, + f"permute({n_head}, {n_head_kv}) " + lazy_tensor.description, + ) def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: def load() -> Tensor: return lazy_tensor.load().part(n_part) + s = lazy_tensor.shape.copy() s[0] = s[0] // 3 - return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) + return LazyTensor(load, s, lazy_tensor.data_type, "part " + lazy_tensor.description) def pack_experts_lazy(lazy_tensors: list[LazyTensor]) -> LazyTensor: def load() -> Tensor: tensors = [lazy_tensor.load() for lazy_tensor in lazy_tensors] return UnquantizedTensor(np.array([tensor.ndarray for tensor in tensors])) + s = lazy_tensors[0].shape.copy() s.insert(0, len(lazy_tensors)) - return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) + return LazyTensor( + load, + s, + lazy_tensors[0].data_type, + "pack_experts " + " | ".join(lt.description for lt in lazy_tensors), + ) # Functionality that simulates `torch.load` but where individual tensors are @@ -581,11 +669,11 @@ def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile self.zip_file = zip_file def persistent_load(self, pid: Any) -> Any: - assert pid[0] == 'storage' + assert pid[0] == "storage" assert isinstance(pid[1], LazyStorageKind) data_type = pid[1].data_type filename_stem = pid[2] - filename = f'{self.data_base_path}/{filename_stem}' + filename = f"{self.data_base_path}/{filename_stem}" info = self.zip_file.getinfo(filename) def load(offset: int, elm_count: int) -> NDArray: @@ -596,18 +684,31 @@ def load(offset: int, elm_count: int) -> NDArray: data = fp.read(size) assert len(data) == size return np.frombuffer(data, dtype) - description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' + + description = f"storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}" return LazyStorage(load=load, kind=pid[1], description=description) @staticmethod - def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, - requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: + def lazy_rebuild_tensor_v2( + storage: Any, + storage_offset: Any, + size: Any, + stride: Any, + requires_grad: Any, + backward_hooks: Any, + metadata: Any = None, + ) -> LazyTensor: assert isinstance(storage, LazyStorage) def load() -> UnquantizedTensor: elm_count = stride[0] * size[0] - return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) - description = f'pickled storage_offset={storage_offset} in {storage.description}' + return UnquantizedTensor( + storage.load(storage_offset, elm_count).reshape(size) + ) + + description = ( + f"pickled storage_offset={storage_offset} in {storage.description}" + ) return LazyTensor(load, list(size), storage.kind.data_type, description) @staticmethod @@ -617,57 +718,70 @@ def rebuild_from_type_v2(func, new_type, args, state): CLASSES: dict[tuple[str, str], type[LazyTensor] | LazyStorageKind] = { # getattr used here as a workaround for mypy not being smart enough to determine # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, + ("torch._tensor", "_rebuild_from_type_v2"): getattr( + rebuild_from_type_v2, "__func__" + ), + ("torch._utils", "_rebuild_tensor_v2"): getattr( + lazy_rebuild_tensor_v2, "__func__" + ), + ("torch", "BFloat16Storage"): LazyStorageKind(DT_BF16), + ("torch", "HalfStorage"): LazyStorageKind(DT_F16), + ("torch", "FloatStorage"): LazyStorageKind(DT_F32), + ("torch", "IntStorage"): LazyStorageKind(DT_I32), + ("torch", "Tensor"): LazyTensor, } def find_class(self, module: str, name: str) -> Any: - if not module.startswith('torch'): + if not module.startswith("torch"): return super().find_class(module, name) return self.CLASSES[(module, name)] def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: zf = zipfile.ZipFile(outer_fp) - pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] + pickle_paths = [name for name in zf.namelist() if name.endswith(".pkl")] assert len(pickle_paths) == 1, pickle_paths - pickle_fp = zf.open(pickle_paths[0], 'r') - unpickler = LazyUnpickler(pickle_fp, - data_base_path=pickle_paths[0][:-4], - zip_file=zf) + pickle_fp = zf.open(pickle_paths[0], "r") + unpickler = LazyUnpickler( + pickle_fp, data_base_path=pickle_paths[0][:-4], zip_file=zf + ) model = unpickler.load() - if 'model' in model: model = model['model'] + if "model" in model: + model = model["model"] as_dict = dict(model.items()) - return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) + return ModelPlus(model=as_dict, paths=[path], format="torch", vocab=None) def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: - header_size, = struct.unpack(' LazyTensor: - data_type = SAFETENSORS_DATA_TYPES[info['dtype']] + data_type = SAFETENSORS_DATA_TYPES[info["dtype"]] numpy_dtype = data_type.dtype - shape: list[int] = info['shape'] - begin, end = info['data_offsets'] + shape: list[int] = info["shape"] + begin, end = info["data_offsets"] assert 0 <= begin <= end <= len(byte_buf) assert end - begin == math.prod(shape) * numpy_dtype.itemsize buf = byte_buf[begin:end] def load() -> UnquantizedTensor: - return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) - description = f'safetensors begin={begin} end={end} type={data_type} path={path}' + return UnquantizedTensor( + np.frombuffer(buf, dtype=numpy_dtype).reshape(shape) + ) + + description = ( + f"safetensors begin={begin} end={end} type={data_type} path={path}" + ) return LazyTensor(load, shape, data_type, description) - model = {name: convert(info) for (name, info) in header.items() if name != '__metadata__'} - return ModelPlus(model=model, paths=[path], format='safetensors', vocab=None) + + model = { + name: convert(info) for (name, info) in header.items() if name != "__metadata__" + } + return ModelPlus(model=model, paths=[path], format="safetensors", vocab=None) def must_read(fp: IO[bytes], length: int) -> bytes: @@ -679,28 +793,34 @@ def must_read(fp: IO[bytes], length: int) -> bytes: @functools.lru_cache(maxsize=None) def lazy_load_file(path: Path) -> ModelPlus: - fp = open(path, 'rb') + fp = open(path, "rb") first8 = fp.read(8) fp.seek(0) - if first8[:2] == b'PK': + if first8[:2] == b"PK": # A zip file, i.e. PyTorch format return lazy_load_torch_file(fp, path) - elif struct.unpack(' Iterable[Out]: - '''Parallel map, but with backpressure. If the caller doesn't call `next` +def bounded_parallel_map( + func: Callable[[In], Out], + iterable: Iterable[In], + concurrency: int, + max_workers: int | None = None, + use_processpool_executor: bool = False, +) -> Iterable[Out]: + """Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than letting results pile up in memory. Specifically, there is a max of one - output value buffered per thread.''' + output value buffered per thread.""" if concurrency < 2: yield from map(func, iterable) # Not reached. @@ -743,7 +863,9 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) # Check for a vocab size mismatch if params.n_vocab == vocab.vocab_size: - logger.warning("Ignoring added_tokens.json since model matches vocab size without it.") + logger.warning( + "Ignoring added_tokens.json since model matches vocab size without it." + ) return if pad_vocab and params.n_vocab > vocab.vocab_size: @@ -767,8 +889,12 @@ def check_vocab_size(params: Params, vocab: BaseVocab, pad_vocab: bool = False) class OutputFile: - def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian = gguf.GGUFEndian.LITTLE): - self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess) + def __init__( + self, fname_out: Path, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE + ): + self.gguf = gguf.GGUFWriter( + fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess + ) def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None: # Metadata About The Model And Its Provenence @@ -837,11 +963,17 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None if "author" in base_model_entry: self.gguf.add_base_model_author(key, base_model_entry["author"]) if "version" in base_model_entry: - self.gguf.add_base_model_version(key, base_model_entry["version"]) + self.gguf.add_base_model_version( + key, base_model_entry["version"] + ) if "organization" in base_model_entry: - self.gguf.add_base_model_organization(key, base_model_entry["organization"]) + self.gguf.add_base_model_organization( + key, base_model_entry["organization"] + ) if "description" in base_model_entry: - self.gguf.add_base_model_description(key, base_model_entry["description"]) + self.gguf.add_base_model_description( + key, base_model_entry["description"] + ) if "url" in base_model_entry: self.gguf.add_base_model_url(key, base_model_entry["url"]) if "doi" in base_model_entry: @@ -849,7 +981,9 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None if "uuid" in base_model_entry: self.gguf.add_base_model_uuid(key, base_model_entry["uuid"]) if "repo_url" in base_model_entry: - self.gguf.add_base_model_repo_url(key, base_model_entry["repo_url"]) + self.gguf.add_base_model_repo_url( + key, base_model_entry["repo_url"] + ) if metadata.datasets is not None: self.gguf.add_dataset_count(len(metadata.datasets)) @@ -861,9 +995,13 @@ def add_meta_model(self, params: Params, metadata: gguf.Metadata | None) -> None if "version" in dataset_entry: self.gguf.add_dataset_version(key, dataset_entry["version"]) if "organization" in dataset_entry: - self.gguf.add_dataset_organization(key, dataset_entry["organization"]) + self.gguf.add_dataset_organization( + key, dataset_entry["organization"] + ) if "description" in dataset_entry: - self.gguf.add_dataset_description(key, dataset_entry["description"]) + self.gguf.add_dataset_description( + key, dataset_entry["description"] + ) if "url" in dataset_entry: self.gguf.add_dataset_url(key, dataset_entry["url"]) if "doi" in dataset_entry: @@ -886,8 +1024,8 @@ def add_meta_arch(self, params: Params) -> None: self.gguf.add_block_count(params.n_layer) self.gguf.add_feed_forward_length(params.n_ff) self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) - self.gguf.add_head_count (params.n_head) - self.gguf.add_head_count_kv (params.n_head_kv) + self.gguf.add_head_count(params.n_head) + self.gguf.add_head_count_kv(params.n_head_kv) if params.n_experts: self.gguf.add_expert_count(params.n_experts) @@ -898,7 +1036,7 @@ def add_meta_arch(self, params: Params) -> None: if params.f_norm_eps: self.gguf.add_layer_norm_rms_eps(params.f_norm_eps) else: - raise ValueError('f_norm_eps is None') + raise ValueError("f_norm_eps is None") if params.f_rope_freq_base is not None: self.gguf.add_rope_freq_base(params.f_rope_freq_base) @@ -917,7 +1055,9 @@ def add_meta_arch(self, params: Params) -> None: if params.ftype is not None: self.gguf.add_file_type(params.ftype) - def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + def extract_vocabulary_from_model( + self, vocab: Vocab + ) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: tokens = [] scores = [] toktypes = [] @@ -949,10 +1089,14 @@ def add_meta_special_vocab(self, svocab: gguf.SpecialVocab) -> None: def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: n_elements = int(np.prod(tensor.shape)) - raw_dtype = getattr(tensor.data_type, 'ggml_type', None) - data_type = getattr(tensor.data_type, 'quantized_type', None) or tensor.data_type.dtype + raw_dtype = getattr(tensor.data_type, "ggml_type", None) + data_type = ( + getattr(tensor.data_type, "quantized_type", None) or tensor.data_type.dtype + ) data_nbytes = tensor.data_type.elements_to_bytes(n_elements) - self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype) + self.gguf.add_tensor_info( + name, tensor.shape, data_type, data_nbytes, raw_dtype=raw_dtype + ) def write_meta(self) -> None: self.gguf.write_header_to_file() @@ -961,20 +1105,29 @@ def write_meta(self) -> None: def write_tensor_info(self) -> None: self.gguf.write_ti_data_to_file() - def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + def write_tensor_data( + self, ftype: GGMLFileType, model: LazyModel, concurrency: int + ) -> None: + ndarrays_inner = bounded_parallel_map( + OutputFile.do_item, model.items(), concurrency=concurrency + ) if ftype == GGMLFileType.MostlyQ8_0: ndarrays = bounded_parallel_map( - OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + OutputFile.maybe_do_quantize, + ndarrays_inner, + concurrency=concurrency, + max_workers=concurrency, use_processpool_executor=True, ) else: ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + for i, ((name, lazy_tensor), ndarray) in enumerate( + zip(model.items(), ndarrays) + ): elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + size = " x ".join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) logger.info( f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" @@ -986,8 +1139,13 @@ def close(self) -> None: @staticmethod def write_vocab_only( - fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, - endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None, + fname_out: Path, + params: Params, + vocab: Vocab, + svocab: gguf.SpecialVocab, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + pad_vocab: bool = False, + metadata: gguf.Metadata | None = None, ) -> None: check_vocab_size(params, vocab, pad_vocab=pad_vocab) @@ -1018,8 +1176,14 @@ def maybe_do_quantize(item: tuple[DataType, NDArray]) -> NDArray: @staticmethod def write_all( - fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: BaseVocab, svocab: gguf.SpecialVocab, - concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + fname_out: Path, + ftype: GGMLFileType, + params: Params, + model: LazyModel, + vocab: BaseVocab, + svocab: gguf.SpecialVocab, + concurrency: int = DEFAULT_CONCURRENCY, + endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, metadata: gguf.Metadata | None = None, ) -> None: @@ -1050,28 +1214,38 @@ def write_all( def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType: - wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type + wq_type = model[ + gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight" + ].data_type - if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + if output_type_str == "f32" or ( + output_type_str is None and wq_type in (DT_F32, DT_BF16) + ): return GGMLFileType.AllF32 if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): return GGMLFileType.MostlyF16 if output_type_str == "q8_0": return GGMLFileType.MostlyQ8_0 - name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + name_to_type = { + name: lazy_tensor.data_type for (name, lazy_tensor) in model.items() + } raise ValueError(f"Unexpected combination of types: {name_to_type}") -def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) -> tuple[int, int, int]: +def per_model_weight_count_estimation( + tensors: Iterable[tuple[str, LazyTensor]], +) -> tuple[int, int, int]: total_params = 0 shared_params = 0 expert_params = 0 for name, lazy_tensor in tensors: # We don't need these - if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + if name.endswith( + (".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq") + ): continue # Got A Tensor @@ -1093,11 +1267,15 @@ def per_model_weight_count_estimation(tensors: Iterable[tuple[str, LazyTensor]]) def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: - return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) - for (name, tensor) in model.items()} + return { + name: tensor.astype(output_type.type_for_tensor(name, tensor)) + for (name, tensor) in model.items() + } -def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> LazyModel: +def convert_model_names( + model: LazyModel, params: Params, skip_unknown: bool +) -> LazyModel: tmap = gguf.TensorNameMap(ARCH, params.n_layer) should_skip = set(gguf.MODEL_TENSOR_SKIP.get(ARCH, [])) @@ -1110,62 +1288,103 @@ def convert_model_names(model: LazyModel, params: Params, skip_unknown: bool) -> experts = [] for e in range(params.n_experts): if f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" in model: - experts.append(model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"]) + experts.append( + model[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"] + ) del tmp[f"layers.{i_l}.feed_forward.experts.{e}.w{w}.weight"] - elif f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" in model: - experts.append(model[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"]) - del tmp[f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight"] + elif ( + f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" + in model + ): + experts.append( + model[ + f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" + ] + ) + del tmp[ + f"model.layers.{i_l}.block_sparse_moe.experts.{e}.w{w}.weight" + ] else: - raise ValueError(f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight") - tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) + raise ValueError( + f"Expert tensor not found: layers.{i_l}.feed_forward.experts.{e}.w{w}.weight" + ) + tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = ( + pack_experts_lazy(experts) + ) # HF models permut or pack some of the tensors, so we need to undo that for i in itertools.count(): if f"model.layers.{i}.self_attn.q_proj.weight" in model: logger.debug(f"Permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.q_proj.weight"], + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy( + model[f"model.layers.{i}.self_attn.k_proj.weight"], + params.n_head, + params.n_head_kv, + ) # tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] elif f"model.layers.{i}.self_attn.W_pack.weight" in model: logger.debug(f"Unpacking and permuting layer {i}") - tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) - tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) - tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 0, + params.n_head, + params.n_head, + ) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], + 1, + params.n_head, + params.n_head_kv, + ) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy( + model[f"model.layers.{i}.self_attn.W_pack.weight"], 2 + ) del tmp[f"model.layers.{i}.self_attn.W_pack.weight"] else: break out: LazyModel = {} for name, lazy_tensor in model.items(): - tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + tensor_type, name_new = tmap.get_type_and_name( + name, try_suffixes=(".weight", ".bias") + ) or (None, None) if name_new is None: if skip_unknown: logger.warning(f"Unexpected tensor name: {name} - skipping") continue - raise ValueError(f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)") + raise ValueError( + f"Unexpected tensor name: {name}. Use --skip-unknown to ignore it (e.g. LLaVA)" + ) if tensor_type in should_skip: logger.debug(f"skipping tensor {name_new}") continue - logger.debug(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}") + logger.debug( + f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type.name:6s} | {lazy_tensor.shape}" + ) out[name_new] = lazy_tensor return out def nth_multifile_path(path: Path, n: int) -> Path | None: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the nth path in the model. - ''' + """ # Support the following patterns: patterns = [ # - x.00.pth, x.01.pth, etc. - (r'\.[0-9]{2}\.pth$', f'.{n:02}.pth'), + (r"\.[0-9]{2}\.pth$", f".{n:02}.pth"), # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. - (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), + (r"-[0-9]{5}-of-(.*)$", rf"-{n:05}-of-\1"), # x.bin, x.bin.1, etc. - (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') + (r"(\.[0-9]+)?$", r"\1" if n == 0 else rf"\1.{n}"), ] for regex, replacement in patterns: if re.search(regex, path.name): @@ -1176,9 +1395,9 @@ def nth_multifile_path(path: Path, n: int) -> Path | None: def find_multifile_paths(path: Path) -> list[Path]: - '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return + """Given any path belonging to a multi-file model (e.g. foo.bin.1), return the whole list of paths in the model. - ''' + """ ret: list[Path] = [] for i in itertools.count(): nth_path = nth_multifile_path(path, i) @@ -1194,20 +1413,31 @@ def find_multifile_paths(path: Path) -> list[Path]: def load_some_model(path: Path) -> ModelPlus: - '''Load a model of any supported format.''' + """Load a model of any supported format.""" # Be extra-friendly and accept either a file or a directory: if path.is_dir(): # Check if it's a set of safetensors files first - globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"] + globs = [ + "model-00001-of-*.safetensors", + "model.safetensors", + "consolidated.safetensors", + ] files = [file for glob in globs for file in path.glob(glob)] if not files: # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] + globs = [ + "consolidated.00.pth", + "pytorch_model-00001-of-*.bin", + "*.pt", + "pytorch_model.bin", + ] files = [file for glob in globs for file in path.glob(glob)] if not files: raise FileNotFoundError(f"Can't find model in directory {path}") if len(files) > 1: - raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") + raise ValueError( + f"Found multiple models in {path}, not sure which to pick: {files}" + ) path = files[0] paths = find_multifile_paths(path) @@ -1226,7 +1456,9 @@ class VocabFactory: def __init__(self, path: Path): self.path = path - def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gguf.SpecialVocab: + def _create_special_vocab( + self, vocab: BaseVocab, model_parent_path: Path + ) -> gguf.SpecialVocab: load_merges = vocab.name == "bpe" n_vocab = vocab.vocab_size if isinstance(vocab, Vocab) else None return gguf.SpecialVocab( @@ -1237,7 +1469,9 @@ def _create_special_vocab(self, vocab: BaseVocab, model_parent_path: Path) -> gg ) def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: - vocab_classes: dict[str, type[Vocab]] = {cls.name: cls for cls in self._VOCAB_CLASSES} + vocab_classes: dict[str, type[Vocab]] = { + cls.name: cls for cls in self._VOCAB_CLASSES + } selected_vocabs: dict[str, type[Vocab]] = {} for vtype in vocab_types: try: @@ -1252,12 +1486,16 @@ def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: except FileNotFoundError: pass # ignore unavailable tokenizers else: - raise FileNotFoundError(f"Could not find a tokenizer matching any of {vocab_types}") + raise FileNotFoundError( + f"Could not find a tokenizer matching any of {vocab_types}" + ) logger.info(f"Loaded vocab file {vocab.fname_tokenizer!r}, type {vocab.name!r}") return vocab - def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> tuple[BaseVocab, gguf.SpecialVocab]: + def load_vocab( + self, vocab_types: list[str] | None, model_parent_path: Path + ) -> tuple[BaseVocab, gguf.SpecialVocab]: vocab: BaseVocab if vocab_types is None: vocab = NoVocab() @@ -1271,39 +1509,61 @@ def load_vocab(self, vocab_types: list[str] | None, model_parent_path: Path) -> return vocab, special_vocab -def default_convention_outfile(file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> str: +def default_convention_outfile( + file_type: GGMLFileType, + expert_count: int | None, + model_params_count: tuple[int, int, int], + metadata: gguf.Metadata, +) -> str: name = metadata.name if metadata.name is not None else None basename = metadata.basename if metadata.basename is not None else None finetune = metadata.finetune if metadata.finetune is not None else None version = metadata.version if metadata.version is not None else None - size_label = metadata.size_label if metadata.size_label is not None else gguf.size_label(*model_params_count, expert_count=expert_count or 0) + size_label = ( + metadata.size_label + if metadata.size_label is not None + else gguf.size_label(*model_params_count, expert_count=expert_count or 0) + ) output_type = { - GGMLFileType.AllF32: "F32", + GGMLFileType.AllF32: "F32", GGMLFileType.MostlyF16: "F16", GGMLFileType.MostlyQ8_0: "Q8_0", }[file_type] - return gguf.naming_convention(name, basename, finetune, version, size_label, output_type) - - -def default_outfile(model_paths: list[Path], file_type: GGMLFileType, expert_count: int | None, model_params_count: tuple[int, int, int], metadata: gguf.Metadata) -> Path: - default_filename = default_convention_outfile(file_type, expert_count, model_params_count, metadata) + return gguf.naming_convention( + name, basename, finetune, version, size_label, output_type + ) + + +def default_outfile( + model_paths: list[Path], + file_type: GGMLFileType, + expert_count: int | None, + model_params_count: tuple[int, int, int], + metadata: gguf.Metadata, +) -> Path: + default_filename = default_convention_outfile( + file_type, expert_count, model_params_count, metadata + ) ret = model_paths[0].parent / f"{default_filename}.gguf" if ret in model_paths: logger.error( f"Error: Default output path ({ret}) would overwrite the input. " - "Please explicitly specify a path using --outfile.") + "Please explicitly specify a path using --outfile." + ) sys.exit(1) return ret def do_dump_model(model_plus: ModelPlus) -> None: - print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100 - print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100 - print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100 + print(f"model_plus.paths = {model_plus.paths!r}") # noqa: NP100 + print(f"model_plus.format = {model_plus.format!r}") # noqa: NP100 + print(f"model_plus.vocab = {model_plus.vocab!r}") # noqa: NP100 for name, lazy_tensor in model_plus.model.items(): - print(f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}") # noqa: NP100 + print( + f"{name}: shape={lazy_tensor.shape} type={lazy_tensor.data_type}; {lazy_tensor.description}" + ) # noqa: NP100 def main(args_in: list[str] | None = None) -> None: @@ -1311,25 +1571,86 @@ def main(args_in: list[str] | None = None) -> None: if np.uint32(1) == np.uint32(1).newbyteorder("<"): # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") - parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") - parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") - parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") - parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") - parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") - parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") - parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default=DEFAULT_CONCURRENCY) - parser.add_argument("--big-endian", action="store_true", help="model is executed on big endian machine") - parser.add_argument("--pad-vocab", action="store_true", help="add pad tokens when model vocab expects more than tokenizer metadata provides") - parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") - parser.add_argument("--metadata", type=Path, help="Specify the path for an authorship metadata override file") - parser.add_argument("--get-outfile", action="store_true", help="get calculated default outfile name") - parser.add_argument("--model-name", type=str, default=None, help="name of the model") + parser = argparse.ArgumentParser( + description="Convert a LLaMA model to a GGML compatible file" + ) + parser.add_argument( + "--dump", + action="store_true", + help="don't convert, just show what's in the model", + ) + parser.add_argument( + "--dump-single", + action="store_true", + help="don't convert, just show what's in a single model file", + ) + parser.add_argument( + "--vocab-only", action="store_true", help="extract only the vocab" + ) + parser.add_argument( + "--no-vocab", action="store_true", help="store model without the vocab" + ) + parser.add_argument( + "--outtype", + choices=output_choices, + help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)", + ) + parser.add_argument( + "--vocab-dir", + type=Path, + help="directory containing tokenizer.model, if separate from model file", + ) + parser.add_argument( + "--vocab-type", + help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", + default="spm,hfft", + ) + parser.add_argument( + "--outfile", type=Path, help="path to write to; default: based on input" + ) + parser.add_argument( + "model", + type=Path, + help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)", + ) + parser.add_argument( + "--ctx", type=int, help="model training context (default: based on input)" + ) + parser.add_argument( + "--concurrency", + type=int, + help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", + default=DEFAULT_CONCURRENCY, + ) + parser.add_argument( + "--big-endian", + action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "--pad-vocab", + action="store_true", + help="add pad tokens when model vocab expects more than tokenizer metadata provides", + ) + parser.add_argument( + "--skip-unknown", + action="store_true", + help="skip unknown tensor names instead of failing", + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) + parser.add_argument( + "--metadata", + type=Path, + help="Specify the path for an authorship metadata override file", + ) + parser.add_argument( + "--get-outfile", action="store_true", help="get calculated default outfile name" + ) + parser.add_argument( + "--model-name", type=str, default=None, help="name of the model" + ) args = parser.parse_args(args_in) @@ -1353,10 +1674,14 @@ def main(args_in: list[str] | None = None) -> None: model_params_count = per_model_weight_count_estimation(model_plus.model.items()) ftype = pick_output_type(model, args.outtype) - if (metadata is None or metadata.name is None) and params.path_model is not None: + if ( + metadata is None or metadata.name is None + ) and params.path_model is not None: metadata.name = params.path_model.name - print(f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}") # noqa: NP100 + print( + f"{default_convention_outfile(ftype, params.n_experts, model_params_count, metadata)}" + ) # noqa: NP100 return if args.no_vocab and args.vocab_only: @@ -1370,7 +1695,9 @@ def main(args_in: list[str] | None = None) -> None: if not args.vocab_only: model_plus = load_some_model(dir_model) else: - model_plus = ModelPlus(model = {}, paths = [dir_model / 'dummy'], format = 'none', vocab = None) + model_plus = ModelPlus( + model={}, paths=[dir_model / "dummy"], format="none", vocab=None + ) if args.dump: do_dump_model(model_plus) @@ -1415,17 +1742,24 @@ def main(args_in: list[str] | None = None) -> None: outfile = args.outfile if params is None: params = Params( - n_vocab = vocab.vocab_size, - n_embd = 1, - n_layer = 1, - n_ctx = 1, - n_ff = 1, - n_head = 1, - n_head_kv = 1, - f_norm_eps = 1e-5, + n_vocab=vocab.vocab_size, + n_embd=1, + n_layer=1, + n_ctx=1, + n_ff=1, + n_head=1, + n_head_kv=1, + f_norm_eps=1e-5, ) - OutputFile.write_vocab_only(outfile, params, vocab, special_vocab, - endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + OutputFile.write_vocab_only( + outfile, + params, + vocab, + special_vocab, + endianess=endianess, + pad_vocab=args.pad_vocab, + metadata=metadata, + ) logger.info(f"Wrote {outfile}") return @@ -1438,25 +1772,41 @@ def main(args_in: list[str] | None = None) -> None: metadata.name = params.path_model.name model_params_count = per_model_weight_count_estimation(model_plus.model.items()) - logger.info(f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})") + logger.info( + f"model parameters count : {model_params_count} ({gguf.model_weight_count_rounded_notation(model_params_count[0])})" + ) logger.info(f"Vocab info: {vocab}") logger.info(f"Special vocab info: {special_vocab}") - model = model_plus.model - model = convert_model_names(model, params, args.skip_unknown) - ftype = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, ftype) - outfile = args.outfile or default_outfile(model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata) - - metadata.size_label = gguf.size_label(*model_params_count, expert_count=params.n_experts or 0) + model = model_plus.model + model = convert_model_names(model, params, args.skip_unknown) + ftype = pick_output_type(model, args.outtype) + model = convert_to_output_type(model, ftype) + outfile = args.outfile or default_outfile( + model_plus.paths, ftype, params.n_experts, model_params_count, metadata=metadata + ) + + metadata.size_label = gguf.size_label( + *model_params_count, expert_count=params.n_experts or 0 + ) params.ftype = ftype logger.info(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, - concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab, metadata=metadata) + OutputFile.write_all( + outfile, + ftype, + params, + model, + vocab, + special_vocab, + concurrency=args.concurrency, + endianess=endianess, + pad_vocab=args.pad_vocab, + metadata=metadata, + ) logger.info(f"Wrote {outfile}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/examples/jeopardy/graph.py b/smallthinker/examples/jeopardy/graph.py index 8bc0706b..e6b6bb6a 100755 --- a/smallthinker/examples/jeopardy/graph.py +++ b/smallthinker/examples/jeopardy/graph.py @@ -11,7 +11,7 @@ def bar_chart(numbers, labels, pos): - plt.bar(pos, numbers, color='blue') + plt.bar(pos, numbers, color="blue") plt.xticks(ticks=pos, labels=labels) plt.title("Jeopardy Results by Model") plt.xlabel("Model") @@ -21,7 +21,9 @@ def bar_chart(numbers, labels, pos): def calculatecorrect(): directory = os.fsencode("./examples/jeopardy/results/") - csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',') + csv_reader = csv.reader( + open("./examples/jeopardy/qasheet.csv", "rt"), delimiter="," + ) for row in csv_reader: global rows rows.append(row) @@ -48,7 +50,7 @@ def calculatecorrect(): numbers.append(totalcorrect) -if __name__ == '__main__': +if __name__ == "__main__": calculatecorrect() pos = list(range(numEntries)) labels.append("Human") diff --git a/smallthinker/examples/json_schema_pydantic_example.py b/smallthinker/examples/json_schema_pydantic_example.py index 19c0bdb5..0ee5c40d 100644 --- a/smallthinker/examples/json_schema_pydantic_example.py +++ b/smallthinker/examples/json_schema_pydantic_example.py @@ -10,29 +10,40 @@ if True: - def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1/chat/completions", messages, **kwargs): - ''' + def create_completion( + *, + response_model=None, + endpoint="http://localhost:8080/v1/chat/completions", + messages, + **kwargs, + ): + """ Creates a chat completion using an OpenAI-compatible endpoint w/ JSON schema support (llama.cpp server, llama-cpp-python, Anyscale / Together...) The response_model param takes a type (+ supports Pydantic) and behaves just as w/ Instructor (see below) - ''' + """ response_format = None type_adapter = None if response_model: type_adapter = TypeAdapter(response_model) schema = type_adapter.json_schema() - messages = [{ - "role": "system", - "content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}" - }] + messages - response_format={"type": "json_object", "schema": schema} + messages = [ + { + "role": "system", + "content": f"You respond in JSON format with the following schema: {json.dumps(schema, indent=2)}", + } + ] + messages + response_format = {"type": "json_object", "schema": schema} - data = requests.post(endpoint, headers={"Content-Type": "application/json"}, - json=dict(messages=messages, response_format=response_format, **kwargs)).json() - if 'error' in data: - raise Exception(data['error']['message']) + data = requests.post( + endpoint, + headers={"Content-Type": "application/json"}, + json=dict(messages=messages, response_format=response_format, **kwargs), + ).json() + if "error" in data: + raise Exception(data["error"]["message"]) content = data["choices"][0]["message"]["content"] return type_adapter.validate_json(content) if type_adapter else content @@ -44,17 +55,20 @@ def create_completion(*, response_model=None, endpoint="http://localhost:8080/v1 # (see https://python.useinstructor.com/) #! pip install instructor openai import instructor, openai + client = instructor.patch( openai.OpenAI(api_key="123", base_url="http://localhost:8080"), - mode=instructor.Mode.JSON_SCHEMA) + mode=instructor.Mode.JSON_SCHEMA, + ) create_completion = client.chat.completions.create -if __name__ == '__main__': +if __name__ == "__main__": class QAPair(BaseModel): class Config: - extra = 'forbid' # triggers additionalProperties: false in the JSON schema + extra = "forbid" # triggers additionalProperties: false in the JSON schema + question: str concise_answer: str justification: str @@ -62,21 +76,28 @@ class Config: class PyramidalSummary(BaseModel): class Config: - extra = 'forbid' # triggers additionalProperties: false in the JSON schema + extra = "forbid" # triggers additionalProperties: false in the JSON schema + title: str summary: str question_answers: Annotated[List[QAPair], MinLen(2)] - sub_sections: Optional[Annotated[List['PyramidalSummary'], MinLen(2)]] - - print("# Summary\n", create_completion( - model="...", - response_model=PyramidalSummary, - messages=[{ - "role": "user", - "content": f""" + sub_sections: Optional[Annotated[List["PyramidalSummary"], MinLen(2)]] + + print( + "# Summary\n", + create_completion( + model="...", + response_model=PyramidalSummary, + messages=[ + { + "role": "user", + "content": f""" You are a highly efficient corporate document summarizer. Create a pyramidal summary of an imaginary internal document about our company processes (starting high-level, going down to each sub sections). Keep questions short, and answers even shorter (trivia / quizz style). - """ - }])) + """, + } + ], + ), + ) diff --git a/smallthinker/examples/json_schema_to_grammar.py b/smallthinker/examples/json_schema_to_grammar.py index ed379585..9b012db9 100755 --- a/smallthinker/examples/json_schema_to_grammar.py +++ b/smallthinker/examples/json_schema_to_grammar.py @@ -8,26 +8,42 @@ import sys from typing import Any, List, Optional, Set, Tuple, Union + def _build_repetition(item_rule, min_items, max_items, separator_rule=None): if max_items == 0: return "" if min_items == 0 and max_items == 1: - return f'{item_rule}?' + return f"{item_rule}?" if not separator_rule: if min_items == 1 and max_items is None: - return f'{item_rule}+' + return f"{item_rule}+" elif min_items == 0 and max_items is None: - return f'{item_rule}*' + return f"{item_rule}*" else: return f'{item_rule}{{{min_items},{max_items if max_items is not None else ""}}}' - result = item_rule + ' ' + _build_repetition(f'({separator_rule} {item_rule})', min_items - 1 if min_items > 0 else 0, max_items - 1 if max_items is not None else None) - return f'({result})?' if min_items == 0 else result + result = ( + item_rule + + " " + + _build_repetition( + f"({separator_rule} {item_rule})", + min_items - 1 if min_items > 0 else 0, + max_items - 1 if max_items is not None else None, + ) + ) + return f"({result})?" if min_items == 0 else result + -def _generate_min_max_int(min_value: Optional[int], max_value: Optional[int], out: list, decimals_left: int = 16, top_level: bool = True): +def _generate_min_max_int( + min_value: Optional[int], + max_value: Optional[int], + out: list, + decimals_left: int = 16, + top_level: bool = True, +): has_min = min_value != None has_max = max_value != None @@ -58,16 +74,16 @@ def uniform_range(from_str: str, to_str: str): while i < len(from_str) and from_str[i] == to_str[i]: i += 1 if i > 0: - out.append("\"") + out.append('"') out.append(from_str[:i]) - out.append("\"") + out.append('"') if i < len(from_str): if i > 0: out.append(" ") sub_len = len(from_str) - i - 1 if sub_len > 0: - from_sub = from_str[i+1:] - to_sub = to_str[i+1:] + from_sub = from_str[i + 1 :] + to_sub = to_str[i + 1 :] sub_zeros = "0" * sub_len sub_nines = "9" * sub_len @@ -90,7 +106,9 @@ def uniform_range(from_str: str, to_str: str): digit_range(chr(ord(from_str[i]) + 1), to_str[i]) to_reached = True else: - digit_range(chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1)) + digit_range( + chr(ord(from_str[i]) + 1), chr(ord(to_str[i]) - 1) + ) out.append(" ") more_digits(sub_len, sub_len) if not to_reached: @@ -108,13 +126,15 @@ def uniform_range(from_str: str, to_str: str): if has_min and has_max: if min_value < 0 and max_value < 0: - out.append("\"-\" (") - _generate_min_max_int(-max_value, -min_value, out, decimals_left, top_level=True) + out.append('"-" (') + _generate_min_max_int( + -max_value, -min_value, out, decimals_left, top_level=True + ) out.append(")") return if min_value < 0: - out.append("\"-\" (") + out.append('"-" (') _generate_min_max_int(0, -min_value, out, decimals_left, top_level=True) out.append(") | ") min_value = 0 @@ -135,7 +155,7 @@ def uniform_range(from_str: str, to_str: str): if has_min: if min_value < 0: - out.append("\"-\" (") + out.append('"-" (') _generate_min_max_int(None, -min_value, out, decimals_left, top_level=False) out.append(") | [0] | [1-9] ") more_digits(0, decimals_left - 1) @@ -147,7 +167,7 @@ def uniform_range(from_str: str, to_str: str): more_digits(1, decimals_left) elif min_value <= 9: c = str(min_value) - range_start = '1' if top_level else '0' + range_start = "1" if top_level else "0" if c > range_start: digit_range(range_start, chr(ord(c) - 1)) out.append(" ") @@ -168,7 +188,9 @@ def uniform_range(from_str: str, to_str: str): out.append(" | ") digit_range(c, c) out.append(" (") - _generate_min_max_int(int(min_s[1:]), None, out, less_decimals, top_level=False) + _generate_min_max_int( + int(min_s[1:]), None, out, less_decimals, top_level=False + ) out.append(")") if c < "9": out.append(" | ") @@ -180,63 +202,89 @@ def uniform_range(from_str: str, to_str: str): if has_max: if max_value >= 0: if top_level: - out.append("\"-\" [1-9] ") + out.append('"-" [1-9] ') more_digits(0, less_decimals) out.append(" | ") _generate_min_max_int(0, max_value, out, decimals_left, top_level=True) else: - out.append("\"-\" (") + out.append('"-" (') _generate_min_max_int(-max_value, None, out, decimals_left, top_level=False) out.append(")") return raise RuntimeError("At least one of min_value or max_value must be set") + class BuiltinRule: def __init__(self, content: str, deps: list | None = None): self.content = content self.deps = deps or [] + # Constraining spaces to prevent model "running away". SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' PRIMITIVE_RULES = { - 'boolean' : BuiltinRule('("true" | "false") space', []), - 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), - 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), - 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), - 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), - 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), - 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), - 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), - 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), - 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), - 'null' : BuiltinRule('"null" space', []), + "boolean": BuiltinRule('("true" | "false") space', []), + "decimal-part": BuiltinRule("[0-9]{1,16}", []), + "integral-part": BuiltinRule("[0] | [1-9] [0-9]{0,15}", []), + "number": BuiltinRule( + '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', + ["integral-part", "decimal-part"], + ), + "integer": BuiltinRule('("-"? integral-part) space', ["integral-part"]), + "value": BuiltinRule( + "object | array | string | number | boolean | null", + ["object", "array", "string", "number", "boolean", "null"], + ), + "object": BuiltinRule( + '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', + ["string", "value"], + ), + "array": BuiltinRule( + '"[" space ( value ("," space value)* )? "]" space', ["value"] + ), + "uuid": BuiltinRule( + r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', + [], + ), + "char": BuiltinRule( + r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', [] + ), + "string": BuiltinRule(r'"\"" char* "\"" space', ["char"]), + "null": BuiltinRule('"null" space', []), } # TODO: support "uri", "email" string formats STRING_FORMAT_RULES = { - 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), - 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), - 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), - 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), - 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), - 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), + "date": BuiltinRule( + '[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( "0" [1-9] | [1-2] [0-9] | "3" [0-1] )', + [], + ), + "time": BuiltinRule( + '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', + [], + ), + "date-time": BuiltinRule('date "T" time', ["date", "time"]), + "date-string": BuiltinRule('"\\"" date "\\"" space', ["date"]), + "time-string": BuiltinRule('"\\"" time "\\"" space', ["time"]), + "date-time-string": BuiltinRule('"\\"" date-time "\\"" space', ["date-time"]), } -DOTALL = '[\\U00000000-\\U0010FFFF]' -DOT = '[^\\x0A\\x0D]' +DOTALL = "[\\U00000000-\\U0010FFFF]" +DOT = "[^\\x0A\\x0D]" -RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) +RESERVED_NAMES = set( + ["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()] +) -INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') +INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+") GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') -GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} +GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"', "-": "\\-", "]": "\\]"} -NON_LITERAL_SET = set('|.()[]{}*+?') -ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') +NON_LITERAL_SET = set("|.()[]{}*+?") +ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set("^$.[]()|{}*+?") class SchemaConverter: @@ -246,7 +294,7 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._dotall = dotall self._raw_pattern = raw_pattern self._rules = { - 'space': SPACE_RULE, + "space": SPACE_RULE, } self._refs = {} self._refs_being_resolved = set() @@ -257,28 +305,31 @@ def _format_literal(self, literal): ) return f'"{escaped}"' - def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str: - ''' - not_literal('a') -> '[^a]' - not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' - ''' - assert len(literal) > 0, 'Empty literal not supported' + def not_literal( + self, literal: str, dotall: bool = True, maybe_escaped_underscores=False + ) -> str: + """ + not_literal('a') -> '[^a]' + not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?' + """ + assert len(literal) > 0, "Empty literal not supported" + def recurse(i: int): c = literal[i] - if maybe_escaped_underscores and c == '_': - yield f'[^{c}\\\\]' - yield ' | ' + if maybe_escaped_underscores and c == "_": + yield f"[^{c}\\\\]" + yield " | " yield f'"\\\\"? "{c}"' else: - yield f'[^{c}]' + yield f"[^{c}]" if i < len(literal) - 1: - yield ' | ' + yield " | " yield self._format_literal(c) - yield ' (' + yield " (" yield from recurse(i + 1) - yield ')?' + yield ")?" - return ''.join(('(', *recurse(0), ')')) + return "".join(("(", *recurse(0), ")")) def _not_strings(self, strings): class TrieNode: @@ -296,7 +347,7 @@ def insert(self, string): for s in strings: trie.insert(s) - char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) + char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"]) out = ['["] ( '] def visit(node): @@ -308,70 +359,81 @@ def visit(node): if first: first = False else: - out.append(' | ') - out.append(f'[{c}]') + out.append(" | ") + out.append(f"[{c}]") if child.children: - out.append(f' (') + out.append(f" (") visit(child) - out.append(')') + out.append(")") elif child.is_end_of_string: - out.append(f' {char_rule}+') + out.append(f" {char_rule}+") if node.children: if not first: - out.append(' | ') + out.append(" | ") out.append(f'[^"{"".join(rejects)}] {char_rule}*') + visit(trie) out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') - return ''.join(out) + return "".join(out) def _add_rule(self, name, rule): - esc_name = INVALID_RULE_CHARS_RE.sub('-', name) + esc_name = INVALID_RULE_CHARS_RE.sub("-", name) if esc_name not in self._rules or self._rules[esc_name] == rule: key = esc_name else: i = 0 - while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule: + while ( + f"{esc_name}{i}" in self._rules + and self._rules[f"{esc_name}{i}"] != rule + ): i += 1 - key = f'{esc_name}{i}' + key = f"{esc_name}{i}" self._rules[key] = rule return key def resolve_refs(self, schema: dict, url: str): - ''' - Resolves all $ref fields in the given schema, fetching any remote schemas, - replacing $ref with absolute reference URL and populating self._refs with the - respective referenced (sub)schema dictionaries. - ''' + """ + Resolves all $ref fields in the given schema, fetching any remote schemas, + replacing $ref with absolute reference URL and populating self._refs with the + respective referenced (sub)schema dictionaries. + """ + def visit(n: dict): if isinstance(n, list): return [visit(x) for x in n] elif isinstance(n, dict): - ref = n.get('$ref') + ref = n.get("$ref") if ref is not None and ref not in self._refs: - if ref.startswith('https://'): - assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + if ref.startswith("https://"): + assert ( + self._allow_fetch + ), "Fetching remote schemas is not allowed (use --allow-fetch for force)" import requests - frag_split = ref.split('#') + frag_split = ref.split("#") base_url = frag_split[0] target = self._refs.get(base_url) if target is None: - target = self.resolve_refs(requests.get(ref).json(), base_url) + target = self.resolve_refs( + requests.get(ref).json(), base_url + ) self._refs[base_url] = target - if len(frag_split) == 1 or frag_split[-1] == '': + if len(frag_split) == 1 or frag_split[-1] == "": return target - elif ref.startswith('#/'): + elif ref.startswith("#/"): target = schema - ref = f'{url}{ref}' - n['$ref'] = ref + ref = f"{url}{ref}" + n["$ref"] = ref else: - raise ValueError(f'Unsupported ref {ref}') + raise ValueError(f"Unsupported ref {ref}") - for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + for sel in ref.split("#")[-1].split("/")[1:]: + assert ( + target is not None and sel in target + ), f"Error resolving ref {ref}: {sel} not in {target}" target = target[sel] self._refs[ref] = target @@ -380,28 +442,33 @@ def visit(n: dict): visit(v) return n + return visit(schema) def _generate_union_rule(self, name, alt_schemas): - return ' | '.join(( - self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') - for i, alt_schema in enumerate(alt_schemas) - )) + return " | ".join( + ( + self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') + for i, alt_schema in enumerate(alt_schemas) + ) + ) def _visit_pattern(self, pattern, name): - ''' - Transforms a regular expression pattern into a GBNF rule. + """ + Transforms a regular expression pattern into a GBNF rule. - Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions - Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions + Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md - Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. + Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. - Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which - we define sub-rules to keep the output lean. - ''' + Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which + we define sub-rules to keep the output lean. + """ - assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"' + assert pattern.startswith("^") and pattern.endswith( + "$" + ), 'Pattern must start with "^" and end with "$"' pattern = pattern[1:-1] sub_rule_ids = {} @@ -410,12 +477,12 @@ def _visit_pattern(self, pattern, name): def to_rule(s: tuple[str, bool]) -> str: (txt, is_literal) = s - return "\"" + txt + "\"" if is_literal else txt + return '"' + txt + '"' if is_literal else txt def transform() -> tuple[str, bool]: - ''' - Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. - ''' + """ + Parse a unit at index i (advancing it), and return its string representation + whether it's a literal. + """ nonlocal i nonlocal pattern nonlocal sub_rule_ids @@ -433,64 +500,72 @@ def get_dot(): else: # Accept any character... except \n and \r line break chars (\x0A and \xOD) rule = DOT - return self._add_rule(f'dot', rule) + return self._add_rule(f"dot", rule) def join_seq(): nonlocal seq ret = [] for is_literal, g in itertools.groupby(seq, lambda x: x[1]): if is_literal: - ret.append((''.join(x[0] for x in g), True)) + ret.append(("".join(x[0] for x in g), True)) else: ret.extend(g) if len(ret) == 1: return ret[0] - return (' '.join(to_rule(x) for x in seq), False) + return (" ".join(to_rule(x) for x in seq), False) while i < length: c = pattern[i] - if c == '.': + if c == ".": seq.append((get_dot(), False)) i += 1 - elif c == '(': + elif c == "(": i += 1 if i < length: - assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' - seq.append((f'({to_rule(transform())})', False)) - elif c == ')': + assert ( + pattern[i] != "?" + ), f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/' + seq.append((f"({to_rule(transform())})", False)) + elif c == ")": i += 1 - assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}' + assert ( + start > 0 and pattern[start - 1] == "(" + ), f"Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}" return join_seq() - elif c == '[': + elif c == "[": square_brackets = c i += 1 - while i < length and pattern[i] != ']': - if pattern[i] == '\\': - square_brackets += pattern[i:i+2] + while i < length and pattern[i] != "]": + if pattern[i] == "\\": + square_brackets += pattern[i : i + 2] i += 2 else: square_brackets += pattern[i] i += 1 - assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}' - square_brackets += ']' + assert ( + i < length + ), f"Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}" + square_brackets += "]" i += 1 seq.append((square_brackets, False)) - elif c == '|': - seq.append(('|', False)) + elif c == "|": + seq.append(("|", False)) i += 1 - elif c in ('*', '+', '?'): + elif c in ("*", "+", "?"): seq[-1] = (to_rule(seq[-1]) + c, False) i += 1 - elif c == '{': + elif c == "{": curly_brackets = c i += 1 - while i < length and pattern[i] != '}': + while i < length and pattern[i] != "}": curly_brackets += pattern[i] i += 1 - assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}' - curly_brackets += '}' + assert ( + i < length + ), f"Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}" + curly_brackets += "}" i += 1 - nums = [s.strip() for s in curly_brackets[1:-1].split(',')] + nums = [s.strip() for s in curly_brackets[1:-1].split(",")] min_times = 0 max_times = None try: @@ -502,35 +577,46 @@ def join_seq(): min_times = int(nums[0]) if nums[0] else 0 max_times = int(nums[1]) if nums[1] else None except ValueError: - raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/') + raise ValueError( + f"Invalid quantifier {curly_brackets} in /{pattern}/" + ) (sub, sub_is_literal) = seq[-1] if not sub_is_literal: id = sub_rule_ids.get(sub) if id is None: - id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub) + id = self._add_rule(f"{name}-{len(sub_rule_ids) + 1}", sub) sub_rule_ids[sub] = id sub = id - seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times), False) + seq[-1] = ( + _build_repetition( + f'"{sub}"' if sub_is_literal else sub, min_times, max_times + ), + False, + ) else: - literal = '' + literal = "" while i < length: - if pattern[i] == '\\' and i < length - 1: + if pattern[i] == "\\" and i < length - 1: next = pattern[i + 1] if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS: i += 1 literal += pattern[i] i += 1 else: - literal += pattern[i:i+2] + literal += pattern[i : i + 2] i += 2 elif pattern[i] == '"' and not self._raw_pattern: literal += '\\"' i += 1 - elif pattern[i] not in NON_LITERAL_SET and \ - (i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET): + elif pattern[i] not in NON_LITERAL_SET and ( + i == length - 1 + or literal == "" + or pattern[i + 1] == "." + or pattern[i + 1] not in NON_LITERAL_SET + ): literal += pattern[i] i += 1 else: @@ -542,12 +628,15 @@ def join_seq(): return self._add_rule( name, - to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") - + ( + to_rule(transform()) + if self._raw_pattern + else '"\\"" (' + to_rule(transform()) + ') "\\"" space' + ), + ) def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] + ref_name = ref.split("/")[-1] if ref_name not in self._rules and ref not in self._refs_being_resolved: self._refs_being_resolved.add(ref) resolved = self._refs[ref] @@ -559,153 +648,236 @@ def _generate_constant_rule(self, value): return self._format_literal(json.dumps(value)) def visit(self, schema, name): - schema_type = schema.get('type') - schema_format = schema.get('format') - rule_name = name + '-' if name in RESERVED_NAMES else name or 'root' + schema_type = schema.get("type") + schema_format = schema.get("format") + rule_name = name + "-" if name in RESERVED_NAMES else name or "root" - if (ref := schema.get('$ref')) is not None: + if (ref := schema.get("$ref")) is not None: return self._add_rule(rule_name, self._resolve_ref(ref)) - elif 'oneOf' in schema or 'anyOf' in schema: - return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf'])) + elif "oneOf" in schema or "anyOf" in schema: + return self._add_rule( + rule_name, + self._generate_union_rule(name, schema.get("oneOf") or schema["anyOf"]), + ) elif isinstance(schema_type, list): - return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) + return self._add_rule( + rule_name, + self._generate_union_rule( + name, [{**schema, "type": t} for t in schema_type] + ), + ) - elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') + elif "const" in schema: + return self._add_rule( + rule_name, self._generate_constant_rule(schema["const"]) + " space" + ) - elif 'enum' in schema: - rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' + elif "enum" in schema: + rule = ( + "(" + + " | ".join((self._generate_constant_rule(v) for v in schema["enum"])) + + ") space" + ) return self._add_rule(rule_name, rule) - elif schema_type in (None, 'object') and \ - ('properties' in schema or \ - ('additionalProperties' in schema and schema['additionalProperties'] is not True)): - required = set(schema.get('required', [])) - properties = list(schema.get('properties', {}).items()) - return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) + elif schema_type in (None, "object") and ( + "properties" in schema + or ( + "additionalProperties" in schema + and schema["additionalProperties"] is not True + ) + ): + required = set(schema.get("required", [])) + properties = list(schema.get("properties", {}).items()) + return self._add_rule( + rule_name, + self._build_object_rule( + properties, required, name, schema.get("additionalProperties") + ), + ) - elif schema_type in (None, 'object') and 'allOf' in schema: + elif schema_type in (None, "object") and "allOf" in schema: required = set() properties = [] hybrid_name = name + def add_component(comp_schema, is_required): - if (ref := comp_schema.get('$ref')) is not None: + if (ref := comp_schema.get("$ref")) is not None: comp_schema = self._refs[ref] - if 'properties' in comp_schema: - for prop_name, prop_schema in comp_schema['properties'].items(): + if "properties" in comp_schema: + for prop_name, prop_schema in comp_schema["properties"].items(): properties.append((prop_name, prop_schema)) if is_required: required.add(prop_name) - for t in schema['allOf']: - if 'anyOf' in t: - for tt in t['anyOf']: + for t in schema["allOf"]: + if "anyOf" in t: + for tt in t["anyOf"]: add_component(tt, is_required=False) else: add_component(t, is_required=True) - return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) + return self._add_rule( + rule_name, + self._build_object_rule( + properties, required, hybrid_name, additional_properties=None + ), + ) - elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): - items = schema.get('items') or schema['prefixItems'] + elif schema_type in (None, "array") and ( + "items" in schema or "prefixItems" in schema + ): + items = schema.get("items") or schema["prefixItems"] if isinstance(items, list): return self._add_rule( rule_name, - '"[" space ' + - ' "," space '.join( + '"[" space ' + + ' "," space '.join( self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') - for i, item in enumerate(items)) + - ' "]" space') + for i, item in enumerate(items) + ) + + ' "]" space', + ) else: item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') min_items = schema.get("minItems", 0) max_items = schema.get("maxItems") - return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + return self._add_rule( + rule_name, + '"[" space ' + + _build_repetition( + item_rule_name, min_items, max_items, separator_rule='"," space' + ) + + ' "]" space', + ) - elif schema_type in (None, 'string') and 'pattern' in schema: - return self._visit_pattern(schema['pattern'], rule_name) + elif schema_type in (None, "string") and "pattern" in schema: + return self._visit_pattern(schema["pattern"], rule_name) - elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''): + elif schema_type in (None, "string") and re.match( + r"^uuid[1-5]?$", schema_format or "" + ): return self._add_primitive( - 'root' if rule_name == 'root' else schema_format, - PRIMITIVE_RULES['uuid'] + "root" if rule_name == "root" else schema_format, + PRIMITIVE_RULES["uuid"], ) - elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES: - prim_name = f'{schema_format}-string' - return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name])) - - elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema): - char_rule = self._add_primitive('char', PRIMITIVE_RULES['char']) - min_len = schema.get('minLength', 0) - max_len = schema.get('maxLength') + elif ( + schema_type in (None, "string") + and f"{schema_format}-string" in STRING_FORMAT_RULES + ): + prim_name = f"{schema_format}-string" + return self._add_rule( + rule_name, + self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]), + ) - return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + elif schema_type == "string" and ( + "minLength" in schema or "maxLength" in schema + ): + char_rule = self._add_primitive("char", PRIMITIVE_RULES["char"]) + min_len = schema.get("minLength", 0) + max_len = schema.get("maxLength") + + return self._add_rule( + rule_name, + r'"\"" ' + + _build_repetition(char_rule, min_len, max_len) + + r' "\"" space', + ) - elif schema_type in (None, 'integer') and \ - ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): + elif schema_type in (None, "integer") and ( + "minimum" in schema + or "exclusiveMinimum" in schema + or "maximum" in schema + or "exclusiveMaximum" in schema + ): min_value = None max_value = None - if 'minimum' in schema: - min_value = schema['minimum'] - elif 'exclusiveMinimum' in schema: - min_value = schema['exclusiveMinimum'] + 1 - if 'maximum' in schema: - max_value = schema['maximum'] - elif 'exclusiveMaximum' in schema: - max_value = schema['exclusiveMaximum'] - 1 + if "minimum" in schema: + min_value = schema["minimum"] + elif "exclusiveMinimum" in schema: + min_value = schema["exclusiveMinimum"] + 1 + if "maximum" in schema: + max_value = schema["maximum"] + elif "exclusiveMaximum" in schema: + max_value = schema["exclusiveMaximum"] - 1 out = ["("] _generate_min_max_int(min_value, max_value, out) out.append(") space") - return self._add_rule(rule_name, ''.join(out)) + return self._add_rule(rule_name, "".join(out)) - elif (schema_type == 'object') or (len(schema) == 0): - return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + elif (schema_type == "object") or (len(schema) == 0): + return self._add_rule( + rule_name, self._add_primitive("object", PRIMITIVE_RULES["object"]) + ) else: - assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' + assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type]) + return self._add_primitive( + "root" if rule_name == "root" else schema_type, + PRIMITIVE_RULES[schema_type], + ) def _add_primitive(self, name: str, rule: BuiltinRule): n = self._add_rule(name, rule.content) for dep in rule.deps: dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep) - assert dep_rule, f'Rule {dep} not known' + assert dep_rule, f"Rule {dep} not known" if dep not in self._rules: self._add_primitive(dep, dep_rule) return n - def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Optional[Union[bool, Any]]): + def _build_object_rule( + self, + properties: List[Tuple[str, Any]], + required: Set[str], + name: str, + additional_properties: Optional[Union[bool, Any]], + ): prop_order = self._prop_order # sort by position in prop_order (if specified) then by original order - sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))] + sorted_props = [ + kv[0] + for _, kv in sorted( + enumerate(properties), + key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]), + ) + ] prop_kv_rule_names = {} for prop_name, prop_schema in properties: - prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}') + prop_rule_name = self.visit( + prop_schema, f'{name}{"-" if name else ""}{prop_name}' + ) prop_kv_rule_names[prop_name] = self._add_rule( f'{name}{"-" if name else ""}{prop_name}-kv', - fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}' + rf'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}', ) required_props = [k for k in sorted_props if k in required] optional_props = [k for k in sorted_props if k not in required] if additional_properties is not None and additional_properties != False: sub_name = f'{name}{"-" if name else ""}additional' - value_rule = self.visit(additional_properties, f'{sub_name}-value') if isinstance(additional_properties, dict) else \ - self._add_primitive('value', PRIMITIVE_RULES['value']) - key_rule = self._add_primitive('string', PRIMITIVE_RULES['string']) if not sorted_props \ - else self._add_rule(f'{sub_name}-k', self._not_strings(sorted_props)) + value_rule = ( + self.visit(additional_properties, f"{sub_name}-value") + if isinstance(additional_properties, dict) + else self._add_primitive("value", PRIMITIVE_RULES["value"]) + ) + key_rule = ( + self._add_primitive("string", PRIMITIVE_RULES["string"]) + if not sorted_props + else self._add_rule(f"{sub_name}-k", self._not_strings(sorted_props)) + ) prop_kv_rule_names["*"] = self._add_rule( - f'{sub_name}-kv', - f'{key_rule} ":" space {value_rule}' + f"{sub_name}-kv", f'{key_rule} ":" space {value_rule}' ) optional_props.append("*") @@ -713,7 +885,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props) if optional_props: - rule += ' (' + rule += " (" if required_props: rule += ' "," space ( ' @@ -722,93 +894,98 @@ def get_recursive_refs(ks, first_is_optional): kv_rule_name = prop_kv_rule_names[k] comma_ref = f'( "," space {kv_rule_name} )' if first_is_optional: - res = comma_ref + ('*' if k == '*' else '?') + res = comma_ref + ("*" if k == "*" else "?") else: - res = kv_rule_name + (' ' + comma_ref + "*" if k == '*' else '') + res = kv_rule_name + (" " + comma_ref + "*" if k == "*" else "") if len(rest) > 0: - res += ' ' + self._add_rule( + res += " " + self._add_rule( f'{name}{"-" if name else ""}{k}-rest', - get_recursive_refs(rest, first_is_optional=True) + get_recursive_refs(rest, first_is_optional=True), ) return res - rule += ' | '.join( + rule += " | ".join( get_recursive_refs(optional_props[i:], first_is_optional=False) for i in range(len(optional_props)) ) if required_props: - rule += ' )' - rule += ' )?' + rule += " )" + rule += " )?" rule += ' "}" space' return rule def format_grammar(self): - return '\n'.join( - f'{name} ::= {rule}' + return "\n".join( + f"{name} ::= {rule}" for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0]) ) -def main(args_in = None): +def main(args_in=None): parser = argparse.ArgumentParser( - description=''' + description=""" Generates a grammar (suitable for use in ./llama-cli) that produces JSON conforming to a given JSON schema. Only a subset of JSON schema features are supported; more may be added in the future. - ''', + """, ) parser.add_argument( - '--prop-order', + "--prop-order", default=[], - type=lambda s: s.split(','), - help=''' + type=lambda s: s.split(","), + help=""" comma-separated property names defining the order of precedence for object properties; properties not specified here are given lower precedence than those that are, and are kept in their original order from the schema. Required properties are always given precedence over optional properties. - ''' + """, ) parser.add_argument( - '--allow-fetch', - action='store_true', + "--allow-fetch", + action="store_true", default=False, - help='Whether to allow fetching referenced schemas over HTTPS') + help="Whether to allow fetching referenced schemas over HTTPS", + ) parser.add_argument( - '--dotall', - action='store_true', + "--dotall", + action="store_true", default=False, - help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns') + help='Whether to treat dot (".") as matching all chars including line breaks in regular expression patterns', + ) parser.add_argument( - '--raw-pattern', - action='store_true', + "--raw-pattern", + action="store_true", default=False, - help='Treats string patterns as raw patterns w/o quotes (or quote escapes)') + help="Treats string patterns as raw patterns w/o quotes (or quote escapes)", + ) - parser.add_argument('schema', help='file containing JSON schema ("-" for stdin)') + parser.add_argument("schema", help='file containing JSON schema ("-" for stdin)') args = parser.parse_args(args_in) - if args.schema.startswith('https://'): + if args.schema.startswith("https://"): url = args.schema import requests + schema = requests.get(url).json() - elif args.schema == '-': - url = 'stdin' + elif args.schema == "-": + url = "stdin" schema = json.load(sys.stdin) else: - url = f'file://{args.schema}' + url = f"file://{args.schema}" with open(args.schema) as f: schema = json.load(f) converter = SchemaConverter( prop_order={name: idx for idx, name in enumerate(args.prop_order)}, allow_fetch=args.allow_fetch, dotall=args.dotall, - raw_pattern=args.raw_pattern) + raw_pattern=args.raw_pattern, + ) schema = converter.resolve_refs(schema, url) - converter.visit(schema, '') + converter.visit(schema, "") print(converter.format_grammar()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/examples/pydantic_models_to_grammar.py b/smallthinker/examples/pydantic_models_to_grammar.py index 93e5dcb6..45b27c33 100644 --- a/smallthinker/examples/pydantic_models_to_grammar.py +++ b/smallthinker/examples/pydantic_models_to_grammar.py @@ -6,7 +6,17 @@ from copy import copy from enum import Enum from inspect import getdoc, isclass -from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union, get_args, get_origin, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + Callable, + List, + Optional, + Union, + get_args, + get_origin, + get_type_hints, +) from docstring_parser import parse from pydantic import BaseModel, create_model @@ -114,7 +124,9 @@ def generate_list_rule(element_type): def get_members_structure(cls, rule_name): if issubclass(cls, Enum): # Handle Enum types - members = [f'"\\"{member.value}\\""' for name, member in cls.__members__.items()] + members = [ + f'"\\"{member.value}\\""' for name, member in cls.__members__.items() + ] return f"{cls.__name__.lower()} ::= " + " | ".join(members) if cls.__annotations__ and cls.__annotations__ != {}: result = f'{rule_name} ::= "{{"' @@ -212,7 +224,9 @@ def generate_gbnf_integer_rules(max_digit=None, min_digit=None): return integer_rule, additional_rules -def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None, min_precision=None): +def generate_gbnf_float_rules( + max_digit=None, min_digit=None, max_precision=None, min_precision=None +): """ Generate GBNF float rules based on the given constraints. @@ -249,21 +263,29 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None fractional_part_rule = "fractional-part" fractional_rule_part = "" if max_precision is not None or min_precision is not None: - fractional_part_rule += (f"-max{max_precision}" if max_precision is not None else "") + ( - f"-min{min_precision}" if min_precision is not None else "" - ) + fractional_part_rule += ( + f"-max{max_precision}" if max_precision is not None else "" + ) + (f"-min{min_precision}" if min_precision is not None else "") # Minimum number of digits - fractional_rule_part = "[0-9]" * (min_precision if min_precision is not None else 1) + fractional_rule_part = "[0-9]" * ( + min_precision if min_precision is not None else 1 + ) # Optional additional digits fractional_rule_part += "".join( - [" [0-9]?"] * ((max_precision - ( - min_precision if min_precision is not None else 1)) if max_precision is not None else 0) + [" [0-9]?"] + * ( + (max_precision - (min_precision if min_precision is not None else 1)) + if max_precision is not None + else 0 + ) ) additional_rules.append(f"{fractional_part_rule} ::= {fractional_rule_part}") # Define the float rule float_rule = f"float-{max_digit if max_digit is not None else 'X'}-{min_digit if min_digit is not None else 'X'}-{max_precision if max_precision is not None else 'X'}-{min_precision if min_precision is not None else 'X'}" - additional_rules.append(f'{float_rule} ::= {integer_part_rule} "." {fractional_part_rule}') + additional_rules.append( + f'{float_rule} ::= {integer_part_rule} "." {fractional_part_rule}' + ) # Generating the integer part rule definition, if necessary if max_digit is not None or min_digit is not None: @@ -271,14 +293,22 @@ def generate_gbnf_float_rules(max_digit=None, min_digit=None, max_precision=None if min_digit is not None and min_digit > 1: integer_rule_part += " [0-9]" * (min_digit - 1) if max_digit is not None: - integer_rule_part += "".join([" [0-9]?"] * (max_digit - (min_digit if min_digit is not None else 1))) + integer_rule_part += "".join( + [" [0-9]?"] * (max_digit - (min_digit if min_digit is not None else 1)) + ) additional_rules.append(f"{integer_part_rule} ::= {integer_rule_part.strip()}") return float_rule, additional_rules def generate_gbnf_rule_for_type( - model_name, field_name, field_type, is_optional, processed_models, created_rules, field_info=None + model_name, + field_name, + field_type, + is_optional, + processed_models, + created_rules, + field_info=None, ) -> tuple[str, list[str]]: """ Generate GBNF rule for a given field type. @@ -305,18 +335,27 @@ def generate_gbnf_rule_for_type( if isclass(origin_type) and issubclass(origin_type, BaseModel): nested_model_name = format_model_and_field_name(field_type.__name__) - nested_model_rules, _ = generate_gbnf_grammar(field_type, processed_models, created_rules) + nested_model_rules, _ = generate_gbnf_grammar( + field_type, processed_models, created_rules + ) rules.extend(nested_model_rules) gbnf_type, rules = nested_model_name, rules elif isclass(origin_type) and issubclass(origin_type, Enum): - enum_values = [f'"\\"{e.value}\\""' for e in field_type] # Adding escaped quotes + enum_values = [ + f'"\\"{e.value}\\""' for e in field_type + ] # Adding escaped quotes enum_rule = f"{model_name}-{field_name} ::= {' | '.join(enum_values)}" rules.append(enum_rule) gbnf_type, rules = model_name + "-" + field_name, rules elif origin_type is list: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( - model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules + model_name, + f"{field_name}-element", + element_type, + is_optional, + processed_models, + created_rules, ) rules.extend(additional_rules) array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """ @@ -326,7 +365,12 @@ def generate_gbnf_rule_for_type( elif origin_type is set: # Array element_type = get_args(field_type)[0] element_rule_name, additional_rules = generate_gbnf_rule_for_type( - model_name, f"{field_name}-element", element_type, is_optional, processed_models, created_rules + model_name, + f"{field_name}-element", + element_type, + is_optional, + processed_models, + created_rules, ) rules.extend(additional_rules) array_rule = f"""{model_name}-{field_name} ::= "[" ws {element_rule_name} ("," ws {element_rule_name})* "]" """ @@ -339,10 +383,20 @@ def generate_gbnf_rule_for_type( key_type, value_type = get_args(field_type) additional_key_type, additional_key_rules = generate_gbnf_rule_for_type( - model_name, f"{field_name}-key-type", key_type, is_optional, processed_models, created_rules + model_name, + f"{field_name}-key-type", + key_type, + is_optional, + processed_models, + created_rules, ) additional_value_type, additional_value_rules = generate_gbnf_rule_for_type( - model_name, f"{field_name}-value-type", value_type, is_optional, processed_models, created_rules + model_name, + f"{field_name}-value-type", + value_type, + is_optional, + processed_models, + created_rules, ) gbnf_type = rf'{gbnf_type} ::= "{{" ( {additional_key_type} ": " {additional_value_type} ("," "\n" ws {additional_key_type} ":" {additional_value_type})* )? "}}" ' @@ -355,14 +409,24 @@ def generate_gbnf_rule_for_type( for union_type in union_types: if isinstance(union_type, GenericAlias): union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type( - model_name, field_name, union_type, False, processed_models, created_rules + model_name, + field_name, + union_type, + False, + processed_models, + created_rules, ) union_rules.append(union_gbnf_type) rules.extend(union_rules_list) elif not issubclass(union_type, type(None)): union_gbnf_type, union_rules_list = generate_gbnf_rule_for_type( - model_name, field_name, union_type, False, processed_models, created_rules + model_name, + field_name, + union_type, + False, + processed_models, + created_rules, ) union_rules.append(union_gbnf_type) rules.extend(union_rules_list) @@ -371,19 +435,37 @@ def generate_gbnf_rule_for_type( if len(union_rules) == 1: union_grammar_rule = f"{model_name}-{field_name}-optional ::= {' | '.join(union_rules)} | null" else: - union_grammar_rule = f"{model_name}-{field_name}-union ::= {' | '.join(union_rules)}" + union_grammar_rule = ( + f"{model_name}-{field_name}-union ::= {' | '.join(union_rules)}" + ) rules.append(union_grammar_rule) if len(union_rules) == 1: gbnf_type = f"{model_name}-{field_name}-optional" else: gbnf_type = f"{model_name}-{field_name}-union" elif isclass(origin_type) and issubclass(origin_type, str): - if field_info and hasattr(field_info, "json_schema_extra") and field_info.json_schema_extra is not None: - triple_quoted_string = field_info.json_schema_extra.get("triple_quoted_string", False) - markdown_string = field_info.json_schema_extra.get("markdown_code_block", False) + if ( + field_info + and hasattr(field_info, "json_schema_extra") + and field_info.json_schema_extra is not None + ): + triple_quoted_string = field_info.json_schema_extra.get( + "triple_quoted_string", False + ) + markdown_string = field_info.json_schema_extra.get( + "markdown_code_block", False + ) - gbnf_type = PydanticDataType.TRIPLE_QUOTED_STRING.value if triple_quoted_string else PydanticDataType.STRING.value - gbnf_type = PydanticDataType.MARKDOWN_CODE_BLOCK.value if markdown_string else gbnf_type + gbnf_type = ( + PydanticDataType.TRIPLE_QUOTED_STRING.value + if triple_quoted_string + else PydanticDataType.STRING.value + ) + gbnf_type = ( + PydanticDataType.MARKDOWN_CODE_BLOCK.value + if markdown_string + else gbnf_type + ) elif field_info and hasattr(field_info, "pattern"): # Convert regex pattern to grammar rule @@ -401,21 +483,32 @@ def generate_gbnf_rule_for_type( ): # Retrieve precision attributes for floats max_precision = ( - field_info.json_schema_extra.get("max_precision") if field_info and hasattr(field_info, - "json_schema_extra") else None + field_info.json_schema_extra.get("max_precision") + if field_info and hasattr(field_info, "json_schema_extra") + else None ) min_precision = ( - field_info.json_schema_extra.get("min_precision") if field_info and hasattr(field_info, - "json_schema_extra") else None + field_info.json_schema_extra.get("min_precision") + if field_info and hasattr(field_info, "json_schema_extra") + else None + ) + max_digits = ( + field_info.json_schema_extra.get("max_digit") + if field_info and hasattr(field_info, "json_schema_extra") + else None + ) + min_digits = ( + field_info.json_schema_extra.get("min_digit") + if field_info and hasattr(field_info, "json_schema_extra") + else None ) - max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info, - "json_schema_extra") else None - min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info, - "json_schema_extra") else None # Generate GBNF rule for float with given attributes gbnf_type, rules = generate_gbnf_float_rules( - max_digit=max_digits, min_digit=min_digits, max_precision=max_precision, min_precision=min_precision + max_digit=max_digits, + min_digit=min_digits, + max_precision=max_precision, + min_precision=min_precision, ) elif ( @@ -426,20 +519,32 @@ def generate_gbnf_rule_for_type( and field_info.json_schema_extra is not None ): # Retrieve digit attributes for integers - max_digits = field_info.json_schema_extra.get("max_digit") if field_info and hasattr(field_info, - "json_schema_extra") else None - min_digits = field_info.json_schema_extra.get("min_digit") if field_info and hasattr(field_info, - "json_schema_extra") else None + max_digits = ( + field_info.json_schema_extra.get("max_digit") + if field_info and hasattr(field_info, "json_schema_extra") + else None + ) + min_digits = ( + field_info.json_schema_extra.get("min_digit") + if field_info and hasattr(field_info, "json_schema_extra") + else None + ) # Generate GBNF rule for integer with given attributes - gbnf_type, rules = generate_gbnf_integer_rules(max_digit=max_digits, min_digit=min_digits) + gbnf_type, rules = generate_gbnf_integer_rules( + max_digit=max_digits, min_digit=min_digits + ) else: gbnf_type, rules = gbnf_type, [] return gbnf_type, rules -def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[BaseModel]], created_rules: dict[str, list[str]]) -> tuple[list[str], bool]: +def generate_gbnf_grammar( + model: type[BaseModel], + processed_models: set[type[BaseModel]], + created_rules: dict[str, list[str]], +) -> tuple[list[str], bool]: """ Generate GBnF Grammar @@ -468,12 +573,17 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas if not issubclass(model, BaseModel): # For non-Pydantic classes, generate model_fields from __annotations__ or __init__ if hasattr(model, "__annotations__") and model.__annotations__: - model_fields = {name: (typ, ...) for name, typ in get_type_hints(model).items()} + model_fields = { + name: (typ, ...) for name, typ in get_type_hints(model).items() + } else: init_signature = inspect.signature(model.__init__) parameters = init_signature.parameters - model_fields = {name: (param.annotation, param.default) for name, param in parameters.items() if - name != "self"} + model_fields = { + name: (param.annotation, param.default) + for name, param in parameters.items() + if name != "self" + } else: # For Pydantic models, use model_fields and check for ellipsis (required fields) model_fields = get_type_hints(model) @@ -488,21 +598,36 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas if not issubclass(model, BaseModel): field_type, default_value = field_info # Check if the field is optional (not required) - is_optional = (default_value is not inspect.Parameter.empty) and (default_value is not Ellipsis) + is_optional = (default_value is not inspect.Parameter.empty) and ( + default_value is not Ellipsis + ) else: field_type = field_info field_info = model.model_fields[field_name] - is_optional = field_info.is_required is False and get_origin(field_type) is Optional + is_optional = ( + field_info.is_required is False and get_origin(field_type) is Optional + ) rule_name, additional_rules = generate_gbnf_rule_for_type( - model_name, format_model_and_field_name(field_name), field_type, is_optional, processed_models, - created_rules, field_info + model_name, + format_model_and_field_name(field_name), + field_type, + is_optional, + processed_models, + created_rules, + field_info, + ) + look_for_markdown_code_block = ( + True if rule_name == "markdown_code_block" else False + ) + look_for_triple_quoted_string = ( + True if rule_name == "triple_quoted_string" else False ) - look_for_markdown_code_block = True if rule_name == "markdown_code_block" else False - look_for_triple_quoted_string = True if rule_name == "triple_quoted_string" else False if not look_for_markdown_code_block and not look_for_triple_quoted_string: if rule_name not in created_rules: created_rules[rule_name] = additional_rules - model_rule_parts.append(f' ws "\\"{field_name}\\"" ":" ws {rule_name}') # Adding escaped quotes + model_rule_parts.append( + f' ws "\\"{field_name}\\"" ":" ws {rule_name}' + ) # Adding escaped quotes nested_rules.extend(additional_rules) else: has_triple_quoted_string = look_for_triple_quoted_string @@ -526,8 +651,10 @@ def generate_gbnf_grammar(model: type[BaseModel], processed_models: set[type[Bas def generate_gbnf_grammar_from_pydantic_models( - models: list[type[BaseModel]], outer_object_name: str | None = None, outer_object_content: str | None = None, - list_of_outputs: bool = False + models: list[type[BaseModel]], + outer_object_name: str | None = None, + outer_object_content: str | None = None, + list_of_outputs: bool = False, ) -> str: """ Generate GBNF Grammar from Pydantic Models. @@ -556,15 +683,21 @@ def generate_gbnf_grammar_from_pydantic_models( created_rules: dict[str, list[str]] = {} if outer_object_name is None: for model in models: - model_rules, _ = generate_gbnf_grammar(model, processed_models, created_rules) + model_rules, _ = generate_gbnf_grammar( + model, processed_models, created_rules + ) all_rules.extend(model_rules) if list_of_outputs: - root_rule = r'root ::= (" "| "\n") "[" ws grammar-models ("," ws grammar-models)* ws "]"' + "\n" + root_rule = ( + r'root ::= (" "| "\n") "[" ws grammar-models ("," ws grammar-models)* ws "]"' + + "\n" + ) else: root_rule = r'root ::= (" "| "\n") grammar-models' + "\n" root_rule += "grammar-models ::= " + " | ".join( - [format_model_and_field_name(model.__name__) for model in models]) + [format_model_and_field_name(model.__name__) for model in models] + ) all_rules.insert(0, root_rule) return "\n".join(all_rules) elif outer_object_name is not None: @@ -576,26 +709,32 @@ def generate_gbnf_grammar_from_pydantic_models( else: root_rule = f"root ::= {format_model_and_field_name(outer_object_name)}\n" - model_rule = ( - rf'{format_model_and_field_name(outer_object_name)} ::= (" "| "\n") "{{" ws "\"{outer_object_name}\"" ":" ws grammar-models' - ) + model_rule = rf'{format_model_and_field_name(outer_object_name)} ::= (" "| "\n") "{{" ws "\"{outer_object_name}\"" ":" ws grammar-models' fields_joined = " | ".join( - [rf"{format_model_and_field_name(model.__name__)}-grammar-model" for model in models]) + [ + rf"{format_model_and_field_name(model.__name__)}-grammar-model" + for model in models + ] + ) grammar_model_rules = f"\ngrammar-models ::= {fields_joined}" mod_rules = [] for model in models: - mod_rule = rf"{format_model_and_field_name(model.__name__)}-grammar-model ::= " + mod_rule = ( + rf"{format_model_and_field_name(model.__name__)}-grammar-model ::= " + ) mod_rule += ( - rf'"\"{model.__name__}\"" "," ws "\"{outer_object_content}\"" ":" ws {format_model_and_field_name(model.__name__)}' + "\n" + rf'"\"{model.__name__}\"" "," ws "\"{outer_object_content}\"" ":" ws {format_model_and_field_name(model.__name__)}' + + "\n" ) mod_rules.append(mod_rule) grammar_model_rules += "\n" + "\n".join(mod_rules) for model in models: - model_rules, has_special_string = generate_gbnf_grammar(model, processed_models, - created_rules) + model_rules, has_special_string = generate_gbnf_grammar( + model, processed_models, created_rules + ) if not has_special_string: model_rules[0] += r'"\n" ws "}"' @@ -670,12 +809,20 @@ def get_primitive_grammar(grammar): triple-quoted-string ::= triple-quotes triple-quoted-string-content triple-quotes triple-quoted-string-content ::= ( [^'] | "'" [^'] | "'" "'" [^'] )* triple-quotes ::= "'''" """ - return "\n" + "\n".join(additional_grammar) + any_block + primitive_grammar + markdown_code_block_grammar + return ( + "\n" + + "\n".join(additional_grammar) + + any_block + + primitive_grammar + + markdown_code_block_grammar + ) def generate_markdown_documentation( - pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields", - documentation_with_field_description=True + pydantic_models: list[type[BaseModel]], + model_prefix="Model", + fields_prefix="Fields", + documentation_with_field_description=True, ) -> str: """ Generate markdown documentation for a list of Pydantic models. @@ -690,7 +837,9 @@ def generate_markdown_documentation( str: Generated text documentation. """ documentation = "" - pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] + pyd_models: list[tuple[type[BaseModel], bool]] = [ + (model, True) for model in pydantic_models + ] for model, add_prefix in pyd_models: if add_prefix: documentation += f"{model_prefix}: {model.__name__}\n" @@ -701,7 +850,9 @@ def generate_markdown_documentation( class_doc = getdoc(model) base_class_doc = getdoc(BaseModel) - class_description = class_doc if class_doc and class_doc != base_class_doc else "" + class_description = ( + class_doc if class_doc and class_doc != base_class_doc else "" + ) if class_description != "": documentation += " Description: " documentation += format_multiline_description(class_description, 0) + "\n" @@ -722,15 +873,23 @@ def generate_markdown_documentation( if get_origin(field_type) == Union: element_types = get_args(field_type) for element_type in element_types: - if isclass(element_type) and issubclass(element_type, BaseModel): + if isclass(element_type) and issubclass( + element_type, BaseModel + ): pyd_models.append((element_type, False)) documentation += generate_field_markdown( - name, field_type, model, documentation_with_field_description=documentation_with_field_description + name, + field_type, + model, + documentation_with_field_description=documentation_with_field_description, ) documentation += "\n" - if hasattr(model, "Config") and hasattr(model.Config, - "json_schema_extra") and "example" in model.Config.json_schema_extra: + if ( + hasattr(model, "Config") + and hasattr(model.Config, "json_schema_extra") + and "example" in model.Config.json_schema_extra + ): documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n" json_example = json.dumps(model.Config.json_schema_extra["example"]) documentation += format_multiline_description(json_example, 2) + "\n" @@ -739,8 +898,11 @@ def generate_markdown_documentation( def generate_field_markdown( - field_name: str, field_type: type[Any], model: type[BaseModel], depth=1, - documentation_with_field_description=True + field_name: str, + field_type: type[Any], + model: type[BaseModel], + depth=1, + documentation_with_field_description=True, ) -> str: """ Generate markdown documentation for a Pydantic model field. @@ -758,7 +920,9 @@ def generate_field_markdown( indent = " " * depth field_info = model.model_fields.get(field_name) - field_description = field_info.description if field_info and field_info.description else "" + field_description = ( + field_info.description if field_info and field_info.description else "" + ) origin_type = get_origin(field_type) origin_type = field_type if origin_type is None else origin_type @@ -781,7 +945,9 @@ def generate_field_markdown( else: field_text += "\n" else: - field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + field_text = ( + f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + ) if field_description != "": field_text += ":\n" else: @@ -794,11 +960,18 @@ def generate_field_markdown( field_text += f" Description: {field_description}\n" # Check for and include field-specific examples if available - if hasattr(model, "Config") and hasattr(model.Config, - "json_schema_extra") and "example" in model.Config.json_schema_extra: + if ( + hasattr(model, "Config") + and hasattr(model.Config, "json_schema_extra") + and "example" in model.Config.json_schema_extra + ): field_example = model.Config.json_schema_extra["example"].get(field_name) if field_example is not None: - example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example + example_text = ( + f"'{field_example}'" + if isinstance(field_example, str) + else field_example + ) field_text += f"{indent} Example: {example_text}\n" if isclass(origin_type) and issubclass(origin_type, BaseModel): @@ -830,8 +1003,10 @@ def format_json_example(example: dict[str, Any], depth: int) -> str: def generate_text_documentation( - pydantic_models: list[type[BaseModel]], model_prefix="Model", fields_prefix="Fields", - documentation_with_field_description=True + pydantic_models: list[type[BaseModel]], + model_prefix="Model", + fields_prefix="Fields", + documentation_with_field_description=True, ) -> str: """ Generate text documentation for a list of Pydantic models. @@ -846,7 +1021,9 @@ def generate_text_documentation( str: Generated text documentation. """ documentation = "" - pyd_models: list[tuple[type[BaseModel], bool]] = [(model, True) for model in pydantic_models] + pyd_models: list[tuple[type[BaseModel], bool]] = [ + (model, True) for model in pydantic_models + ] for model, add_prefix in pyd_models: if add_prefix: documentation += f"{model_prefix}: {model.__name__}\n" @@ -857,10 +1034,14 @@ def generate_text_documentation( class_doc = getdoc(model) base_class_doc = getdoc(BaseModel) - class_description = class_doc if class_doc and class_doc != base_class_doc else "" + class_description = ( + class_doc if class_doc and class_doc != base_class_doc else "" + ) if class_description != "": documentation += " Description: " - documentation += "\n" + format_multiline_description(class_description, 2) + "\n" + documentation += ( + "\n" + format_multiline_description(class_description, 2) + "\n" + ) if isclass(model) and issubclass(model, BaseModel): documentation_fields = "" @@ -874,10 +1055,15 @@ def generate_text_documentation( if get_origin(field_type) == Union: element_types = get_args(field_type) for element_type in element_types: - if isclass(element_type) and issubclass(element_type, BaseModel): + if isclass(element_type) and issubclass( + element_type, BaseModel + ): pyd_models.append((element_type, False)) documentation_fields += generate_field_text( - name, field_type, model, documentation_with_field_description=documentation_with_field_description + name, + field_type, + model, + documentation_with_field_description=documentation_with_field_description, ) if documentation_fields != "": if add_prefix: @@ -886,8 +1072,11 @@ def generate_text_documentation( documentation += f" Fields:\n{documentation_fields}" documentation += "\n" - if hasattr(model, "Config") and hasattr(model.Config, - "json_schema_extra") and "example" in model.Config.json_schema_extra: + if ( + hasattr(model, "Config") + and hasattr(model.Config, "json_schema_extra") + and "example" in model.Config.json_schema_extra + ): documentation += f" Expected Example Output for {format_model_and_field_name(model.__name__)}:\n" json_example = json.dumps(model.Config.json_schema_extra["example"]) documentation += format_multiline_description(json_example, 2) + "\n" @@ -896,8 +1085,11 @@ def generate_text_documentation( def generate_field_text( - field_name: str, field_type: type[Any], model: type[BaseModel], depth=1, - documentation_with_field_description=True + field_name: str, + field_type: type[Any], + model: type[BaseModel], + depth=1, + documentation_with_field_description=True, ) -> str: """ Generate text documentation for a Pydantic model field. @@ -915,7 +1107,9 @@ def generate_field_text( indent = " " * depth field_info = model.model_fields.get(field_name) - field_description = field_info.description if field_info and field_info.description else "" + field_description = ( + field_info.description if field_info and field_info.description else "" + ) if get_origin(field_type) == list: element_type = get_args(field_type)[0] @@ -935,7 +1129,9 @@ def generate_field_text( else: field_text += "\n" else: - field_text = f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + field_text = ( + f"{indent}{field_name} ({format_model_and_field_name(field_type.__name__)})" + ) if field_description != "": field_text += ":\n" else: @@ -948,11 +1144,18 @@ def generate_field_text( field_text += f"{indent} Description: " + field_description + "\n" # Check for and include field-specific examples if available - if hasattr(model, "Config") and hasattr(model.Config, - "json_schema_extra") and "example" in model.Config.json_schema_extra: + if ( + hasattr(model, "Config") + and hasattr(model.Config, "json_schema_extra") + and "example" in model.Config.json_schema_extra + ): field_example = model.Config.json_schema_extra["example"].get(field_name) if field_example is not None: - example_text = f"'{field_example}'" if isinstance(field_example, str) else field_example + example_text = ( + f"'{field_example}'" + if isinstance(field_example, str) + else field_example + ) field_text += f"{indent} Example: {example_text}\n" if isclass(field_type) and issubclass(field_type, BaseModel): @@ -979,7 +1182,10 @@ def format_multiline_description(description: str, indent_level: int) -> str: def save_gbnf_grammar_and_documentation( - grammar, documentation, grammar_file_path="./grammar.gbnf", documentation_file_path="./grammar_documentation.md" + grammar, + documentation, + grammar_file_path="./grammar.gbnf", + documentation_file_path="./grammar_documentation.md", ): """ Save GBNF grammar and documentation to specified files. @@ -1053,13 +1259,18 @@ def generate_and_save_gbnf_grammar_and_documentation( None """ documentation = generate_markdown_documentation( - pydantic_model_list, model_prefix, fields_prefix, - documentation_with_field_description=documentation_with_field_description + pydantic_model_list, + model_prefix, + fields_prefix, + documentation_with_field_description=documentation_with_field_description, + ) + grammar = generate_gbnf_grammar_from_pydantic_models( + pydantic_model_list, outer_object_name, outer_object_content, list_of_outputs ) - grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, - list_of_outputs) grammar = remove_empty_lines(grammar) - save_gbnf_grammar_and_documentation(grammar, documentation, grammar_file_path, documentation_file_path) + save_gbnf_grammar_and_documentation( + grammar, documentation, grammar_file_path, documentation_file_path + ) def generate_gbnf_grammar_and_documentation( @@ -1087,11 +1298,14 @@ def generate_gbnf_grammar_and_documentation( tuple: GBNF grammar string, documentation string. """ documentation = generate_markdown_documentation( - copy(pydantic_model_list), model_prefix, fields_prefix, - documentation_with_field_description=documentation_with_field_description + copy(pydantic_model_list), + model_prefix, + fields_prefix, + documentation_with_field_description=documentation_with_field_description, + ) + grammar = generate_gbnf_grammar_from_pydantic_models( + pydantic_model_list, outer_object_name, outer_object_content, list_of_outputs ) - grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, - list_of_outputs) grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar)) return grammar, documentation @@ -1122,11 +1336,14 @@ def generate_gbnf_grammar_and_documentation_from_dictionaries( """ pydantic_model_list = create_dynamic_models_from_dictionaries(dictionaries) documentation = generate_markdown_documentation( - copy(pydantic_model_list), model_prefix, fields_prefix, - documentation_with_field_description=documentation_with_field_description + copy(pydantic_model_list), + model_prefix, + fields_prefix, + documentation_with_field_description=documentation_with_field_description, + ) + grammar = generate_gbnf_grammar_from_pydantic_models( + pydantic_model_list, outer_object_name, outer_object_content, list_of_outputs ) - grammar = generate_gbnf_grammar_from_pydantic_models(pydantic_model_list, outer_object_name, outer_object_content, - list_of_outputs) grammar = remove_empty_lines(grammar + get_primitive_grammar(grammar)) return grammar, documentation @@ -1158,15 +1375,20 @@ def create_dynamic_model_from_function(func: Callable[..., Any]): # Assert that the parameter has a type annotation if param.annotation == inspect.Parameter.empty: - raise TypeError(f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation") + raise TypeError( + f"Parameter '{param.name}' in function '{func.__name__}' lacks a type annotation" + ) # Find the parameter's description in the docstring - param_doc = next((d for d in docstring.params if d.arg_name == param.name), None) + param_doc = next( + (d for d in docstring.params if d.arg_name == param.name), None + ) # Assert that the parameter has a description if not param_doc or not param_doc.description: raise ValueError( - f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring") + f"Parameter '{param.name}' in function '{func.__name__}' lacks a description in the docstring" + ) # Add parameter details to the schema param_docs.append((param.name, param_doc)) @@ -1175,7 +1397,9 @@ def create_dynamic_model_from_function(func: Callable[..., Any]): else: default_value = param.default dynamic_fields[param.name] = ( - param.annotation if param.annotation != inspect.Parameter.empty else str, default_value) + param.annotation if param.annotation != inspect.Parameter.empty else str, + default_value, + ) # Creating the dynamic model dynamic_model = create_model(f"{func.__name__}", **dynamic_fields) @@ -1257,7 +1481,9 @@ def list_to_enum(enum_name, values): return Enum(enum_name, {value: value for value in values}) -def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: str = "CustomModel") -> type[Any]: +def convert_dictionary_to_pydantic_model( + dictionary: dict[str, Any], model_name: str = "CustomModel" +) -> type[Any]: """ Convert a dictionary to a Pydantic model class. @@ -1273,23 +1499,32 @@ def convert_dictionary_to_pydantic_model(dictionary: dict[str, Any], model_name: if "properties" in dictionary: for field_name, field_data in dictionary.get("properties", {}).items(): if field_data == "object": - submodel = convert_dictionary_to_pydantic_model(dictionary, f"{model_name}_{field_name}") + submodel = convert_dictionary_to_pydantic_model( + dictionary, f"{model_name}_{field_name}" + ) fields[field_name] = (submodel, ...) else: field_type = field_data.get("type", "str") if field_data.get("enum", []): - fields[field_name] = (list_to_enum(field_name, field_data.get("enum", [])), ...) + fields[field_name] = ( + list_to_enum(field_name, field_data.get("enum", [])), + ..., + ) elif field_type == "array": items = field_data.get("items", {}) if items != {}: array = {"properties": items} - array_type = convert_dictionary_to_pydantic_model(array, f"{model_name}_{field_name}_items") + array_type = convert_dictionary_to_pydantic_model( + array, f"{model_name}_{field_name}_items" + ) fields[field_name] = (List[array_type], ...) else: fields[field_name] = (list, ...) elif field_type == "object": - submodel = convert_dictionary_to_pydantic_model(field_data, f"{model_name}_{field_name}") + submodel = convert_dictionary_to_pydantic_model( + field_data, f"{model_name}_{field_name}" + ) fields[field_name] = (submodel, ...) elif field_type == "required": required = field_data.get("enum", []) diff --git a/smallthinker/examples/pydantic_models_to_grammar_examples.py b/smallthinker/examples/pydantic_models_to_grammar_examples.py index 6dadb7f3..0c1e3bc6 100755 --- a/smallthinker/examples/pydantic_models_to_grammar_examples.py +++ b/smallthinker/examples/pydantic_models_to_grammar_examples.py @@ -15,8 +15,12 @@ import requests from pydantic import BaseModel, Field -from pydantic_models_to_grammar import (add_run_method_to_dynamic_model, convert_dictionary_to_pydantic_model, - create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation) +from pydantic_models_to_grammar import ( + add_run_method_to_dynamic_model, + convert_dictionary_to_pydantic_model, + create_dynamic_model_from_function, + generate_gbnf_grammar_and_documentation, +) def create_completion(host, prompt, gbnf_grammar): @@ -25,22 +29,31 @@ def create_completion(host, prompt, gbnf_grammar): See https://github.com/ggml-org/llama.cpp/tree/HEAD/tools/server#api-endpoints """ - print(f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}") + print( + f" Request:\n Grammar:\n{textwrap.indent(gbnf_grammar, ' ')}\n Prompt:\n{textwrap.indent(prompt.rstrip(), ' ')}" + ) headers = {"Content-Type": "application/json"} data = {"prompt": prompt, "grammar": gbnf_grammar} - result = requests.post(f"http://{host}/completion", headers=headers, json=data).json() + result = requests.post( + f"http://{host}/completion", headers=headers, json=data + ).json() assert data.get("error") is None, data logging.info("Result: %s", result) content = result["content"] print(f" Model: {result['model']}") - print(f" Result:\n{textwrap.indent(json.dumps(json.loads(content), indent=2), ' ')}") + print( + f" Result:\n{textwrap.indent(json.dumps(json.loads(content), indent=2), ' ')}" + ) return content # A function for the agent to send a message to the user. class SendMessageToUser(BaseModel): """Send a message to the User.""" - chain_of_thought: str = Field(..., description="Your chain of thought while sending the message.") + + chain_of_thought: str = Field( + ..., description="Your chain of thought while sending the message." + ) message: str = Field(..., description="Message you want to send to the user.") def run(self): @@ -52,14 +65,21 @@ def example_rce(host): print("- example_rce") tools = [SendMessageToUser] gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( - pydantic_model_list=tools, outer_object_name="function", - outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") - system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + pydantic_model_list=tools, + outer_object_name="function", + outer_object_content="function_parameters", + model_prefix="Function", + fields_prefix="Parameters", + ) + system_message = ( + "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + + documentation + ) user_message = "What is 42 * 42?" prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" text = create_completion(host, prompt, gbnf_grammar) json_data = json.loads(text) - tools_map = {tool.__name__:tool for tool in tools} + tools_map = {tool.__name__: tool for tool in tools} # This finds "SendMessageToUser": tool = tools_map.get(json_data["function"]) if not tool: @@ -82,6 +102,7 @@ class MathOperation(Enum): # system prompt. class Calculator(BaseModel): """Perform a math operation on two numbers.""" + number_one: Union[int, float] = Field(..., description="First number.") operation: MathOperation = Field(..., description="Math operation to perform.") number_two: Union[int, float] = Field(..., description="Second number.") @@ -118,9 +139,16 @@ def example_calculator(host): print("- example_calculator") tools = [SendMessageToUser, Calculator] gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( - pydantic_model_list=tools, outer_object_name="function", - outer_object_content="function_parameters", model_prefix="Function", fields_prefix="Parameters") - system_message = "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + documentation + pydantic_model_list=tools, + outer_object_name="function", + outer_object_content="function_parameters", + model_prefix="Function", + fields_prefix="Parameters", + ) + system_message = ( + "You are an advanced AI, tasked to assist the user by calling functions in JSON format. The following are the available functions and their parameters and types:\n\n" + + documentation + ) user_message1 = "What is 42 * 42?" prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{user_message1}<|im_end|>\n<|im_start|>assistant" text = create_completion(host, prompt, gbnf_grammar) @@ -130,12 +158,12 @@ def example_calculator(host): "function_parameters": { "number_one": 42, "operation": "multiply", - "number_two": 42 - } + "number_two": 42, + }, } if json_data != expected: print(" Result is not as expected!") - tools_map = {tool.__name__:tool for tool in tools} + tools_map = {tool.__name__: tool for tool in tools} # This finds "Calculator": tool = tools_map.get(json_data["function"]) if not tool: @@ -148,15 +176,19 @@ def example_calculator(host): class Category(Enum): """The category of the book.""" + Fiction = "Fiction" NonFiction = "Non-Fiction" class Book(BaseModel): """Represents an entry about a book.""" + title: str = Field(..., description="Title of the book.") author: str = Field(..., description="Author of the book.") - published_year: Optional[int] = Field(..., description="Publishing year of the book.") + published_year: Optional[int] = Field( + ..., description="Publishing year of the book." + ) keywords: list[str] = Field(..., description="A list of keywords.") category: Category = Field(..., description="Category of the book.") summary: str = Field(..., description="Summary of the book.") @@ -171,15 +203,22 @@ def example_struct(host): """ print("- example_struct") tools = [Book] - gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation(pydantic_model_list=tools) - system_message = "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + documentation + gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( + pydantic_model_list=tools + ) + system_message = ( + "You are an advanced AI, tasked to create a dataset entry in JSON for a Book. The following is the expected output model:\n\n" + + documentation + ) text = """The Feynman Lectures on Physics is a physics textbook based on some lectures by Richard Feynman, a Nobel laureate who has sometimes been called "The Great Explainer". The lectures were presented before undergraduate students at the California Institute of Technology (Caltech), during 1961–1963. The book's co-authors are Feynman, Robert B. Leighton, and Matthew Sands.""" prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" text = create_completion(host, prompt, gbnf_grammar) json_data = json.loads(text) # In this case, there's no function nor function_parameters. # Here the result will vary based on the LLM used. - keys = sorted(["title", "author", "published_year", "keywords", "category", "summary"]) + keys = sorted( + ["title", "author", "published_year", "keywords", "category", "summary"] + ) if keys != sorted(json_data.keys()): print(f"Unexpected result: {sorted(json_data.keys())}") return 1 @@ -201,11 +240,17 @@ def get_current_datetime(output_format: Optional[str] = None): def get_current_weather(location, unit): """Get the current weather in a given location""" if "London" in location: - return json.dumps({"location": "London", "temperature": "42", "unit": unit.value}) + return json.dumps( + {"location": "London", "temperature": "42", "unit": unit.value} + ) elif "New York" in location: - return json.dumps({"location": "New York", "temperature": "24", "unit": unit.value}) + return json.dumps( + {"location": "New York", "temperature": "24", "unit": unit.value} + ) elif "North Pole" in location: - return json.dumps({"location": "North Pole", "temperature": "-42", "unit": unit.value}) + return json.dumps( + {"location": "North Pole", "temperature": "-42", "unit": unit.value} + ) return json.dumps({"location": location, "temperature": "unknown"}) @@ -234,58 +279,66 @@ def example_concurrent(host): }, } # Convert OpenAI function definition into pydantic model. - current_weather_tool_model = convert_dictionary_to_pydantic_model(current_weather_tool) + current_weather_tool_model = convert_dictionary_to_pydantic_model( + current_weather_tool + ) # Add the actual function to a pydantic model. - current_weather_tool_model = add_run_method_to_dynamic_model(current_weather_tool_model, get_current_weather) + current_weather_tool_model = add_run_method_to_dynamic_model( + current_weather_tool_model, get_current_weather + ) # Convert normal Python function to a pydantic model. current_datetime_model = create_dynamic_model_from_function(get_current_datetime) - tools = [SendMessageToUser, Calculator, current_datetime_model, current_weather_tool_model] + tools = [ + SendMessageToUser, + Calculator, + current_datetime_model, + current_weather_tool_model, + ] gbnf_grammar, documentation = generate_gbnf_grammar_and_documentation( - pydantic_model_list=tools, outer_object_name="function", - outer_object_content="params", model_prefix="Function", fields_prefix="Parameters", list_of_outputs=True) - system_message = "You are an advanced AI assistant. You are interacting with the user and with your environment by calling functions. You call functions by writing JSON objects, which represent specific function calls.\nBelow is a list of your available function calls:\n\n" + documentation + pydantic_model_list=tools, + outer_object_name="function", + outer_object_content="params", + model_prefix="Function", + fields_prefix="Parameters", + list_of_outputs=True, + ) + system_message = ( + "You are an advanced AI assistant. You are interacting with the user and with your environment by calling functions. You call functions by writing JSON objects, which represent specific function calls.\nBelow is a list of your available function calls:\n\n" + + documentation + ) text = """Get the date and time, get the current weather in celsius in London and solve the following calculation: 42 * 42""" prompt = f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{text}<|im_end|>\n<|im_start|>assistant" text = create_completion(host, prompt, gbnf_grammar) json_data = json.loads(text) expected = [ - { - "function": "get_current_datetime", - "params": { - "output_format": "%Y-%m-%d %H:%M:%S" - } - }, - { - "function": "get_current_weather", - "params": { - "location": "London", - "unit": "celsius" - } - }, - { - "function": "Calculator", - "params": { - "number_one": 42, - "operation": "multiply", - "number_two": 42 - } - } + { + "function": "get_current_datetime", + "params": {"output_format": "%Y-%m-%d %H:%M:%S"}, + }, + { + "function": "get_current_weather", + "params": {"location": "London", "unit": "celsius"}, + }, + { + "function": "Calculator", + "params": {"number_one": 42, "operation": "multiply", "number_two": 42}, + }, ] res = 0 if json_data != expected: print(" Result is not as expected!") print(" This can happen on highly quantized models") res = 1 - tools_map = {tool.__name__:tool for tool in tools} + tools_map = {tool.__name__: tool for tool in tools} for call in json_data: - tool = tools_map.get(call["function"]) - if not tool: - print(f"Error: unknown tool {call['function']}") - return 1 - result = tool(**call["params"]).run() - print(f" Call {call['function']} returned {result}") + tool = tools_map.get(call["function"]) + if not tool: + print(f"Error: unknown tool {call['function']}") + return 1 + result = tool(**call["params"]).run() + print(f" Call {call['function']} returned {result}") # Should output something like this: # Call get_current_datetime returned 2024-07-15 09:50:38 # Call get_current_weather returned {"location": "London", "temperature": "42", "unit": "celsius"} diff --git a/smallthinker/examples/regex_to_grammar.py b/smallthinker/examples/regex_to_grammar.py index 5cd9210a..768e7479 100644 --- a/smallthinker/examples/regex_to_grammar.py +++ b/smallthinker/examples/regex_to_grammar.py @@ -3,18 +3,24 @@ assert len(sys.argv) >= 2 [_, pattern, *rest] = sys.argv -print(subprocess.check_output( - [ - "python", - os.path.join( - os.path.dirname(os.path.realpath(__file__)), - "json_schema_to_grammar.py"), - *rest, - "-", - "--raw-pattern", - ], - text=True, - input=json.dumps({ - "type": "string", - "pattern": pattern, - }, indent=2))) +print( + subprocess.check_output( + [ + "python", + os.path.join( + os.path.dirname(os.path.realpath(__file__)), "json_schema_to_grammar.py" + ), + *rest, + "-", + "--raw-pattern", + ], + text=True, + input=json.dumps( + { + "type": "string", + "pattern": pattern, + }, + indent=2, + ), + ) +) diff --git a/smallthinker/examples/server_embd.py b/smallthinker/examples/server_embd.py index f8b0ffec..a9dbab94 100644 --- a/smallthinker/examples/server_embd.py +++ b/smallthinker/examples/server_embd.py @@ -8,28 +8,37 @@ result = [] + async def requests_post_async(*args, **kwargs): return await asyncio.threads.to_thread(requests.post, *args, **kwargs) + async def main(): model_url = "http://127.0.0.1:6900" - responses: list[requests.Response] = await asyncio.gather(*[requests_post_async( - url= f"{model_url}/embedding", - json= {"content": "a "*1022} - ) for i in range(n)]) + responses: list[requests.Response] = await asyncio.gather( + *[ + requests_post_async( + url=f"{model_url}/embedding", json={"content": "a " * 1022} + ) + for i in range(n) + ] + ) for response in responses: embedding = response.json()["embedding"] print(embedding[-8:]) result.append(embedding) + asyncio.run(main()) # compute cosine similarity -for i in range(n-1): - for j in range(i+1, n): +for i in range(n - 1): + for j in range(i + 1, n): embedding1 = np.array(result[i]) embedding2 = np.array(result[j]) - similarity = np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) + similarity = np.dot(embedding1, embedding2) / ( + np.linalg.norm(embedding1) * np.linalg.norm(embedding2) + ) print(f"Similarity between {i} and {j}: {similarity:.2f}") diff --git a/smallthinker/get_no_moe_weights_ffn.py b/smallthinker/get_no_moe_weights_ffn.py index e08c6b7b..2e11b630 100644 --- a/smallthinker/get_no_moe_weights_ffn.py +++ b/smallthinker/get_no_moe_weights_ffn.py @@ -6,7 +6,6 @@ from pathlib import Path - try: from gguf.gguf_reader import GGUFReader, ReaderTensor from gguf.constants import GGUF_MAGIC, GGUFEndian @@ -16,9 +15,9 @@ from gguf.constants import GGUF_MAGIC, GGUFEndian - def looks_like_moe(name: str) -> bool: - return 'exps' in name + return "exps" in name + def align_offset(offset: int, alignment: int) -> int: if alignment == 0: @@ -28,75 +27,78 @@ def align_offset(offset: int, alignment: int) -> int: def main(src_path: str, dst_path: str): try: - reader = GGUFReader(src_path, 'r') + reader = GGUFReader(src_path, "r") except Exception as e: print(f"read file error: {e}", file=sys.stderr) return - tensors_to_keep: list[ReaderTensor] = [t for t in reader.tensors if not looks_like_moe(t.name)] + tensors_to_keep: list[ReaderTensor] = [ + t for t in reader.tensors if not looks_like_moe(t.name) + ] if len(tensors_to_keep) == len(reader.tensors): - print(" COPY!") + print(" COPY!") alignment = reader.alignment - endian_char = '<' if reader.endianess == GGUFEndian.LITTLE else '>' + endian_char = "<" if reader.endianess == GGUFEndian.LITTLE else ">" - with open(dst_path, 'wb') as fout: + with open(dst_path, "wb") as fout: header_format = f"{endian_char}I I Q Q" - kv_count = int(reader.fields['GGUF.kv_count'].parts[-1][0]) - gguf_version = int(reader.fields['GGUF.version'].parts[-1][0]) + kv_count = int(reader.fields["GGUF.kv_count"].parts[-1][0]) + gguf_version = int(reader.fields["GGUF.version"].parts[-1][0]) new_header = struct.pack( header_format, GGUF_MAGIC, gguf_version, len(tensors_to_keep), kv_count ) fout.write(new_header) - header_size = struct.calcsize(header_format) - first_tensor_info_offset = reader.tensors[0].field.offset if reader.tensors else reader.data_offset + first_tensor_info_offset = ( + reader.tensors[0].field.offset if reader.tensors else reader.data_offset + ) metadata_size = first_tensor_info_offset - header_size - + metadata_bytes = reader.data[header_size : header_size + metadata_size] fout.write(metadata_bytes) - + current_data_offset = 0 for tensor in tensors_to_keep: - name_bytes = tensor.name.encode('utf-8') + name_bytes = tensor.name.encode("utf-8") fout.write(struct.pack(f"{endian_char}Q", len(name_bytes))) fout.write(name_bytes) fout.write(struct.pack(f"{endian_char}I", len(tensor.shape))) - for dim in (tensor.shape): + for dim in tensor.shape: fout.write(struct.pack(f"{endian_char}Q", dim)) fout.write(struct.pack(f"{endian_char}I", tensor.tensor_type.value)) fout.write(struct.pack(f"{endian_char}Q", current_data_offset)) - - + current_data_offset += tensor.n_bytes current_data_offset = align_offset(current_data_offset, alignment) - - + tensor_info_end_offset = fout.tell() data_block_start_offset = align_offset(tensor_info_end_offset, alignment) padding_size = data_block_start_offset - tensor_info_end_offset if padding_size > 0: - fout.write(b'\x00' * padding_size) - - with open(src_path, 'rb') as fin: + fout.write(b"\x00" * padding_size) + + with open(src_path, "rb") as fin: total_data_written = 0 for tensor in tensors_to_keep: fin.seek(tensor.data_offset) tensor_bytes = fin.read(tensor.n_bytes) if len(tensor_bytes) != tensor.n_bytes: - raise IOError(f"Error: expect {tensor.n_bytes} bytes, but get {len(tensor_bytes)}") - + raise IOError( + f"Error: expect {tensor.n_bytes} bytes, but get {len(tensor_bytes)}" + ) + fout.write(tensor_bytes) total_data_written += len(tensor_bytes) - - padding_to_add = align_offset(tensor.n_bytes, alignment) - tensor.n_bytes + padding_to_add = ( + align_offset(tensor.n_bytes, alignment) - tensor.n_bytes + ) if padding_to_add > 0: - fout.write(b'\x00' * padding_to_add) - - + fout.write(b"\x00" * padding_to_add) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -104,4 +106,4 @@ def main(src_path: str, dst_path: str): parser.add_argument("dst", help="output path") args = parser.parse_args() - main(args.src, args.dst) \ No newline at end of file + main(args.src, args.dst) diff --git a/smallthinker/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/smallthinker/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index 3428113d..375d8a23 100755 --- a/smallthinker/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/smallthinker/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,7 +3,14 @@ from glob import glob import os -TYPES_KV = ["GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_F16"] +TYPES_KV = [ + "GGML_TYPE_Q4_0", + "GGML_TYPE_Q4_1", + "GGML_TYPE_Q5_0", + "GGML_TYPE_Q5_1", + "GGML_TYPE_Q8_0", + "GGML_TYPE_F16", +] SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -18,13 +25,29 @@ """ -SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" +SOURCE_FATTN_MMA_CASE = ( + "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" +) TYPES_MMQ = [ - "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", - "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", - "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_Q4_0", + "GGML_TYPE_Q4_1", + "GGML_TYPE_Q5_0", + "GGML_TYPE_Q5_1", + "GGML_TYPE_Q8_0", + "GGML_TYPE_Q2_K", + "GGML_TYPE_Q3_K", + "GGML_TYPE_Q4_K", + "GGML_TYPE_Q5_K", + "GGML_TYPE_Q6_K", + "GGML_TYPE_IQ2_XXS", + "GGML_TYPE_IQ2_XS", + "GGML_TYPE_IQ2_S", + "GGML_TYPE_IQ3_XXS", + "GGML_TYPE_IQ3_S", + "GGML_TYPE_IQ1_S", + "GGML_TYPE_IQ4_NL", + "GGML_TYPE_IQ4_XS", ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -54,15 +77,27 @@ def get_head_sizes(type_k, type_v): for type_k in TYPES_KV: for type_v in TYPES_KV: for head_size in get_head_sizes(type_k, type_v): - with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: - f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) + with open( + f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", + "w", + ) as f: + f.write( + SOURCE_FATTN_VEC.format( + vkq_size=vkq_size, + head_size=head_size, + type_k=type_k, + type_v=type_v, + ) + ) for ncols in [8, 16, 32, 64]: for ncols2 in [1, 2, 4, 8, 16]: if ncols2 > ncols: continue ncols1 = ncols // ncols2 - with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: + with open( + f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w" + ) as f: f.write(SOURCE_FATTN_MMA_START) for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: @@ -71,7 +106,14 @@ def get_head_sizes(type_k, type_v): if head_size_kq == 576 and ncols2 != 16: continue head_size_v = head_size_kq if head_size_kq != 576 else 512 - f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) + f.write( + SOURCE_FATTN_MMA_CASE.format( + ncols1=ncols1, + ncols2=ncols2, + head_size_kq=head_size_kq, + head_size_v=head_size_v, + ) + ) for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: diff --git a/smallthinker/ggml/src/ggml-opencl/kernels/embed_kernel.py b/smallthinker/ggml/src/ggml-opencl/kernels/embed_kernel.py index b5d1d724..75fed333 100644 --- a/smallthinker/ggml/src/ggml-opencl/kernels/embed_kernel.py +++ b/smallthinker/ggml/src/ggml-opencl/kernels/embed_kernel.py @@ -2,6 +2,7 @@ import sys import logging + logger = logging.getLogger("opencl-embed-kernel") diff --git a/smallthinker/gguf-py/examples/reader.py b/smallthinker/gguf-py/examples/reader.py index 703b782b..6238a251 100644 --- a/smallthinker/gguf-py/examples/reader.py +++ b/smallthinker/gguf-py/examples/reader.py @@ -22,26 +22,32 @@ def read_gguf_file(gguf_file_path): reader = GGUFReader(gguf_file_path) # List all key-value pairs in a columnized format - print("Key-Value Pairs:") # noqa: NP100 + print("Key-Value Pairs:") # noqa: NP100 max_key_length = max(len(key) for key in reader.fields.keys()) for key, field in reader.fields.items(): value = field.parts[field.data[0]] - print(f"{key:{max_key_length}} : {value}") # noqa: NP100 - print("----") # noqa: NP100 + print(f"{key:{max_key_length}} : {value}") # noqa: NP100 + print("----") # noqa: NP100 # List all tensors - print("Tensors:") # noqa: NP100 + print("Tensors:") # noqa: NP100 tensor_info_format = "{:<30} | Shape: {:<15} | Size: {:<12} | Quantization: {}" - print(tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization")) # noqa: NP100 - print("-" * 80) # noqa: NP100 + print( + tensor_info_format.format("Tensor Name", "Shape", "Size", "Quantization") + ) # noqa: NP100 + print("-" * 80) # noqa: NP100 for tensor in reader.tensors: shape_str = "x".join(map(str, tensor.shape)) size_str = str(tensor.n_elements) quantization_str = tensor.tensor_type.name - print(tensor_info_format.format(tensor.name, shape_str, size_str, quantization_str)) # noqa: NP100 + print( + tensor_info_format.format( + tensor.name, shape_str, size_str, quantization_str + ) + ) # noqa: NP100 -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) < 2: logger.info("Usage: reader.py ") sys.exit(1) diff --git a/smallthinker/gguf-py/examples/writer.py b/smallthinker/gguf-py/examples/writer.py index 731873a7..6263fc0c 100755 --- a/smallthinker/gguf-py/examples/writer.py +++ b/smallthinker/gguf-py/examples/writer.py @@ -35,5 +35,5 @@ def writer_example() -> None: gguf_writer.close() -if __name__ == '__main__': +if __name__ == "__main__": writer_example() diff --git a/smallthinker/gguf-py/gguf/constants.py b/smallthinker/gguf-py/gguf/constants.py index 76ae9930..f7206b61 100644 --- a/smallthinker/gguf-py/gguf/constants.py +++ b/smallthinker/gguf-py/gguf/constants.py @@ -7,10 +7,10 @@ # constants # -GGUF_MAGIC = 0x46554747 # "GGUF" -GGUF_VERSION = 3 +GGUF_MAGIC = 0x46554747 # "GGUF" +GGUF_VERSION = 3 GGUF_DEFAULT_ALIGNMENT = 32 -GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h +GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h # # metadata keys @@ -19,249 +19,256 @@ class Keys: class General: - TYPE = "general.type" - ARCHITECTURE = "general.architecture" - QUANTIZATION_VERSION = "general.quantization_version" - ALIGNMENT = "general.alignment" - FILE_TYPE = "general.file_type" + TYPE = "general.type" + ARCHITECTURE = "general.architecture" + QUANTIZATION_VERSION = "general.quantization_version" + ALIGNMENT = "general.alignment" + FILE_TYPE = "general.file_type" # Authorship Metadata - NAME = "general.name" - AUTHOR = "general.author" - VERSION = "general.version" - ORGANIZATION = "general.organization" + NAME = "general.name" + AUTHOR = "general.author" + VERSION = "general.version" + ORGANIZATION = "general.organization" - FINETUNE = "general.finetune" - BASENAME = "general.basename" + FINETUNE = "general.finetune" + BASENAME = "general.basename" - DESCRIPTION = "general.description" - QUANTIZED_BY = "general.quantized_by" + DESCRIPTION = "general.description" + QUANTIZED_BY = "general.quantized_by" - SIZE_LABEL = "general.size_label" + SIZE_LABEL = "general.size_label" # Licensing details - LICENSE = "general.license" - LICENSE_NAME = "general.license.name" - LICENSE_LINK = "general.license.link" + LICENSE = "general.license" + LICENSE_NAME = "general.license.name" + LICENSE_LINK = "general.license.link" # Typically represents the converted GGUF repo (Unless native) - URL = "general.url" # Model Website/Paper - DOI = "general.doi" - UUID = "general.uuid" - REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...) + URL = "general.url" # Model Website/Paper + DOI = "general.doi" + UUID = "general.uuid" + REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...) # Model Source during conversion - SOURCE_URL = "general.source.url" # Model Website/Paper - SOURCE_DOI = "general.source.doi" - SOURCE_UUID = "general.source.uuid" - SOURCE_REPO_URL = "general.source.repo_url" # Model Source Repository (git/svn/etc...) + SOURCE_URL = "general.source.url" # Model Website/Paper + SOURCE_DOI = "general.source.doi" + SOURCE_UUID = "general.source.uuid" + SOURCE_REPO_URL = ( + "general.source.repo_url" # Model Source Repository (git/svn/etc...) + ) # Base Model Source. There can be more than one source if it's a merged # model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in # tracing linage of models as it is finetuned or merged over time. - BASE_MODEL_COUNT = "general.base_model.count" - BASE_MODEL_NAME = "general.base_model.{id}.name" - BASE_MODEL_AUTHOR = "general.base_model.{id}.author" - BASE_MODEL_VERSION = "general.base_model.{id}.version" - BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization" - BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description" - BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper - BASE_MODEL_DOI = "general.base_model.{id}.doi" - BASE_MODEL_UUID = "general.base_model.{id}.uuid" - BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...) + BASE_MODEL_COUNT = "general.base_model.count" + BASE_MODEL_NAME = "general.base_model.{id}.name" + BASE_MODEL_AUTHOR = "general.base_model.{id}.author" + BASE_MODEL_VERSION = "general.base_model.{id}.version" + BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization" + BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description" + BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper + BASE_MODEL_DOI = "general.base_model.{id}.doi" + BASE_MODEL_UUID = "general.base_model.{id}.uuid" + BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...) # Dataset Source - DATASET_COUNT = "general.dataset.count" - DATASET_NAME = "general.dataset.{id}.name" - DATASET_AUTHOR = "general.dataset.{id}.author" - DATASET_VERSION = "general.dataset.{id}.version" - DATASET_ORGANIZATION = "general.dataset.{id}.organization" - DATASET_DESCRIPTION = "general.dataset.{id}.description" - DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper - DATASET_DOI = "general.dataset.{id}.doi" - DATASET_UUID = "general.dataset.{id}.uuid" - DATASET_REPO_URL = "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...) + DATASET_COUNT = "general.dataset.count" + DATASET_NAME = "general.dataset.{id}.name" + DATASET_AUTHOR = "general.dataset.{id}.author" + DATASET_VERSION = "general.dataset.{id}.version" + DATASET_ORGANIZATION = "general.dataset.{id}.organization" + DATASET_DESCRIPTION = "general.dataset.{id}.description" + DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper + DATASET_DOI = "general.dataset.{id}.doi" + DATASET_UUID = "general.dataset.{id}.uuid" + DATASET_REPO_URL = ( + "general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...) + ) # Array based KV stores - TAGS = "general.tags" - LANGUAGES = "general.languages" + TAGS = "general.tags" + LANGUAGES = "general.languages" class LLM: - VOCAB_SIZE = "{arch}.vocab_size" - CONTEXT_LENGTH = "{arch}.context_length" - EMBEDDING_LENGTH = "{arch}.embedding_length" - FEATURES_LENGTH = "{arch}.features_length" - BLOCK_COUNT = "{arch}.block_count" - LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" - FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" - EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" + VOCAB_SIZE = "{arch}.vocab_size" + CONTEXT_LENGTH = "{arch}.context_length" + EMBEDDING_LENGTH = "{arch}.embedding_length" + FEATURES_LENGTH = "{arch}.features_length" + BLOCK_COUNT = "{arch}.block_count" + LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" + FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" + EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length" EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length" - USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" - TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" - EXPERT_COUNT = "{arch}.expert_count" - EXPERT_USED_COUNT = "{arch}.expert_used_count" - EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" - EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" - EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" - EXPERT_GATING_FUNC = "{arch}.expert_gating_func" - MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" - POOLING_TYPE = "{arch}.pooling_type" - LOGIT_SCALE = "{arch}.logit_scale" - DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" - ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" - FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" - SWIN_NORM = "{arch}.swin_norm" - RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" - TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" - TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" - RESIDUAL_SCALE = "{arch}.residual_scale" - EMBEDDING_SCALE = "{arch}.embedding_scale" - TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" - INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" + TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + EXPERT_COUNT = "{arch}.expert_count" + EXPERT_USED_COUNT = "{arch}.expert_used_count" + EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" + EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" + EXPERT_GATING_FUNC = "{arch}.expert_gating_func" + MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers" + POOLING_TYPE = "{arch}.pooling_type" + LOGIT_SCALE = "{arch}.logit_scale" + DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" + FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + SWIN_NORM = "{arch}.swin_norm" + RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers" + TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim" + TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim" + RESIDUAL_SCALE = "{arch}.residual_scale" + EMBEDDING_SCALE = "{arch}.embedding_scale" + TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" + INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" class Attention: - HEAD_COUNT = "{arch}.attention.head_count" - HEAD_COUNT_KV = "{arch}.attention.head_count_kv" - MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" - CLAMP_KQV = "{arch}.attention.clamp_kqv" - KEY_LENGTH = "{arch}.attention.key_length" - VALUE_LENGTH = "{arch}.attention.value_length" - LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" - LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" - GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon" - GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups" - CAUSAL = "{arch}.attention.causal" - Q_LORA_RANK = "{arch}.attention.q_lora_rank" - KV_LORA_RANK = "{arch}.attention.kv_lora_rank" - DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank" - ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank" + HEAD_COUNT = "{arch}.attention.head_count" + HEAD_COUNT_KV = "{arch}.attention.head_count_kv" + MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" + CLAMP_KQV = "{arch}.attention.clamp_kqv" + KEY_LENGTH = "{arch}.attention.key_length" + VALUE_LENGTH = "{arch}.attention.value_length" + LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" + LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon" + GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups" + CAUSAL = "{arch}.attention.causal" + Q_LORA_RANK = "{arch}.attention.q_lora_rank" + KV_LORA_RANK = "{arch}.attention.kv_lora_rank" + DECAY_LORA_RANK = "{arch}.attention.decay_lora_rank" + ICLR_LORA_RANK = "{arch}.attention.iclr_lora_rank" VALUE_RESIDUAL_MIX_LORA_RANK = "{arch}.attention.value_residual_mix_lora_rank" - GATE_LORA_RANK = "{arch}.attention.gate_lora_rank" - REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" - SLIDING_WINDOW = "{arch}.attention.sliding_window" - SCALE = "{arch}.attention.scale" - KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" - VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" + GATE_LORA_RANK = "{arch}.attention.gate_lora_rank" + REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count" + SLIDING_WINDOW = "{arch}.attention.sliding_window" + SCALE = "{arch}.attention.scale" + KEY_LENGTH_MLA = "{arch}.attention.key_length_mla" + VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla" class Rope: - DIMENSION_COUNT = "{arch}.rope.dimension_count" - DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" - FREQ_BASE = "{arch}.rope.freq_base" - SCALING_TYPE = "{arch}.rope.scaling.type" - SCALING_FACTOR = "{arch}.rope.scaling.factor" - SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" - SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" - SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" - SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" + DIMENSION_COUNT = "{arch}.rope.dimension_count" + DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" + FREQ_BASE = "{arch}.rope.freq_base" + SCALING_TYPE = "{arch}.rope.scaling.type" + SCALING_FACTOR = "{arch}.rope.scaling.factor" + SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor" + SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length" + SCALING_FINETUNED = "{arch}.rope.scaling.finetuned" + SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier" class Split: - LLM_KV_SPLIT_NO = "split.no" - LLM_KV_SPLIT_COUNT = "split.count" + LLM_KV_SPLIT_NO = "split.no" + LLM_KV_SPLIT_COUNT = "split.count" LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count" class SSM: - CONV_KERNEL = "{arch}.ssm.conv_kernel" - INNER_SIZE = "{arch}.ssm.inner_size" - STATE_SIZE = "{arch}.ssm.state_size" + CONV_KERNEL = "{arch}.ssm.conv_kernel" + INNER_SIZE = "{arch}.ssm.inner_size" + STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" - DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" class PosNet: EMBEDDING_LENGTH = "{arch}.posnet.embedding_length" - BLOCK_COUNT = "{arch}.posnet.block_count" + BLOCK_COUNT = "{arch}.posnet.block_count" class ConvNext: EMBEDDING_LENGTH = "{arch}.convnext.embedding_length" - BLOCK_COUNT = "{arch}.convnext.block_count" + BLOCK_COUNT = "{arch}.convnext.block_count" class Classifier: OUTPUT_LABELS = "{arch}.classifier.output_labels" class Tokenizer: - MODEL = "tokenizer.ggml.model" - PRE = "tokenizer.ggml.pre" - LIST = "tokenizer.ggml.tokens" - TOKEN_TYPE = "tokenizer.ggml.token_type" - TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count" # for BERT-style token types - SCORES = "tokenizer.ggml.scores" - MERGES = "tokenizer.ggml.merges" - BOS_ID = "tokenizer.ggml.bos_token_id" - EOS_ID = "tokenizer.ggml.eos_token_id" - EOT_ID = "tokenizer.ggml.eot_token_id" - EOM_ID = "tokenizer.ggml.eom_token_id" - UNK_ID = "tokenizer.ggml.unknown_token_id" - SEP_ID = "tokenizer.ggml.seperator_token_id" - PAD_ID = "tokenizer.ggml.padding_token_id" - MASK_ID = "tokenizer.ggml.mask_token_id" - ADD_BOS = "tokenizer.ggml.add_bos_token" - ADD_EOS = "tokenizer.ggml.add_eos_token" - ADD_PREFIX = "tokenizer.ggml.add_space_prefix" - REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" + MODEL = "tokenizer.ggml.model" + PRE = "tokenizer.ggml.pre" + LIST = "tokenizer.ggml.tokens" + TOKEN_TYPE = "tokenizer.ggml.token_type" + TOKEN_TYPE_COUNT = ( + "tokenizer.ggml.token_type_count" # for BERT-style token types + ) + SCORES = "tokenizer.ggml.scores" + MERGES = "tokenizer.ggml.merges" + BOS_ID = "tokenizer.ggml.bos_token_id" + EOS_ID = "tokenizer.ggml.eos_token_id" + EOT_ID = "tokenizer.ggml.eot_token_id" + EOM_ID = "tokenizer.ggml.eom_token_id" + UNK_ID = "tokenizer.ggml.unknown_token_id" + SEP_ID = "tokenizer.ggml.seperator_token_id" + PAD_ID = "tokenizer.ggml.padding_token_id" + MASK_ID = "tokenizer.ggml.mask_token_id" + ADD_BOS = "tokenizer.ggml.add_bos_token" + ADD_EOS = "tokenizer.ggml.add_eos_token" + ADD_PREFIX = "tokenizer.ggml.add_space_prefix" + REMOVE_EXTRA_WS = "tokenizer.ggml.remove_extra_whitespaces" PRECOMPILED_CHARSMAP = "tokenizer.ggml.precompiled_charsmap" - HF_JSON = "tokenizer.huggingface.json" - RWKV = "tokenizer.rwkv.world" - CHAT_TEMPLATE = "tokenizer.chat_template" - CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" - CHAT_TEMPLATES = "tokenizer.chat_templates" + HF_JSON = "tokenizer.huggingface.json" + RWKV = "tokenizer.rwkv.world" + CHAT_TEMPLATE = "tokenizer.chat_template" + CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}" + CHAT_TEMPLATES = "tokenizer.chat_templates" # FIM/Infill special tokens constants - FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id" - FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id" - FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id" - FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" - FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" - FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" + FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id" + FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id" + FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id" + FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id" + FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id" + FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id" # deprecated: - PREFIX_ID = "tokenizer.ggml.prefix_token_id" - SUFFIX_ID = "tokenizer.ggml.suffix_token_id" - MIDDLE_ID = "tokenizer.ggml.middle_token_id" + PREFIX_ID = "tokenizer.ggml.prefix_token_id" + SUFFIX_ID = "tokenizer.ggml.suffix_token_id" + MIDDLE_ID = "tokenizer.ggml.middle_token_id" class Adapter: - TYPE = "adapter.type" + TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" class Clip: - PROJECTOR_TYPE = "clip.projector_type" - HAS_VISION_ENCODER = "clip.has_vision_encoder" - HAS_AUDIO_ENCODER = "clip.has_audio_encoder" + PROJECTOR_TYPE = "clip.projector_type" + HAS_VISION_ENCODER = "clip.has_vision_encoder" + HAS_AUDIO_ENCODER = "clip.has_audio_encoder" HAS_LLAVA_PROJECTOR = "clip.has_llava_projector" class ClipVision: - IMAGE_SIZE = "clip.vision.image_size" - PATCH_SIZE = "clip.vision.patch_size" - EMBEDDING_LENGTH = "clip.vision.embedding_length" + IMAGE_SIZE = "clip.vision.image_size" + PATCH_SIZE = "clip.vision.patch_size" + EMBEDDING_LENGTH = "clip.vision.embedding_length" FEED_FORWARD_LENGTH = "clip.vision.feed_forward_length" - PROJECTION_DIM = "clip.vision.projection_dim" - BLOCK_COUNT = "clip.vision.block_count" - IMAGE_MEAN = "clip.vision.image_mean" - IMAGE_STD = "clip.vision.image_std" - SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" - USE_GELU = "clip.use_gelu" - USE_SILU = "clip.use_silu" - N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + PROJECTION_DIM = "clip.vision.projection_dim" + BLOCK_COUNT = "clip.vision.block_count" + IMAGE_MEAN = "clip.vision.image_mean" + IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" + USE_GELU = "clip.use_gelu" + USE_SILU = "clip.use_silu" + N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl class Attention: - HEAD_COUNT = "clip.vision.attention.head_count" - LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" + HEAD_COUNT = "clip.vision.attention.head_count" + LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" class Projector: - SCALE_FACTOR = "clip.vision.projector.scale_factor" + SCALE_FACTOR = "clip.vision.projector.scale_factor" class ClipAudio: - NUM_MEL_BINS = "clip.audio.num_mel_bins" - EMBEDDING_LENGTH = "clip.audio.embedding_length" + NUM_MEL_BINS = "clip.audio.num_mel_bins" + EMBEDDING_LENGTH = "clip.audio.embedding_length" FEED_FORWARD_LENGTH = "clip.audio.feed_forward_length" - PROJECTION_DIM = "clip.audio.projection_dim" - BLOCK_COUNT = "clip.audio.block_count" + PROJECTION_DIM = "clip.audio.projection_dim" + BLOCK_COUNT = "clip.audio.block_count" class Attention: - HEAD_COUNT = "clip.audio.attention.head_count" - LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon" + HEAD_COUNT = "clip.audio.attention.head_count" + LAYERNORM_EPS = "clip.audio.attention.layer_norm_epsilon" class Projector: - STACK_FACTOR = "clip.audio.projector.stack_factor" + STACK_FACTOR = "clip.audio.projector.stack_factor" + # # recommended mapping of model tensor names for storage in gguf @@ -269,580 +276,578 @@ class Projector: class GGUFType: - MODEL = "model" + MODEL = "model" ADAPTER = "adapter" - MMPROJ = "mmproj" # dummy, unused for now + MMPROJ = "mmproj" # dummy, unused for now class MODEL_ARCH(IntEnum): - MMPROJ = auto() # dummy arch for clip.cpp - LLAMA = auto() - LLAMA4 = auto() - DECI = auto() - FALCON = auto() - BAICHUAN = auto() - GROK = auto() - GPT2 = auto() - GPTJ = auto() - GPTNEOX = auto() - MPT = auto() - STARCODER = auto() - REFACT = auto() - BERT = auto() - NOMIC_BERT = auto() - NOMIC_BERT_MOE = auto() - JINA_BERT_V2 = auto() - BLOOM = auto() - STABLELM = auto() - QWEN = auto() - QWEN2 = auto() - QWEN2MOE = auto() - QWEN2VL = auto() - QWEN3 = auto() - QWEN3MOE = auto() - SMALLTHINKERMOE = auto() - PHI2 = auto() - PHI3 = auto() - PHIMOE = auto() - PLAMO = auto() - CODESHELL = auto() - ORION = auto() - INTERNLM2 = auto() - MINICPM = auto() - MINICPM3 = auto() - GEMMA = auto() - GEMMA2 = auto() - GEMMA3 = auto() - STARCODER2 = auto() - RWKV6 = auto() - RWKV6QWEN2 = auto() - RWKV7 = auto() - ARWKV7 = auto() - MAMBA = auto() - XVERSE = auto() - COMMAND_R = auto() - COHERE2 = auto() - DBRX = auto() - OLMO = auto() - OLMO2 = auto() - OLMOE = auto() - OPENELM = auto() - ARCTIC = auto() - DEEPSEEK = auto() - DEEPSEEK2 = auto() - CHATGLM = auto() - GLM4 = auto() - BITNET = auto() - T5 = auto() - T5ENCODER = auto() - JAIS = auto() - NEMOTRON = auto() - EXAONE = auto() - GRANITE = auto() - GRANITE_MOE = auto() - CHAMELEON = auto() + MMPROJ = auto() # dummy arch for clip.cpp + LLAMA = auto() + LLAMA4 = auto() + DECI = auto() + FALCON = auto() + BAICHUAN = auto() + GROK = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + STARCODER = auto() + REFACT = auto() + BERT = auto() + NOMIC_BERT = auto() + NOMIC_BERT_MOE = auto() + JINA_BERT_V2 = auto() + BLOOM = auto() + STABLELM = auto() + QWEN = auto() + QWEN2 = auto() + QWEN2MOE = auto() + QWEN2VL = auto() + QWEN3 = auto() + QWEN3MOE = auto() + SMALLTHINKERMOE = auto() + PHI2 = auto() + PHI3 = auto() + PHIMOE = auto() + PLAMO = auto() + CODESHELL = auto() + ORION = auto() + INTERNLM2 = auto() + MINICPM = auto() + MINICPM3 = auto() + GEMMA = auto() + GEMMA2 = auto() + GEMMA3 = auto() + STARCODER2 = auto() + RWKV6 = auto() + RWKV6QWEN2 = auto() + RWKV7 = auto() + ARWKV7 = auto() + MAMBA = auto() + XVERSE = auto() + COMMAND_R = auto() + COHERE2 = auto() + DBRX = auto() + OLMO = auto() + OLMO2 = auto() + OLMOE = auto() + OPENELM = auto() + ARCTIC = auto() + DEEPSEEK = auto() + DEEPSEEK2 = auto() + CHATGLM = auto() + GLM4 = auto() + BITNET = auto() + T5 = auto() + T5ENCODER = auto() + JAIS = auto() + NEMOTRON = auto() + EXAONE = auto() + GRANITE = auto() + GRANITE_MOE = auto() + CHAMELEON = auto() WAVTOKENIZER_DEC = auto() - PLM = auto() - BAILINGMOE = auto() + PLM = auto() + BAILINGMOE = auto() class VISION_PROJECTOR_TYPE(IntEnum): - MLP = auto() - LDP = auto() - LDPV2 = auto() + MLP = auto() + LDP = auto() + LDPV2 = auto() RESAMPLER = auto() - GLM_EDGE = auto() - MERGER = auto() - GEMMA3 = auto() + GLM_EDGE = auto() + MERGER = auto() + GEMMA3 = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD = auto() - TOKEN_EMBD_NORM = auto() - TOKEN_TYPES = auto() - POS_EMBD = auto() - OUTPUT = auto() - OUTPUT_NORM = auto() - ROPE_FREQS = auto() - ROPE_FACTORS_LONG = auto() - ROPE_FACTORS_SHORT = auto() - ATTN_Q = auto() - ATTN_K = auto() - ATTN_V = auto() - ATTN_QKV = auto() - ATTN_OUT = auto() - ATTN_NORM = auto() - ATTN_NORM_2 = auto() - ATTN_OUT_NORM = auto() - ATTN_POST_NORM = auto() - ATTN_ROT_EMBD = auto() - FFN_GATE_INP = auto() - FFN_GATE_INP_SHEXP = auto() - FFN_NORM = auto() - FFN_PRE_NORM = auto() - FFN_POST_NORM = auto() - FFN_GATE = auto() - FFN_DOWN = auto() - FFN_UP = auto() - FFN_ACT = auto() - FFN_NORM_EXP = auto() - FFN_GATE_EXP = auto() - FFN_DOWN_EXP = auto() - FFN_UP_EXP = auto() - FFN_GATE_SHEXP = auto() - FFN_DOWN_SHEXP = auto() - FFN_UP_SHEXP = auto() - FFN_EXP_PROBS_B = auto() - ATTN_Q_NORM = auto() - ATTN_K_NORM = auto() - LAYER_OUT_NORM = auto() - SSM_IN = auto() - SSM_CONV1D = auto() - SSM_X = auto() - SSM_DT = auto() - SSM_A = auto() - SSM_D = auto() - SSM_OUT = auto() - TIME_MIX_W0 = auto() - TIME_MIX_W1 = auto() - TIME_MIX_W2 = auto() - TIME_MIX_A0 = auto() - TIME_MIX_A1 = auto() - TIME_MIX_A2 = auto() - TIME_MIX_V0 = auto() - TIME_MIX_V1 = auto() - TIME_MIX_V2 = auto() - TIME_MIX_G1 = auto() - TIME_MIX_G2 = auto() - TIME_MIX_K_K = auto() - TIME_MIX_K_A = auto() - TIME_MIX_R_K = auto() - TIME_MIX_LERP_X = auto() - TIME_MIX_LERP_K = auto() - TIME_MIX_LERP_V = auto() - TIME_MIX_LERP_R = auto() - TIME_MIX_LERP_G = auto() - TIME_MIX_LERP_FUSED = auto() - TIME_MIX_LERP_W = auto() - TIME_MIX_FIRST = auto() - TIME_MIX_DECAY = auto() - TIME_MIX_DECAY_W1 = auto() - TIME_MIX_DECAY_W2 = auto() - TIME_MIX_KEY = auto() - TIME_MIX_VALUE = auto() - TIME_MIX_RECEPTANCE = auto() - TIME_MIX_GATE = auto() - TIME_MIX_LN = auto() - TIME_MIX_OUTPUT = auto() - CHANNEL_MIX_LERP_K = auto() - CHANNEL_MIX_LERP_R = auto() - CHANNEL_MIX_KEY = auto() + TOKEN_EMBD = auto() + TOKEN_EMBD_NORM = auto() + TOKEN_TYPES = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ROPE_FACTORS_LONG = auto() + ROPE_FACTORS_SHORT = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_OUT_NORM = auto() + ATTN_POST_NORM = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE_INP = auto() + FFN_GATE_INP_SHEXP = auto() + FFN_NORM = auto() + FFN_PRE_NORM = auto() + FFN_POST_NORM = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_ACT = auto() + FFN_NORM_EXP = auto() + FFN_GATE_EXP = auto() + FFN_DOWN_EXP = auto() + FFN_UP_EXP = auto() + FFN_GATE_SHEXP = auto() + FFN_DOWN_SHEXP = auto() + FFN_UP_SHEXP = auto() + FFN_EXP_PROBS_B = auto() + ATTN_Q_NORM = auto() + ATTN_K_NORM = auto() + LAYER_OUT_NORM = auto() + SSM_IN = auto() + SSM_CONV1D = auto() + SSM_X = auto() + SSM_DT = auto() + SSM_A = auto() + SSM_D = auto() + SSM_OUT = auto() + TIME_MIX_W0 = auto() + TIME_MIX_W1 = auto() + TIME_MIX_W2 = auto() + TIME_MIX_A0 = auto() + TIME_MIX_A1 = auto() + TIME_MIX_A2 = auto() + TIME_MIX_V0 = auto() + TIME_MIX_V1 = auto() + TIME_MIX_V2 = auto() + TIME_MIX_G1 = auto() + TIME_MIX_G2 = auto() + TIME_MIX_K_K = auto() + TIME_MIX_K_A = auto() + TIME_MIX_R_K = auto() + TIME_MIX_LERP_X = auto() + TIME_MIX_LERP_K = auto() + TIME_MIX_LERP_V = auto() + TIME_MIX_LERP_R = auto() + TIME_MIX_LERP_G = auto() + TIME_MIX_LERP_FUSED = auto() + TIME_MIX_LERP_W = auto() + TIME_MIX_FIRST = auto() + TIME_MIX_DECAY = auto() + TIME_MIX_DECAY_W1 = auto() + TIME_MIX_DECAY_W2 = auto() + TIME_MIX_KEY = auto() + TIME_MIX_VALUE = auto() + TIME_MIX_RECEPTANCE = auto() + TIME_MIX_GATE = auto() + TIME_MIX_LN = auto() + TIME_MIX_OUTPUT = auto() + CHANNEL_MIX_LERP_K = auto() + CHANNEL_MIX_LERP_R = auto() + CHANNEL_MIX_KEY = auto() CHANNEL_MIX_RECEPTANCE = auto() - CHANNEL_MIX_VALUE = auto() - ATTN_Q_A = auto() - ATTN_Q_B = auto() - ATTN_KV_A_MQA = auto() - ATTN_KV_B = auto() - ATTN_K_B = auto() - ATTN_V_B = auto() - ATTN_Q_A_NORM = auto() - ATTN_KV_A_NORM = auto() - FFN_SUB_NORM = auto() - ATTN_SUB_NORM = auto() - DEC_ATTN_NORM = auto() - DEC_ATTN_Q = auto() - DEC_ATTN_K = auto() - DEC_ATTN_V = auto() - DEC_ATTN_OUT = auto() - DEC_ATTN_REL_B = auto() - DEC_CROSS_ATTN_NORM = auto() - DEC_CROSS_ATTN_Q = auto() - DEC_CROSS_ATTN_K = auto() - DEC_CROSS_ATTN_V = auto() - DEC_CROSS_ATTN_OUT = auto() + CHANNEL_MIX_VALUE = auto() + ATTN_Q_A = auto() + ATTN_Q_B = auto() + ATTN_KV_A_MQA = auto() + ATTN_KV_B = auto() + ATTN_K_B = auto() + ATTN_V_B = auto() + ATTN_Q_A_NORM = auto() + ATTN_KV_A_NORM = auto() + FFN_SUB_NORM = auto() + ATTN_SUB_NORM = auto() + DEC_ATTN_NORM = auto() + DEC_ATTN_Q = auto() + DEC_ATTN_K = auto() + DEC_ATTN_V = auto() + DEC_ATTN_OUT = auto() + DEC_ATTN_REL_B = auto() + DEC_CROSS_ATTN_NORM = auto() + DEC_CROSS_ATTN_Q = auto() + DEC_CROSS_ATTN_K = auto() + DEC_CROSS_ATTN_V = auto() + DEC_CROSS_ATTN_OUT = auto() DEC_CROSS_ATTN_REL_B = auto() - DEC_FFN_NORM = auto() - DEC_FFN_GATE = auto() - DEC_FFN_DOWN = auto() - DEC_FFN_UP = auto() - DEC_OUTPUT_NORM = auto() - ENC_ATTN_NORM = auto() - ENC_ATTN_Q = auto() - ENC_ATTN_K = auto() - ENC_ATTN_V = auto() - ENC_ATTN_OUT = auto() - ENC_ATTN_REL_B = auto() - ENC_FFN_NORM = auto() - ENC_FFN_GATE = auto() - ENC_FFN_DOWN = auto() - ENC_FFN_UP = auto() - ENC_OUTPUT_NORM = auto() - CLS = auto() # classifier - CLS_OUT = auto() # classifier output projection - CONV1D = auto() - CONVNEXT_DW = auto() - CONVNEXT_NORM = auto() - CONVNEXT_PW1 = auto() - CONVNEXT_PW2 = auto() - CONVNEXT_GAMMA = auto() - POSNET_CONV1 = auto() - POSNET_CONV2 = auto() - POSNET_NORM = auto() - POSNET_NORM1 = auto() - POSNET_NORM2 = auto() - POSNET_ATTN_NORM = auto() - POSNET_ATTN_Q = auto() - POSNET_ATTN_K = auto() - POSNET_ATTN_V = auto() - POSNET_ATTN_OUT = auto() + DEC_FFN_NORM = auto() + DEC_FFN_GATE = auto() + DEC_FFN_DOWN = auto() + DEC_FFN_UP = auto() + DEC_OUTPUT_NORM = auto() + ENC_ATTN_NORM = auto() + ENC_ATTN_Q = auto() + ENC_ATTN_K = auto() + ENC_ATTN_V = auto() + ENC_ATTN_OUT = auto() + ENC_ATTN_REL_B = auto() + ENC_FFN_NORM = auto() + ENC_FFN_GATE = auto() + ENC_FFN_DOWN = auto() + ENC_FFN_UP = auto() + ENC_OUTPUT_NORM = auto() + CLS = auto() # classifier + CLS_OUT = auto() # classifier output projection + CONV1D = auto() + CONVNEXT_DW = auto() + CONVNEXT_NORM = auto() + CONVNEXT_PW1 = auto() + CONVNEXT_PW2 = auto() + CONVNEXT_GAMMA = auto() + POSNET_CONV1 = auto() + POSNET_CONV2 = auto() + POSNET_NORM = auto() + POSNET_NORM1 = auto() + POSNET_NORM2 = auto() + POSNET_ATTN_NORM = auto() + POSNET_ATTN_Q = auto() + POSNET_ATTN_K = auto() + POSNET_ATTN_V = auto() + POSNET_ATTN_OUT = auto() # vision - V_MMPROJ = auto() - V_MMPROJ_FC = auto() - V_MMPROJ_MLP = auto() - V_MMPROJ_PEG = auto() - V_ENC_EMBD_CLS = auto() - V_ENC_EMBD_PATCH = auto() - V_ENC_EMBD_POS = auto() - V_ENC_INPUT_NORM = auto() - V_ENC_ATTN_Q = auto() - V_ENC_ATTN_Q_NORM = auto() - V_ENC_ATTN_K = auto() - V_ENC_ATTN_K_NORM = auto() - V_ENC_ATTN_V = auto() - V_ENC_ATTN_O = auto() - V_ENC_ATTN_O_NORM = auto() + V_MMPROJ = auto() + V_MMPROJ_FC = auto() + V_MMPROJ_MLP = auto() + V_MMPROJ_PEG = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_Q_NORM = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_K_NORM = auto() + V_ENC_ATTN_V = auto() + V_ENC_ATTN_O = auto() + V_ENC_ATTN_O_NORM = auto() V_ENC_POST_ATTN_NORM = auto() - V_ENC_FFN_UP = auto() - V_ENC_FFN_GATE = auto() - V_ENC_FFN_DOWN = auto() - V_LAYER_SCALE_1 = auto() - V_LAYER_SCALE_2 = auto() - V_PRE_NORM = auto() - V_POST_NORM = auto() - V_MM_INP_NORM = auto() - V_MM_INP_PROJ = auto() # gemma3 - V_MM_SOFT_EMB_NORM = auto() # gemma3 - V_RESMPL_POS_EMBD_K = auto() # minicpmv - V_RESMPL_ATTN_Q = auto() # minicpmv - V_RESMPL_ATTN_K = auto() # minicpmv - V_RESMPL_ATTN_V = auto() # minicpmv - V_RESMPL_ATTN_OUT = auto() # minicpmv - V_RESMPL_KV = auto() # minicpmv - V_RESMPL_KV_NORM = auto() # minicpmv - V_RESMPL_POST_NORM = auto() # minicpmv - V_RESMPL_Q_NORM = auto() # minicpmv - V_RESMPL_PROJ = auto() # minicpmv - V_RESMPL_QUERY = auto() # minicpmv - V_TOK_EMBD_IMG_BREAK = auto() # pixtral - V_MM_PATCH_MERGER = auto() # mistral small 3.1 + V_ENC_FFN_UP = auto() + V_ENC_FFN_GATE = auto() + V_ENC_FFN_DOWN = auto() + V_LAYER_SCALE_1 = auto() + V_LAYER_SCALE_2 = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() + V_MM_INP_NORM = auto() + V_MM_INP_PROJ = auto() # gemma3 + V_MM_SOFT_EMB_NORM = auto() # gemma3 + V_RESMPL_POS_EMBD_K = auto() # minicpmv + V_RESMPL_ATTN_Q = auto() # minicpmv + V_RESMPL_ATTN_K = auto() # minicpmv + V_RESMPL_ATTN_V = auto() # minicpmv + V_RESMPL_ATTN_OUT = auto() # minicpmv + V_RESMPL_KV = auto() # minicpmv + V_RESMPL_KV_NORM = auto() # minicpmv + V_RESMPL_POST_NORM = auto() # minicpmv + V_RESMPL_Q_NORM = auto() # minicpmv + V_RESMPL_PROJ = auto() # minicpmv + V_RESMPL_QUERY = auto() # minicpmv + V_TOK_EMBD_IMG_BREAK = auto() # pixtral + V_MM_PATCH_MERGER = auto() # mistral small 3.1 # audio (mtmd) - A_ENC_EMBD_POS = auto() - A_ENC_CONV1D = auto() - A_PRE_NORM = auto() - A_POST_NORM = auto() - A_ENC_ATTN_Q = auto() - A_ENC_ATTN_K = auto() - A_ENC_ATTN_V = auto() - A_ENC_INPUT_NORM = auto() - A_ENC_OUTPUT = auto() - A_ENC_OUTPUT_NORM = auto() - A_ENC_FFN_UP = auto() - A_ENC_FFN_GATE = auto() - A_ENC_FFN_DOWN = auto() - A_MMPROJ = auto() - A_MMPROJ_FC = auto() - A_MM_NORM_PRE = auto() - A_MM_NORM_MID = auto() + A_ENC_EMBD_POS = auto() + A_ENC_CONV1D = auto() + A_PRE_NORM = auto() + A_POST_NORM = auto() + A_ENC_ATTN_Q = auto() + A_ENC_ATTN_K = auto() + A_ENC_ATTN_V = auto() + A_ENC_INPUT_NORM = auto() + A_ENC_OUTPUT = auto() + A_ENC_OUTPUT_NORM = auto() + A_ENC_FFN_UP = auto() + A_ENC_FFN_GATE = auto() + A_ENC_FFN_DOWN = auto() + A_MMPROJ = auto() + A_MMPROJ_FC = auto() + A_MM_NORM_PRE = auto() + A_MM_NORM_MID = auto() # -- PowerInfer - LMHEAD_PROFILER = auto() + LMHEAD_PROFILER = auto() # -- PowerInfer end MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { - MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp - MODEL_ARCH.LLAMA: "llama", - MODEL_ARCH.LLAMA4: "llama4", - MODEL_ARCH.DECI: "deci", - MODEL_ARCH.FALCON: "falcon", - MODEL_ARCH.BAICHUAN: "baichuan", - MODEL_ARCH.GROK: "grok", - MODEL_ARCH.GPT2: "gpt2", - MODEL_ARCH.GPTJ: "gptj", - MODEL_ARCH.GPTNEOX: "gptneox", - MODEL_ARCH.MPT: "mpt", - MODEL_ARCH.STARCODER: "starcoder", - MODEL_ARCH.REFACT: "refact", - MODEL_ARCH.BERT: "bert", - MODEL_ARCH.NOMIC_BERT: "nomic-bert", - MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", - MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", - MODEL_ARCH.BLOOM: "bloom", - MODEL_ARCH.STABLELM: "stablelm", - MODEL_ARCH.QWEN: "qwen", - MODEL_ARCH.QWEN2: "qwen2", - MODEL_ARCH.QWEN2MOE: "qwen2moe", - MODEL_ARCH.QWEN2VL: "qwen2vl", - MODEL_ARCH.QWEN3: "qwen3", - MODEL_ARCH.QWEN3MOE: "qwen3moe", - MODEL_ARCH.SMALLTHINKERMOE: "smallthinker", - MODEL_ARCH.PHI2: "phi2", - MODEL_ARCH.PHI3: "phi3", - MODEL_ARCH.PHIMOE: "phimoe", - MODEL_ARCH.PLAMO: "plamo", - MODEL_ARCH.CODESHELL: "codeshell", - MODEL_ARCH.ORION: "orion", - MODEL_ARCH.INTERNLM2: "internlm2", - MODEL_ARCH.MINICPM: "minicpm", - MODEL_ARCH.MINICPM3: "minicpm3", - MODEL_ARCH.GEMMA: "gemma", - MODEL_ARCH.GEMMA2: "gemma2", - MODEL_ARCH.GEMMA3: "gemma3", - MODEL_ARCH.STARCODER2: "starcoder2", - MODEL_ARCH.RWKV6: "rwkv6", - MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", - MODEL_ARCH.RWKV7: "rwkv7", - MODEL_ARCH.ARWKV7: "arwkv7", - MODEL_ARCH.MAMBA: "mamba", - MODEL_ARCH.XVERSE: "xverse", - MODEL_ARCH.COMMAND_R: "command-r", - MODEL_ARCH.COHERE2: "cohere2", - MODEL_ARCH.DBRX: "dbrx", - MODEL_ARCH.OLMO: "olmo", - MODEL_ARCH.OLMO2: "olmo2", - MODEL_ARCH.OLMOE: "olmoe", - MODEL_ARCH.OPENELM: "openelm", - MODEL_ARCH.ARCTIC: "arctic", - MODEL_ARCH.DEEPSEEK: "deepseek", - MODEL_ARCH.DEEPSEEK2: "deepseek2", - MODEL_ARCH.CHATGLM: "chatglm", - MODEL_ARCH.GLM4: "glm4", - MODEL_ARCH.BITNET: "bitnet", - MODEL_ARCH.T5: "t5", - MODEL_ARCH.T5ENCODER: "t5encoder", - MODEL_ARCH.JAIS: "jais", - MODEL_ARCH.NEMOTRON: "nemotron", - MODEL_ARCH.EXAONE: "exaone", - MODEL_ARCH.GRANITE: "granite", - MODEL_ARCH.GRANITE_MOE: "granitemoe", - MODEL_ARCH.CHAMELEON: "chameleon", + MODEL_ARCH.MMPROJ: "clip", # dummy arch for clip.cpp + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.LLAMA4: "llama4", + MODEL_ARCH.DECI: "deci", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.BAICHUAN: "baichuan", + MODEL_ARCH.GROK: "grok", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", + MODEL_ARCH.STARCODER: "starcoder", + MODEL_ARCH.REFACT: "refact", + MODEL_ARCH.BERT: "bert", + MODEL_ARCH.NOMIC_BERT: "nomic-bert", + MODEL_ARCH.NOMIC_BERT_MOE: "nomic-bert-moe", + MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", + MODEL_ARCH.BLOOM: "bloom", + MODEL_ARCH.STABLELM: "stablelm", + MODEL_ARCH.QWEN: "qwen", + MODEL_ARCH.QWEN2: "qwen2", + MODEL_ARCH.QWEN2MOE: "qwen2moe", + MODEL_ARCH.QWEN2VL: "qwen2vl", + MODEL_ARCH.QWEN3: "qwen3", + MODEL_ARCH.QWEN3MOE: "qwen3moe", + MODEL_ARCH.SMALLTHINKERMOE: "smallthinker", + MODEL_ARCH.PHI2: "phi2", + MODEL_ARCH.PHI3: "phi3", + MODEL_ARCH.PHIMOE: "phimoe", + MODEL_ARCH.PLAMO: "plamo", + MODEL_ARCH.CODESHELL: "codeshell", + MODEL_ARCH.ORION: "orion", + MODEL_ARCH.INTERNLM2: "internlm2", + MODEL_ARCH.MINICPM: "minicpm", + MODEL_ARCH.MINICPM3: "minicpm3", + MODEL_ARCH.GEMMA: "gemma", + MODEL_ARCH.GEMMA2: "gemma2", + MODEL_ARCH.GEMMA3: "gemma3", + MODEL_ARCH.STARCODER2: "starcoder2", + MODEL_ARCH.RWKV6: "rwkv6", + MODEL_ARCH.RWKV6QWEN2: "rwkv6qwen2", + MODEL_ARCH.RWKV7: "rwkv7", + MODEL_ARCH.ARWKV7: "arwkv7", + MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.XVERSE: "xverse", + MODEL_ARCH.COMMAND_R: "command-r", + MODEL_ARCH.COHERE2: "cohere2", + MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OLMO2: "olmo2", + MODEL_ARCH.OLMOE: "olmoe", + MODEL_ARCH.OPENELM: "openelm", + MODEL_ARCH.ARCTIC: "arctic", + MODEL_ARCH.DEEPSEEK: "deepseek", + MODEL_ARCH.DEEPSEEK2: "deepseek2", + MODEL_ARCH.CHATGLM: "chatglm", + MODEL_ARCH.GLM4: "glm4", + MODEL_ARCH.BITNET: "bitnet", + MODEL_ARCH.T5: "t5", + MODEL_ARCH.T5ENCODER: "t5encoder", + MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.NEMOTRON: "nemotron", + MODEL_ARCH.EXAONE: "exaone", + MODEL_ARCH.GRANITE: "granite", + MODEL_ARCH.GRANITE_MOE: "granitemoe", + MODEL_ARCH.CHAMELEON: "chameleon", MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec", - MODEL_ARCH.PLM: "plm", - MODEL_ARCH.BAILINGMOE: "bailingmoe", + MODEL_ARCH.PLM: "plm", + MODEL_ARCH.BAILINGMOE: "bailingmoe", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { - VISION_PROJECTOR_TYPE.MLP: "mlp", - VISION_PROJECTOR_TYPE.LDP: "ldp", - VISION_PROJECTOR_TYPE.LDPV2: "ldpv2", + VISION_PROJECTOR_TYPE.MLP: "mlp", + VISION_PROJECTOR_TYPE.LDP: "ldp", + VISION_PROJECTOR_TYPE.LDPV2: "ldpv2", VISION_PROJECTOR_TYPE.RESAMPLER: "resampler", - VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", - VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", - VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", + VISION_PROJECTOR_TYPE.GLM_EDGE: "adapter", + VISION_PROJECTOR_TYPE.MERGER: "qwen2vl_merger", + VISION_PROJECTOR_TYPE.GEMMA3: "gemma3", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", - MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", - MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", - MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", - MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", - MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", - MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", - MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", - MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", - MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", - MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", - MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", - MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", - MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", - MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", - MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", - MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", - MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", - MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", - MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", - MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", - MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", - MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", - MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", - MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", - MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", - MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0", - MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1", - MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2", - MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0", - MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1", - MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2", - MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1", - MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2", - MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k", - MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a", - MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k", - MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x", - MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k", - MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", - MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r", - MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g", - MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused", - MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w", - MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first", - MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay", - MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1", - MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2", - MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key", - MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value", - MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance", - MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate", - MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln", - MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output", - MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k", - MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r", - MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key", - MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance", - MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value", - MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", - MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", - MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", - MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", - MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", - MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", - MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", - MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", - MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", - MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", - MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", - MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", - MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", - MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", - MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", - MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", - MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", - MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", - MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", - MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", - MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", - MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", - MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", - MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", - MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", - MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", - MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", - MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", - MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", - MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", - MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", - MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", - MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", - MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", - MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", - MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", - MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", - MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", - MODEL_TENSOR.CLS: "cls", - MODEL_TENSOR.CLS_OUT: "cls.output", - MODEL_TENSOR.CONV1D: "conv1d", - MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", - MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", - MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1", - MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2", - MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma", - MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1", - MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2", - MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm", - MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1", - MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2", - MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm", - MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q", - MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", - MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", - MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ROPE_FACTORS_LONG: "rope_factors_long", + MODEL_TENSOR.ROPE_FACTORS_SHORT: "rope_factors_short", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm", + MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm", + MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", + MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp", + MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp", + MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp", + MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn", + MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps", + MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", + MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", + MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", + MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", + MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in", + MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", + MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", + MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", + MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", + MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", + MODEL_TENSOR.TIME_MIX_A0: "blk.{bid}.time_mix_a0", + MODEL_TENSOR.TIME_MIX_A1: "blk.{bid}.time_mix_a1", + MODEL_TENSOR.TIME_MIX_A2: "blk.{bid}.time_mix_a2", + MODEL_TENSOR.TIME_MIX_V0: "blk.{bid}.time_mix_v0", + MODEL_TENSOR.TIME_MIX_V1: "blk.{bid}.time_mix_v1", + MODEL_TENSOR.TIME_MIX_V2: "blk.{bid}.time_mix_v2", + MODEL_TENSOR.TIME_MIX_G1: "blk.{bid}.time_mix_g1", + MODEL_TENSOR.TIME_MIX_G2: "blk.{bid}.time_mix_g2", + MODEL_TENSOR.TIME_MIX_K_K: "blk.{bid}.time_mix_k_k", + MODEL_TENSOR.TIME_MIX_K_A: "blk.{bid}.time_mix_k_a", + MODEL_TENSOR.TIME_MIX_R_K: "blk.{bid}.time_mix_r_k", + MODEL_TENSOR.TIME_MIX_LERP_X: "blk.{bid}.time_mix_lerp_x", + MODEL_TENSOR.TIME_MIX_LERP_K: "blk.{bid}.time_mix_lerp_k", + MODEL_TENSOR.TIME_MIX_LERP_V: "blk.{bid}.time_mix_lerp_v", + MODEL_TENSOR.TIME_MIX_LERP_R: "blk.{bid}.time_mix_lerp_r", + MODEL_TENSOR.TIME_MIX_LERP_G: "blk.{bid}.time_mix_lerp_g", + MODEL_TENSOR.TIME_MIX_LERP_FUSED: "blk.{bid}.time_mix_lerp_fused", + MODEL_TENSOR.TIME_MIX_LERP_W: "blk.{bid}.time_mix_lerp_w", + MODEL_TENSOR.TIME_MIX_FIRST: "blk.{bid}.time_mix_first", + MODEL_TENSOR.TIME_MIX_DECAY: "blk.{bid}.time_mix_decay", + MODEL_TENSOR.TIME_MIX_DECAY_W1: "blk.{bid}.time_mix_decay_w1", + MODEL_TENSOR.TIME_MIX_DECAY_W2: "blk.{bid}.time_mix_decay_w2", + MODEL_TENSOR.TIME_MIX_KEY: "blk.{bid}.time_mix_key", + MODEL_TENSOR.TIME_MIX_VALUE: "blk.{bid}.time_mix_value", + MODEL_TENSOR.TIME_MIX_RECEPTANCE: "blk.{bid}.time_mix_receptance", + MODEL_TENSOR.TIME_MIX_GATE: "blk.{bid}.time_mix_gate", + MODEL_TENSOR.TIME_MIX_LN: "blk.{bid}.time_mix_ln", + MODEL_TENSOR.TIME_MIX_OUTPUT: "blk.{bid}.time_mix_output", + MODEL_TENSOR.CHANNEL_MIX_LERP_K: "blk.{bid}.channel_mix_lerp_k", + MODEL_TENSOR.CHANNEL_MIX_LERP_R: "blk.{bid}.channel_mix_lerp_r", + MODEL_TENSOR.CHANNEL_MIX_KEY: "blk.{bid}.channel_mix_key", + MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: "blk.{bid}.channel_mix_receptance", + MODEL_TENSOR.CHANNEL_MIX_VALUE: "blk.{bid}.channel_mix_value", + MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", + MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", + MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa", + MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b", + MODEL_TENSOR.ATTN_K_B: "blk.{bid}.attn_k_b", + MODEL_TENSOR.ATTN_V_B: "blk.{bid}.attn_v_b", + MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm", + MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm", + MODEL_TENSOR.ATTN_SUB_NORM: "blk.{bid}.attn_sub_norm", + MODEL_TENSOR.FFN_SUB_NORM: "blk.{bid}.ffn_sub_norm", + MODEL_TENSOR.DEC_ATTN_NORM: "dec.blk.{bid}.attn_norm", + MODEL_TENSOR.DEC_ATTN_Q: "dec.blk.{bid}.attn_q", + MODEL_TENSOR.DEC_ATTN_K: "dec.blk.{bid}.attn_k", + MODEL_TENSOR.DEC_ATTN_V: "dec.blk.{bid}.attn_v", + MODEL_TENSOR.DEC_ATTN_OUT: "dec.blk.{bid}.attn_o", + MODEL_TENSOR.DEC_ATTN_REL_B: "dec.blk.{bid}.attn_rel_b", + MODEL_TENSOR.DEC_CROSS_ATTN_NORM: "dec.blk.{bid}.cross_attn_norm", + MODEL_TENSOR.DEC_CROSS_ATTN_Q: "dec.blk.{bid}.cross_attn_q", + MODEL_TENSOR.DEC_CROSS_ATTN_K: "dec.blk.{bid}.cross_attn_k", + MODEL_TENSOR.DEC_CROSS_ATTN_V: "dec.blk.{bid}.cross_attn_v", + MODEL_TENSOR.DEC_CROSS_ATTN_OUT: "dec.blk.{bid}.cross_attn_o", + MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: "dec.blk.{bid}.cross_attn_rel_b", + MODEL_TENSOR.DEC_FFN_NORM: "dec.blk.{bid}.ffn_norm", + MODEL_TENSOR.DEC_FFN_GATE: "dec.blk.{bid}.ffn_gate", + MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down", + MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up", + MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm", + MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm", + MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q", + MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k", + MODEL_TENSOR.ENC_ATTN_V: "enc.blk.{bid}.attn_v", + MODEL_TENSOR.ENC_ATTN_OUT: "enc.blk.{bid}.attn_o", + MODEL_TENSOR.ENC_ATTN_REL_B: "enc.blk.{bid}.attn_rel_b", + MODEL_TENSOR.ENC_FFN_NORM: "enc.blk.{bid}.ffn_norm", + MODEL_TENSOR.ENC_FFN_GATE: "enc.blk.{bid}.ffn_gate", + MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", + MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", + MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + MODEL_TENSOR.CLS: "cls", + MODEL_TENSOR.CLS_OUT: "cls.output", + MODEL_TENSOR.CONV1D: "conv1d", + MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", + MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", + MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1", + MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2", + MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma", + MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1", + MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2", + MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm", + MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1", + MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2", + MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm", + MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q", + MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k", + MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v", + MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output", # vision - MODEL_TENSOR.V_MMPROJ: "mm.{bid}", - MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", - MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}", - MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", - MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", - MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", - MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", - MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", - MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm", - MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", - MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", - MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", - MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", - MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", - MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", - MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", - MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", - MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", - MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", - MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", - MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", - MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", - MODEL_TENSOR.V_POST_NORM: "v.post_ln", - MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", - MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", - MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", - MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", - MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", - MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", - MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v", - MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out", - MODEL_TENSOR.V_RESMPL_KV: "resampler.kv", - MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv", - MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post", - MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q", - MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", - MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", - MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral - MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 + MODEL_TENSOR.V_MMPROJ: "mm.{bid}", + MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", + MODEL_TENSOR.V_MMPROJ_MLP: "mm.model.mlp.{bid}", + MODEL_TENSOR.V_MMPROJ_PEG: "mm.model.peg.{bid}", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.class_embd", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.patch_embd", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.position_embd", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_Q_NORM: "v.blk.{bid}.attn_q_norm", + MODEL_TENSOR.V_ENC_ATTN_K: "v.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm", + MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1", + MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out", + MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm", + MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2", + MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", + MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", + MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", + MODEL_TENSOR.V_POST_NORM: "v.post_ln", + MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", + MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", + MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", + MODEL_TENSOR.V_RESMPL_ATTN_K: "resampler.attn.k", + MODEL_TENSOR.V_RESMPL_ATTN_V: "resampler.attn.v", + MODEL_TENSOR.V_RESMPL_ATTN_OUT: "resampler.attn.out", + MODEL_TENSOR.V_RESMPL_KV: "resampler.kv", + MODEL_TENSOR.V_RESMPL_KV_NORM: "resampler.ln_kv", + MODEL_TENSOR.V_RESMPL_POST_NORM: "resampler.ln_post", + MODEL_TENSOR.V_RESMPL_Q_NORM: "resampler.ln_q", + MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", + MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", + MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral + MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 # audio (mtmd) - MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", - MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", - MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", - MODEL_TENSOR.A_POST_NORM: "a.post_ln", - MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", - MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k", - MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v", - MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", - MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", - MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", - MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", - MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", - MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", - MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", - MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", - MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", - MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", - + MODEL_TENSOR.A_ENC_EMBD_POS: "a.position_embd", + MODEL_TENSOR.A_ENC_CONV1D: "a.conv1d.{bid}", + MODEL_TENSOR.A_PRE_NORM: "a.pre_ln", + MODEL_TENSOR.A_POST_NORM: "a.post_ln", + MODEL_TENSOR.A_ENC_ATTN_Q: "a.blk.{bid}.attn_q", + MODEL_TENSOR.A_ENC_ATTN_K: "a.blk.{bid}.attn_k", + MODEL_TENSOR.A_ENC_ATTN_V: "a.blk.{bid}.attn_v", + MODEL_TENSOR.A_ENC_INPUT_NORM: "a.blk.{bid}.ln1", + MODEL_TENSOR.A_ENC_OUTPUT: "a.blk.{bid}.attn_out", + MODEL_TENSOR.A_ENC_OUTPUT_NORM: "a.blk.{bid}.ln2", + MODEL_TENSOR.A_ENC_FFN_UP: "a.blk.{bid}.ffn_up", + MODEL_TENSOR.A_ENC_FFN_GATE: "a.blk.{bid}.ffn_gate", + MODEL_TENSOR.A_ENC_FFN_DOWN: "a.blk.{bid}.ffn_down", + MODEL_TENSOR.A_MMPROJ: "mm.a.mlp.{bid}", + MODEL_TENSOR.A_MMPROJ_FC: "mm.a.fc", + MODEL_TENSOR.A_MM_NORM_PRE: "mm.a.norm_pre", + MODEL_TENSOR.A_MM_NORM_MID: "mm.a.norm_mid", # -- PowerInfer - MODEL_TENSOR.LMHEAD_PROFILER: "lm_head_profiler", + MODEL_TENSOR.LMHEAD_PROFILER: "lm_head_profiler", # -- PowerInfer end - } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1848,7 +1853,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_DOWN, ], - MODEL_ARCH.CHATGLM : [ + MODEL_ARCH.CHATGLM: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.OUTPUT_NORM, @@ -1863,7 +1868,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], - MODEL_ARCH.GLM4 : [ + MODEL_ARCH.GLM4: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.ROPE_FREQS, MODEL_TENSOR.OUTPUT_NORM, @@ -2137,66 +2142,66 @@ class MODEL_TENSOR(IntEnum): class TokenType(IntEnum): - NORMAL = 1 - UNKNOWN = 2 - CONTROL = 3 + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 USER_DEFINED = 4 - UNUSED = 5 - BYTE = 6 + UNUSED = 5 + BYTE = 6 class RopeScalingType(Enum): - NONE = 'none' - LINEAR = 'linear' - YARN = 'yarn' - LONGROPE = 'longrope' + NONE = "none" + LINEAR = "linear" + YARN = "yarn" + LONGROPE = "longrope" class PoolingType(IntEnum): NONE = 0 MEAN = 1 - CLS = 2 + CLS = 2 LAST = 3 RANK = 4 class GGMLQuantizationType(IntEnum): - F32 = 0 - F16 = 1 - Q4_0 = 2 - Q4_1 = 3 - Q5_0 = 6 - Q5_1 = 7 - Q8_0 = 8 - Q8_1 = 9 - Q2_K = 10 - Q3_K = 11 - Q4_K = 12 - Q5_K = 13 - Q6_K = 14 - Q8_K = 15 + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 IQ2_XXS = 16 - IQ2_XS = 17 + IQ2_XS = 17 IQ3_XXS = 18 - IQ1_S = 19 - IQ4_NL = 20 - IQ3_S = 21 - IQ2_S = 22 - IQ4_XS = 23 - I8 = 24 - I16 = 25 - I32 = 26 - I64 = 27 - F64 = 28 - IQ1_M = 29 - BF16 = 30 - TQ1_0 = 34 - TQ2_0 = 35 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = 22 + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + BF16 = 30 + TQ1_0 = 34 + TQ2_0 = 35 class ExpertGatingFuncType(IntEnum): - SOFTMAX = 1 - SIGMOID = 2 + SOFTMAX = 1 + SIGMOID = 2 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -2205,46 +2210,46 @@ class ExpertGatingFuncType(IntEnum): # from llama_ftype in llama.h # ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE. class LlamaFileType(IntEnum): - ALL_F32 = 0 - MOSTLY_F16 = 1 # except 1d tensors - MOSTLY_Q4_0 = 2 # except 1d tensors - MOSTLY_Q4_1 = 3 # except 1d tensors + ALL_F32 = 0 + MOSTLY_F16 = 1 # except 1d tensors + MOSTLY_Q4_0 = 2 # except 1d tensors + MOSTLY_Q4_1 = 3 # except 1d tensors # MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16 # MOSTLY_Q4_2 = 5 # support has been removed # MOSTLY_Q4_3 = 6 # support has been removed - MOSTLY_Q8_0 = 7 # except 1d tensors - MOSTLY_Q5_0 = 8 # except 1d tensors - MOSTLY_Q5_1 = 9 # except 1d tensors - MOSTLY_Q2_K = 10 # except 1d tensors - MOSTLY_Q3_K_S = 11 # except 1d tensors - MOSTLY_Q3_K_M = 12 # except 1d tensors - MOSTLY_Q3_K_L = 13 # except 1d tensors - MOSTLY_Q4_K_S = 14 # except 1d tensors - MOSTLY_Q4_K_M = 15 # except 1d tensors - MOSTLY_Q5_K_S = 16 # except 1d tensors - MOSTLY_Q5_K_M = 17 # except 1d tensors - MOSTLY_Q6_K = 18 # except 1d tensors - MOSTLY_IQ2_XXS = 19 # except 1d tensors - MOSTLY_IQ2_XS = 20 # except 1d tensors - MOSTLY_Q2_K_S = 21 # except 1d tensors - MOSTLY_IQ3_XS = 22 # except 1d tensors - MOSTLY_IQ3_XXS = 23 # except 1d tensors - MOSTLY_IQ1_S = 24 # except 1d tensors - MOSTLY_IQ4_NL = 25 # except 1d tensors - MOSTLY_IQ3_S = 26 # except 1d tensors - MOSTLY_IQ3_M = 27 # except 1d tensors - MOSTLY_IQ2_S = 28 # except 1d tensors - MOSTLY_IQ2_M = 29 # except 1d tensors - MOSTLY_IQ4_XS = 30 # except 1d tensors - MOSTLY_IQ1_M = 31 # except 1d tensors - MOSTLY_BF16 = 32 # except 1d tensors + MOSTLY_Q8_0 = 7 # except 1d tensors + MOSTLY_Q5_0 = 8 # except 1d tensors + MOSTLY_Q5_1 = 9 # except 1d tensors + MOSTLY_Q2_K = 10 # except 1d tensors + MOSTLY_Q3_K_S = 11 # except 1d tensors + MOSTLY_Q3_K_M = 12 # except 1d tensors + MOSTLY_Q3_K_L = 13 # except 1d tensors + MOSTLY_Q4_K_S = 14 # except 1d tensors + MOSTLY_Q4_K_M = 15 # except 1d tensors + MOSTLY_Q5_K_S = 16 # except 1d tensors + MOSTLY_Q5_K_M = 17 # except 1d tensors + MOSTLY_Q6_K = 18 # except 1d tensors + MOSTLY_IQ2_XXS = 19 # except 1d tensors + MOSTLY_IQ2_XS = 20 # except 1d tensors + MOSTLY_Q2_K_S = 21 # except 1d tensors + MOSTLY_IQ3_XS = 22 # except 1d tensors + MOSTLY_IQ3_XXS = 23 # except 1d tensors + MOSTLY_IQ1_S = 24 # except 1d tensors + MOSTLY_IQ4_NL = 25 # except 1d tensors + MOSTLY_IQ3_S = 26 # except 1d tensors + MOSTLY_IQ3_M = 27 # except 1d tensors + MOSTLY_IQ2_S = 28 # except 1d tensors + MOSTLY_IQ2_M = 29 # except 1d tensors + MOSTLY_IQ4_XS = 30 # except 1d tensors + MOSTLY_IQ1_M = 31 # except 1d tensors + MOSTLY_BF16 = 32 # except 1d tensors # MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack # MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack # MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack - MOSTLY_TQ1_0 = 36 # except 1d tensors - MOSTLY_TQ2_0 = 37 # except 1d tensors + MOSTLY_TQ1_0 = 36 # except 1d tensors + MOSTLY_TQ2_0 = 37 # except 1d tensors - GUESSED = 1024 # not specified in the model file + GUESSED = 1024 # not specified in the model file class GGUFEndian(IntEnum): @@ -2253,18 +2258,18 @@ class GGUFEndian(IntEnum): class GGUFValueType(IntEnum): - UINT8 = 0 - INT8 = 1 - UINT16 = 2 - INT16 = 3 - UINT32 = 4 - INT32 = 5 + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 FLOAT32 = 6 - BOOL = 7 - STRING = 8 - ARRAY = 9 - UINT64 = 10 - INT64 = 11 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 FLOAT64 = 12 @staticmethod @@ -2293,110 +2298,110 @@ class VisionProjectorType: QWEN25VL = "qwen2.5vl_merger" ULTRAVOX = "ultravox" INTERNVL = "internvl" - QWEN2A = "qwen2a" # audio - QWEN25O = "qwen2.5o" # omni + QWEN2A = "qwen2a" # audio + QWEN25O = "qwen2.5o" # omni # Items here are (block size, type size) QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { - GGMLQuantizationType.F32: (1, 4), - GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q4_0: (32, 2 + 16), - GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), - GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), - GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), - GGMLQuantizationType.Q8_0: (32, 2 + 32), - GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), - GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), - GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), - GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), - GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), - GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q4_0: (32, 2 + 16), + GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), + GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), + GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), + GGMLQuantizationType.Q8_0: (32, 2 + 32), + GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), + GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), + GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), + GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), + GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), - GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), + GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), - GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), - GGMLQuantizationType.IQ4_NL: (32, 2 + 16), - GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), - GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), - GGMLQuantizationType.I8: (1, 1), - GGMLQuantizationType.I16: (1, 2), - GGMLQuantizationType.I32: (1, 4), - GGMLQuantizationType.I64: (1, 8), - GGMLQuantizationType.F64: (1, 8), - GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), - GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), - GGMLQuantizationType.TQ2_0: (256, 2 + 64), + GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), + GGMLQuantizationType.IQ4_NL: (32, 2 + 16), + GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), + GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I8: (1, 1), + GGMLQuantizationType.I16: (1, 2), + GGMLQuantizationType.I32: (1, 4), + GGMLQuantizationType.I64: (1, 8), + GGMLQuantizationType.F64: (1, 8), + GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), + GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), + GGMLQuantizationType.TQ2_0: (256, 2 + 64), } # Aliases for backward compatibility. # general -KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE +KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION -KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT -KEY_GENERAL_NAME = Keys.General.NAME -KEY_GENERAL_AUTHOR = Keys.General.AUTHOR -KEY_GENERAL_URL = Keys.General.URL -KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION -KEY_GENERAL_LICENSE = Keys.General.LICENSE -KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL -KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE +KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT +KEY_GENERAL_NAME = Keys.General.NAME +KEY_GENERAL_AUTHOR = Keys.General.AUTHOR +KEY_GENERAL_URL = Keys.General.URL +KEY_GENERAL_DESCRIPTION = Keys.General.DESCRIPTION +KEY_GENERAL_LICENSE = Keys.General.LICENSE +KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL +KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE # LLM -KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE -KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH -KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH -KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT -KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH +KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE +KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH +KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH +KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT +KEY_FEED_FORWARD_LENGTH = Keys.LLM.FEED_FORWARD_LENGTH KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL -KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT +KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT # attention -KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT -KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV -KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS -KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV -KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS +KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT +KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV +KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS +KEY_ATTENTION_CLAMP_KQV = Keys.Attention.CLAMP_KQV +KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS # RoPE -KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT -KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE -KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE -KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR +KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT +KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE +KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE +KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN -KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED +KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED # SSM -KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL -KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE -KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE +KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL +KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE +KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK -KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS +KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization -KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL -KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE -KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST +KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL +KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE +KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE -KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES -KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES -KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID -KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID -KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID -KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID -KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID -KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID -KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID -KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID -KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON -KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV +KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES +KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES +KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID +KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID +KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID +KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID +KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID +KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID +KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID +KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID +KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON +KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID @@ -2406,6 +2411,6 @@ class VisionProjectorType: KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID # deprecated -KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID -KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID -KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID \ No newline at end of file +KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID +KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID +KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID diff --git a/smallthinker/gguf-py/gguf/gguf_reader.py b/smallthinker/gguf-py/gguf/gguf_reader.py index d87e8f72..0d24fdfa 100644 --- a/smallthinker/gguf-py/gguf/gguf_reader.py +++ b/smallthinker/gguf-py/gguf/gguf_reader.py @@ -56,7 +56,7 @@ class ReaderField(NamedTuple): def contents(self, index_or_slice: int | slice = slice(None)) -> Any: if self.types: - to_string = lambda x: str(x.tobytes(), encoding='utf-8') # noqa: E731 + to_string = lambda x: str(x.tobytes(), encoding="utf-8") # noqa: E731 main_type = self.types[0] if main_type == GGUFValueType.ARRAY: @@ -66,9 +66,9 @@ def contents(self, index_or_slice: int | slice = slice(None)) -> Any: indices = self.data[index_or_slice] if isinstance(index_or_slice, int): - return to_string(self.parts[indices]) # type: ignore + return to_string(self.parts[indices]) # type: ignore else: - return [to_string(self.parts[idx]) for idx in indices] # type: ignore + return [to_string(self.parts[idx]) for idx in indices] # type: ignore else: # FIXME: When/if _get_field_parts() support multi-dimensional arrays, this must do so too @@ -87,7 +87,11 @@ def contents(self, index_or_slice: int | slice = slice(None)) -> Any: if isinstance(index_or_slice, int): return self.parts[self.data[index_or_slice]].tolist()[0] else: - return [pv for idx in self.data[index_or_slice] for pv in self.parts[idx].tolist()] + return [ + pv + for idx in self.data[index_or_slice] + for pv in self.parts[idx].tolist() + ] if main_type == GGUFValueType.STRING: return to_string(self.parts[-1]) @@ -110,32 +114,34 @@ class ReaderTensor(NamedTuple): class GGUFReader: # I - same as host, S - swapped - byte_order: Literal['I', 'S'] = 'I' + byte_order: Literal["I", "S"] = "I" alignment: int = GGUF_DEFAULT_ALIGNMENT data_offset: int # Note: Internal helper, API may change. gguf_scalar_to_np: dict[GGUFValueType, type[np.generic]] = { - GGUFValueType.UINT8: np.uint8, - GGUFValueType.INT8: np.int8, - GGUFValueType.UINT16: np.uint16, - GGUFValueType.INT16: np.int16, - GGUFValueType.UINT32: np.uint32, - GGUFValueType.INT32: np.int32, + GGUFValueType.UINT8: np.uint8, + GGUFValueType.INT8: np.int8, + GGUFValueType.UINT16: np.uint16, + GGUFValueType.INT16: np.int16, + GGUFValueType.UINT32: np.uint32, + GGUFValueType.INT32: np.int32, GGUFValueType.FLOAT32: np.float32, - GGUFValueType.UINT64: np.uint64, - GGUFValueType.INT64: np.int64, + GGUFValueType.UINT64: np.uint64, + GGUFValueType.INT64: np.int64, GGUFValueType.FLOAT64: np.float64, - GGUFValueType.BOOL: np.bool_, + GGUFValueType.BOOL: np.bool_, } - def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = 'r'): - self.data = np.memmap(path, mode = mode) + def __init__( + self, path: os.PathLike[str] | str, mode: Literal["r", "r+", "c"] = "r" + ): + self.data = np.memmap(path, mode=mode) offs = 0 # Check for GGUF magic - if self._get(offs, np.uint32, override_order = '<')[0] != GGUF_MAGIC: - raise ValueError('GGUF magic invalid') + if self._get(offs, np.uint32, override_order="<")[0] != GGUF_MAGIC: + raise ValueError("GGUF magic invalid") offs += 4 # Check GGUF version @@ -143,11 +149,15 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = if temp_version[0] & 65535 == 0: # If we get 0 here that means it's (probably) a GGUF file created for # the opposite byte order of the machine this script is running on. - self.byte_order = 'S' - temp_version = temp_version.view(temp_version.dtype.newbyteorder(self.byte_order)) + self.byte_order = "S" + temp_version = temp_version.view( + temp_version.dtype.newbyteorder(self.byte_order) + ) version = temp_version[0] if version not in READER_SUPPORTED_VERSIONS: - raise ValueError(f'Sorry, file appears to be version {version} which we cannot handle') + raise ValueError( + f"Sorry, file appears to be version {version} which we cannot handle" + ) if sys.byteorder == "little": # Host is little endian host_endian = GGUFEndian.LITTLE @@ -159,21 +169,37 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = self.endianess = swapped_endian if self.byte_order == "S" else host_endian self.fields: OrderedDict[str, ReaderField] = OrderedDict() self.tensors: list[ReaderTensor] = [] - offs += self._push_field(ReaderField(offs, 'GGUF.version', [temp_version], [0], [GGUFValueType.UINT32])) + offs += self._push_field( + ReaderField( + offs, "GGUF.version", [temp_version], [0], [GGUFValueType.UINT32] + ) + ) # Check tensor count and kv count temp_counts = self._get(offs, np.uint64, 2) - offs += self._push_field(ReaderField(offs, 'GGUF.tensor_count', [temp_counts[:1]], [0], [GGUFValueType.UINT64])) - offs += self._push_field(ReaderField(offs, 'GGUF.kv_count', [temp_counts[1:]], [0], [GGUFValueType.UINT64])) + offs += self._push_field( + ReaderField( + offs, + "GGUF.tensor_count", + [temp_counts[:1]], + [0], + [GGUFValueType.UINT64], + ) + ) + offs += self._push_field( + ReaderField( + offs, "GGUF.kv_count", [temp_counts[1:]], [0], [GGUFValueType.UINT64] + ) + ) tensor_count, kv_count = temp_counts offs = self._build_fields(offs, kv_count) # Build Tensor Info Fields offs, tensors_fields = self._build_tensor_info(offs, tensor_count) - new_align = self.fields.get('general.alignment') + new_align = self.fields.get("general.alignment") if new_align is not None: if new_align.types != [GGUFValueType.UINT32]: - raise ValueError('Bad type for general.alignment field') + raise ValueError("Bad type for general.alignment field") self.alignment = new_align.parts[-1][0] padding = offs % self.alignment if padding != 0: @@ -181,7 +207,7 @@ def __init__(self, path: os.PathLike[str] | str, mode: Literal['r', 'r+', 'c'] = self.data_offset = offs self._build_tensors(offs, tensors_fields) - _DT = TypeVar('_DT', bound = npt.DTypeLike) + _DT = TypeVar("_DT", bound=npt.DTypeLike) # Fetch a key/value metadata field by key. def get_field(self, key: str) -> Union[ReaderField, None]: @@ -192,31 +218,43 @@ def get_tensor(self, idx: int) -> ReaderTensor: return self.tensors[idx] def _get( - self, offset: int, dtype: npt.DTypeLike, count: int = 1, override_order: None | Literal['I', 'S', '<'] = None, + self, + offset: int, + dtype: npt.DTypeLike, + count: int = 1, + override_order: None | Literal["I", "S", "<"] = None, ) -> npt.NDArray[Any]: count = int(count) - itemsize = int(np.empty([], dtype = dtype).itemsize) + itemsize = int(np.empty([], dtype=dtype).itemsize) end_offs = offset + itemsize * count arr = self.data[offset:end_offs].view(dtype=dtype)[:count] - return arr.view(arr.dtype.newbyteorder(self.byte_order if override_order is None else override_order)) + return arr.view( + arr.dtype.newbyteorder( + self.byte_order if override_order is None else override_order + ) + ) def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: # TODO: add option to generate error on duplicate keys # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') - logger.warning(f'Duplicate key {field.name} at offset {field.offset}') - self.fields[field.name + '_{}'.format(field.offset)] = field + logger.warning(f"Duplicate key {field.name} at offset {field.offset}") + self.fields[field.name + "_{}".format(field.offset)] = field else: self.fields[field.name] = field return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) - def _get_str(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + def _get_str( + self, offset: int + ) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: slen = self._get(offset, np.uint64) return slen, self._get(offset + 8, np.uint8, slen[0]) def _get_field_parts( - self, orig_offs: int, raw_type: int, + self, + orig_offs: int, + raw_type: int, ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: offs = orig_offs types: list[GGUFValueType] = [] @@ -242,7 +280,9 @@ def _get_field_parts( data_idxs: list[int] = [] # FIXME: Handle multi-dimensional arrays properly instead of flattening for idx in range(alen[0]): - curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts(offs, raw_itype[0]) + curr_size, curr_parts, curr_idxs, curr_types = self._get_field_parts( + offs, raw_itype[0] + ) if idx == 0: types += curr_types idxs_offs = len(aparts) @@ -251,7 +291,7 @@ def _get_field_parts( offs += curr_size return offs - orig_offs, aparts, data_idxs, types # We can't deal with this one. - raise ValueError(f'Unknown/unhandled field type {gtype}') + raise ValueError(f"Unknown/unhandled field type {gtype}") def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: offs = orig_offs @@ -278,7 +318,7 @@ def _get_tensor_info_field(self, orig_offs: int) -> ReaderField: return ReaderField( orig_offs, - str(bytes(name_data), encoding = 'utf-8'), + str(bytes(name_data), encoding="utf-8"), [name_len, name_data, n_dims, dims, raw_dtype, offset_tensor], [1, 3, 4, 5], ) @@ -292,19 +332,26 @@ def _build_fields(self, offs: int, count: int) -> int: offs += int(raw_kv_type.nbytes) parts: list[npt.NDArray[Any]] = [kv_klen, kv_kdata, raw_kv_type] idxs_offs = len(parts) - field_size, field_parts, field_idxs, field_types = self._get_field_parts(offs, raw_kv_type[0]) + field_size, field_parts, field_idxs, field_types = self._get_field_parts( + offs, raw_kv_type[0] + ) parts += field_parts - self._push_field(ReaderField( - orig_offs, - str(bytes(kv_kdata), encoding = 'utf-8'), - parts, - [idx + idxs_offs for idx in field_idxs], - field_types, - ), skip_sum = True) + self._push_field( + ReaderField( + orig_offs, + str(bytes(kv_kdata), encoding="utf-8"), + parts, + [idx + idxs_offs for idx in field_idxs], + field_types, + ), + skip_sum=True, + ) offs += field_size return offs - def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderField]]: + def _build_tensor_info( + self, offs: int, count: int + ) -> tuple[int, list[ReaderField]]: tensor_fields = [] for _ in range(count): field = self._get_tensor_info_field(offs) @@ -314,13 +361,13 @@ def _build_tensor_info(self, offs: int, count: int) -> tuple[int, list[ReaderFie def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: tensors = [] - tensor_names = set() # keep track of name to prevent duplicated tensors + tensor_names = set() # keep track of name to prevent duplicated tensors for field in fields: _name_len, name_data, _n_dims, dims, raw_dtype, offset_tensor = field.parts # check if there's any tensor having same name already in the list - tensor_name = str(bytes(name_data), encoding = 'utf-8') + tensor_name = str(bytes(name_data), encoding="utf-8") if tensor_name in tensor_names: - raise ValueError(f'Found duplicated tensor with name {tensor_name}') + raise ValueError(f"Found duplicated tensor with name {tensor_name}") tensor_names.add(tensor_name) ggml_type = GGMLQuantizationType(raw_dtype[0]) n_elems = int(np.prod(dims)) @@ -354,14 +401,16 @@ def _build_tensors(self, start_offs: int, fields: list[ReaderField]) -> None: item_count = n_bytes item_type = np.uint8 np_dims = quant_shape_to_byte_shape(np_dims, ggml_type) - tensors.append(ReaderTensor( - name = tensor_name, - tensor_type = ggml_type, - shape = dims, - n_elements = n_elems, - n_bytes = n_bytes, - data_offset = data_offs, - data = self._get(data_offs, item_type, item_count).reshape(np_dims), - field = field, - )) + tensors.append( + ReaderTensor( + name=tensor_name, + tensor_type=ggml_type, + shape=dims, + n_elements=n_elems, + n_bytes=n_bytes, + data_offset=data_offs, + data=self._get(data_offs, item_type, item_count).reshape(np_dims), + field=field, + ) + ) self.tensors = tensors diff --git a/smallthinker/gguf-py/gguf/gguf_writer.py b/smallthinker/gguf-py/gguf/gguf_writer.py index de6e45ae..4e29cda3 100644 --- a/smallthinker/gguf-py/gguf/gguf_writer.py +++ b/smallthinker/gguf-py/gguf/gguf_writer.py @@ -54,8 +54,8 @@ class GGUFValue: class WriterState(Enum): NO_FILE = auto() - EMPTY = auto() - HEADER = auto() + EMPTY = auto() + HEADER = auto() KV_DATA = auto() TI_DATA = auto() WEIGHTS = auto() @@ -69,22 +69,29 @@ class GGUFWriter: kv_data: list[dict[str, GGUFValue]] state: WriterState _simple_value_packing = { - GGUFValueType.UINT8: "B", - GGUFValueType.INT8: "b", - GGUFValueType.UINT16: "H", - GGUFValueType.INT16: "h", - GGUFValueType.UINT32: "I", - GGUFValueType.INT32: "i", + GGUFValueType.UINT8: "B", + GGUFValueType.INT8: "b", + GGUFValueType.UINT16: "H", + GGUFValueType.INT16: "h", + GGUFValueType.UINT32: "I", + GGUFValueType.INT32: "i", GGUFValueType.FLOAT32: "f", - GGUFValueType.UINT64: "Q", - GGUFValueType.INT64: "q", + GGUFValueType.UINT64: "Q", + GGUFValueType.INT64: "q", GGUFValueType.FLOAT64: "d", - GGUFValueType.BOOL: "?", + GGUFValueType.BOOL: "?", } def __init__( - self, path: os.PathLike[str] | str | None, arch: str, use_temp_file: bool = False, endianess: GGUFEndian = GGUFEndian.LITTLE, - split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False + self, + path: os.PathLike[str] | str | None, + arch: str, + use_temp_file: bool = False, + endianess: GGUFEndian = GGUFEndian.LITTLE, + split_max_tensors: int = 0, + split_max_size: int = 0, + dry_run: bool = False, + small_first_shard: bool = False, ): self.fout = None self.path = Path(path) if path else None @@ -99,9 +106,11 @@ def __init__( self.split_max_size = split_max_size self.dry_run = dry_run self.small_first_shard = small_first_shard - logger.info("gguf: This GGUF file is for {0} Endian only".format( - "Big" if self.endianess == GGUFEndian.BIG else "Little", - )) + logger.info( + "gguf: This GGUF file is for {0} Endian only".format( + "Big" if self.endianess == GGUFEndian.BIG else "Little", + ) + ) self.state = WriterState.NO_FILE if self.small_first_shard: @@ -130,7 +139,9 @@ def get_total_parameter_count(self) -> tuple[int, int, int, int]: elif name.endswith(".lora_b"): if last_lora_a is None or last_lora_a[0] != name[:-1] + "a": # Bail when the LoRA pair can't be found trivially - logger.warning("can't measure LoRA size correctly, tensor order is unusual") + logger.warning( + "can't measure LoRA size correctly, tensor order is unusual" + ) return 0, 0, 0, 0 else: shape = (*shape[:-1], last_lora_a[1].shape[-1]) @@ -138,7 +149,7 @@ def get_total_parameter_count(self) -> tuple[int, int, int, int]: size = prod(shape) if "_exps." in name: - expert_params += (size // shape[-3]) + expert_params += size // shape[-3] expert_sum += shape[-3] n_expert_tensors += 1 else: @@ -159,15 +170,26 @@ def get_total_parameter_count(self) -> tuple[int, int, int, int]: def format_shard_names(self, path: Path) -> list[Path]: if len(self.tensors) == 1: return [path] - return [path.with_name(SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors))) for i in range(len(self.tensors))] + return [ + path.with_name( + SHARD_NAME_FORMAT.format(path.stem, i + 1, len(self.tensors)) + ) + for i in range(len(self.tensors)) + ] def open_output_file(self, path: Path | None = None) -> None: - if self.state is WriterState.EMPTY and self.fout is not None and (path is None or path == self.path): + if ( + self.state is WriterState.EMPTY + and self.fout is not None + and (path is None or path == self.path) + ): # allow calling this multiple times as long as the path is the same return if self.state is not WriterState.NO_FILE: - raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + raise ValueError( + f"Expected output file to be not yet opened, got {self.state}" + ) if path is not None: self.path = path @@ -183,7 +205,9 @@ def print_plan(self) -> list[Path]: filenames = self.format_shard_names(self.path) assert len(filenames) == len(self.tensors) for name, tensors in zip(filenames, self.tensors): - logger.info(f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}") + logger.info( + f"{name}: n_tensors = {len(tensors)}, total_size = {GGUFWriter.format_n_bytes_to_str(sum(ti.nbytes for ti in tensors.values()))}" + ) if self.dry_run: logger.info("Dry run, not writing files") @@ -203,17 +227,23 @@ def add_shard_kv_data(self) -> None: self.kv_data.extend({} for _ in range(len(self.kv_data), total_splits)) for i, kv_data in enumerate(self.kv_data): kv_data[Keys.Split.LLM_KV_SPLIT_NO] = GGUFValue(i, GGUFValueType.UINT16) - kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue(total_splits, GGUFValueType.UINT16) - kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue(total_tensors, GGUFValueType.INT32) + kv_data[Keys.Split.LLM_KV_SPLIT_COUNT] = GGUFValue( + total_splits, GGUFValueType.UINT16 + ) + kv_data[Keys.Split.LLM_KV_SPLIT_TENSORS_COUNT] = GGUFValue( + total_tensors, GGUFValueType.INT32 + ) def write_header_to_file(self, path: Path | None = None) -> None: - if len(self.tensors) == 1 and (self.split_max_tensors != 0 or self.split_max_size != 0): + if len(self.tensors) == 1 and ( + self.split_max_tensors != 0 or self.split_max_size != 0 + ): logger.warning("Model fails split requirements, not splitting") self.open_output_file(path) if self.state is not WriterState.EMPTY: - raise ValueError(f'Expected output file to be empty, got {self.state}') + raise ValueError(f"Expected output file to be empty, got {self.state}") assert self.fout is not None assert len(self.fout) == len(self.tensors) @@ -222,7 +252,7 @@ def write_header_to_file(self, path: Path | None = None) -> None: self.add_shard_kv_data() for fout, tensors, kv_data in zip(self.fout, self.tensors, self.kv_data): - fout.write(self._pack(" None: def write_kv_data_to_file(self) -> None: if self.state is not WriterState.HEADER: - raise ValueError(f'Expected output file to contain the header, got {self.state}') + raise ValueError( + f"Expected output file to contain the header, got {self.state}" + ) assert self.fout is not None for fout, kv_data in zip(self.fout, self.kv_data): @@ -239,7 +271,9 @@ def write_kv_data_to_file(self) -> None: for key, val in kv_data.items(): kv_bytes += self._pack_val(key, GGUFValueType.STRING, add_vtype=False) - kv_bytes += self._pack_val(val.value, val.type, add_vtype=True, sub_type=val.sub_type) + kv_bytes += self._pack_val( + val.value, val.type, add_vtype=True, sub_type=val.sub_type + ) fout.write(kv_bytes) @@ -248,7 +282,9 @@ def write_kv_data_to_file(self) -> None: def write_ti_data_to_file(self) -> None: if self.state is not WriterState.KV_DATA: - raise ValueError(f'Expected output file to contain KV data, got {self.state}') + raise ValueError( + f"Expected output file to contain KV data, got {self.state}" + ) assert self.fout is not None for fout, tensors in zip(self.fout, self.tensors): @@ -269,14 +305,20 @@ def write_ti_data_to_file(self) -> None: fout.flush() self.state = WriterState.TI_DATA - def add_key_value(self, key: str, val: Any, vtype: GGUFValueType, sub_type: GGUFValueType | None = None) -> None: + def add_key_value( + self, + key: str, + val: Any, + vtype: GGUFValueType, + sub_type: GGUFValueType | None = None, + ) -> None: if any(key in kv_data for kv_data in self.kv_data): - raise ValueError(f'Duplicated key name {key!r}') + raise ValueError(f"Duplicated key name {key!r}") self.kv_data[0][key] = GGUFValue(value=val, type=vtype, sub_type=sub_type) def add_uint8(self, key: str, val: int) -> None: - self.add_key_value(key,val, GGUFValueType.UINT8) + self.add_key_value(key, val, GGUFValueType.UINT8) def add_int8(self, key: str, val: int) -> None: self.add_key_value(key, val, GGUFValueType.INT8) @@ -323,14 +365,20 @@ def ggml_pad(x: int, n: int) -> int: return ((x + n - 1) // n) * n def add_tensor_info( - self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype, - tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None, + self, + name: str, + tensor_shape: Sequence[int], + tensor_dtype: np.dtype, + tensor_nbytes: int, + raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.state is not WriterState.NO_FILE: - raise ValueError(f'Expected output file to be not yet opened, got {self.state}') + raise ValueError( + f"Expected output file to be not yet opened, got {self.state}" + ) if any(name in tensors for tensors in self.tensors): - raise ValueError(f'Duplicated tensor name {name!r}') + raise ValueError(f"Duplicated tensor name {name!r}") if raw_dtype is None: if tensor_dtype == np.float16: @@ -348,7 +396,9 @@ def add_tensor_info( elif tensor_dtype == np.int64: dtype = GGMLQuantizationType.I64 else: - raise ValueError("Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now") + raise ValueError( + "Only F16, F32, F64, I8, I16, I32, I64 tensors are supported for now" + ) else: dtype = raw_dtype if tensor_dtype == np.uint8: @@ -359,16 +409,22 @@ def add_tensor_info( if ( # split when over tensor limit self.split_max_tensors != 0 and len(self.tensors[-1]) >= self.split_max_tensors - ) or ( # split when over size limit + ) or ( # split when over size limit self.split_max_size != 0 - and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes > self.split_max_size + and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes + > self.split_max_size ): self.tensors.append({}) - self.tensors[-1][name] = TensorInfo(shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes) + self.tensors[-1][name] = TensorInfo( + shape=tensor_shape, dtype=dtype, nbytes=tensor_nbytes + ) def add_tensor( - self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, + self, + name: str, + tensor: np.ndarray[Any, Any], + raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None, ) -> None: if self.endianess == GGUFEndian.BIG: @@ -379,7 +435,9 @@ def add_tensor( self.temp_file = fp shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape - self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype) + self.add_tensor_info( + name, shape, tensor.dtype, tensor.nbytes, raw_dtype=raw_dtype + ) if self.temp_file is None: self.tensors[-1][name].tensor = tensor @@ -389,13 +447,21 @@ def add_tensor( self.write_padding(self.temp_file, tensor.nbytes) def write_padding(self, fp: IO[bytes], n: int, align: int | None = None) -> None: - pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n + pad = ( + GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) + - n + ) if pad != 0: fp.write(bytes([0] * pad)) def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: - if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: - raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') + if ( + self.state is not WriterState.TI_DATA + and self.state is not WriterState.WEIGHTS + ): + raise ValueError( + f"Expected output file to contain tensor info or weights, got {self.state}" + ) assert self.fout is not None if self.endianess == GGUFEndian.BIG: @@ -411,7 +477,9 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: # pop the first tensor info # TODO: cleaner way to get the first key - first_tensor_name = [name for name, _ in zip(self.tensors[file_id].keys(), range(1))][0] + first_tensor_name = [ + name for name, _ in zip(self.tensors[file_id].keys(), range(1)) + ][0] ti = self.tensors[file_id].pop(first_tensor_name) assert ti.nbytes == tensor.nbytes @@ -439,8 +507,15 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: total_bytes = sum(ti.nbytes for t in self.tensors for ti in t.values()) if len(self.fout) > 1: - shard_bar = tqdm(desc=f"Shard (0/{len(self.fout)})", total=None, unit="byte", unit_scale=True) - bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) + shard_bar = tqdm( + desc=f"Shard (0/{len(self.fout)})", + total=None, + unit="byte", + unit_scale=True, + ) + bar = tqdm( + desc="Writing", total=total_bytes, unit="byte", unit_scale=True + ) for i, (fout, tensors) in enumerate(zip(self.fout, self.tensors)): if shard_bar is not None: @@ -450,7 +525,9 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: # relying on the fact that Python dicts preserve insertion order (since 3.7) for ti in tensors.values(): - assert ti.tensor is not None # can only iterate once over the tensors + assert ( + ti.tensor is not None + ) # can only iterate once over the tensors assert ti.tensor.nbytes == ti.nbytes ti.tensor.tofile(fout) if shard_bar is not None: @@ -462,7 +539,9 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None: else: self.temp_file.seek(0) - shutil.copyfileobj(self.temp_file, self.fout[0 if not self.small_first_shard else 1]) + shutil.copyfileobj( + self.temp_file, self.fout[0 if not self.small_first_shard else 1] + ) self.flush() self.temp_file.close() @@ -568,10 +647,14 @@ def add_base_model_version(self, source_id: int, version: str) -> None: self.add_string(Keys.General.BASE_MODEL_VERSION.format(id=source_id), version) def add_base_model_organization(self, source_id: int, organization: str) -> None: - self.add_string(Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization) + self.add_string( + Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization + ) def add_base_model_description(self, source_id: int, description: str) -> None: - self.add_string(Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description) + self.add_string( + Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description + ) def add_base_model_url(self, source_id: int, url: str) -> None: self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url) @@ -598,10 +681,14 @@ def add_dataset_version(self, source_id: int, version: str) -> None: self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version) def add_dataset_organization(self, source_id: int, organization: str) -> None: - self.add_string(Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization) + self.add_string( + Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization + ) def add_dataset_description(self, source_id: int, description: str) -> None: - self.add_string(Keys.General.DATASET_DESCRIPTION.format(id=source_id), description) + self.add_string( + Keys.General.DATASET_DESCRIPTION.format(id=source_id), description + ) def add_dataset_url(self, source_id: int, url: str) -> None: self.add_string(Keys.General.DATASET_URL.format(id=source_id), url) @@ -652,7 +739,9 @@ def add_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length) def add_leading_dense_block_count(self, length: int) -> None: - self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + self.add_uint32( + Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length + ) def add_feed_forward_length(self, length: int | Sequence[int]) -> None: if isinstance(length, int): @@ -661,10 +750,14 @@ def add_feed_forward_length(self, length: int | Sequence[int]) -> None: self.add_array(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) def add_expert_feed_forward_length(self, length: int) -> None: - self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + self.add_uint32( + Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length + ) def add_expert_shared_feed_forward_length(self, length: int) -> None: - self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length) + self.add_uint32( + Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length + ) def add_parallel_residual(self, use: bool) -> None: self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use) @@ -757,7 +850,9 @@ def add_token_shift_count(self, count: int) -> None: self.add_uint32(Keys.LLM.TOKEN_SHIFT_COUNT.format(arch=self.arch), count) def add_interleave_moe_layer_step(self, value: int) -> None: - self.add_uint32(Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value) + self.add_uint32( + Keys.LLM.INTERLEAVE_MOE_LAYER_STEP.format(arch=self.arch), value + ) def add_layer_norm_eps(self, value: float) -> None: self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value) @@ -787,7 +882,9 @@ def add_iclr_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.ICLR_LORA_RANK.format(arch=self.arch), length) def add_value_residual_mix_lora_rank(self, length: int) -> None: - self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length) + self.add_uint32( + Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length + ) def add_gate_lora_rank(self, length: int) -> None: self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length) @@ -852,10 +949,14 @@ def add_tokenizer_model(self, model: str) -> None: def add_tokenizer_pre(self, pre: str) -> None: self.add_string(Keys.Tokenizer.PRE, pre) - def add_token_list(self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + def add_token_list( + self, tokens: Sequence[str] | Sequence[bytes] | Sequence[bytearray] + ) -> None: self.add_array(Keys.Tokenizer.LIST, tokens) - def add_token_merges(self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None: + def add_token_merges( + self, merges: Sequence[str] | Sequence[bytes] | Sequence[bytearray] + ) -> None: self.add_array(Keys.Tokenizer.MERGES, merges) def add_token_types(self, types: Sequence[TokenType] | Sequence[int]) -> None: @@ -906,18 +1007,22 @@ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: template_names = set() for choice in value: - name = choice.get('name', '') - template = choice.get('template') + name = choice.get("name", "") + template = choice.get("template") # Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it - name = ''.join((c if c in ascii_letters + digits else '_' for c in name)) + name = "".join( + (c if c in ascii_letters + digits else "_" for c in name) + ) if name and template is not None: - if name == 'default': + if name == "default": template_default = template else: template_names.add(name) - self.add_string(Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template) + self.add_string( + Keys.Tokenizer.CHAT_TEMPLATE_N.format(name=name), template + ) if template_names: self.add_array(Keys.Tokenizer.CHAT_TEMPLATES, list(template_names)) @@ -1018,12 +1123,18 @@ def add_audio_stack_factor(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value) def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes: - pack_prefix = '' + pack_prefix = "" if not skip_pack_prefix: - pack_prefix = '<' if self.endianess == GGUFEndian.LITTLE else '>' - return struct.pack(f'{pack_prefix}{fmt}', value) - - def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: GGUFValueType | None = None) -> bytes: + pack_prefix = "<" if self.endianess == GGUFEndian.LITTLE else ">" + return struct.pack(f"{pack_prefix}{fmt}", value) + + def _pack_val( + self, + val: Any, + vtype: GGUFValueType, + add_vtype: bool, + sub_type: GGUFValueType | None = None, + ) -> bytes: kv_data = bytearray() if add_vtype: @@ -1031,7 +1142,9 @@ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: G pack_fmt = self._simple_value_packing.get(vtype) if pack_fmt is not None: - kv_data += self._pack(pack_fmt, val, skip_pack_prefix = vtype == GGUFValueType.BOOL) + kv_data += self._pack( + pack_fmt, val, skip_pack_prefix=vtype == GGUFValueType.BOOL + ) elif vtype == GGUFValueType.STRING: encoded_val = val.encode("utf-8") if isinstance(val, str) else val kv_data += self._pack("Q", len(encoded_val)) @@ -1051,7 +1164,9 @@ def _pack_val(self, val: Any, vtype: GGUFValueType, add_vtype: bool, sub_type: G else: ltype = GGUFValueType.get_type(val[0]) if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]): - raise ValueError("All items in a GGUF array should be of the same type") + raise ValueError( + "All items in a GGUF array should be of the same type" + ) kv_data += self._pack("I", ltype) kv_data += self._pack("Q", len(val)) for item in val: diff --git a/smallthinker/gguf-py/gguf/lazy.py b/smallthinker/gguf-py/gguf/lazy.py index f9bcadae..0d6c24bd 100644 --- a/smallthinker/gguf-py/gguf/lazy.py +++ b/smallthinker/gguf-py/gguf/lazy.py @@ -13,7 +13,9 @@ class LazyMeta(ABCMeta): - def __new__(cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs): + def __new__( + cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs + ): def __getattr__(self, name: str) -> Any: meta_attr = getattr(self._meta, name) if callable(meta_attr): @@ -41,6 +43,7 @@ def wrapped_special_op(self, *args, **kwargs): getattr(type(self)._tensor_type, op_name), meta_noop=meta_noop, )(self, *args, **kwargs) + return wrapped_special_op # special methods bypass __getattr__, so they need to be added manually @@ -48,11 +51,48 @@ def wrapped_special_op(self, *args, **kwargs): # NOTE: doing this from a metaclass is very convenient # TODO: make this even more comprehensive for binary_op in ( - "lt", "le", "eq", "ne", "ge", "gt", "not" - "abs", "add", "and", "floordiv", "invert", "lshift", "mod", "mul", "matmul", - "neg", "or", "pos", "pow", "rshift", "sub", "truediv", "xor", - "iadd", "iand", "ifloordiv", "ilshift", "imod", "imul", "ior", "irshift", "isub", "ixor", - "radd", "rand", "rfloordiv", "rmul", "ror", "rpow", "rsub", "rtruediv", "rxor", + "lt", + "le", + "eq", + "ne", + "ge", + "gt", + "not" "abs", + "add", + "and", + "floordiv", + "invert", + "lshift", + "mod", + "mul", + "matmul", + "neg", + "or", + "pos", + "pow", + "rshift", + "sub", + "truediv", + "xor", + "iadd", + "iand", + "ifloordiv", + "ilshift", + "imod", + "imul", + "ior", + "irshift", + "isub", + "ixor", + "radd", + "rand", + "rfloordiv", + "rmul", + "ror", + "rpow", + "rsub", + "rtruediv", + "rxor", ): attr_name = f"__{binary_op}__" # the result of these operators usually has the same shape and dtype as the input, @@ -60,7 +100,9 @@ def wrapped_special_op(self, *args, **kwargs): namespace[attr_name] = mk_wrap(attr_name, meta_noop=True) for special_op in ( - "getitem", "setitem", "len", + "getitem", + "setitem", + "len", ): attr_name = f"__{special_op}__" namespace[attr_name] = mk_wrap(attr_name, meta_noop=False) @@ -77,7 +119,15 @@ class LazyBase(ABC, metaclass=LazyMeta): _kwargs: dict[str, Any] _func: Callable[[Any], Any] | None - def __init__(self, *, meta: Any, data: Any | None = None, args: tuple = (), kwargs: dict[str, Any] | None = None, func: Callable[[Any], Any] | None = None): + def __init__( + self, + *, + meta: Any, + data: Any | None = None, + args: tuple = (), + kwargs: dict[str, Any] | None = None, + func: Callable[[Any], Any] | None = None, + ): super().__init__() self._meta = meta self._data = data @@ -107,7 +157,17 @@ def _recurse_apply(o: Any, fn: Callable[[Any], Any]) -> Any: return o @classmethod - def _wrap_fn(cls, fn: Callable, *, use_self: LazyBase | None = None, meta_noop: bool | DTypeLike | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] = False) -> Callable[[Any], Any]: + def _wrap_fn( + cls, + fn: Callable, + *, + use_self: LazyBase | None = None, + meta_noop: ( + bool + | DTypeLike + | tuple[DTypeLike, Callable[[tuple[int, ...]], tuple[int, ...]]] + ) = False, + ) -> Callable[[Any], Any]: def wrapped_fn(*args, **kwargs): if kwargs is None: kwargs = {} @@ -138,8 +198,12 @@ def wrapped_fn(*args, **kwargs): res = cls.meta_with_dtype_and_shape(meta_noop, res.shape) if isinstance(res, cls._tensor_type): - return cls(meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn) - elif isinstance(res, tuple) and all(isinstance(t, cls._tensor_type) for t in res): + return cls( + meta=cls.eager_to_meta(res), args=args, kwargs=kwargs, func=fn + ) + elif isinstance(res, tuple) and all( + isinstance(t, cls._tensor_type) for t in res + ): # share the evaluation between lazy tuple elements shared_args: list = [args, None] @@ -148,13 +212,23 @@ def eager_tuple_element(a: list[Any], i: int = 0, /, **kw) -> LazyBase: if a[1] is None: a[1] = fn(*a[0], **kw) return a[1][i] - return tuple(cls(meta=cls.eager_to_meta(res[i]), args=(shared_args, i), kwargs=kwargs, func=eager_tuple_element) for i in range(len(res))) + + return tuple( + cls( + meta=cls.eager_to_meta(res[i]), + args=(shared_args, i), + kwargs=kwargs, + func=eager_tuple_element, + ) + for i in range(len(res)) + ) else: del res # not needed # non-tensor return likely relies on the contents of the args # (e.g. the result of torch.equal) eager_args = cls.to_eager(args) return fn(*eager_args, **kwargs) + return wrapped_fn @classmethod @@ -185,7 +259,8 @@ def eager_to_meta(cls, t: Any) -> Any: # must be overridden, meta tensor init is backend-specific @classmethod @abstractmethod - def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: pass + def meta_with_dtype_and_shape(cls, dtype: Any, shape: Any) -> Any: + pass @classmethod def from_eager(cls, t: Any) -> Any: @@ -204,7 +279,9 @@ class LazyNumpyTensor(LazyBase): shape: tuple[int, ...] # Makes the type checker happy in quants.py @classmethod - def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> np.ndarray[Any, Any]: + def meta_with_dtype_and_shape( + cls, dtype: DTypeLike, shape: tuple[int, ...] + ) -> np.ndarray[Any, Any]: # The initial idea was to use np.nan as the fill value, # but non-float types like np.int16 can't use that. # So zero it is. @@ -213,8 +290,16 @@ def meta_with_dtype_and_shape(cls, dtype: DTypeLike, shape: tuple[int, ...]) -> def astype(self, dtype, *args, **kwargs): meta = type(self).meta_with_dtype_and_shape(dtype, self._meta.shape) - full_args = (self, dtype,) + args - return type(self)(meta=meta, args=full_args, kwargs=kwargs, func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs))) + full_args = ( + self, + dtype, + ) + args + return type(self)( + meta=meta, + args=full_args, + kwargs=kwargs, + func=(lambda a, *args, **kwargs: a.astype(*args, **kwargs)), + ) def tofile(self, *args, **kwargs): eager = LazyNumpyTensor.to_eager(self) diff --git a/smallthinker/gguf-py/gguf/metadata.py b/smallthinker/gguf-py/gguf/metadata.py index e807f434..3af3a347 100644 --- a/smallthinker/gguf-py/gguf/metadata.py +++ b/smallthinker/gguf-py/gguf/metadata.py @@ -44,7 +44,12 @@ class Metadata: datasets: Optional[list[dict]] = None @staticmethod - def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Path] = None, model_name: Optional[str] = None, total_params: int = 0) -> Metadata: + def load( + metadata_override_path: Optional[Path] = None, + model_path: Optional[Path] = None, + model_name: Optional[str] = None, + total_params: int = 0, + ) -> Metadata: # This grabs as many contextual authorship metadata as possible from the model repository # making any conversion as required to match the gguf kv store metadata format # as well as giving users the ability to override any authorship metadata that may be incorrect @@ -57,45 +62,77 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat # TODO: load adapter_config.json when possible, it usually contains the base model of the LoRA adapter # heuristics - metadata = Metadata.apply_metadata_heuristic(metadata, model_card, hf_params, model_path, total_params) + metadata = Metadata.apply_metadata_heuristic( + metadata, model_card, hf_params, model_path, total_params + ) # Metadata Override File Provided # This is based on LLM_KV_NAMES mapping in llama.cpp metadata_override = Metadata.load_metadata_override(metadata_override_path) - metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) - metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) - metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) - metadata.organization = metadata_override.get(Keys.General.ORGANIZATION, metadata.organization) - - metadata.finetune = metadata_override.get(Keys.General.FINETUNE, metadata.finetune) - metadata.basename = metadata_override.get(Keys.General.BASENAME, metadata.basename) - - metadata.description = metadata_override.get(Keys.General.DESCRIPTION, metadata.description) - metadata.quantized_by = metadata_override.get(Keys.General.QUANTIZED_BY, metadata.quantized_by) - - metadata.size_label = metadata_override.get(Keys.General.SIZE_LABEL, metadata.size_label) - metadata.license_name = metadata_override.get(Keys.General.LICENSE_NAME, metadata.license_name) - metadata.license_link = metadata_override.get(Keys.General.LICENSE_LINK, metadata.license_link) - - metadata.url = metadata_override.get(Keys.General.URL, metadata.url) - metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) - metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) - metadata.repo_url = metadata_override.get(Keys.General.REPO_URL, metadata.repo_url) - - metadata.source_url = metadata_override.get(Keys.General.SOURCE_URL, metadata.source_url) - metadata.source_doi = metadata_override.get(Keys.General.SOURCE_DOI, metadata.source_doi) - metadata.source_uuid = metadata_override.get(Keys.General.SOURCE_UUID, metadata.source_uuid) - metadata.source_repo_url = metadata_override.get(Keys.General.SOURCE_REPO_URL, metadata.source_repo_url) + metadata.name = metadata_override.get(Keys.General.NAME, metadata.name) + metadata.author = metadata_override.get(Keys.General.AUTHOR, metadata.author) + metadata.version = metadata_override.get(Keys.General.VERSION, metadata.version) + metadata.organization = metadata_override.get( + Keys.General.ORGANIZATION, metadata.organization + ) + + metadata.finetune = metadata_override.get( + Keys.General.FINETUNE, metadata.finetune + ) + metadata.basename = metadata_override.get( + Keys.General.BASENAME, metadata.basename + ) + + metadata.description = metadata_override.get( + Keys.General.DESCRIPTION, metadata.description + ) + metadata.quantized_by = metadata_override.get( + Keys.General.QUANTIZED_BY, metadata.quantized_by + ) + + metadata.size_label = metadata_override.get( + Keys.General.SIZE_LABEL, metadata.size_label + ) + metadata.license_name = metadata_override.get( + Keys.General.LICENSE_NAME, metadata.license_name + ) + metadata.license_link = metadata_override.get( + Keys.General.LICENSE_LINK, metadata.license_link + ) + + metadata.url = metadata_override.get(Keys.General.URL, metadata.url) + metadata.doi = metadata_override.get(Keys.General.DOI, metadata.doi) + metadata.uuid = metadata_override.get(Keys.General.UUID, metadata.uuid) + metadata.repo_url = metadata_override.get( + Keys.General.REPO_URL, metadata.repo_url + ) + + metadata.source_url = metadata_override.get( + Keys.General.SOURCE_URL, metadata.source_url + ) + metadata.source_doi = metadata_override.get( + Keys.General.SOURCE_DOI, metadata.source_doi + ) + metadata.source_uuid = metadata_override.get( + Keys.General.SOURCE_UUID, metadata.source_uuid + ) + metadata.source_repo_url = metadata_override.get( + Keys.General.SOURCE_REPO_URL, metadata.source_repo_url + ) # Base Models is received here as an array of models - metadata.base_models = metadata_override.get("general.base_models", metadata.base_models) + metadata.base_models = metadata_override.get( + "general.base_models", metadata.base_models + ) # Datasets is received here as an array of datasets - metadata.datasets = metadata_override.get("general.datasets", metadata.datasets) + metadata.datasets = metadata_override.get("general.datasets", metadata.datasets) - metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) - metadata.languages = metadata_override.get(Keys.General.LANGUAGES, metadata.languages) + metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags) + metadata.languages = metadata_override.get( + Keys.General.LANGUAGES, metadata.languages + ) # Direct Metadata Override (via direct cli argument) if model_name is not None: @@ -104,7 +141,9 @@ def load(metadata_override_path: Optional[Path] = None, model_path: Optional[Pat return metadata @staticmethod - def load_metadata_override(metadata_override_path: Optional[Path] = None) -> dict[str, Any]: + def load_metadata_override( + metadata_override_path: Optional[Path] = None, + ) -> dict[str, Any]: if metadata_override_path is None or not metadata_override_path.is_file(): return {} @@ -136,21 +175,23 @@ def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: return {} for line in lines[1:]: if line == "---": - break # End of frontmatter + break # End of frontmatter else: lines_yaml.append(line) yaml_content = "\n".join(lines_yaml) + "\n" # Quick hack to fix the Norway problem # https://hitchdev.com/strictyaml/why/implicit-typing-removed/ - yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") + yaml_content = yaml_content.replace("- no\n", '- "no"\n') if yaml_content: data = yaml.safe_load(yaml_content) if isinstance(data, dict): return data else: - logger.error(f"while reading YAML model card frontmatter, data is {type(data)} instead of dict") + logger.error( + f"while reading YAML model card frontmatter, data is {type(data)} instead of dict" + ) return {} else: return {} @@ -171,10 +212,21 @@ def load_hf_parameters(model_path: Optional[Path] = None) -> dict[str, Any]: @staticmethod def id_to_title(string): # Convert capitalization into title form unless acronym or version number - return ' '.join([w.title() if w.islower() and not re.match(r'^(v\d+(?:\.\d+)*|\d.*)$', w) else w for w in string.strip().replace('-', ' ').split()]) + return " ".join( + [ + ( + w.title() + if w.islower() and not re.match(r"^(v\d+(?:\.\d+)*|\d.*)$", w) + else w + ) + for w in string.strip().replace("-", " ").split() + ] + ) @staticmethod - def get_model_id_components(model_id: Optional[str] = None, total_params: int = 0) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: + def get_model_id_components( + model_id: Optional[str] = None, total_params: int = 0 + ) -> tuple[str | None, str | None, str | None, str | None, str | None, str | None]: # Huggingface often store model id as '/' # so let's parse it and apply some heuristics if possible for model name components @@ -182,24 +234,28 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = # model ID missing return None, None, None, None, None, None - if ' ' in model_id: + if " " in model_id: # model ID is actually a normal human sentence # which means its most likely a normal model name only # not part of the hugging face naming standard, but whatever return model_id, None, None, None, None, None - if '/' in model_id: + if "/" in model_id: # model ID (huggingface style) - org_component, model_full_name_component = model_id.split('/', 1) + org_component, model_full_name_component = model_id.split("/", 1) else: # model ID but missing org components org_component, model_full_name_component = None, model_id # Check if we erroneously matched against './' or '../' etc... - if org_component is not None and len(org_component) > 0 and org_component[0] == '.': + if ( + org_component is not None + and len(org_component) > 0 + and org_component[0] == "." + ): org_component = None - name_parts: list[str] = model_full_name_component.split('-') + name_parts: list[str] = model_full_name_component.split("-") # Remove empty parts for i in reversed(range(len(name_parts))): @@ -213,14 +269,18 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = # Annotate the name for i, part in enumerate(name_parts): # Version - if re.fullmatch(r'(v|iter)?\d+([.]\d+)*', part, re.IGNORECASE): + if re.fullmatch(r"(v|iter)?\d+([.]\d+)*", part, re.IGNORECASE): name_types[i].add("version") # Quant type (should not be there for base models, but still annotated) - elif re.fullmatch(r'i?q\d(_\w)*|b?fp?(16|32)', part, re.IGNORECASE): + elif re.fullmatch(r"i?q\d(_\w)*|b?fp?(16|32)", part, re.IGNORECASE): name_types[i].add("type") name_parts[i] = part.upper() # Model size - elif i > 0 and re.fullmatch(r'(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)', part, re.IGNORECASE): + elif i > 0 and re.fullmatch( + r"(([A]|\d+[x])?\d+([._]\d+)?[KMBT][\d]?|small|mini|medium|large|x?xl)", + part, + re.IGNORECASE, + ): part = part.replace("_", ".") # Handle weird bloom-7b1 notation if part[-1].isdecimal(): @@ -231,14 +291,19 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = part = part[:-1] + part[-1].upper() if total_params != 0: try: - label_params = float(part[:-1]) * pow(1000, " KMBT".find(part[-1])) + label_params = float(part[:-1]) * pow( + 1000, " KMBT".find(part[-1]) + ) # Only use it as a size label if it's close or bigger than the model size # Note that LoRA adapters don't necessarily include all layers, # so this is why bigger label sizes are accepted. # Do not use the size label when it's smaller than 1/8 of the model size - if (total_params < 0 and label_params < abs(total_params) // 8) or ( + if ( + total_params < 0 and label_params < abs(total_params) // 8 + ) or ( # Check both directions when the current model isn't a LoRA adapter - total_params > 0 and abs(label_params - total_params) > 7 * total_params // 8 + total_params > 0 + and abs(label_params - total_params) > 7 * total_params // 8 ): # Likely a context length name_types[i].add("finetune") @@ -251,7 +316,9 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = name_types[i].add("size_label") name_parts[i] = part # Some easy to recognize finetune names - elif i > 0 and re.fullmatch(r'chat|instruct|vision|lora', part, re.IGNORECASE): + elif i > 0 and re.fullmatch( + r"chat|instruct|vision|lora", part, re.IGNORECASE + ): if total_params < 0 and part.lower() == "lora": # ignore redundant "lora" in the finetune part when the output is a lora adapter name_types[i].add("type") @@ -260,7 +327,12 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = # Ignore word-based size labels when there is at least a number-based one present # TODO: should word-based size labels always be removed instead? - if any(c.isdecimal() for n, t in zip(name_parts, name_types) if "size_label" in t for c in n): + if any( + c.isdecimal() + for n, t in zip(name_parts, name_types) + if "size_label" in t + for c in n + ): for n, t in zip(name_parts, name_types): if "size_label" in t: if all(c.isalpha() for c in n): @@ -284,22 +356,55 @@ def get_model_id_components(model_id: Optional[str] = None, total_params: int = else: break - basename = "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) or None + basename = ( + "-".join(n for n, t in zip(name_parts, name_types) if "basename" in t) + or None + ) # Deduplicate size labels using order-preserving 'dict' ('set' seems to sort the keys) - size_label = "-".join(dict.fromkeys(s for s, t in zip(name_parts, name_types) if "size_label" in t).keys()) or None - finetune = "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) or None + size_label = ( + "-".join( + dict.fromkeys( + s for s, t in zip(name_parts, name_types) if "size_label" in t + ).keys() + ) + or None + ) + finetune = ( + "-".join(f for f, t in zip(name_parts, name_types) if "finetune" in t) + or None + ) # TODO: should the basename version always be excluded? # NOTE: multiple finetune versions are joined together - version = "-".join(v for v, t, in zip(name_parts, name_types) if "version" in t and "basename" not in t) or None + version = ( + "-".join( + v + for v, t, in zip(name_parts, name_types) + if "version" in t and "basename" not in t + ) + or None + ) if size_label is None and finetune is None and version is None: # Too ambiguous, output nothing basename = None - return model_full_name_component, org_component, basename, finetune, version, size_label + return ( + model_full_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) @staticmethod - def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = None, hf_params: Optional[dict] = None, model_path: Optional[Path] = None, total_params: int = 0) -> Metadata: + def apply_metadata_heuristic( + metadata: Metadata, + model_card: Optional[dict] = None, + hf_params: Optional[dict] = None, + model_path: Optional[Path] = None, + total_params: int = 0, + ) -> Metadata: # Reference Model Card Metadata: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 # Model Card Heuristics @@ -307,7 +412,10 @@ def apply_metadata_heuristic(metadata: Metadata, model_card: Optional[dict] = No if model_card is not None: def use_model_card_metadata(metadata_key: str, model_card_key: str): - if model_card_key in model_card and getattr(metadata, metadata_key, None) is None: + if ( + model_card_key in model_card + and getattr(metadata, metadata_key, None) is None + ): setattr(metadata, metadata_key, model_card.get(model_card_key)) def use_array_model_card_metadata(metadata_key: str, model_card_key: str): @@ -368,12 +476,21 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): use_model_card_metadata("author", "model_creator") use_model_card_metadata("basename", "model_type") - if "base_model" in model_card or "base_models" in model_card or "base_model_sources" in model_card: + if ( + "base_model" in model_card + or "base_models" in model_card + or "base_model_sources" in model_card + ): # This represents the parent models that this is based on # Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges) # Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md metadata_base_models = [] - base_model_value = model_card.get("base_model", model_card.get("base_models", model_card.get("base_model_sources", None))) + base_model_value = model_card.get( + "base_model", + model_card.get( + "base_models", model_card.get("base_model_sources", None) + ), + ) if base_model_value is not None: if isinstance(base_model_value, str): @@ -388,50 +505,94 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): # NOTE: model size of base model is assumed to be similar to the size of the current model base_model = {} if isinstance(model_id, str): - if model_id.startswith("http://") or model_id.startswith("https://") or model_id.startswith("ssh://"): + if ( + model_id.startswith("http://") + or model_id.startswith("https://") + or model_id.startswith("ssh://") + ): base_model["repo_url"] = model_id # Check if Hugging Face ID is present in URL if "huggingface.co" in model_id: - match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", model_id) + match = re.match( + r"https?://huggingface.co/([^/]+/[^/]+)$", model_id + ) if match: model_id_component = match.group(1) - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id_component, total_params) + ( + model_full_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components( + model_id_component, total_params + ) # Populate model dictionary with extracted components if model_full_name_component is not None: - base_model["name"] = Metadata.id_to_title(model_full_name_component) + base_model["name"] = Metadata.id_to_title( + model_full_name_component + ) if org_component is not None: - base_model["organization"] = Metadata.id_to_title(org_component) + base_model["organization"] = ( + Metadata.id_to_title(org_component) + ) if version is not None: base_model["version"] = version else: # Likely a Hugging Face ID - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + ( + model_full_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components(model_id, total_params) # Populate model dictionary with extracted components if model_full_name_component is not None: - base_model["name"] = Metadata.id_to_title(model_full_name_component) + base_model["name"] = Metadata.id_to_title( + model_full_name_component + ) if org_component is not None: - base_model["organization"] = Metadata.id_to_title(org_component) + base_model["organization"] = Metadata.id_to_title( + org_component + ) if version is not None: base_model["version"] = version - if org_component is not None and model_full_name_component is not None: - base_model["repo_url"] = f"https://huggingface.co/{org_component}/{model_full_name_component}" + if ( + org_component is not None + and model_full_name_component is not None + ): + base_model["repo_url"] = ( + f"https://huggingface.co/{org_component}/{model_full_name_component}" + ) elif isinstance(model_id, dict): base_model = model_id else: - logger.error(f"base model entry '{str(model_id)}' not in a known format") + logger.error( + f"base model entry '{str(model_id)}' not in a known format" + ) metadata.base_models.append(base_model) - if "datasets" in model_card or "dataset" in model_card or "dataset_sources" in model_card: + if ( + "datasets" in model_card + or "dataset" in model_card + or "dataset_sources" in model_card + ): # This represents the datasets that this was trained from metadata_datasets = [] - dataset_value = model_card.get("datasets", model_card.get("dataset", model_card.get("dataset_sources", None))) + dataset_value = model_card.get( + "datasets", + model_card.get("dataset", model_card.get("dataset_sources", None)), + ) if dataset_value is not None: if isinstance(dataset_value, str): @@ -451,38 +612,74 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): # Check if Hugging Face ID is present in URL if "huggingface.co" in dataset_id: - match = re.match(r"https?://huggingface.co/([^/]+/[^/]+)$", dataset_id) + match = re.match( + r"https?://huggingface.co/([^/]+/[^/]+)$", + dataset_id, + ) if match: dataset_id_component = match.group(1) - dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id_component, total_params) + ( + dataset_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components( + dataset_id_component, total_params + ) # Populate dataset dictionary with extracted components if dataset_name_component is not None: - dataset["name"] = Metadata.id_to_title(dataset_name_component) + dataset["name"] = Metadata.id_to_title( + dataset_name_component + ) if org_component is not None: - dataset["organization"] = Metadata.id_to_title(org_component) + dataset["organization"] = Metadata.id_to_title( + org_component + ) if version is not None: dataset["version"] = version else: # Likely a Hugging Face ID - dataset_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(dataset_id, total_params) + ( + dataset_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components( + dataset_id, total_params + ) # Populate dataset dictionary with extracted components if dataset_name_component is not None: - dataset["name"] = Metadata.id_to_title(dataset_name_component) + dataset["name"] = Metadata.id_to_title( + dataset_name_component + ) if org_component is not None: - dataset["organization"] = Metadata.id_to_title(org_component) + dataset["organization"] = Metadata.id_to_title( + org_component + ) if version is not None: dataset["version"] = version - if org_component is not None and dataset_name_component is not None: - dataset["repo_url"] = f"https://huggingface.co/{org_component}/{dataset_name_component}" + if ( + org_component is not None + and dataset_name_component is not None + ): + dataset["repo_url"] = ( + f"https://huggingface.co/{org_component}/{dataset_name_component}" + ) elif isinstance(dataset_id, dict): dataset = dataset_id else: - logger.error(f"dataset entry '{str(dataset_id)}' not in a known format") + logger.error( + f"dataset entry '{str(dataset_id)}' not in a known format" + ) metadata.datasets.append(dataset) @@ -502,11 +699,18 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): if hf_params is not None: hf_name_or_path = hf_params.get("_name_or_path") - if hf_name_or_path is not None and hf_name_or_path.count('/') <= 1: + if hf_name_or_path is not None and hf_name_or_path.count("/") <= 1: # Use _name_or_path only if its actually a model name and not some computer path # e.g. 'meta-llama/Llama-2-7b-hf' model_id = hf_name_or_path - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + ( + model_full_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components(model_id, total_params) if metadata.name is None and model_full_name_component is not None: metadata.name = Metadata.id_to_title(model_full_name_component) if metadata.organization is None and org_component is not None: @@ -524,7 +728,14 @@ def use_array_model_card_metadata(metadata_key: str, model_card_key: str): ############################################ if model_path is not None: model_id = model_path.name - model_full_name_component, org_component, basename, finetune, version, size_label = Metadata.get_model_id_components(model_id, total_params) + ( + model_full_name_component, + org_component, + basename, + finetune, + version, + size_label, + ) = Metadata.get_model_id_components(model_id, total_params) if metadata.name is None and model_full_name_component is not None: metadata.name = Metadata.id_to_title(model_full_name_component) if metadata.organization is None and org_component is not None: @@ -602,9 +813,13 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): if "version" in base_model_entry: gguf_writer.add_base_model_version(key, base_model_entry["version"]) if "organization" in base_model_entry: - gguf_writer.add_base_model_organization(key, base_model_entry["organization"]) + gguf_writer.add_base_model_organization( + key, base_model_entry["organization"] + ) if "description" in base_model_entry: - gguf_writer.add_base_model_description(key, base_model_entry["description"]) + gguf_writer.add_base_model_description( + key, base_model_entry["description"] + ) if "url" in base_model_entry: gguf_writer.add_base_model_url(key, base_model_entry["url"]) if "doi" in base_model_entry: @@ -612,7 +827,9 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): if "uuid" in base_model_entry: gguf_writer.add_base_model_uuid(key, base_model_entry["uuid"]) if "repo_url" in base_model_entry: - gguf_writer.add_base_model_repo_url(key, base_model_entry["repo_url"]) + gguf_writer.add_base_model_repo_url( + key, base_model_entry["repo_url"] + ) if self.datasets is not None: gguf_writer.add_dataset_count(len(self.datasets)) @@ -624,9 +841,13 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter): if "version" in dataset_entry: gguf_writer.add_dataset_version(key, dataset_entry["version"]) if "organization" in dataset_entry: - gguf_writer.add_dataset_organization(key, dataset_entry["organization"]) + gguf_writer.add_dataset_organization( + key, dataset_entry["organization"] + ) if "description" in dataset_entry: - gguf_writer.add_dataset_description(key, dataset_entry["description"]) + gguf_writer.add_dataset_description( + key, dataset_entry["description"] + ) if "url" in dataset_entry: gguf_writer.add_dataset_url(key, dataset_entry["url"]) if "doi" in dataset_entry: diff --git a/smallthinker/gguf-py/gguf/quants.py b/smallthinker/gguf-py/gguf/quants.py index 7c26829e..df4d50a2 100644 --- a/smallthinker/gguf-py/gguf/quants.py +++ b/smallthinker/gguf-py/gguf/quants.py @@ -11,22 +11,35 @@ import numpy as np -def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: +def quant_shape_to_byte_shape( + shape: Sequence[int], quant_type: GGMLQuantizationType +) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[quant_type] if shape[-1] % block_size != 0: - raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})") + raise ValueError( + f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})" + ) return (*shape[:-1], shape[-1] // block_size * type_size) -def quant_shape_from_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: +def quant_shape_from_byte_shape( + shape: Sequence[int], quant_type: GGMLQuantizationType +) -> tuple[int, ...]: block_size, type_size = GGML_QUANT_SIZES[quant_type] if shape[-1] % type_size != 0: - raise ValueError(f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})") + raise ValueError( + f"Quantized tensor bytes per row ({shape[-1]}) is not a multiple of {quant_type.name} type size ({type_size})" + ) return (*shape[:-1], shape[-1] // type_size * block_size) # This is faster than np.vectorize and np.apply_along_axis because it works on more than one row at a time -def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.ndarray, otype: DTypeLike, oshape: tuple[int, ...]) -> np.ndarray: +def _apply_over_grouped_rows( + func: Callable[[np.ndarray], np.ndarray], + arr: np.ndarray, + otype: DTypeLike, + oshape: tuple[int, ...], +) -> np.ndarray: rows = arr.reshape((-1, arr.shape[-1])) osize = 1 for dim in oshape: @@ -34,7 +47,11 @@ def _apply_over_grouped_rows(func: Callable[[np.ndarray], np.ndarray], arr: np.n out = np.empty(shape=osize, dtype=otype) # compute over groups of 16 rows (arbitrary, but seems good for performance) n_groups = (rows.shape[0] // 16) or 1 - np.concatenate([func(group).ravel() for group in np.array_split(rows, n_groups)], axis=0, out=out) + np.concatenate( + [func(group).ravel() for group in np.array_split(rows, n_groups)], + axis=0, + out=out, + ) return out.reshape(oshape) @@ -64,7 +81,9 @@ def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: elif (q := _type_traits.get(qtype)) is not None: return q.quantize(data) else: - raise NotImplementedError(f"Quantization for {qtype.name} is not yet implemented") + raise NotImplementedError( + f"Quantization for {qtype.name} is not yet implemented" + ) def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: @@ -75,7 +94,9 @@ def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: elif (q := _type_traits.get(qtype)) is not None: return q.dequantize(data) else: - raise NotImplementedError(f"Dequantization for {qtype.name} is not yet implemented") + raise NotImplementedError( + f"Dequantization for {qtype.name} is not yet implemented" + ) class __Quant(ABC): @@ -95,12 +116,10 @@ def __init_subclass__(cls, qtype: GGMLQuantizationType) -> None: cls.qtype = qtype cls.block_size, cls.type_size = GGML_QUANT_SIZES[qtype] cls.__quantize_lazy = LazyNumpyTensor._wrap_fn( - cls.__quantize_array, - meta_noop=(np.uint8, cls.__shape_to_bytes) + cls.__quantize_array, meta_noop=(np.uint8, cls.__shape_to_bytes) ) cls.__dequantize_lazy = LazyNumpyTensor._wrap_fn( - cls.__dequantize_array, - meta_noop=(np.float32, cls.__shape_from_bytes) + cls.__dequantize_array, meta_noop=(np.float32, cls.__shape_from_bytes) ) assert qtype not in _type_traits _type_traits[qtype] = cls @@ -117,10 +136,14 @@ def init_grid(cls): grid = np.frombuffer(cls.grid_hex, dtype=np.uint8) # decode hexadecimal chars from grid grid = grid.reshape((-1, 2)) - grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array([4, 0], dtype=np.uint8).reshape((1, 2)) + grid = (np.where(grid > 0x40, grid + 9, grid) & 0x0F) << np.array( + [4, 0], dtype=np.uint8 + ).reshape((1, 2)) grid = grid[..., 0] | grid[..., 1] # unpack the grid values - grid = grid.reshape((-1, 1)) >> np.array([i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8).reshape((1, elems_per_byte)) + grid = grid.reshape((-1, 1)) >> np.array( + [i for i in range(0, 8, 8 // elems_per_byte)], dtype=np.uint8 + ).reshape((1, elems_per_byte)) grid = (grid & ((1 << bits_per_elem) - 1)).reshape((-1, 1)) grid_map = np.array(cls.grid_map, dtype=np.float32).reshape((1, -1)) grid = np.take_along_axis(grid_map, grid, axis=-1) @@ -168,12 +191,22 @@ def __shape_from_bytes(cls, shape: Sequence[int]): @classmethod def __quantize_array(cls, array: np.ndarray) -> np.ndarray: - return _apply_over_grouped_rows(cls.quantize_rows, arr=array, otype=np.uint8, oshape=cls.__shape_to_bytes(array.shape)) + return _apply_over_grouped_rows( + cls.quantize_rows, + arr=array, + otype=np.uint8, + oshape=cls.__shape_to_bytes(array.shape), + ) @classmethod def __dequantize_array(cls, array: np.ndarray) -> np.ndarray: cls.init_grid() - return _apply_over_grouped_rows(cls.dequantize_rows, arr=array, otype=np.float32, oshape=cls.__shape_from_bytes(array.shape)) + return _apply_over_grouped_rows( + cls.dequantize_rows, + arr=array, + otype=np.float32, + oshape=cls.__shape_from_bytes(array.shape), + ) @classmethod def __quantize_lazy(cls, lazy_tensor: LazyNumpyTensor, /) -> Any: @@ -190,7 +223,9 @@ def can_quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> bool: @classmethod def quantize(cls, tensor: np.ndarray | LazyNumpyTensor) -> np.ndarray: if not cls.can_quantize(tensor): - raise QuantError(f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}") + raise QuantError( + f"Can't quantize tensor with shape {tensor.shape} to {cls.qtype.name}" + ) if isinstance(tensor, LazyNumpyTensor): return cls.__quantize_lazy(tensor) else: @@ -210,9 +245,13 @@ class BF16(__Quant, qtype=GGMLQuantizationType.BF16): def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: n = blocks.view(np.uint32) # force nan to quiet - n = np.where((n & 0x7fffffff) > 0x7f800000, (n & np.uint32(0xffff0000)) | np.uint32(64 << 16), n) + n = np.where( + (n & 0x7FFFFFFF) > 0x7F800000, + (n & np.uint32(0xFFFF0000)) | np.uint32(64 << 16), + n, + ) # round to nearest even - n = (np.uint64(n) + (0x7fff + ((n >> 16) & 1))) >> 16 + n = (np.uint64(n) + (0x7FFF + ((n >> 16) & 1))) >> 16 return n.astype(np.uint16).view(np.uint8) @classmethod @@ -232,7 +271,14 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) # FIXME: Q4_0's reference rounding is cursed and depends on FMA - qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15) + qs = ( + np.trunc( + (np.float64(blocks) * np.float64(id)) + np.float64(8.5), + dtype=np.float32, + ) + .astype(np.uint8) + .clip(0, 15) + ) qs = qs.reshape((n_blocks, 2, cls.block_size // 2)) qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) @@ -249,10 +295,12 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) - qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.int8) - np.int8(8) - return (d * qs.astype(np.float32)) + return d * qs.astype(np.float32) class Q4_1(__Quant, qtype=GGMLQuantizationType.Q4_1): @@ -266,7 +314,11 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = (max - min) / 15 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - qs = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 15) + qs = ( + np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32) + .astype(np.uint8) + .clip(0, 15) + ) qs = qs.reshape((n_blocks, 2, cls.block_size // 2)) qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4)) @@ -286,7 +338,9 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) m = m.view(np.float16).astype(np.float32) - qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1)).astype(np.float32) return (d * qs) + m @@ -304,12 +358,21 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) # FIXME: Q5_0's reference rounding is cursed and depends on FMA - q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31) + q = ( + np.trunc( + (np.float64(blocks) * np.float64(id)) + np.float64(16.5), + dtype=np.float32, + ) + .astype(np.uint8) + .clip(0, 31) + ) qs = q.reshape((n_blocks, 2, cls.block_size // 2)) qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) - qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4) + qh = np.packbits( + q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little" + ).reshape(n_blocks, 4) d = d.astype(np.float16).view(np.uint8) @@ -325,14 +388,18 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) qh = qh.view(np.uint32) - qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32)) - ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, 1)) >> np.array( + [i for i in range(32)], dtype=np.uint32 + ).reshape((1, 32)) + ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qh = (qh & np.uint32(0x01)).astype(np.uint8) ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1)) qs = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(16) - return (d * qs.astype(np.float32)) + return d * qs.astype(np.float32) class Q5_1(__Quant, qtype=GGMLQuantizationType.Q5_1): @@ -346,12 +413,18 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = (max - min) / 31 with np.errstate(divide="ignore"): id = np.where(d == 0, 0, 1 / d) - q = np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32).astype(np.uint8).clip(0, 31) + q = ( + np.trunc((blocks - min) * id + np.float32(0.5), dtype=np.float32) + .astype(np.uint8) + .clip(0, 31) + ) qs = q.reshape((n_blocks, 2, cls.block_size // 2)) qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4)) - qh = np.packbits(q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little").reshape(n_blocks, 4) + qh = np.packbits( + q.reshape((n_blocks, 1, 32)) >> np.uint8(4), axis=-1, bitorder="little" + ).reshape(n_blocks, 4) d = d.astype(np.float16).view(np.uint8) m = min.astype(np.float16).view(np.uint8) @@ -370,8 +443,12 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: m = m.view(np.float16).astype(np.float32) qh = qh.view(np.uint32) - qh = qh.reshape((n_blocks, 1)) >> np.array([i for i in range(32)], dtype=np.uint32).reshape((1, 32)) - ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, 1)) >> np.array( + [i for i in range(32)], dtype=np.uint32 + ).reshape((1, 32)) + ql = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qh = (qh & np.uint32(0x01)).astype(np.uint8) ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1)) @@ -403,7 +480,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) x = x.view(np.int8).astype(np.float32) - return (x * d) + return x * d class Q2_K(__Quant, qtype=GGMLQuantizationType.Q2_K): @@ -420,7 +497,9 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: # (n_blocks, 16, 1) dl = (d * (scales & 0xF).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1)) - ml = (dmin * (scales >> 4).astype(np.float32)).reshape((n_blocks, QK_K // 16, 1)) + ml = (dmin * (scales >> 4).astype(np.float32)).reshape( + (n_blocks, QK_K // 16, 1) + ) shift = np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) @@ -458,21 +537,33 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: # 10: OOKKGGCC # 11: PPLLHHDD lscales, hscales = np.hsplit(scales, [8]) - lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1)) + lscales = lscales.reshape((n_blocks, 1, 8)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 2, 1)) lscales = lscales.reshape((n_blocks, 16)) - hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 4, 1)) + hscales = hscales.reshape((n_blocks, 1, 4)) >> np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 4, 1)) hscales = hscales.reshape((n_blocks, 16)) - scales = (lscales & np.uint8(0x0F)) | ((hscales & np.uint8(0x03)) << np.uint8(4)) + scales = (lscales & np.uint8(0x0F)) | ( + (hscales & np.uint8(0x03)) << np.uint8(4) + ) scales = (scales.astype(np.int8) - np.int8(32)).astype(np.float32) dl = (d * scales).reshape((n_blocks, 16, 1)) - ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) - qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1)) + ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) + qh = hmask.reshape(n_blocks, -1, 1, 32) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 8, 1)) ql = ql.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(3) - qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1)) + qh = qh.reshape((n_blocks, 16, QK_K // 16)) & np.uint8(1) qh = qh ^ np.uint8(1) # strangely, the offset is zero when the bitmask is 1 - q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype(np.float32) + q = (ql.astype(np.int8) - (qh << np.uint8(2)).astype(np.int8)).astype( + np.float32 + ) return (dl * q).reshape((n_blocks, QK_K)) @@ -521,7 +612,9 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1)) dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1)) - qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 32)).astype(np.float32) return (d * qs - dm).reshape((n_blocks, QK_K)) @@ -545,8 +638,12 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = (d * sc.astype(np.float32)).reshape((n_blocks, -1, 1)) dm = (dmin * m.astype(np.float32)).reshape((n_blocks, -1, 1)) - ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) - qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8, 1)) + ql = qs.reshape((n_blocks, -1, 1, 32)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 8, 1)) ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32)) qh = (qh & np.uint8(0x01)).reshape((n_blocks, -1, 32)) q = (ql | (qh << np.uint8(4))).astype(np.float32) @@ -567,9 +664,13 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) d = (d * scales).reshape((n_blocks, QK_K // 16, 1)) - ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + ql = ql.reshape((n_blocks, -1, 1, 64)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) ql = (ql & np.uint8(0x0F)).reshape((n_blocks, -1, 32)) - qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = qh.reshape((n_blocks, -1, 1, 32)) >> np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) qh = (qh & np.uint8(0x03)).reshape((n_blocks, -1, 32)) q = (ql | (qh << np.uint8(4))).astype(np.int8) - np.int8(32) q = q.reshape((n_blocks, QK_K // 16, -1)).astype(np.float32) @@ -588,12 +689,22 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: qs = np_roundf(blocks * id) qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) - qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):] - qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0, qs1, qh = ( + qs[..., : (32 * 5)], + qs[..., (32 * 5) : (48 * 5)], + qs[..., (48 * 5) :], + ) + qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array( + [81, 27, 9, 3, 1], dtype=np.uint8 + ).reshape((1, 1, 5, 1)) qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1)) - qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array( + [81, 27, 9, 3, 1], dtype=np.uint8 + ).reshape((1, 1, 5, 1)) qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1)) - qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array( + [81, 27, 9, 3], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) qh = np.sum(qh, axis=-2).reshape((n_blocks, -1)) qs = np.concatenate([qs0, qs1, qh], axis=-1) qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243 @@ -613,16 +724,22 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) qs0, qs1 = qs[..., :32], qs[..., 32:] - qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array( + [1, 3, 9, 27, 81], dtype=np.uint8 + ).reshape((1, 1, 5, 1)) qs0 = qs0.reshape((n_blocks, -1)) - qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array( + [1, 3, 9, 27, 81], dtype=np.uint8 + ).reshape((1, 1, 5, 1)) qs1 = qs1.reshape((n_blocks, -1)) - qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array( + [1, 3, 9, 27], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) qh = qh.reshape((n_blocks, -1)) qs = np.concatenate([qs0, qs1, qh], axis=-1) qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1) - return (d * qs.astype(np.float32)) + return d * qs.astype(np.float32) class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0): @@ -636,7 +753,9 @@ def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: qs = np_roundf(blocks * id) qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) - qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :] qs = qs.reshape((n_blocks, -1)) @@ -652,10 +771,12 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) - qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 1, 4, 1)) qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1) - return (d * qs.astype(np.float32)) + return d * qs.astype(np.float32) class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): @@ -673,7 +794,7 @@ class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): # iq2xxs_grid, but with each byte of the original packed in 2 bits, # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. grid_shape = (256, 8) - grid_map = (0x08, 0x19, 0x2b) + grid_map = (0x08, 0x19, 0x2B) grid_hex = ( b"00000200050008000a00110014002000220028002a0041004400500058006100" b"6400800082008a00a20001010401100115014001840198010002020222028202" @@ -703,21 +824,33 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: qs = qs.view(np.uint32).reshape(n_blocks, -1, 2) - db = d * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) * np.float32(0.25) + db = ( + d + * (np.float32(0.5) + (qs[..., 1] >> 28).astype(np.float32)) + * np.float32(0.25) + ) db = db.reshape((n_blocks, -1, 1, 1)) # get the sign indices and unpack the bits - signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4)) + signs = qs[..., 1].reshape((n_blocks, -1, 1)) >> np.array( + [0, 7, 14, 21], dtype=np.uint32 + ).reshape((1, 1, 4)) ksigns = np.frombuffer(cls.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128)) signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1)) signs = np.take_along_axis(ksigns, signs, axis=-1) - signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8)) + signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 1, 8)) signs = signs & np.uint8(0x01) signs = np.where(signs == 0, np.float32(1), np.float32(-1)) signs = signs.reshape((n_blocks, -1, 4, 8)) assert cls.grid is not None - grid = np.take_along_axis(cls.grid, qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = np.take_along_axis( + cls.grid, + qs[..., 0].copy().view(np.uint8).reshape((n_blocks, -1, 1, 1)), + axis=-2, + ) grid = grid.reshape((n_blocks, -1, 4, 8)) return (db * grid * signs).reshape((n_blocks, -1)) @@ -727,7 +860,7 @@ class IQ2_XS(__Quant, qtype=GGMLQuantizationType.IQ2_XS): # iq2xs_grid, but with each byte of the original packed in 2 bits, # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. grid_shape = (512, 8) - grid_map = (0x08, 0x19, 0x2b) + grid_map = (0x08, 0x19, 0x2B) grid_hex = ( b"00000200050008000a0011001400160019002000220025002800410044004600" b"49005000520055005800610064008000820085008800910094009900a0000101" @@ -773,7 +906,9 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) qs = qs.view(np.uint16) - scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = scales.reshape((n_blocks, -1, 1)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2)) scales = (scales & 0x0F).reshape((n_blocks, -1)) db = d * (np.float32(0.5) + scales) * np.float32(0.25) db = db.reshape((n_blocks, -1, 1, 1)) @@ -781,13 +916,17 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: # get the sign indices and unpack the bits signs = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape(1, 1, 128) signs = np.take_along_axis(signs, (qs >> 9).reshape((n_blocks, -1, 1)), axis=-1) - signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs.reshape((n_blocks, -1, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 8)) signs = signs & np.uint8(0x01) signs = np.where(signs == 0, np.float32(1), np.float32(-1)) signs = signs.reshape((n_blocks, -1, 2, 8)) assert cls.grid is not None - grid = np.take_along_axis(cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2) + grid = np.take_along_axis( + cls.grid, (qs & np.uint16(511)).reshape((n_blocks, -1, 1, 1)), axis=-2 + ) grid = grid.reshape((n_blocks, -1, 2, 8)) return (db * grid * signs).reshape((n_blocks, -1)) @@ -797,7 +936,7 @@ class IQ2_S(__Quant, qtype=GGMLQuantizationType.IQ2_S): # iq2s_grid, but with each byte of the original packed in 2 bits, # by mapping 0x08 to 0, 0x19 to 1, and 0x2b to 2. grid_shape = (1024, 8) - grid_map = (0x08, 0x19, 0x2b) + grid_map = (0x08, 0x19, 0x2B) grid_hex = ( b"00000200050008000a0011001400160019002000220025002800410044004600" b"490050005200550058006100640066006900800082008500880091009400a000" @@ -876,19 +1015,27 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) - scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = scales.reshape((n_blocks, -1, 1)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2)) scales = (scales & 0x0F).reshape((n_blocks, -1)) db = d * (np.float32(0.5) + scales) * np.float32(0.25) db = db.reshape((n_blocks, -1, 1, 1)) # unpack the sign bits - signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs.reshape((n_blocks, -1, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 8)) signs = signs & np.uint8(0x01) signs = np.where(signs == 0, np.float32(1), np.float32(-1)) signs = signs.reshape((n_blocks, -1, 2, 8)) - qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4)) - qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 1)) >> np.array( + [0, 2, 4, 6], dtype=np.uint8 + ).reshape((1, 1, 4)) + qs = qs.astype(np.uint16) | ((qh & 0x03).astype(np.uint16) << 8).reshape( + (n_blocks, -1) + ) assert cls.grid is not None grid = np.take_along_axis(cls.grid, qs.reshape((n_blocks, -1, 1, 1)), axis=-2) @@ -899,7 +1046,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: class IQ3_XXS(__Quant, qtype=GGMLQuantizationType.IQ3_XXS): grid_shape = (256, 4) - grid_map = (0x04, 0x0c, 0x14, 0x1c, 0x24, 0x2c, 0x34, 0x3e) + grid_map = (0x04, 0x0C, 0x14, 0x1C, 0x24, 0x2C, 0x34, 0x3E) grid_hex = ( b"0000020004001100130017002000220031004200730075000101030110011201" b"2101250130013201410154017001000202020402110220022202310233023702" @@ -933,11 +1080,15 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: db = db.reshape((n_blocks, -1, 1, 1)) # get the sign indices and unpack the bits - signs = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 7, 14, 21], dtype=np.uint32).reshape((1, 1, 4)) + signs = scales.reshape((n_blocks, -1, 1)) >> np.array( + [0, 7, 14, 21], dtype=np.uint32 + ).reshape((1, 1, 4)) ksigns = np.frombuffer(IQ2_XXS.ksigns, dtype=np.uint8).reshape((1, 1, 1, 128)) signs = (signs & np.uint32(0x7F)).reshape((n_blocks, -1, 4, 1)) signs = np.take_along_axis(ksigns, signs, axis=-1) - signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 1, 8)) + signs = signs.reshape((n_blocks, -1, 4, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 1, 8)) signs = signs & np.uint8(0x01) signs = np.where(signs == 0, np.float32(1), np.float32(-1)) signs = signs.reshape((n_blocks, -1, 4, 8)) @@ -951,7 +1102,7 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: class IQ3_S(__Quant, qtype=GGMLQuantizationType.IQ3_S): grid_shape = (512, 4) - grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0b, 0x0d, 0x0f) + grid_map = (0x01, 0x03, 0x05, 0x07, 0x09, 0x0B, 0x0D, 0x0F) grid_hex = ( b"0000010002000500070010001100120014001600200021002500330040004200" b"4500470051005300600062007100740077000001010102010401100111011501" @@ -998,18 +1149,24 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) - scales = scales.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) + scales = scales.reshape((n_blocks, -1, 1)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2)) scales = (scales & 0x0F).reshape((n_blocks, -1)) db = d * (1 + 2 * scales) db = db.reshape((n_blocks, -1, 1, 1)) # unpack the sign bits - signs = signs.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8).reshape((1, 1, 8)) + signs = signs.reshape((n_blocks, -1, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ).reshape((1, 1, 8)) signs = signs & np.uint8(0x01) signs = np.where(signs == 0, np.float32(1), np.float32(-1)) signs = signs.reshape((n_blocks, -1, 4, 8)) - qh = qh.reshape((n_blocks, -1, 1)) >> np.array([i for i in range(8)], dtype=np.uint8) + qh = qh.reshape((n_blocks, -1, 1)) >> np.array( + [i for i in range(8)], dtype=np.uint8 + ) qh = (qh & 0x01).astype(np.uint16).reshape((n_blocks, -1)) qs = qs.astype(np.uint16) | (qh << 8) @@ -1173,7 +1330,9 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: delta = np.where((qh & np.uint16(0x8000)) == 0, cls.delta, -cls.delta) delta = delta.reshape((n_blocks, -1, 1, 1)) - qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4)) + qh = qh.reshape((n_blocks, -1, 1)) >> np.array( + [0, 3, 6, 9], dtype=np.uint16 + ).reshape((1, 1, 4)) qs = qs.astype(np.uint16) | ((qh & 7) << 8).reshape((n_blocks, -1)) assert cls.grid is not None @@ -1200,17 +1359,25 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: # The f16 scale is packed across multiple bytes scales = scales.view(np.uint16) - d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array([12, 8, 4, 0], dtype=np.uint16).reshape((1, 4)) + d = (scales.reshape((n_blocks, 4)) & np.uint16(0xF000)) >> np.array( + [12, 8, 4, 0], dtype=np.uint16 + ).reshape((1, 4)) d = d[..., 0] | d[..., 1] | d[..., 2] | d[..., 3] d = d.view(np.float16).astype(np.float32).reshape((n_blocks, 1)) - scales = scales.reshape(n_blocks, -1, 1) >> np.array([0, 3, 6, 9], dtype=np.uint16).reshape((1, 1, 4)) + scales = scales.reshape(n_blocks, -1, 1) >> np.array( + [0, 3, 6, 9], dtype=np.uint16 + ).reshape((1, 1, 4)) scales = (scales & 0x07).reshape((n_blocks, -1)) dl = d * (2 * scales + 1) dl = dl.reshape((n_blocks, -1, 2, 1, 1)) - qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) - qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape( + (1, 1, 2) + ) + qs = qs.astype(np.uint16) | ((qh & 0x07).astype(np.uint16) << 8).reshape( + (n_blocks, -1) + ) delta = np.where(qh & 0x08 == 0, cls.delta, -cls.delta) delta = delta.reshape((n_blocks, -1, 2, 2, 1)) @@ -1233,14 +1400,20 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) - qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 1, cls.block_size // 2)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qs = (qs & np.uint8(0x0F)).reshape((n_blocks, -1, 1)) kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16) - qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1)) + qs = ( + np.take_along_axis(kvalues, qs, axis=-1) + .astype(np.float32) + .reshape((n_blocks, -1)) + ) - return (d * qs) + return d * qs class IQ4_XS(__Quant, qtype=GGMLQuantizationType.IQ4_XS): @@ -1255,18 +1428,28 @@ def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: d = d.view(np.float16).astype(np.float32) scales_h = scales_h.view(np.uint16) - scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2)) - scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array([2 * i for i in range(QK_K // 32)], dtype=np.uint16).reshape((1, -1, 1)) + scales_l = scales_l.reshape((n_blocks, -1, 1)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2)) + scales_h = scales_h.reshape((n_blocks, 1, -1)) >> np.array( + [2 * i for i in range(QK_K // 32)], dtype=np.uint16 + ).reshape((1, -1, 1)) scales_l = scales_l.reshape((n_blocks, -1)) & np.uint8(0x0F) scales_h = scales_h.reshape((n_blocks, -1)).astype(np.uint8) & np.uint8(0x03) scales = (scales_l | (scales_h << np.uint8(4))).astype(np.int8) - np.int8(32) dl = (d * scales.astype(np.float32)).reshape((n_blocks, -1, 1)) - qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qs = qs.reshape((n_blocks, -1, 1, 16)) >> np.array( + [0, 4], dtype=np.uint8 + ).reshape((1, 1, 2, 1)) qs = qs.reshape((n_blocks, -1, 32, 1)) & np.uint8(0x0F) kvalues = np.array(IQ4_NL.kvalues, dtype=np.int8).reshape((1, 1, 1, -1)) - qs = np.take_along_axis(kvalues, qs, axis=-1).astype(np.float32).reshape((n_blocks, -1, 32)) + qs = ( + np.take_along_axis(kvalues, qs, axis=-1) + .astype(np.float32) + .reshape((n_blocks, -1, 32)) + ) return (dl * qs).reshape((n_blocks, -1)) diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_convert_endian.py b/smallthinker/gguf-py/gguf/scripts/gguf_convert_endian.py index 0e0febaa..3474bf97 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_convert_endian.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_convert_endian.py @@ -11,7 +11,10 @@ import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import gguf @@ -21,12 +24,14 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None: file_endian = reader.endianess.name - if reader.byte_order == 'S': - host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' + if reader.byte_order == "S": + host_endian = "BIG" if file_endian == "LITTLE" else "LITTLE" else: host_endian = file_endian order = host_endian if args.order == "native" else args.order.upper() - logger.info(f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian") + logger.info( + f"* Host is {host_endian} endian, GGUF file seems to be {file_endian} endian" + ) if file_endian == order: logger.info(f"* File is already {order} endian. Nothing to do.") sys.exit(0) @@ -39,28 +44,42 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None gguf.GGMLQuantizationType.Q4_K, gguf.GGMLQuantizationType.Q6_K, ): - raise ValueError(f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}") + raise ValueError( + f"Cannot handle type {tensor.tensor_type.name} for tensor {repr(tensor.name)}" + ) logger.info(f"* Preparing to convert from {file_endian} to {order}") if args.dry_run: return logger.warning("*** Warning *** Warning *** Warning **") - logger.warning("* This conversion process may damage the file. Ensure you have a backup.") + logger.warning( + "* This conversion process may damage the file. Ensure you have a backup." + ) if order != host_endian: - logger.warning("* Requested endian differs from host, you will not be able to load the model on this machine.") - logger.warning("* The file will be modified immediately, so if conversion fails or is interrupted") - logger.warning("* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:") + logger.warning( + "* Requested endian differs from host, you will not be able to load the model on this machine." + ) + logger.warning( + "* The file will be modified immediately, so if conversion fails or is interrupted" + ) + logger.warning( + "* the file will be corrupted. Enter exactly YES if you are positive you want to proceed:" + ) response = input("YES, I am sure> ") if response != "YES": logger.warning("You didn't enter YES. Okay then, see ya!") sys.exit(0) logger.info(f"* Converting fields ({len(reader.fields)})") for idx, field in enumerate(reader.fields.values()): - logger.info(f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}") + logger.info( + f"- {idx:4}: Converting field {repr(field.name)}, part count: {len(field.parts)}" + ) for part in field.parts: part.byteswap(inplace=True) logger.info(f"* Converting tensors ({len(reader.tensors)})") - for idx, tensor in enumerate(pbar := tqdm(reader.tensors, desc="Converting tensor")): + for idx, tensor in enumerate( + pbar := tqdm(reader.tensors, desc="Converting tensor") + ): log_message = ( f"Converting tensor {repr(tensor.name)}, " f"type={tensor.tensor_type.name}, " @@ -77,19 +96,25 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None # Specific handling of block_q8_0 is required. # Each block_q8_0 consists of an f16 delta (scaling factor) followed by 32 int8 quantizations. - block_size = 34 # 34 bytes = + 32 * + block_size = 34 # 34 bytes = + 32 * n_blocks = len(tensor.data) // block_size - for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + for block_num in ( + inner_pbar := tqdm( + range(n_blocks), desc="Byte-swapping Blocks", leave=False + ) + ): block_offs = block_num * block_size # Byte-Swap f16 sized delta field - delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta = tensor.data[block_offs : block_offs + 2].view(dtype=np.uint16) delta.byteswap(inplace=True) # Byte-Swap Q8 weights if block_num % 100000 == 0: - inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + inner_pbar.set_description( + f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]" + ) elif tensor.tensor_type == gguf.GGMLQuantizationType.Q4_K: # Handle Q4_K tensor blocks (block_q4_k) @@ -105,19 +130,27 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None block_size = 144 n_blocks = len(tensor.data) // block_size - for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + for block_num in ( + inner_pbar := tqdm( + range(n_blocks), desc="Byte-swapping Blocks", leave=False + ) + ): block_offs = block_num * block_size # Byte-Swap f16 sized fields - delta = tensor.data[block_offs:block_offs + 2].view(dtype=np.uint16) + delta = tensor.data[block_offs : block_offs + 2].view(dtype=np.uint16) delta.byteswap(inplace=True) - delta = tensor.data[block_offs + 2:block_offs + 4].view(dtype=np.uint16) + delta = tensor.data[block_offs + 2 : block_offs + 4].view( + dtype=np.uint16 + ) delta.byteswap(inplace=True) # Byte-Swap if block_num % 100000 == 0: - inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + inner_pbar.set_description( + f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]" + ) elif tensor.tensor_type == gguf.GGMLQuantizationType.Q6_K: # Handle Q6_K tensor blocks (block_q6_k) @@ -133,16 +166,24 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None block_size = 210 n_blocks = len(tensor.data) // block_size - for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)): + for block_num in ( + inner_pbar := tqdm( + range(n_blocks), desc="Byte-swapping Blocks", leave=False + ) + ): block_offs = block_num * block_size # Byte-Swap f16 sized field - delta = tensor.data[block_offs + 208:block_offs + 210].view(dtype=np.uint16) + delta = tensor.data[block_offs + 208 : block_offs + 210].view( + dtype=np.uint16 + ) delta.byteswap(inplace=True) # Byte-Swap if block_num % 100000 == 0: - inner_pbar.set_description(f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]") + inner_pbar.set_description( + f"Byte-swapping Blocks [{(n_blocks - block_num) // n_blocks}]" + ) else: # Handle other tensor types @@ -156,25 +197,31 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None def main() -> None: parser = argparse.ArgumentParser(description="Convert GGUF file byte order") parser.add_argument( - "model", type=str, + "model", + type=str, help="GGUF format model filename", ) parser.add_argument( - "order", type=str, choices=['big', 'little', 'native'], + "order", + type=str, + choices=["big", "little", "native"], help="Requested byte order", ) parser.add_argument( - "--dry-run", action="store_true", + "--dry-run", + action="store_true", help="Don't actually change anything", ) - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - logger.info(f'* Loading: {args.model}') - reader = gguf.GGUFReader(args.model, 'r' if args.dry_run else 'r+') + logger.info(f"* Loading: {args.model}") + reader = gguf.GGUFReader(args.model, "r" if args.dry_run else "r+") convert_byteorder(reader, args) diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_dump.py b/smallthinker/gguf-py/gguf/scripts/gguf_dump.py index e282892d..e51e0708 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_dump.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_dump.py @@ -10,7 +10,10 @@ from typing import Any # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader, GGUFValueType, ReaderTensor # noqa: E402 @@ -20,8 +23,8 @@ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: file_endian = reader.endianess.name - if reader.byte_order == 'S': - host_endian = 'BIG' if file_endian == 'LITTLE' else 'LITTLE' + if reader.byte_order == "S": + host_endian = "BIG" if file_endian == "LITTLE" else "LITTLE" else: host_endian = file_endian return (host_endian, file_endian) @@ -31,43 +34,53 @@ def get_file_host_endian(reader: GGUFReader) -> tuple[str, str]: # please see the comments in the modify_gguf.py example. def dump_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: host_endian, file_endian = get_file_host_endian(reader) - print(f'* File is {file_endian} endian, script is running on a {host_endian} endian host.') # noqa: NP100 - print(f'* Dumping {len(reader.fields)} key/value pair(s)') # noqa: NP100 + print( + f"* File is {file_endian} endian, script is running on a {host_endian} endian host." + ) # noqa: NP100 + print(f"* Dumping {len(reader.fields)} key/value pair(s)") # noqa: NP100 for n, field in enumerate(reader.fields.values(), 1): if not field.types: - pretty_type = 'N/A' + pretty_type = "N/A" elif field.types[0] == GGUFValueType.ARRAY: nest_count = len(field.types) - 1 - pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + pretty_type = ( + "[" * nest_count + str(field.types[-1].name) + "]" * nest_count + ) else: pretty_type = str(field.types[-1].name) - log_message = f' {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}' + log_message = f" {n:5}: {pretty_type:10} | {len(field.data):8} | {field.name}" if field.types: curr_type = field.types[0] if curr_type == GGUFValueType.STRING: content = field.contents() if len(content) > 60: - content = content[:57] + '...' - log_message += ' = {0}'.format(repr(content)) + content = content[:57] + "..." + log_message += " = {0}".format(repr(content)) elif curr_type in reader.gguf_scalar_to_np: - log_message += ' = {0}'.format(field.contents()) + log_message += " = {0}".format(field.contents()) else: content = repr(field.contents(slice(6))) if len(field.data) > 6: - content = content[:-1] + ', ...]' - log_message += ' = {0}'.format(content) + content = content[:-1] + ", ...]" + log_message += " = {0}".format(content) print(log_message) # noqa: NP100 if args.no_tensors: return - print(f'* Dumping {len(reader.tensors)} tensor(s)') # noqa: NP100 + print(f"* Dumping {len(reader.tensors)} tensor(s)") # noqa: NP100 for n, tensor in enumerate(reader.tensors, 1): - prettydims = ', '.join('{0:5}'.format(d) for d in list(tensor.shape) + [1] * (4 - len(tensor.shape))) - print(f' {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}') # noqa: NP100 + prettydims = ", ".join( + "{0:5}".format(d) + for d in list(tensor.shape) + [1] * (4 - len(tensor.shape)) + ) + print( + f" {n:5}: {tensor.n_elements:10} | {prettydims} | {tensor.tensor_type.name:7} | {tensor.name}" + ) # noqa: NP100 def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: import json + host_endian, file_endian = get_file_host_endian(reader) metadata: dict[str, Any] = {} tensors: dict[str, Any] = {} @@ -80,7 +93,7 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: for idx, field in enumerate(reader.fields.values()): curr: dict[str, Any] = { "index": idx, - "type": field.types[0].name if field.types else 'UNKNOWN', + "type": field.types[0].name if field.types else "UNKNOWN", "offset": field.offset, } metadata[field.name] = curr @@ -102,72 +115,108 @@ def dump_metadata_json(reader: GGUFReader, args: argparse.Namespace) -> None: json.dump(result, sys.stdout) -def markdown_table_with_alignment_support(header_map: list[dict[str, str]], data: list[dict[str, Any]]): +def markdown_table_with_alignment_support( + header_map: list[dict[str, str]], data: list[dict[str, Any]] +): # JSON to Markdown table formatting: https://stackoverflow.com/a/72983854/2850957 # Alignment Utility Function def strAlign(padding: int, alignMode: str | None, strVal: str): - if alignMode == 'center': + if alignMode == "center": return strVal.center(padding) - elif alignMode == 'right': - return strVal.rjust(padding - 1) + ' ' - elif alignMode == 'left': - return ' ' + strVal.ljust(padding - 1) - else: # default left - return ' ' + strVal.ljust(padding - 1) + elif alignMode == "right": + return strVal.rjust(padding - 1) + " " + elif alignMode == "left": + return " " + strVal.ljust(padding - 1) + else: # default left + return " " + strVal.ljust(padding - 1) def dashAlign(padding: int, alignMode: str | None): - if alignMode == 'center': - return ':' + '-' * (padding - 2) + ':' - elif alignMode == 'right': - return '-' * (padding - 1) + ':' - elif alignMode == 'left': - return ':' + '-' * (padding - 1) - else: # default left - return '-' * (padding) + if alignMode == "center": + return ":" + "-" * (padding - 2) + ":" + elif alignMode == "right": + return "-" * (padding - 1) + ":" + elif alignMode == "left": + return ":" + "-" * (padding - 1) + else: # default left + return "-" * (padding) # Calculate Padding For Each Column Based On Header and Data Length rowsPadding = {} for index, columnEntry in enumerate(header_map): - padCount = max([len(str(v)) for d in data for k, v in d.items() if k == columnEntry['key_name']], default=0) + 2 - headerPadCount = len(columnEntry['header_name']) + 2 + padCount = ( + max( + [ + len(str(v)) + for d in data + for k, v in d.items() + if k == columnEntry["key_name"] + ], + default=0, + ) + + 2 + ) + headerPadCount = len(columnEntry["header_name"]) + 2 rowsPadding[index] = headerPadCount if padCount <= headerPadCount else padCount # Render Markdown Header rows = [] - rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(columnEntry['header_name'])) for index, columnEntry in enumerate(header_map))) - rows.append('|'.join(dashAlign(rowsPadding[index], columnEntry.get('align')) for index, columnEntry in enumerate(header_map))) + rows.append( + "|".join( + strAlign( + rowsPadding[index], + columnEntry.get("align"), + str(columnEntry["header_name"]), + ) + for index, columnEntry in enumerate(header_map) + ) + ) + rows.append( + "|".join( + dashAlign(rowsPadding[index], columnEntry.get("align")) + for index, columnEntry in enumerate(header_map) + ) + ) # Render Tabular Data for item in data: - rows.append('|'.join(strAlign(rowsPadding[index], columnEntry.get('align'), str(item[columnEntry['key_name']])) for index, columnEntry in enumerate(header_map))) + rows.append( + "|".join( + strAlign( + rowsPadding[index], + columnEntry.get("align"), + str(item[columnEntry["key_name"]]), + ) + for index, columnEntry in enumerate(header_map) + ) + ) # Convert Tabular String Rows Into String tableString = "" for row in rows: - tableString += f'|{row}|\n' + tableString += f"|{row}|\n" return tableString def element_count_rounded_notation(count: int) -> str: - if count > 1e15 : + if count > 1e15: # Quadrillion scaled_amount = count * 1e-15 scale_suffix = "Q" - elif count > 1e12 : + elif count > 1e12: # Trillions scaled_amount = count * 1e-12 scale_suffix = "T" - elif count > 1e9 : + elif count > 1e9: # Billions scaled_amount = count * 1e-9 scale_suffix = "B" - elif count > 1e6 : + elif count > 1e6: # Millions scaled_amount = count * 1e-6 scale_suffix = "M" - elif count > 1e3 : + elif count > 1e3: # Thousands scaled_amount = count * 1e-3 scale_suffix = "K" @@ -183,35 +232,35 @@ def translate_tensor_name(name): # Source: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#standardized-tensor-names abbreviation_dictionary = { - 'token_embd': 'Token embedding', - 'pos_embd': 'Position embedding', - 'output_norm': 'Output normalization', - 'output': 'Output', - 'attn_norm': 'Attention normalization', - 'attn_norm_2': 'Attention normalization', - 'attn_qkv': 'Attention query-key-value', - 'attn_q': 'Attention query', - 'attn_k': 'Attention key', - 'attn_v': 'Attention value', - 'attn_output': 'Attention output', - 'ffn_norm': 'Feed-forward network normalization', - 'ffn_up': 'Feed-forward network "up"', - 'ffn_gate': 'Feed-forward network "gate"', - 'ffn_down': 'Feed-forward network "down"', - 'ffn_gate_inp': 'Expert-routing layer for the Feed-forward network in Mixture of Expert models', - 'ffn_gate_exp': 'Feed-forward network "gate" layer per expert in Mixture of Expert models', - 'ffn_down_exp': 'Feed-forward network "down" layer per expert in Mixture of Expert models', - 'ffn_up_exp': 'Feed-forward network "up" layer per expert in Mixture of Expert models', - 'ssm_in': 'State space model input projections', - 'ssm_conv1d': 'State space model rolling/shift', - 'ssm_x': 'State space model selective parametrization', - 'ssm_a': 'State space model state compression', - 'ssm_d': 'State space model skip connection', - 'ssm_dt': 'State space model time step', - 'ssm_out': 'State space model output projection', - 'blk': 'Block', - 'enc': 'Encoder', - 'dec': 'Decoder', + "token_embd": "Token embedding", + "pos_embd": "Position embedding", + "output_norm": "Output normalization", + "output": "Output", + "attn_norm": "Attention normalization", + "attn_norm_2": "Attention normalization", + "attn_qkv": "Attention query-key-value", + "attn_q": "Attention query", + "attn_k": "Attention key", + "attn_v": "Attention value", + "attn_output": "Attention output", + "ffn_norm": "Feed-forward network normalization", + "ffn_up": 'Feed-forward network "up"', + "ffn_gate": 'Feed-forward network "gate"', + "ffn_down": 'Feed-forward network "down"', + "ffn_gate_inp": "Expert-routing layer for the Feed-forward network in Mixture of Expert models", + "ffn_gate_exp": 'Feed-forward network "gate" layer per expert in Mixture of Expert models', + "ffn_down_exp": 'Feed-forward network "down" layer per expert in Mixture of Expert models', + "ffn_up_exp": 'Feed-forward network "up" layer per expert in Mixture of Expert models', + "ssm_in": "State space model input projections", + "ssm_conv1d": "State space model rolling/shift", + "ssm_x": "State space model selective parametrization", + "ssm_a": "State space model state compression", + "ssm_d": "State space model skip connection", + "ssm_dt": "State space model time step", + "ssm_out": "State space model output projection", + "blk": "Block", + "enc": "Encoder", + "dec": "Decoder", } expanded_words = [] @@ -222,37 +271,42 @@ def translate_tensor_name(name): else: expanded_words.append(word.title()) - return ' '.join(expanded_words) + return " ".join(expanded_words) def dump_markdown_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: host_endian, file_endian = get_file_host_endian(reader) markdown_content = "" - markdown_content += f'# {args.model} - GGUF Internal File Dump\n\n' - markdown_content += f'- Endian: {file_endian} endian\n' - markdown_content += '\n' - markdown_content += '## Key Value Metadata Store\n\n' - markdown_content += f'There are {len(reader.fields)} key-value pairs in this file\n' - markdown_content += '\n' + markdown_content += f"# {args.model} - GGUF Internal File Dump\n\n" + markdown_content += f"- Endian: {file_endian} endian\n" + markdown_content += "\n" + markdown_content += "## Key Value Metadata Store\n\n" + markdown_content += f"There are {len(reader.fields)} key-value pairs in this file\n" + markdown_content += "\n" kv_dump_table: list[dict[str, str | int]] = [] for n, field in enumerate(reader.fields.values(), 1): if not field.types: - pretty_type = 'N/A' + pretty_type = "N/A" elif field.types[0] == GGUFValueType.ARRAY: nest_count = len(field.types) - 1 - pretty_type = '[' * nest_count + str(field.types[-1].name) + ']' * nest_count + pretty_type = ( + "[" * nest_count + str(field.types[-1].name) + "]" * nest_count + ) else: pretty_type = str(field.types[-1].name) def escape_markdown_inline_code(value_string): # Find the longest contiguous sequence of backticks in the string then # wrap string with appropriate number of backticks required to escape it - max_backticks = max((len(match.group(0)) for match in re.finditer(r'`+', value_string)), default=0) - inline_code_marker = '`' * (max_backticks + 1) + max_backticks = max( + (len(match.group(0)) for match in re.finditer(r"`+", value_string)), + default=0, + ) + inline_code_marker = "`" * (max_backticks + 1) # If the string starts or ends with a backtick, add a space at the beginning and end - if value_string.startswith('`') or value_string.endswith('`'): + if value_string.startswith("`") or value_string.endswith("`"): value_string = f" {value_string} " return f"{inline_code_marker}{value_string}{inline_code_marker}" @@ -263,10 +317,14 @@ def escape_markdown_inline_code(value_string): curr_type = field.types[0] if curr_type == GGUFValueType.STRING: truncate_length = 60 - value_string = str(bytes(field.parts[-1]), encoding='utf-8') + value_string = str(bytes(field.parts[-1]), encoding="utf-8") if len(value_string) > truncate_length: - head = escape_markdown_inline_code(value_string[:truncate_length // 2]) - tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + head = escape_markdown_inline_code( + value_string[: truncate_length // 2] + ) + tail = escape_markdown_inline_code( + value_string[-truncate_length // 2 :] + ) value = "{head}...{tail}".format(head=head, tail=tail) else: value = escape_markdown_inline_code(value_string) @@ -281,10 +339,19 @@ def escape_markdown_inline_code(value_string): render_element = min(5, total_elements) for element_pos in range(render_element): truncate_length = 30 - value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + value_string = str( + bytes( + field.parts[-1 - (total_elements - element_pos - 1) * 2] + ), + encoding="utf-8", + ) if len(value_string) > truncate_length: - head = escape_markdown_inline_code(value_string[:truncate_length // 2]) - tail = escape_markdown_inline_code(value_string[-truncate_length // 2:]) + head = escape_markdown_inline_code( + value_string[: truncate_length // 2] + ) + tail = escape_markdown_inline_code( + value_string[-truncate_length // 2 :] + ) value = "{head}...{tail}".format(head=head, tail=tail) else: value = escape_markdown_inline_code(value_string) @@ -293,21 +360,33 @@ def escape_markdown_inline_code(value_string): elif curr_type in reader.gguf_scalar_to_np: render_element = min(7, total_elements) for element_pos in range(render_element): - array_elements.append(str(field.parts[-1 - (total_elements - element_pos - 1)][0])) + array_elements.append( + str(field.parts[-1 - (total_elements - element_pos - 1)][0]) + ) value = f'[ {", ".join(array_elements).strip()}{", ..." if total_elements > len(array_elements) else ""} ]' - kv_dump_table.append({"n":n, "pretty_type":pretty_type, "total_elements":total_elements, "field_name":field.name, "value":value}) + kv_dump_table.append( + { + "n": n, + "pretty_type": pretty_type, + "total_elements": total_elements, + "field_name": field.name, + "value": value, + } + ) kv_dump_table_header_map = [ - {'key_name':'n', 'header_name':'POS', 'align':'right'}, - {'key_name':'pretty_type', 'header_name':'TYPE', 'align':'left'}, - {'key_name':'total_elements', 'header_name':'Count', 'align':'right'}, - {'key_name':'field_name', 'header_name':'Key', 'align':'left'}, - {'key_name':'value', 'header_name':'Value', 'align':'left'}, + {"key_name": "n", "header_name": "POS", "align": "right"}, + {"key_name": "pretty_type", "header_name": "TYPE", "align": "left"}, + {"key_name": "total_elements", "header_name": "Count", "align": "right"}, + {"key_name": "field_name", "header_name": "Key", "align": "left"}, + {"key_name": "value", "header_name": "Value", "align": "left"}, ] - markdown_content += markdown_table_with_alignment_support(kv_dump_table_header_map, kv_dump_table) + markdown_content += markdown_table_with_alignment_support( + kv_dump_table_header_map, kv_dump_table + ) markdown_content += "\n" @@ -320,15 +399,17 @@ def escape_markdown_inline_code(value_string): # Parsing Tensors Record for key, tensor in enumerate(reader.tensors): - tensor_components = tensor.name.split('.') + tensor_components = tensor.name.split(".") # Classify Tensor Group tensor_group_name = "base" - if tensor_components[0] == 'blk': + if tensor_components[0] == "blk": tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}" - elif tensor_components[0] in ['enc', 'dec'] and tensor_components[1] == 'blk': + elif ( + tensor_components[0] in ["enc", "dec"] and tensor_components[1] == "blk" + ): tensor_group_name = f"{tensor_components[0]}.{tensor_components[1]}.{tensor_components[2]}" - elif tensor_components[0] in ['enc', 'dec']: + elif tensor_components[0] in ["enc", "dec"]: tensor_group_name = f"{tensor_components[0]}" # Check if new Tensor Group @@ -341,9 +422,11 @@ def escape_markdown_inline_code(value_string): tensor_name_to_key[tensor.name] = key # Tensors Mapping Dump - markdown_content += f'## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n' - markdown_content += f'Total number of elements in all tensors: {total_elements} Elements\n' - markdown_content += '\n' + markdown_content += f"## Tensors Overview {element_count_rounded_notation(total_elements)} Elements\n\n" + markdown_content += ( + f"Total number of elements in all tensors: {total_elements} Elements\n" + ) + markdown_content += "\n" for group in tensor_prefix_order: tensors = tensor_groups[group] @@ -353,24 +436,41 @@ def escape_markdown_inline_code(value_string): markdown_content += "\n" markdown_content += "### Tensor Data Offset\n" - markdown_content += '\n' - markdown_content += 'This table contains the offset and data segment relative to start of file\n' - markdown_content += '\n' + markdown_content += "\n" + markdown_content += "This table contains the offset and data segment relative to start of file\n" + markdown_content += "\n" tensor_mapping_table: list[dict[str, str | int]] = [] for key, tensor in enumerate(reader.tensors): - data_offset_pretty = '{0:#16x}'.format(tensor.data_offset) - data_size_pretty = '{0:#16x}'.format(tensor.n_bytes) - tensor_mapping_table.append({"t_id":key, "layer_name":tensor.name, "data_offset":data_offset_pretty, "data_size":data_size_pretty}) + data_offset_pretty = "{0:#16x}".format(tensor.data_offset) + data_size_pretty = "{0:#16x}".format(tensor.n_bytes) + tensor_mapping_table.append( + { + "t_id": key, + "layer_name": tensor.name, + "data_offset": data_offset_pretty, + "data_size": data_size_pretty, + } + ) tensors_mapping_table_header_map = [ - {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, - {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, - {'key_name':'data_offset', 'header_name':'Data Offset (B)', 'align':'right'}, - {'key_name':'data_size', 'header_name':'Data Size (B)', 'align':'right'}, + {"key_name": "t_id", "header_name": "T_ID", "align": "right"}, + { + "key_name": "layer_name", + "header_name": "Tensor Layer Name", + "align": "left", + }, + { + "key_name": "data_offset", + "header_name": "Data Offset (B)", + "align": "right", + }, + {"key_name": "data_size", "header_name": "Data Size (B)", "align": "right"}, ] - markdown_content += markdown_table_with_alignment_support(tensors_mapping_table_header_map, tensor_mapping_table) + markdown_content += markdown_table_with_alignment_support( + tensors_mapping_table_header_map, tensor_mapping_table + ) markdown_content += "\n" for group in tensor_prefix_order: @@ -384,35 +484,81 @@ def escape_markdown_inline_code(value_string): prettify_element_count_size: int = 1 prettify_dimension_max_widths: dict[int, int] = {} for tensor in tensors: - prettify_element_est_count_size = max(prettify_element_est_count_size, len(str(element_count_rounded_notation(tensor.n_elements)))) - prettify_element_count_size = max(prettify_element_count_size, len(str(tensor.n_elements))) - for i, dimension_size in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape))): - prettify_dimension_max_widths[i] = max(prettify_dimension_max_widths.get(i,1), len(str(dimension_size))) + prettify_element_est_count_size = max( + prettify_element_est_count_size, + len(str(element_count_rounded_notation(tensor.n_elements))), + ) + prettify_element_count_size = max( + prettify_element_count_size, len(str(tensor.n_elements)) + ) + for i, dimension_size in enumerate( + list(tensor.shape) + [1] * (4 - len(tensor.shape)) + ): + prettify_dimension_max_widths[i] = max( + prettify_dimension_max_widths.get(i, 1), + len(str(dimension_size)), + ) # Generate Tensor Layer Table Content tensor_dump_table: list[dict[str, str | int]] = [] for tensor in tensors: - human_friendly_name = translate_tensor_name(tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)")) - pretty_dimension = ' x '.join(f'{str(d):>{prettify_dimension_max_widths[i]}}' for i, d in enumerate(list(tensor.shape) + [1] * (4 - len(tensor.shape)))) + human_friendly_name = translate_tensor_name( + tensor.name.replace(".weight", ".(W)").replace(".bias", ".(B)") + ) + pretty_dimension = " x ".join( + f"{str(d):>{prettify_dimension_max_widths[i]}}" + for i, d in enumerate( + list(tensor.shape) + [1] * (4 - len(tensor.shape)) + ) + ) element_count_est = f"({element_count_rounded_notation(tensor.n_elements):>{prettify_element_est_count_size}})" element_count_string = f"{element_count_est} {tensor.n_elements:>{prettify_element_count_size}}" type_name_string = f"{tensor.tensor_type.name}" - tensor_dump_table.append({"t_id":tensor_name_to_key[tensor.name], "layer_name":tensor.name, "human_layer_name":human_friendly_name, "element_count":element_count_string, "pretty_dimension":pretty_dimension, "tensor_type":type_name_string}) + tensor_dump_table.append( + { + "t_id": tensor_name_to_key[tensor.name], + "layer_name": tensor.name, + "human_layer_name": human_friendly_name, + "element_count": element_count_string, + "pretty_dimension": pretty_dimension, + "tensor_type": type_name_string, + } + ) tensor_dump_table_header_map = [ - {'key_name':'t_id', 'header_name':'T_ID', 'align':'right'}, - {'key_name':'layer_name', 'header_name':'Tensor Layer Name', 'align':'left'}, - {'key_name':'human_layer_name', 'header_name':'Human Friendly Tensor Layer Name', 'align':'left'}, - {'key_name':'element_count', 'header_name':'Elements', 'align':'left'}, - {'key_name':'pretty_dimension', 'header_name':'Shape', 'align':'left'}, - {'key_name':'tensor_type', 'header_name':'Type', 'align':'left'}, + {"key_name": "t_id", "header_name": "T_ID", "align": "right"}, + { + "key_name": "layer_name", + "header_name": "Tensor Layer Name", + "align": "left", + }, + { + "key_name": "human_layer_name", + "header_name": "Human Friendly Tensor Layer Name", + "align": "left", + }, + { + "key_name": "element_count", + "header_name": "Elements", + "align": "left", + }, + { + "key_name": "pretty_dimension", + "header_name": "Shape", + "align": "left", + }, + {"key_name": "tensor_type", "header_name": "Type", "align": "left"}, ] - markdown_content += markdown_table_with_alignment_support(tensor_dump_table_header_map, tensor_dump_table) + markdown_content += markdown_table_with_alignment_support( + tensor_dump_table_header_map, tensor_dump_table + ) markdown_content += "\n" markdown_content += f"- Total elements in {group}: ({element_count_rounded_notation(group_elements):>4}) {group_elements}\n" - markdown_content += f"- Percentage of total elements: {group_percentage:.2f}%\n" + markdown_content += ( + f"- Percentage of total elements: {group_percentage:.2f}%\n" + ) markdown_content += "\n\n" print(markdown_content) # noqa: NP100 @@ -420,23 +566,44 @@ def escape_markdown_inline_code(value_string): def main() -> None: parser = argparse.ArgumentParser(description="Dump GGUF file metadata") - parser.add_argument("model", type=str, help="GGUF format model filename") - parser.add_argument("--no-tensors", action="store_true", help="Don't dump tensor metadata") - parser.add_argument("--json", action="store_true", help="Produce JSON output") - parser.add_argument("--json-array", action="store_true", help="Include full array values in JSON output (long)") - parser.add_argument("--data-offset", action="store_true", help="Start of data offset") - parser.add_argument("--data-alignment", action="store_true", help="Data alignment applied globally to data field") - parser.add_argument("--markdown", action="store_true", help="Produce markdown output") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument( + "--no-tensors", action="store_true", help="Don't dump tensor metadata" + ) + parser.add_argument("--json", action="store_true", help="Produce JSON output") + parser.add_argument( + "--json-array", + action="store_true", + help="Include full array values in JSON output (long)", + ) + parser.add_argument( + "--data-offset", action="store_true", help="Start of data offset" + ) + parser.add_argument( + "--data-alignment", + action="store_true", + help="Data alignment applied globally to data field", + ) + parser.add_argument( + "--markdown", action="store_true", help="Produce markdown output" + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - if not args.json and not args.markdown and not args.data_offset and not args.data_alignment: - logger.info(f'* Loading: {args.model}') + if ( + not args.json + and not args.markdown + and not args.data_offset + and not args.data_alignment + ): + logger.info(f"* Loading: {args.model}") - reader = GGUFReader(args.model, 'r') + reader = GGUFReader(args.model, "r") if args.json: dump_metadata_json(reader, args) @@ -450,5 +617,5 @@ def main() -> None: dump_metadata(reader, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_editor_gui.py b/smallthinker/gguf-py/gguf/scripts/gguf_editor_gui.py index 05f4db0f..a21d693e 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_editor_gui.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_editor_gui.py @@ -13,16 +13,33 @@ import numpy as np from PySide6.QtWidgets import ( - QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, - QPushButton, QLabel, QLineEdit, QFileDialog, QTableWidget, - QTableWidgetItem, QComboBox, QMessageBox, QTabWidget, - QTextEdit, QFormLayout, - QHeaderView, QDialog, QDialogButtonBox + QApplication, + QMainWindow, + QWidget, + QVBoxLayout, + QHBoxLayout, + QPushButton, + QLabel, + QLineEdit, + QFileDialog, + QTableWidget, + QTableWidgetItem, + QComboBox, + QMessageBox, + QTabWidget, + QTextEdit, + QFormLayout, + QHeaderView, + QDialog, + QDialogButtonBox, ) from PySide6.QtCore import Qt # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import gguf @@ -43,7 +60,7 @@ TOKENIZER_LINKED_KEYS = [ gguf.Keys.Tokenizer.LIST, gguf.Keys.Tokenizer.TOKEN_TYPE, - gguf.Keys.Tokenizer.SCORES + gguf.Keys.Tokenizer.SCORES, ] @@ -79,7 +96,9 @@ def __init__(self, tokens, token_types, scores, parent=None): # Add page controls self.page_size = 100 # Show 100 items per page self.current_page = 0 - self.total_pages = max(1, (len(self.tokens) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.tokens) + self.page_size - 1) // self.page_size + ) self.page_label = QLabel(f"Page 1 of {self.total_pages}") filter_layout.addWidget(self.page_label) @@ -98,10 +117,18 @@ def __init__(self, tokens, token_types, scores, parent=None): self.tokens_table = QTableWidget() self.tokens_table.setColumnCount(4) self.tokens_table.setHorizontalHeaderLabels(["Index", "Token", "Type", "Score"]) - self.tokens_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) - self.tokens_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - self.tokens_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) - self.tokens_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + self.tokens_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + self.tokens_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.Stretch + ) + self.tokens_table.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) + self.tokens_table.horizontalHeader().setSectionResizeMode( + 3, QHeaderView.ResizeMode.ResizeToContents + ) layout.addWidget(self.tokens_table) @@ -121,7 +148,9 @@ def __init__(self, tokens, token_types, scores, parent=None): layout.addLayout(controls_layout) # Buttons - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(self.accept) buttons.rejected.connect(self.reject) layout.addWidget(buttons) @@ -147,7 +176,9 @@ def apply_filter(self): self.filtered_indices.append(i) # Reset to first page and reload - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) self.current_page = 0 self.page_label.setText(f"Page 1 of {self.total_pages}") self.load_page() @@ -156,14 +187,18 @@ def previous_page(self): """Go to the previous page of results.""" if self.current_page > 0: self.current_page -= 1 - self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.page_label.setText( + f"Page {self.current_page + 1} of {self.total_pages}" + ) self.load_page() def next_page(self): """Go to the next page of results.""" if self.current_page < self.total_pages - 1: self.current_page += 1 - self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.page_label.setText( + f"Page {self.current_page + 1} of {self.total_pages}" + ) self.load_page() def load_page(self): @@ -182,7 +217,9 @@ def load_page(self): # Index index_item = QTableWidgetItem(str(orig_idx)) - index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setData( + Qt.ItemDataRole.UserRole, orig_idx + ) # Store original index index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) self.tokens_table.setItem(row, 0, index_item) @@ -191,7 +228,9 @@ def load_page(self): self.tokens_table.setItem(row, 1, token_item) # Token Type - token_type = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + token_type = ( + self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + ) try: enum_val = TokenType(token_type) display_text = f"{enum_val.name} ({token_type})" @@ -223,7 +262,9 @@ def handle_cell_double_click(self, row, column): def edit_token_type(self, row, orig_idx): """Edit a token type using a dialog with a dropdown of all enum options.""" - current_value = self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + current_value = ( + self.token_types[orig_idx] if orig_idx < len(self.token_types) else 0 + ) # Create a dialog with enum options dialog = QDialog(self) @@ -244,7 +285,9 @@ def edit_token_type(self, row, orig_idx): layout.addWidget(combo) - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(dialog.accept) buttons.rejected.connect(dialog.reject) layout.addWidget(buttons) @@ -279,7 +322,9 @@ def add_token(self): self.filtered_indices.append(orig_idx) # Update pagination - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) # Go to the last page to show the new item self.current_page = self.total_pages - 1 @@ -325,7 +370,9 @@ def remove_selected(self): self.filtered_indices.append(i) # Update pagination - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) self.current_page = min(self.current_page, self.total_pages - 1) self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") @@ -379,7 +426,9 @@ def __init__(self, array_values, element_type, key=None, parent=None): # Add page controls for large arrays self.page_size = 100 # Show 100 items per page self.current_page = 0 - self.total_pages = max(1, (len(array_values) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(array_values) + self.page_size - 1) // self.page_size + ) self.page_label = QLabel(f"Page 1 of {self.total_pages}") filter_layout.addWidget(self.page_label) @@ -401,14 +450,24 @@ def __init__(self, array_values, element_type, key=None, parent=None): if self.enum_type is not None: self.items_table.setColumnCount(3) self.items_table.setHorizontalHeaderLabels(["Index", "Value", "Actions"]) - self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) - self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) - self.items_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) + self.items_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + self.items_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.Stretch + ) + self.items_table.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) else: self.items_table.setColumnCount(2) self.items_table.setHorizontalHeaderLabels(["Index", "Value"]) - self.items_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents) - self.items_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.Stretch) + self.items_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.ResizeToContents + ) + self.items_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.Stretch + ) layout.addWidget(self.items_table) @@ -434,7 +493,9 @@ def __init__(self, array_values, element_type, key=None, parent=None): layout.addLayout(controls_layout) # Buttons - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(self.accept) buttons.rejected.connect(self.reject) layout.addWidget(buttons) @@ -473,7 +534,9 @@ def apply_filter(self): self.filtered_indices.append(i) # Reset to first page and reload - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) self.current_page = 0 self.page_label.setText(f"Page 1 of {self.total_pages}") self.load_page() @@ -482,14 +545,18 @@ def previous_page(self): """Go to the previous page of results.""" if self.current_page > 0: self.current_page -= 1 - self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.page_label.setText( + f"Page {self.current_page + 1} of {self.total_pages}" + ) self.load_page() def next_page(self): """Go to the next page of results.""" if self.current_page < self.total_pages - 1: self.current_page += 1 - self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") + self.page_label.setText( + f"Page {self.current_page + 1} of {self.total_pages}" + ) self.load_page() def load_page(self): @@ -509,7 +576,9 @@ def load_page(self): # Index index_item = QTableWidgetItem(str(orig_idx)) - index_item.setData(Qt.ItemDataRole.UserRole, orig_idx) # Store original index + index_item.setData( + Qt.ItemDataRole.UserRole, orig_idx + ) # Store original index index_item.setFlags(index_item.flags() & ~Qt.ItemFlag.ItemIsEditable) self.items_table.setItem(row, 0, index_item) @@ -556,7 +625,12 @@ def edit_array_enum_value(self): # Get the original index from the table item orig_item = self.items_table.item(row, 0) new_item = self.items_table.item(row, 1) - if orig_item and new_item and self.enum_type and self.edit_enum_value(row, self.enum_type): + if ( + orig_item + and new_item + and self.enum_type + and self.edit_enum_value(row, self.enum_type) + ): orig_idx = orig_item.data(Qt.ItemDataRole.UserRole) new_value = new_item.data(Qt.ItemDataRole.UserRole) # Update the stored value in the array @@ -573,7 +647,9 @@ def bulk_edit_selected(self): selected_rows.add(item.row()) if not selected_rows: - QMessageBox.information(self, "No Selection", "Please select at least one row to edit.") + QMessageBox.information( + self, "No Selection", "Please select at least one row to edit." + ) return # Create a dialog with enum options @@ -589,7 +665,9 @@ def bulk_edit_selected(self): layout.addWidget(combo) - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(dialog.accept) buttons.rejected.connect(dialog.reject) layout.addWidget(buttons) @@ -631,7 +709,9 @@ def add_item(self): self.filtered_indices.append(orig_idx) # Update pagination - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) # Go to the last page to show the new item self.current_page = self.total_pages - 1 @@ -685,7 +765,9 @@ def remove_selected(self): self.filtered_indices.append(i) # Update pagination - self.total_pages = max(1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size) + self.total_pages = max( + 1, (len(self.filtered_indices) + self.page_size - 1) // self.page_size + ) self.current_page = min(self.current_page, self.total_pages - 1) self.page_label.setText(f"Page {self.current_page + 1} of {self.total_pages}") @@ -726,7 +808,9 @@ def edit_enum_value(self, row: int, enum_type: Type[enum.Enum]): layout.addWidget(combo) - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(dialog.accept) buttons.rejected.connect(dialog.reject) layout.addWidget(buttons) @@ -776,7 +860,9 @@ def __init__(self, parent=None): layout.addLayout(form_layout) - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(self.accept) buttons.rejected.connect(self.reject) layout.addWidget(buttons) @@ -802,7 +888,7 @@ def get_data(self) -> Tuple[str, GGUFValueType, Any]: elif value_type == GGUFValueType.FLOAT32: value = np.float32(float(value_text)) elif value_type == GGUFValueType.BOOL: - value = value_text.lower() in ('true', 'yes', '1') + value = value_text.lower() in ("true", "yes", "1") elif value_type == GGUFValueType.STRING: value = value_text else: @@ -860,11 +946,21 @@ def setup_ui(self): # Metadata table self.metadata_table = QTableWidget() self.metadata_table.setColumnCount(4) - self.metadata_table.setHorizontalHeaderLabels(["Key", "Type", "Value", "Actions"]) - self.metadata_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) - self.metadata_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) - self.metadata_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.Stretch) - self.metadata_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) + self.metadata_table.setHorizontalHeaderLabels( + ["Key", "Type", "Value", "Actions"] + ) + self.metadata_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.Stretch + ) + self.metadata_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.ResizeToContents + ) + self.metadata_table.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeMode.Stretch + ) + self.metadata_table.horizontalHeader().setSectionResizeMode( + 3, QHeaderView.ResizeMode.ResizeToContents + ) metadata_layout.addWidget(self.metadata_table) # Metadata controls @@ -884,12 +980,24 @@ def setup_ui(self): self.tensors_table = QTableWidget() self.tensors_table.setColumnCount(5) - self.tensors_table.setHorizontalHeaderLabels(["Name", "Type", "Shape", "Elements", "Size (bytes)"]) - self.tensors_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.Stretch) - self.tensors_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents) - self.tensors_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeMode.ResizeToContents) - self.tensors_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.ResizeMode.ResizeToContents) - self.tensors_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeMode.ResizeToContents) + self.tensors_table.setHorizontalHeaderLabels( + ["Name", "Type", "Shape", "Elements", "Size (bytes)"] + ) + self.tensors_table.horizontalHeader().setSectionResizeMode( + 0, QHeaderView.ResizeMode.Stretch + ) + self.tensors_table.horizontalHeader().setSectionResizeMode( + 1, QHeaderView.ResizeMode.ResizeToContents + ) + self.tensors_table.horizontalHeader().setSectionResizeMode( + 2, QHeaderView.ResizeMode.ResizeToContents + ) + self.tensors_table.horizontalHeader().setSectionResizeMode( + 3, QHeaderView.ResizeMode.ResizeToContents + ) + self.tensors_table.horizontalHeader().setSectionResizeMode( + 4, QHeaderView.ResizeMode.ResizeToContents + ) tensors_layout.addWidget(self.tensors_table) # Add tabs to tab widget @@ -907,7 +1015,7 @@ def load_file(self, file_path): self.statusBar().showMessage(f"Loading {file_path}...") QApplication.processEvents() - self.reader = GGUFReader(file_path, 'r') + self.reader = GGUFReader(file_path, "r") self.current_file = file_path self.file_path_edit.setText(file_path) @@ -944,7 +1052,7 @@ def load_metadata(self): # Disconnect to prevent triggering during loading if self.on_metadata_changed_is_connected: with warnings.catch_warnings(): - warnings.filterwarnings('ignore') + warnings.filterwarnings("ignore") self.metadata_table.itemChanged.disconnect(self.on_metadata_changed) self.on_metadata_changed_is_connected = False @@ -966,7 +1074,7 @@ def load_metadata(self): enum_type = self.get_enum_for_key(key) if enum_type is not None and field.types[-1] == GGUFValueType.INT32: element_type = enum_type.__name__ - type_str = '[' * nest_count + element_type + ']' * nest_count + type_str = "[" * nest_count + element_type + "]" * nest_count else: type_str = str(field.types[0].name) # Check if this is an enum field @@ -1037,11 +1145,16 @@ def extract_array_values(self, field: ReaderField) -> list: if curr_type == GGUFValueType.STRING: for element_pos in range(total_elements): - value_string = str(bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), encoding='utf-8') + value_string = str( + bytes(field.parts[-1 - (total_elements - element_pos - 1) * 2]), + encoding="utf-8", + ) array_values.append(value_string) elif self.reader and curr_type in self.reader.gguf_scalar_to_np: for element_pos in range(total_elements): - array_values.append(field.parts[-1 - (total_elements - element_pos - 1)][0]) + array_values.append( + field.parts[-1 - (total_elements - element_pos - 1)][0] + ) return array_values @@ -1066,7 +1179,7 @@ def format_field_value(self, field: ReaderField) -> str: if len(field.types) == 1: curr_type = field.types[0] if curr_type == GGUFValueType.STRING: - return str(bytes(field.parts[-1]), encoding='utf-8') + return str(bytes(field.parts[-1]), encoding="utf-8") elif self.reader and curr_type in self.reader.gguf_scalar_to_np: value = field.parts[-1][0] # Check if this field has an enum type @@ -1085,7 +1198,9 @@ def format_field_value(self, field: ReaderField) -> str: if enum_type is not None: array_elements = [] for i in range(render_element): - array_elements.append(self.format_enum_value(array_values[i], enum_type)) + array_elements.append( + self.format_enum_value(array_values[i], enum_type) + ) else: array_elements = [str(array_values[i]) for i in range(render_element)] @@ -1158,9 +1273,9 @@ def on_metadata_changed(self, item): converted_value = enum_val.value except (KeyError, AttributeError): # Check if it's a number or "NAME (value)" format - if '(' in new_value and ')' in new_value: + if "(" in new_value and ")" in new_value: # Extract the value from "NAME (value)" format - value_part = new_value.split('(')[1].split(')')[0].strip() + value_part = new_value.split("(")[1].split(")")[0].strip() converted_value = int(value_part) else: # Try to convert directly to int @@ -1184,7 +1299,8 @@ def on_metadata_changed(self, item): self, f"Invalid Enum Value ({e})", f"'{new_value}' is not a valid {enum_type.__name__} value.\n" - f"Valid values are: {', '.join(v.name for v in enum_type)}") + f"Valid values are: {', '.join(v.name for v in enum_type)}", + ) # Revert to original value original_value = self.format_field_value(field) @@ -1208,7 +1324,7 @@ def on_metadata_changed(self, item): elif value_type == GGUFValueType.FLOAT32: converted_value = np.float32(float(new_value)) elif value_type == GGUFValueType.BOOL: - converted_value = new_value.lower() in ('true', 'yes', '1') + converted_value = new_value.lower() in ("true", "yes", "1") elif value_type == GGUFValueType.STRING: converted_value = new_value else: @@ -1221,7 +1337,11 @@ def on_metadata_changed(self, item): self.statusBar().showMessage(f"Changed {key} to {new_value}") except ValueError: - QMessageBox.warning(self, "Invalid Value", f"The value '{new_value}' is not valid for type {value_type.name}") + QMessageBox.warning( + self, + "Invalid Value", + f"The value '{new_value}' is not valid for type {value_type.name}", + ) # Revert to original value original_value = self.format_field_value(field) @@ -1233,9 +1353,11 @@ def remove_metadata(self): row = button.property("row") reply = QMessageBox.question( - self, "Confirm Removal", + self, + "Confirm Removal", f"Are you sure you want to remove the metadata key '{key}'?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.No + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, ) if reply == QMessageBox.StandardButton.Yes: @@ -1287,7 +1409,9 @@ def edit_metadata_enum(self): layout.addWidget(combo) - buttons = QDialogButtonBox(QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel) + buttons = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel + ) buttons.accepted.connect(dialog.accept) buttons.rejected.connect(dialog.reject) layout.addWidget(buttons) @@ -1337,7 +1461,10 @@ def edit_array_metadata(self): new_values = dialog.get_array_values() # Store the change - self.metadata_changes[key] = (GGUFValueType.ARRAY, (element_type, new_values)) + self.metadata_changes[key] = ( + GGUFValueType.ARRAY, + (element_type, new_values), + ) self.modified = True # Update display @@ -1364,7 +1491,9 @@ def edit_tokenizer_metadata(self, trigger_key): # Extract values from each field tokens = self.extract_array_values(tokens_field) if tokens_field else [] - token_types = self.extract_array_values(token_types_field) if token_types_field else [] + token_types = ( + self.extract_array_values(token_types_field) if token_types_field else [] + ) scores = self.extract_array_values(scores_field) if scores_field else [] # Apply any pending changes @@ -1384,26 +1513,28 @@ def edit_tokenizer_metadata(self, trigger_key): if tokens_field: self.metadata_changes[gguf.Keys.Tokenizer.LIST] = ( GGUFValueType.ARRAY, - (tokens_field.types[1], new_tokens) + (tokens_field.types[1], new_tokens), ) if token_types_field: self.metadata_changes[gguf.Keys.Tokenizer.TOKEN_TYPE] = ( GGUFValueType.ARRAY, - (token_types_field.types[1], new_token_types) + (token_types_field.types[1], new_token_types), ) if scores_field: self.metadata_changes[gguf.Keys.Tokenizer.SCORES] = ( GGUFValueType.ARRAY, - (scores_field.types[1], new_scores) + (scores_field.types[1], new_scores), ) self.modified = True # Update display for all three fields self.update_tokenizer_display(gguf.Keys.Tokenizer.LIST, new_tokens) - self.update_tokenizer_display(gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types) + self.update_tokenizer_display( + gguf.Keys.Tokenizer.TOKEN_TYPE, new_token_types + ) self.update_tokenizer_display(gguf.Keys.Tokenizer.SCORES, new_scores) self.statusBar().showMessage("Updated tokenizer data") @@ -1432,7 +1563,9 @@ def add_metadata(self): for row in range(self.metadata_table.rowCount()): orig_item = self.metadata_table.item(row, 0) if orig_item and orig_item.text() == key: - QMessageBox.warning(self, "Duplicate Key", f"Key '{key}' already exists") + QMessageBox.warning( + self, "Duplicate Key", f"Key '{key}' already exists" + ) return # Add to table @@ -1478,7 +1611,11 @@ def save_file(self): QMessageBox.warning(self, "No File Open", "Please open a GGUF file first") return - if not self.modified and not self.metadata_changes and not self.metadata_to_remove: + if ( + not self.modified + and not self.metadata_changes + and not self.metadata_to_remove + ): QMessageBox.information(self, "No Changes", "No changes to save") return @@ -1494,7 +1631,7 @@ def save_file(self): QApplication.processEvents() # Get architecture and endianness from the original file - arch = 'unknown' + arch = "unknown" field = self.reader.get_field(gguf.Keys.General.ARCHITECTURE) if field: arch = field.contents() @@ -1513,7 +1650,10 @@ def save_file(self): # Copy metadata with changes for field in self.reader.fields.values(): # Skip virtual fields and fields written by GGUFWriter - if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): + if ( + field.name == gguf.Keys.General.ARCHITECTURE + or field.name.startswith("GGUF.") + ): continue # Skip fields marked for removal @@ -1535,7 +1675,9 @@ def save_file(self): sub_type = field.types[-1] if value is not None: - writer.add_key_value(field.name, value, value_type, sub_type=sub_type) + writer.add_key_value( + field.name, value, value_type, sub_type=sub_type + ) # Add new metadata for key, (value_type, value) in self.metadata_changes.items(): @@ -1552,7 +1694,12 @@ def save_file(self): # Add tensors (including data) for tensor in self.reader.tensors: - writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) + writer.add_tensor( + tensor.name, + tensor.data, + raw_shape=tensor.data.shape, + raw_dtype=tensor.tensor_type, + ) # Write header and metadata writer.open_output_file(Path(file_path)) @@ -1568,13 +1715,15 @@ def save_file(self): # Ask if user wants to open the new file reply = QMessageBox.question( - self, "Open Saved File", + self, + "Open Saved File", "Would you like to open the newly saved file?", - QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, QMessageBox.StandardButton.Yes + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.Yes, ) if reply == QMessageBox.StandardButton.Yes: - self.reader = GGUFReader(file_path, 'r') + self.reader = GGUFReader(file_path, "r") self.current_file = file_path self.file_path_edit.setText(file_path) @@ -1592,8 +1741,12 @@ def save_file(self): def main() -> None: parser = argparse.ArgumentParser(description="GUI GGUF Editor") - parser.add_argument("model_path", nargs="?", help="path to GGUF model file to load at startup") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument( + "model_path", nargs="?", help="path to GGUF model file to load at startup" + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) args = parser.parse_args() @@ -1605,17 +1758,18 @@ def main() -> None: # Load model if specified if args.model_path: - if os.path.isfile(args.model_path) and args.model_path.endswith('.gguf'): + if os.path.isfile(args.model_path) and args.model_path.endswith(".gguf"): window.load_file(args.model_path) else: logger.error(f"Invalid model path: {args.model_path}") QMessageBox.warning( window, "Invalid Model Path", - f"The specified file does not exist or is not a GGUF file: {args.model_path}") + f"The specified file does not exist or is not a GGUF file: {args.model_path}", + ) sys.exit(app.exec()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_hash.py b/smallthinker/gguf-py/gguf/scripts/gguf_hash.py index 3ef98992..12946027 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_hash.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_hash.py @@ -13,7 +13,10 @@ from tqdm import tqdm # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader # noqa: E402 @@ -22,12 +25,14 @@ logger = logging.getLogger("gguf-hash") # UUID_NAMESPACE_LLAMA_CPP = uuid.uuid5(uuid.NAMESPACE_URL, 'en.wikipedia.org/wiki/Llama.cpp') -UUID_NAMESPACE_LLAMA_CPP = uuid.UUID('ef001206-dadc-5f6d-a15f-3359e577d4e5') +UUID_NAMESPACE_LLAMA_CPP = uuid.UUID("ef001206-dadc-5f6d-a15f-3359e577d4e5") # For more information about what field.parts and field.data represent, # please see the comments in the modify_gguf.py example. -def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool) -> None: +def gguf_hash( + reader: GGUFReader, filename: str, disable_progress_bar: bool, no_layer: bool +) -> None: sha1 = hashlib.sha1() sha256 = hashlib.sha256() uuidv5_sha1 = hashlib.sha1() @@ -38,7 +43,9 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_ for n, tensor in enumerate(reader.tensors, 1): # We don't need these - if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + if tensor.name.endswith( + (".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq") + ): continue # Calculate Tensor Volume @@ -48,13 +55,21 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_ total_weights += sum_weights_in_tensor # Hash Progress Bar - bar = tqdm(desc="Hashing", total=total_weights, unit="weights", unit_scale=True, disable=disable_progress_bar) + bar = tqdm( + desc="Hashing", + total=total_weights, + unit="weights", + unit_scale=True, + disable=disable_progress_bar, + ) # Hashing Process for tensor in reader.tensors: # We don't need these - if tensor.name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")): + if tensor.name.endswith( + (".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq") + ): continue # Progressbar @@ -67,11 +82,19 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_ sha1_layer = hashlib.sha1() sha1_layer.update(tensor.data.data) - print("sha1 {0} {1}:{2}".format(sha1_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 + print( + "sha1 {0} {1}:{2}".format( + sha1_layer.hexdigest(), filename, tensor.name + ) + ) # noqa: NP100 sha256_layer = hashlib.sha256() sha256_layer.update(tensor.data.data) - print("sha256 {0} {1}:{2}".format(sha256_layer.hexdigest(), filename, tensor.name)) # noqa: NP100 + print( + "sha256 {0} {1}:{2}".format( + sha256_layer.hexdigest(), filename, tensor.name + ) + ) # noqa: NP100 sha1.update(tensor.data.data) sha256.update(tensor.data.data) @@ -81,22 +104,30 @@ def gguf_hash(reader: GGUFReader, filename: str, disable_progress_bar: bool, no_ bar.close() # Display Hash Output - print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100 - print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100 - print("uuid {0} {1}".format(uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename)) # noqa: NP100 + print("sha1 {0} {1}".format(sha1.hexdigest(), filename)) # noqa: NP100 + print("sha256 {0} {1}".format(sha256.hexdigest(), filename)) # noqa: NP100 + print( + "uuid {0} {1}".format( + uuid.UUID(bytes=uuidv5_sha1.digest()[:16], version=5), filename + ) + ) # noqa: NP100 def main() -> None: parser = argparse.ArgumentParser(description="Dump GGUF file metadata") - parser.add_argument("model", type=str, help="GGUF format model filename") - parser.add_argument("--no-layer", action="store_true", help="exclude per layer hash") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument( + "--no-layer", action="store_true", help="exclude per layer hash" + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) parser.add_argument("--progressbar", action="store_true", help="enable progressbar") args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - reader = GGUFReader(args.model, 'r') + reader = GGUFReader(args.model, "r") gguf_hash(reader, args.model, not args.progressbar, args.no_layer) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_new_metadata.py b/smallthinker/gguf-py/gguf/scripts/gguf_new_metadata.py index 63f23003..0e458542 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_new_metadata.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_new_metadata.py @@ -12,7 +12,10 @@ from typing import Any, Sequence, NamedTuple # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import gguf @@ -23,7 +26,7 @@ class MetadataDetails(NamedTuple): type: gguf.GGUFValueType value: Any - description: str = '' + description: str = "" sub_type: gguf.GGUFValueType | None = None @@ -42,20 +45,30 @@ def find_token(token_list: Sequence[int], token: str) -> Sequence[int]: return token_ids -def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new_metadata: dict[str, MetadataDetails], remove_metadata: Sequence[str]) -> None: +def copy_with_new_metadata( + reader: gguf.GGUFReader, + writer: gguf.GGUFWriter, + new_metadata: dict[str, MetadataDetails], + remove_metadata: Sequence[str], +) -> None: for field in reader.fields.values(): # Suppress virtual fields and fields written by GGUFWriter - if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith('GGUF.'): - logger.debug(f'Suppressing {field.name}') + if field.name == gguf.Keys.General.ARCHITECTURE or field.name.startswith( + "GGUF." + ): + logger.debug(f"Suppressing {field.name}") continue # Skip old chat templates if we have new ones - if field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: - logger.debug(f'Skipping {field.name}') + if ( + field.name.startswith(gguf.Keys.Tokenizer.CHAT_TEMPLATE) + and gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata + ): + logger.debug(f"Skipping {field.name}") continue if field.name in remove_metadata: - logger.debug(f'Removing {field.name}') + logger.debug(f"Removing {field.name}") continue val_type = field.types[0] @@ -64,16 +77,23 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new val = new_metadata.get(field.name, old_val) if field.name in new_metadata: - logger.debug(f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}') + logger.debug( + f'Modifying {field.name}: "{old_val.value}" -> "{val.value}" {val.description}' + ) del new_metadata[field.name] elif val.value is not None: - logger.debug(f'Copying {field.name}') + logger.debug(f"Copying {field.name}") if val.value is not None: - writer.add_key_value(field.name, val.value, val.type, sub_type=sub_type if val.sub_type is None else val.sub_type) + writer.add_key_value( + field.name, + val.value, + val.type, + sub_type=sub_type if val.sub_type is None else val.sub_type, + ) if gguf.Keys.Tokenizer.CHAT_TEMPLATE in new_metadata: - logger.debug('Adding chat template(s)') + logger.debug("Adding chat template(s)") writer.add_chat_template(new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE].value) del new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] @@ -85,7 +105,13 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new for tensor in reader.tensors: total_bytes += tensor.n_bytes - writer.add_tensor_info(tensor.name, tensor.data.shape, tensor.data.dtype, tensor.data.nbytes, tensor.tensor_type) + writer.add_tensor_info( + tensor.name, + tensor.data.shape, + tensor.data.dtype, + tensor.data.nbytes, + tensor.tensor_type, + ) bar = tqdm(desc="Writing", total=total_bytes, unit="byte", unit_scale=True) @@ -101,22 +127,78 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new def main() -> None: - tokenizer_metadata = (getattr(gguf.Keys.Tokenizer, n) for n in gguf.Keys.Tokenizer.__dict__.keys() if not n.startswith('_')) - token_names = dict((n.split('.')[-1][:-len('_token_id')], n) for n in tokenizer_metadata if n.endswith('_token_id')) - - parser = argparse.ArgumentParser(description="Make a copy of a GGUF file with new metadata") - parser.add_argument("input", type=Path, help="GGUF format model input filename") - parser.add_argument("output", type=Path, help="GGUF format model output filename") - parser.add_argument("--general-name", type=str, help="The models general.name", metavar='"name"') - parser.add_argument("--general-description", type=str, help="The models general.description", metavar='"Description ..."') - parser.add_argument("--chat-template", type=str, help="Chat template string (or JSON string containing templates)", metavar='"{% ... %} ..."') - parser.add_argument("--chat-template-config", type=Path, help="Config file containing chat template(s)", metavar='tokenizer_config.json') - parser.add_argument("--pre-tokenizer", type=str, help="The models tokenizer.ggml.pre", metavar='"pre tokenizer"') - parser.add_argument("--remove-metadata", action="append", type=str, help="Remove metadata (by key name) from output model", metavar='general.url') - parser.add_argument("--special-token", action="append", type=str, help="Special token by value", nargs=2, metavar=(' | '.join(token_names.keys()), '""')) - parser.add_argument("--special-token-by-id", action="append", type=str, help="Special token by id", nargs=2, metavar=(' | '.join(token_names.keys()), '0')) - parser.add_argument("--force", action="store_true", help="Bypass warnings without confirmation") - parser.add_argument("--verbose", action="store_true", help="Increase output verbosity") + tokenizer_metadata = ( + getattr(gguf.Keys.Tokenizer, n) + for n in gguf.Keys.Tokenizer.__dict__.keys() + if not n.startswith("_") + ) + token_names = dict( + (n.split(".")[-1][: -len("_token_id")], n) + for n in tokenizer_metadata + if n.endswith("_token_id") + ) + + parser = argparse.ArgumentParser( + description="Make a copy of a GGUF file with new metadata" + ) + parser.add_argument("input", type=Path, help="GGUF format model input filename") + parser.add_argument("output", type=Path, help="GGUF format model output filename") + parser.add_argument( + "--general-name", type=str, help="The models general.name", metavar='"name"' + ) + parser.add_argument( + "--general-description", + type=str, + help="The models general.description", + metavar='"Description ..."', + ) + parser.add_argument( + "--chat-template", + type=str, + help="Chat template string (or JSON string containing templates)", + metavar='"{% ... %} ..."', + ) + parser.add_argument( + "--chat-template-config", + type=Path, + help="Config file containing chat template(s)", + metavar="tokenizer_config.json", + ) + parser.add_argument( + "--pre-tokenizer", + type=str, + help="The models tokenizer.ggml.pre", + metavar='"pre tokenizer"', + ) + parser.add_argument( + "--remove-metadata", + action="append", + type=str, + help="Remove metadata (by key name) from output model", + metavar="general.url", + ) + parser.add_argument( + "--special-token", + action="append", + type=str, + help="Special token by value", + nargs=2, + metavar=(" | ".join(token_names.keys()), '""'), + ) + parser.add_argument( + "--special-token-by-id", + action="append", + type=str, + help="Special token by id", + nargs=2, + metavar=(" | ".join(token_names.keys()), "0"), + ) + parser.add_argument( + "--force", action="store_true", help="Bypass warnings without confirmation" + ) + parser.add_argument( + "--verbose", action="store_true", help="Increase output verbosity" + ) args = parser.parse_args(None if len(sys.argv) > 2 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) @@ -125,38 +207,57 @@ def main() -> None: remove_metadata = args.remove_metadata or [] if args.general_name: - new_metadata[gguf.Keys.General.NAME] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_name) + new_metadata[gguf.Keys.General.NAME] = MetadataDetails( + gguf.GGUFValueType.STRING, args.general_name + ) if args.general_description: - new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails(gguf.GGUFValueType.STRING, args.general_description) + new_metadata[gguf.Keys.General.DESCRIPTION] = MetadataDetails( + gguf.GGUFValueType.STRING, args.general_description + ) if args.chat_template: - new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, json.loads(args.chat_template) if args.chat_template.startswith('[') else args.chat_template) + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails( + gguf.GGUFValueType.STRING, + ( + json.loads(args.chat_template) + if args.chat_template.startswith("[") + else args.chat_template + ), + ) if args.chat_template_config: - with open(args.chat_template_config, 'r') as fp: + with open(args.chat_template_config, "r") as fp: config = json.load(fp) - template = config.get('chat_template') + template = config.get("chat_template") if template: - new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails(gguf.GGUFValueType.STRING, template) + new_metadata[gguf.Keys.Tokenizer.CHAT_TEMPLATE] = MetadataDetails( + gguf.GGUFValueType.STRING, template + ) if args.pre_tokenizer: - new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails(gguf.GGUFValueType.STRING, args.pre_tokenizer) + new_metadata[gguf.Keys.Tokenizer.PRE] = MetadataDetails( + gguf.GGUFValueType.STRING, args.pre_tokenizer + ) if remove_metadata: - logger.warning('*** Warning *** Warning *** Warning **') - logger.warning('* Most metadata is required for a fully functional GGUF file,') - logger.warning('* removing crucial metadata may result in a corrupt output file!') + logger.warning("*** Warning *** Warning *** Warning **") + logger.warning("* Most metadata is required for a fully functional GGUF file,") + logger.warning( + "* removing crucial metadata may result in a corrupt output file!" + ) if not args.force: - logger.warning('* Enter exactly YES if you are positive you want to proceed:') - response = input('YES, I am sure> ') - if response != 'YES': + logger.warning( + "* Enter exactly YES if you are positive you want to proceed:" + ) + response = input("YES, I am sure> ") + if response != "YES": logger.info("You didn't enter YES. Okay then, see ya!") sys.exit(0) - logger.info(f'* Loading: {args.input}') - reader = gguf.GGUFReader(args.input, 'r') + logger.info(f"* Loading: {args.input}") + reader = gguf.GGUFReader(args.input, "r") arch = get_field_data(reader, gguf.Keys.General.ARCHITECTURE) @@ -167,11 +268,15 @@ def main() -> None: logger.warning(f'Unknown special token "{name}", ignoring...') else: ids = find_token(token_list, token) - new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, ids[0], f'= {token}') + new_metadata[token_names[name]] = MetadataDetails( + gguf.GGUFValueType.UINT32, ids[0], f"= {token}" + ) if len(ids) > 1: - logger.warning(f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:') - logger.warning(', '.join(str(i) for i in ids)) + logger.warning( + f'Multiple "{token}" tokens found, choosing ID {ids[0]}, use --special-token-by-id if you want another:' + ) + logger.warning(", ".join(str(i) for i in ids)) for name, id_string in args.special_token_by_id or []: if name not in token_names: @@ -182,29 +287,33 @@ def main() -> None: id_int = int(id_string) if id_int >= 0 and id_int < len(token_list): - new_metadata[token_names[name]] = MetadataDetails(gguf.GGUFValueType.UINT32, id_int, f'= {token_list[id_int]}') + new_metadata[token_names[name]] = MetadataDetails( + gguf.GGUFValueType.UINT32, id_int, f"= {token_list[id_int]}" + ) else: - raise LookupError(f'Token ID {id_int} is not within token list!') + raise LookupError(f"Token ID {id_int} is not within token list!") if os.path.isfile(args.output) and not args.force: - logger.warning('*** Warning *** Warning *** Warning **') - logger.warning(f'* The "{args.output}" GGUF file already exists, it will be overwritten!') - logger.warning('* Enter exactly YES if you are positive you want to proceed:') - response = input('YES, I am sure> ') - if response != 'YES': + logger.warning("*** Warning *** Warning *** Warning **") + logger.warning( + f'* The "{args.output}" GGUF file already exists, it will be overwritten!' + ) + logger.warning("* Enter exactly YES if you are positive you want to proceed:") + response = input("YES, I am sure> ") + if response != "YES": logger.info("You didn't enter YES. Okay then, see ya!") sys.exit(0) - logger.info(f'* Writing: {args.output}') + logger.info(f"* Writing: {args.output}") writer = gguf.GGUFWriter(args.output, arch=arch, endianess=reader.endianess) alignment = get_field_data(reader, gguf.Keys.General.ALIGNMENT) if alignment is not None: - logger.debug(f'Setting custom alignment: {alignment}') + logger.debug(f"Setting custom alignment: {alignment}") writer.data_alignment = alignment copy_with_new_metadata(reader, writer, new_metadata, remove_metadata) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/gguf-py/gguf/scripts/gguf_set_metadata.py b/smallthinker/gguf-py/gguf/scripts/gguf_set_metadata.py index f5809c35..f613ffca 100755 --- a/smallthinker/gguf-py/gguf/scripts/gguf_set_metadata.py +++ b/smallthinker/gguf-py/gguf/scripts/gguf_set_metadata.py @@ -6,7 +6,10 @@ from pathlib import Path # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from gguf import GGUFReader # noqa: E402 @@ -15,8 +18,8 @@ def minimal_example(filename: str) -> None: - reader = GGUFReader(filename, 'r+') - field = reader.fields['tokenizer.ggml.bos_token_id'] + reader = GGUFReader(filename, "r+") + field = reader.fields["tokenizer.ggml.bos_token_id"] if field is None: return part_index = field.data[0] @@ -44,52 +47,68 @@ def minimal_example(filename: str) -> None: def set_metadata(reader: GGUFReader, args: argparse.Namespace) -> None: field = reader.get_field(args.key) if field is None: - logger.error(f'! Field {repr(args.key)} not found') + logger.error(f"! Field {repr(args.key)} not found") sys.exit(1) # Note that field.types is a list of types. This is because the GGUF # format supports arrays. For example, an array of UINT32 would # look like [GGUFValueType.ARRAY, GGUFValueType.UINT32] handler = reader.gguf_scalar_to_np.get(field.types[0]) if field.types else None if handler is None: - logger.error(f'! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}') + logger.error( + f"! This tool only supports changing simple values, {repr(args.key)} has unsupported type {field.types}" + ) sys.exit(1) current_value = field.parts[field.data[0]][0] new_value = handler(args.value) - logger.info(f'* Preparing to change field {repr(args.key)} from {current_value} to {new_value}') + logger.info( + f"* Preparing to change field {repr(args.key)} from {current_value} to {new_value}" + ) if current_value == new_value: - logger.info(f'- Key {repr(args.key)} already set to requested value {current_value}') + logger.info( + f"- Key {repr(args.key)} already set to requested value {current_value}" + ) sys.exit(0) if args.dry_run: sys.exit(0) if not args.force: - logger.warning('*** Warning *** Warning *** Warning **') - logger.warning('* Changing fields in a GGUF file can make it unusable. Proceed at your own risk.') - logger.warning('* Enter exactly YES if you are positive you want to proceed:') - response = input('YES, I am sure> ') - if response != 'YES': + logger.warning("*** Warning *** Warning *** Warning **") + logger.warning( + "* Changing fields in a GGUF file can make it unusable. Proceed at your own risk." + ) + logger.warning("* Enter exactly YES if you are positive you want to proceed:") + response = input("YES, I am sure> ") + if response != "YES": logger.info("You didn't enter YES. Okay then, see ya!") sys.exit(0) field.parts[field.data[0]][0] = new_value - logger.info('* Field changed. Successful completion.') + logger.info("* Field changed. Successful completion.") def main() -> None: - parser = argparse.ArgumentParser(description="Set a simple value in GGUF file metadata") - parser.add_argument("model", type=str, help="GGUF format model filename") - parser.add_argument("key", type=str, help="Metadata key to set") - parser.add_argument("value", type=str, help="Metadata value to set") - parser.add_argument("--dry-run", action="store_true", help="Don't actually change anything") - parser.add_argument("--force", action="store_true", help="Change the field without confirmation") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser = argparse.ArgumentParser( + description="Set a simple value in GGUF file metadata" + ) + parser.add_argument("model", type=str, help="GGUF format model filename") + parser.add_argument("key", type=str, help="Metadata key to set") + parser.add_argument("value", type=str, help="Metadata value to set") + parser.add_argument( + "--dry-run", action="store_true", help="Don't actually change anything" + ) + parser.add_argument( + "--force", action="store_true", help="Change the field without confirmation" + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) args = parser.parse_args(None if len(sys.argv) > 1 else ["--help"]) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - logger.info(f'* Loading: {args.model}') - reader = GGUFReader(args.model, 'r' if args.dry_run else 'r+') + logger.info(f"* Loading: {args.model}") + reader = GGUFReader(args.model, "r" if args.dry_run else "r+") set_metadata(reader, args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/gguf-py/gguf/tensor_mapping.py b/smallthinker/gguf-py/gguf/tensor_mapping.py index b51c3809..920bb2a5 100644 --- a/smallthinker/gguf-py/gguf/tensor_mapping.py +++ b/smallthinker/gguf-py/gguf/tensor_mapping.py @@ -9,1212 +9,887 @@ class TensorNameMap: mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Token embeddings MODEL_TENSOR.TOKEN_EMBD: ( - "gpt_neox.embed_in", # gptneox - "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone - "transformer.word_embeddings", # falcon - "word_embeddings", # bloom - "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 - "tok_embeddings", # llama-pth - "embeddings.word_embeddings", # bert nomic-bert + "gpt_neox.embed_in", # gptneox + "transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone + "transformer.word_embeddings", # falcon + "word_embeddings", # bloom + "model.embed_tokens", # llama-hf nemotron olmoe olmo2 rwkv6qwen2 glm4-0414 + "tok_embeddings", # llama-pth + "embeddings.word_embeddings", # bert nomic-bert "language_model.embedding.word_embeddings", # persimmon - "wte", # gpt2 - "transformer.embd.wte", # phi2 - "model.tok_embeddings", # internlm2 - "model.embedding", # mamba-qbert - "backbone.embedding", # mamba - "backbone.embeddings", # mamba-hf - "transformer.in_out_embed", # Grok - "embedding.word_embeddings", # chatglm - "transformer.token_embeddings", # openelm - "shared", # t5 - "rwkv.embeddings", # rwkv6 - "model.embeddings", # rwkv7 - "model.word_embeddings", # bailingmoe - "language_model.model.embed_tokens", # llama4 + "wte", # gpt2 + "transformer.embd.wte", # phi2 + "model.tok_embeddings", # internlm2 + "model.embedding", # mamba-qbert + "backbone.embedding", # mamba + "backbone.embeddings", # mamba-hf + "transformer.in_out_embed", # Grok + "embedding.word_embeddings", # chatglm + "transformer.token_embeddings", # openelm + "shared", # t5 + "rwkv.embeddings", # rwkv6 + "model.embeddings", # rwkv7 + "model.word_embeddings", # bailingmoe + "language_model.model.embed_tokens", # llama4 ), - # Token type embeddings MODEL_TENSOR.TOKEN_TYPES: ( "embeddings.token_type_embeddings", # bert nomic-bert ), - # Normalization of token embeddings MODEL_TENSOR.TOKEN_EMBD_NORM: ( "word_embeddings_layernorm", # bloom - "embeddings.LayerNorm", # bert - "emb_ln", # nomic-bert - "transformer.norm", # openelm - "rwkv.blocks.0.pre_ln", # rwkv - "rwkv.blocks.0.pre_ln", # rwkv6 - "model.pre_ln", # rwkv7 - "model.layers.0.pre_norm", # rwkv7 - "backbone.norm", # wavtokenizer + "embeddings.LayerNorm", # bert + "emb_ln", # nomic-bert + "transformer.norm", # openelm + "rwkv.blocks.0.pre_ln", # rwkv + "rwkv.blocks.0.pre_ln", # rwkv6 + "model.pre_ln", # rwkv7 + "model.layers.0.pre_norm", # rwkv7 + "backbone.norm", # wavtokenizer ), - # Position embeddings MODEL_TENSOR.POS_EMBD: ( - "transformer.wpe", # gpt2 + "transformer.wpe", # gpt2 "embeddings.position_embeddings", # bert - "wpe", # gpt2 + "wpe", # gpt2 ), - # Output MODEL_TENSOR.OUTPUT: ( - "embed_out", # gptneox - "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe - "output", # llama-pth bloom internlm2 + "embed_out", # gptneox + "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe + "output", # llama-pth bloom internlm2 "word_embeddings_for_head", # persimmon - "lm_head.linear", # phi2 - "output_layer", # chatglm - "head", # rwkv - "head.out", # wavtokenizer - "lm_head", # llama4 + "lm_head.linear", # phi2 + "output_layer", # chatglm + "head", # rwkv + "head.out", # wavtokenizer + "lm_head", # llama4 ), - # Output norm MODEL_TENSOR.OUTPUT_NORM: ( - "gpt_neox.final_layer_norm", # gptneox - "transformer.ln_f", # gpt2 gpt-j falcon jais exaone - "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe - "norm", # llama-pth - "transformer.norm_f", # mpt dbrx - "ln_f", # refact bloom qwen gpt2 + "gpt_neox.final_layer_norm", # gptneox + "transformer.ln_f", # gpt2 gpt-j falcon jais exaone + "model.norm", # llama-hf baichuan internlm2 olmoe olmo2 phimoe + "norm", # llama-pth + "transformer.norm_f", # mpt dbrx + "ln_f", # refact bloom qwen gpt2 "language_model.encoder.final_layernorm", # persimmon - "model.final_layernorm", # persimmon - "lm_head.ln", # phi2 - "model.norm_f", # mamba-qbert - "backbone.norm_f", # mamba - "transformer.rms_norm", # Grok - "encoder.final_layernorm", # chatglm - "transformer.norm", # openelm - "model.norm", # nemotron - "rwkv.ln_out", # rwkv6 - "model.ln_out", # rwkv7 - "backbone.final_layer_norm", # wavtokenizer - "model.norm", # llama4 + "model.final_layernorm", # persimmon + "lm_head.ln", # phi2 + "model.norm_f", # mamba-qbert + "backbone.norm_f", # mamba + "transformer.rms_norm", # Grok + "encoder.final_layernorm", # chatglm + "transformer.norm", # openelm + "model.norm", # nemotron + "rwkv.ln_out", # rwkv6 + "model.ln_out", # rwkv7 + "backbone.final_layer_norm", # wavtokenizer + "model.norm", # llama4 ), - # Rope frequencies MODEL_TENSOR.ROPE_FREQS: ( "rope.freqs", # llama-pth "rotary_pos_emb.inv_freq", # chatglm ), - MODEL_TENSOR.ROPE_FACTORS_LONG: (), MODEL_TENSOR.ROPE_FACTORS_SHORT: (), - - MODEL_TENSOR.CONV1D: ( - "backbone.embed", # roberta - ), + MODEL_TENSOR.CONV1D: ("backbone.embed",), # roberta } block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { # Attention norm MODEL_TENSOR.ATTN_NORM: ( - "gpt_neox.layers.{bid}.input_layernorm", # gptneox - "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone - "transformer.blocks.{bid}.norm_1", # mpt - "transformer.h.{bid}.input_layernorm", # falcon7b - "h.{bid}.input_layernorm", # bloom - "transformer.h.{bid}.ln_mlp", # falcon40b - "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe - "layers.{bid}.attention_norm", # llama-pth + "gpt_neox.layers.{bid}.input_layernorm", # gptneox + "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone + "transformer.blocks.{bid}.norm_1", # mpt + "transformer.h.{bid}.input_layernorm", # falcon7b + "h.{bid}.input_layernorm", # bloom + "transformer.h.{bid}.ln_mlp", # falcon40b + "model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe phimoe + "layers.{bid}.attention_norm", # llama-pth "language_model.encoder.layers.{bid}.input_layernorm", # persimmon - "model.layers.{bid}.ln1", # yi - "h.{bid}.ln_1", # gpt2 - "transformer.h.{bid}.ln", # phi2 - "model.layers.layers.{bid}.norm", # plamo - "model.layers.{bid}.attention_norm", # internlm2 - "model.layers.{bid}.norm", # mamba-qbert - "backbone.layers.{bid}.norm", # mamba - "transformer.decoder_layer.{bid}.rms_norm", # Grok - "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx - "encoder.layers.{bid}.input_layernorm", # chatglm - "transformer.layers.{bid}.attn_norm", # openelm - "rwkv.blocks.{bid}.ln1", # rwkv6 - "model.layers.{bid}.ln1", # rwkv7 - "model.layers.{bid}.input_layernorm", # llama4 + "model.layers.{bid}.ln1", # yi + "h.{bid}.ln_1", # gpt2 + "transformer.h.{bid}.ln", # phi2 + "model.layers.layers.{bid}.norm", # plamo + "model.layers.{bid}.attention_norm", # internlm2 + "model.layers.{bid}.norm", # mamba-qbert + "backbone.layers.{bid}.norm", # mamba + "transformer.decoder_layer.{bid}.rms_norm", # Grok + "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "encoder.layers.{bid}.input_layernorm", # chatglm + "transformer.layers.{bid}.attn_norm", # openelm + "rwkv.blocks.{bid}.ln1", # rwkv6 + "model.layers.{bid}.ln1", # rwkv7 + "model.layers.{bid}.input_layernorm", # llama4 ), - # Attention norm 2 MODEL_TENSOR.ATTN_NORM_2: ( - "transformer.h.{bid}.ln_attn", # falcon40b - "encoder.layer.{bid}.layer_norm_1", # jina-v2-code - "rwkv.blocks.{bid}.ln2", # rwkv6 - "model.layers.{bid}.ln2", # rwkv7 + "transformer.h.{bid}.ln_attn", # falcon40b + "encoder.layer.{bid}.layer_norm_1", # jina-v2-code + "rwkv.blocks.{bid}.ln2", # rwkv6 + "model.layers.{bid}.ln2", # rwkv7 ), - # Attention query-key-value MODEL_TENSOR.ATTN_QKV: ( - "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox - "transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais - "transformer.blocks.{bid}.attn.Wqkv", # mpt - "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx - "transformer.h.{bid}.self_attention.query_key_value", # falcon - "h.{bid}.self_attention.query_key_value", # bloom + "gpt_neox.layers.{bid}.attention.query_key_value", # gptneox + "transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais + "transformer.blocks.{bid}.attn.Wqkv", # mpt + "transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx + "transformer.h.{bid}.self_attention.query_key_value", # falcon + "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon - "model.layers.{bid}.self_attn.query_key_value", # persimmon - "h.{bid}.attn.c_attn", # gpt2 - "transformer.h.{bid}.mixer.Wqkv", # phi2 - "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "encoder.layers.{bid}.mixer.Wqkv", # jina - "model.layers.{bid}.self_attn.qkv_proj", # phi3 - "encoder.layers.{bid}.self_attention.query_key_value", # chatglm - "transformer.layers.{bid}.attn.qkv_proj", # openelm + "model.layers.{bid}.self_attn.query_key_value", # persimmon + "h.{bid}.attn.c_attn", # gpt2 + "transformer.h.{bid}.mixer.Wqkv", # phi2 + "encoder.layers.{bid}.attn.Wqkv", # nomic-bert + "encoder.layers.{bid}.mixer.Wqkv", # jina + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "encoder.layers.{bid}.self_attention.query_key_value", # chatglm + "transformer.layers.{bid}.attn.qkv_proj", # openelm ), - # Attention query MODEL_TENSOR.ATTN_Q: ( - "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe - "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom - "layers.{bid}.attention.wq", # llama-pth - "encoder.layer.{bid}.attention.self.query", # bert - "transformer.layer.{bid}.attention.q_lin", # distillbert - "transformer.h.{bid}.attn.q_proj", # gpt-j - "model.layers.layers.{bid}.self_attn.q_proj", # plamo - "model.layers.{bid}.attention.wq", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok - "transformer.h.{bid}.attn.attention.q_proj", # exaone - "model.layers.{bid}.self_attn.q_proj", # llama4 + "model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom + "layers.{bid}.attention.wq", # llama-pth + "encoder.layer.{bid}.attention.self.query", # bert + "transformer.layer.{bid}.attention.q_lin", # distillbert + "transformer.h.{bid}.attn.q_proj", # gpt-j + "model.layers.layers.{bid}.self_attn.q_proj", # plamo + "model.layers.{bid}.attention.wq", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.query", # Grok + "transformer.h.{bid}.attn.attention.q_proj", # exaone + "model.layers.{bid}.self_attn.q_proj", # llama4 ), - # Attention key MODEL_TENSOR.ATTN_K: ( - "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe - "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom - "layers.{bid}.attention.wk", # llama-pth - "encoder.layer.{bid}.attention.self.key", # bert - "transformer.layer.{bid}.attention.k_lin", # distillbert - "transformer.h.{bid}.attn.k_proj", # gpt-j - "transformer.h.{bid}.attn.k", # refact - "model.layers.layers.{bid}.self_attn.k_proj", # plamo - "model.layers.{bid}.attention.wk", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok - "transformer.h.{bid}.attn.attention.k_proj", # exaone - "model.layers.{bid}.self_attn.k_proj", # llama4 + "model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom + "layers.{bid}.attention.wk", # llama-pth + "encoder.layer.{bid}.attention.self.key", # bert + "transformer.layer.{bid}.attention.k_lin", # distillbert + "transformer.h.{bid}.attn.k_proj", # gpt-j + "transformer.h.{bid}.attn.k", # refact + "model.layers.layers.{bid}.self_attn.k_proj", # plamo + "model.layers.{bid}.attention.wk", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.key", # Grok + "transformer.h.{bid}.attn.attention.k_proj", # exaone + "model.layers.{bid}.self_attn.k_proj", # llama4 ), - # Attention value MODEL_TENSOR.ATTN_V: ( - "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe - "layers.{bid}.attention.wv", # llama-pth - "encoder.layer.{bid}.attention.self.value", # bert - "transformer.layer.{bid}.attention.v_lin", # distillbert - "transformer.h.{bid}.attn.v_proj", # gpt-j - "transformer.h.{bid}.attn.v", # refact - "model.layers.layers.{bid}.self_attn.v_proj", # plamo - "model.layers.{bid}.attention.wv", # internlm2 - "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok - "transformer.h.{bid}.attn.attention.v_proj", # exaone - "model.layers.{bid}.self_attn.v_proj", # llama4 + "model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2 phimoe + "layers.{bid}.attention.wv", # llama-pth + "encoder.layer.{bid}.attention.self.value", # bert + "transformer.layer.{bid}.attention.v_lin", # distillbert + "transformer.h.{bid}.attn.v_proj", # gpt-j + "transformer.h.{bid}.attn.v", # refact + "model.layers.layers.{bid}.self_attn.v_proj", # plamo + "model.layers.{bid}.attention.wv", # internlm2 + "transformer.decoder_layer.{bid}.multi_head_attention.value", # Grok + "transformer.h.{bid}.attn.attention.v_proj", # exaone + "model.layers.{bid}.self_attn.v_proj", # llama4 ), - # Attention output MODEL_TENSOR.ATTN_OUT: ( - "gpt_neox.layers.{bid}.attention.dense", # gptneox - "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais - "transformer.blocks.{bid}.attn.out_proj", # mpt - "transformer.h.{bid}.self_attention.dense", # falcon - "h.{bid}.self_attention.dense", # bloom - "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe - "model.layers.{bid}.self_attn.linear_attn", # deci - "layers.{bid}.attention.wo", # llama-pth - "encoder.layer.{bid}.attention.output.dense", # bert - "transformer.layer.{bid}.attention.out_lin", # distillbert - "transformer.h.{bid}.attn.out_proj", # gpt-j - "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon - "model.layers.{bid}.self_attn.dense", # persimmon - "h.{bid}.attn.c_proj", # gpt2 - "transformer.h.{bid}.mixer.out_proj", # phi2 - "model.layers.layers.{bid}.self_attn.o_proj", # plamo - "model.layers.{bid}.attention.wo", # internlm2 - "encoder.layers.{bid}.attn.out_proj", # nomic-bert - "encoder.layers.{bid}.mixer.out_proj", # jina + "gpt_neox.layers.{bid}.attention.dense", # gptneox + "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais + "transformer.blocks.{bid}.attn.out_proj", # mpt + "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom + "model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2 phimoe + "model.layers.{bid}.self_attn.linear_attn", # deci + "layers.{bid}.attention.wo", # llama-pth + "encoder.layer.{bid}.attention.output.dense", # bert + "transformer.layer.{bid}.attention.out_lin", # distillbert + "transformer.h.{bid}.attn.out_proj", # gpt-j + "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon + "model.layers.{bid}.self_attn.dense", # persimmon + "h.{bid}.attn.c_proj", # gpt2 + "transformer.h.{bid}.mixer.out_proj", # phi2 + "model.layers.layers.{bid}.self_attn.o_proj", # plamo + "model.layers.{bid}.attention.wo", # internlm2 + "encoder.layers.{bid}.attn.out_proj", # nomic-bert + "encoder.layers.{bid}.mixer.out_proj", # jina "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok - "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx - "encoder.layers.{bid}.self_attention.dense", # chatglm - "transformer.layers.{bid}.attn.out_proj", # openelm - "transformer.h.{bid}.attn.attention.out_proj", # exaone - "model.layers.{bid}.self_attn.o_proj", # llama4 + "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "encoder.layers.{bid}.self_attention.dense", # chatglm + "transformer.layers.{bid}.attn.out_proj", # openelm + "transformer.h.{bid}.attn.attention.out_proj", # exaone + "model.layers.{bid}.self_attn.o_proj", # llama4 ), - # Attention output norm MODEL_TENSOR.ATTN_OUT_NORM: ( "encoder.layer.{bid}.attention.output.LayerNorm", # bert - "transformer.layer.{bid}.sa_layer_norm", # distillbert - "encoder.layers.{bid}.norm1", # nomic-bert - "transformer.decoder_layer.{bid}.rms_norm_1", # Grok + "transformer.layer.{bid}.sa_layer_norm", # distillbert + "encoder.layers.{bid}.norm1", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_1", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx ), - MODEL_TENSOR.ATTN_POST_NORM: ( - "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge - "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 + "model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2 # ge + "model.layers.{bid}.post_self_attn_layernorm", # glm-4-0414 ), - # Rotary embeddings MODEL_TENSOR.ATTN_ROT_EMBD: ( - "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf - "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth - "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo - "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell + "model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf + "layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth + "model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo + "transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell ), - # Feed-forward norm MODEL_TENSOR.FFN_NORM: ( - "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox - "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone - "h.{bid}.post_attention_layernorm", # bloom - "transformer.blocks.{bid}.norm_2", # mpt - "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe - "layers.{bid}.ffn_norm", # llama-pth + "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox + "transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone + "h.{bid}.post_attention_layernorm", # bloom + "transformer.blocks.{bid}.norm_2", # mpt + "model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe phimoe + "layers.{bid}.ffn_norm", # llama-pth "language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon - "model.layers.{bid}.ln2", # yi - "h.{bid}.ln_2", # gpt2 - "model.layers.{bid}.ffn_norm", # internlm2 - "transformer.decoder_layer.{bid}.rms_norm_2", # Grok - "encoder.layers.{bid}.post_attention_layernorm", # chatglm - "transformer.layers.{bid}.ffn_norm", # openelm - "model.layers.{bid}.post_attention_layernorm", # llama4 + "model.layers.{bid}.ln2", # yi + "h.{bid}.ln_2", # gpt2 + "model.layers.{bid}.ffn_norm", # internlm2 + "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "encoder.layers.{bid}.post_attention_layernorm", # chatglm + "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.post_attention_layernorm", # llama4 ), - # Post feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( - "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 ), - # Post feed-forward norm MODEL_TENSOR.FFN_POST_NORM: ( - "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 - "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 + "model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2 + "model.layers.{bid}.post_mlp_layernorm", # glm-4-0414 ), - MODEL_TENSOR.FFN_GATE_INP: ( - "layers.{bid}.feed_forward.gate", # mixtral - "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe - "model.layers.{bid}.mlp.gate", # qwen2moe olmoe - "transformer.decoder_layer.{bid}.router", # Grok - "transformer.blocks.{bid}.ffn.router.layer", # dbrx - "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe - "model.layers.{bid}.feed_forward.router", # llama4 - "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe - "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker + "layers.{bid}.feed_forward.gate", # mixtral + "model.layers.{bid}.block_sparse_moe.gate", # mixtral phimoe + "model.layers.{bid}.mlp.gate", # qwen2moe olmoe + "transformer.decoder_layer.{bid}.router", # Grok + "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe + "model.layers.{bid}.feed_forward.router", # llama4 + "encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker ), - MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe + "model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe ), - MODEL_TENSOR.FFN_EXP_PROBS_B: ( - "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 + "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 ), - # Feed-forward up MODEL_TENSOR.FFN_UP: ( - "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox - "transformer.h.{bid}.mlp.c_fc", # gpt2 jais - "transformer.blocks.{bid}.ffn.up_proj", # mpt - "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon - "h.{bid}.mlp.dense_h_to_4h", # bloom - "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 - "layers.{bid}.feed_forward.w3", # llama-pth - "encoder.layer.{bid}.intermediate.dense", # bert - "transformer.layer.{bid}.ffn.lin1", # distillbert - "transformer.h.{bid}.mlp.fc_in", # gpt-j - "transformer.h.{bid}.mlp.linear_3", # refact + "gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox + "transformer.h.{bid}.mlp.c_fc", # gpt2 jais + "transformer.blocks.{bid}.ffn.up_proj", # mpt + "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "h.{bid}.mlp.dense_h_to_4h", # bloom + "model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2 + "layers.{bid}.feed_forward.w3", # llama-pth + "encoder.layer.{bid}.intermediate.dense", # bert + "transformer.layer.{bid}.ffn.lin1", # distillbert + "transformer.h.{bid}.mlp.fc_in", # gpt-j + "transformer.h.{bid}.mlp.linear_3", # refact "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon - "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon - "transformer.h.{bid}.mlp.w1", # qwen - "h.{bid}.mlp.c_fc", # gpt2 - "transformer.h.{bid}.mlp.fc1", # phi2 - "model.layers.{bid}.mlp.fc1", # phi2 - "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 - "model.layers.layers.{bid}.mlp.up_proj", # plamo - "model.layers.{bid}.feed_forward.w3", # internlm2 - "encoder.layers.{bid}.mlp.fc11", # nomic-bert - "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe - "model.layers.{bid}.mlp.c_fc", # starcoder2 - "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 - "model.layers.{bid}.residual_mlp.w3", # arctic - "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm - "transformer.h.{bid}.mlp.c_fc_1", # exaone - "model.layers.{bid}.feed_forward.up_proj", # llama4 - "model.layers.{bid}.block_sparse_moe.up", # smallthinker + "model.layers.{bid}.mlp.dense_h_to_4h", # persimmon + "transformer.h.{bid}.mlp.w1", # qwen + "h.{bid}.mlp.c_fc", # gpt2 + "transformer.h.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.fc1", # phi2 + "model.layers.{bid}.mlp.gate_up_proj", # phi3 glm-4-0414 + "model.layers.layers.{bid}.mlp.up_proj", # plamo + "model.layers.{bid}.feed_forward.w3", # internlm2 + "encoder.layers.{bid}.mlp.fc11", # nomic-bert + "encoder.layers.{bid}.mlp.fc1", # nomic-bert-moe + "model.layers.{bid}.mlp.c_fc", # starcoder2 + "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "model.layers.{bid}.residual_mlp.w3", # arctic + "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm + "transformer.h.{bid}.mlp.c_fc_1", # exaone + "model.layers.{bid}.feed_forward.up_proj", # llama4 + "model.layers.{bid}.block_sparse_moe.up", # smallthinker ), - MODEL_TENSOR.FFN_UP_EXP: ( - "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx - "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) - "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) - "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 - "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe - "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker + "layers.{bid}.feed_forward.experts.w3", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx + "model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker ), - MODEL_TENSOR.FFN_UP_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 - "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4 ), - # AWQ-activation gate - MODEL_TENSOR.FFN_ACT: ( - "transformer.blocks.{bid}.ffn.act", # mpt - ), - + MODEL_TENSOR.FFN_ACT: ("transformer.blocks.{bid}.ffn.act",), # mpt # Feed-forward gate MODEL_TENSOR.FFN_GATE: ( - "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 - "layers.{bid}.feed_forward.w1", # llama-pth - "transformer.h.{bid}.mlp.w2", # qwen - "transformer.h.{bid}.mlp.c_fc2", # jais - "model.layers.layers.{bid}.mlp.gate_proj", # plamo - "model.layers.{bid}.feed_forward.w1", # internlm2 - "encoder.layers.{bid}.mlp.fc12", # nomic-bert - "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 - "transformer.h.{bid}.mlp.linear_1", # refact - "model.layers.{bid}.residual_mlp.w1", # arctic - "transformer.h.{bid}.mlp.c_fc_0", # exaone + "model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2 + "layers.{bid}.feed_forward.w1", # llama-pth + "transformer.h.{bid}.mlp.w2", # qwen + "transformer.h.{bid}.mlp.c_fc2", # jais + "model.layers.layers.{bid}.mlp.gate_proj", # plamo + "model.layers.{bid}.feed_forward.w1", # internlm2 + "encoder.layers.{bid}.mlp.fc12", # nomic-bert + "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 + "transformer.h.{bid}.mlp.linear_1", # refact + "model.layers.{bid}.residual_mlp.w1", # arctic + "transformer.h.{bid}.mlp.c_fc_0", # exaone "model.layers.{bid}.feed_forward.gate_proj", # llama4 - "model.layers.{bid}.block_sparse_moe.gate", # smallthinker + "model.layers.{bid}.block_sparse_moe.gate", # smallthinker ), - MODEL_TENSOR.FFN_GATE_EXP: ( - "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx - "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) - "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) - "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "layers.{bid}.feed_forward.experts.w1", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx + "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 "model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker ), - MODEL_TENSOR.FFN_GATE_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 - "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 + "model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.gate_proj", # llama4 ), - # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( - "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox - "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais - "transformer.blocks.{bid}.ffn.down_proj", # mpt - "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon - "h.{bid}.mlp.dense_4h_to_h", # bloom - "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 - "layers.{bid}.feed_forward.w2", # llama-pth - "encoder.layer.{bid}.output.dense", # bert - "transformer.layer.{bid}.ffn.lin2", # distillbert - "transformer.h.{bid}.mlp.fc_out", # gpt-j + "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox + "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais + "transformer.blocks.{bid}.ffn.down_proj", # mpt + "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "h.{bid}.mlp.dense_4h_to_h", # bloom + "model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2 + "layers.{bid}.feed_forward.w2", # llama-pth + "encoder.layer.{bid}.output.dense", # bert + "transformer.layer.{bid}.ffn.lin2", # distillbert + "transformer.h.{bid}.mlp.fc_out", # gpt-j "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon - "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon - "h.{bid}.mlp.c_proj", # gpt2 - "transformer.h.{bid}.mlp.fc2", # phi2 - "model.layers.{bid}.mlp.fc2", # phi2 - "model.layers.layers.{bid}.mlp.down_proj", # plamo - "model.layers.{bid}.feed_forward.w2", # internlm2 - "encoder.layers.{bid}.mlp.fc2", # nomic-bert - "model.layers.{bid}.mlp.c_proj", # starcoder2 - "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 - "transformer.layers.{bid}.ffn.proj_2", # openelm - "model.layers.{bid}.residual_mlp.w2", # arctic - "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 - "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm - "model.layers.h.{bid}.mlp.c_proj", # exaone - "model.layers.{bid}.feed_forward.down_proj", # llama4 - "model.layers.{bid}.block_sparse_moe.down", # smallthinker + "model.layers.{bid}.mlp.dense_4h_to_h", # persimmon + "h.{bid}.mlp.c_proj", # gpt2 + "transformer.h.{bid}.mlp.fc2", # phi2 + "model.layers.{bid}.mlp.fc2", # phi2 + "model.layers.layers.{bid}.mlp.down_proj", # plamo + "model.layers.{bid}.feed_forward.w2", # internlm2 + "encoder.layers.{bid}.mlp.fc2", # nomic-bert + "model.layers.{bid}.mlp.c_proj", # starcoder2 + "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "transformer.layers.{bid}.ffn.proj_2", # openelm + "model.layers.{bid}.residual_mlp.w2", # arctic + "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 + "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm + "model.layers.h.{bid}.mlp.c_proj", # exaone + "model.layers.{bid}.feed_forward.down_proj", # llama4 + "model.layers.{bid}.block_sparse_moe.down", # smallthinker ), - MODEL_TENSOR.FFN_DOWN_EXP: ( - "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) - "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) - "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx - "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) - "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe - "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) - "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 - "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe + "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) + "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) + "transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx + "model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged) + "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe + "model.layers.{bid}.block_sparse_moe.experts.w2", # phimoe (merged) + "model.layers.{bid}.feed_forward.experts.down_proj", # llama4 + "encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe "model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker ), - MODEL_TENSOR.FFN_DOWN_SHEXP: ( - "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe - "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 - "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 - "model.layers.{bid}.shared_mlp.output_linear", # granitemoe + "model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe + "model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2 + "model.layers.{bid}.feed_forward.shared_expert.down_proj", # llama4 + "model.layers.{bid}.shared_mlp.output_linear", # granitemoe ), - MODEL_TENSOR.ATTN_Q_NORM: ( "language_model.encoder.layers.{bid}.self_attention.q_layernorm", - "model.layers.{bid}.self_attn.q_layernorm", # persimmon - "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 - "transformer.blocks.{bid}.attn.q_ln", # sea-lion - "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 - "transformer.layers.{bid}.attn.q_norm", # openelm + "model.layers.{bid}.self_attn.q_layernorm", # persimmon + "model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2 + "transformer.blocks.{bid}.attn.q_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 + "transformer.layers.{bid}.attn.q_norm", # openelm ), - MODEL_TENSOR.ATTN_K_NORM: ( "language_model.encoder.layers.{bid}.self_attention.k_layernorm", - "model.layers.{bid}.self_attn.k_layernorm", # persimmon - "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 - "transformer.blocks.{bid}.attn.k_ln", # sea-lion - "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 - "transformer.layers.{bid}.attn.k_norm", # openelm + "model.layers.{bid}.self_attn.k_layernorm", # persimmon + "model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2 + "transformer.blocks.{bid}.attn.k_ln", # sea-lion + "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 + "transformer.layers.{bid}.attn.k_norm", # openelm ), - MODEL_TENSOR.ROPE_FREQS: ( "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon ), - MODEL_TENSOR.LAYER_OUT_NORM: ( - "encoder.layer.{bid}.output.LayerNorm", # bert - "transformer.layer.{bid}.output_layer_norm", # distillbert - "encoder.layers.{bid}.norm2", # nomic-bert - "transformer.decoder_layer.{bid}.rms_norm_3", # Grok - "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.output.LayerNorm", # bert + "transformer.layer.{bid}.output_layer_norm", # distillbert + "encoder.layers.{bid}.norm2", # nomic-bert + "transformer.decoder_layer.{bid}.rms_norm_3", # Grok + "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), - MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", ), - MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", ), - MODEL_TENSOR.SSM_X: ( "model.layers.{bid}.x_proj", "backbone.layers.{bid}.mixer.x_proj", ), - MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", ), - MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", ), - MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", ), - MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", ), - - MODEL_TENSOR.TIME_MIX_W0: ( - "model.layers.{bid}.attention.w0", # rwkv7 - ), - + MODEL_TENSOR.TIME_MIX_W0: ("model.layers.{bid}.attention.w0",), # rwkv7 MODEL_TENSOR.TIME_MIX_W1: ( - "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6 - "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 - "model.layers.{bid}.attention.w1", # rwkv7 + "rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w1", # rwkv6qwen2 + "model.layers.{bid}.attention.w1", # rwkv7 ), - MODEL_TENSOR.TIME_MIX_W2: ( - "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6 - "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 - "model.layers.{bid}.attention.w2", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_A0: ( - "model.layers.{bid}.attention.a0", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_A1: ( - "model.layers.{bid}.attention.a1", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_A2: ( - "model.layers.{bid}.attention.a2", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_V0: ( - "model.layers.{bid}.attention.v0", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_V1: ( - "model.layers.{bid}.attention.v1", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_V2: ( - "model.layers.{bid}.attention.v2", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_G1: ( - "model.layers.{bid}.attention.g1", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_G2: ( - "model.layers.{bid}.attention.g2", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_K_K: ( - "model.layers.{bid}.attention.k_k", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_K_A: ( - "model.layers.{bid}.attention.k_a", # rwkv7 - ), - - MODEL_TENSOR.TIME_MIX_R_K: ( - "model.layers.{bid}.attention.r_k", # rwkv7 - ), - + "rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv6 + "model.layers.{bid}.self_attn.time_maa_w2", # rwkv6qwen2 + "model.layers.{bid}.attention.w2", # rwkv7 + ), + MODEL_TENSOR.TIME_MIX_A0: ("model.layers.{bid}.attention.a0",), # rwkv7 + MODEL_TENSOR.TIME_MIX_A1: ("model.layers.{bid}.attention.a1",), # rwkv7 + MODEL_TENSOR.TIME_MIX_A2: ("model.layers.{bid}.attention.a2",), # rwkv7 + MODEL_TENSOR.TIME_MIX_V0: ("model.layers.{bid}.attention.v0",), # rwkv7 + MODEL_TENSOR.TIME_MIX_V1: ("model.layers.{bid}.attention.v1",), # rwkv7 + MODEL_TENSOR.TIME_MIX_V2: ("model.layers.{bid}.attention.v2",), # rwkv7 + MODEL_TENSOR.TIME_MIX_G1: ("model.layers.{bid}.attention.g1",), # rwkv7 + MODEL_TENSOR.TIME_MIX_G2: ("model.layers.{bid}.attention.g2",), # rwkv7 + MODEL_TENSOR.TIME_MIX_K_K: ("model.layers.{bid}.attention.k_k",), # rwkv7 + MODEL_TENSOR.TIME_MIX_K_A: ("model.layers.{bid}.attention.k_a",), # rwkv7 + MODEL_TENSOR.TIME_MIX_R_K: ("model.layers.{bid}.attention.r_k",), # rwkv7 MODEL_TENSOR.TIME_MIX_LERP_X: ( - "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_x", # rwkv6 "model.layers.{bid}.self_attn.time_maa_x", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LERP_K: ( - "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_k", # rwkv6 "model.layers.{bid}.self_attn.time_maa_k", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LERP_V: ( - "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_v", # rwkv6 "model.layers.{bid}.self_attn.time_maa_v", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LERP_R: ( - "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_r", # rwkv6 "model.layers.{bid}.self_attn.time_maa_r", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LERP_G: ( - "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_g", # rwkv6 "model.layers.{bid}.self_attn.time_maa_g", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LERP_W: ( - "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6 + "rwkv.blocks.{bid}.attention.time_maa_w", # rwkv6 "model.layers.{bid}.self_attn.time_maa_w", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_FIRST: ( - "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6 + "rwkv.blocks.{bid}.attention.time_faaaa", # rwkv6 ), - MODEL_TENSOR.TIME_MIX_DECAY: ( - "rwkv.blocks.{bid}.attention.time_decay", # rwkv6 + "rwkv.blocks.{bid}.attention.time_decay", # rwkv6 "model.layers.{bid}.self_attn.time_decay", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_DECAY_W1: ( "rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv6 - "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 + "model.layers.{bid}.self_attn.time_decay_w1", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_DECAY_W2: ( "rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv6 - "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 + "model.layers.{bid}.self_attn.time_decay_w2", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_KEY: ( - "rwkv.blocks.{bid}.attention.key", # rwkv6 - "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 - "model.layers.{bid}.attention.key", # rwkv7 - "model.layers.{bid}.attention.k_proj", # rwkv7 + "rwkv.blocks.{bid}.attention.key", # rwkv6 + "model.layers.{bid}.self_attn.k_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.key", # rwkv7 + "model.layers.{bid}.attention.k_proj", # rwkv7 ), - MODEL_TENSOR.TIME_MIX_VALUE: ( - "rwkv.blocks.{bid}.attention.value", # rwkv6 - "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 + "rwkv.blocks.{bid}.attention.value", # rwkv6 + "model.layers.{bid}.self_attn.v_proj", # rwkv6qwen2 "model.layers.{bid}.attention.value", # rwkv7 - "model.layers.{bid}.attention.v_proj", # rwkv7 + "model.layers.{bid}.attention.v_proj", # rwkv7 ), - MODEL_TENSOR.TIME_MIX_RECEPTANCE: ( "rwkv.blocks.{bid}.attention.receptance", # rwkv6 - "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 - "model.layers.{bid}.attention.receptance", # rwkv7 - "model.layers.{bid}.attention.r_proj", # rwkv7 + "model.layers.{bid}.self_attn.q_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.receptance", # rwkv7 + "model.layers.{bid}.attention.r_proj", # rwkv7 ), - MODEL_TENSOR.TIME_MIX_GATE: ( - "rwkv.blocks.{bid}.attention.gate", # rwkv6 - "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 + "rwkv.blocks.{bid}.attention.gate", # rwkv6 + "model.layers.{bid}.self_attn.gate", # rwkv6qwen2 ), - MODEL_TENSOR.TIME_MIX_LN: ( - "rwkv.blocks.{bid}.attention.ln_x", # rwkv6 - "model.layers.{bid}.attention.ln_x" # rwkv7 + "rwkv.blocks.{bid}.attention.ln_x", # rwkv6 + "model.layers.{bid}.attention.ln_x", # rwkv7 ), - MODEL_TENSOR.TIME_MIX_OUTPUT: ( "rwkv.blocks.{bid}.attention.output", # rwkv6 - "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 - "model.layers.{bid}.attention.output", # rwkv7 - "model.layers.{bid}.attention.o_proj", # rwkv7 + "model.layers.{bid}.self_attn.o_proj", # rwkv6qwen2 + "model.layers.{bid}.attention.output", # rwkv7 + "model.layers.{bid}.attention.o_proj", # rwkv7 ), - MODEL_TENSOR.CHANNEL_MIX_LERP_K: ( - "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6 - "model.layers.{bid}.feed_forward.x_k", # rwkv7 + "rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv6 + "model.layers.{bid}.feed_forward.x_k", # rwkv7 ), - MODEL_TENSOR.CHANNEL_MIX_LERP_R: ( - "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6 + "rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv6 ), - MODEL_TENSOR.CHANNEL_MIX_KEY: ( "rwkv.blocks.{bid}.feed_forward.key", # rwkv6 - "model.layers.{bid}.feed_forward.key", # rwkv7 + "model.layers.{bid}.feed_forward.key", # rwkv7 ), - MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: ( - "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6 + "rwkv.blocks.{bid}.feed_forward.receptance", # rwkv6 ), - MODEL_TENSOR.CHANNEL_MIX_VALUE: ( "rwkv.blocks.{bid}.feed_forward.value", # rwkv6 - "model.layers.{bid}.feed_forward.value", # rwkv7 - ), - - MODEL_TENSOR.ATTN_Q_A: ( - "model.layers.{bid}.self_attn.q_a_proj", # deepseek2 - ), - - MODEL_TENSOR.ATTN_Q_B: ( - "model.layers.{bid}.self_attn.q_b_proj", # deepseek2 + "model.layers.{bid}.feed_forward.value", # rwkv7 ), - + MODEL_TENSOR.ATTN_Q_A: ("model.layers.{bid}.self_attn.q_a_proj",), # deepseek2 + MODEL_TENSOR.ATTN_Q_B: ("model.layers.{bid}.self_attn.q_b_proj",), # deepseek2 MODEL_TENSOR.ATTN_KV_A_MQA: ( - "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 + "model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2 ), - MODEL_TENSOR.ATTN_KV_B: ( - "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 - ), - - MODEL_TENSOR.ATTN_K_B: ( - "model.layers.{bid}.self_attn.k_b_proj", # deepseek2 - ), - - MODEL_TENSOR.ATTN_V_B: ( - "model.layers.{bid}.self_attn.v_b_proj", # deepseek2 + "model.layers.{bid}.self_attn.kv_b_proj", # deepseek2 ), - + MODEL_TENSOR.ATTN_K_B: ("model.layers.{bid}.self_attn.k_b_proj",), # deepseek2 + MODEL_TENSOR.ATTN_V_B: ("model.layers.{bid}.self_attn.v_b_proj",), # deepseek2 MODEL_TENSOR.ATTN_Q_A_NORM: ( - "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 + "model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2 ), - MODEL_TENSOR.ATTN_KV_A_NORM: ( - "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 + "model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2 ), - MODEL_TENSOR.ATTN_SUB_NORM: ( "model.layers.{bid}.self_attn.inner_attn_ln", # bitnet ), - - MODEL_TENSOR.FFN_SUB_NORM: ( - "model.layers.{bid}.mlp.ffn_layernorm", # bitnet - ), - - MODEL_TENSOR.DEC_ATTN_NORM: ( - "decoder.block.{bid}.layer.0.layer_norm", # t5 - ), - - MODEL_TENSOR.DEC_ATTN_Q: ( - "decoder.block.{bid}.layer.0.SelfAttention.q", # t5 - ), - - MODEL_TENSOR.DEC_ATTN_K: ( - "decoder.block.{bid}.layer.0.SelfAttention.k", # t5 - ), - - MODEL_TENSOR.DEC_ATTN_V: ( - "decoder.block.{bid}.layer.0.SelfAttention.v", # t5 - ), - + MODEL_TENSOR.FFN_SUB_NORM: ("model.layers.{bid}.mlp.ffn_layernorm",), # bitnet + MODEL_TENSOR.DEC_ATTN_NORM: ("decoder.block.{bid}.layer.0.layer_norm",), # t5 + MODEL_TENSOR.DEC_ATTN_Q: ("decoder.block.{bid}.layer.0.SelfAttention.q",), # t5 + MODEL_TENSOR.DEC_ATTN_K: ("decoder.block.{bid}.layer.0.SelfAttention.k",), # t5 + MODEL_TENSOR.DEC_ATTN_V: ("decoder.block.{bid}.layer.0.SelfAttention.v",), # t5 MODEL_TENSOR.DEC_ATTN_OUT: ( - "decoder.block.{bid}.layer.0.SelfAttention.o", # t5 + "decoder.block.{bid}.layer.0.SelfAttention.o", # t5 ), - MODEL_TENSOR.DEC_ATTN_REL_B: ( - "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 + "decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_NORM: ( - "decoder.block.{bid}.layer.1.layer_norm", # t5 + "decoder.block.{bid}.layer.1.layer_norm", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_Q: ( - "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5 + "decoder.block.{bid}.layer.1.EncDecAttention.q", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_K: ( - "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5 + "decoder.block.{bid}.layer.1.EncDecAttention.k", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_V: ( - "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5 + "decoder.block.{bid}.layer.1.EncDecAttention.v", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_OUT: ( - "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5 + "decoder.block.{bid}.layer.1.EncDecAttention.o", # t5 ), - MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: ( - "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5 - ), - - MODEL_TENSOR.DEC_FFN_NORM: ( - "decoder.block.{bid}.layer.2.layer_norm", # t5 + "decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5 ), - + MODEL_TENSOR.DEC_FFN_NORM: ("decoder.block.{bid}.layer.2.layer_norm",), # t5 MODEL_TENSOR.DEC_FFN_GATE: ( - "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5 ), - MODEL_TENSOR.DEC_FFN_UP: ( - "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5 - "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5 ), - MODEL_TENSOR.DEC_FFN_DOWN: ( - "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5 + "decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5 ), - - MODEL_TENSOR.DEC_OUTPUT_NORM: ( - "decoder.final_layer_norm", # t5 - ), - - MODEL_TENSOR.ENC_ATTN_NORM: ( - "encoder.block.{bid}.layer.0.layer_norm", # t5 - ), - - MODEL_TENSOR.ENC_ATTN_Q: ( - "encoder.block.{bid}.layer.0.SelfAttention.q", # t5 - ), - - MODEL_TENSOR.ENC_ATTN_K: ( - "encoder.block.{bid}.layer.0.SelfAttention.k", # t5 - ), - - MODEL_TENSOR.ENC_ATTN_V: ( - "encoder.block.{bid}.layer.0.SelfAttention.v", # t5 - ), - + MODEL_TENSOR.DEC_OUTPUT_NORM: ("decoder.final_layer_norm",), # t5 + MODEL_TENSOR.ENC_ATTN_NORM: ("encoder.block.{bid}.layer.0.layer_norm",), # t5 + MODEL_TENSOR.ENC_ATTN_Q: ("encoder.block.{bid}.layer.0.SelfAttention.q",), # t5 + MODEL_TENSOR.ENC_ATTN_K: ("encoder.block.{bid}.layer.0.SelfAttention.k",), # t5 + MODEL_TENSOR.ENC_ATTN_V: ("encoder.block.{bid}.layer.0.SelfAttention.v",), # t5 MODEL_TENSOR.ENC_ATTN_OUT: ( - "encoder.block.{bid}.layer.0.SelfAttention.o", # t5 + "encoder.block.{bid}.layer.0.SelfAttention.o", # t5 ), - MODEL_TENSOR.ENC_ATTN_REL_B: ( - "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 - ), - - MODEL_TENSOR.ENC_FFN_NORM: ( - "encoder.block.{bid}.layer.1.layer_norm", # t5 + "encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5 ), - + MODEL_TENSOR.ENC_FFN_NORM: ("encoder.block.{bid}.layer.1.layer_norm",), # t5 MODEL_TENSOR.ENC_FFN_GATE: ( - "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5 ), - MODEL_TENSOR.ENC_FFN_UP: ( - "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5 - "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5 ), - MODEL_TENSOR.ENC_FFN_DOWN: ( - "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5 + "encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5 ), - ############################################################################ # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg - MODEL_TENSOR.ENC_OUTPUT_NORM: ( - "encoder.final_layer_norm", # t5 - ), - + MODEL_TENSOR.ENC_OUTPUT_NORM: ("encoder.final_layer_norm",), # t5 MODEL_TENSOR.CLS: ( - "classifier", # jina - "classifier.dense", # roberta - "pre_classifier", # distillbert - ), - - MODEL_TENSOR.CLS_OUT: ( - "classifier.out_proj", # roberta + "classifier", # jina + "classifier.dense", # roberta + "pre_classifier", # distillbert ), + MODEL_TENSOR.CLS_OUT: ("classifier.out_proj",), # roberta ############################################################################# - - MODEL_TENSOR.CONVNEXT_DW: ( - "backbone.convnext.{bid}.dwconv", # wavtokenizer - ), - - MODEL_TENSOR.CONVNEXT_NORM: ( - "backbone.convnext.{bid}.norm", # wavtokenizer - ), - - MODEL_TENSOR.CONVNEXT_PW1: ( - "backbone.convnext.{bid}.pwconv1", # wavtokenizer - ), - - MODEL_TENSOR.CONVNEXT_PW2: ( - "backbone.convnext.{bid}.pwconv2", # wavtokenizer - ), - - MODEL_TENSOR.CONVNEXT_GAMMA: ( - "backbone.convnext.{bid}.gamma", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_CONV1: ( - "backbone.posnet.{bid}.conv1", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_CONV2: ( - "backbone.posnet.{bid}.conv2", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_NORM: ( - "backbone.posnet.{bid}.norm", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_NORM1: ( - "backbone.posnet.{bid}.norm1", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_NORM2: ( - "backbone.posnet.{bid}.norm2", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_ATTN_NORM: ( - "backbone.posnet.{bid}.norm", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_ATTN_Q: ( - "backbone.posnet.{bid}.q", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_ATTN_K: ( - "backbone.posnet.{bid}.k", # wavtokenizer - ), - - MODEL_TENSOR.POSNET_ATTN_V: ( - "backbone.posnet.{bid}.v", # wavtokenizer - ), - + MODEL_TENSOR.CONVNEXT_DW: ("backbone.convnext.{bid}.dwconv",), # wavtokenizer + MODEL_TENSOR.CONVNEXT_NORM: ("backbone.convnext.{bid}.norm",), # wavtokenizer + MODEL_TENSOR.CONVNEXT_PW1: ("backbone.convnext.{bid}.pwconv1",), # wavtokenizer + MODEL_TENSOR.CONVNEXT_PW2: ("backbone.convnext.{bid}.pwconv2",), # wavtokenizer + MODEL_TENSOR.CONVNEXT_GAMMA: ("backbone.convnext.{bid}.gamma",), # wavtokenizer + MODEL_TENSOR.POSNET_CONV1: ("backbone.posnet.{bid}.conv1",), # wavtokenizer + MODEL_TENSOR.POSNET_CONV2: ("backbone.posnet.{bid}.conv2",), # wavtokenizer + MODEL_TENSOR.POSNET_NORM: ("backbone.posnet.{bid}.norm",), # wavtokenizer + MODEL_TENSOR.POSNET_NORM1: ("backbone.posnet.{bid}.norm1",), # wavtokenizer + MODEL_TENSOR.POSNET_NORM2: ("backbone.posnet.{bid}.norm2",), # wavtokenizer + MODEL_TENSOR.POSNET_ATTN_NORM: ("backbone.posnet.{bid}.norm",), # wavtokenizer + MODEL_TENSOR.POSNET_ATTN_Q: ("backbone.posnet.{bid}.q",), # wavtokenizer + MODEL_TENSOR.POSNET_ATTN_K: ("backbone.posnet.{bid}.k",), # wavtokenizer + MODEL_TENSOR.POSNET_ATTN_V: ("backbone.posnet.{bid}.v",), # wavtokenizer MODEL_TENSOR.POSNET_ATTN_OUT: ( - "backbone.posnet.{bid}.proj_out", # wavtokenizer + "backbone.posnet.{bid}.proj_out", # wavtokenizer ), - ############################################################################# ## Vision encoder - MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", - "visual.merger.mlp.{bid}", # qwen2vl + "visual.merger.mlp.{bid}", # qwen2vl ), - MODEL_TENSOR.V_MMPROJ_FC: ( - "model.connector.modality_projection.proj", # SmolVLM + "model.connector.modality_projection.proj", # SmolVLM ), - MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", - "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 - "mlp1.{bid}", # InternVL - ), - - MODEL_TENSOR.V_MMPROJ_PEG: ( - "model.mm_projector.peg.peg.{bid}", + "vision_model.vision_adapter.mlp.fc{bid}", # llama 4 + "mlp1.{bid}", # InternVL ), - + MODEL_TENSOR.V_MMPROJ_PEG: ("model.mm_projector.peg.peg.{bid}",), MODEL_TENSOR.V_ENC_EMBD_CLS: ( "vision_tower.vision_model.embeddings.class_embedding", - "vision_model.class_embedding", # llama 4 + "vision_model.class_embedding", # llama 4 ), - MODEL_TENSOR.V_ENC_EMBD_PATCH: ( "vision_tower.vision_model.embeddings.patch_embedding", "vpm.embeddings.patch_embedding", - "model.vision_model.embeddings.patch_embedding", # SmolVLM - "vision_tower.patch_conv", # pixtral - "vision_model.patch_embedding.linear", # llama 4 - "visual.patch_embed.proj", # qwen2vl + "model.vision_model.embeddings.patch_embedding", # SmolVLM + "vision_tower.patch_conv", # pixtral + "vision_model.patch_embedding.linear", # llama 4 + "visual.patch_embed.proj", # qwen2vl ), - MODEL_TENSOR.V_ENC_EMBD_POS: ( "vision_tower.vision_model.embeddings.position_embedding", "vpm.embeddings.position_embedding", - "model.vision_model.embeddings.position_embedding", # SmolVLM - "vision_model.positional_embedding_vlm", # llama 4 + "model.vision_model.embeddings.position_embedding", # SmolVLM + "vision_model.positional_embedding_vlm", # llama 4 ), - MODEL_TENSOR.V_ENC_ATTN_Q: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", "vpm.encoder.layers.{bid}.self_attn.q_proj", - "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM - "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral - "visual.blocks.{bid}.attn.q", # qwen2vl, generated + "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral + "visual.blocks.{bid}.attn.q", # qwen2vl, generated ), - MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( - "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL ), - MODEL_TENSOR.V_ENC_ATTN_K: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", "vpm.encoder.layers.{bid}.self_attn.k_proj", - "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM - "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral - "visual.blocks.{bid}.attn.k", # qwen2vl, generated + "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral + "visual.blocks.{bid}.attn.k", # qwen2vl, generated ), - MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( - "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL ), - MODEL_TENSOR.V_ENC_ATTN_V: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", "vpm.encoder.layers.{bid}.self_attn.v_proj", - "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM - "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral - "visual.blocks.{bid}.attn.v", # qwen2vl, generated + "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral + "visual.blocks.{bid}.attn.v", # qwen2vl, generated ), - MODEL_TENSOR.V_ENC_INPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", - "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL "vpm.encoder.layers.{bid}.layer_norm1", - "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM - "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral - "vision_model.model.layers.{bid}.input_layernorm", # llama4 - "visual.blocks.{bid}.norm1", # qwen2vl + "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM + "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral + "vision_model.model.layers.{bid}.input_layernorm", # llama4 + "visual.blocks.{bid}.norm1", # qwen2vl ), - MODEL_TENSOR.V_ENC_ATTN_O: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", - "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL "vpm.encoder.layers.{bid}.self_attn.out_proj", - "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM - "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 - "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral - "visual.blocks.{bid}.attn.proj", # qwen2vl + "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM + "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4 + "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral + "visual.blocks.{bid}.attn.proj", # qwen2vl ), - MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", - "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL "vpm.encoder.layers.{bid}.layer_norm2", - "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM - "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 - "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral - "visual.blocks.{bid}.norm2", # qwen2vl + "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM + "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4 + "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral + "visual.blocks.{bid}.norm2", # qwen2vl ), - MODEL_TENSOR.V_ENC_FFN_UP: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", "vpm.encoder.layers.{bid}.mlp.fc1", - "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 - "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral - "vision_model.model.layers.{bid}.mlp.fc1", # llama4 - "visual.blocks.{bid}.mlp.fc1", # qwen2vl - "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl + "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc1", # llama4 + "visual.blocks.{bid}.mlp.fc1", # qwen2vl + "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl ), - MODEL_TENSOR.V_ENC_FFN_GATE: ( - "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral - "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl + "vision_tower.transformer.layers.{bid}.feed_forward.gate_proj", # pixtral + "visual.blocks.{bid}.mlp.gate_proj", # qwen2.5vl ), - MODEL_TENSOR.V_ENC_FFN_DOWN: ( "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", "vpm.encoder.layers.{bid}.mlp.fc2", - "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 - "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral - "vision_model.model.layers.{bid}.mlp.fc2", # llama4 - "visual.blocks.{bid}.mlp.fc2", # qwen2vl - "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl + "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3 + "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral + "vision_model.model.layers.{bid}.mlp.fc2", # llama4 + "visual.blocks.{bid}.mlp.fc2", # qwen2vl + "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), - MODEL_TENSOR.V_LAYER_SCALE_1: ( - "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL ), - MODEL_TENSOR.V_LAYER_SCALE_2: ( - "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL ), - MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", - "vision_tower.ln_pre", # pixtral - "vision_model.layernorm_pre", # llama4 + "vision_tower.ln_pre", # pixtral + "vision_model.layernorm_pre", # llama4 ), - MODEL_TENSOR.V_POST_NORM: ( "vision_tower.vision_model.post_layernorm", - "model.vision_model.post_layernorm", # SmolVLM - "vision_model.layernorm_post", # llama4 - "visual.merger.ln_q", # qwen2vl - ), - - MODEL_TENSOR.V_MM_INP_PROJ: ( - "multi_modal_projector.mm_input_projection", - ), - - MODEL_TENSOR.V_MM_INP_NORM: ( - "multi_modal_projector.norm", - ), - - MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( - "multi_modal_projector.mm_soft_emb_norm", - ), - - MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ( - "resampler.pos_embed_k", - ), - + "model.vision_model.post_layernorm", # SmolVLM + "vision_model.layernorm_post", # llama4 + "visual.merger.ln_q", # qwen2vl + ), + MODEL_TENSOR.V_MM_INP_PROJ: ("multi_modal_projector.mm_input_projection",), + MODEL_TENSOR.V_MM_INP_NORM: ("multi_modal_projector.norm",), + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ("multi_modal_projector.mm_soft_emb_norm",), + MODEL_TENSOR.V_RESMPL_POS_EMBD_K: ("resampler.pos_embed_k",), MODEL_TENSOR.V_RESMPL_ATTN_Q: ( - "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj + "resampler.attn.in_proj_q", # tensor generated from resampler.attn.in_proj ), - MODEL_TENSOR.V_RESMPL_ATTN_K: ( - "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj + "resampler.attn.in_proj_k", # tensor generated from resampler.attn.in_proj ), - MODEL_TENSOR.V_RESMPL_ATTN_V: ( - "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj - ), - - MODEL_TENSOR.V_RESMPL_ATTN_OUT: ( - "resampler.attn.out_proj", - ), - - MODEL_TENSOR.V_RESMPL_KV: ( - "resampler.kv_proj", - ), - - MODEL_TENSOR.V_RESMPL_POST_NORM: ( - "resampler.ln_post", - ), - - MODEL_TENSOR.V_RESMPL_KV_NORM: ( - "resampler.ln_kv", - ), - - MODEL_TENSOR.V_RESMPL_Q_NORM: ( - "resampler.ln_q", - ), - - MODEL_TENSOR.V_RESMPL_PROJ: ( - "resampler.proj", - ), - - MODEL_TENSOR.V_RESMPL_QUERY: ( - "resampler.query", - ), - + "resampler.attn.in_proj_v", # tensor generated from resampler.attn.in_proj + ), + MODEL_TENSOR.V_RESMPL_ATTN_OUT: ("resampler.attn.out_proj",), + MODEL_TENSOR.V_RESMPL_KV: ("resampler.kv_proj",), + MODEL_TENSOR.V_RESMPL_POST_NORM: ("resampler.ln_post",), + MODEL_TENSOR.V_RESMPL_KV_NORM: ("resampler.ln_kv",), + MODEL_TENSOR.V_RESMPL_Q_NORM: ("resampler.ln_q",), + MODEL_TENSOR.V_RESMPL_PROJ: ("resampler.proj",), + MODEL_TENSOR.V_RESMPL_QUERY: ("resampler.query",), MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( - "v.token_embd.img_break", # for pixtral, this is a generated vector + "v.token_embd.img_break", # for pixtral, this is a generated vector ), - MODEL_TENSOR.V_MM_PATCH_MERGER: ( - "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 + "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 ), - # audio (mtmd) - - MODEL_TENSOR.A_ENC_EMBD_POS: ( - "audio_tower.embed_positions", # ultravox - ), - - MODEL_TENSOR.A_ENC_CONV1D: ( - "audio_tower.conv{bid}", # ultravox - ), - + MODEL_TENSOR.A_ENC_EMBD_POS: ("audio_tower.embed_positions",), # ultravox + MODEL_TENSOR.A_ENC_CONV1D: ("audio_tower.conv{bid}",), # ultravox MODEL_TENSOR.A_PRE_NORM: (), - MODEL_TENSOR.A_POST_NORM: ( - "audio_tower.layer_norm", # ultravox - "audio_tower.ln_post", # qwen2omni + "audio_tower.layer_norm", # ultravox + "audio_tower.ln_post", # qwen2omni ), - MODEL_TENSOR.A_ENC_ATTN_Q: ( - "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox + "audio_tower.layers.{bid}.self_attn.q_proj", # ultravox ), - MODEL_TENSOR.A_ENC_ATTN_K: ( - "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox + "audio_tower.layers.{bid}.self_attn.k_proj", # ultravox ), - MODEL_TENSOR.A_ENC_ATTN_V: ( - "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox + "audio_tower.layers.{bid}.self_attn.v_proj", # ultravox ), - MODEL_TENSOR.A_ENC_INPUT_NORM: ( - "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox + "audio_tower.layers.{bid}.self_attn_layer_norm", # ultravox ), - MODEL_TENSOR.A_ENC_OUTPUT: ( - "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox + "audio_tower.layers.{bid}.self_attn.out_proj", # ultravox ), - MODEL_TENSOR.A_ENC_OUTPUT_NORM: ( - "audio_tower.layers.{bid}.final_layer_norm", # ultravox - ), - - MODEL_TENSOR.A_ENC_FFN_UP: ( - "audio_tower.layers.{bid}.fc1", # ultravox + "audio_tower.layers.{bid}.final_layer_norm", # ultravox ), - + MODEL_TENSOR.A_ENC_FFN_UP: ("audio_tower.layers.{bid}.fc1",), # ultravox MODEL_TENSOR.A_ENC_FFN_GATE: (), - - MODEL_TENSOR.A_ENC_FFN_DOWN: ( - "audio_tower.layers.{bid}.fc2", # ultravox - ), - + MODEL_TENSOR.A_ENC_FFN_DOWN: ("audio_tower.layers.{bid}.fc2",), # ultravox # note: some tensors below has "audio." pseudo-prefix, to prevent conflicts with vision tensors # this prefix is added in the conversion code in modify_tensors() - MODEL_TENSOR.A_MMPROJ: ( - "audio.multi_modal_projector.linear_{bid}", # ultravox + "audio.multi_modal_projector.linear_{bid}", # ultravox ), - MODEL_TENSOR.A_MMPROJ_FC: ( - "audio.multi_modal_projector.linear", # qwen2audio - "audio_tower.proj", # qwen2omni - ), - - MODEL_TENSOR.A_MM_NORM_PRE: ( - "audio.multi_modal_projector.ln_pre", # ultravox + "audio.multi_modal_projector.linear", # qwen2audio + "audio_tower.proj", # qwen2omni ), - - MODEL_TENSOR.A_MM_NORM_MID: ( - "audio.multi_modal_projector.ln_mid", # ultravox - ), - + MODEL_TENSOR.A_MM_NORM_PRE: ("audio.multi_modal_projector.ln_pre",), # ultravox + MODEL_TENSOR.A_MM_NORM_MID: ("audio.multi_modal_projector.ln_mid",), # ultravox # -- PowerInfer - MODEL_TENSOR.LMHEAD_PROFILER:{ + MODEL_TENSOR.LMHEAD_PROFILER: { "model.lm_head_profiler", - } + }, # -- PowerInfer end } # architecture-specific block mappings arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = { MODEL_ARCH.ARCTIC: { - MODEL_TENSOR.FFN_NORM: ( - "model.layers.{bid}.residual_layernorm", - ), - MODEL_TENSOR.FFN_NORM_EXP: ( - "model.layers.{bid}.post_attention_layernorm", - ), + MODEL_TENSOR.FFN_NORM: ("model.layers.{bid}.residual_layernorm",), + MODEL_TENSOR.FFN_NORM_EXP: ("model.layers.{bid}.post_attention_layernorm",), }, } @@ -1236,31 +911,35 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int): if tensor not in MODEL_TENSORS[arch]: continue - tensor_name = TENSOR_NAMES[tensor].format(bid = bid) + tensor_name = TENSOR_NAMES[tensor].format(bid=bid) self.mapping[tensor_name] = (tensor, tensor_name) for key in keys: - key = key.format(bid = bid) + key = key.format(bid=bid) self.mapping[key] = (tensor, tensor_name) - def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None: + def get_type_and_name( + self, key: str, try_suffixes: Sequence[str] = () + ) -> tuple[MODEL_TENSOR, str] | None: result = self.mapping.get(key) if result is not None: return result for suffix in try_suffixes: if key.endswith(suffix): - result = self.mapping.get(key[:-len(suffix)]) + result = self.mapping.get(key[: -len(suffix)]) if result is not None: return result[0], result[1] + suffix return None def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None: - result = self.get_type_and_name(key, try_suffixes = try_suffixes) + result = self.get_type_and_name(key, try_suffixes=try_suffixes) if result is None: return None return result[1] - def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None: - result = self.get_type_and_name(key, try_suffixes = try_suffixes) + def get_type( + self, key: str, try_suffixes: Sequence[str] = () + ) -> MODEL_TENSOR | None: + result = self.get_type_and_name(key, try_suffixes=try_suffixes) if result is None: return None return result[0] diff --git a/smallthinker/gguf-py/gguf/utility.py b/smallthinker/gguf-py/gguf/utility.py index 00adcbc9..d297b4fa 100644 --- a/smallthinker/gguf-py/gguf/utility.py +++ b/smallthinker/gguf-py/gguf/utility.py @@ -11,21 +11,27 @@ def fill_templated_filename(filename: str, output_type: str | None) -> str: # Given a file name fill in any type templates e.g. 'some-model-name.{ftype}.gguf' ftype_lowercase: str = output_type.lower() if output_type is not None else "" ftype_uppercase: str = output_type.upper() if output_type is not None else "" - return filename.format(ftype_lowercase, - outtype=ftype_lowercase, ftype=ftype_lowercase, - OUTTYPE=ftype_uppercase, FTYPE=ftype_uppercase) - - -def model_weight_count_rounded_notation(model_params_count: int, min_digits: int = 2) -> str: - if model_params_count > 1e12 : + return filename.format( + ftype_lowercase, + outtype=ftype_lowercase, + ftype=ftype_lowercase, + OUTTYPE=ftype_uppercase, + FTYPE=ftype_uppercase, + ) + + +def model_weight_count_rounded_notation( + model_params_count: int, min_digits: int = 2 +) -> str: + if model_params_count > 1e12: # Trillions Of Parameters scaled_model_params = model_params_count * 1e-12 scale_suffix = "T" - elif model_params_count > 1e9 : + elif model_params_count > 1e9: # Billions Of Parameters scaled_model_params = model_params_count * 1e-9 scale_suffix = "B" - elif model_params_count > 1e6 : + elif model_params_count > 1e6: # Millions Of Parameters scaled_model_params = model_params_count * 1e-6 scale_suffix = "M" @@ -34,39 +40,65 @@ def model_weight_count_rounded_notation(model_params_count: int, min_digits: int scaled_model_params = model_params_count * 1e-3 scale_suffix = "K" - fix = max(min_digits - len(str(round(scaled_model_params)).lstrip('0')), 0) + fix = max(min_digits - len(str(round(scaled_model_params)).lstrip("0")), 0) return f"{scaled_model_params:.{fix}f}{scale_suffix}" -def size_label(total_params: int, shared_params: int, expert_params: int, expert_count: int) -> str: +def size_label( + total_params: int, shared_params: int, expert_params: int, expert_count: int +) -> str: if expert_count > 0: - pretty_size = model_weight_count_rounded_notation(abs(shared_params) + abs(expert_params), min_digits=2) + pretty_size = model_weight_count_rounded_notation( + abs(shared_params) + abs(expert_params), min_digits=2 + ) size_class = f"{expert_count}x{pretty_size}" else: - size_class = model_weight_count_rounded_notation(abs(total_params), min_digits=2) + size_class = model_weight_count_rounded_notation( + abs(total_params), min_digits=2 + ) return size_class -def naming_convention(model_name: str | None, base_name: str | None, finetune_string: str | None, version_string: str | None, size_label: str | None, output_type: str | None, model_type: Literal['vocab', 'LoRA'] | None = None) -> str: +def naming_convention( + model_name: str | None, + base_name: str | None, + finetune_string: str | None, + version_string: str | None, + size_label: str | None, + output_type: str | None, + model_type: Literal["vocab", "LoRA"] | None = None, +) -> str: # Reference: https://github.com/ggml-org/ggml/blob/master/docs/gguf.md#gguf-naming-convention if base_name is not None: - name = base_name.strip().replace(' ', '-').replace('/', '-') + name = base_name.strip().replace(" ", "-").replace("/", "-") elif model_name is not None: - name = model_name.strip().replace(' ', '-').replace('/', '-') + name = model_name.strip().replace(" ", "-").replace("/", "-") else: name = "ggml-model" parameters = f"-{size_label}" if size_label is not None else "" - finetune = f"-{finetune_string.strip().replace(' ', '-')}" if finetune_string is not None else "" + finetune = ( + f"-{finetune_string.strip().replace(' ', '-')}" + if finetune_string is not None + else "" + ) - version = f"-{version_string.strip().replace(' ', '-')}" if version_string is not None else "" + version = ( + f"-{version_string.strip().replace(' ', '-')}" + if version_string is not None + else "" + ) - encoding = f"-{output_type.strip().replace(' ', '-').upper()}" if output_type is not None else "" + encoding = ( + f"-{output_type.strip().replace(' ', '-').upper()}" + if output_type is not None + else "" + ) kind = f"-{model_type.strip().replace(' ', '-')}" if model_type is not None else "" @@ -84,7 +116,11 @@ class RemoteTensor: def data(self) -> bytearray: # TODO: handle request errors (maybe with limited retries?) # NOTE: using a bytearray, otherwise PyTorch complains the buffer is not writeable - data = bytearray(SafetensorRemote.get_data_by_range(url=self.url, start=self.offset_start, size=self.size)) + data = bytearray( + SafetensorRemote.get_data_by_range( + url=self.url, start=self.offset_start, size=self.size + ) + ) return data @@ -108,7 +144,7 @@ class SafetensorRemote: """ BASE_DOMAIN = "https://huggingface.co" - ALIGNMENT = 8 # bytes + ALIGNMENT = 8 # bytes @classmethod def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: @@ -119,24 +155,30 @@ def get_list_tensors_hf_model(cls, model_id: str) -> dict[str, RemoteTensor]: Each tensor is represented as a tuple of (dtype, shape, offset_start, size, remote_safetensor_url) """ # case 1: model has only one single model.safetensor file - is_single_file = cls.check_file_exist(f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors") + is_single_file = cls.check_file_exist( + f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" + ) if is_single_file: url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors" return cls.get_list_tensors(url) # case 2: model has multiple files - index_url = f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + index_url = ( + f"{cls.BASE_DOMAIN}/{model_id}/resolve/main/model.safetensors.index.json" + ) is_multiple_files = cls.check_file_exist(index_url) if is_multiple_files: # read the index file index_data = cls.get_data_by_range(index_url, 0) - index_str = index_data.decode('utf-8') + index_str = index_data.decode("utf-8") index_json = json.loads(index_str) - assert index_json.get("weight_map") is not None, "weight_map not found in index file" + assert ( + index_json.get("weight_map") is not None + ), "weight_map not found in index file" weight_map = index_json["weight_map"] # get the list of files all_files = list(set(weight_map.values())) - all_files.sort() # make sure we load shard files in order + all_files.sort() # make sure we load shard files in order # get the list of tensors tensors: dict[str, RemoteTensor] = {} for file in all_files: @@ -169,9 +211,17 @@ def get_list_tensors(cls, url: str) -> dict[str, RemoteTensor]: offset_start_relative, offset_end_relative = meta["data_offsets"] size = offset_end_relative - offset_start_relative offset_start = data_start_offset + offset_start_relative - res[name] = RemoteTensor(dtype=dtype, shape=tuple(shape), offset_start=offset_start, size=size, url=url) + res[name] = RemoteTensor( + dtype=dtype, + shape=tuple(shape), + offset_start=offset_start, + size=size, + url=url, + ) except KeyError as e: - raise ValueError(f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}") + raise ValueError( + f"Missing key in metadata for tensor '{name}': {e}, meta = {meta}" + ) return res @@ -190,7 +240,7 @@ def get_metadata(cls, url: str) -> tuple[dict, int]: # First 8 bytes contain the metadata length as u64 little-endian if len(raw_data) < 8: raise ValueError("Not enough data to read metadata size") - metadata_length = int.from_bytes(raw_data[:8], byteorder='little') + metadata_length = int.from_bytes(raw_data[:8], byteorder="little") # Calculate the data start offset data_start_offset = 8 + metadata_length @@ -200,11 +250,13 @@ def get_metadata(cls, url: str) -> tuple[dict, int]: # Check if we have enough data to read the metadata if len(raw_data) < 8 + metadata_length: - raise ValueError(f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}") + raise ValueError( + f"Could not read complete metadata. Need {8 + metadata_length} bytes, got {len(raw_data)}" + ) # Extract metadata bytes and parse as JSON - metadata_bytes = raw_data[8:8 + metadata_length] - metadata_str = metadata_bytes.decode('utf-8') + metadata_bytes = raw_data[8 : 8 + metadata_length] + metadata_str = metadata_bytes.decode("utf-8") try: metadata = json.loads(metadata_str) return metadata, data_start_offset diff --git a/smallthinker/gguf-py/gguf/vocab.py b/smallthinker/gguf-py/gguf/vocab.py index cca09798..7db9b37d 100644 --- a/smallthinker/gguf-py/gguf/vocab.py +++ b/smallthinker/gguf-py/gguf/vocab.py @@ -5,7 +5,16 @@ import json import os from pathlib import Path -from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable +from typing import ( + Any, + Callable, + Sequence, + Mapping, + Iterable, + Protocol, + ClassVar, + runtime_checkable, +) from sentencepiece import SentencePieceProcessor @@ -23,7 +32,9 @@ class SpecialVocab: chat_template: str | Sequence[Mapping[str, str]] | None def __init__( - self, path: str | os.PathLike[str], load_merges: bool = False, + self, + path: str | os.PathLike[str], + load_merges: bool = False, special_token_types: Iterable[str] | None = None, n_vocab: int | None = None, ): @@ -36,40 +47,60 @@ def __init__( if special_token_types is not None: self.special_token_types = special_token_types else: - self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask') + self.special_token_types = ( + "bos", + "eos", + "unk", + "sep", + "pad", + "cls", + "mask", + ) self._load(Path(path)) def __repr__(self) -> str: - return ''.format( - len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", + return "".format( + len(self.merges), + self.special_token_ids or "unset", + self.add_special_token or "unset", ) def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: if self.merges: if not quiet: - logger.info(f'Adding {len(self.merges)} merge(s).') + logger.info(f"Adding {len(self.merges)} merge(s).") gw.add_token_merges(self.merges) elif self.load_merges: - logger.warning('Adding merges requested but no merges found, output may be non-functional.') + logger.warning( + "Adding merges requested but no merges found, output may be non-functional." + ) for typ, tokid in self.special_token_ids.items(): - id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) + id_handler: Callable[[int], None] | None = getattr( + gw, f"add_{typ}_token_id", None + ) if id_handler is None: - logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping') + logger.warning( + f"No handler for special token type {typ} with id {tokid} - skipping" + ) continue if not quiet: - logger.info(f'Setting special token type {typ} to {tokid}') + logger.info(f"Setting special token type {typ} to {tokid}") id_handler(tokid) for typ, value in self.add_special_token.items(): - add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) + add_handler: Callable[[bool], None] | None = getattr( + gw, f"add_add_{typ}_token", None + ) if add_handler is None: - logger.warning(f'No handler for add_{typ}_token with value {value} - skipping') + logger.warning( + f"No handler for add_{typ}_token with value {value} - skipping" + ) continue if not quiet: - logger.info(f'Setting add_{typ}_token to {value}') + logger.info(f"Setting add_{typ}_token to {value}") add_handler(value) if self.chat_template is not None: if not quiet: - logger.info(f'Setting chat_template to {self.chat_template}') + logger.info(f"Setting chat_template to {self.chat_template}") gw.add_chat_template(self.chat_template) def _load(self, path: Path) -> None: @@ -79,12 +110,12 @@ def _load(self, path: Path) -> None: self._try_load_merges_txt(path) def _try_load_merges_txt(self, path: Path) -> bool: - merges_file = path / 'merges.txt' + merges_file = path / "merges.txt" if not merges_file.is_file(): return False - with open(merges_file, 'r', encoding = 'utf-8') as fp: - first_line = next(fp, '').strip() - if not first_line.startswith('#'): + with open(merges_file, "r", encoding="utf-8") as fp: + first_line = next(fp, "").strip() + if not first_line.startswith("#"): fp.seek(0) line_num = 0 else: @@ -97,9 +128,11 @@ def _try_load_merges_txt(self, path: Path) -> bool: continue parts = line.split(None, 3) if len(parts) != 2: - logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring') + logger.warning( + f"{merges_file.name}: Line {line_num}: Entry malformed, ignoring" + ) continue - merges.append(f'{parts[0]} {parts[1]}') + merges.append(f"{parts[0]} {parts[1]}") self.merges = merges return True @@ -107,36 +140,44 @@ def _set_special_token(self, typ: str, tid: Any) -> None: if not isinstance(tid, int): return if tid < 0: - raise ValueError(f'invalid value for special token type {typ}: {tid}') + raise ValueError(f"invalid value for special token type {typ}: {tid}") if self.n_vocab is None or tid < self.n_vocab: if typ in self.special_token_ids: return self.special_token_ids[typ] = tid return - logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') + logger.warning( + f"Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping" + ) def _try_load_from_tokenizer_json(self, path: Path) -> bool: - tokenizer_file = path / 'tokenizer.json' + tokenizer_file = path / "tokenizer.json" if tokenizer_file.is_file(): - with open(tokenizer_file, encoding = 'utf-8') as f: + with open(tokenizer_file, encoding="utf-8") as f: tokenizer = json.load(f) if self.load_merges: - merges = tokenizer.get('model', {}).get('merges') + merges = tokenizer.get("model", {}).get("merges") if isinstance(merges, list) and merges: if isinstance(merges[0], str): self.merges = merges - elif isinstance(merges[0], list) and len(merges[0]) == 2 and isinstance(merges[0][0], str): + elif ( + isinstance(merges[0], list) + and len(merges[0]) == 2 + and isinstance(merges[0][0], str) + ): # New format since transformers 4.45 to support spaces in merges # ref: https://github.com/ggml-org/llama.cpp/issues/9692 # TODO: internally store as the new format instead of converting to old - if any(' ' in s for pair in merges for s in pair): - logger.warning(f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}') + if any(" " in s for pair in merges for s in pair): + logger.warning( + f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}' + ) self.merges = [ - ' '.join( + " ".join( [ # ensure the spaces are properly encoded - ''.join( - chr(ord(c) + 256) if c == ' ' else c + "".join( + chr(ord(c) + 256) if c == " " else c for c in part ) for part in pair @@ -146,33 +187,35 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: ] else: raise ValueError("Unknown tokenizer merges format") - added_tokens = tokenizer.get('added_tokens', {}) + added_tokens = tokenizer.get("added_tokens", {}) else: added_tokens = {} - tokenizer_config_file = path / 'tokenizer_config.json' + tokenizer_config_file = path / "tokenizer_config.json" if not tokenizer_config_file.is_file(): return True - with open(tokenizer_config_file, encoding = 'utf-8') as f: + with open(tokenizer_config_file, encoding="utf-8") as f: tokenizer_config = json.load(f) chat_template_alt = None - chat_template_file = path / 'chat_template.json' + chat_template_file = path / "chat_template.json" if chat_template_file.is_file(): - with open(chat_template_file, encoding = 'utf-8') as f: - chat_template_alt = json.load(f).get('chat_template') - chat_template = tokenizer_config.get('chat_template', chat_template_alt) + with open(chat_template_file, encoding="utf-8") as f: + chat_template_alt = json.load(f).get("chat_template") + chat_template = tokenizer_config.get("chat_template", chat_template_alt) if chat_template is None or isinstance(chat_template, (str, list)): self.chat_template = chat_template else: - logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring') + logger.warning( + f"Bad type for chat_template field in {tokenizer_config_file!r} - ignoring" + ) for typ in self.special_token_types: - add_entry = tokenizer_config.get(f'add_{typ}_token') + add_entry = tokenizer_config.get(f"add_{typ}_token") if isinstance(add_entry, bool): self.add_special_token[typ] = add_entry - entry = tokenizer_config.get(f'{typ}_token') + entry = tokenizer_config.get(f"{typ}_token") if isinstance(entry, str): tc_content = entry elif isinstance(entry, dict): - entry_content = entry.get('content') + entry_content = entry.get("content") if not isinstance(entry_content, str): continue tc_content = entry_content @@ -180,20 +223,24 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool: continue # We only need the first match here. maybe_token_id = next( - (atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), + ( + atok.get("id") + for atok in added_tokens + if atok.get("content") == tc_content + ), None, ) self._set_special_token(typ, maybe_token_id) return True def _try_load_from_config_json(self, path: Path) -> bool: - config_file = path / 'config.json' + config_file = path / "config.json" if not config_file.is_file(): return False - with open(config_file, encoding = 'utf-8') as f: + with open(config_file, encoding="utf-8") as f: config = json.load(f) for typ in self.special_token_types: - self._set_special_token(typ, config.get(f'{typ}_token_id')) + self._set_special_token(typ, config.get(f"{typ}_token_id")) return True @@ -229,54 +276,59 @@ class BpeVocab(Vocab): def __init__(self, base_path: Path): added_tokens: dict[str, int] = {} - if (fname_tokenizer := base_path / 'vocab.json').exists(): + if (fname_tokenizer := base_path / "vocab.json").exists(): # "slow" tokenizer with open(fname_tokenizer, encoding="utf-8") as f: self.vocab = json.load(f) try: # FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. - with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + with open(base_path / "added_tokens.json", encoding="utf-8") as f: added_tokens = json.load(f) except FileNotFoundError: pass else: # "fast" tokenizer - fname_tokenizer = base_path / 'tokenizer.json' + fname_tokenizer = base_path / "tokenizer.json" # if this fails, FileNotFoundError propagates to caller with open(fname_tokenizer, encoding="utf-8") as f: tokenizer_json = json.load(f) - tokenizer_model: dict[str, Any] = tokenizer_json['model'] + tokenizer_model: dict[str, Any] = tokenizer_json["model"] if ( - tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) - or tokenizer_json['decoder']['type'] != 'ByteLevel' + tokenizer_model["type"] != "BPE" + or tokenizer_model.get("byte_fallback", False) + or tokenizer_json["decoder"]["type"] != "ByteLevel" ): - raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') + raise FileNotFoundError("Cannot find GPT-2 BPE tokenizer") self.vocab = tokenizer_model["vocab"] - if (added := tokenizer_json.get('added_tokens')) is not None: + if (added := tokenizer_json.get("added_tokens")) is not None: # Added tokens here can be duplicates of the main vocabulary. - added_tokens = {item['content']: item['id'] - for item in added - if item['content'] not in self.vocab} + added_tokens = { + item["content"]: item["id"] + for item in added + if item["content"] not in self.vocab + } - vocab_size = len(self.vocab) + vocab_size = len(self.vocab) expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: expected_end_id = vocab_size + len(actual_ids) - 1 - raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " - f"{vocab_size} - {expected_end_id}; got {actual_ids}") + raise ValueError( + f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " + f"{vocab_size} - {expected_end_id}; got {actual_ids}" + ) items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) - self.added_tokens_dict = added_tokens - self.added_tokens_list = [text for (text, idx) in items] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_dict = added_tokens + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} @@ -303,40 +355,44 @@ class SentencePieceVocab(Vocab): def __init__(self, base_path: Path): added_tokens: dict[str, int] = {} - if (fname_tokenizer := base_path / 'tokenizer.model').exists(): + if (fname_tokenizer := base_path / "tokenizer.model").exists(): # normal location try: - with open(base_path / 'added_tokens.json', encoding="utf-8") as f: + with open(base_path / "added_tokens.json", encoding="utf-8") as f: added_tokens = json.load(f) except FileNotFoundError: pass - elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): + elif not (fname_tokenizer := base_path.parent / "tokenizer.model").exists(): # not found in alternate location either - raise FileNotFoundError('Cannot find tokenizer.model') + raise FileNotFoundError("Cannot find tokenizer.model") self.sentencepiece_tokenizer = SentencePieceProcessor() self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer)) vocab_size = self.sentencepiece_tokenizer.vocab_size() - new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} + new_tokens = { + id: piece for piece, id in added_tokens.items() if id >= vocab_size + } expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) - actual_new_ids = sorted(new_tokens.keys()) + actual_new_ids = sorted(new_tokens.keys()) if expected_new_ids != actual_new_ids: - raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") + raise ValueError( + f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}" + ) # Token pieces that were added to the base vocabulary. - self.added_tokens_dict = added_tokens - self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] - self.vocab_size_base = vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) - self.fname_tokenizer = fname_tokenizer + self.added_tokens_dict = added_tokens + self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] + self.vocab_size_base = vocab_size + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer for i in range(tokenizer.vocab_size()): piece = tokenizer.IdToPiece(i) - text = piece.encode("utf-8") + text = piece.encode("utf-8") score: float = tokenizer.GetScore(i) toktype = gguf.TokenType.NORMAL @@ -374,25 +430,27 @@ class LlamaHfVocab(Vocab): name = "hfft" def __init__(self, base_path: Path): - fname_tokenizer = base_path / 'tokenizer.json' + fname_tokenizer = base_path / "tokenizer.json" # if this fails, FileNotFoundError propagates to caller - with open(fname_tokenizer, encoding='utf-8') as f: + with open(fname_tokenizer, encoding="utf-8") as f: tokenizer_json = json.load(f) # pre-check so we know if we need transformers - tokenizer_model: dict[str, Any] = tokenizer_json['model'] + tokenizer_model: dict[str, Any] = tokenizer_json["model"] is_llama3 = ( - tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) - and not tokenizer_model.get('byte_fallback', True) + tokenizer_model["type"] == "BPE" + and tokenizer_model.get("ignore_merges", False) + and not tokenizer_model.get("byte_fallback", True) ) if is_llama3: - raise TypeError('Llama 3 must be converted with BpeVocab') + raise TypeError("Llama 3 must be converted with BpeVocab") if not is_llama3 and ( - tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) - or tokenizer_json['decoder']['type'] != 'Sequence' + tokenizer_model["type"] != "BPE" + or not tokenizer_model.get("byte_fallback", False) + or tokenizer_json["decoder"]["type"] != "Sequence" ): - raise FileNotFoundError('Cannot find Llama BPE tokenizer') + raise FileNotFoundError("Cannot find Llama BPE tokenizer") try: from transformers import AutoTokenizer @@ -414,7 +472,7 @@ def __init__(self, base_path: Path): # Initialize lists and dictionaries for added tokens self.added_tokens_list = [] self.added_tokens_dict = dict() - self.added_tokens_ids = set() + self.added_tokens_ids = set() # Process added tokens for tok, tokidx in sorted( @@ -435,7 +493,7 @@ def __init__(self, base_path: Path): # Set vocabulary sizes self.vocab_size_base = self.tokenizer.vocab_size - self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) + self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) self.fname_tokenizer = fname_tokenizer @@ -454,16 +512,22 @@ def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: # Yield token text, score, and type yield token_text, self.get_token_score(token_id), self.get_token_type( - token_id, token_text, self.special_ids # Reuse already stored special IDs + token_id, + token_text, + self.special_ids, # Reuse already stored special IDs ) - def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: + def get_token_type( + self, token_id: int, token_text: bytes, special_ids: set[int] + ) -> gguf.TokenType: # Special case for byte tokens - if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): + if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text): return gguf.TokenType.BYTE # Determine token type based on whether it's a special token - return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + return ( + gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL + ) def get_token_score(self, token_id: int) -> float: # Placeholder for actual logic to determine the token's score @@ -473,7 +537,9 @@ def get_token_score(self, token_id: int) -> float: def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: if text in self.specials: - toktype = self.get_token_type(self.specials[text], b'', self.special_ids) + toktype = self.get_token_type( + self.specials[text], b"", self.special_ids + ) score = self.get_token_score(self.specials[text]) else: toktype = gguf.TokenType.USER_DEFINED diff --git a/smallthinker/gguf-py/tests/test_metadata.py b/smallthinker/gguf-py/tests/test_metadata.py index 40d484f4..6a6e8dca 100755 --- a/smallthinker/gguf-py/tests/test_metadata.py +++ b/smallthinker/gguf-py/tests/test_metadata.py @@ -6,7 +6,10 @@ import sys # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent)) import gguf @@ -15,222 +18,591 @@ class TestMetadataMethod(unittest.TestCase): def test_id_to_title(self): - self.assertEqual(gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), "Mixtral 8x7B Instruct v0.1") - self.assertEqual(gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B") - self.assertEqual(gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), "Hermes 2 Pro Llama 3 8b DPO") + self.assertEqual( + gguf.Metadata.id_to_title("Mixtral-8x7B-Instruct-v0.1"), + "Mixtral 8x7B Instruct v0.1", + ) + self.assertEqual( + gguf.Metadata.id_to_title("Meta-Llama-3-8B"), "Meta Llama 3 8B" + ) + self.assertEqual( + gguf.Metadata.id_to_title("hermes-2-pro-llama-3-8b-DPO"), + "Hermes 2 Pro Llama 3 8b DPO", + ) def test_get_model_id_components(self): # This is the basic standard form with organization marker - self.assertEqual(gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), - ('Mixtral-8x7B-Instruct-v0.1', "Mistral", 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mistral/Mixtral-8x7B-Instruct-v0.1"), + ( + "Mixtral-8x7B-Instruct-v0.1", + "Mistral", + "Mixtral", + "Instruct", + "v0.1", + "8x7B", + ), + ) # Similar to basic standard form but without organization marker - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), - ('Mixtral-8x7B-Instruct-v0.1', None, 'Mixtral', 'Instruct', 'v0.1', '8x7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct-v0.1"), + ("Mixtral-8x7B-Instruct-v0.1", None, "Mixtral", "Instruct", "v0.1", "8x7B"), + ) # Missing version - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), - ('Mixtral-8x7B-Instruct', None, 'Mixtral', 'Instruct', None, '8x7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral-8x7B-Instruct"), + ("Mixtral-8x7B-Instruct", None, "Mixtral", "Instruct", None, "8x7B"), + ) # Missing finetune - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), - ('Mixtral-8x7B-v0.1', None, 'Mixtral', None, 'v0.1', '8x7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral-8x7B-v0.1"), + ("Mixtral-8x7B-v0.1", None, "Mixtral", None, "v0.1", "8x7B"), + ) # Base name and size label only - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-8x7B"), - ('Mixtral-8x7B', None, 'Mixtral', None, None, '8x7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral-8x7B"), + ("Mixtral-8x7B", None, "Mixtral", None, None, "8x7B"), + ) # Base name and version only - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral-v0.1"), - ('Mixtral-v0.1', None, 'Mixtral', None, 'v0.1', None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral-v0.1"), + ("Mixtral-v0.1", None, "Mixtral", None, "v0.1", None), + ) ## Edge Cases ## # This is too ambiguous... best to err on caution and output nothing - self.assertEqual(gguf.Metadata.get_model_id_components("Mixtral"), - ('Mixtral', None, None, None, None, None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("Mixtral"), + ("Mixtral", None, None, None, None, None), + ) # Basename has numbers mixed in and also size label provided. Must avoid capturing number in basename - self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), - ('Meta-Llama-3-8B', "NousResearch", 'Meta-Llama-3', None, None, '8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("NousResearch/Meta-Llama-3-8B"), + ("Meta-Llama-3-8B", "NousResearch", "Meta-Llama-3", None, None, "8B"), + ) # Non standard naming - self.assertEqual(gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), - ('Qwen1.5-MoE-A2.7B-Chat', None, 'Qwen1.5-MoE', 'Chat', None, 'A2.7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Qwen1.5-MoE-A2.7B-Chat"), + ("Qwen1.5-MoE-A2.7B-Chat", None, "Qwen1.5-MoE", "Chat", None, "A2.7B"), + ) # Capture 'sub size labels' e.g. A14B in '57B-A14B' usually refers to activated params/weight count - self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"), - ('Qwen2-57B-A14B-Instruct', None, 'Qwen2', 'Instruct', None, '57B-A14B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Qwen2-57B-A14B-Instruct"), + ("Qwen2-57B-A14B-Instruct", None, "Qwen2", "Instruct", None, "57B-A14B"), + ) # Check that it can handle a real model id with no version code # Note that 4k in this string is non standard and microsoft were referring to context length rather than weight count - self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Phi-3-mini-4k-instruct", 4 * 10**9), - ('Phi-3-mini-4k-instruct', 'microsoft', 'Phi-3', '4k-instruct', None, 'mini')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "microsoft/Phi-3-mini-4k-instruct", 4 * 10**9 + ), + ( + "Phi-3-mini-4k-instruct", + "microsoft", + "Phi-3", + "4k-instruct", + None, + "mini", + ), + ) # There is some legitimate models with only thousands of parameters - self.assertEqual(gguf.Metadata.get_model_id_components("delphi-suite/stories-llama2-50k", 50 * 10**3), - ('stories-llama2-50k', 'delphi-suite', 'stories-llama2', None, None, '50K')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "delphi-suite/stories-llama2-50k", 50 * 10**3 + ), + ("stories-llama2-50k", "delphi-suite", "stories-llama2", None, None, "50K"), + ) # Non standard and not easy to disambiguate - self.assertEqual(gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), - ('DeepSeek-Coder-V2-Lite-Instruct', None, 'DeepSeek-Coder-V2-Lite', 'Instruct', None, None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("DeepSeek-Coder-V2-Lite-Instruct"), + ( + "DeepSeek-Coder-V2-Lite-Instruct", + None, + "DeepSeek-Coder-V2-Lite", + "Instruct", + None, + None, + ), + ) # This is a real model_id where they append 2DPO to refer to Direct Preference Optimization - self.assertEqual(gguf.Metadata.get_model_id_components("crestf411/daybreak-kunoichi-2dpo-7b"), - ('daybreak-kunoichi-2dpo-7b', 'crestf411', 'daybreak-kunoichi', '2dpo', None, '7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "crestf411/daybreak-kunoichi-2dpo-7b" + ), + ( + "daybreak-kunoichi-2dpo-7b", + "crestf411", + "daybreak-kunoichi", + "2dpo", + None, + "7B", + ), + ) # This is a real model id where the weight size has a decimal point - self.assertEqual(gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"), - ('Qwen2-0.5B-Instruct', None, 'Qwen2', 'Instruct', None, '0.5B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("Qwen2-0.5B-Instruct"), + ("Qwen2-0.5B-Instruct", None, "Qwen2", "Instruct", None, "0.5B"), + ) # Uses an underscore in the size label - self.assertEqual(gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"), - ('Refact-1_6B-fim', 'smallcloudai', 'Refact', 'fim', None, '1.6B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("smallcloudai/Refact-1_6B-fim"), + ("Refact-1_6B-fim", "smallcloudai", "Refact", "fim", None, "1.6B"), + ) # Uses Iter3 for the version - self.assertEqual(gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"), - ('Gemma-2-9B-It-SPPO-Iter3', 'UCLA-AGI', 'Gemma-2', 'It-SPPO', 'Iter3', '9B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("UCLA-AGI/Gemma-2-9B-It-SPPO-Iter3"), + ( + "Gemma-2-9B-It-SPPO-Iter3", + "UCLA-AGI", + "Gemma-2", + "It-SPPO", + "Iter3", + "9B", + ), + ) # Has two potential versions in the basename - self.assertEqual(gguf.Metadata.get_model_id_components("NousResearch/Hermes-2-Theta-Llama-3-8B"), - ('Hermes-2-Theta-Llama-3-8B', 'NousResearch', 'Hermes-2-Theta-Llama-3', None, None, '8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "NousResearch/Hermes-2-Theta-Llama-3-8B" + ), + ( + "Hermes-2-Theta-Llama-3-8B", + "NousResearch", + "Hermes-2-Theta-Llama-3", + None, + None, + "8B", + ), + ) # Potential version in the basename - self.assertEqual(gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"), - ('SeaLLMs-v3-7B-Chat', 'SeaLLMs', 'SeaLLMs-v3', 'Chat', None, '7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("SeaLLMs/SeaLLMs-v3-7B-Chat"), + ("SeaLLMs-v3-7B-Chat", "SeaLLMs", "SeaLLMs-v3", "Chat", None, "7B"), + ) # Underscore in the basename, and 1m for the context size - self.assertEqual(gguf.Metadata.get_model_id_components("internlm/internlm2_5-7b-chat-1m", 7 * 10**9), - ('internlm2_5-7b-chat-1m', 'internlm', 'internlm2_5', 'chat-1m', None, '7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "internlm/internlm2_5-7b-chat-1m", 7 * 10**9 + ), + ( + "internlm2_5-7b-chat-1m", + "internlm", + "internlm2_5", + "chat-1m", + None, + "7B", + ), + ) # Version before the finetune name - self.assertEqual(gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"), - ('jamba-900M-v0.13-KIx2', 'pszemraj', 'jamba', 'KIx2', 'v0.13', '900M')) + self.assertEqual( + gguf.Metadata.get_model_id_components("pszemraj/jamba-900M-v0.13-KIx2"), + ("jamba-900M-v0.13-KIx2", "pszemraj", "jamba", "KIx2", "v0.13", "900M"), + ) # TODO: hf suffix which could be ignored but isn't - self.assertEqual(gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"), - ('mamba-2.8b-hf', 'state-spaces', 'mamba', 'hf', None, '2.8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("state-spaces/mamba-2.8b-hf"), + ("mamba-2.8b-hf", "state-spaces", "mamba", "hf", None, "2.8B"), + ) # Two sizes, don't merge them, the other is the number of tokens on which it was trained - self.assertEqual(gguf.Metadata.get_model_id_components("abacaj/llama-161M-100B", 161 * 10**6), - ('llama-161M-100B', 'abacaj', 'llama', '100b', None, '161M')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "abacaj/llama-161M-100B", 161 * 10**6 + ), + ("llama-161M-100B", "abacaj", "llama", "100b", None, "161M"), + ) # It's a trap, there is no size label - self.assertEqual(gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6), - ('relu-100B', 'SparseLLM', 'relu', '100b', None, None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("SparseLLM/relu-100B", 1340 * 10**6), + ("relu-100B", "SparseLLM", "relu", "100b", None, None), + ) # Weird size notation - self.assertEqual(gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), - ('bloom-7b1-petals', 'bigscience', 'bloom', 'petals', None, '7.1B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("bigscience/bloom-7b1-petals"), + ("bloom-7b1-petals", "bigscience", "bloom", "petals", None, "7.1B"), + ) # Ignore full-text size labels when there are number-based ones, and deduplicate size labels - self.assertEqual(gguf.Metadata.get_model_id_components("MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1"), - ('GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1', 'MaziyarPanahi', 'GreenNode-mini', 'multilingual-v1olet-Mistral-Instruct', 'v0.1', '7B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "MaziyarPanahi/GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1" + ), + ( + "GreenNode-mini-7B-multilingual-v1olet-Mistral-7B-Instruct-v0.1", + "MaziyarPanahi", + "GreenNode-mini", + "multilingual-v1olet-Mistral-Instruct", + "v0.1", + "7B", + ), + ) # Instruct in a name without a size label - self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/Mistral-Nemo-Instruct-2407"), - ('Mistral-Nemo-Instruct-2407', 'mistralai', 'Mistral-Nemo', 'Instruct', '2407', None)) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "mistralai/Mistral-Nemo-Instruct-2407" + ), + ( + "Mistral-Nemo-Instruct-2407", + "mistralai", + "Mistral-Nemo", + "Instruct", + "2407", + None, + ), + ) # Non-obvious splitting relying on 'chat' keyword - self.assertEqual(gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), - ('DeepSeek-V2-Chat-0628', 'deepseek-ai', 'DeepSeek-V2', 'Chat', '0628', None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("deepseek-ai/DeepSeek-V2-Chat-0628"), + ( + "DeepSeek-V2-Chat-0628", + "deepseek-ai", + "DeepSeek-V2", + "Chat", + "0628", + None, + ), + ) # Multiple versions - self.assertEqual(gguf.Metadata.get_model_id_components("OpenGVLab/Mini-InternVL-Chat-2B-V1-5"), - ('Mini-InternVL-Chat-2B-V1-5', 'OpenGVLab', 'Mini-InternVL', 'Chat', 'V1-5', '2B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "OpenGVLab/Mini-InternVL-Chat-2B-V1-5" + ), + ( + "Mini-InternVL-Chat-2B-V1-5", + "OpenGVLab", + "Mini-InternVL", + "Chat", + "V1-5", + "2B", + ), + ) # TODO: DPO in the name - self.assertEqual(gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), - ('bagel-dpo-2.8b-v0.2', 'jondurbin', 'bagel-dpo', None, 'v0.2', '2.8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components("jondurbin/bagel-dpo-2.8b-v0.2"), + ("bagel-dpo-2.8b-v0.2", "jondurbin", "bagel-dpo", None, "v0.2", "2.8B"), + ) # DPO in name, but can't be used for the finetune to keep 'LLaMA-3' in the basename - self.assertEqual(gguf.Metadata.get_model_id_components("voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized"), - ('SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized', 'voxmenthe', 'SFR-Iterative-DPO-LLaMA-3', 'R-unquantized', None, '8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "voxmenthe/SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized" + ), + ( + "SFR-Iterative-DPO-LLaMA-3-8B-R-unquantized", + "voxmenthe", + "SFR-Iterative-DPO-LLaMA-3", + "R-unquantized", + None, + "8B", + ), + ) # Too ambiguous # TODO: should "base" be a 'finetune' or 'size_label'? # (in this case it should be a size label, but other models use it to signal that they are not finetuned) - self.assertEqual(gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), - ('Florence-2-base', 'microsoft', None, None, None, None)) + self.assertEqual( + gguf.Metadata.get_model_id_components("microsoft/Florence-2-base"), + ("Florence-2-base", "microsoft", None, None, None, None), + ) ## Invalid cases ## # Start with a dash and has dashes in rows - self.assertEqual(gguf.Metadata.get_model_id_components("mistralai/-Mistral--Nemo-Base-2407-"), - ('-Mistral--Nemo-Base-2407-', 'mistralai', 'Mistral-Nemo-Base', None, '2407', None)) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "mistralai/-Mistral--Nemo-Base-2407-" + ), + ( + "-Mistral--Nemo-Base-2407-", + "mistralai", + "Mistral-Nemo-Base", + None, + "2407", + None, + ), + ) ## LoRA ## - self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), - ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "Llama-3-Instruct-abliteration-LoRA-8B" + ), + ( + "Llama-3-Instruct-abliteration-LoRA-8B", + None, + "Llama-3", + "Instruct-abliteration-LoRA", + None, + "8B", + ), + ) # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix - self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), - ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) + self.assertEqual( + gguf.Metadata.get_model_id_components( + "Llama-3-Instruct-abliteration-LoRA-8B", -1234 + ), + ( + "Llama-3-Instruct-abliteration-LoRA-8B", + None, + "Llama-3", + "Instruct-abliteration", + None, + "8B", + ), + ) def test_apply_metadata_heuristic_from_model_card(self): model_card = { - 'tags': ['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'], - 'model-index': [{'name': 'Mixtral-8x7B-Instruct-v0.1', 'results': []}], - 'language': ['en'], - 'datasets': ['teknium/OpenHermes-2.5'], - 'widget': [{'example_title': 'Hermes 2 Pro', 'messages': [{'role': 'system', 'content': 'You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.'}, {'role': 'user', 'content': 'Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.'}]}], - 'base_model': ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"] + "tags": [ + "Llama-3", + "instruct", + "finetune", + "chatml", + "DPO", + "RLHF", + "gpt4", + "synthetic data", + "distillation", + "function calling", + "json mode", + "axolotl", + ], + "model-index": [{"name": "Mixtral-8x7B-Instruct-v0.1", "results": []}], + "language": ["en"], + "datasets": ["teknium/OpenHermes-2.5"], + "widget": [ + { + "example_title": "Hermes 2 Pro", + "messages": [ + { + "role": "system", + "content": "You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.", + }, + { + "role": "user", + "content": "Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.", + }, + ], + } + ], + "base_model": ["EmbeddedLLM/Mistral-7B-Merge-14-v0", "janai-hq/trinity-v1"], } - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) expect = gguf.Metadata() - expect.base_models=[{'name': 'Mistral 7B Merge 14 v0', 'organization': 'EmbeddedLLM', 'version': '14-v0', 'repo_url': 'https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0'}, {'name': 'Trinity v1', 'organization': 'Janai Hq', 'version': 'v1', 'repo_url': 'https://huggingface.co/janai-hq/trinity-v1'}] - expect.tags=['Llama-3', 'instruct', 'finetune', 'chatml', 'DPO', 'RLHF', 'gpt4', 'synthetic data', 'distillation', 'function calling', 'json mode', 'axolotl'] - expect.languages=['en'] - expect.datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}] + expect.base_models = [ + { + "name": "Mistral 7B Merge 14 v0", + "organization": "EmbeddedLLM", + "version": "14-v0", + "repo_url": "https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0", + }, + { + "name": "Trinity v1", + "organization": "Janai Hq", + "version": "v1", + "repo_url": "https://huggingface.co/janai-hq/trinity-v1", + }, + ] + expect.tags = [ + "Llama-3", + "instruct", + "finetune", + "chatml", + "DPO", + "RLHF", + "gpt4", + "synthetic data", + "distillation", + "function calling", + "json mode", + "axolotl", + ] + expect.languages = ["en"] + expect.datasets = [ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] self.assertEqual(got, expect) # Base Model spec is inferred from model id - model_card = {'base_models': 'teknium/OpenHermes-2.5'} - expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = {"base_models": "teknium/OpenHermes-2.5"} + expect = gguf.Metadata( + base_models=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) # Base Model spec is only url - model_card = {'base_models': ['https://huggingface.co/teknium/OpenHermes-2.5']} - expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = {"base_models": ["https://huggingface.co/teknium/OpenHermes-2.5"]} + expect = gguf.Metadata( + base_models=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) # Base Model spec is given directly - model_card = {'base_models': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} - expect = gguf.Metadata(base_models=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = { + "base_models": [ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + } + expect = gguf.Metadata( + base_models=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) # Dataset spec is inferred from model id - model_card = {'datasets': 'teknium/OpenHermes-2.5'} - expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = {"datasets": "teknium/OpenHermes-2.5"} + expect = gguf.Metadata( + datasets=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) # Dataset spec is only url - model_card = {'datasets': ['https://huggingface.co/teknium/OpenHermes-2.5']} - expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = {"datasets": ["https://huggingface.co/teknium/OpenHermes-2.5"]} + expect = gguf.Metadata( + datasets=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) # Dataset spec is given directly - model_card = {'datasets': [{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]} - expect = gguf.Metadata(datasets=[{'name': 'OpenHermes 2.5', 'organization': 'Teknium', 'version': '2.5', 'repo_url': 'https://huggingface.co/teknium/OpenHermes-2.5'}]) - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card, None, None) + model_card = { + "datasets": [ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + } + expect = gguf.Metadata( + datasets=[ + { + "name": "OpenHermes 2.5", + "organization": "Teknium", + "version": "2.5", + "repo_url": "https://huggingface.co/teknium/OpenHermes-2.5", + } + ] + ) + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card, None, None + ) self.assertEqual(got, expect) def test_apply_metadata_heuristic_from_hf_parameters(self): hf_params = {"_name_or_path": "./hermes-2-pro-llama-3-8b-DPO"} - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None) - expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card=None, hf_params=hf_params, model_path=None + ) + expect = gguf.Metadata( + name="Hermes 2 Pro Llama 3 8b DPO", + finetune="DPO", + basename="hermes-2-pro-llama-3", + size_label="8B", + ) self.assertEqual(got, expect) def test_apply_metadata_heuristic_from_model_dir(self): model_dir_path = Path("./hermes-2-pro-llama-3-8b-DPO") - got = gguf.Metadata.apply_metadata_heuristic(gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path) - expect = gguf.Metadata(name='Hermes 2 Pro Llama 3 8b DPO', finetune='DPO', basename='hermes-2-pro-llama-3', size_label='8B') + got = gguf.Metadata.apply_metadata_heuristic( + gguf.Metadata(), model_card=None, hf_params=None, model_path=model_dir_path + ) + expect = gguf.Metadata( + name="Hermes 2 Pro Llama 3 8b DPO", + finetune="DPO", + basename="hermes-2-pro-llama-3", + size_label="8B", + ) self.assertEqual(got, expect) diff --git a/smallthinker/gguf-py/tests/test_quants.py b/smallthinker/gguf-py/tests/test_quants.py index f04d5acc..33a1bec6 100755 --- a/smallthinker/gguf-py/tests/test_quants.py +++ b/smallthinker/gguf-py/tests/test_quants.py @@ -16,7 +16,10 @@ import numpy as np # Necessary to load the local gguf package -if "NO_LOCAL_GGUF" not in os.environ and (Path(__file__).parent.parent.parent / 'gguf-py').exists(): +if ( + "NO_LOCAL_GGUF" not in os.environ + and (Path(__file__).parent.parent.parent / "gguf-py").exists() +): sys.path.insert(0, str(Path(__file__).parent.parent)) import gguf @@ -64,55 +67,117 @@ def __init__(self, libggml: Path): self.libggml.ggml_quantize_requires_imatrix.argtypes = (ctypes.c_int,) for t in ( - "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", - "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", - "tq1_0", "tq2_0", - "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", - "iq4_nl", "iq4_xs", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_K", + "q3_K", + "q4_K", + "q5_K", + "q6_K", + "tq1_0", + "tq2_0", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq1_s", + "iq1_m", + "iq4_nl", + "iq4_xs", ): - dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + t) + dequant_func: ctypes._NamedFuncPointer = getattr( + self.libggml, "dequantize_row_" + t + ) dequant_func.restype = None - dequant_func.argtypes = (ctypes.c_void_p, ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + dequant_func.argtypes = ( + ctypes.c_void_p, + ctypes.POINTER(ctypes.c_float), + ctypes.c_int64, + ) self.libggml.ggml_fp16_to_fp32_row.restype = None - self.libggml.ggml_fp16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + self.libggml.ggml_fp16_to_fp32_row.argtypes = ( + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.c_int64, + ) self.libggml.ggml_bf16_to_fp32_row.restype = None - self.libggml.ggml_bf16_to_fp32_row.argtypes = (ctypes.POINTER(ctypes.c_uint16), ctypes.POINTER(ctypes.c_float), ctypes.c_int64) + self.libggml.ggml_bf16_to_fp32_row.argtypes = ( + ctypes.POINTER(ctypes.c_uint16), + ctypes.POINTER(ctypes.c_float), + ctypes.c_int64, + ) self.libggml.ggml_init.argtypes = (ggml_init_params,) self.libggml.ggml_init(ggml_init_params(1 * 1024 * 1024, 0, False)) def dequantize(self, tensor: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: - result = np.zeros(gguf.quant_shape_from_byte_shape(tensor.shape, qtype), dtype=np.float32, order="C") + result = np.zeros( + gguf.quant_shape_from_byte_shape(tensor.shape, qtype), + dtype=np.float32, + order="C", + ) if qtype == GGMLQuantizationType.F32: # no-op result = tensor.view(np.float32) elif qtype == GGMLQuantizationType.F16: - self.libggml.ggml_fp16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + self.libggml.ggml_fp16_to_fp32_row( + tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + result.ctypes.data_as(c_float_p), + result.size, + ) elif qtype == GGMLQuantizationType.BF16: - self.libggml.ggml_bf16_to_fp32_row(tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), result.ctypes.data_as(c_float_p), result.size) + self.libggml.ggml_bf16_to_fp32_row( + tensor.ctypes.data_as(ctypes.POINTER(ctypes.c_uint16)), + result.ctypes.data_as(c_float_p), + result.size, + ) else: lw_qname = qtype.name.lower() if lw_qname[-1] == "k": lw_qname = lw_qname[:-1] + "K" - dequant_func: ctypes._NamedFuncPointer = getattr(self.libggml, "dequantize_row_" + lw_qname) - dequant_func(tensor.ctypes.data_as(ctypes.c_void_p), result.ctypes.data_as(c_float_p), result.size) + dequant_func: ctypes._NamedFuncPointer = getattr( + self.libggml, "dequantize_row_" + lw_qname + ) + dequant_func( + tensor.ctypes.data_as(ctypes.c_void_p), + result.ctypes.data_as(c_float_p), + result.size, + ) return result def quantize(self, data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: - result = np.zeros(gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C") + result = np.zeros( + gguf.quant_shape_to_byte_shape(data.shape, qtype), dtype=np.uint8, order="C" + ) if self.libggml.ggml_quantize_requires_imatrix(qtype.value): # TODO: is a column-wise sum of squares appropriate? - qw = np.sum((data * data).reshape((-1, data.shape[-1])), axis=0).ctypes.data_as(c_float_p) + qw = np.sum( + (data * data).reshape((-1, data.shape[-1])), axis=0 + ).ctypes.data_as(c_float_p) else: qw = ctypes.cast(0, c_float_p) - result_size = self.libggml.ggml_quantize_chunk(qtype.value, data.ctypes.data_as(c_float_p), result.ctypes.data_as(ctypes.c_void_p), 0, prod(data.shape[:-1]), data.shape[-1], qw) + result_size = self.libggml.ggml_quantize_chunk( + qtype.value, + data.ctypes.data_as(c_float_p), + result.ctypes.data_as(ctypes.c_void_p), + 0, + prod(data.shape[:-1]), + data.shape[-1], + qw, + ) assert result.size == result_size return result -def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) -> bool: +def compare_tensors( + t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType +) -> bool: same = np.array_equal(t1, t2) if same: return True @@ -130,20 +195,30 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType) if num_bad_blocks == 0 and t1.shape == t2.shape: logger.debug("Bits are equal, but arrays don't match, likely contains NANs") return True - logger.debug(f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)") + logger.debug( + f"{num_bad_blocks} bad blocks ({100 * num_bad_blocks / x.shape[0]:.6f}%)" + ) bad_block_id = np.argmax(diff_bits, axis=0) logger.debug(f"Worst block id: {bad_block_id}") - logger.debug(f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}") + logger.debug( + f"Sample bad block ({diff_bits[bad_block_id]} differing bits):\n{t1[bad_block_id]}\nReference:\n{t2[bad_block_id]}" + ) sum_diff_bits = np.sum(diff_bits) - logger.debug(f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)") + logger.debug( + f"{sum_diff_bits} bits differ ({100 * sum_diff_bits / (x.size * 8):.6f}%)" + ) return False def do_test(libggml_path: Path, quick: bool = False): ggml_quants = GGMLQuants(libggml_path) - np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n}) + np.set_printoptions( + precision=None, + threshold=(4 * 256) + 1, + formatter={"int": lambda n: "0x%02X" % n}, + ) r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False) @@ -152,14 +227,18 @@ def do_test(libggml_path: Path, quick: bool = False): has_quantize = False try: - gguf.dequantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype) + gguf.dequantize( + np.zeros((gguf.GGML_QUANT_SIZES[qtype][1]), dtype=np.uint8), qtype + ) has_dequantize = True except (NotImplementedError, AssertionError) as e: if isinstance(e, AssertionError): logger.error(f"Error with {qtype.name}: {e}") raise e try: - gguf.quantize(np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype) + gguf.quantize( + np.zeros((gguf.GGML_QUANT_SIZES[qtype][0]), dtype=np.float32), qtype + ) has_quantize = True except (NotImplementedError, AssertionError) as e: if isinstance(e, AssertionError): @@ -210,7 +289,9 @@ def do_test(libggml_path: Path, quick: bool = False): else: logger.info(f"Dequantization from {qtype.name} matches exactly ✅") - rq_shape = gguf.quants.quant_shape_to_byte_shape((8, 1024, 1024 // 2), qtype) + rq_shape = gguf.quants.quant_shape_to_byte_shape( + (8, 1024, 1024 // 2), qtype + ) rq = np.random.random(rq_shape).astype(np.float16).view(np.uint8) logger.debug(f"Dequantizing random f16 data as {qtype.name} with Python") @@ -221,15 +302,34 @@ def do_test(libggml_path: Path, quick: bool = False): dequant_equal = compare_tensors(pydq, ggdq, qtype) if not dequant_equal: - logger.error(f"Dequantization from random f16 data as {qtype.name} does not match ❌") + logger.error( + f"Dequantization from random f16 data as {qtype.name} does not match ❌" + ) else: - logger.info(f"Dequantization from random f16 data as {qtype.name} matches exactly ✅") + logger.info( + f"Dequantization from random f16 data as {qtype.name} matches exactly ✅" + ) if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation") - parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so") - parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary") + parser = argparse.ArgumentParser( + description="Test Python (de)quantization against the reference C implementation" + ) + parser.add_argument( + "--libggml", + type=Path, + default=Path(__file__).parent.parent.parent + / "build" + / "ggml" + / "src" + / "libggml.so", + help="The path to libggml.so", + ) + parser.add_argument( + "--quick", + action="store_true", + help="Don't quantize with C when it's not strictly necessary", + ) args = parser.parse_args() diff --git a/smallthinker/powerinfer/libaz/external/fmt/support/docopt.py b/smallthinker/powerinfer/libaz/external/fmt/support/docopt.py index 2e43f7ce..e3e7a730 100644 --- a/smallthinker/powerinfer/libaz/external/fmt/support/docopt.py +++ b/smallthinker/powerinfer/libaz/external/fmt/support/docopt.py @@ -1,32 +1,31 @@ """Pythonic command-line interface parser that will make you smile. - * http://docopt.org - * Repository and issue-tracker: https://github.com/docopt/docopt - * Licensed under terms of MIT license (see LICENSE-MIT) - * Copyright (c) 2013 Vladimir Keleshev, vladimir@keleshev.com +* http://docopt.org +* Repository and issue-tracker: https://github.com/docopt/docopt +* Licensed under terms of MIT license (see LICENSE-MIT) +* Copyright (c) 2013 Vladimir Keleshev, vladimir@keleshev.com """ + import sys import re -__all__ = ['docopt'] -__version__ = '0.6.1' +__all__ = ["docopt"] +__version__ = "0.6.1" class DocoptLanguageError(Exception): - """Error in construction of usage-message by developer.""" class DocoptExit(SystemExit): - """Exit in case user invoked program with incorrect arguments.""" - usage = '' + usage = "" - def __init__(self, message=''): - SystemExit.__init__(self, (message + '\n' + self.usage).strip()) + def __init__(self, message=""): + SystemExit.__init__(self, (message + "\n" + self.usage).strip()) class Pattern(object): @@ -44,11 +43,11 @@ def fix(self): def fix_identities(self, uniq=None): """Make pattern-tree tips point to same object if they are equal.""" - if not hasattr(self, 'children'): + if not hasattr(self, "children"): return self uniq = list(set(self.flat())) if uniq is None else uniq for i, child in enumerate(self.children): - if not hasattr(child, 'children'): + if not hasattr(child, "children"): assert child in uniq self.children[i] = uniq[uniq.index(child)] else: @@ -97,14 +96,13 @@ def transform(pattern): class LeafPattern(Pattern): - """Leaf/terminal node of a pattern tree.""" def __init__(self, name, value=None): self.name, self.value = name, value def __repr__(self): - return '%s(%r, %r)' % (self.__class__.__name__, self.name, self.value) + return "%s(%r, %r)" % (self.__class__.__name__, self.name, self.value) def flat(self, *types): return [self] if not types or type(self) in types else [] @@ -114,14 +112,13 @@ def match(self, left, collected=None): pos, match = self.single_match(left) if match is None: return False, left, collected - left_ = left[:pos] + left[pos + 1:] + left_ = left[:pos] + left[pos + 1 :] same_name = [a for a in collected if a.name == self.name] if type(self.value) in (int, list): if type(self.value) is int: increment = 1 else: - increment = ([match.value] if type(match.value) is str - else match.value) + increment = [match.value] if type(match.value) is str else match.value if not same_name: match.value = increment return True, left_, collected + [match] @@ -131,15 +128,16 @@ def match(self, left, collected=None): class BranchPattern(Pattern): - """Branch/inner node of a pattern tree.""" def __init__(self, *children): self.children = list(children) def __repr__(self): - return '%s(%s)' % (self.__class__.__name__, - ', '.join(repr(a) for a in self.children)) + return "%s(%s)" % ( + self.__class__.__name__, + ", ".join(repr(a) for a in self.children), + ) def flat(self, *types): if type(self) in types: @@ -157,8 +155,8 @@ def single_match(self, left): @classmethod def parse(class_, source): - name = re.findall('(<\S*?>)', source)[0] - value = re.findall('\[default: (.*)\]', source, flags=re.I) + name = re.findall("(<\S*?>)", source)[0] + value = re.findall("\[default: (.*)\]", source, flags=re.I) return class_(name, value[0] if value else None) @@ -187,17 +185,17 @@ def __init__(self, short=None, long=None, argcount=0, value=False): @classmethod def parse(class_, option_description): short, long, argcount, value = None, None, 0, False - options, _, description = option_description.strip().partition(' ') - options = options.replace(',', ' ').replace('=', ' ') + options, _, description = option_description.strip().partition(" ") + options = options.replace(",", " ").replace("=", " ") for s in options.split(): - if s.startswith('--'): + if s.startswith("--"): long = s - elif s.startswith('-'): + elif s.startswith("-"): short = s else: argcount = 1 if argcount: - matched = re.findall('\[default: (.*)\]', description, flags=re.I) + matched = re.findall("\[default: (.*)\]", description, flags=re.I) value = matched[0] if matched else None return class_(short, long, argcount, value) @@ -212,8 +210,12 @@ def name(self): return self.long or self.short def __repr__(self): - return 'Option(%r, %r, %r, %r)' % (self.short, self.long, - self.argcount, self.value) + return "Option(%r, %r, %r, %r)" % ( + self.short, + self.long, + self.argcount, + self.value, + ) class Required(BranchPattern): @@ -239,7 +241,6 @@ def match(self, left, collected=None): class OptionsShortcut(Optional): - """Marker/placeholder for [options] shortcut.""" @@ -282,13 +283,13 @@ def match(self, left, collected=None): class Tokens(list): def __init__(self, source, error=DocoptExit): - self += source.split() if hasattr(source, 'split') else source + self += source.split() if hasattr(source, "split") else source self.error = error @staticmethod def from_pattern(source): - source = re.sub(r'([\[\]\(\)\|]|\.\.\.)', r' \1 ', source) - source = [s for s in re.split('\s+|(\S*<.*?>)', source) if s] + source = re.sub(r"([\[\]\(\)\|]|\.\.\.)", r" \1 ", source) + source = [s for s in re.split("\s+|(\S*<.*?>)", source) if s] return Tokens(source, error=DocoptLanguageError) def move(self): @@ -300,31 +301,34 @@ def current(self): def parse_long(tokens, options): """long ::= '--' chars [ ( ' ' | '=' ) chars ] ;""" - long, eq, value = tokens.move().partition('=') - assert long.startswith('--') - value = None if eq == value == '' else value + long, eq, value = tokens.move().partition("=") + assert long.startswith("--") + value = None if eq == value == "" else value similar = [o for o in options if o.long == long] if tokens.error is DocoptExit and similar == []: # if no exact match similar = [o for o in options if o.long and o.long.startswith(long)] if len(similar) > 1: # might be simply specified ambiguously 2+ times? - raise tokens.error('%s is not a unique prefix: %s?' % - (long, ', '.join(o.long for o in similar))) + raise tokens.error( + "%s is not a unique prefix: %s?" + % (long, ", ".join(o.long for o in similar)) + ) elif len(similar) < 1: - argcount = 1 if eq == '=' else 0 + argcount = 1 if eq == "=" else 0 o = Option(None, long, argcount) options.append(o) if tokens.error is DocoptExit: o = Option(None, long, argcount, value if argcount else True) else: - o = Option(similar[0].short, similar[0].long, - similar[0].argcount, similar[0].value) + o = Option( + similar[0].short, similar[0].long, similar[0].argcount, similar[0].value + ) if o.argcount == 0: if value is not None: - raise tokens.error('%s must not have an argument' % o.long) + raise tokens.error("%s must not have an argument" % o.long) else: if value is None: - if tokens.current() in [None, '--']: - raise tokens.error('%s requires argument' % o.long) + if tokens.current() in [None, "--"]: + raise tokens.error("%s requires argument" % o.long) value = tokens.move() if tokens.error is DocoptExit: o.value = value if value is not None else True @@ -334,32 +338,32 @@ def parse_long(tokens, options): def parse_shorts(tokens, options): """shorts ::= '-' ( chars )* [ [ ' ' ] chars ] ;""" token = tokens.move() - assert token.startswith('-') and not token.startswith('--') - left = token.lstrip('-') + assert token.startswith("-") and not token.startswith("--") + left = token.lstrip("-") parsed = [] - while left != '': - short, left = '-' + left[0], left[1:] + while left != "": + short, left = "-" + left[0], left[1:] similar = [o for o in options if o.short == short] if len(similar) > 1: - raise tokens.error('%s is specified ambiguously %d times' % - (short, len(similar))) + raise tokens.error( + "%s is specified ambiguously %d times" % (short, len(similar)) + ) elif len(similar) < 1: o = Option(short, None, 0) options.append(o) if tokens.error is DocoptExit: o = Option(short, None, 0, True) else: # why copying is necessary here? - o = Option(short, similar[0].long, - similar[0].argcount, similar[0].value) + o = Option(short, similar[0].long, similar[0].argcount, similar[0].value) value = None if o.argcount != 0: - if left == '': - if tokens.current() in [None, '--']: - raise tokens.error('%s requires argument' % short) + if left == "": + if tokens.current() in [None, "--"]: + raise tokens.error("%s requires argument" % short) value = tokens.move() else: value = left - left = '' + left = "" if tokens.error is DocoptExit: o.value = value if value is not None else True parsed.append(o) @@ -370,17 +374,17 @@ def parse_pattern(source, options): tokens = Tokens.from_pattern(source) result = parse_expr(tokens, options) if tokens.current() is not None: - raise tokens.error('unexpected ending: %r' % ' '.join(tokens)) + raise tokens.error("unexpected ending: %r" % " ".join(tokens)) return Required(*result) def parse_expr(tokens, options): """expr ::= seq ( '|' seq )* ;""" seq = parse_seq(tokens, options) - if tokens.current() != '|': + if tokens.current() != "|": return seq result = [Required(*seq)] if len(seq) > 1 else seq - while tokens.current() == '|': + while tokens.current() == "|": tokens.move() seq = parse_seq(tokens, options) result += [Required(*seq)] if len(seq) > 1 else seq @@ -390,9 +394,9 @@ def parse_expr(tokens, options): def parse_seq(tokens, options): """seq ::= ( atom [ '...' ] )* ;""" result = [] - while tokens.current() not in [None, ']', ')', '|']: + while tokens.current() not in [None, "]", ")", "|"]: atom = parse_atom(tokens, options) - if tokens.current() == '...': + if tokens.current() == "...": atom = [OneOrMore(*atom)] tokens.move() result += atom @@ -401,25 +405,25 @@ def parse_seq(tokens, options): def parse_atom(tokens, options): """atom ::= '(' expr ')' | '[' expr ']' | 'options' - | long | shorts | argument | command ; + | long | shorts | argument | command ; """ token = tokens.current() result = [] - if token in '([': + if token in "([": tokens.move() - matching, pattern = {'(': [')', Required], '[': [']', Optional]}[token] + matching, pattern = {"(": [")", Required], "[": ["]", Optional]}[token] result = pattern(*parse_expr(tokens, options)) if tokens.move() != matching: raise tokens.error("unmatched '%s'" % token) return [result] - elif token == 'options': + elif token == "options": tokens.move() return [OptionsShortcut()] - elif token.startswith('--') and token != '--': + elif token.startswith("--") and token != "--": return parse_long(tokens, options) - elif token.startswith('-') and token not in ('-', '--'): + elif token.startswith("-") and token not in ("-", "--"): return parse_shorts(tokens, options) - elif token.startswith('<') and token.endswith('>') or token.isupper(): + elif token.startswith("<") and token.endswith(">") or token.isupper(): return [Argument(tokens.move())] else: return [Command(tokens.move())] @@ -436,11 +440,11 @@ def parse_argv(tokens, options, options_first=False): """ parsed = [] while tokens.current() is not None: - if tokens.current() == '--': + if tokens.current() == "--": return parsed + [Argument(None, v) for v in tokens] - elif tokens.current().startswith('--'): + elif tokens.current().startswith("--"): parsed += parse_long(tokens, options) - elif tokens.current().startswith('-') and tokens.current() != '-': + elif tokens.current().startswith("-") and tokens.current() != "-": parsed += parse_shorts(tokens, options) elif options_first: return parsed + [Argument(None, v) for v in tokens] @@ -451,40 +455,42 @@ def parse_argv(tokens, options, options_first=False): def parse_defaults(doc): defaults = [] - for s in parse_section('options:', doc): + for s in parse_section("options:", doc): # FIXME corner case "bla: options: --foo" - _, _, s = s.partition(':') # get rid of "options:" - split = re.split('\n[ \t]*(-\S+?)', '\n' + s)[1:] + _, _, s = s.partition(":") # get rid of "options:" + split = re.split("\n[ \t]*(-\S+?)", "\n" + s)[1:] split = [s1 + s2 for s1, s2 in zip(split[::2], split[1::2])] - options = [Option.parse(s) for s in split if s.startswith('-')] + options = [Option.parse(s) for s in split if s.startswith("-")] defaults += options return defaults def parse_section(name, source): - pattern = re.compile('^([^\n]*' + name + '[^\n]*\n?(?:[ \t].*?(?:\n|$))*)', - re.IGNORECASE | re.MULTILINE) + pattern = re.compile( + "^([^\n]*" + name + "[^\n]*\n?(?:[ \t].*?(?:\n|$))*)", + re.IGNORECASE | re.MULTILINE, + ) return [s.strip() for s in pattern.findall(source)] def formal_usage(section): - _, _, section = section.partition(':') # drop "usage:" + _, _, section = section.partition(":") # drop "usage:" pu = section.split() - return '( ' + ' '.join(') | (' if s == pu[0] else s for s in pu[1:]) + ' )' + return "( " + " ".join(") | (" if s == pu[0] else s for s in pu[1:]) + " )" def extras(help, version, options, doc): - if help and any((o.name in ('-h', '--help')) and o.value for o in options): + if help and any((o.name in ("-h", "--help")) and o.value for o in options): print(doc.strip("\n")) sys.exit() - if version and any(o.name == '--version' and o.value for o in options): + if version and any(o.name == "--version" and o.value for o in options): print(version) sys.exit() class Dict(dict): def __repr__(self): - return '{%s}' % ',\n '.join('%r: %r' % i for i in sorted(self.items())) + return "{%s}" % ",\n ".join("%r: %r" % i for i in sorted(self.items())) def docopt(doc, argv=None, help=True, version=None, options_first=False): @@ -552,7 +558,7 @@ def docopt(doc, argv=None, help=True, version=None, options_first=False): """ argv = sys.argv[1:] if argv is None else argv - usage_sections = parse_section('usage:', doc) + usage_sections = parse_section("usage:", doc) if len(usage_sections) == 0: raise DocoptLanguageError('"usage:" (case-insensitive) not found.') if len(usage_sections) > 1: @@ -562,7 +568,7 @@ def docopt(doc, argv=None, help=True, version=None, options_first=False): options = parse_defaults(doc) pattern = parse_pattern(formal_usage(DocoptExit.usage), options) # [default] syntax for argument is disabled - #for a in pattern.flat(Argument): + # for a in pattern.flat(Argument): # same_name = [d for d in arguments if d.name == a.name] # if same_name: # a.value = same_name[0].value @@ -571,7 +577,7 @@ def docopt(doc, argv=None, help=True, version=None, options_first=False): for options_shortcut in pattern.flat(OptionsShortcut): doc_options = parse_defaults(doc) options_shortcut.children = list(set(doc_options) - pattern_options) - #if any_options: + # if any_options: # options_shortcut.children += [Option(o.short, o.long, o.argcount) # for o in argv if type(o) is Option] extras(help, version, argv, doc) diff --git a/smallthinker/powerinfer/libaz/external/fmt/support/printable.py b/smallthinker/powerinfer/libaz/external/fmt/support/printable.py index 8fa86b30..7b274838 100644 --- a/smallthinker/powerinfer/libaz/external/fmt/support/printable.py +++ b/smallthinker/powerinfer/libaz/external/fmt/support/printable.py @@ -13,7 +13,8 @@ import os import subprocess -NUM_CODEPOINTS=0x110000 +NUM_CODEPOINTS = 0x110000 + def to_ranges(iter): current = None @@ -27,11 +28,15 @@ def to_ranges(iter): if current is not None: yield tuple(current) + def get_escaped(codepoints): for c in codepoints: - if (c.class_ or "Cn") in "Cc Cf Cs Co Cn Zl Zp Zs".split() and c.value != ord(' '): + if (c.class_ or "Cn") in "Cc Cf Cs Co Cn Zl Zp Zs".split() and c.value != ord( + " " + ): yield c.value + def get_file(f): try: return open(os.path.basename(f)) @@ -39,7 +44,9 @@ def get_file(f): subprocess.run(["curl", "-O", f], check=True) return open(os.path.basename(f)) -Codepoint = namedtuple('Codepoint', 'value class_') + +Codepoint = namedtuple("Codepoint", "value class_") + def get_codepoints(f): r = csv.reader(f, delimiter=";") @@ -70,13 +77,14 @@ def get_codepoints(f): for c in range(prev_codepoint + 1, NUM_CODEPOINTS): yield Codepoint(c, None) + def compress_singletons(singletons): - uppers = [] # (upper, # items in lowers) + uppers = [] # (upper, # items in lowers) lowers = [] for i in singletons: upper = i >> 8 - lower = i & 0xff + lower = i & 0xFF if len(uppers) == 0 or uppers[-1][0] != upper: uppers.append((upper, 1)) else: @@ -86,10 +94,11 @@ def compress_singletons(singletons): return uppers, lowers + def compress_normal(normal): # lengths 0x00..0x7f are encoded as 00, 01, ..., 7e, 7f # lengths 0x80..0x7fff are encoded as 80 80, 80 81, ..., ff fe, ff ff - compressed = [] # [truelen, (truelenaux), falselen, (falselenaux)] + compressed = [] # [truelen, (truelenaux), falselen, (falselenaux)] prev_start = 0 for start, count in normal: @@ -99,21 +108,22 @@ def compress_normal(normal): assert truelen < 0x8000 and falselen < 0x8000 entry = [] - if truelen > 0x7f: + if truelen > 0x7F: entry.append(0x80 | (truelen >> 8)) - entry.append(truelen & 0xff) + entry.append(truelen & 0xFF) else: - entry.append(truelen & 0x7f) - if falselen > 0x7f: + entry.append(truelen & 0x7F) + if falselen > 0x7F: entry.append(0x80 | (falselen >> 8)) - entry.append(falselen & 0xff) + entry.append(falselen & 0xFF) else: - entry.append(falselen & 0x7f) + entry.append(falselen & 0x7F) compressed.append(entry) return compressed + def print_singletons(uppers, lowers, uppersname, lowersname): print(" static constexpr singleton {}[] = {{".format(uppersname)) for u, c in uppers: @@ -121,21 +131,25 @@ def print_singletons(uppers, lowers, uppersname, lowersname): print(" };") print(" static constexpr unsigned char {}[] = {{".format(lowersname)) for i in range(0, len(lowers), 8): - print(" {}".format(" ".join("{:#04x},".format(l) for l in lowers[i:i+8]))) + print( + " {}".format(" ".join("{:#04x},".format(l) for l in lowers[i : i + 8])) + ) print(" };") + def print_normal(normal, normalname): print(" static constexpr unsigned char {}[] = {{".format(normalname)) for v in normal: print(" {}".format(" ".join("{:#04x},".format(i) for i in v))) print(" };") + def main(): file = get_file("https://www.unicode.org/Public/UNIDATA/UnicodeData.txt") codepoints = get_codepoints(file) - CUTOFF=0x10000 + CUTOFF = 0x10000 singletons0 = [] singletons1 = [] normal0 = [] @@ -170,14 +184,17 @@ def main(): normal0 = compress_normal(normal0) normal1 = compress_normal(normal1) - print("""\ + print( + """\ FMT_FUNC auto is_printable(uint32_t cp) -> bool {\ -""") - print_singletons(singletons0u, singletons0l, 'singletons0', 'singletons0_lower') - print_singletons(singletons1u, singletons1l, 'singletons1', 'singletons1_lower') - print_normal(normal0, 'normal0') - print_normal(normal1, 'normal1') - print("""\ +""" + ) + print_singletons(singletons0u, singletons0l, "singletons0", "singletons0_lower") + print_singletons(singletons1u, singletons1l, "singletons1", "singletons1_lower") + print_normal(normal0, "normal0") + print_normal(normal1, "normal1") + print( + """\ auto lower = static_cast(cp); if (cp < 0x10000) { return is_printable(lower, singletons0, @@ -189,13 +206,19 @@ def main(): sizeof(singletons1) / sizeof(*singletons1), singletons1_lower, normal1, sizeof(normal1)); }\ -""") +""" + ) for a, b in extra: print(" if (0x{:x} <= cp && cp < 0x{:x}) return false;".format(a, a + b)) - print("""\ + print( + """\ return cp < 0x{:x}; }}\ -""".format(NUM_CODEPOINTS)) +""".format( + NUM_CODEPOINTS + ) + ) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/powerinfer/libaz/external/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py b/smallthinker/powerinfer/libaz/external/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py index 85d52a58..14c641d2 100644 --- a/smallthinker/powerinfer/libaz/external/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py +++ b/smallthinker/powerinfer/libaz/external/fmt/support/python/mkdocstrings_handlers/cxx/__init__.py @@ -14,11 +14,15 @@ class Definition: """A definition extracted by Doxygen.""" - def __init__(self, name: str, kind: Optional[str] = None, - node: Optional[ElementTree.Element] = None, - is_member: bool = False): + def __init__( + self, + name: str, + kind: Optional[str] = None, + node: Optional[ElementTree.Element] = None, + is_member: bool = False, + ): self.name = name - self.kind = kind if kind is not None else node.get('kind') + self.kind = kind if kind is not None else node.get("kind") self.desc = None self.id = name if not is_member else None self.members = None @@ -30,20 +34,16 @@ def __init__(self, name: str, kind: Optional[str] = None, # A map from Doxygen to HTML tags. tag_map = { - 'bold': 'b', - 'emphasis': 'em', - 'computeroutput': 'code', - 'para': 'p', - 'programlisting': 'pre', - 'verbatim': 'pre' + "bold": "b", + "emphasis": "em", + "computeroutput": "code", + "para": "p", + "programlisting": "pre", + "verbatim": "pre", } # A map from Doxygen tags to text. -tag_text_map = { - 'codeline': '', - 'highlight': '', - 'sp': ' ' -} +tag_text_map = {"codeline": "", "highlight": "", "sp": " "} def escape_html(s: str) -> str: @@ -51,50 +51,51 @@ def escape_html(s: str) -> str: def doxyxml2html(nodes: List[ElementTree.Element]): - out = '' + out = "" for n in nodes: tag = tag_map.get(n.tag) if not tag: out += tag_text_map[n.tag] - out += '<' + tag + '>' if tag else '' - out += '' if tag == 'pre' else '' + out += "<" + tag + ">" if tag else "" + out += '' if tag == "pre" else "" if n.text: out += escape_html(n.text) out += doxyxml2html(list(n)) - out += '' if tag == 'pre' else '' - out += '' if tag else '' + out += "" if tag == "pre" else "" + out += "" if tag else "" if n.tail: out += n.tail return out def convert_template_params(node: ElementTree.Element) -> Optional[List[Definition]]: - template_param_list = node.find('templateparamlist') + template_param_list = node.find("templateparamlist") if template_param_list is None: return None params = [] - for param_node in template_param_list.findall('param'): - name = param_node.find('declname') - param = Definition(name.text if name is not None else '', 'param') - param.type = param_node.find('type').text + for param_node in template_param_list.findall("param"): + name = param_node.find("declname") + param = Definition(name.text if name is not None else "", "param") + param.type = param_node.find("type").text params.append(param) return params def get_description(node: ElementTree.Element) -> List[ElementTree.Element]: - return node.findall('briefdescription/para') + \ - node.findall('detaileddescription/para') + return node.findall("briefdescription/para") + node.findall( + "detaileddescription/para" + ) def normalize_type(type_: str) -> str: - type_ = type_.replace('< ', '<').replace(' >', '>') - return type_.replace(' &', '&').replace(' *', '*') + type_ = type_.replace("< ", "<").replace(" >", ">") + return type_.replace(" &", "&").replace(" *", "*") def convert_type(type_: ElementTree.Element) -> Optional[str]: if type_ is None: return None - result = type_.text if type_.text else '' + result = type_.text if type_.text else "" for ref in type_: result += ref.text if ref.tail: @@ -105,87 +106,99 @@ def convert_type(type_: ElementTree.Element) -> Optional[str]: def convert_params(func: ElementTree.Element) -> List[Definition]: params = [] - for p in func.findall('param'): - d = Definition(p.find('declname').text, 'param') - d.type = convert_type(p.find('type')) + for p in func.findall("param"): + d = Definition(p.find("declname").text, "param") + d.type = convert_type(p.find("type")) params.append(d) return params def convert_return_type(d: Definition, node: ElementTree.Element) -> None: d.trailing_return_type = None - if d.type == 'auto' or d.type == 'constexpr auto': - parts = node.find('argsstring').text.split(' -> ') + if d.type == "auto" or d.type == "constexpr auto": + parts = node.find("argsstring").text.split(" -> ") if len(parts) > 1: d.trailing_return_type = normalize_type(parts[1]) def render_param(param: Definition) -> str: - return param.type + (f' {param.name}' if len(param.name) > 0 else '') + return param.type + (f" {param.name}" if len(param.name) > 0 else "") def render_decl(d: Definition) -> str: - text = '' + text = "" if d.id is not None: text += f'\n' text += '
'
 
-    text += '
' + text += "
" if d.template_params is not None: - text += 'template <' - text += ', '.join([render_param(p) for p in d.template_params]) - text += '>\n' - text += '
' - - text += '
' - end = ';' - if d.kind == 'function' or d.kind == 'variable': - text += d.type + ' ' if len(d.type) > 0 else '' - elif d.kind == 'typedef': - text += 'using ' - elif d.kind == 'define': - end = '' + text += "template <" + text += ", ".join([render_param(p) for p in d.template_params]) + text += ">\n" + text += "
" + + text += "
" + end = ";" + if d.kind == "function" or d.kind == "variable": + text += d.type + " " if len(d.type) > 0 else "" + elif d.kind == "typedef": + text += "using " + elif d.kind == "define": + end = "" else: - text += d.kind + ' ' + text += d.kind + " " text += d.name if d.params is not None: - params = ', '.join([ - (p.type + ' ' if p.type else '') + p.name for p in d.params]) - text += '(' + escape_html(params) + ')' + params = ", ".join( + [(p.type + " " if p.type else "") + p.name for p in d.params] + ) + text += "(" + escape_html(params) + ")" if d.trailing_return_type: - text += ' -⁠> ' + escape_html(d.trailing_return_type) - elif d.kind == 'typedef': - text += ' = ' + escape_html(d.type) + text += " -⁠> " + escape_html(d.trailing_return_type) + elif d.kind == "typedef": + text += " = " + escape_html(d.type) text += end - text += '
' - text += '
\n' + text += "" + text += "\n" if d.id is not None: - text += f'
\n' + text += f"\n" return text class CxxHandler(BaseHandler): def __init__(self, **kwargs: Any) -> None: - super().__init__(handler='cxx', **kwargs) + super().__init__(handler="cxx", **kwargs) headers = [ - 'args.h', 'base.h', 'chrono.h', 'color.h', 'compile.h', 'format.h', - 'os.h', 'ostream.h', 'printf.h', 'ranges.h', 'std.h', 'xchar.h' + "args.h", + "base.h", + "chrono.h", + "color.h", + "compile.h", + "format.h", + "os.h", + "ostream.h", + "printf.h", + "ranges.h", + "std.h", + "xchar.h", ] # Run doxygen. - cmd = ['doxygen', '-'] + cmd = ["doxygen", "-"] support_dir = Path(__file__).parents[3] top_dir = os.path.dirname(support_dir) - include_dir = os.path.join(top_dir, 'include', 'fmt') + include_dir = os.path.join(top_dir, "include", "fmt") self._ns2doxyxml = {} - build_dir = os.path.join(top_dir, 'build') + build_dir = os.path.join(top_dir, "build") os.makedirs(build_dir, exist_ok=True) - self._doxyxml_dir = os.path.join(build_dir, 'doxyxml') + self._doxyxml_dir = os.path.join(build_dir, "doxyxml") p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=STDOUT) - _, _ = p.communicate(input=r''' + _, _ = p.communicate( + input=r""" PROJECT_NAME = fmt GENERATE_XML = YES GENERATE_LATEX = NO @@ -205,9 +218,13 @@ def __init__(self, **kwargs: Any) -> None: "FMT_BEGIN_NAMESPACE=namespace fmt {{" \ "FMT_END_NAMESPACE=}}" \ "FMT_DOC=1" - '''.format( - ' '.join([os.path.join(include_dir, h) for h in headers]), - self._doxyxml_dir).encode('utf-8')) + """.format( + " ".join([os.path.join(include_dir, h) for h in headers]), + self._doxyxml_dir, + ).encode( + "utf-8" + ) + ) if p.returncode != 0: raise CalledProcessError(p.returncode, cmd) @@ -224,33 +241,34 @@ def __init__(self, **kwargs: Any) -> None: for node in doxyxml.getroot(): root.append(node) - def collect_compound(self, identifier: str, - cls: List[ElementTree.Element]) -> Definition: + def collect_compound( + self, identifier: str, cls: List[ElementTree.Element] + ) -> Definition: """Collect a compound definition such as a struct.""" - path = os.path.join(self._doxyxml_dir, cls[0].get('refid') + '.xml') + path = os.path.join(self._doxyxml_dir, cls[0].get("refid") + ".xml") with open(path) as f: xml = ElementTree.parse(f) - node = xml.find('compounddef') + node = xml.find("compounddef") d = Definition(identifier, node=node) d.template_params = convert_template_params(node) d.desc = get_description(node) d.members = [] - for m in \ - node.findall('sectiondef[@kind="public-attrib"]/memberdef') + \ - node.findall('sectiondef[@kind="public-func"]/memberdef'): - name = m.find('name').text + for m in node.findall( + 'sectiondef[@kind="public-attrib"]/memberdef' + ) + node.findall('sectiondef[@kind="public-func"]/memberdef'): + name = m.find("name").text # Doxygen incorrectly classifies members of private unnamed unions as # public members of the containing class. - if name.endswith('_'): + if name.endswith("_"): continue desc = get_description(m) if len(desc) == 0: continue - kind = m.get('kind') - member = Definition(name if name else '', kind=kind, is_member=True) - type_text = m.find('type').text - member.type = type_text if type_text else '' - if kind == 'function': + kind = m.get("kind") + member = Definition(name if name else "", kind=kind, is_member=True) + type_text = m.find("type").text + member.type = type_text if type_text else "" + if kind == "function": member.params = convert_params(m) convert_return_type(member, m) member.template_params = None @@ -259,15 +277,15 @@ def collect_compound(self, identifier: str, return d def collect(self, identifier: str, _config: Mapping[str, Any]) -> Definition: - qual_name = 'fmt::' + identifier + qual_name = "fmt::" + identifier param_str = None - paren = qual_name.find('(') + paren = qual_name.find("(") if paren > 0: - qual_name, param_str = qual_name[:paren], qual_name[paren + 1:-1] + qual_name, param_str = qual_name[:paren], qual_name[paren + 1 : -1] - colons = qual_name.rfind('::') - namespace, name = qual_name[:colons], qual_name[colons + 2:] + colons = qual_name.rfind("::") + namespace, name = qual_name[:colons], qual_name[colons + 2 :] # Load XML. doxyxml = self._ns2doxyxml.get(namespace) @@ -277,29 +295,29 @@ def collect(self, identifier: str, _config: Mapping[str, Any]) -> Definition: doxyxml = ElementTree.parse(f) self._ns2doxyxml[namespace] = doxyxml - nodes = doxyxml.findall( - f"compounddef/sectiondef/memberdef/name[.='{name}']/..") + nodes = doxyxml.findall(f"compounddef/sectiondef/memberdef/name[.='{name}']/..") if len(nodes) == 0: nodes = self._file_doxyxml.findall( - f"compounddef/sectiondef/memberdef/name[.='{name}']/..") + f"compounddef/sectiondef/memberdef/name[.='{name}']/.." + ) candidates = [] for node in nodes: # Process a function or a typedef. params = None d = Definition(name, node=node) - if d.kind == 'function': + if d.kind == "function": params = convert_params(node) - node_param_str = ', '.join([p.type for p in params]) + node_param_str = ", ".join([p.type for p in params]) if param_str and param_str != node_param_str: - candidates.append(f'{name}({node_param_str})') + candidates.append(f"{name}({node_param_str})") continue - elif d.kind == 'define': + elif d.kind == "define": params = [] - for p in node.findall('param'): - param = Definition(p.find('defname').text, kind='param') + for p in node.findall("param"): + param = Definition(p.find("defname").text, kind="param") param.type = None params.append(param) - d.type = convert_type(node.find('type')) + d.type = convert_type(node.find("type")) d.template_params = convert_template_params(node) d.params = params convert_return_type(d, node) @@ -308,12 +326,12 @@ def collect(self, identifier: str, _config: Mapping[str, Any]) -> Definition: cls = doxyxml.findall(f"compounddef/innerclass[.='{qual_name}']") if not cls: - raise Exception(f'Cannot find {identifier}. Candidates: {candidates}') + raise Exception(f"Cannot find {identifier}. Candidates: {candidates}") return self.collect_compound(identifier, cls) def render(self, d: Definition, config: dict) -> str: if d.id is not None: - self.do_heading('', 0, id=d.id) + self.do_heading("", 0, id=d.id) text = '
\n' text += render_decl(d) text += '
\n' @@ -321,13 +339,14 @@ def render(self, d: Definition, config: dict) -> str: if d.members is not None: for m in d.members: text += self.render(m, config) - text += '
\n' - text += '
\n' + text += "\n" + text += "\n" return text -def get_handler(theme: str, custom_templates: Optional[str] = None, - **_config: Any) -> CxxHandler: +def get_handler( + theme: str, custom_templates: Optional[str] = None, **_config: Any +) -> CxxHandler: """Return an instance of `CxxHandler`. Arguments: diff --git a/smallthinker/powerinfer/libaz/external/fmt/support/release.py b/smallthinker/powerinfer/libaz/external/fmt/support/release.py index 26de7f4f..416431d1 100644 --- a/smallthinker/powerinfer/libaz/external/fmt/support/release.py +++ b/smallthinker/powerinfer/libaz/external/fmt/support/release.py @@ -21,31 +21,31 @@ def __init__(self, dir): self.dir = dir def call(self, method, args, **kwargs): - return check_call(['git', method] + list(args), **kwargs) + return check_call(["git", method] + list(args), **kwargs) def add(self, *args): - return self.call('add', args, cwd=self.dir) + return self.call("add", args, cwd=self.dir) def checkout(self, *args): - return self.call('checkout', args, cwd=self.dir) + return self.call("checkout", args, cwd=self.dir) def clean(self, *args): - return self.call('clean', args, cwd=self.dir) + return self.call("clean", args, cwd=self.dir) def clone(self, *args): - return self.call('clone', list(args) + [self.dir]) + return self.call("clone", list(args) + [self.dir]) def commit(self, *args): - return self.call('commit', args, cwd=self.dir) + return self.call("commit", args, cwd=self.dir) def pull(self, *args): - return self.call('pull', args, cwd=self.dir) + return self.call("pull", args, cwd=self.dir) def push(self, *args): - return self.call('push', args, cwd=self.dir) + return self.call("push", args, cwd=self.dir) def reset(self, *args): - return self.call('reset', args, cwd=self.dir) + return self.call("reset", args, cwd=self.dir) def update(self, *args): clone = not os.path.exists(self.dir) @@ -55,8 +55,8 @@ def update(self, *args): def clean_checkout(repo, branch): - repo.clean('-f', '-d') - repo.reset('--hard') + repo.clean("-f", "-d") + repo.reset("--hard") repo.checkout(branch) @@ -65,70 +65,71 @@ def __init__(self, cwd): self.cwd = cwd def __call__(self, *args, **kwargs): - kwargs['cwd'] = kwargs.get('cwd', self.cwd) + kwargs["cwd"] = kwargs.get("cwd", self.cwd) check_call(args, **kwargs) def create_build_env(): """Create a build environment.""" + class Env: pass + env = Env() env.fmt_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - env.build_dir = 'build' - env.fmt_repo = Git(os.path.join(env.build_dir, 'fmt')) + env.build_dir = "build" + env.fmt_repo = Git(os.path.join(env.build_dir, "fmt")) return env -if __name__ == '__main__': +if __name__ == "__main__": args = docopt.docopt(__doc__) env = create_build_env() fmt_repo = env.fmt_repo - branch = args.get('') + branch = args.get("") if branch is None: - branch = 'master' - if not fmt_repo.update('-b', branch, 'git@github.com:fmtlib/fmt'): + branch = "master" + if not fmt_repo.update("-b", branch, "git@github.com:fmtlib/fmt"): clean_checkout(fmt_repo, branch) # Update the date in the changelog and extract the version and the first # section content. - changelog = 'ChangeLog.md' + changelog = "ChangeLog.md" changelog_path = os.path.join(fmt_repo.dir, changelog) is_first_section = True first_section = [] for i, line in enumerate(fileinput.input(changelog_path, inplace=True)): if i == 0: - version = re.match(r'# (.*) - TBD', line).group(1) - line = '# {} - {}\n'.format( - version, datetime.date.today().isoformat()) + version = re.match(r"# (.*) - TBD", line).group(1) + line = "# {} - {}\n".format(version, datetime.date.today().isoformat()) elif not is_first_section: pass - elif line.startswith('#'): + elif line.startswith("#"): is_first_section = False else: first_section.append(line) sys.stdout.write(line) - if first_section[0] == '\n': + if first_section[0] == "\n": first_section.pop(0) ns_version = None - base_h_path = os.path.join(fmt_repo.dir, 'include', 'fmt', 'base.h') + base_h_path = os.path.join(fmt_repo.dir, "include", "fmt", "base.h") for line in fileinput.input(base_h_path): - m = re.match(r'\s*inline namespace v(.*) .*', line) + m = re.match(r"\s*inline namespace v(.*) .*", line) if m: ns_version = m.group(1) break - major_version = version.split('.')[0] + major_version = version.split(".")[0] if not ns_version or ns_version != major_version: - raise Exception(f'Version mismatch {ns_version} != {major_version}') + raise Exception(f"Version mismatch {ns_version} != {major_version}") # Workaround GitHub-flavored Markdown treating newlines as
. - changes = '' + changes = "" code_block = False stripped = False for line in first_section: - if re.match(r'^\s*```', line): + if re.match(r"^\s*```", line): code_block = not code_block changes += line stripped = False @@ -136,53 +137,64 @@ class Env: if code_block: changes += line continue - if line == '\n' or re.match(r'^\s*\|.*', line): + if line == "\n" or re.match(r"^\s*\|.*", line): if stripped: - changes += '\n' + changes += "\n" stripped = False changes += line continue if stripped: - line = ' ' + line.lstrip() + line = " " + line.lstrip() changes += line.rstrip() stripped = True - fmt_repo.checkout('-B', 'release') + fmt_repo.checkout("-B", "release") fmt_repo.add(changelog) - fmt_repo.commit('-m', 'Update version') + fmt_repo.commit("-m", "Update version") # Build the docs and package. run = Runner(fmt_repo.dir) - run('cmake', '.') - run('make', 'doc', 'package_source') + run("cmake", ".") + run("make", "doc", "package_source") # Create a release on GitHub. - fmt_repo.push('origin', 'release') - auth_headers = {'Authorization': 'token ' + os.getenv('FMT_TOKEN')} + fmt_repo.push("origin", "release") + auth_headers = {"Authorization": "token " + os.getenv("FMT_TOKEN")} req = urllib.request.Request( - 'https://api.github.com/repos/fmtlib/fmt/releases', - data=json.dumps({'tag_name': version, - 'target_commitish': 'release', - 'body': changes, 'draft': True}).encode('utf-8'), - headers=auth_headers, method='POST') + "https://api.github.com/repos/fmtlib/fmt/releases", + data=json.dumps( + { + "tag_name": version, + "target_commitish": "release", + "body": changes, + "draft": True, + } + ).encode("utf-8"), + headers=auth_headers, + method="POST", + ) with urllib.request.urlopen(req) as response: if response.status != 201: - raise Exception(f'Failed to create a release ' + - '{response.status} {response.reason}') - response_data = json.loads(response.read().decode('utf-8')) - id = response_data['id'] + raise Exception( + f"Failed to create a release " + "{response.status} {response.reason}" + ) + response_data = json.loads(response.read().decode("utf-8")) + id = response_data["id"] # Upload the package. - uploads_url = 'https://uploads.github.com/repos/fmtlib/fmt/releases' - package = 'fmt-{}.zip'.format(version) + uploads_url = "https://uploads.github.com/repos/fmtlib/fmt/releases" + package = "fmt-{}.zip".format(version) req = urllib.request.Request( - f'{uploads_url}/{id}/assets?name={package}', - headers={'Content-Type': 'application/zip'} | auth_headers, - data=open('build/fmt/' + package, 'rb').read(), method='POST') + f"{uploads_url}/{id}/assets?name={package}", + headers={"Content-Type": "application/zip"} | auth_headers, + data=open("build/fmt/" + package, "rb").read(), + method="POST", + ) with urllib.request.urlopen(req) as response: if response.status != 201: - raise Exception(f'Failed to upload an asset ' - '{response.status} {response.reason}') + raise Exception( + f"Failed to upload an asset " "{response.status} {response.reason}" + ) - short_version = '.'.join(version.split('.')[:-1]) - check_call(['./mkdocs', 'deploy', short_version]) + short_version = ".".join(version.split(".")[:-1]) + check_call(["./mkdocs", "deploy", short_version]) diff --git a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_leak_test.py b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_leak_test.py index 8b02bc46..59ab58cc 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_leak_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_leak_test.py @@ -33,10 +33,10 @@ from googlemock.test import gmock_test_utils -PROGRAM_PATH = gmock_test_utils.GetTestExecutablePath('gmock_leak_test_') -TEST_WITH_EXPECT_CALL = [PROGRAM_PATH, '--gtest_filter=*ExpectCall*'] -TEST_WITH_ON_CALL = [PROGRAM_PATH, '--gtest_filter=*OnCall*'] -TEST_MULTIPLE_LEAKS = [PROGRAM_PATH, '--gtest_filter=*MultipleLeaked*'] +PROGRAM_PATH = gmock_test_utils.GetTestExecutablePath("gmock_leak_test_") +TEST_WITH_EXPECT_CALL = [PROGRAM_PATH, "--gtest_filter=*ExpectCall*"] +TEST_WITH_ON_CALL = [PROGRAM_PATH, "--gtest_filter=*OnCall*"] +TEST_MULTIPLE_LEAKS = [PROGRAM_PATH, "--gtest_filter=*MultipleLeaked*"] environ = gmock_test_utils.environ SetEnvVar = gmock_test_utils.SetEnvVar @@ -51,63 +51,61 @@ class GMockLeakTest(gmock_test_utils.TestCase): - def testCatchesLeakedMockByDefault(self): - self.assertNotEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_EXPECT_CALL, env=environ - ).exit_code, - ) - self.assertNotEqual( - 0, gmock_test_utils.Subprocess(TEST_WITH_ON_CALL, env=environ).exit_code - ) + def testCatchesLeakedMockByDefault(self): + self.assertNotEqual( + 0, + gmock_test_utils.Subprocess(TEST_WITH_EXPECT_CALL, env=environ).exit_code, + ) + self.assertNotEqual( + 0, gmock_test_utils.Subprocess(TEST_WITH_ON_CALL, env=environ).exit_code + ) - def testDoesNotCatchLeakedMockWhenDisabled(self): - self.assertEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_EXPECT_CALL + ['--gmock_catch_leaked_mocks=0'], - env=environ, - ).exit_code, - ) - self.assertEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_ON_CALL + ['--gmock_catch_leaked_mocks=0'], env=environ - ).exit_code, - ) + def testDoesNotCatchLeakedMockWhenDisabled(self): + self.assertEqual( + 0, + gmock_test_utils.Subprocess( + TEST_WITH_EXPECT_CALL + ["--gmock_catch_leaked_mocks=0"], + env=environ, + ).exit_code, + ) + self.assertEqual( + 0, + gmock_test_utils.Subprocess( + TEST_WITH_ON_CALL + ["--gmock_catch_leaked_mocks=0"], env=environ + ).exit_code, + ) - def testCatchesLeakedMockWhenEnabled(self): - self.assertNotEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_EXPECT_CALL + ['--gmock_catch_leaked_mocks'], env=environ - ).exit_code, - ) - self.assertNotEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_ON_CALL + ['--gmock_catch_leaked_mocks'], env=environ - ).exit_code, - ) + def testCatchesLeakedMockWhenEnabled(self): + self.assertNotEqual( + 0, + gmock_test_utils.Subprocess( + TEST_WITH_EXPECT_CALL + ["--gmock_catch_leaked_mocks"], env=environ + ).exit_code, + ) + self.assertNotEqual( + 0, + gmock_test_utils.Subprocess( + TEST_WITH_ON_CALL + ["--gmock_catch_leaked_mocks"], env=environ + ).exit_code, + ) - def testCatchesLeakedMockWhenEnabledWithExplictFlagValue(self): - self.assertNotEqual( - 0, - gmock_test_utils.Subprocess( - TEST_WITH_EXPECT_CALL + ['--gmock_catch_leaked_mocks=1'], - env=environ, - ).exit_code, - ) + def testCatchesLeakedMockWhenEnabledWithExplictFlagValue(self): + self.assertNotEqual( + 0, + gmock_test_utils.Subprocess( + TEST_WITH_EXPECT_CALL + ["--gmock_catch_leaked_mocks=1"], + env=environ, + ).exit_code, + ) - def testCatchesMultipleLeakedMocks(self): - self.assertNotEqual( - 0, - gmock_test_utils.Subprocess( - TEST_MULTIPLE_LEAKS + ['--gmock_catch_leaked_mocks'], env=environ - ).exit_code, - ) + def testCatchesMultipleLeakedMocks(self): + self.assertNotEqual( + 0, + gmock_test_utils.Subprocess( + TEST_MULTIPLE_LEAKS + ["--gmock_catch_leaked_mocks"], env=environ + ).exit_code, + ) -if __name__ == '__main__': - gmock_test_utils.Main() +if __name__ == "__main__": + gmock_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_output_test.py b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_output_test.py index 7c24d683..03d4b219 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_output_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_output_test.py @@ -47,144 +47,144 @@ # The flag for generating the golden file -GENGOLDEN_FLAG = '--gengolden' +GENGOLDEN_FLAG = "--gengolden" -PROGRAM_PATH = gmock_test_utils.GetTestExecutablePath('gmock_output_test_') -COMMAND = [PROGRAM_PATH, '--gtest_stack_trace_depth=0', '--gtest_print_time=0'] -GOLDEN_NAME = 'gmock_output_test_golden.txt' +PROGRAM_PATH = gmock_test_utils.GetTestExecutablePath("gmock_output_test_") +COMMAND = [PROGRAM_PATH, "--gtest_stack_trace_depth=0", "--gtest_print_time=0"] +GOLDEN_NAME = "gmock_output_test_golden.txt" GOLDEN_PATH = os.path.join(gmock_test_utils.GetSourceDir(), GOLDEN_NAME) def ToUnixLineEnding(s): - """Changes all Windows/Mac line endings in s to UNIX line endings.""" + """Changes all Windows/Mac line endings in s to UNIX line endings.""" - return s.replace('\r\n', '\n').replace('\r', '\n') + return s.replace("\r\n", "\n").replace("\r", "\n") def RemoveReportHeaderAndFooter(output): - """Removes Google Test result report's header and footer from the output.""" + """Removes Google Test result report's header and footer from the output.""" - output = re.sub(r'.*gtest_main.*\n', '', output) - output = re.sub(r'\[.*\d+ tests.*\n', '', output) - output = re.sub(r'\[.* test environment .*\n', '', output) - output = re.sub(r'\[=+\] \d+ tests .* ran.*', '', output) - output = re.sub(r'.* FAILED TESTS\n', '', output) - return output + output = re.sub(r".*gtest_main.*\n", "", output) + output = re.sub(r"\[.*\d+ tests.*\n", "", output) + output = re.sub(r"\[.* test environment .*\n", "", output) + output = re.sub(r"\[=+\] \d+ tests .* ran.*", "", output) + output = re.sub(r".* FAILED TESTS\n", "", output) + return output def RemoveLocations(output): - """Removes all file location info from a Google Test program's output. + """Removes all file location info from a Google Test program's output. - Args: - output: the output of a Google Test program. + Args: + output: the output of a Google Test program. - Returns: - output with all file location info (in the form of - 'DIRECTORY/FILE_NAME:LINE_NUMBER: 'or - 'DIRECTORY\\FILE_NAME(LINE_NUMBER): ') replaced by - 'FILE:#: '. - """ + Returns: + output with all file location info (in the form of + 'DIRECTORY/FILE_NAME:LINE_NUMBER: 'or + 'DIRECTORY\\FILE_NAME(LINE_NUMBER): ') replaced by + 'FILE:#: '. + """ - return re.sub(r'.*[/\\](.+)(\:\d+|\(\d+\))\:', 'FILE:#:', output) + return re.sub(r".*[/\\](.+)(\:\d+|\(\d+\))\:", "FILE:#:", output) def NormalizeErrorMarker(output): - """Normalizes the error marker, which is different on Windows vs on Linux.""" + """Normalizes the error marker, which is different on Windows vs on Linux.""" - return re.sub(r' error: ', ' Failure\n', output) + return re.sub(r" error: ", " Failure\n", output) def RemoveMemoryAddresses(output): - """Removes memory addresses from the test output.""" + """Removes memory addresses from the test output.""" - return re.sub(r'@\w+', '@0x#', output) + return re.sub(r"@\w+", "@0x#", output) def RemoveTestNamesOfLeakedMocks(output): - """Removes the test names of leaked mock objects from the test output.""" + """Removes the test names of leaked mock objects from the test output.""" - return re.sub(r'\(used in test .+\) ', '', output) + return re.sub(r"\(used in test .+\) ", "", output) def GetLeakyTests(output): - """Returns a list of test names that leak mock objects.""" + """Returns a list of test names that leak mock objects.""" - # findall() returns a list of all matches of the regex in output. - # For example, if '(used in test FooTest.Bar)' is in output, the - # list will contain 'FooTest.Bar'. - return re.findall(r'\(used in test (.+)\)', output) + # findall() returns a list of all matches of the regex in output. + # For example, if '(used in test FooTest.Bar)' is in output, the + # list will contain 'FooTest.Bar'. + return re.findall(r"\(used in test (.+)\)", output) def GetNormalizedOutputAndLeakyTests(output): - """Normalizes the output of gmock_output_test_. + """Normalizes the output of gmock_output_test_. - Args: - output: The test output. + Args: + output: The test output. - Returns: - A tuple (the normalized test output, the list of test names that have - leaked mocks). - """ + Returns: + A tuple (the normalized test output, the list of test names that have + leaked mocks). + """ - output = ToUnixLineEnding(output) - output = RemoveReportHeaderAndFooter(output) - output = NormalizeErrorMarker(output) - output = RemoveLocations(output) - output = RemoveMemoryAddresses(output) - return (RemoveTestNamesOfLeakedMocks(output), GetLeakyTests(output)) + output = ToUnixLineEnding(output) + output = RemoveReportHeaderAndFooter(output) + output = NormalizeErrorMarker(output) + output = RemoveLocations(output) + output = RemoveMemoryAddresses(output) + return (RemoveTestNamesOfLeakedMocks(output), GetLeakyTests(output)) def GetShellCommandOutput(cmd): - """Runs a command in a sub-process, and returns its STDOUT in a string.""" + """Runs a command in a sub-process, and returns its STDOUT in a string.""" - return gmock_test_utils.Subprocess(cmd, capture_stderr=False).output + return gmock_test_utils.Subprocess(cmd, capture_stderr=False).output def GetNormalizedCommandOutputAndLeakyTests(cmd): - """Runs a command and returns its normalized output and a list of leaky tests. + """Runs a command and returns its normalized output and a list of leaky tests. - Args: - cmd: the shell command. - """ + Args: + cmd: the shell command. + """ - # Disables exception pop-ups on Windows. - os.environ['GTEST_CATCH_EXCEPTIONS'] = '1' - return GetNormalizedOutputAndLeakyTests(GetShellCommandOutput(cmd)) + # Disables exception pop-ups on Windows. + os.environ["GTEST_CATCH_EXCEPTIONS"] = "1" + return GetNormalizedOutputAndLeakyTests(GetShellCommandOutput(cmd)) class GMockOutputTest(gmock_test_utils.TestCase): - def testOutput(self): - (output, leaky_tests) = GetNormalizedCommandOutputAndLeakyTests(COMMAND) - golden_file = open(GOLDEN_PATH, 'rb') - golden = golden_file.read().decode('utf-8') - golden_file.close() - # On Windows the repository might have been checked out with \r\n line - # endings, so normalize it here. - golden = ToUnixLineEnding(golden) - - # The normalized output should match the golden file. - self.assertEqual(golden, output) - - # The raw output should contain 2 leaked mock object errors for - # test GMockOutputTest.CatchesLeakedMocks. - self.assertEqual( - [ - 'GMockOutputTest.CatchesLeakedMocks', - 'GMockOutputTest.CatchesLeakedMocks', - ], - leaky_tests, - ) - - -if __name__ == '__main__': - if sys.argv[1:] == [GENGOLDEN_FLAG]: - (output, _) = GetNormalizedCommandOutputAndLeakyTests(COMMAND) - golden_file = open(GOLDEN_PATH, 'wb') - golden_file.write(output) - golden_file.close() - # Suppress the error "googletest was imported but a call to its main() - # was never detected." - os._exit(0) - else: - gmock_test_utils.Main() + def testOutput(self): + (output, leaky_tests) = GetNormalizedCommandOutputAndLeakyTests(COMMAND) + golden_file = open(GOLDEN_PATH, "rb") + golden = golden_file.read().decode("utf-8") + golden_file.close() + # On Windows the repository might have been checked out with \r\n line + # endings, so normalize it here. + golden = ToUnixLineEnding(golden) + + # The normalized output should match the golden file. + self.assertEqual(golden, output) + + # The raw output should contain 2 leaked mock object errors for + # test GMockOutputTest.CatchesLeakedMocks. + self.assertEqual( + [ + "GMockOutputTest.CatchesLeakedMocks", + "GMockOutputTest.CatchesLeakedMocks", + ], + leaky_tests, + ) + + +if __name__ == "__main__": + if sys.argv[1:] == [GENGOLDEN_FLAG]: + (output, _) = GetNormalizedCommandOutputAndLeakyTests(COMMAND) + golden_file = open(GOLDEN_PATH, "wb") + golden_file.write(output) + golden_file.close() + # Suppress the error "googletest was imported but a call to its main() + # was never detected." + os._exit(0) + else: + gmock_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_test_utils.py b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_test_utils.py index edad1f75..01a2f12a 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_test_utils.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googlemock/test/gmock_test_utils.py @@ -36,45 +36,45 @@ def GetSourceDir(): - """Returns the absolute path of the directory where the .py files are.""" + """Returns the absolute path of the directory where the .py files are.""" - return gtest_test_utils.GetSourceDir() + return gtest_test_utils.GetSourceDir() def GetTestExecutablePath(executable_name): - """Returns the absolute path of the test binary given its name. + """Returns the absolute path of the test binary given its name. - The function will print a message and abort the program if the resulting file - doesn't exist. + The function will print a message and abort the program if the resulting file + doesn't exist. - Args: - executable_name: name of the test binary that the test script runs. + Args: + executable_name: name of the test binary that the test script runs. - Returns: - The absolute path of the test binary. - """ + Returns: + The absolute path of the test binary. + """ - return gtest_test_utils.GetTestExecutablePath(executable_name) + return gtest_test_utils.GetTestExecutablePath(executable_name) def GetExitStatus(exit_code): - """Returns the argument to exit(), or -1 if exit() wasn't called. - - Args: - exit_code: the result value of os.system(command). - """ - - if os.name == 'nt': - # On Windows, os.WEXITSTATUS() doesn't work and os.system() returns - # the argument to exit() directly. - return exit_code - else: - # On Unix, os.WEXITSTATUS() must be used to extract the exit status - # from the result of os.system(). - if os.WIFEXITED(exit_code): - return os.WEXITSTATUS(exit_code) + """Returns the argument to exit(), or -1 if exit() wasn't called. + + Args: + exit_code: the result value of os.system(command). + """ + + if os.name == "nt": + # On Windows, os.WEXITSTATUS() doesn't work and os.system() returns + # the argument to exit() directly. + return exit_code else: - return -1 + # On Unix, os.WEXITSTATUS() must be used to extract the exit status + # from the result of os.system(). + if os.WIFEXITED(exit_code): + return os.WEXITSTATUS(exit_code) + else: + return -1 # Exposes utilities from gtest_test_utils. @@ -86,6 +86,6 @@ def GetExitStatus(exit_code): def Main(): - """Runs the unit test.""" + """Runs the unit test.""" - gtest_test_utils.Main() + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-break-on-failure-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-break-on-failure-unittest.py index e314b5cc..14ebb61a 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-break-on-failure-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-break-on-failure-unittest.py @@ -43,23 +43,23 @@ # Constants. -IS_WINDOWS = os.name == 'nt' +IS_WINDOWS = os.name == "nt" # The environment variable for enabling/disabling the break-on-failure mode. -BREAK_ON_FAILURE_ENV_VAR = 'GTEST_BREAK_ON_FAILURE' +BREAK_ON_FAILURE_ENV_VAR = "GTEST_BREAK_ON_FAILURE" # The command line flag for enabling/disabling the break-on-failure mode. -BREAK_ON_FAILURE_FLAG = 'gtest_break_on_failure' +BREAK_ON_FAILURE_FLAG = "gtest_break_on_failure" # The environment variable for enabling/disabling the throw-on-failure mode. -THROW_ON_FAILURE_ENV_VAR = 'GTEST_THROW_ON_FAILURE' +THROW_ON_FAILURE_ENV_VAR = "GTEST_THROW_ON_FAILURE" # The environment variable for enabling/disabling the catch-exceptions mode. -CATCH_EXCEPTIONS_ENV_VAR = 'GTEST_CATCH_EXCEPTIONS' +CATCH_EXCEPTIONS_ENV_VAR = "GTEST_CATCH_EXCEPTIONS" # Path to the googletest-break-on-failure-unittest_ program. EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'googletest-break-on-failure-unittest_' + "googletest-break-on-failure-unittest_" ) @@ -75,121 +75,121 @@ def Run(command): - """Runs a command; returns 1 if it was killed by a signal, or 0 otherwise.""" + """Runs a command; returns 1 if it was killed by a signal, or 0 otherwise.""" - p = gtest_test_utils.Subprocess(command, env=environ) - if p.terminated_by_signal: - return 1 - else: - return 0 + p = gtest_test_utils.Subprocess(command, env=environ) + if p.terminated_by_signal: + return 1 + else: + return 0 # The tests. class GTestBreakOnFailureUnitTest(gtest_test_utils.TestCase): - """Unit test for Google Test's break-on-failure mode. - - Tests using the GTEST_BREAK_ON_FAILURE environment variable or - the --gtest_break_on_failure flag to turn assertion failures into - segmentation faults. - """ - - def RunAndVerify(self, env_var_value, flag_value, expect_seg_fault): - """Runs googletest-break-on-failure-unittest_ and verifies its behavior. - - Runs googletest-break-on-failure-unittest_ and verifies that it does - (or does not) have a seg-fault. - - Args: - env_var_value: value of the GTEST_BREAK_ON_FAILURE environment - variable; None if the variable should be unset. - flag_value: value of the --gtest_break_on_failure flag; None if the - flag should not be present. - expect_seg_fault: 1 if the program is expected to generate a seg-fault; 0 - otherwise. + """Unit test for Google Test's break-on-failure mode. + + Tests using the GTEST_BREAK_ON_FAILURE environment variable or + the --gtest_break_on_failure flag to turn assertion failures into + segmentation faults. """ - SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, env_var_value) + def RunAndVerify(self, env_var_value, flag_value, expect_seg_fault): + """Runs googletest-break-on-failure-unittest_ and verifies its behavior. - if env_var_value is None: - env_var_value_msg = ' is not set' - else: - env_var_value_msg = '=' + env_var_value + Runs googletest-break-on-failure-unittest_ and verifies that it does + (or does not) have a seg-fault. - if flag_value is None: - flag = '' - elif flag_value == '0': - flag = '--%s=0' % BREAK_ON_FAILURE_FLAG - else: - flag = '--%s' % BREAK_ON_FAILURE_FLAG + Args: + env_var_value: value of the GTEST_BREAK_ON_FAILURE environment + variable; None if the variable should be unset. + flag_value: value of the --gtest_break_on_failure flag; None if the + flag should not be present. + expect_seg_fault: 1 if the program is expected to generate a seg-fault; 0 + otherwise. + """ - command = [EXE_PATH] - if flag: - command.append(flag) + SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, env_var_value) - if expect_seg_fault: - should_or_not = 'should' - else: - should_or_not = 'should not' + if env_var_value is None: + env_var_value_msg = " is not set" + else: + env_var_value_msg = "=" + env_var_value + + if flag_value is None: + flag = "" + elif flag_value == "0": + flag = "--%s=0" % BREAK_ON_FAILURE_FLAG + else: + flag = "--%s" % BREAK_ON_FAILURE_FLAG + + command = [EXE_PATH] + if flag: + command.append(flag) + + if expect_seg_fault: + should_or_not = "should" + else: + should_or_not = "should not" - has_seg_fault = Run(command) + has_seg_fault = Run(command) - SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, None) + SetEnvVar(BREAK_ON_FAILURE_ENV_VAR, None) - msg = 'when %s%s, an assertion failure in "%s" %s cause a seg-fault.' % ( - BREAK_ON_FAILURE_ENV_VAR, - env_var_value_msg, - ' '.join(command), - should_or_not, - ) - self.assertTrue(has_seg_fault == expect_seg_fault, msg) + msg = 'when %s%s, an assertion failure in "%s" %s cause a seg-fault.' % ( + BREAK_ON_FAILURE_ENV_VAR, + env_var_value_msg, + " ".join(command), + should_or_not, + ) + self.assertTrue(has_seg_fault == expect_seg_fault, msg) - def testDefaultBehavior(self): - """Tests the behavior of the default mode.""" + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" - self.RunAndVerify(env_var_value=None, flag_value=None, expect_seg_fault=0) + self.RunAndVerify(env_var_value=None, flag_value=None, expect_seg_fault=0) - def testEnvVar(self): - """Tests using the GTEST_BREAK_ON_FAILURE environment variable.""" + def testEnvVar(self): + """Tests using the GTEST_BREAK_ON_FAILURE environment variable.""" - self.RunAndVerify(env_var_value='0', flag_value=None, expect_seg_fault=0) - self.RunAndVerify(env_var_value='1', flag_value=None, expect_seg_fault=1) + self.RunAndVerify(env_var_value="0", flag_value=None, expect_seg_fault=0) + self.RunAndVerify(env_var_value="1", flag_value=None, expect_seg_fault=1) - def testFlag(self): - """Tests using the --gtest_break_on_failure flag.""" + def testFlag(self): + """Tests using the --gtest_break_on_failure flag.""" - self.RunAndVerify(env_var_value=None, flag_value='0', expect_seg_fault=0) - self.RunAndVerify(env_var_value=None, flag_value='1', expect_seg_fault=1) + self.RunAndVerify(env_var_value=None, flag_value="0", expect_seg_fault=0) + self.RunAndVerify(env_var_value=None, flag_value="1", expect_seg_fault=1) - def testFlagOverridesEnvVar(self): - """Tests that the flag overrides the environment variable.""" + def testFlagOverridesEnvVar(self): + """Tests that the flag overrides the environment variable.""" - self.RunAndVerify(env_var_value='0', flag_value='0', expect_seg_fault=0) - self.RunAndVerify(env_var_value='0', flag_value='1', expect_seg_fault=1) - self.RunAndVerify(env_var_value='1', flag_value='0', expect_seg_fault=0) - self.RunAndVerify(env_var_value='1', flag_value='1', expect_seg_fault=1) + self.RunAndVerify(env_var_value="0", flag_value="0", expect_seg_fault=0) + self.RunAndVerify(env_var_value="0", flag_value="1", expect_seg_fault=1) + self.RunAndVerify(env_var_value="1", flag_value="0", expect_seg_fault=0) + self.RunAndVerify(env_var_value="1", flag_value="1", expect_seg_fault=1) - def testBreakOnFailureOverridesThrowOnFailure(self): - """Tests that gtest_break_on_failure overrides gtest_throw_on_failure.""" + def testBreakOnFailureOverridesThrowOnFailure(self): + """Tests that gtest_break_on_failure overrides gtest_throw_on_failure.""" - SetEnvVar(THROW_ON_FAILURE_ENV_VAR, '1') - try: - self.RunAndVerify(env_var_value=None, flag_value='1', expect_seg_fault=1) - finally: - SetEnvVar(THROW_ON_FAILURE_ENV_VAR, None) + SetEnvVar(THROW_ON_FAILURE_ENV_VAR, "1") + try: + self.RunAndVerify(env_var_value=None, flag_value="1", expect_seg_fault=1) + finally: + SetEnvVar(THROW_ON_FAILURE_ENV_VAR, None) - if IS_WINDOWS: + if IS_WINDOWS: - def testCatchExceptionsDoesNotInterfere(self): - """Tests that gtest_catch_exceptions doesn't interfere.""" + def testCatchExceptionsDoesNotInterfere(self): + """Tests that gtest_catch_exceptions doesn't interfere.""" - SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, '1') - try: - self.RunAndVerify(env_var_value='1', flag_value='1', expect_seg_fault=1) - finally: - SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, None) + SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, "1") + try: + self.RunAndVerify(env_var_value="1", flag_value="1", expect_seg_fault=1) + finally: + SetEnvVar(CATCH_EXCEPTIONS_ENV_VAR, None) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-catch-exceptions-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-catch-exceptions-test.py index 180e18de..ac25a289 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-catch-exceptions-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-catch-exceptions-test.py @@ -38,21 +38,21 @@ from googletest.test import gtest_test_utils # Constants. -FLAG_PREFIX = '--gtest_' -LIST_TESTS_FLAG = FLAG_PREFIX + 'list_tests' -NO_CATCH_EXCEPTIONS_FLAG = FLAG_PREFIX + 'catch_exceptions=0' -FILTER_FLAG = FLAG_PREFIX + 'filter' +FLAG_PREFIX = "--gtest_" +LIST_TESTS_FLAG = FLAG_PREFIX + "list_tests" +NO_CATCH_EXCEPTIONS_FLAG = FLAG_PREFIX + "catch_exceptions=0" +FILTER_FLAG = FLAG_PREFIX + "filter" # Path to the googletest-catch-exceptions-ex-test_ binary, compiled with # exceptions enabled. EX_EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'googletest-catch-exceptions-ex-test_' + "googletest-catch-exceptions-ex-test_" ) # Path to the googletest-catch-exceptions-test_ binary, compiled with # exceptions disabled. EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'googletest-catch-exceptions-no-ex-test_' + "googletest-catch-exceptions-no-ex-test_" ) environ = gtest_test_utils.environ @@ -65,251 +65,234 @@ # the file. SetEnvVar(gtest_test_utils.PREMATURE_EXIT_FILE_ENV_VAR, None) -TEST_LIST = gtest_test_utils.Subprocess( - [EXE_PATH, LIST_TESTS_FLAG], env=environ -).output +TEST_LIST = gtest_test_utils.Subprocess([EXE_PATH, LIST_TESTS_FLAG], env=environ).output -SUPPORTS_SEH_EXCEPTIONS = 'ThrowsSehException' in TEST_LIST +SUPPORTS_SEH_EXCEPTIONS = "ThrowsSehException" in TEST_LIST if SUPPORTS_SEH_EXCEPTIONS: - BINARY_OUTPUT = gtest_test_utils.Subprocess([EXE_PATH], env=environ).output + BINARY_OUTPUT = gtest_test_utils.Subprocess([EXE_PATH], env=environ).output -EX_BINARY_OUTPUT = gtest_test_utils.Subprocess( - [EX_EXE_PATH], env=environ -).output +EX_BINARY_OUTPUT = gtest_test_utils.Subprocess([EX_EXE_PATH], env=environ).output # The tests. if SUPPORTS_SEH_EXCEPTIONS: - class CatchSehExceptionsTest(gtest_test_utils.TestCase): - """Tests exception-catching behavior.""" - - def TestSehExceptions(self, test_output): - self.assertIn( - ( - 'SEH exception with code 0x2a thrown ' - "in the test fixture's constructor" - ), - test_output, - ) - self.assertIn( - ( - 'SEH exception with code 0x2a thrown ' - "in the test fixture's destructor" - ), - test_output, - ) - self.assertIn( - 'SEH exception with code 0x2a thrown in SetUpTestSuite()', test_output - ) - self.assertIn( - 'SEH exception with code 0x2a thrown in TearDownTestSuite()', - test_output, - ) - self.assertIn( - 'SEH exception with code 0x2a thrown in SetUp()', test_output - ) - self.assertIn( - 'SEH exception with code 0x2a thrown in TearDown()', test_output - ) - self.assertIn( - 'SEH exception with code 0x2a thrown in the test body', test_output - ) - - def testCatchesSehExceptionsWithCxxExceptionsEnabled(self): - self.TestSehExceptions(EX_BINARY_OUTPUT) - - def testCatchesSehExceptionsWithCxxExceptionsDisabled(self): - self.TestSehExceptions(BINARY_OUTPUT) + class CatchSehExceptionsTest(gtest_test_utils.TestCase): + """Tests exception-catching behavior.""" + + def TestSehExceptions(self, test_output): + self.assertIn( + ( + "SEH exception with code 0x2a thrown " + "in the test fixture's constructor" + ), + test_output, + ) + self.assertIn( + ( + "SEH exception with code 0x2a thrown " + "in the test fixture's destructor" + ), + test_output, + ) + self.assertIn( + "SEH exception with code 0x2a thrown in SetUpTestSuite()", test_output + ) + self.assertIn( + "SEH exception with code 0x2a thrown in TearDownTestSuite()", + test_output, + ) + self.assertIn("SEH exception with code 0x2a thrown in SetUp()", test_output) + self.assertIn( + "SEH exception with code 0x2a thrown in TearDown()", test_output + ) + self.assertIn( + "SEH exception with code 0x2a thrown in the test body", test_output + ) + + def testCatchesSehExceptionsWithCxxExceptionsEnabled(self): + self.TestSehExceptions(EX_BINARY_OUTPUT) + + def testCatchesSehExceptionsWithCxxExceptionsDisabled(self): + self.TestSehExceptions(BINARY_OUTPUT) class CatchCxxExceptionsTest(gtest_test_utils.TestCase): - """Tests C++ exception-catching behavior. - - Tests in this test case verify that: - * C++ exceptions are caught and logged as C++ (not SEH) exceptions - * Exception thrown affect the remainder of the test work flow in the - expected manner. - """ - - def testCatchesCxxExceptionsInFixtureConstructor(self): - self.assertTrue( - 'C++ exception with description ' - '"Standard C++ exception" thrown ' - "in the test fixture's constructor" - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'unexpected' not in EX_BINARY_OUTPUT, - ( - 'This failure belongs in this test only if ' - '"CxxExceptionInConstructorTest" (no quotes) ' - 'appears on the same line as words "called unexpectedly"' - ), - ) - - if ( - 'CxxExceptionInDestructorTest.ThrowsExceptionInDestructor' - in EX_BINARY_OUTPUT - ): - - def testCatchesCxxExceptionsInFixtureDestructor(self): - self.assertTrue( - 'C++ exception with description ' - '"Standard C++ exception" thrown ' - "in the test fixture's destructor" - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInDestructorTest::TearDownTestSuite() ' - 'called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testCatchesCxxExceptionsInSetUpTestCase(self): - self.assertTrue( - 'C++ exception with description "Standard C++ exception"' - ' thrown in SetUpTestSuite()' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInConstructorTest::TearDownTestSuite() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertFalse( - 'CxxExceptionInSetUpTestSuiteTest constructor called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertFalse( - 'CxxExceptionInSetUpTestSuiteTest destructor called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertFalse( - 'CxxExceptionInSetUpTestSuiteTest::SetUp() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertFalse( - 'CxxExceptionInSetUpTestSuiteTest::TearDown() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertFalse( - 'CxxExceptionInSetUpTestSuiteTest test body called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testCatchesCxxExceptionsInTearDownTestCase(self): - self.assertTrue( - 'C++ exception with description "Standard C++ exception"' - ' thrown in TearDownTestSuite()' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testCatchesCxxExceptionsInSetUp(self): - self.assertTrue( - 'C++ exception with description "Standard C++ exception"' - ' thrown in SetUp()' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInSetUpTest::TearDownTestSuite() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInSetUpTest destructor called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInSetUpTest::TearDown() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'unexpected' not in EX_BINARY_OUTPUT, - ( - 'This failure belongs in this test only if ' - '"CxxExceptionInSetUpTest" (no quotes) ' - 'appears on the same line as words "called unexpectedly"' - ), - ) - - def testCatchesCxxExceptionsInTearDown(self): - self.assertTrue( - 'C++ exception with description "Standard C++ exception"' - ' thrown in TearDown()' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInTearDownTest::TearDownTestSuite() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInTearDownTest destructor called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testCatchesCxxExceptionsInTestBody(self): - self.assertTrue( - 'C++ exception with description "Standard C++ exception"' - ' thrown in the test body' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInTestBodyTest::TearDownTestSuite() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInTestBodyTest destructor called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - self.assertTrue( - 'CxxExceptionInTestBodyTest::TearDown() called as expected.' - in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testCatchesNonStdCxxExceptions(self): - self.assertTrue( - 'Unknown C++ exception thrown in the test body' in EX_BINARY_OUTPUT, - EX_BINARY_OUTPUT, - ) - - def testUnhandledCxxExceptionsAbortTheProgram(self): - # Filters out SEH exception tests on Windows. Unhandled SEH exceptions - # cause tests to show pop-up windows there. - filter_out_seh_tests_flag = FILTER_FLAG + '=-*Seh*' - # By default, Google Test doesn't catch the exceptions. - uncaught_exceptions_ex_binary_output = gtest_test_utils.Subprocess( - [EX_EXE_PATH, NO_CATCH_EXCEPTIONS_FLAG, filter_out_seh_tests_flag], - env=environ, - ).output - - self.assertIn( - 'Unhandled C++ exception terminating the program', - uncaught_exceptions_ex_binary_output, - ) - self.assertNotIn('unexpected', uncaught_exceptions_ex_binary_output) - - -if __name__ == '__main__': - gtest_test_utils.Main() + """Tests C++ exception-catching behavior. + + Tests in this test case verify that: + * C++ exceptions are caught and logged as C++ (not SEH) exceptions + * Exception thrown affect the remainder of the test work flow in the + expected manner. + """ + + def testCatchesCxxExceptionsInFixtureConstructor(self): + self.assertTrue( + "C++ exception with description " + '"Standard C++ exception" thrown ' + "in the test fixture's constructor" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "unexpected" not in EX_BINARY_OUTPUT, + ( + "This failure belongs in this test only if " + '"CxxExceptionInConstructorTest" (no quotes) ' + 'appears on the same line as words "called unexpectedly"' + ), + ) + + if "CxxExceptionInDestructorTest.ThrowsExceptionInDestructor" in EX_BINARY_OUTPUT: + + def testCatchesCxxExceptionsInFixtureDestructor(self): + self.assertTrue( + "C++ exception with description " + '"Standard C++ exception" thrown ' + "in the test fixture's destructor" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInDestructorTest::TearDownTestSuite() " + "called as expected." in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testCatchesCxxExceptionsInSetUpTestCase(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + " thrown in SetUpTestSuite()" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInConstructorTest::TearDownTestSuite() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertFalse( + "CxxExceptionInSetUpTestSuiteTest constructor called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertFalse( + "CxxExceptionInSetUpTestSuiteTest destructor called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertFalse( + "CxxExceptionInSetUpTestSuiteTest::SetUp() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertFalse( + "CxxExceptionInSetUpTestSuiteTest::TearDown() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertFalse( + "CxxExceptionInSetUpTestSuiteTest test body called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testCatchesCxxExceptionsInTearDownTestCase(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + " thrown in TearDownTestSuite()" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testCatchesCxxExceptionsInSetUp(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + " thrown in SetUp()" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInSetUpTest::TearDownTestSuite() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInSetUpTest destructor called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInSetUpTest::TearDown() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "unexpected" not in EX_BINARY_OUTPUT, + ( + "This failure belongs in this test only if " + '"CxxExceptionInSetUpTest" (no quotes) ' + 'appears on the same line as words "called unexpectedly"' + ), + ) + + def testCatchesCxxExceptionsInTearDown(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + " thrown in TearDown()" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInTearDownTest::TearDownTestSuite() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInTearDownTest destructor called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testCatchesCxxExceptionsInTestBody(self): + self.assertTrue( + 'C++ exception with description "Standard C++ exception"' + " thrown in the test body" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInTestBodyTest::TearDownTestSuite() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInTestBodyTest destructor called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + self.assertTrue( + "CxxExceptionInTestBodyTest::TearDown() called as expected." + in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testCatchesNonStdCxxExceptions(self): + self.assertTrue( + "Unknown C++ exception thrown in the test body" in EX_BINARY_OUTPUT, + EX_BINARY_OUTPUT, + ) + + def testUnhandledCxxExceptionsAbortTheProgram(self): + # Filters out SEH exception tests on Windows. Unhandled SEH exceptions + # cause tests to show pop-up windows there. + filter_out_seh_tests_flag = FILTER_FLAG + "=-*Seh*" + # By default, Google Test doesn't catch the exceptions. + uncaught_exceptions_ex_binary_output = gtest_test_utils.Subprocess( + [EX_EXE_PATH, NO_CATCH_EXCEPTIONS_FLAG, filter_out_seh_tests_flag], + env=environ, + ).output + + self.assertIn( + "Unhandled C++ exception terminating the program", + uncaught_exceptions_ex_binary_output, + ) + self.assertNotIn("unexpected", uncaught_exceptions_ex_binary_output) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-color-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-color-test.py index 8968cf1f..9aba107d 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-color-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-color-test.py @@ -34,97 +34,97 @@ import os from googletest.test import gtest_test_utils -IS_WINDOWS = os.name == 'nt' +IS_WINDOWS = os.name == "nt" -COLOR_ENV_VAR = 'GTEST_COLOR' -COLOR_FLAG = 'gtest_color' -COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-color-test_') +COLOR_ENV_VAR = "GTEST_COLOR" +COLOR_FLAG = "gtest_color" +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-color-test_") def SetEnvVar(env_var, value): - """Sets the env variable to 'value'; unsets it when 'value' is None.""" + """Sets the env variable to 'value'; unsets it when 'value' is None.""" - if value is not None: - os.environ[env_var] = value - elif env_var in os.environ: - del os.environ[env_var] + if value is not None: + os.environ[env_var] = value + elif env_var in os.environ: + del os.environ[env_var] def UsesColor(term, color_env_var, color_flag): - """Runs googletest-color-test_ and returns its exit code.""" + """Runs googletest-color-test_ and returns its exit code.""" - SetEnvVar('TERM', term) - SetEnvVar(COLOR_ENV_VAR, color_env_var) + SetEnvVar("TERM", term) + SetEnvVar(COLOR_ENV_VAR, color_env_var) - if color_flag is None: - args = [] - else: - args = ['--%s=%s' % (COLOR_FLAG, color_flag)] - p = gtest_test_utils.Subprocess([COMMAND] + args) - return not p.exited or p.exit_code + if color_flag is None: + args = [] + else: + args = ["--%s=%s" % (COLOR_FLAG, color_flag)] + p = gtest_test_utils.Subprocess([COMMAND] + args) + return not p.exited or p.exit_code class GTestColorTest(gtest_test_utils.TestCase): - def testNoEnvVarNoFlag(self): - """Tests the case when there's neither GTEST_COLOR nor --gtest_color.""" - - if not IS_WINDOWS: - self.assertTrue(not UsesColor('dumb', None, None)) - self.assertTrue(not UsesColor('emacs', None, None)) - self.assertTrue(not UsesColor('xterm-mono', None, None)) - self.assertTrue(not UsesColor('unknown', None, None)) - self.assertTrue(not UsesColor(None, None, None)) - self.assertTrue(UsesColor('linux', None, None)) - self.assertTrue(UsesColor('cygwin', None, None)) - self.assertTrue(UsesColor('xterm', None, None)) - self.assertTrue(UsesColor('xterm-color', None, None)) - self.assertTrue(UsesColor('xterm-kitty', None, None)) - self.assertTrue(UsesColor('alacritty', None, None)) - self.assertTrue(UsesColor('xterm-256color', None, None)) - - def testFlagOnly(self): - """Tests the case when there's --gtest_color but not GTEST_COLOR.""" - - self.assertTrue(not UsesColor('dumb', None, 'no')) - self.assertTrue(not UsesColor('xterm-color', None, 'no')) - if not IS_WINDOWS: - self.assertTrue(not UsesColor('emacs', None, 'auto')) - self.assertTrue(UsesColor('xterm', None, 'auto')) - self.assertTrue(UsesColor('dumb', None, 'yes')) - self.assertTrue(UsesColor('xterm', None, 'yes')) - - def testEnvVarOnly(self): - """Tests the case when there's GTEST_COLOR but not --gtest_color.""" - - self.assertTrue(not UsesColor('dumb', 'no', None)) - self.assertTrue(not UsesColor('xterm-color', 'no', None)) - if not IS_WINDOWS: - self.assertTrue(not UsesColor('dumb', 'auto', None)) - self.assertTrue(UsesColor('xterm-color', 'auto', None)) - self.assertTrue(UsesColor('dumb', 'yes', None)) - self.assertTrue(UsesColor('xterm-color', 'yes', None)) - - def testEnvVarAndFlag(self): - """Tests the case when there are both GTEST_COLOR and --gtest_color.""" - - self.assertTrue(not UsesColor('xterm-color', 'no', 'no')) - self.assertTrue(UsesColor('dumb', 'no', 'yes')) - self.assertTrue(UsesColor('xterm-color', 'no', 'auto')) - - def testAliasesOfYesAndNo(self): - """Tests using aliases in specifying --gtest_color.""" - - self.assertTrue(UsesColor('dumb', None, 'true')) - self.assertTrue(UsesColor('dumb', None, 'YES')) - self.assertTrue(UsesColor('dumb', None, 'T')) - self.assertTrue(UsesColor('dumb', None, '1')) - - self.assertTrue(not UsesColor('xterm', None, 'f')) - self.assertTrue(not UsesColor('xterm', None, 'false')) - self.assertTrue(not UsesColor('xterm', None, '0')) - self.assertTrue(not UsesColor('xterm', None, 'unknown')) - - -if __name__ == '__main__': - gtest_test_utils.Main() + def testNoEnvVarNoFlag(self): + """Tests the case when there's neither GTEST_COLOR nor --gtest_color.""" + + if not IS_WINDOWS: + self.assertTrue(not UsesColor("dumb", None, None)) + self.assertTrue(not UsesColor("emacs", None, None)) + self.assertTrue(not UsesColor("xterm-mono", None, None)) + self.assertTrue(not UsesColor("unknown", None, None)) + self.assertTrue(not UsesColor(None, None, None)) + self.assertTrue(UsesColor("linux", None, None)) + self.assertTrue(UsesColor("cygwin", None, None)) + self.assertTrue(UsesColor("xterm", None, None)) + self.assertTrue(UsesColor("xterm-color", None, None)) + self.assertTrue(UsesColor("xterm-kitty", None, None)) + self.assertTrue(UsesColor("alacritty", None, None)) + self.assertTrue(UsesColor("xterm-256color", None, None)) + + def testFlagOnly(self): + """Tests the case when there's --gtest_color but not GTEST_COLOR.""" + + self.assertTrue(not UsesColor("dumb", None, "no")) + self.assertTrue(not UsesColor("xterm-color", None, "no")) + if not IS_WINDOWS: + self.assertTrue(not UsesColor("emacs", None, "auto")) + self.assertTrue(UsesColor("xterm", None, "auto")) + self.assertTrue(UsesColor("dumb", None, "yes")) + self.assertTrue(UsesColor("xterm", None, "yes")) + + def testEnvVarOnly(self): + """Tests the case when there's GTEST_COLOR but not --gtest_color.""" + + self.assertTrue(not UsesColor("dumb", "no", None)) + self.assertTrue(not UsesColor("xterm-color", "no", None)) + if not IS_WINDOWS: + self.assertTrue(not UsesColor("dumb", "auto", None)) + self.assertTrue(UsesColor("xterm-color", "auto", None)) + self.assertTrue(UsesColor("dumb", "yes", None)) + self.assertTrue(UsesColor("xterm-color", "yes", None)) + + def testEnvVarAndFlag(self): + """Tests the case when there are both GTEST_COLOR and --gtest_color.""" + + self.assertTrue(not UsesColor("xterm-color", "no", "no")) + self.assertTrue(UsesColor("dumb", "no", "yes")) + self.assertTrue(UsesColor("xterm-color", "no", "auto")) + + def testAliasesOfYesAndNo(self): + """Tests using aliases in specifying --gtest_color.""" + + self.assertTrue(UsesColor("dumb", None, "true")) + self.assertTrue(UsesColor("dumb", None, "YES")) + self.assertTrue(UsesColor("dumb", None, "T")) + self.assertTrue(UsesColor("dumb", None, "1")) + + self.assertTrue(not UsesColor("xterm", None, "f")) + self.assertTrue(not UsesColor("xterm", None, "false")) + self.assertTrue(not UsesColor("xterm", None, "0")) + self.assertTrue(not UsesColor("xterm", None, "unknown")) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-env-var-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-env-var-test.py index 24d8edbb..ac01bffb 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-env-var-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-env-var-test.py @@ -35,86 +35,86 @@ from googletest.test import gtest_test_utils -IS_WINDOWS = os.name == 'nt' -IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' +IS_WINDOWS = os.name == "nt" +IS_LINUX = os.name == "posix" and os.uname()[0] == "Linux" -COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-env-var-test_') +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-env-var-test_") environ = os.environ.copy() def AssertEq(expected, actual): - if expected != actual: - print('Expected: %s' % (expected,)) - print(' Actual: %s' % (actual,)) - raise AssertionError + if expected != actual: + print("Expected: %s" % (expected,)) + print(" Actual: %s" % (actual,)) + raise AssertionError def SetEnvVar(env_var, value): - """Sets the env variable to 'value'; unsets it when 'value' is None.""" + """Sets the env variable to 'value'; unsets it when 'value' is None.""" - if value is not None: - environ[env_var] = value - elif env_var in environ: - del environ[env_var] + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] def GetFlag(flag): - """Runs googletest-env-var-test_ and returns its output.""" + """Runs googletest-env-var-test_ and returns its output.""" - args = [COMMAND] - if flag is not None: - args += [flag] - return gtest_test_utils.Subprocess(args, env=environ).output + args = [COMMAND] + if flag is not None: + args += [flag] + return gtest_test_utils.Subprocess(args, env=environ).output def TestFlag(flag, test_val, default_val): - """Verifies that the given flag is affected by the corresponding env var.""" + """Verifies that the given flag is affected by the corresponding env var.""" - env_var = 'GTEST_' + flag.upper() - SetEnvVar(env_var, test_val) - AssertEq(test_val, GetFlag(flag)) - SetEnvVar(env_var, None) - AssertEq(default_val, GetFlag(flag)) + env_var = "GTEST_" + flag.upper() + SetEnvVar(env_var, test_val) + AssertEq(test_val, GetFlag(flag)) + SetEnvVar(env_var, None) + AssertEq(default_val, GetFlag(flag)) class GTestEnvVarTest(gtest_test_utils.TestCase): - def testEnvVarAffectsFlag(self): - """Tests that environment variable should affect the corresponding flag.""" + def testEnvVarAffectsFlag(self): + """Tests that environment variable should affect the corresponding flag.""" - TestFlag('break_on_failure', '1', '0') - TestFlag('color', 'yes', 'auto') - SetEnvVar('TESTBRIDGE_TEST_RUNNER_FAIL_FAST', None) # For 'fail_fast' test - TestFlag('fail_fast', '1', '0') - TestFlag('filter', 'FooTest.Bar', '*') - SetEnvVar('XML_OUTPUT_FILE', None) # For 'output' test - TestFlag('output', 'xml:tmp/foo.xml', '') - TestFlag('brief', '1', '0') - TestFlag('print_time', '0', '1') - TestFlag('repeat', '999', '1') - TestFlag('throw_on_failure', '1', '0') - TestFlag('death_test_style', 'threadsafe', 'fast') - TestFlag('catch_exceptions', '0', '1') + TestFlag("break_on_failure", "1", "0") + TestFlag("color", "yes", "auto") + SetEnvVar("TESTBRIDGE_TEST_RUNNER_FAIL_FAST", None) # For 'fail_fast' test + TestFlag("fail_fast", "1", "0") + TestFlag("filter", "FooTest.Bar", "*") + SetEnvVar("XML_OUTPUT_FILE", None) # For 'output' test + TestFlag("output", "xml:tmp/foo.xml", "") + TestFlag("brief", "1", "0") + TestFlag("print_time", "0", "1") + TestFlag("repeat", "999", "1") + TestFlag("throw_on_failure", "1", "0") + TestFlag("death_test_style", "threadsafe", "fast") + TestFlag("catch_exceptions", "0", "1") - if IS_LINUX: - TestFlag('death_test_use_fork', '1', '0') - TestFlag('stack_trace_depth', '0', '100') + if IS_LINUX: + TestFlag("death_test_use_fork", "1", "0") + TestFlag("stack_trace_depth", "0", "100") - def testXmlOutputFile(self): - """Tests that $XML_OUTPUT_FILE affects the output flag.""" + def testXmlOutputFile(self): + """Tests that $XML_OUTPUT_FILE affects the output flag.""" - SetEnvVar('GTEST_OUTPUT', None) - SetEnvVar('XML_OUTPUT_FILE', 'tmp/bar.xml') - AssertEq('xml:tmp/bar.xml', GetFlag('output')) + SetEnvVar("GTEST_OUTPUT", None) + SetEnvVar("XML_OUTPUT_FILE", "tmp/bar.xml") + AssertEq("xml:tmp/bar.xml", GetFlag("output")) - def testXmlOutputFileOverride(self): - """Tests that $XML_OUTPUT_FILE is overridden by $GTEST_OUTPUT.""" + def testXmlOutputFileOverride(self): + """Tests that $XML_OUTPUT_FILE is overridden by $GTEST_OUTPUT.""" - SetEnvVar('GTEST_OUTPUT', 'xml:tmp/foo.xml') - SetEnvVar('XML_OUTPUT_FILE', 'tmp/bar.xml') - AssertEq('xml:tmp/foo.xml', GetFlag('output')) + SetEnvVar("GTEST_OUTPUT", "xml:tmp/foo.xml") + SetEnvVar("XML_OUTPUT_FILE", "tmp/bar.xml") + AssertEq("xml:tmp/foo.xml", GetFlag("output")) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-fail-if-no-test-linked-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-fail-if-no-test-linked-test.py index f5854ba9..7d5b4364 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-fail-if-no-test-linked-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-fail-if-no-test-linked-test.py @@ -42,128 +42,118 @@ class GTestFailIfNoTestLinkedTest(gtest_test_utils.TestCase): - """Tests the --gtest_fail_if_no_test_linked flag.""" - - def Run(self, program_name, flag=None, env=None): - """Run the given program with the given flag. - - Args: - program_name: Name of the program to run. - flag: The command line flag to pass to the program, or None. - env: Dictionary with environment to pass to the subprocess. - - Returns: - True if the program exits with code 0, false otherwise. - """ - - exe_path = gtest_test_utils.GetTestExecutablePath(program_name) - args = [exe_path] - if flag is not None: - args += [flag] - process = gtest_test_utils.Subprocess(args, capture_stderr=False, env=env) - return process.exited and process.exit_code == 0 - - def testSucceedsIfNoTestLinkedAndFlagNotSpecified(self): - """Tests the behavior of no test linked and flag not specified.""" - self.assertTrue( - self.Run("googletest-fail-if-no-test-linked-test-without-test_") - ) - - def testSucceedsIfNoTestLinkedAndFlagNotSpecifiedWithWarningFile(self): - """Tests that no test linked results in warning file output.""" - - warning_file = os.path.join(gtest_test_utils.GetTempDir(), "NO_TEST_LINKED") - self.assertTrue( - self.Run( - "googletest-fail-if-no-test-linked-test-without-test_", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + """Tests the --gtest_fail_if_no_test_linked flag.""" + + def Run(self, program_name, flag=None, env=None): + """Run the given program with the given flag. + + Args: + program_name: Name of the program to run. + flag: The command line flag to pass to the program, or None. + env: Dictionary with environment to pass to the subprocess. + + Returns: + True if the program exits with code 0, false otherwise. + """ + + exe_path = gtest_test_utils.GetTestExecutablePath(program_name) + args = [exe_path] + if flag is not None: + args += [flag] + process = gtest_test_utils.Subprocess(args, capture_stderr=False, env=env) + return process.exited and process.exit_code == 0 + + def testSucceedsIfNoTestLinkedAndFlagNotSpecified(self): + """Tests the behavior of no test linked and flag not specified.""" + self.assertTrue( + self.Run("googletest-fail-if-no-test-linked-test-without-test_") ) - ) - warning_file_contents = open(warning_file, "r").read() - self.assertEqual( - warning_file_contents, - "This test program does NOT link in any test case. Please make sure" - " this is intended.\n", - ) - - def testFailsIfNoTestLinkedAndFlagSpecified(self): - """Tests the behavior of no test linked and flag specified.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST" - ) - self.assertFalse( - self.Run( - "googletest-fail-if-no-test-linked-test-without-test_", - f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + + def testSucceedsIfNoTestLinkedAndFlagNotSpecifiedWithWarningFile(self): + """Tests that no test linked results in warning file output.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "NO_TEST_LINKED") + self.assertTrue( + self.Run( + "googletest-fail-if-no-test-linked-test-without-test_", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) + ) + warning_file_contents = open(warning_file, "r").read() + self.assertEqual( + warning_file_contents, + "This test program does NOT link in any test case. Please make sure" + " this is intended.\n", + ) + + def testFailsIfNoTestLinkedAndFlagSpecified(self): + """Tests the behavior of no test linked and flag specified.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST") + self.assertFalse( + self.Run( + "googletest-fail-if-no-test-linked-test-without-test_", + f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) ) - ) - with self.assertRaises(FileNotFoundError): - open(warning_file, "r") - - def testSucceedsIfEnabledTestLinkedAndFlagNotSpecified(self): - """Tests the behavior of enabled test linked and flag not specified.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST" - ) - self.assertTrue( - self.Run( - "googletest-fail-if-no-test-linked-test-with-enabled-test_", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + with self.assertRaises(FileNotFoundError): + open(warning_file, "r") + + def testSucceedsIfEnabledTestLinkedAndFlagNotSpecified(self): + """Tests the behavior of enabled test linked and flag not specified.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST") + self.assertTrue( + self.Run( + "googletest-fail-if-no-test-linked-test-with-enabled-test_", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) ) - ) - with self.assertRaises(FileNotFoundError): - open(warning_file, "r") - - def testSucceedsIfEnabledTestLinkedAndFlagSpecified(self): - """Tests the behavior of enabled test linked and flag specified.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST" - ) - self.assertTrue( - self.Run( - "googletest-fail-if-no-test-linked-test-with-enabled-test_", - f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + with self.assertRaises(FileNotFoundError): + open(warning_file, "r") + + def testSucceedsIfEnabledTestLinkedAndFlagSpecified(self): + """Tests the behavior of enabled test linked and flag specified.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST") + self.assertTrue( + self.Run( + "googletest-fail-if-no-test-linked-test-with-enabled-test_", + f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) ) - ) - with self.assertRaises(FileNotFoundError): - open(warning_file, "r") - - def testSucceedsIfDisabledTestLinkedAndFlagNotSpecified(self): - """Tests the behavior of disabled test linked and flag not specified.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST" - ) - self.assertTrue( - self.Run( - "googletest-fail-if-no-test-linked-test-with-disabled-test_", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + with self.assertRaises(FileNotFoundError): + open(warning_file, "r") + + def testSucceedsIfDisabledTestLinkedAndFlagNotSpecified(self): + """Tests the behavior of disabled test linked and flag not specified.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST") + self.assertTrue( + self.Run( + "googletest-fail-if-no-test-linked-test-with-disabled-test_", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) ) - ) - with self.assertRaises(FileNotFoundError): - open(warning_file, "r") - - def testSucceedsIfDisabledTestLinkedAndFlagSpecified(self): - """Tests the behavior of disabled test linked and flag specified.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST" - ) - self.assertTrue( - self.Run( - "googletest-fail-if-no-test-linked-test-with-disabled-test_", - f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", - env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + with self.assertRaises(FileNotFoundError): + open(warning_file, "r") + + def testSucceedsIfDisabledTestLinkedAndFlagSpecified(self): + """Tests the behavior of disabled test linked and flag specified.""" + + warning_file = os.path.join(gtest_test_utils.GetTempDir(), "SHOULD_NOT_EXIST") + self.assertTrue( + self.Run( + "googletest-fail-if-no-test-linked-test-with-disabled-test_", + f"--{FAIL_IF_NO_TEST_LINKED_FLAG}", + env={TEST_WARNINGS_OUTPUT_FILE: warning_file}, + ) ) - ) - with self.assertRaises(FileNotFoundError): - open(warning_file, "r") + with self.assertRaises(FileNotFoundError): + open(warning_file, "r") if __name__ == "__main__": - gtest_test_utils.Main() + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-failfast-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-failfast-unittest.py index cdbce0c5..02d95d25 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-failfast-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-failfast-unittest.py @@ -46,33 +46,30 @@ # Constants. # Bazel testbridge environment variable for fail fast -BAZEL_FAIL_FAST_ENV_VAR = 'TESTBRIDGE_TEST_RUNNER_FAIL_FAST' +BAZEL_FAIL_FAST_ENV_VAR = "TESTBRIDGE_TEST_RUNNER_FAIL_FAST" # The environment variable for specifying fail fast. -FAIL_FAST_ENV_VAR = 'GTEST_FAIL_FAST' +FAIL_FAST_ENV_VAR = "GTEST_FAIL_FAST" # The command line flag for specifying fail fast. -FAIL_FAST_FLAG = 'gtest_fail_fast' +FAIL_FAST_FLAG = "gtest_fail_fast" # The command line flag to run disabled tests. -RUN_DISABLED_FLAG = 'gtest_also_run_disabled_tests' +RUN_DISABLED_FLAG = "gtest_also_run_disabled_tests" # The command line flag for specifying a filter. -FILTER_FLAG = 'gtest_filter' +FILTER_FLAG = "gtest_filter" # Command to run the googletest-failfast-unittest_ program. -COMMAND = gtest_test_utils.GetTestExecutablePath( - 'googletest-failfast-unittest_' -) +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-failfast-unittest_") # The command line flag to tell Google Test to output the list of tests it # will run. -LIST_TESTS_FLAG = '--gtest_list_tests' +LIST_TESTS_FLAG = "--gtest_list_tests" # Indicates whether Google Test supports death tests. SUPPORTS_DEATH_TESTS = ( - 'HasDeathTest' - in gtest_test_utils.Subprocess([COMMAND, LIST_TESTS_FLAG]).output + "HasDeathTest" in gtest_test_utils.Subprocess([COMMAND, LIST_TESTS_FLAG]).output ) # Utilities. @@ -81,381 +78,377 @@ def SetEnvVar(env_var, value): - """Sets the env variable to 'value'; unsets it when 'value' is None.""" + """Sets the env variable to 'value'; unsets it when 'value' is None.""" - if value is not None: - environ[env_var] = value - elif env_var in environ: - del environ[env_var] + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] def RunAndReturnOutput(test_suite=None, fail_fast=None, run_disabled=False): - """Runs the test program and returns its output.""" - - args = [] - xml_path = os.path.join( - gtest_test_utils.GetTempDir(), '.GTestFailFastUnitTest.xml' - ) - args += ['--gtest_output=xml:' + xml_path] - if fail_fast is not None: - if isinstance(fail_fast, str): - args += ['--%s=%s' % (FAIL_FAST_FLAG, fail_fast)] - elif fail_fast: - args += ['--%s' % FAIL_FAST_FLAG] - else: - args += ['--no%s' % FAIL_FAST_FLAG] - if test_suite: - args += ['--%s=%s.*' % (FILTER_FLAG, test_suite)] - if run_disabled: - args += ['--%s' % RUN_DISABLED_FLAG] - txt_out = gtest_test_utils.Subprocess([COMMAND] + args, env=environ).output - with open(xml_path) as xml_file: - return txt_out, xml_file.read() + """Runs the test program and returns its output.""" + + args = [] + xml_path = os.path.join(gtest_test_utils.GetTempDir(), ".GTestFailFastUnitTest.xml") + args += ["--gtest_output=xml:" + xml_path] + if fail_fast is not None: + if isinstance(fail_fast, str): + args += ["--%s=%s" % (FAIL_FAST_FLAG, fail_fast)] + elif fail_fast: + args += ["--%s" % FAIL_FAST_FLAG] + else: + args += ["--no%s" % FAIL_FAST_FLAG] + if test_suite: + args += ["--%s=%s.*" % (FILTER_FLAG, test_suite)] + if run_disabled: + args += ["--%s" % RUN_DISABLED_FLAG] + txt_out = gtest_test_utils.Subprocess([COMMAND] + args, env=environ).output + with open(xml_path) as xml_file: + return txt_out, xml_file.read() # The unit test. class GTestFailFastUnitTest(gtest_test_utils.TestCase): - """Tests the env variable or the command line flag for fail_fast.""" - - def testDefaultBehavior(self): - """Tests the behavior of not specifying the fail_fast.""" - - txt, _ = RunAndReturnOutput() - self.assertIn('22 FAILED TEST', txt) - - def testGoogletestFlag(self): - txt, _ = RunAndReturnOutput(test_suite='HasSimpleTest', fail_fast=True) - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - - txt, _ = RunAndReturnOutput(test_suite='HasSimpleTest', fail_fast=False) - self.assertIn('4 FAILED TEST', txt) - self.assertNotIn('[ SKIPPED ]', txt) - - def testGoogletestEnvVar(self): - """Tests the behavior of specifying fail_fast via Googletest env var.""" - - try: - SetEnvVar(FAIL_FAST_ENV_VAR, '1') - txt, _ = RunAndReturnOutput('HasSimpleTest') - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - - SetEnvVar(FAIL_FAST_ENV_VAR, '0') - txt, _ = RunAndReturnOutput('HasSimpleTest') - self.assertIn('4 FAILED TEST', txt) - self.assertNotIn('[ SKIPPED ]', txt) - finally: - SetEnvVar(FAIL_FAST_ENV_VAR, None) - - def testBazelEnvVar(self): - """Tests the behavior of specifying fail_fast via Bazel testbridge.""" - - try: - SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '1') - txt, _ = RunAndReturnOutput('HasSimpleTest') - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - - SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '0') - txt, _ = RunAndReturnOutput('HasSimpleTest') - self.assertIn('4 FAILED TEST', txt) - self.assertNotIn('[ SKIPPED ]', txt) - finally: - SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) - - def testFlagOverridesEnvVar(self): - """Tests precedence of flag over env var.""" - - try: - SetEnvVar(FAIL_FAST_ENV_VAR, '0') - txt, _ = RunAndReturnOutput('HasSimpleTest', True) - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - finally: - SetEnvVar(FAIL_FAST_ENV_VAR, None) - - def testGoogletestEnvVarOverridesBazelEnvVar(self): - """Tests that the Googletest native env var over Bazel testbridge.""" - - try: - SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, '0') - SetEnvVar(FAIL_FAST_ENV_VAR, '1') - txt, _ = RunAndReturnOutput('HasSimpleTest') - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - finally: - SetEnvVar(FAIL_FAST_ENV_VAR, None) - SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) - - def testEventListener(self): - txt, _ = RunAndReturnOutput(test_suite='HasSkipTest', fail_fast=True) - self.assertIn('1 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 3 tests', txt) - for expected_count, callback in [ - (1, 'OnTestSuiteStart'), - (5, 'OnTestStart'), - (5, 'OnTestEnd'), - (5, 'OnTestPartResult'), - (1, 'OnTestSuiteEnd'), - ]: - self.assertEqual( - expected_count, - txt.count(callback), - 'Expected %d calls to callback %s match count on output: %s ' - % (expected_count, callback, txt), - ) - - txt, _ = RunAndReturnOutput(test_suite='HasSkipTest', fail_fast=False) - self.assertIn('3 FAILED TEST', txt) - self.assertIn('[ SKIPPED ] 1 test', txt) - for expected_count, callback in [ - (1, 'OnTestSuiteStart'), - (5, 'OnTestStart'), - (5, 'OnTestEnd'), - (5, 'OnTestPartResult'), - (1, 'OnTestSuiteEnd'), - ]: - self.assertEqual( - expected_count, - txt.count(callback), - 'Expected %d calls to callback %s match count on output: %s ' - % (expected_count, callback, txt), - ) - - def assertXmlResultCount(self, result, count, xml): - self.assertEqual( - count, - xml.count('result="%s"' % result), - 'Expected \'result="%s"\' match count of %s: %s ' - % (result, count, xml), - ) - - def assertXmlStatusCount(self, status, count, xml): - self.assertEqual( - count, - xml.count('status="%s"' % status), - 'Expected \'status="%s"\' match count of %s: %s ' - % (status, count, xml), - ) - - def assertFailFastXmlAndTxtOutput( - self, - fail_fast, - test_suite, - passed_count, - failure_count, - skipped_count, - suppressed_count, - run_disabled=False, - ): - """Assert XML and text output of a test execution.""" - - txt, xml = RunAndReturnOutput(test_suite, fail_fast, run_disabled) - if failure_count > 0: - self.assertIn('%s FAILED TEST' % failure_count, txt) - if suppressed_count > 0: - self.assertIn('%s DISABLED TEST' % suppressed_count, txt) - if skipped_count > 0: - self.assertIn('[ SKIPPED ] %s tests' % skipped_count, txt) - self.assertXmlStatusCount( - 'run', passed_count + failure_count + skipped_count, xml - ) - self.assertXmlStatusCount('notrun', suppressed_count, xml) - self.assertXmlResultCount('completed', passed_count + failure_count, xml) - self.assertXmlResultCount('skipped', skipped_count, xml) - self.assertXmlResultCount('suppressed', suppressed_count, xml) - - def assertFailFastBehavior( - self, - test_suite, - passed_count, - failure_count, - skipped_count, - suppressed_count, - run_disabled=False, - ): - """Assert --fail_fast via flag.""" - - for fail_fast in ('true', '1', 't', True): - self.assertFailFastXmlAndTxtOutput( - fail_fast, - test_suite, - passed_count, - failure_count, - skipped_count, - suppressed_count, - run_disabled, - ) - - def assertNotFailFastBehavior( - self, - test_suite, - passed_count, - failure_count, - skipped_count, - suppressed_count, - run_disabled=False, - ): - """Assert --nofail_fast via flag.""" - - for fail_fast in ('false', '0', 'f', False): - self.assertFailFastXmlAndTxtOutput( - fail_fast, - test_suite, - passed_count, - failure_count, - skipped_count, - suppressed_count, - run_disabled, - ) - - def testFlag_HasFixtureTest(self): - """Tests the behavior of fail_fast and TEST_F.""" - self.assertFailFastBehavior( - test_suite='HasFixtureTest', - passed_count=1, - failure_count=1, - skipped_count=3, - suppressed_count=0, - ) - self.assertNotFailFastBehavior( - test_suite='HasFixtureTest', - passed_count=1, - failure_count=4, - skipped_count=0, - suppressed_count=0, - ) - - def testFlag_HasSimpleTest(self): - """Tests the behavior of fail_fast and TEST.""" - self.assertFailFastBehavior( - test_suite='HasSimpleTest', - passed_count=1, - failure_count=1, - skipped_count=3, - suppressed_count=0, - ) - self.assertNotFailFastBehavior( - test_suite='HasSimpleTest', - passed_count=1, - failure_count=4, - skipped_count=0, - suppressed_count=0, - ) - - def testFlag_HasParametersTest(self): - """Tests the behavior of fail_fast and TEST_P.""" - self.assertFailFastBehavior( - test_suite='HasParametersSuite/HasParametersTest', - passed_count=0, - failure_count=1, - skipped_count=3, - suppressed_count=0, - ) - self.assertNotFailFastBehavior( - test_suite='HasParametersSuite/HasParametersTest', - passed_count=0, - failure_count=4, - skipped_count=0, - suppressed_count=0, - ) - - def testFlag_HasDisabledTest(self): - """Tests the behavior of fail_fast and Disabled test cases.""" - self.assertFailFastBehavior( - test_suite='HasDisabledTest', - passed_count=1, - failure_count=1, - skipped_count=2, - suppressed_count=1, - run_disabled=False, - ) - self.assertNotFailFastBehavior( - test_suite='HasDisabledTest', - passed_count=1, - failure_count=3, - skipped_count=0, - suppressed_count=1, + """Tests the env variable or the command line flag for fail_fast.""" + + def testDefaultBehavior(self): + """Tests the behavior of not specifying the fail_fast.""" + + txt, _ = RunAndReturnOutput() + self.assertIn("22 FAILED TEST", txt) + + def testGoogletestFlag(self): + txt, _ = RunAndReturnOutput(test_suite="HasSimpleTest", fail_fast=True) + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + + txt, _ = RunAndReturnOutput(test_suite="HasSimpleTest", fail_fast=False) + self.assertIn("4 FAILED TEST", txt) + self.assertNotIn("[ SKIPPED ]", txt) + + def testGoogletestEnvVar(self): + """Tests the behavior of specifying fail_fast via Googletest env var.""" + + try: + SetEnvVar(FAIL_FAST_ENV_VAR, "1") + txt, _ = RunAndReturnOutput("HasSimpleTest") + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + + SetEnvVar(FAIL_FAST_ENV_VAR, "0") + txt, _ = RunAndReturnOutput("HasSimpleTest") + self.assertIn("4 FAILED TEST", txt) + self.assertNotIn("[ SKIPPED ]", txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + + def testBazelEnvVar(self): + """Tests the behavior of specifying fail_fast via Bazel testbridge.""" + + try: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, "1") + txt, _ = RunAndReturnOutput("HasSimpleTest") + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, "0") + txt, _ = RunAndReturnOutput("HasSimpleTest") + self.assertIn("4 FAILED TEST", txt) + self.assertNotIn("[ SKIPPED ]", txt) + finally: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) + + def testFlagOverridesEnvVar(self): + """Tests precedence of flag over env var.""" + + try: + SetEnvVar(FAIL_FAST_ENV_VAR, "0") + txt, _ = RunAndReturnOutput("HasSimpleTest", True) + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + + def testGoogletestEnvVarOverridesBazelEnvVar(self): + """Tests that the Googletest native env var over Bazel testbridge.""" + + try: + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, "0") + SetEnvVar(FAIL_FAST_ENV_VAR, "1") + txt, _ = RunAndReturnOutput("HasSimpleTest") + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + finally: + SetEnvVar(FAIL_FAST_ENV_VAR, None) + SetEnvVar(BAZEL_FAIL_FAST_ENV_VAR, None) + + def testEventListener(self): + txt, _ = RunAndReturnOutput(test_suite="HasSkipTest", fail_fast=True) + self.assertIn("1 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 3 tests", txt) + for expected_count, callback in [ + (1, "OnTestSuiteStart"), + (5, "OnTestStart"), + (5, "OnTestEnd"), + (5, "OnTestPartResult"), + (1, "OnTestSuiteEnd"), + ]: + self.assertEqual( + expected_count, + txt.count(callback), + "Expected %d calls to callback %s match count on output: %s " + % (expected_count, callback, txt), + ) + + txt, _ = RunAndReturnOutput(test_suite="HasSkipTest", fail_fast=False) + self.assertIn("3 FAILED TEST", txt) + self.assertIn("[ SKIPPED ] 1 test", txt) + for expected_count, callback in [ + (1, "OnTestSuiteStart"), + (5, "OnTestStart"), + (5, "OnTestEnd"), + (5, "OnTestPartResult"), + (1, "OnTestSuiteEnd"), + ]: + self.assertEqual( + expected_count, + txt.count(callback), + "Expected %d calls to callback %s match count on output: %s " + % (expected_count, callback, txt), + ) + + def assertXmlResultCount(self, result, count, xml): + self.assertEqual( + count, + xml.count('result="%s"' % result), + "Expected 'result=\"%s\"' match count of %s: %s " % (result, count, xml), + ) + + def assertXmlStatusCount(self, status, count, xml): + self.assertEqual( + count, + xml.count('status="%s"' % status), + "Expected 'status=\"%s\"' match count of %s: %s " % (status, count, xml), + ) + + def assertFailFastXmlAndTxtOutput( + self, + fail_fast, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, run_disabled=False, - ) - - def testFlag_HasDisabledRunDisabledTest(self): - """Tests the behavior of fail_fast and Disabled test cases enabled.""" - self.assertFailFastBehavior( - test_suite='HasDisabledTest', - passed_count=1, - failure_count=1, - skipped_count=3, - suppressed_count=0, - run_disabled=True, - ) - self.assertNotFailFastBehavior( - test_suite='HasDisabledTest', - passed_count=1, - failure_count=4, - skipped_count=0, - suppressed_count=0, - run_disabled=True, - ) - - def testFlag_HasDisabledSuiteTest(self): - """Tests the behavior of fail_fast and Disabled test suites.""" - self.assertFailFastBehavior( - test_suite='DISABLED_HasDisabledSuite', - passed_count=0, - failure_count=0, - skipped_count=0, - suppressed_count=5, + ): + """Assert XML and text output of a test execution.""" + + txt, xml = RunAndReturnOutput(test_suite, fail_fast, run_disabled) + if failure_count > 0: + self.assertIn("%s FAILED TEST" % failure_count, txt) + if suppressed_count > 0: + self.assertIn("%s DISABLED TEST" % suppressed_count, txt) + if skipped_count > 0: + self.assertIn("[ SKIPPED ] %s tests" % skipped_count, txt) + self.assertXmlStatusCount( + "run", passed_count + failure_count + skipped_count, xml + ) + self.assertXmlStatusCount("notrun", suppressed_count, xml) + self.assertXmlResultCount("completed", passed_count + failure_count, xml) + self.assertXmlResultCount("skipped", skipped_count, xml) + self.assertXmlResultCount("suppressed", suppressed_count, xml) + + def assertFailFastBehavior( + self, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, run_disabled=False, - ) - self.assertNotFailFastBehavior( - test_suite='DISABLED_HasDisabledSuite', - passed_count=0, - failure_count=0, - skipped_count=0, - suppressed_count=5, + ): + """Assert --fail_fast via flag.""" + + for fail_fast in ("true", "1", "t", True): + self.assertFailFastXmlAndTxtOutput( + fail_fast, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, + run_disabled, + ) + + def assertNotFailFastBehavior( + self, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, run_disabled=False, - ) - - def testFlag_HasDisabledSuiteRunDisabledTest(self): - """Tests the behavior of fail_fast and Disabled test suites enabled.""" - self.assertFailFastBehavior( - test_suite='DISABLED_HasDisabledSuite', - passed_count=1, - failure_count=1, - skipped_count=3, - suppressed_count=0, - run_disabled=True, - ) - self.assertNotFailFastBehavior( - test_suite='DISABLED_HasDisabledSuite', - passed_count=1, - failure_count=4, - skipped_count=0, - suppressed_count=0, - run_disabled=True, - ) - - if SUPPORTS_DEATH_TESTS: - - def testFlag_HasDeathTest(self): - """Tests the behavior of fail_fast and death tests.""" - self.assertFailFastBehavior( - test_suite='HasDeathTest', - passed_count=1, - failure_count=1, - skipped_count=3, - suppressed_count=0, - ) - self.assertNotFailFastBehavior( - test_suite='HasDeathTest', - passed_count=1, - failure_count=4, - skipped_count=0, - suppressed_count=0, - ) - - -if __name__ == '__main__': - gtest_test_utils.Main() + ): + """Assert --nofail_fast via flag.""" + + for fail_fast in ("false", "0", "f", False): + self.assertFailFastXmlAndTxtOutput( + fail_fast, + test_suite, + passed_count, + failure_count, + skipped_count, + suppressed_count, + run_disabled, + ) + + def testFlag_HasFixtureTest(self): + """Tests the behavior of fail_fast and TEST_F.""" + self.assertFailFastBehavior( + test_suite="HasFixtureTest", + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + ) + self.assertNotFailFastBehavior( + test_suite="HasFixtureTest", + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + ) + + def testFlag_HasSimpleTest(self): + """Tests the behavior of fail_fast and TEST.""" + self.assertFailFastBehavior( + test_suite="HasSimpleTest", + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + ) + self.assertNotFailFastBehavior( + test_suite="HasSimpleTest", + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + ) + + def testFlag_HasParametersTest(self): + """Tests the behavior of fail_fast and TEST_P.""" + self.assertFailFastBehavior( + test_suite="HasParametersSuite/HasParametersTest", + passed_count=0, + failure_count=1, + skipped_count=3, + suppressed_count=0, + ) + self.assertNotFailFastBehavior( + test_suite="HasParametersSuite/HasParametersTest", + passed_count=0, + failure_count=4, + skipped_count=0, + suppressed_count=0, + ) + + def testFlag_HasDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test cases.""" + self.assertFailFastBehavior( + test_suite="HasDisabledTest", + passed_count=1, + failure_count=1, + skipped_count=2, + suppressed_count=1, + run_disabled=False, + ) + self.assertNotFailFastBehavior( + test_suite="HasDisabledTest", + passed_count=1, + failure_count=3, + skipped_count=0, + suppressed_count=1, + run_disabled=False, + ) + + def testFlag_HasDisabledRunDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test cases enabled.""" + self.assertFailFastBehavior( + test_suite="HasDisabledTest", + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + run_disabled=True, + ) + self.assertNotFailFastBehavior( + test_suite="HasDisabledTest", + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + run_disabled=True, + ) + + def testFlag_HasDisabledSuiteTest(self): + """Tests the behavior of fail_fast and Disabled test suites.""" + self.assertFailFastBehavior( + test_suite="DISABLED_HasDisabledSuite", + passed_count=0, + failure_count=0, + skipped_count=0, + suppressed_count=5, + run_disabled=False, + ) + self.assertNotFailFastBehavior( + test_suite="DISABLED_HasDisabledSuite", + passed_count=0, + failure_count=0, + skipped_count=0, + suppressed_count=5, + run_disabled=False, + ) + + def testFlag_HasDisabledSuiteRunDisabledTest(self): + """Tests the behavior of fail_fast and Disabled test suites enabled.""" + self.assertFailFastBehavior( + test_suite="DISABLED_HasDisabledSuite", + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + run_disabled=True, + ) + self.assertNotFailFastBehavior( + test_suite="DISABLED_HasDisabledSuite", + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + run_disabled=True, + ) + + if SUPPORTS_DEATH_TESTS: + + def testFlag_HasDeathTest(self): + """Tests the behavior of fail_fast and death tests.""" + self.assertFailFastBehavior( + test_suite="HasDeathTest", + passed_count=1, + failure_count=1, + skipped_count=3, + suppressed_count=0, + ) + self.assertNotFailFastBehavior( + test_suite="HasDeathTest", + passed_count=1, + failure_count=4, + skipped_count=0, + suppressed_count=0, + ) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-filter-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-filter-unittest.py index a44882a6..6d5c83e7 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-filter-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-filter-unittest.py @@ -44,9 +44,9 @@ import re try: - from sets import Set as set # For Python 2.3 compatibility + from sets import Set as set # For Python 2.3 compatibility except ImportError: - pass + pass import sys from googletest.test import gtest_test_utils @@ -59,11 +59,11 @@ # exception is thrown if the input is anything other than 'True' nor 'False'. CAN_PASS_EMPTY_ENV = False if sys.executable: - os.environ['EMPTY_VAR'] = '' - child = gtest_test_utils.Subprocess( - [sys.executable, '-c', "import os; print('EMPTY_VAR' in os.environ)"] - ) - CAN_PASS_EMPTY_ENV = eval(child.output) + os.environ["EMPTY_VAR"] = "" + child = gtest_test_utils.Subprocess( + [sys.executable, "-c", "import os; print('EMPTY_VAR' in os.environ)"] + ) + CAN_PASS_EMPTY_ENV = eval(child.output) # Check if this platform can unset environment variables in child processes. @@ -74,12 +74,12 @@ # is thrown if the input is neither 'True' nor 'False'. CAN_UNSET_ENV = False if sys.executable: - os.environ['UNSET_VAR'] = 'X' - del os.environ['UNSET_VAR'] - child = gtest_test_utils.Subprocess( - [sys.executable, '-c', "import os; print('UNSET_VAR' not in os.environ)"] - ) - CAN_UNSET_ENV = eval(child.output) + os.environ["UNSET_VAR"] = "X" + del os.environ["UNSET_VAR"] + child = gtest_test_utils.Subprocess( + [sys.executable, "-c", "import os; print('UNSET_VAR' not in os.environ)"] + ) + CAN_UNSET_ENV = eval(child.output) # Checks if we should test with an empty filter. This doesn't @@ -90,87 +90,86 @@ # The environment variable for specifying the test filters. -FILTER_ENV_VAR = 'GTEST_FILTER' +FILTER_ENV_VAR = "GTEST_FILTER" # The environment variables for test sharding. -TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' -SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' -SHARD_STATUS_FILE_ENV_VAR = 'GTEST_SHARD_STATUS_FILE' +TOTAL_SHARDS_ENV_VAR = "GTEST_TOTAL_SHARDS" +SHARD_INDEX_ENV_VAR = "GTEST_SHARD_INDEX" +SHARD_STATUS_FILE_ENV_VAR = "GTEST_SHARD_STATUS_FILE" # The environment variable for the test warnings output file. -TEST_WARNINGS_OUTPUT_FILE = 'TEST_WARNINGS_OUTPUT_FILE' +TEST_WARNINGS_OUTPUT_FILE = "TEST_WARNINGS_OUTPUT_FILE" # The command line flag for specifying the test filters. -FILTER_FLAG = 'gtest_filter' +FILTER_FLAG = "gtest_filter" # The command line flag for including disabled tests. -ALSO_RUN_DISABLED_TESTS_FLAG = 'gtest_also_run_disabled_tests' +ALSO_RUN_DISABLED_TESTS_FLAG = "gtest_also_run_disabled_tests" # Command to run the googletest-filter-unittest_ program. -COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-filter-unittest_') +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-filter-unittest_") # Regex for determining whether parameterized tests are enabled in the binary. -PARAM_TEST_REGEX = re.compile(r'/ParamTest') +PARAM_TEST_REGEX = re.compile(r"/ParamTest") # Regex for parsing test case names from Google Test's output. -TEST_CASE_REGEX = re.compile(r'^\[\-+\] \d+ tests? from (\w+(/\w+)?)') +TEST_CASE_REGEX = re.compile(r"^\[\-+\] \d+ tests? from (\w+(/\w+)?)") # Regex for parsing test names from Google Test's output. -TEST_REGEX = re.compile(r'^\[\s*RUN\s*\].*\.(\w+(/\w+)?)') +TEST_REGEX = re.compile(r"^\[\s*RUN\s*\].*\.(\w+(/\w+)?)") # Regex for parsing disabled banner from Google Test's output -DISABLED_BANNER_REGEX = re.compile(r'^\[\s*DISABLED\s*\] (.*)') +DISABLED_BANNER_REGEX = re.compile(r"^\[\s*DISABLED\s*\] (.*)") # The command line flag to tell Google Test to output the list of tests it # will run. -LIST_TESTS_FLAG = '--gtest_list_tests' +LIST_TESTS_FLAG = "--gtest_list_tests" # Indicates whether Google Test supports death tests. SUPPORTS_DEATH_TESTS = ( - 'HasDeathTest' - in gtest_test_utils.Subprocess([COMMAND, LIST_TESTS_FLAG]).output + "HasDeathTest" in gtest_test_utils.Subprocess([COMMAND, LIST_TESTS_FLAG]).output ) # Full names of all tests in googletest-filter-unittests_. PARAM_TESTS = [ - 'SeqP/ParamTest.TestX/0', - 'SeqP/ParamTest.TestX/1', - 'SeqP/ParamTest.TestY/0', - 'SeqP/ParamTest.TestY/1', - 'SeqQ/ParamTest.TestX/0', - 'SeqQ/ParamTest.TestX/1', - 'SeqQ/ParamTest.TestY/0', - 'SeqQ/ParamTest.TestY/1', + "SeqP/ParamTest.TestX/0", + "SeqP/ParamTest.TestX/1", + "SeqP/ParamTest.TestY/0", + "SeqP/ParamTest.TestY/1", + "SeqQ/ParamTest.TestX/0", + "SeqQ/ParamTest.TestX/1", + "SeqQ/ParamTest.TestY/0", + "SeqQ/ParamTest.TestY/1", ] DISABLED_TESTS = [ - 'BarTest.DISABLED_TestFour', - 'BarTest.DISABLED_TestFive', - 'BazTest.DISABLED_TestC', - 'DISABLED_FoobarTest.Test1', - 'DISABLED_FoobarTest.DISABLED_Test2', - 'DISABLED_FoobarbazTest.TestA', + "BarTest.DISABLED_TestFour", + "BarTest.DISABLED_TestFive", + "BazTest.DISABLED_TestC", + "DISABLED_FoobarTest.Test1", + "DISABLED_FoobarTest.DISABLED_Test2", + "DISABLED_FoobarbazTest.TestA", ] if SUPPORTS_DEATH_TESTS: - DEATH_TESTS = [ - 'HasDeathTest.Test1', - 'HasDeathTest.Test2', - ] + DEATH_TESTS = [ + "HasDeathTest.Test1", + "HasDeathTest.Test2", + ] else: - DEATH_TESTS = [] + DEATH_TESTS = [] # All the non-disabled tests. ACTIVE_TESTS = ( [ - 'FooTest.Abc', - 'FooTest.Xyz', - 'BarTest.TestOne', - 'BarTest.TestTwo', - 'BarTest.TestThree', - 'BazTest.TestOne', - 'BazTest.TestA', - 'BazTest.TestB', + "FooTest.Abc", + "FooTest.Xyz", + "BarTest.TestOne", + "BarTest.TestTwo", + "BarTest.TestThree", + "BazTest.TestOne", + "BazTest.TestA", + "BazTest.TestB", ] + DEATH_TESTS + PARAM_TESTS @@ -184,582 +183,580 @@ def SetEnvVar(env_var, value): - """Sets the env variable to 'value'; unsets it when 'value' is None.""" + """Sets the env variable to 'value'; unsets it when 'value' is None.""" - if value is not None: - environ[env_var] = value - elif env_var in environ: - del environ[env_var] + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] def RunAndReturnOutput(args=None): - """Runs the test program and returns its output.""" + """Runs the test program and returns its output.""" - return gtest_test_utils.Subprocess( - [COMMAND] + (args or []), env=environ - ).output + return gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ).output def RunAndExtractTestList(args=None): - """Runs the test program and returns its exit code and a list of tests run.""" - - p = gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ) - tests_run = [] - test_case = '' - test = '' - for line in p.output.split('\n'): - match = TEST_CASE_REGEX.match(line) - if match is not None: - test_case = match.group(1) - else: - match = TEST_REGEX.match(line) - if match is not None: - test = match.group(1) - tests_run.append(test_case + '.' + test) - return (tests_run, p.exit_code) + """Runs the test program and returns its exit code and a list of tests run.""" + + p = gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ) + tests_run = [] + test_case = "" + test = "" + for line in p.output.split("\n"): + match = TEST_CASE_REGEX.match(line) + if match is not None: + test_case = match.group(1) + else: + match = TEST_REGEX.match(line) + if match is not None: + test = match.group(1) + tests_run.append(test_case + "." + test) + return (tests_run, p.exit_code) def RunAndExtractDisabledBannerList(args=None): - """Runs the test program and returns tests that printed a disabled banner.""" - p = gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ) - banners_printed = [] - for line in p.output.split('\n'): - match = DISABLED_BANNER_REGEX.match(line) - if match is not None: - banners_printed.append(match.group(1)) - return banners_printed + """Runs the test program and returns tests that printed a disabled banner.""" + p = gtest_test_utils.Subprocess([COMMAND] + (args or []), env=environ) + banners_printed = [] + for line in p.output.split("\n"): + match = DISABLED_BANNER_REGEX.match(line) + if match is not None: + banners_printed.append(match.group(1)) + return banners_printed def InvokeWithModifiedEnv(extra_env, function, *args, **kwargs): - """Runs the given function and arguments in a modified environment.""" - try: - original_env = environ.copy() - environ.update(extra_env) - return function(*args, **kwargs) - finally: - environ.clear() - environ.update(original_env) + """Runs the given function and arguments in a modified environment.""" + try: + original_env = environ.copy() + environ.update(extra_env) + return function(*args, **kwargs) + finally: + environ.clear() + environ.update(original_env) def RunWithSharding(total_shards, shard_index, command): - """Runs a test program shard and returns exit code and a list of tests run.""" + """Runs a test program shard and returns exit code and a list of tests run.""" - extra_env = { - SHARD_INDEX_ENV_VAR: str(shard_index), - TOTAL_SHARDS_ENV_VAR: str(total_shards), - } - return InvokeWithModifiedEnv(extra_env, RunAndExtractTestList, command) + extra_env = { + SHARD_INDEX_ENV_VAR: str(shard_index), + TOTAL_SHARDS_ENV_VAR: str(total_shards), + } + return InvokeWithModifiedEnv(extra_env, RunAndExtractTestList, command) # The unit test. class GTestFilterUnitTest(gtest_test_utils.TestCase): - """Tests the env variable or the command line flag to filter tests.""" - - # Utilities. - - def AssertSetEqual(self, lhs, rhs): - """Asserts that two sets are equal.""" - - for elem in lhs: - self.assertTrue(elem in rhs, '%s in %s' % (elem, rhs)) - - for elem in rhs: - self.assertTrue(elem in lhs, '%s in %s' % (elem, lhs)) - - def AssertPartitionIsValid(self, set_var, list_of_sets): - """Asserts that list_of_sets is a valid partition of set_var.""" - - full_partition = [] - for slice_var in list_of_sets: - full_partition.extend(slice_var) - self.assertEqual(len(set_var), len(full_partition)) - self.assertEqual(set(set_var), set(full_partition)) - - def AdjustForParameterizedTests(self, tests_to_run): - """Adjust tests_to_run in case value parameterized tests are disabled.""" - - global param_tests_present - if not param_tests_present: - return list(set(tests_to_run) - set(PARAM_TESTS)) - else: - return tests_to_run - - def RunAndVerify(self, gtest_filter, tests_to_run): - """Checks that the binary runs correct set of tests for a given filter.""" - - tests_to_run = self.AdjustForParameterizedTests(tests_to_run) - - # First, tests using the environment variable. - - # Windows removes empty variables from the environment when passing it - # to a new process. This means it is impossible to pass an empty filter - # into a process using the environment variable. However, we can still - # test the case when the variable is not supplied (i.e., gtest_filter is - # None). - # pylint: disable=g-explicit-bool-comparison - if CAN_TEST_EMPTY_FILTER or gtest_filter != '': - SetEnvVar(FILTER_ENV_VAR, gtest_filter) - tests_run = RunAndExtractTestList()[0] - SetEnvVar(FILTER_ENV_VAR, None) - self.AssertSetEqual(tests_run, tests_to_run) - # pylint: enable=g-explicit-bool-comparison - - # Next, tests using the command line flag. - - if gtest_filter is None: - args = [] - else: - args = ['--%s=%s' % (FILTER_FLAG, gtest_filter)] - - tests_run = RunAndExtractTestList(args)[0] - self.AssertSetEqual(tests_run, tests_to_run) - - def RunAndVerifyWithSharding( - self, - gtest_filter, - total_shards, - tests_to_run, - args=None, - check_exit_0=False, - ): - """Checks that binary runs correct tests for the given filter and shard. - - Runs all shards of googletest-filter-unittest_ with the given filter, and - verifies that the right set of tests were run. The union of tests run - on each shard should be identical to tests_to_run, without duplicates. - If check_exit_0, . - - Args: - gtest_filter: A filter to apply to the tests. - total_shards: A total number of shards to split test run into. - tests_to_run: A set of tests expected to run. - args: Arguments to pass to the to the test binary. - check_exit_0: When set to a true value, make sure that all shards return - 0. - """ - - tests_to_run = self.AdjustForParameterizedTests(tests_to_run) - - # Windows removes empty variables from the environment when passing it - # to a new process. This means it is impossible to pass an empty filter - # into a process using the environment variable. However, we can still - # test the case when the variable is not supplied (i.e., gtest_filter is - # None). - # pylint: disable=g-explicit-bool-comparison - if CAN_TEST_EMPTY_FILTER or gtest_filter != '': - SetEnvVar(FILTER_ENV_VAR, gtest_filter) - partition = [] - for i in range(0, total_shards): - (tests_run, exit_code) = RunWithSharding(total_shards, i, args) - if check_exit_0: - self.assertEqual(0, exit_code) - partition.append(tests_run) - - self.AssertPartitionIsValid(tests_to_run, partition) - SetEnvVar(FILTER_ENV_VAR, None) - # pylint: enable=g-explicit-bool-comparison - - def RunAndVerifyAllowingDisabled(self, gtest_filter, tests_to_run): - """Checks that the binary runs correct set of tests for the given filter. - - Runs googletest-filter-unittest_ with the given filter, and enables - disabled tests. Verifies that the right set of tests were run. - - Args: - gtest_filter: A filter to apply to the tests. - tests_to_run: A set of tests expected to run. - """ - - tests_to_run = self.AdjustForParameterizedTests(tests_to_run) - - # Construct the command line. - args = ['--%s' % ALSO_RUN_DISABLED_TESTS_FLAG] - if gtest_filter is not None: - args.append('--%s=%s' % (FILTER_FLAG, gtest_filter)) - - tests_run = RunAndExtractTestList(args)[0] - self.AssertSetEqual(tests_run, tests_to_run) - - def setUp(self): - """Sets up test case. - - Determines whether value-parameterized tests are enabled in the binary and - sets the flags accordingly. - """ - - global param_tests_present - if param_tests_present is None: - param_tests_present = ( - PARAM_TEST_REGEX.search(RunAndReturnOutput()) is not None - ) - - def testDefaultBehavior(self): - """Tests the behavior of not specifying the filter.""" - - self.RunAndVerify(None, ACTIVE_TESTS) - - def testDefaultBehaviorWithShards(self): - """Tests the behavior without the filter, with sharding enabled.""" - - self.RunAndVerifyWithSharding(None, 1, ACTIVE_TESTS) - self.RunAndVerifyWithSharding(None, 2, ACTIVE_TESTS) - self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) - 1, ACTIVE_TESTS) - self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS), ACTIVE_TESTS) - self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) + 1, ACTIVE_TESTS) - - def testEmptyFilter(self): - """Tests an empty filter.""" - - self.RunAndVerify('', []) - self.RunAndVerifyWithSharding('', 1, []) - self.RunAndVerifyWithSharding('', 2, []) - - def testBadFilter(self): - """Tests a filter that matches nothing.""" - - self.RunAndVerify('BadFilter', []) - self.RunAndVerifyAllowingDisabled('BadFilter', []) - - def testBadFilterWithWarningFile(self): - """Tests the warning file when a filter that matches nothing.""" - - warning_file = os.path.join( - gtest_test_utils.GetTempDir(), 'testBadFilterWithWarningFile' - ) - extra_env = {TEST_WARNINGS_OUTPUT_FILE: warning_file} - args = ['--%s=%s' % (FILTER_FLAG, 'BadFilter')] - InvokeWithModifiedEnv(extra_env, RunAndReturnOutput, args) - with open(warning_file, 'r') as f: - warning_file_contents = f.read() - self.assertEqual( - warning_file_contents, - 'filter "BadFilter" did not match any test; no tests were run\n', - ) - - def testFullName(self): - """Tests filtering by full name.""" - - self.RunAndVerify('FooTest.Xyz', ['FooTest.Xyz']) - self.RunAndVerifyAllowingDisabled('FooTest.Xyz', ['FooTest.Xyz']) - self.RunAndVerifyWithSharding('FooTest.Xyz', 5, ['FooTest.Xyz']) - - def testUniversalFilters(self): - """Tests filters that match everything.""" - - self.RunAndVerify('*', ACTIVE_TESTS) - self.RunAndVerify('*.*', ACTIVE_TESTS) - self.RunAndVerifyWithSharding('*.*', len(ACTIVE_TESTS) - 3, ACTIVE_TESTS) - self.RunAndVerifyAllowingDisabled('*', ACTIVE_TESTS + DISABLED_TESTS) - self.RunAndVerifyAllowingDisabled('*.*', ACTIVE_TESTS + DISABLED_TESTS) - - def testFilterByTestCase(self): - """Tests filtering by test case name.""" - - self.RunAndVerify('FooTest.*', ['FooTest.Abc', 'FooTest.Xyz']) - - BAZ_TESTS = ['BazTest.TestOne', 'BazTest.TestA', 'BazTest.TestB'] - self.RunAndVerify('BazTest.*', BAZ_TESTS) - self.RunAndVerifyAllowingDisabled( - 'BazTest.*', BAZ_TESTS + ['BazTest.DISABLED_TestC'] - ) - - def testFilterByTest(self): - """Tests filtering by test name.""" - - self.RunAndVerify('*.TestOne', ['BarTest.TestOne', 'BazTest.TestOne']) - - def testFilterDisabledTests(self): - """Select only the disabled tests to run.""" + """Tests the env variable or the command line flag to filter tests.""" + + # Utilities. + + def AssertSetEqual(self, lhs, rhs): + """Asserts that two sets are equal.""" + + for elem in lhs: + self.assertTrue(elem in rhs, "%s in %s" % (elem, rhs)) + + for elem in rhs: + self.assertTrue(elem in lhs, "%s in %s" % (elem, lhs)) + + def AssertPartitionIsValid(self, set_var, list_of_sets): + """Asserts that list_of_sets is a valid partition of set_var.""" + + full_partition = [] + for slice_var in list_of_sets: + full_partition.extend(slice_var) + self.assertEqual(len(set_var), len(full_partition)) + self.assertEqual(set(set_var), set(full_partition)) + + def AdjustForParameterizedTests(self, tests_to_run): + """Adjust tests_to_run in case value parameterized tests are disabled.""" + + global param_tests_present + if not param_tests_present: + return list(set(tests_to_run) - set(PARAM_TESTS)) + else: + return tests_to_run + + def RunAndVerify(self, gtest_filter, tests_to_run): + """Checks that the binary runs correct set of tests for a given filter.""" + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # First, tests using the environment variable. + + # Windows removes empty variables from the environment when passing it + # to a new process. This means it is impossible to pass an empty filter + # into a process using the environment variable. However, we can still + # test the case when the variable is not supplied (i.e., gtest_filter is + # None). + # pylint: disable=g-explicit-bool-comparison + if CAN_TEST_EMPTY_FILTER or gtest_filter != "": + SetEnvVar(FILTER_ENV_VAR, gtest_filter) + tests_run = RunAndExtractTestList()[0] + SetEnvVar(FILTER_ENV_VAR, None) + self.AssertSetEqual(tests_run, tests_to_run) + # pylint: enable=g-explicit-bool-comparison + + # Next, tests using the command line flag. + + if gtest_filter is None: + args = [] + else: + args = ["--%s=%s" % (FILTER_FLAG, gtest_filter)] + + tests_run = RunAndExtractTestList(args)[0] + self.AssertSetEqual(tests_run, tests_to_run) + + def RunAndVerifyWithSharding( + self, + gtest_filter, + total_shards, + tests_to_run, + args=None, + check_exit_0=False, + ): + """Checks that binary runs correct tests for the given filter and shard. + + Runs all shards of googletest-filter-unittest_ with the given filter, and + verifies that the right set of tests were run. The union of tests run + on each shard should be identical to tests_to_run, without duplicates. + If check_exit_0, . + + Args: + gtest_filter: A filter to apply to the tests. + total_shards: A total number of shards to split test run into. + tests_to_run: A set of tests expected to run. + args: Arguments to pass to the to the test binary. + check_exit_0: When set to a true value, make sure that all shards return + 0. + """ + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # Windows removes empty variables from the environment when passing it + # to a new process. This means it is impossible to pass an empty filter + # into a process using the environment variable. However, we can still + # test the case when the variable is not supplied (i.e., gtest_filter is + # None). + # pylint: disable=g-explicit-bool-comparison + if CAN_TEST_EMPTY_FILTER or gtest_filter != "": + SetEnvVar(FILTER_ENV_VAR, gtest_filter) + partition = [] + for i in range(0, total_shards): + (tests_run, exit_code) = RunWithSharding(total_shards, i, args) + if check_exit_0: + self.assertEqual(0, exit_code) + partition.append(tests_run) + + self.AssertPartitionIsValid(tests_to_run, partition) + SetEnvVar(FILTER_ENV_VAR, None) + # pylint: enable=g-explicit-bool-comparison + + def RunAndVerifyAllowingDisabled(self, gtest_filter, tests_to_run): + """Checks that the binary runs correct set of tests for the given filter. + + Runs googletest-filter-unittest_ with the given filter, and enables + disabled tests. Verifies that the right set of tests were run. + + Args: + gtest_filter: A filter to apply to the tests. + tests_to_run: A set of tests expected to run. + """ + + tests_to_run = self.AdjustForParameterizedTests(tests_to_run) + + # Construct the command line. + args = ["--%s" % ALSO_RUN_DISABLED_TESTS_FLAG] + if gtest_filter is not None: + args.append("--%s=%s" % (FILTER_FLAG, gtest_filter)) + + tests_run = RunAndExtractTestList(args)[0] + self.AssertSetEqual(tests_run, tests_to_run) + + def setUp(self): + """Sets up test case. + + Determines whether value-parameterized tests are enabled in the binary and + sets the flags accordingly. + """ + + global param_tests_present + if param_tests_present is None: + param_tests_present = ( + PARAM_TEST_REGEX.search(RunAndReturnOutput()) is not None + ) + + def testDefaultBehavior(self): + """Tests the behavior of not specifying the filter.""" + + self.RunAndVerify(None, ACTIVE_TESTS) + + def testDefaultBehaviorWithShards(self): + """Tests the behavior without the filter, with sharding enabled.""" + + self.RunAndVerifyWithSharding(None, 1, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, 2, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) - 1, ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS), ACTIVE_TESTS) + self.RunAndVerifyWithSharding(None, len(ACTIVE_TESTS) + 1, ACTIVE_TESTS) + + def testEmptyFilter(self): + """Tests an empty filter.""" + + self.RunAndVerify("", []) + self.RunAndVerifyWithSharding("", 1, []) + self.RunAndVerifyWithSharding("", 2, []) + + def testBadFilter(self): + """Tests a filter that matches nothing.""" + + self.RunAndVerify("BadFilter", []) + self.RunAndVerifyAllowingDisabled("BadFilter", []) + + def testBadFilterWithWarningFile(self): + """Tests the warning file when a filter that matches nothing.""" + + warning_file = os.path.join( + gtest_test_utils.GetTempDir(), "testBadFilterWithWarningFile" + ) + extra_env = {TEST_WARNINGS_OUTPUT_FILE: warning_file} + args = ["--%s=%s" % (FILTER_FLAG, "BadFilter")] + InvokeWithModifiedEnv(extra_env, RunAndReturnOutput, args) + with open(warning_file, "r") as f: + warning_file_contents = f.read() + self.assertEqual( + warning_file_contents, + 'filter "BadFilter" did not match any test; no tests were run\n', + ) + + def testFullName(self): + """Tests filtering by full name.""" + + self.RunAndVerify("FooTest.Xyz", ["FooTest.Xyz"]) + self.RunAndVerifyAllowingDisabled("FooTest.Xyz", ["FooTest.Xyz"]) + self.RunAndVerifyWithSharding("FooTest.Xyz", 5, ["FooTest.Xyz"]) + + def testUniversalFilters(self): + """Tests filters that match everything.""" + + self.RunAndVerify("*", ACTIVE_TESTS) + self.RunAndVerify("*.*", ACTIVE_TESTS) + self.RunAndVerifyWithSharding("*.*", len(ACTIVE_TESTS) - 3, ACTIVE_TESTS) + self.RunAndVerifyAllowingDisabled("*", ACTIVE_TESTS + DISABLED_TESTS) + self.RunAndVerifyAllowingDisabled("*.*", ACTIVE_TESTS + DISABLED_TESTS) + + def testFilterByTestCase(self): + """Tests filtering by test case name.""" + + self.RunAndVerify("FooTest.*", ["FooTest.Abc", "FooTest.Xyz"]) + + BAZ_TESTS = ["BazTest.TestOne", "BazTest.TestA", "BazTest.TestB"] + self.RunAndVerify("BazTest.*", BAZ_TESTS) + self.RunAndVerifyAllowingDisabled( + "BazTest.*", BAZ_TESTS + ["BazTest.DISABLED_TestC"] + ) - self.RunAndVerify('DISABLED_FoobarTest.Test1', []) - self.RunAndVerifyAllowingDisabled( - 'DISABLED_FoobarTest.Test1', ['DISABLED_FoobarTest.Test1'] - ) + def testFilterByTest(self): + """Tests filtering by test name.""" - self.RunAndVerify('*DISABLED_*', []) - self.RunAndVerifyAllowingDisabled('*DISABLED_*', DISABLED_TESTS) - - self.RunAndVerify('*.DISABLED_*', []) - self.RunAndVerifyAllowingDisabled( - '*.DISABLED_*', - [ - 'BarTest.DISABLED_TestFour', - 'BarTest.DISABLED_TestFive', - 'BazTest.DISABLED_TestC', - 'DISABLED_FoobarTest.DISABLED_Test2', - ], - ) + self.RunAndVerify("*.TestOne", ["BarTest.TestOne", "BazTest.TestOne"]) - self.RunAndVerify('DISABLED_*', []) - self.RunAndVerifyAllowingDisabled( - 'DISABLED_*', - [ - 'DISABLED_FoobarTest.Test1', - 'DISABLED_FoobarTest.DISABLED_Test2', - 'DISABLED_FoobarbazTest.TestA', - ], - ) + def testFilterDisabledTests(self): + """Select only the disabled tests to run.""" - def testWildcardInTestCaseName(self): - """Tests using wildcard in the test case name.""" - - self.RunAndVerify( - '*a*.*', - [ - 'BarTest.TestOne', - 'BarTest.TestTwo', - 'BarTest.TestThree', - 'BazTest.TestOne', - 'BazTest.TestA', - 'BazTest.TestB', - ] - + DEATH_TESTS - + PARAM_TESTS, - ) + self.RunAndVerify("DISABLED_FoobarTest.Test1", []) + self.RunAndVerifyAllowingDisabled( + "DISABLED_FoobarTest.Test1", ["DISABLED_FoobarTest.Test1"] + ) - def testWildcardInTestName(self): - """Tests using wildcard in the test name.""" + self.RunAndVerify("*DISABLED_*", []) + self.RunAndVerifyAllowingDisabled("*DISABLED_*", DISABLED_TESTS) + + self.RunAndVerify("*.DISABLED_*", []) + self.RunAndVerifyAllowingDisabled( + "*.DISABLED_*", + [ + "BarTest.DISABLED_TestFour", + "BarTest.DISABLED_TestFive", + "BazTest.DISABLED_TestC", + "DISABLED_FoobarTest.DISABLED_Test2", + ], + ) - self.RunAndVerify('*.*A*', ['FooTest.Abc', 'BazTest.TestA']) + self.RunAndVerify("DISABLED_*", []) + self.RunAndVerifyAllowingDisabled( + "DISABLED_*", + [ + "DISABLED_FoobarTest.Test1", + "DISABLED_FoobarTest.DISABLED_Test2", + "DISABLED_FoobarbazTest.TestA", + ], + ) - def testFilterWithoutDot(self): - """Tests a filter that has no '.' in it.""" + def testWildcardInTestCaseName(self): + """Tests using wildcard in the test case name.""" + + self.RunAndVerify( + "*a*.*", + [ + "BarTest.TestOne", + "BarTest.TestTwo", + "BarTest.TestThree", + "BazTest.TestOne", + "BazTest.TestA", + "BazTest.TestB", + ] + + DEATH_TESTS + + PARAM_TESTS, + ) - self.RunAndVerify( - '*z*', - [ - 'FooTest.Xyz', - 'BazTest.TestOne', - 'BazTest.TestA', - 'BazTest.TestB', - ], - ) + def testWildcardInTestName(self): + """Tests using wildcard in the test name.""" - def testTwoPatterns(self): - """Tests filters that consist of two patterns.""" + self.RunAndVerify("*.*A*", ["FooTest.Abc", "BazTest.TestA"]) - self.RunAndVerify( - 'Foo*.*:*A*', - [ - 'FooTest.Abc', - 'FooTest.Xyz', - 'BazTest.TestA', - ], - ) + def testFilterWithoutDot(self): + """Tests a filter that has no '.' in it.""" - # An empty pattern + a non-empty one - self.RunAndVerify(':*A*', ['FooTest.Abc', 'BazTest.TestA']) - - def testThreePatterns(self): - """Tests filters that consist of three patterns.""" - - self.RunAndVerify( - '*oo*:*A*:*One', - [ - 'FooTest.Abc', - 'FooTest.Xyz', - 'BarTest.TestOne', - 'BazTest.TestOne', - 'BazTest.TestA', - ], - ) + self.RunAndVerify( + "*z*", + [ + "FooTest.Xyz", + "BazTest.TestOne", + "BazTest.TestA", + "BazTest.TestB", + ], + ) - # The 2nd pattern is empty. - self.RunAndVerify( - '*oo*::*One', - [ - 'FooTest.Abc', - 'FooTest.Xyz', - 'BarTest.TestOne', - 'BazTest.TestOne', - ], - ) + def testTwoPatterns(self): + """Tests filters that consist of two patterns.""" - # The last 2 patterns are empty. - self.RunAndVerify( - '*oo*::', - [ - 'FooTest.Abc', - 'FooTest.Xyz', - ], - ) + self.RunAndVerify( + "Foo*.*:*A*", + [ + "FooTest.Abc", + "FooTest.Xyz", + "BazTest.TestA", + ], + ) - def testNegativeFilters(self): - self.RunAndVerify( - '*-BazTest.TestOne', - [ - 'FooTest.Abc', - 'FooTest.Xyz', - 'BarTest.TestOne', - 'BarTest.TestTwo', - 'BarTest.TestThree', - 'BazTest.TestA', - 'BazTest.TestB', - ] - + DEATH_TESTS - + PARAM_TESTS, - ) + # An empty pattern + a non-empty one + self.RunAndVerify(":*A*", ["FooTest.Abc", "BazTest.TestA"]) + + def testThreePatterns(self): + """Tests filters that consist of three patterns.""" + + self.RunAndVerify( + "*oo*:*A*:*One", + [ + "FooTest.Abc", + "FooTest.Xyz", + "BarTest.TestOne", + "BazTest.TestOne", + "BazTest.TestA", + ], + ) - self.RunAndVerify( - '*-FooTest.Abc:BazTest.*', - [ - 'FooTest.Xyz', - 'BarTest.TestOne', - 'BarTest.TestTwo', - 'BarTest.TestThree', - ] - + DEATH_TESTS - + PARAM_TESTS, - ) + # The 2nd pattern is empty. + self.RunAndVerify( + "*oo*::*One", + [ + "FooTest.Abc", + "FooTest.Xyz", + "BarTest.TestOne", + "BazTest.TestOne", + ], + ) - self.RunAndVerify( - 'BarTest.*-BarTest.TestOne', - [ - 'BarTest.TestTwo', - 'BarTest.TestThree', - ], - ) + # The last 2 patterns are empty. + self.RunAndVerify( + "*oo*::", + [ + "FooTest.Abc", + "FooTest.Xyz", + ], + ) - # Tests without leading '*'. - self.RunAndVerify( - '-FooTest.Abc:FooTest.Xyz:BazTest.*', - [ - 'BarTest.TestOne', - 'BarTest.TestTwo', - 'BarTest.TestThree', - ] - + DEATH_TESTS - + PARAM_TESTS, - ) + def testNegativeFilters(self): + self.RunAndVerify( + "*-BazTest.TestOne", + [ + "FooTest.Abc", + "FooTest.Xyz", + "BarTest.TestOne", + "BarTest.TestTwo", + "BarTest.TestThree", + "BazTest.TestA", + "BazTest.TestB", + ] + + DEATH_TESTS + + PARAM_TESTS, + ) - # Value parameterized tests. - self.RunAndVerify('*/*', PARAM_TESTS) - - # Value parameterized tests filtering by the sequence name. - self.RunAndVerify( - 'SeqP/*', - [ - 'SeqP/ParamTest.TestX/0', - 'SeqP/ParamTest.TestX/1', - 'SeqP/ParamTest.TestY/0', - 'SeqP/ParamTest.TestY/1', - ], - ) + self.RunAndVerify( + "*-FooTest.Abc:BazTest.*", + [ + "FooTest.Xyz", + "BarTest.TestOne", + "BarTest.TestTwo", + "BarTest.TestThree", + ] + + DEATH_TESTS + + PARAM_TESTS, + ) - # Value parameterized tests filtering by the test name. - self.RunAndVerify( - '*/0', - [ - 'SeqP/ParamTest.TestX/0', - 'SeqP/ParamTest.TestY/0', - 'SeqQ/ParamTest.TestX/0', - 'SeqQ/ParamTest.TestY/0', - ], - ) + self.RunAndVerify( + "BarTest.*-BarTest.TestOne", + [ + "BarTest.TestTwo", + "BarTest.TestThree", + ], + ) - def testFlagOverridesEnvVar(self): - """Tests that the filter flag overrides the filtering env. variable.""" + # Tests without leading '*'. + self.RunAndVerify( + "-FooTest.Abc:FooTest.Xyz:BazTest.*", + [ + "BarTest.TestOne", + "BarTest.TestTwo", + "BarTest.TestThree", + ] + + DEATH_TESTS + + PARAM_TESTS, + ) - SetEnvVar(FILTER_ENV_VAR, 'Foo*') - args = ['--%s=%s' % (FILTER_FLAG, '*One')] - tests_run = RunAndExtractTestList(args)[0] - SetEnvVar(FILTER_ENV_VAR, None) + # Value parameterized tests. + self.RunAndVerify("*/*", PARAM_TESTS) + + # Value parameterized tests filtering by the sequence name. + self.RunAndVerify( + "SeqP/*", + [ + "SeqP/ParamTest.TestX/0", + "SeqP/ParamTest.TestX/1", + "SeqP/ParamTest.TestY/0", + "SeqP/ParamTest.TestY/1", + ], + ) - self.AssertSetEqual(tests_run, ['BarTest.TestOne', 'BazTest.TestOne']) + # Value parameterized tests filtering by the test name. + self.RunAndVerify( + "*/0", + [ + "SeqP/ParamTest.TestX/0", + "SeqP/ParamTest.TestY/0", + "SeqQ/ParamTest.TestX/0", + "SeqQ/ParamTest.TestY/0", + ], + ) - def testShardStatusFileIsCreated(self): - """Tests that the shard file is created if specified in the environment.""" + def testFlagOverridesEnvVar(self): + """Tests that the filter flag overrides the filtering env. variable.""" - shard_status_file = os.path.join( - gtest_test_utils.GetTempDir(), 'shard_status_file' - ) - self.assertTrue(not os.path.exists(shard_status_file)) + SetEnvVar(FILTER_ENV_VAR, "Foo*") + args = ["--%s=%s" % (FILTER_FLAG, "*One")] + tests_run = RunAndExtractTestList(args)[0] + SetEnvVar(FILTER_ENV_VAR, None) - extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} - try: - InvokeWithModifiedEnv(extra_env, RunAndReturnOutput) - finally: - self.assertTrue(os.path.exists(shard_status_file)) - os.remove(shard_status_file) + self.AssertSetEqual(tests_run, ["BarTest.TestOne", "BazTest.TestOne"]) - def testShardStatusFileIsCreatedWithListTests(self): - """Tests that the shard file is created with the "list_tests" flag.""" + def testShardStatusFileIsCreated(self): + """Tests that the shard file is created if specified in the environment.""" - shard_status_file = os.path.join( - gtest_test_utils.GetTempDir(), 'shard_status_file2' - ) - self.assertTrue(not os.path.exists(shard_status_file)) + shard_status_file = os.path.join( + gtest_test_utils.GetTempDir(), "shard_status_file" + ) + self.assertTrue(not os.path.exists(shard_status_file)) - extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} - try: - output = InvokeWithModifiedEnv( - extra_env, RunAndReturnOutput, [LIST_TESTS_FLAG] - ) - finally: - # This assertion ensures that Google Test enumerated the tests as - # opposed to running them. - self.assertTrue( - '[==========]' not in output, - ( - 'Unexpected output during test enumeration.\n' - 'Please ensure that LIST_TESTS_FLAG is assigned the\n' - 'correct flag value for listing Google Test tests.' - ), - ) - - self.assertTrue(os.path.exists(shard_status_file)) - os.remove(shard_status_file) - - def testDisabledBanner(self): - """Tests that the disabled banner prints only tests that match filter.""" - make_filter = lambda s: ['--%s=%s' % (FILTER_FLAG, s)] - - banners = RunAndExtractDisabledBannerList(make_filter('*')) - self.AssertSetEqual( - banners, - [ - 'BarTest.DISABLED_TestFour', - 'BarTest.DISABLED_TestFive', - 'BazTest.DISABLED_TestC', - ], - ) + extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} + try: + InvokeWithModifiedEnv(extra_env, RunAndReturnOutput) + finally: + self.assertTrue(os.path.exists(shard_status_file)) + os.remove(shard_status_file) - banners = RunAndExtractDisabledBannerList(make_filter('Bar*')) - self.AssertSetEqual( - banners, ['BarTest.DISABLED_TestFour', 'BarTest.DISABLED_TestFive'] - ) + def testShardStatusFileIsCreatedWithListTests(self): + """Tests that the shard file is created with the "list_tests" flag.""" - banners = RunAndExtractDisabledBannerList(make_filter('*-Bar*')) - self.AssertSetEqual(banners, ['BazTest.DISABLED_TestC']) - - if SUPPORTS_DEATH_TESTS: - - def testShardingWorksWithDeathTests(self): - """Tests integration with death tests and sharding.""" - - gtest_filter = 'HasDeathTest.*:SeqP/*' - expected_tests = [ - 'HasDeathTest.Test1', - 'HasDeathTest.Test2', - 'SeqP/ParamTest.TestX/0', - 'SeqP/ParamTest.TestX/1', - 'SeqP/ParamTest.TestY/0', - 'SeqP/ParamTest.TestY/1', - ] - - for flag in [ - '--gtest_death_test_style=threadsafe', - '--gtest_death_test_style=fast', - ]: - self.RunAndVerifyWithSharding( - gtest_filter, 3, expected_tests, check_exit_0=True, args=[flag] + shard_status_file = os.path.join( + gtest_test_utils.GetTempDir(), "shard_status_file2" ) - self.RunAndVerifyWithSharding( - gtest_filter, 5, expected_tests, check_exit_0=True, args=[flag] + self.assertTrue(not os.path.exists(shard_status_file)) + + extra_env = {SHARD_STATUS_FILE_ENV_VAR: shard_status_file} + try: + output = InvokeWithModifiedEnv( + extra_env, RunAndReturnOutput, [LIST_TESTS_FLAG] + ) + finally: + # This assertion ensures that Google Test enumerated the tests as + # opposed to running them. + self.assertTrue( + "[==========]" not in output, + ( + "Unexpected output during test enumeration.\n" + "Please ensure that LIST_TESTS_FLAG is assigned the\n" + "correct flag value for listing Google Test tests." + ), + ) + + self.assertTrue(os.path.exists(shard_status_file)) + os.remove(shard_status_file) + + def testDisabledBanner(self): + """Tests that the disabled banner prints only tests that match filter.""" + make_filter = lambda s: ["--%s=%s" % (FILTER_FLAG, s)] + + banners = RunAndExtractDisabledBannerList(make_filter("*")) + self.AssertSetEqual( + banners, + [ + "BarTest.DISABLED_TestFour", + "BarTest.DISABLED_TestFive", + "BazTest.DISABLED_TestC", + ], ) + banners = RunAndExtractDisabledBannerList(make_filter("Bar*")) + self.AssertSetEqual( + banners, ["BarTest.DISABLED_TestFour", "BarTest.DISABLED_TestFive"] + ) -if __name__ == '__main__': - gtest_test_utils.Main() + banners = RunAndExtractDisabledBannerList(make_filter("*-Bar*")) + self.AssertSetEqual(banners, ["BazTest.DISABLED_TestC"]) + + if SUPPORTS_DEATH_TESTS: + + def testShardingWorksWithDeathTests(self): + """Tests integration with death tests and sharding.""" + + gtest_filter = "HasDeathTest.*:SeqP/*" + expected_tests = [ + "HasDeathTest.Test1", + "HasDeathTest.Test2", + "SeqP/ParamTest.TestX/0", + "SeqP/ParamTest.TestX/1", + "SeqP/ParamTest.TestY/0", + "SeqP/ParamTest.TestY/1", + ] + + for flag in [ + "--gtest_death_test_style=threadsafe", + "--gtest_death_test_style=fast", + ]: + self.RunAndVerifyWithSharding( + gtest_filter, 3, expected_tests, check_exit_0=True, args=[flag] + ) + self.RunAndVerifyWithSharding( + gtest_filter, 5, expected_tests, check_exit_0=True, args=[flag] + ) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-global-environment-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-global-environment-unittest.py index bd73a2e1..243476cf 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-global-environment-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-global-environment-unittest.py @@ -40,102 +40,102 @@ def RunAndReturnOutput(args=None): - """Runs the test program and returns its output.""" + """Runs the test program and returns its output.""" - return gtest_test_utils.Subprocess( - [ - gtest_test_utils.GetTestExecutablePath( - 'googletest-global-environment-unittest_' - ) - ] - + (args or []) - ).output + return gtest_test_utils.Subprocess( + [ + gtest_test_utils.GetTestExecutablePath( + "googletest-global-environment-unittest_" + ) + ] + + (args or []) + ).output class GTestGlobalEnvironmentUnitTest(gtest_test_utils.TestCase): - """Tests global test environment failures.""" - - def testEnvironmentSetUpFails(self): - """Tests the behavior of not specifying the fail_fast.""" - - # Run the test. - txt = RunAndReturnOutput() - - # We should see the text of the global environment setup error. - self.assertIn('Canned environment setup error', txt) - - # Our test should have been skipped due to the error, and not treated as a - # pass. - self.assertIn('[ SKIPPED ] 1 test', txt) - self.assertIn('[ PASSED ] 0 tests', txt) - - # The test case shouldn't have been run. - self.assertNotIn('Unexpected call', txt) - - def testEnvironmentSetUpAndTornDownForEachRepeat(self): - """Tests the behavior of test environments and gtest_repeat.""" - - # When --gtest_recreate_environments_when_repeating is true, the global test - # environment should be set up and torn down for each iteration. - txt = RunAndReturnOutput([ - '--gtest_repeat=2', - '--gtest_recreate_environments_when_repeating=true', - ]) - - expected_pattern = ( - '(.|\n)*' - r'Repeating all tests \(iteration 1\)' - '(.|\n)*' - 'Global test environment set-up.' - '(.|\n)*' - 'SomeTest.DoesFoo' - '(.|\n)*' - 'Global test environment tear-down' - '(.|\n)*' - r'Repeating all tests \(iteration 2\)' - '(.|\n)*' - 'Global test environment set-up.' - '(.|\n)*' - 'SomeTest.DoesFoo' - '(.|\n)*' - 'Global test environment tear-down' - '(.|\n)*' - ) - self.assertRegex(txt, expected_pattern) - - def testEnvironmentSetUpAndTornDownOnce(self): - """Tests environment and --gtest_recreate_environments_when_repeating.""" - - # By default the environment should only be set up and torn down once, at - # the start and end of the test respectively. - txt = RunAndReturnOutput( - [ - '--gtest_repeat=2', - ] - ) - - expected_pattern = ( - '(.|\n)*' - r'Repeating all tests \(iteration 1\)' - '(.|\n)*' - 'Global test environment set-up.' - '(.|\n)*' - 'SomeTest.DoesFoo' - '(.|\n)*' - r'Repeating all tests \(iteration 2\)' - '(.|\n)*' - 'SomeTest.DoesFoo' - '(.|\n)*' - 'Global test environment tear-down' - '(.|\n)*' - ) - self.assertRegex(txt, expected_pattern) - - self.assertEqual(len(re.findall('Global test environment set-up', txt)), 1) - self.assertEqual( - len(re.findall('Global test environment tear-down', txt)), 1 - ) - - -if __name__ == '__main__': - gtest_test_utils.Main() + """Tests global test environment failures.""" + + def testEnvironmentSetUpFails(self): + """Tests the behavior of not specifying the fail_fast.""" + + # Run the test. + txt = RunAndReturnOutput() + + # We should see the text of the global environment setup error. + self.assertIn("Canned environment setup error", txt) + + # Our test should have been skipped due to the error, and not treated as a + # pass. + self.assertIn("[ SKIPPED ] 1 test", txt) + self.assertIn("[ PASSED ] 0 tests", txt) + + # The test case shouldn't have been run. + self.assertNotIn("Unexpected call", txt) + + def testEnvironmentSetUpAndTornDownForEachRepeat(self): + """Tests the behavior of test environments and gtest_repeat.""" + + # When --gtest_recreate_environments_when_repeating is true, the global test + # environment should be set up and torn down for each iteration. + txt = RunAndReturnOutput( + [ + "--gtest_repeat=2", + "--gtest_recreate_environments_when_repeating=true", + ] + ) + + expected_pattern = ( + "(.|\n)*" + r"Repeating all tests \(iteration 1\)" + "(.|\n)*" + "Global test environment set-up." + "(.|\n)*" + "SomeTest.DoesFoo" + "(.|\n)*" + "Global test environment tear-down" + "(.|\n)*" + r"Repeating all tests \(iteration 2\)" + "(.|\n)*" + "Global test environment set-up." + "(.|\n)*" + "SomeTest.DoesFoo" + "(.|\n)*" + "Global test environment tear-down" + "(.|\n)*" + ) + self.assertRegex(txt, expected_pattern) + + def testEnvironmentSetUpAndTornDownOnce(self): + """Tests environment and --gtest_recreate_environments_when_repeating.""" + + # By default the environment should only be set up and torn down once, at + # the start and end of the test respectively. + txt = RunAndReturnOutput( + [ + "--gtest_repeat=2", + ] + ) + + expected_pattern = ( + "(.|\n)*" + r"Repeating all tests \(iteration 1\)" + "(.|\n)*" + "Global test environment set-up." + "(.|\n)*" + "SomeTest.DoesFoo" + "(.|\n)*" + r"Repeating all tests \(iteration 2\)" + "(.|\n)*" + "SomeTest.DoesFoo" + "(.|\n)*" + "Global test environment tear-down" + "(.|\n)*" + ) + self.assertRegex(txt, expected_pattern) + + self.assertEqual(len(re.findall("Global test environment set-up", txt)), 1) + self.assertEqual(len(re.findall("Global test environment tear-down", txt)), 1) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-outfiles-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-outfiles-test.py index 5626004e..bb279d6b 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-outfiles-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-outfiles-test.py @@ -35,146 +35,154 @@ from googletest.test import gtest_json_test_utils from googletest.test import gtest_test_utils -GTEST_OUTPUT_SUBDIR = 'json_outfiles' -GTEST_OUTPUT_1_TEST = 'gtest_xml_outfile1_test_' -GTEST_OUTPUT_2_TEST = 'gtest_xml_outfile2_test_' +GTEST_OUTPUT_SUBDIR = "json_outfiles" +GTEST_OUTPUT_1_TEST = "gtest_xml_outfile1_test_" +GTEST_OUTPUT_2_TEST = "gtest_xml_outfile2_test_" EXPECTED_1 = { - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'name': 'AllTests', - 'testsuites': [{ - 'name': 'PropertyOne', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'TestSomeProperties', - 'file': 'gtest_xml_outfile1_test_.cc', - 'line': 41, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'PropertyOne', - 'SetUpProp': '1', - 'TestSomeProperty': '1', - 'TearDownProp': '1', - }], - }], + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "name": "AllTests", + "testsuites": [ + { + "name": "PropertyOne", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "TestSomeProperties", + "file": "gtest_xml_outfile1_test_.cc", + "line": 41, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "PropertyOne", + "SetUpProp": "1", + "TestSomeProperty": "1", + "TearDownProp": "1", + } + ], + } + ], } EXPECTED_2 = { - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'name': 'AllTests', - 'testsuites': [{ - 'name': 'PropertyTwo', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'TestInt64ConvertibleProperties', - 'file': 'gtest_xml_outfile2_test_.cc', - 'line': 43, - 'status': 'RUN', - 'result': 'COMPLETED', - 'timestamp': '*', - 'time': '*', - 'classname': 'PropertyTwo', - 'SetUpProp': '2', - 'TestFloatProperty': '3.25', - 'TestDoubleProperty': '4.75', - 'TestSizetProperty': '5', - 'TestBoolProperty': 'true', - 'TestCharProperty': 'A', - 'TestInt16Property': '6', - 'TestInt32Property': '7', - 'TestInt64Property': '8', - 'TestEnumProperty': '9', - 'TestAtomicIntProperty': '10', - 'TearDownProp': '2', - }], - }], + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "name": "AllTests", + "testsuites": [ + { + "name": "PropertyTwo", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "TestInt64ConvertibleProperties", + "file": "gtest_xml_outfile2_test_.cc", + "line": 43, + "status": "RUN", + "result": "COMPLETED", + "timestamp": "*", + "time": "*", + "classname": "PropertyTwo", + "SetUpProp": "2", + "TestFloatProperty": "3.25", + "TestDoubleProperty": "4.75", + "TestSizetProperty": "5", + "TestBoolProperty": "true", + "TestCharProperty": "A", + "TestInt16Property": "6", + "TestInt32Property": "7", + "TestInt64Property": "8", + "TestEnumProperty": "9", + "TestAtomicIntProperty": "10", + "TearDownProp": "2", + } + ], + } + ], } class GTestJsonOutFilesTest(gtest_test_utils.TestCase): - """Unit test for Google Test's JSON output functionality.""" - - def setUp(self): - # We want the trailing '/' that the last "" provides in os.path.join, for - # telling Google Test to create an output directory instead of a single file - # for xml output. - self.output_dir_ = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_OUTPUT_SUBDIR, '' - ) - self.DeleteFilesAndDir() - - def tearDown(self): - self.DeleteFilesAndDir() - - def DeleteFilesAndDir(self): - try: - os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + '.json')) - except os.error: - pass - try: - os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + '.json')) - except os.error: - pass - try: - os.rmdir(self.output_dir_) - except os.error: - pass - - def testOutfile1(self): - self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_1) - - def testOutfile2(self): - self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_2) - - def _TestOutFile(self, test_name, expected): - gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) - command = [gtest_prog_path, '--gtest_output=json:%s' % self.output_dir_] - p = gtest_test_utils.Subprocess( - command, working_dir=gtest_test_utils.GetTempDir() - ) - self.assertTrue(p.exited) - self.assertEqual(0, p.exit_code) - - output_file_name1 = test_name + '.json' - output_file1 = os.path.join(self.output_dir_, output_file_name1) - output_file_name2 = 'lt-' + output_file_name1 - output_file2 = os.path.join(self.output_dir_, output_file_name2) - self.assertTrue( - os.path.isfile(output_file1) or os.path.isfile(output_file2), - output_file1, - ) - - if os.path.isfile(output_file1): - with open(output_file1) as f: - actual = json.load(f) - else: - with open(output_file2) as f: - actual = json.load(f) - self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) - - -if __name__ == '__main__': - os.environ['GTEST_STACK_TRACE_DEPTH'] = '0' - gtest_test_utils.Main() + """Unit test for Google Test's JSON output functionality.""" + + def setUp(self): + # We want the trailing '/' that the last "" provides in os.path.join, for + # telling Google Test to create an output directory instead of a single file + # for xml output. + self.output_dir_ = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_OUTPUT_SUBDIR, "" + ) + self.DeleteFilesAndDir() + + def tearDown(self): + self.DeleteFilesAndDir() + + def DeleteFilesAndDir(self): + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + ".json")) + except os.error: + pass + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + ".json")) + except os.error: + pass + try: + os.rmdir(self.output_dir_) + except os.error: + pass + + def testOutfile1(self): + self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_1) + + def testOutfile2(self): + self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_2) + + def _TestOutFile(self, test_name, expected): + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) + command = [gtest_prog_path, "--gtest_output=json:%s" % self.output_dir_] + p = gtest_test_utils.Subprocess( + command, working_dir=gtest_test_utils.GetTempDir() + ) + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + + output_file_name1 = test_name + ".json" + output_file1 = os.path.join(self.output_dir_, output_file_name1) + output_file_name2 = "lt-" + output_file_name1 + output_file2 = os.path.join(self.output_dir_, output_file_name2) + self.assertTrue( + os.path.isfile(output_file1) or os.path.isfile(output_file2), + output_file1, + ) + + if os.path.isfile(output_file1): + with open(output_file1) as f: + actual = json.load(f) + else: + with open(output_file2) as f: + actual = json.load(f) + self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) + + +if __name__ == "__main__": + os.environ["GTEST_STACK_TRACE_DEPTH"] = "0" + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-output-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-output-unittest.py index c75051c8..b521a83c 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-output-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-json-output-unittest.py @@ -40,596 +40,617 @@ from googletest.test import gtest_json_test_utils from googletest.test import gtest_test_utils -GTEST_FILTER_FLAG = '--gtest_filter' -GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' -GTEST_OUTPUT_FLAG = '--gtest_output' -GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.json' -GTEST_PROGRAM_NAME = 'gtest_xml_output_unittest_' +GTEST_FILTER_FLAG = "--gtest_filter" +GTEST_LIST_TESTS_FLAG = "--gtest_list_tests" +GTEST_OUTPUT_FLAG = "--gtest_output" +GTEST_DEFAULT_OUTPUT_FILE = "test_detail.json" +GTEST_PROGRAM_NAME = "gtest_xml_output_unittest_" # The flag indicating stacktraces are not supported -NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' +NO_STACKTRACE_SUPPORT_FLAG = "--no_stacktrace_support" SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv if SUPPORTS_STACK_TRACES: - STACK_TRACE_TEMPLATE = '\nStack trace:\n*' + STACK_TRACE_TEMPLATE = "\nStack trace:\n*" else: - STACK_TRACE_TEMPLATE = '\n' + STACK_TRACE_TEMPLATE = "\n" EXPECTED_NON_EMPTY = { - 'tests': 28, - 'failures': 5, - 'disabled': 2, - 'errors': 0, - 'timestamp': '*', - 'time': '*', - 'ad_hoc_property': '42', - 'name': 'AllTests', - 'testsuites': [ + "tests": 28, + "failures": 5, + "disabled": 2, + "errors": 0, + "timestamp": "*", + "time": "*", + "ad_hoc_property": "42", + "name": "AllTests", + "testsuites": [ { - 'name': 'SuccessfulTest', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'Succeeds', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 53, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'SuccessfulTest', - }], + "name": "SuccessfulTest", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "Succeeds", + "file": "gtest_xml_output_unittest_.cc", + "line": 53, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "SuccessfulTest", + } + ], }, { - 'name': 'FailedTest', - 'tests': 1, - 'failures': 1, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'Fails', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 61, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'FailedTest', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Expected equality of these values:\n' - ' 1\n 2' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], - }], + "name": "FailedTest", + "tests": 1, + "failures": 1, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "Fails", + "file": "gtest_xml_output_unittest_.cc", + "line": 61, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "FailedTest", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + "Expected equality of these values:\n" + " 1\n 2" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], + } + ], }, { - 'name': 'DisabledTest', - 'tests': 1, - 'failures': 0, - 'disabled': 1, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'DISABLED_test_not_run', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 68, - 'status': 'NOTRUN', - 'result': 'SUPPRESSED', - 'time': '*', - 'timestamp': '*', - 'classname': 'DisabledTest', - }], + "name": "DisabledTest", + "tests": 1, + "failures": 0, + "disabled": 1, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "DISABLED_test_not_run", + "file": "gtest_xml_output_unittest_.cc", + "line": 68, + "status": "NOTRUN", + "result": "SUPPRESSED", + "time": "*", + "timestamp": "*", + "classname": "DisabledTest", + } + ], }, { - 'name': 'SkippedTest', - 'tests': 3, - 'failures': 1, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [ - { - 'name': 'Skipped', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 75, - 'status': 'RUN', - 'result': 'SKIPPED', - 'time': '*', - 'timestamp': '*', - 'classname': 'SkippedTest', - 'skipped': [ - {'message': 'gtest_xml_output_unittest_.cc:*\n\n'} - ], + "name": "SkippedTest", + "tests": 3, + "failures": 1, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "Skipped", + "file": "gtest_xml_output_unittest_.cc", + "line": 75, + "status": "RUN", + "result": "SKIPPED", + "time": "*", + "timestamp": "*", + "classname": "SkippedTest", + "skipped": [{"message": "gtest_xml_output_unittest_.cc:*\n\n"}], }, { - 'name': 'SkippedWithMessage', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 79, - 'status': 'RUN', - 'result': 'SKIPPED', - 'time': '*', - 'timestamp': '*', - 'classname': 'SkippedTest', - 'skipped': [{ - 'message': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'It is good practice to tell why you skip a test.\n' - ) - }], + "name": "SkippedWithMessage", + "file": "gtest_xml_output_unittest_.cc", + "line": 79, + "status": "RUN", + "result": "SKIPPED", + "time": "*", + "timestamp": "*", + "classname": "SkippedTest", + "skipped": [ + { + "message": ( + "gtest_xml_output_unittest_.cc:*\n" + "It is good practice to tell why you skip a test.\n" + ) + } + ], }, { - 'name': 'SkippedAfterFailure', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 83, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'SkippedTest', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Expected equality of these values:\n' - ' 1\n 2' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], - 'skipped': [{ - 'message': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'It is good practice to tell why you skip a test.\n' - ) - }], + "name": "SkippedAfterFailure", + "file": "gtest_xml_output_unittest_.cc", + "line": 83, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "SkippedTest", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + "Expected equality of these values:\n" + " 1\n 2" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], + "skipped": [ + { + "message": ( + "gtest_xml_output_unittest_.cc:*\n" + "It is good practice to tell why you skip a test.\n" + ) + } + ], }, ], }, { - 'name': 'MixedResultTest', - 'tests': 3, - 'failures': 1, - 'disabled': 1, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [ - { - 'name': 'Succeeds', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 88, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'MixedResultTest', + "name": "MixedResultTest", + "tests": 3, + "failures": 1, + "disabled": 1, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "Succeeds", + "file": "gtest_xml_output_unittest_.cc", + "line": 88, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "MixedResultTest", }, { - 'name': 'Fails', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 93, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'MixedResultTest', - 'failures': [ + "name": "Fails", + "file": "gtest_xml_output_unittest_.cc", + "line": 93, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "MixedResultTest", + "failures": [ { - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Expected equality of these values:\n' - ' 1\n 2' - + STACK_TRACE_TEMPLATE + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + "Expected equality of these values:\n" + " 1\n 2" + STACK_TRACE_TEMPLATE ), - 'type': '', + "type": "", }, { - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Expected equality of these values:\n' - ' 2\n 3' - + STACK_TRACE_TEMPLATE + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + "Expected equality of these values:\n" + " 2\n 3" + STACK_TRACE_TEMPLATE ), - 'type': '', + "type": "", }, ], }, { - 'name': 'DISABLED_test', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 98, - 'status': 'NOTRUN', - 'result': 'SUPPRESSED', - 'time': '*', - 'timestamp': '*', - 'classname': 'MixedResultTest', + "name": "DISABLED_test", + "file": "gtest_xml_output_unittest_.cc", + "line": 98, + "status": "NOTRUN", + "result": "SUPPRESSED", + "time": "*", + "timestamp": "*", + "classname": "MixedResultTest", }, ], }, { - 'name': 'XmlQuotingTest', - 'tests': 1, - 'failures': 1, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'OutputsCData', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 102, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'XmlQuotingTest', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Failed\nXML output: ' - '' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], - }], + "name": "XmlQuotingTest", + "tests": 1, + "failures": 1, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "OutputsCData", + "file": "gtest_xml_output_unittest_.cc", + "line": 102, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "XmlQuotingTest", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + 'Failed\nXML output: ' + "" + + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], + } + ], }, { - 'name': 'InvalidCharactersTest', - 'tests': 1, - 'failures': 1, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'InvalidCharactersInMessage', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 109, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'InvalidCharactersTest', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\n' - 'Failed\nInvalid characters in brackets' - ' [\x01\x02]' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], - }], + "name": "InvalidCharactersTest", + "tests": 1, + "failures": 1, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "InvalidCharactersInMessage", + "file": "gtest_xml_output_unittest_.cc", + "line": 109, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "InvalidCharactersTest", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\n" + "Failed\nInvalid characters in brackets" + " [\x01\x02]" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], + } + ], }, { - 'name': 'PropertyRecordingTest', - 'tests': 4, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'SetUpTestSuite': 'yes', - 'SetUpTestSuite (with whitespace)': 'yes and yes', - 'TearDownTestSuite': 'aye', - 'TearDownTestSuite (with whitespace)': 'aye and aye', - 'testsuite': [ - { - 'name': 'OneProperty', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 125, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'PropertyRecordingTest', - 'key_1': '1', + "name": "PropertyRecordingTest", + "tests": 4, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "SetUpTestSuite": "yes", + "SetUpTestSuite (with whitespace)": "yes and yes", + "TearDownTestSuite": "aye", + "TearDownTestSuite (with whitespace)": "aye and aye", + "testsuite": [ + { + "name": "OneProperty", + "file": "gtest_xml_output_unittest_.cc", + "line": 125, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "PropertyRecordingTest", + "key_1": "1", }, { - 'name': 'IntValuedProperty', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 129, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'PropertyRecordingTest', - 'key_int': '1', + "name": "IntValuedProperty", + "file": "gtest_xml_output_unittest_.cc", + "line": 129, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "PropertyRecordingTest", + "key_int": "1", }, { - 'name': 'ThreeProperties', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 133, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'PropertyRecordingTest', - 'key_1': '1', - 'key_2': '2', - 'key_3': '3', + "name": "ThreeProperties", + "file": "gtest_xml_output_unittest_.cc", + "line": 133, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "PropertyRecordingTest", + "key_1": "1", + "key_2": "2", + "key_3": "3", }, { - 'name': 'TwoValuesForOneKeyUsesLastValue', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 139, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'PropertyRecordingTest', - 'key_1': '2', + "name": "TwoValuesForOneKeyUsesLastValue", + "file": "gtest_xml_output_unittest_.cc", + "line": 139, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "PropertyRecordingTest", + "key_1": "2", }, ], }, { - 'name': 'NoFixtureTest', - 'tests': 3, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [ - { - 'name': 'RecordProperty', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 144, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'NoFixtureTest', - 'key': '1', + "name": "NoFixtureTest", + "tests": 3, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "RecordProperty", + "file": "gtest_xml_output_unittest_.cc", + "line": 144, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "NoFixtureTest", + "key": "1", }, { - 'name': 'ExternalUtilityThatCallsRecordIntValuedProperty', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 157, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'NoFixtureTest', - 'key_for_utility_int': '1', + "name": "ExternalUtilityThatCallsRecordIntValuedProperty", + "file": "gtest_xml_output_unittest_.cc", + "line": 157, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "NoFixtureTest", + "key_for_utility_int": "1", }, { - 'name': ( - 'ExternalUtilityThatCallsRecordStringValuedProperty' - ), - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 161, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'NoFixtureTest', - 'key_for_utility_string': '1', + "name": ("ExternalUtilityThatCallsRecordStringValuedProperty"), + "file": "gtest_xml_output_unittest_.cc", + "line": 161, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "NoFixtureTest", + "key_for_utility_string": "1", }, ], }, { - 'name': 'SetupFailTest', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [ - { - 'name': 'NoopPassingTest', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 172, - 'status': 'RUN', - 'result': 'SKIPPED', - 'timestamp': '*', - 'time': '*', - 'classname': 'SetupFailTest', - 'skipped': [ - {'message': 'gtest_xml_output_unittest_.cc:*\n'} - ], + "name": "SetupFailTest", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "NoopPassingTest", + "file": "gtest_xml_output_unittest_.cc", + "line": 172, + "status": "RUN", + "result": "SKIPPED", + "timestamp": "*", + "time": "*", + "classname": "SetupFailTest", + "skipped": [{"message": "gtest_xml_output_unittest_.cc:*\n"}], }, { - 'name': '', - 'status': 'RUN', - 'result': 'COMPLETED', - 'timestamp': '*', - 'time': '*', - 'classname': '', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\nExpected equality' - ' of these values:\n 1\n 2' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], + "name": "", + "status": "RUN", + "result": "COMPLETED", + "timestamp": "*", + "time": "*", + "classname": "", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\nExpected equality" + " of these values:\n 1\n 2" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], }, ], }, { - 'name': 'TearDownFailTest', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'timestamp': '*', - 'time': '*', - 'testsuite': [ - { - 'name': 'NoopPassingTest', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 179, - 'status': 'RUN', - 'result': 'COMPLETED', - 'timestamp': '*', - 'time': '*', - 'classname': 'TearDownFailTest', + "name": "TearDownFailTest", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "timestamp": "*", + "time": "*", + "testsuite": [ + { + "name": "NoopPassingTest", + "file": "gtest_xml_output_unittest_.cc", + "line": 179, + "status": "RUN", + "result": "COMPLETED", + "timestamp": "*", + "time": "*", + "classname": "TearDownFailTest", }, { - 'name': '', - 'status': 'RUN', - 'result': 'COMPLETED', - 'timestamp': '*', - 'time': '*', - 'classname': '', - 'failures': [{ - 'failure': ( - 'gtest_xml_output_unittest_.cc:*\nExpected equality' - ' of these values:\n 1\n 2' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], + "name": "", + "status": "RUN", + "result": "COMPLETED", + "timestamp": "*", + "time": "*", + "classname": "", + "failures": [ + { + "failure": ( + "gtest_xml_output_unittest_.cc:*\nExpected equality" + " of these values:\n 1\n 2" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], }, ], }, { - 'name': 'TypedTest/0', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'HasTypeParamAttribute', - 'type_param': 'int', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 193, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'TypedTest/0', - }], + "name": "TypedTest/0", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "HasTypeParamAttribute", + "type_param": "int", + "file": "gtest_xml_output_unittest_.cc", + "line": 193, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "TypedTest/0", + } + ], }, { - 'name': 'TypedTest/1', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'HasTypeParamAttribute', - 'type_param': 'long', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 193, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'TypedTest/1', - }], + "name": "TypedTest/1", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "HasTypeParamAttribute", + "type_param": "long", + "file": "gtest_xml_output_unittest_.cc", + "line": 193, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "TypedTest/1", + } + ], }, { - 'name': 'Single/TypeParameterizedTestSuite/0', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'HasTypeParamAttribute', - 'type_param': 'int', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 200, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/TypeParameterizedTestSuite/0', - }], + "name": "Single/TypeParameterizedTestSuite/0", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "HasTypeParamAttribute", + "type_param": "int", + "file": "gtest_xml_output_unittest_.cc", + "line": 200, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/TypeParameterizedTestSuite/0", + } + ], }, { - 'name': 'Single/TypeParameterizedTestSuite/1', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'HasTypeParamAttribute', - 'type_param': 'long', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 200, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/TypeParameterizedTestSuite/1', - }], + "name": "Single/TypeParameterizedTestSuite/1", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "HasTypeParamAttribute", + "type_param": "long", + "file": "gtest_xml_output_unittest_.cc", + "line": 200, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/TypeParameterizedTestSuite/1", + } + ], }, { - 'name': 'Single/ValueParamTest', - 'tests': 4, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [ - { - 'name': 'HasValueParamAttribute/0', - 'value_param': '33', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 184, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/ValueParamTest', + "name": "Single/ValueParamTest", + "tests": 4, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "HasValueParamAttribute/0", + "value_param": "33", + "file": "gtest_xml_output_unittest_.cc", + "line": 184, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/ValueParamTest", }, { - 'name': 'HasValueParamAttribute/1', - 'value_param': '42', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 184, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/ValueParamTest', + "name": "HasValueParamAttribute/1", + "value_param": "42", + "file": "gtest_xml_output_unittest_.cc", + "line": 184, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/ValueParamTest", }, { - 'name': 'AnotherTestThatHasValueParamAttribute/0', - 'value_param': '33', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 185, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/ValueParamTest', + "name": "AnotherTestThatHasValueParamAttribute/0", + "value_param": "33", + "file": "gtest_xml_output_unittest_.cc", + "line": 185, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/ValueParamTest", }, { - 'name': 'AnotherTestThatHasValueParamAttribute/1', - 'value_param': '42', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 185, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'Single/ValueParamTest', + "name": "AnotherTestThatHasValueParamAttribute/1", + "value_param": "42", + "file": "gtest_xml_output_unittest_.cc", + "line": 185, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "Single/ValueParamTest", }, ], }, @@ -637,76 +658,85 @@ } EXPECTED_FILTERED = { - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'name': 'AllTests', - 'ad_hoc_property': '42', - 'testsuites': [{ - 'name': 'SuccessfulTest', - 'tests': 1, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': 'Succeeds', - 'file': 'gtest_xml_output_unittest_.cc', - 'line': 53, - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': 'SuccessfulTest', - }], - }], + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "name": "AllTests", + "ad_hoc_property": "42", + "testsuites": [ + { + "name": "SuccessfulTest", + "tests": 1, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "Succeeds", + "file": "gtest_xml_output_unittest_.cc", + "line": 53, + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "SuccessfulTest", + } + ], + } + ], } EXPECTED_NO_TEST = { - 'tests': 0, - 'failures': 0, - 'disabled': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'name': 'AllTests', - 'testsuites': [{ - 'name': 'NonTestSuiteFailure', - 'tests': 1, - 'failures': 1, - 'disabled': 0, - 'skipped': 0, - 'errors': 0, - 'time': '*', - 'timestamp': '*', - 'testsuite': [{ - 'name': '', - 'status': 'RUN', - 'result': 'COMPLETED', - 'time': '*', - 'timestamp': '*', - 'classname': '', - 'failures': [{ - 'failure': ( - 'gtest_no_test_unittest.cc:*\n' - 'Expected equality of these values:\n' - ' 1\n 2' - + STACK_TRACE_TEMPLATE - ), - 'type': '', - }], - }], - }], + "tests": 0, + "failures": 0, + "disabled": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "name": "AllTests", + "testsuites": [ + { + "name": "NonTestSuiteFailure", + "tests": 1, + "failures": 1, + "disabled": 0, + "skipped": 0, + "errors": 0, + "time": "*", + "timestamp": "*", + "testsuite": [ + { + "name": "", + "status": "RUN", + "result": "COMPLETED", + "time": "*", + "timestamp": "*", + "classname": "", + "failures": [ + { + "failure": ( + "gtest_no_test_unittest.cc:*\n" + "Expected equality of these values:\n" + " 1\n 2" + STACK_TRACE_TEMPLATE + ), + "type": "", + } + ], + } + ], + } + ], } GTEST_PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath(GTEST_PROGRAM_NAME) SUPPORTS_TYPED_TESTS = ( - 'TypedTest' + "TypedTest" in gtest_test_utils.Subprocess( [GTEST_PROGRAM_PATH, GTEST_LIST_TESTS_FLAG], capture_stderr=False ).output @@ -714,201 +744,201 @@ class GTestJsonOutputUnitTest(gtest_test_utils.TestCase): - """Unit test for Google Test's JSON output functionality.""" - - # This test currently breaks on platforms that do not support typed and - # type-parameterized tests, so we don't run it under them. - if SUPPORTS_TYPED_TESTS: - - def testNonEmptyJsonOutput(self): - """Verifies JSON output for a Google Test binary with non-empty output. - - Runs a test program that generates a non-empty JSON output, and - tests that the JSON output is expected. - """ - self._TestJsonOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY, 1) - - def testNoTestJsonOutput(self): - """Verifies JSON output for a Google Test binary without actual tests. - - Runs a test program that generates an JSON output for a binary with no - tests, and tests that the JSON output is expected. - """ - - self._TestJsonOutput('gtest_no_test_unittest', EXPECTED_NO_TEST, 0) - - def testTimestampValue(self): - """Checks whether the timestamp attribute in the JSON output is valid. - - Runs a test program that generates an empty JSON output, and checks if - the timestamp attribute in the testsuites tag is valid. - """ - actual = self._GetJsonOutput('gtest_no_test_unittest', [], 0) - date_time_str = actual['timestamp'] - # datetime.strptime() is only available in Python 2.5+ so we have to - # parse the expected datetime manually. - match = re.match(r'(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)', date_time_str) - self.assertTrue( - re.match, - 'JSON datettime string %s has incorrect format' % date_time_str, - ) - date_time_from_json = datetime.datetime( - year=int(match.group(1)), - month=int(match.group(2)), - day=int(match.group(3)), - hour=int(match.group(4)), - minute=int(match.group(5)), - second=int(match.group(6)), - ) - - time_delta = abs(datetime.datetime.now() - date_time_from_json) - # timestamp value should be near the current local time - self.assertTrue( - time_delta < datetime.timedelta(seconds=600), - 'time_delta is %s' % time_delta, - ) - - def testDefaultOutputFile(self): - """Verifies the default output file name. - - Confirms that Google Test produces an JSON output file with the expected - default name if no name is explicitly specified. - """ - output_file = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_DEFAULT_OUTPUT_FILE - ) - gtest_prog_path = gtest_test_utils.GetTestExecutablePath( - 'gtest_no_test_unittest' - ) - try: - os.remove(output_file) - except OSError: - e = sys.exc_info()[1] - if e.errno != errno.ENOENT: - raise - - p = gtest_test_utils.Subprocess( - [gtest_prog_path, '%s=json' % GTEST_OUTPUT_FLAG], - working_dir=gtest_test_utils.GetTempDir(), - ) - self.assertTrue(p.exited) - self.assertEqual(0, p.exit_code) - self.assertTrue(os.path.isfile(output_file)) - - def testSuppressedJsonOutput(self): - """Verifies that no JSON output is generated. - - Tests that no JSON file is generated if the default JSON listener is - shut down before RUN_ALL_TESTS is invoked. - """ - - json_path = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_PROGRAM_NAME + 'out.json' - ) - if os.path.isfile(json_path): - os.remove(json_path) - - command = [ - GTEST_PROGRAM_PATH, - '%s=json:%s' % (GTEST_OUTPUT_FLAG, json_path), - '--shut_down_xml', - ] - p = gtest_test_utils.Subprocess(command) - if p.terminated_by_signal: - # p.signal is available only if p.terminated_by_signal is True. - self.assertFalse( - p.terminated_by_signal, - '%s was killed by signal %d' % (GTEST_PROGRAM_NAME, p.signal), - ) - else: - self.assertTrue(p.exited) - self.assertEqual( - 1, - p.exit_code, - "'%s' exited with code %s, which doesn't match " - 'the expected exit code %s.' % (command, p.exit_code, 1), - ) - - self.assertTrue(not os.path.isfile(json_path)) - - def testFilteredTestJsonOutput(self): - """Verifies JSON output when a filter is applied. - - Runs a test program that executes only some tests and verifies that - non-selected tests do not show up in the JSON output. - """ - - self._TestJsonOutput( - GTEST_PROGRAM_NAME, - EXPECTED_FILTERED, - 0, - extra_args=['%s=SuccessfulTest.*' % GTEST_FILTER_FLAG], - ) - - def _GetJsonOutput(self, gtest_prog_name, extra_args, expected_exit_code): - """Returns the JSON output generated by running the program gtest_prog_name. - - Furthermore, the program's exit code must be expected_exit_code. - - Args: - gtest_prog_name: Google Test binary name. - extra_args: extra arguments to binary invocation. - expected_exit_code: program's exit code. - """ - json_path = os.path.join( - gtest_test_utils.GetTempDir(), gtest_prog_name + 'out.json' - ) - gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) - - command = [ - gtest_prog_path, - '%s=json:%s' % (GTEST_OUTPUT_FLAG, json_path), - ] + extra_args - p = gtest_test_utils.Subprocess(command) - if p.terminated_by_signal: - self.assertTrue( - False, '%s was killed by signal %d' % (gtest_prog_name, p.signal) - ) - else: - self.assertTrue(p.exited) - self.assertEqual( - expected_exit_code, - p.exit_code, - "'%s' exited with code %s, which doesn't match " - 'the expected exit code %s.' - % (command, p.exit_code, expected_exit_code), - ) - with open(json_path) as f: - actual = json.load(f) - return actual - - def _TestJsonOutput( - self, gtest_prog_name, expected, expected_exit_code, extra_args=None - ): - """Checks the JSON output generated by the Google Test binary. - - Asserts that the JSON document generated by running the program - gtest_prog_name matches expected_json, a string containing another - JSON document. Furthermore, the program's exit code must be - expected_exit_code. - - Args: - gtest_prog_name: Google Test binary name. - expected: expected output. - expected_exit_code: program's exit code. - extra_args: extra arguments to binary invocation. - """ - - actual = self._GetJsonOutput( - gtest_prog_name, extra_args or [], expected_exit_code - ) - self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) - - -if __name__ == '__main__': - if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: - # unittest.main() can't handle unknown flags - sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) - - os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' - gtest_test_utils.Main() + """Unit test for Google Test's JSON output functionality.""" + + # This test currently breaks on platforms that do not support typed and + # type-parameterized tests, so we don't run it under them. + if SUPPORTS_TYPED_TESTS: + + def testNonEmptyJsonOutput(self): + """Verifies JSON output for a Google Test binary with non-empty output. + + Runs a test program that generates a non-empty JSON output, and + tests that the JSON output is expected. + """ + self._TestJsonOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY, 1) + + def testNoTestJsonOutput(self): + """Verifies JSON output for a Google Test binary without actual tests. + + Runs a test program that generates an JSON output for a binary with no + tests, and tests that the JSON output is expected. + """ + + self._TestJsonOutput("gtest_no_test_unittest", EXPECTED_NO_TEST, 0) + + def testTimestampValue(self): + """Checks whether the timestamp attribute in the JSON output is valid. + + Runs a test program that generates an empty JSON output, and checks if + the timestamp attribute in the testsuites tag is valid. + """ + actual = self._GetJsonOutput("gtest_no_test_unittest", [], 0) + date_time_str = actual["timestamp"] + # datetime.strptime() is only available in Python 2.5+ so we have to + # parse the expected datetime manually. + match = re.match(r"(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)", date_time_str) + self.assertTrue( + re.match, + "JSON datettime string %s has incorrect format" % date_time_str, + ) + date_time_from_json = datetime.datetime( + year=int(match.group(1)), + month=int(match.group(2)), + day=int(match.group(3)), + hour=int(match.group(4)), + minute=int(match.group(5)), + second=int(match.group(6)), + ) + + time_delta = abs(datetime.datetime.now() - date_time_from_json) + # timestamp value should be near the current local time + self.assertTrue( + time_delta < datetime.timedelta(seconds=600), + "time_delta is %s" % time_delta, + ) + + def testDefaultOutputFile(self): + """Verifies the default output file name. + + Confirms that Google Test produces an JSON output file with the expected + default name if no name is explicitly specified. + """ + output_file = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_DEFAULT_OUTPUT_FILE + ) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + "gtest_no_test_unittest" + ) + try: + os.remove(output_file) + except OSError: + e = sys.exc_info()[1] + if e.errno != errno.ENOENT: + raise + + p = gtest_test_utils.Subprocess( + [gtest_prog_path, "%s=json" % GTEST_OUTPUT_FLAG], + working_dir=gtest_test_utils.GetTempDir(), + ) + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + self.assertTrue(os.path.isfile(output_file)) + + def testSuppressedJsonOutput(self): + """Verifies that no JSON output is generated. + + Tests that no JSON file is generated if the default JSON listener is + shut down before RUN_ALL_TESTS is invoked. + """ + + json_path = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_PROGRAM_NAME + "out.json" + ) + if os.path.isfile(json_path): + os.remove(json_path) + + command = [ + GTEST_PROGRAM_PATH, + "%s=json:%s" % (GTEST_OUTPUT_FLAG, json_path), + "--shut_down_xml", + ] + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + # p.signal is available only if p.terminated_by_signal is True. + self.assertFalse( + p.terminated_by_signal, + "%s was killed by signal %d" % (GTEST_PROGRAM_NAME, p.signal), + ) + else: + self.assertTrue(p.exited) + self.assertEqual( + 1, + p.exit_code, + "'%s' exited with code %s, which doesn't match " + "the expected exit code %s." % (command, p.exit_code, 1), + ) + + self.assertTrue(not os.path.isfile(json_path)) + + def testFilteredTestJsonOutput(self): + """Verifies JSON output when a filter is applied. + + Runs a test program that executes only some tests and verifies that + non-selected tests do not show up in the JSON output. + """ + + self._TestJsonOutput( + GTEST_PROGRAM_NAME, + EXPECTED_FILTERED, + 0, + extra_args=["%s=SuccessfulTest.*" % GTEST_FILTER_FLAG], + ) + + def _GetJsonOutput(self, gtest_prog_name, extra_args, expected_exit_code): + """Returns the JSON output generated by running the program gtest_prog_name. + + Furthermore, the program's exit code must be expected_exit_code. + + Args: + gtest_prog_name: Google Test binary name. + extra_args: extra arguments to binary invocation. + expected_exit_code: program's exit code. + """ + json_path = os.path.join( + gtest_test_utils.GetTempDir(), gtest_prog_name + "out.json" + ) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) + + command = [ + gtest_prog_path, + "%s=json:%s" % (GTEST_OUTPUT_FLAG, json_path), + ] + extra_args + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + self.assertTrue( + False, "%s was killed by signal %d" % (gtest_prog_name, p.signal) + ) + else: + self.assertTrue(p.exited) + self.assertEqual( + expected_exit_code, + p.exit_code, + "'%s' exited with code %s, which doesn't match " + "the expected exit code %s." + % (command, p.exit_code, expected_exit_code), + ) + with open(json_path) as f: + actual = json.load(f) + return actual + + def _TestJsonOutput( + self, gtest_prog_name, expected, expected_exit_code, extra_args=None + ): + """Checks the JSON output generated by the Google Test binary. + + Asserts that the JSON document generated by running the program + gtest_prog_name matches expected_json, a string containing another + JSON document. Furthermore, the program's exit code must be + expected_exit_code. + + Args: + gtest_prog_name: Google Test binary name. + expected: expected output. + expected_exit_code: program's exit code. + extra_args: extra arguments to binary invocation. + """ + + actual = self._GetJsonOutput( + gtest_prog_name, extra_args or [], expected_exit_code + ) + self.assertEqual(expected, gtest_json_test_utils.normalize(actual)) + + +if __name__ == "__main__": + if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + + os.environ["GTEST_STACK_TRACE_DEPTH"] = "1" + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-list-tests-unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-list-tests-unittest.py index 977e57f0..a69bc71f 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-list-tests-unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-list-tests-unittest.py @@ -43,12 +43,10 @@ # Constants. # The command line flag for enabling/disabling listing all tests. -LIST_TESTS_FLAG = 'gtest_list_tests' +LIST_TESTS_FLAG = "gtest_list_tests" # Path to the googletest-list-tests-unittest_ program. -EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'googletest-list-tests-unittest_' -) +EXE_PATH = gtest_test_utils.GetTestExecutablePath("googletest-list-tests-unittest_") # The expected output when running googletest-list-tests-unittest_ with # --gtest_list_tests @@ -118,108 +116,106 @@ def Run(args): - """Runs googletest-list-tests-unittest_ and returns the list of tests printed.""" + """Runs googletest-list-tests-unittest_ and returns the list of tests printed.""" - return gtest_test_utils.Subprocess( - [EXE_PATH] + args, capture_stderr=False - ).output + return gtest_test_utils.Subprocess([EXE_PATH] + args, capture_stderr=False).output # The unit test. class GTestListTestsUnitTest(gtest_test_utils.TestCase): - """Tests using the --gtest_list_tests flag to list all tests.""" - - def RunAndVerify(self, flag_value, expected_output_re, other_flag): - """Run googletest-list-tests-unittest_ and verify the output. - - Runs googletest-list-tests-unittest_ and verifies that it prints - the correct tests. - - Args: - flag_value: value of the --gtest_list_tests flag; None if the flag - should not be present. - expected_output_re: regular expression that matches the expected output - after running command; - other_flag: a different flag to be passed to command along with - gtest_list_tests; None if the flag should not be present. - """ - - if flag_value is None: - flag = '' - flag_expression = 'not set' - elif flag_value == '0': - flag = '--%s=0' % LIST_TESTS_FLAG - flag_expression = '0' - else: - flag = '--%s' % LIST_TESTS_FLAG - flag_expression = '1' - - args = [flag] - - if other_flag is not None: - args += [other_flag] - - output = Run(args) - - if expected_output_re: - self.assertTrue( - expected_output_re.match(output), - 'when %s is %s, the output of "%s" is "%s",\n' - 'which does not match regex "%s"' - % ( - LIST_TESTS_FLAG, - flag_expression, - ' '.join(args), - output, - expected_output_re.pattern, - ), - ) - else: - self.assertTrue( - not EXPECTED_OUTPUT_NO_FILTER_RE.match(output), - 'when %s is %s, the output of "%s" is "%s"' - % (LIST_TESTS_FLAG, flag_expression, ' '.join(args), output), - ) - - def testDefaultBehavior(self): - """Tests the behavior of the default mode.""" - - self.RunAndVerify(flag_value=None, expected_output_re=None, other_flag=None) - - def testFlag(self): - """Tests using the --gtest_list_tests flag.""" - - self.RunAndVerify(flag_value='0', expected_output_re=None, other_flag=None) - self.RunAndVerify( - flag_value='1', - expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, - other_flag=None, - ) - - def testOverrideNonFilterFlags(self): - """Tests that --gtest_list_tests overrides the non-filter flags.""" - - self.RunAndVerify( - flag_value='1', - expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, - other_flag='--gtest_break_on_failure', - ) - - def testWithFilterFlags(self): - """Tests that --gtest_list_tests takes into account the filter flags. - - Tests that --gtest_list_tests takes into account the - --gtest_filter flag. - """ - - self.RunAndVerify( - flag_value='1', - expected_output_re=EXPECTED_OUTPUT_FILTER_FOO_RE, - other_flag='--gtest_filter=Foo*', - ) - - -if __name__ == '__main__': - gtest_test_utils.Main() + """Tests using the --gtest_list_tests flag to list all tests.""" + + def RunAndVerify(self, flag_value, expected_output_re, other_flag): + """Run googletest-list-tests-unittest_ and verify the output. + + Runs googletest-list-tests-unittest_ and verifies that it prints + the correct tests. + + Args: + flag_value: value of the --gtest_list_tests flag; None if the flag + should not be present. + expected_output_re: regular expression that matches the expected output + after running command; + other_flag: a different flag to be passed to command along with + gtest_list_tests; None if the flag should not be present. + """ + + if flag_value is None: + flag = "" + flag_expression = "not set" + elif flag_value == "0": + flag = "--%s=0" % LIST_TESTS_FLAG + flag_expression = "0" + else: + flag = "--%s" % LIST_TESTS_FLAG + flag_expression = "1" + + args = [flag] + + if other_flag is not None: + args += [other_flag] + + output = Run(args) + + if expected_output_re: + self.assertTrue( + expected_output_re.match(output), + 'when %s is %s, the output of "%s" is "%s",\n' + 'which does not match regex "%s"' + % ( + LIST_TESTS_FLAG, + flag_expression, + " ".join(args), + output, + expected_output_re.pattern, + ), + ) + else: + self.assertTrue( + not EXPECTED_OUTPUT_NO_FILTER_RE.match(output), + 'when %s is %s, the output of "%s" is "%s"' + % (LIST_TESTS_FLAG, flag_expression, " ".join(args), output), + ) + + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" + + self.RunAndVerify(flag_value=None, expected_output_re=None, other_flag=None) + + def testFlag(self): + """Tests using the --gtest_list_tests flag.""" + + self.RunAndVerify(flag_value="0", expected_output_re=None, other_flag=None) + self.RunAndVerify( + flag_value="1", + expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, + other_flag=None, + ) + + def testOverrideNonFilterFlags(self): + """Tests that --gtest_list_tests overrides the non-filter flags.""" + + self.RunAndVerify( + flag_value="1", + expected_output_re=EXPECTED_OUTPUT_NO_FILTER_RE, + other_flag="--gtest_break_on_failure", + ) + + def testWithFilterFlags(self): + """Tests that --gtest_list_tests takes into account the filter flags. + + Tests that --gtest_list_tests takes into account the + --gtest_filter flag. + """ + + self.RunAndVerify( + flag_value="1", + expected_output_re=EXPECTED_OUTPUT_FILTER_FOO_RE, + other_flag="--gtest_filter=Foo*", + ) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-output-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-output-test.py index 6d80d532..49f58edf 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-output-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-output-test.py @@ -46,47 +46,47 @@ # The flag for generating the golden file -GENGOLDEN_FLAG = '--gengolden' -CATCH_EXCEPTIONS_ENV_VAR_NAME = 'GTEST_CATCH_EXCEPTIONS' +GENGOLDEN_FLAG = "--gengolden" +CATCH_EXCEPTIONS_ENV_VAR_NAME = "GTEST_CATCH_EXCEPTIONS" # The flag indicating stacktraces are not supported -NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' +NO_STACKTRACE_SUPPORT_FLAG = "--no_stacktrace_support" -IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' -IS_WINDOWS = os.name == 'nt' +IS_LINUX = os.name == "posix" and os.uname()[0] == "Linux" +IS_WINDOWS = os.name == "nt" -GOLDEN_NAME = 'googletest-output-test-golden-lin.txt' +GOLDEN_NAME = "googletest-output-test-golden-lin.txt" -PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath('googletest-output-test_') +PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath("googletest-output-test_") # At least one command we exercise must not have the # 'internal_skip_environment_and_ad_hoc_tests' argument. -COMMAND_LIST_TESTS = ({}, [PROGRAM_PATH, '--gtest_list_tests']) -COMMAND_WITH_COLOR = ({}, [PROGRAM_PATH, '--gtest_color=yes']) +COMMAND_LIST_TESTS = ({}, [PROGRAM_PATH, "--gtest_list_tests"]) +COMMAND_WITH_COLOR = ({}, [PROGRAM_PATH, "--gtest_color=yes"]) COMMAND_WITH_TIME = ( {}, [ PROGRAM_PATH, - '--gtest_print_time', - 'internal_skip_environment_and_ad_hoc_tests', - '--gtest_filter=FatalFailureTest.*:LoggingTest.*', + "--gtest_print_time", + "internal_skip_environment_and_ad_hoc_tests", + "--gtest_filter=FatalFailureTest.*:LoggingTest.*", ], ) COMMAND_WITH_DISABLED = ( {}, [ PROGRAM_PATH, - '--gtest_also_run_disabled_tests', - 'internal_skip_environment_and_ad_hoc_tests', - '--gtest_filter=*DISABLED_*', + "--gtest_also_run_disabled_tests", + "internal_skip_environment_and_ad_hoc_tests", + "--gtest_filter=*DISABLED_*", ], ) COMMAND_WITH_SHARDING = ( - {'GTEST_SHARD_INDEX': '1', 'GTEST_TOTAL_SHARDS': '2'}, + {"GTEST_SHARD_INDEX": "1", "GTEST_TOTAL_SHARDS": "2"}, [ PROGRAM_PATH, - 'internal_skip_environment_and_ad_hoc_tests', - '--gtest_filter=PassingTest.*', + "internal_skip_environment_and_ad_hoc_tests", + "--gtest_filter=PassingTest.*", ], ) @@ -94,183 +94,181 @@ def ToUnixLineEnding(s): - """Changes all Windows/Mac line endings in s to UNIX line endings.""" + """Changes all Windows/Mac line endings in s to UNIX line endings.""" - return s.replace('\r\n', '\n').replace('\r', '\n') + return s.replace("\r\n", "\n").replace("\r", "\n") def RemoveLocations(test_output): - """Removes all file location info from a Google Test program's output. + """Removes all file location info from a Google Test program's output. - Args: - test_output: the output of a Google Test program. + Args: + test_output: the output of a Google Test program. - Returns: - output with all file location info (in the form of - 'DIRECTORY/FILE_NAME:LINE_NUMBER: 'or - 'DIRECTORY\\FILE_NAME(LINE_NUMBER): ') replaced by - 'FILE_NAME:#: '. - """ + Returns: + output with all file location info (in the form of + 'DIRECTORY/FILE_NAME:LINE_NUMBER: 'or + 'DIRECTORY\\FILE_NAME(LINE_NUMBER): ') replaced by + 'FILE_NAME:#: '. + """ - return re.sub( - r'.*[/\\]((googletest-output-test_|gtest).cc)(\:\d+|\(\d+\))\: ', - r'\1:#: ', - test_output, - ) + return re.sub( + r".*[/\\]((googletest-output-test_|gtest).cc)(\:\d+|\(\d+\))\: ", + r"\1:#: ", + test_output, + ) def RemoveStackTraceDetails(output): - """Removes all stack traces from a Google Test program's output.""" + """Removes all stack traces from a Google Test program's output.""" - # *? means "find the shortest string that matches". - return re.sub( - r'Stack trace:(.|\n)*?\n\n', 'Stack trace: (omitted)\n\n', output - ) + # *? means "find the shortest string that matches". + return re.sub(r"Stack trace:(.|\n)*?\n\n", "Stack trace: (omitted)\n\n", output) def RemoveStackTraces(output): - """Removes all traces of stack traces from a Google Test program's output.""" + """Removes all traces of stack traces from a Google Test program's output.""" - # *? means "find the shortest string that matches". - return re.sub(r'Stack trace:(.|\n)*?\n', '', output) + # *? means "find the shortest string that matches". + return re.sub(r"Stack trace:(.|\n)*?\n", "", output) def RemoveTime(output): - """Removes all time information from a Google Test program's output.""" + """Removes all time information from a Google Test program's output.""" - return re.sub(r'\(\d+ ms', '(? ms', output) + return re.sub(r"\(\d+ ms", "(? ms", output) def RemoveTypeInfoDetails(test_output): - """Removes compiler-specific type info from Google Test program's output. + """Removes compiler-specific type info from Google Test program's output. - Args: - test_output: the output of a Google Test program. + Args: + test_output: the output of a Google Test program. - Returns: - output with type information normalized to canonical form. - """ + Returns: + output with type information normalized to canonical form. + """ - # some compilers output the name of type 'unsigned int' as 'unsigned' - return re.sub(r'unsigned int', 'unsigned', test_output) + # some compilers output the name of type 'unsigned int' as 'unsigned' + return re.sub(r"unsigned int", "unsigned", test_output) def NormalizeToCurrentPlatform(test_output): - """Normalizes platform specific output details for easier comparison.""" + """Normalizes platform specific output details for easier comparison.""" - if IS_WINDOWS: - # Removes the color information that is not present on Windows. - test_output = re.sub('\x1b\\[(0;3\d)?m', '', test_output) - # Changes failure message headers into the Windows format. - test_output = re.sub(r': Failure\n', r': error: ', test_output) - # Changes file(line_number) to file:line_number. - test_output = re.sub(r'((\w|\.)+)\((\d+)\):', r'\1:\3:', test_output) + if IS_WINDOWS: + # Removes the color information that is not present on Windows. + test_output = re.sub("\x1b\\[(0;3\d)?m", "", test_output) + # Changes failure message headers into the Windows format. + test_output = re.sub(r": Failure\n", r": error: ", test_output) + # Changes file(line_number) to file:line_number. + test_output = re.sub(r"((\w|\.)+)\((\d+)\):", r"\1:\3:", test_output) - return test_output + return test_output def RemoveTestCounts(output): - """Removes test counts from a Google Test program's output.""" + """Removes test counts from a Google Test program's output.""" - output = re.sub(r'\d+ tests?, listed below', '? tests, listed below', output) - output = re.sub(r'\d+ FAILED TESTS', '? FAILED TESTS', output) - output = re.sub( - r'\d+ tests? from \d+ test cases?', '? tests from ? test cases', output - ) - output = re.sub(r'\d+ tests? from ([a-zA-Z_])', r'? tests from \1', output) - return re.sub(r'\d+ tests?\.', '? tests.', output) + output = re.sub(r"\d+ tests?, listed below", "? tests, listed below", output) + output = re.sub(r"\d+ FAILED TESTS", "? FAILED TESTS", output) + output = re.sub( + r"\d+ tests? from \d+ test cases?", "? tests from ? test cases", output + ) + output = re.sub(r"\d+ tests? from ([a-zA-Z_])", r"? tests from \1", output) + return re.sub(r"\d+ tests?\.", "? tests.", output) def RemoveMatchingTests(test_output, pattern): - """Removes output of specified tests from a Google Test program's output. + """Removes output of specified tests from a Google Test program's output. - This function strips not only the beginning and the end of a test but also - all output in between. + This function strips not only the beginning and the end of a test but also + all output in between. - Args: - test_output: A string containing the test output. - pattern: A regex string that matches names of test cases or tests - to remove. + Args: + test_output: A string containing the test output. + pattern: A regex string that matches names of test cases or tests + to remove. - Returns: - Contents of test_output with tests whose names match pattern removed. - """ + Returns: + Contents of test_output with tests whose names match pattern removed. + """ - test_output = re.sub( - r'.*\[ RUN \] .*%s(.|\n)*?\[( FAILED | OK )\] .*%s.*\n' - % (pattern, pattern), - '', - test_output, - ) - return re.sub(r'.*%s.*\n' % pattern, '', test_output) + test_output = re.sub( + r".*\[ RUN \] .*%s(.|\n)*?\[( FAILED | OK )\] .*%s.*\n" + % (pattern, pattern), + "", + test_output, + ) + return re.sub(r".*%s.*\n" % pattern, "", test_output) def NormalizeOutput(output): - """Normalizes output (the output of googletest-output-test_.exe).""" + """Normalizes output (the output of googletest-output-test_.exe).""" - output = ToUnixLineEnding(output) - output = RemoveLocations(output) - output = RemoveStackTraceDetails(output) - output = RemoveTime(output) - return output + output = ToUnixLineEnding(output) + output = RemoveLocations(output) + output = RemoveStackTraceDetails(output) + output = RemoveTime(output) + return output def GetShellCommandOutput(env_cmd): - """Runs a command in a sub-process, and returns its output in a string. + """Runs a command in a sub-process, and returns its output in a string. - Args: - env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra - environment variables to set, and element 1 is a string with the command - and any flags. + Args: + env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra + environment variables to set, and element 1 is a string with the command + and any flags. - Returns: - A string with the command's combined standard and diagnostic output. - """ + Returns: + A string with the command's combined standard and diagnostic output. + """ - # Spawns cmd in a sub-process, and gets its standard I/O file objects. - # Set and save the environment properly. - environ = os.environ.copy() - environ.update(env_cmd[0]) - p = gtest_test_utils.Subprocess(env_cmd[1], env=environ) + # Spawns cmd in a sub-process, and gets its standard I/O file objects. + # Set and save the environment properly. + environ = os.environ.copy() + environ.update(env_cmd[0]) + p = gtest_test_utils.Subprocess(env_cmd[1], env=environ) - return p.output + return p.output def GetCommandOutput(env_cmd): - """Runs a command and returns output with all file location info stripped off. + """Runs a command and returns output with all file location info stripped off. - Args: - env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra - environment variables to set, and element 1 is a string with the command - and any flags. + Args: + env_cmd: The shell command. A 2-tuple where element 0 is a dict of extra + environment variables to set, and element 1 is a string with the command + and any flags. - Returns: - A string with the command's combined standard and diagnostic output. File - location info is stripped. - """ + Returns: + A string with the command's combined standard and diagnostic output. File + location info is stripped. + """ - # Disables exception pop-ups on Windows. - environ, cmdline = env_cmd - environ = dict(environ) # Ensures we are modifying a copy. - environ[CATCH_EXCEPTIONS_ENV_VAR_NAME] = '1' - return NormalizeOutput(GetShellCommandOutput((environ, cmdline))) + # Disables exception pop-ups on Windows. + environ, cmdline = env_cmd + environ = dict(environ) # Ensures we are modifying a copy. + environ[CATCH_EXCEPTIONS_ENV_VAR_NAME] = "1" + return NormalizeOutput(GetShellCommandOutput((environ, cmdline))) def GetOutputOfAllCommands(): - """Returns concatenated output from several representative commands.""" + """Returns concatenated output from several representative commands.""" - return ( - GetCommandOutput(COMMAND_WITH_COLOR) - + GetCommandOutput(COMMAND_WITH_TIME) - + GetCommandOutput(COMMAND_WITH_DISABLED) - + GetCommandOutput(COMMAND_WITH_SHARDING) - ) + return ( + GetCommandOutput(COMMAND_WITH_COLOR) + + GetCommandOutput(COMMAND_WITH_TIME) + + GetCommandOutput(COMMAND_WITH_DISABLED) + + GetCommandOutput(COMMAND_WITH_SHARDING) + ) test_list = GetShellCommandOutput(COMMAND_LIST_TESTS) -SUPPORTS_DEATH_TESTS = 'DeathTest' in test_list -SUPPORTS_TYPED_TESTS = 'TypedTest' in test_list -SUPPORTS_THREADS = 'ExpectFailureWithThreadsTest' in test_list +SUPPORTS_DEATH_TESTS = "DeathTest" in test_list +SUPPORTS_TYPED_TESTS = "TypedTest" in test_list +SUPPORTS_THREADS = "ExpectFailureWithThreadsTest" in test_list SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv CAN_GENERATE_GOLDEN_FILE = ( @@ -283,103 +281,103 @@ def GetOutputOfAllCommands(): class GTestOutputTest(gtest_test_utils.TestCase): - def RemoveUnsupportedTests(self, test_output): - if not SUPPORTS_DEATH_TESTS: - test_output = RemoveMatchingTests(test_output, 'DeathTest') - if not SUPPORTS_TYPED_TESTS: - test_output = RemoveMatchingTests(test_output, 'TypedTest') - test_output = RemoveMatchingTests(test_output, 'TypedDeathTest') - test_output = RemoveMatchingTests(test_output, 'TypeParamDeathTest') - if not SUPPORTS_THREADS: - test_output = RemoveMatchingTests( - test_output, 'ExpectFailureWithThreadsTest' - ) - test_output = RemoveMatchingTests( - test_output, 'ScopedFakeTestPartResultReporterTest' - ) - test_output = RemoveMatchingTests(test_output, 'WorksConcurrently') - if not SUPPORTS_STACK_TRACES: - test_output = RemoveStackTraces(test_output) - - return test_output - - def testOutput(self): - output = GetOutputOfAllCommands() - - golden_file = open(GOLDEN_PATH, 'rb') - # A mis-configured source control system can cause \r appear in EOL - # sequences when we read the golden file irrespective of an operating - # system used. Therefore, we need to strip those \r's from newlines - # unconditionally. - golden = ToUnixLineEnding(golden_file.read().decode()) - golden_file.close() - - # We want the test to pass regardless of certain features being - # supported or not. - - # We still have to remove type name specifics in all cases. - normalized_actual = RemoveTypeInfoDetails(output) - normalized_golden = RemoveTypeInfoDetails(golden) - - if CAN_GENERATE_GOLDEN_FILE: - self.assertEqual( - normalized_golden, - normalized_actual, - '\n'.join( - difflib.unified_diff( - normalized_golden.split('\n'), - normalized_actual.split('\n'), - 'golden', - 'actual', - ) - ), - ) - else: - normalized_actual = NormalizeToCurrentPlatform( - RemoveTestCounts(normalized_actual) - ) - normalized_golden = NormalizeToCurrentPlatform( - RemoveTestCounts(self.RemoveUnsupportedTests(normalized_golden)) - ) - - # This code is very handy when debugging golden file differences: - if os.getenv('DEBUG_GTEST_OUTPUT_TEST'): - open( - os.path.join( - gtest_test_utils.GetSourceDir(), - '_googletest-output-test_normalized_actual.txt', - ), - 'wb', - ).write(normalized_actual) - open( - os.path.join( - gtest_test_utils.GetSourceDir(), - '_googletest-output-test_normalized_golden.txt', - ), - 'wb', - ).write(normalized_golden) - - self.assertEqual(normalized_golden, normalized_actual) - - -if __name__ == '__main__': - if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: - # unittest.main() can't handle unknown flags - sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) - - if GENGOLDEN_FLAG in sys.argv: - if CAN_GENERATE_GOLDEN_FILE: - output = GetOutputOfAllCommands() - golden_file = open(GOLDEN_PATH, 'wb') - golden_file.write(output.encode()) - golden_file.close() - else: - message = """Unable to write a golden file when compiled in an environment + def RemoveUnsupportedTests(self, test_output): + if not SUPPORTS_DEATH_TESTS: + test_output = RemoveMatchingTests(test_output, "DeathTest") + if not SUPPORTS_TYPED_TESTS: + test_output = RemoveMatchingTests(test_output, "TypedTest") + test_output = RemoveMatchingTests(test_output, "TypedDeathTest") + test_output = RemoveMatchingTests(test_output, "TypeParamDeathTest") + if not SUPPORTS_THREADS: + test_output = RemoveMatchingTests( + test_output, "ExpectFailureWithThreadsTest" + ) + test_output = RemoveMatchingTests( + test_output, "ScopedFakeTestPartResultReporterTest" + ) + test_output = RemoveMatchingTests(test_output, "WorksConcurrently") + if not SUPPORTS_STACK_TRACES: + test_output = RemoveStackTraces(test_output) + + return test_output + + def testOutput(self): + output = GetOutputOfAllCommands() + + golden_file = open(GOLDEN_PATH, "rb") + # A mis-configured source control system can cause \r appear in EOL + # sequences when we read the golden file irrespective of an operating + # system used. Therefore, we need to strip those \r's from newlines + # unconditionally. + golden = ToUnixLineEnding(golden_file.read().decode()) + golden_file.close() + + # We want the test to pass regardless of certain features being + # supported or not. + + # We still have to remove type name specifics in all cases. + normalized_actual = RemoveTypeInfoDetails(output) + normalized_golden = RemoveTypeInfoDetails(golden) + + if CAN_GENERATE_GOLDEN_FILE: + self.assertEqual( + normalized_golden, + normalized_actual, + "\n".join( + difflib.unified_diff( + normalized_golden.split("\n"), + normalized_actual.split("\n"), + "golden", + "actual", + ) + ), + ) + else: + normalized_actual = NormalizeToCurrentPlatform( + RemoveTestCounts(normalized_actual) + ) + normalized_golden = NormalizeToCurrentPlatform( + RemoveTestCounts(self.RemoveUnsupportedTests(normalized_golden)) + ) + + # This code is very handy when debugging golden file differences: + if os.getenv("DEBUG_GTEST_OUTPUT_TEST"): + open( + os.path.join( + gtest_test_utils.GetSourceDir(), + "_googletest-output-test_normalized_actual.txt", + ), + "wb", + ).write(normalized_actual) + open( + os.path.join( + gtest_test_utils.GetSourceDir(), + "_googletest-output-test_normalized_golden.txt", + ), + "wb", + ).write(normalized_golden) + + self.assertEqual(normalized_golden, normalized_actual) + + +if __name__ == "__main__": + if NO_STACKTRACE_SUPPORT_FLAG in sys.argv: + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + + if GENGOLDEN_FLAG in sys.argv: + if CAN_GENERATE_GOLDEN_FILE: + output = GetOutputOfAllCommands() + golden_file = open(GOLDEN_PATH, "wb") + golden_file.write(output.encode()) + golden_file.close() + else: + message = """Unable to write a golden file when compiled in an environment that does not support all the required features (death tests, typed tests, stack traces, and multiple threads). Please build this test and generate the golden file using Blaze on Linux.""" - sys.stderr.write(message) - sys.exit(1) - else: - gtest_test_utils.Main() + sys.stderr.write(message) + sys.exit(1) + else: + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name1-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name1-test.py index 4886e49e..be61f3f4 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name1-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name1-test.py @@ -32,32 +32,32 @@ from googletest.test import gtest_test_utils -binary_name = 'googletest-param-test-invalid-name1-test_' +binary_name = "googletest-param-test-invalid-name1-test_" COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) def Assert(condition): - if not condition: - raise AssertionError + if not condition: + raise AssertionError def TestExitCodeAndOutput(command): - """Runs the given command and verifies its exit code and output.""" + """Runs the given command and verifies its exit code and output.""" - err = 'Parameterized test name \'"InvalidWithQuotes"\' is invalid' + err = "Parameterized test name '\"InvalidWithQuotes\"' is invalid" - p = gtest_test_utils.Subprocess(command) - Assert(p.terminated_by_signal) + p = gtest_test_utils.Subprocess(command) + Assert(p.terminated_by_signal) - # Verify the output message contains appropriate output - Assert(err in p.output) + # Verify the output message contains appropriate output + Assert(err in p.output) class GTestParamTestInvalidName1Test(gtest_test_utils.TestCase): - def testExitCodeAndOutput(self): - TestExitCodeAndOutput(COMMAND) + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name2-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name2-test.py index bcd8ddf0..a674a63d 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name2-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-param-test-invalid-name2-test.py @@ -32,32 +32,32 @@ from googletest.test import gtest_test_utils -binary_name = 'googletest-param-test-invalid-name2-test_' +binary_name = "googletest-param-test-invalid-name2-test_" COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) def Assert(condition): - if not condition: - raise AssertionError + if not condition: + raise AssertionError def TestExitCodeAndOutput(command): - """Runs the given command and verifies its exit code and output.""" + """Runs the given command and verifies its exit code and output.""" - err = "Duplicate parameterized test name 'a'" + err = "Duplicate parameterized test name 'a'" - p = gtest_test_utils.Subprocess(command) - Assert(p.terminated_by_signal) + p = gtest_test_utils.Subprocess(command) + Assert(p.terminated_by_signal) - # Check for appropriate output - Assert(err in p.output) + # Check for appropriate output + Assert(err in p.output) class GTestParamTestInvalidName2Test(gtest_test_utils.TestCase): - def testExitCodeAndOutput(self): - TestExitCodeAndOutput(COMMAND) + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-setuptestsuite-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-setuptestsuite-test.py index 899531f3..af4696fd 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-setuptestsuite-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-setuptestsuite-test.py @@ -33,26 +33,24 @@ from googletest.test import gtest_test_utils -COMMAND = gtest_test_utils.GetTestExecutablePath( - 'googletest-setuptestsuite-test_' -) +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-setuptestsuite-test_") class GTestSetUpTestSuiteTest(gtest_test_utils.TestCase): - def testSetupErrorAndTearDownError(self): - p = gtest_test_utils.Subprocess(COMMAND) - self.assertNotEqual(p.exit_code, 0, msg=p.output) + def testSetupErrorAndTearDownError(self): + p = gtest_test_utils.Subprocess(COMMAND) + self.assertNotEqual(p.exit_code, 0, msg=p.output) - self.assertIn( - ( - '[ FAILED ] SetupFailTest: SetUpTestSuite or TearDownTestSuite\n[' - ' FAILED ] TearDownFailTest: SetUpTestSuite or' - ' TearDownTestSuite\n\n 2 FAILED TEST SUITES\n' - ), - p.output, - ) + self.assertIn( + ( + "[ FAILED ] SetupFailTest: SetUpTestSuite or TearDownTestSuite\n[" + " FAILED ] TearDownFailTest: SetUpTestSuite or" + " TearDownTestSuite\n\n 2 FAILED TEST SUITES\n" + ), + p.output, + ) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-shuffle-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-shuffle-test.py index 61e3a15e..0db01b8a 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-shuffle-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-shuffle-test.py @@ -34,13 +34,13 @@ from googletest.test import gtest_test_utils # Command to run the googletest-shuffle-test_ program. -COMMAND = gtest_test_utils.GetTestExecutablePath('googletest-shuffle-test_') +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-shuffle-test_") # The environment variables for test sharding. -TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' -SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' +TOTAL_SHARDS_ENV_VAR = "GTEST_TOTAL_SHARDS" +SHARD_INDEX_ENV_VAR = "GTEST_SHARD_INDEX" -TEST_FILTER = 'A*.A:A*.B:C*' +TEST_FILTER = "A*.A:A*.B:C*" ALL_TESTS = [] ACTIVE_TESTS = [] @@ -54,325 +54,311 @@ def AlsoRunDisabledTestsFlag(): - return '--gtest_also_run_disabled_tests' + return "--gtest_also_run_disabled_tests" def FilterFlag(test_filter): - return '--gtest_filter=%s' % (test_filter,) + return "--gtest_filter=%s" % (test_filter,) def RepeatFlag(n): - return '--gtest_repeat=%s' % (n,) + return "--gtest_repeat=%s" % (n,) def ShuffleFlag(): - return '--gtest_shuffle' + return "--gtest_shuffle" def RandomSeedFlag(n): - return '--gtest_random_seed=%s' % (n,) + return "--gtest_random_seed=%s" % (n,) def RunAndReturnOutput(extra_env, args): - """Runs the test program and returns its output.""" + """Runs the test program and returns its output.""" - environ_copy = os.environ.copy() - environ_copy.update(extra_env) + environ_copy = os.environ.copy() + environ_copy.update(extra_env) - return gtest_test_utils.Subprocess([COMMAND] + args, env=environ_copy).output + return gtest_test_utils.Subprocess([COMMAND] + args, env=environ_copy).output def GetTestsForAllIterations(extra_env, args): - """Runs the test program and returns a list of test lists. + """Runs the test program and returns a list of test lists. - Args: - extra_env: a map from environment variables to their values - args: command line flags to pass to googletest-shuffle-test_ + Args: + extra_env: a map from environment variables to their values + args: command line flags to pass to googletest-shuffle-test_ - Returns: - A list where the i-th element is the list of tests run in the i-th - test iteration. - """ + Returns: + A list where the i-th element is the list of tests run in the i-th + test iteration. + """ - test_iterations = [] - for line in RunAndReturnOutput(extra_env, args).split('\n'): - if line.startswith('----'): - tests = [] - test_iterations.append(tests) - elif line.strip(): - tests.append(line.strip()) # 'TestCaseName.TestName' + test_iterations = [] + for line in RunAndReturnOutput(extra_env, args).split("\n"): + if line.startswith("----"): + tests = [] + test_iterations.append(tests) + elif line.strip(): + tests.append(line.strip()) # 'TestCaseName.TestName' - return test_iterations + return test_iterations def GetTestCases(tests): - """Returns a list of test cases in the given full test names. + """Returns a list of test cases in the given full test names. - Args: - tests: a list of full test names + Args: + tests: a list of full test names - Returns: - A list of test cases from 'tests', in their original order. - Consecutive duplicates are removed. - """ + Returns: + A list of test cases from 'tests', in their original order. + Consecutive duplicates are removed. + """ - test_cases = [] - for test in tests: - test_case = test.split('.')[0] - if not test_case in test_cases: - test_cases.append(test_case) + test_cases = [] + for test in tests: + test_case = test.split(".")[0] + if not test_case in test_cases: + test_cases.append(test_case) - return test_cases + return test_cases def CalculateTestLists(): - """Calculates the list of tests run under different flags.""" + """Calculates the list of tests run under different flags.""" - if not ALL_TESTS: - ALL_TESTS.extend( - GetTestsForAllIterations({}, [AlsoRunDisabledTestsFlag()])[0] - ) + if not ALL_TESTS: + ALL_TESTS.extend(GetTestsForAllIterations({}, [AlsoRunDisabledTestsFlag()])[0]) - if not ACTIVE_TESTS: - ACTIVE_TESTS.extend(GetTestsForAllIterations({}, [])[0]) + if not ACTIVE_TESTS: + ACTIVE_TESTS.extend(GetTestsForAllIterations({}, [])[0]) - if not FILTERED_TESTS: - FILTERED_TESTS.extend( - GetTestsForAllIterations({}, [FilterFlag(TEST_FILTER)])[0] - ) + if not FILTERED_TESTS: + FILTERED_TESTS.extend( + GetTestsForAllIterations({}, [FilterFlag(TEST_FILTER)])[0] + ) - if not SHARDED_TESTS: - SHARDED_TESTS.extend( - GetTestsForAllIterations( - {TOTAL_SHARDS_ENV_VAR: '3', SHARD_INDEX_ENV_VAR: '1'}, [] - )[0] - ) + if not SHARDED_TESTS: + SHARDED_TESTS.extend( + GetTestsForAllIterations( + {TOTAL_SHARDS_ENV_VAR: "3", SHARD_INDEX_ENV_VAR: "1"}, [] + )[0] + ) - if not SHUFFLED_ALL_TESTS: - SHUFFLED_ALL_TESTS.extend( - GetTestsForAllIterations( - {}, [AlsoRunDisabledTestsFlag(), ShuffleFlag(), RandomSeedFlag(1)] - )[0] - ) + if not SHUFFLED_ALL_TESTS: + SHUFFLED_ALL_TESTS.extend( + GetTestsForAllIterations( + {}, [AlsoRunDisabledTestsFlag(), ShuffleFlag(), RandomSeedFlag(1)] + )[0] + ) - if not SHUFFLED_ACTIVE_TESTS: - SHUFFLED_ACTIVE_TESTS.extend( - GetTestsForAllIterations({}, [ShuffleFlag(), RandomSeedFlag(1)])[0] - ) + if not SHUFFLED_ACTIVE_TESTS: + SHUFFLED_ACTIVE_TESTS.extend( + GetTestsForAllIterations({}, [ShuffleFlag(), RandomSeedFlag(1)])[0] + ) - if not SHUFFLED_FILTERED_TESTS: - SHUFFLED_FILTERED_TESTS.extend( - GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(1), FilterFlag(TEST_FILTER)] - )[0] - ) + if not SHUFFLED_FILTERED_TESTS: + SHUFFLED_FILTERED_TESTS.extend( + GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), FilterFlag(TEST_FILTER)] + )[0] + ) - if not SHUFFLED_SHARDED_TESTS: - SHUFFLED_SHARDED_TESTS.extend( - GetTestsForAllIterations( - {TOTAL_SHARDS_ENV_VAR: '3', SHARD_INDEX_ENV_VAR: '1'}, - [ShuffleFlag(), RandomSeedFlag(1)], - )[0] - ) + if not SHUFFLED_SHARDED_TESTS: + SHUFFLED_SHARDED_TESTS.extend( + GetTestsForAllIterations( + {TOTAL_SHARDS_ENV_VAR: "3", SHARD_INDEX_ENV_VAR: "1"}, + [ShuffleFlag(), RandomSeedFlag(1)], + )[0] + ) class GTestShuffleUnitTest(gtest_test_utils.TestCase): - """Tests test shuffling.""" - - def setUp(self): - CalculateTestLists() - - def testShufflePreservesNumberOfTests(self): - self.assertEqual(len(ALL_TESTS), len(SHUFFLED_ALL_TESTS)) - self.assertEqual(len(ACTIVE_TESTS), len(SHUFFLED_ACTIVE_TESTS)) - self.assertEqual(len(FILTERED_TESTS), len(SHUFFLED_FILTERED_TESTS)) - self.assertEqual(len(SHARDED_TESTS), len(SHUFFLED_SHARDED_TESTS)) - - def testShuffleChangesTestOrder(self): - self.assertTrue(SHUFFLED_ALL_TESTS != ALL_TESTS, SHUFFLED_ALL_TESTS) - self.assertTrue( - SHUFFLED_ACTIVE_TESTS != ACTIVE_TESTS, SHUFFLED_ACTIVE_TESTS - ) - self.assertTrue( - SHUFFLED_FILTERED_TESTS != FILTERED_TESTS, SHUFFLED_FILTERED_TESTS - ) - self.assertTrue( - SHUFFLED_SHARDED_TESTS != SHARDED_TESTS, SHUFFLED_SHARDED_TESTS - ) - - def testShuffleChangesTestCaseOrder(self): - self.assertTrue( - GetTestCases(SHUFFLED_ALL_TESTS) != GetTestCases(ALL_TESTS), - GetTestCases(SHUFFLED_ALL_TESTS), - ) - self.assertTrue( - GetTestCases(SHUFFLED_ACTIVE_TESTS) != GetTestCases(ACTIVE_TESTS), - GetTestCases(SHUFFLED_ACTIVE_TESTS), - ) - self.assertTrue( - GetTestCases(SHUFFLED_FILTERED_TESTS) != GetTestCases(FILTERED_TESTS), - GetTestCases(SHUFFLED_FILTERED_TESTS), - ) - self.assertTrue( - GetTestCases(SHUFFLED_SHARDED_TESTS) != GetTestCases(SHARDED_TESTS), - GetTestCases(SHUFFLED_SHARDED_TESTS), - ) - - def testShuffleDoesNotRepeatTest(self): - for test in SHUFFLED_ALL_TESTS: - self.assertEqual( - 1, - SHUFFLED_ALL_TESTS.count(test), - '%s appears more than once' % (test,), - ) - for test in SHUFFLED_ACTIVE_TESTS: - self.assertEqual( - 1, - SHUFFLED_ACTIVE_TESTS.count(test), - '%s appears more than once' % (test,), - ) - for test in SHUFFLED_FILTERED_TESTS: - self.assertEqual( - 1, - SHUFFLED_FILTERED_TESTS.count(test), - '%s appears more than once' % (test,), - ) - for test in SHUFFLED_SHARDED_TESTS: - self.assertEqual( - 1, - SHUFFLED_SHARDED_TESTS.count(test), - '%s appears more than once' % (test,), - ) - - def testShuffleDoesNotCreateNewTest(self): - for test in SHUFFLED_ALL_TESTS: - self.assertTrue(test in ALL_TESTS, '%s is an invalid test' % (test,)) - for test in SHUFFLED_ACTIVE_TESTS: - self.assertTrue(test in ACTIVE_TESTS, '%s is an invalid test' % (test,)) - for test in SHUFFLED_FILTERED_TESTS: - self.assertTrue(test in FILTERED_TESTS, '%s is an invalid test' % (test,)) - for test in SHUFFLED_SHARDED_TESTS: - self.assertTrue(test in SHARDED_TESTS, '%s is an invalid test' % (test,)) - - def testShuffleIncludesAllTests(self): - for test in ALL_TESTS: - self.assertTrue(test in SHUFFLED_ALL_TESTS, '%s is missing' % (test,)) - for test in ACTIVE_TESTS: - self.assertTrue(test in SHUFFLED_ACTIVE_TESTS, '%s is missing' % (test,)) - for test in FILTERED_TESTS: - self.assertTrue( - test in SHUFFLED_FILTERED_TESTS, '%s is missing' % (test,) - ) - for test in SHARDED_TESTS: - self.assertTrue(test in SHUFFLED_SHARDED_TESTS, '%s is missing' % (test,)) - - def testShuffleLeavesDeathTestsAtFront(self): - non_death_test_found = False - for test in SHUFFLED_ACTIVE_TESTS: - if 'DeathTest.' in test: + """Tests test shuffling.""" + + def setUp(self): + CalculateTestLists() + + def testShufflePreservesNumberOfTests(self): + self.assertEqual(len(ALL_TESTS), len(SHUFFLED_ALL_TESTS)) + self.assertEqual(len(ACTIVE_TESTS), len(SHUFFLED_ACTIVE_TESTS)) + self.assertEqual(len(FILTERED_TESTS), len(SHUFFLED_FILTERED_TESTS)) + self.assertEqual(len(SHARDED_TESTS), len(SHUFFLED_SHARDED_TESTS)) + + def testShuffleChangesTestOrder(self): + self.assertTrue(SHUFFLED_ALL_TESTS != ALL_TESTS, SHUFFLED_ALL_TESTS) + self.assertTrue(SHUFFLED_ACTIVE_TESTS != ACTIVE_TESTS, SHUFFLED_ACTIVE_TESTS) self.assertTrue( - not non_death_test_found, - '%s appears after a non-death test' % (test,), + SHUFFLED_FILTERED_TESTS != FILTERED_TESTS, SHUFFLED_FILTERED_TESTS ) - else: - non_death_test_found = True + self.assertTrue(SHUFFLED_SHARDED_TESTS != SHARDED_TESTS, SHUFFLED_SHARDED_TESTS) - def _VerifyTestCasesDoNotInterleave(self, tests): - test_cases = [] - for test in tests: - [test_case, _] = test.split('.') - if test_cases and test_cases[-1] != test_case: - test_cases.append(test_case) - self.assertEqual( - 1, - test_cases.count(test_case), - 'Test case %s is not grouped together in %s' % (test_case, tests), + def testShuffleChangesTestCaseOrder(self): + self.assertTrue( + GetTestCases(SHUFFLED_ALL_TESTS) != GetTestCases(ALL_TESTS), + GetTestCases(SHUFFLED_ALL_TESTS), + ) + self.assertTrue( + GetTestCases(SHUFFLED_ACTIVE_TESTS) != GetTestCases(ACTIVE_TESTS), + GetTestCases(SHUFFLED_ACTIVE_TESTS), + ) + self.assertTrue( + GetTestCases(SHUFFLED_FILTERED_TESTS) != GetTestCases(FILTERED_TESTS), + GetTestCases(SHUFFLED_FILTERED_TESTS), + ) + self.assertTrue( + GetTestCases(SHUFFLED_SHARDED_TESTS) != GetTestCases(SHARDED_TESTS), + GetTestCases(SHUFFLED_SHARDED_TESTS), ) - def testShuffleDoesNotInterleaveTestCases(self): - self._VerifyTestCasesDoNotInterleave(SHUFFLED_ALL_TESTS) - self._VerifyTestCasesDoNotInterleave(SHUFFLED_ACTIVE_TESTS) - self._VerifyTestCasesDoNotInterleave(SHUFFLED_FILTERED_TESTS) - self._VerifyTestCasesDoNotInterleave(SHUFFLED_SHARDED_TESTS) - - def testShuffleRestoresOrderAfterEachIteration(self): - # Get the test lists in all 3 iterations, using random seed 1, 2, - # and 3 respectively. Google Test picks a different seed in each - # iteration, and this test depends on the current implementation - # picking successive numbers. This dependency is not ideal, but - # makes the test much easier to write. - # pylint: disable-next=unbalanced-tuple-unpacking - [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( - GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)] + def testShuffleDoesNotRepeatTest(self): + for test in SHUFFLED_ALL_TESTS: + self.assertEqual( + 1, + SHUFFLED_ALL_TESTS.count(test), + "%s appears more than once" % (test,), + ) + for test in SHUFFLED_ACTIVE_TESTS: + self.assertEqual( + 1, + SHUFFLED_ACTIVE_TESTS.count(test), + "%s appears more than once" % (test,), + ) + for test in SHUFFLED_FILTERED_TESTS: + self.assertEqual( + 1, + SHUFFLED_FILTERED_TESTS.count(test), + "%s appears more than once" % (test,), + ) + for test in SHUFFLED_SHARDED_TESTS: + self.assertEqual( + 1, + SHUFFLED_SHARDED_TESTS.count(test), + "%s appears more than once" % (test,), + ) + + def testShuffleDoesNotCreateNewTest(self): + for test in SHUFFLED_ALL_TESTS: + self.assertTrue(test in ALL_TESTS, "%s is an invalid test" % (test,)) + for test in SHUFFLED_ACTIVE_TESTS: + self.assertTrue(test in ACTIVE_TESTS, "%s is an invalid test" % (test,)) + for test in SHUFFLED_FILTERED_TESTS: + self.assertTrue(test in FILTERED_TESTS, "%s is an invalid test" % (test,)) + for test in SHUFFLED_SHARDED_TESTS: + self.assertTrue(test in SHARDED_TESTS, "%s is an invalid test" % (test,)) + + def testShuffleIncludesAllTests(self): + for test in ALL_TESTS: + self.assertTrue(test in SHUFFLED_ALL_TESTS, "%s is missing" % (test,)) + for test in ACTIVE_TESTS: + self.assertTrue(test in SHUFFLED_ACTIVE_TESTS, "%s is missing" % (test,)) + for test in FILTERED_TESTS: + self.assertTrue(test in SHUFFLED_FILTERED_TESTS, "%s is missing" % (test,)) + for test in SHARDED_TESTS: + self.assertTrue(test in SHUFFLED_SHARDED_TESTS, "%s is missing" % (test,)) + + def testShuffleLeavesDeathTestsAtFront(self): + non_death_test_found = False + for test in SHUFFLED_ACTIVE_TESTS: + if "DeathTest." in test: + self.assertTrue( + not non_death_test_found, + "%s appears after a non-death test" % (test,), + ) + else: + non_death_test_found = True + + def _VerifyTestCasesDoNotInterleave(self, tests): + test_cases = [] + for test in tests: + [test_case, _] = test.split(".") + if test_cases and test_cases[-1] != test_case: + test_cases.append(test_case) + self.assertEqual( + 1, + test_cases.count(test_case), + "Test case %s is not grouped together in %s" % (test_case, tests), + ) + + def testShuffleDoesNotInterleaveTestCases(self): + self._VerifyTestCasesDoNotInterleave(SHUFFLED_ALL_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_ACTIVE_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_FILTERED_TESTS) + self._VerifyTestCasesDoNotInterleave(SHUFFLED_SHARDED_TESTS) + + def testShuffleRestoresOrderAfterEachIteration(self): + # Get the test lists in all 3 iterations, using random seed 1, 2, + # and 3 respectively. Google Test picks a different seed in each + # iteration, and this test depends on the current implementation + # picking successive numbers. This dependency is not ideal, but + # makes the test much easier to write. + # pylint: disable-next=unbalanced-tuple-unpacking + [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( + GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)] + ) ) - ) - - # Make sure running the tests with random seed 1 gets the same - # order as in iteration 1 above. - tests_with_seed1 = GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(1)] - )[0] - self.assertEqual(tests_in_iteration1, tests_with_seed1) - - # Make sure running the tests with random seed 2 gets the same - # order as in iteration 2 above. Success means that Google Test - # correctly restores the test order before re-shuffling at the - # beginning of iteration 2. - tests_with_seed2 = GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(2)] - )[0] - self.assertEqual(tests_in_iteration2, tests_with_seed2) - - # Make sure running the tests with random seed 3 gets the same - # order as in iteration 3 above. Success means that Google Test - # correctly restores the test order before re-shuffling at the - # beginning of iteration 3. - tests_with_seed3 = GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(3)] - )[0] - self.assertEqual(tests_in_iteration3, tests_with_seed3) - - def testShuffleGeneratesNewOrderInEachIteration(self): - # pylint: disable-next=unbalanced-tuple-unpacking - [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( - GetTestsForAllIterations( - {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)] + + # Make sure running the tests with random seed 1 gets the same + # order as in iteration 1 above. + tests_with_seed1 = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1)] + )[0] + self.assertEqual(tests_in_iteration1, tests_with_seed1) + + # Make sure running the tests with random seed 2 gets the same + # order as in iteration 2 above. Success means that Google Test + # correctly restores the test order before re-shuffling at the + # beginning of iteration 2. + tests_with_seed2 = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(2)] + )[0] + self.assertEqual(tests_in_iteration2, tests_with_seed2) + + # Make sure running the tests with random seed 3 gets the same + # order as in iteration 3 above. Success means that Google Test + # correctly restores the test order before re-shuffling at the + # beginning of iteration 3. + tests_with_seed3 = GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(3)] + )[0] + self.assertEqual(tests_in_iteration3, tests_with_seed3) + + def testShuffleGeneratesNewOrderInEachIteration(self): + # pylint: disable-next=unbalanced-tuple-unpacking + [tests_in_iteration1, tests_in_iteration2, tests_in_iteration3] = ( + GetTestsForAllIterations( + {}, [ShuffleFlag(), RandomSeedFlag(1), RepeatFlag(3)] + ) ) - ) - - self.assertTrue( - tests_in_iteration1 != tests_in_iteration2, tests_in_iteration1 - ) - self.assertTrue( - tests_in_iteration1 != tests_in_iteration3, tests_in_iteration1 - ) - self.assertTrue( - tests_in_iteration2 != tests_in_iteration3, tests_in_iteration2 - ) - - def testShuffleShardedTestsPreservesPartition(self): - # If we run M tests on N shards, the same M tests should be run in - # total, regardless of the random seeds used by the shards. - tests1 = GetTestsForAllIterations( - {TOTAL_SHARDS_ENV_VAR: '3', SHARD_INDEX_ENV_VAR: '0'}, - [ShuffleFlag(), RandomSeedFlag(1)], - )[0] - tests2 = GetTestsForAllIterations( - {TOTAL_SHARDS_ENV_VAR: '3', SHARD_INDEX_ENV_VAR: '1'}, - [ShuffleFlag(), RandomSeedFlag(20)], - )[0] - tests3 = GetTestsForAllIterations( - {TOTAL_SHARDS_ENV_VAR: '3', SHARD_INDEX_ENV_VAR: '2'}, - [ShuffleFlag(), RandomSeedFlag(25)], - )[0] - sorted_sharded_tests = tests1 + tests2 + tests3 - sorted_sharded_tests.sort() - sorted_active_tests = [] - sorted_active_tests.extend(ACTIVE_TESTS) - sorted_active_tests.sort() - self.assertEqual(sorted_active_tests, sorted_sharded_tests) - - -if __name__ == '__main__': - gtest_test_utils.Main() + + self.assertTrue(tests_in_iteration1 != tests_in_iteration2, tests_in_iteration1) + self.assertTrue(tests_in_iteration1 != tests_in_iteration3, tests_in_iteration1) + self.assertTrue(tests_in_iteration2 != tests_in_iteration3, tests_in_iteration2) + + def testShuffleShardedTestsPreservesPartition(self): + # If we run M tests on N shards, the same M tests should be run in + # total, regardless of the random seeds used by the shards. + tests1 = GetTestsForAllIterations( + {TOTAL_SHARDS_ENV_VAR: "3", SHARD_INDEX_ENV_VAR: "0"}, + [ShuffleFlag(), RandomSeedFlag(1)], + )[0] + tests2 = GetTestsForAllIterations( + {TOTAL_SHARDS_ENV_VAR: "3", SHARD_INDEX_ENV_VAR: "1"}, + [ShuffleFlag(), RandomSeedFlag(20)], + )[0] + tests3 = GetTestsForAllIterations( + {TOTAL_SHARDS_ENV_VAR: "3", SHARD_INDEX_ENV_VAR: "2"}, + [ShuffleFlag(), RandomSeedFlag(25)], + )[0] + sorted_sharded_tests = tests1 + tests2 + tests3 + sorted_sharded_tests.sort() + sorted_active_tests = [] + sorted_active_tests.extend(ACTIVE_TESTS) + sorted_active_tests.sort() + self.assertEqual(sorted_active_tests, sorted_sharded_tests) + + +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-throw-on-failure-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-throw-on-failure-test.py index 106b0044..f6266d0b 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-throw-on-failure-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-throw-on-failure-test.py @@ -42,125 +42,123 @@ # Constants. # The command line flag for enabling/disabling the throw-on-failure mode. -THROW_ON_FAILURE = 'gtest_throw_on_failure' +THROW_ON_FAILURE = "gtest_throw_on_failure" # Path to the googletest-throw-on-failure-test_ program, compiled with # exceptions disabled. -EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'googletest-throw-on-failure-test_' -) +EXE_PATH = gtest_test_utils.GetTestExecutablePath("googletest-throw-on-failure-test_") # Utilities. def SetEnvVar(env_var, value): - """Sets an environment variable. + """Sets an environment variable. - Sets an environment variable to a given value; unsets it when the - given value is None. + Sets an environment variable to a given value; unsets it when the + given value is None. - Args: - env_var: environment variable. - value: value to set. - """ + Args: + env_var: environment variable. + value: value to set. + """ - env_var = env_var.upper() - if value is not None: - os.environ[env_var] = value - elif env_var in os.environ: - del os.environ[env_var] + env_var = env_var.upper() + if value is not None: + os.environ[env_var] = value + elif env_var in os.environ: + del os.environ[env_var] def Run(command): - """Runs a command; returns True/False if its exit code is/isn't 0.""" + """Runs a command; returns True/False if its exit code is/isn't 0.""" - print('Running "%s". . .' % ' '.join(command)) - p = gtest_test_utils.Subprocess(command) - return p.exited and p.exit_code == 0 + print('Running "%s". . .' % " ".join(command)) + p = gtest_test_utils.Subprocess(command) + return p.exited and p.exit_code == 0 # The tests. class ThrowOnFailureTest(gtest_test_utils.TestCase): - """Tests the throw-on-failure mode.""" - - def RunAndVerify(self, env_var_value, flag_value, should_fail): - """Runs googletest-throw-on-failure-test_ and verifies its behavior. - - Runs googletest-throw-on-failure-test_ and verifies that it does - (or does not) exit with a non-zero code. - - Args: - env_var_value: value of the GTEST_BREAK_ON_FAILURE environment - variable; None if the variable should be unset. - flag_value: value of the --gtest_break_on_failure flag; None if the - flag should not be present. - should_fail: True if and only if the program is expected to fail. - """ - - SetEnvVar(THROW_ON_FAILURE, env_var_value) - - if env_var_value is None: - env_var_value_msg = ' is not set' - else: - env_var_value_msg = '=' + env_var_value - - if flag_value is None: - flag = '' - elif flag_value == '0': - flag = '--%s=0' % THROW_ON_FAILURE - else: - flag = '--%s' % THROW_ON_FAILURE - - command = [EXE_PATH] - if flag: - command.append(flag) - - if should_fail: - should_or_not = 'should' - else: - should_or_not = 'should not' - - failed = not Run(command) - - SetEnvVar(THROW_ON_FAILURE, None) - - msg = ( - 'when %s%s, an assertion failure in "%s" %s cause a non-zero exit code.' - % ( - THROW_ON_FAILURE, - env_var_value_msg, - ' '.join(command), - should_or_not, + """Tests the throw-on-failure mode.""" + + def RunAndVerify(self, env_var_value, flag_value, should_fail): + """Runs googletest-throw-on-failure-test_ and verifies its behavior. + + Runs googletest-throw-on-failure-test_ and verifies that it does + (or does not) exit with a non-zero code. + + Args: + env_var_value: value of the GTEST_BREAK_ON_FAILURE environment + variable; None if the variable should be unset. + flag_value: value of the --gtest_break_on_failure flag; None if the + flag should not be present. + should_fail: True if and only if the program is expected to fail. + """ + + SetEnvVar(THROW_ON_FAILURE, env_var_value) + + if env_var_value is None: + env_var_value_msg = " is not set" + else: + env_var_value_msg = "=" + env_var_value + + if flag_value is None: + flag = "" + elif flag_value == "0": + flag = "--%s=0" % THROW_ON_FAILURE + else: + flag = "--%s" % THROW_ON_FAILURE + + command = [EXE_PATH] + if flag: + command.append(flag) + + if should_fail: + should_or_not = "should" + else: + should_or_not = "should not" + + failed = not Run(command) + + SetEnvVar(THROW_ON_FAILURE, None) + + msg = ( + 'when %s%s, an assertion failure in "%s" %s cause a non-zero exit code.' + % ( + THROW_ON_FAILURE, + env_var_value_msg, + " ".join(command), + should_or_not, + ) ) - ) - self.assertTrue(failed == should_fail, msg) + self.assertTrue(failed == should_fail, msg) - def testDefaultBehavior(self): - """Tests the behavior of the default mode.""" + def testDefaultBehavior(self): + """Tests the behavior of the default mode.""" - self.RunAndVerify(env_var_value=None, flag_value=None, should_fail=False) + self.RunAndVerify(env_var_value=None, flag_value=None, should_fail=False) - def testThrowOnFailureEnvVar(self): - """Tests using the GTEST_THROW_ON_FAILURE environment variable.""" + def testThrowOnFailureEnvVar(self): + """Tests using the GTEST_THROW_ON_FAILURE environment variable.""" - self.RunAndVerify(env_var_value='0', flag_value=None, should_fail=False) - self.RunAndVerify(env_var_value='1', flag_value=None, should_fail=True) + self.RunAndVerify(env_var_value="0", flag_value=None, should_fail=False) + self.RunAndVerify(env_var_value="1", flag_value=None, should_fail=True) - def testThrowOnFailureFlag(self): - """Tests using the --gtest_throw_on_failure flag.""" + def testThrowOnFailureFlag(self): + """Tests using the --gtest_throw_on_failure flag.""" - self.RunAndVerify(env_var_value=None, flag_value='0', should_fail=False) - self.RunAndVerify(env_var_value=None, flag_value='1', should_fail=True) + self.RunAndVerify(env_var_value=None, flag_value="0", should_fail=False) + self.RunAndVerify(env_var_value=None, flag_value="1", should_fail=True) - def testThrowOnFailureFlagOverridesEnvVar(self): - """Tests that --gtest_throw_on_failure overrides GTEST_THROW_ON_FAILURE.""" + def testThrowOnFailureFlagOverridesEnvVar(self): + """Tests that --gtest_throw_on_failure overrides GTEST_THROW_ON_FAILURE.""" - self.RunAndVerify(env_var_value='0', flag_value='0', should_fail=False) - self.RunAndVerify(env_var_value='0', flag_value='1', should_fail=True) - self.RunAndVerify(env_var_value='1', flag_value='0', should_fail=False) - self.RunAndVerify(env_var_value='1', flag_value='1', should_fail=True) + self.RunAndVerify(env_var_value="0", flag_value="0", should_fail=False) + self.RunAndVerify(env_var_value="0", flag_value="1", should_fail=True) + self.RunAndVerify(env_var_value="1", flag_value="0", should_fail=False) + self.RunAndVerify(env_var_value="1", flag_value="1", should_fail=True) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-uninitialized-test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-uninitialized-test.py index e5af7c84..df165c0f 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-uninitialized-test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/googletest-uninitialized-test.py @@ -33,38 +33,36 @@ from googletest.test import gtest_test_utils -COMMAND = gtest_test_utils.GetTestExecutablePath( - 'googletest-uninitialized-test_' -) +COMMAND = gtest_test_utils.GetTestExecutablePath("googletest-uninitialized-test_") def Assert(condition): - if not condition: - raise AssertionError + if not condition: + raise AssertionError def AssertEq(expected, actual): - if expected != actual: - print('Expected: %s' % (expected,)) - print(' Actual: %s' % (actual,)) - raise AssertionError + if expected != actual: + print("Expected: %s" % (expected,)) + print(" Actual: %s" % (actual,)) + raise AssertionError def TestExitCodeAndOutput(command): - """Runs the given command and verifies its exit code and output.""" + """Runs the given command and verifies its exit code and output.""" - # Verifies that 'command' exits with code 1. - p = gtest_test_utils.Subprocess(command) - if p.exited and p.exit_code == 0: - Assert('IMPORTANT NOTICE' in p.output) - Assert('InitGoogleTest' in p.output) + # Verifies that 'command' exits with code 1. + p = gtest_test_utils.Subprocess(command) + if p.exited and p.exit_code == 0: + Assert("IMPORTANT NOTICE" in p.output) + Assert("InitGoogleTest" in p.output) class GTestUninitializedTest(gtest_test_utils.TestCase): - def testExitCodeAndOutput(self): - TestExitCodeAndOutput(COMMAND) + def testExitCodeAndOutput(self): + TestExitCodeAndOutput(COMMAND) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_help_test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_help_test.py index 38fc90ff..f643a376 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_help_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_help_test.py @@ -43,141 +43,140 @@ from googletest.test import gtest_test_utils -FREEBSD = ('FreeBSD', 'GNU/kFreeBSD') -NETBSD = ('NetBSD',) -OPENBSD = ('OpenBSD',) +FREEBSD = ("FreeBSD", "GNU/kFreeBSD") +NETBSD = ("NetBSD",) +OPENBSD = ("OpenBSD",) def is_bsd_based_os() -> bool: - """Determine whether or not the OS is BSD-based.""" - if os.name != 'posix': - return False + """Determine whether or not the OS is BSD-based.""" + if os.name != "posix": + return False - return os.uname()[0] in (FREEBSD + NETBSD + OPENBSD) + return os.uname()[0] in (FREEBSD + NETBSD + OPENBSD) -IS_DARWIN = os.name == 'posix' and os.uname()[0] == 'Darwin' -IS_LINUX = os.name == 'posix' and os.uname()[0] == 'Linux' -IS_GNUHURD = os.name == 'posix' and os.uname()[0] == 'GNU' -IS_WINDOWS = os.name == 'nt' +IS_DARWIN = os.name == "posix" and os.uname()[0] == "Darwin" +IS_LINUX = os.name == "posix" and os.uname()[0] == "Linux" +IS_GNUHURD = os.name == "posix" and os.uname()[0] == "GNU" +IS_WINDOWS = os.name == "nt" -PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath('gtest_help_test_') -FLAG_PREFIX = '--gtest_' -DEATH_TEST_STYLE_FLAG = FLAG_PREFIX + 'death_test_style' -STREAM_RESULT_TO_FLAG = FLAG_PREFIX + 'stream_result_to' -LIST_TESTS_FLAG = FLAG_PREFIX + 'list_tests' -INTERNAL_FLAG_FOR_TESTING = FLAG_PREFIX + 'internal_flag_for_testing' +PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath("gtest_help_test_") +FLAG_PREFIX = "--gtest_" +DEATH_TEST_STYLE_FLAG = FLAG_PREFIX + "death_test_style" +STREAM_RESULT_TO_FLAG = FLAG_PREFIX + "stream_result_to" +LIST_TESTS_FLAG = FLAG_PREFIX + "list_tests" +INTERNAL_FLAG_FOR_TESTING = FLAG_PREFIX + "internal_flag_for_testing" SUPPORTS_DEATH_TESTS = ( - 'DeathTest' - in gtest_test_utils.Subprocess([PROGRAM_PATH, LIST_TESTS_FLAG]).output + "DeathTest" in gtest_test_utils.Subprocess([PROGRAM_PATH, LIST_TESTS_FLAG]).output ) -HAS_ABSL_FLAGS = '--has_absl_flags' in sys.argv +HAS_ABSL_FLAGS = "--has_absl_flags" in sys.argv # The help message must match this regex. HELP_REGEX = re.compile( FLAG_PREFIX - + r'list_tests.*' + + r"list_tests.*" + FLAG_PREFIX - + r'filter=.*' + + r"filter=.*" + FLAG_PREFIX - + r'also_run_disabled_tests.*' + + r"also_run_disabled_tests.*" + FLAG_PREFIX - + r'repeat=.*' + + r"repeat=.*" + FLAG_PREFIX - + r'shuffle.*' + + r"shuffle.*" + FLAG_PREFIX - + r'random_seed=.*' + + r"random_seed=.*" + FLAG_PREFIX - + r'color=.*' + + r"color=.*" + FLAG_PREFIX - + r'brief.*' + + r"brief.*" + FLAG_PREFIX - + r'print_time.*' + + r"print_time.*" + FLAG_PREFIX - + r'output=.*' + + r"output=.*" + FLAG_PREFIX - + r'break_on_failure.*' + + r"break_on_failure.*" + FLAG_PREFIX - + r'throw_on_failure.*' + + r"throw_on_failure.*" + FLAG_PREFIX - + r'catch_exceptions=0.*', + + r"catch_exceptions=0.*", re.DOTALL, ) def run_with_flag(flag): - """Runs gtest_help_test_ with the given flag. + """Runs gtest_help_test_ with the given flag. - Returns: - the exit code and the text output as a tuple. - Args: - flag: the command-line flag to pass to gtest_help_test_, or None. - """ + Returns: + the exit code and the text output as a tuple. + Args: + flag: the command-line flag to pass to gtest_help_test_, or None. + """ - if flag is None: - command = [PROGRAM_PATH] - else: - command = [PROGRAM_PATH, flag] - child = gtest_test_utils.Subprocess(command) - return child.exit_code, child.output + if flag is None: + command = [PROGRAM_PATH] + else: + command = [PROGRAM_PATH, flag] + child = gtest_test_utils.Subprocess(command) + return child.exit_code, child.output class GTestHelpTest(gtest_test_utils.TestCase): - """Tests the --help flag and its equivalent forms.""" + """Tests the --help flag and its equivalent forms.""" - def test_prints_help_with_full_flag(self): - """Verifies correct behavior when help flag is specified. + def test_prints_help_with_full_flag(self): + """Verifies correct behavior when help flag is specified. - The right message must be printed and the tests must - skipped when the given flag is specified. - """ + The right message must be printed and the tests must + skipped when the given flag is specified. + """ - exit_code, output = run_with_flag('--help') - if HAS_ABSL_FLAGS: - # The Abseil flags library prints the ProgramUsageMessage() with - # --help and returns 1. - self.assertEqual(1, exit_code) - else: - self.assertEqual(0, exit_code) + exit_code, output = run_with_flag("--help") + if HAS_ABSL_FLAGS: + # The Abseil flags library prints the ProgramUsageMessage() with + # --help and returns 1. + self.assertEqual(1, exit_code) + else: + self.assertEqual(0, exit_code) - self.assertTrue(HELP_REGEX.search(output), output) + self.assertTrue(HELP_REGEX.search(output), output) - if IS_DARWIN or IS_LINUX or IS_GNUHURD or is_bsd_based_os(): - self.assertIn(STREAM_RESULT_TO_FLAG, output) - else: - self.assertNotIn(STREAM_RESULT_TO_FLAG, output) + if IS_DARWIN or IS_LINUX or IS_GNUHURD or is_bsd_based_os(): + self.assertIn(STREAM_RESULT_TO_FLAG, output) + else: + self.assertNotIn(STREAM_RESULT_TO_FLAG, output) - if SUPPORTS_DEATH_TESTS and not IS_WINDOWS: - self.assertIn(DEATH_TEST_STYLE_FLAG, output) - else: - self.assertNotIn(DEATH_TEST_STYLE_FLAG, output) + if SUPPORTS_DEATH_TESTS and not IS_WINDOWS: + self.assertIn(DEATH_TEST_STYLE_FLAG, output) + else: + self.assertNotIn(DEATH_TEST_STYLE_FLAG, output) - def test_runs_tests_without_help_flag(self): - """Verifies correct behavior when no help flag is specified. + def test_runs_tests_without_help_flag(self): + """Verifies correct behavior when no help flag is specified. - Verifies that when no help flag is specified, the tests are run - and the help message is not printed. - """ + Verifies that when no help flag is specified, the tests are run + and the help message is not printed. + """ - exit_code, output = run_with_flag(None) - self.assertNotEqual(exit_code, 0) - self.assertFalse(HELP_REGEX.search(output), output) + exit_code, output = run_with_flag(None) + self.assertNotEqual(exit_code, 0) + self.assertFalse(HELP_REGEX.search(output), output) - def test_runs_tests_with_gtest_internal_flag(self): - """Verifies correct behavior when internal testing flag is specified. + def test_runs_tests_with_gtest_internal_flag(self): + """Verifies correct behavior when internal testing flag is specified. - Verifies that the tests are run and no help message is printed when - a flag starting with Google Test prefix and 'internal_' is supplied. - """ + Verifies that the tests are run and no help message is printed when + a flag starting with Google Test prefix and 'internal_' is supplied. + """ - exit_code, output = run_with_flag(INTERNAL_FLAG_FOR_TESTING) - self.assertNotEqual(exit_code, 0) - self.assertFalse(HELP_REGEX.search(output), output) + exit_code, output = run_with_flag(INTERNAL_FLAG_FOR_TESTING) + self.assertNotEqual(exit_code, 0) + self.assertFalse(HELP_REGEX.search(output), output) -if __name__ == '__main__': - if '--has_absl_flags' in sys.argv: - sys.argv.remove('--has_absl_flags') - gtest_test_utils.Main() +if __name__ == "__main__": + if "--has_absl_flags" in sys.argv: + sys.argv.remove("--has_absl_flags") + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_json_test_utils.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_json_test_utils.py index 694a7a60..d91194cc 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_json_test_utils.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_json_test_utils.py @@ -33,35 +33,35 @@ def normalize(obj): - """Normalize output object. + """Normalize output object. - Args: - obj: Google Test's JSON output object to normalize. + Args: + obj: Google Test's JSON output object to normalize. - Returns: - Normalized output without any references to transient information that may - change from run to run. - """ + Returns: + Normalized output without any references to transient information that may + change from run to run. + """ - def _normalize(key, value): - if key == 'time': - return re.sub(r'^\d+(\.\d+)?s$', '*', value) - elif key == 'timestamp': - return re.sub(r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ$', '*', value) - elif key == 'failure': - value = re.sub(r'^.*[/\\](.*:)\d+\n', '\\1*\n', value) - return re.sub(r'Stack trace:\n(.|\n)*', 'Stack trace:\n*', value) - elif key == 'message': - value = re.sub(r'^.*[/\\](.*:)\d+\n', '\\1*\n', value) - return re.sub(r'Stack trace:\n(.|\n)*', 'Stack trace:\n*', value) - elif key == 'file': - return re.sub(r'^.*[/\\](.*)', '\\1', value) - else: - return normalize(value) + def _normalize(key, value): + if key == "time": + return re.sub(r"^\d+(\.\d+)?s$", "*", value) + elif key == "timestamp": + return re.sub(r"^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\dZ$", "*", value) + elif key == "failure": + value = re.sub(r"^.*[/\\](.*:)\d+\n", "\\1*\n", value) + return re.sub(r"Stack trace:\n(.|\n)*", "Stack trace:\n*", value) + elif key == "message": + value = re.sub(r"^.*[/\\](.*:)\d+\n", "\\1*\n", value) + return re.sub(r"Stack trace:\n(.|\n)*", "Stack trace:\n*", value) + elif key == "file": + return re.sub(r"^.*[/\\](.*)", "\\1", value) + else: + return normalize(value) - if isinstance(obj, dict): - return {k: _normalize(k, v) for k, v in obj.items()} - if isinstance(obj, list): - return [normalize(x) for x in obj] - else: - return obj + if isinstance(obj, dict): + return {k: _normalize(k, v) for k, v in obj.items()} + if isinstance(obj, list): + return [normalize(x) for x in obj] + else: + return obj diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_list_output_unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_list_output_unittest.py index afd521d2..fd255e53 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_list_output_unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_list_output_unittest.py @@ -42,8 +42,8 @@ import re from googletest.test import gtest_test_utils -GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' -GTEST_OUTPUT_FLAG = '--gtest_output' +GTEST_LIST_TESTS_FLAG = "--gtest_list_tests" +GTEST_OUTPUT_FLAG = "--gtest_output" EXPECTED_XML = """<\?xml version="1.0" encoding="UTF-8"\?> @@ -224,66 +224,66 @@ class GTestListTestsOutputUnitTest(gtest_test_utils.TestCase): - """Unit test for Google Test's list tests with output to file functionality.""" + """Unit test for Google Test's list tests with output to file functionality.""" - def testXml(self): - """Verifies XML output for listing tests in a Google Test binary. + def testXml(self): + """Verifies XML output for listing tests in a Google Test binary. - Runs a test program that generates an empty XML output, and - tests that the XML output is expected. - """ - self._TestOutput('xml', EXPECTED_XML) + Runs a test program that generates an empty XML output, and + tests that the XML output is expected. + """ + self._TestOutput("xml", EXPECTED_XML) - def testJSON(self): - """Verifies XML output for listing tests in a Google Test binary. + def testJSON(self): + """Verifies XML output for listing tests in a Google Test binary. - Runs a test program that generates an empty XML output, and - tests that the XML output is expected. - """ - self._TestOutput('json', EXPECTED_JSON) + Runs a test program that generates an empty XML output, and + tests that the XML output is expected. + """ + self._TestOutput("json", EXPECTED_JSON) - def _GetOutput(self, out_format): - file_path = os.path.join( - gtest_test_utils.GetTempDir(), 'test_out.' + out_format - ) - gtest_prog_path = gtest_test_utils.GetTestExecutablePath( - 'gtest_list_output_unittest_' - ) + def _GetOutput(self, out_format): + file_path = os.path.join( + gtest_test_utils.GetTempDir(), "test_out." + out_format + ) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + "gtest_list_output_unittest_" + ) - command = [ - gtest_prog_path, - '%s=%s:%s' % (GTEST_OUTPUT_FLAG, out_format, file_path), - '--gtest_list_tests', - ] - environ_copy = os.environ.copy() - p = gtest_test_utils.Subprocess( - command, env=environ_copy, working_dir=gtest_test_utils.GetTempDir() - ) + command = [ + gtest_prog_path, + "%s=%s:%s" % (GTEST_OUTPUT_FLAG, out_format, file_path), + "--gtest_list_tests", + ] + environ_copy = os.environ.copy() + p = gtest_test_utils.Subprocess( + command, env=environ_copy, working_dir=gtest_test_utils.GetTempDir() + ) - self.assertTrue(p.exited) - self.assertEqual(0, p.exit_code) - self.assertTrue(os.path.isfile(file_path)) - with open(file_path) as f: - result = f.read() - return result + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + self.assertTrue(os.path.isfile(file_path)) + with open(file_path) as f: + result = f.read() + return result - def _TestOutput(self, test_format, expected_output): - actual = self._GetOutput(test_format) - actual_lines = actual.splitlines() - expected_lines = expected_output.splitlines() - line_count = 0 - for actual_line in actual_lines: - expected_line = expected_lines[line_count] - expected_line_re = re.compile(expected_line.strip()) - self.assertTrue( - expected_line_re.match(actual_line.strip()), - 'actual output of "%s",\n' - 'which does not match expected regex of "%s"\n' - 'on line %d' % (actual, expected_output, line_count), - ) - line_count = line_count + 1 + def _TestOutput(self, test_format, expected_output): + actual = self._GetOutput(test_format) + actual_lines = actual.splitlines() + expected_lines = expected_output.splitlines() + line_count = 0 + for actual_line in actual_lines: + expected_line = expected_lines[line_count] + expected_line_re = re.compile(expected_line.strip()) + self.assertTrue( + expected_line_re.match(actual_line.strip()), + 'actual output of "%s",\n' + 'which does not match expected regex of "%s"\n' + "on line %d" % (actual, expected_output, line_count), + ) + line_count = line_count + 1 -if __name__ == '__main__': - os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' - gtest_test_utils.Main() +if __name__ == "__main__": + os.environ["GTEST_STACK_TRACE_DEPTH"] = "1" + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_check_output_test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_check_output_test.py index b30a1650..d47fe497 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_check_output_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_check_output_test.py @@ -38,7 +38,7 @@ from googletest.test import gtest_test_utils # Path to the gtest_skip_in_environment_setup_test binary -EXE_PATH = gtest_test_utils.GetTestExecutablePath('gtest_skip_test') +EXE_PATH = gtest_test_utils.GetTestExecutablePath("gtest_skip_test") OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output @@ -46,15 +46,15 @@ # Test. class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): - def testSkipEntireEnvironmentTest(self): - self.assertIn('Skipped\nskipping single test\n', OUTPUT) - skip_fixture = 'Skipped\nskipping all tests for this fixture\n' - self.assertIsNotNone( - re.search(skip_fixture + '.*' + skip_fixture, OUTPUT, flags=re.DOTALL), - repr(OUTPUT), - ) - self.assertNotIn('FAILED', OUTPUT) + def testSkipEntireEnvironmentTest(self): + self.assertIn("Skipped\nskipping single test\n", OUTPUT) + skip_fixture = "Skipped\nskipping all tests for this fixture\n" + self.assertIsNotNone( + re.search(skip_fixture + ".*" + skip_fixture, OUTPUT, flags=re.DOTALL), + repr(OUTPUT), + ) + self.assertNotIn("FAILED", OUTPUT) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_environment_check_output_test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_environment_check_output_test.py index 388a4e95..7c0cae8d 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_environment_check_output_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_skip_environment_check_output_test.py @@ -37,7 +37,7 @@ # Path to the gtest_skip_in_environment_setup_test binary EXE_PATH = gtest_test_utils.GetTestExecutablePath( - 'gtest_skip_in_environment_setup_test' + "gtest_skip_in_environment_setup_test" ) OUTPUT = gtest_test_utils.Subprocess([EXE_PATH]).output @@ -46,10 +46,10 @@ # Test. class SkipEntireEnvironmentTest(gtest_test_utils.TestCase): - def testSkipEntireEnvironmentTest(self): - self.assertIn('Skipping the entire environment', OUTPUT) - self.assertNotIn('FAILED', OUTPUT) + def testSkipEntireEnvironmentTest(self): + self.assertIn("Skipping the entire environment", OUTPUT) + self.assertNotIn("FAILED", OUTPUT) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_test_utils.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_test_utils.py index 964fa9fa..d21484f0 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_test_utils.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_test_utils.py @@ -35,31 +35,32 @@ import subprocess import sys -IS_WINDOWS = os.name == 'nt' -IS_CYGWIN = os.name == 'posix' and 'CYGWIN' in os.uname()[0] -IS_OS2 = os.name == 'os2' +IS_WINDOWS = os.name == "nt" +IS_CYGWIN = os.name == "posix" and "CYGWIN" in os.uname()[0] +IS_OS2 = os.name == "os2" import atexit import shutil import tempfile import unittest as _test_module + # pylint: enable=g-import-not-at-top -GTEST_OUTPUT_VAR_NAME = 'GTEST_OUTPUT' +GTEST_OUTPUT_VAR_NAME = "GTEST_OUTPUT" # The environment variable for specifying the path to the premature-exit file. -PREMATURE_EXIT_FILE_ENV_VAR = 'TEST_PREMATURE_EXIT_FILE' +PREMATURE_EXIT_FILE_ENV_VAR = "TEST_PREMATURE_EXIT_FILE" environ = os.environ.copy() def SetEnvVar(env_var, value): - """Sets/unsets an environment variable to a given value.""" + """Sets/unsets an environment variable to a given value.""" - if value is not None: - environ[env_var] = value - elif env_var in environ: - del environ[env_var] + if value is not None: + environ[env_var] = value + elif env_var in environ: + del environ[env_var] # Here we expose a class from a particular module, depending on the @@ -69,194 +70,194 @@ def SetEnvVar(env_var, value): # Initially maps a flag to its default value. After # _ParseAndStripGTestFlags() is called, maps a flag to its actual value. _flag_map = { - 'source_dir': os.path.dirname(sys.argv[0]), - 'build_dir': os.path.dirname(sys.argv[0]), + "source_dir": os.path.dirname(sys.argv[0]), + "build_dir": os.path.dirname(sys.argv[0]), } _gtest_flags_are_parsed = False def _ParseAndStripGTestFlags(argv): - """Parses and strips Google Test flags from argv. This is idempotent.""" - - global _gtest_flags_are_parsed - if _gtest_flags_are_parsed: - return - - _gtest_flags_are_parsed = True - for flag in _flag_map: - # The environment variable overrides the default value. - if flag.upper() in os.environ: - _flag_map[flag] = os.environ[flag.upper()] - - # The command line flag overrides the environment variable. - i = 1 # Skips the program name. - while i < len(argv): - prefix = '--' + flag + '=' - if argv[i].startswith(prefix): - _flag_map[flag] = argv[i][len(prefix) :] - del argv[i] - break - else: - # We don't increment i in case we just found a --gtest_* flag - # and removed it from argv. - i += 1 + """Parses and strips Google Test flags from argv. This is idempotent.""" + + global _gtest_flags_are_parsed + if _gtest_flags_are_parsed: + return + + _gtest_flags_are_parsed = True + for flag in _flag_map: + # The environment variable overrides the default value. + if flag.upper() in os.environ: + _flag_map[flag] = os.environ[flag.upper()] + + # The command line flag overrides the environment variable. + i = 1 # Skips the program name. + while i < len(argv): + prefix = "--" + flag + "=" + if argv[i].startswith(prefix): + _flag_map[flag] = argv[i][len(prefix) :] + del argv[i] + break + else: + # We don't increment i in case we just found a --gtest_* flag + # and removed it from argv. + i += 1 def GetFlag(flag): - """Returns the value of the given flag.""" + """Returns the value of the given flag.""" - # In case GetFlag() is called before Main(), we always call - # _ParseAndStripGTestFlags() here to make sure the --gtest_* flags - # are parsed. - _ParseAndStripGTestFlags(sys.argv) + # In case GetFlag() is called before Main(), we always call + # _ParseAndStripGTestFlags() here to make sure the --gtest_* flags + # are parsed. + _ParseAndStripGTestFlags(sys.argv) - return _flag_map[flag] + return _flag_map[flag] def GetSourceDir(): - """Returns the absolute path of the directory where the .py files are.""" + """Returns the absolute path of the directory where the .py files are.""" - return os.path.abspath(GetFlag('source_dir')) + return os.path.abspath(GetFlag("source_dir")) def GetBuildDir(): - """Returns the absolute path of the directory where the test binaries are.""" + """Returns the absolute path of the directory where the test binaries are.""" - return os.path.abspath(GetFlag('build_dir')) + return os.path.abspath(GetFlag("build_dir")) _temp_dir = None + def _RemoveTempDir(): - if _temp_dir: - shutil.rmtree(_temp_dir, ignore_errors=True) + if _temp_dir: + shutil.rmtree(_temp_dir, ignore_errors=True) + atexit.register(_RemoveTempDir) def GetTempDir(): - global _temp_dir - if not _temp_dir: - _temp_dir = tempfile.mkdtemp() - return _temp_dir + global _temp_dir + if not _temp_dir: + _temp_dir = tempfile.mkdtemp() + return _temp_dir def GetTestExecutablePath(executable_name, build_dir=None): - """Returns the absolute path of the test binary given its name. + """Returns the absolute path of the test binary given its name. - The function will print a message and abort the program if the resulting file - doesn't exist. + The function will print a message and abort the program if the resulting file + doesn't exist. - Args: - executable_name: name of the test binary that the test script runs. - build_dir: directory where to look for executables, by default the - result of GetBuildDir(). + Args: + executable_name: name of the test binary that the test script runs. + build_dir: directory where to look for executables, by default the + result of GetBuildDir(). - Returns: - The absolute path of the test binary. - """ + Returns: + The absolute path of the test binary. + """ - path = os.path.abspath( - os.path.join(build_dir or GetBuildDir(), executable_name) - ) - if (IS_WINDOWS or IS_CYGWIN or IS_OS2) and not path.endswith('.exe'): - path += '.exe' + path = os.path.abspath(os.path.join(build_dir or GetBuildDir(), executable_name)) + if (IS_WINDOWS or IS_CYGWIN or IS_OS2) and not path.endswith(".exe"): + path += ".exe" - if not os.path.exists(path): - message = ( - 'Unable to find the test binary "%s". Please make sure to provide\n' - 'a path to the binary via the --build_dir flag or the BUILD_DIR\n' - 'environment variable.' % path - ) - print(message, file=sys.stderr) - sys.exit(1) + if not os.path.exists(path): + message = ( + 'Unable to find the test binary "%s". Please make sure to provide\n' + "a path to the binary via the --build_dir flag or the BUILD_DIR\n" + "environment variable." % path + ) + print(message, file=sys.stderr) + sys.exit(1) - return path + return path def GetExitStatus(exit_code): - """Returns the argument to exit(), or -1 if exit() wasn't called. - - Args: - exit_code: the result value of os.system(command). - """ - - if os.name == 'nt': - # On Windows, os.WEXITSTATUS() doesn't work and os.system() returns - # the argument to exit() directly. - return exit_code - else: - # On Unix, os.WEXITSTATUS() must be used to extract the exit status - # from the result of os.system(). - if os.WIFEXITED(exit_code): - return os.WEXITSTATUS(exit_code) - else: - return -1 - - -class Subprocess: - - def __init__(self, command, working_dir=None, capture_stderr=True, env=None): - """Changes into a specified directory, if provided, and executes a command. - - Restores the old directory afterwards. + """Returns the argument to exit(), or -1 if exit() wasn't called. Args: - command: The command to run, in the form of sys.argv. - working_dir: The directory to change into. - capture_stderr: Determines whether to capture stderr in the output member - or to discard it. - env: Dictionary with environment to pass to the subprocess. - - Returns: - An object that represents outcome of the executed process. It has the - following attributes: - terminated_by_signal True if and only if the child process has been - terminated by a signal. - exited True if and only if the child process exited - normally. - exit_code The code with which the child process exited. - output Child process's stdout and stderr output - combined in a string. + exit_code: the result value of os.system(command). """ - if capture_stderr: - stderr = subprocess.STDOUT - else: - stderr = subprocess.PIPE - - p = subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=stderr, - cwd=working_dir, - universal_newlines=True, - env=env, - ) - # communicate returns a tuple with the file object for the child's - # output. - self.output = p.communicate()[0] - self._return_code = p.returncode - - if bool(self._return_code & 0x80000000): - self.terminated_by_signal = True - self.exited = False + if os.name == "nt": + # On Windows, os.WEXITSTATUS() doesn't work and os.system() returns + # the argument to exit() directly. + return exit_code else: - self.terminated_by_signal = False - self.exited = True - self.exit_code = self._return_code + # On Unix, os.WEXITSTATUS() must be used to extract the exit status + # from the result of os.system(). + if os.WIFEXITED(exit_code): + return os.WEXITSTATUS(exit_code) + else: + return -1 + + +class Subprocess: + + def __init__(self, command, working_dir=None, capture_stderr=True, env=None): + """Changes into a specified directory, if provided, and executes a command. + + Restores the old directory afterwards. + + Args: + command: The command to run, in the form of sys.argv. + working_dir: The directory to change into. + capture_stderr: Determines whether to capture stderr in the output member + or to discard it. + env: Dictionary with environment to pass to the subprocess. + + Returns: + An object that represents outcome of the executed process. It has the + following attributes: + terminated_by_signal True if and only if the child process has been + terminated by a signal. + exited True if and only if the child process exited + normally. + exit_code The code with which the child process exited. + output Child process's stdout and stderr output + combined in a string. + """ + + if capture_stderr: + stderr = subprocess.STDOUT + else: + stderr = subprocess.PIPE + + p = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=stderr, + cwd=working_dir, + universal_newlines=True, + env=env, + ) + # communicate returns a tuple with the file object for the child's + # output. + self.output = p.communicate()[0] + self._return_code = p.returncode + + if bool(self._return_code & 0x80000000): + self.terminated_by_signal = True + self.exited = False + else: + self.terminated_by_signal = False + self.exited = True + self.exit_code = self._return_code def Main(): - """Runs the unit test.""" - - # We must call _ParseAndStripGTestFlags() before calling - # unittest.main(). Otherwise the latter will be confused by the - # --gtest_* flags. - _ParseAndStripGTestFlags(sys.argv) - # The tested binaries should not be writing XML output files unless the - # script explicitly instructs them to. - if GTEST_OUTPUT_VAR_NAME in os.environ: - del os.environ[GTEST_OUTPUT_VAR_NAME] - - _test_module.main() + """Runs the unit test.""" + + # We must call _ParseAndStripGTestFlags() before calling + # unittest.main(). Otherwise the latter will be confused by the + # --gtest_* flags. + _ParseAndStripGTestFlags(sys.argv) + # The tested binaries should not be writing XML output files unless the + # script explicitly instructs them to. + if GTEST_OUTPUT_VAR_NAME in os.environ: + del os.environ[GTEST_OUTPUT_VAR_NAME] + + _test_module.main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_testbridge_test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_testbridge_test.py index 0d58758b..bb170174 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_testbridge_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_testbridge_test.py @@ -33,31 +33,31 @@ from googletest.test import gtest_test_utils -binary_name = 'gtest_testbridge_test_' +binary_name = "gtest_testbridge_test_" COMMAND = gtest_test_utils.GetTestExecutablePath(binary_name) -TESTBRIDGE_NAME = 'TESTBRIDGE_TEST_ONLY' +TESTBRIDGE_NAME = "TESTBRIDGE_TEST_ONLY" def Assert(condition): - if not condition: - raise AssertionError + if not condition: + raise AssertionError class GTestTestFilterTest(gtest_test_utils.TestCase): - def testTestExecutionIsFiltered(self): - """Tests that the test filter is picked up from the testbridge env var.""" - subprocess_env = os.environ.copy() + def testTestExecutionIsFiltered(self): + """Tests that the test filter is picked up from the testbridge env var.""" + subprocess_env = os.environ.copy() - subprocess_env[TESTBRIDGE_NAME] = '*.TestThatSucceeds' - p = gtest_test_utils.Subprocess(COMMAND, env=subprocess_env) + subprocess_env[TESTBRIDGE_NAME] = "*.TestThatSucceeds" + p = gtest_test_utils.Subprocess(COMMAND, env=subprocess_env) - self.assertEqual(0, p.exit_code) + self.assertEqual(0, p.exit_code) - Assert('filter = *.TestThatSucceeds' in p.output) - Assert('[ OK ] TestFilterTest.TestThatSucceeds' in p.output) - Assert('[ PASSED ] 1 test.' in p.output) + Assert("filter = *.TestThatSucceeds" in p.output) + Assert("[ OK ] TestFilterTest.TestThatSucceeds" in p.output) + Assert("[ PASSED ] 1 test." in p.output) -if __name__ == '__main__': - gtest_test_utils.Main() +if __name__ == "__main__": + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_outfiles_test.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_outfiles_test.py index d17cc0c9..9bc483f1 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_outfiles_test.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_outfiles_test.py @@ -79,69 +79,69 @@ class GTestXMLOutFilesTest(gtest_xml_test_utils.GTestXMLTestCase): - """Unit test for Google Test's XML output functionality.""" - - def setUp(self): - # We want the trailing '/' that the last "" provides in os.path.join, for - # telling Google Test to create an output directory instead of a single file - # for xml output. - self.output_dir_ = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_OUTPUT_SUBDIR, "" - ) - self.DeleteFilesAndDir() - - def tearDown(self): - self.DeleteFilesAndDir() - - def DeleteFilesAndDir(self): - try: - os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + ".xml")) - except os.error: - pass - try: - os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + ".xml")) - except os.error: - pass - try: - os.rmdir(self.output_dir_) - except os.error: - pass - - def testOutfile1(self): - self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_XML_1) - - def testOutfile2(self): - self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_XML_2) - - def _TestOutFile(self, test_name, expected_xml): - gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) - command = [gtest_prog_path, "--gtest_output=xml:%s" % self.output_dir_] - p = gtest_test_utils.Subprocess( - command, working_dir=gtest_test_utils.GetTempDir() - ) - self.assertTrue(p.exited) - self.assertEqual(0, p.exit_code) - - output_file_name1 = test_name + ".xml" - output_file1 = os.path.join(self.output_dir_, output_file_name1) - output_file_name2 = "lt-" + output_file_name1 - output_file2 = os.path.join(self.output_dir_, output_file_name2) - self.assertTrue( - os.path.isfile(output_file1) or os.path.isfile(output_file2), - output_file1, - ) - - expected = minidom.parseString(expected_xml) - if os.path.isfile(output_file1): - actual = minidom.parse(output_file1) - else: - actual = minidom.parse(output_file2) - self.NormalizeXml(actual.documentElement) - self.AssertEquivalentNodes(expected.documentElement, actual.documentElement) - expected.unlink() - actual.unlink() + """Unit test for Google Test's XML output functionality.""" + + def setUp(self): + # We want the trailing '/' that the last "" provides in os.path.join, for + # telling Google Test to create an output directory instead of a single file + # for xml output. + self.output_dir_ = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_OUTPUT_SUBDIR, "" + ) + self.DeleteFilesAndDir() + + def tearDown(self): + self.DeleteFilesAndDir() + + def DeleteFilesAndDir(self): + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_1_TEST + ".xml")) + except os.error: + pass + try: + os.remove(os.path.join(self.output_dir_, GTEST_OUTPUT_2_TEST + ".xml")) + except os.error: + pass + try: + os.rmdir(self.output_dir_) + except os.error: + pass + + def testOutfile1(self): + self._TestOutFile(GTEST_OUTPUT_1_TEST, EXPECTED_XML_1) + + def testOutfile2(self): + self._TestOutFile(GTEST_OUTPUT_2_TEST, EXPECTED_XML_2) + + def _TestOutFile(self, test_name, expected_xml): + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(test_name) + command = [gtest_prog_path, "--gtest_output=xml:%s" % self.output_dir_] + p = gtest_test_utils.Subprocess( + command, working_dir=gtest_test_utils.GetTempDir() + ) + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + + output_file_name1 = test_name + ".xml" + output_file1 = os.path.join(self.output_dir_, output_file_name1) + output_file_name2 = "lt-" + output_file_name1 + output_file2 = os.path.join(self.output_dir_, output_file_name2) + self.assertTrue( + os.path.isfile(output_file1) or os.path.isfile(output_file2), + output_file1, + ) + + expected = minidom.parseString(expected_xml) + if os.path.isfile(output_file1): + actual = minidom.parse(output_file1) + else: + actual = minidom.parse(output_file2) + self.NormalizeXml(actual.documentElement) + self.AssertEquivalentNodes(expected.documentElement, actual.documentElement) + expected.unlink() + actual.unlink() if __name__ == "__main__": - os.environ["GTEST_STACK_TRACE_DEPTH"] = "0" - gtest_test_utils.Main() + os.environ["GTEST_STACK_TRACE_DEPTH"] = "0" + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_output_unittest.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_output_unittest.py index 87a7683a..9d612f88 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_output_unittest.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_output_unittest.py @@ -41,30 +41,30 @@ from googletest.test import gtest_test_utils from googletest.test import gtest_xml_test_utils -GTEST_FILTER_FLAG = '--gtest_filter' -GTEST_LIST_TESTS_FLAG = '--gtest_list_tests' -GTEST_OUTPUT_FLAG = '--gtest_output' -GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.xml' -GTEST_PROGRAM_NAME = 'gtest_xml_output_unittest_' +GTEST_FILTER_FLAG = "--gtest_filter" +GTEST_LIST_TESTS_FLAG = "--gtest_list_tests" +GTEST_OUTPUT_FLAG = "--gtest_output" +GTEST_DEFAULT_OUTPUT_FILE = "test_detail.xml" +GTEST_PROGRAM_NAME = "gtest_xml_output_unittest_" # The flag indicating stacktraces are not supported -NO_STACKTRACE_SUPPORT_FLAG = '--no_stacktrace_support' +NO_STACKTRACE_SUPPORT_FLAG = "--no_stacktrace_support" # The environment variables for test sharding. -TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS' -SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX' -SHARD_STATUS_FILE_ENV_VAR = 'GTEST_SHARD_STATUS_FILE' +TOTAL_SHARDS_ENV_VAR = "GTEST_TOTAL_SHARDS" +SHARD_INDEX_ENV_VAR = "GTEST_SHARD_INDEX" +SHARD_STATUS_FILE_ENV_VAR = "GTEST_SHARD_STATUS_FILE" SUPPORTS_STACK_TRACES = NO_STACKTRACE_SUPPORT_FLAG not in sys.argv if SUPPORTS_STACK_TRACES: - STACK_TRACE_TEMPLATE = '\nStack trace:\n*' - STACK_TRACE_ENTITY_TEMPLATE = '' + STACK_TRACE_TEMPLATE = "\nStack trace:\n*" + STACK_TRACE_ENTITY_TEMPLATE = "" else: - STACK_TRACE_TEMPLATE = '\n' - STACK_TRACE_ENTITY_TEMPLATE = ' ' - # unittest.main() can't handle unknown flags - sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) + STACK_TRACE_TEMPLATE = "\n" + STACK_TRACE_ENTITY_TEMPLATE = " " + # unittest.main() can't handle unknown flags + sys.argv.remove(NO_STACKTRACE_SUPPORT_FLAG) EXPECTED_NON_EMPTY_XML = """ @@ -222,8 +222,8 @@ """ % { - 'stack': STACK_TRACE_TEMPLATE, - 'stack_entity': STACK_TRACE_ENTITY_TEMPLATE, + "stack": STACK_TRACE_TEMPLATE, + "stack_entity": STACK_TRACE_ENTITY_TEMPLATE, } EXPECTED_FILTERED_TEST_XML = """ @@ -275,14 +275,14 @@ """ % { - 'stack': STACK_TRACE_TEMPLATE, - 'stack_entity': STACK_TRACE_ENTITY_TEMPLATE, + "stack": STACK_TRACE_TEMPLATE, + "stack_entity": STACK_TRACE_ENTITY_TEMPLATE, } GTEST_PROGRAM_PATH = gtest_test_utils.GetTestExecutablePath(GTEST_PROGRAM_NAME) SUPPORTS_TYPED_TESTS = ( - 'TypedTest' + "TypedTest" in gtest_test_utils.Subprocess( [GTEST_PROGRAM_PATH, GTEST_LIST_TESTS_FLAG], capture_stderr=False ).output @@ -290,224 +290,222 @@ class GTestXMLOutputUnitTest(gtest_xml_test_utils.GTestXMLTestCase): - """Unit test for Google Test's XML output functionality.""" - - # This test currently breaks on platforms that do not support typed and - # type-parameterized tests, so we don't run it under them. - if SUPPORTS_TYPED_TESTS: - - def testNonEmptyXmlOutput(self): - """Generates non-empty XML and verifies it matches the expected output. - - Runs a test program that generates a non-empty XML output, and - tests that the XML output is expected. - """ - self._TestXmlOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY_XML, 1) - - def testNoTestXmlOutput(self): - """Verifies XML output for a Google Test binary without actual tests. - - Runs a test program that generates an XML output for a binary without tests, - and tests that the XML output is expected. - """ - - self._TestXmlOutput('gtest_no_test_unittest', EXPECTED_NO_TEST_XML, 0) - - def testTimestampValue(self): - """Checks whether the timestamp attribute in the XML output is valid. - - Runs a test program that generates an empty XML output, and checks if - the timestamp attribute in the testsuites tag is valid. - """ - actual = self._GetXmlOutput('gtest_no_test_unittest', [], {}, 0) - date_time_str = actual.documentElement.getAttributeNode('timestamp').value - # datetime.strptime() is only available in Python 2.5+ so we have to - # parse the expected datetime manually. - match = re.match(r'(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)', date_time_str) - self.assertTrue( - re.match, 'XML datettime string %s has incorrect format' % date_time_str - ) - date_time_from_xml = datetime.datetime( - year=int(match.group(1)), - month=int(match.group(2)), - day=int(match.group(3)), - hour=int(match.group(4)), - minute=int(match.group(5)), - second=int(match.group(6)), - ) - - time_delta = abs(datetime.datetime.now() - date_time_from_xml) - # timestamp value should be near the current local time - self.assertLess(time_delta, datetime.timedelta(seconds=600)) - actual.unlink() - - def testDefaultOutputFile(self): - """Tests XML file with default name is created when name is not specified. - - Confirms that Google Test produces an XML output file with the expected - default name if no name is explicitly specified. - """ - output_file = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_DEFAULT_OUTPUT_FILE - ) - gtest_prog_path = gtest_test_utils.GetTestExecutablePath( - 'gtest_no_test_unittest' - ) - try: - os.remove(output_file) - except OSError: - e = sys.exc_info()[1] - if e.errno != errno.ENOENT: - raise - - p = gtest_test_utils.Subprocess( - [gtest_prog_path, '%s=xml' % GTEST_OUTPUT_FLAG], - working_dir=gtest_test_utils.GetTempDir(), - ) - self.assertTrue(p.exited) - self.assertEqual(0, p.exit_code) - self.assertTrue(os.path.isfile(output_file)) - - def testSuppressedXmlOutput(self): - """Verifies XML output is suppressed if default listener is shut down. - - Tests that no XML file is generated if the default XML listener is - shut down before RUN_ALL_TESTS is invoked. - """ - - xml_path = os.path.join( - gtest_test_utils.GetTempDir(), GTEST_PROGRAM_NAME + 'out.xml' - ) - if os.path.isfile(xml_path): - os.remove(xml_path) - - command = [ - GTEST_PROGRAM_PATH, - '%s=xml:%s' % (GTEST_OUTPUT_FLAG, xml_path), - '--shut_down_xml', - ] - p = gtest_test_utils.Subprocess(command) - if p.terminated_by_signal: - # p.signal is available only if p.terminated_by_signal is True. - self.assertFalse( - p.terminated_by_signal, - '%s was killed by signal %d' % (GTEST_PROGRAM_NAME, p.signal), - ) - else: - self.assertTrue(p.exited) - self.assertEqual( - 1, - p.exit_code, - "'%s' exited with code %s, which doesn't match " - 'the expected exit code %s.' % (command, p.exit_code, 1), - ) - - self.assertFalse(os.path.isfile(xml_path)) - - def testFilteredTestXmlOutput(self): - """Verifies XML output when a filter is applied. - - Runs a test program that executes only some tests and verifies that - non-selected tests do not show up in the XML output. - """ - - self._TestXmlOutput( - GTEST_PROGRAM_NAME, - EXPECTED_FILTERED_TEST_XML, - 0, - extra_args=['%s=SuccessfulTest.*' % GTEST_FILTER_FLAG], - ) - - def testShardedTestXmlOutput(self): - """Verifies XML output when run using multiple shards. - - Runs a test program that executes only one shard and verifies that tests - from other shards do not show up in the XML output. - """ - - self._TestXmlOutput( - GTEST_PROGRAM_NAME, - EXPECTED_SHARDED_TEST_XML, - 0, - extra_env={SHARD_INDEX_ENV_VAR: '0', TOTAL_SHARDS_ENV_VAR: '10'}, - ) - - def _GetXmlOutput( - self, gtest_prog_name, extra_args, extra_env, expected_exit_code - ): - """Returns the XML output generated by running the program gtest_prog_name. - - Furthermore, the program's exit code must be expected_exit_code. - - Args: - gtest_prog_name: Program to run. - extra_args: Optional arguments to pass to program. - extra_env: Optional environment variables to set. - expected_exit_code: Expected exit code from running gtest_prog_name. - """ - xml_path = os.path.join( - gtest_test_utils.GetTempDir(), gtest_prog_name + 'out.xml' - ) - gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) - - command = [ - gtest_prog_path, - '%s=xml:%s' % (GTEST_OUTPUT_FLAG, xml_path), - ] + extra_args - environ_copy = os.environ.copy() - if extra_env: - environ_copy.update(extra_env) - p = gtest_test_utils.Subprocess(command, env=environ_copy) - - if p.terminated_by_signal: - self.assertTrue( - False, '%s was killed by signal %d' % (gtest_prog_name, p.signal) - ) - else: - self.assertTrue(p.exited) - self.assertEqual( - expected_exit_code, - p.exit_code, - "'%s' exited with code %s, which doesn't match " - 'the expected exit code %s.' - % (command, p.exit_code, expected_exit_code), - ) - actual = minidom.parse(xml_path) - return actual - - def _TestXmlOutput( - self, - gtest_prog_name, - expected_xml, - expected_exit_code, - extra_args=None, - extra_env=None, - ): - """Asserts that the XML document matches. - - Asserts that the XML document generated by running the program - gtest_prog_name matches expected_xml, a string containing another - XML document. Furthermore, the program's exit code must be - expected_exit_code. - - Args: - gtest_prog_name: Program to run. - expected_xml: Path to XML document to match. - expected_exit_code: Expected exit code from running gtest_prog_name. - extra_args: Optional arguments to pass to program. - extra_env: Optional environment variables to set. - """ - - actual = self._GetXmlOutput( - gtest_prog_name, extra_args or [], extra_env or {}, expected_exit_code - ) - expected = minidom.parseString(expected_xml) - self.NormalizeXml(actual.documentElement) - self.AssertEquivalentNodes(expected.documentElement, actual.documentElement) - expected.unlink() - actual.unlink() - - -if __name__ == '__main__': - os.environ['GTEST_STACK_TRACE_DEPTH'] = '1' - gtest_test_utils.Main() + """Unit test for Google Test's XML output functionality.""" + + # This test currently breaks on platforms that do not support typed and + # type-parameterized tests, so we don't run it under them. + if SUPPORTS_TYPED_TESTS: + + def testNonEmptyXmlOutput(self): + """Generates non-empty XML and verifies it matches the expected output. + + Runs a test program that generates a non-empty XML output, and + tests that the XML output is expected. + """ + self._TestXmlOutput(GTEST_PROGRAM_NAME, EXPECTED_NON_EMPTY_XML, 1) + + def testNoTestXmlOutput(self): + """Verifies XML output for a Google Test binary without actual tests. + + Runs a test program that generates an XML output for a binary without tests, + and tests that the XML output is expected. + """ + + self._TestXmlOutput("gtest_no_test_unittest", EXPECTED_NO_TEST_XML, 0) + + def testTimestampValue(self): + """Checks whether the timestamp attribute in the XML output is valid. + + Runs a test program that generates an empty XML output, and checks if + the timestamp attribute in the testsuites tag is valid. + """ + actual = self._GetXmlOutput("gtest_no_test_unittest", [], {}, 0) + date_time_str = actual.documentElement.getAttributeNode("timestamp").value + # datetime.strptime() is only available in Python 2.5+ so we have to + # parse the expected datetime manually. + match = re.match(r"(\d+)-(\d\d)-(\d\d)T(\d\d):(\d\d):(\d\d)", date_time_str) + self.assertTrue( + re.match, "XML datettime string %s has incorrect format" % date_time_str + ) + date_time_from_xml = datetime.datetime( + year=int(match.group(1)), + month=int(match.group(2)), + day=int(match.group(3)), + hour=int(match.group(4)), + minute=int(match.group(5)), + second=int(match.group(6)), + ) + + time_delta = abs(datetime.datetime.now() - date_time_from_xml) + # timestamp value should be near the current local time + self.assertLess(time_delta, datetime.timedelta(seconds=600)) + actual.unlink() + + def testDefaultOutputFile(self): + """Tests XML file with default name is created when name is not specified. + + Confirms that Google Test produces an XML output file with the expected + default name if no name is explicitly specified. + """ + output_file = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_DEFAULT_OUTPUT_FILE + ) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath( + "gtest_no_test_unittest" + ) + try: + os.remove(output_file) + except OSError: + e = sys.exc_info()[1] + if e.errno != errno.ENOENT: + raise + + p = gtest_test_utils.Subprocess( + [gtest_prog_path, "%s=xml" % GTEST_OUTPUT_FLAG], + working_dir=gtest_test_utils.GetTempDir(), + ) + self.assertTrue(p.exited) + self.assertEqual(0, p.exit_code) + self.assertTrue(os.path.isfile(output_file)) + + def testSuppressedXmlOutput(self): + """Verifies XML output is suppressed if default listener is shut down. + + Tests that no XML file is generated if the default XML listener is + shut down before RUN_ALL_TESTS is invoked. + """ + + xml_path = os.path.join( + gtest_test_utils.GetTempDir(), GTEST_PROGRAM_NAME + "out.xml" + ) + if os.path.isfile(xml_path): + os.remove(xml_path) + + command = [ + GTEST_PROGRAM_PATH, + "%s=xml:%s" % (GTEST_OUTPUT_FLAG, xml_path), + "--shut_down_xml", + ] + p = gtest_test_utils.Subprocess(command) + if p.terminated_by_signal: + # p.signal is available only if p.terminated_by_signal is True. + self.assertFalse( + p.terminated_by_signal, + "%s was killed by signal %d" % (GTEST_PROGRAM_NAME, p.signal), + ) + else: + self.assertTrue(p.exited) + self.assertEqual( + 1, + p.exit_code, + "'%s' exited with code %s, which doesn't match " + "the expected exit code %s." % (command, p.exit_code, 1), + ) + + self.assertFalse(os.path.isfile(xml_path)) + + def testFilteredTestXmlOutput(self): + """Verifies XML output when a filter is applied. + + Runs a test program that executes only some tests and verifies that + non-selected tests do not show up in the XML output. + """ + + self._TestXmlOutput( + GTEST_PROGRAM_NAME, + EXPECTED_FILTERED_TEST_XML, + 0, + extra_args=["%s=SuccessfulTest.*" % GTEST_FILTER_FLAG], + ) + + def testShardedTestXmlOutput(self): + """Verifies XML output when run using multiple shards. + + Runs a test program that executes only one shard and verifies that tests + from other shards do not show up in the XML output. + """ + + self._TestXmlOutput( + GTEST_PROGRAM_NAME, + EXPECTED_SHARDED_TEST_XML, + 0, + extra_env={SHARD_INDEX_ENV_VAR: "0", TOTAL_SHARDS_ENV_VAR: "10"}, + ) + + def _GetXmlOutput(self, gtest_prog_name, extra_args, extra_env, expected_exit_code): + """Returns the XML output generated by running the program gtest_prog_name. + + Furthermore, the program's exit code must be expected_exit_code. + + Args: + gtest_prog_name: Program to run. + extra_args: Optional arguments to pass to program. + extra_env: Optional environment variables to set. + expected_exit_code: Expected exit code from running gtest_prog_name. + """ + xml_path = os.path.join( + gtest_test_utils.GetTempDir(), gtest_prog_name + "out.xml" + ) + gtest_prog_path = gtest_test_utils.GetTestExecutablePath(gtest_prog_name) + + command = [ + gtest_prog_path, + "%s=xml:%s" % (GTEST_OUTPUT_FLAG, xml_path), + ] + extra_args + environ_copy = os.environ.copy() + if extra_env: + environ_copy.update(extra_env) + p = gtest_test_utils.Subprocess(command, env=environ_copy) + + if p.terminated_by_signal: + self.assertTrue( + False, "%s was killed by signal %d" % (gtest_prog_name, p.signal) + ) + else: + self.assertTrue(p.exited) + self.assertEqual( + expected_exit_code, + p.exit_code, + "'%s' exited with code %s, which doesn't match " + "the expected exit code %s." + % (command, p.exit_code, expected_exit_code), + ) + actual = minidom.parse(xml_path) + return actual + + def _TestXmlOutput( + self, + gtest_prog_name, + expected_xml, + expected_exit_code, + extra_args=None, + extra_env=None, + ): + """Asserts that the XML document matches. + + Asserts that the XML document generated by running the program + gtest_prog_name matches expected_xml, a string containing another + XML document. Furthermore, the program's exit code must be + expected_exit_code. + + Args: + gtest_prog_name: Program to run. + expected_xml: Path to XML document to match. + expected_exit_code: Expected exit code from running gtest_prog_name. + extra_args: Optional arguments to pass to program. + extra_env: Optional environment variables to set. + """ + + actual = self._GetXmlOutput( + gtest_prog_name, extra_args or [], extra_env or {}, expected_exit_code + ) + expected = minidom.parseString(expected_xml) + self.NormalizeXml(actual.documentElement) + self.AssertEquivalentNodes(expected.documentElement, actual.documentElement) + expected.unlink() + actual.unlink() + + +if __name__ == "__main__": + os.environ["GTEST_STACK_TRACE_DEPTH"] = "1" + gtest_test_utils.Main() diff --git a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_test_utils.py b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_test_utils.py index 74e0f4a0..e9dd589a 100644 --- a/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_test_utils.py +++ b/smallthinker/powerinfer/libaz/external/googletest/googletest/test/gtest_xml_test_utils.py @@ -33,210 +33,210 @@ from xml.dom import minidom, Node from googletest.test import gtest_test_utils -GTEST_DEFAULT_OUTPUT_FILE = 'test_detail.xml' +GTEST_DEFAULT_OUTPUT_FILE = "test_detail.xml" class GTestXMLTestCase(gtest_test_utils.TestCase): - """Base class for tests of Google Test's XML output functionality.""" - - def AssertEquivalentNodes(self, expected_node, actual_node): - """Asserts that actual_node is equivalent to expected_node. - - Asserts that actual_node (a DOM node object) is equivalent to - expected_node (another DOM node object), in that either both of - them are CDATA nodes and have the same value, or both are DOM - elements and actual_node meets all of the following conditions: - - * It has the same tag name as expected_node. - * It has the same set of attributes as expected_node, each with - the same value as the corresponding attribute of expected_node. - Exceptions are any attribute named "time", which needs only be - convertible to a floating-point number and any attribute named - "type_param" which only has to be non-empty. - * It has an equivalent set of child nodes (including elements and - CDATA sections) as expected_node. Note that we ignore the - order of the children as they are not guaranteed to be in any - particular order. - - Args: - expected_node: expected DOM node object - actual_node: actual DOM node object - """ - - if expected_node.nodeType == Node.CDATA_SECTION_NODE: - self.assertEqual(Node.CDATA_SECTION_NODE, actual_node.nodeType) - self.assertEqual(expected_node.nodeValue, actual_node.nodeValue) - return - - self.assertEqual(Node.ELEMENT_NODE, actual_node.nodeType) - self.assertEqual(Node.ELEMENT_NODE, expected_node.nodeType) - self.assertEqual(expected_node.tagName, actual_node.tagName) - - expected_attributes = expected_node.attributes - actual_attributes = actual_node.attributes - self.assertEqual( - expected_attributes.length, - actual_attributes.length, - 'attribute numbers differ in element %s:\nExpected: %r\nActual: %r' - % ( - actual_node.tagName, - expected_attributes.keys(), - actual_attributes.keys(), - ), - ) - for i in range(expected_attributes.length): - expected_attr = expected_attributes.item(i) - actual_attr = actual_attributes.get(expected_attr.name) - self.assertTrue( - actual_attr is not None, - 'expected attribute %s not found in element %s' - % (expected_attr.name, actual_node.tagName), - ) - self.assertEqual( - expected_attr.value, - actual_attr.value, - ' values of attribute %s in element %s differ: %s vs %s' - % ( - expected_attr.name, - actual_node.tagName, - expected_attr.value, - actual_attr.value, - ), - ) - - expected_children = self._GetChildren(expected_node) - actual_children = self._GetChildren(actual_node) - self.assertEqual( - len(expected_children), - len(actual_children), - 'number of child elements differ in element ' + actual_node.tagName, - ) - for child_id, child in expected_children.items(): - self.assertTrue( - child_id in actual_children, - '<%s> is not in <%s> (in element %s)' - % (child_id, actual_children, actual_node.tagName), - ) - self.AssertEquivalentNodes(child, actual_children[child_id]) - - identifying_attribute = { - 'testsuites': 'name', - 'testsuite': 'name', - 'testcase': 'name', - 'failure': 'message', - 'skipped': 'message', - 'property': 'name', - } - - def _GetChildren(self, element): - """Fetches all of the child nodes of element, a DOM Element object. - - Returns them as the values of a dictionary keyed by the IDs of the children. - For , , , and elements, the ID - is the value of their "name" attribute; for elements, it is the - value of the "message" attribute; for elements, it is the value - of their parent's "name" attribute plus the literal string "properties"; - CDATA sections and non-whitespace text nodes are concatenated into a single - CDATA section with ID "detail". An exception is raised if any element other - than the above four is encountered, if two child elements with the same - identifying attributes are encountered, or if any other type of node is - encountered. - - Args: - element: DOM Element object - - Returns: - Dictionary where keys are the IDs of the children. - """ - - children = {} - for child in element.childNodes: - if child.nodeType == Node.ELEMENT_NODE: - if child.tagName == 'properties': - self.assertTrue( - child.parentNode is not None, - 'Encountered element without a parent', - ) - child_id = child.parentNode.getAttribute('name') + '-properties' - else: - self.assertTrue( - child.tagName in self.identifying_attribute, - 'Encountered unknown element <%s>' % child.tagName, - ) - child_id = child.getAttribute( - self.identifying_attribute[child.tagName] - ) - self.assertNotIn(child_id, children) - children[child_id] = child - elif child.nodeType in [Node.TEXT_NODE, Node.CDATA_SECTION_NODE]: - if 'detail' not in children: - if ( - child.nodeType == Node.CDATA_SECTION_NODE - or not child.nodeValue.isspace() - ): - children['detail'] = child.ownerDocument.createCDATASection( - child.nodeValue + """Base class for tests of Google Test's XML output functionality.""" + + def AssertEquivalentNodes(self, expected_node, actual_node): + """Asserts that actual_node is equivalent to expected_node. + + Asserts that actual_node (a DOM node object) is equivalent to + expected_node (another DOM node object), in that either both of + them are CDATA nodes and have the same value, or both are DOM + elements and actual_node meets all of the following conditions: + + * It has the same tag name as expected_node. + * It has the same set of attributes as expected_node, each with + the same value as the corresponding attribute of expected_node. + Exceptions are any attribute named "time", which needs only be + convertible to a floating-point number and any attribute named + "type_param" which only has to be non-empty. + * It has an equivalent set of child nodes (including elements and + CDATA sections) as expected_node. Note that we ignore the + order of the children as they are not guaranteed to be in any + particular order. + + Args: + expected_node: expected DOM node object + actual_node: actual DOM node object + """ + + if expected_node.nodeType == Node.CDATA_SECTION_NODE: + self.assertEqual(Node.CDATA_SECTION_NODE, actual_node.nodeType) + self.assertEqual(expected_node.nodeValue, actual_node.nodeValue) + return + + self.assertEqual(Node.ELEMENT_NODE, actual_node.nodeType) + self.assertEqual(Node.ELEMENT_NODE, expected_node.nodeType) + self.assertEqual(expected_node.tagName, actual_node.tagName) + + expected_attributes = expected_node.attributes + actual_attributes = actual_node.attributes + self.assertEqual( + expected_attributes.length, + actual_attributes.length, + "attribute numbers differ in element %s:\nExpected: %r\nActual: %r" + % ( + actual_node.tagName, + expected_attributes.keys(), + actual_attributes.keys(), + ), + ) + for i in range(expected_attributes.length): + expected_attr = expected_attributes.item(i) + actual_attr = actual_attributes.get(expected_attr.name) + self.assertTrue( + actual_attr is not None, + "expected attribute %s not found in element %s" + % (expected_attr.name, actual_node.tagName), ) - else: - children['detail'].nodeValue += child.nodeValue - else: - self.fail('Encountered unexpected node type %d' % child.nodeType) - return children - - def NormalizeXml(self, element): - """Normalizes XML that may change from run to run. - - Normalizes Google Test's XML output to eliminate references to transient - information that may change from run to run. - - * The "time" attribute of , and - elements is replaced with a single asterisk, if it contains - only digit characters. - * The "timestamp" attribute of elements is replaced with a - single asterisk, if it contains a valid ISO8601 datetime value. - * The "type_param" attribute of elements is replaced with a - single asterisk (if it sn non-empty) as it is the type name returned - by the compiler and is platform dependent. - * The line info reported in the first line of the "message" - attribute and CDATA section of elements is replaced with the - file's basename and a single asterisk for the line number. - * The directory names in file paths are removed. - * The stack traces are removed. - - Args: - element: DOM element to normalize - """ - - if element.tagName == 'testcase': - source_file = element.getAttributeNode('file') - if source_file: - source_file.value = re.sub(r'^.*[/\\](.*)', '\\1', source_file.value) - if element.tagName in ('testsuites', 'testsuite', 'testcase'): - timestamp = element.getAttributeNode('timestamp') - timestamp.value = re.sub( - r'^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d\d\d$', '*', timestamp.value - ) - if element.tagName in ('testsuites', 'testsuite', 'testcase'): - time = element.getAttributeNode('time') - # The value for exact N seconds has a trailing decimal point (e.g., "10." - # instead of "10") - time.value = re.sub(r'^\d+\.(\d+)?$', '*', time.value) - type_param = element.getAttributeNode('type_param') - if type_param and type_param.value: - type_param.value = '*' - elif element.tagName == 'failure' or element.tagName == 'skipped': - source_line_pat = r'^.*[/\\](.*:)\d+\n' - # Replaces the source line information with a normalized form. - message = element.getAttributeNode('message') - message.value = re.sub(source_line_pat, '\\1*\n', message.value) - for child in element.childNodes: - if child.nodeType == Node.CDATA_SECTION_NODE: - # Replaces the source line information with a normalized form. - cdata = re.sub(source_line_pat, '\\1*\n', child.nodeValue) - # Removes the actual stack trace. - child.nodeValue = re.sub( - r'Stack trace:\n(.|\n)*', 'Stack trace:\n*', cdata - ) - for child in element.childNodes: - if child.nodeType == Node.ELEMENT_NODE: - self.NormalizeXml(child) + self.assertEqual( + expected_attr.value, + actual_attr.value, + " values of attribute %s in element %s differ: %s vs %s" + % ( + expected_attr.name, + actual_node.tagName, + expected_attr.value, + actual_attr.value, + ), + ) + + expected_children = self._GetChildren(expected_node) + actual_children = self._GetChildren(actual_node) + self.assertEqual( + len(expected_children), + len(actual_children), + "number of child elements differ in element " + actual_node.tagName, + ) + for child_id, child in expected_children.items(): + self.assertTrue( + child_id in actual_children, + "<%s> is not in <%s> (in element %s)" + % (child_id, actual_children, actual_node.tagName), + ) + self.AssertEquivalentNodes(child, actual_children[child_id]) + + identifying_attribute = { + "testsuites": "name", + "testsuite": "name", + "testcase": "name", + "failure": "message", + "skipped": "message", + "property": "name", + } + + def _GetChildren(self, element): + """Fetches all of the child nodes of element, a DOM Element object. + + Returns them as the values of a dictionary keyed by the IDs of the children. + For , , , and elements, the ID + is the value of their "name" attribute; for elements, it is the + value of the "message" attribute; for elements, it is the value + of their parent's "name" attribute plus the literal string "properties"; + CDATA sections and non-whitespace text nodes are concatenated into a single + CDATA section with ID "detail". An exception is raised if any element other + than the above four is encountered, if two child elements with the same + identifying attributes are encountered, or if any other type of node is + encountered. + + Args: + element: DOM Element object + + Returns: + Dictionary where keys are the IDs of the children. + """ + + children = {} + for child in element.childNodes: + if child.nodeType == Node.ELEMENT_NODE: + if child.tagName == "properties": + self.assertTrue( + child.parentNode is not None, + "Encountered element without a parent", + ) + child_id = child.parentNode.getAttribute("name") + "-properties" + else: + self.assertTrue( + child.tagName in self.identifying_attribute, + "Encountered unknown element <%s>" % child.tagName, + ) + child_id = child.getAttribute( + self.identifying_attribute[child.tagName] + ) + self.assertNotIn(child_id, children) + children[child_id] = child + elif child.nodeType in [Node.TEXT_NODE, Node.CDATA_SECTION_NODE]: + if "detail" not in children: + if ( + child.nodeType == Node.CDATA_SECTION_NODE + or not child.nodeValue.isspace() + ): + children["detail"] = child.ownerDocument.createCDATASection( + child.nodeValue + ) + else: + children["detail"].nodeValue += child.nodeValue + else: + self.fail("Encountered unexpected node type %d" % child.nodeType) + return children + + def NormalizeXml(self, element): + """Normalizes XML that may change from run to run. + + Normalizes Google Test's XML output to eliminate references to transient + information that may change from run to run. + + * The "time" attribute of , and + elements is replaced with a single asterisk, if it contains + only digit characters. + * The "timestamp" attribute of elements is replaced with a + single asterisk, if it contains a valid ISO8601 datetime value. + * The "type_param" attribute of elements is replaced with a + single asterisk (if it sn non-empty) as it is the type name returned + by the compiler and is platform dependent. + * The line info reported in the first line of the "message" + attribute and CDATA section of elements is replaced with the + file's basename and a single asterisk for the line number. + * The directory names in file paths are removed. + * The stack traces are removed. + + Args: + element: DOM element to normalize + """ + + if element.tagName == "testcase": + source_file = element.getAttributeNode("file") + if source_file: + source_file.value = re.sub(r"^.*[/\\](.*)", "\\1", source_file.value) + if element.tagName in ("testsuites", "testsuite", "testcase"): + timestamp = element.getAttributeNode("timestamp") + timestamp.value = re.sub( + r"^\d{4}-\d\d-\d\dT\d\d:\d\d:\d\d\.\d\d\d$", "*", timestamp.value + ) + if element.tagName in ("testsuites", "testsuite", "testcase"): + time = element.getAttributeNode("time") + # The value for exact N seconds has a trailing decimal point (e.g., "10." + # instead of "10") + time.value = re.sub(r"^\d+\.(\d+)?$", "*", time.value) + type_param = element.getAttributeNode("type_param") + if type_param and type_param.value: + type_param.value = "*" + elif element.tagName == "failure" or element.tagName == "skipped": + source_line_pat = r"^.*[/\\](.*:)\d+\n" + # Replaces the source line information with a normalized form. + message = element.getAttributeNode("message") + message.value = re.sub(source_line_pat, "\\1*\n", message.value) + for child in element.childNodes: + if child.nodeType == Node.CDATA_SECTION_NODE: + # Replaces the source line information with a normalized form. + cdata = re.sub(source_line_pat, "\\1*\n", child.nodeValue) + # Removes the actual stack trace. + child.nodeValue = re.sub( + r"Stack trace:\n(.|\n)*", "Stack trace:\n*", cdata + ) + for child in element.childNodes: + if child.nodeType == Node.ELEMENT_NODE: + self.NormalizeXml(child) diff --git a/smallthinker/scripts/compare-llama-bench.py b/smallthinker/scripts/compare-llama-bench.py index a1013c3b..60b3b33b 100755 --- a/smallthinker/scripts/compare-llama-bench.py +++ b/smallthinker/scripts/compare-llama-bench.py @@ -16,56 +16,166 @@ import git from tabulate import tabulate except ImportError as e: - print("the following Python libraries are required: GitPython, tabulate.") # noqa: NP100 + print( + "the following Python libraries are required: GitPython, tabulate." + ) # noqa: NP100 raise e logger = logging.getLogger("compare-llama-bench") # All llama-bench SQL fields DB_FIELDS = [ - "build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename", - "model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads", - "cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers", - "split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides", + "build_commit", + "build_number", + "cpu_info", + "gpu_info", + "backends", + "model_filename", + "model_type", + "model_size", + "model_n_params", + "n_batch", + "n_ubatch", + "n_threads", + "cpu_mask", + "cpu_strict", + "poll", + "type_k", + "type_v", + "n_gpu_layers", + "split_mode", + "main_gpu", + "no_kv_offload", + "flash_attn", + "tensor_split", + "tensor_buft_overrides", "defrag_thold", - "use_mmap", "embeddings", "no_op_offload", "n_prompt", "n_gen", "n_depth", - "test_time", "avg_ns", "stddev_ns", "avg_ts", "stddev_ts", + "use_mmap", + "embeddings", + "no_op_offload", + "n_prompt", + "n_gen", + "n_depth", + "test_time", + "avg_ns", + "stddev_ns", + "avg_ts", + "stddev_ts", ] DB_TYPES = [ - "TEXT", "INTEGER", "TEXT", "TEXT", "TEXT", "TEXT", - "TEXT", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", - "TEXT", "INTEGER", "INTEGER", "TEXT", "TEXT", "INTEGER", - "TEXT", "INTEGER", "INTEGER", "INTEGER", "TEXT", "TEXT", + "TEXT", + "INTEGER", + "TEXT", + "TEXT", + "TEXT", + "TEXT", + "TEXT", + "INTEGER", + "INTEGER", + "INTEGER", + "INTEGER", + "INTEGER", + "TEXT", + "INTEGER", + "INTEGER", + "TEXT", + "TEXT", + "INTEGER", + "TEXT", + "INTEGER", + "INTEGER", + "INTEGER", + "TEXT", + "TEXT", + "REAL", + "INTEGER", + "INTEGER", + "INTEGER", + "INTEGER", + "INTEGER", + "INTEGER", + "TEXT", + "INTEGER", + "INTEGER", + "REAL", "REAL", - "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", "INTEGER", - "TEXT", "INTEGER", "INTEGER", "REAL", "REAL", ] assert len(DB_FIELDS) == len(DB_TYPES) # Properties by which to differentiate results per commit: KEY_PROPERTIES = [ - "cpu_info", "gpu_info", "backends", "n_gpu_layers", "tensor_buft_overrides", "model_filename", "model_type", - "n_batch", "n_ubatch", "embeddings", "cpu_mask", "cpu_strict", "poll", "n_threads", "type_k", "type_v", - "use_mmap", "no_kv_offload", "split_mode", "main_gpu", "tensor_split", "flash_attn", "n_prompt", "n_gen", "n_depth" + "cpu_info", + "gpu_info", + "backends", + "n_gpu_layers", + "tensor_buft_overrides", + "model_filename", + "model_type", + "n_batch", + "n_ubatch", + "embeddings", + "cpu_mask", + "cpu_strict", + "poll", + "n_threads", + "type_k", + "type_v", + "use_mmap", + "no_kv_offload", + "split_mode", + "main_gpu", + "tensor_split", + "flash_attn", + "n_prompt", + "n_gen", + "n_depth", ] # Properties that are boolean and are converted to Yes/No for the table: -BOOL_PROPERTIES = ["embeddings", "cpu_strict", "use_mmap", "no_kv_offload", "flash_attn"] +BOOL_PROPERTIES = [ + "embeddings", + "cpu_strict", + "use_mmap", + "no_kv_offload", + "flash_attn", +] # Header names for the table: PRETTY_NAMES = { - "cpu_info": "CPU", "gpu_info": "GPU", "backends": "Backends", "n_gpu_layers": "GPU layers", - "tensor_buft_overrides": "Tensor overrides", "model_filename": "File", "model_type": "Model", "model_size": "Model size [GiB]", - "model_n_params": "Num. of par.", "n_batch": "Batch size", "n_ubatch": "Microbatch size", "embeddings": "Embeddings", - "cpu_mask": "CPU mask", "cpu_strict": "CPU strict", "poll": "Poll", "n_threads": "Threads", "type_k": "K type", "type_v": "V type", - "use_mmap": "Use mmap", "no_kv_offload": "NKVO", "split_mode": "Split mode", "main_gpu": "Main GPU", "tensor_split": "Tensor split", + "cpu_info": "CPU", + "gpu_info": "GPU", + "backends": "Backends", + "n_gpu_layers": "GPU layers", + "tensor_buft_overrides": "Tensor overrides", + "model_filename": "File", + "model_type": "Model", + "model_size": "Model size [GiB]", + "model_n_params": "Num. of par.", + "n_batch": "Batch size", + "n_ubatch": "Microbatch size", + "embeddings": "Embeddings", + "cpu_mask": "CPU mask", + "cpu_strict": "CPU strict", + "poll": "Poll", + "n_threads": "Threads", + "type_k": "K type", + "type_v": "V type", + "use_mmap": "Use mmap", + "no_kv_offload": "NKVO", + "split_mode": "Split mode", + "main_gpu": "Main GPU", + "tensor_split": "Tensor split", "flash_attn": "FlashAttention", } DEFAULT_SHOW = ["model_type"] # Always show these properties by default. DEFAULT_HIDE = ["model_filename"] # Always hide these properties by default. -GPU_NAME_STRIP = ["NVIDIA GeForce ", "Tesla ", "AMD Radeon "] # Strip prefixes for smaller tables. +GPU_NAME_STRIP = [ + "NVIDIA GeForce ", + "Tesla ", + "AMD Radeon ", +] # Strip prefixes for smaller tables. MODEL_SUFFIX_REPLACE = {" - Small": "_S", " - Medium": "_M", " - Large": "_L"} DESCRIPTION = """Creates tables from llama-bench data written to multiple JSON/CSV files, a single JSONL file or SQLite database. Example usage (Linux): @@ -82,7 +192,8 @@ """ parser = argparse.ArgumentParser( - description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter) + description=DESCRIPTION, formatter_class=argparse.RawDescriptionHelpFormatter +) help_b = ( "The baseline commit to compare performance to. " "Accepts either a branch name, tag name, or commit hash. " @@ -119,7 +230,11 @@ "If the columns are manually specified, then the results for each unique combination of the " "specified values are averaged WITHOUT weighing by the --repetitions parameter of llama-bench." ) -parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed") +parser.add_argument( + "--check", + action="store_true", + help="check if all required Python libraries are installed", +) parser.add_argument("-s", "--show", help=help_s) parser.add_argument("--verbose", action="store_true", help="increase output verbosity") @@ -179,11 +294,11 @@ def find_parent_in_data(self, commit: git.Commit) -> Optional[str]: seen_hexsha8 = set() while heap: depth, current_commit = heapq.heappop(heap) - current_hexsha8 = commit.hexsha[:self.build_len] + current_hexsha8 = commit.hexsha[: self.build_len] if current_hexsha8 in self.builds: return current_hexsha8 for parent in commit.parents: - parent_hexsha8 = parent.hexsha[:self.build_len] + parent_hexsha8 = parent.hexsha[: self.build_len] if parent_hexsha8 not in seen_hexsha8: seen_hexsha8.add(parent_hexsha8) heapq.heappush(heap, (depth + 1, parent)) @@ -192,13 +307,13 @@ def find_parent_in_data(self, commit: git.Commit) -> Optional[str]: def get_all_parent_hexsha8s(self, commit: git.Commit) -> Sequence[str]: """Helper method to recursively get hexsha8 values for all parents of a commit.""" unvisited = [commit] - visited = [] + visited = [] while unvisited: current_commit = unvisited.pop(0) - visited.append(current_commit.hexsha[:self.build_len]) + visited.append(current_commit.hexsha[: self.build_len]) for parent in current_commit.parents: - if parent.hexsha[:self.build_len] not in visited: + if parent.hexsha[: self.build_len] not in visited: unvisited.append(parent) return visited @@ -208,10 +323,10 @@ def get_commit_name(self, hexsha8: str) -> str: if self.repo is None: return hexsha8 for h in self.repo.heads: - if h.commit.hexsha[:self.build_len] == hexsha8: + if h.commit.hexsha[: self.build_len] == hexsha8: return h.name for t in self.repo.tags: - if t.commit.hexsha[:self.build_len] == hexsha8: + if t.commit.hexsha[: self.build_len] == hexsha8: return t.name return hexsha8 @@ -221,20 +336,24 @@ def get_commit_hexsha8(self, name: str) -> Optional[str]: return None for h in self.repo.heads: if h.name == name: - return h.commit.hexsha[:self.build_len] + return h.commit.hexsha[: self.build_len] for t in self.repo.tags: if t.name == name: - return t.commit.hexsha[:self.build_len] + return t.commit.hexsha[: self.build_len] for c in self.repo.iter_commits("--all"): - if c.hexsha[:self.build_len] == name[:self.build_len]: - return c.hexsha[:self.build_len] + if c.hexsha[: self.build_len] == name[: self.build_len]: + return c.hexsha[: self.build_len] return None - def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + def builds_timestamp( + self, reverse: bool = False + ) -> Union[Iterator[tuple], Sequence[tuple]]: """Helper method that gets rows of (build_commit, test_time) sorted by the latter.""" return [] - def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + def get_rows( + self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str + ) -> Sequence[tuple]: """ Helper method that gets table rows for some list of properties. Rows are created by combining those where all provided properties are equal. @@ -252,37 +371,71 @@ def __init__(self): super().__init__() self.connection = sqlite3.connect(":memory:") self.cursor = self.connection.cursor() - self.cursor.execute(f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});") + self.cursor.execute( + f"CREATE TABLE test({', '.join(' '.join(x) for x in zip(DB_FIELDS, DB_TYPES))});" + ) def _builds_init(self): if self.connection: - self.build_len_min = self.cursor.execute("SELECT MIN(LENGTH(build_commit)) from test;").fetchone()[0] - self.build_len_max = self.cursor.execute("SELECT MAX(LENGTH(build_commit)) from test;").fetchone()[0] + self.build_len_min = self.cursor.execute( + "SELECT MIN(LENGTH(build_commit)) from test;" + ).fetchone()[0] + self.build_len_max = self.cursor.execute( + "SELECT MAX(LENGTH(build_commit)) from test;" + ).fetchone()[0] if self.build_len_min != self.build_len_max: - logger.warning("Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " - "Try purging the the database of old commits.") - self.cursor.execute(f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});") - - builds = self.cursor.execute("SELECT DISTINCT build_commit FROM test;").fetchall() - self.builds = list(map(lambda b: b[0], builds)) # list[tuple[str]] -> list[str] + logger.warning( + "Data contains commit hashes of differing lengths. It's possible that the wrong commits will be compared. " + "Try purging the the database of old commits." + ) + self.cursor.execute( + f"UPDATE test SET build_commit = SUBSTRING(build_commit, 1, {self.build_len_min});" + ) + + builds = self.cursor.execute( + "SELECT DISTINCT build_commit FROM test;" + ).fetchall() + self.builds = list( + map(lambda b: b[0], builds) + ) # list[tuple[str]] -> list[str] super()._builds_init() - def builds_timestamp(self, reverse: bool = False) -> Union[Iterator[tuple], Sequence[tuple]]: + def builds_timestamp( + self, reverse: bool = False + ) -> Union[Iterator[tuple], Sequence[tuple]]: data = self.cursor.execute( - "SELECT build_commit, test_time FROM test ORDER BY test_time;").fetchall() + "SELECT build_commit, test_time FROM test ORDER BY test_time;" + ).fetchall() return reversed(data) if reverse else data - def get_rows(self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str) -> Sequence[tuple]: + def get_rows( + self, properties: list[str], hexsha8_baseline: str, hexsha8_compare: str + ) -> Sequence[tuple]: select_string = ", ".join( - [f"tb.{p}" for p in properties] + ["tb.n_prompt", "tb.n_gen", "tb.n_depth", "AVG(tb.avg_ts)", "AVG(tc.avg_ts)"]) + [f"tb.{p}" for p in properties] + + [ + "tb.n_prompt", + "tb.n_gen", + "tb.n_depth", + "AVG(tb.avg_ts)", + "AVG(tc.avg_ts)", + ] + ) equal_string = " AND ".join( - [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + [ - f"tb.build_commit = '{hexsha8_baseline}'", f"tc.build_commit = '{hexsha8_compare}'"] + [f"tb.{p} = tc.{p}" for p in KEY_PROPERTIES] + + [ + f"tb.build_commit = '{hexsha8_baseline}'", + f"tc.build_commit = '{hexsha8_compare}'", + ] + ) + group_order_string = ", ".join( + [f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"] + ) + query = ( + f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " + f"GROUP BY {group_order_string} ORDER BY {group_order_string};" ) - group_order_string = ", ".join([f"tb.{p}" for p in properties] + ["tb.n_gen", "tb.n_prompt", "tb.n_depth"]) - query = (f"SELECT {select_string} FROM test tb JOIN test tc ON {equal_string} " - f"GROUP BY {group_order_string} ORDER BY {group_order_string};") return self.cursor.execute(query).fetchall() @@ -302,7 +455,9 @@ def valid_format(data_file: str) -> bool: try: if cursor.execute("PRAGMA schema_version;").fetchone()[0] == 0: - raise sqlite3.DatabaseError("The provided input file does not exist or is empty.") + raise sqlite3.DatabaseError( + "The provided input file does not exist or is empty." + ) except sqlite3.DatabaseError as e: logger.debug(f'"{data_file}" is not a valid SQLite3 file.', exc_info=e) cursor = None @@ -322,10 +477,15 @@ def __init__(self, data_file: str): for k in parsed.keys() - set(DB_FIELDS): del parsed[k] - if (missing_keys := self._check_keys(parsed.keys())): - raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + if missing_keys := self._check_keys(parsed.keys()): + raise RuntimeError( + f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}" + ) - self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + self.cursor.execute( + f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", + tuple(parsed.values()), + ) self._builds_init() @@ -355,10 +515,15 @@ def __init__(self, data_files: list[str]): for k in entry.keys() - set(DB_FIELDS): del entry[k] - if (missing_keys := self._check_keys(entry.keys())): - raise RuntimeError(f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}") + if missing_keys := self._check_keys(entry.keys()): + raise RuntimeError( + f"Missing required data key(s) at entry {i + 1}: {', '.join(missing_keys)}" + ) - self.cursor.execute(f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", tuple(entry.values())) + self.cursor.execute( + f"INSERT INTO test({', '.join(entry.keys())}) VALUES({', '.join('?' * len(entry))});", + tuple(entry.values()), + ) self._builds_init() @@ -390,10 +555,15 @@ def __init__(self, data_files: list[str]): for k in keys - set(DB_FIELDS): del parsed[k] - if (missing_keys := self._check_keys(keys)): - raise RuntimeError(f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}") + if missing_keys := self._check_keys(keys): + raise RuntimeError( + f"Missing required data key(s) at line {i + 1}: {', '.join(missing_keys)}" + ) - self.cursor.execute(f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", tuple(parsed.values())) + self.cursor.execute( + f"INSERT INTO test({', '.join(parsed.keys())}) VALUES({', '.join('?' * len(parsed))});", + tuple(parsed.values()), + ) self._builds_init() @@ -451,15 +621,21 @@ def valid_format(data_files: list[str]) -> bool: sys.exit(1) # Otherwise, search for the most recent parent of master for which there is data: elif bench_data.repo is not None: - hexsha8_baseline = bench_data.find_parent_in_data(bench_data.repo.heads.master.commit) + hexsha8_baseline = bench_data.find_parent_in_data( + bench_data.repo.heads.master.commit + ) if hexsha8_baseline is None: - logger.error("No baseline was provided and did not find data for any master branch commits.\n") + logger.error( + "No baseline was provided and did not find data for any master branch commits.\n" + ) parser.print_help() sys.exit(1) else: - logger.error("No baseline was provided and the current working directory " - "is not part of a git repository from which a baseline could be inferred.\n") + logger.error( + "No baseline was provided and the current working directory " + "is not part of a git repository from which a baseline could be inferred.\n" + ) parser.print_help() sys.exit(1) @@ -481,19 +657,25 @@ def valid_format(data_files: list[str]) -> bool: # Otherwise, search for the commit for llama-bench was most recently run # and that is not a parent of master: elif bench_data.repo is not None: - hexsha8s_master = bench_data.get_all_parent_hexsha8s(bench_data.repo.heads.master.commit) - for (hexsha8, _) in bench_data.builds_timestamp(reverse=True): + hexsha8s_master = bench_data.get_all_parent_hexsha8s( + bench_data.repo.heads.master.commit + ) + for hexsha8, _ in bench_data.builds_timestamp(reverse=True): if hexsha8 not in hexsha8s_master: hexsha8_compare = hexsha8 break if hexsha8_compare is None: - logger.error("No compare target was provided and did not find data for any non-master commits.\n") + logger.error( + "No compare target was provided and did not find data for any non-master commits.\n" + ) parser.print_help() sys.exit(1) else: - logger.error("No compare target was provided and the current working directory " - "is not part of a git repository from which a compare target could be inferred.\n") + logger.error( + "No compare target was provided and the current working directory " + "is not part of a git repository from which a compare target could be inferred.\n" + ) parser.print_help() sys.exit(1) @@ -505,7 +687,9 @@ def valid_format(data_files: list[str]) -> bool: show = known_args.show.split(",") unknown_cols = [] for prop in show: - if prop not in KEY_PROPERTIES[:-3]: # Last three values are n_prompt, n_gen, n_depth. + if ( + prop not in KEY_PROPERTIES[:-3] + ): # Last three values are n_prompt, n_gen, n_depth. unknown_cols.append(prop) if unknown_cols: logger.error(f"Unknown values for --show: {', '.join(unknown_cols)}") @@ -547,14 +731,16 @@ def valid_format(data_files: list[str]) -> bool: rows_show = bench_data.get_rows(show, hexsha8_baseline, hexsha8_compare) if not rows_show: - logger.error(f"No comparable data was found between {name_baseline} and {name_compare}.\n") + logger.error( + f"No comparable data was found between {name_baseline} and {name_compare}.\n" + ) sys.exit(1) table = [] for row in rows_show: n_prompt = int(row[-5]) - n_gen = int(row[-4]) - n_depth = int(row[-3]) + n_gen = int(row[-4]) + n_depth = int(row[-3]) if n_prompt != 0 and n_gen == 0: test_name = f"pp{n_prompt}" elif n_prompt == 0 and n_gen != 0: @@ -565,7 +751,12 @@ def valid_format(data_files: list[str]) -> bool: test_name = f"{test_name}@d{n_depth}" # Regular columns test name avg t/s values Speedup # VVVVVVVVVVVVV VVVVVVVVV VVVVVVVVVVVVVV VVVVVVV - table.append(list(row[:-5]) + [test_name] + list(row[-2:]) + [float(row[-1]) / float(row[-2])]) + table.append( + list(row[:-5]) + + [test_name] + + list(row[-2:]) + + [float(row[-1]) / float(row[-2])] + ) # Some a-posteriori fixes to make the table contents prettier: for bool_property in BOOL_PROPERTIES: @@ -576,14 +767,14 @@ def valid_format(data_files: list[str]) -> bool: if "model_type" in show: ip = show.index("model_type") - for (old, new) in MODEL_SUFFIX_REPLACE.items(): + for old, new in MODEL_SUFFIX_REPLACE.items(): for row_table in table: row_table[ip] = row_table[ip].replace(old, new) if "model_size" in show: ip = show.index("model_size") for row_table in table: - row_table[ip] = float(row_table[ip]) / 1024 ** 3 + row_table[ip] = float(row_table[ip]) / 1024**3 if "gpu_info" in show: ip = show.index("gpu_info") @@ -597,12 +788,11 @@ def valid_format(data_files: list[str]) -> bool: if len(gpu_names) >= 2 and all_names_the_same: row_table[ip] = f"{num_gpus}x {gpu_names[0]}" -headers = [PRETTY_NAMES[p] for p in show] +headers = [PRETTY_NAMES[p] for p in show] headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"] -print(tabulate( # noqa: NP100 - table, - headers=headers, - floatfmt=".2f", - tablefmt=known_args.output -)) +print( + tabulate( # noqa: NP100 + table, headers=headers, floatfmt=".2f", tablefmt=known_args.output + ) +) diff --git a/smallthinker/scripts/fetch_server_test_models.py b/smallthinker/scripts/fetch_server_test_models.py index ac483ef5..321e91c9 100755 --- a/smallthinker/scripts/fetch_server_test_models.py +++ b/smallthinker/scripts/fetch_server_test_models.py @@ -1,15 +1,15 @@ #!/usr/bin/env python -''' - This script fetches all the models used in the server tests. +""" +This script fetches all the models used in the server tests. - This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. +This is useful for slow tests that use larger models, to avoid them timing out on the model downloads. - It is meant to be run from the root of the repository. +It is meant to be run from the root of the repository. - Example: - python scripts/fetch_server_test_models.py - ( cd tools/server/tests && ./tests.sh -v -x -m slow ) -''' +Example: + python scripts/fetch_server_test_models.py + ( cd tools/server/tests && ./tests.sh -v -x -m slow ) +""" import ast import glob import logging @@ -28,78 +28,123 @@ class Config: frozen = True -def collect_hf_model_test_parameters(test_file) -> Generator[HuggingFaceModel, None, None]: +def collect_hf_model_test_parameters( + test_file, +) -> Generator[HuggingFaceModel, None, None]: try: with open(test_file) as f: tree = ast.parse(f.read()) except Exception as e: - logging.error(f'collect_hf_model_test_parameters failed on {test_file}: {e}') + logging.error(f"collect_hf_model_test_parameters failed on {test_file}: {e}") return for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): for dec in node.decorator_list: - if isinstance(dec, ast.Call) and isinstance(dec.func, ast.Attribute) and dec.func.attr == 'parametrize': + if ( + isinstance(dec, ast.Call) + and isinstance(dec.func, ast.Attribute) + and dec.func.attr == "parametrize" + ): param_names = ast.literal_eval(dec.args[0]).split(",") if "hf_repo" not in param_names: continue raw_param_values = dec.args[1] if not isinstance(raw_param_values, ast.List): - logging.warning(f'Skipping non-list parametrize entry at {test_file}:{node.lineno}') + logging.warning( + f"Skipping non-list parametrize entry at {test_file}:{node.lineno}" + ) continue hf_repo_idx = param_names.index("hf_repo") - hf_file_idx = param_names.index("hf_file") if "hf_file" in param_names else None + hf_file_idx = ( + param_names.index("hf_file") + if "hf_file" in param_names + else None + ) for t in raw_param_values.elts: if not isinstance(t, ast.Tuple): - logging.warning(f'Skipping non-tuple parametrize entry at {test_file}:{node.lineno}') + logging.warning( + f"Skipping non-tuple parametrize entry at {test_file}:{node.lineno}" + ) continue yield HuggingFaceModel( hf_repo=ast.literal_eval(t.elts[hf_repo_idx]), - hf_file=ast.literal_eval(t.elts[hf_file_idx]) if hf_file_idx is not None else None) - - -if __name__ == '__main__': - logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') - - models = sorted(list(set([ - model - for test_file in glob.glob('tools/server/tests/unit/test_*.py') - for model in collect_hf_model_test_parameters(test_file) - ])), key=lambda m: (m.hf_repo, m.hf_file)) - - logging.info(f'Found {len(models)} models in parameterized tests:') + hf_file=( + ast.literal_eval(t.elts[hf_file_idx]) + if hf_file_idx is not None + else None + ), + ) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") + + models = sorted( + list( + set( + [ + model + for test_file in glob.glob("tools/server/tests/unit/test_*.py") + for model in collect_hf_model_test_parameters(test_file) + ] + ) + ), + key=lambda m: (m.hf_repo, m.hf_file), + ) + + logging.info(f"Found {len(models)} models in parameterized tests:") for m in models: - logging.info(f' - {m.hf_repo} / {m.hf_file}') + logging.info(f" - {m.hf_repo} / {m.hf_file}") cli_path = os.environ.get( - 'LLAMA_CLI_BIN_PATH', + "LLAMA_CLI_BIN_PATH", os.path.join( os.path.dirname(__file__), - '../build/bin/Release/llama-cli.exe' if os.name == 'nt' else '../build/bin/llama-cli')) + ( + "../build/bin/Release/llama-cli.exe" + if os.name == "nt" + else "../build/bin/llama-cli" + ), + ), + ) for m in models: - if '<' in m.hf_repo or (m.hf_file is not None and '<' in m.hf_file): + if "<" in m.hf_repo or (m.hf_file is not None and "<" in m.hf_file): continue - if m.hf_file is not None and '-of-' in m.hf_file: - logging.warning(f'Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file') + if m.hf_file is not None and "-of-" in m.hf_file: + logging.warning( + f"Skipping model at {m.hf_repo} / {m.hf_file} because it is a split file" + ) continue - logging.info(f'Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched') + logging.info( + f"Using llama-cli to ensure model {m.hf_repo}/{m.hf_file} was fetched" + ) cmd = [ cli_path, - '-hfr', m.hf_repo, - *([] if m.hf_file is None else ['-hff', m.hf_file]), - '-n', '1', - '-p', 'Hey', - '--no-warmup', - '--log-disable', - '-no-cnv'] - if m.hf_file != 'tinyllamas/stories260K.gguf' and 'Mistral-Nemo' not in m.hf_repo: - cmd.append('-fa') + "-hfr", + m.hf_repo, + *([] if m.hf_file is None else ["-hff", m.hf_file]), + "-n", + "1", + "-p", + "Hey", + "--no-warmup", + "--log-disable", + "-no-cnv", + ] + if ( + m.hf_file != "tinyllamas/stories260K.gguf" + and "Mistral-Nemo" not in m.hf_repo + ): + cmd.append("-fa") try: subprocess.check_call(cmd) except subprocess.CalledProcessError: - logging.error(f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}') + logging.error( + f'Failed to fetch model at {m.hf_repo} / {m.hf_file} with command:\n {" ".join(cmd)}' + ) exit(1) diff --git a/smallthinker/scripts/gen-unicode-data.py b/smallthinker/scripts/gen-unicode-data.py index 2d9bde01..c50f1c6a 100644 --- a/smallthinker/scripts/gen-unicode-data.py +++ b/smallthinker/scripts/gen-unicode-data.py @@ -50,33 +50,33 @@ def unicode_data_iter(): # see definition in unicode.h -CODEPOINT_FLAG_UNDEFINED = 0x0001 # -CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N} -CODEPOINT_FLAG_LETTER = 0x0004 # \p{L} -CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z} -CODEPOINT_FLAG_MARK = 0x0010 # \p{M} +CODEPOINT_FLAG_UNDEFINED = 0x0001 # +CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N} +CODEPOINT_FLAG_LETTER = 0x0004 # \p{L} +CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z} +CODEPOINT_FLAG_MARK = 0x0010 # \p{M} CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P} -CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S} -CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C} +CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S} +CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C} UNICODE_CATEGORY_TO_FLAG = { - "Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined - "Cc": CODEPOINT_FLAG_CONTROL, # Control - "Cf": CODEPOINT_FLAG_CONTROL, # Format - "Co": CODEPOINT_FLAG_CONTROL, # Private Use - "Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate - "Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter - "Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter - "Lo": CODEPOINT_FLAG_LETTER, # Other Letter - "Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter - "Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter - "L&": CODEPOINT_FLAG_LETTER, # Cased Letter - "Mc": CODEPOINT_FLAG_MARK, # Spacing Mark - "Me": CODEPOINT_FLAG_MARK, # Enclosing Mark - "Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark - "Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number - "Nl": CODEPOINT_FLAG_NUMBER, # Letter Number - "No": CODEPOINT_FLAG_NUMBER, # Other Number + "Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined + "Cc": CODEPOINT_FLAG_CONTROL, # Control + "Cf": CODEPOINT_FLAG_CONTROL, # Format + "Co": CODEPOINT_FLAG_CONTROL, # Private Use + "Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate + "Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter + "Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter + "Lo": CODEPOINT_FLAG_LETTER, # Other Letter + "Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter + "Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter + "L&": CODEPOINT_FLAG_LETTER, # Cased Letter + "Mc": CODEPOINT_FLAG_MARK, # Spacing Mark + "Me": CODEPOINT_FLAG_MARK, # Enclosing Mark + "Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark + "Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number + "Nl": CODEPOINT_FLAG_NUMBER, # Letter Number + "No": CODEPOINT_FLAG_NUMBER, # Other Number "Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation "Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation "Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation @@ -84,23 +84,23 @@ def unicode_data_iter(): "Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation "Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation "Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation - "Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol - "Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol - "Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol - "So": CODEPOINT_FLAG_SYMBOL, # Other Symbol - "Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator - "Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator - "Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator + "Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol + "Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol + "Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol + "So": CODEPOINT_FLAG_SYMBOL, # Other Symbol + "Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator + "Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator + "Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator } -codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS +codepoint_flags = array.array("H", [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS table_whitespace = [] table_lowercase = [] table_uppercase = [] table_nfd = [] -for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter(): +for cpt, cpt_lower, cpt_upper, categ, bidir in unicode_data_iter(): # convert codepoint to unicode character char = chr(cpt) @@ -116,7 +116,7 @@ def unicode_data_iter(): table_uppercase.append((cpt, cpt_upper)) # NFD normalization - norm = ord(unicodedata.normalize('NFD', char)[0]) + norm = ord(unicodedata.normalize("NFD", char)[0]) if cpt != norm: table_nfd.append((cpt, norm)) @@ -124,7 +124,9 @@ def unicode_data_iter(): # whitespaces, see "" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt table_whitespace.extend(range(0x0009, 0x000D + 1)) table_whitespace.extend(range(0x2000, 0x200A + 1)) -table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000]) +table_whitespace.extend( + [0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000] +) # sort by codepoint @@ -155,11 +157,13 @@ def unicode_data_iter(): # Generate 'unicode-data.cpp': # python ./scripts//gen-unicode-data.py > unicode-data.cpp + def out(line=""): - print(line, end='\n') # noqa + print(line, end="\n") # noqa -out("""\ +out( + """\ // generated with scripts/gen-unicode-data.py #include "unicode-data.h" @@ -168,9 +172,12 @@ def out(line=""): #include #include #include -""") +""" +) -out("const std::vector> unicode_ranges_flags = { // start, flags // last=next_start-1") +out( + "const std::vector> unicode_ranges_flags = { // start, flags // last=next_start-1" +) for codepoint, flags in ranges_flags: out("{0x%06X, 0x%04X}," % (codepoint, flags)) out("};\n") diff --git a/smallthinker/scripts/get_chat_template.py b/smallthinker/scripts/get_chat_template.py index b4827b31..477bdcb3 100755 --- a/smallthinker/scripts/get_chat_template.py +++ b/smallthinker/scripts/get_chat_template.py @@ -1,15 +1,15 @@ #!/usr/bin/env python -''' - Fetches the Jinja chat template of a HuggingFace model. - If a model has multiple chat templates, you can specify the variant name. +""" +Fetches the Jinja chat template of a HuggingFace model. +If a model has multiple chat templates, you can specify the variant name. - Syntax: - ./scripts/get_chat_template.py model_id [variant] +Syntax: + ./scripts/get_chat_template.py model_id [variant] - Examples: - ./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use - ./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct -''' +Examples: + ./scripts/get_chat_template.py CohereForAI/c4ai-command-r-plus tool_use + ./scripts/get_chat_template.py microsoft/Phi-3.5-mini-instruct +""" import json import re @@ -21,14 +21,23 @@ def get_chat_template(model_id, variant=None): # Use huggingface_hub library if available. # Allows access to gated models if the user has access and ran `huggingface-cli login`. from huggingface_hub import hf_hub_download - with open(hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), encoding="utf-8") as f: + + with open( + hf_hub_download(repo_id=model_id, filename="tokenizer_config.json"), + encoding="utf-8", + ) as f: config_str = f.read() except ImportError: import requests + assert re.match(r"^[\w.-]+/[\w.-]+$", model_id), f"Invalid model ID: {model_id}" - response = requests.get(f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json") + response = requests.get( + f"https://huggingface.co/{model_id}/resolve/main/tokenizer_config.json" + ) if response.status_code == 401: - raise Exception('Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`') + raise Exception( + "Access to this model is gated, please request access, authenticate with `huggingface-cli login` and make sure to run `pip install huggingface_hub`" + ) response.raise_for_status() config_str = response.text @@ -37,27 +46,36 @@ def get_chat_template(model_id, variant=None): except json.JSONDecodeError: # Fix https://huggingface.co/NousResearch/Meta-Llama-3-8B-Instruct/blob/main/tokenizer_config.json # (Remove extra '}' near the end of the file) - config = json.loads(re.sub(r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', r'\1', config_str)) + config = json.loads( + re.sub( + r'\}([\n\s]*\}[\n\s]*\],[\n\s]*"clean_up_tokenization_spaces")', + r"\1", + config_str, + ) + ) - chat_template = config['chat_template'] + chat_template = config["chat_template"] if isinstance(chat_template, str): return chat_template else: - variants = { - ct['name']: ct['template'] - for ct in chat_template - } + variants = {ct["name"]: ct["template"] for ct in chat_template} def format_variants(): - return ', '.join(f'"{v}"' for v in variants.keys()) + return ", ".join(f'"{v}"' for v in variants.keys()) if variant is None: - if 'default' not in variants: - raise Exception(f'Please specify a chat template variant (one of {format_variants()})') - variant = 'default' - sys.stderr.write(f'Note: picked "default" chat template variant (out of {format_variants()})\n') + if "default" not in variants: + raise Exception( + f"Please specify a chat template variant (one of {format_variants()})" + ) + variant = "default" + sys.stderr.write( + f'Note: picked "default" chat template variant (out of {format_variants()})\n' + ) elif variant not in variants: - raise Exception(f"Variant {variant} not found in chat template (found {format_variants()})") + raise Exception( + f"Variant {variant} not found in chat template (found {format_variants()})" + ) return variants[variant] @@ -72,5 +90,5 @@ def main(args): sys.stdout.write(template) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/smallthinker/scripts/sync_vendor.py b/smallthinker/scripts/sync_vendor.py index 1151c9f0..08e57115 100755 --- a/smallthinker/scripts/sync_vendor.py +++ b/smallthinker/scripts/sync_vendor.py @@ -3,20 +3,16 @@ import urllib.request vendor = { - "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", + "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", - # sync manually # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/minja.hpp": "vendor/minja/minja.hpp", # "https://raw.githubusercontent.com/ochafik/minja/refs/heads/main/include/minja/chat-template.hpp": "vendor/minja/chat-template.hpp", - "https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h", - "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.22/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.20.1/httplib.h": "vendor/cpp-httplib/httplib.h", } for url, filename in vendor.items(): - print(f"downloading {url} to {filename}") # noqa: NP100 + print(f"downloading {url} to {filename}") # noqa: NP100 urllib.request.urlretrieve(url, filename) diff --git a/smallthinker/scripts/tool_bench.py b/smallthinker/scripts/tool_bench.py index d8018e2e..fcbacf05 100755 --- a/smallthinker/scripts/tool_bench.py +++ b/smallthinker/scripts/tool_bench.py @@ -1,26 +1,26 @@ #!/usr/bin/env uv run -''' - Simplistic tool call benchmarks for llama-server and ollama. +""" +Simplistic tool call benchmarks for llama-server and ollama. - Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama), - and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap. +Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama), +and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap. - Simple usage example: +Simple usage example: - cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server + cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server - export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server - export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} + export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server + export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp} - ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L - ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M - ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b + ./scripts/tool_bench.py run --n 10 --temp -1 --temp 0 --temp 1 --temp 2 --temp 5 --llama-baseline $PWD/buildMaster/bin/llama-server --output qwen14b.jsonl --hf bartowski/Qwen2.5-14B-Instruct-GGUF:Q4_K_L + ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M" --output qwen1.5b.jsonl --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF --ollama qwen2.5:1.5b-instruct-q4_K_M + ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M" --output qwenc7b.jsonl --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF --ollama qwen2.5-coder:7b - ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap - ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png + ./scripts/tool_bench.py plot *.jsonl # Opens window w/ heatmap + ./scripts/tool_bench.py plot qwen*.jsonl --output qwen.png # Saves heatmap to qwen.png - (please see ./scripts/tool_bench.sh for a more complete example) -''' +(please see ./scripts/tool_bench.sh for a more complete example) +""" # /// script # requires-python = ">=3.10" # dependencies = [ @@ -53,7 +53,12 @@ sys.path.insert(0, Path(__file__).parent.parent.as_posix()) if True: from tools.server.tests.utils import ServerProcess - from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather + from tools.server.tests.unit.test_tool_call import ( + TIMEOUT_SERVER_START, + do_test_calc_result, + do_test_hello_world, + do_test_weather, + ) @contextmanager @@ -62,15 +67,15 @@ def stop(): nonlocal sp if sp is not None: sp.stop() - sp = None # type: ignore + sp = None # type: ignore + atexit.register(stop) yield sp stop() logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s' + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) @@ -78,7 +83,12 @@ def stop(): @app.command() -def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None): +def plot( + files: List[Path], + output: Optional[Path] = None, + test_regex: Optional[str] = None, + server_regex: Optional[str] = None, +): lines: List[Dict] = [] for file in files: @@ -91,7 +101,7 @@ def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[ raw_data = f.read() logger.info(f"Reading {file} ({len(raw_data)} bytes)") - for line_num, line in enumerate(raw_data.split('\n'), 1): + for line_num, line in enumerate(raw_data.split("\n"), 1): line = line.strip() if not line: continue @@ -159,11 +169,7 @@ def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[ matrix: list[list[float]] = [] index: list[str] = [] - all_cols = [ - (server_name, test) - for server_name in server_names - for test in tests - ] + all_cols = [(server_name, test) for server_name in server_names for test in tests] for model in models: for temp in temps: index.append(f"{model} @ {temp}") @@ -180,21 +186,33 @@ def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[ plt.figure(figsize=(12, 6)) sns.heatmap( - df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5, + df, + annot=True, + cmap="RdYlGn", + vmin=0.0, + vmax=1.0, + cbar=True, + fmt=".2f", + center=0.5, + square=True, + linewidths=0.5, cbar_kws={"label": "Success Ratio"}, ) - plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20) + plt.title( + f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", + pad=20, + ) plt.xlabel("Server & Test", labelpad=10) plt.ylabel("Model @ Temperature", labelpad=10) - plt.xticks(rotation=45, ha='right') + plt.xticks(rotation=45, ha="right") plt.yticks(rotation=0) plt.tight_layout() if output: - plt.savefig(output, dpi=300, bbox_inches='tight') + plt.savefig(output, dpi=300, bbox_inches="tight") logger.info(f"Plot saved to {output}") else: plt.show() @@ -203,14 +221,32 @@ def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[ @app.command() def run( output: Annotated[Path, typer.Option(help="Output JSON file")], - model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None, - hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None, - chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None, - chat_template_file: Annotated[Optional[str], typer.Option(help="Chat template file override for llama-server")] = None, - ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None, - llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None, + model: Annotated[ + Optional[str], typer.Option(help="Name of the model to test (server agnostic)") + ] = None, + hf: Annotated[ + Optional[str], + typer.Option( + help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server" + ), + ] = None, + chat_template: Annotated[ + Optional[str], typer.Option(help="Chat template override for llama-server") + ] = None, + chat_template_file: Annotated[ + Optional[str], typer.Option(help="Chat template file override for llama-server") + ] = None, + ollama: Annotated[ + Optional[str], typer.Option(help="Ollama model tag to test") + ] = None, + llama_baseline: Annotated[ + Optional[str], + typer.Option(help="llama-server baseline binary path to use as baseline"), + ] = None, n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10, - temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None, + temp: Annotated[ + Optional[List[float]], typer.Option(help="Set of temperatures to test") + ] = None, top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None, top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None, ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None, @@ -220,14 +256,19 @@ def run( port: Annotated[int, typer.Option(help="llama-server port")] = 8084, force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False, append: Annotated[bool, typer.Option(help="Append to output file")] = False, - - test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True, - test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True, - test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False, + test_hello_world: Annotated[ + bool, typer.Option(help="Whether to run the hello world test") + ] = True, + test_weather: Annotated[ + bool, typer.Option(help="Whether to run the weather test") + ] = True, + test_calc_result: Annotated[ + bool, typer.Option(help="Whether to run the calc result test") + ] = False, ): # Check only one of output and append - n_predict = 512 # High because of DeepSeek R1 + n_predict = 512 # High because of DeepSeek R1 # n_ctx = 8192 n_ctx = 2048 @@ -237,30 +278,46 @@ def run( elif ollama is not None: model = ollama - assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite" - - with output.open('a' if append else 'w') as output_file: - - def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}): + assert ( + force or append or not output.exists() + ), f"Output file already exists: {output}; use --force to overwrite" + + with output.open("a" if append else "w") as output_file: + + def run( + server: ServerProcess, + *, + server_name: str, + model_id: str, + temp: Optional[float] = None, + output_kwargs={}, + request_kwargs={}, + ): request_kwargs = {**request_kwargs} if temp is not None: - request_kwargs['temperature'] = temp + request_kwargs["temperature"] = temp if top_p is not None: - request_kwargs['top_p'] = top_p + request_kwargs["top_p"] = top_p if top_k is not None: - request_kwargs['top_k'] = top_k + request_kwargs["top_k"] = top_k if seed is not None: - request_kwargs['seed'] = seed + request_kwargs["seed"] = seed - request_kwargs['cache_prompt'] = False + request_kwargs["cache_prompt"] = False tests = {} if test_hello_world: - tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs) + tests["hello world"] = lambda server: do_test_hello_world( + server, **request_kwargs + ) if test_weather: - tests["weather"] = lambda server: do_test_weather(server, **request_kwargs) + tests["weather"] = lambda server: do_test_weather( + server, **request_kwargs + ) if test_calc_result: - tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs) + tests["calc result"] = lambda server: do_test_calc_result( + server, None, 512, **request_kwargs + ) for test_name, test in tests.items(): success_count = 0 @@ -279,42 +336,50 @@ def elapsed(): test(server) success_times.append(elapsed()) success_count += 1 - logger.info('success') + logger.info("success") except Exception as e: - logger.error(f'failure: {e}') + logger.error(f"failure: {e}") failure_count += 1 failure_times.append(elapsed()) failures.append(str(e)) # import traceback # traceback.print_exc() - output_file.write(json.dumps({**output_kwargs, **dict( - model=model, - server_name=server_name, - model_id=model_id, - test=test_name, - temp=t, - top_p=top_p, - top_k=top_k, - ctk=ctk, - ctv=ctv, - seed=seed, - success_ratio=float(success_count) / n, - avg_time=mean(success_times + failure_times), - median_time=median(success_times + failure_times), - success_count=success_count, - success_times=success_times, - failure_count=failure_count, - failure_times=failure_times, - failures=list(set(failures)), - )}) + '\n') + output_file.write( + json.dumps( + { + **output_kwargs, + **dict( + model=model, + server_name=server_name, + model_id=model_id, + test=test_name, + temp=t, + top_p=top_p, + top_k=top_k, + ctk=ctk, + ctv=ctv, + seed=seed, + success_ratio=float(success_count) / n, + avg_time=mean(success_times + failure_times), + median_time=median(success_times + failure_times), + success_count=success_count, + success_times=success_times, + failure_count=failure_count, + failure_times=failure_times, + failures=list(set(failures)), + ), + } + ) + + "\n" + ) output_file.flush() for t in [None] if temp is None else [t if t >= 0 else None for t in temp]: if hf is not None: - servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)] + servers: list[Tuple[str, Optional[str]]] = [("llama-server", None)] if llama_baseline is not None: - servers.append(('llama-server (baseline)', llama_baseline)) + servers.append(("llama-server (baseline)", llama_baseline)) for server_name, server_path in servers: server = ServerProcess() @@ -370,7 +435,7 @@ def elapsed(): request_kwargs=dict( model=ollama, max_tokens=n_predict, - num_ctx = n_ctx, + num_ctx=n_ctx, ), ) diff --git a/smallthinker/scripts/verify-checksum-models.py b/smallthinker/scripts/verify-checksum-models.py index 0b5b9aaf..5ce0811d 100755 --- a/smallthinker/scripts/verify-checksum-models.py +++ b/smallthinker/scripts/verify-checksum-models.py @@ -12,7 +12,7 @@ def sha256sum(file): b = bytearray(block_size) file_hash = hashlib.sha256() mv = memoryview(b) - with open(file, 'rb', buffering=0) as f: + with open(file, "rb", buffering=0) as f: while True: n = f.readinto(mv) if not n: @@ -68,17 +68,23 @@ def sha256sum(file): file_missing = "X" # Add the results to the array - results.append({ - "filename": filename, - "valid checksum": valid_checksum, - "file missing": file_missing - }) + results.append( + { + "filename": filename, + "valid checksum": valid_checksum, + "file missing": file_missing, + } + ) # Print column headers for results table -print("filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20)) # noqa: NP100 -print("-" * 80) # noqa: NP100 +print( + "filename".ljust(40) + "valid checksum".center(20) + "file missing".center(20) +) # noqa: NP100 +print("-" * 80) # noqa: NP100 # Output the results as a table for r in results: - print(f"{r['filename']:40} {r['valid checksum']:^20} {r['file missing']:^20}") # noqa: NP100 + print( + f"{r['filename']:40} {r['valid checksum']:^20} {r['file missing']:^20}" + ) # noqa: NP100 diff --git a/smallthinker/tests/test-tokenizer-0.py b/smallthinker/tests/test-tokenizer-0.py index cd760d1c..5fb6487a 100644 --- a/smallthinker/tests/test-tokenizer-0.py +++ b/smallthinker/tests/test-tokenizer-0.py @@ -5,7 +5,9 @@ parser = argparse.ArgumentParser() parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") -parser.add_argument("--fname-tok", help="path to a text file to tokenize", required=True) +parser.add_argument( + "--fname-tok", help="path to a text file to tokenize", required=True +) args = parser.parse_args() dir_tokenizer = args.dir_tokenizer @@ -13,16 +15,18 @@ tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) -print('tokenizing file: ', fname_tok) # noqa: NP100 -fname_out = fname_tok + '.tok' -with open(fname_tok, 'r', encoding='utf-8') as f: +print("tokenizing file: ", fname_tok) # noqa: NP100 +fname_out = fname_tok + ".tok" +with open(fname_tok, "r", encoding="utf-8") as f: lines = f.readlines() - s = ''.join(lines) + s = "".join(lines) t_start = time.time() res = tokenizer.encode(s, add_special_tokens=False) t_end = time.time() - print('\nmain : tokenized in', "{:.3f}".format(1000.0 * (t_end - t_start)), 'ms (py)') # noqa: NP100 - with open(fname_out, 'w', encoding='utf-8') as f: + print( + "\nmain : tokenized in", "{:.3f}".format(1000.0 * (t_end - t_start)), "ms (py)" + ) # noqa: NP100 + with open(fname_out, "w", encoding="utf-8") as f: for x in res: # LLaMA v3 for some reason strips the space for these tokens (and others) # if x == 662: @@ -40,7 +44,7 @@ # else: # f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') # f.write(str(x) + ' \'' + tokenizer.decode(x).strip() + '\'\n') - f.write(str(x) + '\n') - print('len(res): ', len(res)) # noqa: NP100 - print('len(lines): ', len(lines)) # noqa: NP100 -print('results written to: ', fname_out) # noqa: NP100 + f.write(str(x) + "\n") + print("len(res): ", len(res)) # noqa: NP100 + print("len(lines): ", len(lines)) # noqa: NP100 +print("results written to: ", fname_out) # noqa: NP100 diff --git a/smallthinker/tests/test-tokenizer-random.py b/smallthinker/tests/test-tokenizer-random.py index c6cdcb55..8417d1ff 100644 --- a/smallthinker/tests/test-tokenizer-random.py +++ b/smallthinker/tests/test-tokenizer-random.py @@ -30,25 +30,44 @@ class LibLlama: DEFAULT_PATH_LLAMA_H = "./include/llama.h" DEFAULT_PATH_INCLUDES = ["./ggml/include/", "./include/"] - DEFAULT_PATH_LIBLLAMA = "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON + DEFAULT_PATH_LIBLLAMA = ( + "./build/src/libllama.so" # CMakeLists.txt: BUILD_SHARED_LIBS ON + ) - def __init__(self, path_llama_h: str | None = None, path_includes: list[str] = [], path_libllama: str | None = None): + def __init__( + self, + path_llama_h: str | None = None, + path_includes: list[str] = [], + path_libllama: str | None = None, + ): path_llama_h = path_llama_h or self.DEFAULT_PATH_LLAMA_H path_includes = path_includes or self.DEFAULT_PATH_INCLUDES path_libllama = path_libllama or self.DEFAULT_PATH_LIBLLAMA - (self.ffi, self.lib) = self._load_libllama_cffi(path_llama_h, path_includes, path_libllama) + (self.ffi, self.lib) = self._load_libllama_cffi( + path_llama_h, path_includes, path_libllama + ) self.lib.llama_backend_init() - def _load_libllama_cffi(self, path_llama_h: str, path_includes: list[str], path_libllama: str) -> tuple[cffi.FFI, Any]: - cmd = ["gcc", "-O0", "-E", "-P", "-D__restrict=", "-D__attribute__(x)=", "-D__asm__(x)="] + def _load_libllama_cffi( + self, path_llama_h: str, path_includes: list[str], path_libllama: str + ) -> tuple[cffi.FFI, Any]: + cmd = [ + "gcc", + "-O0", + "-E", + "-P", + "-D__restrict=", + "-D__attribute__(x)=", + "-D__asm__(x)=", + ] cmd += ["-I" + path for path in path_includes] + [path_llama_h] res = subprocess.run(cmd, stdout=subprocess.PIPE) - assert (res.returncode == 0) + assert res.returncode == 0 source = res.stdout.decode() ffi = cffi.FFI() if True: # workarounds for pycparser source = "typedef struct { } __builtin_va_list;" + "\n" + source - source = source.replace("sizeof (int)", str(ffi.sizeof("int"))) + source = source.replace("sizeof (int)", str(ffi.sizeof("int"))) source = source.replace("sizeof (void *)", str(ffi.sizeof("void*"))) source = source.replace("sizeof (size_t)", str(ffi.sizeof("size_t"))) source = source.replace("sizeof(int32_t)", str(ffi.sizeof("int32_t"))) @@ -83,7 +102,9 @@ def __init__(self, libllama: LibLlama, path_model: str, mparams={}, cparams={}): cparams = libllama.context_default_params(**cparams) self.ctx = self.lib.llama_new_context_with_model(self.model, cparams) if not self.ctx: - raise RuntimeError("error: failed to create context for model '%s'" % path_model) + raise RuntimeError( + "error: failed to create context for model '%s'" % path_model + ) n_tokens_max = self.lib.llama_n_ctx(self.ctx) self.token_ids = self.ffi.new("llama_token[]", n_tokens_max) self.text_buff = self.ffi.new("uint8_t[]", 1024) @@ -97,24 +118,67 @@ def free(self): self.model = None self.lib = None - def tokenize(self, text: str, add_special: bool = False, parse_special: bool = False) -> list[int]: + def tokenize( + self, text: str, add_special: bool = False, parse_special: bool = False + ) -> list[int]: encoded_text: bytes = text.encode("utf-8") - num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special) + num = self.lib.llama_tokenize( + self.model, + encoded_text, + len(encoded_text), + self.token_ids, + len(self.token_ids), + add_special, + parse_special, + ) while num < 0 and len(self.token_ids) < (16 << 20): self.token_ids = self.ffi.new("llama_token[]", -2 * num) - num = self.lib.llama_tokenize(self.model, encoded_text, len(encoded_text), self.token_ids, len(self.token_ids), add_special, parse_special) + num = self.lib.llama_tokenize( + self.model, + encoded_text, + len(encoded_text), + self.token_ids, + len(self.token_ids), + add_special, + parse_special, + ) return list(self.token_ids[0:num]) - def detokenize(self, ids: list[int], remove_special: bool = False, unparse_special: bool = False) -> str: + def detokenize( + self, + ids: list[int], + remove_special: bool = False, + unparse_special: bool = False, + ) -> str: if len(self.token_ids) < len(ids): self.token_ids = self.ffi.new("llama_token[]", 2 * len(ids)) for i, id in enumerate(ids): self.token_ids[i] = id - num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) + num = self.lib.llama_detokenize( + self.model, + self.token_ids, + len(ids), + self.text_buff, + len(self.text_buff), + remove_special, + unparse_special, + ) while num < 0 and len(self.text_buff) < (16 << 20): self.text_buff = self.ffi.new("uint8_t[]", -2 * num) - num = self.lib.llama_detokenize(self.model, self.token_ids, len(ids), self.text_buff, len(self.text_buff), remove_special, unparse_special) - return str(cast(Buffer, self.ffi.buffer(self.text_buff, num)), encoding="utf-8", errors="replace") # replace errors with '\uFFFD' + num = self.lib.llama_detokenize( + self.model, + self.token_ids, + len(ids), + self.text_buff, + len(self.text_buff), + remove_special, + unparse_special, + ) + return str( + cast(Buffer, self.ffi.buffer(self.text_buff, num)), + encoding="utf-8", + errors="replace", + ) # replace errors with '\uFFFD' class Tokenizer: @@ -126,7 +190,7 @@ def decode(self, ids: list[int]) -> str: raise NotImplementedError -class TokenizerGroundtruth (Tokenizer): +class TokenizerGroundtruth(Tokenizer): def __init__(self, dir_tokenizer: str): self.model: PreTrainedTokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) @@ -143,7 +207,9 @@ def __init__(self, dir_tokenizer: str): self.vocab = list(sorted(self.vocab)) # tokens and lists self.special_tokens = list(self.model.all_special_tokens) - self.added_tokens = self.model.batch_decode(self.model.added_tokens_encoder.values(), skip_special_tokens=False) + self.added_tokens = self.model.batch_decode( + self.model.added_tokens_encoder.values(), skip_special_tokens=False + ) self.bos_token = self.model.bos_token self.eos_token = self.model.eos_token @@ -154,14 +220,19 @@ def decode(self, ids: list[int]) -> str: return self.model.decode(ids, skip_special_tokens=False) -class TokenizerLlamaCpp (Tokenizer): +class TokenizerLlamaCpp(Tokenizer): libllama: LibLlama | None = None def __init__(self, vocab_file: str): if not self.libllama: self.libllama = LibLlama() - self.model = LibLlamaModel(self.libllama, vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096)) + self.model = LibLlamaModel( + self.libllama, + vocab_file, + mparams=dict(vocab_only=True), + cparams=dict(n_ctx=4096), + ) def encode(self, text: str) -> list[int]: return self.model.tokenize(text, add_special=True, parse_special=True) @@ -219,30 +290,30 @@ def generator_custom_text() -> Iterator[str]: def generator_custom_text_edge_cases() -> Iterator[str]: """Edge cases found while debugging""" yield from [ - '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F} - '¼-a', # unicode_ranges_digit, 0x00BC - '½-a', # unicode_ranges_digit, 0x00BD - '¾-a', # unicode_ranges_digit, 0x00BE - 'a 〇b', # unicode_ranges_digit, 0x3007 - 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms - '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) - 'Cửa Việt', # llama-3, ignore_merges = true - 'a', # Phi-3 fail - '<|endoftext|>', # Phi-3 fail - 'a\na', # bert fail - '"`', # falcon - ' \u2e4e', # falcon - '\n\x0b ', # falcon - 'a\xa0\xa0\x00b', # jina-v2-es - 'one ', # jina-v2-es lstrip=true - 'a b', # rstrip phi-3 - 'a b', # lstrip jina-v2 - '\xa0aC', # deepseek - '\u2029 \uA3E4', # deepseek-llm + "\x1f-a", # unicode_ranges_control, {0x00001C, 0x00001F} + "¼-a", # unicode_ranges_digit, 0x00BC + "½-a", # unicode_ranges_digit, 0x00BD + "¾-a", # unicode_ranges_digit, 0x00BE + "a 〇b", # unicode_ranges_digit, 0x3007 + "Ⅵ-a", # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms + "\ufeff//", # unicode_ranges_control, 0xFEFF (BOM) + "Cửa Việt", # llama-3, ignore_merges = true + "a", # Phi-3 fail + "<|endoftext|>", # Phi-3 fail + "a\na", # bert fail + '"`', # falcon + " \u2e4e", # falcon + "\n\x0b ", # falcon + "a\xa0\xa0\x00b", # jina-v2-es + "one ", # jina-v2-es lstrip=true + "a b", # rstrip phi-3 + "a b", # lstrip jina-v2 + "\xa0aC", # deepseek + "\u2029 \ua3e4", # deepseek-llm "a ?", - 'å', # mpt - '\U000ac517', # utf-8 encode error, falcon - '\U000522f4', # utf-8 encode error, starcoder + "å", # mpt + "\U000ac517", # utf-8 encode error, falcon + "\U000522f4", # utf-8 encode error, starcoder "abcd", " abcd", ] @@ -289,20 +360,30 @@ def generator_added_lr_strip(tokenizer: TokenizerGroundtruth) -> Iterator[str]: yield "a" + lstrip + token + rstrip + "z" -def generator_random_added_tokens(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]: +def generator_random_added_tokens( + tokenizer: TokenizerGroundtruth, iterations=100 +) -> Iterator[str]: separations = [" ", "\n", "\t", "-", "!", "one", "1", "", ""] - all_tokens = list(sorted(set(tokenizer.special_tokens + tokenizer.added_tokens + separations))) + all_tokens = list( + sorted(set(tokenizer.special_tokens + tokenizer.added_tokens + separations)) + ) rand = random.Random() for m in range(iterations): rand.seed(m) words = rand.choices(all_tokens, k=500) if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS - while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS + while ( + len(words) > 1 and words[1] == tokenizer.bos_token + ): # leave one starting BOS words.pop(0) if tokenizer.add_bos_token: # drop all starting BOS words.pop(0) - if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS - while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS + if ( + words and words[-1] == tokenizer.eos_token + ): # skip spam warning of double EOS + while ( + len(words) > 1 and words[-2] == tokenizer.eos_token + ): # leave one trailing EOS words.pop(-1) if tokenizer.add_bos_token: # drop all trailing EOS words.pop(-1) @@ -314,13 +395,19 @@ def generator_random_chars(iterations=100) -> Iterator[str]: NUM_WORDS = 400 WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) - CHARS = list(sorted(set(""" + CHARS = list( + sorted( + set( + """ ABCDEFGHIJKLMNOPQRSTUVWXYZ abcdefghijklmnopqrstuvwxyz ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ áéíóúàèìòùâêîôûäëïöü .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_ - """))) + """ + ) + ) + ) rand = random.Random() for m in range(iterations): @@ -344,7 +431,11 @@ def _valid(cpt): return False # if cpt == 0x2029: # deepseek-llm # return False - if unicodedata.category(chr(cpt)) in ("Cn", "Cs", "Co"): # undefined, surrogates, private + if unicodedata.category(chr(cpt)) in ( + "Cn", + "Cs", + "Co", + ): # undefined, surrogates, private return False return True @@ -373,7 +464,9 @@ def generator_random_unicodes(iterations=100) -> Iterator[str]: yield "".join(text) -def generator_random_vocab_chars(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]: +def generator_random_vocab_chars( + tokenizer: TokenizerGroundtruth, iterations=100 +) -> Iterator[str]: """Brute force random text with vocab characters""" vocab_chars = set() @@ -388,7 +481,9 @@ def generator_random_vocab_chars(tokenizer: TokenizerGroundtruth, iterations=100 yield "".join(text) -def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100) -> Iterator[str]: +def generator_random_vocab_words( + tokenizer: TokenizerGroundtruth, iterations=100 +) -> Iterator[str]: """Brute force random text from vocab words""" vocab = [w.strip() for w in tokenizer.vocab] @@ -407,7 +502,11 @@ def generator_random_vocab_words(tokenizer: TokenizerGroundtruth, iterations=100 yield "".join(text) -def compare_tokenizers(tokenizer1: TokenizerGroundtruth, tokenizer2: TokenizerLlamaCpp, generator: Iterator[str]): +def compare_tokenizers( + tokenizer1: TokenizerGroundtruth, + tokenizer2: TokenizerLlamaCpp, + generator: Iterator[str], +): def find_first_mismatch(ids1: list[int] | str, ids2: list[int] | str): for i, (a, b) in enumerate(zip(ids1, ids2)): @@ -423,10 +522,10 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: # equal to source text? if tokenizer1.add_bos_token: # remove BOS if text2.startswith(tokenizer1.bos_token): - text2 = text2[len(tokenizer1.bos_token):] + text2 = text2[len(tokenizer1.bos_token) :] if tokenizer1.add_eos_token: # remove EOS if text2.endswith(tokenizer1.eos_token): - text2 = text2[:-len(tokenizer1.eos_token)] + text2 = text2[: -len(tokenizer1.eos_token)] return text == text2 t_encode1 = 0 @@ -477,17 +576,23 @@ def check_detokenizer(text: str, text1: str, text2: str) -> bool: break t_total = time.perf_counter() - t_start - logger.info(f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}") + logger.info( + f"{generator.__qualname__}: end, {t_encode1=:.3f} {t_encode2=:.3f} {t_decode1=:.3f} {t_decode2=:.3f} {t_total=:.3f}" + ) def main(argv: list[str] | None = None): parser = argparse.ArgumentParser() parser.add_argument("vocab_file", type=str, help="path to vocab 'gguf' file") - parser.add_argument("dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file") - parser.add_argument("--verbose", action="store_true", help="increase output verbosity") + parser.add_argument( + "dir_tokenizer", type=str, help="directory containing 'tokenizer.model' file" + ) + parser.add_argument( + "--verbose", action="store_true", help="increase output verbosity" + ) args = parser.parse_args(argv) - logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO) + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logger.info(f"VOCABFILE: '{args.vocab_file}'") tokenizer1 = TokenizerGroundtruth(args.dir_tokenizer) @@ -514,47 +619,47 @@ def main(argv: list[str] | None = None): if True: logging.basicConfig( - level = logging.DEBUG, - format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s", - datefmt = "%Y-%m-%d %H:%M:%S", - filename = logger.name + ".log", - filemode = "a" + level=logging.DEBUG, + format="%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + filename=logger.name + ".log", + filemode="a", ) logging.basicConfig( - level = logging.DEBUG, - format = "%(levelname)s %(message)s", + level=logging.DEBUG, + format="%(levelname)s %(message)s", ) - path_tokenizers = Path("./models/tokenizers/") + path_tokenizers = Path("./models/tokenizers/") path_vocab_format = "./models/ggml-vocab-%s.gguf" tokenizers = [ - "llama-spm", # SPM - "phi-3", # SPM - "gemma", # SPM - "gemma-2", # SPM - "baichuan", # SPM - "bert-bge", # WPM - "jina-v2-en", # WPM - "llama-bpe", # BPE - "phi-2", # BPE - "deepseek-llm", # BPE - "deepseek-coder", # BPE - "falcon", # BPE - "mpt", # BPE - "starcoder", # BPE - "gpt-2", # BPE - "stablelm2", # BPE - "refact", # BPE - "qwen2", # BPE - "olmo", # BPE - "jina-v2-es", # BPE - "jina-v2-de", # BPE - "smaug-bpe", # BPE - "poro-chat", # BPE - "jina-v2-code", # BPE - "viking", # BPE - "jais", # BPE + "llama-spm", # SPM + "phi-3", # SPM + "gemma", # SPM + "gemma-2", # SPM + "baichuan", # SPM + "bert-bge", # WPM + "jina-v2-en", # WPM + "llama-bpe", # BPE + "phi-2", # BPE + "deepseek-llm", # BPE + "deepseek-coder", # BPE + "falcon", # BPE + "mpt", # BPE + "starcoder", # BPE + "gpt-2", # BPE + "stablelm2", # BPE + "refact", # BPE + "qwen2", # BPE + "olmo", # BPE + "jina-v2-es", # BPE + "jina-v2-de", # BPE + "smaug-bpe", # BPE + "poro-chat", # BPE + "jina-v2-code", # BPE + "viking", # BPE + "jais", # BPE ] logger.info("=" * 50) diff --git a/smallthinker/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py b/smallthinker/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py index 2949faec..d64719c0 100644 --- a/smallthinker/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py +++ b/smallthinker/tools/mtmd/legacy-models/convert_image_encoder_to_gguf.py @@ -16,7 +16,9 @@ def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: +def should_skip_tensor( + name: str, has_text: bool, has_vision: bool, has_llava: bool +) -> bool: if name in ( "logit_scale", "text_model.embeddings.position_ids", @@ -24,7 +26,11 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b ): return True - if has_llava and name in ["visual_projection.weight", "vision_model.post_layernorm.weight", "vision_model.post_layernorm.bias"]: + if has_llava and name in [ + "visual_projection.weight", + "vision_model.post_layernorm.weight", + "vision_model.post_layernorm.bias", + ]: return True if name.startswith("v") and not has_vision: @@ -53,11 +59,25 @@ def get_tensor_name(name: str) -> str: return name if "mm_projector" in name: name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + name = re.sub(r"mm\.mlp\.mlp", "mm.model.mlp", name, count=1) + name = re.sub(r"mm\.peg\.peg", "mm.model.peg", name, count=1) return name - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + return ( + name.replace("text_model", "t") + .replace("vision_model", "v") + .replace("encoder.layers", "blk") + .replace("embeddings.", "") + .replace("_proj", "") + .replace("self_attn.", "attn_") + .replace("layer_norm", "ln") + .replace("layernorm", "ln") + .replace("mlp.fc1", "ffn_down") + .replace("mlp.fc2", "ffn_up") + .replace("embedding", "embd") + .replace("final", "post") + .replace("layrnorm", "ln") + ) def bytes_to_unicode(): @@ -87,52 +107,113 @@ def bytes_to_unicode(): ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument('--bigendian', action="store_true", default=False, help="Model is executed on big-endian machine") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") +ap.add_argument( + "-m", + "--model-dir", + help="Path to model directory cloned from HF Hub", + required=True, +) +ap.add_argument( + "--use-f32", action="store_true", default=False, help="Use f32 instead of f16" +) +ap.add_argument( + "--bigendian", + action="store_true", + default=False, + help="Model is executed on big-endian machine", +) +ap.add_argument( + "--text-only", + action="store_true", + required=False, + help="Save a text-only model. It can't be used to encode images", +) +ap.add_argument( + "--vision-only", + action="store_true", + required=False, + help="Save a vision-only model. It can't be used to encode texts", +) +ap.add_argument( + "--clip-model-is-vision", + action="store_true", + required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)", +) # Selectable visual encoders that are compatible with this script encoder_group = ap.add_mutually_exclusive_group() -encoder_group.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -encoder_group.add_argument("--clip-model-is-siglip", action="store_true", required=False, - help="the visual encoder is Siglip.") - -ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +encoder_group.add_argument( + "--clip-model-is-openclip", + action="store_true", + required=False, + help="The clip model is from openclip (for ViT-SO400M type))", +) +encoder_group.add_argument( + "--clip-model-is-siglip", + action="store_true", + required=False, + help="the visual encoder is Siglip.", +) + +ap.add_argument( + "--llava-projector", + help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.", +) +ap.add_argument( + "--projector-type", + help="Type of projector. Possible values: mlp, ldp, ldpv2", + choices=["mlp", "ldp", "ldpv2"], + default="mlp", +) +ap.add_argument( + "-o", + "--output-dir", + help="Directory to save GGUF files. Default is the original model directory", + default=None, +) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) +ap.add_argument( + "--image-mean", + type=float, + nargs="+", + help="Mean of the images for normalization (overrides processor) ", + default=None, +) +ap.add_argument( + "--image-std", + type=float, + nargs="+", + help="Standard deviation of the images for normalization (overrides processor)", + default=None, +) # with proper args = ap.parse_args() if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") + print( + "--text-only and --image-only arguments cannot be specified at the same time." + ) exit(1) if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + print( + "WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet." + ) # output in the same directory as the model if output_dir is None dir_model = args.model_dir if ( - args.clip_model_is_vision or - not os.path.exists(dir_model + "/vocab.json") or - args.clip_model_is_openclip or - args.clip_model_is_siglip + args.clip_model_is_vision + or not os.path.exists(dir_model + "/vocab.json") + or args.clip_model_is_openclip + or args.clip_model_is_siglip ): vocab = None tokens = None @@ -192,13 +273,21 @@ def bytes_to_unicode(): os.makedirs(output_dir, exist_ok=True) output_prefix = os.path.basename(output_dir).replace("ggml_", "") fname_out = os.path.join(output_dir, f"{fname_middle}model-{ftype_str[ftype]}.gguf") -fout = GGUFWriter(path=fname_out, arch="clip", endianess=GGUFEndian.LITTLE if not args.bigendian else GGUFEndian.BIG) +fout = GGUFWriter( + path=fname_out, + arch="clip", + endianess=GGUFEndian.LITTLE if not args.bigendian else GGUFEndian.BIG, +) fout.add_bool("clip.has_text_encoder", has_text_encoder) fout.add_bool("clip.has_vision_encoder", has_vision_encoder) fout.add_bool("clip.has_llava_projector", has_llava_projector) fout.add_file_type(ftype) -model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model) +model_name = ( + config["_name_or_path"] + if "_name_or_path" in config + else os.path.basename(dir_model) +) fout.add_name(model_name) if args.text_only: fout.add_description("text-only CLIP model") @@ -229,7 +318,6 @@ def bytes_to_unicode(): fout.add_token_list(tokens) - def get_non_negative_vision_feature_layers(v_hparams): """ Determine the vision feature layer(s) for the llava model, which are indices into the @@ -243,7 +331,9 @@ def get_non_negative_vision_feature_layers(v_hparams): the model as an unset value. If no vision feature layer is found, we leave it unset. """ num_hidden_layers = v_hparams["num_hidden_layers"] - to_non_negative = lambda layer_idx: layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1 + to_non_negative = lambda layer_idx: ( + layer_idx if layer_idx >= 0 else num_hidden_layers + layer_idx + 1 + ) feature_layers_key = None # Key used for llava models in transformers if "vision_feature_layer" in config: @@ -257,6 +347,7 @@ def get_non_negative_vision_feature_layers(v_hparams): feature_layers = [feature_layers] return [to_non_negative(feature_layer) for feature_layer in feature_layers] + # Determine if we have explicitly specified vision feature layers in our config feature_layers = get_non_negative_vision_feature_layers(v_hparams) @@ -265,7 +356,9 @@ def get_non_negative_vision_feature_layers(v_hparams): if args.clip_model_is_siglip: visual_projection_dim = 0 else: - visual_projection_dim = v_hparams.get("projection_dim", config["projection_dim"]) + visual_projection_dim = v_hparams.get( + "projection_dim", config["projection_dim"] + ) # set vision_model hparams fout.add_uint32("clip.vision.image_size", v_hparams["image_size"]) @@ -273,46 +366,54 @@ def get_non_negative_vision_feature_layers(v_hparams): fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) fout.add_uint32("clip.vision.projection_dim", visual_projection_dim) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) - fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"]) + fout.add_uint32( + k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"] + ) + fout.add_float32( + k(KEY_ATTENTION_LAYERNORM_EPS, VISION), v_hparams["layer_norm_eps"] + ) if feature_layers: block_count = max(feature_layers) else: - block_count = v_hparams["num_hidden_layers"] - 1 if has_llava_projector else v_hparams["num_hidden_layers"] + block_count = ( + v_hparams["num_hidden_layers"] - 1 + if has_llava_projector + else v_hparams["num_hidden_layers"] + ) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), block_count) - # /** - # "image_grid_pinpoints": [ - # [ - # 336, - # 672 - # ], - # [ - # 672, - # 336 - # ], - # [ - # 672, - # 672 - # ], - # [ - # 1008, - # 336 - # ], - # [ - # 336, - # 1008 - # ] - # ], - # Flattened: - # [ - # 336, 672, - # 672, 336, - # 672, 672, - # 1008, 336, - # 336, 1008 - # ] - # * - # */ + # /** + # "image_grid_pinpoints": [ + # [ + # 336, + # 672 + # ], + # [ + # 672, + # 336 + # ], + # [ + # 672, + # 672 + # ], + # [ + # 1008, + # 336 + # ], + # [ + # 336, + # 1008 + # ] + # ], + # Flattened: + # [ + # 336, 672, + # 672, 336, + # 672, 672, + # 1008, 336, + # 336, 1008 + # ] + # * + # */ if "image_grid_pinpoints" in v_hparams: # flatten it image_grid_pinpoints = [] @@ -321,23 +422,41 @@ def get_non_negative_vision_feature_layers(v_hparams): image_grid_pinpoints.append(p) fout.add_array("clip.vision.image_grid_pinpoints", image_grid_pinpoints) if "image_crop_resolution" in v_hparams: - fout.add_uint32("clip.vision.image_crop_resolution", v_hparams["image_crop_resolution"]) + fout.add_uint32( + "clip.vision.image_crop_resolution", v_hparams["image_crop_resolution"] + ) if "image_aspect_ratio" in v_hparams: - fout.add_string("clip.vision.image_aspect_ratio", v_hparams["image_aspect_ratio"]) + fout.add_string( + "clip.vision.image_aspect_ratio", v_hparams["image_aspect_ratio"] + ) if "image_split_resolution" in v_hparams: - fout.add_uint32("clip.vision.image_split_resolution", v_hparams["image_split_resolution"]) + fout.add_uint32( + "clip.vision.image_split_resolution", v_hparams["image_split_resolution"] + ) if "mm_patch_merge_type" in v_hparams: - fout.add_string("clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"]) + fout.add_string( + "clip.vision.mm_patch_merge_type", v_hparams["mm_patch_merge_type"] + ) if "mm_projector_type" in v_hparams: fout.add_string("clip.vision.mm_projector_type", v_hparams["mm_projector_type"]) if feature_layers: fout.add_array("clip.vision.feature_layer", feature_layers) if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean # pyright: ignore[reportAttributeAccessIssue] - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std # pyright: ignore[reportAttributeAccessIssue] + image_mean = ( + processor.image_processor.image_mean + if args.image_mean is None or args.image_mean == default_image_mean + else args.image_mean + ) # pyright: ignore[reportAttributeAccessIssue] + image_std = ( + processor.image_processor.image_std + if args.image_std is None or args.image_std == default_image_std + else args.image_std + ) # pyright: ignore[reportAttributeAccessIssue] else: - image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_mean = ( + args.image_mean if args.image_mean is not None else default_image_mean + ) image_std = args.image_std if args.image_std is not None else default_image_std fout.add_array("clip.vision.image_mean", image_mean) fout.add_array("clip.vision.image_std", image_std) @@ -352,7 +471,9 @@ def get_non_negative_vision_feature_layers(v_hparams): if feature_layers is None: model.vision_model.encoder.layers.pop(-1) else: - model.vision_model.encoder.layers = model.vision_model.encoder.layers[:max(feature_layers)] + model.vision_model.encoder.layers = model.vision_model.encoder.layers[ + : max(feature_layers) + ] projector = torch.load(args.llava_projector) for name, data in projector.items(): @@ -369,7 +490,9 @@ def get_non_negative_vision_feature_layers(v_hparams): state_dict = model.state_dict() for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_llava_projector): + if should_skip_tensor( + name, has_text_encoder, has_vision_encoder, has_llava_projector + ): # we don't need this print(f"skipping parameter: {name}") continue diff --git a/smallthinker/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py b/smallthinker/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py index 848ef1cf..69738a97 100644 --- a/smallthinker/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py +++ b/smallthinker/tools/mtmd/legacy-models/glmedge-convert-image-encoder-to-gguf.py @@ -11,11 +11,14 @@ VISION = "clip.vision" from transformers import SiglipVisionModel, SiglipVisionConfig + def k(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: bool) -> bool: +def should_skip_tensor( + name: str, has_text: bool, has_vision: bool, has_llava: bool +) -> bool: if name in ( "logit_scale", "text_model.embeddings.position_ids", @@ -34,7 +37,7 @@ def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_llava: b "vision_model.head.mlp.fc1.weight", "vision_model.head.mlp.fc1.bias", "vision_model.head.mlp.fc2.weight", - "vision_model.head.mlp.fc2.bias" + "vision_model.head.mlp.fc2.bias", ): return True @@ -52,11 +55,25 @@ def get_tensor_name(name: str) -> str: return name if "mm_projector" in name: name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + name = re.sub(r"mm\.mlp\.mlp", "mm.model.mlp", name, count=1) + name = re.sub(r"mm\.peg\.peg", "mm.model.peg", name, count=1) return name - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + return ( + name.replace("text_model", "t") + .replace("vision_model", "v") + .replace("encoder.layers", "blk") + .replace("embeddings.", "") + .replace("_proj", "") + .replace("self_attn.", "attn_") + .replace("layer_norm", "ln") + .replace("layernorm", "ln") + .replace("mlp.fc1", "ffn_down") + .replace("mlp.fc2", "ffn_up") + .replace("embedding", "embd") + .replace("final", "post") + .replace("layrnorm", "ln") + ) def bytes_to_unicode(): @@ -86,41 +103,97 @@ def bytes_to_unicode(): ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--llava-projector", help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2","adapter"], default="adapter") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +ap.add_argument( + "-m", + "--model-dir", + help="Path to model directory cloned from HF Hub", + required=True, +) +ap.add_argument( + "--use-f32", action="store_true", default=False, help="Use f32 instead of f16" +) +ap.add_argument( + "--text-only", + action="store_true", + required=False, + help="Save a text-only model. It can't be used to encode images", +) +ap.add_argument( + "--vision-only", + action="store_true", + required=False, + help="Save a vision-only model. It can't be used to encode texts", +) +ap.add_argument( + "--clip-model-is-vision", + action="store_true", + required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)", +) +ap.add_argument( + "--clip-model-is-openclip", + action="store_true", + required=False, + help="The clip model is from openclip (for ViT-SO400M type))", +) +ap.add_argument( + "--llava-projector", + help="Path to llava.projector file. If specified, save an image encoder for LLaVA models.", +) +ap.add_argument( + "--projector-type", + help="Type of projector. Possible values: mlp, ldp, ldpv2", + choices=["mlp", "ldp", "ldpv2", "adapter"], + default="adapter", +) +ap.add_argument( + "-o", + "--output-dir", + help="Directory to save GGUF files. Default is the original model directory", + default=None, +) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 default_image_mean = [0.5, 0.5, 0.5] default_image_std = [0.5, 0.5, 0.5] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) +ap.add_argument( + "--image-mean", + type=float, + nargs="+", + help="Mean of the images for normalization (overrides processor) ", + default=None, +) +ap.add_argument( + "--image-std", + type=float, + nargs="+", + help="Standard deviation of the images for normalization (overrides processor)", + default=None, +) # with proper args = ap.parse_args() if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") + print( + "--text-only and --image-only arguments cannot be specified at the same time." + ) exit(1) if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + print( + "WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet." + ) # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: +if ( + args.clip_model_is_vision + or not os.path.exists(dir_model + "/vocab.json") + or args.clip_model_is_openclip +): vocab = None tokens = None else: @@ -179,7 +252,11 @@ def bytes_to_unicode(): fout.add_bool("clip.has_vision_encoder", has_vision_encoder) fout.add_bool("clip.has_glm_projector", has_glm_projector) fout.add_file_type(ftype) -model_name = config["_name_or_path"] if "_name_or_path" in config else os.path.basename(dir_model) +model_name = ( + config["_name_or_path"] + if "_name_or_path" in config + else os.path.basename(dir_model) +) fout.add_name(model_name) if has_glm_projector: fout.add_description("image encoder for glm4v") @@ -194,7 +271,10 @@ def bytes_to_unicode(): fout.add_uint32(k(KEY_CONTEXT_LENGTH, TEXT), t_hparams["max_position_embeddings"]) fout.add_uint32(k(KEY_EMBEDDING_LENGTH, TEXT), t_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, TEXT), t_hparams["intermediate_size"]) - fout.add_uint32("clip.text.projection_dim", t_hparams.get("projection_dim", config["projection_dim"])) + fout.add_uint32( + "clip.text.projection_dim", + t_hparams.get("projection_dim", config["projection_dim"]), + ) fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, TEXT), t_hparams["num_attention_heads"]) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, TEXT), t_hparams["layer_norm_eps"]) fout.add_uint32(k(KEY_BLOCK_COUNT, TEXT), t_hparams["num_hidden_layers"]) @@ -207,7 +287,9 @@ def bytes_to_unicode(): fout.add_uint32(k(KEY_EMBEDDING_LENGTH, VISION), v_hparams["hidden_size"]) fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), v_hparams["intermediate_size"]) fout.add_uint32("clip.vision.projection_dim", 0) - fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"]) + fout.add_uint32( + k(KEY_ATTENTION_HEAD_COUNT, VISION), v_hparams["num_attention_heads"] + ) fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), v_hparams["num_hidden_layers"]) @@ -230,14 +312,16 @@ def bytes_to_unicode(): else: data = data.squeeze().numpy().astype(np.float32) if name.startswith("vision."): - name=name.replace("vision.","") + name = name.replace("vision.", "") fout.add_tensor(name, data) print(f"Projector {name} - {data.dtype} - shape = {data.shape}") # print(f"Projector {name} tensors added\n") state_dict = model.state_dict() # pyright: ignore[reportAttributeAccessIssue] for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_glm_projector): + if should_skip_tensor( + name, has_text_encoder, has_vision_encoder, has_glm_projector + ): # we don't need this print(f"skipping parameter: {name}") continue diff --git a/smallthinker/tools/mtmd/legacy-models/glmedge-surgery.py b/smallthinker/tools/mtmd/legacy-models/glmedge-surgery.py index 16bb915d..ffb46f7f 100644 --- a/smallthinker/tools/mtmd/legacy-models/glmedge-surgery.py +++ b/smallthinker/tools/mtmd/legacy-models/glmedge-surgery.py @@ -8,7 +8,9 @@ args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True) +model = AutoModel.from_pretrained( + args.model, trust_remote_code=True, local_files_only=True +) checkpoint = model.state_dict() # get a list of mm tensor names @@ -18,9 +20,14 @@ projector = {name: checkpoint[name].float() for name in mm_tensors} torch.save(projector, f"{args.model}/glm.projector") -clip_tensors = [k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.")] +clip_tensors = [ + k for k, v in checkpoint.items() if k.startswith("vision.vit.model.vision_model.") +] if len(clip_tensors) > 0: - clip = {name.replace("vision.vit.model.", ""): checkpoint[name].float() for name in clip_tensors} + clip = { + name.replace("vision.vit.model.", ""): checkpoint[name].float() + for name in clip_tensors + } torch.save(clip, f"{args.model}/glm.clip") # added tokens should be removed to be able to convert Mistral models diff --git a/smallthinker/tools/mtmd/legacy-models/llava_surgery.py b/smallthinker/tools/mtmd/legacy-models/llava_surgery.py index 4f2da3be..0aa96212 100644 --- a/smallthinker/tools/mtmd/legacy-models/llava_surgery.py +++ b/smallthinker/tools/mtmd/legacy-models/llava_surgery.py @@ -22,17 +22,18 @@ # BakLLaVA models contain CLIP tensors in it clip_tensors = [k for k, v in checkpoint.items() if k.startswith("model.vision_tower")] if len(clip_tensors) > 0: - clip = {name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() for name in clip_tensors} + clip = { + name.replace("vision_tower.vision_tower.", ""): checkpoint[name].float() + for name in clip_tensors + } torch.save(clip, f"{args.model}/llava.clip") - # added tokens should be removed to be able to convert Mistral models if os.path.exists(f"{args.model}/added_tokens.json"): with open(f"{args.model}/added_tokens.json", "w") as f: f.write("{}\n") - print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") print(f"Also, use {args.model}/llava.projector to prepare a llava-encoder.gguf file.") diff --git a/smallthinker/tools/mtmd/legacy-models/llava_surgery_v2.py b/smallthinker/tools/mtmd/legacy-models/llava_surgery_v2.py index b07c3e32..68ef75b7 100644 --- a/smallthinker/tools/mtmd/legacy-models/llava_surgery_v2.py +++ b/smallthinker/tools/mtmd/legacy-models/llava_surgery_v2.py @@ -6,61 +6,69 @@ from safetensors.torch import save_file from typing import Any, ContextManager, cast + # Function to determine if file is a SafeTensor file def is_safetensor_file(file_path): - return file_path.endswith('.safetensors') + return file_path.endswith(".safetensors") # Unified loading function def load_model(file_path): if is_safetensor_file(file_path): tensors = {} - with cast(ContextManager[Any], safe_open(file_path, framework="pt", device="cpu")) as f: + with cast( + ContextManager[Any], safe_open(file_path, framework="pt", device="cpu") + ) as f: for key in f.keys(): tensors[key] = f.get_tensor(key).clone() # output shape print(f"{key} : {tensors[key].shape}") - return tensors, 'safetensor' + return tensors, "safetensor" else: - return torch.load(file_path, map_location=torch.device('cpu')), 'pytorch' + return torch.load(file_path, map_location=torch.device("cpu")), "pytorch" # Unified saving function def save_model(model, file_path, file_type): - if file_type == 'safetensor': + if file_type == "safetensor": # safe_save(model, file_path) save_file(model, file_path) else: torch.save(model, file_path) + # Helpers to match weight names from specific components or # determine if a saved shard contains that component def is_vision_tower(weight_name): return ( - weight_name.startswith("model.vision_tower") or - weight_name.startswith("vit.") or - weight_name.startswith("vision_tower") + weight_name.startswith("model.vision_tower") + or weight_name.startswith("vit.") + or weight_name.startswith("vision_tower") ) + def is_newline(weight_name): - return ( - weight_name.startswith("model.image_newline") or - weight_name.startswith("image_newline") + return weight_name.startswith("model.image_newline") or weight_name.startswith( + "image_newline" ) + def is_mm_projector(weight_name): return ( - weight_name.startswith("model.mm_projector") or - weight_name.startswith("vision_proj.") or - weight_name.startswith("multi_modal_projector") + weight_name.startswith("model.mm_projector") + or weight_name.startswith("vision_proj.") + or weight_name.startswith("multi_modal_projector") ) + def newline_criteria(checkpoint): return any(is_newline(k) for k in checkpoint.keys()) + def proj_criteria(checkpoint): return any(is_mm_projector(k) for k in checkpoint.keys()) + # Adapted function to clean vision tower from checkpoint def clean_vision_tower_from_checkpoint(checkpoint_path): checkpoint, file_type = load_model(checkpoint_path) @@ -82,13 +90,15 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): existing_clip = {} # Update existing_clip with new tensors, avoid duplicates for name in clip_tensors: - simple_name = name[name.index('vision_model.'):] if 'vision_model.' in name else name + simple_name = ( + name[name.index("vision_model.") :] if "vision_model." in name else name + ) print(f"Adding {simple_name} to llava.clip") if simple_name not in existing_clip: existing_clip[simple_name] = checkpoint[name] # Save the updated clip tensors back to llava.clip - save_model(existing_clip, clip_path, 'pytorch') + save_model(existing_clip, clip_path, "pytorch") # Remove the tensors from the original checkpoint for name in clip_tensors: @@ -98,6 +108,7 @@ def clean_vision_tower_from_checkpoint(checkpoint_path): return True return False + def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): newline_checkpoint_path = None projector_checkpoint_path = None @@ -115,28 +126,54 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): # Command-line interface setup ap = argparse.ArgumentParser() ap.add_argument("-m", "--model", required=True, help="Path to LLaVA v1.5+ model") -ap.add_argument("-C", "--clean-vision-tower", action="store_true", help="Remove any vision tower from the model files") +ap.add_argument( + "-C", + "--clean-vision-tower", + action="store_true", + help="Remove any vision tower from the model files", +) args = ap.parse_args() if args.clean_vision_tower: # Generalized to handle both PyTorch and SafeTensors models - model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) + model_files = sorted( + glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True + ) # checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and path.startswith('pytorch')) or (path.endswith('.safetensors') and path.startswith('model'))] - checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] + checkpoint_paths = [ + path + for path in model_files + if (path.endswith(".bin") and "pytorch" in path.split("/")[-1].split("\\")[-1]) + or ( + path.endswith(".safetensors") + and "model" in path.split("/")[-1].split("\\")[-1] + ) + ] for projector_checkpoint_path in checkpoint_paths: print(f"Cleaning {projector_checkpoint_path}") if not clean_vision_tower_from_checkpoint(projector_checkpoint_path): print(f"No vision tower found in {projector_checkpoint_path}") # we break once none is found, so far all models append them at the end # break - print("Done! All vision tower tensors are removed from the model files and stored in llava.clip file.") + print( + "Done! All vision tower tensors are removed from the model files and stored in llava.clip file." + ) # Now we look for the projector in the last checkpoint model_files = sorted(glob.glob(f"{args.model}/*"), key=os.path.getmtime, reverse=True) -checkpoint_paths = [path for path in model_files if (path.endswith('.bin') and 'pytorch' in path.split('/')[-1].split('\\')[-1]) or (path.endswith('.safetensors') and 'model' in path.split('/')[-1].split('\\')[-1])] +checkpoint_paths = [ + path + for path in model_files + if (path.endswith(".bin") and "pytorch" in path.split("/")[-1].split("\\")[-1]) + or ( + path.endswith(".safetensors") and "model" in path.split("/")[-1].split("\\")[-1] + ) +] # last_checkpoint_path = checkpoint_paths[0] # first_checkpoint_path = checkpoint_paths[-1] -newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints(checkpoint_paths, newline_criteria, proj_criteria) +newline_checkpoint_path, projector_checkpoint_path = find_relevant_checkpoints( + checkpoint_paths, newline_criteria, proj_criteria +) print(f"Taking projector from {projector_checkpoint_path}") first_mm_tensors = [] @@ -157,7 +194,9 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): if last_checkpoint is not None: for k, v in last_checkpoint.items(): print(k) - print(f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors.") + print( + f"Found {len(mm_tensors)} tensors to extract out of {len(last_checkpoint) if last_checkpoint is not None else 0} tensors." + ) print("No tensors found. Is this a LLaVA model?") exit() @@ -173,7 +212,7 @@ def find_relevant_checkpoints(checkpoint_paths, newline_criteria, projector): projector[name] = first_checkpoint[name].float() if len(projector) > 0: - save_model(projector, f"{args.model}/llava.projector", 'pytorch') + save_model(projector, f"{args.model}/llava.projector", "pytorch") print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") diff --git a/smallthinker/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/smallthinker/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py index cfe0961f..90d71925 100644 --- a/smallthinker/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +++ b/smallthinker/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Siglip model. """ +"""PyTorch Siglip model.""" # Copied from HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit and add tgt_sizes @@ -37,6 +37,7 @@ logger = logging.get_logger(__name__) + class SiglipVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`SiglipVisionModel`]. It is used to instantiate a @@ -107,6 +108,7 @@ def __init__( self.layer_norm_eps = layer_norm_eps self.hidden_act = hidden_act + _CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ @@ -114,6 +116,7 @@ def __init__( # See all SigLIP models at https://huggingface.co/models?filter=siglip ] + # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) @@ -177,7 +180,11 @@ def norm_cdf(x): def trunc_normal_tf_( - tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, ): """Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the @@ -233,6 +240,7 @@ def lecun_normal_(tensor): def default_flax_embed_init(tensor): variance_scaling_(tensor, mode="fan_in", distribution="normal") + class SiglipVisionEmbeddings(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() @@ -254,6 +262,7 @@ def __init__(self, config: SiglipVisionConfig): self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -277,6 +286,7 @@ def __init__(self, config): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip class SiglipMLP(nn.Module): def __init__(self, config): @@ -293,13 +303,12 @@ def __init__(self, config: SiglipVisionConfig): super().__init__() self.embed_dim = config.hidden_size self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.self_attn = ( - SiglipAttention(config) - ) + self.self_attn = SiglipAttention(config) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + class SiglipPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -383,9 +392,12 @@ class SiglipEncoder(nn.Module): def __init__(self, config: SiglipVisionConfig): super().__init__() self.config = config - self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList( + [SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) self.gradient_checkpointing = False + class SiglipVisionTransformer(SiglipPreTrainedModel): config_class = SiglipVisionConfig main_input_name = "pixel_values" @@ -407,13 +419,17 @@ def __init__(self, config: SiglipVisionConfig): def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embedding + import argparse import json import re import numpy as np from gguf import * -from transformers.models.idefics2.modeling_idefics2 import Idefics2VisionTransformer, Idefics2VisionConfig +from transformers.models.idefics2.modeling_idefics2 import ( + Idefics2VisionTransformer, + Idefics2VisionConfig, +) TEXT = "clip.text" VISION = "clip.vision" @@ -423,7 +439,9 @@ def add_key_str(raw_key: str, arch: str) -> str: return raw_key.format(arch=arch) -def should_skip_tensor(name: str, has_text: bool, has_vision: bool, has_minicpmv: bool) -> bool: +def should_skip_tensor( + name: str, has_text: bool, has_vision: bool, has_minicpmv: bool +) -> bool: if name in ( "logit_scale", "text_model.embeddings.position_ids", @@ -448,11 +466,25 @@ def get_tensor_name(name: str) -> str: return name if "mm_projector" in name: name = name.replace("model.mm_projector", "mm") - name = re.sub(r'mm\.mlp\.mlp', 'mm.model.mlp', name, count=1) - name = re.sub(r'mm\.peg\.peg', 'mm.model.peg', name, count=1) + name = re.sub(r"mm\.mlp\.mlp", "mm.model.mlp", name, count=1) + name = re.sub(r"mm\.peg\.peg", "mm.model.peg", name, count=1) return name - return name.replace("text_model", "t").replace("vision_model", "v").replace("encoder.layers", "blk").replace("embeddings.", "").replace("_proj", "").replace("self_attn.", "attn_").replace("layer_norm", "ln").replace("layernorm", "ln").replace("mlp.fc1", "ffn_down").replace("mlp.fc2", "ffn_up").replace("embedding", "embd").replace("final", "post").replace("layrnorm", "ln") + return ( + name.replace("text_model", "t") + .replace("vision_model", "v") + .replace("encoder.layers", "blk") + .replace("embeddings.", "") + .replace("_proj", "") + .replace("self_attn.", "attn_") + .replace("layer_norm", "ln") + .replace("layernorm", "ln") + .replace("mlp.fc1", "ffn_down") + .replace("mlp.fc2", "ffn_up") + .replace("embedding", "embd") + .replace("final", "post") + .replace("layrnorm", "ln") + ) def bytes_to_unicode(): @@ -482,42 +514,103 @@ def bytes_to_unicode(): ap = argparse.ArgumentParser() -ap.add_argument("-m", "--model-dir", help="Path to model directory cloned from HF Hub", required=True) -ap.add_argument("--use-f32", action="store_true", default=False, help="Use f32 instead of f16") -ap.add_argument("--text-only", action="store_true", required=False, - help="Save a text-only model. It can't be used to encode images") -ap.add_argument("--vision-only", action="store_true", required=False, - help="Save a vision-only model. It can't be used to encode texts") -ap.add_argument("--clip-model-is-vision", action="store_true", required=False, - help="The clip model is a pure vision model (ShareGPT4V vision extract for example)") -ap.add_argument("--clip-model-is-openclip", action="store_true", required=False, - help="The clip model is from openclip (for ViT-SO400M type))") -ap.add_argument("--minicpmv-projector", help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.") -ap.add_argument("--projector-type", help="Type of projector. Possible values: mlp, ldp, ldpv2", choices=["mlp", "ldp", "ldpv2"], default="mlp") -ap.add_argument("-o", "--output-dir", help="Directory to save GGUF files. Default is the original model directory", default=None) +ap.add_argument( + "-m", + "--model-dir", + help="Path to model directory cloned from HF Hub", + required=True, +) +ap.add_argument( + "--use-f32", action="store_true", default=False, help="Use f32 instead of f16" +) +ap.add_argument( + "--text-only", + action="store_true", + required=False, + help="Save a text-only model. It can't be used to encode images", +) +ap.add_argument( + "--vision-only", + action="store_true", + required=False, + help="Save a vision-only model. It can't be used to encode texts", +) +ap.add_argument( + "--clip-model-is-vision", + action="store_true", + required=False, + help="The clip model is a pure vision model (ShareGPT4V vision extract for example)", +) +ap.add_argument( + "--clip-model-is-openclip", + action="store_true", + required=False, + help="The clip model is from openclip (for ViT-SO400M type))", +) +ap.add_argument( + "--minicpmv-projector", + help="Path to minicpmv.projector file. If specified, save an image encoder for MiniCPM-V models.", +) +ap.add_argument( + "--projector-type", + help="Type of projector. Possible values: mlp, ldp, ldpv2", + choices=["mlp", "ldp", "ldpv2"], + default="mlp", +) +ap.add_argument( + "-o", + "--output-dir", + help="Directory to save GGUF files. Default is the original model directory", + default=None, +) # Example --image_mean 0.48145466 0.4578275 0.40821073 --image_std 0.26862954 0.26130258 0.27577711 # Example --image_mean 0.5 0.5 0.5 --image_std 0.5 0.5 0.5 default_image_mean = [0.48145466, 0.4578275, 0.40821073] default_image_std = [0.26862954, 0.26130258, 0.27577711] -ap.add_argument('--image-mean', type=float, nargs='+', help='Mean of the images for normalization (overrides processor) ', default=None) -ap.add_argument('--image-std', type=float, nargs='+', help='Standard deviation of the images for normalization (overrides processor)', default=None) -ap.add_argument('--minicpmv_version', type=int, help='minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4', default=2) +ap.add_argument( + "--image-mean", + type=float, + nargs="+", + help="Mean of the images for normalization (overrides processor) ", + default=None, +) +ap.add_argument( + "--image-std", + type=float, + nargs="+", + help="Standard deviation of the images for normalization (overrides processor)", + default=None, +) +ap.add_argument( + "--minicpmv_version", + type=int, + help="minicpmv_version: MiniCPM-V-2 use 1; MiniCPM-V-2.5 use 2; MiniCPM-V-2.6 use 3; MiniCPM-o-2.6 use 4", + default=2, +) # with proper args = ap.parse_args() if args.text_only and args.vision_only: - print("--text-only and --image-only arguments cannot be specified at the same time.") + print( + "--text-only and --image-only arguments cannot be specified at the same time." + ) exit(1) if args.use_f32: - print("WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet.") + print( + "WARNING: Weights for the convolution op is always saved in f16, as the convolution op in GGML does not support 32-bit kernel weights yet." + ) # output in the same directory as the model if output_dir is None dir_model = args.model_dir -if args.clip_model_is_vision or not os.path.exists(dir_model + "/vocab.json") or args.clip_model_is_openclip: +if ( + args.clip_model_is_vision + or not os.path.exists(dir_model + "/vocab.json") + or args.clip_model_is_openclip +): vocab = None tokens = None else: @@ -560,14 +653,14 @@ def bytes_to_unicode(): block_count = 27 default_vision_config = { - "hidden_size": 1152, - "image_size": 980, - "intermediate_size": 4304, - "model_type": "idefics2", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "patch_size": 14, - } + "hidden_size": 1152, + "image_size": 980, + "intermediate_size": 4304, + "model_type": "idefics2", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, +} vision_config = Idefics2VisionConfig(**default_vision_config) model = Idefics2VisionTransformer(vision_config) @@ -637,10 +730,20 @@ def bytes_to_unicode(): fout.add_uint32(add_key_str(KEY_BLOCK_COUNT, VISION), block_count) if processor is not None: - image_mean = processor.image_processor.image_mean if args.image_mean is None or args.image_mean == default_image_mean else args.image_mean - image_std = processor.image_processor.image_std if args.image_std is None or args.image_std == default_image_std else args.image_std + image_mean = ( + processor.image_processor.image_mean + if args.image_mean is None or args.image_mean == default_image_mean + else args.image_mean + ) + image_std = ( + processor.image_processor.image_std + if args.image_std is None or args.image_std == default_image_std + else args.image_std + ) else: - image_mean = args.image_mean if args.image_mean is not None else default_image_mean + image_mean = ( + args.image_mean if args.image_mean is not None else default_image_mean + ) image_std = args.image_std if args.image_std is not None else default_image_std fout.add_array("clip.vision.image_mean", image_mean) fout.add_array("clip.vision.image_std", image_std) @@ -648,6 +751,7 @@ def bytes_to_unicode(): use_gelu = True fout.add_bool("clip.use_gelu", use_gelu) + def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position @@ -656,11 +760,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000 ** omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) @@ -668,6 +772,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb + def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): assert embed_dim % 2 == 0 @@ -702,15 +807,20 @@ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) return pos_embed + def _replace_name_resampler(s, v): if re.match("resampler.pos_embed", s): return { s: v, - re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), + re.sub("pos_embed", "pos_embed_k", s): torch.from_numpy( + get_2d_sincos_pos_embed(emb_dim, (70, 70)) + ), } if re.match("resampler.proj", s): return { - re.sub("proj", "pos_embed_k", s): torch.from_numpy(get_2d_sincos_pos_embed(emb_dim, (70, 70))), + re.sub("proj", "pos_embed_k", s): torch.from_numpy( + get_2d_sincos_pos_embed(emb_dim, (70, 70)) + ), re.sub("proj", "proj.weight", s): v.transpose(-1, -2).contiguous(), } if re.match("resampler.attn.in_proj_.*", s): @@ -721,6 +831,7 @@ def _replace_name_resampler(s, v): } return {s: v} + if has_minicpmv_projector: projector = torch.load(args.minicpmv_projector) new_state_dict = {} @@ -755,6 +866,7 @@ def _replace_name_resampler(s, v): print("Projector tensors added\n") + def _replace_name(s, v): s = "vision_model." + s if re.match("vision_model.embeddings.position_embedding", s): @@ -763,6 +875,7 @@ def _replace_name(s, v): return {s: v} + state_dict = model.state_dict() new_state_dict = {} for k, v in state_dict.items(): @@ -771,7 +884,9 @@ def _replace_name(s, v): new_state_dict[nk] = nv state_dict = new_state_dict for name, data in state_dict.items(): - if should_skip_tensor(name, has_text_encoder, has_vision_encoder, has_minicpmv_projector): + if should_skip_tensor( + name, has_text_encoder, has_vision_encoder, has_minicpmv_projector + ): # we don't need this print(f"skipping parameter: {name}") continue diff --git a/smallthinker/tools/mtmd/legacy-models/minicpmv-surgery.py b/smallthinker/tools/mtmd/legacy-models/minicpmv-surgery.py index ba821165..62bca488 100644 --- a/smallthinker/tools/mtmd/legacy-models/minicpmv-surgery.py +++ b/smallthinker/tools/mtmd/legacy-models/minicpmv-surgery.py @@ -8,7 +8,12 @@ args = ap.parse_args() # find the model part that includes the the multimodal projector weights -model = AutoModel.from_pretrained(args.model, trust_remote_code=True, local_files_only=True, torch_dtype=torch.bfloat16) +model = AutoModel.from_pretrained( + args.model, + trust_remote_code=True, + local_files_only=True, + torch_dtype=torch.bfloat16, +) checkpoint = model.state_dict() # get a list of mm tensor names @@ -34,7 +39,7 @@ "AutoModel": "modeling_minicpm.MiniCPMModel", "AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM", "AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM", - "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification" + "AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification", } model.llm.save_pretrained(f"{args.model}/model") tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) @@ -42,4 +47,6 @@ print("Done!") print(f"Now you can convert {args.model} to a regular LLaMA GGUF file.") -print(f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file.") +print( + f"Also, use {args.model}/minicpmv.projector to prepare a minicpmv-encoder.gguf file." +) diff --git a/smallthinker/tools/server/README.md b/smallthinker/tools/server/README.md index 06533c17..90adc973 100644 --- a/smallthinker/tools/server/README.md +++ b/smallthinker/tools/server/README.md @@ -208,7 +208,7 @@ services: llamacpp-server: image: ghcr.io/ggml-org/llama.cpp:server ports: - - 8080:8080 + - 18001:8080 volumes: - ./models:/models environment: @@ -254,6 +254,8 @@ For more details, please refer to [multimodal documentation](../../docs/multimod The project includes a web-based user interface that enables interaction with the model through the `/chat/completions` endpoint. +For Chrome AI Hub integration in this repository, prefer Docker runtime and map host port `18001` to container port `8080`. + The web UI is developed using: - `react` framework for frontend development - `tailwindcss` and `daisyui` for styling @@ -261,16 +263,13 @@ The web UI is developed using: A pre-built version is available as a single HTML file under `/public` directory. -To build or to run the dev server (with hot reload): +To build the web UI assets: ```sh # make sure you have nodejs installed cd tools/server/webui npm i -# to run the dev server -npm run dev - # to build the public/index.html.gz npm run build ``` @@ -279,38 +278,23 @@ headers (like build/tools/server/index.html.gz.hpp) that will be included by server.cpp. This is done by building `llama-server` as described in the [build](#build) section above. -NOTE: if you are using the vite dev server, you can change the API base URL to llama.cpp. To do that, run this code snippet in browser's console: +NOTE: for Chrome AI Hub integration, if you need to override API base URL in browser console: ```js -localStorage.setItem('base', 'http://localhost:8080') +localStorage.setItem('base', 'http://localhost:18001') ``` ## Quick Start -To get started right away, run the following command, making sure to use the correct path for the model you have: - -### Unix-based systems (Linux, macOS, etc.) - -```bash -./llama-server -m models/7B/ggml-model.gguf -c 2048 -``` - -### Windows - -```powershell -llama-server.exe -m models\7B\ggml-model.gguf -c 2048 -``` - -The above command will start a server that by default listens on `127.0.0.1:8080`. -You can consume the endpoints with Postman or NodeJS with axios library. You can visit the web front end at the same url. +For Chrome AI Hub usage in this repository, run the server in Docker only and use host endpoint `http://localhost:18001`. ### Docker ```bash -docker run -p 8080:8080 -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:server -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 +docker run -p 18001:8080 -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:server -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 # or, with CUDA: -docker run -p 8080:8080 -v /path/to/models:/models --gpus all ghcr.io/ggml-org/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99 +docker run -p 18001:8080 -v /path/to/models:/models --gpus all ghcr.io/ggml-org/llama.cpp:server-cuda -m models/7B/ggml-model.gguf -c 512 --host 0.0.0.0 --port 8080 --n-gpu-layers 99 ``` ## Testing with CURL @@ -319,7 +303,7 @@ Using [curl](https://curl.se/). On Windows, `curl.exe` should be available in th ```sh curl --request POST \ - --url http://localhost:8080/completion \ + --url http://localhost:18001/completion \ --header "Content-Type: application/json" \ --data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}' ``` @@ -345,7 +329,7 @@ Create an index.js file and put this inside: const prompt = "Building a website can be done in 10 simple steps:" async function test() { - let response = await fetch("http://127.0.0.1:8080/completion", { + let response = await fetch("http://localhost:18001/completion", { method: "POST", body: JSON.stringify({ prompt, @@ -1089,7 +1073,7 @@ Example usage with `openai` python library: import openai client = openai.OpenAI( - base_url="http://localhost:8080/v1", # "http://:port" + base_url="http://localhost:18001/v1", # "http://:port" api_key = "sk-no-key-required" ) @@ -1122,7 +1106,7 @@ You can use either Python `openai` library with appropriate checkpoints: import openai client = openai.OpenAI( - base_url="http://localhost:8080/v1", # "http://:port" + base_url="http://localhost:18001/v1", # "http://:port" api_key = "sk-no-key-required" ) @@ -1140,7 +1124,7 @@ print(completion.choices[0].message) ... or raw HTTP requests: ```shell -curl http://localhost:8080/v1/chat/completions \ +curl http://localhost:18001/v1/chat/completions \ -H "Content-Type: application/json" \ -H "Authorization: Bearer no-key" \ -d '{ @@ -1177,7 +1161,7 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r - input as string ```shell - curl http://localhost:8080/v1/embeddings \ + curl http://localhost:18001/v1/embeddings \ -H "Content-Type: application/json" \ -H "Authorization: Bearer no-key" \ -d '{ @@ -1190,7 +1174,7 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r - `input` as string array ```shell - curl http://localhost:8080/v1/embeddings \ + curl http://localhost:18001/v1/embeddings \ -H "Content-Type: application/json" \ -H "Authorization: Bearer no-key" \ -d '{ @@ -1272,7 +1256,7 @@ A new chat-based UI has replaced the old completion-based since [this PR](https: For example: ```sh -./llama-server -m my_model.gguf -c 8192 --path ./tools/server/public_legacy +docker run -p 18001:8080 -v /path/to/models:/models ghcr.io/ggml-org/llama.cpp:server -m /models/my_model.gguf -c 8192 --host 0.0.0.0 --port 8080 --path ./tools/server/public_legacy ``` ### Extending or building alternative Web Front End diff --git a/smallthinker/tools/server/bench/bench.py b/smallthinker/tools/server/bench/bench.py index 5cc6f92a..b30cc4af 100644 --- a/smallthinker/tools/server/bench/bench.py +++ b/smallthinker/tools/server/bench/bench.py @@ -27,24 +27,65 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--runner-label", type=str, help="Runner label", required=True) parser.add_argument("--branch", type=str, help="Branch name", default="detached") parser.add_argument("--commit", type=str, help="Commit name", default="dirty") - parser.add_argument("--host", type=str, help="Server listen host", default="0.0.0.0") + parser.add_argument( + "--host", type=str, help="Server listen host", default="0.0.0.0" + ) parser.add_argument("--port", type=int, help="Server listen host", default="8080") - parser.add_argument("--model-path-prefix", type=str, help="Prefix where to store the model files", default="models") - parser.add_argument("--n-prompts", type=int, - help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", required=True) - parser.add_argument("--max-prompt-tokens", type=int, - help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset", - required=True) - parser.add_argument("--max-tokens", type=int, - help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens", - required=True) - parser.add_argument("--hf-repo", type=str, help="Hugging Face model repository", required=True) - parser.add_argument("--hf-file", type=str, help="Hugging Face model file", required=True) - parser.add_argument("-ngl", "--n-gpu-layers", type=int, help="layers to the GPU for computation", required=True) - parser.add_argument("--ctx-size", type=int, help="Set the size of the prompt context", required=True) - parser.add_argument("--parallel", type=int, help="Set the number of slots for process requests", required=True) - parser.add_argument("--batch-size", type=int, help="Set the batch size for prompt processing", required=True) - parser.add_argument("--ubatch-size", type=int, help="physical maximum batch size", required=True) + parser.add_argument( + "--model-path-prefix", + type=str, + help="Prefix where to store the model files", + default="models", + ) + parser.add_argument( + "--n-prompts", + type=int, + help="SERVER_BENCH_N_PROMPTS: total prompts to randomly select in the benchmark", + required=True, + ) + parser.add_argument( + "--max-prompt-tokens", + type=int, + help="SERVER_BENCH_MAX_PROMPT_TOKENS: maximum prompt tokens to filter out in the dataset", + required=True, + ) + parser.add_argument( + "--max-tokens", + type=int, + help="SERVER_BENCH_MAX_CONTEXT: maximum context size of the completions request to filter out in the dataset: prompt + predicted tokens", + required=True, + ) + parser.add_argument( + "--hf-repo", type=str, help="Hugging Face model repository", required=True + ) + parser.add_argument( + "--hf-file", type=str, help="Hugging Face model file", required=True + ) + parser.add_argument( + "-ngl", + "--n-gpu-layers", + type=int, + help="layers to the GPU for computation", + required=True, + ) + parser.add_argument( + "--ctx-size", type=int, help="Set the size of the prompt context", required=True + ) + parser.add_argument( + "--parallel", + type=int, + help="Set the number of slots for process requests", + required=True, + ) + parser.add_argument( + "--batch-size", + type=int, + help="Set the batch size for prompt processing", + required=True, + ) + parser.add_argument( + "--ubatch-size", type=int, help="physical maximum batch size", required=True + ) parser.add_argument("--scenario", type=str, help="Scenario to run", required=True) parser.add_argument("--duration", type=str, help="Bench scenario", required=True) @@ -66,20 +107,23 @@ def main(args_in: list[str] | None = None) -> None: try: start_benchmark(args) - with open("results.github.env", 'w') as github_env: + with open("results.github.env", "w") as github_env: # parse output - with open('k6-results.json', 'r') as bench_results: + with open("k6-results.json", "r") as bench_results: # Load JSON data from file data = json.load(bench_results) - for metric_name in data['metrics']: - for metric_metric in data['metrics'][metric_name]: - value = data['metrics'][metric_name][metric_metric] + for metric_name in data["metrics"]: + for metric_metric in data["metrics"][metric_name]: + value = data["metrics"][metric_name][metric_metric] if isinstance(value, float) or isinstance(value, int): value = round(value, 2) - data['metrics'][metric_name][metric_metric]=value + data["metrics"][metric_name][metric_metric] = value github_env.write( - f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n") - iterations = data['root_group']['checks']['success completion']['passes'] + f"{escape_metric_name(metric_name)}_{escape_metric_name(metric_metric)}={value}\n" + ) + iterations = data["root_group"]["checks"]["success completion"][ + "passes" + ] except Exception: print("bench: error :") @@ -89,7 +133,7 @@ def main(args_in: list[str] | None = None) -> None: if server_process: try: print(f"bench: shutting down server pid={server_process.pid} ...") - if os.name == 'nt': + if os.name == "nt": interrupt = signal.CTRL_C_EVENT else: interrupt = signal.SIGINT @@ -97,55 +141,80 @@ def main(args_in: list[str] | None = None) -> None: server_process.wait(0.5) except subprocess.TimeoutExpired: - print(f"server still alive after 500ms, force-killing pid={server_process.pid} ...") + print( + f"server still alive after 500ms, force-killing pid={server_process.pid} ..." + ) server_process.kill() # SIGKILL server_process.wait() while is_server_listening(args.host, args.port): time.sleep(0.1) - title = (f"llama.cpp {args.name} on {args.runner_label}\n " - f"duration={args.duration} {iterations} iterations") - xlabel = (f"{args.hf_repo}/{args.hf_file}\n" - f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n" - f"branch={args.branch} commit={args.commit}") + title = ( + f"llama.cpp {args.name} on {args.runner_label}\n " + f"duration={args.duration} {iterations} iterations" + ) + xlabel = ( + f"{args.hf_repo}/{args.hf_file}\n" + f"parallel={args.parallel} ctx-size={args.ctx_size} ngl={args.n_gpu_layers} batch-size={args.batch_size} ubatch-size={args.ubatch_size} pp={args.max_prompt_tokens} pp+tg={args.max_tokens}\n" + f"branch={args.branch} commit={args.commit}" + ) # Prometheus end_time = time.time() prometheus_metrics = {} if is_server_listening("0.0.0.0", 9090): - metrics = ['prompt_tokens_seconds', 'predicted_tokens_seconds', - 'kv_cache_usage_ratio', 'requests_processing', 'requests_deferred'] + metrics = [ + "prompt_tokens_seconds", + "predicted_tokens_seconds", + "kv_cache_usage_ratio", + "requests_processing", + "requests_deferred", + ] for metric in metrics: - resp = requests.get(f"http://localhost:9090/api/v1/query_range", - params={'query': 'llamacpp:' + metric, 'start': start_time, 'end': end_time, 'step': 2}) - - with open(f"{metric}.json", 'w') as metric_json: + resp = requests.get( + f"http://localhost:9090/api/v1/query_range", + params={ + "query": "llamacpp:" + metric, + "start": start_time, + "end": end_time, + "step": 2, + }, + ) + + with open(f"{metric}.json", "w") as metric_json: metric_json.write(resp.text) if resp.status_code != 200: - print(f"bench: unable to extract prometheus metric {metric}: {resp.text}") + print( + f"bench: unable to extract prometheus metric {metric}: {resp.text}" + ) else: metric_data = resp.json() - values = metric_data['data']['result'][0]['values'] + values = metric_data["data"]["result"][0]["values"] timestamps, metric_values = zip(*values) metric_values = [float(value) for value in metric_values] prometheus_metrics[metric] = metric_values - timestamps_dt = [str(datetime.fromtimestamp(int(ts))) for ts in timestamps] + timestamps_dt = [ + str(datetime.fromtimestamp(int(ts))) for ts in timestamps + ] plt.figure(figsize=(16, 10), dpi=80) plt.plot(timestamps_dt, metric_values, label=metric) - plt.xticks(rotation=0, fontsize=14, horizontalalignment='center', alpha=.7) - plt.yticks(fontsize=12, alpha=.7) + plt.xticks( + rotation=0, fontsize=14, horizontalalignment="center", alpha=0.7 + ) + plt.yticks(fontsize=12, alpha=0.7) ylabel = f"llamacpp:{metric}" - plt.title(title, - fontsize=14, wrap=True) - plt.grid(axis='both', alpha=.3) + plt.title(title, fontsize=14, wrap=True) + plt.grid(axis="both", alpha=0.3) plt.ylabel(ylabel, fontsize=22) plt.xlabel(xlabel, fontsize=14, wrap=True) plt.gca().xaxis.set_major_locator(matplotlib.dates.MinuteLocator()) - plt.gca().xaxis.set_major_formatter(matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S")) + plt.gca().xaxis.set_major_formatter( + matplotlib.dates.DateFormatter("%Y-%m-%d %H:%M:%S") + ) plt.gcf().autofmt_xdate() # Remove borders @@ -155,13 +224,12 @@ def main(args_in: list[str] | None = None) -> None: plt.gca().spines["left"].set_alpha(0.3) # Save the plot as a jpg image - plt.savefig(f'{metric}.jpg', dpi=60) + plt.savefig(f"{metric}.jpg", dpi=60) plt.close() # Mermaid format in case images upload failed - with open(f"{metric}.mermaid", 'w') as mermaid_f: - mermaid = ( - f"""--- + with open(f"{metric}.mermaid", "w") as mermaid_f: + mermaid = f"""--- config: xyChart: titleFontSize: 12 @@ -176,56 +244,73 @@ def main(args_in: list[str] | None = None) -> None: y-axis "llamacpp:{metric}" x-axis "llamacpp:{metric}" {int(min(timestamps))} --> {int(max(timestamps))} line [{', '.join([str(round(float(value), 2)) for value in metric_values])}] - """) + """ mermaid_f.write(mermaid) # 140 chars max for commit status description bench_results = { "i": iterations, "req": { - "p95": round(data['metrics']["http_req_duration"]["p(95)"], 2), - "avg": round(data['metrics']["http_req_duration"]["avg"], 2), + "p95": round(data["metrics"]["http_req_duration"]["p(95)"], 2), + "avg": round(data["metrics"]["http_req_duration"]["avg"], 2), }, "pp": { - "p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2), - "avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0, + "p95": round( + data["metrics"]["llamacpp_prompt_processing_second"]["p(95)"], 2 + ), + "avg": round( + data["metrics"]["llamacpp_prompt_processing_second"]["avg"], 2 + ), + "0": ( + round(mean(prometheus_metrics["prompt_tokens_seconds"]), 2) + if "prompt_tokens_seconds" in prometheus_metrics + else 0 + ), }, "tg": { - "p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2), - "avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2), - "0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0, + "p95": round(data["metrics"]["llamacpp_tokens_second"]["p(95)"], 2), + "avg": round(data["metrics"]["llamacpp_tokens_second"]["avg"], 2), + "0": ( + round(mean(prometheus_metrics["predicted_tokens_seconds"]), 2) + if "predicted_tokens_seconds" in prometheus_metrics + else 0 + ), }, } - with open("results.github.env", 'a') as github_env: - github_env.write(f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n") + with open("results.github.env", "a") as github_env: + github_env.write( + f"BENCH_RESULTS={json.dumps(bench_results, indent=None, separators=(',', ':') )}\n" + ) github_env.write(f"BENCH_ITERATIONS={iterations}\n") - title = title.replace('\n', ' ') - xlabel = xlabel.replace('\n', ' ') + title = title.replace("\n", " ") + xlabel = xlabel.replace("\n", " ") github_env.write(f"BENCH_GRAPH_TITLE={title}\n") github_env.write(f"BENCH_GRAPH_XLABEL={xlabel}\n") def start_benchmark(args): - k6_path = './k6' - if 'BENCH_K6_BIN_PATH' in os.environ: - k6_path = os.environ['BENCH_K6_BIN_PATH'] + k6_path = "./k6" + if "BENCH_K6_BIN_PATH" in os.environ: + k6_path = os.environ["BENCH_K6_BIN_PATH"] k6_args = [ - 'run', args.scenario, - '--no-color', - '--no-connection-reuse', - '--no-vu-connection-reuse', + "run", + args.scenario, + "--no-color", + "--no-connection-reuse", + "--no-vu-connection-reuse", ] - k6_args.extend(['--duration', args.duration]) - k6_args.extend(['--iterations', args.n_prompts]) - k6_args.extend(['--vus', args.parallel]) - k6_args.extend(['--summary-export', 'k6-results.json']) - k6_args.extend(['--out', 'csv=k6-results.csv']) + k6_args.extend(["--duration", args.duration]) + k6_args.extend(["--iterations", args.n_prompts]) + k6_args.extend(["--vus", args.parallel]) + k6_args.extend(["--summary-export", "k6-results.json"]) + k6_args.extend(["--out", "csv=k6-results.csv"]) args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} " - args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]]) + args = args + " ".join([str(arg) for arg in [k6_path, *k6_args]]) print(f"bench: starting k6 with: {args}") - k6_completed = subprocess.run(args, shell=True, stdout=sys.stdout, stderr=sys.stderr) + k6_completed = subprocess.run( + args, shell=True, stdout=sys.stdout, stderr=sys.stderr + ) if k6_completed.returncode != 0: raise Exception("bench: unable to run k6") @@ -235,7 +320,7 @@ def start_server(args): attempts = 0 max_attempts = 600 - if 'GITHUB_ACTIONS' in os.environ: + if "GITHUB_ACTIONS" in os.environ: max_attempts *= 2 while not is_server_listening(args.host, args.port): @@ -259,42 +344,45 @@ def start_server(args): def start_server_background(args): # Start the server - server_path = '../../../build/bin/llama-server' - if 'LLAMA_SERVER_BIN_PATH' in os.environ: - server_path = os.environ['LLAMA_SERVER_BIN_PATH'] + server_path = "../../../build/bin/llama-server" + if "LLAMA_SERVER_BIN_PATH" in os.environ: + server_path = os.environ["LLAMA_SERVER_BIN_PATH"] server_args = [ - '--host', args.host, - '--port', args.port, + "--host", + args.host, + "--port", + args.port, ] - server_args.extend(['--hf-repo', args.hf_repo]) - server_args.extend(['--hf-file', args.hf_file]) - server_args.extend(['--n-gpu-layers', args.n_gpu_layers]) - server_args.extend(['--ctx-size', args.ctx_size]) - server_args.extend(['--parallel', args.parallel]) - server_args.extend(['--batch-size', args.batch_size]) - server_args.extend(['--ubatch-size', args.ubatch_size]) - server_args.extend(['--n-predict', args.max_tokens * 2]) - server_args.extend(['--defrag-thold', "0.1"]) - server_args.append('--cont-batching') - server_args.append('--metrics') - server_args.append('--flash-attn') + server_args.extend(["--hf-repo", args.hf_repo]) + server_args.extend(["--hf-file", args.hf_file]) + server_args.extend(["--n-gpu-layers", args.n_gpu_layers]) + server_args.extend(["--ctx-size", args.ctx_size]) + server_args.extend(["--parallel", args.parallel]) + server_args.extend(["--batch-size", args.batch_size]) + server_args.extend(["--ubatch-size", args.ubatch_size]) + server_args.extend(["--n-predict", args.max_tokens * 2]) + server_args.extend(["--defrag-thold", "0.1"]) + server_args.append("--cont-batching") + server_args.append("--metrics") + server_args.append("--flash-attn") args = [str(arg) for arg in [server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}") - pkwargs = { - 'stdout': subprocess.PIPE, - 'stderr': subprocess.PIPE - } + pkwargs = {"stdout": subprocess.PIPE, "stderr": subprocess.PIPE} server_process = subprocess.Popen( - args, - **pkwargs) # pyright: ignore[reportArgumentType, reportCallIssue] + args, **pkwargs + ) # pyright: ignore[reportArgumentType, reportCallIssue] def server_log(in_stream, out_stream): - for line in iter(in_stream.readline, b''): - print(line.decode('utf-8'), end='', file=out_stream) + for line in iter(in_stream.readline, b""): + print(line.decode("utf-8"), end="", file=out_stream) - thread_stdout = threading.Thread(target=server_log, args=(server_process.stdout, sys.stdout)) + thread_stdout = threading.Thread( + target=server_log, args=(server_process.stdout, sys.stdout) + ) thread_stdout.start() - thread_stderr = threading.Thread(target=server_log, args=(server_process.stderr, sys.stderr)) + thread_stderr = threading.Thread( + target=server_log, args=(server_process.stderr, sys.stderr) + ) thread_stderr.start() return server_process @@ -316,8 +404,8 @@ def is_server_ready(server_fqdn, server_port): def escape_metric_name(metric_name): - return re.sub('[^A-Z0-9]', '_', metric_name.upper()) + return re.sub("[^A-Z0-9]", "_", metric_name.upper()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/smallthinker/tools/server/tests/unit/test_basic.py b/smallthinker/tools/server/tests/unit/test_basic.py index 1485de8c..514f7154 100644 --- a/smallthinker/tools/server/tests/unit/test_basic.py +++ b/smallthinker/tools/server/tests/unit/test_basic.py @@ -47,7 +47,7 @@ def test_server_slots(): server.server_slots = False server.start() res = server.make_request("GET", "/slots") - assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED + assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED assert "error" in res.body server.stop() @@ -70,11 +70,15 @@ def test_load_split_model(): server.model_hf_file = "tinyllamas/split/stories15M-q8_0-00001-of-00003.gguf" server.model_alias = "tinyllama-split" server.start() - res = server.make_request("POST", "/completion", data={ - "n_predict": 16, - "prompt": "Hello", - "temperature": 0.0, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": 16, + "prompt": "Hello", + "temperature": 0.0, + }, + ) assert res.status_code == 200 assert match_regex("(little|girl)+", res.body["content"]) diff --git a/smallthinker/tools/server/tests/unit/test_chat_completion.py b/smallthinker/tools/server/tests/unit/test_chat_completion.py index 1b5205f7..5fde22b4 100644 --- a/smallthinker/tools/server/tests/unit/test_chat_completion.py +++ b/smallthinker/tools/server/tests/unit/test_chat_completion.py @@ -4,6 +4,7 @@ server: ServerProcess + @pytest.fixture(autouse=True) def create_server(): global server @@ -15,38 +16,151 @@ def create_server(): [ (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None), (None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None), - (None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'), - (None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", False, None), - ("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length", True, None), - (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None), - (None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None), - ] + ( + None, + "Book", + "What is the best book", + 8, + '(Suddenly)+|\\{ " Sarax.', + 77, + 8, + "length", + False, + None, + ), + ( + None, + "Book", + "What is the best book", + 8, + '(Suddenly)+|\\{ " Sarax.', + 77, + 8, + "length", + True, + None, + ), + ( + None, + "Book", + "What is the best book", + 8, + '(Suddenly)+|\\{ " Sarax.', + 77, + 8, + "length", + True, + "chatml", + ), + ( + None, + "Book", + "What is the best book", + 8, + "^ blue", + 23, + 8, + "length", + True, + "This is not a chat template, it is", + ), + ( + "codellama70b", + "You are a coding assistant.", + "Write the fibonacci function in c++.", + 128, + "(Aside|she|felter|alonger)+", + 104, + 64, + "length", + False, + None, + ), + ( + "codellama70b", + "You are a coding assistant.", + "Write the fibonacci function in c++.", + 128, + "(Aside|she|felter|alonger)+", + 104, + 64, + "length", + True, + None, + ), + ( + None, + "Book", + [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": "the best book"}, + ], + 8, + "Whillicter", + 79, + 8, + "length", + False, + None, + ), + ( + None, + "Book", + [ + {"type": "text", "text": "What is"}, + {"type": "text", "text": "the best book"}, + ], + 8, + "Whillicter", + 79, + 8, + "length", + True, + None, + ), + ], ) -def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template): +def test_chat_completion( + model, + system_prompt, + user_prompt, + max_tokens, + re_content, + n_prompt, + n_predicted, + finish_reason, + jinja, + chat_template, +): global server server.jinja = jinja server.chat_template = chat_template server.start() - res = server.make_request("POST", "/chat/completions", data={ - "model": model, - "max_tokens": max_tokens, - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "model": model, + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + }, + ) assert res.status_code == 200 - assert "cmpl" in res.body["id"] # make sure the completion id has the expected format + assert ( + "cmpl" in res.body["id"] + ) # make sure the completion id has the expected format assert res.body["system_fingerprint"].startswith("b") assert res.body["model"] == model if model is not None else server.model_alias assert res.body["usage"]["prompt_tokens"] == n_prompt assert res.body["usage"]["completion_tokens"] == n_predicted choice = res.body["choices"][0] assert "assistant" == choice["message"]["role"] - assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + assert match_regex( + re_content, choice["message"]["content"] + ), f'Expected {re_content}, got {choice["message"]["content"]}' assert choice["finish_reason"] == finish_reason @@ -54,21 +168,41 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte "system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason", [ ("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"), - ("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"), - ] + ( + "You are a coding assistant.", + "Write the fibonacci function in c++.", + 128, + "(Aside|she|felter|alonger)+", + 104, + 64, + "length", + ), + ], ) -def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason): +def test_chat_completion_stream( + system_prompt, + user_prompt, + max_tokens, + re_content, + n_prompt, + n_predicted, + finish_reason, +): global server - server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL + server.model_alias = None # try using DEFAULT_OAICOMPAT_MODEL server.start() - res = server.make_stream_request("POST", "/chat/completions", data={ - "max_tokens": max_tokens, - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt}, - ], - "stream": True, - }) + res = server.make_stream_request( + "POST", + "/chat/completions", + data={ + "max_tokens": max_tokens, + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], + "stream": True, + }, + ) content = "" last_cmpl_id = None for i, data in enumerate(res): @@ -80,10 +214,14 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte else: assert "role" not in choice["delta"] assert data["system_fingerprint"].startswith("b") - assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future + assert ( + "gpt-3.5" in data["model"] + ) # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future if last_cmpl_id is None: last_cmpl_id = data["id"] - assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream + assert ( + last_cmpl_id == data["id"] + ) # make sure the completion id is the same for all events in the stream if choice["finish_reason"] in ["stop", "length"]: assert data["usage"]["prompt_tokens"] == n_prompt assert data["usage"]["completion_tokens"] == n_predicted @@ -92,13 +230,15 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte assert choice["finish_reason"] == finish_reason else: assert choice["finish_reason"] is None - content += choice["delta"]["content"] or '' + content += choice["delta"]["content"] or "" def test_chat_completion_with_openai_library(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", messages=[ @@ -120,55 +260,86 @@ def test_chat_template(): server.chat_template = "llama3" server.debug = True # to get the "__verbose" object in the response server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": 8, - "messages": [ - {"role": "system", "content": "Book"}, - {"role": "user", "content": "What is the best book"}, - ] - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "max_tokens": 8, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + }, + ) assert res.status_code == 200 assert "__verbose" in res.body - assert res.body["__verbose"]["prompt"] == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + assert ( + res.body["__verbose"]["prompt"] + == " <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ) def test_apply_chat_template(): global server server.chat_template = "command-r" server.start() - res = server.make_request("POST", "/apply-template", data={ - "messages": [ - {"role": "system", "content": "You are a test."}, - {"role": "user", "content":"Hi there"}, - ] - }) + res = server.make_request( + "POST", + "/apply-template", + data={ + "messages": [ + {"role": "system", "content": "You are a test."}, + {"role": "user", "content": "Hi there"}, + ] + }, + ) assert res.status_code == 200 assert "prompt" in res.body - assert res.body["prompt"] == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + assert ( + res.body["prompt"] + == "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>You are a test.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>Hi there<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) -@pytest.mark.parametrize("response_format,n_predicted,re_content", [ - ({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""), - ({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"), - ({"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, 10, "\"foooooo\""), - ({"type": "json_object"}, 10, "(\\{|John)+"), - ({"type": "sound"}, 0, None), - # invalid response format (expected to fail) - ({"type": "json_object", "schema": 123}, 0, None), - ({"type": "json_object", "schema": {"type": 123}}, 0, None), - ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), -]) -def test_completion_with_response_format(response_format: dict, n_predicted: int, re_content: str | None): +@pytest.mark.parametrize( + "response_format,n_predicted,re_content", + [ + ({"type": "json_object", "schema": {"const": "42"}}, 6, '"42"'), + ( + {"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, + 10, + "[ -3000 ]", + ), + ( + {"type": "json_schema", "json_schema": {"schema": {"const": "foooooo"}}}, + 10, + '"foooooo"', + ), + ({"type": "json_object"}, 10, "(\\{|John)+"), + ({"type": "sound"}, 0, None), + # invalid response format (expected to fail) + ({"type": "json_object", "schema": 123}, 0, None), + ({"type": "json_object", "schema": {"type": 123}}, 0, None), + ({"type": "json_object", "schema": {"type": "hiccup"}}, 0, None), + ], +) +def test_completion_with_response_format( + response_format: dict, n_predicted: int, re_content: str | None +): global server server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predicted, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "Write an example"}, - ], - "response_format": response_format, - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "response_format": response_format, + }, + ) if re_content is not None: assert res.status_code == 200 choice = res.body["choices"][0] @@ -178,63 +349,92 @@ def test_completion_with_response_format(response_format: dict, n_predicted: int assert "error" in res.body -@pytest.mark.parametrize("jinja,json_schema,n_predicted,re_content", [ - (False, {"const": "42"}, 6, "\"42\""), - (True, {"const": "42"}, 6, "\"42\""), -]) -def test_completion_with_json_schema(jinja: bool, json_schema: dict, n_predicted: int, re_content: str): +@pytest.mark.parametrize( + "jinja,json_schema,n_predicted,re_content", + [ + (False, {"const": "42"}, 6, '"42"'), + (True, {"const": "42"}, 6, '"42"'), + ], +) +def test_completion_with_json_schema( + jinja: bool, json_schema: dict, n_predicted: int, re_content: str +): global server server.jinja = jinja server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predicted, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "Write an example"}, - ], - "json_schema": json_schema, - }) - assert res.status_code == 200, f'Expected 200, got {res.status_code}' + res = server.make_request( + "POST", + "/chat/completions", + data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "json_schema": json_schema, + }, + ) + assert res.status_code == 200, f"Expected 200, got {res.status_code}" choice = res.body["choices"][0] - assert match_regex(re_content, choice["message"]["content"]), f'Expected {re_content}, got {choice["message"]["content"]}' + assert match_regex( + re_content, choice["message"]["content"] + ), f'Expected {re_content}, got {choice["message"]["content"]}' -@pytest.mark.parametrize("jinja,grammar,n_predicted,re_content", [ - (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), - (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), -]) -def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re_content: str): +@pytest.mark.parametrize( + "jinja,grammar,n_predicted,re_content", + [ + (False, 'root ::= "a"{5,5}', 6, "a{5,5}"), + (True, 'root ::= "a"{5,5}', 6, "a{5,5}"), + ], +) +def test_completion_with_grammar( + jinja: bool, grammar: str, n_predicted: int, re_content: str +): global server server.jinja = jinja server.start() - res = server.make_request("POST", "/chat/completions", data={ - "max_tokens": n_predicted, - "messages": [ - {"role": "user", "content": "Does not matter what I say, does it?"}, - ], - "grammar": grammar, - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "max_tokens": n_predicted, + "messages": [ + {"role": "user", "content": "Does not matter what I say, does it?"}, + ], + "grammar": grammar, + }, + ) assert res.status_code == 200, res.body choice = res.body["choices"][0] - assert match_regex(re_content, choice["message"]["content"]), choice["message"]["content"] + assert match_regex(re_content, choice["message"]["content"]), choice["message"][ + "content" + ] -@pytest.mark.parametrize("messages", [ - None, - "string", - [123], - [{}], - [{"role": 123}], - [{"role": "system", "content": 123}], - # [{"content": "hello"}], # TODO: should not be a valid case - [{"role": "system", "content": "test"}, {}], -]) +@pytest.mark.parametrize( + "messages", + [ + None, + "string", + [123], + [{}], + [{"role": 123}], + [{"role": "system", "content": 123}], + # [{"content": "hello"}], # TODO: should not be a valid case + [{"role": "system", "content": "test"}, {}], + ], +) def test_invalid_chat_completion_req(messages): global server server.start() - res = server.make_request("POST", "/chat/completions", data={ - "messages": messages, - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "messages": messages, + }, + ) assert res.status_code == 400 or res.status_code == 500 assert "error" in res.body @@ -242,18 +442,22 @@ def test_invalid_chat_completion_req(messages): def test_chat_completion_with_timings_per_token(): global server server.start() - res = server.make_stream_request("POST", "/chat/completions", data={ - "max_tokens": 10, - "messages": [{"role": "user", "content": "test"}], - "stream": True, - "timings_per_token": True, - }) + res = server.make_stream_request( + "POST", + "/chat/completions", + data={ + "max_tokens": 10, + "messages": [{"role": "user", "content": "test"}], + "stream": True, + "timings_per_token": True, + }, + ) for i, data in enumerate(res): if i == 0: # Check first role message for stream=True assert data["choices"][0]["delta"]["content"] is None assert data["choices"][0]["delta"]["role"] == "assistant" - assert "timings" not in data, f'First event should not have timings: {data}' + assert "timings" not in data, f"First event should not have timings: {data}" else: assert "role" not in data["choices"][0]["delta"] assert "timings" in data @@ -266,7 +470,9 @@ def test_chat_completion_with_timings_per_token(): def test_logprobs(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, @@ -279,7 +485,7 @@ def test_logprobs(): top_logprobs=10, ) output_text = res.choices[0].message.content - aggregated_text = '' + aggregated_text = "" assert res.choices[0].logprobs is not None assert res.choices[0].logprobs.content is not None for token in res.choices[0].logprobs.content: @@ -293,7 +499,9 @@ def test_logprobs(): def test_logprobs_stream(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.chat.completions.create( model="gpt-3.5-turbo-instruct", temperature=0.0, @@ -306,8 +514,8 @@ def test_logprobs_stream(): top_logprobs=10, stream=True, ) - output_text = '' - aggregated_text = '' + output_text = "" + aggregated_text = "" for i, data in enumerate(res): choice = data.choices[0] if i == 0: diff --git a/smallthinker/tools/server/tests/unit/test_completion.py b/smallthinker/tools/server/tests/unit/test_completion.py index f6909e9a..e651cb3c 100644 --- a/smallthinker/tools/server/tests/unit/test_completion.py +++ b/smallthinker/tools/server/tests/unit/test_completion.py @@ -12,18 +12,42 @@ def create_server(): global server server = ServerPreset.tinyllama2() -@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [ - ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True), -]) -def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool): + +@pytest.mark.parametrize( + "prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", + [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False), + ( + "Write a joke about AI from a very long prompt which will not be truncated", + 256, + "(princesses|everyone|kids|Anna|forest)+", + 46, + 64, + False, + True, + ), + ], +) +def test_completion( + prompt: str, + n_predict: int, + re_content: str, + n_prompt: int, + n_predicted: int, + truncated: bool, + return_tokens: bool, +): global server server.start() - res = server.make_request("POST", "/completion", data={ - "n_predict": n_predict, - "prompt": prompt, - "return_tokens": return_tokens, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "return_tokens": return_tokens, + }, + ) assert res.status_code == 200 assert res.body["timings"]["prompt_n"] == n_prompt assert res.body["timings"]["predicted_n"] == n_predicted @@ -37,18 +61,39 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, assert res.body["tokens"] == [] -@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [ - ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), - ("Write a joke about AI from a very long prompt which will not be truncated", 256, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False), -]) -def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool): +@pytest.mark.parametrize( + "prompt,n_predict,re_content,n_prompt,n_predicted,truncated", + [ + ("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False), + ( + "Write a joke about AI from a very long prompt which will not be truncated", + 256, + "(princesses|everyone|kids|Anna|forest)+", + 46, + 64, + False, + ), + ], +) +def test_completion_stream( + prompt: str, + n_predict: int, + re_content: str, + n_prompt: int, + n_predicted: int, + truncated: bool, +): global server server.start() - res = server.make_stream_request("POST", "/completion", data={ - "n_predict": n_predict, - "prompt": prompt, - "stream": True, - }) + res = server.make_stream_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": prompt, + "stream": True, + }, + ) content = "" for data in res: assert "stop" in data and type(data["stop"]) == bool @@ -60,7 +105,9 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp assert type(data["has_new_line"]) == bool assert "generation_settings" in data assert server.n_predict is not None - assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict) + assert data["generation_settings"]["n_predict"] == min( + n_predict, server.n_predict + ) assert data["generation_settings"]["seed"] == server.seed assert match_regex(re_content, content) else: @@ -72,15 +119,23 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp def test_completion_stream_vs_non_stream(): global server server.start() - res_stream = server.make_stream_request("POST", "/completion", data={ - "n_predict": 8, - "prompt": "I believe the meaning of life is", - "stream": True, - }) - res_non_stream = server.make_request("POST", "/completion", data={ - "n_predict": 8, - "prompt": "I believe the meaning of life is", - }) + res_stream = server.make_stream_request( + "POST", + "/completion", + data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + "stream": True, + }, + ) + res_non_stream = server.make_request( + "POST", + "/completion", + data={ + "n_predict": 8, + "prompt": "I believe the meaning of life is", + }, + ) content_stream = "" for data in res_stream: content_stream += data["content"] @@ -90,7 +145,9 @@ def test_completion_stream_vs_non_stream(): def test_completion_with_openai_library(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.completions.create( model="davinci-002", prompt="I believe the meaning of life is", @@ -105,14 +162,16 @@ def test_completion_with_openai_library(): def test_completion_stream_with_openai_library(): global server server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.completions.create( model="davinci-002", prompt="I believe the meaning of life is", max_tokens=8, stream=True, ) - output_text = '' + output_text = "" for data in res: choice = data.choices[0] if choice.finish_reason is None: @@ -128,7 +187,9 @@ def test_completion_stream_with_openai_library_stops(): server.model_hf_repo = "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M" server.model_hf_file = None server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) res = client.completions.create( model="davinci-002", prompt="System: You are helpfull assistant.\nAssistant:\nHey! How could I help?\nUser:\nTell me a joke.\nAssistant:\n", @@ -136,13 +197,15 @@ def test_completion_stream_with_openai_library_stops(): max_tokens=200, stream=True, ) - output_text = '' + output_text = "" for data in res: choice = data.choices[0] if choice.finish_reason is None: assert choice.text is not None output_text += choice.text - assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}' + assert match_regex( + "Sure, here's one for[\\s\\S]*", output_text + ), f"Unexpected output: {output_text}" @pytest.mark.parametrize("n_slots", [1, 2]) @@ -152,12 +215,16 @@ def test_consistent_result_same_seed(n_slots: int): server.start() last_res = None for _ in range(4): - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "seed": 42, - "temperature": 0.0, - "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }, + ) if last_res is not None: assert res.body["content"] == last_res.body["content"] last_res = res @@ -170,18 +237,23 @@ def test_different_result_different_seed(n_slots: int): server.start() last_res = None for seed in range(4): - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "seed": seed, - "temperature": 1.0, - "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "seed": seed, + "temperature": 1.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }, + ) if last_res is not None: assert res.body["content"] != last_res.body["content"] last_res = res + # TODO figure why it don't work with temperature = 1 -# @pytest.mark.parametrize("temperature", [0.0, 1.0]) +# @pytest.mark.parametrize("temperature", [0.0, 1.0]) @pytest.mark.parametrize("n_batch", [16, 32]) @pytest.mark.parametrize("temperature", [0.0]) def test_consistent_result_different_batch_size(n_batch: int, temperature: float): @@ -190,12 +262,16 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float server.start() last_res = None for _ in range(4): - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "seed": 42, - "temperature": temperature, - "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": temperature, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }, + ) if last_res is not None: assert res.body["content"] == last_res.body["content"] last_res = res @@ -205,30 +281,42 @@ def test_consistent_result_different_batch_size(n_batch: int, temperature: float def test_cache_vs_nocache_prompt(): global server server.start() - res_cache = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "seed": 42, - "temperature": 1.0, - "cache_prompt": True, - }) - res_no_cache = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "seed": 42, - "temperature": 1.0, - "cache_prompt": False, - }) + res_cache = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": True, + }, + ) + res_no_cache = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }, + ) assert res_cache.body["content"] == res_no_cache.body["content"] def test_nocache_long_input_prompt(): global server server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is"*32, - "seed": 42, - "temperature": 1.0, - "cache_prompt": False, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is" * 32, + "seed": 42, + "temperature": 1.0, + "cache_prompt": False, + }, + ) assert res.status_code == 200 @@ -237,53 +325,76 @@ def test_completion_with_tokens_input(): server.temperature = 0.0 server.start() prompt_str = "I believe the meaning of life is" - res = server.make_request("POST", "/tokenize", data={ - "content": prompt_str, - "add_special": True, - }) + res = server.make_request( + "POST", + "/tokenize", + data={ + "content": prompt_str, + "add_special": True, + }, + ) assert res.status_code == 200 tokens = res.body["tokens"] # single completion - res = server.make_request("POST", "/completion", data={ - "prompt": tokens, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": tokens, + }, + ) assert res.status_code == 200 assert type(res.body["content"]) == str # batch completion - res = server.make_request("POST", "/completion", data={ - "prompt": [tokens, tokens], - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": [tokens, tokens], + }, + ) assert res.status_code == 200 assert type(res.body) == list assert len(res.body) == 2 assert res.body[0]["content"] == res.body[1]["content"] # mixed string and tokens - res = server.make_request("POST", "/completion", data={ - "prompt": [tokens, prompt_str], - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": [tokens, prompt_str], + }, + ) assert res.status_code == 200 assert type(res.body) == list assert len(res.body) == 2 assert res.body[0]["content"] == res.body[1]["content"] # mixed string and tokens in one sequence - res = server.make_request("POST", "/completion", data={ - "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": [1, 2, 3, 4, 5, 6, prompt_str, 7, 8, 9, 10, prompt_str], + }, + ) assert res.status_code == 200 assert type(res.body["content"]) == str -@pytest.mark.parametrize("n_slots,n_requests", [ - (1, 3), - (2, 2), - (2, 4), - (4, 2), # some slots must be idle - (4, 6), -]) +@pytest.mark.parametrize( + "n_slots,n_requests", + [ + (1, 3), + (2, 2), + (2, 4), + (4, 2), # some slots must be idle + (4, 6), + ], +) def test_completion_parallel_slots(n_slots: int, n_requests: int): global server server.n_slots = n_slots @@ -298,6 +409,7 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int): ("Write another very long music lyrics.", "(friends|step|sky)+"), ("Write a very long joke.", "(cat|Whiskers)+"), ] + def check_slots_status(): should_all_slots_busy = n_requests >= n_slots time.sleep(0.1) @@ -311,11 +423,20 @@ def check_slots_status(): tasks = [] for i in range(n_requests): prompt, re_content = PROMPTS[i % len(PROMPTS)] - tasks.append((server.make_request, ("POST", "/completion", { - "prompt": prompt, - "seed": 42, - "temperature": 1.0, - }))) + tasks.append( + ( + server.make_request, + ( + "POST", + "/completion", + { + "prompt": prompt, + "seed": 42, + "temperature": 1.0, + }, + ), + ) + ) tasks.append((check_slots_status, ())) results = parallel_function_calls(tasks) @@ -334,7 +455,11 @@ def check_slots_status(): "prompt,n_predict,response_fields", [ ("I believe the meaning of life is", 8, []), - ("I believe the meaning of life is", 32, ["content", "generation_settings/n_predict", "prompt"]), + ( + "I believe the meaning of life is", + 32, + ["content", "generation_settings/n_predict", "prompt"], + ), ], ) def test_completion_response_fields( @@ -367,12 +492,16 @@ def test_completion_response_fields( def test_n_probs(): global server server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "n_probs": 10, - "temperature": 0.0, - "n_predict": 5, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + }, + ) assert res.status_code == 200 assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 @@ -392,13 +521,17 @@ def test_n_probs(): def test_n_probs_stream(): global server server.start() - res = server.make_stream_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "n_probs": 10, - "temperature": 0.0, - "n_predict": 5, - "stream": True, - }) + res = server.make_stream_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "stream": True, + }, + ) for data in res: if data["stop"] == False: assert "completion_probabilities" in data @@ -419,13 +552,17 @@ def test_n_probs_stream(): def test_n_probs_post_sampling(): global server server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "n_probs": 10, - "temperature": 0.0, - "n_predict": 5, - "post_sampling_probs": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "n_probs": 10, + "temperature": 0.0, + "n_predict": 5, + "post_sampling_probs": True, + }, + ) assert res.status_code == 200 assert "completion_probabilities" in res.body assert len(res.body["completion_probabilities"]) == 5 @@ -453,12 +590,17 @@ def test_cancel_request(): server.start() # send a request that will take a long time, but cancel it before it finishes try: - server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - }, timeout=0.1) + server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + }, + timeout=0.1, + ) except requests.exceptions.ReadTimeout: - pass # expected + pass # expected # make sure the slot is free - time.sleep(1) # wait for HTTP_POLLING_SECONDS + time.sleep(1) # wait for HTTP_POLLING_SECONDS res = server.make_request("GET", "/slots") assert res.body[0]["is_processing"] == False diff --git a/smallthinker/tools/server/tests/unit/test_ctx_shift.py b/smallthinker/tools/server/tests/unit/test_ctx_shift.py index 2431ac70..be853bc6 100644 --- a/smallthinker/tools/server/tests/unit/test_ctx_shift.py +++ b/smallthinker/tools/server/tests/unit/test_ctx_shift.py @@ -11,6 +11,7 @@ Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. """.strip() + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -26,29 +27,42 @@ def test_ctx_shift_enabled(): # 64 tokens are generated thanks to shifting the context when it gets full global server server.start() - res = server.make_request("POST", "/completion", data={ - "n_predict": 64, - "prompt": LONG_TEXT, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }, + ) assert res.status_code == 200 assert res.body["timings"]["prompt_n"] == 109 assert res.body["timings"]["predicted_n"] == 64 assert res.body["truncated"] is True -@pytest.mark.parametrize("n_predict,n_token_output,truncated", [ - (64, 64, False), - (-1, 120, True), -]) -def test_ctx_shift_disabled_short_prompt(n_predict: int, n_token_output: int, truncated: bool): +@pytest.mark.parametrize( + "n_predict,n_token_output,truncated", + [ + (64, 64, False), + (-1, 120, True), + ], +) +def test_ctx_shift_disabled_short_prompt( + n_predict: int, n_token_output: int, truncated: bool +): global server server.disable_ctx_shift = True server.n_predict = -1 server.start() - res = server.make_request("POST", "/completion", data={ - "n_predict": n_predict, - "prompt": "Hi how are you", - }) + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": n_predict, + "prompt": "Hi how are you", + }, + ) assert res.status_code == 200 assert res.body["timings"]["predicted_n"] == n_token_output assert res.body["truncated"] == truncated @@ -58,23 +72,32 @@ def test_ctx_shift_disabled_long_prompt(): global server server.disable_ctx_shift = True server.start() - res = server.make_request("POST", "/completion", data={ - "n_predict": 64, - "prompt": LONG_TEXT, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "n_predict": 64, + "prompt": LONG_TEXT, + }, + ) assert res.status_code != 200 assert "error" in res.body assert "exceeds the available context size" in res.body["error"]["message"] + def test_ctx_shift_disabled_stream(): global server server.disable_ctx_shift = True server.start() - res = server.make_stream_request("POST", "/v1/completions", data={ - "n_predict": 256, - "prompt": "Once", - "stream": True, - }) + res = server.make_stream_request( + "POST", + "/v1/completions", + data={ + "n_predict": 256, + "prompt": "Once", + "stream": True, + }, + ) content = "" for data in res: choice = data["choices"][0] diff --git a/smallthinker/tools/server/tests/unit/test_embedding.py b/smallthinker/tools/server/tests/unit/test_embedding.py index 0feb452c..4ff8c874 100644 --- a/smallthinker/tools/server/tests/unit/test_embedding.py +++ b/smallthinker/tools/server/tests/unit/test_embedding.py @@ -8,6 +8,7 @@ EPSILON = 1e-3 + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -16,57 +17,69 @@ def create_server(): def test_embedding_single(): global server - server.pooling = 'last' + server.pooling = "last" server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": "I believe the meaning of life is", - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": "I believe the meaning of life is", + }, + ) assert res.status_code == 200 - assert len(res.body['data']) == 1 - assert 'embedding' in res.body['data'][0] - assert len(res.body['data'][0]['embedding']) > 1 + assert len(res.body["data"]) == 1 + assert "embedding" in res.body["data"][0] + assert len(res.body["data"][0]["embedding"]) > 1 # make sure embedding vector is normalized - assert abs(sum([x ** 2 for x in res.body['data'][0]['embedding']]) - 1) < EPSILON + assert abs(sum([x**2 for x in res.body["data"][0]["embedding"]]) - 1) < EPSILON def test_embedding_multiple(): global server - server.pooling = 'last' + server.pooling = "last" server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": [ - "I believe the meaning of life is", - "Write a joke about AI from a very long prompt which will not be truncated", - "This is a test", - "This is another test", - ], - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": [ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ], + }, + ) assert res.status_code == 200 - assert len(res.body['data']) == 4 - for d in res.body['data']: - assert 'embedding' in d - assert len(d['embedding']) > 1 + assert len(res.body["data"]) == 4 + for d in res.body["data"]: + assert "embedding" in d + assert len(d["embedding"]) > 1 def test_embedding_multiple_with_fa(): server = ServerPreset.bert_bge_small_with_fa() - server.pooling = 'last' + server.pooling = "last" server.start() # one of these should trigger the FA branch (i.e. context size % 256 == 0) - res = server.make_request("POST", "/v1/embeddings", data={ - "input": [ - "a "*253, - "b "*254, - "c "*255, - "d "*256, - ], - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": [ + "a " * 253, + "b " * 254, + "c " * 255, + "d " * 256, + ], + }, + ) assert res.status_code == 200 - assert len(res.body['data']) == 4 - for d in res.body['data']: - assert 'embedding' in d - assert len(d['embedding']) > 1 + assert len(res.body["data"]) == 4 + for d in res.body["data"]: + assert "embedding" in d + assert len(d["embedding"]) > 1 @pytest.mark.parametrize( @@ -83,47 +96,55 @@ def test_embedding_multiple_with_fa(): (["string1", [12, 34, 56]], True), ([[12, 34, 56], [12, 34, 56]], True), ([[12, 34, 56], [12, "string", 34, 56]], True), - ] + ], ) def test_embedding_mixed_input(input, is_multi_prompt: bool): global server server.start() res = server.make_request("POST", "/v1/embeddings", data={"input": input}) assert res.status_code == 200 - data = res.body['data'] + data = res.body["data"] if is_multi_prompt: assert len(data) == len(input) for d in data: - assert 'embedding' in d - assert len(d['embedding']) > 1 + assert "embedding" in d + assert len(d["embedding"]) > 1 else: - assert 'embedding' in data[0] - assert len(data[0]['embedding']) > 1 + assert "embedding" in data[0] + assert len(data[0]["embedding"]) > 1 def test_embedding_pooling_none(): global server - server.pooling = 'none' + server.pooling = "none" server.start() - res = server.make_request("POST", "/embeddings", data={ - "input": "hello hello hello", - }) + res = server.make_request( + "POST", + "/embeddings", + data={ + "input": "hello hello hello", + }, + ) assert res.status_code == 200 - assert 'embedding' in res.body[0] - assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + assert "embedding" in res.body[0] + assert len(res.body[0]["embedding"]) == 5 # 3 text tokens + 2 special # make sure embedding vector is not normalized - for x in res.body[0]['embedding']: - assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + for x in res.body[0]["embedding"]: + assert abs(sum([x**2 for x in x]) - 1) > EPSILON def test_embedding_pooling_none_oai(): global server - server.pooling = 'none' + server.pooling = "none" server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": "hello hello hello", - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": "hello hello hello", + }, + ) # /v1/embeddings does not support pooling type 'none' assert res.status_code == 400 @@ -132,25 +153,34 @@ def test_embedding_pooling_none_oai(): def test_embedding_openai_library_single(): global server - server.pooling = 'last' + server.pooling = "last" server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") - res = client.embeddings.create(model="text-embedding-3-small", input="I believe the meaning of life is") + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) + res = client.embeddings.create( + model="text-embedding-3-small", input="I believe the meaning of life is" + ) assert len(res.data) == 1 assert len(res.data[0].embedding) > 1 def test_embedding_openai_library_multiple(): global server - server.pooling = 'last' + server.pooling = "last" server.start() - client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1") - res = client.embeddings.create(model="text-embedding-3-small", input=[ - "I believe the meaning of life is", - "Write a joke about AI from a very long prompt which will not be truncated", - "This is a test", - "This is another test", - ]) + client = OpenAI( + api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1" + ) + res = client.embeddings.create( + model="text-embedding-3-small", + input=[ + "I believe the meaning of life is", + "Write a joke about AI from a very long prompt which will not be truncated", + "This is a test", + "This is another test", + ], + ) assert len(res.data) == 4 for d in res.data: assert len(d.embedding) > 1 @@ -158,32 +188,40 @@ def test_embedding_openai_library_multiple(): def test_embedding_error_prompt_too_long(): global server - server.pooling = 'last' + server.pooling = "last" server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": "This is a test " * 512, - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": "This is a test " * 512, + }, + ) assert res.status_code != 200 assert "too large" in res.body["error"]["message"] def test_same_prompt_give_same_result(): - server.pooling = 'last' + server.pooling = "last" server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": [ - "I believe the meaning of life is", - "I believe the meaning of life is", - "I believe the meaning of life is", - "I believe the meaning of life is", - "I believe the meaning of life is", - ], - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }, + ) assert res.status_code == 200 - assert len(res.body['data']) == 5 - for i in range(1, len(res.body['data'])): - v0 = res.body['data'][0]['embedding'] - vi = res.body['data'][i]['embedding'] + assert len(res.body["data"]) == 5 + for i in range(1, len(res.body["data"])): + v0 = res.body["data"][0]["embedding"] + vi = res.body["data"][i]["embedding"] for x, y in zip(v0, vi): assert abs(x - y) < EPSILON @@ -193,29 +231,33 @@ def test_same_prompt_give_same_result(): [ ("I believe the meaning of life is", 9), ("This is a test", 6), - ] + ], ) def test_embedding_usage_single(content, n_tokens): global server server.start() res = server.make_request("POST", "/v1/embeddings", data={"input": content}) assert res.status_code == 200 - assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] - assert res.body['usage']['prompt_tokens'] == n_tokens + assert res.body["usage"]["prompt_tokens"] == res.body["usage"]["total_tokens"] + assert res.body["usage"]["prompt_tokens"] == n_tokens def test_embedding_usage_multiple(): global server server.start() - res = server.make_request("POST", "/v1/embeddings", data={ - "input": [ - "I believe the meaning of life is", - "I believe the meaning of life is", - ], - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={ + "input": [ + "I believe the meaning of life is", + "I believe the meaning of life is", + ], + }, + ) assert res.status_code == 200 - assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] - assert res.body['usage']['prompt_tokens'] == 2 * 9 + assert res.body["usage"]["prompt_tokens"] == res.body["usage"]["total_tokens"] + assert res.body["usage"]["prompt_tokens"] == 2 * 9 def test_embedding_openai_library_base64(): @@ -223,17 +265,16 @@ def test_embedding_openai_library_base64(): test_input = "Test base64 embedding output" # get embedding in default format - res = server.make_request("POST", "/v1/embeddings", data={ - "input": test_input - }) + res = server.make_request("POST", "/v1/embeddings", data={"input": test_input}) assert res.status_code == 200 vec0 = res.body["data"][0]["embedding"] # get embedding in base64 format - res = server.make_request("POST", "/v1/embeddings", data={ - "input": test_input, - "encoding_format": "base64" - }) + res = server.make_request( + "POST", + "/v1/embeddings", + data={"input": test_input, "encoding_format": "base64"}, + ) assert res.status_code == 200 assert "data" in res.body @@ -247,7 +288,7 @@ def test_embedding_openai_library_base64(): decoded = base64.b64decode(embedding_data["embedding"]) # Verify decoded data can be converted back to float array float_count = len(decoded) // 4 # 4 bytes per float - floats = struct.unpack(f'{float_count}f', decoded) + floats = struct.unpack(f"{float_count}f", decoded) assert len(floats) > 0 assert all(isinstance(x, float) for x in floats) assert len(floats) == len(vec0) diff --git a/smallthinker/tools/server/tests/unit/test_infill.py b/smallthinker/tools/server/tests/unit/test_infill.py index 10554db0..1297aab3 100644 --- a/smallthinker/tools/server/tests/unit/test_infill.py +++ b/smallthinker/tools/server/tests/unit/test_infill.py @@ -3,6 +3,7 @@ server = ServerPreset.tinyllama_infill() + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -12,11 +13,15 @@ def create_server(): def test_infill_without_input_extra(): global server server.start() - res = server.make_request("POST", "/infill", data={ - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", - "prompt": " int n_threads = llama_", - "input_suffix": "}\n", - }) + res = server.make_request( + "POST", + "/infill", + data={ + "input_prefix": '#include \n#include "llama.h"\n\nint main() {\n', + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }, + ) assert res.status_code == 200 assert match_regex("(Ann|small|shiny|Daddy)+", res.body["content"]) @@ -24,35 +29,48 @@ def test_infill_without_input_extra(): def test_infill_with_input_extra(): global server server.start() - res = server.make_request("POST", "/infill", data={ - "input_extra": [{ - "filename": "llama.h", - "text": "LLAMA_API int32_t llama_n_threads();\n" - }], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", - "prompt": " int n_threads = llama_", - "input_suffix": "}\n", - }) + res = server.make_request( + "POST", + "/infill", + data={ + "input_extra": [ + { + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n", + } + ], + "input_prefix": '#include \n#include "llama.h"\n\nint main() {\n', + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }, + ) assert res.status_code == 200 assert match_regex("(Dad|excited|park)+", res.body["content"]) -@pytest.mark.parametrize("input_extra", [ - {}, - {"filename": "ok"}, - {"filename": 123}, - {"filename": 123, "text": "abc"}, - {"filename": 123, "text": 456}, -]) +@pytest.mark.parametrize( + "input_extra", + [ + {}, + {"filename": "ok"}, + {"filename": 123}, + {"filename": 123, "text": "abc"}, + {"filename": 123, "text": 456}, + ], +) def test_invalid_input_extra_req(input_extra): global server server.start() - res = server.make_request("POST", "/infill", data={ - "input_extra": [input_extra], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", - "prompt": " int n_threads = llama_", - "input_suffix": "}\n", - }) + res = server.make_request( + "POST", + "/infill", + data={ + "input_extra": [input_extra], + "input_prefix": '#include \n#include "llama.h"\n\nint main() {\n', + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }, + ) assert res.status_code == 400 assert "error" in res.body @@ -64,14 +82,23 @@ def test_with_qwen_model(): server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF" server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" server.start(timeout_seconds=600) - res = server.make_request("POST", "/infill", data={ - "input_extra": [{ - "filename": "llama.h", - "text": "LLAMA_API int32_t llama_n_threads();\n" - }], - "input_prefix": "#include \n#include \"llama.h\"\n\nint main() {\n", - "prompt": " int n_threads = llama_", - "input_suffix": "}\n", - }) + res = server.make_request( + "POST", + "/infill", + data={ + "input_extra": [ + { + "filename": "llama.h", + "text": "LLAMA_API int32_t llama_n_threads();\n", + } + ], + "input_prefix": '#include \n#include "llama.h"\n\nint main() {\n', + "prompt": " int n_threads = llama_", + "input_suffix": "}\n", + }, + ) assert res.status_code == 200 - assert res.body["content"] == "n_threads();\n printf(\"Number of threads: %d\\n\", n_threads);\n return 0;\n" + assert ( + res.body["content"] + == 'n_threads();\n printf("Number of threads: %d\\n", n_threads);\n return 0;\n' + ) diff --git a/smallthinker/tools/server/tests/unit/test_lora.py b/smallthinker/tools/server/tests/unit/test_lora.py index c1aa8be7..a9427278 100644 --- a/smallthinker/tools/server/tests/unit/test_lora.py +++ b/smallthinker/tools/server/tests/unit/test_lora.py @@ -5,6 +5,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -12,22 +13,29 @@ def create_server(): server.lora_files = [download_file(LORA_FILE_URL)] -@pytest.mark.parametrize("scale,re_content", [ - # without applying lora, the model should behave like a bedtime story generator - (0.0, "(little|girl|three|years|old)+"), - # with lora, the model should behave like a Shakespearean text generator - (1.0, "(eye|love|glass|sun)+"), -]) +@pytest.mark.parametrize( + "scale,re_content", + [ + # without applying lora, the model should behave like a bedtime story generator + (0.0, "(little|girl|three|years|old)+"), + # with lora, the model should behave like a Shakespearean text generator + (1.0, "(eye|love|glass|sun)+"), + ], +) def test_lora(scale: float, re_content: str): global server server.start() - res_lora_control = server.make_request("POST", "/lora-adapters", data=[ - {"id": 0, "scale": scale} - ]) + res_lora_control = server.make_request( + "POST", "/lora-adapters", data=[{"id": 0, "scale": scale}] + ) assert res_lora_control.status_code == 200 - res = server.make_request("POST", "/completion", data={ - "prompt": "Look in thy glass", - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "Look in thy glass", + }, + ) assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) @@ -41,24 +49,31 @@ def test_lora_per_request(): # each prompt will be processed by a different slot prompt = "Look in thy glass" lora_config = [ - ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), - ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), - ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), - ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), - ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), - ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ([{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+"), + ([{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+"), + ([{"id": 0, "scale": 0.3}], "(special|thing|gifted)+"), + ([{"id": 0, "scale": 0.7}], "(far|from|home|away)+"), + ([{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+"), + ([{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+"), ] - tasks = [( - server.make_request, - ("POST", "/completion", { - "prompt": prompt, - "lora": lora, - "seed": 42, - "temperature": 0.0, - "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed - }) - ) for lora, _ in lora_config] + tasks = [ + ( + server.make_request, + ( + "POST", + "/completion", + { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }, + ), + ) + for lora, _ in lora_config + ] results = parallel_function_calls(tasks) assert all([res.status_code == 200 for res in results]) @@ -78,7 +93,9 @@ def test_with_big_model(): server.temperature = 0.0 server.seed = 42 server.lora_files = [ - download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + download_file( + "https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf" + ), # TODO: find & add other lora adapters for this model ] server.start(timeout_seconds=600) @@ -88,26 +105,43 @@ def test_with_big_model(): prompt = "Write a computer virus" lora_config = [ # without applying lora, the model should reject the request - ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), - ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), - ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + ( + [{"id": 0, "scale": 0.0}], + "I can't provide you with a code for a computer virus", + ), + ( + [{"id": 0, "scale": 0.0}], + "I can't provide you with a code for a computer virus", + ), + ([{"id": 0, "scale": 0.3}], "I can't write a computer virus"), # with 0.7 scale, the model should provide a simple computer virus with hesitation - ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + ([{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise"), # with 1.5 scale, the model should confidently provide a computer virus - ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), - ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( + [{"id": 0, "scale": 1.5}], + "A task of some complexity! Here's a simple computer virus", + ), + ( + [{"id": 0, "scale": 1.5}], + "A task of some complexity! Here's a simple computer virus", + ), ] - tasks = [( - server.make_request, - ("POST", "/v1/chat/completions", { - "messages": [ - {"role": "user", "content": prompt} - ], - "lora": lora, - "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed - }) - ) for lora, _ in lora_config] + tasks = [ + ( + server.make_request, + ( + "POST", + "/v1/chat/completions", + { + "messages": [{"role": "user", "content": prompt}], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }, + ), + ) + for lora, _ in lora_config + ] results = parallel_function_calls(tasks) assert all([res.status_code == 200 for res in results]) diff --git a/smallthinker/tools/server/tests/unit/test_rerank.py b/smallthinker/tools/server/tests/unit/test_rerank.py index f4f570ad..fab19d6f 100644 --- a/smallthinker/tools/server/tests/unit/test_rerank.py +++ b/smallthinker/tools/server/tests/unit/test_rerank.py @@ -14,17 +14,21 @@ def create_server(): "A machine is a physical system that uses power to apply forces and control movement to perform an action. The term is commonly applied to artificial devices, such as those employing engines or motors, but also to natural biological macromolecules, such as molecular machines.", "Learning is the process of acquiring new understanding, knowledge, behaviors, skills, values, attitudes, and preferences. The ability to learn is possessed by humans, non-human animals, and some machines; there is also evidence for some kind of learning in certain plants.", "Machine learning is a field of study in artificial intelligence concerned with the development and study of statistical algorithms that can learn from data and generalize to unseen data, and thus perform tasks without explicit instructions.", - "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine." + "Paris, capitale de la France, est une grande ville européenne et un centre mondial de l'art, de la mode, de la gastronomie et de la culture. Son paysage urbain du XIXe siècle est traversé par de larges boulevards et la Seine.", ] def test_rerank(): global server server.start() - res = server.make_request("POST", "/rerank", data={ - "query": "Machine learning is", - "documents": TEST_DOCUMENTS, - }) + res = server.make_request( + "POST", + "/rerank", + data={ + "query": "Machine learning is", + "documents": TEST_DOCUMENTS, + }, + ) assert res.status_code == 200 assert len(res.body["results"]) == 4 @@ -44,10 +48,14 @@ def test_rerank(): def test_rerank_tei_format(): global server server.start() - res = server.make_request("POST", "/rerank", data={ - "query": "Machine learning is", - "texts": TEST_DOCUMENTS, - }) + res = server.make_request( + "POST", + "/rerank", + data={ + "query": "Machine learning is", + "texts": TEST_DOCUMENTS, + }, + ) assert res.status_code == 200 assert len(res.body) == 4 @@ -64,19 +72,26 @@ def test_rerank_tei_format(): assert least_relevant["index"] == 3 -@pytest.mark.parametrize("documents", [ - [], - None, - 123, - [1, 2, 3], -]) +@pytest.mark.parametrize( + "documents", + [ + [], + None, + 123, + [1, 2, 3], + ], +) def test_invalid_rerank_req(documents): global server server.start() - res = server.make_request("POST", "/rerank", data={ - "query": "Machine learning is", - "documents": documents, - }) + res = server.make_request( + "POST", + "/rerank", + data={ + "query": "Machine learning is", + "documents": documents, + }, + ) assert res.status_code == 400 assert "error" in res.body @@ -86,19 +101,23 @@ def test_invalid_rerank_req(documents): [ ("Machine learning is", "A machine", "Learning is", 19), ("Which city?", "Machine learning is ", "Paris, capitale de la", 26), - ] + ], ) def test_rerank_usage(query, doc1, doc2, n_tokens): global server server.start() - res = server.make_request("POST", "/rerank", data={ - "query": query, - "documents": [ - doc1, - doc2, - ] - }) + res = server.make_request( + "POST", + "/rerank", + data={ + "query": query, + "documents": [ + doc1, + doc2, + ], + }, + ) assert res.status_code == 200 - assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens'] - assert res.body['usage']['prompt_tokens'] == n_tokens + assert res.body["usage"]["prompt_tokens"] == res.body["usage"]["total_tokens"] + assert res.body["usage"]["prompt_tokens"] == n_tokens diff --git a/smallthinker/tools/server/tests/unit/test_security.py b/smallthinker/tools/server/tests/unit/test_security.py index 620b2537..0d96e108 100644 --- a/smallthinker/tools/server/tests/unit/test_security.py +++ b/smallthinker/tools/server/tests/unit/test_security.py @@ -6,6 +6,7 @@ TEST_API_KEY = "sk-this-is-the-secret-key" + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -26,11 +27,16 @@ def test_access_public_endpoint(endpoint: str): def test_incorrect_api_key(api_key: str): global server server.start() - res = server.make_request("POST", "/completions", data={ - "prompt": "I believe the meaning of life is", - }, headers={ - "Authorization": f"Bearer {api_key}" if api_key else None, - }) + res = server.make_request( + "POST", + "/completions", + data={ + "prompt": "I believe the meaning of life is", + }, + headers={ + "Authorization": f"Bearer {api_key}" if api_key else None, + }, + ) assert res.status_code == 401 assert "error" in res.body assert res.body["error"]["type"] == "authentication_error" @@ -39,11 +45,16 @@ def test_incorrect_api_key(api_key: str): def test_correct_api_key(): global server server.start() - res = server.make_request("POST", "/completions", data={ - "prompt": "I believe the meaning of life is", - }, headers={ - "Authorization": f"Bearer {TEST_API_KEY}", - }) + res = server.make_request( + "POST", + "/completions", + data={ + "prompt": "I believe the meaning of life is", + }, + headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + }, + ) assert res.status_code == 200 assert "error" not in res.body assert "content" in res.body @@ -52,7 +63,10 @@ def test_correct_api_key(): def test_openai_library_correct_api_key(): global server server.start() - client = OpenAI(api_key=TEST_API_KEY, base_url=f"http://{server.server_host}:{server.server_port}") + client = OpenAI( + api_key=TEST_API_KEY, + base_url=f"http://{server.server_host}:{server.server_port}", + ) res = client.chat.completions.create( model="gpt-3.5-turbo", messages=[ @@ -63,21 +77,28 @@ def test_openai_library_correct_api_key(): assert len(res.choices) == 1 -@pytest.mark.parametrize("origin,cors_header,cors_header_value", [ - ("localhost", "Access-Control-Allow-Origin", "localhost"), - ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), - ("origin", "Access-Control-Allow-Credentials", "true"), - ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), - ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), -]) +@pytest.mark.parametrize( + "origin,cors_header,cors_header_value", + [ + ("localhost", "Access-Control-Allow-Origin", "localhost"), + ("web.mydomain.fr", "Access-Control-Allow-Origin", "web.mydomain.fr"), + ("origin", "Access-Control-Allow-Credentials", "true"), + ("web.mydomain.fr", "Access-Control-Allow-Methods", "GET, POST"), + ("web.mydomain.fr", "Access-Control-Allow-Headers", "*"), + ], +) def test_cors_options(origin: str, cors_header: str, cors_header_value: str): global server server.start() - res = server.make_request("OPTIONS", "/completions", headers={ - "Origin": origin, - "Access-Control-Request-Method": "POST", - "Access-Control-Request-Headers": "Authorization", - }) + res = server.make_request( + "OPTIONS", + "/completions", + headers={ + "Origin": origin, + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "Authorization", + }, + ) assert res.status_code == 200 assert cors_header in res.headers assert res.headers[cors_header] == cors_header_value diff --git a/smallthinker/tools/server/tests/unit/test_slot_save.py b/smallthinker/tools/server/tests/unit/test_slot_save.py index 38704f5e..fe28f984 100644 --- a/smallthinker/tools/server/tests/unit/test_slot_save.py +++ b/smallthinker/tools/server/tests/unit/test_slot_save.py @@ -3,6 +3,7 @@ server = ServerPreset.tinyllama2() + @pytest.fixture(scope="module", autouse=True) def create_server(): global server @@ -16,55 +17,79 @@ def test_slot_save_restore(): server.start() # First prompt in slot 1 should be fully processed - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of France?", - "id_slot": 1, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Whiskers|Flana)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed # Save state of slot 1 - res = server.make_request("POST", "/slots/1?action=save", data={ - "filename": "slot1.bin", - }) + res = server.make_request( + "POST", + "/slots/1?action=save", + data={ + "filename": "slot1.bin", + }, + ) assert res.status_code == 200 assert res.body["n_saved"] == 84 # Since we have cache, this should only process the last tokens - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of Germany?", - "id_slot": 1, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 6 # only different part is processed # Loading the saved cache into slot 0 - res = server.make_request("POST", "/slots/0?action=restore", data={ - "filename": "slot1.bin", - }) + res = server.make_request( + "POST", + "/slots/0?action=restore", + data={ + "filename": "slot1.bin", + }, + ) assert res.status_code == 200 assert res.body["n_restored"] == 84 # Since we have cache, slot 0 should only process the last tokens - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of Germany?", - "id_slot": 0, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of Germany?", + "id_slot": 0, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 6 # only different part is processed # For verification that slot 1 was not corrupted during slot 0 load, same thing should work - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of Germany?", - "id_slot": 1, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of Germany?", + "id_slot": 1, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Jack|said)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 1 @@ -74,11 +99,15 @@ def test_slot_erase(): global server server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of France?", - "id_slot": 1, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Whiskers|Flana)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed @@ -88,11 +117,15 @@ def test_slot_erase(): assert res.status_code == 200 # re-run the same prompt, it should process all tokens again - res = server.make_request("POST", "/completion", data={ - "prompt": "What is the capital of France?", - "id_slot": 1, - "cache_prompt": True, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "What is the capital of France?", + "id_slot": 1, + "cache_prompt": True, + }, + ) assert res.status_code == 200 assert match_regex("(Whiskers|Flana)+", res.body["content"]) assert res.body["timings"]["prompt_n"] == 21 # all tokens are processed diff --git a/smallthinker/tools/server/tests/unit/test_speculative.py b/smallthinker/tools/server/tests/unit/test_speculative.py index 54db38cf..195712b5 100644 --- a/smallthinker/tools/server/tests/unit/test_speculative.py +++ b/smallthinker/tools/server/tests/unit/test_speculative.py @@ -7,6 +7,7 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" + def create_server(): global server server = ServerPreset.stories15m_moe() @@ -25,11 +26,15 @@ def test_with_and_without_draft(): global server server.model_draft = None # disable draft model server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "temperature": 0.0, - "top_k": 1, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }, + ) assert res.status_code == 200 content_no_draft = res.body["content"] server.stop() @@ -37,11 +42,15 @@ def test_with_and_without_draft(): # create new server with draft model create_server() server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "temperature": 0.0, - "top_k": 1, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }, + ) assert res.status_code == 200 content_draft = res.body["content"] @@ -63,11 +72,15 @@ def test_different_draft_min_draft_max(): server.draft_min = draft_min server.draft_max = draft_max server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "I believe the meaning of life is", - "temperature": 0.0, - "top_k": 1, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }, + ) assert res.status_code == 200 if last_content is not None: assert last_content == res.body["content"] @@ -78,12 +91,16 @@ def test_slot_ctx_not_exceeded(): global server server.n_ctx = 64 server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, - "temperature": 0.0, - "top_k": 1, - "speculative.p_min": 0.0, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "speculative.p_min": 0.0, + }, + ) assert res.status_code == 200 assert len(res.body["content"]) > 0 @@ -92,34 +109,50 @@ def test_with_ctx_shift(): global server server.n_ctx = 64 server.start() - res = server.make_request("POST", "/completion", data={ - "prompt": "Hello " * 56, - "temperature": 0.0, - "top_k": 1, - "n_predict": 64, - "speculative.p_min": 0.0, - }) + res = server.make_request( + "POST", + "/completion", + data={ + "prompt": "Hello " * 56, + "temperature": 0.0, + "top_k": 1, + "n_predict": 64, + "speculative.p_min": 0.0, + }, + ) assert res.status_code == 200 assert len(res.body["content"]) > 0 assert res.body["tokens_predicted"] == 64 assert res.body["truncated"] == True -@pytest.mark.parametrize("n_slots,n_requests", [ - (1, 2), - (2, 2), -]) +@pytest.mark.parametrize( + "n_slots,n_requests", + [ + (1, 2), + (2, 2), + ], +) def test_multi_requests_parallel(n_slots: int, n_requests: int): global server server.n_slots = n_slots server.start() tasks = [] for _ in range(n_requests): - tasks.append((server.make_request, ("POST", "/completion", { - "prompt": "I believe the meaning of life is", - "temperature": 0.0, - "top_k": 1, - }))) + tasks.append( + ( + server.make_request, + ( + "POST", + "/completion", + { + "prompt": "I believe the meaning of life is", + "temperature": 0.0, + "top_k": 1, + }, + ), + ) + ) results = parallel_function_calls(tasks) for res in results: assert res.status_code == 200 diff --git a/smallthinker/tools/server/tests/unit/test_template.py b/smallthinker/tools/server/tests/unit/test_template.py index c53eda5b..7edb6994 100644 --- a/smallthinker/tools/server/tests/unit/test_template.py +++ b/smallthinker/tools/server/tests/unit/test_template.py @@ -6,6 +6,7 @@ import sys from unit.test_tool_call import TEST_TOOL + path = Path(__file__).resolve().parents[1] sys.path.insert(0, str(path)) @@ -14,7 +15,8 @@ server: ServerProcess -TIMEOUT_SERVER_START = 15*60 +TIMEOUT_SERVER_START = 15 * 60 + @pytest.fixture(autouse=True) def create_server(): @@ -26,83 +28,127 @@ def create_server(): @pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]]) -@pytest.mark.parametrize("template_name,reasoning_budget,expected_end", [ - ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "\n"), - ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "\n"), - ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "\n"), - - ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"), - ("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n\n\n\n\n"), - - ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n\n"), - ("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n\n"), - - ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", -1, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"), - ("CohereForAI-c4ai-command-r7b-12-2024-tool_use", 0, "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>"), -]) -def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expected_end: str, tools: list[dict]): +@pytest.mark.parametrize( + "template_name,reasoning_budget,expected_end", + [ + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", None, "\n"), + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", -1, "\n"), + ("deepseek-ai-DeepSeek-R1-Distill-Qwen-32B", 0, "\n"), + ("Qwen-Qwen3-0.6B", -1, "<|im_start|>assistant\n"), + ("Qwen-Qwen3-0.6B", 0, "<|im_start|>assistant\n\n\n\n\n"), + ("Qwen-QwQ-32B", -1, "<|im_start|>assistant\n\n"), + ("Qwen-QwQ-32B", 0, "<|im_start|>assistant\n\n"), + ( + "CohereForAI-c4ai-command-r7b-12-2024-tool_use", + -1, + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + ), + ( + "CohereForAI-c4ai-command-r7b-12-2024-tool_use", + 0, + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|><|START_THINKING|><|END_THINKING|>", + ), + ], +) +def test_reasoning_budget( + template_name: str, + reasoning_budget: int | None, + expected_end: str, + tools: list[dict], +): global server server.jinja = True server.reasoning_budget = reasoning_budget - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/apply-template", data={ - "messages": [ - {"role": "user", "content": "What is today?"}, - ], - "tools": tools, - }) + res = server.make_request( + "POST", + "/apply-template", + data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "tools": tools, + }, + ) assert res.status_code == 200 prompt = res.body["prompt"] - assert prompt.endswith(expected_end), f"Expected prompt to end with '{expected_end}', got '{prompt}'" + assert prompt.endswith( + expected_end + ), f"Expected prompt to end with '{expected_end}', got '{prompt}'" @pytest.mark.parametrize("tools", [None, [], [TEST_TOOL]]) -@pytest.mark.parametrize("template_name,format", [ - ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"), - ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"), -]) +@pytest.mark.parametrize( + "template_name,format", + [ + ("meta-llama-Llama-3.3-70B-Instruct", "%d %b %Y"), + ("fireworks-ai-llama-3-firefunction-v2", "%b %d %Y"), + ], +) def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): global server server.jinja = True - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/apply-template", data={ - "messages": [ - {"role": "user", "content": "What is today?"}, - ], - "tools": tools, - }) + res = server.make_request( + "POST", + "/apply-template", + data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "tools": tools, + }, + ) assert res.status_code == 200 prompt = res.body["prompt"] today_str = datetime.date.today().strftime(format) - assert today_str in prompt, f"Expected today's date ({today_str}) in content ({prompt})" + assert ( + today_str in prompt + ), f"Expected today's date ({today_str}) in content ({prompt})" @pytest.mark.parametrize("add_generation_prompt", [False, True]) -@pytest.mark.parametrize("template_name,expected_generation_prompt", [ - ("meta-llama-Llama-3.3-70B-Instruct", "<|start_header_id|>assistant<|end_header_id|>"), -]) -def test_add_generation_prompt(template_name: str, expected_generation_prompt: str, add_generation_prompt: bool): +@pytest.mark.parametrize( + "template_name,expected_generation_prompt", + [ + ( + "meta-llama-Llama-3.3-70B-Instruct", + "<|start_header_id|>assistant<|end_header_id|>", + ), + ], +) +def test_add_generation_prompt( + template_name: str, expected_generation_prompt: str, add_generation_prompt: bool +): global server server.jinja = True - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - res = server.make_request("POST", "/apply-template", data={ - "messages": [ - {"role": "user", "content": "What is today?"}, - ], - "add_generation_prompt": add_generation_prompt, - }) + res = server.make_request( + "POST", + "/apply-template", + data={ + "messages": [ + {"role": "user", "content": "What is today?"}, + ], + "add_generation_prompt": add_generation_prompt, + }, + ) assert res.status_code == 200 prompt = res.body["prompt"] if add_generation_prompt: - assert expected_generation_prompt in prompt, f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})" + assert ( + expected_generation_prompt in prompt + ), f"Expected generation prompt ({expected_generation_prompt}) in content ({prompt})" else: - assert expected_generation_prompt not in prompt, f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})" + assert ( + expected_generation_prompt not in prompt + ), f"Did not expect generation prompt ({expected_generation_prompt}) in content ({prompt})" diff --git a/smallthinker/tools/server/tests/unit/test_tokenize.py b/smallthinker/tools/server/tests/unit/test_tokenize.py index 382457c9..87f2476c 100644 --- a/smallthinker/tools/server/tests/unit/test_tokenize.py +++ b/smallthinker/tools/server/tests/unit/test_tokenize.py @@ -15,15 +15,17 @@ def test_tokenize_detokenize(): server.start() # tokenize content = "What is the capital of France ?" - res_tok = server.make_request("POST", "/tokenize", data={ - "content": content - }) + res_tok = server.make_request("POST", "/tokenize", data={"content": content}) assert res_tok.status_code == 200 assert len(res_tok.body["tokens"]) > 5 # detokenize - res_detok = server.make_request("POST", "/detokenize", data={ - "tokens": res_tok.body["tokens"], - }) + res_detok = server.make_request( + "POST", + "/detokenize", + data={ + "tokens": res_tok.body["tokens"], + }, + ) assert res_detok.status_code == 200 assert res_detok.body["content"].strip() == content @@ -34,10 +36,14 @@ def test_tokenize_with_bos(): # tokenize content = "What is the capital of France ?" bosId = 1 - res_tok = server.make_request("POST", "/tokenize", data={ - "content": content, - "add_special": True, - }) + res_tok = server.make_request( + "POST", + "/tokenize", + data={ + "content": content, + "add_special": True, + }, + ) assert res_tok.status_code == 200 assert res_tok.body["tokens"][0] == bosId @@ -47,10 +53,14 @@ def test_tokenize_with_pieces(): server.start() # tokenize content = "This is a test string with unicode 媽 and emoji 🤗" - res_tok = server.make_request("POST", "/tokenize", data={ - "content": content, - "with_pieces": True, - }) + res_tok = server.make_request( + "POST", + "/tokenize", + data={ + "content": content, + "with_pieces": True, + }, + ) assert res_tok.status_code == 200 for token in res_tok.body["tokens"]: assert "id" in token diff --git a/smallthinker/tools/server/tests/unit/test_tool_call.py b/smallthinker/tools/server/tests/unit/test_tool_call.py index 20f048c6..b521b72c 100755 --- a/smallthinker/tools/server/tests/unit/test_tool_call.py +++ b/smallthinker/tools/server/tests/unit/test_tool_call.py @@ -4,6 +4,7 @@ # ensure grandparent path is in sys.path from pathlib import Path import sys + path = Path(__file__).resolve().parents[1] sys.path.insert(0, str(path)) @@ -12,9 +13,10 @@ server: ServerProcess -TIMEOUT_SERVER_START = 15*60 +TIMEOUT_SERVER_START = 15 * 60 TIMEOUT_HTTP_REQUEST = 60 + @pytest.fixture(autouse=True) def create_server(): global server @@ -23,12 +25,14 @@ def create_server(): server.server_port = 8081 server.n_slots = 1 + class CompletionMode(Enum): NORMAL = "normal" STREAMED = "streamed" + TEST_TOOL = { - "type":"function", + "type": "function", "function": { "name": "test", "description": "", @@ -37,9 +41,9 @@ class CompletionMode(Enum): "properties": { "success": {"type": "boolean", "const": True}, }, - "required": ["success"] - } - } + "required": ["success"], + }, + }, } PYTHON_TOOL = { @@ -52,179 +56,284 @@ class CompletionMode(Enum): "properties": { "code": { "type": "string", - "description": "The code to run in the ipython interpreter." + "description": "The code to run in the ipython interpreter.", } }, - "required": ["code"] - } - } + "required": ["code"], + }, + }, } WEATHER_TOOL = { - "type":"function", - "function":{ - "name":"get_current_weather", - "description":"Get the current weather in a given location", - "parameters":{ - "type":"object", - "properties":{ - "location":{ - "type":"string", - "description":"The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'" - } - }, - "required":["location"] - } - } + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country/state, e.g. 'San Francisco, CA', or 'Paris, France'", + } + }, + "required": ["location"], + }, + }, } -def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs): - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "Write an example"}, - ], - "tool_choice": "required", - "tools": [tool], - "parallel_tool_calls": False, - **kwargs, - }) + +def do_test_completion_with_required_tool_tiny( + server: ServerProcess, tool: dict, argument_key: str | None, n_predict, **kwargs +): + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + **kwargs, + }, + ) # assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + assert ( + tool_calls and len(tool_calls) == 1 + ), f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] - assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in ( + None, + "", + ), f'Expected no content in {choice["message"]}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' - expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + expected_function_name = ( + "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + ) assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] assert isinstance(actual_arguments, str) if argument_key is not None: actual_arguments = json.loads(actual_arguments) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + assert ( + argument_key in actual_arguments + ), f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("template_name,tool,argument_key", [ - ("google-gemma-2-2b-it", TEST_TOOL, "success"), - ("google-gemma-2-2b-it", TEST_TOOL, "success"), - ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), - ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), -]) -def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode): +@pytest.mark.parametrize( + "template_name,tool,argument_key", + [ + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), + ], +) +def test_completion_with_required_tool_tiny_fast( + template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode +): global server n_predict = 1024 # server = ServerPreset.stories15m_moe() server.jinja = True server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) + do_test_completion_with_required_tool_tiny( + server, + tool, + argument_key, + n_predict, + stream=stream == CompletionMode.STREAMED, + temperature=0.0, + top_k=1, + top_p=1.0, + ) @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("template_name,tool,argument_key", [ - ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), - - ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), - ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), - - ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), - # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own. - # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), - - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), - - ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), - ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), - - ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), - ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), - - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), - - ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), - ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), - - ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), - # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True), - # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), - -]) -def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode): +@pytest.mark.parametrize( + "template_name,tool,argument_key", + [ + ("meta-llama-Llama-3.1-8B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.1-8B-Instruct", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.1", TEST_TOOL, "success"), + ("meetkai-functionary-medium-v3.1", PYTHON_TOOL, "code"), + ("meetkai-functionary-medium-v3.2", TEST_TOOL, "success"), + # Functionary v3.2 format supports raw python content, which w/ a dummy stories model will never end on its own. + # ("meetkai-functionary-medium-v3.2", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", PYTHON_TOOL, "code"), + ("meta-llama-Llama-3.2-3B-Instruct", TEST_TOOL, "success"), + ("meta-llama-Llama-3.2-3B-Instruct", PYTHON_TOOL, "code"), + ("mistralai-Mistral-Nemo-Instruct-2407", TEST_TOOL, "success"), + ("mistralai-Mistral-Nemo-Instruct-2407", PYTHON_TOOL, "code"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", TEST_TOOL, "success"), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", PYTHON_TOOL, "code"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", TEST_TOOL, "success"), + ("deepseek-ai-DeepSeek-R1-Distill-Llama-8B", PYTHON_TOOL, "code"), + ("fireworks-ai-llama-3-firefunction-v2", TEST_TOOL, "success"), + # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "codeFalse), True), + # ("fireworks-ai-llama-3-firefunction-v2", PYTHON_TOOL, "code"), + ], +) +def test_completion_with_required_tool_tiny_slow( + template_name: str, tool: dict, argument_key: str | None, stream: CompletionMode +): global server n_predict = 512 # server = ServerPreset.stories15m_moe() server.jinja = True server.n_predict = n_predict - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) + do_test_completion_with_required_tool_tiny( + server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED + ) @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("tool,argument_key,hf_repo,template_override", [ - (TEST_TOOL, "success", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), - - # (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), - (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", ("meetkai/functionary-medium-v3.2", None)), - (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), - - (TEST_TOOL, "success", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (PYTHON_TOOL, "code", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), -]) -def test_completion_with_required_tool_real_model(tool: dict, argument_key: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): +@pytest.mark.parametrize( + "tool,argument_key,hf_repo,template_override", + [ + ( + TEST_TOOL, + "success", + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", + None, + ), + (PYTHON_TOOL, "code", "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ( + PYTHON_TOOL, + "code", + "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", + "chatml", + ), + (TEST_TOOL, "success", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + (TEST_TOOL, "success", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + ( + PYTHON_TOOL, + "code", + "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", + "chatml", + ), + (TEST_TOOL, "success", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + (PYTHON_TOOL, "code", "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + TEST_TOOL, + "success", + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use"), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use"), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + "chatml", + ), + ( + TEST_TOOL, + "success", + "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use"), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use"), + ), + (PYTHON_TOOL, "code", "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + # (TEST_TOOL, "success", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # (PYTHON_TOOL, "code", "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ( + TEST_TOOL, + "success", + "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", + ("meetkai/functionary-medium-v3.2", None), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", + ("meetkai/functionary-medium-v3.2", None), + ), + (PYTHON_TOOL, "code", "bartowski/functionary-small-v3.2-GGUF:Q4_K_M", "chatml"), + ( + TEST_TOOL, + "success", + "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", + ("meta-llama/Llama-3.2-3B-Instruct", None), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", + ("meta-llama/Llama-3.2-3B-Instruct", None), + ), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + TEST_TOOL, + "success", + "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", + ("meta-llama/Llama-3.2-3B-Instruct", None), + ), + ( + PYTHON_TOOL, + "code", + "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", + ("meta-llama/Llama-3.2-3B-Instruct", None), + ), + (PYTHON_TOOL, "code", "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + TEST_TOOL, + "success", + "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", + None, + ), + ( + PYTHON_TOOL, + "code", + "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", + None, + ), + ], +) +def test_completion_with_required_tool_real_model( + tool: dict, + argument_key: str | None, + hf_repo: str, + template_override: str | Tuple[str, str | None] | None, + stream: CompletionMode, +): global server n_predict = 512 server.jinja = True @@ -235,133 +344,187 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists( + server.chat_template_file + ), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "Write an example"}, - ], - "tool_choice": "required", - "tools": [tool], - "parallel_tool_calls": False, - "stream": stream == CompletionMode.STREAMED, - "temperature": 0.0, - "top_k": 1, - "top_p": 1.0, - }, timeout=TIMEOUT_HTTP_REQUEST) + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "Write an example"}, + ], + "tool_choice": "required", + "tools": [tool], + "parallel_tool_calls": False, + "stream": stream == CompletionMode.STREAMED, + "temperature": 0.0, + "top_k": 1, + "top_p": 1.0, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + assert ( + tool_calls and len(tool_calls) == 1 + ), f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' - expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + expected_function_name = ( + "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] + ) assert expected_function_name == tool_call["function"]["name"] actual_arguments = tool_call["function"]["arguments"] assert isinstance(actual_arguments, str) if argument_key is not None: actual_arguments = json.loads(actual_arguments) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" - - -def do_test_completion_without_tool_call(server: ServerProcess, n_predict: int, tools: list[dict], tool_choice: str | None, **kwargs): - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a coding assistant."}, - {"role": "user", "content": "say hello world with python"}, - ], - "tools": tools if tools else None, - "tool_choice": tool_choice, - **kwargs, - }, timeout=TIMEOUT_HTTP_REQUEST) + assert ( + argument_key in actual_arguments + ), f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + + +def do_test_completion_without_tool_call( + server: ServerProcess, + n_predict: int, + tools: list[dict], + tool_choice: str | None, + **kwargs, +): + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "max_tokens": n_predict, + "messages": [ + {"role": "system", "content": "You are a coding assistant."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": tools if tools else None, + "tool_choice": tool_choice, + **kwargs, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] - assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + assert ( + choice["message"].get("tool_calls") is None + ), f'Expected no tool call in {choice["message"]}' @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ - ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), - ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), - ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], 'none'), -]) -def test_completion_without_tool_call_fast(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode): +@pytest.mark.parametrize( + "template_name,n_predict,tools,tool_choice", + [ + ("meta-llama-Llama-3.3-70B-Instruct", 128, [], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [TEST_TOOL], None), + ("meta-llama-Llama-3.3-70B-Instruct", 128, [PYTHON_TOOL], "none"), + ], +) +def test_completion_without_tool_call_fast( + template_name: str, + n_predict: int, + tools: list[dict], + tool_choice: str | None, + stream: CompletionMode, +): global server server.n_predict = n_predict server.jinja = True - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) + do_test_completion_without_tool_call( + server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED + ) @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("template_name,n_predict,tools,tool_choice", [ - ("meetkai-functionary-medium-v3.2", 256, [], None), - ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], 'none'), - ("meetkai-functionary-medium-v3.1", 256, [], None), - ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None), - ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], 'none'), - ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None), - ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), - ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], 'none'), -]) -def test_completion_without_tool_call_slow(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None, stream: CompletionMode): +@pytest.mark.parametrize( + "template_name,n_predict,tools,tool_choice", + [ + ("meetkai-functionary-medium-v3.2", 256, [], None), + ("meetkai-functionary-medium-v3.2", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.2", 256, [PYTHON_TOOL], "none"), + ("meetkai-functionary-medium-v3.1", 256, [], None), + ("meetkai-functionary-medium-v3.1", 256, [TEST_TOOL], None), + ("meetkai-functionary-medium-v3.1", 256, [PYTHON_TOOL], "none"), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [TEST_TOOL], None), + ("meta-llama-Llama-3.2-3B-Instruct", 256, [PYTHON_TOOL], "none"), + ], +) +def test_completion_without_tool_call_slow( + template_name: str, + n_predict: int, + tools: list[dict], + tool_choice: str | None, + stream: CompletionMode, +): global server server.n_predict = n_predict server.jinja = True - server.chat_template_file = f'../../../models/templates/{template_name}.jinja' + server.chat_template_file = f"../../../models/templates/{template_name}.jinja" server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) + do_test_completion_without_tool_call( + server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED + ) @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("hf_repo,template_override", [ - ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), - - # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - - # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), - # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - - ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), - ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use")), - - ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - - # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. - ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - - # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), -]) -def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): +@pytest.mark.parametrize( + "hf_repo,template_override", + [ + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-1.5B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use"), + ), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use"), + ), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + # ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), + # ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + ( + "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", + ("meta-llama/Llama-3.2-3B-Instruct", None), + ), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/c4ai-command-r7b-12-2024-GGUF:Q6_K_L", + ("CohereForAI/c4ai-command-r7b-12-2024", "tool_use"), + ), + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # Note: gemma-2-2b-it knows itself as "model", not "assistant", so we don't test the ill-suited chatml on it. + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + # ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama/Llama-3.2-3B-Instruct", None)), + ], +) +def test_weather( + hf_repo: str, + template_override: str | Tuple[str, str | None] | None, + stream: CompletionMode, +): global server n_predict = 512 server.jinja = True @@ -372,56 +535,106 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists( + server.chat_template_file + ), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) + do_test_weather( + server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict + ) def do_test_weather(server: ServerProcess, **kwargs): - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "messages": [ - {"role": "system", "content": "You are a chatbot that uses tools/functions. Dont overthink things."}, - {"role": "user", "content": "What is the weather in Istanbul?"}, - ], - "tools": [WEATHER_TOOL], - **kwargs, - }, timeout=TIMEOUT_HTTP_REQUEST) + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "messages": [ + { + "role": "system", + "content": "You are a chatbot that uses tools/functions. Dont overthink things.", + }, + {"role": "user", "content": "What is the weather in Istanbul?"}, + ], + "tools": [WEATHER_TOOL], + **kwargs, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + assert ( + tool_calls and len(tool_calls) == 1 + ), f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' - assert tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"], f'Expected weather tool call, got {tool_call["function"]["name"]}' + assert ( + tool_call["function"]["name"] == WEATHER_TOOL["function"]["name"] + ), f'Expected weather tool call, got {tool_call["function"]["name"]}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' actual_arguments = json.loads(tool_call["function"]["arguments"]) - assert 'location' in actual_arguments, f"location not found in {json.dumps(actual_arguments)}" + assert ( + "location" in actual_arguments + ), f"location not found in {json.dumps(actual_arguments)}" location = actual_arguments["location"] - assert isinstance(location, str), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" - assert re.match('^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$', location), f'Expected Istanbul for location, got {location}' + assert isinstance( + location, str + ), f"Expected location to be a string, got {type(location)}: {json.dumps(location)}" + assert re.match( + "^Istanbul(( |, ?)(TR|Turkey|Türkiye))?$", location + ), f"Expected Istanbul for location, got {location}" @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("result_override,n_predict,hf_repo,template_override", [ - (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), - (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - (None, 128, "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use")), - (None, 128, "bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai/functionary-medium-v3.2", None)), - (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), - - # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) - # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), -]) -def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): +@pytest.mark.parametrize( + "result_override,n_predict,hf_repo,template_override", + [ + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Qwen2.5-Coder-3B-Instruct-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + None, + 128, + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use"), + ), + ( + None, + 128, + "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-3-Llama-3.1-8B", "tool_use"), + ), + ( + None, + 128, + "bartowski/functionary-small-v3.2-GGUF:Q8_0", + ("meetkai/functionary-medium-v3.2", None), + ), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + (None, 128, "bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + (None, 128, "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ( + "[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", + 8192, + "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", + ("llama-cpp-deepseek-r1", None), + ), + # TODO: fix these (wrong results, either didn't respect decimal instruction or got wrong value) + # (None, 128, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + # ("[\\s\\S]*?\\*\\*\\s*0.5($|\\*\\*)", 8192, "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ], +) +def test_calc_result( + result_override: str | None, + n_predict: int, + hf_repo: str, + template_override: str | Tuple[str, str | None] | None, + stream: CompletionMode, +): global server server.jinja = True server.n_ctx = 8192 * 2 @@ -431,84 +644,144 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists( + server.chat_template_file + ), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) - - -def do_test_calc_result(server: ServerProcess, result_override: str | None, n_predict: int, **kwargs): - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "system", "content": "You are a tools-calling assistant. You express numerical values with at most two decimals."}, - {"role": "user", "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?"}, - { - "role": "assistant", - "content": None, - "tool_calls": [ - { - "id": "call_6789", - "type": "function", - "function": { - "name": "calculate", - "arguments": "{\"expression\":\"sin(30 * pi / 180)\"}" + do_test_calc_result( + server, result_override, n_predict, stream=stream == CompletionMode.STREAMED + ) + + +def do_test_calc_result( + server: ServerProcess, result_override: str | None, n_predict: int, **kwargs +): + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "max_tokens": n_predict, + "messages": [ + { + "role": "system", + "content": "You are a tools-calling assistant. You express numerical values with at most two decimals.", + }, + { + "role": "user", + "content": "What's the y coordinate of a point on the unit sphere at angle 30 degrees?", + }, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_6789", + "type": "function", + "function": { + "name": "calculate", + "arguments": '{"expression":"sin(30 * pi / 180)"}', + }, } - } - ] - }, - { - "role": "tool", - "name": "calculate", - "content": "0.55644242476", - "tool_call_id": "call_6789" - } - ], - "tools": [ - { - "type":"function", - "function":{ - "name":"calculate", - "description":"A calculator function that computes values of arithmetic expressions in the Python syntax", - "parameters":{ - "type":"object", - "properties":{ - "expression":{ - "type":"string", - "description":"An arithmetic expression to compute the value of (Python syntad, assuming all floats)" - } + ], + }, + { + "role": "tool", + "name": "calculate", + "content": "0.55644242476", + "tool_call_id": "call_6789", + }, + ], + "tools": [ + { + "type": "function", + "function": { + "name": "calculate", + "description": "A calculator function that computes values of arithmetic expressions in the Python syntax", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "An arithmetic expression to compute the value of (Python syntad, assuming all floats)", + } + }, + "required": ["expression"], }, - "required":["expression"] - } + }, } - } - ], - **kwargs, - }, timeout=TIMEOUT_HTTP_REQUEST) + ], + **kwargs, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] tool_calls = choice["message"].get("tool_calls") assert tool_calls is None, f'Expected no tool call in {choice["message"]}' content = choice["message"].get("content") assert content is not None, f'Expected content in {choice["message"]}' if result_override is not None: - assert re.match(result_override, content), f'Expected {result_override}, got {content}' + assert re.match( + result_override, content + ), f"Expected {result_override}, got {content}" else: - assert re.match('^[\\s\\S]*?((That\'s|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)', content), \ - f'Expected something like "The y coordinate is 0.56.", got {content}' + assert re.match( + "^[\\s\\S]*?((That's|\\bis) (approximately )?)?\\b0\\.(5\\b|56\\b|556)", + content, + ), f'Expected something like "The y coordinate is 0.56.", got {content}' @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", [ - (128, 'deepseek', None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (128, None, None, "^The sum of 102 and 7 is 109[\\s\\S]*", "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - (1024, 'deepseek', "I need to calculate the sum of 102 and 7[\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - (1024, 'deepseek', "First, I [\\s\\S]*", "To find the sum of[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", ("llama-cpp-deepseek-r1", None)), - # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), -]) -def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] | None, expect_content: str | None, expect_reasoning_content: str | None, hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): +@pytest.mark.parametrize( + "n_predict,reasoning_format,expect_reasoning_content,expect_content,hf_repo,template_override", + [ + ( + 128, + "deepseek", + None, + "^The sum of 102 and 7 is 109[\\s\\S]*", + "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", + None, + ), + ( + 128, + None, + None, + "^The sum of 102 and 7 is 109[\\s\\S]*", + "bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", + None, + ), + ( + 1024, + "deepseek", + "I need to calculate the sum of 102 and 7[\\s\\S]*", + "To find the sum of[\\s\\S]*", + "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", + None, + ), + ( + 1024, + "deepseek", + "First, I [\\s\\S]*", + "To find the sum of[\\s\\S]*", + "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", + ("llama-cpp-deepseek-r1", None), + ), + # (1024, 'none', CompletionMode.NORMAL, None, "^(\\s*)?I need[\\s\\S]*?\\s*To find[\\s\\S]*", "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + # (128, 'deepseek', None, "^Okay, let me figure out the sum of 102 and 7[\\s\\S]*", "bartowski/Qwen_QwQ-32B-GGUF:Q4_K_M", None), + ], +) +def test_thoughts( + n_predict: int, + reasoning_format: Literal["deepseek", "none"] | None, + expect_content: str | None, + expect_reasoning_content: str | None, + hf_repo: str, + template_override: str | Tuple[str, str | None] | None, + stream: CompletionMode, +): global server server.reasoning_format = reasoning_format server.jinja = True @@ -519,71 +792,101 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists( + server.chat_template_file + ), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "max_tokens": n_predict, - "messages": [ - {"role": "user", "content": "What's the sum of 102 and 7?"}, - ], - "stream": stream == CompletionMode.STREAMED, - }, timeout=TIMEOUT_HTTP_REQUEST) + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "max_tokens": n_predict, + "messages": [ + {"role": "user", "content": "What's the sum of 102 and 7?"}, + ], + "stream": stream == CompletionMode.STREAMED, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] - assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' + assert ( + choice["message"].get("tool_calls") is None + ), f'Expected no tool call in {choice["message"]}' content = choice["message"].get("content") if expect_content is None: - assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' + assert choice["message"].get("content") in ( + None, + "", + ), f'Expected no content in {choice["message"]}' else: - assert re.match(expect_content, content), f'Expected {expect_content}, got {content}' + assert re.match( + expect_content, content + ), f"Expected {expect_content}, got {content}" reasoning_content = choice["message"].get("reasoning_content") if expect_reasoning_content is None: - assert reasoning_content is None, f'Expected no reasoning content in {choice["message"]}' + assert ( + reasoning_content is None + ), f'Expected no reasoning content in {choice["message"]}' else: - assert re.match(expect_reasoning_content, reasoning_content), f'Expected {expect_reasoning_content}, got {reasoning_content}' + assert re.match( + expect_reasoning_content, reasoning_content + ), f"Expected {expect_reasoning_content}, got {reasoning_content}" @pytest.mark.slow @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) -@pytest.mark.parametrize("hf_repo,template_override", [ - ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), - - ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), - ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/functionary-small-v3.2-GGUF:Q8_0", ("meetkai-functionary-medium-v3.2", None)), - ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), - - # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None), - - ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", ("meta-llama-Llama-3.2-3B-Instruct", None)), - ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None), - - ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), - ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use")), - ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use")), - ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), - - ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), - ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), - - ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), - ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), -]) -def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | None] | None, stream: CompletionMode): +@pytest.mark.parametrize( + "hf_repo,template_override", + [ + ("bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", None), + ("bartowski/Phi-3.5-mini-instruct-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/functionary-small-v3.2-GGUF:Q8_0", + ("meetkai-functionary-medium-v3.2", None), + ), + ("bartowski/functionary-small-v3.2-GGUF:Q8_0", "chatml"), + # ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Meta-Llama-3.1-8B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", + ("meta-llama-Llama-3.2-3B-Instruct", None), + ), + ("bartowski/Llama-3.2-1B-Instruct-GGUF:Q4_K_M", None), + ( + "bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", + ("meta-llama-Llama-3.2-3B-Instruct", None), + ), + ("bartowski/Llama-3.2-3B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", None), + ("bartowski/Qwen2.5-7B-Instruct-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", + ("NousResearch/Hermes-2-Pro-Llama-3-8B", "tool_use"), + ), + ("bartowski/Hermes-2-Pro-Llama-3-8B-GGUF:Q4_K_M", "chatml"), + ( + "bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", + ("NousResearch-Hermes-3-Llama-3.1-8B", "tool_use"), + ), + ("bartowski/Hermes-3-Llama-3.1-8B-GGUF:Q4_K_M", "chatml"), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", None), + ("bartowski/Mistral-Nemo-Instruct-2407-GGUF:Q4_K_M", "chatml"), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", None), + ("bartowski/gemma-2-2b-it-GGUF:Q4_K_M", "chatml"), + ], +) +def test_hello_world( + hf_repo: str, + template_override: str | Tuple[str, str | None] | None, + stream: CompletionMode, +): global server - n_predict = 512 # High because of DeepSeek R1 + n_predict = 512 # High because of DeepSeek R1 server.jinja = True server.n_ctx = 8192 server.n_predict = n_predict @@ -592,32 +895,50 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non if isinstance(template_override, tuple): (template_hf_repo, template_variant) = template_override server.chat_template_file = f"../../../models/templates/{template_hf_repo.replace('/', '-') + ('-' + template_variant if template_variant else '')}.jinja" - assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." + assert os.path.exists( + server.chat_template_file + ), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override server.start(timeout_seconds=TIMEOUT_SERVER_START) - do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) + do_test_hello_world( + server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict + ) def do_test_hello_world(server: ServerProcess, **kwargs): - body = server.make_any_request("POST", "/v1/chat/completions", data={ - "messages": [ - {"role": "system", "content": "You are a tool-calling agent."}, - {"role": "user", "content": "say hello world with python"}, - ], - "tools": [PYTHON_TOOL], - **kwargs, - }, timeout=TIMEOUT_HTTP_REQUEST) + body = server.make_any_request( + "POST", + "/v1/chat/completions", + data={ + "messages": [ + {"role": "system", "content": "You are a tool-calling agent."}, + {"role": "user", "content": "say hello world with python"}, + ], + "tools": [PYTHON_TOOL], + **kwargs, + }, + timeout=TIMEOUT_HTTP_REQUEST, + ) choice = body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' + assert ( + tool_calls and len(tool_calls) == 1 + ), f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] # assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' assert tool_call["function"]["name"] == PYTHON_TOOL["function"]["name"] # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' actual_arguments = json.loads(tool_call["function"]["arguments"]) - assert 'code' in actual_arguments, f"code not found in {json.dumps(actual_arguments)}" + assert ( + "code" in actual_arguments + ), f"code not found in {json.dumps(actual_arguments)}" code = actual_arguments["code"] - assert isinstance(code, str), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" - assert re.match(r'''print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)''', re.sub(r'#.*\n?', '', code)), f'Expected hello world, got {code}' + assert isinstance( + code, str + ), f"Expected code to be a string, got {type(code)}: {json.dumps(code)}" + assert re.match( + r"""print\(("[Hh]ello,? [Ww]orld!?"|'[Hh]ello,? [Ww]orld!?')\)""", + re.sub(r"#.*\n?", "", code), + ), f"Expected hello world, got {code}" diff --git a/smallthinker/tools/server/tests/unit/test_vision_api.py b/smallthinker/tools/server/tests/unit/test_vision_api.py index fc63caa1..20810909 100644 --- a/smallthinker/tools/server/tests/unit/test_vision_api.py +++ b/smallthinker/tools/server/tests/unit/test_vision_api.py @@ -5,12 +5,18 @@ server: ServerProcess -IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" -IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" +IMG_URL_0 = ( + "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" +) +IMG_URL_1 = ( + "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" +) response = requests.get(IMG_URL_0) -response.raise_for_status() # Raise an exception for bad status codes -IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") +response.raise_for_status() # Raise an exception for bad status codes +IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode( + "utf-8" +) @pytest.fixture(autouse=True) @@ -23,33 +29,55 @@ def create_server(): "prompt, image_url, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this:\n", IMG_URL_0, True, "(cat)+"), - ("What is this:\n", "IMG_BASE64_0", True, "(cat)+"), # exceptional, so that we don't cog up the log - ("What is this:\n", IMG_URL_1, True, "(frog)+"), - ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache - ("What is this:\n", "malformed", False, None), - ("What is this:\n", "https://google.com/404", False, None), # non-existent image - ("What is this:\n", "https://ggml.ai", False, None), # non-image data + ("What is this:\n", IMG_URL_0, True, "(cat)+"), + ( + "What is this:\n", + "IMG_BASE64_0", + True, + "(cat)+", + ), # exceptional, so that we don't cog up the log + ("What is this:\n", IMG_URL_1, True, "(frog)+"), + ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "malformed", False, None), + ( + "What is this:\n", + "https://google.com/404", + False, + None, + ), # non-existent image + ("What is this:\n", "https://ggml.ai", False, None), # non-image data # TODO @ngxson : test with multiple images, no images and with audio - ] + ], ) def test_vision_chat_completion(prompt, image_url, success, re_content): global server - server.start(timeout_seconds=60) # vision model may take longer to load due to download size + server.start( + timeout_seconds=60 + ) # vision model may take longer to load due to download size if image_url == "IMG_BASE64_0": image_url = IMG_BASE64_0 - res = server.make_request("POST", "/chat/completions", data={ - "temperature": 0.0, - "top_k": 1, - "messages": [ - {"role": "user", "content": [ - {"type": "text", "text": prompt}, - {"type": "image_url", "image_url": { - "url": image_url, - }}, - ]}, - ], - }) + res = server.make_request( + "POST", + "/chat/completions", + data={ + "temperature": 0.0, + "top_k": 1, + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": prompt}, + { + "type": "image_url", + "image_url": { + "url": image_url, + }, + }, + ], + }, + ], + }, + ) if success: assert res.status_code == 200 choice = res.body["choices"][0] @@ -57,4 +85,3 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): assert match_regex(re_content, choice["message"]["content"]) else: assert res.status_code != 200 - diff --git a/smallthinker/tools/server/tests/utils.py b/smallthinker/tools/server/tests/utils.py index bc547ca0..49079126 100644 --- a/smallthinker/tools/server/tests/utils.py +++ b/smallthinker/tools/server/tests/utils.py @@ -84,7 +84,7 @@ class ServerProcess: draft_max: int | None = None no_webui: bool | None = None jinja: bool | None = None - reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None + reasoning_format: Literal["deepseek", "none", "nothink"] | None = None reasoning_budget: int | None = None chat_template: str | None = None chat_template_file: str | None = None @@ -215,7 +215,11 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: creationflags=flags, stdout=sys.stdout, stderr=sys.stdout, - env={**os.environ, "LLAMA_CACHE": "tmp"} if "LLAMA_CACHE" not in os.environ else None, + env=( + {**os.environ, "LLAMA_CACHE": "tmp"} + if "LLAMA_CACHE" not in os.environ + else None + ), ) server_instances.add(self) @@ -225,9 +229,15 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: start_time = time.time() while time.time() - start_time < timeout_seconds: try: - response = self.make_request("GET", "/health", headers={ - "Authorization": f"Bearer {self.api_key}" if self.api_key else None - }) + response = self.make_request( + "GET", + "/health", + headers={ + "Authorization": ( + f"Bearer {self.api_key}" if self.api_key else None + ) + }, + ) if response.status_code == 200: self.ready = True return # server is ready @@ -235,7 +245,9 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: pass # Check if process died if self.process.poll() is not None: - raise RuntimeError(f"Server process died with return code {self.process.returncode}") + raise RuntimeError( + f"Server process died with return code {self.process.returncode}" + ) print(f"Waiting for server to start...") time.sleep(0.5) @@ -290,9 +302,9 @@ def make_stream_request( raise ValueError(f"Unimplemented method: {method}") for line_bytes in response.iter_lines(): line = line_bytes.decode("utf-8") - if '[DONE]' in line: + if "[DONE]" in line: break - elif line.startswith('data: '): + elif line.startswith("data: "): data = json.loads(line[6:]) print("Partial response from server", json.dumps(data, indent=2)) yield data @@ -305,7 +317,7 @@ def make_any_request( headers: dict | None = None, timeout: float | None = None, ) -> dict: - stream = data.get('stream', False) + stream = data.get("stream", False) if stream: content: list[str] = [] reasoning_content: list[str] = [] @@ -318,56 +330,75 @@ def make_any_request( arguments_parts = 0 for chunk in self.make_stream_request(method, path, data, headers): - assert len(chunk['choices']) == 1, f'Expected 1 choice, got {len(chunk["choices"])}' - choice = chunk['choices'][0] - if choice['delta'].get('content') is not None: - assert len(choice['delta']['content']) > 0, f'Expected non empty content delta!' - content.append(choice['delta']['content']) + assert ( + len(chunk["choices"]) == 1 + ), f'Expected 1 choice, got {len(chunk["choices"])}' + choice = chunk["choices"][0] + if choice["delta"].get("content") is not None: + assert ( + len(choice["delta"]["content"]) > 0 + ), f"Expected non empty content delta!" + content.append(choice["delta"]["content"]) content_parts += 1 - if choice['delta'].get('reasoning_content') is not None: - assert len(choice['delta']['reasoning_content']) > 0, f'Expected non empty reasoning_content delta!' - reasoning_content.append(choice['delta']['reasoning_content']) + if choice["delta"].get("reasoning_content") is not None: + assert ( + len(choice["delta"]["reasoning_content"]) > 0 + ), f"Expected non empty reasoning_content delta!" + reasoning_content.append(choice["delta"]["reasoning_content"]) reasoning_content_parts += 1 - if choice['delta'].get('finish_reason') is not None: - finish_reason = choice['delta']['finish_reason'] - for tc in choice['delta'].get('tool_calls', []): - if 'function' not in tc: + if choice["delta"].get("finish_reason") is not None: + finish_reason = choice["delta"]["finish_reason"] + for tc in choice["delta"].get("tool_calls", []): + if "function" not in tc: raise ValueError(f"Expected function type, got {tc['type']}") - if tc['index'] >= len(tool_calls): - assert 'id' in tc - assert tc.get('type') == 'function' - assert 'function' in tc and 'name' in tc['function'] and len(tc['function']['name']) > 0, \ - f"Expected function call with name, got {tc.get('function')}" - tool_calls.append(dict( - id="", - type="function", - function=dict( - name="", - arguments="", + if tc["index"] >= len(tool_calls): + assert "id" in tc + assert tc.get("type") == "function" + assert ( + "function" in tc + and "name" in tc["function"] + and len(tc["function"]["name"]) > 0 + ), f"Expected function call with name, got {tc.get('function')}" + tool_calls.append( + dict( + id="", + type="function", + function=dict( + name="", + arguments="", + ), ) - )) - tool_call = tool_calls[tc['index']] - if tc.get('id') is not None: - tool_call['id'] = tc['id'] - fct = tc['function'] - assert 'id' not in fct, f"Function call should not have id: {fct}" - if fct.get('name') is not None: - tool_call['function']['name'] = tool_call['function'].get('name', '') + fct['name'] - if fct.get('arguments') is not None: - tool_call['function']['arguments'] += fct['arguments'] + ) + tool_call = tool_calls[tc["index"]] + if tc.get("id") is not None: + tool_call["id"] = tc["id"] + fct = tc["function"] + assert "id" not in fct, f"Function call should not have id: {fct}" + if fct.get("name") is not None: + tool_call["function"]["name"] = ( + tool_call["function"].get("name", "") + fct["name"] + ) + if fct.get("arguments") is not None: + tool_call["function"]["arguments"] += fct["arguments"] arguments_parts += 1 tool_call_parts += 1 - print(f'Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts') + print( + f"Streamed response had {content_parts} content parts, {reasoning_content_parts} reasoning_content parts, {tool_call_parts} tool call parts incl. {arguments_parts} arguments parts" + ) result = dict( choices=[ dict( index=0, finish_reason=finish_reason, message=dict( - role='assistant', - content=''.join(content) if content else None, - reasoning_content=''.join(reasoning_content) if reasoning_content else None, + role="assistant", + content="".join(content) if content else None, + reasoning_content=( + "".join(reasoning_content) + if reasoning_content + else None + ), tool_calls=tool_calls if tool_calls else None, ), ) @@ -377,11 +408,12 @@ def make_any_request( return result else: response = self.make_request(method, path, data, headers, timeout=timeout) - assert response.status_code == 200, f"Server returned error: {response.status_code}" + assert ( + response.status_code == 200 + ), f"Server returned error: {response.status_code}" return response.body - server_instances: Set[ServerProcess] = set() @@ -485,7 +517,9 @@ def tinygemma3() -> ServerProcess: return server -def parallel_function_calls(function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]]) -> List[Any]: +def parallel_function_calls( + function_list: List[Tuple[Callable[..., Any], Tuple[Any, ...]]], +) -> List[Any]: """ Run multiple functions in parallel and return results in the same order as calls. Equivalent to Promise.all in JS. @@ -542,8 +576,8 @@ def download_file(url: str, output_file_path: str | None = None) -> str: Returns the local path of the downloaded file. """ - file_name = url.split('/').pop() - output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + file_name = url.split("/").pop() + output_file = f"./tmp/{file_name}" if output_file_path is None else output_file_path if not os.path.exists(output_file): print(f"Downloading {url} to {output_file}") wget.download(url, out=output_file) diff --git a/smallthinker/tools/server/webui/src/utils/storage.ts b/smallthinker/tools/server/webui/src/utils/storage.ts index 505693e9..db7fe129 100644 --- a/smallthinker/tools/server/webui/src/utils/storage.ts +++ b/smallthinker/tools/server/webui/src/utils/storage.ts @@ -30,6 +30,28 @@ db.version(1).stores({ messages: '&id, convId, [convId+id], timestamp', }); +const parseStorageJson = ( + key: string, + fallback: T, + opts?: { clearOnInvalid?: boolean } +): T => { + const raw = localStorage.getItem(key); + if (!raw) return fallback; + + try { + return JSON.parse(raw) as T; + } catch (error) { + console.warn( + `[StorageUtils] Invalid JSON in localStorage key "${key}"`, + error + ); + if (opts?.clearOnInvalid) { + localStorage.removeItem(key); + } + return fallback; + } +}; + // convId is a string prefixed with 'conv-' const StorageUtils = { /** @@ -193,11 +215,27 @@ const StorageUtils = { // manage config getConfig(): typeof CONFIG_DEFAULT { - const savedVal = JSON.parse(localStorage.getItem('config') || '{}'); + const savedVal = parseStorageJson( + 'config', + {}, + { + clearOnInvalid: true, + } + ); + const isObjectConfig = + savedVal !== null && + typeof savedVal === 'object' && + !Array.isArray(savedVal); + if (!isObjectConfig) { + console.warn( + '[StorageUtils] Config value is not an object. Resetting localStorage key "config".' + ); + localStorage.removeItem('config'); + } // to prevent breaking changes in the future, we always provide default value for missing keys return { ...CONFIG_DEFAULT, - ...savedVal, + ...(isObjectConfig ? (savedVal as Partial) : {}), }; }, setConfig(config: typeof CONFIG_DEFAULT) { @@ -231,12 +269,51 @@ interface LSMessage { content: string; timings?: TimingReport; } + +const isLSMessage = (value: unknown): value is LSMessage => { + if (!value || typeof value !== 'object') return false; + const msg = value as Partial; + return ( + typeof msg.id === 'number' && + (msg.role === 'user' || + msg.role === 'assistant' || + msg.role === 'system') && + typeof msg.content === 'string' + ); +}; + +const isLSConversation = (value: unknown): value is LSConversation => { + if (!value || typeof value !== 'object') return false; + const conv = value as Partial; + if ( + typeof conv.id !== 'string' || + typeof conv.lastModified !== 'number' || + !Array.isArray(conv.messages) + ) { + return false; + } + return conv.messages.every(isLSMessage); +}; + async function migrationLStoIDB() { if (localStorage.getItem('migratedToIDB')) return; const res: LSConversation[] = []; for (const key in localStorage) { if (key.startsWith('conv-')) { - res.push(JSON.parse(localStorage.getItem(key) ?? '{}')); + const parsed = parseStorageJson(key, null, { + clearOnInvalid: true, + }); + if (parsed === null) continue; + + if (!isLSConversation(parsed)) { + console.warn( + `[StorageUtils] Invalid conversation payload in key "${key}". Removing key.` + ); + localStorage.removeItem(key); + continue; + } + + res.push(parsed); } } if (res.length === 0) return; diff --git a/smallthinker/tools/tts/convert_pt_to_hf.py b/smallthinker/tools/tts/convert_pt_to_hf.py index 8909a65f..ab4c6fab 100644 --- a/smallthinker/tools/tts/convert_pt_to_hf.py +++ b/smallthinker/tools/tts/convert_pt_to_hf.py @@ -12,7 +12,7 @@ from safetensors.torch import save_file # default -model_path = './model.pt'; +model_path = "./model.pt" # read from CLI if len(sys.argv) > 1: @@ -23,18 +23,18 @@ print(f"Loading model from {model_path}") -model = torch.load(model_path, map_location='cpu') +model = torch.load(model_path, map_location="cpu") -#print(model) +# print(model) # print all keys for key in model.keys(): print(key) - if key == 'hyper_parameters': - #print(model[key]) + if key == "hyper_parameters": + # print(model[key]) # dump as json pretty print(json.dumps(model[key], indent=4)) - #if key != 'state_dict' and key != 'optimizer_states': + # if key != 'state_dict' and key != 'optimizer_states': # print(model[key]) # Check if the loaded model is a state_dict or a model instance @@ -48,8 +48,9 @@ for key in state_dict.keys(): print(key) + # Ensure the state_dict is flat and contains only torch.Tensor objects -def flatten_state_dict(state_dict, parent_key='', sep='.'): +def flatten_state_dict(state_dict, parent_key="", sep="."): items = [] items_new = [] @@ -65,22 +66,24 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): for key, value in list(items): # keep only what we need for inference - if not key.startswith('state_dict.feature_extractor.encodec.quantizer.') and \ - not key.startswith('state_dict.backbone.') and \ - not key.startswith('state_dict.head.out'): - print('Skipping key: ', key) - continue + if ( + not key.startswith("state_dict.feature_extractor.encodec.quantizer.") + and not key.startswith("state_dict.backbone.") + and not key.startswith("state_dict.head.out") + ): + print("Skipping key: ", key) + continue new_key = key - new_key = new_key.replace('state_dict.', '') - new_key = new_key.replace('pos_net', 'posnet') + new_key = new_key.replace("state_dict.", "") + new_key = new_key.replace("pos_net", "posnet") # check if matches "backbone.posnet.%d.bias" or "backbone.posnet.%d.weight" if new_key.startswith("backbone.posnet."): match = re.match(r"backbone\.posnet\.(\d+)\.(bias|weight)", new_key) if match: - new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}" + new_key = f"backbone.posnet.{match.group(1)}.norm.{match.group(2)}" # "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed" -> "backbone.embedding.weight" if new_key == "feature_extractor.encodec.quantizer.vq.layers.0._codebook.embed": @@ -100,7 +103,15 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): new_key = new_key.replace("gamma", "gamma.weight") # convert from 1D [768] to 2D [768, 1] so that ggml_add can broadcast the bias - if (new_key.endswith("norm.weight") or new_key.endswith("norm1.weight") or new_key.endswith("norm2.weight") or new_key.endswith(".bias")) and (new_key.startswith("backbone.posnet") or new_key.startswith("backbone.embed.bias")): + if ( + new_key.endswith("norm.weight") + or new_key.endswith("norm1.weight") + or new_key.endswith("norm2.weight") + or new_key.endswith(".bias") + ) and ( + new_key.startswith("backbone.posnet") + or new_key.startswith("backbone.embed.bias") + ): value = value.unsqueeze(1) if new_key.endswith("dwconv.bias"): @@ -111,8 +122,8 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): size_total_mb += size_mb - #print(key, '->', new_key, ': ', value) - #print(key, '->', new_key) + # print(key, '->', new_key, ': ', value) + # print(key, '->', new_key) items_new.append((new_key, value)) @@ -120,11 +131,12 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): return dict(items_new) + flattened_state_dict = flatten_state_dict(state_dict) # Convert the model to the safetensors format -output_path = path_dst + '/model.safetensors' +output_path = path_dst + "/model.safetensors" save_file(flattened_state_dict, output_path) print(f"Model has been successfully converted and saved to {output_path}") @@ -133,27 +145,20 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): total_size = os.path.getsize(output_path) # Create the weight map -weight_map = { - "model.safetensors": ["*"] # Assuming all weights are in one file -} +weight_map = {"model.safetensors": ["*"]} # Assuming all weights are in one file # Create metadata for the index.json file -metadata = { - "total_size": total_size, - "weight_map": weight_map -} +metadata = {"total_size": total_size, "weight_map": weight_map} # Save the metadata to index.json -index_path = path_dst + '/index.json' -with open(index_path, 'w') as f: +index_path = path_dst + "/index.json" +with open(index_path, "w") as f: json.dump(metadata, f, indent=4) print(f"Metadata has been saved to {index_path}") config = { - "architectures": [ - "WavTokenizerDec" - ], + "architectures": ["WavTokenizerDec"], "hidden_size": 1282, "n_embd_features": 512, "n_ff": 2304, @@ -162,19 +167,13 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): "layer_norm_epsilon": 1e-6, "group_norm_epsilon": 1e-6, "group_norm_groups": 32, - "max_position_embeddings": 8192, # ? + "max_position_embeddings": 8192, # ? "n_layer": 12, - "posnet": { - "n_embd": 768, - "n_layer": 6 - }, - "convnext": { - "n_embd": 768, - "n_layer": 12 - }, + "posnet": {"n_embd": 768, "n_layer": 6}, + "convnext": {"n_embd": 768, "n_layer": 12}, } -with open(path_dst + '/config.json', 'w') as f: +with open(path_dst + "/config.json", "w") as f: json.dump(config, f, indent=4) print(f"Config has been saved to {path_dst + 'config.json'}") diff --git a/smallthinker/tools/tts/tts-outetts.py b/smallthinker/tools/tts/tts-outetts.py index 3791f9fc..023347d5 100644 --- a/smallthinker/tools/tts/tts-outetts.py +++ b/smallthinker/tools/tts/tts-outetts.py @@ -1,6 +1,7 @@ import sys -#import json -#import struct + +# import json +# import struct import requests import re import struct @@ -25,7 +26,7 @@ def fold(buffer, n_out, n_win, n_hop, n_pad): for i in range(n_frames): start = i * n_hop end = start + n_win - result[start:end] += buffer[i * n_win:(i + 1) * n_win] + result[start:end] += buffer[i * n_win : (i + 1) * n_win] return result[n_pad:-n_pad] if n_pad > 0 else result @@ -73,8 +74,8 @@ def embd_to_audio(embd, n_codes, n_embd, n_thread=4): results = list(executor.map(process_frame, args)) for l, (frame, hann2) in enumerate(results): - res[l*n_fft:(l+1)*n_fft] = frame - hann2_buffer[l*n_fft:(l+1)*n_fft] = hann2 + res[l * n_fft : (l + 1) * n_fft] = frame + hann2_buffer[l * n_fft : (l + 1) * n_fft] = hann2 audio = fold(res, n_out, n_win, n_hop, n_pad) env = fold(hann2_buffer, n_out, n_win, n_hop, n_pad) @@ -95,42 +96,47 @@ def save_wav(filename, audio_data, sample_rate): chunk_size = 36 + data_size # 36 = size of header minus first 8 bytes header = struct.pack( - '<4sI4s4sIHHIIHH4sI', - b'RIFF', + "<4sI4s4sIHHIIHH4sI", + b"RIFF", chunk_size, - b'WAVE', - b'fmt ', - 16, # fmt chunk size - 1, # audio format (PCM) + b"WAVE", + b"fmt ", + 16, # fmt chunk size + 1, # audio format (PCM) num_channels, sample_rate, byte_rate, block_align, bits_per_sample, - b'data', - data_size + b"data", + data_size, ) audio_data = np.clip(audio_data * 32767, -32768, 32767) pcm_data = audio_data.astype(np.int16) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(header) f.write(pcm_data.tobytes()) def process_text(text: str): - text = re.sub(r'\d+(\.\d+)?', lambda x: x.group(), text.lower()) # TODO this needs to be fixed - text = re.sub(r'[-_/,\.\\]', ' ', text) - text = re.sub(r'[^a-z\s]', '', text) - text = re.sub(r'\s+', ' ', text).strip() + text = re.sub( + r"\d+(\.\d+)?", lambda x: x.group(), text.lower() + ) # TODO this needs to be fixed + text = re.sub(r"[-_/,\.\\]", " ", text) + text = re.sub(r"[^a-z\s]", "", text) + text = re.sub(r"\s+", " ", text).strip() return text.split() + # usage: # python tts-outetts.py http://server-llm:port http://server-dec:port "text" if len(sys.argv) <= 3: - print("usage: python tts-outetts.py http://server-llm:port http://server-dec:port \"text\"") + print( + 'usage: python tts-outetts.py http://server-llm:port http://server-dec:port "text"' + ) exit(1) host_llm = sys.argv[1] @@ -146,100 +152,857 @@ def process_text(text: str): # voice data # TODO: load from json -#suffix = """<|audio_start|> -#the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> -#overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> -#package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> -#from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|> -#just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|> -#two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|> -#people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|> -#is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|> -#pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|> -#remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|> -#sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|> -#i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|> -#have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|> -#some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|> -#critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|> -#about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|> -#some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|> -#of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|> -#the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|> -#gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|> -#aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|> -#but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|> -#its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|> -#still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|> -#really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|> -#enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|> -#and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|> -#it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|> -#looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> -#lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>""" +# suffix = """<|audio_start|> +# the<|t_0.08|><|code_start|><|257|><|740|><|636|><|913|><|788|><|1703|><|code_end|> +# overall<|t_0.36|><|code_start|><|127|><|201|><|191|><|774|><|700|><|532|><|1056|><|557|><|798|><|298|><|1741|><|747|><|1662|><|1617|><|1702|><|1527|><|368|><|1588|><|1049|><|1008|><|1625|><|747|><|1576|><|728|><|1019|><|1696|><|1765|><|code_end|> +# package<|t_0.56|><|code_start|><|935|><|584|><|1319|><|627|><|1016|><|1491|><|1344|><|1117|><|1526|><|1040|><|239|><|1435|><|951|><|498|><|723|><|1180|><|535|><|789|><|1649|><|1637|><|78|><|465|><|1668|><|901|><|595|><|1675|><|117|><|1009|><|1667|><|320|><|840|><|79|><|507|><|1762|><|1508|><|1228|><|1768|><|802|><|1450|><|1457|><|232|><|639|><|code_end|> +# from<|t_0.19|><|code_start|><|604|><|782|><|1682|><|872|><|1532|><|1600|><|1036|><|1761|><|647|><|1554|><|1371|><|653|><|1595|><|950|><|code_end|> +# just<|t_0.25|><|code_start|><|1782|><|1670|><|317|><|786|><|1748|><|631|><|599|><|1155|><|1364|><|1524|><|36|><|1591|><|889|><|1535|><|541|><|440|><|1532|><|50|><|870|><|code_end|> +# two<|t_0.24|><|code_start|><|1681|><|1510|><|673|><|799|><|805|><|1342|><|330|><|519|><|62|><|640|><|1138|><|565|><|1552|><|1497|><|1552|><|572|><|1715|><|1732|><|code_end|> +# people<|t_0.39|><|code_start|><|593|><|274|><|136|><|740|><|691|><|633|><|1484|><|1061|><|1138|><|1485|><|344|><|428|><|397|><|1562|><|645|><|917|><|1035|><|1449|><|1669|><|487|><|442|><|1484|><|1329|><|1832|><|1704|><|600|><|761|><|653|><|269|><|code_end|> +# is<|t_0.16|><|code_start|><|566|><|583|><|1755|><|646|><|1337|><|709|><|802|><|1008|><|485|><|1583|><|652|><|10|><|code_end|> +# pretty<|t_0.32|><|code_start|><|1818|><|1747|><|692|><|733|><|1010|><|534|><|406|><|1697|><|1053|><|1521|><|1355|><|1274|><|816|><|1398|><|211|><|1218|><|817|><|1472|><|1703|><|686|><|13|><|822|><|445|><|1068|><|code_end|> +# remarkable<|t_0.68|><|code_start|><|230|><|1048|><|1705|><|355|><|706|><|1149|><|1535|><|1787|><|1356|><|1396|><|835|><|1583|><|486|><|1249|><|286|><|937|><|1076|><|1150|><|614|><|42|><|1058|><|705|><|681|><|798|><|934|><|490|><|514|><|1399|><|572|><|1446|><|1703|><|1346|><|1040|><|1426|><|1304|><|664|><|171|><|1530|><|625|><|64|><|1708|><|1830|><|1030|><|443|><|1509|><|1063|><|1605|><|1785|><|721|><|1440|><|923|><|code_end|> +# sure<|t_0.36|><|code_start|><|792|><|1780|><|923|><|1640|><|265|><|261|><|1525|><|567|><|1491|><|1250|><|1730|><|362|><|919|><|1766|><|543|><|1|><|333|><|113|><|970|><|252|><|1606|><|133|><|302|><|1810|><|1046|><|1190|><|1675|><|code_end|> +# i<|t_0.08|><|code_start|><|123|><|439|><|1074|><|705|><|1799|><|637|><|code_end|> +# have<|t_0.16|><|code_start|><|1509|><|599|><|518|><|1170|><|552|><|1029|><|1267|><|864|><|419|><|143|><|1061|><|0|><|code_end|> +# some<|t_0.16|><|code_start|><|619|><|400|><|1270|><|62|><|1370|><|1832|><|917|><|1661|><|167|><|269|><|1366|><|1508|><|code_end|> +# critiques<|t_0.60|><|code_start|><|559|><|584|><|1163|><|1129|><|1313|><|1728|><|721|><|1146|><|1093|><|577|><|928|><|27|><|630|><|1080|><|1346|><|1337|><|320|><|1382|><|1175|><|1682|><|1556|><|990|><|1683|><|860|><|1721|><|110|><|786|><|376|><|1085|><|756|><|1523|><|234|><|1334|><|1506|><|1578|><|659|><|612|><|1108|><|1466|><|1647|><|308|><|1470|><|746|><|556|><|1061|><|code_end|> +# about<|t_0.29|><|code_start|><|26|><|1649|><|545|><|1367|><|1263|><|1728|><|450|><|859|><|1434|><|497|><|1220|><|1285|><|179|><|755|><|1154|><|779|><|179|><|1229|><|1213|><|922|><|1774|><|1408|><|code_end|> +# some<|t_0.23|><|code_start|><|986|><|28|><|1649|><|778|><|858|><|1519|><|1|><|18|><|26|><|1042|><|1174|><|1309|><|1499|><|1712|><|1692|><|1516|><|1574|><|code_end|> +# of<|t_0.07|><|code_start|><|197|><|716|><|1039|><|1662|><|64|><|code_end|> +# the<|t_0.08|><|code_start|><|1811|><|1568|><|569|><|886|><|1025|><|1374|><|code_end|> +# gameplay<|t_0.48|><|code_start|><|1269|><|1092|><|933|><|1362|><|1762|><|1700|><|1675|><|215|><|781|><|1086|><|461|><|838|><|1022|><|759|><|649|><|1416|><|1004|><|551|><|909|><|787|><|343|><|830|><|1391|><|1040|><|1622|><|1779|><|1360|><|1231|><|1187|><|1317|><|76|><|997|><|989|><|978|><|737|><|189|><|code_end|> +# aspects<|t_0.56|><|code_start|><|1423|><|797|><|1316|><|1222|><|147|><|719|><|1347|><|386|><|1390|><|1558|><|154|><|440|><|634|><|592|><|1097|><|1718|><|712|><|763|><|1118|><|1721|><|1311|><|868|><|580|><|362|><|1435|><|868|><|247|><|221|><|886|><|1145|><|1274|><|1284|><|457|><|1043|><|1459|><|1818|><|62|><|599|><|1035|><|62|><|1649|><|778|><|code_end|> +# but<|t_0.20|><|code_start|><|780|><|1825|><|1681|><|1007|><|861|><|710|><|702|><|939|><|1669|><|1491|><|613|><|1739|><|823|><|1469|><|648|><|code_end|> +# its<|t_0.09|><|code_start|><|92|><|688|><|1623|><|962|><|1670|><|527|><|599|><|code_end|> +# still<|t_0.27|><|code_start|><|636|><|10|><|1217|><|344|><|713|><|957|><|823|><|154|><|1649|><|1286|><|508|><|214|><|1760|><|1250|><|456|><|1352|><|1368|><|921|><|615|><|5|><|code_end|> +# really<|t_0.36|><|code_start|><|55|><|420|><|1008|><|1659|><|27|><|644|><|1266|><|617|><|761|><|1712|><|109|><|1465|><|1587|><|503|><|1541|><|619|><|197|><|1019|><|817|><|269|><|377|><|362|><|1381|><|507|><|1488|><|4|><|1695|><|code_end|> +# enjoyable<|t_0.49|><|code_start|><|678|><|501|><|864|><|319|><|288|><|1472|><|1341|><|686|><|562|><|1463|><|619|><|1563|><|471|><|911|><|730|><|1811|><|1006|><|520|><|861|><|1274|><|125|><|1431|><|638|><|621|><|153|><|876|><|1770|><|437|><|987|><|1653|><|1109|><|898|><|1285|><|80|><|593|><|1709|><|843|><|code_end|> +# and<|t_0.15|><|code_start|><|1285|><|987|><|303|><|1037|><|730|><|1164|><|502|><|120|><|1737|><|1655|><|1318|><|code_end|> +# it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><|code_end|> +# looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> +# lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>""" # TODO: tokenization is slow for some reason - here is pre-tokenized input -suffix = [ 151667, 198, 1782, 155780, 151669, 151929, 152412, 152308, 152585, 152460, 153375, 151670, 198, 74455, - 155808, 151669, 151799, 151873, 151863, 152446, 152372, 152204, 152728, 152229, 152470, 151970, 153413, - 152419, 153334, 153289, 153374, 153199, 152040, 153260, 152721, 152680, 153297, 152419, 153248, 152400, - 152691, 153368, 153437, 151670, 198, 1722, 155828, 151669, 152607, 152256, 152991, 152299, 152688, 153163, - 153016, 152789, 153198, 152712, 151911, 153107, 152623, 152170, 152395, 152852, 152207, 152461, 153321, - 153309, 151750, 152137, 153340, 152573, 152267, 153347, 151789, 152681, 153339, 151992, 152512, 151751, - 152179, 153434, 153180, 152900, 153440, 152474, 153122, 153129, 151904, 152311, 151670, 198, 1499, 155791, - 151669, 152276, 152454, 153354, 152544, 153204, 153272, 152708, 153433, 152319, 153226, 153043, 152325, - 153267, 152622, 151670, 198, 4250, 155797, 151669, 153454, 153342, 151989, 152458, 153420, 152303, 152271, - 152827, 153036, 153196, 151708, 153263, 152561, 153207, 152213, 152112, 153204, 151722, 152542, 151670, 198, - 19789, 155796, 151669, 153353, 153182, 152345, 152471, 152477, 153014, 152002, 152191, 151734, 152312, 152810, - 152237, 153224, 153169, 153224, 152244, 153387, 153404, 151670, 198, 16069, 155811, 151669, 152265, 151946, - 151808, 152412, 152363, 152305, 153156, 152733, 152810, 153157, 152016, 152100, 152069, 153234, 152317, - 152589, 152707, 153121, 153341, 152159, 152114, 153156, 153001, 153504, 153376, 152272, 152433, 152325, - 151941, 151670, 198, 285, 155788, 151669, 152238, 152255, 153427, 152318, 153009, 152381, 152474, 152680, - 152157, 153255, 152324, 151682, 151670, 198, 32955, 155804, 151669, 153490, 153419, 152364, 152405, 152682, - 152206, 152078, 153369, 152725, 153193, 153027, 152946, 152488, 153070, 151883, 152890, 152489, 153144, - 153375, 152358, 151685, 152494, 152117, 152740, 151670, 198, 37448, 480, 155840, 151669, 151902, 152720, - 153377, 152027, 152378, 152821, 153207, 153459, 153028, 153068, 152507, 153255, 152158, 152921, 151958, - 152609, 152748, 152822, 152286, 151714, 152730, 152377, 152353, 152470, 152606, 152162, 152186, 153071, - 152244, 153118, 153375, 153018, 152712, 153098, 152976, 152336, 151843, 153202, 152297, 151736, 153380, - 153502, 152702, 152115, 153181, 152735, 153277, 153457, 152393, 153112, 152595, 151670, 198, 19098, 155808, - 151669, 152464, 153452, 152595, 153312, 151937, 151933, 153197, 152239, 153163, 152922, 153402, 152034, - 152591, 153438, 152215, 151673, 152005, 151785, 152642, 151924, 153278, 151805, 151974, 153482, 152718, - 152862, 153347, 151670, 198, 72, 155780, 151669, 151795, 152111, 152746, 152377, 153471, 152309, 151670, 198, - 19016, 155788, 151669, 153181, 152271, 152190, 152842, 152224, 152701, 152939, 152536, 152091, 151815, 152733, - 151672, 151670, 198, 14689, 155788, 151669, 152291, 152072, 152942, 151734, 153042, 153504, 152589, 153333, - 151839, 151941, 153038, 153180, 151670, 198, 36996, 8303, 155832, 151669, 152231, 152256, 152835, 152801, - 152985, 153400, 152393, 152818, 152765, 152249, 152600, 151699, 152302, 152752, 153018, 153009, 151992, - 153054, 152847, 153354, 153228, 152662, 153355, 152532, 153393, 151782, 152458, 152048, 152757, 152428, - 153195, 151906, 153006, 153178, 153250, 152331, 152284, 152780, 153138, 153319, 151980, 153142, 152418, - 152228, 152733, 151670, 198, 9096, 155801, 151669, 151698, 153321, 152217, 153039, 152935, 153400, 152122, - 152531, 153106, 152169, 152892, 152957, 151851, 152427, 152826, 152451, 151851, 152901, 152885, 152594, - 153446, 153080, 151670, 198, 14689, 155795, 151669, 152658, 151700, 153321, 152450, 152530, 153191, 151673, - 151690, 151698, 152714, 152846, 152981, 153171, 153384, 153364, 153188, 153246, 151670, 198, 1055, 155779, - 151669, 151869, 152388, 152711, 153334, 151736, 151670, 198, 1782, 155780, 151669, 153483, 153240, 152241, - 152558, 152697, 153046, 151670, 198, 5804, 1363, 155820, 151669, 152941, 152764, 152605, 153034, 153434, - 153372, 153347, 151887, 152453, 152758, 152133, 152510, 152694, 152431, 152321, 153088, 152676, 152223, - 152581, 152459, 152015, 152502, 153063, 152712, 153294, 153451, 153032, 152903, 152859, 152989, 151748, - 152669, 152661, 152650, 152409, 151861, 151670, 198, 300, 7973, 155828, 151669, 153095, 152469, 152988, - 152894, 151819, 152391, 153019, 152058, 153062, 153230, 151826, 152112, 152306, 152264, 152769, 153390, - 152384, 152435, 152790, 153393, 152983, 152540, 152252, 152034, 153107, 152540, 151919, 151893, 152558, - 152817, 152946, 152956, 152129, 152715, 153131, 153490, 151734, 152271, 152707, 151734, 153321, 152450, - 151670, 198, 8088, 155792, 151669, 152452, 153497, 153353, 152679, 152533, 152382, 152374, 152611, 153341, - 153163, 152285, 153411, 152495, 153141, 152320, 151670, 198, 1199, 155781, 151669, 151764, 152360, 153295, - 152634, 153342, 152199, 152271, 151670, 198, 43366, 155799, 151669, 152308, 151682, 152889, 152016, 152385, - 152629, 152495, 151826, 153321, 152958, 152180, 151886, 153432, 152922, 152128, 153024, 153040, 152593, - 152287, 151677, 151670, 198, 53660, 155808, 151669, 151727, 152092, 152680, 153331, 151699, 152316, 152938, - 152289, 152433, 153384, 151781, 153137, 153259, 152175, 153213, 152291, 151869, 152691, 152489, 151941, - 152049, 152034, 153053, 152179, 153160, 151676, 153367, 151670, 198, 268, 4123, 480, 155821, 151669, 152350, - 152173, 152536, 151991, 151960, 153144, 153013, 152358, 152234, 153135, 152291, 153235, 152143, 152583, - 152402, 153483, 152678, 152192, 152533, 152946, 151797, 153103, 152310, 152293, 151825, 152548, 153442, - 152109, 152659, 153325, 152781, 152570, 152957, 151752, 152265, 153381, 152515, 151670, 198, 437, 155787, - 151669, 152957, 152659, 151975, 152709, 152402, 152836, 152174, 151792, 153409, 153327, 152990, 151670, 198, - 275, 155781, 151669, 152520, 153038, 152067, 153273, 153185, 152265, 152974, 151670, 198, 94273, 155799, - 151669, 152953, 152938, 153427, 152244, 151920, 153423, 152929, 152367, 153052, 152129, 152331, 152257, - 152987, 152777, 153448, 152408, 151696, 152408, 152326, 152699, 151670, 198, 385, 16239, 155828, 151669, - 152306, 152268, 153438, 153228, 152978, 152957, 153153, 153393, 152795, 152110, 152918, 152923, 152467, - 152331, 153053, 153330, 151889, 153444, 152234, 152624, 151779, 152801, 152784, 152139, 152222, 152751, - 152512, 153287, 153141, 153052, 151840, 152589, 152508, 153499, 152109, 152255, 151739, 152267, 152759, - 153318, 153165, 153349, 151670, ] +suffix = [ + 151667, + 198, + 1782, + 155780, + 151669, + 151929, + 152412, + 152308, + 152585, + 152460, + 153375, + 151670, + 198, + 74455, + 155808, + 151669, + 151799, + 151873, + 151863, + 152446, + 152372, + 152204, + 152728, + 152229, + 152470, + 151970, + 153413, + 152419, + 153334, + 153289, + 153374, + 153199, + 152040, + 153260, + 152721, + 152680, + 153297, + 152419, + 153248, + 152400, + 152691, + 153368, + 153437, + 151670, + 198, + 1722, + 155828, + 151669, + 152607, + 152256, + 152991, + 152299, + 152688, + 153163, + 153016, + 152789, + 153198, + 152712, + 151911, + 153107, + 152623, + 152170, + 152395, + 152852, + 152207, + 152461, + 153321, + 153309, + 151750, + 152137, + 153340, + 152573, + 152267, + 153347, + 151789, + 152681, + 153339, + 151992, + 152512, + 151751, + 152179, + 153434, + 153180, + 152900, + 153440, + 152474, + 153122, + 153129, + 151904, + 152311, + 151670, + 198, + 1499, + 155791, + 151669, + 152276, + 152454, + 153354, + 152544, + 153204, + 153272, + 152708, + 153433, + 152319, + 153226, + 153043, + 152325, + 153267, + 152622, + 151670, + 198, + 4250, + 155797, + 151669, + 153454, + 153342, + 151989, + 152458, + 153420, + 152303, + 152271, + 152827, + 153036, + 153196, + 151708, + 153263, + 152561, + 153207, + 152213, + 152112, + 153204, + 151722, + 152542, + 151670, + 198, + 19789, + 155796, + 151669, + 153353, + 153182, + 152345, + 152471, + 152477, + 153014, + 152002, + 152191, + 151734, + 152312, + 152810, + 152237, + 153224, + 153169, + 153224, + 152244, + 153387, + 153404, + 151670, + 198, + 16069, + 155811, + 151669, + 152265, + 151946, + 151808, + 152412, + 152363, + 152305, + 153156, + 152733, + 152810, + 153157, + 152016, + 152100, + 152069, + 153234, + 152317, + 152589, + 152707, + 153121, + 153341, + 152159, + 152114, + 153156, + 153001, + 153504, + 153376, + 152272, + 152433, + 152325, + 151941, + 151670, + 198, + 285, + 155788, + 151669, + 152238, + 152255, + 153427, + 152318, + 153009, + 152381, + 152474, + 152680, + 152157, + 153255, + 152324, + 151682, + 151670, + 198, + 32955, + 155804, + 151669, + 153490, + 153419, + 152364, + 152405, + 152682, + 152206, + 152078, + 153369, + 152725, + 153193, + 153027, + 152946, + 152488, + 153070, + 151883, + 152890, + 152489, + 153144, + 153375, + 152358, + 151685, + 152494, + 152117, + 152740, + 151670, + 198, + 37448, + 480, + 155840, + 151669, + 151902, + 152720, + 153377, + 152027, + 152378, + 152821, + 153207, + 153459, + 153028, + 153068, + 152507, + 153255, + 152158, + 152921, + 151958, + 152609, + 152748, + 152822, + 152286, + 151714, + 152730, + 152377, + 152353, + 152470, + 152606, + 152162, + 152186, + 153071, + 152244, + 153118, + 153375, + 153018, + 152712, + 153098, + 152976, + 152336, + 151843, + 153202, + 152297, + 151736, + 153380, + 153502, + 152702, + 152115, + 153181, + 152735, + 153277, + 153457, + 152393, + 153112, + 152595, + 151670, + 198, + 19098, + 155808, + 151669, + 152464, + 153452, + 152595, + 153312, + 151937, + 151933, + 153197, + 152239, + 153163, + 152922, + 153402, + 152034, + 152591, + 153438, + 152215, + 151673, + 152005, + 151785, + 152642, + 151924, + 153278, + 151805, + 151974, + 153482, + 152718, + 152862, + 153347, + 151670, + 198, + 72, + 155780, + 151669, + 151795, + 152111, + 152746, + 152377, + 153471, + 152309, + 151670, + 198, + 19016, + 155788, + 151669, + 153181, + 152271, + 152190, + 152842, + 152224, + 152701, + 152939, + 152536, + 152091, + 151815, + 152733, + 151672, + 151670, + 198, + 14689, + 155788, + 151669, + 152291, + 152072, + 152942, + 151734, + 153042, + 153504, + 152589, + 153333, + 151839, + 151941, + 153038, + 153180, + 151670, + 198, + 36996, + 8303, + 155832, + 151669, + 152231, + 152256, + 152835, + 152801, + 152985, + 153400, + 152393, + 152818, + 152765, + 152249, + 152600, + 151699, + 152302, + 152752, + 153018, + 153009, + 151992, + 153054, + 152847, + 153354, + 153228, + 152662, + 153355, + 152532, + 153393, + 151782, + 152458, + 152048, + 152757, + 152428, + 153195, + 151906, + 153006, + 153178, + 153250, + 152331, + 152284, + 152780, + 153138, + 153319, + 151980, + 153142, + 152418, + 152228, + 152733, + 151670, + 198, + 9096, + 155801, + 151669, + 151698, + 153321, + 152217, + 153039, + 152935, + 153400, + 152122, + 152531, + 153106, + 152169, + 152892, + 152957, + 151851, + 152427, + 152826, + 152451, + 151851, + 152901, + 152885, + 152594, + 153446, + 153080, + 151670, + 198, + 14689, + 155795, + 151669, + 152658, + 151700, + 153321, + 152450, + 152530, + 153191, + 151673, + 151690, + 151698, + 152714, + 152846, + 152981, + 153171, + 153384, + 153364, + 153188, + 153246, + 151670, + 198, + 1055, + 155779, + 151669, + 151869, + 152388, + 152711, + 153334, + 151736, + 151670, + 198, + 1782, + 155780, + 151669, + 153483, + 153240, + 152241, + 152558, + 152697, + 153046, + 151670, + 198, + 5804, + 1363, + 155820, + 151669, + 152941, + 152764, + 152605, + 153034, + 153434, + 153372, + 153347, + 151887, + 152453, + 152758, + 152133, + 152510, + 152694, + 152431, + 152321, + 153088, + 152676, + 152223, + 152581, + 152459, + 152015, + 152502, + 153063, + 152712, + 153294, + 153451, + 153032, + 152903, + 152859, + 152989, + 151748, + 152669, + 152661, + 152650, + 152409, + 151861, + 151670, + 198, + 300, + 7973, + 155828, + 151669, + 153095, + 152469, + 152988, + 152894, + 151819, + 152391, + 153019, + 152058, + 153062, + 153230, + 151826, + 152112, + 152306, + 152264, + 152769, + 153390, + 152384, + 152435, + 152790, + 153393, + 152983, + 152540, + 152252, + 152034, + 153107, + 152540, + 151919, + 151893, + 152558, + 152817, + 152946, + 152956, + 152129, + 152715, + 153131, + 153490, + 151734, + 152271, + 152707, + 151734, + 153321, + 152450, + 151670, + 198, + 8088, + 155792, + 151669, + 152452, + 153497, + 153353, + 152679, + 152533, + 152382, + 152374, + 152611, + 153341, + 153163, + 152285, + 153411, + 152495, + 153141, + 152320, + 151670, + 198, + 1199, + 155781, + 151669, + 151764, + 152360, + 153295, + 152634, + 153342, + 152199, + 152271, + 151670, + 198, + 43366, + 155799, + 151669, + 152308, + 151682, + 152889, + 152016, + 152385, + 152629, + 152495, + 151826, + 153321, + 152958, + 152180, + 151886, + 153432, + 152922, + 152128, + 153024, + 153040, + 152593, + 152287, + 151677, + 151670, + 198, + 53660, + 155808, + 151669, + 151727, + 152092, + 152680, + 153331, + 151699, + 152316, + 152938, + 152289, + 152433, + 153384, + 151781, + 153137, + 153259, + 152175, + 153213, + 152291, + 151869, + 152691, + 152489, + 151941, + 152049, + 152034, + 153053, + 152179, + 153160, + 151676, + 153367, + 151670, + 198, + 268, + 4123, + 480, + 155821, + 151669, + 152350, + 152173, + 152536, + 151991, + 151960, + 153144, + 153013, + 152358, + 152234, + 153135, + 152291, + 153235, + 152143, + 152583, + 152402, + 153483, + 152678, + 152192, + 152533, + 152946, + 151797, + 153103, + 152310, + 152293, + 151825, + 152548, + 153442, + 152109, + 152659, + 153325, + 152781, + 152570, + 152957, + 151752, + 152265, + 153381, + 152515, + 151670, + 198, + 437, + 155787, + 151669, + 152957, + 152659, + 151975, + 152709, + 152402, + 152836, + 152174, + 151792, + 153409, + 153327, + 152990, + 151670, + 198, + 275, + 155781, + 151669, + 152520, + 153038, + 152067, + 153273, + 153185, + 152265, + 152974, + 151670, + 198, + 94273, + 155799, + 151669, + 152953, + 152938, + 153427, + 152244, + 151920, + 153423, + 152929, + 152367, + 153052, + 152129, + 152331, + 152257, + 152987, + 152777, + 153448, + 152408, + 151696, + 152408, + 152326, + 152699, + 151670, + 198, + 385, + 16239, + 155828, + 151669, + 152306, + 152268, + 153438, + 153228, + 152978, + 152957, + 153153, + 153393, + 152795, + 152110, + 152918, + 152923, + 152467, + 152331, + 153053, + 153330, + 151889, + 153444, + 152234, + 152624, + 151779, + 152801, + 152784, + 152139, + 152222, + 152751, + 152512, + 153287, + 153141, + 153052, + 151840, + 152589, + 152508, + 153499, + 152109, + 152255, + 151739, + 152267, + 152759, + 153318, + 153165, + 153349, + 151670, +] response = requests.post( host_llm + "/completion", @@ -251,15 +1014,15 @@ def process_text(text: str): "samplers": ["top_k"], "top_k": 16, "seed": 1003, - } + }, ) response_json = response.json() -#print(json.dumps(response_json, indent=4)) -#print(json.dumps(response_json["prompt"], indent=4).replace("\\n", "\n")) -#print(json.dumps(response_json["timings"], indent=4)) -#print(json.dumps(response_json["tokens"], indent=4)) +# print(json.dumps(response_json, indent=4)) +# print(json.dumps(response_json["prompt"], indent=4).replace("\\n", "\n")) +# print(json.dumps(response_json["timings"], indent=4)) +# print(json.dumps(response_json["tokens"], indent=4)) codes = response_json["tokens"] @@ -269,12 +1032,12 @@ def process_text(text: str): host_dec + "/embeddings", json={ "input": [*codes], - } + }, ) response_json = response.json() -#print(json.dumps(response_json, indent=4)) +# print(json.dumps(response_json, indent=4)) # spectrogram embd = response_json[0]["embedding"] @@ -282,18 +1045,18 @@ def process_text(text: str): n_codes = len(embd) n_embd = len(embd[0]) -print('spectrogram generated: n_codes: %d, n_embd: %d' % (n_codes, n_embd)) +print("spectrogram generated: n_codes: %d, n_embd: %d" % (n_codes, n_embd)) # post-process the spectrogram to convert to audio -print('converting to audio ...') +print("converting to audio ...") audio = embd_to_audio(embd, n_codes, n_embd) -print('audio generated: %d samples' % len(audio)) +print("audio generated: %d samples" % len(audio)) filename = "output.wav" -sample_rate = 24000 # sampling rate +sample_rate = 24000 # sampling rate # zero out first 0.25 seconds -audio[:24000 // 4] = 0.0 +audio[: 24000 // 4] = 0.0 save_wav(filename, audio, sample_rate) print('audio written to file "%s"' % filename) diff --git a/spm-headers/ggml.h b/spm-headers/ggml.h deleted file mode 120000 index 39215298..00000000 --- a/spm-headers/ggml.h +++ /dev/null @@ -1 +0,0 @@ -../ggml.h \ No newline at end of file diff --git a/spm-headers/ggml.h b/spm-headers/ggml.h new file mode 100644 index 00000000..2d10c431 --- /dev/null +++ b/spm-headers/ggml.h @@ -0,0 +1,2251 @@ +#pragma once + +// +// GGML Tensor Library +// +// This documentation is still a work in progress. +// If you wish some specific topics to be covered, feel free to drop a comment: +// +// https://github.com/ggerganov/whisper.cpp/issues/40 +// +// ## Overview +// +// This library implements: +// +// - a set of tensor operations +// - automatic differentiation +// - basic optimization algorithms +// +// The aim of this library is to provide a minimalistic approach for various machine learning tasks. This includes, +// but is not limited to, the following: +// +// - linear regression +// - support vector machines +// - neural networks +// +// The library allows the user to define a certain function using the available tensor operations. This function +// definition is represented internally via a computation graph. Each tensor operation in the function definition +// corresponds to a node in the graph. Having the computation graph defined, the user can choose to compute the +// function's value and/or its gradient with respect to the input variables. Optionally, the function can be optimized +// using one of the available optimization algorithms. +// +// For example, here we define the function: f(x) = a*x^2 + b +// +// { +// struct ggml_init_params params = { +// .mem_size = 16*1024*1024, +// .mem_buffer = NULL, +// }; +// +// // memory allocation happens here +// struct ggml_context * ctx = ggml_init(params); +// +// struct ggml_tensor * x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// +// ggml_set_param(ctx, x); // x is an input variable +// +// struct ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); +// struct ggml_tensor * x2 = ggml_mul(ctx, x, x); +// struct ggml_tensor * f = ggml_add(ctx, ggml_mul(ctx, a, x2), b); +// +// ... +// } +// +// Notice that the function definition above does not involve any actual computation. The computation is performed only +// when the user explicitly requests it. For example, to compute the function's value at x = 2.0: +// +// { +// ... +// +// struct ggml_cgraph * gf = ggml_new_graph(ctx); +// ggml_build_forward_expand(gf, f); +// +// // set the input variable and parameter values +// ggml_set_f32(x, 2.0f); +// ggml_set_f32(a, 3.0f); +// ggml_set_f32(b, 4.0f); +// +// ggml_graph_compute_with_ctx(ctx, &gf, n_threads); +// +// printf("f = %f\n", ggml_get_f32_1d(f, 0)); +// +// ... +// } +// +// The actual computation is performed in the ggml_graph_compute() function. +// +// The ggml_new_tensor_...() functions create new tensors. They are allocated in the memory buffer provided to the +// ggml_init() function. You have to be careful not to exceed the memory buffer size. Therefore, you have to know +// in advance how much memory you need for your computation. Alternatively, you can allocate a large enough memory +// and after defining the computation graph, call the ggml_used_mem() function to find out how much memory was +// actually needed. +// +// The ggml_set_param() function marks a tensor as an input variable. This is used by the automatic +// differentiation and optimization algorithms. +// +// The described approach allows to define the function graph once and then compute its forward or backward graphs +// multiple times. All computations will use the same memory buffer allocated in the ggml_init() function. This way +// the user can avoid the memory allocation overhead at runtime. +// +// The library supports multi-dimensional tensors - up to 4 dimensions. The FP16 and FP32 data types are first class +// citizens, but in theory the library can be extended to support FP8 and integer data types. +// +// Each tensor operation produces a new tensor. Initially the library was envisioned to support only the use of unary +// and binary operations. Most of the available operations fall into one of these two categories. With time, it became +// clear that the library needs to support more complex operations. The way to support these operations is not clear +// yet, but a few examples are demonstrated in the following operations: +// +// - ggml_permute() +// - ggml_conv_1d_1s() +// - ggml_conv_1d_2s() +// +// For each tensor operator, the library implements a forward and backward computation function. The forward function +// computes the output tensor value given the input tensor values. The backward function computes the adjoint of the +// input tensors given the adjoint of the output tensor. For a detailed explanation of what this means, take a +// calculus class, or watch the following video: +// +// What is Automatic Differentiation? +// https://www.youtube.com/watch?v=wG_nF1awSSY +// +// +// ## Tensor data (struct ggml_tensor) +// +// The tensors are stored in memory via the ggml_tensor struct. The structure provides information about the size of +// the tensor, the data type, and the memory buffer where the tensor data is stored. Additionally, it contains +// pointers to the "source" tensors - i.e. the tensors that were used to compute the current tensor. For example: +// +// { +// struct ggml_tensor * c = ggml_add(ctx, a, b); +// +// assert(c->src[0] == a); +// assert(c->src[1] == b); +// } +// +// The multi-dimensional tensors are stored in row-major order. The ggml_tensor struct contains fields for the +// number of elements in each dimension ("ne") as well as the number of bytes ("nb", a.k.a. stride). This allows +// to store tensors that are not contiguous in memory, which is useful for operations such as transposition and +// permutation. All tensor operations have to take the stride into account and not assume that the tensor is +// contiguous in memory. +// +// The data of the tensor is accessed via the "data" pointer. For example: +// +// { +// const int nx = 2; +// const int ny = 3; +// +// struct ggml_tensor * a = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, ny); +// +// for (int y = 0; y < ny; y++) { +// for (int x = 0; x < nx; x++) { +// *(float *) ((char *) a->data + y*a->nb[1] + x*a->nb[0]) = x + y; +// } +// } +// +// ... +// } +// +// Alternatively, there are helper functions, such as ggml_get_f32_1d() and ggml_set_f32_1d() that can be used. +// +// ## The matrix multiplication operator (ggml_mul_mat) +// +// TODO +// +// +// ## Multi-threading +// +// TODO +// +// +// ## Overview of ggml.c +// +// TODO +// +// +// ## SIMD optimizations +// +// TODO +// +// +// ## Debugging ggml +// +// TODO +// +// + +#ifdef GGML_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BUILD +# define GGML_API __declspec(dllexport) +# else +# define GGML_API __declspec(dllimport) +# endif +# else +# define GGML_API __attribute__ ((visibility ("default"))) +# endif +#else +# define GGML_API +#endif + +// TODO: support for clang +#ifdef __GNUC__ +# define GGML_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define GGML_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define GGML_DEPRECATED(func, hint) func +#endif + +#ifndef __GNUC__ +# define GGML_ATTRIBUTE_FORMAT(...) +#elif defined(__MINGW32__) +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +# define GGML_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif + +#include +#include +#include +#ifdef __cplusplus + #include + using std::atomic_int; + using std::memory_order; + using std::memory_order_acquire; +#else /* not __cplusplus */ +#if defined(_WIN32) +# include "atomic_windows.h" +#else +# include +#endif +#endif /* __cplusplus */ + +#define GGML_FILE_MAGIC 0x67676d6c // "ggml" +#define GGML_FILE_VERSION 1 + +#define GGML_QNT_VERSION 2 // bump this on quantization format changes +#define GGML_QNT_VERSION_FACTOR 1000 // do not change this + +#define GGML_MAX_DIMS 4 +#define GGML_MAX_PARAMS 1024 +#define GGML_MAX_CONTEXTS 64 +#define GGML_MAX_SRC 6 +#define GGML_MAX_NAME 64 +#define GGML_MAX_OP_PARAMS 64 +#define GGML_DEFAULT_N_THREADS 4 +#define GGML_DEFAULT_GRAPH_SIZE 2048 +#if UINTPTR_MAX == 0xFFFFFFFF + #define GGML_MEM_ALIGN 4 +#else + #define GGML_MEM_ALIGN 16 +#endif + +#define GGML_EXIT_SUCCESS 0 +#define GGML_EXIT_ABORTED 1 + +#define GGUF_MAGIC "GGUF" +#define GGUF_POWERINFER_MAGIC "PWRI" + +#define GGUF_VERSION 3 + +#define GGUF_DEFAULT_ALIGNMENT 32 + +#define GGML_UNUSED(x) (void)(x) + +#define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) + +#define GGML_ASSERT(x) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + fflush(stderr); \ + fflush(stdout); \ + ggml_print_backtrace(); \ + exit(1); \ + } \ + } while (0) + +#define GGML_ASSERT_DBG(x, s, ...) \ + do { \ + if (!(x)) { \ + fprintf(stderr, "GGML_ASSERT: %s:%d: " s "\n", __FILE__, __LINE__, ##__VA_ARGS__); \ + fflush(stderr); \ + fflush(stdout); \ + ggml_print_backtrace(); \ + exit(1); \ + } \ + } while (0) + +#ifndef NDEBUG +#define GGML_UNREACHABLE() GGML_ASSERT(!"statement should not be reached") +#elif defined(__GNUC__) +#define GGML_UNREACHABLE() __builtin_unreachable() +#else +#define GGML_UNREACHABLE() ((void) 0) +#endif + +// used to copy the number of elements and stride in bytes of tensors into local variables. +// main purpose is to reduce code duplication and improve readability. +// +// example: +// +// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne); +// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb); +// +#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \ + const type prefix##0 = (pointer)->array[0]; \ + GGML_UNUSED(prefix##0); +#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \ + const type prefix##1 = (pointer)->array[1]; \ + GGML_UNUSED(prefix##1); +#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \ + const type prefix##2 = (pointer)->array[2]; \ + GGML_UNUSED(prefix##2); +#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \ + GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \ + const type prefix##3 = (pointer)->array[3]; \ + GGML_UNUSED(prefix##3); + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_NEON) && defined(__CUDACC__) + typedef half ggml_fp16_t; +#elif defined(__ARM_NEON) + typedef __fp16 ggml_fp16_t; +#else + typedef uint16_t ggml_fp16_t; +#endif + + // convert FP16 <-> FP32 + GGML_API float ggml_fp16_to_fp32(ggml_fp16_t x); + GGML_API ggml_fp16_t ggml_fp32_to_fp16(float x); + + GGML_API void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int n); + GGML_API void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int n); + + struct ggml_object; + struct ggml_context; + + enum ggml_type { + GGML_TYPE_F32 = 0, + GGML_TYPE_F16 = 1, + GGML_TYPE_Q4_0 = 2, + GGML_TYPE_Q4_1 = 3, + // GGML_TYPE_Q4_2 = 4, support has been removed + // GGML_TYPE_Q4_3 (5) support has been removed + GGML_TYPE_Q5_0 = 6, + GGML_TYPE_Q5_1 = 7, + GGML_TYPE_Q8_0 = 8, + GGML_TYPE_Q8_1 = 9, + // k-quantizations + GGML_TYPE_Q2_K = 10, + GGML_TYPE_Q3_K = 11, + GGML_TYPE_Q4_K = 12, + GGML_TYPE_Q5_K = 13, + GGML_TYPE_Q6_K = 14, + GGML_TYPE_Q8_K = 15, + GGML_TYPE_I8, + GGML_TYPE_I16, + GGML_TYPE_I32, + GGML_TYPE_COUNT, + }; + + enum ggml_backend_type { + GGML_BACKEND_CPU = 0, + GGML_BACKEND_GPU = 10, + GGML_BACKEND_GPU_SPLIT = 20, + }; + + enum ggml_sparse_deriv { + GGML_DENSE_INFERENCE = 0, + GGML_SPARSE_INFERENCE = 1, + }; + + // model file types + enum ggml_ftype { + GGML_FTYPE_UNKNOWN = -1, + GGML_FTYPE_ALL_F32 = 0, + GGML_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + GGML_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + GGML_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + GGML_FTYPE_MOSTLY_Q3_K = 11, // except 1d tensors + GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors + GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors + GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors + }; + + // available tensor operations: + enum ggml_op { + GGML_OP_NONE = 0, + + GGML_OP_DUP, + GGML_OP_ADD, + GGML_OP_ADD1, + GGML_OP_ACC, + GGML_OP_SUB, + GGML_OP_MUL, + GGML_OP_DIV, + GGML_OP_SQR, + GGML_OP_SQRT, + GGML_OP_LOG, + GGML_OP_SUM, + GGML_OP_SUM_ROWS, + GGML_OP_MEAN, + GGML_OP_ARGMAX, + GGML_OP_REPEAT, + GGML_OP_REPEAT_BACK, + GGML_OP_CONCAT, + GGML_OP_SILU_BACK, + GGML_OP_NORM, // normalize + GGML_OP_RMS_NORM, + GGML_OP_RMS_NORM_BACK, + GGML_OP_GROUP_NORM, + + GGML_OP_MUL_MAT, + GGML_OP_MUL_MAT_SPARSE, + GGML_OP_AXPY, + GGML_OP_OUT_PROD, + + GGML_OP_SCALE, + GGML_OP_SET, + GGML_OP_CPY, + GGML_OP_CONT, + GGML_OP_RESHAPE, + GGML_OP_VIEW, + GGML_OP_PERMUTE, + GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, + GGML_OP_GET_ROWS_BACK, + GGML_OP_DIAG, + GGML_OP_DIAG_MASK_INF, + GGML_OP_DIAG_MASK_ZERO, + GGML_OP_SOFT_MAX, + GGML_OP_SOFT_MAX_BACK, + GGML_OP_ROPE, + GGML_OP_ROPE_BACK, + GGML_OP_ALIBI, + GGML_OP_CLAMP, + GGML_OP_CONV_TRANSPOSE_1D, + GGML_OP_IM2COL, + GGML_OP_CONV_TRANSPOSE_2D, + GGML_OP_POOL_1D, + GGML_OP_POOL_2D, + + GGML_OP_UPSCALE, // nearest interpolate + + GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_FF, + GGML_OP_FLASH_ATTN_BACK, + GGML_OP_WIN_PART, + GGML_OP_WIN_UNPART, + GGML_OP_GET_REL_POS, + GGML_OP_ADD_REL_POS, + + GGML_OP_UNARY, + + GGML_OP_MAP_UNARY, + GGML_OP_MAP_BINARY, + + GGML_OP_MAP_CUSTOM1_F32, + GGML_OP_MAP_CUSTOM2_F32, + GGML_OP_MAP_CUSTOM3_F32, + + GGML_OP_MAP_CUSTOM1, + GGML_OP_MAP_CUSTOM2, + GGML_OP_MAP_CUSTOM3, + + GGML_OP_CROSS_ENTROPY_LOSS, + GGML_OP_CROSS_ENTROPY_LOSS_BACK, + + GGML_OP_COUNT, + }; + + enum ggml_unary_op { + GGML_UNARY_OP_ABS, + GGML_UNARY_OP_SGN, + GGML_UNARY_OP_NEG, + GGML_UNARY_OP_STEP, + GGML_UNARY_OP_TANH, + GGML_UNARY_OP_ELU, + GGML_UNARY_OP_RELU, + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_GELU_QUICK, + GGML_UNARY_OP_SILU, + GGML_UNARY_OP_LEAKY + }; + + enum ggml_object_type { + GGML_OBJECT_TENSOR, + GGML_OBJECT_GRAPH, + GGML_OBJECT_WORK_BUFFER + }; + + enum ggml_log_level { + GGML_LOG_LEVEL_ERROR = 2, + GGML_LOG_LEVEL_WARN = 3, + GGML_LOG_LEVEL_INFO = 4 + }; + + // ggml object + struct ggml_object { + size_t offs; + size_t size; + + struct ggml_object * next; + + enum ggml_object_type type; + + char padding[4]; + }; + + static const size_t GGML_OBJECT_SIZE = sizeof(struct ggml_object); + + // n-dimensional tensor + struct ggml_tensor { + enum ggml_type type; + enum ggml_backend_type backend; + + struct ggml_backend_buffer * buffer; + + int n_dims; + int64_t ne[GGML_MAX_DIMS]; // number of elements + size_t nb[GGML_MAX_DIMS]; // stride in bytes: + // nb[0] = ggml_type_size(type) + // nb[1] = nb[0] * (ne[0] / ggml_blck_size(type)) + padding + // nb[i] = nb[i-1] * ne[i-1] + + // compute data + enum ggml_op op; + + // op params - allocated as int32_t for alignment + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + + bool is_param; + + struct ggml_tensor * grad; + struct ggml_tensor * src[GGML_MAX_SRC]; + + // performance + atomic_int is_finish; + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + + struct ggml_tensor * view_src; + size_t view_offs; + + void * data; + + char name[GGML_MAX_NAME]; + + void * extra; // extra things e.g. for ggml-cuda.cu + + char padding[12]; + }; + + + static const int64_t GGML_NE_WILDCARD = -1; + + static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); + + // the compute plan that needs to be prepared for ggml_graph_compute() + // since https://github.com/ggerganov/ggml/issues/287 + struct ggml_cplan { + size_t work_size; // size of work buffer, calculated by `ggml_graph_plan()` + uint8_t * work_data; // work buffer, to be allocated by caller before calling to `ggml_graph_compute()` + + int n_threads; + + // abort ggml_graph_compute when true + bool (*abort_callback)(void * data); + void * abort_callback_data; + }; + + enum ggml_cgraph_eval_order { + GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT = 0, + GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT, + GGML_CGRAPH_EVAL_ORDER_COUNT + }; + + struct ggml_hash_set { + size_t size; + struct ggml_tensor ** keys; + }; + + // computation graph + struct ggml_cgraph { + int size; + int n_nodes; + int n_leafs; + + struct ggml_tensor ** nodes; + struct ggml_tensor ** grads; + struct ggml_tensor ** leafs; + + struct ggml_hash_set visited_hash_table; + + enum ggml_cgraph_eval_order order; + + // performance + int perf_runs; + int64_t perf_cycles; + int64_t perf_time_us; + }; + + // scratch buffer + struct ggml_scratch { + size_t offs; + size_t size; + void * data; + }; + + struct ggml_context { + size_t mem_size; + void * mem_buffer; + bool mem_buffer_owned; + bool no_alloc; + bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers + + int n_objects; + + struct ggml_object * objects_begin; + struct ggml_object * objects_end; + + struct ggml_scratch scratch; + struct ggml_scratch scratch_save; + }; + + struct ggml_init_params { + // memory pool + size_t mem_size; // bytes + void * mem_buffer; // if NULL, memory will be allocated internally + bool no_alloc; // don't allocate memory for the tensor data + }; + + + // compute types + + // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled. + // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995. + enum ggml_task_type { + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, + }; + + struct ggml_compute_params { + enum ggml_task_type type; + + // ith = thread index, nth = number of threads + int ith, nth; + + // work buffer for all threads + size_t wsize; + void * wdata; + atomic_int *aic; + }; + + // misc + + GGML_API void ggml_time_init(void); // call this once at the beginning of the program + GGML_API int64_t ggml_time_ms(void); + GGML_API int64_t ggml_time_us(void); + GGML_API int64_t ggml_cycles(void); + GGML_API int64_t ggml_cycles_per_ms(void); + + GGML_API void ggml_print_backtrace(void); + + GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems + GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node + + GGML_API void ggml_print_object (const struct ggml_object * obj); + GGML_API void ggml_print_objects(const struct ggml_context * ctx); + + GGML_API + + GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); + GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN + GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split); + + GGML_API int ggml_blck_size (enum ggml_type type); + GGML_API size_t ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block + GGML_API float ggml_type_sizef(enum ggml_type type); // ggml_type_size()/ggml_blck_size() as float + + GGML_API const char * ggml_type_name(enum ggml_type type); + GGML_API const char * ggml_op_name (enum ggml_op op); + GGML_API const char * ggml_op_symbol(enum ggml_op op); + + GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor); + + GGML_API bool ggml_is_quantized(enum ggml_type type); + + // TODO: temporary until model loading of ggml examples is refactored + GGML_API enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype); + + GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_contiguous(const struct ggml_tensor * tensor); + GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); + + GGML_API bool ggml_are_same_shape(const struct ggml_tensor * t0, const struct ggml_tensor * t1); + + // use this to compute the memory overhead of a tensor + GGML_API size_t ggml_tensor_overhead(void); + + // main + + GGML_API struct ggml_context * ggml_init(struct ggml_init_params params); + GGML_API void ggml_free(struct ggml_context * ctx); + + GGML_API size_t ggml_used_mem(const struct ggml_context * ctx); + + GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch); + GGML_API bool ggml_get_no_alloc(struct ggml_context * ctx); + GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc); + + GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx); + GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx); + GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx); + + GGML_API struct ggml_tensor * ggml_new_tensor( + struct ggml_context * ctx, + enum ggml_type type, + int n_dims, + const int64_t *ne); + + GGML_API struct ggml_tensor * ggml_new_tensor_1d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_new_tensor_2d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_new_tensor_3d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_new_tensor_4d( + struct ggml_context * ctx, + enum ggml_type type, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value); + GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value); + + GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); + GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + + // Context tensor enumeration and lookup + GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); + + GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value); + GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value); + + // Converts a flat index into coordinates + GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3); + + GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value); + + GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value); + + GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i); + GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value); + + GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3); + GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value); + + GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); + GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); + GGML_API int32_t * ggml_get_data_i32(const struct ggml_tensor * tensor); + + GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); + + GGML_API const char * ggml_get_name (const struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_set_name ( struct ggml_tensor * tensor, const char * name); + GGML_ATTRIBUTE_FORMAT(2, 3) + GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); + + GGML_API void ggml_set_backend(struct ggml_tensor * tensor, enum ggml_backend_type backend); + + + // + // operations on tensors with backpropagation + // + + GGML_API struct ggml_tensor * ggml_dup( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_dup_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_add( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor *ggml_add_idx( + struct ggml_context *ctx, + struct ggml_tensor *a, + struct ggml_tensor *b, + struct ggml_tensor *idx); + + GGML_API struct ggml_tensor * ggml_add_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add_cast( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + enum ggml_type type); + + GGML_API struct ggml_tensor * ggml_add1( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_add1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_acc( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_acc_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_sub( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sub_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_mul_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_div_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_sqr( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqr_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sqrt_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_log_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // return scalar + GGML_API struct ggml_tensor * ggml_sum( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // sums along rows, with input shape [a,b,c,d] return shape [1,b,c,d] + GGML_API struct ggml_tensor * ggml_sum_rows( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // mean along rows + GGML_API struct ggml_tensor * ggml_mean( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // argmax along rows + GGML_API struct ggml_tensor * ggml_argmax( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // if a is the same shape as b, and a is not parameter, return a + // otherwise, return a new tensor: repeat(a) to fit in b + GGML_API struct ggml_tensor * ggml_repeat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // sums repetitions in a into shape of b + GGML_API struct ggml_tensor * ggml_repeat_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // concat a and b on dim 2 + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_concat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_abs( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_abs_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_sgn_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_neg_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_step_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_tanh_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_elu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_leaky( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_relu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // TODO: double-check this computation is correct + GGML_API struct ggml_tensor * ggml_gelu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_gelu_quick_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_silu_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_silu_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // normalize along rows + GGML_API struct ggml_tensor * ggml_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + GGML_API struct ggml_tensor * ggml_rms_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float eps); + + // group normalize along ne0*ne1*n_groups + // used in stable-diffusion + // TODO: eps is hardcoded to 1e-6 for now + GGML_API struct ggml_tensor * ggml_group_norm( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + GGML_API struct ggml_tensor * ggml_group_norm_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_groups); + + // a - x + // b - dy + GGML_API struct ggml_tensor * ggml_rms_norm_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + float eps); + + // A: k columns, n rows => [ne03, ne02, n, k] + // B: k columns, m rows (i.e. we transpose it internally) => [ne03 * x, ne02 * y, m, k] + // result is n columns, m rows => [ne03 * x, ne02 * y, m, n] + GGML_API struct ggml_tensor * ggml_mul_mat( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + GGML_API struct ggml_tensor *ggml_mul_mat_idx( + struct ggml_context *ctx, + struct ggml_tensor *a, + struct ggml_tensor *b, + struct ggml_tensor *sparse_idx, + struct ggml_tensor *gpu_idx); + GGML_API struct ggml_tensor *ggml_mul_mat_idx_upscale( + struct ggml_context *ctx, + struct ggml_tensor *a, + struct ggml_tensor *b, + struct ggml_tensor *sparse_idx, + struct ggml_tensor *gpu_bucket, + int64_t result_ne0); + GGML_API struct ggml_tensor *ggml_axpy( + struct ggml_context *ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * sparse_idx, + struct ggml_tensor * hybrid_aux); + + // A: m columns, n rows, + // B: p columns, n rows, + // result is m columns, p rows + GGML_API struct ggml_tensor * ggml_out_prod( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // + // operations on tensors without backpropagation + // + + GGML_API struct ggml_tensor * ggml_scale( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_scale_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t nb2, + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + GGML_API struct ggml_tensor * ggml_set_1d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return modified a + GGML_API struct ggml_tensor * ggml_set_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + // b -> view(a,offset,nb1,nb2,3), return view(a) + GGML_API struct ggml_tensor * ggml_set_2d_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + size_t nb1, + size_t offset); + + // a -> b, return view(b) + GGML_API struct ggml_tensor * ggml_cpy( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // a -> b, in-place, return view(b) + GGML_API struct ggml_tensor * ggml_cpy_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // make contiguous + GGML_API struct ggml_tensor * ggml_cont( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, in-place + GGML_API struct ggml_tensor * ggml_cont_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // make contiguous, with new shape + GGML_API struct ggml_tensor * ggml_cont_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_cont_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + GGML_API struct ggml_tensor * ggml_cont_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_cont_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // return view(a), b specifies the new shape + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0); + + GGML_API struct ggml_tensor * ggml_reshape_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1); + + // return view(a) + // TODO: when we start computing gradient, make a copy instead of view + GGML_API struct ggml_tensor * ggml_reshape_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2); + + GGML_API struct ggml_tensor * ggml_reshape_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3); + + // offset in bytes + GGML_API struct ggml_tensor * ggml_view_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + size_t nb1, // row stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_3d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t offset); + + GGML_API struct ggml_tensor * ggml_view_4d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int64_t ne3, + size_t nb1, // row stride in bytes + size_t nb2, // slice stride in bytes + size_t nb3, + size_t offset); + + GGML_API struct ggml_tensor * ggml_permute( + struct ggml_context * ctx, + struct ggml_tensor * a, + int axis0, + int axis1, + int axis2, + int axis3); + + // alias for ggml_permute(ctx, a, 1, 0, 2, 3) + GGML_API struct ggml_tensor * ggml_transpose( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_get_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_get_rows_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + GGML_API struct ggml_tensor * ggml_diag( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // set elements above the diagonal to -INF + GGML_API struct ggml_tensor * ggml_diag_mask_inf( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_inf_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // set elements above the diagonal to 0 + GGML_API struct ggml_tensor * ggml_diag_mask_zero( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_diag_mask_zero_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past); + + GGML_API struct ggml_tensor * ggml_soft_max( + struct ggml_context * ctx, + struct ggml_tensor * a); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a); + + GGML_API struct ggml_tensor * ggml_soft_max_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_soft_max_back_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // rotary position embedding + // if mode & 1 == 1, skip n_past elements (DEPRECATED) + // if mode & 2 == 1, GPT-NeoX style + // if mode & 4 == 1, ChatGLM style + // + // b is an int32 vector with size a->ne[2], it contains the positions + GGML_API struct ggml_tensor * ggml_rope( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx); + + // custom RoPE + GGML_API struct ggml_tensor * ggml_rope_custom( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_custom_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow); + + // compute correction dims for YaRN RoPE scaling + void ggml_rope_yarn_corr_dims( + int n_dims, int n_orig_ctx, float freq_base, float beta_fast, float beta_slow, float dims[2]); + + // xPos RoPE, in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_rope_xpos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + float base, + bool down); + + // rotary position embedding backward, i.e compute dx from dy + // a - dy + GGML_API struct ggml_tensor * ggml_rope_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int n_dims, + int mode, + int n_ctx, + int n_orig_ctx, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow, + float xpos_base, + bool xpos_down); + + // alibi position embedding + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_alibi( + struct ggml_context * ctx, + struct ggml_tensor * a, + int n_past, + int n_head, + float bias_max); + + // clamp + // in-place, returns view(a) + GGML_API struct ggml_tensor * ggml_clamp( + struct ggml_context * ctx, + struct ggml_tensor * a, + float min, + float max); + + GGML_API struct ggml_tensor * ggml_im2col( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + bool is_2D); + + GGML_API struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, // stride + int p0, // padding + int d0); // dilation + + // conv_1d with padding = half + // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d) + GGML_API struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d); + + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0); + + GGML_API struct ggml_tensor * ggml_conv_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1); + + + // kernel size is a->ne[0] x a->ne[1] + // stride is equal to kernel size + // padding is zero + // example: + // a: 16 16 3 768 + // b: 1024 1024 3 1 + // res: 64 64 768 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + // kernel size is a->ne[0] x a->ne[1] + // stride is 1 + // padding is half + // example: + // a: 3 3 256 256 + // b: 64 64 256 1 + // res: 64 64 256 1 + // used in sam + GGML_API struct ggml_tensor * ggml_conv_2d_s1_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int stride); + + enum ggml_op_pool { + GGML_OP_POOL_MAX, + GGML_OP_POOL_AVG, + GGML_OP_POOL_COUNT, + }; + + GGML_API struct ggml_tensor * ggml_pool_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, // kernel size + int s0, // stride + int p0); // padding + + // the result will have 2*p0 padding for the first dimension + // and 2*p1 padding for the second dimension + GGML_API struct ggml_tensor * ggml_pool_2d( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_op_pool op, + int k0, + int k1, + int s0, + int s1, + float p0, + float p1); + + // nearest interpolate + // used in stable-diffusion + GGML_API struct ggml_tensor * ggml_upscale( + struct ggml_context * ctx, + struct ggml_tensor * a, + int scale_factor); + + GGML_API struct ggml_tensor * ggml_flash_attn( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_attn_back( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * d, + bool masked); + + GGML_API struct ggml_tensor * ggml_flash_ff( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b0, + struct ggml_tensor * b1, + struct ggml_tensor * c0, + struct ggml_tensor * c1); + + // partition into non-overlapping windows with padding if needed + // example: + // a: 768 64 64 1 + // w: 14 + // res: 768 14 14 25 + // used in sam + GGML_API struct ggml_tensor * ggml_win_part( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w); + + // reverse of ggml_win_part + // used in sam + GGML_API struct ggml_tensor * ggml_win_unpart( + struct ggml_context * ctx, + struct ggml_tensor * a, + int w0, + int h0, + int w); + + GGML_API struct ggml_tensor * ggml_unary( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + GGML_API struct ggml_tensor * ggml_unary_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_unary_op op); + + // used in sam + GGML_API struct ggml_tensor * ggml_get_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + int qh, + int kh); + + // used in sam + + GGML_API struct ggml_tensor * ggml_add_rel_pos( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + GGML_API struct ggml_tensor * ggml_add_rel_pos_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * pw, + struct ggml_tensor * ph); + + // custom operators + + typedef void (*ggml_unary_op_f32_t) (const int, float *, const float *); + typedef void (*ggml_binary_op_f32_t)(const int, float *, const float *, const float *); + + typedef void (*ggml_custom1_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom2_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + typedef void (*ggml_custom3_op_f32_t)(struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *, const struct ggml_tensor *); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_unary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_unary_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_binary_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_binary_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom1_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_f32_t fun), + "use ggml_map_custom1_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom2_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_f32_t fun), + "use ggml_map_custom2_inplace instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3 instead"); + + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_map_custom3_inplace_f32( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_f32_t fun), + "use ggml_map_custom3_inplace instead"); + + // custom operators v2 + + typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); + typedef void (*ggml_custom2_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, int ith, int nth, void * userdata); + typedef void (*ggml_custom3_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, const struct ggml_tensor * b, const struct ggml_tensor * c, int ith, int nth, void * userdata); + + #define GGML_N_TASKS_MAX -1 + + GGML_API struct ggml_tensor * ggml_map_custom1( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom1_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + ggml_custom1_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom2_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + ggml_custom2_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + GGML_API struct ggml_tensor * ggml_map_custom3_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + ggml_custom3_op_t fun, + int n_tasks, + void * userdata); + + // loss function + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b); + + GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c); + + // + // automatic differentiation + // + + GGML_API void ggml_set_param( + struct ggml_context * ctx, + struct ggml_tensor * tensor); + + + GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); + + // graph allocation in a context + GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false + GGML_API struct ggml_cgraph * ggml_new_graph_custom (struct ggml_context * ctx, size_t size, bool grads); + GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); + GGML_API struct ggml_cgraph * ggml_graph_view (struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i0, int i1); + GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads + GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); + + GGML_API size_t ggml_graph_overhead(void); + GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); + + // ggml_graph_plan() has to be called before ggml_graph_compute() + // when plan.work_size > 0, caller must allocate memory for plan.work_data + GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); + GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); + + // same as ggml_graph_compute() but the work data is allocated as a part of the context + // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data + GGML_API void ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads); + + GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + + GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); + GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); + + // print info and performance information for the graph + GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); + + // dump the graph into a file using the dot format + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + + // build gradient checkpointing backward graph gb for gf using provided checkpoints + // gb_tmp will contain original backward graph with rewritten backward process nodes, + // but without the second forward pass nodes. + GGML_API void ggml_build_backward_gradient_checkpointing( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + struct ggml_cgraph * gb_tmp, + struct ggml_tensor * * checkpoints, + int n_checkpoints); + // + // optimization + // + + // optimization methods + enum ggml_opt_type { + GGML_OPT_ADAM, + GGML_OPT_LBFGS, + }; + + // linesearch methods + enum ggml_linesearch { + GGML_LINESEARCH_DEFAULT = 1, + + GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, + GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, + GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, + }; + + // optimization return values + enum ggml_opt_result { + GGML_OPT_OK = 0, + GGML_OPT_DID_NOT_CONVERGE, + GGML_OPT_NO_CONTEXT, + GGML_OPT_INVALID_WOLFE, + GGML_OPT_FAIL, + GGML_OPT_CANCEL, + + GGML_LINESEARCH_FAIL = -128, + GGML_LINESEARCH_MINIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_STEP, + GGML_LINESEARCH_MAXIMUM_ITERATIONS, + GGML_LINESEARCH_INVALID_PARAMETERS, + }; + + typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); + + // optimization parameters + // + // see ggml.c (ggml_opt_default_params) for default values + // + struct ggml_opt_params { + enum ggml_opt_type type; + + size_t graph_size; + + int n_threads; + + // delta-based convergence test + // + // if past == 0 - disabled + // if past > 0: + // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) + // + int past; + float delta; + + // maximum number of iterations without improvement + // + // if 0 - disabled + // if > 0: + // assume convergence if no cost improvement in this number of iterations + // + int max_no_improvement; + + bool print_forward_graph; + bool print_backward_graph; + + int n_gradient_accumulation; + + // ADAM parameters + struct { + int n_iter; + + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable + int decay_min_ndim; // minimum number of tensor dimension to apply weight decay + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float eps_f; // epsilon for convergence test + float eps_g; // epsilon for convergence test + float gclip; // gradient clipping + } adam; + + // LBFGS parameters + struct { + int m; // number of corrections to approximate the inv. Hessian + int n_iter; + int max_linesearch; + + float eps; // convergence tolerance + float ftol; // line search tolerance + float wolfe; + float min_step; + float max_step; + + enum ggml_linesearch linesearch; + } lbfgs; + }; + + struct ggml_opt_context { + struct ggml_context * ctx; + struct ggml_opt_params params; + + int iter; + int64_t nx; // number of parameter elements + + bool just_initialized; + + float loss_before; + float loss_after; + + struct { + struct ggml_tensor * g; // current gradient + struct ggml_tensor * m; // first moment + struct ggml_tensor * v; // second moment + struct ggml_tensor * pf; // past function values + float fx_best; + float fx_prev; + int n_no_improvement; + } adam; + + struct { + struct ggml_tensor * x; // current parameters + struct ggml_tensor * xp; // previous parameters + struct ggml_tensor * g; // current gradient + struct ggml_tensor * gp; // previous gradient + struct ggml_tensor * d; // search direction + struct ggml_tensor * pf; // past function values + struct ggml_tensor * lmal; // the L-BFGS memory alpha + struct ggml_tensor * lmys; // the L-BFGS memory ys + struct ggml_tensor * lms; // the L-BFGS memory s + struct ggml_tensor * lmy; // the L-BFGS memory y + float fx_best; + float step; + int j; + int k; + int end; + int n_no_improvement; + } lbfgs; + }; + + GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); + + // optimize the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt( + struct ggml_context * ctx, + struct ggml_opt_params params, + struct ggml_tensor * f); + + // initialize optimizer context + GGML_API void ggml_opt_init( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_opt_params params, + int64_t nx); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f); + + // continue optimizing the function defined by the tensor f + GGML_API enum ggml_opt_result ggml_opt_resume_g( + struct ggml_context * ctx, + struct ggml_opt_context * opt, + struct ggml_tensor * f, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); + + // + // quantization + // + + // TODO: these would probably get removed in favor of the more general ggml_quantize_chunk + GGML_API size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q8_0(const float * src, void * dst, int n, int k, int64_t * hist); + + GGML_API size_t ggml_quantize_q2_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q3_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q4_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q5_K(const float * src, void * dst, int n, int k, int64_t * hist); + GGML_API size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist); + + GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + + // + // gguf + // + + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_UINT64 = 10, + GGUF_TYPE_INT64 = 11, + GGUF_TYPE_FLOAT64 = 12, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_empty_sparse(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API int gguf_get_version (const struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API void * gguf_get_data (const struct gguf_context * ctx); + + GGML_API int gguf_get_n_kv(const struct gguf_context * ctx); + GGML_API int gguf_find_key(const struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (const struct gguf_context * ctx, int key_id); + + GGML_API enum gguf_type gguf_get_kv_type (const struct gguf_context * ctx, int key_id); + GGML_API enum gguf_type gguf_get_arr_type(const struct gguf_context * ctx, int key_id); + + // will abort if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (const struct gguf_context * ctx, int key_id); + GGML_API int8_t gguf_get_val_i8 (const struct gguf_context * ctx, int key_id); + GGML_API uint16_t gguf_get_val_u16 (const struct gguf_context * ctx, int key_id); + GGML_API int16_t gguf_get_val_i16 (const struct gguf_context * ctx, int key_id); + GGML_API uint32_t gguf_get_val_u32 (const struct gguf_context * ctx, int key_id); + GGML_API int32_t gguf_get_val_i32 (const struct gguf_context * ctx, int key_id); + GGML_API float gguf_get_val_f32 (const struct gguf_context * ctx, int key_id); + GGML_API uint64_t gguf_get_val_u64 (const struct gguf_context * ctx, int key_id); + GGML_API int64_t gguf_get_val_i64 (const struct gguf_context * ctx, int key_id); + GGML_API double gguf_get_val_f64 (const struct gguf_context * ctx, int key_id); + GGML_API bool gguf_get_val_bool(const struct gguf_context * ctx, int key_id); + GGML_API const char * gguf_get_val_str (const struct gguf_context * ctx, int key_id); + GGML_API int gguf_get_arr_n (const struct gguf_context * ctx, int key_id); + GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); + GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); + + GGML_API enum ggml_sparse_deriv gguf_get_sparse_deriv(const struct gguf_context * ctx); + GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); + GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (const struct gguf_context * ctx, int i); + + // overrides existing values or adds a new one + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_u64 (struct gguf_context * ctx, const char * key, uint64_t val); + GGML_API void gguf_set_val_i64 (struct gguf_context * ctx, const char * key, int64_t val); + GGML_API void gguf_set_val_f64 (struct gguf_context * ctx, const char * key, double val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); + + // manage tensor info + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); + + // writing gguf files can be done in 2 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); + // fwrite(f, ...); + // void * data = gguf_meta_get_meta_data(ctx); + // fseek(f, 0, SEEK_SET); + // fwrite(f, data, gguf_get_meta_size(ctx)); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API void gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx); + GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data); + + // + // system info + // + + GGML_API int ggml_cpu_has_avx (void); + GGML_API int ggml_cpu_has_avx2 (void); + GGML_API int ggml_cpu_has_avx512 (void); + GGML_API int ggml_cpu_has_avx512_vbmi(void); + GGML_API int ggml_cpu_has_avx512_vnni(void); + GGML_API int ggml_cpu_has_fma (void); + GGML_API int ggml_cpu_has_neon (void); + GGML_API int ggml_cpu_has_arm_fma (void); + GGML_API int ggml_cpu_has_metal (void); + GGML_API int ggml_cpu_has_f16c (void); + GGML_API int ggml_cpu_has_fp16_va (void); + GGML_API int ggml_cpu_has_wasm_simd (void); + GGML_API int ggml_cpu_has_blas (void); + GGML_API int ggml_cpu_has_cublas (void); + GGML_API int ggml_cpu_has_clblast (void); + GGML_API int ggml_cpu_has_gpublas (void); + GGML_API int ggml_cpu_has_sse3 (void); + GGML_API int ggml_cpu_has_ssse3 (void); + GGML_API int ggml_cpu_has_vsx (void); + + // + // global variables + // + // TODO: these should be moved to the context + extern float sparse_pred_threshold; + + // + // Internal types and functions exposed for tests and benchmarks + // + +#ifdef __cplusplus +// restrict not standard in C++ +#define GGML_RESTRICT +#else +#define GGML_RESTRICT restrict +#endif + typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int k); + typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int k); + typedef void (*ggml_vec_dot_t) (const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT x, const void * GGML_RESTRICT y); + + typedef struct { + const char * type_name; + int blck_size; + size_t type_size; + bool is_quantized; + ggml_to_float_t to_float; + ggml_from_float_t from_float; + ggml_from_float_t from_float_reference; + ggml_vec_dot_t vec_dot; + enum ggml_type vec_dot_type; + } ggml_type_traits_t; + + GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type); + +#ifdef __cplusplus +} +#endif diff --git a/spm-headers/llama.h b/spm-headers/llama.h deleted file mode 120000 index 9acceb98..00000000 --- a/spm-headers/llama.h +++ /dev/null @@ -1 +0,0 @@ -../llama.h \ No newline at end of file diff --git a/spm-headers/llama.h b/spm-headers/llama.h new file mode 100644 index 00000000..aeabaa27 --- /dev/null +++ b/spm-headers/llama.h @@ -0,0 +1,788 @@ +#ifndef LLAMA_H +#define LLAMA_H + +#include "ggml.h" +#ifdef GGML_USE_CUBLAS +#include "ggml-cuda.h" +#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES +#else +#define LLAMA_MAX_DEVICES 1 +#endif // GGML_USE_CUBLAS +#include +#include +#include +#include + +#ifdef LLAMA_SHARED +# if defined(_WIN32) && !defined(__MINGW32__) +# ifdef LLAMA_BUILD +# define LLAMA_API __declspec(dllexport) +# else +# define LLAMA_API __declspec(dllimport) +# endif +# else +# define LLAMA_API __attribute__ ((visibility ("default"))) +# endif +#else +# define LLAMA_API +#endif + +#ifdef __GNUC__ +# define DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define DEPRECATED(func, hint) func +#endif + +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF + +#define LLAMA_MAX_RNG_STATE (64*1024) + +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' + +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 2 + +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) +// Defined when llama.cpp is compiled with support for offloading model layers to GPU. +#define LLAMA_SUPPORTS_GPU_OFFLOAD +#endif + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // TODO: show sample usage + // + + struct llama_model; + struct llama_context; + + typedef int32_t llama_pos; + typedef int32_t llama_token; + typedef int32_t llama_seq_id; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + }; + + enum llama_token_type { + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18, // except 1d tensors + + LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file + }; + + enum llama_rope_scaling_type { + LLAMA_ROPE_SCALING_UNSPECIFIED = -1, + LLAMA_ROPE_SCALING_NONE = 0, + LLAMA_ROPE_SCALING_LINEAR = 1, + LLAMA_ROPE_SCALING_YARN = 2, + LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN, + }; + + typedef struct llama_token_data { + llama_token id; // token id + float logit; // log-odds of the token + float p; // probability of the token + } llama_token_data; + + typedef struct llama_token_data_array { + llama_token_data * data; + size_t size; + bool sorted; + } llama_token_data_array; + + typedef void (*llama_progress_callback)(float progress, void *ctx); + + // Input data for llama_decode + // A llama_batch object can contain input about one or many sequences + // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens + // + // - token : the token ids of the input (used when embd is NULL) + // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + // - pos : the positions of the respective token in the sequence + // - seq_id : the sequence to which the respective token belongs + // - logits : if zero, the logits for the respective token will not be output + // + typedef struct llama_batch { + int32_t n_tokens; + + llama_token * token; + float * embd; + llama_pos * pos; + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + } llama_batch; + + struct llama_model_params { + int32_t n_gpu_layers; // number of layers to store in VRAM + int32_t main_gpu; // the GPU that is used for scratch and small tensors + float vram_budget_gb; // VRAM budget in GB, -1 for all available VRAM (for a single GPU) + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) + + // called with a progress value between 0 and 1, pass NULL to disable + llama_progress_callback progress_callback; + // context pointer passed to the progress callback + void * progress_callback_user_data; + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool vocab_only; // only load the vocabulary, no weights + bool use_mmap; // use mmap if possible + bool use_mlock; // force system to keep model in RAM + bool reset_gpu_index; // force reset of the GPU index + bool disable_gpu_index; // bypass the GPU index and FFN split + }; + + struct llama_context_params { + uint32_t seed; // RNG seed, -1 for random + uint32_t n_ctx; // text context, 0 = from model + uint32_t n_batch; // prompt processing maximum batch size + uint32_t n_threads; // number of threads to use for generation + uint32_t n_threads_batch; // number of threads to use for batch processing + int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` + + // ref: https://github.com/ggerganov/llama.cpp/pull/2054 + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model + float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model + float yarn_attn_factor; // YaRN magnitude scaling factor + float yarn_beta_fast; // YaRN low correction dim + float yarn_beta_slow; // YaRN high correction dim + uint32_t yarn_orig_ctx; // YaRN original context size + + // Keep the booleans together to avoid misalignment during copy-by-value. + bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true) + bool f16_kv; // use fp16 for KV cache, fp32 otherwise + bool logits_all; // the llama_eval() call computes all logits, not just the last one + bool embedding; // embedding mode only + }; + + // model quantization parameters + typedef struct llama_model_quantize_params { + int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // disable k-quant mixtures and quantize all tensors to the same type + } llama_model_quantize_params; + + // grammar types + struct llama_grammar; + + // grammar element type + enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID + } llama_grammar_element; + + // performance timing information + struct llama_timings { + double t_start_ms; + double t_end_ms; + double t_load_ms; + double t_sample_ms; + double t_p_eval_ms; + double t_eval_ms; + + int32_t n_sample; + int32_t n_p_eval; + int32_t n_eval; + }; + + // Helpers for getting default parameters + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); + + // Initialize the llama + ggml backend + // If numa is true, use NUMA optimizations + // Call once at the start of the program + LLAMA_API void llama_backend_init(bool numa); + + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + + LLAMA_API struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_model_params params); + + LLAMA_API struct llama_model * llama_load_model_from_file_with_context( + const char * path_model, + struct llama_model_params params, + struct llama_context_params * cparams); + + LLAMA_API void llama_free_model(struct llama_model * model); + + LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params); + + // Frees all allocated memory + LLAMA_API void llama_free(struct llama_context * ctx); + + LLAMA_API int64_t llama_time_us(void); + + LLAMA_API int llama_max_devices (void); + LLAMA_API bool llama_mmap_supported (void); + LLAMA_API bool llama_mlock_supported(void); + + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); + LLAMA_API bool llama_use_sparse_inference(const struct llama_model * model); + + LLAMA_API int llama_n_vocab (const struct llama_model * model); + LLAMA_API int llama_n_ctx_train(const struct llama_model * model); + LLAMA_API int llama_n_embd (const struct llama_model * model); + + // Get the model's RoPE frequency scaling factor + LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + + // Get a string describing the model type + LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size); + + // Returns the total size of all the tensors in the model in bytes + LLAMA_API uint64_t llama_model_size(const struct llama_model * model); + + // Returns the total number of parameters in the model + LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); + + // Get a llama model tensor + LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name); + + // Returns 0 on success + LLAMA_API int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params); + + // Reserve KV cache in VRAM. This is an optimization to allocate KV cache before + // FFN layers being split and offloaded to GPU. + LLAMA_API void llama_reserve_model_kv_cache(struct llama_model * model, const struct llama_context_params * cparams); + + // Apply a LoRA adapter to a loaded model + // path_base_model is the path to a higher quality model to use as a base for + // the layers modified by the adapter. Can be NULL to use the current loaded model. + // The model needs to be reloaded before applying a new adapter, otherwise the adapter + // will be applied on top of the previous one + // Returns 0 on success + LLAMA_API DEPRECATED(int llama_apply_lora_from_file( + struct llama_context * ctx, + const char * path_lora, + float scale, + const char * path_base_model, + int n_threads), + "use llama_model_apply_lora_from_file instead"); + + LLAMA_API int llama_model_apply_lora_from_file( + const struct llama_model * model, + const char * path_lora, + float scale, + const char * path_base_model, + int n_threads); + + LLAMA_API size_t llama_model_offload_ffn_split(struct llama_model * model); + + // + // KV cache + // + + // Returns the number of tokens in the KV cache + LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), + "avoid using this, it will be removed in the future, instead - count the tokens in user code"); + + // Clear the KV cache + LLAMA_API void llama_kv_cache_clear( + struct llama_context * ctx); + + // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1); + + // Copy all tokens that belong to the specified sequence to another sequence + // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + + // Removes all tokens that do not belong to the specified sequence + LLAMA_API void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id); + + // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + // If the KV cache is RoPEd, the KV data is updated accordingly + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) + LLAMA_API void llama_kv_cache_seq_shift( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta); + + // + // State / sessions + // + + // Returns the maximum size in bytes of the state (rng, logits, embedding + // and kv_cache) - will often be smaller after compacting tokens + LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_copy_state_data( + struct llama_context * ctx, + uint8_t * dst); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_set_state_data( + struct llama_context * ctx, + uint8_t * src); + + // Save/load session file + LLAMA_API bool llama_load_session_file( + struct llama_context * ctx, + const char * path_session, + llama_token * tokens_out, + size_t n_token_capacity, + size_t * n_token_count_out); + + LLAMA_API bool llama_save_session_file( + struct llama_context * ctx, + const char * path_session, + const llama_token * tokens, + size_t n_token_count); + + // + // Decoding + // + + // Run the llama inference to obtain the logits and probabilities for the next token(s). + // tokens + n_tokens is the provided batch of new tokens to process + // n_past is the number of tokens to use from previous eval calls + // Returns 0 on success + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval( + struct llama_context * ctx, + llama_token * tokens, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); + + // Same as llama_eval, but use float matrix input directly. + // DEPRECATED: use llama_decode() instead + LLAMA_API DEPRECATED(int llama_eval_embd( + struct llama_context * ctx, + float * embd, + int32_t n_tokens, + int n_past), + "use llama_decode() instead"); + + // Return batch for single sequence of tokens starting at pos_0 + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + llama_token * tokens, + int32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_free() + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + // The rest of the llama_batch members are allocated with size n_tokens + // All members are left uninitialized + LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max); + + // Frees a batch of tokens allocated with llama_batch_init() + LLAMA_API void llama_batch_free(struct llama_batch batch); + + // Positive return values does not mean a fatal error, but rather a warning. + // 0 - success + // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) + // < 0 - error + LLAMA_API int llama_decode( + struct llama_context * ctx, + struct llama_batch batch); + + // Set the number of threads used for decoding + // n_threads is the number of threads used for generation (single token) + // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + + // Token logits obtained from the last call to llama_eval() + // The logits for the last token are stored in the last row + // Logits for which llama_batch.logits[i] == 0 are undefined + // Rows: n_tokens provided with llama_batch + // Cols: n_vocab + LLAMA_API float * llama_get_logits(struct llama_context * ctx); + + // Logits for the ith token. Equivalent to: + // llama_get_logits(ctx) + i*n_vocab + LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); + + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + + // + // Vocab + // + + LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); + + LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); + + LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + + // Special tokens + LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence + LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line + + // codellama infill tokens + LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix + LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle + LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix + LLAMA_API llama_token llama_token_eot (const struct llama_model * model); // End of infill middle + + // + // Tokenization + // + + /// @details Convert the provided text into tokens. + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. + /// @return Returns the number of tokens on success, no more than n_max_tokens + /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. + /// Does not insert a leading space. + LLAMA_API int llama_tokenize( + const struct llama_model * model, + const char * text, + int text_len, + llama_token * tokens, + int n_max_tokens, + bool add_bos, + bool special); + + // Token Id -> Piece. + // Uses the vocabulary in the provided context. + // Does not write null terminator to the buffer. + // User code is responsible to remove the leading whitespace of the first non-BOS token when decoding multiple tokens. + LLAMA_API int llama_token_to_piece( + const struct llama_model * model, + llama_token token, + char * buf, + int length); + + // + // Grammar + // + + LLAMA_API struct llama_grammar * llama_grammar_init( + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + + LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + + LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + + // + // Sampling functions + // + + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + LLAMA_API void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); + + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. + /// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + LLAMA_API void llama_sample_classifier_free_guidance( + struct llama_context * ctx, + llama_token_data_array * candidates, + struct llama_context * guidance_ctx, + float scale); + + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + LLAMA_API void llama_sample_softmax( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_k( + struct llama_context * ctx, + llama_token_data_array * candidates, + int k, + size_t min_keep); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + LLAMA_API void llama_sample_top_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + LLAMA_API void llama_sample_min_p( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + LLAMA_API void llama_sample_tail_free( + struct llama_context * ctx, + llama_token_data_array * candidates, + float z, + size_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + LLAMA_API void llama_sample_typical( + struct llama_context * ctx, + llama_token_data_array * candidates, + float p, + size_t min_keep); + + LLAMA_API void llama_sample_temp( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp); + + LLAMA_API DEPRECATED(void llama_sample_temperature( + struct llama_context * ctx, + llama_token_data_array * candidates, + float temp), + "use llama_sample_temp instead"); + + /// @details Apply constraints from grammar + LLAMA_API void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar); + + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + int m, + float * mu); + + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + LLAMA_API llama_token llama_sample_token_mirostat_v2( + struct llama_context * ctx, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); + + /// @details Selects the token with the highest probability. + /// Does not compute the token probabilities. Use llama_sample_softmax() instead. + LLAMA_API llama_token llama_sample_token_greedy( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probabilities. + LLAMA_API llama_token llama_sample_token( + struct llama_context * ctx, + llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_grammar_accept_token( + struct llama_context * ctx, + struct llama_grammar * grammar, + llama_token token); + + // + // Beam search + // + + struct llama_beam_view { + const llama_token * tokens; + + size_t n_tokens; + float p; // Cumulative beam probability (renormalized relative to all beams) + bool eob; // Callback should set this to true when a beam is at end-of-beam. + }; + + // Passed to beam_search_callback function. + // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams + // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. + // These pointers are valid only during the synchronous callback, so should not be saved. + struct llama_beams_state { + struct llama_beam_view * beam_views; + + size_t n_beams; // Number of elements in beam_views[]. + size_t common_prefix_length; // Current max length of prefix tokens shared by all beams. + bool last_call; // True iff this is the last callback invocation. + }; + + // Type of pointer to the beam_search_callback function. + // void* callback_data is any custom data passed to llama_beam_search, that is subsequently + // passed back to beam_search_callback. This avoids having to use global variables in the callback. + typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state); + + /// @details Deterministically returns entire sentence constructed by a beam search. + /// @param ctx Pointer to the llama_context. + /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state. + /// @param callback_data A pointer that is simply passed back to callback. + /// @param n_beams Number of beams to use. + /// @param n_past Number of tokens already evaluated. + /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier. + LLAMA_API void llama_beam_search( + struct llama_context * ctx, + llama_beam_search_callback_fn_t callback, + void * callback_data, + size_t n_beams, + int n_past, + int n_predict); + + // Performance information + LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); + + LLAMA_API void llama_print_timings(struct llama_context * ctx); + LLAMA_API void llama_reset_timings(struct llama_context * ctx); + + // Print system information + LLAMA_API const char * llama_print_system_info(void); + + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); + + LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + +#ifdef __cplusplus +} +#endif + +// Internal API to be implemented by llama.cpp and used by tests/benchmarks only +#ifdef LLAMA_API_INTERNAL + +#include +#include + +struct ggml_tensor; + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +#endif // LLAMA_API_INTERNAL + +#endif // LLAMA_H diff --git a/tests/test-tokenizer-0-falcon.py b/tests/test-tokenizer-0-falcon.py index cf65a3f6..dbf38988 100644 --- a/tests/test-tokenizer-0-falcon.py +++ b/tests/test-tokenizer-0-falcon.py @@ -8,7 +8,7 @@ parser = argparse.ArgumentParser() parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") -parser.add_argument("--fname-tok", help="path to a text file to tokenize") +parser.add_argument("--fname-tok", help="path to a text file to tokenize") args = parser.parse_args() dir_tokenizer = args.dir_tokenizer @@ -16,37 +16,37 @@ tokenizer = AutoTokenizer.from_pretrained(dir_tokenizer) tests = [ - "", - " ", - " ", - " ", - "\t", - "\n", - "\t\n", - "Hello world", - " Hello world", - "Hello World", - " Hello World", - " Hello World!", - "Hello, world!", - " Hello, world!", - " this is 🦙.cpp", - "w048 7tuijk dsdfhu", - "нещо на Български", - "កាន់តែពិសេសអាចខលចេញ", - "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", - "Hello", - " Hello", - " Hello", - " Hello", - " Hello", - " Hello\n Hello", - "\n =", - "' era", - ] + "", + " ", + " ", + " ", + "\t", + "\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "нещо на Български", + "កាន់តែពិសេសអាចខលចេញ", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", + "\n =", + "' era", +] for text in tests: - print('text: ', text) + print("text: ", text) print(tokenizer.encode(text)) print(tokenizer.decode(tokenizer.encode(text))) @@ -54,31 +54,31 @@ for text in tests: res = tokenizer.encode(text) - k = text.replace('\n', '\\n') - k = k.replace('\t', '\\t') + k = text.replace("\n", "\\n") + k = k.replace("\t", "\\t") k = '"' + k + '"' - print("{ %-24s, { " % k, end='') + print("{ %-24s, { " % k, end="") for x in res: - print("%7d," % x, end='') + print("%7d," % x, end="") print(" }, },") -print(tokenizer.encode('hello')) -print(tokenizer.encode('world')) -print(tokenizer.encode(' world')) -print(tokenizer.encode('hello world')) +print(tokenizer.encode("hello")) +print(tokenizer.encode("world")) +print(tokenizer.encode(" world")) +print(tokenizer.encode("hello world")) fname_tok = args.fname_tok if fname_tok: - print('tokenizing file: ', fname_tok) - fname_out = fname_tok + '.tok' - with open(fname_tok, 'r', encoding='utf-8') as f: + print("tokenizing file: ", fname_tok) + fname_out = fname_tok + ".tok" + with open(fname_tok, "r", encoding="utf-8") as f: lines = f.readlines() - s = ''.join(lines) + s = "".join(lines) res = tokenizer.encode(s) # write to file - with open(fname_out, 'w', encoding='utf-8') as f: + with open(fname_out, "w", encoding="utf-8") as f: for x in res: - f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') - print('len(res): ', len(res)) - print('len(lines): ', len(lines)) - print('results written to: ', fname_out) + f.write(str(x) + " '" + tokenizer.decode(x) + "'\n") + print("len(res): ", len(res)) + print("len(lines): ", len(lines)) + print("results written to: ", fname_out) diff --git a/tests/test-tokenizer-0-llama.py b/tests/test-tokenizer-0-llama.py index 078f680b..ed482136 100644 --- a/tests/test-tokenizer-0-llama.py +++ b/tests/test-tokenizer-0-llama.py @@ -8,87 +8,87 @@ parser = argparse.ArgumentParser() parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") -parser.add_argument("--fname-tok", help="path to a text file to tokenize") +parser.add_argument("--fname-tok", help="path to a text file to tokenize") args = parser.parse_args() dir_tokenizer = args.dir_tokenizer -tokenizer = SentencePieceProcessor(dir_tokenizer + '/tokenizer.model') +tokenizer = SentencePieceProcessor(dir_tokenizer + "/tokenizer.model") tests = [ - "", - " ", - " ", - " ", - "\t", - "\n", - "\t\n", - "Hello world", - " Hello world", - "Hello World", - " Hello World", - " Hello World!", - "Hello, world!", - " Hello, world!", - " this is 🦙.cpp", - "w048 7tuijk dsdfhu", - "нещо на Български", - "កាន់តែពិសេសអាចខលចេញ", - "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", - "Hello", - " Hello", - " Hello", - " Hello", - " Hello", - " Hello\n Hello", - ] + "", + " ", + " ", + " ", + "\t", + "\n", + "\t\n", + "Hello world", + " Hello world", + "Hello World", + " Hello World", + " Hello World!", + "Hello, world!", + " Hello, world!", + " this is 🦙.cpp", + "w048 7tuijk dsdfhu", + "нещо на Български", + "កាន់តែពិសេសអាចខលចេញ", + "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + "Hello", + " Hello", + " Hello", + " Hello", + " Hello", + " Hello\n Hello", +] for text in tests: - print('text: ', text) - print('\nwith bos:') + print("text: ", text) + print("\nwith bos:") print(tokenizer.encode(text, add_bos=True)) print(tokenizer.decode(tokenizer.encode(text, add_bos=True))) - print('\nwithout bos:') + print("\nwithout bos:") print(tokenizer.encode(text, add_bos=False)) print(tokenizer.decode(tokenizer.encode(text, add_bos=False))) -print("'" + tokenizer.id_to_piece(15043) + "'") # '_Hello' -print("'" + tokenizer.id_to_piece(29871) + "'") # '_' -print("'" + tokenizer.decode([15043]) + "'") # 'Hello' -print("'" + tokenizer.decode([15043, 15043]) + "'") # 'Hello Hello' -print("'" + tokenizer.decode([29871, 15043]) + "'") # ' Hello' -print("'" + tokenizer.decode([29871, 15043, 29871, 15043]) + "'") # ' Hello Hello' +print("'" + tokenizer.id_to_piece(15043) + "'") # '_Hello' +print("'" + tokenizer.id_to_piece(29871) + "'") # '_' +print("'" + tokenizer.decode([15043]) + "'") # 'Hello' +print("'" + tokenizer.decode([15043, 15043]) + "'") # 'Hello Hello' +print("'" + tokenizer.decode([29871, 15043]) + "'") # ' Hello' +print("'" + tokenizer.decode([29871, 15043, 29871, 15043]) + "'") # ' Hello Hello' print("\n\ntests for C++:\n") for text in tests: res = tokenizer.encode(text, add_bos=False) - k = text.replace('\n', '\\n') - k = k.replace('\t', '\\t') + k = text.replace("\n", "\\n") + k = k.replace("\t", "\\t") k = '"' + k + '"' - print("{ %-24s, { " % k, end='') + print("{ %-24s, { " % k, end="") for x in res: - print("%7d," % x, end='') + print("%7d," % x, end="") print(" }, },") -print(tokenizer.encode('hello')) -print(tokenizer.encode('world')) -print(tokenizer.encode(' world')) -print(tokenizer.encode('hello world')) +print(tokenizer.encode("hello")) +print(tokenizer.encode("world")) +print(tokenizer.encode(" world")) +print(tokenizer.encode("hello world")) fname_tok = args.fname_tok if fname_tok: - print('tokenizing file: ', fname_tok) - fname_out = fname_tok + '.tok' - with open(fname_tok, 'r', encoding='utf-8') as f: + print("tokenizing file: ", fname_tok) + fname_out = fname_tok + ".tok" + with open(fname_tok, "r", encoding="utf-8") as f: lines = f.readlines() - s = ''.join(lines) + s = "".join(lines) res = tokenizer.encode(s, add_bos=True) # write to file - with open(fname_out, 'w', encoding='utf-8') as f: + with open(fname_out, "w", encoding="utf-8") as f: for x in res: - f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') - print('len(res): ', len(res)) - print('len(lines): ', len(lines)) - print('results written to: ', fname_out) + f.write(str(x) + " '" + tokenizer.decode(x) + "'\n") + print("len(res): ", len(res)) + print("len(lines): ", len(lines)) + print("results written to: ", fname_out)