##################################### # Packages & Dependencies ##################################### import param import panel as pn import torch import numpy as np import plotly.graph_objects as go from . import canvas from app_utils import styles import sys, os APP_PATH = os.path.dirname(os.path.dirname(__file__)) # Path to the digit-classifier-app directory sys.path.append(APP_PATH + '/model_training') # Imports from model_training import data_setup, model ##################################### # Plotly Panels ##################################### PLOTLY_CONFIGS = { 'displayModeBar': True, 'displaylogo': False, 'modeBarButtonsToRemove': ['autoScale', 'lasso', 'select', 'toImage', 'pan', 'zoom', 'zoomIn', 'zoomOut'] } class PlotPanels(param.Parameterized): ''' Contains all Plotly pane objects for the application. This includes the probability bar chart and the MNIST preprocessed image heat map. Args: canvas_info (param.ClassSelector): A Canvas class object to get the data URI of the drawn image. mod_path (str): The absolute path to the saved TinyVGG model. mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model. This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes ''' canvas_info = param.ClassSelector(class_ = canvas.Canvas) # Canvas object to get the data URI def __init__(self, mod_path: str, mod_kwargs: dict, **params): super().__init__(**params) self.class_labels = np.arange(0, 10) self.cnn_mod = model.TinyVGG(**mod_kwargs) self.cnn_mod.load_state_dict(torch.load(mod_path, map_location = 'cpu')) self.img_pane = pn.pane.Plotly( name = 'image_plot', config = PLOTLY_CONFIGS, sizing_mode = 'stretch_both', margin = 0, ) self.prob_pane = pn.pane.Plotly( name = 'prob_plot', config = PLOTLY_CONFIGS, sizing_mode = 'stretch_both', margin = 0 ) self.pred_txt = pn.pane.HTML( styles = {'margin':'0rem', 'color':styles.CLRS['pred_txt'], 'font-size':styles.FONTSIZES['pred_txt'], 'font-family':styles.FONTFAMILY} ) # Initialize plotly figures self._update_prediction() # Set up watchers thta update based on data URI changes self.canvas_info.param.watch(self._update_prediction, 'uri') def _update_prediction(self, *event): ''' Performs all prediction-related updates for the application. This function is connected to the URI parameter of canvas_info through a watcher. Any times the URI changes, a class prediction is immediately. Following this, the probability bar chart and model input heatmap are updated as well. ''' self._update_preprocessed_tensor() self._update_pred_txt() self._update_img_plot() self._update_prob_plot() def _update_preprocessed_tensor(self): ''' Transforms the data URI (string) from canvas_info into a preprocessed tensor. This is done by having it undergo the MNISt preprocessing pipeline (see mnist_preprocess in data_setup for details). Additionally, a prediction is made for the preprocessed tensor to get its class label. The correpsonding set of prediction probabilities are stored. ''' # Check if uri is non-empty if self.canvas_info.uri: self.input_img = data_setup.mnist_preprocess(self.canvas_info.uri) self.cnn_mod.eval() # Set CNN to eval & inference mode with torch.inference_mode(): pred_logits = self.cnn_mod(self.input_img.unsqueeze(0)) self.pred_probs = torch.softmax(pred_logits, dim = 1)[0].numpy() self.pred_label = np.argmax(self.pred_probs) else: self.input_img = torch.zeros((28, 28)) self.pred_probs = np.zeros(10) self.pred_label = None def _update_pred_txt(self): ''' Updates the prediction and probability HTML text to reflect the current data URI. ''' if self.canvas_info.uri: pred, prob = self.pred_label, f'{self.pred_probs[self.pred_label]:.3f}' else: pred, prob = 'N/A', 'N/A' self.pred_txt.object = f'''
Prediction: {pred}
Probability: {prob}
''' def _update_prob_plot(self): ''' Updates the probability bar chart to showcase the softmax output probability distribution obtained from the prediction in _update_preprocessed_tensor. ''' # Marker fill and outline color for bar plot mkr_clrs = [styles.CLRS['base_bar']] * len(self.class_labels) mkr_line_clrs = [styles.CLRS['base_bar_line']] * len(self.class_labels) if self.pred_label is not None: mkr_clrs[self.pred_label] = styles.CLRS['pred_bar'] mkr_line_clrs[self.pred_label] = styles.CLRS['pred_bar_line'] fig = go.Figure() # Bar plot fig.add_trace( go.Bar(x = self.class_labels, y = self.pred_probs, marker_color = mkr_clrs, marker_line_color = mkr_line_clrs, marker_line_width = 1.5, showlegend = False, text = self.pred_probs, textposition = 'outside', textfont = dict(color = styles.CLRS['plot_txt'], size = styles.FONTSIZES['plot_bar_txt'], family = styles.FONTFAMILY), texttemplate = '%{text:.3f}', customdata = self.pred_probs * 100, hoverlabel_font = dict(family = styles.FONTFAMILY), hovertemplate = 'Class Label: %{x}' + '
Probability: %{customdata:.2f} %' + '' ) ) # Used to fix axis limits fig.add_trace( go.Scatter( x = [0.5, 0.5], y = [0.1, 1], marker = dict(color = 'rgba(0, 0, 0, 0)', size = 10), mode = 'markers', hoverinfo = 'skip', showlegend = False ) ) fig.update_yaxes( title = dict(text = 'Prediction Probability', standoff = 0, font = dict(color = styles.CLRS['plot_txt'], size = styles.FONTSIZES['plot_labels'], family = styles.FONTFAMILY)), tickfont = dict(size = styles.FONTSIZES['plot_ticks'], family = styles.FONTFAMILY), dtick = 0.1, ticks = 'outside', ticklen = 0, gridcolor = styles.CLRS['prob_plot_grid'] ) fig.update_xaxes( title = dict(text = 'Class Label', standoff = 6, font = dict(color = styles.CLRS['plot_txt'], size = styles.FONTSIZES['plot_labels'], family = styles.FONTFAMILY)), dtick = 1, tickfont = dict(size = styles.FONTSIZES['plot_ticks'], family = styles.FONTFAMILY), ) fig.update_layout( paper_bgcolor = styles.CLRS['prob_plot_bg'], plot_bgcolor = styles.CLRS['prob_plot_bg'], margin = dict(l = 60, r = 0, t = 5, b = 45), ) self.prob_pane.object = fig def _update_img_plot(self): ''' Updates the heat map to showcase the current model input, i.e. the preprocessed canvas drawing. ''' img_np = self.input_img.squeeze().numpy() if self.pred_label is not None: zmin, zmax = np.min(img_np), np.max(img_np) else: zmin, zmax = 0, 1 fig = go.Figure( data = go.Heatmap( z = img_np, colorscale = 'gray', showscale = False, zmin = zmin, zmax = zmax, hoverlabel_font = dict(family = styles.FONTFAMILY), hovertemplate = 'Pixel Position: (%{x}, %{y})' + '
Pixel Value: %{z:.3f}' + '' ) ) fig.update_yaxes(autorange = 'reversed') fig.update_layout( plot_bgcolor = styles.CLRS['img_plot_bg'], margin = dict(l = 0, r = 0, t = 0, b = 0), xaxis = dict(showticklabels = False), yaxis = dict(showticklabels = False), ) self.img_pane.object = fig