Spaces:
Sleeping
Sleeping
File size: 4,716 Bytes
3a329d1 7ac370b 3a329d1 7ac370b 3a329d1 7ac370b ca7444f 7ac370b 1bcc7b4 7ac370b b40aac1 ca7444f 7ac370b b40aac1 ca7444f 7ac370b 1bcc7b4 7ac370b b40aac1 f4c1e61 7ac370b 1bcc7b4 7ac370b 2d1d8cb 7ac370b b40aac1 7ac370b 12ee82e 7ac370b 12ee82e 7ac370b b40aac1 7ac370b b40aac1 7ac370b b40aac1 7ac370b 1bcc7b4 12ee82e b40aac1 7ac370b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from constants import *
from mpl_data_plotter import MatplotlibDataPlotter
def convert_int64_to_int32(df):
for col in df.columns:
if df[col].dtype == 'int64':
print(col)
df[col] = df[col].astype('int32')
return df
print(f"Loading domains data...")
single_df = pd.read_csv(SINGLE_DOMAINS_FILE, compression='gzip')
single_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
single_df['biosyn_class_index'] = single_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
single_df = convert_int64_to_int32(single_df)
pair_df = pd.read_csv(PAIR_DOMAINS_FILE, compression='gzip')
pair_df.rename(columns={'bgc_class': 'biosyn_class'}, inplace=True)
pair_df['biosyn_class_index'] = pair_df.biosyn_class.apply(lambda x: BIOSYN_CLASS_NAMES.index(x))
pair_df = convert_int64_to_int32(pair_df)
num_domains_in_region_df = single_df.groupby('cds_region_id', as_index=False).agg({'as_domain_id': 'count'}).rename(
columns={'as_domain_id': 'num_domains'})
unique_domain_lengths = num_domains_in_region_df.num_domains.unique()
print(f"Initializing data plotter...")
data_plotter = MatplotlibDataPlotter(single_df, pair_df, num_domains_in_region_df)
def create_color_legend(class_to_color):
# Create HTML for the color legend
legend_html = """
<div style="
margin: 10px 0;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
background: white;
">
<div style="
font-weight: bold;
margin-bottom: 8px;
">Color Legend:</div>
<div style="
display: flex;
flex-wrap: wrap;
gap: 15px;
align-items: center;
">
"""
# Add each class and its color
for class_name, color in class_to_color.items():
legend_html += f"""
<div style="
display: flex;
align-items: center;
gap: 5px;
">
<div style="
width: 20px;
height: 20px;
background-color: {color};
border-radius: 3px;
"></div>
<span>{class_name}</span>
</div>
"""
legend_html += """
</div>
</div>
"""
return gr.HTML(legend_html)
def update_all_plots(frequency, split_name):
return data_plotter.plot_single_domains(frequency, split_name), data_plotter.plot_pair_domains(frequency, split_name)
print(f"Defining blocks...")
# Create Gradio interface
with gr.Blocks(title="BGC Keyword Plotter") as demo:
gr.Markdown("## BGC Keyword Plotter")
gr.Markdown("Select the model name and minimal number of domains in Antismash-db subset.")
color_legend = create_color_legend(BIOSYN_CLASS_HEX_COLORS)
with gr.Row():
frequency_slider = gr.Slider(
minimum=int(unique_domain_lengths.min()),
maximum=int(unique_domain_lengths.max()),
step=1,
value=int(unique_domain_lengths.min()),
label="Min number of domains"
)
model_selector = gr.Radio(
choices=["stratified"] + BIOSYN_CLASS_NAMES,
value="stratified",
label="Model name"
)
with gr.Row():
with gr.Column():
single_domains_plot = gr.Plot(
label="Single domains",
container=True,
elem_id="single_domains_plot"
)
# gr.HTML("""
# <style>
# #single_domains_plot {
# height: 100% !important;
# width: 100% !important;
# }
# </style>
# """)
with gr.Column():
pair_domains_plot = gr.Plot(label="Pair domains")
# with gr.Column():
# combined_plot = gr.Plot(label="Combined Wave")
frequency_slider.release(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]#, cosine_plot]
)
demo.load(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
model_selector.input(
fn=update_all_plots,
inputs=[frequency_slider, model_selector],
outputs=[single_domains_plot, pair_domains_plot]
)
print(f"Launching!...")
demo.launch()
# demo.load(filter_map, [min_price, max_price, boroughs], map) |