File size: 3,848 Bytes
694c1c6
 
d03fbaa
 
1bcc7b4
694c1c6
 
 
 
 
 
 
 
 
 
 
2d1d8cb
 
694c1c6
af49af1
2d1d8cb
 
 
b40aac1
2d1d8cb
b40aac1
 
 
 
7f8d6ba
694c1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b40aac1
 
 
694c1c6
b40aac1
 
694c1c6
b40aac1
 
694c1c6
b40aac1
 
694c1c6
af49af1
2d1d8cb
 
 
b40aac1
2d1d8cb
b40aac1
 
 
 
694c1c6
af49af1
694c1c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b40aac1
694c1c6
 
 
 
 
 
 
b40aac1
 
 
694c1c6
b40aac1
 
694c1c6
b40aac1
 
694c1c6
b40aac1
694c1c6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114

import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('agg')

import plot_utils
from constants import *


class MatplotlibDataPlotter:
    def __init__(self, single_df, pair_df, num_domains_in_region_df):
        self.single_df = single_df
        self.pair_df = pair_df

        self.num_domains_in_region_df = num_domains_in_region_df

        self.single_domains_fig = plt.figure(figsize=(5, 10))
        self.pair_domains_fig = plt.figure(figsize=(5, 10))

    def plot_single_domains(self, num_domains, split_name="stratified"):
        selected_region_ids = self.num_domains_in_region_df.loc[
            self.num_domains_in_region_df.num_domains >= num_domains, 
            'cds_region_id'].values

        single_df_subset = self.single_df.loc[self.single_df.cds_region_id.isin(selected_region_ids)]

        biosyn_counts_single = single_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
        hue2count_single = dict(biosyn_counts_single.values)

        # split_name = 'stratified'
        column_name = f'cosine_similarity_{split_name}'
        # single_df_subset = single_df.loc[single_df.dom_location_len >= num_domains]
        selected_keyword_index = single_df_subset.groupby('cds_region_id').agg(
            {column_name: 'idxmax'}
        ).values.flatten()
        targets_list = single_df_subset.loc[selected_keyword_index, 'biosyn_class_index'].values
        label_list = single_df_subset.loc[selected_keyword_index, 'profile_name'].values

        top_n=5
        bin_width=1
        hue_group_offset=0.5
        width=0.9

        fig = self.single_domains_fig
        fig.clf()

        ax = fig.gca()
        plot_utils.draw_barplots(
            targets_list, 
            label_list=label_list,
            top_n=top_n,
            bin_width=bin_width,
            hue_group_offset=hue_group_offset,
            hue_order=BIOSYN_CLASS_NAMES,
            hue2count=hue2count_single,
            width=width,
            ax=ax, 
            show_legend=False,
            palette=COLOR_PALETTE
        )
        fig.tight_layout()
        return fig

    def plot_pair_domains(self, num_domains, split_name="stratified"):
        selected_region_ids = self.num_domains_in_region_df.loc[
            self.num_domains_in_region_df.num_domains >= num_domains, 
            'cds_region_id'].values
        
        pair_df_subset = self.pair_df.loc[self.pair_df.cds_region_id.isin(selected_region_ids)]
        
        biosyn_counts_pairs = pair_df_subset[['cds_region_id', 'biosyn_class']].drop_duplicates().groupby("biosyn_class", as_index=False).count()
        hue2count_pairs = dict(biosyn_counts_pairs.values)
        
        column_name = f'cosine_similarity_{split_name}'

        selected_keyword_index = pair_df_subset.groupby('cds_region_id').agg(
            {column_name: 'idxmax'}
        ).values.flatten()
        targets_list = pair_df_subset.loc[
            selected_keyword_index, 'biosyn_class_index'].values
        label_list=pair_df_subset.loc[
            selected_keyword_index, 'profile_name'].values

        top_n=5
        bin_width=1
        hue_group_offset=0.5
        # hue_order=BIOSYN_CLASS_NAMES
        hue2count={}
        width=0.9

        show_legend=False
        fig = self.pair_domains_fig
        fig.clf()

        ax = fig.gca()
        plot_utils.draw_barplots(
            targets_list, 
            label_list=label_list,
            top_n=top_n,
            bin_width=bin_width,
            hue_group_offset=hue_group_offset,
            hue_order=BIOSYN_CLASS_NAMES,
            hue2count=hue2count_pairs,
            width=width,
            ax=ax, 
            show_legend=show_legend,
            palette=COLOR_PALETTE
        )
        fig.tight_layout()
        return fig  #plt.gcf()