diff --git a/README.md b/README.md index 55f40d262..4404c24d9 100644 --- a/README.md +++ b/README.md @@ -7,16 +7,16 @@ - +

-`ShinkaEvolve` is a framework that combines Large Language Models (LLMs) with evolutionary algorithms to drive scientific discovery. By leveraging the creative capabilities of LLMs and the optimization power of evolutionary search, `ShinkaEvolve` enables automated exploration and improvement of scientific code. The system is inspired by the [AI Scientist](https://sakana.ai/ai-scientist/), [AlphaEvolve](https://deepmind.google/discover/blog/alphaevolve-a-gemini-powered-coding-agent-for-designing-advanced-algorithms/) and the [Darwin Goedel Machine](https://sakana.ai/dgm/): It maintains a population of programs that evolve over generations, with an ensemble of LLMs acting as intelligent mutation operators that suggest code improvements. +[`ShinkaEvolve`](https://arxiv.org/abs/2509.19349) is a framework that combines Large Language Models (LLMs) with evolutionary algorithms to drive scientific discovery. By leveraging the creative capabilities of LLMs and the optimization power of evolutionary search, `ShinkaEvolve` enables automated exploration and improvement of scientific code. The system is inspired by the [AI Scientist](https://sakana.ai/ai-scientist/), [AlphaEvolve](https://deepmind.google/discover/blog/alphaevolve-a-gemini-powered-coding-agent-for-designing-advanced-algorithms/) and the [Darwin Goedel Machine](https://sakana.ai/dgm/): It maintains a population of programs that evolve over generations, with an ensemble of LLMs acting as intelligent mutation operators that suggest code improvements. The framework supports **parallel evaluation of candidates** locally or on a Slurm cluster. It maintains an archive of successful solutions, enabling knowledge transfer between different evolutionary islands. `ShinkaEvolve` is particularly well-suited for scientific tasks where there is a verifier available and the goal is to optimize performance metrics while maintaining code correctness and readability. -![](docs/conceptual.png) +![evolution](https://github.com/user-attachments/assets/22cf3468-17fe-4995-9e13-d602b490a54e) ## Documentation πŸ“ @@ -26,6 +26,7 @@ The framework supports **parallel evaluation of candidates** locally or on a Slu | πŸ““ **[Tutorial Notebook](examples/shinka_tutorial.ipynb)** | Interactive walkthrough of Shinka features | Hands-on examples, configuration, best practices | | βš™οΈ **[Configuration](docs/configuration.md)** | Comprehensive configuration reference | All config options, optimization settings, advanced features | | 🎨 **[WebUI](docs/webui.md)** | Interactive visualization and monitoring | Real-time tracking, result analysis, debugging tools | +|πŸ•ΉοΈ **[Local LLM Support](https://github.com/SakanaAI/ShinkaEvolve/blob/main/docs/support_local_llm.md)**| Instructions for Local LLMs | How to setup local LLMs on your machine| ## Installation & Quick Start πŸš€ @@ -52,9 +53,9 @@ For detailed installation instructions and usage examples, see the [Getting Star | Example | Description | Environment Setup | |---------|-------------|-------------------| | β­• [Circle Packing](examples/circle_packing) | Optimize circle packing to maximize radii. | `LocalJobConfig` | -| πŸ€– [Agent Design](examples/agent_design) | Design agent scaffolds for math tasks. | `LocalJobConfig` | +| πŸ€– [Agent Design](examples/adas_aime) | Design agent scaffolds for math tasks. | `LocalJobConfig` | | 🎯 [ALE-Bench](examples/ale_bench) | Code optimization for ALE-Bench tasks. | `LocalJobConfig` | -| ✨ [Novelty Generator](examples/novelty_generator_bck) | Generate creative, surprising outputs (e.g., ASCII art). | `LocalJobConfig` | +| ✨ [Novelty Generator](examples/novelty_generator) | Generate creative, surprising outputs (e.g., ASCII art). | `LocalJobConfig` | ## `shinka` Run with Python API 🐍 @@ -308,9 +309,9 @@ If you use `ShinkaEvolve` in your research, please cite it as follows: ``` @article{lange2025shinka, - title={ShinkaEvolve: Towards Open-Ended and Sample-Efficient Program Evolution}, + title={ShinkaEvolve: Towards Open-Ended And Sample-Efficient Program Evolution}, author={Lange, Robert Tjarko and Imajuku, Yuki and Cetin, Edoardo}, - journal={arXiv preprint}, + journal={arXiv preprint arXiv:2509.19349}, year={2025} } -``` \ No newline at end of file +``` diff --git a/configs/config.yaml b/configs/config.yaml index 9702c6617..577e1dfe2 100644 --- a/configs/config.yaml +++ b/configs/config.yaml @@ -2,9 +2,9 @@ defaults: - _self_ - database@_global_: island_small - evolution@_global_: small_budget - - task@_global_: mad_tf + - task@_global_: circle_packing - cluster@_global_: local - - variant@_global_: mad_tf_example + - variant@_global_: circle_packing_example verbose: false results_dir: results diff --git a/configs/task/autoformalization.yaml b/configs/task/autoformalization.yaml new file mode 100644 index 000000000..56881fbf9 --- /dev/null +++ b/configs/task/autoformalization.yaml @@ -0,0 +1,49 @@ +evaluate_function: + _target_: examples.autoformalization.evaluate.main + program_path: ??? + results_dir: ??? + +distributed_job_config: + _target_: shinka.launch.SlurmCondaJobConfig + modules: + - "cuda/12.4" + - "cudnn/8.9.7" + - "hpcx/2.20" + eval_program_path: "shinka/eval_hydra.py" + conda_env: "shinka" + time: "01:00:00" + cpus: 16 + gpus: 1 + mem: "8G" + +evo_config: + task_sys_msg: | + You are an expert mathematician with a specialization in group theory. You want to prove the following statement: + + "Let $H$ be the subgroup generated by two elements $a, b$ of a group $G$. Prove that if $a b=b a$, then $H$ is an abelian group." + + Your task is to identify the minimal set of variables and assumptions needed to formalize the proof in Lean 4. You do not have to provide the proof itself. Therefore, your output is expected to look like: + + ```lean + -- Variables + theorem abelian_group () + -- Assumptions + + : + -- Conjecture + () + := + ``` + + Key components to add: + 1. Introduce all variables you need to complete the proof. + 2. Write all assumptions needed to complete the proof. + 3. Think about a conjecture that reflects the scientific concept you want to prove. + + Make sure your output is valid Lean 4. Do not leave the conjecture empty `()`. Do **not** finish the proof, you can leave the program up to the `:=` sign. + + language: "lean" + init_program_path: "examples/autoformalization/initial.lean" + job_type: "slurm_conda" + +exp_name: "shinka_autoformalization" diff --git a/configs/variant/autoformalization_example.yaml b/configs/variant/autoformalization_example.yaml new file mode 100644 index 000000000..2983ba959 --- /dev/null +++ b/configs/variant/autoformalization_example.yaml @@ -0,0 +1,8 @@ +defaults: + - override /database@_global_: island_small + - override /evolution@_global_: small_budget + - override /task@_global_: autoformalization + - override /cluster@_global_: local + - _self_ + +variant_suffix: "_example" \ No newline at end of file diff --git a/docs/getting_started.md b/docs/getting_started.md index 234158839..03bc54c80 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -2,6 +2,8 @@ Shinka is a framework that combines Large Language Models (LLMs) with evolutionary algorithms to drive scientific discovery. This guide will help you get started with installing, configuring, and running your first evolutionary experiments. +![](../docs/conceptual.png) + ## Table of Contents 1. [What is Shinka?](#what-is-shinka) @@ -53,7 +55,7 @@ pip install uv ```bash git clone -cd shinka +cd ShinkaEvolve # Create virtual environment with Python 3.11 uv venv --python 3.11 @@ -79,7 +81,7 @@ conda activate shinka ```bash git clone -cd shinka +cd ShinkaEvolve pip install -e . ``` @@ -249,7 +251,7 @@ from shinka.core import run_shinka_eval def main(program_path: str, results_dir: str): """Main evaluation function called by Shinka""" - + metrics, correct, error_msg = run_shinka_eval( program_path=program_path, results_dir=results_dir, @@ -268,11 +270,11 @@ def main(program_path: str, results_dir: str): def validate_packing(run_output): """Returns (is_valid: bool, error_msg: str or None)""" centers, radii, reported_sum = run_output - + # Check constraints (bounds, overlaps, etc.) if constraint_violated: return False, "Specific error description" - + return True, None # Valid solution ``` @@ -280,10 +282,10 @@ def validate_packing(run_output): ```python def aggregate_metrics(results, results_dir): """Returns metrics dictionary with required structure""" - + # Extract data from results centers, radii, reported_sum = results[0] - + return { "combined_score": float(reported_sum), # PRIMARY FITNESS (higher = better) "public": { # Visible in WebUI/logs @@ -331,6 +333,75 @@ The `run_shinka_eval` function returns three values: ## Advanced Usage +### Resuming Experiments + +If you need to pause and resume an evolutionary run, or extend a completed run with more generations, Shinka supports seamless resumption from existing results. + +#### How Resuming Works + +When you specify an existing `results_dir` that contains a database, Shinka will: +- Detect the previous run automatically +- Restore the population database and all program history +- Resume meta-recommendations from the last checkpoint +- Continue from the last completed generation + +#### Using the CLI (Hydra) + +```bash +# Resume an existing run and extend to 50 generations +shinka_launch \ + variant=circle_packing_example \ + evo_config.results_dir=results_20250101_120000 \ + evo_config.num_generations=50 + +# Or with a custom task +shinka_launch \ + task=circle_packing \ + database=island_small \ + evolution=small_budget \ + cluster=local \ + evo_config.results_dir=path/to/previous/results \ + evo_config.num_generations=100 +``` + +#### Using the Python API + +```python +from shinka.core import EvolutionRunner, EvolutionConfig +from shinka.database import DatabaseConfig +from shinka.launch import LocalJobConfig + +# Point to existing results directory +evo_config = EvolutionConfig( + num_generations=50, # Extend to 50 total generations + results_dir="results_20250101_120000", # Existing results + # ... other config parameters ... +) + +job_config = LocalJobConfig( + eval_program_path="examples/circle_packing/evaluate.py", +) + +db_config = DatabaseConfig( + archive_size=20, + num_islands=2, +) + +# Run will automatically detect and resume +runner = EvolutionRunner( + evo_config=evo_config, + job_config=job_config, + db_config=db_config, +) +runner.run() +``` + +**Important Notes:** +- The `num_generations` parameter should be set to the **total** number of generations you want (not additional generations) +- For example, if you completed 20 generations and want 30 more, set `num_generations=50` +- The database configuration (number of islands, archive size, etc.) should match the original run +- All previous progress, including the best solutions and meta-recommendations, will be preserved + ### Environment Management for Local Jobs When running jobs locally, you have several options for managing Python environments: diff --git a/docs/support_local_llm.md b/docs/support_local_llm.md new file mode 100644 index 000000000..5f406e7b9 --- /dev/null +++ b/docs/support_local_llm.md @@ -0,0 +1,232 @@ + +# 🧩 Integrating Local LLMs into **ShinkaEvolve** + +## 🧠 Overview + +The original **ShinkaEvolve** code does **not** include built-in support for running **local LLMs**. +To enable this functionality, parts of the codebase can be modified to integrate locally hosted models. + +--- + +## πŸ—οΈ Code Organization + +**ShinkaEvolve** uses a **modular architecture** that supports multiple **LLM providers**. +The relevant code for LLM interaction is located in the **`LLM/`** folder, which manages all model communications. +ShinkaEvolve distinguishes between two LLM types: + +* **Regular LLMs** +* **Embedding LLMs** + +--- + +## βš™οΈ Adding a Regular LLM + +To add support for a **regular LLM**, follow these steps. They will show an example of adding support for gpt-oss models running with unsloth, which provides an API compatible with OpenAI API (v1/completions). +This LLM can then be specified in the configuration variables: + +```yaml +llm_models: +meta_llm_models: +``` + +--- + +### πŸ”§ Step 1: Modify the Client + +The file **`client.py`** is responsible for creating clients that interact with LLMs. +Each client instance is later used to query a specific model. + +To add a local model, introduce a new client configuration. +The API URL is extracted from the model name, which follows this format: + +``` +local-gptoss-unsloth-url +``` + +#### Example + +```python +elif "local-gptoss-unsloth" in model_name: + # Extract URL from model name + pattern = r"https?://" + match = re.search(pattern, model_name) + if match: + start_index = match.start() + url = model_name[start_index:] + else: + raise ValueError(f"Invalid URL in model name: {model_name}") + + # Create OpenAI-compatible client + client = openai.OpenAI( + api_key="filler", + base_url=url + ) + + # Structured output mode (if required) + if structured_output: + client = instructor.from_openai( + client, + mode=instructor.Mode.JSON, + ) +``` + +--- + +### πŸ“ Step 2: Create the Local Query Function + +Inside the **`models/`** folder, create a new subfolder to store the query functions for your local models: + +``` +LLM/models/local/ +``` + +> Don’t forget to include an empty `__init__.py` file. + +This folder should contain a **custom query function** for the local model. I called my file local_gptoss_unsloth.py. +It should follow the same structure as other functions in `LLM/models/`, but with small adjustments. + +#### My Key Adjustments + +* Replace `max_output_tokens` with **`max_tokens`** to match the local API. +* Extract additional response metadata such as: + + * `total_tokens` + * `thinking_tokens` (if your model includes reasoning traces) + +This function is later imported and registered in **`query.py`**. + +--- + +### 🧩 Step 3: Update `__init__.py` + +Configure **`__init__.py`** to include and expose the new local query function, so it can be imported elsewhere. + +``` +from .local.local_gptoss_unsloth import query_local_gptoss_unsloth # ADDED THIS LINE +from .result import QueryResult + +__all__ = [ + "query_anthropic", + "query_openai", + "query_deepseek", + "query_gemini", + "query_local_gptoss_unsloth", # ADDED THIS LINE + "QueryResult", +] +``` + +--- + +### πŸ“¬ Step 4: Update `query.py` + +Import and register the new local query function in query.py. + +#### Imports + +```python +from .models import ( + query_anthropic, + query_openai, + query_deepseek, + query_gemini, + query_local_gptoss_unsloth, # ADDED THIS LINE + QueryResult, +) +``` + +#### Model Selection Logic + +```python +elif "local-gptoss-unsloth" in model_name: # ADDED THIS LINE + query_fn = query_local_gptoss_unsloth +``` + +--- + +### 🧠 Step 5: Other Observations + +The file **`query.py`** also defines functions such as: + +* `sample_model_kwargs` +* `sample_batch_kwargs` + +However, these are **not referenced anywhere else** in the repository, so no modifications are required here for now. + +--- + +### βœ… Summary + +| Step | File | Change | Description | +| ---- | -------------------------------------------- | -------------------- | -------------------------------------------------------- | +| 1 | `client.py` | Add new client block | Create OpenAI-compatible client for local LLM | +| 2 | `models/local/query_local_gptoss_unsloth.py` | New function | Query local model, adjust tokens, extract reasoning info | +| 3 | `__init__.py` | Add import | Expose new query function | +| 4 | `query.py` | Register model | Add conditional for local LLM | +| 5 | β€” | Review only | Ignored unused functions | + +--- + +## 🧬 Adding a Local Embedding Model + +For embedding models, you can use **Ollama**, which follows the **OpenAI API** format. +The only relevant file is **`embedding.py`**. + +### Code Addition + +```python +elif model_name.startswith("local-"): + # Pattern: local-(model-name)-(http or https url) + match = re.match(r"local-(.+?)-(https?://.+)", model_name) + if match: + model_to_use = match.group(1) + url = match.group(2) + else: + raise ValueError(f"Invalid local model format: {model_name}") + + client = openai.OpenAI( + base_url=url, + api_key="filler" + ) +``` + +#### Notes + +* Compatible with **any Ollama model**. +* The model name must follow this convention: + + ``` + local-model-name-url + ``` +* The code extracts both `model-name` and `url`, and uses them to query Ollama. + +--- + +### Query Logic + +The existing line in **`embedding.py`** remains unchanged: + +```python +response = self.client.embeddings.create( + model=self.model, + input=code, + encoding_format="float" +) +``` + +For local embedding models, `self.model` corresponds to the extracted model name. +The only addition to the **Embedding Client** class: + +```python +elif self.model_name.startswith("local-"): + cost = 0.0 +``` + +--- + +## πŸš€ Result + +ShinkaEvolve can now connect to **locally hosted LLMs** and **embedding models** through **OpenAI-compatible APIs**. +This setup supports **Ollama** and other frameworks such as **gpt-oss** under **Unsloth**. + +If your model has different requirements, follow the same pattern with a distinct model identifier and your own custom logic. + diff --git a/examples/autoformalization/evaluate.py b/examples/autoformalization/evaluate.py new file mode 100644 index 000000000..e89bd1657 --- /dev/null +++ b/examples/autoformalization/evaluate.py @@ -0,0 +1,207 @@ +import os +import logging +import argparse +from typing import Optional, List, Tuple, Dict, Any + +import numpy as np +from lean_interact import LeanREPLConfig, AutoLeanServer, Command, TempRequireProject, FileCommand +from lean_interact.interface import BaseREPLResponse, LeanError + +from .utils_lean import validate_lean, generate_proof +from shinka.llm.client import get_client_llm +from shinka.core import run_shinka_eval + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def check_lean(path_or_str: str) -> BaseREPLResponse | LeanError: + """ + Plug the generated proof through the Lean 4 compiler. + + Args: + path_or_str (str): The path to the proof or the proof itself. + + Returns: + BaseREPLResponse: the output of the Lean compiler. + """ + project = TempRequireProject(lean_version="v4.24.0", require="mathlib") + config = LeanREPLConfig(project=project) + server = AutoLeanServer(config) # start Lean REPL + command = FileCommand(path=path_or_str) if path_or_str.endswith(".lean") else Command(cmd=path_or_str) + server_output = server.run(command) + logger.info(server_output.messages) + return server_output + +def validate_proof(run_output: Tuple[str, Optional[str]]) -> Tuple[bool, Optional[str]]: + """ + Validates the proof generation results based on the output of ``generate_proof``. + + Args: + run_output(Tuple[str, Optional[str]): the run output, containing + - file_path (str): the path to a Lean 4 file containing an incomplete proof (may include ``sorry`` s) + - proof_text (str): the output of ``generate_proof``. + + Returns: + (is_valid: bool, error_message: Optional[str]) + """ + file_path, proof_text = run_output + return validate_lean(proof_text, allow_sorry=False, timeout=60, verbose=False) + + +def aggregate_hypothesis_generation_metrics(results: Tuple[str, str], results_dir: str) -> Dict[str, Any]: + """ + Aggregates metrics for the generation of hypotheses. Assumes num_runs=1. Saves extra.npz with detailed generation + information. + + Args: + results (Tuple[str, str]): the validated output of ``generate_proof``. + results_dir (str): the path to the directory where to save the results. + + Returns: + dict: a dictionary of the results. + + """ + print("Aggregation results:", results) + if not results: + return {"combined_score": 0.0, "error": "No results to aggregate"} + + path, lean_cmd = results + + server_output = check_lean(lean_cmd) + if not server_output.lean_code_is_valid(allow_sorry=False): + penalty = 0 + for message in server_output.messages: + if "error" in message.severity: + penalty += -1 + + messages = server_output + text_feedback = ( + f"The generated proof:\n{lean_cmd} was invalid. Each error or sorry leads to a -1 penalty." + f"Please consider the following compiler feedback and update the formalization accordingly:\n" + f"{messages}" + ) + else: + text_feedback = "" + penalty = 0 + + public_metrics = { + "proof_length": len(lean_cmd), + } + + private_metrics = {} + metrics = { + "combined_score": len(lean_cmd), + "public": public_metrics, + "private": private_metrics, + "text_feedback": text_feedback, + } + + extra_file = os.path.join(results_dir, "extra.npz") + try: + np.savez(extra_file, proof_length=len(results)) + print(f"Detailed packing data saved to {extra_file}") + except Exception as e: + print(f"Error saving extra.npz: {e}") + metrics["extra_npz_save_error"] = str(e) + return metrics + + +def get_proof_generation_kwargs(run_index: int) -> Dict[str, Any]: + """ + Provides keyword arguments for generating proofs. Insert your sampling parameters here. The timeout is provided + in seconds. + + Args: + run_index (int): the index of the run, added for compatibility with ShinkaEvolve and not used. + + Returns: + dict: a dictionary of hypothesis generation parameters. + """ + del run_index # Unused + return { + "sampling_params": { + }, + "timeout": 180, + } + + +def main(program_path: str, results_dir: str, prover_model: str='gpt-5-nano') -> None: + """ + Run the hypothesis evaluation using shinka.eval + + Args: + program_path (str): Path to program to evaluate. + results_dir (str): Dir to save results (metrics.json, correct.json, extra.npz) + prover_model (str): LLM agent used to construct LEAN proofs based on the initial header and formalization. + + Returns: + None + """ + + print(f"Evaluating program: {program_path} with {prover_model}") + print(f"Saving results to: {results_dir}") + os.makedirs(results_dir, exist_ok=True) + + client, prover_model = get_client_llm(prover_model, False,) + + # Helper functions + def _aggregator_with_context( + r: List[Tuple[str, str]], + ) -> Dict[str, Any]: + """A curried function to pass results_dir to the aggregator, extracts the tuple from the list containing 1 element""" + return aggregate_hypothesis_generation_metrics(r[0], results_dir) + + def _kwargs_with_context(run_index: int) -> dict: + """A curried function to pass the proof client to the proof solver""" + return {"model": prover_model, "proof_client": client} | get_proof_generation_kwargs(run_index=run_index) + + num_experiment_runs = 1 + + metrics, correct, error_msg = run_shinka_eval( + program_path=program_path, + results_dir=results_dir, + experiment_fn_name=generate_proof, + num_runs=num_experiment_runs, + get_experiment_kwargs=_kwargs_with_context, + validate_fn=validate_proof, + aggregate_metrics_fn=_aggregator_with_context, + ) + + if correct: + print("Evaluation and Validation completed successfully.") + else: + print(f"Evaluation or Validation failed: {error_msg}") + + print("Metrics:") + for key, value in metrics.items(): + if isinstance(value, str) and len(value) > 100: + print(f" {key}: ") + else: + print(f" {key}: {value}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Hypothesis evaluator using shinka.eval") + parser.add_argument( + "--program_path", + type=str, + default="initial.lean", + help="Path to program to evaluate", + ) + parser.add_argument( + "--results_dir", + type=str, + default="results", + help="Dir to save results (metrics.json, correct.json, extra.npz)", + ) + + parser.add_argument( + "--prover_model", + type=str, + default="gpt-5-nano", # or an actual prover like "deepseek-ai/DeepSeek-Prover-V2-7B" (requires local LLM support) + help="LLM agent used to construct LEAN proofs based on the initial header and formalization.", + ) + + parsed_args = parser.parse_args() + main(parsed_args.program_path, parsed_args.results_dir, parsed_args.prover_model) diff --git a/examples/autoformalization/initial.lean b/examples/autoformalization/initial.lean new file mode 100644 index 000000000..19b44103c --- /dev/null +++ b/examples/autoformalization/initial.lean @@ -0,0 +1,14 @@ +import data.real.basic +import data.real.nnreal +import Mathlib.Tactic + +-- EVOLVE-BLOCK-START +theorem abelian_group () +-- Assumptions +() + +: +-- Conjecture +() +-- EVOLVE-BLOCK-END +:= \ No newline at end of file diff --git a/examples/autoformalization/utils_lean.py b/examples/autoformalization/utils_lean.py new file mode 100644 index 000000000..1add4517a --- /dev/null +++ b/examples/autoformalization/utils_lean.py @@ -0,0 +1,267 @@ +import json +import re +import os +import logging +from pathlib import Path +from typing import Tuple, Optional, List + +from openai import OpenAI + +from lean_interact import LeanREPLConfig, AutoLeanServer, TempRequireProject +from lean_interact.interface import Command, FileCommand, LeanError +from lean_interact.utils import remove_lean_comments + +logger = logging.getLogger(__name__) + + +def generate_prompt(file_path: str) -> str: + """ + Generate a prompt for an LLM-based prover agent for the completion of the proof provided in ``file_path``. + + Args: + file_path (str): the path to a Lean 4 file. + + Returns: + str: a prompt for an LLM-based prover agent based on the contents of ``file_path``. + + Notes: + - This function is based on the example provided by the DeepSeek Development team. + https://huggingface.co/deepseek-ai/DeepSeek-Prover-V2-7B + """ + formal_statement = Path(file_path).read_text(encoding="UTF-8") + + prompt = """ + Complete the following Lean 4 code: + + ```lean4 + {} + ``` + + Please *only* output the submitted, fully proven lean program. Do not add any reasoning or explanation to your answer. You are not allowed to include "sorrys" in your result. Make sure to include all imports in your final answer. + """.strip() + + return prompt.format(formal_statement) + + +def generate_proof( + file_path: str, model: str, proof_client: OpenAI, sampling_params: dict, timeout: int +) -> Tuple[str, Optional[str]]: + """ + Complete the proof provided at ``file_path`` using ``model``. The ``model`` should be hosted via the vLLM-based + OpenAI client. + + Args: + file_path (str): the path to a Lean 4 file containing an incomplete proof (may include ``sorry``s) + model (str): the name of the LLM to use. Recommended are ``DeepSeek-Prover-V2-7B``, + ``Goedel-LM/Goedel-Prover-V2-8B`` or``Goedel-LM/Goedel-Prover-V2-32B``. + proof_client (OpenAI): the inference API. + sampling_params (dict): a dictionary of the sampling params that are passed as parameters to the API request. + timeout (int): The timeout for the request in seconds. + + Returns: + str: the generated proof. + + """ + try: + prompt = generate_prompt(file_path) + if not prompt: + return file_path, None + response = proof_client.chat.completions.create( + model=model, + messages=[ + {"role": "user", "content": prompt}, + ], + **sampling_params, + timeout=timeout, + ) + proof_text = response.choices[0].message.content + + results_dir, _ = os.path.split(file_path) + fname = os.path.join(results_dir, fr"unprocessed_proof.lean") + with open(fname, "w", encoding="utf-8") as f: + f.write(proof_text) + + proof_text = postprocess(proof_text) + fname = os.path.join(results_dir, fr"processed_proof.lean") + with open(fname, "w", encoding="utf-8") as f: + f.write(proof_text) + + # Reformat proofs by removing comments, as they can sometimes cause issues when validating the proofs + return file_path, remove_lean_comments(proof_text) + + except Exception as e: + logger.error(f"Error generating proof for {file_path}: {e}") + return file_path, None + + +def postprocess(proof: str) -> str: + """ + Postprocess the ``proof``, fixing common syntax errors. + + Args: + proof (str): a Lean 4 proof. + + Returns: + (str): the post-processed and cleaned proof. + """ + try: + proof = proof.split("```lean4")[1].replace("`", "") + except IndexError: + proof = proof.replace("`", "") + proof = validate_imports(proof_text=proof) + # Reformat proofs by removing comments, as they can sometimes cause issues when validating the proofs + clean_lean = remove_lean_comments(proof) + if clean_lean.endswith("D"): + clean_lean = clean_lean[:-1] + # Replace "βˆ‘ n in range" with "βˆ‘ n ∈ range" + lean_txt = fix_range_notation(clean_lean) + # Replace " pi " with Ο€ + lean_txt = re.sub(r'\s+pi\s+', ' Ο€ ', lean_txt) + return lean_txt + + +def fix_range_notation(text: str) -> str: + """ + Replace 'βˆ‘/∏ variable in range' with 'βˆ‘/∏ variable ∈ range' in Lean 4 code, regardless of the variable name. + + Args: + text (str): The input string containing Lean code. + + Returns: + str: The string with all sum notations fixed. + """ + pattern = r"([βˆ‘βˆ]\s+)(\w+)(\s+)in(\s+)" + replacement = r"\1\2\3∈\4" + return re.sub(pattern, replacement, text) + + +def validate_imports( + proof_text: str, + standard_imports: Optional[str] = None, + open_imports: Optional[str] = None, +) -> str: + """ + Check whether the imports are present in the proof header. Add missing imports if required. + + Args: + proof_text (str): the proof text. + standard_imports (Optional[List]): a list of the standard imports required to solve most proofs in chemical + physics. + open_imports (Optional[List]): a list of the standard opens required to solve mst proofs in chemical physics. + + Returns: + str: the proof text including the required imports. + + """ + if not standard_imports: # Default argument + standard_imports = ["import mathlib", ] + + if not open_imports: + open_imports = ["Real", "BigOperators", "Topology", "Set", "Filter", "Finset"] + + imports, opens, proof = [], [], [] + + proof_lines = proof_text.split("\n") + for line in proof_lines: + if line.startswith("import"): # Verify imports + continue + + if line.lower() == "lean": # fix parsing issues + continue + + elif line.startswith("open"): # Open statement + curated_line = line + for statement in open_imports: + if statement not in curated_line: + curated_line += f" {statement}" + opens.append(curated_line) + + else: # The rest of the proof text + proof.append(line) + + for statement in standard_imports: # Add the remaining standard imports + if statement not in imports: + imports.append(statement) + + proof_text = "\n".join(imports) + "\n" + "\n".join(opens) + "\n" + "\n".join(proof) + + return proof_text.strip() + + +def validate_lean( + path_or_str: str, allow_sorry: bool, timeout: int, verbose: bool, lean_version: str = "v4.24.0" +) -> Tuple[bool, str]: + """ + Verify the validity of the Lean program found at ``path`` via LeanInteract. The function builds a Lean + Read-Eval-Print-Loop (REPL) for solving Lean programs. The function returns ``True`` if the lean program has run + successfully and when the Lean code is considered valid. + + Args: + path_or_str (str): The path of the file to be operated on by the REPL. + allow_sorry (bool): True to allow for partially complete proofs that include ``sorry`` statements. + timeout (int): The timeout for the request in seconds. + verbose (bool): Whether to print additional information when downloading and building the Lean REPL, + and running the Lean REPL request using ``run``. + lean_version (str): The Lean version used. Default is ``"v4.24.0"``, which is the latest version around October + 2025. + + Returns: + bool: ``True`` if the lean program associated with ``path`` is considered valid and has run successfully. + str: the outcome or exception associated with the validation. + + Notes: + - Store the lean output as a json, storing the header and formalization code only, to run as a ``Command``. + - `Formalization_code` must end by `:=`. + """ + project = TempRequireProject(lean_version=lean_version, require="mathlib") + config = LeanREPLConfig(project=project) + + try: + server = AutoLeanServer(config) + command = FileCommand(path=path_or_str) if path_or_str.endswith(".lean") else Command(cmd=path_or_str) + server_output = server.run(command, timeout=timeout, verbose=verbose, add_to_session_cache=False) + + logger.info(server_output.messages) + + if not server_output.lean_code_is_valid(allow_sorry=allow_sorry): + return False, f"The provided lean file {path_or_str} is invalid or contains 'sorry'." + elif isinstance(server_output, LeanError): + return False, f"{path_or_str}:" + server_output.message + else: + return ( + True, + f"Run {path_or_str} terminated successfully: {server_output}", + ) # The content may still contain errors + + except (TimeoutError, ConnectionAbortedError, json.JSONDecodeError) as e: + logger.error(f"Error while checking the lean file {path_or_str}: {e}") + return False, e + +async def async_validate_lean( + path_or_str: str, allow_sorry: bool, timeout: int, verbose: bool, lean_version: str = "v4.24.0" +) -> Tuple[bool, str]: + """ + Verify the validity of the Lean program found at ``path`` via LeanInteract. The function builds a Lean + Read-Eval-Print-Loop (REPL) for solving Lean programs. The function returns ``True`` if the lean program has run + successfully and when the Lean code is considered valid. + + Args: + path_or_str (str): The path of the file to be operated on by the REPL. + allow_sorry (bool): True to allow for partially complete proofs that include ``sorry`` statements. + timeout (int): The timeout for the request in seconds. + verbose (bool): Whether to print additional information when downloading and building the Lean REPL, + and running the Lean REPL request using ``run``. + lean_version (str): The Lean version used. Default is ``"v4.24.0"``, which is the latest version around October + 2025. + + Returns: + bool: ``True`` if the lean program associated with ``path`` is considered valid and has run successfully. + str: the outcome or exception associated with the validation. + + Notes: + - Store the lean output as a json, storing the header and formalization code only, to run as a ``Command``. + - `Formalization_code` must end by `:=`. + """ + return validate_lean( + path_or_str, allow_sorry=allow_sorry, timeout=timeout, verbose=verbose, lean_version=lean_version + ) \ No newline at end of file diff --git a/examples/shinka_tutorial.ipynb b/examples/shinka_tutorial.ipynb index 66a71a073..c6d818994 100644 --- a/examples/shinka_tutorial.ipynb +++ b/examples/shinka_tutorial.ipynb @@ -237,6 +237,17 @@ "if not llm_models:\n", " llm_models = [\"gpt-5-mini\"] # fallback if no keys detected\n", "\n", + "# pick embedding model based on available keys\n", + "embedding_model_name = \"\"\n", + "if os.getenv(\"GEMINI_API_KEY\"):\n", + " embedding_model_name = \"gemini-embedding-001\"\n", + "elif os.getenv(\"OPENAI_API_KEY\"):\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "else:\n", + " embedding_model_name = \"text-embedding-3-small\"\n", + "print(f\"βœ… Embedding model selected: {embedding_model_name}\")\n", + "\n", + "\n", "# unique experiment directory\n", "timestamp = dt.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", "run_tag = f\"{timestamp}_weighted_fast\"\n", @@ -271,6 +282,8 @@ " max_novelty_attempts=3,\n", " # ensemble llm selection among candidates based on past performance\n", " llm_dynamic_selection=None, # e.g. \"ucb1\"\n", + " # set embedding model\n", + " embedding_model=embedding_model_name,\n", ")\n", "\n", "db_config = DatabaseConfig(\n", @@ -286,11 +299,13 @@ " enforce_island_separation=True,\n", " parent_selection_strategy=\"weighted\",\n", " parent_selection_lambda=10.0,\n", + " \n", ")\n", "\n", "job_config = LocalJobConfig(eval_program_path=\"evaluate.py\")\n", "\n", "print(\"llm_models:\", llm_models)\n", + "print(\"embedding_model:\", embedding_model_name)\n", "print(\"results_dir:\", evo_config.results_dir)" ] }, diff --git a/pyproject.toml b/pyproject.toml index e3ec455af..afbaf5f3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,17 +45,21 @@ dependencies = [ "adjustText", "markdown", "aiofiles", + "google-generativeai", + "lean-interact", ] [tool.setuptools] -packages = ["shinka"] script-files = ["shinka/shinka_launch", "shinka/shinka_visualize"] +[tool.setuptools.packages.find] +include = ["shinka", "shinka.*"] + [tool.setuptools.package-data] "*" = ["*"] -[tool.uv] -dev-dependencies = [ +[dependency-groups] +dev = [ "pytest>=6.0", "black", "isort", diff --git a/shinka/core/runner.py b/shinka/core/runner.py index 3c818742c..92cdee99d 100644 --- a/shinka/core/runner.py +++ b/shinka/core/runner.py @@ -158,7 +158,12 @@ def __init__( # Initialize database and scheduler db_config.db_path = str(db_path) - self.db = ProgramDatabase(config=db_config) + embedding_model_to_use = ( + evo_config.embedding_model or "text-embedding-3-small" + ) + self.db = ProgramDatabase( + config=db_config, embedding_model=embedding_model_to_use + ) self.scheduler = JobScheduler( job_type=evo_config.job_type, config=job_config, # type: ignore @@ -231,6 +236,14 @@ def __init__( self.lang_ext = "cpp" elif self.evo_config.language == "python": self.lang_ext = "py" + elif self.evo_config.language == "rust": + self.lang_ext = "rs" + elif self.evo_config.language == "swift": + self.lang_ext = "swift" + elif self.evo_config.language in ["json", "json5"]: + self.lang_ext = "json" + elif self.evo_config.language == "lean": + self.lang_ext = "lean" else: msg = f"Language {self.evo_config.language} not supported" raise ValueError(msg) @@ -481,7 +494,7 @@ def _run_generation_0(self): logger.info(f"Initial program generated and saved to {exec_fname}") # Run the evaluation synchronously - results, rtime = self.scheduler.run(exec_fname, results_dir) + results, rtime = self.scheduler.run(exec_fname, results_dir,) code_embedding, e_cost = self.get_code_embedding(exec_fname) @@ -726,7 +739,7 @@ def _submit_new_job(self): meta_patch_data["novelty_explanation"] = novelty_explanation # Submit the job asynchronously - job_id = self.scheduler.submit_async(exec_fname, results_dir) + job_id = self.scheduler.submit_async(exec_fname, results_dir,) # Add to running jobs queue running_job = RunningJob( @@ -1096,9 +1109,10 @@ def run_patch( # error_attempt is already set from apply_patch or default pass - # Only consider the diff summary for the original.py file!!! - if "original.py" in diff_summary: - diff_summary = diff_summary["original.py"] + # Only consider the diff summary for the original source file + original_filename = f"original.{self.lang_ext}" + if original_filename in diff_summary: + diff_summary = diff_summary[original_filename] meta_edit_data = { "patch_type": patch_type, diff --git a/shinka/core/wrap_eval.py b/shinka/core/wrap_eval.py index 7e1d1e5d3..c411b1bfd 100644 --- a/shinka/core/wrap_eval.py +++ b/shinka/core/wrap_eval.py @@ -96,22 +96,25 @@ def run_shinka_eval( num_valid_runs = 0 num_invalid_runs = 0 - try: - module = load_program(program_path) - if not hasattr(module, experiment_fn_name): - raise AttributeError( - f"Experiment function '{experiment_fn_name}' not found in " - f"{program_path}" - ) - experiment_fn = getattr(module, experiment_fn_name) + all_run_results: List[Any] = [] + execution_times: List[float] = [] - all_run_results: List[Any] = [] - execution_times: List[float] = [] + try: + if program_path.endswith(".lean"): + experiment_fn = experiment_fn_name + else: + module = load_program(program_path) + if not hasattr(module, experiment_fn_name): + raise AttributeError( + f"Experiment function '{experiment_fn_name}' not found in " + f"{program_path}" + ) + experiment_fn = getattr(module, experiment_fn_name) for i in range(num_runs): kwargs: Dict[str, Any] = {} if get_experiment_kwargs: - kwargs = get_experiment_kwargs(i) + kwargs = get_experiment_kwargs(i) | {"file_path": program_path} else: kwargs = {"seed": i + 1} diff --git a/shinka/database/complexity.py b/shinka/database/complexity.py index 4116567e9..70cd5d3a1 100644 --- a/shinka/database/complexity.py +++ b/shinka/database/complexity.py @@ -259,8 +259,8 @@ def analyze_code_metrics(code_string, language="python"): # If Python parsing fails, fall back to C++ analysis return analyze_cpp_complexity(code_string) - # For C/C++/CUDA and other languages, use regex-based analysis - elif language in ["cpp", "c", "cuda", "c++"]: + # For C/C++/CUDA/Rust/Swift/JSON and other languages, use regex-based analysis + elif language in ["cpp", "c", "cuda", "c++", "rust", "swift", "json", "json5"]: return analyze_cpp_complexity(code_string) # For unknown languages, use simple line-based complexity diff --git a/shinka/database/dbase.py b/shinka/database/dbase.py index 69fdf5432..2118763c4 100644 --- a/shinka/database/dbase.py +++ b/shinka/database/dbase.py @@ -50,7 +50,7 @@ def clean_nan_values(obj: Any) -> Any: @dataclass class DatabaseConfig: - db_path: Optional[str] = None + db_path: str = "evolution_db.sqlite" num_islands: int = 4 archive_size: int = 100 @@ -82,6 +82,9 @@ class DatabaseConfig: # Beam search parent selection parameters num_beams: int = 5 + # Embedding model name + embedding_model: str = "text-embedding-3-small" + def db_retry(max_retries=5, initial_delay=0.1, backoff_factor=2): """ @@ -248,12 +251,22 @@ class ProgramDatabase: populations, and an archive of elites. """ - def __init__(self, config: DatabaseConfig, read_only: bool = False): + def __init__( + self, + config: DatabaseConfig, + embedding_model: str = "text-embedding-3-small", + read_only: bool = False, + ): self.config = config self.conn: Optional[sqlite3.Connection] = None self.cursor: Optional[sqlite3.Cursor] = None self.read_only = read_only - self.embedding_client = EmbeddingClient() + # Only create embedding client if not in read-only mode + # (e.g., WebUI doesn't need it for visualization) + if not read_only: + self.embedding_client = EmbeddingClient(model_name=embedding_model) + else: + self.embedding_client = None self.last_iteration: int = 0 self.best_program_id: Optional[str] = None diff --git a/shinka/database/display.py b/shinka/database/display.py index 4c34d3445..3e55439bf 100644 --- a/shinka/database/display.py +++ b/shinka/database/display.py @@ -122,6 +122,18 @@ def print_program_summary(self, program, console: Optional[RichConsole] = None): else: time_display = f"{time_val:.1f}s" + # Safely extract metadata fields for display + metadata = program.metadata or {} + patch_name_raw = metadata.get("patch_name", "[dim]N/A[/dim]") + if patch_name_raw is None: + patch_name_raw = "[dim]N/A[/dim]" + patch_name = str(patch_name_raw)[:30] + + patch_type_raw = metadata.get("patch_type", "[dim]N/A[/dim]") + if patch_type_raw is None: + patch_type_raw = "[dim]N/A[/dim]" + patch_type = str(patch_type_raw) + # Add the data row island_display = ( f"I-{program.island_idx}" if program.island_idx is not None else "N/A" @@ -131,8 +143,8 @@ def print_program_summary(self, program, console: Optional[RichConsole] = None): island_display, status_display, score_display, - program.metadata.get("patch_name", "[dim]N/A[/dim]")[:30], - program.metadata.get("patch_type", "[dim]N/A[/dim]"), + patch_name, + patch_type, f"{program.complexity:.1f}", cost_display, time_display, diff --git a/shinka/database/inspirations.py b/shinka/database/inspirations.py index ee564dfa1..42c3859d8 100644 --- a/shinka/database/inspirations.py +++ b/shinka/database/inspirations.py @@ -72,6 +72,7 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: self.cursor.execute( """ SELECT p.id FROM programs p + JOIN archive a ON p.id = a.program_id WHERE p.island_idx = ? AND p.correct = 1 ORDER BY p.combined_score DESC LIMIT ? @@ -93,7 +94,8 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: placeholders_rand = ",".join("?" * len(insp_ids)) sql_rand = f""" SELECT p.id FROM programs p - WHERE p.island_idx = ? AND p.correct = 1 + JOIN archive a ON p.id = a.program_id + WHERE p.island_idx = ? AND p.correct = 1 AND p.id NOT IN ({placeholders_rand}) ORDER BY RANDOM() LIMIT ? """ @@ -111,9 +113,10 @@ def sample_context(self, parent: Any, n: int) -> List[Any]: needed = n - len(inspirations) if needed > 0: placeholders_rand = ",".join("?" * len(insp_ids)) - sql_rand = f"""SELECT id FROM programs - WHERE correct = 1 - AND id NOT IN ({placeholders_rand}) + sql_rand = f"""SELECT p.id FROM programs p + JOIN archive a ON p.id = a.program_id + WHERE p.correct = 1 + AND p.id NOT IN ({placeholders_rand}) ORDER BY RANDOM() LIMIT ? """ params_rand = list(insp_ids) + [needed] diff --git a/shinka/database/islands.py b/shinka/database/islands.py index 9975eac3b..341dea79c 100644 --- a/shinka/database/islands.py +++ b/shinka/database/islands.py @@ -682,6 +682,16 @@ def copy_program_to_islands(self, program: Any) -> List[str]: f"Created copy {new_id[:8]}... of program {program.id[:8]}... " f"for island {island_idx}" ) + + # Add the copied program to the archive if it's correct + # This ensures it can be used as inspiration for that island + if program.correct: + self.cursor.execute( + "INSERT OR IGNORE INTO archive (program_id) VALUES (?)", + (new_id,), + ) + logger.debug(f"Added copy {new_id[:8]}... to archive (correct program)") + self.conn.commit() logger.info( f"Created {len(created_ids)} copies of program " diff --git a/shinka/edit/apply_diff.py b/shinka/edit/apply_diff.py index ead28e231..f9dfb97c0 100644 --- a/shinka/edit/apply_diff.py +++ b/shinka/edit/apply_diff.py @@ -12,8 +12,8 @@ ) -EVOLVE_START = re.compile(r"(?:#|//|)?\s*EVOLVE-BLOCK-START") -EVOLVE_END = re.compile(r"(?:#|//|)?\s*EVOLVE-BLOCK-END") +EVOLVE_START = re.compile(r"(?:#|//|--|)?\s*EVOLVE-BLOCK-START") +EVOLVE_END = re.compile(r"(?:#|//|--|)?\s*EVOLVE-BLOCK-END") def _mutable_ranges(text: str) -> list[tuple[int, int]]: @@ -122,9 +122,11 @@ def _clean_evolve_markers(text: str) -> str: r"^\s*#\s*EVOLVE-BLOCK-START\s*$", # Python style r"^\s*//\s*EVOLVE-BLOCK-START\s*$", # C/C++/CUDA style r"^\s*EVOLVE-BLOCK-START\s*$", # Plain text + r"^\s*--\s*EVOLVE-BLOCK-START\s*$", # LEAN 4 r"^\s*#\s*EVOLVE-BLOCK-END\s*$", # Python style r"^\s*//\s*EVOLVE-BLOCK-END\s*$", # C/C++/CUDA r"^\s*EVOLVE-BLOCK-END\s*$", # Plain text + r"^\s*--\s*EVOLVE-BLOCK-END\s*$", # LEAN 4 ] cleaned_text = text @@ -553,7 +555,7 @@ def _create_no_evolve_block_error(original_text: str, operation: str) -> str: "", "Suggestions:", "1. Add EVOLVE-BLOCK-START and EVOLVE-BLOCK-END markers around editable code", - "2. Ensure the markers are properly formatted (with # for Python, // for C/C++)", + "2. Ensure the markers are properly formatted (with # for Python, // for C/C++ and -- for LEAN)", "3. Check that there's at least one EVOLVE-BLOCK region in the file", ] ) @@ -698,12 +700,15 @@ def apply_diff_patch( patch_str = _strip_trailing_whitespace(patch_str) # Remove the EVOLVE-BLOCK START and EVOLVE-BLOCK END markers - if language in ["cuda", "cpp"]: - patch_str = re.sub(r"// EVOLVE-BLOCK START\\n", "", patch_str) - patch_str = re.sub(r"// EVOLVE-BLOCK END\\n", "", patch_str) + if language in ["cuda", "cpp", "rust", "swift", "json", "json5"]: + patch_str = re.sub(r"// EVOLVE-BLOCK-START\\n", "", patch_str) + patch_str = re.sub(r"// EVOLVE-BLOCK-END\\n", "", patch_str) elif language == "python": - patch_str = re.sub(r"# EVOLVE-BLOCK START\\n", "", patch_str) - patch_str = re.sub(r"# EVOLVE-BLOCK END\\n", "", patch_str) + patch_str = re.sub(r"# EVOLVE-BLOCK-START\\n", "", patch_str) + patch_str = re.sub(r"# EVOLVE-BLOCK-END\\n", "", patch_str) + elif language.lower() == "lean": + patch_str = re.sub(r"-- EVOLVE-BLOCK START\\n", "", patch_str) + patch_str = re.sub(r"-- EVOLVE-BLOCK END\\n", "", patch_str) else: raise ValueError(f"Language {language} not supported") @@ -730,6 +735,14 @@ def apply_diff_patch( suffix = ".cpp" elif language == "cuda": suffix = ".cu" + elif language == "rust": + suffix = ".rs" + elif language == "swift": + suffix = ".swift" + elif language in ["json", "json5"]: + suffix = ".json" + elif language.lower() == "lean": # Run full lean files with the `LeanInteract` `FileCommand`. + suffix = ".lean" else: raise ValueError(f"Language {language} not supported") diff --git a/shinka/edit/apply_full.py b/shinka/edit/apply_full.py index b7e2e2b37..e14ce7d38 100644 --- a/shinka/edit/apply_full.py +++ b/shinka/edit/apply_full.py @@ -1,6 +1,6 @@ from pathlib import Path from typing import Optional, Union -from .apply_diff import write_git_diff, _mutable_ranges +from .apply_diff import write_git_diff, _mutable_ranges, EVOLVE_START, EVOLVE_END from shinka.llm import extract_between import logging @@ -72,10 +72,15 @@ def apply_full_patch( updated_content = "" last_end = 0 - # Check if patch_code contains EVOLVE-BLOCK markers - patch_mutable_ranges = _mutable_ranges(patch_code) + # Detect EVOLVE markers presence in the patch content + patch_has_start = EVOLVE_START.search(patch_code) is not None + patch_has_end = EVOLVE_END.search(patch_code) is not None + patch_has_both = patch_has_start and patch_has_end + patch_has_none = not patch_has_start and not patch_has_end - if patch_mutable_ranges: + if patch_has_both: + # Patch contains both EVOLVE-BLOCK markers, extract from them + patch_mutable_ranges = _mutable_ranges(patch_code) # Patch contains EVOLVE-BLOCK markers, extract from them for i, (start, end) in enumerate(mutable_ranges): # Add immutable part before this mutable range @@ -91,47 +96,158 @@ def apply_full_patch( updated_content += replacement_content last_end = end - else: + elif patch_has_none: # Patch doesn't contain EVOLVE-BLOCK markers # Assume entire patch content should replace all mutable regions if len(mutable_ranges) == 1: - # Single mutable region, replace with entire patch content + # Single mutable region. If the patch appears to be a full-file + # rewrite that omitted EVOLVE markers, safely extract only the + # content intended for the evolve block by matching immutable + # prefix/suffix from the original file. start, end = mutable_ranges[0] - # The mutable range ends before "EVOLVE-BLOCK-END" text - # We need to find the actual start of the comment line - if language == "python": - end_marker = "# EVOLVE-BLOCK-END" - elif language in ["cuda", "cpp"]: - end_marker = "// EVOLVE-BLOCK-END" - else: - end_marker = "# EVOLVE-BLOCK-END" # Default fallback - - end_marker_pos = original.find(end_marker, end - 5) - if end_marker_pos == -1: - # Fallback: use the original end position - end_marker_pos = end + # Immutable portions that remain outside the evolve block + immutable_prefix = original[:start] + immutable_suffix = original[end:] - # Ensure proper newline handling around the patch content - if patch_code and not patch_code.startswith("\n"): - patch_code = "\n" + patch_code + # Also compute the portions strictly outside the marker lines + # to detect full-file patches that omitted EVOLVE markers. + # Find the start and end marker line boundaries. + start_match = None + end_match = None + for m in EVOLVE_START.finditer(original): + if m.end() == start: + start_match = m + break + for m in EVOLVE_END.finditer(original): + if m.start() == end: + end_match = m + break - if patch_code and not patch_code.endswith("\n"): - patch_code = patch_code + "\n" - - updated_content = ( - original[:start] + patch_code + original[end_marker_pos:] + prefix_outside = ( + original[: start_match.start()] if start_match else immutable_prefix + ) + suffix_outside = ( + original[end_match.end() :] if end_match else immutable_suffix ) + + # Heuristic: if patch includes the same immutable prefix/suffix + # outside the markers, treat the middle part as the evolve-block + # replacement. Be tolerant to a missing trailing newline in the + # footer by checking both versions. + suffix_opts = (suffix_outside, suffix_outside.rstrip("\r\n")) + if patch_code.startswith(prefix_outside) and any( + patch_code.endswith(sfx) for sfx in suffix_opts + ): + mid_start = len(prefix_outside) + # choose the matching suffix option to compute end + sfx = next(sfx for sfx in suffix_opts if patch_code.endswith(sfx)) + mid_end = len(patch_code) - len(sfx) + replacement_content = patch_code[mid_start:mid_end] + # Ensure marker boundaries stay on their own lines. + # Add a leading newline only if there is a START marker. + if ( + start_match is not None + and replacement_content + and not replacement_content.startswith("\n") + ): + replacement_content = "\n" + replacement_content + # Add a trailing newline only if there is an END marker. + if ( + end_match is not None + and replacement_content + and not replacement_content.endswith("\n") + ): + replacement_content = replacement_content + "\n" + updated_content = ( + immutable_prefix + replacement_content + immutable_suffix + ) + else: + # Otherwise, assume the patch_code represents only the + # evolve-block payload and insert it directly between markers. + # Ensure proper newline handling around the patch content. + payload = patch_code + if ( + start_match is not None + and payload + and not payload.startswith("\n") + ): + payload = "\n" + payload + if end_match is not None and payload and not payload.endswith("\n"): + payload = payload + "\n" + updated_content = immutable_prefix + payload + immutable_suffix else: - # Multiple mutable regions, this is ambiguous + # Multiple EVOLVE-BLOCK regions found, ambiguous without markers error_message = ( "Multiple EVOLVE-BLOCK regions found but patch " "doesn't specify which to replace" ) return original, 0, None, error_message, None, None + else: + # Patch contains exactly one marker (START xor END). + # Only safe to apply when original has a single evolve region. + if len(mutable_ranges) != 1: + error_message = ( + "Patch contains only one EVOLVE-BLOCK marker, but the original " + f"has {len(mutable_ranges)} editable regions; cannot determine target" + ) + return original, 0, None, error_message, None, None + + # Single target region in original + start, end = mutable_ranges[0] + immutable_prefix = original[:start] + immutable_suffix = original[end:] + + # Find exact marker locations in original for newline policy + start_match = None + end_match = None + for m in EVOLVE_START.finditer(original): + if m.end() == start: + start_match = m + break + for m in EVOLVE_END.finditer(original): + if m.start() == end: + end_match = m + break + + # Compute outside-of-markers prefix/suffix from original + prefix_outside = ( + original[: start_match.start()] if start_match else immutable_prefix + ) + suffix_outside = ( + original[end_match.end() :] if end_match else immutable_suffix + ) + + # Extract payload based on which single marker is present in patch + if patch_has_start and not patch_has_end: + m = EVOLVE_START.search(patch_code) + payload = patch_code[m.end() :] if m else patch_code + # Trim footer if the patch included it + for sfx in (suffix_outside, suffix_outside.rstrip("\r\n")): + if sfx and payload.endswith(sfx): + payload = payload[: -len(sfx)] + break + elif patch_has_end and not patch_has_start: + m = EVOLVE_END.search(patch_code) + payload = patch_code[: m.start()] if m else patch_code + # Trim header if the patch included it + for pfx in (prefix_outside, prefix_outside.rstrip("\r\n")): + if pfx and payload.startswith(pfx): + payload = payload[len(pfx) :] + break + else: + payload = patch_code + + # Normalize newlines so markers remain on their own lines + if start_match is not None and payload and not payload.startswith("\n"): + payload = "\n" + payload + if end_match is not None and payload and not payload.endswith("\n"): + payload = payload + "\n" + + updated_content = immutable_prefix + payload + immutable_suffix # Add remaining immutable content after last mutable range - if patch_mutable_ranges and mutable_ranges: + if patch_has_both and mutable_ranges: updated_content += original[mutable_ranges[-1][1] :] num_applied = 1 @@ -146,6 +262,14 @@ def apply_full_patch( suffix = ".cpp" elif language == "cuda": suffix = ".cu" + elif language == "rust": + suffix = ".rs" + elif language == "swift": + suffix = ".swift" + elif language in ["json", "json5"]: + suffix = ".json" + elif language == "lean": + suffix = ".lean" else: raise ValueError(f"Language {language} not supported") diff --git a/shinka/edit/async_apply.py b/shinka/edit/async_apply.py index 8e542c565..e4c21202f 100644 --- a/shinka/edit/async_apply.py +++ b/shinka/edit/async_apply.py @@ -118,6 +118,31 @@ async def validate_code_async( error_msg = stderr.decode() if stderr else "Unknown compilation error" return False, error_msg + elif language == "rust": + # Use rustc for Rust syntax checking + proc = await asyncio.create_subprocess_exec( + "rustc", + "--crate-type=lib", + "-Zparse-only", + code_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return False, f"Validation timeout after {timeout}s" + + if proc.returncode == 0: + return True, None + else: + error_msg = stderr.decode() if stderr else "Unknown compilation error" + return False, error_msg elif language == "cpp": # Use g++ for C++ compilation check proc = await asyncio.create_subprocess_exec( @@ -128,6 +153,31 @@ async def validate_code_async( stderr=asyncio.subprocess.PIPE, ) + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + return False, f"Validation timeout after {timeout}s" + + if proc.returncode == 0: + return True, None + else: + error_msg = stderr.decode() if stderr else "Unknown compilation error" + return False, error_msg + elif language == "swift": + # Use swiftc for Swift syntax checking + proc = await asyncio.create_subprocess_exec( + "swiftc", + "-typecheck", + "-parse-as-library", + code_path, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + try: stdout, stderr = await asyncio.wait_for( proc.communicate(), timeout=timeout diff --git a/shinka/eval_hydra.py b/shinka/eval_hydra.py index 82b3c0d4b..4c5c72d35 100644 --- a/shinka/eval_hydra.py +++ b/shinka/eval_hydra.py @@ -65,4 +65,4 @@ def chdir_to_target(cfg): # import & run under the target directory with chdir_to_target(cfg.evaluate_function): - hydra.utils.instantiate(cfg.evaluate_function) + hydra.utils.instantiate(cfg.evaluate_function) \ No newline at end of file diff --git a/shinka/launch/scheduler.py b/shinka/launch/scheduler.py index 5782613ee..91efa1be2 100644 --- a/shinka/launch/scheduler.py +++ b/shinka/launch/scheduler.py @@ -138,7 +138,13 @@ def _build_command(self, exec_fname_t: str, results_dir_t: str) -> List[str]: ] if self.config.extra_cmd_args: for k, v in self.config.extra_cmd_args.items(): - cmd.extend([f"--{k}", str(v)]) + # Handle boolean flags + if isinstance(v, bool): + if v: # Only append flag if True + cmd.append(f"--{k}") + else: + # For non-boolean values, append both flag and value + cmd.extend([f"--{k}", str(v)]) return cmd def run( @@ -370,4 +376,4 @@ def cancel_job(): def shutdown(self): """Shutdown the thread pool executor.""" - self.executor.shutdown(wait=True) + self.executor.shutdown(wait=True) \ No newline at end of file diff --git a/shinka/llm/dynamic_sampling.py b/shinka/llm/dynamic_sampling.py index 6c038d9fa..eb0cd8cb3 100644 --- a/shinka/llm/dynamic_sampling.py +++ b/shinka/llm/dynamic_sampling.py @@ -28,7 +28,8 @@ def _logdiffexp(a_log, b_log): def _logexpm1(z): z = np.asarray(z, dtype=float) - return np.where(z > 50.0, z, np.log(np.expm1(z))) + with np.errstate(divide='ignore', invalid='ignore'): + return np.where(z > 50.0, z, np.log(np.expm1(z))) class BanditBase(ABC): @@ -433,12 +434,13 @@ def decay(self, factor: float) -> None: if self.use_exponential_scaling and self.asymmetric_scaling: # shrink in exp space to match original score scale s = self.s - log1p_term = np.where( - s > 0.0, - s + np.log(one_minus_factor + np.exp(-s)), - np.log1p(one_minus_factor * np.exp(s)), - ) - self.s = s + np.log(factor) - log1p_term + with np.errstate(divide='ignore', invalid='ignore'): + log1p_term = np.where( + s > 0.0, + s + np.log(one_minus_factor + np.exp(-s)), + np.log1p(one_minus_factor * np.exp(s)), + ) + self.s = s + np.log(factor) - log1p_term if self.adaptive_scale and np.isfinite(self._obs_max): means_log = self._mean() diff --git a/shinka/llm/embedding.py b/shinka/llm/embedding.py index a5c6b07cc..ba751b5f3 100644 --- a/shinka/llm/embedding.py +++ b/shinka/llm/embedding.py @@ -1,5 +1,6 @@ import os import openai +import google.generativeai as genai import pandas as pd from typing import Union, List, Optional, Tuple import numpy as np @@ -20,13 +21,23 @@ "azure-text-embedding-3-large", ] +GEMINI_EMBEDDING_MODELS = [ + "gemini-embedding-exp-03-07", + "gemini-embedding-001", +] + OPENAI_EMBEDDING_COSTS = { "text-embedding-3-small": 0.02 / M, "text-embedding-3-large": 0.13 / M, } +# Gemini embedding costs (approximate - check current pricing) +GEMINI_EMBEDDING_COSTS = { + "gemini-embedding-exp-03-07": 0.0 / M, # Experimental model, often free + "gemini-embedding-001": 0.15 / M, # Check current pricing +} -def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: +def get_client_model(model_name: str) -> tuple[Union[openai.OpenAI, str], str]: if model_name in OPENAI_EMBEDDING_MODELS: client = openai.OpenAI() model_to_use = model_name @@ -38,6 +49,14 @@ def get_client_model(model_name: str) -> tuple[openai.OpenAI, str]: api_version=os.getenv("AZURE_API_VERSION"), azure_endpoint=os.getenv("AZURE_API_ENDPOINT"), ) + elif model_name in GEMINI_EMBEDDING_MODELS: + # Configure Gemini API + api_key = os.getenv("GEMINI_API_KEY") + if not api_key: + raise ValueError("GEMINI_API_KEY environment variable not set for Gemini models") + genai.configure(api_key=api_key) + client = "gemini" # Use string identifier for Gemini + model_to_use = model_name else: raise ValueError(f"Invalid embedding model: {model_name}") @@ -52,9 +71,10 @@ def __init__( Initialize the EmbeddingClient. Args: - model (str): The OpenAI embedding model name to use. + model (str): The OpenAI, Azure, or Gemini embedding model name to use. """ self.client, self.model = get_client_model(model_name) + self.model_name = model_name self.verbose = verbose def get_embedding( @@ -76,6 +96,34 @@ def get_embedding( single_code = True else: single_code = False + # Handle Gemini models + if self.model_name in GEMINI_EMBEDDING_MODELS: + try: + embeddings = [] + total_tokens = 0 + + for text in code: + result = genai.embed_content( + model=f"models/{self.model}", + content=text, + task_type="retrieval_document" + ) + embeddings.append(result['embedding']) + total_tokens += len(text.split()) + + cost = total_tokens * GEMINI_EMBEDDING_COSTS.get(self.model, 0.0) + + if single_code: + return embeddings[0] if embeddings else [], cost + else: + return embeddings, cost + except Exception as e: + logger.error(f"Error getting Gemini embedding: {e}") + if single_code: + return [], 0.0 + else: + return [[]], 0.0 + # Handle OpenAI and Azure models (same interface) try: response = self.client.embeddings.create( model=self.model, input=code, encoding_format="float" diff --git a/shinka/llm/models/pricing.py b/shinka/llm/models/pricing.py index c9c101a2c..72d6300bc 100644 --- a/shinka/llm/models/pricing.py +++ b/shinka/llm/models/pricing.py @@ -35,6 +35,10 @@ "input_price": 3.0 / M, "output_price": 15.0 / M, }, + "claude-sonnet-4-5-20250929": { + "input_price": 3.0 / M, + "output_price": 15.0 / M, + }, } OPENAI_MODELS = { @@ -114,6 +118,14 @@ "input_price": 0.05 / M, "output_price": 0.4 / M, }, + "gpt-5.1": { + "input_price": 1.25 / M, + "output_price": 10.0 / M, + }, + "gpt-5.2": { + "input_price": 1.75 / M, + "output_price": 14.0 / M, + }, } @@ -141,6 +153,14 @@ "input_price": 0.1 / M, "output_price": 0.4 / M, }, + "gemini-3-pro-preview" : { + "input_price": 2.0 / M, + "output_price": 12.0 / M, + }, + "gemini-3-flash-preview" : { + "input_price": 0.5 / M, + "output_price": 3.0 / M, + }, } BEDROCK_MODELS = { @@ -171,11 +191,14 @@ "gpt-5", "gpt-5-mini", "gpt-5-nano", + "gpt-5.1", + "gpt-5.2", ] REASONING_CLAUDE_MODELS = [ "claude-3-7-sonnet-20250219", "claude-4-sonnet-20250514", + "claude-sonnet-4-5-20250929", ] REASONING_DEEPSEEK_MODELS = [ @@ -186,6 +209,8 @@ "gemini-2.5-pro", "gemini-2.5-flash", "gemini-2.5-flash-lite-preview-06-17", + "gemini-3-pro-preview", + "gemini-3-flash-preview", ] REASONING_AZURE_MODELS = [ diff --git a/shinka/llm/query.py b/shinka/llm/query.py index a7288df8e..c88c7d7c3 100644 --- a/shinka/llm/query.py +++ b/shinka/llm/query.py @@ -137,16 +137,13 @@ def sample_model_kwargs( r_effort = random.choice(reasoning_efforts) think_bool = r_effort != "auto" if think_bool: - thinking_tokens = [ - t - for t in THINKING_TOKENS.values() - if t < kwargs_dict["max_tokens"] and t >= 1024 - ] + t = THINKING_TOKENS[r_effort] + thinking_tokens = t if t < kwargs_dict["max_tokens"] else 1024 kwargs_dict["extra_body"] = { "extra_body": { "google": { "thinking_config": { - "thinking_budget": random.choice(thinking_tokens), + "thinking_budget": thinking_tokens, "include_thoughts": True, } } @@ -157,19 +154,17 @@ def sample_model_kwargs( REASONING_CLAUDE_MODELS + REASONING_BEDROCK_MODELS ): kwargs_dict["max_tokens"] = min(random.choice(max_tokens), 16384) - think_bool = random.choice(reasoning_efforts) != "auto" + r_effort = random.choice(reasoning_efforts) + think_bool = r_effort != "auto" if think_bool: # filter thinking tokens to be smaller than max_tokens # not auto THINKING_TOKENS - thinking_tokens = [ - t - for t in THINKING_TOKENS.values() - if t < kwargs_dict["max_tokens"] and t >= 1024 - ] + t = THINKING_TOKENS[r_effort] + thinking_tokens = t if t < kwargs_dict["max_tokens"] else 1024 # sample only from thinking tokens that are valid kwargs_dict["thinking"] = { "type": "enabled", - "budget_tokens": random.choice(thinking_tokens), + "budget_tokens": thinking_tokens, } else: diff --git a/shinka/webui/__init__.py b/shinka/webui/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_edit_base.py b/tests/test_edit_base.py index edc0e1178..67c6f2e20 100644 --- a/tests/test_edit_base.py +++ b/tests/test_edit_base.py @@ -161,6 +161,110 @@ def new_func2(): # Should have replaced both evolve blocks with new content +def test_apply_full_patch_full_file_without_markers_extracts_block_only(): + """Full-file patch without EVOLVE markers should not copy immutable code + into the evolve block; only the block payload is replaced.""" + original_content = """# Header line\n# EVOLVE-BLOCK-START\nold_line()\n# EVOLVE-BLOCK-END\n# Footer line\n""" + + # Patch is the entire file content but with the EVOLVE markers omitted. + patch_content = """```python +new_line() +another_new_line() +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +another_new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + +def test_apply_full_patch_patch_with_start_marker_only(): + """Patch has only START marker; original has both markers.""" + original_content = """# Header line +# EVOLVE-BLOCK-START +old_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + patch_content = """```python +# Header line +# EVOLVE-BLOCK-START +new_line() +# Footer line +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + +def test_apply_full_patch_patch_with_end_marker_only(): + """Patch has only END marker; original has both markers.""" + original_content = """# Header line +# EVOLVE-BLOCK-START +old_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + patch_content = """```python +# Header line +new_line() +# EVOLVE-BLOCK-END +# Footer line +```""" + + expected = """# Header line +# EVOLVE-BLOCK-START +new_line() +# EVOLVE-BLOCK-END +# Footer line +""" + + result = apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + updated_content, num_applied, output_path, error, patch_txt, diff_path = result + + assert error is None + assert num_applied == 1 + assert updated_content == expected + + def test_apply_full_patch_no_evolve_blocks(): """Test apply_full_patch with no EVOLVE-BLOCK regions - should error.""" original_content = """# Just regular code @@ -221,6 +325,41 @@ def new_function(): assert updated_content == original_content # Should return original content +def test_apply_full_patch_patch_with_single_marker_ambiguous_multiple_regions(): + """Single marker in patch is ambiguous when original has multiple regions.""" + original_content = """# Header +# EVOLVE-BLOCK-START +func1() +# EVOLVE-BLOCK-END + +# EVOLVE-BLOCK-START +func2() +# EVOLVE-BLOCK-END +# Footer +""" + + # Patch includes only START marker + patch_content = """```python +# Header +# EVOLVE-BLOCK-START +new_code() +# Footer +```""" + + updated_content, num_applied, output_path, error, patch_txt, diff_path = ( + apply_full_patch( + patch_str=patch_content, + original_str=original_content, + language="python", + verbose=False, + ) + ) + + assert num_applied == 0 + assert error is not None + assert "only one EVOLVE-BLOCK marker" in error + + def test_apply_full_patch_invalid_extraction(): """Test apply_full_patch with invalid code extraction.""" original_content = """# EVOLVE-BLOCK-START