Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# Copyright 2017 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""This tool creates an html visualization of a TensorFlow Lite graph. | |
Example usage: | |
python visualize.py foo.tflite foo.html | |
""" | |
import json | |
import os | |
import re | |
import sys | |
import numpy as np | |
# pylint: disable=g-import-not-at-top | |
if not os.path.splitext(__file__)[0].endswith( | |
os.path.join("tflite_runtime", "visualize")): | |
# This file is part of tensorflow package. | |
from tensorflow.lite.python import schema_py_generated as schema_fb | |
else: | |
# This file is part of tflite_runtime package. | |
from tflite_runtime import schema_py_generated as schema_fb | |
import gradio as gr | |
from html import escape | |
# A CSS description for making the visualizer | |
# body {font-family: sans-serif; background-color: #fa0;} | |
# # font-family: sans-serif; | |
"""<style> | |
table {background-color: #eca;} | |
th {background-color: black; color: white;} | |
h1 { | |
background-color: ffaa00; | |
padding:5px; | |
color: black; | |
} | |
svg { | |
margin: 10px; | |
border: 2px; | |
border-style: solid; | |
border-color: black; | |
background: white; | |
} | |
div { | |
border-radius: 5px; | |
background-color: #fec; | |
padding:5px; | |
margin:5px; | |
} | |
.tooltip {color: blue;} | |
.tooltip .tooltipcontent { | |
visibility: hidden; | |
color: black; | |
background-color: yellow; | |
padding: 5px; | |
border-radius: 4px; | |
position: absolute; | |
z-index: 1; | |
} | |
.tooltip:hover .tooltipcontent { | |
visibility: visible; | |
} | |
.edges line { | |
stroke: #333; | |
} | |
text { | |
font-weight: bold; | |
} | |
.nodes text { | |
color: black; | |
pointer-events: none; | |
font-size: 11px; | |
} | |
</style>""" | |
_CSS = """ | |
<script src="https://d3js.org/d3.v4.min.js"></script> | |
""" | |
_D3_HTML_TEMPLATE = """ | |
<script> | |
function buildGraph() { | |
// Build graph data | |
var graph = %s; | |
var svg = d3.select("#subgraph%d") | |
var width = svg.attr("width"); | |
var height = svg.attr("height"); | |
// Make the graph scrollable. | |
svg = svg.call(d3.zoom().on("zoom", function() { | |
svg.attr("transform", d3.event.transform); | |
})).append("g"); | |
var color = d3.scaleOrdinal(d3.schemeDark2); | |
var simulation = d3.forceSimulation() | |
.force("link", d3.forceLink().id(function(d) {return d.id;})) | |
.force("charge", d3.forceManyBody()) | |
.force("center", d3.forceCenter(0.5 * width, 0.5 * height)); | |
var edge = svg.append("g").attr("class", "edges").selectAll("line") | |
.data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none") | |
// Make the node group | |
var node = svg.selectAll(".nodes") | |
.data(graph.nodes) | |
.enter().append("g") | |
.attr("x", function(d){return d.x}) | |
.attr("y", function(d){return d.y}) | |
.attr("transform", function(d) { | |
return "translate( " + d.x + ", " + d.y + ")" | |
}) | |
.attr("class", "nodes") | |
.call(d3.drag() | |
.on("start", function(d) { | |
if(!d3.event.active) simulation.alphaTarget(1.0).restart(); | |
d.fx = d.x;d.fy = d.y; | |
}) | |
.on("drag", function(d) { | |
d.fx = d3.event.x; d.fy = d3.event.y; | |
}) | |
.on("end", function(d) { | |
if (!d3.event.active) simulation.alphaTarget(0); | |
d.fx = d.fy = null; | |
})); | |
// Within the group, draw a box for the node position and text | |
// on the side. | |
var node_width = 150; | |
var node_height = 30; | |
node.append("rect") | |
.attr("r", "5px") | |
.attr("width", node_width) | |
.attr("height", node_height) | |
.attr("rx", function(d) { return d.group == 1 ? 1 : 10; }) | |
.attr("stroke", "#000000") | |
.attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; }) | |
node.append("text") | |
.text(function(d) { return d.name; }) | |
.attr("x", 5) | |
.attr("y", 20) | |
.attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; }) | |
// Setup force parameters and update position callback | |
var node = svg.selectAll(".nodes") | |
.data(graph.nodes); | |
// Bind the links | |
var name_to_g = {} | |
node.each(function(data, index, nodes) { | |
console.log(data.id) | |
name_to_g[data.id] = this; | |
}); | |
function proc(w, t) { | |
return parseInt(w.getAttribute(t)); | |
} | |
edge.attr("d", function(d) { | |
function lerp(t, a, b) { | |
return (1.0-t) * a + t * b; | |
} | |
var x1 = proc(name_to_g[d.source],"x") + node_width /2; | |
var y1 = proc(name_to_g[d.source],"y") + node_height; | |
var x2 = proc(name_to_g[d.target],"x") + node_width /2; | |
var y2 = proc(name_to_g[d.target],"y"); | |
var s = "M " + x1 + " " + y1 | |
+ " C " + x1 + " " + lerp(.5, y1, y2) | |
+ " " + x2 + " " + lerp(.5, y1, y2) | |
+ " " + x2 + " " + y2 | |
return s; | |
}); | |
} | |
console.log("Helllo!"); | |
buildGraph(); | |
</script> | |
""" | |
def TensorTypeToName(tensor_type): | |
"""Converts a numerical enum to a readable tensor type.""" | |
for name, value in schema_fb.TensorType.__dict__.items(): | |
if value == tensor_type: | |
return name | |
return None | |
def BuiltinCodeToName(code): | |
"""Converts a builtin op code enum to a readable name.""" | |
for name, value in schema_fb.BuiltinOperator.__dict__.items(): | |
if value == code: | |
return name | |
return None | |
def NameListToString(name_list): | |
"""Converts a list of integers to the equivalent ASCII string.""" | |
if isinstance(name_list, str): | |
return name_list | |
else: | |
result = "" | |
if name_list is not None: | |
for val in name_list: | |
result = result + chr(int(val)) | |
return result | |
class OpCodeMapper: | |
"""Maps an opcode index to an op name.""" | |
def __init__(self, data): | |
self.code_to_name = {} | |
for idx, d in enumerate(data["operator_codes"]): | |
self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"]) | |
if self.code_to_name[idx] == "CUSTOM": | |
self.code_to_name[idx] = NameListToString(d["custom_code"]) | |
def __call__(self, x): | |
if x not in self.code_to_name: | |
s = "<UNKNOWN>" | |
else: | |
s = self.code_to_name[x] | |
return "%s (%d)" % (s, x) | |
class DataSizeMapper: | |
"""For buffers, report the number of bytes.""" | |
def __call__(self, x): | |
if x is not None: | |
return "%d bytes" % len(x) | |
else: | |
return "--" | |
class TensorMapper: | |
"""Maps a list of tensor indices to a tooltip hoverable indicator of more.""" | |
def __init__(self, subgraph_data): | |
self.data = subgraph_data | |
def __call__(self, x): | |
html = "" | |
if x is None: | |
return html | |
html += "<span class='tooltip'><span class='tooltipcontent'>" | |
for i in x: | |
tensor = self.data["tensors"][i] | |
html += str(i) + " " | |
html += NameListToString(tensor["name"]) + " " | |
html += TensorTypeToName(tensor["type"]) + " " | |
html += (repr(tensor["shape"]) if "shape" in tensor else "[]") | |
html += (repr(tensor["shape_signature"]) | |
if "shape_signature" in tensor else "[]") + "<br>" | |
html += "</span>" | |
html += repr(x) | |
html += "</span>" | |
return html | |
def GenerateGraph(subgraph_idx, g, opcode_mapper): | |
"""Produces the HTML required to have a d3 visualization of the dag.""" | |
def TensorName(idx): | |
return "t%d" % idx | |
def OpName(idx): | |
return "o%d" % idx | |
edges = [] | |
nodes = [] | |
first = {} | |
second = {} | |
pixel_mult = 200 # TODO(aselle): multiplier for initial placement | |
width_mult = 170 # TODO(aselle): multiplier for initial placement | |
for op_index, op in enumerate(g["operators"] or []): | |
if op["inputs"] is not None: | |
for tensor_input_position, tensor_index in enumerate(op["inputs"]): | |
if tensor_index not in first: | |
first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult, | |
(tensor_input_position + 1) * width_mult) | |
edges.append({ | |
"source": TensorName(tensor_index), | |
"target": OpName(op_index) | |
}) | |
if op["outputs"] is not None: | |
for tensor_output_position, tensor_index in enumerate(op["outputs"]): | |
if tensor_index not in second: | |
second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult, | |
(tensor_output_position + 1) * width_mult) | |
edges.append({ | |
"target": TensorName(tensor_index), | |
"source": OpName(op_index) | |
}) | |
nodes.append({ | |
"id": OpName(op_index), | |
"name": opcode_mapper(op["opcode_index"]), | |
"group": 2, | |
"x": pixel_mult, | |
"y": (op_index + 1) * pixel_mult | |
}) | |
for tensor_index, tensor in enumerate(g["tensors"]): | |
initial_y = ( | |
first[tensor_index] if tensor_index in first else | |
second[tensor_index] if tensor_index in second else (0, 0)) | |
nodes.append({ | |
"id": TensorName(tensor_index), | |
"name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index), | |
"group": 1, | |
"x": initial_y[1], | |
"y": initial_y[0] | |
}) | |
graph_str = json.dumps({"nodes": nodes, "edges": edges}) | |
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx) | |
return html | |
def GenerateTableHtml(items, keys_to_print, display_index=True): | |
"""Given a list of object values and keys to print, make an HTML table. | |
Args: | |
items: Items to print an array of dicts. | |
keys_to_print: (key, display_fn). `key` is a key in the object. i.e. | |
items[0][key] should exist. display_fn is the mapping function on display. | |
i.e. the displayed html cell will have the string returned by | |
`mapping_fn(items[0][key])`. | |
display_index: add a column which is the index of each row in `items`. | |
Returns: | |
An html table. | |
""" | |
html = "" | |
# Print the list of items | |
html += "<table><tr>\n" | |
html += "<tr>\n" | |
if display_index: | |
html += "<th>index</th>" | |
for h, mapper in keys_to_print: | |
html += "<th>%s</th>" % h | |
html += "</tr>\n" | |
for idx, tensor in enumerate(items): | |
html += "<tr>\n" | |
if display_index: | |
html += "<td>%d</td>" % idx | |
# print tensor.keys() | |
for h, mapper in keys_to_print: | |
val = tensor[h] if h in tensor else None | |
val = val if mapper is None else mapper(val) | |
html += "<td>%s</td>\n" % val | |
html += "</tr>\n" | |
html += "</table>\n" | |
return html | |
def CamelCaseToSnakeCase(camel_case_input): | |
"""Converts an identifier in CamelCase to snake_case.""" | |
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input) | |
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() | |
def FlatbufferToDict(fb, preserve_as_numpy): | |
"""Converts a hierarchy of FB objects into a nested dict. | |
We avoid transforming big parts of the flat buffer into python arrays. This | |
speeds conversion from ten minutes to a few seconds on big graphs. | |
Args: | |
fb: a flat buffer structure. (i.e. ModelT) | |
preserve_as_numpy: true if all downstream np.arrays should be preserved. | |
false if all downstream np.array should become python arrays | |
Returns: | |
A dictionary representing the flatbuffer rather than a flatbuffer object. | |
""" | |
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str): | |
return fb | |
elif hasattr(fb, "__dict__"): | |
result = {} | |
for attribute_name in dir(fb): | |
attribute = fb.__getattribute__(attribute_name) | |
if not callable(attribute) and attribute_name[0] != "_": | |
snake_name = CamelCaseToSnakeCase(attribute_name) | |
preserve = True if attribute_name == "buffers" else preserve_as_numpy | |
result[snake_name] = FlatbufferToDict(attribute, preserve) | |
return result | |
elif isinstance(fb, np.ndarray): | |
return fb if preserve_as_numpy else fb.tolist() | |
elif hasattr(fb, "__len__"): | |
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb] | |
else: | |
return fb | |
def CreateDictFromFlatbuffer(buffer_data): | |
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0) | |
model = schema_fb.ModelT.InitFromObj(model_obj) | |
return FlatbufferToDict(model, preserve_as_numpy=False) | |
def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name | |
"""Returns html description with the given tflite model. | |
Args: | |
tflite_input: TFLite flatbuffer model path or model object. | |
input_is_filepath: Tells if tflite_input is a model path or a model object. | |
Returns: | |
Dump of the given tflite model in HTML format. | |
Raises: | |
RuntimeError: If the input is not valid. | |
""" | |
# Convert the model into a JSON flatbuffer using flatc (build if doesn't | |
# exist. | |
if input_is_filepath: | |
if not os.path.exists(tflite_input): | |
raise RuntimeError("Invalid filename %r" % tflite_input) | |
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin") or tflite_input.endswith(".tf_lite"): | |
with open(tflite_input, "rb") as file_handle: | |
file_data = bytearray(file_handle.read()) | |
data = CreateDictFromFlatbuffer(file_data) | |
elif tflite_input.endswith(".json"): | |
data = json.load(open(tflite_input)) | |
else: | |
raise RuntimeError("Input file was not .tflite or .json") | |
else: | |
data = CreateDictFromFlatbuffer(tflite_input) | |
html = "" | |
# html += _CSS | |
html += "<h1>TensorFlow Lite Model</h2>" | |
data["filename"] = tflite_input if input_is_filepath else ( | |
"Null (used model object)") # Avoid special case | |
toplevel_stuff = [("filename", None), ("version", None), | |
("description", None)] | |
html += "<table>\n" | |
for key, mapping in toplevel_stuff: | |
if not mapping: | |
mapping = lambda x: x | |
html += "<tr><th>%s</th><td>%s</td></tr>\n" % (key, mapping(data.get(key))) | |
html += "</table>\n" | |
# Spec on what keys to display | |
buffer_keys_to_display = [("data", DataSizeMapper())] | |
operator_keys_to_display = [("builtin_code", BuiltinCodeToName), | |
("custom_code", NameListToString), | |
("version", None)] | |
# Update builtin code fields. | |
for d in data["operator_codes"]: | |
d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"]) | |
for subgraph_idx, g in enumerate(data["subgraphs"]): | |
# Subgraph local specs on what to display | |
html += "<div class='subgraph'>" | |
tensor_mapper = TensorMapper(g) | |
opcode_mapper = OpCodeMapper(data) | |
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper), | |
("builtin_options", None), | |
("opcode_index", opcode_mapper)] | |
tensor_keys_to_display = [("name", NameListToString), | |
("type", TensorTypeToName), ("shape", None), | |
("shape_signature", None), ("buffer", None), | |
("quantization", None)] | |
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx | |
# Inputs and outputs. | |
html += "<h3>Inputs/Outputs</h3>\n" | |
html += GenerateTableHtml([{ | |
"inputs": g["inputs"], | |
"outputs": g["outputs"] | |
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)], | |
display_index=False) | |
# Print the tensors. | |
html += "<h3>Tensors</h3>\n" | |
html += GenerateTableHtml(g["tensors"], tensor_keys_to_display) | |
# Print the ops. | |
if g["operators"]: | |
html += "<h3>Ops</h3>\n" | |
html += GenerateTableHtml(g["operators"], op_keys_to_display) | |
# Visual graph. | |
html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % ( | |
subgraph_idx,) | |
html += GenerateGraph(subgraph_idx, g, opcode_mapper) | |
html += "</div>" | |
# Buffers have no data, but maybe in the future they will | |
html += "<h2>Buffers</h2>\n" | |
html += GenerateTableHtml(data["buffers"], buffer_keys_to_display) | |
# Operator codes | |
html += "<h2>Operator Codes</h2>\n" | |
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display) | |
# html += "</body></html>\n" | |
# return f"<iframe src={escape(html)} ></iframe>" | |
html += """ <script src="https://d3js.org/d3.v4.min.js"></script> """ | |
return html | |
def main(argv): | |
try: | |
tflite_input = argv[1] | |
html_output = argv[2] | |
except IndexError: | |
print("Usage: %s <input tflite> <output html>" % (argv[0])) | |
else: | |
html = create_html(tflite_input) | |
with open(html_output, "w") as output_file: | |
output_file.write(html) | |
def process_file(file): | |
try: | |
html = create_html(file.name) | |
return html | |
except Exception as e: | |
return f"Error: {str(e)}" | |
with gr.Blocks(head=_CSS, ) as demo: | |
gr.Markdown( | |
""" | |
## TensorFlow Lite Model Visualizer | |
Drag and drop your `.tflite`, `.bin` or `.tf_lite` model files below to analyze them. | |
""") | |
file_input = gr.File(label="Upload TFLite File") | |
html_output = gr.HTML(label="Generated HTML", container=True) | |
file_input.change(process_file, inputs=file_input, outputs=html_output) | |
demo.launch() | |
# if __name__ == "__main__": | |
# main(sys.argv) | |