Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion westpa_wexplore/wex_utils.so
27 changes: 18 additions & 9 deletions westpa_wexplore/wexplore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import networkx as nx
import pandas as pd
import heapq
import cPickle as pickle
#KFW import cPickle as pickle
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be better handled by:

try:
    import cPickle as pickle
except ImportError:
    import pickle

I guess this depends on whether we want to maintain python 2/3 compatibility or if we're dropping python2 completely.

import pickle
import hashlib
import logging
import itertools
Expand Down Expand Up @@ -30,7 +31,8 @@ def __init__(self, n_regions, d_cut, dfunc, dfargs=None, dfkwargs=None):
self.centers = []

# List of bin indices for each level
self.level_indices = [[] for k in xrange(self.n_levels)]
#KFW self.level_indices = [[] for k in xrange(self.n_levels)]
self.level_indices = [[] for k in range(self.n_levels)]

# Directed graph containing that defines the connectivity
# of the hierarchical bin space.
Expand Down Expand Up @@ -102,11 +104,14 @@ def _assign_level(self, coords, centers, mask, output, min_dist):
return output

def dump_graph(self):
print ''
#KFW print ''
print('')
for li, nodes in enumerate(self.level_indices):
print 'Level: ', li
#KFW print 'Level: ', li
print('Level: ', li)
for nix in nodes:
print nix, self.bin_graph.node[nix]
#KFW print nix, self.bin_graph.node[nix]
print(nix, self.bin_graph.node[nix])

def _distribute_to_children(self, G, output, coord_indices, node_indices):
s = pd.Series(output, copy=False)
Expand Down Expand Up @@ -144,15 +149,17 @@ def _prune_violations(self):

self._assign_level(centers, pcenters, mask, output, min_dist)

for k in xrange(ncenters):
#KFW for k in xrange(ncenters):
for k in range(ncenters):
if output[k] != prev_level_nodes.index(parent_nodes[k]):
nix = level_indices[k]
succ = nx.algorithms.traversal.dfs_successors(G, nix).values()
nodes_remove.extend(list(itertools.chain.from_iterable(succ)) + [nix])

if len(nodes_remove):
G.remove_nodes_from(nodes_remove)
for k in xrange(self.n_levels):
#KFW for k in xrange(self.n_levels):
for k in range(self.n_levels):
self.level_indices[k] = [nix for nix in self.level_indices[k] if nix not in nodes_remove]

def assign(self, coords, mask=None, output=None, add_bins=False):
Expand Down Expand Up @@ -211,7 +218,8 @@ def assign(self, coords, mask=None, output=None, add_bins=False):
if coord_ix >= 0:
new_bins.append((0, None, coord_ix))

for lid in xrange(1, self.n_levels):
#KFW for lid in xrange(1, self.n_levels):
for lid in range(1, self.n_levels):
next_obs_nodes = []
for nix in obs_nodes:
successors = list(G.successors(nix))
Expand Down Expand Up @@ -360,7 +368,8 @@ def balance_replicas(self, max_replicas, assignments):
for top_node in self.level_indices[0]:
for nix in nx.algorithms.traversal.dfs_postorder_nodes(G, top_node):
try:
pred = G.pred[nix].keys()[0] # parent node
#KFW pred = G.pred[nix].keys()[0] # parent node
pred = list(G.pred[nix].keys())[0] # parent node
G.node[pred]['nreplicas'] += G.node[nix]['nreplicas']
except IndexError:
pass
Expand Down