karths commited on
Commit
a754efe
·
verified ·
1 Parent(s): 740d8bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -18
app.py CHANGED
@@ -75,7 +75,7 @@ def model_prediction(model, text, device):
75
  # --- Llama 3.2 3B Model Setup ---
76
  LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
77
  LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
78
- LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "700")) # Reduced
79
  llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
80
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
@@ -125,28 +125,43 @@ def llama_generate(
125
 
126
 
127
  def generate_explanation(issue_text, top_qualities):
128
- """Generates an explanation using Llama 3.2 3B."""
129
  if not top_qualities:
130
- return "No explanation available as no quality tags were predicted."
131
 
132
- # Build the prompt, explicitly mentioning each quality
133
- prompt_parts = [
134
- "Given the following issue description:\n---\n",
135
- issue_text,
136
- "\n---\n",
137
- "Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:\n"
138
- ]
139
- for quality, _ in top_qualities: # Iterate through qualities
140
- prompt_parts.append(f"- {quality}\n")
141
 
142
- prompt = "".join(prompt_parts)
143
 
144
  try:
145
- explanation = llama_generate(prompt) # Get the explanation (not streamed)
146
- return explanation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  except Exception as e:
148
  logging.error(f"Error during Llama generation: {e}")
149
- return "An error occurred while generating the explanation."
150
 
151
 
152
  # @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
@@ -229,10 +244,10 @@ interface = gr.Interface(
229
  outputs=[
230
  gr.HTML(label="Prediction Output"),
231
  gr.Textbox(label="Predictions", visible=False),
232
- gr.Textbox(label="Explanation", lines=5)
233
  ],
234
  title="QualityTagger",
235
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
236
  examples=example_texts
237
  )
238
- interface.launch(share=True)
 
75
  # --- Llama 3.2 3B Model Setup ---
76
  LLAMA_MAX_MAX_NEW_TOKENS = 512 # Max tokens for Explanation
77
  LLAMA_DEFAULT_MAX_NEW_TOKENS = 512 # Max tokens for explantion
78
+ LLAMA_MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) # Reduced
79
  llama_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Explicit device
80
  llama_model_id = "meta-llama/Llama-3.2-3B-Instruct"
81
  llama_tokenizer = AutoTokenizer.from_pretrained(llama_model_id)
 
125
 
126
 
127
  def generate_explanation(issue_text, top_qualities):
 
128
  if not top_qualities:
129
+ return "<div style='color: red;'>No explanation available as no quality tags were predicted.</div>"
130
 
131
+ prompt = f"""
132
+ Given the following issue description:
133
+ ---
134
+ {issue_text}
135
+ ---
136
+ Explain why this issue might be classified under the following quality categories. Provide a concise explanation for each category, relating it back to the issue description:
137
+ """
138
+ for quality, _ in top_qualities:
139
+ prompt += f"- {quality}\n"
140
 
 
141
 
142
  try:
143
+ explanation = llama_generate(prompt)
144
+ # Format the explanation for better readability
145
+ formatted_explanation = ""
146
+ for quality, _ in top_qualities:
147
+ formatted_explanation += f"<p><b>{quality}:</b></p>" # Bold the quality name
148
+ # Find the explanation for this specific quality. This is a simple
149
+ # approach that works if Llama follows the prompt structure.
150
+ # A more robust approach might use regex or sentence embeddings.
151
+ start = explanation.find(quality)
152
+ if start != -1:
153
+ start += len(quality) + 2 # Move past "Quality:"
154
+ end = explanation.find("\n", start) # Find next newline
155
+ if end == -1:
156
+ end = len(explanation)
157
+ formatted_explanation += f"<p>{explanation[start:end].strip()}</p>" # Add the explanation text
158
+ else:
159
+ formatted_explanation += f"<p>Explanation for {quality} not found.</p>"
160
+
161
+ return f"<div style='overflow-y: scroll; max-height: 400px;'>{formatted_explanation}</div>" #Added scroll
162
  except Exception as e:
163
  logging.error(f"Error during Llama generation: {e}")
164
+ return "<div style='color: red;'>An error occurred while generating the explanation.</div>"
165
 
166
 
167
  # @spaces.GPU(duration=60) # Apply the GPU decorator *only* to the main interface
 
244
  outputs=[
245
  gr.HTML(label="Prediction Output"),
246
  gr.Textbox(label="Predictions", visible=False),
247
+ gr.HTML(label="Explanation") # Change to gr.HTML
248
  ],
249
  title="QualityTagger",
250
  description="This tool classifies text into different quality domains such as Security, Usability, etc., and provides explanations.",
251
  examples=example_texts
252
  )
253
+ interface.launch(share=True)