File size: 3,138 Bytes
825e978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from flask import Flask, request, jsonify, send_file
from flask_cors import CORS
import os
from langchain_folder.main import ReturnKeywordsfromPrompt
from dotenv import load_dotenv
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

load_dotenv()

app = Flask(__name__)
CORS(app)

CSV_FILE_PATH = os.getenv('file_path')
GRAPH_DIR = './static/graphs'


@app.route('/api/search', methods=['POST'])
def search():
    data = request.get_json()
    query = data.get('query', '')
    keywords = ReturnKeywordsfromPrompt(query)
    return jsonify({"status": "success", "queryReceived": query})


def generate_graphs_from_csv(csv_path, output_dir):
    df = pd.read_csv(csv_path)
    os.makedirs(output_dir, exist_ok=True)

    numeric_cols = df.select_dtypes(include='number').columns.tolist()
    categorical_cols = df.select_dtypes(include='object').columns.tolist()
    graph_paths = []
    print (categorical_cols)

    if len(numeric_cols) >= 1:
        plt.figure(figsize=(4, 3))
        sns.histplot(df[numeric_cols[0]], kde=True)
        plt.title(f'{numeric_cols[0]} Distribution')
        plt.tight_layout()
        path = f'{output_dir}/graph_1_hist.png'
        plt.savefig(path)
        graph_paths.append(path)

    if len(numeric_cols) >= 2 and categorical_cols:
        plt.figure(figsize=(4, 3))
        sns.boxplot(data=df, x=categorical_cols[0], y=numeric_cols[1])
        plt.title(f'{numeric_cols[1]} by {categorical_cols[0]}')
        plt.tight_layout()
        path = f'{output_dir}/graph_2_box.png'
        plt.savefig(path)
        graph_paths.append(path)

    if categorical_cols:
        plt.figure(figsize=(4, 3))
        sns.countplot(data=df, x=categorical_cols[0])
        plt.title(f'{categorical_cols[0]} Distribution')
        plt.tight_layout()
        path = f'{output_dir}/graph_3_count.png'
        plt.savefig(path)
        graph_paths.append(path)

    if len(numeric_cols) >= 2:
        plt.figure(figsize=(4, 3))
        sns.scatterplot(data=df, x=numeric_cols[0], y=numeric_cols[1])
        plt.title(f'{numeric_cols[0]} vs {numeric_cols[1]}')
        plt.tight_layout()
        path = f'{output_dir}/graph_4_scatter.png'
        plt.savefig(path)
        graph_paths.append(path)

    return graph_paths


@app.route('/api/get_csv', methods=['GET'])
def get_csv():
    try:
        if not os.path.exists(CSV_FILE_PATH):
            return jsonify({"error": "CSV file not found"}), 404
        with open(CSV_FILE_PATH, "r", encoding="utf-8") as f:
            return f.read(), 200, {
                "Content-Type": "text/csv",
                "Content-Disposition": "inline; filename=dataset.csv"
            }
    except Exception as e:
        return jsonify({"error": str(e)}), 500


@app.route('/api/download_csv', methods=['GET'])
def download_csv():
    return send_file(CSV_FILE_PATH, as_attachment=True)


@app.route('/api/get_graphs', methods=['GET'])
def get_graphs():
    paths = generate_graphs_from_csv(CSV_FILE_PATH, GRAPH_DIR)
    return jsonify({"graphs": [p.replace("./static", "/static") for p in paths]})


if __name__ == '__main__':
    app.run(debug=True)