Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import torch.nn as nn
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
-
from PIL import Image
|
11 |
|
12 |
class VirusClassifier(nn.Module):
|
13 |
def __init__(self, input_shape: int):
|
@@ -30,18 +30,19 @@ class VirusClassifier(nn.Module):
|
|
30 |
return self.network(x)
|
31 |
|
32 |
def get_feature_importance(self, x):
|
33 |
-
"""Calculate feature importance using gradient-based method"""
|
34 |
x.requires_grad_(True)
|
35 |
output = self.network(x)
|
36 |
-
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
47 |
"""Convert sequence to k-mer frequency vector"""
|
@@ -111,23 +112,20 @@ def predict(file_obj):
|
|
111 |
|
112 |
try:
|
113 |
sequences = parse_fasta(text)
|
114 |
-
# For simplicity, process only the first sequence for plotting
|
115 |
header, seq = sequences[0]
|
116 |
|
117 |
raw_freq_vector = sequence_to_kmer_vector(seq)
|
118 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
119 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
probs = torch.softmax(output, dim=1)
|
124 |
-
|
125 |
-
importance = model.get_feature_importance(X_tensor)
|
126 |
kmer_importance = importance[0].cpu().numpy()
|
127 |
|
128 |
-
|
129 |
-
|
130 |
|
|
|
131 |
top_k = 10
|
132 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
133 |
important_kmers = [
|
@@ -140,9 +138,11 @@ def predict(file_obj):
|
|
140 |
for i in top_indices
|
141 |
]
|
142 |
|
|
|
143 |
top_features = [item['kmer'] for item in important_kmers]
|
144 |
top_values = [item['importance'] for item in important_kmers]
|
145 |
|
|
|
146 |
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
147 |
others_mask[top_indices] = False
|
148 |
others_sum = np.sum(kmer_importance[others_mask])
|
@@ -150,10 +150,12 @@ def predict(file_obj):
|
|
150 |
top_features.append("Others")
|
151 |
top_values.append(others_sum)
|
152 |
|
153 |
-
#
|
|
|
|
|
154 |
explanation = shap.Explanation(
|
155 |
values=np.array(top_values),
|
156 |
-
base_values=0.
|
157 |
data=np.array([
|
158 |
raw_freq_vector[kmer_dict[feat]] if feat != "Others"
|
159 |
else np.sum(raw_freq_vector[others_mask])
|
@@ -161,18 +163,33 @@ def predict(file_obj):
|
|
161 |
]),
|
162 |
feature_names=top_features
|
163 |
)
|
164 |
-
explanation.expected_value = 0.
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
|
|
168 |
buf = io.BytesIO()
|
169 |
-
|
170 |
buf.seek(0)
|
171 |
plot_image = Image.open(buf)
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
174 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
175 |
|
|
|
176 |
results_text += f"""Sequence: {header}
|
177 |
Prediction: {pred_label}
|
178 |
Confidence: {float(max(probs[0])):0.4f}
|
@@ -181,8 +198,9 @@ Non-human probability: {float(probs[0][0]):0.4f}
|
|
181 |
Most influential k-mers (ranked by importance):"""
|
182 |
|
183 |
for kmer in important_kmers:
|
|
|
184 |
results_text += f"\n {kmer['kmer']}: "
|
185 |
-
results_text += f"impact={kmer['importance']:.4f}, "
|
186 |
results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
|
187 |
if kmer['scaled'] > 0:
|
188 |
results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
|
@@ -203,4 +221,3 @@ iface = gr.Interface(
|
|
203 |
|
204 |
if __name__ == "__main__":
|
205 |
iface.launch(share=True)
|
206 |
-
|
|
|
7 |
import shap
|
8 |
import matplotlib.pyplot as plt
|
9 |
import io
|
10 |
+
from PIL import Image
|
11 |
|
12 |
class VirusClassifier(nn.Module):
|
13 |
def __init__(self, input_shape: int):
|
|
|
30 |
return self.network(x)
|
31 |
|
32 |
def get_feature_importance(self, x):
|
33 |
+
"""Calculate feature importance using gradient-based method for the human class (index 1)"""
|
34 |
x.requires_grad_(True)
|
35 |
output = self.network(x)
|
36 |
+
probs = torch.softmax(output, dim=1)
|
37 |
|
38 |
+
# We focus on the human class (index 1) probability
|
39 |
+
human_prob = probs[..., 1]
|
40 |
+
human_prob.backward()
|
41 |
+
|
42 |
+
# The gradient shows how each feature affects the human probability
|
43 |
+
importance = x.grad
|
44 |
+
|
45 |
+
return importance, float(human_prob)
|
46 |
|
47 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
48 |
"""Convert sequence to k-mer frequency vector"""
|
|
|
112 |
|
113 |
try:
|
114 |
sequences = parse_fasta(text)
|
|
|
115 |
header, seq = sequences[0]
|
116 |
|
117 |
raw_freq_vector = sequence_to_kmer_vector(seq)
|
118 |
kmer_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
119 |
X_tensor = torch.FloatTensor(kmer_vector).to(device)
|
120 |
|
121 |
+
# Get feature importance and human probability
|
122 |
+
importance, human_prob = model.get_feature_importance(X_tensor)
|
|
|
|
|
|
|
123 |
kmer_importance = importance[0].cpu().numpy()
|
124 |
|
125 |
+
# Scale importance values relative to the prediction
|
126 |
+
kmer_importance = kmer_importance * human_prob
|
127 |
|
128 |
+
# Get top k-mers by absolute importance
|
129 |
top_k = 10
|
130 |
top_indices = np.argsort(np.abs(kmer_importance))[-top_k:][::-1]
|
131 |
important_kmers = [
|
|
|
138 |
for i in top_indices
|
139 |
]
|
140 |
|
141 |
+
# Prepare data for SHAP waterfall plot
|
142 |
top_features = [item['kmer'] for item in important_kmers]
|
143 |
top_values = [item['importance'] for item in important_kmers]
|
144 |
|
145 |
+
# Calculate the impact of remaining features
|
146 |
others_mask = np.ones_like(kmer_importance, dtype=bool)
|
147 |
others_mask[top_indices] = False
|
148 |
others_sum = np.sum(kmer_importance[others_mask])
|
|
|
150 |
top_features.append("Others")
|
151 |
top_values.append(others_sum)
|
152 |
|
153 |
+
# Create SHAP explanation
|
154 |
+
# Set base_value to 0.5 (neutral prediction)
|
155 |
+
# Values represent the push towards human (>0.5) or non-human (<0.5)
|
156 |
explanation = shap.Explanation(
|
157 |
values=np.array(top_values),
|
158 |
+
base_values=0.5, # Start from neutral prediction
|
159 |
data=np.array([
|
160 |
raw_freq_vector[kmer_dict[feat]] if feat != "Others"
|
161 |
else np.sum(raw_freq_vector[others_mask])
|
|
|
163 |
]),
|
164 |
feature_names=top_features
|
165 |
)
|
166 |
+
explanation.expected_value = 0.5
|
167 |
+
|
168 |
+
# Create waterfall plot
|
169 |
+
plt.figure(figsize=(10, 6))
|
170 |
+
fig = shap.plots._waterfall.waterfall_legacy(
|
171 |
+
explanation,
|
172 |
+
show=False,
|
173 |
+
max_display=11 # Show all features including "Others"
|
174 |
+
)
|
175 |
+
plt.title(f"Impact on prediction (>0.5 pushes toward human, <0.5 toward non-human)")
|
176 |
|
177 |
+
# Save plot
|
178 |
buf = io.BytesIO()
|
179 |
+
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
|
180 |
buf.seek(0)
|
181 |
plot_image = Image.open(buf)
|
182 |
+
plt.close()
|
183 |
+
|
184 |
+
# Calculate final probabilities
|
185 |
+
with torch.no_grad():
|
186 |
+
output = model(X_tensor)
|
187 |
+
probs = torch.softmax(output, dim=1)
|
188 |
|
189 |
pred_class = 1 if probs[0][1] > probs[0][0] else 0
|
190 |
pred_label = 'human' if pred_class == 1 else 'non-human'
|
191 |
|
192 |
+
# Generate results text
|
193 |
results_text += f"""Sequence: {header}
|
194 |
Prediction: {pred_label}
|
195 |
Confidence: {float(max(probs[0])):0.4f}
|
|
|
198 |
Most influential k-mers (ranked by importance):"""
|
199 |
|
200 |
for kmer in important_kmers:
|
201 |
+
direction = "human" if kmer['importance'] > 0 else "non-human"
|
202 |
results_text += f"\n {kmer['kmer']}: "
|
203 |
+
results_text += f"pushes toward {direction} (impact={abs(kmer['importance']):.4f}), "
|
204 |
results_text += f"occurrence={kmer['frequency']*100:.2f}% of sequence "
|
205 |
if kmer['scaled'] > 0:
|
206 |
results_text += f"(appears {abs(kmer['scaled']):.2f}σ more than average)"
|
|
|
221 |
|
222 |
if __name__ == "__main__":
|
223 |
iface.launch(share=True)
|
|