Create classification_model_monitor.py
Browse files- classification_model_monitor.py +243 -0
classification_model_monitor.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import numpy as np
|
4 |
+
import seaborn as sns
|
5 |
+
import plotly.express as px
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from read_predictions_from_db import PredictionDBRead
|
8 |
+
from read_daily_metrics_from_db import MetricsDBRead
|
9 |
+
from sklearn.metrics import balanced_accuracy_score, accuracy_score
|
10 |
+
import logging
|
11 |
+
from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD,
|
12 |
+
PERFORMANCE_THRESHOLD,
|
13 |
+
CLASSIFIER_THRESHOLD)
|
14 |
+
|
15 |
+
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
|
16 |
+
|
17 |
+
|
18 |
+
def filter_prediction_data(data: pd.DataFrame):
|
19 |
+
try:
|
20 |
+
logging.info("Entering filter_prediction_data()")
|
21 |
+
if data is None:
|
22 |
+
raise Exception("Input Prediction Data frame in None")
|
23 |
+
|
24 |
+
filtered_prediction_data = data.loc[(data['y_true'].isin(['WEATHER', 'EDUCATION', 'ASTROLOGY', 'OTHERS']) == False) &
|
25 |
+
(data['y_true_proba'] > CLASSIFIER_THRESHOLD)].copy()
|
26 |
+
|
27 |
+
logging.info("Exiting filter_prediction_data()")
|
28 |
+
return filtered_prediction_data
|
29 |
+
except Exception as e:
|
30 |
+
logging.critical(f"Error in filter_prediction_data(): {e}")
|
31 |
+
return None
|
32 |
+
|
33 |
+
|
34 |
+
def get_adjusted_predictions(df):
|
35 |
+
try:
|
36 |
+
logging.info("Entering get_adjusted_predictions()")
|
37 |
+
if df is None:
|
38 |
+
raise Exception('Input Filtered Prediction Data Frame is None')
|
39 |
+
df = df.copy()
|
40 |
+
df.reset_index(drop=True, inplace=True)
|
41 |
+
df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION'
|
42 |
+
df.loc[(df['text'].str.contains('Pakistan')) & (df['y_pred'] == 'NATION'), 'y_pred'] = 'WORLD'
|
43 |
+
df.loc[(df['text'].str.contains('Zodiac Sign', case=False)) | (df['text'].str.contains('Horoscope', case=False)), 'y_pred'] = 'SCIENCE'
|
44 |
+
logging.info("Exiting get_adjusted_predictions()")
|
45 |
+
return df
|
46 |
+
except Exception as e:
|
47 |
+
logging.info(f"Error in get_adjusted_predictions(): {e}")
|
48 |
+
return None
|
49 |
+
|
50 |
+
|
51 |
+
def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame):
|
52 |
+
try:
|
53 |
+
logging.info("Entering display_kpis()")
|
54 |
+
if data is None:
|
55 |
+
raise Exception("Input Prediction Data frame in None")
|
56 |
+
if adj_data is None:
|
57 |
+
raise Exception('Input Adjusted Data frame is None')
|
58 |
+
|
59 |
+
n_samples = len(data)
|
60 |
+
balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4)
|
61 |
+
accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4)
|
62 |
+
|
63 |
+
adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
|
64 |
+
adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
|
65 |
+
|
66 |
+
st.write('''<style>
|
67 |
+
[data-testid="column"] {
|
68 |
+
width: calc(33.3333% - 1rem) !important;
|
69 |
+
flex: 1 1 calc(33.3333% - 1rem) !important;
|
70 |
+
min-width: calc(33% - 1rem) !important;
|
71 |
+
}
|
72 |
+
</style>''',
|
73 |
+
unsafe_allow_html=True)
|
74 |
+
|
75 |
+
col1, col2= st.columns(2)
|
76 |
+
with col1:
|
77 |
+
metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy)
|
78 |
+
with col2:
|
79 |
+
metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy)
|
80 |
+
|
81 |
+
col3, col4= st.columns(2)
|
82 |
+
with col3:
|
83 |
+
metric3 = st.metric(label="Accuracy", value=accuracy)
|
84 |
+
with col4:
|
85 |
+
metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy)
|
86 |
+
|
87 |
+
col5, col6= st.columns(2)
|
88 |
+
with col5:
|
89 |
+
metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD)
|
90 |
+
with col6:
|
91 |
+
metric6 = st.metric(label="N Samples", value=n_samples)
|
92 |
+
logging.info("Exiting display_kpis()")
|
93 |
+
except Exception as e:
|
94 |
+
logging.critical(f'Error in display_kpis(): {e}')
|
95 |
+
st.error("Couldn't display KPIs")
|
96 |
+
|
97 |
+
|
98 |
+
def plot_daily_metrics(metrics_df: pd.DataFrame):
|
99 |
+
try:
|
100 |
+
logging.info("Entering plot_daily_metrics()")
|
101 |
+
st.write(" ")
|
102 |
+
if metrics_df is None:
|
103 |
+
raise Exception('Input Metrics Data Frame is None')
|
104 |
+
|
105 |
+
metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date'])
|
106 |
+
metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4)
|
107 |
+
metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4)
|
108 |
+
|
109 |
+
hover_data={'mean_balanced_accuracy_score': True,
|
110 |
+
'std_balanced_accuracy_score': False,
|
111 |
+
'mean_score_minus_std': True,
|
112 |
+
'mean_score_plus_std': True,
|
113 |
+
'evaluation_window_days': True,
|
114 |
+
'n_splits': True,
|
115 |
+
'sample_start_date': True,
|
116 |
+
'sample_end_date': True,
|
117 |
+
'sample_size_of_each_split': True}
|
118 |
+
|
119 |
+
hover_labels = {'mean_balanced_accuracy_score': "Mean Score",
|
120 |
+
'mean_score_minus_std': "Mean Score - Stdev",
|
121 |
+
'mean_score_plus_std': "Mean Score + Stdev",
|
122 |
+
'evaluation_window_days': "Observation Window (Days)",
|
123 |
+
'sample_start_date': "Observation Window Start Date",
|
124 |
+
'sample_end_date': "Observation Window End Date",
|
125 |
+
'n_splits': "N Splits For Evaluation",
|
126 |
+
'sample_size_of_each_split': "Sample Size of Each Split"}
|
127 |
+
|
128 |
+
fig = px.line(data_frame=metrics_df, x='evaluation_date',
|
129 |
+
y='mean_balanced_accuracy_score',
|
130 |
+
error_y='std_balanced_accuracy_score',
|
131 |
+
title="Daily Balanced Accuracy",
|
132 |
+
color_discrete_sequence=['black'],
|
133 |
+
hover_data=hover_data, labels=hover_labels, markers=True)
|
134 |
+
|
135 |
+
fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green",
|
136 |
+
annotation_text=f"<b>THRESHOLD</b>",
|
137 |
+
annotation_position="left top")
|
138 |
+
|
139 |
+
fig.update_layout(dragmode='pan')
|
140 |
+
fig.update_layout(margin=dict(l=0, r=0, t=110, b=10))
|
141 |
+
st.plotly_chart(fig, use_container_width=True)
|
142 |
+
logging.info("Exiting plot_daily_metrics()")
|
143 |
+
except Exception as e:
|
144 |
+
logging.critical(f'Error in plot_daily_metrics(): {e}')
|
145 |
+
st.error("Couldn't Plot Daily Model Metrics")
|
146 |
+
|
147 |
+
|
148 |
+
def get_misclassified_classes(data):
|
149 |
+
try:
|
150 |
+
logging.info("Entering get_misclassified_classes()")
|
151 |
+
if data is None:
|
152 |
+
raise Exception("Input Prediction Data Frame is None")
|
153 |
+
|
154 |
+
data = data.copy()
|
155 |
+
data['match'] = (data['y_true'] == data['y_pred']).astype('int')
|
156 |
+
y_pred_counts = data['y_pred'].value_counts()
|
157 |
+
|
158 |
+
misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy()
|
159 |
+
misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True)
|
160 |
+
|
161 |
+
misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts()[y_pred_counts.index]
|
162 |
+
misclassifications /= y_pred_counts
|
163 |
+
misclassifications.sort_values(ascending=False, inplace=True)
|
164 |
+
logging.info("Exiting get_misclassified_classes()")
|
165 |
+
return np.round(misclassifications, 2), misclassified_examples
|
166 |
+
except Exception as e:
|
167 |
+
logging.critical(f'Error in get_misclassified_classes(): {e}')
|
168 |
+
return None, None
|
169 |
+
|
170 |
+
|
171 |
+
def display_misclassified_examples(misclassified_classes, misclassified_examples):
|
172 |
+
try:
|
173 |
+
logging.info("Entering display_misclassified_examples()")
|
174 |
+
st.write(" ")
|
175 |
+
if misclassified_classes is None:
|
176 |
+
raise Exception('Misclassified Classes Distribution Data Frame is None')
|
177 |
+
if misclassified_examples is None:
|
178 |
+
raise Exception('Misclassified Examples Data Frame is None')
|
179 |
+
|
180 |
+
fig, ax = plt.subplots(figsize=(10, 4.5))
|
181 |
+
misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage")
|
182 |
+
plt.yticks([])
|
183 |
+
plt.xlabel("")
|
184 |
+
ax.bar_label(ax.containers[0]);
|
185 |
+
st.pyplot(fig)
|
186 |
+
|
187 |
+
st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True)
|
188 |
+
st.dataframe(misclassified_examples, hide_index=True)
|
189 |
+
st.markdown(
|
190 |
+
"""
|
191 |
+
<style>
|
192 |
+
[data-testid="stElementToolbar"] {
|
193 |
+
display: none;
|
194 |
+
}
|
195 |
+
</style>
|
196 |
+
""",
|
197 |
+
unsafe_allow_html=True
|
198 |
+
)
|
199 |
+
logging.info("Exiting display_misclassified_examples()")
|
200 |
+
except Exception as e:
|
201 |
+
logging.critical(f'Error in display_misclassified_examples(): {e}')
|
202 |
+
st.error("Couldn't display Misclassification Data")
|
203 |
+
|
204 |
+
|
205 |
+
def classification_model_monitor():
|
206 |
+
try:
|
207 |
+
st.write("<h4>Classification Model Monitor</h4>", unsafe_allow_html=True)
|
208 |
+
|
209 |
+
prediction_db = PredictionDBRead()
|
210 |
+
metrics_db = MetricsDBRead()
|
211 |
+
|
212 |
+
# Read Prediction Data From DB
|
213 |
+
prediction_data = prediction_db.read_predictions_from_db()
|
214 |
+
# Filter Prediction Data
|
215 |
+
filtered_prediction_data = filter_prediction_data(prediction_data)
|
216 |
+
# Get Adjusted Prediction Data
|
217 |
+
adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data)
|
218 |
+
# Display KPIs
|
219 |
+
display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data)
|
220 |
+
|
221 |
+
# Read Daily Metrics From DB
|
222 |
+
metrics_df = metrics_db.read_metrics_from_db()
|
223 |
+
# Display daily Metrics Line Plot
|
224 |
+
plot_daily_metrics(metrics_df)
|
225 |
+
|
226 |
+
# Get misclassified class distribution and misclassified examples from Prediction Data
|
227 |
+
misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data)
|
228 |
+
# Display Misclassification Data
|
229 |
+
display_misclassified_examples(misclassified_classes, misclassified_examples)
|
230 |
+
|
231 |
+
st.markdown(
|
232 |
+
"""<style>
|
233 |
+
[data-testid="stMetricValue"] {
|
234 |
+
font-size: 25px;
|
235 |
+
}
|
236 |
+
</style>
|
237 |
+
""", unsafe_allow_html=True
|
238 |
+
)
|
239 |
+
|
240 |
+
except Exception as e:
|
241 |
+
logging.critical(f"Error in classification_model_monitor(): {e}")
|
242 |
+
st.error("Unexpected Error. Couldn't display Classification Model Monitor")
|
243 |
+
|