ksvmuralidhar commited on
Commit
4721618
·
verified ·
1 Parent(s): 3457983

Create classification_model_monitor.py

Browse files
Files changed (1) hide show
  1. classification_model_monitor.py +243 -0
classification_model_monitor.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import seaborn as sns
5
+ import plotly.express as px
6
+ import matplotlib.pyplot as plt
7
+ from read_predictions_from_db import PredictionDBRead
8
+ from read_daily_metrics_from_db import MetricsDBRead
9
+ from sklearn.metrics import balanced_accuracy_score, accuracy_score
10
+ import logging
11
+ from config import (CLASSIFIER_ADJUSTMENT_THRESHOLD,
12
+ PERFORMANCE_THRESHOLD,
13
+ CLASSIFIER_THRESHOLD)
14
+
15
+ logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO)
16
+
17
+
18
+ def filter_prediction_data(data: pd.DataFrame):
19
+ try:
20
+ logging.info("Entering filter_prediction_data()")
21
+ if data is None:
22
+ raise Exception("Input Prediction Data frame in None")
23
+
24
+ filtered_prediction_data = data.loc[(data['y_true'].isin(['WEATHER', 'EDUCATION', 'ASTROLOGY', 'OTHERS']) == False) &
25
+ (data['y_true_proba'] > CLASSIFIER_THRESHOLD)].copy()
26
+
27
+ logging.info("Exiting filter_prediction_data()")
28
+ return filtered_prediction_data
29
+ except Exception as e:
30
+ logging.critical(f"Error in filter_prediction_data(): {e}")
31
+ return None
32
+
33
+
34
+ def get_adjusted_predictions(df):
35
+ try:
36
+ logging.info("Entering get_adjusted_predictions()")
37
+ if df is None:
38
+ raise Exception('Input Filtered Prediction Data Frame is None')
39
+ df = df.copy()
40
+ df.reset_index(drop=True, inplace=True)
41
+ df.loc[df['y_pred_proba']<CLASSIFIER_ADJUSTMENT_THRESHOLD, 'y_pred'] = 'NATION'
42
+ df.loc[(df['text'].str.contains('Pakistan')) & (df['y_pred'] == 'NATION'), 'y_pred'] = 'WORLD'
43
+ df.loc[(df['text'].str.contains('Zodiac Sign', case=False)) | (df['text'].str.contains('Horoscope', case=False)), 'y_pred'] = 'SCIENCE'
44
+ logging.info("Exiting get_adjusted_predictions()")
45
+ return df
46
+ except Exception as e:
47
+ logging.info(f"Error in get_adjusted_predictions(): {e}")
48
+ return None
49
+
50
+
51
+ def display_kpis(data: pd.DataFrame, adj_data: pd.DataFrame):
52
+ try:
53
+ logging.info("Entering display_kpis()")
54
+ if data is None:
55
+ raise Exception("Input Prediction Data frame in None")
56
+ if adj_data is None:
57
+ raise Exception('Input Adjusted Data frame is None')
58
+
59
+ n_samples = len(data)
60
+ balanced_accuracy = np.round(balanced_accuracy_score(data['y_true'], data['y_pred']), 4)
61
+ accuracy = np.round(accuracy_score(data['y_true'], data['y_pred']), 4)
62
+
63
+ adj_balanced_accuracy = np.round(balanced_accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
64
+ adj_accuracy = np.round(accuracy_score(adj_data['y_true'], adj_data['y_pred']), 4)
65
+
66
+ st.write('''<style>
67
+ [data-testid="column"] {
68
+ width: calc(33.3333% - 1rem) !important;
69
+ flex: 1 1 calc(33.3333% - 1rem) !important;
70
+ min-width: calc(33% - 1rem) !important;
71
+ }
72
+ </style>''',
73
+ unsafe_allow_html=True)
74
+
75
+ col1, col2= st.columns(2)
76
+ with col1:
77
+ metric1 = st.metric(label="Balanced Accuracy", value=balanced_accuracy)
78
+ with col2:
79
+ metric2 = st.metric(label="Adj Balanced Accuracy", value=adj_balanced_accuracy)
80
+
81
+ col3, col4= st.columns(2)
82
+ with col3:
83
+ metric3 = st.metric(label="Accuracy", value=accuracy)
84
+ with col4:
85
+ metric4 = st.metric(label="Adj Accuracy", value=adj_accuracy)
86
+
87
+ col5, col6= st.columns(2)
88
+ with col5:
89
+ metric5 = st.metric(label="Bal Accuracy Threshold", value=PERFORMANCE_THRESHOLD)
90
+ with col6:
91
+ metric6 = st.metric(label="N Samples", value=n_samples)
92
+ logging.info("Exiting display_kpis()")
93
+ except Exception as e:
94
+ logging.critical(f'Error in display_kpis(): {e}')
95
+ st.error("Couldn't display KPIs")
96
+
97
+
98
+ def plot_daily_metrics(metrics_df: pd.DataFrame):
99
+ try:
100
+ logging.info("Entering plot_daily_metrics()")
101
+ st.write(" ")
102
+ if metrics_df is None:
103
+ raise Exception('Input Metrics Data Frame is None')
104
+
105
+ metrics_df['evaluation_date'] = pd.to_datetime(metrics_df['evaluation_date'])
106
+ metrics_df['mean_score_minus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] - metrics_df['std_balanced_accuracy_score'], 4)
107
+ metrics_df['mean_score_plus_std'] = np.round(metrics_df['mean_balanced_accuracy_score'] + metrics_df['std_balanced_accuracy_score'], 4)
108
+
109
+ hover_data={'mean_balanced_accuracy_score': True,
110
+ 'std_balanced_accuracy_score': False,
111
+ 'mean_score_minus_std': True,
112
+ 'mean_score_plus_std': True,
113
+ 'evaluation_window_days': True,
114
+ 'n_splits': True,
115
+ 'sample_start_date': True,
116
+ 'sample_end_date': True,
117
+ 'sample_size_of_each_split': True}
118
+
119
+ hover_labels = {'mean_balanced_accuracy_score': "Mean Score",
120
+ 'mean_score_minus_std': "Mean Score - Stdev",
121
+ 'mean_score_plus_std': "Mean Score + Stdev",
122
+ 'evaluation_window_days': "Observation Window (Days)",
123
+ 'sample_start_date': "Observation Window Start Date",
124
+ 'sample_end_date': "Observation Window End Date",
125
+ 'n_splits': "N Splits For Evaluation",
126
+ 'sample_size_of_each_split': "Sample Size of Each Split"}
127
+
128
+ fig = px.line(data_frame=metrics_df, x='evaluation_date',
129
+ y='mean_balanced_accuracy_score',
130
+ error_y='std_balanced_accuracy_score',
131
+ title="Daily Balanced Accuracy",
132
+ color_discrete_sequence=['black'],
133
+ hover_data=hover_data, labels=hover_labels, markers=True)
134
+
135
+ fig.add_hline(y=PERFORMANCE_THRESHOLD, line_dash="dash", line_color="green",
136
+ annotation_text=f"<b>THRESHOLD</b>",
137
+ annotation_position="left top")
138
+
139
+ fig.update_layout(dragmode='pan')
140
+ fig.update_layout(margin=dict(l=0, r=0, t=110, b=10))
141
+ st.plotly_chart(fig, use_container_width=True)
142
+ logging.info("Exiting plot_daily_metrics()")
143
+ except Exception as e:
144
+ logging.critical(f'Error in plot_daily_metrics(): {e}')
145
+ st.error("Couldn't Plot Daily Model Metrics")
146
+
147
+
148
+ def get_misclassified_classes(data):
149
+ try:
150
+ logging.info("Entering get_misclassified_classes()")
151
+ if data is None:
152
+ raise Exception("Input Prediction Data Frame is None")
153
+
154
+ data = data.copy()
155
+ data['match'] = (data['y_true'] == data['y_pred']).astype('int')
156
+ y_pred_counts = data['y_pred'].value_counts()
157
+
158
+ misclassified_examples = data.loc[data['match'] == 0, ['text', 'y_true', 'y_pred', 'y_pred_proba', 'url']].copy()
159
+ misclassified_examples.sort_values(by=['y_pred', 'y_pred_proba'], ascending=[True, False], inplace=True)
160
+
161
+ misclassifications = data.loc[data['match'] == 0, 'y_pred'].value_counts()[y_pred_counts.index]
162
+ misclassifications /= y_pred_counts
163
+ misclassifications.sort_values(ascending=False, inplace=True)
164
+ logging.info("Exiting get_misclassified_classes()")
165
+ return np.round(misclassifications, 2), misclassified_examples
166
+ except Exception as e:
167
+ logging.critical(f'Error in get_misclassified_classes(): {e}')
168
+ return None, None
169
+
170
+
171
+ def display_misclassified_examples(misclassified_classes, misclassified_examples):
172
+ try:
173
+ logging.info("Entering display_misclassified_examples()")
174
+ st.write(" ")
175
+ if misclassified_classes is None:
176
+ raise Exception('Misclassified Classes Distribution Data Frame is None')
177
+ if misclassified_examples is None:
178
+ raise Exception('Misclassified Examples Data Frame is None')
179
+
180
+ fig, ax = plt.subplots(figsize=(10, 4.5))
181
+ misclassified_classes.plot(kind='bar', ax=ax, color='black', title="Misclassification percentage")
182
+ plt.yticks([])
183
+ plt.xlabel("")
184
+ ax.bar_label(ax.containers[0]);
185
+ st.pyplot(fig)
186
+
187
+ st.markdown("<b>Misclassified examples</b>", unsafe_allow_html=True)
188
+ st.dataframe(misclassified_examples, hide_index=True)
189
+ st.markdown(
190
+ """
191
+ <style>
192
+ [data-testid="stElementToolbar"] {
193
+ display: none;
194
+ }
195
+ </style>
196
+ """,
197
+ unsafe_allow_html=True
198
+ )
199
+ logging.info("Exiting display_misclassified_examples()")
200
+ except Exception as e:
201
+ logging.critical(f'Error in display_misclassified_examples(): {e}')
202
+ st.error("Couldn't display Misclassification Data")
203
+
204
+
205
+ def classification_model_monitor():
206
+ try:
207
+ st.write("<h4>Classification Model Monitor</h4>", unsafe_allow_html=True)
208
+
209
+ prediction_db = PredictionDBRead()
210
+ metrics_db = MetricsDBRead()
211
+
212
+ # Read Prediction Data From DB
213
+ prediction_data = prediction_db.read_predictions_from_db()
214
+ # Filter Prediction Data
215
+ filtered_prediction_data = filter_prediction_data(prediction_data)
216
+ # Get Adjusted Prediction Data
217
+ adjusted_filtered_prediction_data = get_adjusted_predictions(filtered_prediction_data)
218
+ # Display KPIs
219
+ display_kpis(filtered_prediction_data, adjusted_filtered_prediction_data)
220
+
221
+ # Read Daily Metrics From DB
222
+ metrics_df = metrics_db.read_metrics_from_db()
223
+ # Display daily Metrics Line Plot
224
+ plot_daily_metrics(metrics_df)
225
+
226
+ # Get misclassified class distribution and misclassified examples from Prediction Data
227
+ misclassified_classes, misclassified_examples = get_misclassified_classes(filtered_prediction_data)
228
+ # Display Misclassification Data
229
+ display_misclassified_examples(misclassified_classes, misclassified_examples)
230
+
231
+ st.markdown(
232
+ """<style>
233
+ [data-testid="stMetricValue"] {
234
+ font-size: 25px;
235
+ }
236
+ </style>
237
+ """, unsafe_allow_html=True
238
+ )
239
+
240
+ except Exception as e:
241
+ logging.critical(f"Error in classification_model_monitor(): {e}")
242
+ st.error("Unexpected Error. Couldn't display Classification Model Monitor")
243
+