hiyata commited on
Commit
40fe6da
·
verified ·
1 Parent(s): 30b15ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
10
- from PIL import Image # Import PIL for image handling
11
 
12
  class VirusClassifier(nn.Module):
13
  def __init__(self, input_shape: int):
@@ -30,18 +30,19 @@ class VirusClassifier(nn.Module):
30
  return self.network(x)
31
 
32
  def get_feature_importance(self, x):
33
- """Calculate feature importance using gradient-based method"""
34
  x.requires_grad_(True)
35
  output = self.network(x)
36
- importance = torch.zeros_like(x)
37
 
38
- for i in range(output.shape[1]):
39
- if x.grad is not None:
40
- x.grad.zero_()
41
- output[..., i].sum().backward(retain_graph=True)
42
- importance += torch.abs(x.grad)
43
-
44
- return importance
 
45
 
46
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
47
  """Convert sequence to k-mer frequency vector"""
@@ -111,23 +112,20 @@ def predict(file_obj):
111
 
112
  try:
113
  sequences = parse_fasta(text)
114
- # For simplicity, process only the first sequence for plotting
115
  header, seq = sequences[0]
116
 
117
  raw_freq_vector = sequence_to_kmer_vector(seq)
118
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
119
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
120
 
121
- with torch.no_grad():
122
- output = model(X_tensor)
123
- probs = torch.softmax(output, dim=1)
124
-
125
- importance = model.get_feature_importance(X_tensor)
126
  kmer_importance = importance[0].cpu().numpy()
127
 
128
- if np.max(np.abs(kmer_importance)) != 0:
129
- kmer_importance = kmer_importance / np.max(np.abs(kmer_importance)) * 0.002
130
 
 
131
  top_k = 10
132
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
133
  important_kmers = [
@@ -140,9 +138,11 @@ def predict(file_obj):
140
  for i in top_indices
141
  ]
142
 
 
143
  top_features = [item['kmer'] for item in important_kmers]
144
  top_values = [item['importance'] for item in important_kmers]
145
 
 
146
  others_mask = np.ones_like(kmer_importance, dtype=bool)
147
  others_mask[top_indices] = False
148
  others_sum = np.sum(kmer_importance[others_mask])
@@ -150,10 +150,12 @@ def predict(file_obj):
150
  top_features.append("Others")
151
  top_values.append(others_sum)
152
 
153
- # Set base_values and expected_value to 0 for the binary classification starting point
 
 
154
  explanation = shap.Explanation(
155
  values=np.array(top_values),
156
- base_values=0.0,
157
  data=np.array([
158
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
159
  else np.sum(raw_freq_vector[others_mask])
@@ -161,18 +163,33 @@ def predict(file_obj):
161
  ]),
162
  feature_names=top_features
163
  )
164
- explanation.expected_value = 0.0
165
-
166
- fig = shap.plots._waterfall.waterfall_legacy(explanation, show=False)
 
 
 
 
 
 
 
167
 
 
168
  buf = io.BytesIO()
169
- fig.savefig(buf, format='png')
170
  buf.seek(0)
171
  plot_image = Image.open(buf)
 
 
 
 
 
 
172
 
173
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
174
  pred_label = 'human' if pred_class == 1 else 'non-human'
175
 
 
176
  results_text += f"""Sequence: {header}
177
  Prediction: {pred_label}
178
  Confidence: {float(max(probs[0])):0.4f}
@@ -181,8 +198,9 @@ Non-human probability: {float(probs[0][0]):0.4f}
181
  Most influential k-mers (ranked by importance):"""
182
 
183
  for kmer in important_kmers:
 
184
  results_text += f"\n {kmer['kmer']}: "
185
- results_text += f"impact={kmer['importance']:.4f}, "
186
  results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
187
  if kmer['scaled'] > 0:
188
  results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
@@ -203,4 +221,3 @@ iface = gr.Interface(
203
 
204
  if __name__ == "__main__":
205
  iface.launch(share=True)
206
-
 
7
  import shap
8
  import matplotlib.pyplot as plt
9
  import io
10
+ from PIL import Image
11
 
12
  class VirusClassifier(nn.Module):
13
  def __init__(self, input_shape: int):
 
30
  return self.network(x)
31
 
32
  def get_feature_importance(self, x):
33
+ """Calculate feature importance using gradient-based method for the human class (index 1)"""
34
  x.requires_grad_(True)
35
  output = self.network(x)
36
+ probs = torch.softmax(output, dim=1)
37
 
38
+ # We focus on the human class (index 1) probability
39
+ human_prob = probs[..., 1]
40
+ human_prob.backward()
41
+
42
+ # The gradient shows how each feature affects the human probability
43
+ importance = x.grad
44
+
45
+ return importance, float(human_prob)
46
 
47
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
48
  """Convert sequence to k-mer frequency vector"""
 
112
 
113
  try:
114
  sequences = parse_fasta(text)
 
115
  header, seq = sequences[0]
116
 
117
  raw_freq_vector = sequence_to_kmer_vector(seq)
118
  kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
119
  X_tensor = torch.FloatTensor(kmer_vector).to(device)
120
 
121
+ # Get feature importance and human probability
122
+ importance, human_prob = model.get_feature_importance(X_tensor)
 
 
 
123
  kmer_importance = importance[0].cpu().numpy()
124
 
125
+ # Scale importance values relative to the prediction
126
+ kmer_importance = kmer_importance * human_prob
127
 
128
+ # Get top k-mers by absolute importance
129
  top_k = 10
130
  top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
131
  important_kmers = [
 
138
  for i in top_indices
139
  ]
140
 
141
+ # Prepare data for SHAP waterfall plot
142
  top_features = [item['kmer'] for item in important_kmers]
143
  top_values = [item['importance'] for item in important_kmers]
144
 
145
+ # Calculate the impact of remaining features
146
  others_mask = np.ones_like(kmer_importance, dtype=bool)
147
  others_mask[top_indices] = False
148
  others_sum = np.sum(kmer_importance[others_mask])
 
150
  top_features.append("Others")
151
  top_values.append(others_sum)
152
 
153
+ # Create SHAP explanation
154
+ # Set base_value to 0.5 (neutral prediction)
155
+ # Values represent the push towards human (>0.5) or non-human (<0.5)
156
  explanation = shap.Explanation(
157
  values=np.array(top_values),
158
+ base_values=0.5, # Start from neutral prediction
159
  data=np.array([
160
  raw_freq_vector[kmer_dict[feat]] if feat != "Others"
161
  else np.sum(raw_freq_vector[others_mask])
 
163
  ]),
164
  feature_names=top_features
165
  )
166
+ explanation.expected_value = 0.5
167
+
168
+ # Create waterfall plot
169
+ plt.figure(figsize=(10, 6))
170
+ fig = shap.plots._waterfall.waterfall_legacy(
171
+ explanation,
172
+ show=False,
173
+ max_display=11 # Show all features including "Others"
174
+ )
175
+ plt.title(f"Impact on prediction (>0.5 pushes toward human, <0.5 toward non-human)")
176
 
177
+ # Save plot
178
  buf = io.BytesIO()
179
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
180
  buf.seek(0)
181
  plot_image = Image.open(buf)
182
+ plt.close()
183
+
184
+ # Calculate final probabilities
185
+ with torch.no_grad():
186
+ output = model(X_tensor)
187
+ probs = torch.softmax(output, dim=1)
188
 
189
  pred_class = 1 if probs[0][1] > probs[0][0] else 0
190
  pred_label = 'human' if pred_class == 1 else 'non-human'
191
 
192
+ # Generate results text
193
  results_text += f"""Sequence: {header}
194
  Prediction: {pred_label}
195
  Confidence: {float(max(probs[0])):0.4f}
 
198
  Most influential k-mers (ranked by importance):"""
199
 
200
  for kmer in important_kmers:
201
+ direction = "human" if kmer['importance'] > 0 else "non-human"
202
  results_text += f"\n {kmer['kmer']}: "
203
+ results_text += f"pushes toward {direction} (impact={abs(kmer['importance']):.4f}), "
204
  results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
205
  if kmer['scaled'] > 0:
206
  results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
 
221
 
222
  if __name__ == "__main__":
223
  iface.launch(share=True)