Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 0 additions & 93 deletions ot/helpers/openmp_helpers.py

This file was deleted.

1 change: 0 additions & 1 deletion ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ enum ProblemType {
};

int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);

int EMD_wrap_sparse(
int n1,
Expand Down
87 changes: 0 additions & 87 deletions ot/lp/EMD_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@


#include "network_simplex_simple.h"
#include "network_simplex_simple_omp.h"
#include "sparse_bipartitegraph.h"
#include "EMD.h"
#include <cstdint>
Expand Down Expand Up @@ -239,92 +238,6 @@ int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G,
}







int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G,
double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) {
// beware M and C are stored in row major C style!!!

using namespace lemon_omp;
uint64_t n, m, cur;

typedef FullBipartiteDigraph Digraph;
DIGRAPH_TYPEDEFS(Digraph);

// Get the number of non zero coordinates for r and c
n=0;
for (int i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
n++;
}else if(val<0){
return INFEASIBLE;
}
}
m=0;
for (int i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
m++;
}else if(val<0){
return INFEASIBLE;
}
}

// Define the graph

std::vector<uint64_t> indI(n), indJ(m);
std::vector<double> weights1(n), weights2(m);
Digraph di(n, m);
const SetupPolicy policy = make_setup_policy(n, m, n1, n2, false);
NetworkSimplexSimple<Digraph,double,double, node_id_type> net(
di, policy.use_arc_mixing, (int) (n + m), n * m, maxIter, numThreads
);

// Set supply and demand, don't account for 0 values (faster)

cur=0;
for (uint64_t i=0; i<n1; i++) {
double val=*(X+i);
if (val>0) {
weights1[ cur ] = val;
indI[cur++]=i;
}
}

// Demand is actually negative supply...

cur=0;
for (uint64_t i=0; i<n2; i++) {
double val=*(Y+i);
if (val>0) {
weights2[ cur ] = -val;
indJ[cur++]=i;
}
}


net.supplyMap(&weights1[0], (int) n, &weights2[0], (int) m);

setup_explicit_arc_costs(net, di, D, n2, indI, indJ, n, m);

// Solve the problem with the network simplex algorithm

int ret=net.run();
if (ret==(int)net.OPTIMAL || ret==(int)net.MAX_ITER_REACHED) {
*cost = 0;
extract_compressed_support(
net, di, INVALID, D, G, alpha, beta, cost, indI, indJ, n, n2
);

}

return ret;
}

// ============================================================================
// SPARSE VERSION: Accepts edge list instead of dense cost matrix
// ============================================================================
Expand Down
14 changes: 8 additions & 6 deletions ot/lp/_barycenter_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,10 @@ def free_support_barycenter(
Print information along iterations
log : bool, optional
record log if True
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
numThreads: int or "max", optional (default=1)
Deprecated compatibility argument. Multi-threaded EMD is no longer
supported and any value greater than 1 falls back to the
single-threaded solver with a warning.

Returns
-------
Expand Down Expand Up @@ -349,9 +350,10 @@ def generalized_free_support_barycenter(
Print information along iterations
log : bool, optional
record log if True
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
numThreads: int or "max", optional (default=1)
Deprecated compatibility argument. Multi-threaded EMD is no longer
supported and any value greater than 1 falls back to the
single-threaded solver with a warning.
eps: Stability coefficient for the change of variable matrix inversion
If the :math:`\mathbf{P}_i^T` matrices don't span :math:`\mathbb{R}^d`, the problem is ill-defined and a matrix
inversion will fail. In this case one may set eps=1e-8 and get a solution anyway (which may make little sense)
Expand Down
14 changes: 8 additions & 6 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,10 @@ def emd(
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:py:func:`ot.lp.center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
numThreads: int or "max", optional (default=1)
Deprecated compatibility argument. Multi-threaded EMD is no longer
supported and any value greater than 1 falls back to the
single-threaded solver with a warning.
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.
Expand Down Expand Up @@ -590,9 +591,10 @@ def emd2(
center_dual: boolean, optional (default=True)
If True, centers the dual potential using function
:py:func:`ot.lp.center_ot_dual`.
numThreads: int or "max", optional (default=1, i.e. OpenMP is not used)
If compiled with OpenMP, chooses the number of threads to parallelize.
"max" selects the highest number possible.
numThreads: int or "max", optional (default=1)
Deprecated compatibility argument. Multi-threaded EMD is no longer
supported and any value greater than 1 falls back to the
single-threaded solver with a warning.
check_marginals: bool, optional (default=True)
If True, checks that the marginals mass are equal. If False, skips the
check.
Expand Down
12 changes: 7 additions & 5 deletions ot/lp/emd_wrap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import warnings

cdef extern from "EMD.h":
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil
int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
int EMD_wrap_lazy(int n1, int n2, double *X, double *Y, double *coords_a, double *coords_b, int dim, int metric, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, double* alpha_init, double* beta_init) nogil
cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED
Expand Down Expand Up @@ -128,12 +127,15 @@ def emd_c(np.ndarray[double, ndim=1, mode="c"] a, np.ndarray[double, ndim=1, mod
alpha_init_ptr = <double*> alpha_init_c.data
beta_init_ptr = <double*> beta_init_c.data

if numThreads != 1:
warnings.warn(
"numThreads is no longer supported for EMD; falling back to the single-threaded solver.",
UserWarning,
)

# calling the function
with nogil:
if numThreads == 1:
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
else:
result_code = EMD_wrap_omp(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, numThreads)
result_code = EMD_wrap(n1, n2, <double*> a.data, <double*> b.data, <double*> M.data, <double*> G.data, <double*> alpha.data, <double*> beta.data, <double*> &cost, max_iter, alpha_init_ptr, beta_init_ptr)
return G, cost, alpha, beta, result_code


Expand Down
Loading
Loading