Nolan Zandi commited on
Commit
a66c2ab
·
1 Parent(s): 32cb2fb

add regression function

Browse files
app.py CHANGED
@@ -36,7 +36,7 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
36
  title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
37
  description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
38
  to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
39
- Can now generate charts and graphs!
40
  Try a sample file to get started!</p>
41
  <p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
42
  open a discussion in the community tab and I will respond.</p>""")
@@ -63,7 +63,8 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
63
  ["What levels of education have the highest and lowest average balance?"],
64
  ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
65
  ["Can you generate a bar chart of education vs. average balance?"],
66
- ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"]
 
67
  ]
68
  elif "online_retail_data" in filename:
69
  example_questions = [
@@ -71,7 +72,8 @@ with gr.Blocks(css=css, delete_cache=(3600,3600)) as demo:
71
  ["What month had the highest revenue?"],
72
  ["Is revenue higher in the morning or afternoon?"],
73
  ["Can you generate a line graph of revenue per month?"],
74
- ["Can you generate a table of revenue per month?"]
 
75
  ]
76
  else:
77
  try:
 
36
  title = gr.HTML("<h1 style='text-align:center;'>Virtual Data Analyst</h1>")
37
  description = gr.HTML("""<p style='text-align:center;'>Upload a data file and chat with our virtual data analyst
38
  to get insights on your data set. Currently accepts CSV, TSV, TXT, XLS, XLSX, XML, and JSON files.
39
+ Can now generate charts and graphs! Can run linear regressions!
40
  Try a sample file to get started!</p>
41
  <p style='text-align:center;'>This tool is under active development. If you experience bugs with use,
42
  open a discussion in the community tab and I will respond.</p>""")
 
63
  ["What levels of education have the highest and lowest average balance?"],
64
  ["What job is most and least common for a yes response from the individuals, not counting 'unknown'?"],
65
  ["Can you generate a bar chart of education vs. average balance?"],
66
+ ["Can you generate a table of levels of education versus average balance, percent married, percent with a loan, and percent in default?"],
67
+ ["Can we predict the relationship between the number of contacts performed before this campaign and the average balance?"],
68
  ]
69
  elif "online_retail_data" in filename:
70
  example_questions = [
 
72
  ["What month had the highest revenue?"],
73
  ["Is revenue higher in the morning or afternoon?"],
74
  ["Can you generate a line graph of revenue per month?"],
75
+ ["Can you generate a table of revenue per month?"],
76
+ ["Can we predict how time of day affects revenue in this data set?"],
77
  ]
78
  else:
79
  try:
functions/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
  from .chart_functions import chart_generation_func, table_generation_func
3
  from .chat_functions import example_question_generator, chatbot_with_fc
 
4
 
5
- __all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","example_question_generator","chatbot_with_fc"]
 
1
  from .sqlite_functions import SQLiteQuery, sqlite_query_func
2
  from .chart_functions import chart_generation_func, table_generation_func
3
  from .chat_functions import example_question_generator, chatbot_with_fc
4
+ from .stat_functions import regression_func
5
 
6
+ __all__ = ["SQLiteQuery","sqlite_query_func","chart_generation_func","table_generation_func","regression_func","example_question_generator","chatbot_with_fc"]
functions/chart_functions.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List
2
  from typing import Dict
3
  import plotly.io as pio
 
4
  import pandas as pd
5
  from utils import TEMP_DIR
6
  import os
@@ -11,31 +12,59 @@ load_dotenv()
11
 
12
  root_url = os.getenv("ROOT_URL")
13
 
14
- def chart_generation_func(data: List[dict], session_hash: str, layout: Dict[str,str]={}):
15
  print("CHART GENERATION")
16
  print(data)
 
 
 
17
  print(layout)
18
  try:
19
  dir_path = TEMP_DIR / str(session_hash)
20
  chart_path = f'{dir_path}/chart.html'
 
21
 
22
  #Processing data to account for variation from LLM
23
  data_list = []
24
- layout_dict = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  if isinstance(data, list):
27
  data_list = data
28
  else:
29
- data_list.append(data)
30
-
31
- data_dict_list = []
32
- for data_obj in data_list:
33
  if isinstance(data_obj, str):
 
 
 
34
  data_dict = ast.literal_eval(data_obj)
35
  else:
36
  data_dict = data_obj
37
- data_dict_list.append(data_dict)
38
-
39
  if isinstance(layout, list):
40
  layout_obj = layout[0]
41
  else:
@@ -46,9 +75,15 @@ def chart_generation_func(data: List[dict], session_hash: str, layout: Dict[str,
46
  else:
47
  layout_dict = layout_obj
48
 
 
 
 
49
 
50
- fig = dict({"data": data_dict_list,
51
- "layout": layout_dict})
 
 
 
52
  pio.write_html(fig, chart_path, full_html=False)
53
 
54
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
@@ -60,7 +95,7 @@ def chart_generation_func(data: List[dict], session_hash: str, layout: Dict[str,
60
  except Exception as e:
61
  print("CHART ERROR")
62
  print(e)
63
- reply = f"""There was an error generating the Plotly Chart from {data} and {layout}
64
  The error is {e},
65
  You should probably try again.
66
  """
 
1
  from typing import List
2
  from typing import Dict
3
  import plotly.io as pio
4
+ import plotly.express as px
5
  import pandas as pd
6
  from utils import TEMP_DIR
7
  import os
 
12
 
13
  root_url = os.getenv("ROOT_URL")
14
 
15
+ def chart_generation_func(data: List[str], x_column: str, y_column: str, graph_type: str, session_hash: str, layout: Dict[str,str]={}, category: str=""):
16
  print("CHART GENERATION")
17
  print(data)
18
+ print(x_column)
19
+ print(y_column)
20
+ print(category)
21
  print(layout)
22
  try:
23
  dir_path = TEMP_DIR / str(session_hash)
24
  chart_path = f'{dir_path}/chart.html'
25
+ csv_query_path = f'{dir_path}/query.csv'
26
 
27
  #Processing data to account for variation from LLM
28
  data_list = []
29
+ layout_dict = {}
30
+
31
+ df = pd.read_csv(csv_query_path)
32
+
33
+ if graph_type == "bar":
34
+ if category in df.columns:
35
+ initial_graph = px.bar(df, x=x_column, y=y_column, color=category, barmode="group")
36
+ else:
37
+ initial_graph = px.bar(df, x=x_column, y=y_column, barmode="group")
38
+ elif graph_type == "scatter":
39
+ if category in df.columns:
40
+ initial_graph = px.scatter(df, x=x_column, y=y_column, color=category)
41
+ else:
42
+ initial_graph = px.scatter(df, x=x_column, y=y_column)
43
+ elif graph_type == "line":
44
+ if category in df.columns:
45
+ initial_graph = px.line(df, x=x_column, y=y_column, color=category)
46
+ else:
47
+ initial_graph = px.line(df, x=x_column, y=y_column)
48
+ elif graph_type == "pie":
49
+ if category in df.columns:
50
+ initial_graph = px.pie(df, x=x_column, y=y_column, color=category)
51
+ else:
52
+ initial_graph = px.pie(df, x=x_column, y=y_column)
53
 
54
  if isinstance(data, list):
55
  data_list = data
56
  else:
57
+ data_list.append(data)
58
+
59
+ for index, data_obj in enumerate(data_list):
 
60
  if isinstance(data_obj, str):
61
+ data_obj = data_obj.replace("\n", "")
62
+ if not data_obj.startswith('{') and not data_obj.endswith('}'):
63
+ data_obj = "{" + data_obj + "}"
64
  data_dict = ast.literal_eval(data_obj)
65
  else:
66
  data_dict = data_obj
67
+
 
68
  if isinstance(layout, list):
69
  layout_obj = layout[0]
70
  else:
 
75
  else:
76
  layout_dict = layout_obj
77
 
78
+ fig = initial_graph.to_dict()
79
+
80
+ fig["layout"] = layout_dict
81
 
82
+ for key, value in data_dict.items():
83
+ if key not in ["x","y"]:
84
+ for data_item in fig["data"]:
85
+ data_item[key] = value
86
+
87
  pio.write_html(fig, chart_path, full_html=False)
88
 
89
  chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
 
95
  except Exception as e:
96
  print("CHART ERROR")
97
  print(e)
98
+ reply = f"""There was an error generating the Plotly Chart from {x_column}, {y_column}, {graph_type}, and {layout}
99
  The error is {e},
100
  You should probably try again.
101
  """
functions/chat_functions.py CHANGED
@@ -36,10 +36,10 @@ def example_question_generator(session_hash):
36
  return example_response["replies"][0].text
37
 
38
  def chatbot_with_fc(message, history, session_hash):
39
- from functions import sqlite_query_func, chart_generation_func, table_generation_func
40
  import tools
41
 
42
- available_functions = {"sql_query_func": sqlite_query_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func }
43
 
44
  if message_dict[session_hash] != None:
45
  message_dict[session_hash].append(ChatMessage.from_user(message))
@@ -47,8 +47,9 @@ def chatbot_with_fc(message, history, session_hash):
47
  messages = [
48
  ChatMessage.from_system(
49
  """You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
50
- You also have access to a chart function that uses plotly dictionaries to generate charts and graphs and returns an iframe that we can display in our chat window.
51
- You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV."""
 
52
  )
53
  ]
54
  messages.append(ChatMessage.from_user(message))
 
36
  return example_response["replies"][0].text
37
 
38
  def chatbot_with_fc(message, history, session_hash):
39
+ from functions import sqlite_query_func, chart_generation_func, table_generation_func, regression_func
40
  import tools
41
 
42
+ available_functions = {"sql_query_func": sqlite_query_func, "chart_generation_func": chart_generation_func, "table_generation_func":table_generation_func, "regression_func":regression_func }
43
 
44
  if message_dict[session_hash] != None:
45
  message_dict[session_hash].append(ChatMessage.from_user(message))
 
47
  messages = [
48
  ChatMessage.from_system(
49
  """You are a helpful and knowledgeable agent who has access to an SQLite database which has a table called 'data_source'.
50
+ You also have access to a chart function that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate charts and graphs and returns an iframe that we can display in our chat window.
51
+ You also have access to a function, called table_generation_func, that builds table formatted html and generates a link to download as CSV.
52
+ You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe."""
53
  )
54
  ]
55
  messages.append(ChatMessage.from_user(message))
functions/sqlite_functions.py CHANGED
@@ -15,11 +15,13 @@ class SQLiteQuery:
15
  self.connection = sqlite3.connect(sql_database, check_same_thread=False)
16
 
17
  @component.output_types(results=List[str], queries=List[str])
18
- def run(self, queries: List[str]):
19
  print("ATTEMPTING TO RUN QUERY")
 
20
  results = []
21
  for query in queries:
22
  result = pd.read_sql(query, self.connection)
 
23
  results.append(f"{result}")
24
  self.connection.close()
25
  return {"results": results, "queries": queries}
@@ -30,8 +32,12 @@ def sqlite_query_func(queries: List[str], session_hash):
30
  dir_path = TEMP_DIR / str(session_hash)
31
  sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
32
  try:
33
- result = sql_query.run(queries)
34
- return {"reply": result["results"][0]}
 
 
 
 
35
 
36
  except Exception as e:
37
  reply = f"""There was an error running the SQL Query = {queries}
 
15
  self.connection = sqlite3.connect(sql_database, check_same_thread=False)
16
 
17
  @component.output_types(results=List[str], queries=List[str])
18
+ def run(self, queries: List[str], session_hash):
19
  print("ATTEMPTING TO RUN QUERY")
20
+ dir_path = TEMP_DIR / str(session_hash)
21
  results = []
22
  for query in queries:
23
  result = pd.read_sql(query, self.connection)
24
+ result.to_csv(f'{dir_path}/query.csv', index=False)
25
  results.append(f"{result}")
26
  self.connection.close()
27
  return {"results": results, "queries": queries}
 
32
  dir_path = TEMP_DIR / str(session_hash)
33
  sql_query = SQLiteQuery(f'{dir_path}/data_source.db')
34
  try:
35
+ result = sql_query.run(queries, session_hash)
36
+ if len(result["results"][0]) > 300:
37
+ print("QUERY TOO LARGE")
38
+ return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file"}
39
+ else:
40
+ return {"reply": result["results"][0]}
41
 
42
  except Exception as e:
43
  reply = f"""There was an error running the SQL Query = {queries}
functions/stat_functions.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pandas as pd
3
+ from typing import List
4
+ from utils import TEMP_DIR
5
+ import plotly.express as px
6
+ import plotly.io as pio
7
+ import os
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
+
12
+ root_url = os.getenv("ROOT_URL")
13
+
14
+ def basic_stats_function(data_list: List[str]):
15
+ return
16
+
17
+ def regression_func(independent_variables: List[str], dependent_variable: str, session_hash, category: str='', regression_type: str="ols"):
18
+ print("LINEAR REGRESSION CALCULATION")
19
+ print(independent_variables)
20
+ print(dependent_variable)
21
+ try:
22
+ dir_path = TEMP_DIR / str(session_hash)
23
+ chart_path = f'{dir_path}/chart.html'
24
+ csv_query_path = f'{dir_path}/query.csv'
25
+
26
+ df = pd.read_csv(csv_query_path)
27
+
28
+ if category in df.columns:
29
+ fig = px.scatter(df, x=independent_variables, y=dependent_variable, color=category, trendline="ols")
30
+ else:
31
+ fig = px.scatter(df, x=independent_variables, y=dependent_variable, trendline="ols")
32
+
33
+ pio.write_html(fig, chart_path, full_html=False)
34
+
35
+ chart_url = f'{root_url}/gradio_api/file/temp/{session_hash}/chart.html'
36
+
37
+ iframe = '<div style=overflow:auto;><iframe\n scrolling="yes"\n width="1000px"\n height="500px"\n src="' + chart_url + '"\n frameborder="0"\n allowfullscreen\n></iframe>\n</div>'
38
+
39
+ results_frame = px.get_trendline_results(fig)
40
+
41
+ print("RESULTS")
42
+ print(results_frame)
43
+ print(results_frame.at[0, 'px_fit_results'])
44
+ results = results_frame.at[0, 'px_fit_results']
45
+ print(results.summary())
46
+
47
+ return {"reply": '{"regression_summary": %s, "regression_chart": %s' % (str(results.summary()), str(iframe))}
48
+
49
+ except Exception as e:
50
+ print("LINEAR REGRESSION ERROR")
51
+ print(e)
52
+ reply = f"""There was an error generating the linear regression calculation from {independent_variables} and {dependent_variable}
53
+ The error is {e},
54
+ You should probably try again.
55
+ """
56
+ return {"reply": reply}
requirements.txt CHANGED
@@ -4,3 +4,4 @@ gradio
4
  pandas
5
  plotly
6
  openpyxl
 
 
4
  pandas
5
  plotly
6
  openpyxl
7
+ statsmodels
tools.py CHANGED
@@ -17,7 +17,9 @@ def tools_call(session_hash):
17
  "type": "function",
18
  "function": {
19
  "name": "sql_query_func",
20
- "description": f"This a tool useful to query a SQLite table called 'data_source' with the following Columns: {columns}",
 
 
21
  "parameters": {
22
  "type": "object",
23
  "properties": {
@@ -37,8 +39,9 @@ def tools_call(session_hash):
37
  "type": "function",
38
  "function": {
39
  "name": "chart_generation_func",
40
- "description": f"""This a chart generation tool useful to generate charts and graphs from queried data from our SQL table called 'data_source
41
- with the following Columns: {columns}. Returns an iframe string which will be displayed inline in our chat window. Do not edit the string returned
 
42
  from the chart_generation_func function in any way and display it fully to the user in the chat window. You can add your own text supplementary
43
  to it for context if desired.""",
44
  "parameters": {
@@ -46,7 +49,40 @@ def tools_call(session_hash):
46
  "properties": {
47
  "data": {
48
  "type": "array",
49
- "description": """The list containing a dictionary that contains the 'data' portion of the plotly chart generation. Infer this from the user's message.""",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  "items": {
51
  "type": "string",
52
  }
@@ -59,7 +95,7 @@ def tools_call(session_hash):
59
  }
60
  }
61
  },
62
- "required": ["data"],
63
  },
64
  },
65
  },
@@ -67,11 +103,10 @@ def tools_call(session_hash):
67
  "type": "function",
68
  "function": {
69
  "name": "table_generation_func",
70
- "description": f"""This an table generation tool useful to format data as a table from queried data from our SQL table called
71
- 'data_source with the following Columns: {columns}. Returns an html string generated from the pandas library and pandas.to_html()
72
  function which will be displayed inline in our chat window. There will also be a link to download the CSV included in the HTML string.
73
- This link should open in a new window. Do not edit the string returned by the function in any way when displaying to the user, as the user needs
74
- all of the information returned by the function in it's exact state.""",
75
  "parameters": {
76
  "type": "object",
77
  "properties": {
@@ -90,5 +125,54 @@ def tools_call(session_hash):
90
  "required": ["data"],
91
  },
92
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  }
94
  ]
 
17
  "type": "function",
18
  "function": {
19
  "name": "sql_query_func",
20
+ "description": f"""This a tool useful to query a SQLite table called 'data_source' with the following Columns: {columns}.
21
+ This function also saves the results of the query to csv file called query.csv. This is useful when query results are too large to process
22
+ or need to be used in an another function.""",
23
  "parameters": {
24
  "type": "object",
25
  "properties": {
 
39
  "type": "function",
40
  "function": {
41
  "name": "chart_generation_func",
42
+ "description": f"""This a chart generation tool useful to generate charts and graphs from queried data from our SQL table called 'data_source.
43
+ The data values will come from the columns of our query.csv (the 'x' and 'y' values of each graph) file but the layout section of the plotly dictionary objects will be generated by you.
44
+ Returns an iframe string which will be displayed inline in our chat window. Do not edit the string returned
45
  from the chart_generation_func function in any way and display it fully to the user in the chat window. You can add your own text supplementary
46
  to it for context if desired.""",
47
  "parameters": {
 
49
  "properties": {
50
  "data": {
51
  "type": "array",
52
+ "description": """The list containing a dictionary that contains the 'data' portion of the plotly chart generation and will include the options requested by the user.
53
+ Do not include the 'x' or 'y' portions of the object as this will come from the query.csv file generated by our SQLite query.
54
+ Infer this from the user's message.""",
55
+ "items": {
56
+ "type": "string",
57
+ }
58
+ },
59
+ "x_column": {
60
+ "type": "string",
61
+ "description": f"""The column in our query.csv file that contain the x values of the graph.""",
62
+ "items": {
63
+ "type": "string",
64
+ }
65
+ },
66
+ "y_column": {
67
+ "type": "string",
68
+ "description": f"""The column in our query.csv file that contain the y values of the graph.""",
69
+ "items": {
70
+ "type": "string",
71
+ }
72
+ },
73
+ "category": {
74
+ "type": "string",
75
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the category for the data.""",
76
+ "items": {
77
+ "type": "string",
78
+ }
79
+ },
80
+ "graph_type": {
81
+ "type": "string",
82
+ "description": f"""The type of plotly graph we wish to generate.
83
+ This graph_type value can be one of ['bar','scatter','line','pie'].
84
+ Do not send any values outside of this list as the function will fail.
85
+ Infer this from the user's message.""",
86
  "items": {
87
  "type": "string",
88
  }
 
95
  }
96
  }
97
  },
98
+ "required": ["graph_type","x_column","y_column","layout"],
99
  },
100
  },
101
  },
 
103
  "type": "function",
104
  "function": {
105
  "name": "table_generation_func",
106
+ "description": f"""This an table generation tool useful to format data as a table from queried data from our SQL table called 'data_source.
107
+ Returns an html string generated from the pandas library and pandas.to_html()
108
  function which will be displayed inline in our chat window. There will also be a link to download the CSV included in the HTML string.
109
+ Do not edit the string returned by the function in any way when displaying to the user.""",
 
110
  "parameters": {
111
  "type": "object",
112
  "properties": {
 
125
  "required": ["data"],
126
  },
127
  },
128
+ },
129
+ {
130
+ "type": "function",
131
+ "function": {
132
+ "name": "regression_func",
133
+ "description": f"""This a tool to calculate regressions on our SQLite table called 'data_source'.
134
+ We can run queries with our 'sql_query_func' function and they will be available to use in this function via the query.csv file that is generated.
135
+ Returns a dictionary of values that includes a regression_summary and a regression chart (which is an iframe displaying the
136
+ linear regression in chart form and should be shown to the user).""",
137
+ "parameters": {
138
+ "type": "object",
139
+ "properties": {
140
+ "independent_variables": {
141
+ "type": "array",
142
+ "description": f"""A list of strings that states the independent variables in our data set which should be column names in our query.csv file that is generated
143
+ in the 'sql_query_func' function. This will allow us to identify the data to use for our independent variables.
144
+ Infer this from the user's message.""",
145
+ "items": {
146
+ "type": "string",
147
+ }
148
+ },
149
+ "dependent_variable": {
150
+ "type": "string",
151
+ "description": f"""A string that states the dependent variables in our data set which should be a column name in our query.csv file that is generated
152
+ in the 'sql_query_func' function. This will allow us to identify the data to use for our dependent variables.
153
+ Infer this from the user's message.""",
154
+ "items": {
155
+ "type": "string",
156
+ }
157
+ },
158
+ "category": {
159
+ "type": "string",
160
+ "description": f"""An optional column in our query.csv file that contain a parameter that will define the category for the data.
161
+ Do not send value if no category is needed or specified. This category must be present in our query.csv file to be valid.""",
162
+ "items": {
163
+ "type": "string",
164
+ }
165
+ },
166
+ "regression_type": {
167
+ "type": "string",
168
+ "description": f"""A parameter that specifies the type of regression being used from the trendline options that plotly offers. Defaults to 'ols'.""",
169
+ "items": {
170
+ "type": "string",
171
+ }
172
+ },
173
+ },
174
+ "required": ["independent_variables","dependent_variable"],
175
+ },
176
+ },
177
  }
178
  ]