virtual-data-analyst / functions /chart_functions.py
Nolan Zandi
add regression function
a66c2ab
raw
history blame
4.79 kB
from typing import List
from typing import Dict
import plotly.io as pio
import plotly.express as px
import pandas as pd
from utils import TEMP_DIR
import os
import ast
from dotenv import load_dotenv
load_dotenv()
root_url = os.getenv("ROOT_URL")
def chart_generation_func(data: List[str], x_column: str, y_column: str, graph_type: str, session_hash: str, layout: Dict[str,str]={}, category: str=""):
print("CHART GENERATION")
print(data)
print(x_column)
print(y_column)
print(category)
print(layout)
try:
dir_path = TEMP_DIR / str(session_hash)
chart_path = f'{dir_path}/chart.html'
csv_query_path = f'{dir_path}/query.csv'
#Processing data to account for variation from LLM
data_list = []
layout_dict = {}
df = pd.read_csv(csv_query_path)
if graph_type == "bar":
if category in df.columns:
initial_graph = px.bar(df, x=x_column, y=y_column, color=category, barmode="group")
else:
initial_graph = px.bar(df, x=x_column, y=y_column, barmode="group")
elif graph_type == "scatter":
if category in df.columns:
initial_graph = px.scatter(df, x=x_column, y=y_column, color=category)
else:
initial_graph = px.scatter(df, x=x_column, y=y_column)
elif graph_type == "line":
if category in df.columns:
initial_graph = px.line(df, x=x_column, y=y_column, color=category)
else:
initial_graph = px.line(df, x=x_column, y=y_column)
elif graph_type == "pie":
if category in df.columns:
initial_graph = px.pie(df, x=x_column, y=y_column, color=category)
else:
initial_graph = px.pie(df, x=x_column, y=y_column)
if isinstance(data, list):
data_list = data
else:
data_list.append(data)
for index, data_obj in enumerate(data_list):
if isinstance(data_obj, str):
data_obj = data_obj.replace("\n", "")
if not data_obj.startswith('{') and not data_obj.endswith('}'):
data_obj = "{" + data_obj + "}"
data_dict = ast.literal_eval(data_obj)
else:
data_dict = data_obj
if isinstance(layout, list):
layout_obj = layout[0]
else:
layout_obj = layout
if isinstance(layout_obj, str):
layout_dict = ast.literal_eval(layout_obj)
else:
layout_dict = layout_obj
fig = initial_graph.to_dict()
fig["layout"] = layout_dict
for key, value in data_dict.items():
if key not in ["x","y"]:
for data_item in fig["data"]:
data_item[key] = value
pio.write_html(fig, chart_path, full_html=False)
chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
return {"reply": iframe}
except Exception as e:
print("CHART ERROR")
print(e)
reply = f"""There was an error generating the Plotly Chart from {x_column}, {y_column}, {graph_type}, and {layout}
The error is {e},
You should probably try again.
"""
return {"reply": reply}
def table_generation_func(data: List[dict], session_hash):
print("TABLE GENERATION")
print(data)
try:
dir_path = TEMP_DIR / str(session_hash)
csv_path = f'{dir_path}/data.csv'
#Processing data to account for variation from LLM
if isinstance(data, list):
data_obj = data[0]
else:
data_obj = data
if isinstance(data_obj, str):
data_dict = ast.literal_eval(data_obj)
else:
data_dict = data_obj
df = pd.DataFrame.from_dict(data_dict)
print(df)
df.to_csv(csv_path)
download_path = f'{root_url}/gradio_api/file/temp/{session_hash}/data.csv'
html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV file</a></p>'
print(html_table)
return {"reply": html_table}
except Exception as e:
print("TABLE ERROR")
print(e)
reply = f"""There was an error generating the Pandas DataFrame table from {data}
The error is {e},
You should probably try again.
"""
return {"reply": reply}