File size: 3,138 Bytes
825e978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
from langchain_folder.main import ReturnKeywordsfromPrompt
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
load_dotenv()
app = Flask(__name__)
CORS(app)
CSV_FILE_PATH = os.getenv('file_path')
GRAPH_DIR = './static/graphs'
@app.route('/api/search', methods=['POST'])
def search():
data = request.get_json()
query = data.get('query', '')
keywords = ReturnKeywordsfromPrompt(query)
return jsonify({"status": "success", "queryReceived": query})
def generate_graphs_from_csv(csv_path, output_dir):
df = pd.read_csv(csv_path)
os.makedirs(output_dir, exist_ok=True)
numeric_cols = df.select_dtypes(include='number').columns.tolist()
categorical_cols = df.select_dtypes(include='object').columns.tolist()
graph_paths = []
print (categorical_cols)
if len(numeric_cols) >= 1:
plt.figure(figsize=(4, 3))
sns.histplot(df[numeric_cols[0]], kde=True)
plt.title(f'{numeric_cols[0]} Distribution')
plt.tight_layout()
path = f'{output_dir}/graph_1_hist.png'
plt.savefig(path)
graph_paths.append(path)
if len(numeric_cols) >= 2 and categorical_cols:
plt.figure(figsize=(4, 3))
sns.boxplot(data=df, x=categorical_cols[0], y=numeric_cols[1])
plt.title(f'{numeric_cols[1]} by {categorical_cols[0]}')
plt.tight_layout()
path = f'{output_dir}/graph_2_box.png'
plt.savefig(path)
graph_paths.append(path)
if categorical_cols:
plt.figure(figsize=(4, 3))
sns.countplot(data=df, x=categorical_cols[0])
plt.title(f'{categorical_cols[0]} Distribution')
plt.tight_layout()
path = f'{output_dir}/graph_3_count.png'
plt.savefig(path)
graph_paths.append(path)
if len(numeric_cols) >= 2:
plt.figure(figsize=(4, 3))
sns.scatterplot(data=df, x=numeric_cols[0], y=numeric_cols[1])
plt.title(f'{numeric_cols[0]} vs {numeric_cols[1]}')
plt.tight_layout()
path = f'{output_dir}/graph_4_scatter.png'
plt.savefig(path)
graph_paths.append(path)
return graph_paths
@app.route('/api/get_csv', methods=['GET'])
def get_csv():
try:
if not os.path.exists(CSV_FILE_PATH):
return jsonify({"error": "CSV file not found"}), 404
with open(CSV_FILE_PATH, "r", encoding="utf-8") as f:
return f.read(), 200, {
"Content-Type": "text/csv",
"Content-Disposition": "inline; filename=dataset.csv"
}
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/api/download_csv', methods=['GET'])
def download_csv():
return send_file(CSV_FILE_PATH, as_attachment=True)
@app.route('/api/get_graphs', methods=['GET'])
def get_graphs():
paths = generate_graphs_from_csv(CSV_FILE_PATH, GRAPH_DIR)
return jsonify({"graphs": [p.replace("./static", "/static") for p in paths]})
if __name__ == '__main__':
app.run(debug=True)
|