File size: 4,785 Bytes
60f68c4
ccbdd61
 
a66c2ab
dd82d0a
 
19b2962
ccbdd61
19b2962
 
 
 
 
60f68c4
a66c2ab
60f68c4
ccbdd61
a66c2ab
 
 
ccbdd61
 
 
 
a66c2ab
ccbdd61
 
 
a66c2ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e00b058
ccbdd61
cd3657e
e00b058
a66c2ab
 
 
cd3657e
a66c2ab
 
 
cd3657e
 
 
a66c2ab
ccbdd61
 
 
 
60f68c4
ccbdd61
 
 
 
 
a66c2ab
 
 
60f68c4
a66c2ab
 
 
 
 
ccbdd61
60f68c4
ccbdd61
60f68c4
ccbdd61
dd82d0a
ccbdd61
 
 
 
657dd2f
a66c2ab
ccbdd61
 
 
 
 
 
dd82d0a
 
ccbdd61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657dd2f
ccbdd61
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import List
from typing import Dict
import plotly.io as pio
import plotly.express as px
import pandas as pd
from utils import TEMP_DIR
import os
import ast
from dotenv import load_dotenv

load_dotenv()

root_url = os.getenv("ROOT_URL")

def chart_generation_func(data: List[str], x_column: str, y_column: str, graph_type: str, session_hash: str, layout: Dict[str,str]={}, category: str=""):
    print("CHART GENERATION")
    print(data)
    print(x_column)
    print(y_column)
    print(category)
    print(layout)
    try:
        dir_path = TEMP_DIR / str(session_hash)
        chart_path = f'{dir_path}/chart.html'
        csv_query_path = f'{dir_path}/query.csv'

        #Processing data to account for variation from LLM
        data_list = []
        layout_dict = {}

        df = pd.read_csv(csv_query_path)

        if graph_type == "bar":
         if category in df.columns:
            initial_graph = px.bar(df, x=x_column, y=y_column, color=category, barmode="group")
         else:
            initial_graph = px.bar(df, x=x_column, y=y_column, barmode="group")
        elif graph_type == "scatter":
         if category in df.columns:
            initial_graph = px.scatter(df, x=x_column, y=y_column, color=category)  
         else:
            initial_graph = px.scatter(df, x=x_column, y=y_column)
        elif graph_type == "line":
         if category in df.columns:
            initial_graph = px.line(df, x=x_column, y=y_column, color=category)
         else:
            initial_graph = px.line(df, x=x_column, y=y_column)
        elif graph_type == "pie":
         if category in df.columns:
            initial_graph = px.pie(df, x=x_column, y=y_column, color=category)
         else:
            initial_graph = px.pie(df, x=x_column, y=y_column)

        if isinstance(data, list):
           data_list = data
        else:
           data_list.append(data)  
           
        for index, data_obj in enumerate(data_list):   
         if isinstance(data_obj, str):
            data_obj = data_obj.replace("\n", "")
            if not data_obj.startswith('{') and not data_obj.endswith('}'):
               data_obj = "{" + data_obj + "}"
            data_dict = ast.literal_eval(data_obj)
         else:
            data_dict = data_obj
    
        if isinstance(layout, list):
           layout_obj = layout[0]
        else:
           layout_obj = layout

        if isinstance(layout_obj, str):
           layout_dict = ast.literal_eval(layout_obj)
        else:
           layout_dict = layout_obj
           
        fig = initial_graph.to_dict()

        fig["layout"] = layout_dict

        for key, value in data_dict.items():
           if key not in ["x","y"]:
            for data_item in fig["data"]:
               data_item[key] = value
     
        pio.write_html(fig, chart_path, full_html=False)

        chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'

        iframe = '<div style=overflow:auto;><iframe\n    scrolling="yes"\n    width="1000px"\n    height="500px"\n    src="' + chart_url + '"\n    frameborder="0"\n    allowfullscreen\n></iframe>\n</div>'

        return {"reply": iframe}
    
    except Exception as e:
      print("CHART ERROR")
      print(e)
      reply = f"""There was an error generating the Plotly Chart from {x_column}, {y_column}, {graph_type}, and {layout}

              The error is {e},

              You should probably try again.

              """
      return {"reply": reply}

def table_generation_func(data: List[dict], session_hash):
    print("TABLE GENERATION")
    print(data)
    try: 
        dir_path = TEMP_DIR / str(session_hash)
        csv_path = f'{dir_path}/data.csv'

        #Processing data to account for variation from LLM
        if isinstance(data, list):
           data_obj = data[0]
        else:
           data_obj = data

        if isinstance(data_obj, str):
           data_dict = ast.literal_eval(data_obj)
        else:
           data_dict = data_obj        

        df = pd.DataFrame.from_dict(data_dict)
        print(df)
        df.to_csv(csv_path)

        download_path = f'{root_url}/gradio_api/file/temp/{session_hash}/data.csv'
        html_table = df.to_html() + f'<p>Download as a <a href="{download_path}">CSV file</a></p>'
        print(html_table)

        return {"reply": html_table}
    
    except Exception as e:
      print("TABLE ERROR")
      print(e)
      reply = f"""There was an error generating the Pandas DataFrame table from {data}

              The error is {e},

              You should probably try again.

              """
      return {"reply": reply}