Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion cgp/genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from .node import Node, OperatorNode
from .primitives import Primitives

try:
Comment thread
jakobj marked this conversation as resolved.
Outdated
import sympy

sympy_available = False
except ModuleNotFoundError:
sympy_available = True

try:
import torch # noqa: F401

Expand Down Expand Up @@ -359,7 +366,11 @@ def change_address_gene_of_output_node(self, new_address: int, output_node_idx:
self.dna = dna

def set_expression_for_output(
self, dna_insert: List[int], hidden_start_node: int = 0, output_node_idx: int = 0
self,
dna_insert: List[int],
hidden_start_node: int = 0,
output_node_idx: int = 0,
target_expression: Optional[str] = None,
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
):
"""Set an expression for one output node

Expand All @@ -378,6 +389,9 @@ def set_expression_for_output(
Index of the output node which will read the last node of the insert.
Relative to the first output node.
Defaults to 0.
target_expression: str, optional
Expression the output node should compile to. Numbers must be written as float.
Defaults to None.
Returns
----------
None
Expand All @@ -388,6 +402,16 @@ def set_expression_for_output(
self.change_address_gene_of_output_node(
new_address=last_inserted_node, output_node_idx=output_node_idx
)
if target_expression is not None:
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
if self._n_outputs > 1:
output_as_sympy = CartesianGraph(self).to_sympy()[output_node_idx]
else:
output_as_sympy = CartesianGraph(self).to_sympy()

target_expression_as_sympy = sympy.parse_expr(target_expression)
if not output_as_sympy == target_expression_as_sympy:
print(target_expression_as_sympy, output_as_sympy)
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
raise ValueError("Target expression and set output expression do not match")
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated

def reorder(self, rng: np.random.RandomState) -> None:
"""Reorder the genome
Expand Down
20 changes: 18 additions & 2 deletions test/test_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,25 @@ def test_set_expression_for_output(genome_params, rng):
assert CartesianGraph(genome).to_sympy() == x_0 + x_1

new_dna = [1, 0, 1]
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression=" x_0 - x_1")
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
assert CartesianGraph(genome).to_sympy() == x_0 - x_1

new_dna = [0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 3] # x_0+x_1; 1.0; 0; x_0+x_1 + 1.0
genome.set_expression_for_output(new_dna)
genome.set_expression_for_output(dna_insert=new_dna, target_expression=" x_0 + x_1 + 1.0")
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
assert CartesianGraph(genome).to_sympy() == x_0 + x_1 + 1.0

with pytest.raises(ValueError):
genome.set_expression_for_output(dna_insert=new_dna, target_expression=" x_0 + x_1 + 1")
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated

genome2_params = {
"n_inputs": 2,
"n_outputs": 2,
"primitives": (cgp.Add, cgp.Sub, cgp.ConstantFloat),
}
genome2 = cgp.Genome(**genome2_params)
genome2.randomize(rng)

genome2.set_expression_for_output(
new_dna, output_node_idx=1, target_expression=" x_0 + x_1 + 1.0"
Comment thread
HenrikMettler marked this conversation as resolved.
Outdated
)
assert CartesianGraph(genome2).to_sympy()[1] == x_0 + x_1 + 1.0