karths commited on
Commit
36029e8
·
verified ·
1 Parent(s): 8e98f20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +265 -0
app.py CHANGED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ from huggingface_hub import login, HfFolder
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from scipy.special import softmax
9
+ import logging
10
+ import spaces
11
+ from threading import Thread
12
+ from collections.abc import Iterator
13
+ import csv
14
+ from llama_cpp import Llama
15
+
16
+ # Increase CSV field size limit
17
+ csv.field_size_limit(1000000)
18
+
19
+ # Setup logging
20
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
21
+
22
+ # Set a seed for reproducibility
23
+ seed = 42
24
+ np.random.seed(seed)
25
+ random.seed(seed)
26
+ torch.manual_seed(seed)
27
+ if torch.cuda.is_available():
28
+ torch.cuda.manual_seed_all(seed)
29
+
30
+ # Login to Hugging Face
31
+ token = os.getenv("hf_token")
32
+ HfFolder.save_token(token)
33
+ login(token)
34
+
35
+ model_paths = [
36
+ 'karths/binary_classification_train_port',
37
+ 'karths/binary_classification_train_perf',
38
+ "karths/binary_classification_train_main",
39
+ "karths/binary_classification_train_secu",
40
+ "karths/binary_classification_train_reli",
41
+ "karths/binary_classification_train_usab",
42
+ "karths/binary_classification_train_comp"
43
+ ]
44
+
45
+ quality_mapping = {
46
+ 'binary_classification_train_port': 'Portability',
47
+ 'binary_classification_train_main': 'Maintainability',
48
+ 'binary_classification_train_secu': 'Security',
49
+ 'binary_classification_train_reli': 'Reliability',
50
+ 'binary_classification_train_usab': 'Usability',
51
+ 'binary_classification_train_perf': 'Performance',
52
+ 'binary_classification_train_comp': 'Compatibility'
53
+ }
54
+
55
+ # Pre-load models and tokenizer for quality prediction
56
+ tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
57
+ models = {path: AutoModelForSequenceClassification.from_pretrained(path) for path in model_paths}
58
+
59
+ def get_quality_name(model_name):
60
+ return quality_mapping.get(model_name.split('/')[-1], "Unknown Quality")
61
+
62
+ def model_prediction(model, text, device):
63
+ model.to(device)
64
+ model.eval()
65
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
66
+ inputs = {k: v.to(device) for k, v in inputs.items()}
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+ logits = outputs.logits
70
+ probs = softmax(logits.cpu().numpy(), axis=1)
71
+ avg_prob = np.mean(probs[:, 1])
72
+ model.to("cpu")
73
+ return avg_prob
74
+
75
+ # --- Llama CPP Model Setup with GPU ---
76
+ LLAMA_MAX_MAX_NEW_TOKENS = 512
77
+ LLAMA_DEFAULT_MAX_NEW_TOKENS = 512
78
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "1024"))
79
+
80
+ # Check if GPU is available
81
+ gpu_layers = None
82
+ if torch.cuda.is_available():
83
+ # Use all GPU layers - you can adjust this number based on your GPU memory
84
+ gpu_layers = -1
85
+ logging.info("GPU is available. Using GPU acceleration for llama-cpp.")
86
+ else:
87
+ logging.info("GPU is not available. Using CPU for llama-cpp.")
88
+
89
+ # Initialize the Llama model with GPU acceleration
90
+ llm = Llama.from_pretrained(
91
+ repo_id="Qwen/Qwen2.5-1.5B-Instruct-GGUF",
92
+ filename="*q8_0.gguf", # Using q8_0 quantization
93
+ n_gpu_layers=gpu_layers, # Use GPU acceleration if available
94
+ verbose=False
95
+ )
96
+
97
+ def llama_generate(
98
+ message: str,
99
+ max_new_tokens: int = LLAMA_DEFAULT_MAX_NEW_TOKENS,
100
+ temperature: float = 0.3,
101
+ top_p: float = 0.9,
102
+ top_k: int = 50,
103
+ repetition_penalty: float = 1.2,
104
+ ) -> str:
105
+ try:
106
+ output = llm(
107
+ message,
108
+ max_tokens=max_new_tokens,
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ top_k=top_k,
112
+ repeat_penalty=repetition_penalty,
113
+ echo=False, # Don't include the prompt in the output
114
+ )
115
+
116
+ # Extract the generated text from the output
117
+ return output['choices'][0]['text']
118
+ except Exception as e:
119
+ logging.error(f"Error during Llama generation: {e}")
120
+ return f"Error generating text: {str(e)}"
121
+
122
+ def generate_explanation(issue_text, top_quality):
123
+ """Generates an explanation for the *single* top quality above threshold."""
124
+ if not top_quality:
125
+ return "<div style='color: red;'>No explanation available as no quality tags met the threshold.</div>"
126
+
127
+ quality_name = top_quality[0] # Get the name of the top quality
128
+
129
+ prompt = f"""
130
+ Given the following issue description:
131
+ ---
132
+ {issue_text}
133
+ ---
134
+ Explain why this issue might be classified as a **{quality_name}** issue. Provide a concise explanation, relating it back to the issue description. Keep the explanation short and concise and dont include anything else.
135
+ """
136
+ print(prompt)
137
+ try:
138
+ explanation = llama_generate(prompt)
139
+ # Format for better readability, directly including the quality name.
140
+ formatted_explanation = f"<p>{explanation}</p>"
141
+ return f"<div style='overflow-y: scroll; max-height: 400px;'>{formatted_explanation}</div>"
142
+ except Exception as e:
143
+ logging.error(f"Error during Llama generation: {e}")
144
+ return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
145
+
146
+ # @spaces.GPU(duration=60)
147
+ def main_interface(text):
148
+ if not text.strip():
149
+ return "<div style='color: red;'>No text provided. Please enter a valid issue description.</div>", "", ""
150
+
151
+ if len(text) < 30:
152
+ return "<div style='color: red;'>Text is less than 30 characters.</div>", "", ""
153
+
154
+ device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ results = []
156
+ for model_path, model in models.items():
157
+ quality_name = get_quality_name(model_path)
158
+ avg_prob = model_prediction(model, text, device)
159
+ if avg_prob >= 0.95: # Keep *all* results above the threshold
160
+ results.append((quality_name, avg_prob))
161
+ logging.info(f"Model: {model_path}, Quality: {quality_name}, Average Probability: {avg_prob:.3f}")
162
+
163
+ if not results:
164
+ return "<div style='color: red;'>No recommendation. Prediction probability is below the threshold.</div>", "", ""
165
+
166
+ # Sort and get the top result (if any meet the threshold)
167
+ top_result = sorted(results, key=lambda x: x[1], reverse=True)
168
+ if top_result:
169
+ top_quality = top_result[:1] # Select only the top result
170
+ output_html = render_html_output(top_quality)
171
+ explanation = generate_explanation(text, top_quality)
172
+ else: # Handle case no predictions >= 0.95
173
+ output_html = "<div style='color: red;'>No quality tag met the prediction probability threshold (>= 0.95).</div>"
174
+ explanation = ""
175
+
176
+ return output_html, "", explanation
177
+
178
+ def render_html_output(top_qualities):
179
+ #Simplified to show only the top prediction
180
+ styles = """
181
+ <style>
182
+ .quality-container {
183
+ font-family: Arial, sans-serif;
184
+ text-align: center;
185
+ margin-top: 20px;
186
+ }
187
+ .quality-label, .ranking {
188
+ display: inline-block;
189
+ padding: 0.5em 1em;
190
+ font-size: 18px;
191
+ font-weight: bold;
192
+ color: white;
193
+ background-color: #007bff;
194
+ border-radius: 0.5rem;
195
+ margin-right: 10px;
196
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
197
+ }
198
+ </style>
199
+ """
200
+ if not top_qualities: # Handle empty case
201
+ return styles + "<div class='quality-container'>No Top Prediction</div>"
202
+
203
+ quality, _ = top_qualities[0] #We know there is only one
204
+ html_content = f"""
205
+ <div class="quality-container">
206
+ <span class="ranking">Top Prediction</span>
207
+ <span class="quality-label">{quality}</span>
208
+ </div>
209
+ """
210
+ return styles + html_content
211
+
212
+ example_texts = [
213
+ ["The algorithm does not accurately distinguish between the positive and negative classes during edge cases.\n\nEnvironment: Production\nReproduction: Run the classifier on the test dataset with known edge cases."],
214
+ ["The regression tests do not cover scenarios involving concurrent user sessions.\n\nEnvironment: Test automation suite\nReproduction: Update the test scripts to include tests for concurrent sessions."],
215
+ ["There is frequent miscommunication between the development and QA teams regarding feature specifications.\n\nEnvironment: Inter-team meetings\nReproduction: Audit recent communication logs and meeting notes between the teams."],
216
+ ["The service-oriented architecture does not effectively isolate failures, leading to cascading failures across services.\n\nEnvironment: Microservices architecture\nReproduction: Simulate a service failure and observe the impact on other services."]
217
+ ]
218
+
219
+ # Improved CSS for better layout and appearance
220
+ css = """
221
+ .quality-container {
222
+ font-family: Arial, sans-serif;
223
+ text-align: center;
224
+ margin-top: 20px;
225
+ padding: 10px;
226
+ border: 1px solid #ddd; /* Added border */
227
+ border-radius: 8px; /* Rounded corners */
228
+ background-color: #f9f9f9; /* Light background */
229
+ }
230
+ .quality-label, .ranking {
231
+ display: inline-block;
232
+ padding: 0.5em 1em;
233
+ font-size: 18px;
234
+ font-weight: bold;
235
+ color: white;
236
+ background-color: #007bff;
237
+ border-radius: 0.5rem;
238
+ margin-right: 10px;
239
+ box-shadow: 0 2px 4px rgba(0, 0, 0, 0.2);
240
+ }
241
+ #explanation {
242
+ border: 1px solid #ccc;
243
+ padding: 10px;
244
+ margin-top: 10px;
245
+ border-radius: 4px;
246
+ background-color: #fff; /* White background for explanation */
247
+ overflow-y: auto; /* Ensure scrollbar appears if needed */
248
+ }
249
+ """
250
+
251
+ interface = gr.Interface(
252
+ fn=main_interface,
253
+ inputs=gr.Textbox(lines=7, label="Issue Description", placeholder="Enter your issue text here"),
254
+ outputs=[
255
+ gr.HTML(label="Prediction Output"),
256
+ gr.Textbox(label="Predictions", visible=False),
257
+ gr.Markdown(label="Explanation")
258
+ ],
259
+ title="QualityTagger",
260
+ description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
261
+ examples=example_texts,
262
+ css=css # Apply the CSS
263
+ )
264
+
265
+ interface.launch(share=True)