infinitymatter commited on
Commit
2f98db0
Β·
verified Β·
1 Parent(s): 2dc76b2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -0
app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ import json
5
+ import requests
6
+ from ctgan import CTGAN
7
+ from sklearn.preprocessing import LabelEncoder
8
+
9
+ def generate_schema(prompt):
10
+ """Fetches schema from Hugging Face Spaces API."""
11
+ API_URL = "https://infinitymatter-Synthetic_Data_Generator_SRIJAN.hf.space/run/predict"
12
+
13
+ # Fetch API token securely
14
+ hf_token = st.secrets["hf_token"]
15
+ headers = {"Authorization": f"Bearer {hf_token}"}
16
+
17
+ payload = {"data": [prompt]}
18
+
19
+ try:
20
+ response = requests.post(API_URL, headers=headers, json=payload)
21
+ response.raise_for_status()
22
+ schema = response.json()
23
+
24
+ if 'columns' not in schema or 'types' not in schema or 'size' not in schema:
25
+ raise ValueError("Invalid schema format!")
26
+
27
+ return schema
28
+ except requests.exceptions.RequestException as e:
29
+ st.error(f"❌ API request failed: {e}")
30
+ return None
31
+ except json.JSONDecodeError:
32
+ st.error("❌ Failed to parse JSON response.")
33
+ return None
34
+
35
+
36
+ def train_and_generate_synthetic(real_data, schema, output_path):
37
+ """Trains a CTGAN model and generates synthetic data."""
38
+ categorical_cols = [col for col, dtype in zip(schema['columns'], schema['types']) if dtype == 'string']
39
+
40
+ # Store label encoders
41
+ label_encoders = {}
42
+ for col in categorical_cols:
43
+ le = LabelEncoder()
44
+ real_data[col] = le.fit_transform(real_data[col])
45
+ label_encoders[col] = le
46
+
47
+ # Train CTGAN
48
+ gan = CTGAN(epochs=300)
49
+ gan.fit(real_data, categorical_cols)
50
+
51
+ # Generate synthetic data
52
+ synthetic_data = gan.sample(schema['size'])
53
+
54
+ # Decode categorical columns
55
+ for col in categorical_cols:
56
+ synthetic_data[col] = label_encoders[col].inverse_transform(synthetic_data[col])
57
+
58
+ # Save to CSV
59
+ os.makedirs('outputs', exist_ok=True)
60
+ synthetic_data.to_csv(output_path, index=False)
61
+ st.success(f"βœ… Synthetic data saved to {output_path}")
62
+
63
+ def fetch_data(domain):
64
+ """Fetches real data for the given domain and ensures it's a valid DataFrame."""
65
+ data_path = f"datasets/{domain}.csv"
66
+ if os.path.exists(data_path):
67
+ df = pd.read_csv(data_path)
68
+ if not isinstance(df, pd.DataFrame) or df.empty:
69
+ raise ValueError("❌ Loaded data is invalid!")
70
+ return df
71
+ else:
72
+ st.error(f"❌ Dataset for {domain} not found.")
73
+ return None
74
+
75
+ st.title("✨ AI-Powered Synthetic Dataset Generator")
76
+ st.write("Give a short description of the dataset you need, and AI will generate it for you using real data + GANs!")
77
+
78
+ # User input
79
+ user_prompt = st.text_input("Describe the dataset (e.g., 'Create dataset for hospital patients')", "")
80
+ domain = st.selectbox("Select Domain for Real Data", ["healthcare", "finance", "retail", "other"])
81
+
82
+ data = None
83
+ if st.button("Generate Schema"):
84
+ if user_prompt.strip():
85
+ with st.spinner("Generating schema..."):
86
+ schema = generate_schema(user_prompt)
87
+
88
+ if schema is None:
89
+ st.error("❌ Schema generation failed. Please check API response.")
90
+ else:
91
+ st.success("βœ… Schema generated successfully!")
92
+ st.json(schema)
93
+ data = fetch_data(domain)
94
+ else:
95
+ st.warning("⚠️ Please enter a dataset description before generating the schema.")
96
+
97
+ if data is not None and schema is not None:
98
+ output_path = "outputs/synthetic_data.csv"
99
+ if st.button("Generate Synthetic Data"):
100
+ with st.spinner("Training GAN and generating synthetic data..."):
101
+ train_and_generate_synthetic(data, schema, output_path)
102
+ with open(output_path, "rb") as file:
103
+ st.download_button("Download Synthetic Data", file, file_name="synthetic_data.csv", mime="text/csv")