diff --git a/cgp/genome.py b/cgp/genome.py index db72b9be..9b48a5ad 100644 --- a/cgp/genome.py +++ b/cgp/genome.py @@ -359,7 +359,11 @@ def change_address_gene_of_output_node(self, new_address: int, output_node_idx: self.dna = dna def set_expression_for_output( - self, dna_insert: List[int], hidden_start_node: int = 0, output_node_idx: int = 0 + self, + dna_insert: List[int], + target_expression: str, + hidden_start_node: int = 0, + output_node_idx: int = 0, ): """Set an expression for one output node @@ -370,6 +374,8 @@ def set_expression_for_output( ---------- dna_insert: List[int] dna segment to be inserted at the first hidden nodes. + target_expression: str + Expression the output node should compile to. Numbers must be written as float. hidden_start_node: int Index of the hidden node, where the insert starts. Relative to the first hidden node. @@ -388,6 +394,22 @@ def set_expression_for_output( self.change_address_gene_of_output_node( new_address=last_inserted_node, output_node_idx=output_node_idx ) + try: + import sympy + + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Can not check output expression. No module named 'sympy' (extra requirement)" + ) + + if self._n_outputs > 1: + output_as_sympy = CartesianGraph(self).to_sympy()[output_node_idx] + else: + output_as_sympy = CartesianGraph(self).to_sympy() + + target_expression_as_sympy = sympy.parse_expr(target_expression) + if not output_as_sympy == target_expression_as_sympy: + raise ValueError("expression of output and target expression do not match") def reorder(self, rng: np.random.RandomState) -> None: """Reorder the genome diff --git a/test/test_genome.py b/test/test_genome.py index 34ae3b47..ffeb43ef 100644 --- a/test/test_genome.py +++ b/test/test_genome.py @@ -835,17 +835,35 @@ def test_set_expression_for_output(genome_params, rng): genome = cgp.Genome(**genome_params) genome.randomize(rng) - new_dna = [0, 0, 1] - genome.set_expression_for_output(new_dna) - x_0 = sympy.symbols("x_0") x_1 = sympy.symbols("x_1") + + new_dna = [0, 0, 1] + genome.set_expression_for_output(new_dna, target_expression="x_0 + x_1") assert CartesianGraph(genome).to_sympy() == x_0 + x_1 new_dna = [1, 0, 1] - genome.set_expression_for_output(new_dna) + genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 - x_1") assert CartesianGraph(genome).to_sympy() == x_0 - x_1 new_dna = [0, 0, 1, 2, 0, 0, 1, 0, 0, 0, 2, 3] # x_0+x_1; 1.0; 0; x_0+x_1 + 1.0 - genome.set_expression_for_output(new_dna) + genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1.0") assert CartesianGraph(genome).to_sympy() == x_0 + x_1 + 1.0 + + with pytest.raises(ValueError): + # setting an int in the str causes an error + genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 + 1") + genome.set_expression_for_output(dna_insert=new_dna, target_expression="x_0 + x_1 * 1.0") + + genome2_params = { + "n_inputs": 2, + "n_outputs": 2, + "primitives": (cgp.Add, cgp.Sub, cgp.ConstantFloat), + } + genome2 = cgp.Genome(**genome2_params) + genome2.randomize(rng) + + genome2.set_expression_for_output( + new_dna, output_node_idx=1, target_expression="x_0 + x_1 + 1.0" + ) + assert CartesianGraph(genome2).to_sympy()[1] == x_0 + x_1 + 1.0