Spaces:
Sleeping
Sleeping
import numpy as np | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from constants import * | |
def groupby(array_like, hue_order=None): | |
idx = np.argsort(array_like, kind='stable') | |
values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True) | |
split_idx = np.split(idx, indices[1:]) | |
name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)} | |
if hue_order is not None and isinstance(hue_order, list): | |
for k in sorted(hue_order): | |
if k in name2indices: | |
yield k, name2indices[k] | |
return | |
for k in sorted(name2indices): | |
yield k, name2indices[k] | |
def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1, | |
hue_group_offset=0.5, hue_order=[], | |
hue2count={}, width=0.9, ax=None, show_legend=True, | |
palette='tab10'): | |
if isinstance(palette, str): | |
palette = sns.color_palette(palette) | |
if label_list is None: | |
label_list = np.asarray([hue_order[x] for x in targets_list]) | |
hue_values, ucount = np.unique(targets_list, return_counts=True) | |
n_bins = max(len(hue_values), len(hue_values)) | |
bin_size = top_n | |
hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) # | |
hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)} | |
# print(hue_label2offset) | |
tick_positions = [] | |
tick_labels = [] | |
max_x_value = 0 | |
for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)): | |
hue_label = hue_order[hue_index] | |
#print(idx, hue_label, hue_indices) | |
bottom = np.zeros(n_bins*bin_size) | |
subset_y = label_list[hue_indices] | |
#print(subset_y) | |
bin_labels, bin_counts = np.unique(subset_y, return_counts=True) | |
# if normalize: | |
denominator = hue2count.get(hue_label, 1) | |
bin_counts = bin_counts / denominator | |
max_x_value = max(max_x_value, bin_counts.max()) | |
if hue_label in hue_order: | |
color_index = hue_order.index(hue_label) | |
else: | |
color_index = idx | |
# new | |
top_indices = np.argsort(bin_counts)[::-1][:bin_size] | |
bin_labels = bin_labels[top_indices] | |
bin_counts = bin_counts[top_indices] | |
bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)]) | |
tick_positions.extend(bin_indices) | |
tick_labels.extend(bin_labels) | |
# old | |
#offset = hue_offsets.get(hue_label, 0) | |
#bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels]) | |
p = ax.barh( | |
bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices], | |
color=palette[color_index]) | |
# if do_stack: | |
# bottom[bin_indices] += bin_counts | |
# if not normalize: | |
# bottom[bin_indices] += bar_offset | |
line_pos = bin_indices.max() + width/2 + hue_group_offset/2 | |
ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE) | |
if show_legend: | |
ax.legend( | |
loc='upper center', bbox_to_anchor=(0.5, -0.05), | |
fancybox=True, shadow=True, | |
ncol=4 | |
) | |
ax.set_yticks(tick_positions) | |
ax.set_yticklabels(tick_labels) | |
if max_x_value <= 1: | |
ax.set_xlim(0, 1.) | |
ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2) | |
ax.invert_yaxis() |