|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import re |
|
import uuid |
|
import time |
|
import shutil |
|
import zipfile |
|
import threading |
|
import subprocess |
|
import select |
|
from datetime import datetime |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
import dash |
|
from dash import dcc, html |
|
import dash_daq as daq |
|
from dash.dependencies import Input, Output, State, ALL |
|
import dash_bootstrap_components as dbc |
|
from dash.exceptions import PreventUpdate |
|
import dash_daq as daq |
|
from flask import Flask, render_template, request, send_file, jsonify, abort |
|
import plotly.graph_objects as go |
|
import plotly.colors as pc |
|
|
|
import yaml |
|
import ruamel.yaml |
|
import pandas as pd |
|
|
|
import logging |
|
import base64 |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
server = Flask(__name__) |
|
server.secret_key = os.urandom(24) |
|
|
|
|
|
@server.route('/') |
|
def welcome_page(): |
|
""" |
|
Handles the welcome page route. |
|
This function extracts the username from the request host, |
|
determines if the duplicate mode should be enabled, and renders |
|
the welcome page template with the duplicate mode state. |
|
Returns: |
|
str: The rendered 'index.html' template with the duplicate_mode parameter. |
|
""" |
|
host = request.host |
|
print("host:", host) |
|
usr_match = re.match(r'^(.*?)\-stm32', host) |
|
print("usr_match:", usr_match) |
|
|
|
if usr_match: |
|
hf_user = usr_match.group(1) |
|
else: |
|
hf_user = "modelzoo_user" |
|
|
|
if hf_user == "stmicroelectronics": |
|
duplicate_mode = True |
|
else: |
|
duplicate_mode = False |
|
|
|
print("hf_user:", hf_user) |
|
print("duplicate_mode:", duplicate_mode) |
|
|
|
return render_template('index.html', duplicate_mode=duplicate_mode) |
|
|
|
|
|
external_stylesheets = [dbc.themes.LITERA] |
|
app = dash.Dash(__name__, server=server,external_stylesheets=external_stylesheets, url_base_pathname='/dash_app/', suppress_callback_exceptions=True) |
|
|
|
local_yamls = { |
|
'image_classification': 'stm32ai-modelzoo-services/image_classification/src/user_config.yaml', |
|
'human_activity_recognition': 'stm32ai-modelzoo-services/human_activity_recognition/src/user_config.yaml', |
|
'hand_posture': 'stm32ai-modelzoo-services/hand_posture/src/user_config.yaml', |
|
'object_detection': 'stm32ai-modelzoo-services/object_detection/src/user_config.yaml', |
|
'audio_event_detection': 'stm32ai-modelzoo-services/audio_event_detection/src/user_config.yaml', |
|
'pose_estimation': 'stm32ai-modelzoo-services/pose_estimation/src/user_config.yaml', |
|
'semantic_segmentation': 'stm32ai-modelzoo-services/semantic_segmentation/src/user_config.yaml' |
|
} |
|
|
|
def banner(): |
|
return html.Div( |
|
id="banner", |
|
className="top-bar", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"justify-content": "space-between", |
|
"position": "fixed", |
|
"top": "0", |
|
"left": "0", |
|
"width": "100%", |
|
"z-index": "1000", |
|
"background": "linear-gradient(to right, #03234b, #054080)", |
|
"box-shadow": "0px 4px 8px rgba(0, 0, 0, 0.2)", |
|
"border-radius": "0 0 10px 10px" |
|
}, |
|
children=[ |
|
html.A( |
|
id="learn-more-button", |
|
children=[ |
|
html.Img( |
|
src=app.get_asset_url("github-mark-white.png"), |
|
style={"width": "22px", "height": "22px", "margin-right": "8px"} |
|
), |
|
html.Span("stm32ai-modelzoo", style={"font-weight": "bold"}) |
|
], |
|
href="https://github.com/STMicroelectronics/stm32ai-modelzoo-services", |
|
target="_blank", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"color": "#ffffff", |
|
"text-decoration": "none", |
|
"font-size": "16px", |
|
"font-family": "Arial, sans-serif", |
|
"transition": "color 0.3s ease" |
|
} |
|
), |
|
html.Div( |
|
[ |
|
dbc.Button( |
|
html.Img( |
|
src=app.get_asset_url("logs.jpg"), |
|
style={"width": "22px", "height": "22px","margin-right":"10px"} |
|
), |
|
id="toggle-log", |
|
n_clicks=0, |
|
className="", |
|
style={ |
|
"background": "none", |
|
"border": "none", |
|
"padding": "0", |
|
"margin-right": "12px", |
|
"cursor": "pointer" |
|
} |
|
), |
|
html.A( |
|
html.H5( |
|
"ST Edge AI Developer Cloud", |
|
style={ |
|
"margin": "0", |
|
"color": "#ffffff", |
|
"font-size": "16px", |
|
"font-weight": "bold", |
|
"font-family": "Arial, sans-serif", |
|
"transition": "color 0.3s ease" |
|
} |
|
), |
|
href="https://stm32ai-cs.st.com/home", |
|
target="_blank", |
|
style={ |
|
"display": "flex", |
|
"align-items": "center", |
|
"text-decoration": "none" |
|
} |
|
) |
|
], |
|
style={"display": "flex", "align-items": "center"} |
|
) |
|
] |
|
) |
|
|
|
def create_dashboard_layout(): |
|
""" |
|
Creates the layout for the application: STM32ModelZoo dashboard. |
|
|
|
This function defines the structure and components of the dashboard, |
|
including the banner, model selection dropdown, YAML update options, |
|
credentials input, output display, training metrics graphs, and download button. |
|
|
|
Returns: |
|
dbc.Container: A Dash Bootstrap Component container with the dashboard layout. |
|
""" |
|
return html.Div([ |
|
banner(), |
|
dbc.Container([ |
|
dcc.Location(id='url', refresh=False), |
|
dbc.Row(dbc.Col(html.H3("STM32 Model zoo Dashboard", style={'color': '#03234b', 'text-align': 'center', "margin-top": "80px", "font-family": "Arial, sans-serif"}), className="mb-4")), |
|
dbc.Row([ |
|
dbc.Col( |
|
html.H5("Use case selection", style={'color': '#03234b', 'margin-bottom': '10px'}), |
|
width=12 |
|
) |
|
], id="use-case-section", style={"display": "none"}), |
|
dbc.Row(dbc.Col(dcc.Dropdown( |
|
id='selected-model', |
|
options=[ |
|
{'label': 'Image Classification (IC)', 'value': 'image_classification'}, |
|
{'label': 'Human Activity Recognition (HAR)', 'value': 'human_activity_recognition'}, |
|
{'label': 'Hand Posture', 'value': 'hand_posture'}, |
|
{'label': 'Audio Event Detection(AED)', 'value': 'audio_event_detection'}, |
|
{'label': 'Object Detection', 'value': 'object_detection'}, |
|
{'label': 'Pose estimation', 'value': 'pose_estimation'}, |
|
{'label': 'Semantic Segmentation', 'value': 'semantic_segmentation'}, |
|
], |
|
placeholder="Please select your use case", |
|
className="mb-4" |
|
))), |
|
dbc.Row( |
|
dbc.Col( |
|
html.Div( |
|
id='toggle-yaml', |
|
children=[ |
|
dbc.Button("How to update User Config ", id="open-offcanvas", n_clicks=0), |
|
dbc.Offcanvas( |
|
html.Div( |
|
[ |
|
html.P([ |
|
html.Strong("Configure Dataset section:"), |
|
html.Br(), |
|
"- Dataset path: ../datasets/your_use_case/name_of_dataset or datasets/your_prepared_dataset.", |
|
html.Br(), |
|
html.Br(), |
|
"- For more details, refer to the ", |
|
html.A("README", href="https://huggingface.co./spaces/STMicroelectronics/stm32-modelzoo-app/blob/main/datasets/README.md", target="_blank", style={'color': '#007bff', 'text-decoration': 'underline'}), |
|
".", |
|
html.Br(), |
|
html.Br(), |
|
"- If you need to upload your model for evaluation, benchmarking or quantizig:", |
|
html.Br(), |
|
"- Update model path under General section: models/your_model_name", |
|
html.Br(), |
|
"- For more details, refer to the ", |
|
html.A("README", href="https://huggingface.co./spaces/STMicroelectronics/stm32-modelzoo-app/blob/main/models/README.md", target="_blank", style={'color': '#007bff', 'text-decoration': 'underline'}), |
|
"." |
|
], style={'font-family': 'Arial, sans-serif', 'color': '#03234b', 'fontSize': '18px'}) |
|
] |
|
), |
|
id="offcanvas", |
|
is_open=False, |
|
title="📚 Help", |
|
placement="end", |
|
), |
|
dcc.RadioItems( |
|
id='modify-yaml-choice', |
|
labelStyle={'display': 'inline-block', 'margin-right': '10px'}, |
|
className="mb-4", |
|
), |
|
dcc.Upload( |
|
id='load-yaml-file', |
|
children=html.Button('Upload YAML File'), |
|
style={'display': 'none'} |
|
), |
|
html.Div(id='load-state', style={'margin-top': '10px'}), |
|
html.Div(id='yaml-layout', style={'display': 'none'}) |
|
], |
|
style={'font-family': 'Arial, sans-serif'} |
|
) |
|
) |
|
), |
|
dbc.Row([ |
|
dbc.Col([ |
|
html.P("Enter your ST Edge AI Developer Cloud credentials:", style={'color': '03234b', 'fontSize': '15px', 'fontWeight': 'bold'}, className="credentials-text"), |
|
dcc.Input(id='devcloud-username-input', type='text', placeholder='Enter username', className="input-field mb-2"), |
|
dcc.Input(id='devcloud-password-input', type='password', placeholder='Enter password', className="input-field mb-4") |
|
], width=6), |
|
dbc.Col([ |
|
dbc.Button('Launch training', id='process-button', color="#ceecf9", className="start-button mb-4", style={'display': 'none', 'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)'}) |
|
], className="credentials-col") |
|
], id='credentials-section', style={ |
|
'display': 'none', |
|
'justify-content': 'center', |
|
'align-items': 'center', |
|
'height': '100vh', |
|
}, className="credentials-section mb-4"), |
|
|
|
dbc.Row([ |
|
dbc.Col( |
|
html.H5("Results visualization", style={'color': '#03234b', 'margin-bottom': '10px'}), |
|
width=12 |
|
) |
|
], id="results-section", style={"display": "none"}), |
|
dbc.Row([ |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
|
dbc.CardBody( |
|
dcc.Graph(id='acc-visualization', style={'height': '100%', 'width': '100%'}), |
|
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
|
) |
|
]), width=6, style={'padding': '10px'}), |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Metrics", style={'background-color': '#03234b', 'color': 'white'}), |
|
dbc.CardBody( |
|
dcc.Graph(id='loss-visualization', style={'height': '100%', 'width': '100%'}), |
|
style={'height': '400px', 'display': 'flex', 'justify-content': 'center', 'align-items': 'center'} |
|
) |
|
]), width=6, style={'padding': '10px'}) |
|
], style={'margin-bottom': '30px'}), |
|
dbc.Row([ |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Memory Usage", style={'background-color': '#8191a5', 'color': 'white', 'font-size': '20px'}), |
|
dbc.CardBody(dcc.Graph(id='memory-bar')) |
|
]), width=4), |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Inference Time", style={'background-color': '#8191a5', 'color': 'white', 'font-size': '20px'}), |
|
dbc.CardBody(dcc.Graph(id='inference-time')) |
|
]), width=4), |
|
],justify="center"), |
|
|
|
dbc.Row([ |
|
html.Div(id='metric-graphs-container', style={ |
|
'margin-bottom': '30px' |
|
}) |
|
]), |
|
dcc.Interval(id='interval-widget', interval=1000, n_intervals=0), |
|
dcc.Download(id="download-resource"), |
|
dbc.Row( |
|
dbc.Col( |
|
dbc.Button('Download outputs', id='download-action', className="mb-4", style={ |
|
'background-color': '#ffd200', |
|
'color': '#ffffff', |
|
'font-size': '14px', |
|
'padding': '10px 10px', |
|
'border-radius': '5px', |
|
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
|
'margin-top': '20px' |
|
}), |
|
style={ |
|
'display': 'flex', |
|
'justify-content': 'center', |
|
'alignItems': 'center', |
|
} |
|
) |
|
), |
|
dbc.Row([ |
|
dbc.Col(dbc.Card([ |
|
dbc.CardHeader("Confusion Matrix", style={'background-color': '#8191a5', 'color': 'white', 'font-size': '20px'}), |
|
dbc.CardBody( |
|
html.Div( |
|
html.Img(id='confusion-matrix-img', style={'max-width': '100%', 'height': 'auto'}), |
|
style={'display': 'flex', 'justify-content': 'center', 'align-items': 'center', 'height': '100%'} |
|
) |
|
) |
|
]), width=12) |
|
], justify="center") |
|
], fluid=True), |
|
dbc.Offcanvas( |
|
html.Div(id='log-reader', style={'whiteSpace': 'pre-wrap', 'padding': '15px', 'height': '200px', 'overflow': 'auto'}), |
|
id="log-offcanvas", |
|
is_open=False, |
|
placement="bottom", |
|
style={'height': '200px', 'background-color': '#343a40', 'color': 'white', 'resize': 'vertical', 'overflow': 'auto'} |
|
), |
|
]) |
|
|
|
def read_configs(selected_model): |
|
""" |
|
Loads a YAML file based on the selected model by the user. |
|
Args: |
|
selected_model (str): The key to select the appropriate YAML file path. |
|
Returns: |
|
dict: The loaded YAML data. |
|
""" |
|
if not selected_model: |
|
raise ValueError("No model selected. Please select a valid model.") |
|
if selected_model not in local_yamls: |
|
raise ValueError(f"Model '{selected_model}' not found in local_yamls") |
|
|
|
yaml_path = local_yamls[selected_model] |
|
try: |
|
with open(yaml_path, 'r') as file: |
|
return yaml.safe_load(file) |
|
except Exception as e: |
|
raise ValueError(f"Error reading YAML file at {yaml_path}: {e}") |
|
|
|
|
|
def build_yaml_form(yaml_content, parent_key=''): |
|
""" |
|
Recursively builds a form based on the provided YAML content. |
|
|
|
Parameters: |
|
- yaml_content (dict): The YAML content to build the form from. |
|
- parent_key (str): The parent key to maintain the hierarchy of nested keys. Default is an empty string. |
|
|
|
Returns: |
|
- list: A list of Dash Bootstrap Components (dbc) AccordionItems representing the form fields. |
|
""" |
|
|
|
hidden_sections = {'tools', 'deployment', 'mlflow', 'hydra'} |
|
accordion_items = [] |
|
|
|
for key, value in yaml_content.items(): |
|
if key in hidden_sections and parent_key == '': |
|
continue |
|
|
|
full_key = f"{parent_key}.{key}" if parent_key else key |
|
|
|
if isinstance(value, dict): |
|
if full_key == "dataset": |
|
section_title = html.Span([ |
|
"Dataset ", |
|
html.Span("*", style={"color": "red", "fontWeight": "bold"}), |
|
html.Span(" (Set dataset path)", style={"fontSize": "0.85rem", "color": "#dc3545", "marginLeft": "5px"}) |
|
]) |
|
else: |
|
section_title = key.capitalize() |
|
|
|
nested_accordion = build_yaml_form(value, full_key) |
|
accordion_items.append( |
|
dbc.AccordionItem( |
|
nested_accordion, |
|
title=section_title |
|
) |
|
) |
|
else: |
|
field = [html.Label(key, style={"font-weight": "bold", "margin-bottom": "5px"})] |
|
|
|
if isinstance(value, bool): |
|
field.append( |
|
dcc.Checklist( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
options=[{'label': '', 'value': True}], |
|
value=[True] if value else [], |
|
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
|
) |
|
) |
|
|
|
|
|
elif isinstance(value, list): |
|
field.append( |
|
dcc.Dropdown( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
options=[{'label': str(v), 'value': v} for v in value], |
|
value=value, |
|
multi=True, |
|
style={"padding": "10px", "border": "1px solid #ddd", "margin-bottom": "10px"} |
|
) |
|
) |
|
|
|
else: |
|
input_style = { |
|
"padding": "10px", |
|
"border": "1px solid #ddd", |
|
"margin-bottom": "10px", |
|
"width": "100%" |
|
} |
|
helper = None |
|
|
|
if full_key == "dataset.training_path": |
|
input_style.update({ |
|
"border": "2px solid #ffc107", |
|
"backgroundColor": "#fff8e1" |
|
}) |
|
helper = html.Div( |
|
"⚠️ Please update dataset path.", |
|
style={ |
|
"color": "#856404", |
|
"fontSize": "0.85rem", |
|
"marginTop": "-8px", |
|
"marginBottom": "10px" |
|
} |
|
) |
|
|
|
field.append( |
|
dcc.Input( |
|
id={'type': 'yaml-setting', 'index': full_key}, |
|
value=value, |
|
type='text', |
|
style=input_style |
|
) |
|
) |
|
|
|
if helper: |
|
field.append(helper) |
|
|
|
accordion_items.append( |
|
dbc.AccordionItem( |
|
field, |
|
title=key.capitalize() |
|
) |
|
) |
|
|
|
return accordion_items |
|
|
|
def create_yaml(yaml_content): |
|
""" |
|
Creates a YAML form using Dash Bootstrap Components (dbc) and Dash HTML Components (html). |
|
|
|
Parameters: |
|
yaml_content (dict): The content of the YAML file to be used for building the form. |
|
|
|
Returns: |
|
dbc.Form: A Dash form component containing an accordion with the YAML content and a submit button. |
|
""" |
|
accordion_items = build_yaml_form(yaml_content) |
|
accordion = dbc.Accordion( |
|
accordion_items, |
|
start_collapsed=True |
|
) |
|
|
|
return dbc.Form([ |
|
accordion, |
|
html.Div( |
|
dbc.Button( |
|
'Submit', |
|
id='apply-button', |
|
style={ |
|
'background-color': '#FFD200', |
|
'color': '#03234b', |
|
'font-size': '14px', |
|
'padding': '10px 10px 10px 10px', |
|
'border-radius': '5px', |
|
'margin-top': '15px', |
|
'border': '2px solid #FFD200', |
|
'box-shadow': '0px 4px 6px rgba(0, 0, 0, 0.1)', |
|
} |
|
), |
|
style={ |
|
'display': 'flex', |
|
'justify-content': 'center', |
|
'margin-top': '15px', |
|
} |
|
), |
|
html.Div( |
|
id='submission-outcome', |
|
style={ |
|
'marginTop': '10px', |
|
'textAlign': 'center', |
|
'fontStyle': 'italic', |
|
'color': '#03234b', |
|
'font-size': '14px' |
|
} |
|
) |
|
]) |
|
|
|
|
|
def process_form_configs(form_configs): |
|
""" |
|
Extracts and processes form data to update YAML content. |
|
|
|
This function processes the form data, converting values to appropriate types |
|
and updating the YAML content accordingly. |
|
|
|
Args: |
|
form_configs (dict): The form data to be processed. |
|
|
|
Returns: |
|
dict: The updated YAML content with processed form data. |
|
""" |
|
updated_yaml = {} |
|
for key, value in form_configs.items(): |
|
if value is not None: |
|
if isinstance(value, list) and len(value) == 1: |
|
value = value[0] |
|
|
|
if isinstance(value, str): |
|
try: |
|
if '.' in value: |
|
value = float(value) |
|
else: |
|
value = int(value) |
|
except ValueError: |
|
pass |
|
|
|
updated_yaml[key] = value |
|
|
|
return updated_yaml |
|
|
|
|
|
def create_archive(archive_path, directory_to_compress): |
|
""" |
|
Creates a ZIP archive of a specified directory. |
|
|
|
Parameters: |
|
archive_path (str): The path where the ZIP archive will be created. |
|
directory_to_compress (str): The directory whose contents will be compressed into the ZIP archive. |
|
|
|
Returns: |
|
None |
|
""" |
|
def add_file_to_zip(zipf, file_path, arcname): |
|
""" |
|
Adds a file to the ZIP archive. |
|
|
|
Parameters: |
|
zipf (zipfile.ZipFile): The ZIP file object. |
|
file_path (str): The path of the file to add to the ZIP archive. |
|
arcname (str): The archive name for the file within the ZIP archive. |
|
|
|
Returns: |
|
None |
|
""" |
|
zipf.write(file_path, arcname=arcname) |
|
|
|
with zipfile.ZipFile(archive_path, 'w', compression=zipfile.ZIP_DEFLATED) as zipf: |
|
with ThreadPoolExecutor() as executor: |
|
for root_dir, sub_dirs, files in os.walk(directory_to_compress): |
|
for file_name in files: |
|
file_path = os.path.join(root_dir, file_name) |
|
if os.path.abspath(file_path) != os.path.abspath(archive_path): |
|
arcname = os.path.relpath(file_path, directory_to_compress) |
|
executor.submit(add_file_to_zip, zipf, file_path, arcname) |
|
|
|
|
|
app.layout = create_dashboard_layout |
|
|
|
logs = [] |
|
lock = threading.Lock() |
|
new_training = False |
|
|
|
def fill_logs(message): |
|
""" |
|
Appends a message to the logs list in a thread-safe manner and returns the formatted logs. |
|
|
|
Parameters: |
|
message (str): The message to be appended to the logs. |
|
|
|
Returns: |
|
html.Pre: The formatted logs as HTML content. |
|
""" |
|
with lock: |
|
logs.append(message) |
|
filtered_logs = filter_logs("\n".join(logs)) |
|
formatted_logs = format_logs(filtered_logs) |
|
return html.Pre(formatted_logs, style={'whiteSpace': 'pre-wrap', 'wordBreak': 'break-all'}) |
|
|
|
def filter_logs(logs): |
|
important_lines = [] |
|
for line in logs.split('\n'): |
|
if '[INFO]' in line or 'Epoch' in line or 'Total params' in line or 'Trainable params' in line or 'Non-trainable params' in line: |
|
important_lines.append(line) |
|
elif 'Segments built' in line: |
|
important_lines.append(line) |
|
return '\n'.join(important_lines) |
|
|
|
def format_logs(logs): |
|
formatted_logs = logs.replace('[INFO]', '\n[INFO]').replace('Epoch', '\nEpoch').replace('Segments built', '\nSegments built') |
|
return formatted_logs |
|
|
|
def extract_metrics(logs): |
|
metrics = { |
|
'float': {}, |
|
'quantized': {}, |
|
'oks': {}, |
|
} |
|
|
|
float_match = re.search(r"Accuracy of float model(?: on validation_set)?\s*=\s*([\d.]+)\s*%", logs) |
|
if float_match: |
|
metrics['float']['accuracy'] = float(float_match.group(1)) |
|
|
|
quant_match = re.search(r"Accuracy of quantized model(?: on validation_set)?\s*=\s*([\d.]+)\s*%", logs) |
|
if quant_match: |
|
metrics['quantized']['accuracy'] = float(quant_match.group(1)) |
|
|
|
precision_match = re.search(r"Mean precision:\s*([\d.]+)", logs) |
|
if precision_match: |
|
metrics['oks']['precision'] = float(precision_match.group(1)) |
|
|
|
recall_match = re.search(r"Mean recall:\s*([\d.]+)", logs) |
|
if recall_match: |
|
metrics['oks']['recall'] = float(recall_match.group(1)) |
|
|
|
ap_match = re.search(r"Mean AP $mAP$:\s*([\d.]+)", logs) |
|
if ap_match: |
|
metrics['oks']['mean_ap'] = float(ap_match.group(1)) |
|
|
|
oks_match = re.search(r"The mean OKS is :\s*([\d.]+)", logs) |
|
if oks_match: |
|
metrics['oks']['mean_oks'] = float(oks_match.group(1)) |
|
|
|
iou_match = re.search(r"Average IoU of float model \(all classes\) on validation_set\s*=\s*([\d.]+)\s*%", logs) |
|
if iou_match: |
|
metrics['float']['average_iou'] = float(iou_match.group(1)) |
|
|
|
return metrics |
|
|
|
|
|
def _parse_inference_memory(logs): |
|
metrics = {} |
|
""" |
|
patterns = { |
|
"ram": r"Total RAM\s*:\s*([\d.]+)\s*\(KiB\)", |
|
"flash": r"Total Flash\s*:\s*([\d.]+)\s*\(KiB\)", |
|
"inference_time": r"Inference Time\s*:\s*([\d.]+)\s*\(ms\)" |
|
} |
|
""" |
|
patterns = { |
|
"ram": r"Total RAM\s*:\s*([\d.]+)\s*\(KiB\)", |
|
"flash": r"Total Flash\s*:\s*([\d.]+)\s*\(KiB\)", |
|
"inference_time": r"Inference Time\s*:\s*([\d.]+)\s*\(ms\)" |
|
} |
|
|
|
for key, pattern in patterns.items(): |
|
matches = re.findall(pattern, logs) |
|
if matches: |
|
metrics[key] = float(matches[-1]) |
|
|
|
return metrics |
|
|
|
def create_accuracy_gauge(accuracy): |
|
return go.Figure(go.Indicator( |
|
mode="gauge+number", |
|
value=accuracy, |
|
title={'text': "Accuracy (%)"}, |
|
gauge={'axis':{'range': [0, 100]}, 'bar':{'color':"#49B170"}}, |
|
)) |
|
|
|
def create_iou_gauge(iou_value): |
|
fig = go.Figure(go.Indicator( |
|
mode="gauge+number", |
|
value=iou_value, |
|
title={'text': "IoU (%)"}, |
|
gauge={'axis': {'range': [0, 100]}, 'bar': {'color': "#49B170"}} |
|
)) |
|
return fig |
|
|
|
def latest_confusion_matrix(outputs_folder, recent_directory): |
|
cf_path = os.path.join(outputs_folder, recent_directory) |
|
|
|
for filename in os.listdir(cf_path): |
|
if "confusion_matrix" in filename and filename.endswith(".png"): |
|
image_path = os.path.join(cf_path, filename) |
|
with open(image_path, "rb") as f: |
|
encoded = base64.b64encode(f.read()).decode() |
|
return f"data:image/png;base64,{encoded}" |
|
return None |
|
|
|
def run_script(script, devcloud_username, devcloud_password): |
|
""" |
|
Executes a given script with the provided ST Developer Cloud credentials and logs the output. |
|
|
|
Parameters: |
|
- script (str): The path to the script to be executed. |
|
- devcloud_username (str): Username for ST Developer Cloud. |
|
- devcloud_password (str): Password for ST Developer Cloud. |
|
|
|
Returns: |
|
- None |
|
""" |
|
global logs |
|
|
|
with lock: |
|
logs = [] |
|
|
|
isolated_env = os.environ.copy() |
|
isolated_env['stmai_username'] = devcloud_username |
|
isolated_env['stmai_password'] = devcloud_password |
|
isolated_env['STATS_TYPE'] = 'HuggingFace_devcloud' |
|
|
|
execution = subprocess.Popen(['python3', script], env=isolated_env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
|
while True: |
|
file_descriptors = [execution.stdout.fileno(), execution.stderr.fileno()] |
|
selected_descriptors = select.select(file_descriptors, [], []) |
|
|
|
for descriptor in selected_descriptors[0]: |
|
if descriptor == execution.stdout.fileno(): |
|
out = execution.stdout.readline() |
|
if out: |
|
fill_logs(out) |
|
if out == '' and execution.poll() is not None: |
|
return |
|
if descriptor == execution.stderr.fileno(): |
|
error = execution.stderr.readline() |
|
if error: |
|
fill_logs(error) |
|
|
|
def execute_async(script, devcloud_username, devcloud_password): |
|
""" |
|
Executes a Python script asynchronously in a separate thread. |
|
|
|
Parameters: |
|
script (str): The path to the Python script to be executed. |
|
devcloud_username (str): The username for the DevCloud environment. |
|
devcloud_password (str): The password for the DevCloud environment. |
|
|
|
Returns: |
|
None |
|
""" |
|
thread = threading.Thread(target=run_script, args=(script, devcloud_username, devcloud_password)) |
|
thread.start() |
|
|
|
|
|
@app.callback( |
|
Output("config-section", "style"), |
|
Input('selected-model', 'value') |
|
) |
|
def toggle_config_section(selected_model): |
|
""" |
|
Toggles the visibility of the configuration section based on the selected model. |
|
|
|
Parameters: |
|
selected_model (str): The value of the selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style for the configuration section. |
|
""" |
|
if selected_model: |
|
return {"display": "block"} |
|
else: |
|
return {"display": "none"} |
|
|
|
|
|
@app.callback( |
|
Output('toggle-yaml', 'style'), |
|
Input('selected-model', 'value') |
|
) |
|
def dipslay_yaml_container(selected_model): |
|
""" |
|
Toggles the display of the YAML update container based on the selected model. |
|
This function updates the CSS style of the YAML update container to either |
|
show or hide it based on whether a model is selected from the dropdown. |
|
Args: |
|
selected_model (str): The selected model from the dropdown. |
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the container. |
|
""" |
|
if selected_model: |
|
return {'display': 'block'} |
|
return {'display': 'none'} |
|
|
|
|
|
@app.callback( |
|
Output("offcanvas", "is_open"), |
|
[Input("open-offcanvas", "n_clicks")], |
|
[dash.dependencies.State("offcanvas", "is_open")], |
|
) |
|
def toggle_offcanvas(n1, is_open): |
|
if n1: |
|
return not is_open |
|
return is_open |
|
|
|
@app.callback( |
|
Output("log-offcanvas", "is_open"), |
|
[Input("toggle-log", "n_clicks")], |
|
[State("log-offcanvas", "is_open")], |
|
) |
|
def toggle_log(n, is_open): |
|
if n: |
|
return not is_open |
|
return is_open |
|
|
|
@app.callback( |
|
[Output('yaml-layout', 'style'), |
|
Output('yaml-layout', 'children')], |
|
[Input('modify-yaml-choice', 'value'), |
|
Input('selected-model', 'value')] |
|
) |
|
|
|
def display_yaml_form(selection_update, selected_model): |
|
""" |
|
Toggles the display of the YAML form and updates its content based on user input. |
|
This function updates the CSS style and content of the YAML form based on whether |
|
the user chooses to update the YAML file and a model is selected from the dropdown. |
|
Args: |
|
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
|
selected_model (str): The selected model from the dropdown. |
|
Returns: |
|
tuple: A tuple containing the CSS style to either display or hide the form, |
|
and the form content generated from the YAML data. |
|
""" |
|
|
|
if not selected_model: |
|
return {'display': 'none'}, "Please select a model to display its configuration." |
|
|
|
try: |
|
yaml_conf = read_configs(selected_model) |
|
form_conf = create_yaml(yaml_conf) |
|
return {'display': 'block'}, form_conf |
|
except ValueError as e: |
|
return {'display': 'none'}, f"Error: {str(e)}" |
|
except Exception as e: |
|
return {'display': 'none'}, f"Unexpected Error: {str(e)}" |
|
|
|
|
|
@app.callback( |
|
Output("log-reader", "style"), |
|
Input('apply-button', 'n_clicks') |
|
) |
|
def toggle_output_section(n_clicks): |
|
""" |
|
Toggles the visibility of the output logs section based on the number of clicks. |
|
|
|
Parameters: |
|
selected_model (str): The value of the selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style for the configuration section. |
|
""" |
|
if n_clicks is None or n_clicks == 0: |
|
return {'display': 'none'} |
|
return {'display': 'block'} |
|
|
|
|
|
@app.callback( |
|
Output('credentials-section', 'style'), |
|
[Input('modify-yaml-choice', 'value'), |
|
Input('selected-model', 'value'), |
|
Input('apply-button', 'n_clicks')] |
|
|
|
) |
|
def display_credentials(selection_update, selected_model, n_clicks): |
|
""" |
|
Toggles the display of the credentials input fields based on user input. |
|
|
|
This function updates the CSS style of the credentials input fields to either |
|
show or hide them based on the user's choice to update the YAML file and the |
|
selection of a model from the dropdown. |
|
|
|
Args: |
|
selection_update (str): The user's choice to update the YAML file ('yes' or 'no'). |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the credentials input fields. |
|
""" |
|
if n_clicks is None or n_clicks == 0: |
|
return {'display': 'none'} |
|
return {'display': 'block'} |
|
|
|
@app.callback( |
|
Output('process-button', 'style'), |
|
[Input('apply-button', 'n_clicks')] |
|
|
|
) |
|
def display_launch_training(n_clicks): |
|
""" |
|
Displays the process button based on the number of clicks on the apply button. |
|
Parameters: |
|
n_clicks (int): The number of times the apply button has been clicked. |
|
Returns: |
|
dict: A dictionary containing the CSS style for the process button. |
|
""" |
|
if n_clicks and n_clicks > 0: |
|
return {'display': 'inline-block'} |
|
return {'display': 'none'} |
|
|
|
@app.callback( |
|
Output("results-section", "style"), |
|
Output("toggle-log", "className"), |
|
Input("process-button", "n_clicks") |
|
) |
|
def display_results_section(n_clicks): |
|
""" |
|
Affiche la section des résultats et déclenche le clignotement du logo. |
|
""" |
|
if n_clicks and n_clicks > 0: |
|
return {"display": "block"}, "blinking" |
|
else: |
|
return {"display": "none"}, "" |
|
|
|
@app.callback( |
|
[Output('log-reader', 'children'), |
|
Output('acc-visualization', 'figure'), |
|
Output('acc-visualization', 'style'), |
|
Output('loss-visualization', 'figure'), |
|
Output('loss-visualization', 'style'), |
|
Output('confusion-matrix-img', 'src')], |
|
[Input('interval-widget', 'n_intervals'), |
|
Input('process-button', 'n_clicks')], |
|
[State('selected-model', 'value'), |
|
State('devcloud-username-input', 'value'), |
|
State('devcloud-password-input', 'value')] |
|
) |
|
def refresh_metrics(n_intervals, nb_clicks, selected_model, devcloud_username, devcloud_password): |
|
""" |
|
Updates the log display and training metrics based on user actions and intervals. |
|
|
|
This function handles the following: |
|
- Executes the training script when the run button is clicked and updates the logs. |
|
- Periodically checks for new training metrics and updates the accuracy and loss graphs. |
|
- Manages the display of the log and metrics components based on the training status. |
|
|
|
Args: |
|
n_intervals (int): The number of intervals that have passed for the interval component. |
|
nb_clicks (int): The number of times the run button has been clicked. |
|
selected_model (str): The selected model from the dropdown. |
|
devcloud_username (str): The username for authentication. |
|
devcloud_password (str): The password for authentication. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- str: The updated log messages. |
|
- dict: The figure data for the accuracy graph. |
|
- dict: The CSS style to display or hide the accuracy graph. |
|
- dict: The figure data for the loss graph. |
|
- dict: The CSS style to display or hide the loss graph. |
|
- str: The base64 encoded image source for the confusion matrix. |
|
|
|
Raises: |
|
PreventUpdate: If the callback context is not triggered by a relevant input. |
|
""" |
|
|
|
global logs, new_training |
|
|
|
callback_context = dash.callback_context |
|
if not callback_context.triggered: |
|
raise PreventUpdate |
|
|
|
button = callback_context.triggered[0]['prop_id'].split('.')[0] |
|
|
|
if button == 'process-button' and nb_clicks: |
|
if devcloud_username and devcloud_password: |
|
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
|
execute_async(st_script, devcloud_username, devcloud_password) |
|
new_training = True |
|
logs.append("Starting application ...") |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
else: |
|
logs.append("Please enter both ST Developer Cloud username and password:") |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
|
|
elif button == 'interval-widget': |
|
if not new_training: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
|
|
outputs_folder = "experiments_outputs" |
|
|
|
if not os.path.exists(outputs_folder): |
|
os.makedirs(outputs_folder) |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
|
|
dated_directories = [d for d in os.listdir(outputs_folder) if os.path.isdir(os.path.join(outputs_folder, d)) and d.startswith('20')] |
|
if dated_directories: |
|
recent_directory = max(dated_directories, key=lambda d: datetime.strptime(d, '%Y_%m_%d_%H_%M_%S')) |
|
train_metrics_file = os.path.join(outputs_folder, recent_directory, 'logs', 'metrics', 'train_metrics.csv') |
|
print(f"Metrics file : {train_metrics_file}") |
|
if os.path.exists(train_metrics_file) and new_training: |
|
metrics_dataframe = pd.read_csv(train_metrics_file) |
|
if not metrics_dataframe.empty: |
|
figures = [] |
|
metrics_pairs = [ |
|
('accuracy', 'val_accuracy'), |
|
('loss', 'val_loss'), |
|
('oks', 'val_oks'), |
|
('val_map',) |
|
] |
|
for pair in metrics_pairs: |
|
if len(pair) == 2: |
|
train_metric, val_metric = pair |
|
if train_metric in metrics_dataframe.columns and val_metric in metrics_dataframe.columns: |
|
fig = { |
|
'data': [ |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[train_metric], |
|
'type': 'line', |
|
'name': train_metric.capitalize(), |
|
'line': {'color': '#FFD200', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
}, |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[val_metric], |
|
'type': 'line', |
|
'name': val_metric.capitalize(), |
|
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
} |
|
], |
|
'layout': { |
|
'title': { |
|
'text': f'{train_metric.capitalize()} vs {val_metric.capitalize()}', |
|
'x': 0.5, |
|
'xanchor': 'center' |
|
}, |
|
'xaxis': { |
|
'title': 'Epochs', |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1', |
|
'tickangle': 45 |
|
}, |
|
'yaxis': { |
|
'title': train_metric.capitalize(), |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1' |
|
}, |
|
'showlegend': True, |
|
'legend': { |
|
'x': 1, |
|
'y': 1, |
|
'traceorder': 'normal', |
|
'font': {'size': 10}, |
|
'bgcolor': '#EEEFF1', |
|
'bordercolor': '#A6ADB5', |
|
'borderwidth': 1 |
|
}, |
|
'hovermode': 'closest', |
|
'plot_bgcolor': '#ffffff' |
|
} |
|
} |
|
figures.append(fig) |
|
elif len(pair) == 1: |
|
val_metric = pair[0] |
|
if val_metric in metrics_dataframe.columns: |
|
fig = { |
|
'data': [ |
|
{ |
|
'x': metrics_dataframe['epoch'], |
|
'y': metrics_dataframe[val_metric], |
|
'type': 'line', |
|
'name': val_metric.capitalize(), |
|
'line': {'color': '#3CB4E6', 'width': 2, 'dash': 'solid'}, |
|
'hoverinfo': 'x+y+name', |
|
'hoverlabel': {'bgcolor': '#EEEFF1', 'font': {'color': '#525A63'}} |
|
} |
|
], |
|
'layout': { |
|
'title': { |
|
'text': f'{val_metric.capitalize()} over Epochs', |
|
'x': 0.5, |
|
'xanchor': 'center' |
|
}, |
|
'xaxis': { |
|
'title': 'Epochs', |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1', |
|
'tickangle': 45 |
|
}, |
|
'yaxis': { |
|
'title': val_metric.capitalize(), |
|
'showgrid': True, |
|
'gridcolor': '#EEEFF1' |
|
}, |
|
'showlegend': True, |
|
'legend': { |
|
'x': 1, |
|
'y': 1, |
|
'traceorder': 'normal', |
|
'font': {'size': 10}, |
|
'bgcolor': '#EEEFF1', |
|
'bordercolor': '#A6ADB5', |
|
'borderwidth': 1 |
|
}, |
|
'hovermode': 'closest', |
|
'plot_bgcolor': '#ffffff' |
|
} |
|
} |
|
figures.append(fig) |
|
|
|
confusion_matrix_src = latest_confusion_matrix(outputs_folder, recent_directory) |
|
|
|
if figures: |
|
return "\n".join(logs), figures[0], {'display': 'block'}, figures[1] if len(figures) > 1 else {}, {'display': 'block'}, confusion_matrix_src |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, confusion_matrix_src |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
else: |
|
return "\n".join(logs), {}, {'display': 'none'}, {}, {'display': 'none'}, None |
|
|
|
raise PreventUpdate |
|
|
|
@app.callback( |
|
Output('metric-graphs-container', 'children'), |
|
Input('log-reader', 'children') |
|
) |
|
def update_metrics_dashboard(logs): |
|
metrics = extract_metrics(logs) |
|
graphs = [] |
|
|
|
def get_metric_card(title, figure): |
|
return dbc.Col( |
|
dbc.Card([ |
|
dbc.CardHeader(title, style={'background-color': '#8191a5', 'color': 'white', 'font-size': '18px'}), |
|
dbc.CardBody(dcc.Graph(figure=figure, config={'displayModeBar': False})) |
|
]), |
|
width=4 |
|
) |
|
|
|
average_iou = metrics.get("float", {}).get("average_iou") |
|
if average_iou is not None: |
|
graphs.append(get_metric_card("Average IoU - Float Model", create_iou_gauge(average_iou))) |
|
|
|
if "float" in metrics: |
|
acc = metrics["float"].get("accuracy") |
|
if acc is not None: |
|
graphs.append(get_metric_card("Accuracy - Float Model", create_accuracy_gauge(acc))) |
|
|
|
if "quantized" in metrics: |
|
acc = metrics["quantized"].get("accuracy") |
|
if acc is not None: |
|
graphs.append(get_metric_card("Accuracy - Quantized Model", create_accuracy_gauge(acc))) |
|
|
|
if "oks" in metrics: |
|
mean_oks = metrics["oks"].get("mean_oks") |
|
if mean_oks is not None: |
|
graphs.append(get_metric_card("Mean OKS", create_accuracy_gauge(mean_oks))) |
|
|
|
precision = metrics["oks"].get("precision") |
|
if precision is not None: |
|
graphs.append(get_metric_card("Mean Precision", create_accuracy_gauge(precision))) |
|
|
|
recall = metrics["oks"].get("recall") |
|
if recall is not None: |
|
graphs.append(get_metric_card("Mean Recall", create_accuracy_gauge(recall))) |
|
|
|
mean_ap = metrics["oks"].get("mean_ap") |
|
if mean_ap is not None: |
|
graphs.append(get_metric_card("Mean AP (mAP)", create_accuracy_gauge(mean_ap))) |
|
|
|
return dbc.Row(graphs, justify="center") |
|
|
|
|
|
@app.callback( |
|
Output('memory-bar', 'figure'), |
|
Input('log-reader', 'children') |
|
) |
|
def update_memory_bar(logs): |
|
metrics = _parse_inference_memory(logs) |
|
ram = metrics.get('ram', 0) |
|
flash = metrics.get('flash', 0) |
|
|
|
fig = go.Figure() |
|
fig.add_trace(go.Bar( |
|
y=["Total RAM ", "Total Flash"], |
|
x=[ram, flash], |
|
orientation='h', |
|
marker_color=["#E6007E", "#3cb4e6"] |
|
)) |
|
|
|
fig.update_layout(title="Memory Usage (KiB)", xaxis_title="Size (KiB)") |
|
return fig |
|
|
|
@app.callback( |
|
Output('inference-time', 'figure'), |
|
Input('log-reader', 'children') |
|
) |
|
def update_inference_time(logs): |
|
metrics = _parse_inference_memory(logs) |
|
inference_time = metrics.get('inference_time', 0) |
|
|
|
fig = go.Figure(go.Indicator( |
|
mode="number", |
|
value=inference_time, |
|
title={'text': "Inference Time (ms)"}, |
|
)) |
|
return fig |
|
|
|
@app.callback( |
|
Output('submission-outcome', 'children'), |
|
[Input('apply-button', 'n_clicks'), |
|
Input('process-button', 'n_clicks')], |
|
[State({'type': 'yaml-setting', 'index': ALL}, 'id'), |
|
State({'type': 'yaml-setting', 'index': ALL}, 'value'), |
|
State('selected-model', 'value'), |
|
State('devcloud-username-input', 'value'), |
|
State('devcloud-password-input', 'value')] |
|
) |
|
def process_button_actions(submit_clicks, exec_nb_clicks, form_input_ids, form_input_values, selected_model, devcloud_username, devcloud_password): |
|
""" |
|
Handles the actions triggered by the submit and run buttons. |
|
|
|
This function processes the form data when the submit button is clicked, |
|
updates the corresponding YAML file, and executes the training script when |
|
the run button is clicked. |
|
|
|
Args: |
|
submit_clicks (int): The number of times the submit button has been clicked. |
|
exec_nb_clicks (int): The number of times the execution/run button has been clicked. |
|
form_input_ids (list): A list of dictionaries containing the IDs of the form inputs. |
|
form_input_values (list): A list of values from the form inputs. |
|
selected_model (str): The selected model from the dropdown. |
|
devcloud_username (str): The username for DevCloud authentication. |
|
devcloud_password (str): The password for DevCloud authentication. |
|
|
|
Returns: |
|
str: A message indicating the result of the action, such as successful YAML update or script execution status. |
|
|
|
Raises: |
|
PreventUpdate: If the callback context is not triggered by a relevant input or if no action is taken. |
|
""" |
|
new_fields = [] |
|
|
|
callback_context = dash.callback_context |
|
if not callback_context.triggered: |
|
raise PreventUpdate |
|
|
|
triggered_button = callback_context.triggered[0]['prop_id'].split('.')[0] |
|
|
|
if triggered_button == 'apply-button': |
|
if submit_clicks: |
|
try: |
|
form_fields_data = {} |
|
for i in range(len(form_input_ids)): |
|
input_id = form_input_ids[i]['index'] |
|
input_value = form_input_values[i] |
|
form_fields_data[input_id] = input_value |
|
|
|
yaml_file_path = local_yamls.get(selected_model) |
|
if yaml_file_path : |
|
yaml_parser = ruamel.yaml.YAML() |
|
with open(yaml_file_path , 'r') as file: |
|
current_yaml_data = yaml_parser.load(file) |
|
|
|
updated_yaml_data = process_form_configs(form_fields_data) |
|
for key, value in updated_yaml_data.items(): |
|
keys = key.split('.') |
|
nested_dict = current_yaml_data |
|
for k in keys[:-1]: |
|
nested_dict = nested_dict.setdefault(k, {}) |
|
if nested_dict[keys[-1]] != value: |
|
nested_dict[keys[-1]] = value |
|
new_fields.append(key) |
|
|
|
with open(yaml_file_path , 'w') as file: |
|
yaml_parser.dump(current_yaml_data, file) |
|
|
|
return f"User config yaml file has been updated successfully ! Updated fields are: {', '.join(new_fields)}" |
|
else: |
|
return f"ERROR: No user config yaml found for '{selected_model}'." |
|
except Exception as e: |
|
return f"ERROR: UPDATING USER CONFIG YAML file: {e}" |
|
else: |
|
raise PreventUpdate |
|
elif triggered_button == 'process-button': |
|
if exec_nb_clicks: |
|
st_script = f"stm32ai-modelzoo-services/{selected_model}/src/stm32ai_main.py" |
|
execute_async(st_script, devcloud_username, devcloud_password) |
|
return "Application is running ..." |
|
else: |
|
raise PreventUpdate |
|
|
|
|
|
|
|
@app.callback( |
|
Output('download-action', 'style'), |
|
[Input('interval-widget', 'n_intervals')], |
|
[State('selected-model', 'value')] |
|
) |
|
def toggle_download_button(n_intervals, selected_model): |
|
""" |
|
Toggles the display of the download button based on the existence of output directories. |
|
|
|
This function checks if the output directories for the selected model exist and |
|
toggles the display of the download button accordingly. |
|
|
|
Args: |
|
n_intervals (int): The number of intervals that have passed for the interval component. |
|
model_choice (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dict: A dictionary containing the CSS style to either display or hide the download button. |
|
""" |
|
out_directory = os.path.join(os.getcwd(), "experiments_outputs") |
|
|
|
if not os.path.exists(out_directory ): |
|
return {'display': 'none'} |
|
|
|
output_subdirectories = [d for d in os.listdir(out_directory ) if os.path.isdir(os.path.join(out_directory , d)) and d.startswith('20')] |
|
|
|
if output_subdirectories: |
|
return {'display': 'block'} |
|
return {'display': 'none'} |
|
|
|
|
|
@app.callback( |
|
Output('download-resource', 'data'), |
|
[Input('download-action', 'n_clicks')], |
|
[State('selected-model', 'value')] |
|
) |
|
def generate_download_link(n_clicks, selected_model): |
|
""" |
|
Generates a download link based on the selected model and operation mode. |
|
|
|
This function reads the YAML configuration for the selected model, determines the operation mode, |
|
and generates a download link for the appropriate file (ZIP or ELF/BIN) based on the operation mode. |
|
|
|
Args: |
|
click_count (int): The number of times the download button has been clicked. |
|
selected_model (str): The selected model from the dropdown. |
|
|
|
Returns: |
|
dcc.send_file: A Dash component to send the file for download. |
|
|
|
Raises: |
|
PreventUpdate: If no relevant action is taken or the required files do not exist. |
|
""" |
|
|
|
if n_clicks is None: |
|
raise PreventUpdate |
|
|
|
|
|
output_directory = os.path.join(os.getcwd(), "./experiments_outputs") |
|
|
|
if not os.path.exists(output_directory ): |
|
raise PreventUpdate |
|
|
|
|
|
timestamped_directories = [d for d in os.listdir(output_directory ) if os.path.isdir(os.path.join(output_directory , d)) and d.startswith('20')] |
|
|
|
timestamped_directories = [ |
|
d for d in os.listdir(output_directory) |
|
if os.path.isdir(os.path.join(output_directory, d)) and d.startswith("20") |
|
] |
|
|
|
if timestamped_directories: |
|
recent_directory = max( |
|
timestamped_directories, |
|
key=lambda d: datetime.strptime(d, "%Y_%m_%d_%H_%M_%S") |
|
) |
|
recent_directory_path = os.path.join(output_directory, recent_directory) |
|
zip_file_path = os.path.join(recent_directory_path, f"{recent_directory}.zip") |
|
|
|
|
|
if not os.path.exists(zip_file_path): |
|
create_archive(zip_file_path, recent_directory_path) |
|
|
|
|
|
if os.path.exists(zip_file_path): |
|
return dcc.send_file(zip_file_path) |
|
|
|
raise PreventUpdate |
|
|
|
@server.route('/download/<path:subpath>') |
|
def download_file(subpath): |
|
""" |
|
Route to download a file from the server. |
|
|
|
Parameters: |
|
- subpath (str): The subpath of the file to be downloaded, relative to the './experiments_outputs' directory. |
|
|
|
Returns: |
|
- Response: A Flask response object to send the file as an attachment if it exists. |
|
- tuple: A tuple containing an error message and a 404 status code if the file is not found. |
|
""" |
|
file_path = os.path.join(os.getcwd(), './experiments_outputs', subpath) |
|
if os.path.exists(file_path): |
|
return send_file(file_path, as_attachment=True) |
|
else: |
|
return "File not found", 404 |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run_server(host='0.0.0.0',port=7860, dev_tools_ui=True, dev_tools_hot_reload=True, threaded=True) |