diff --git a/src/scenicplus/plotting/dotplot.py b/src/scenicplus/plotting/dotplot.py index de01e30..f5f8ae4 100644 --- a/src/scenicplus/plotting/dotplot.py +++ b/src/scenicplus/plotting/dotplot.py @@ -102,10 +102,10 @@ def _find_idx(l, e): size_matrix = size_matrix[color_matrix_features[2]] if subset_eRegulons is not None: - #change to TF names - subset_eRegulons = [x.split('_')[0] for x in subset_eRegulons] - size_matrix = size_matrix[[x for x in size_matrix if x.split('_')[0] in subset_eRegulons]] - color_matrix = color_matrix[[x for x in color_matrix if x.split('_')[0] in subset_eRegulons]] + #filter eRegulon + subset_eRegulons = [x.split('_(')[0] for x in subset_eRegulons] + size_matrix = size_matrix[[x for x in size_matrix if x.split('_(')[0] in subset_eRegulons]] + color_matrix = color_matrix[[x for x in color_matrix if x.split('_(')[0] in subset_eRegulons]] if scale_size_matrix: size_matrix = (size_matrix - size_matrix.min()) / (size_matrix.max() - size_matrix.min()) @@ -204,8 +204,12 @@ def heatmap_dotplot( order = pd.concat([idx_max[idx_max == x] for x in tmp.index.tolist() if len(plotting_df[plotting_df == x]) > 0]).index.tolist() plotting_df['eRegulon_name'] = pd.Categorical(plotting_df['eRegulon_name'], categories = order) plotnine.options.figure_size = figsize + #check repressor availability + plotting_df['repressor_activator'] = ['activator' if '+' in n.split('_')[1] and 'extended' not in n or '+' in n.split('_')[2] and 'extended' in n else 'repressor' for n in plotting_df['eRegulon_name']] + repressor_availability = 'repressor' in plotting_df['repressor_activator'] + if not repressor_availability: + split_repressor_activator = False if split_repressor_activator: - plotting_df['repressor_activator'] = ['activator' if '+' in n.split('_')[1] and 'extended' not in n or '+' in n.split('_')[2] and 'extended' in n else 'repressor' for n in plotting_df['eRegulon_name']] if orientation == 'vertical': plot = ( ggplot(plotting_df, aes('index', 'eRegulon_name')) @@ -218,7 +222,6 @@ def heatmap_dotplot( + geom_point( mapping = aes(size = 'size_val'), colour = "black") - + theme(axis_text_x=element_text(rotation=90, hjust=1)) + theme(axis_title_x = element_blank(), axis_title_y = element_blank())) elif orientation == 'horizontal': plot = ( @@ -252,6 +255,7 @@ def heatmap_dotplot( + geom_point( mapping = aes(size = 'size_val'), colour = "black") + + theme(axis_text_x=element_text(rotation=90, hjust=1)) + theme(axis_title_x = element_blank(), axis_title_y = element_blank())) if save is not None: plot.save(save)