Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 1 addition & 82 deletions src/bridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function MOI.Bridges.Constraint.bridge_constraint(
func.iterators[k].values[idx[k]] for k in eachindex(func.iterators)
]
expanded = _expand(func.func, values)
scalar_func = _convert(F, expanded)
scalar_func = convert(F, expanded)
ci = MOI.Utilities.normalize_and_add_constraint(
model,
scalar_func,
Expand Down Expand Up @@ -164,84 +164,3 @@ function _eval_op(head::Symbol, args::Vector)
)
end
end

# --- Conversion from expanded ScalarNonlinearFunction to target type F ---

function _convert(
::Type{MOI.ScalarNonlinearFunction},
expr::MOI.ScalarNonlinearFunction,
)
return expr
end
_convert(::Type{F}, expr::F) where {F} = expr

function _convert(
::Type{MOI.ScalarAffineFunction{T}},
expr::MOI.ScalarNonlinearFunction,
) where {T}
terms, constant = _collect_affine_terms(T, expr)
if terms === nothing
throw(InexactError(:convert, MOI.ScalarAffineFunction{T}, expr))
end
return MOI.ScalarAffineFunction(terms, T(constant))
end

function _collect_affine_terms(
::Type{T},
expr::MOI.ScalarNonlinearFunction,
) where {T}
if expr.head == :+ && length(expr.args) == 2
t1, c1 = _collect_affine_terms(T, expr.args[1])
t2, c2 = _collect_affine_terms(T, expr.args[2])
(t1 === nothing || t2 === nothing) && return (nothing, zero(T))
return (vcat(t1, t2), c1 + c2)
elseif expr.head == :- && length(expr.args) == 2
t1, c1 = _collect_affine_terms(T, expr.args[1])
t2, c2 = _collect_affine_terms(T, expr.args[2])
(t1 === nothing || t2 === nothing) && return (nothing, zero(T))
neg_t2 = [MOI.ScalarAffineTerm(-t.coefficient, t.variable) for t in t2]
return (vcat(t1, neg_t2), c1 - c2)
elseif expr.head == :- && length(expr.args) == 1
t1, c1 = _collect_affine_terms(T, expr.args[1])
t1 === nothing && return (nothing, zero(T))
neg_t1 = [MOI.ScalarAffineTerm(-t.coefficient, t.variable) for t in t1]
return (neg_t1, -c1)
elseif expr.head == :* && length(expr.args) == 2
a1, a2 = expr.args
if a1 isa Number && a2 isa MOI.VariableIndex
return ([MOI.ScalarAffineTerm(T(a1), a2)], zero(T))
elseif a2 isa Number && a1 isa MOI.VariableIndex
return ([MOI.ScalarAffineTerm(T(a2), a1)], zero(T))
elseif a1 isa Number && a2 isa MOI.ScalarNonlinearFunction
t2, c2 = _collect_affine_terms(T, a2)
t2 === nothing && return (nothing, zero(T))
scaled = [
MOI.ScalarAffineTerm(T(a1) * t.coefficient, t.variable) for
t in t2
]
return (scaled, T(a1) * c2)
elseif a2 isa Number && a1 isa MOI.ScalarNonlinearFunction
t1, c1 = _collect_affine_terms(T, a1)
t1 === nothing && return (nothing, zero(T))
scaled = [
MOI.ScalarAffineTerm(T(a2) * t.coefficient, t.variable) for
t in t1
]
return (scaled, T(a2) * c1)
elseif a1 isa Number && a2 isa Number
return (MOI.ScalarAffineTerm{T}[], T(a1 * a2))
else
return (nothing, zero(T))
end
else
return (nothing, zero(T))
end
end

function _collect_affine_terms(::Type{T}, x::MOI.VariableIndex) where {T}
return ([MOI.ScalarAffineTerm(one(T), x)], zero(T))
end

function _collect_affine_terms(::Type{T}, x::Number) where {T}
return (MOI.ScalarAffineTerm{T}[], T(x))
end
12 changes: 0 additions & 12 deletions test/bridge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,6 @@ function test_expand_with_variable_in_expr()
@test result.args[2] == 3.0
end

function test_convert_to_affine()
x1 = MOI.VariableIndex(1)
# x1 - 1.0 should convert to ScalarAffineFunction
func = MOI.ScalarNonlinearFunction(:-, Any[x1, 1.0])
result = GenOpt._convert(MOI.ScalarAffineFunction{Float64}, func)
@test result isa MOI.ScalarAffineFunction{Float64}
@test length(result.terms) == 1
@test result.terms[1].coefficient == 1.0
@test result.terms[1].variable == x1
@test result.constant == -1.0
end

function test_simple_constraint_group()
# min sum(x) s.t. x[i] >= 1 for i in 1..3
optimizer = _create_optimizer()
Expand Down