##################################### # Packages & Dependencies ##################################### import panel as pn import os, yaml from panel.viewable import Viewer from app_components import canvas, plots from app_utils import styles pn.extension('plotly') FILE_PATH = os.path.dirname(__file__) ################################################ # Digit Classifier Layout ################################################ class DigitClassifier(Viewer): ''' Builds and displays the UI for the classifier application. Args: mod_path (str): The absolute path to the saved TinyVGG model mod_kwargs (dict): A dictionary containing the keyword-arguments for the TinyVGG model. This should have the keys: num_blks, num_convs, in_channels, hidden_channels, and num_classes ''' def __init__(self, mod_path: str, mod_kwargs: dict, **params): self.canvas = canvas.Canvas(sizing_mode = 'stretch_both', styles = {'border':'black solid 0.15rem'}) self.clear_btn = pn.widgets.Button(name = 'Clear', sizing_mode = 'stretch_width', stylesheets = [styles.BTN_STYLESHEET]) self.plot_panels = plots.PlotPanels(canvas_info = self.canvas, mod_path = mod_path, mod_kwargs = mod_kwargs) super().__init__(**params) self.github_logo = pn.pane.PNG( object = FILE_PATH + '/assets/github-mark-white.png', alt_text = 'GitHub Repo', link_url = 'https://github.com/Jechen00/digit-classifier-app', height = 70, styles = {'margin':'0'} ) self.controls_col = pn.FlexBox( self.github_logo, self.clear_btn, self.plot_panels.pred_txt, gap = '60px', flex_direction = 'column', justify_content = 'center', align_items = 'center', flex_wrap = 'nowrap', styles = {'width':'40%', 'height':'100%'} ) self.mod_input_txt = pn.pane.HTML( object = '''
MODEL INPUT
''', styles = {'margin':'0rem', 'padding-left':'0.15rem', 'color':'white', 'font-size':styles.FONTSIZES['mod_input_txt'], 'font-family':styles.FONTFAMILY, 'position':'absolute', 'z-index':'100'} ) self.img_row = pn.FlexBox( self.canvas, self.controls_col, pn.FlexBox(self.mod_input_txt, self.plot_panels.img_pane, sizing_mode = 'stretch_both', styles = {'border':'solid 0.15rem white'}), gap = '1%', flex_wrap = 'nowrap', flex_direction = 'row', justify_content = 'center', sizing_mode = 'stretch_width', styles = {'height':'60%'} ) self.prob_row = pn.FlexBox(self.plot_panels.prob_pane, sizing_mode = 'stretch_width', styles = {'height':'40%', 'border':'solid 0.15rem black'}) self.page_info = pn.pane.HTML( object = f'''
Digit Classifier

This is a handwritten digit classifier that uses a convolutional neural network (CNN) to make predictions. The architecture of the model is a scaled-down version of the Visual Geometry Group (VGG) architecture from the paper: Very Deep Convolutional Networks for Large-Scale Image Recognition.


How To Use: Draw a digit (0-9) on the canvas and the model will produce a prediction for it in real time. Prediction probabilities (or confidences) for each digit are displayed in the bar chart, reflecting the model's softmax output distribution. To the right of the canvas, you'll also find the transformed input image, i.e. the canvas drawing after undergoing MNIST preprocessing. This input image represents what the model receives prior to feature extraction and classification.


Note: Due to resource limitations on HF Spaces (CPU basic), performance may vary. For optimal experience, it's recommended to run the app locally.

''', styles = {'margin':' 0rem', 'color': styles.CLRS['sidebar_txt'], 'width': '19.7%', 'height': '100%', 'font-family': styles.FONTFAMILY, 'background-color': styles.CLRS['sidebar'], 'overflow-y':'scroll', 'border': 'solid 0.15rem black'} ) self.classifier_content = pn.FlexBox( self.img_row, self.prob_row, gap = '0.5%', flex_direction = 'column', flex_wrap = 'nowrap', sizing_mode = 'stretch_height', styles = {'width': '80%'} ) self.page_content = pn.FlexBox( self.page_info, self.classifier_content, gap = '0.3%', flex_direction = 'row', justify_content = 'space-around', align_items = 'center', flex_wrap = 'nowrap', styles = { 'height':'100%', 'width':'100vw', 'padding': '1%', 'min-width': '1200px', 'min-height': '600px', 'max-width': '3600px', 'max-height': '1800px', 'background-color': styles.CLRS['page_bg'] }, ) # This is mainly used to ensure there is always have a grey background self.page_layout = pn.FlexBox( self.page_content, justify_content = 'center', flex_wrap = 'nowrap', sizing_mode = 'stretch_both', styles = { 'min-width': 'max-content', 'background-color': styles.CLRS['page_bg'], } ) # Set up on-click event with clear button and the canvas self.clear_btn.on_click(self.canvas.toggle_clear) def __panel__(self): ''' Returns the main layout of the application to be rendered by Panel. ''' return self.page_layout def create_app(): ''' Creates the application, ensuring that each user gets a different instance of digit_classifier. Mostly used to keep things away from a global scope. ''' # Used to serve with panel serve in command line save_dir = FILE_PATH + '/saved_models' base_name = 'tiny_vgg_less_compute' mod_path = f'{save_dir}/{base_name}_model.pth' # Path to the saved model state dict settings_path = f'{save_dir}/{base_name}_settings.yaml' # Path to the saved model kwargs # Load in model kwargs with open( settings_path, 'r') as f: loaded_settings = yaml.load(f, Loader = yaml.FullLoader) mod_kwargs = loaded_settings['mod_kwargs'] digit_classifier = DigitClassifier(mod_path = mod_path, mod_kwargs = mod_kwargs) return digit_classifier ################################################ # Serve App ################################################ # Used to serve with panel serve in command line create_app().servable(title = 'CNN Digit Classifier')