|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
import os |
|
from langchain_folder.main import ReturnKeywordsfromPrompt |
|
from dotenv import load_dotenv |
|
import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
|
|
load_dotenv() |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
CSV_FILE_PATH = os.getenv('file_path') |
|
GRAPH_DIR = './static/graphs' |
|
|
|
|
|
@app.route('/api/search', methods=['POST']) |
|
def search(): |
|
data = request.get_json() |
|
query = data.get('query', '') |
|
keywords = ReturnKeywordsfromPrompt(query) |
|
return jsonify({"status": "success", "queryReceived": query}) |
|
|
|
|
|
def generate_graphs_from_csv(csv_path, output_dir): |
|
df = pd.read_csv(csv_path) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
numeric_cols = df.select_dtypes(include='number').columns.tolist() |
|
categorical_cols = df.select_dtypes(include='object').columns.tolist() |
|
graph_paths = [] |
|
print (categorical_cols) |
|
|
|
if len(numeric_cols) >= 1: |
|
plt.figure(figsize=(4, 3)) |
|
sns.histplot(df[numeric_cols[0]], kde=True) |
|
plt.title(f'{numeric_cols[0]} Distribution') |
|
plt.tight_layout() |
|
path = f'{output_dir}/graph_1_hist.png' |
|
plt.savefig(path) |
|
graph_paths.append(path) |
|
|
|
if len(numeric_cols) >= 2 and categorical_cols: |
|
plt.figure(figsize=(4, 3)) |
|
sns.boxplot(data=df, x=categorical_cols[0], y=numeric_cols[1]) |
|
plt.title(f'{numeric_cols[1]} by {categorical_cols[0]}') |
|
plt.tight_layout() |
|
path = f'{output_dir}/graph_2_box.png' |
|
plt.savefig(path) |
|
graph_paths.append(path) |
|
|
|
if categorical_cols: |
|
plt.figure(figsize=(4, 3)) |
|
sns.countplot(data=df, x=categorical_cols[0]) |
|
plt.title(f'{categorical_cols[0]} Distribution') |
|
plt.tight_layout() |
|
path = f'{output_dir}/graph_3_count.png' |
|
plt.savefig(path) |
|
graph_paths.append(path) |
|
|
|
if len(numeric_cols) >= 2: |
|
plt.figure(figsize=(4, 3)) |
|
sns.scatterplot(data=df, x=numeric_cols[0], y=numeric_cols[1]) |
|
plt.title(f'{numeric_cols[0]} vs {numeric_cols[1]}') |
|
plt.tight_layout() |
|
path = f'{output_dir}/graph_4_scatter.png' |
|
plt.savefig(path) |
|
graph_paths.append(path) |
|
|
|
return graph_paths |
|
|
|
|
|
@app.route('/api/get_csv', methods=['GET']) |
|
def get_csv(): |
|
try: |
|
if not os.path.exists(CSV_FILE_PATH): |
|
return jsonify({"error": "CSV file not found"}), 404 |
|
with open(CSV_FILE_PATH, "r", encoding="utf-8") as f: |
|
return f.read(), 200, { |
|
"Content-Type": "text/csv", |
|
"Content-Disposition": "inline; filename=dataset.csv" |
|
} |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
|
|
@app.route('/api/download_csv', methods=['GET']) |
|
def download_csv(): |
|
return send_file(CSV_FILE_PATH, as_attachment=True) |
|
|
|
|
|
@app.route('/api/get_graphs', methods=['GET']) |
|
def get_graphs(): |
|
paths = generate_graphs_from_csv(CSV_FILE_PATH, GRAPH_DIR) |
|
return jsonify({"graphs": [p.replace("./static", "/static") for p in paths]}) |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |
|
|