File size: 3,521 Bytes
694c1c6
 
 
 
1bcc7b4
694c1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b40aac1
694c1c6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from constants import *


def groupby(array_like, hue_order=None):
    idx = np.argsort(array_like, kind='stable')
    values, indices, counts = np.unique(array_like[idx], return_counts=True, return_index=True)
    split_idx = np.split(idx, indices[1:])
    name2indices = {group_name: indices for group_name, indices in zip(values, split_idx)}
    if hue_order is not None and isinstance(hue_order, list):
        for k in sorted(hue_order):
            if k in name2indices:
                yield k, name2indices[k]
        return
    for k in sorted(name2indices):
        yield k, name2indices[k]


def draw_barplots(targets_list, label_list=None, top_n=5, bin_width=1, 
                  hue_group_offset=0.5, hue_order=[], 
                  hue2count={}, width=0.9, ax=None, show_legend=True, 
                  palette='tab10'):
    if isinstance(palette, str):
        palette = sns.color_palette(palette)
    if label_list is None:
        label_list = np.asarray([hue_order[x] for x in targets_list])
    hue_values, ucount = np.unique(targets_list, return_counts=True)
    n_bins = max(len(hue_values), len(hue_values))
    bin_size = top_n
    
    hue_offset = np.arange(n_bins)*(bin_size*bin_width + hue_group_offset) # 
    hue_label2offset = {hue_order[k]: v for k, v in zip(hue_values, hue_offset)}
    # print(hue_label2offset)
    tick_positions = []
    tick_labels = []
    max_x_value = 0
    
    for idx, (hue_index, hue_indices) in enumerate(groupby(targets_list)):
        hue_label = hue_order[hue_index]
        #print(idx, hue_label, hue_indices)
        bottom = np.zeros(n_bins*bin_size)
        subset_y = label_list[hue_indices]
        #print(subset_y)
        bin_labels, bin_counts = np.unique(subset_y, return_counts=True)
    
        # if normalize:
        denominator = hue2count.get(hue_label, 1)
        bin_counts = bin_counts / denominator
        max_x_value = max(max_x_value, bin_counts.max())
        
        if hue_label in hue_order:
            color_index = hue_order.index(hue_label)
        else:
            color_index = idx
        # new
        top_indices = np.argsort(bin_counts)[::-1][:bin_size]
        bin_labels = bin_labels[top_indices]
        bin_counts = bin_counts[top_indices]
        
        bin_indices = np.asarray([hue_label2offset[hue_label] + i for i, label in enumerate(bin_labels)])
        tick_positions.extend(bin_indices)
        tick_labels.extend(bin_labels)
        # old
        #offset = hue_offsets.get(hue_label, 0)
        
        #bin_indices = np.asarray([label2tick[t]+offset for t in bin_labels])
        
        p = ax.barh(
            bin_indices, bin_counts, width, label=hue_label, # left=bottom[bin_indices],
            color=palette[color_index])
        # if do_stack:
        #     bottom[bin_indices] += bin_counts 
        #     if not normalize:
        #         bottom[bin_indices] += bar_offset
        line_pos = bin_indices.max() + width/2 + hue_group_offset/2
        ax.axhline(line_pos, linewidth=1, linestyle='dashed', color=POSTER_BLUE)
    if show_legend:
        ax.legend(
            loc='upper center', bbox_to_anchor=(0.5, -0.05),
            fancybox=True, shadow=True, 
            ncol=4
        )

    ax.set_yticks(tick_positions)
    ax.set_yticklabels(tick_labels)
    if max_x_value <= 1:
        ax.set_xlim(0, 1.)
    ax.set_ylim(-0.5, np.max(tick_positions)+width/2+hue_group_offset/2)
    ax.invert_yaxis()