hiyata commited on
Commit
1b5b7bf
·
verified ·
1 Parent(s): 0c54683

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -89
app.py CHANGED
@@ -655,153 +655,174 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
655
  'pos_fraction': float(np.mean(gene_shap > 0))
656
  }
657
 
658
- def create_simple_genome_diagram(gene_results, genome_length):
 
 
 
 
659
  from PIL import Image, ImageDraw, ImageFont
660
-
661
  # Validate inputs
662
  if not gene_results or genome_length <= 0:
663
- img = Image.new('RGBA', (800, 100), color=(255, 255, 255, 255))
664
- draw = ImageDraw.Draw(img, 'RGBA')
665
  draw.text((10, 40), "Error: Invalid input data", fill='black')
666
  return img
667
-
668
- # Ensure valid gene coords
669
  for gene in gene_results:
670
  gene['start'] = max(0, int(gene['start']))
671
- gene['end'] = min(genome_length, int(gene['end']))
672
  if gene['start'] >= gene['end']:
673
- print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: "
674
- f"{gene['start']}-{gene['end']}")
675
-
676
- # Dimensions
677
- width, height = 1500, 600
678
  margin = 50
679
  track_height = 40
680
-
681
- # Create RGBA image
682
- img = Image.new('RGBA', (width, height), (255, 255, 255, 255))
683
- draw = ImageDraw.Draw(img, 'RGBA')
684
-
685
- # Fonts
686
  try:
687
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
688
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
689
  except:
690
  font = ImageFont.load_default()
691
  title_font = ImageFont.load_default()
692
-
693
- # Draw title text
694
- draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font)
695
-
696
- # Draw genome line & ticks FIRST (so rectangles are partially see-through)
697
  line_y = height // 2
698
- draw.line([(margin, line_y), (width - margin, line_y)], fill='black', width=2)
699
-
700
- # Scale factor
701
- scale = (width - 2 * margin) / float(genome_length)
702
-
703
- # Ticks
704
  num_ticks = 10
705
- step = 1 if genome_length < num_ticks else (genome_length // num_ticks)
 
 
 
 
 
706
  for i in range(0, genome_length + 1, step):
707
  x_coord = margin + i * scale
708
- draw.line([(int(x_coord), line_y - 5), (int(x_coord), line_y + 5)],
709
- fill='black', width=1)
710
- draw.text((int(x_coord - 20), line_y + 10), f"{i:,}", fill='black', font=font)
711
-
712
- # Sort genes by absolute shap so smaller shap genes get drawn first
713
- # (and partially appear behind bigger shap genes).
 
714
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
715
-
716
- # Draw gene boxes with partial alpha
717
  for idx, gene in enumerate(sorted_genes):
 
718
  start_x = margin + int(gene['start'] * scale)
719
  end_x = margin + int(gene['end'] * scale)
720
-
721
- # Compute color
722
  avg_shap = gene['avg_shap']
723
- intensity = min(255, int(abs(avg_shap)*500))
724
- # clamp a bit so it doesn't look white
725
- intensity = max(50, intensity)
 
 
726
 
727
  if avg_shap > 0:
728
- # Red-ish
729
- color = (255, 255 - intensity, 255 - intensity, 180)
730
  else:
731
- # Blue-ish
732
- color = (255 - intensity, 255 - intensity, 255, 180)
733
-
734
- # Partially transparent rectangle
735
  draw.rectangle([
736
- (start_x, line_y - track_height // 2),
737
- (end_x, line_y + track_height // 2)
738
- ], fill=color, outline=(0, 0, 0, 255))
739
-
740
- # Label
741
- label = gene.get('gene_name', '?')
 
 
 
 
742
  label_mask = font.getmask(label)
743
  label_width, label_height = label_mask.size
744
-
745
- # Above or below
746
  if idx % 2 == 0:
747
  text_y = line_y - track_height - 15
748
  else:
749
  text_y = line_y + track_height + 5
750
-
751
- # If there's room, draw horizontally; else rotate
752
  gene_width = end_x - start_x
753
  if gene_width > label_width:
754
  text_x = start_x + (gene_width - label_width) // 2
755
- draw.text((text_x, text_y), label, fill='black', font=font)
756
  elif gene_width > 20:
757
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
758
  txt_draw = ImageDraw.Draw(txt_img)
759
  txt_draw.text((0, 0), label, font=font, fill='black')
760
  rotated_img = txt_img.rotate(90, expand=True)
761
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
762
-
763
- # Legend
764
  legend_x = margin
765
  legend_y = height - margin
766
- draw.text((legend_x, legend_y - 60), "SHAP Values:", fill='black', font=font)
767
-
768
- box_width, box_height = 20, 20
 
 
769
  spacing = 15
770
- # strong human-like
 
771
  draw.rectangle([
772
- (legend_x, legend_y - 45),
773
- (legend_x + box_width, legend_y - 45 + box_height)
774
- ], fill=(255, 0, 0, 255), outline=(0, 0, 0, 255))
775
- draw.text((legend_x + box_width + spacing, legend_y - 45),
776
  "Strong human-like signal", fill='black', font=font)
777
-
778
- # weak human-like
779
  draw.rectangle([
780
- (legend_x, legend_y - 20),
781
- (legend_x + box_width, legend_y - 20 + box_height)
782
- ], fill=(255, 200, 200, 255), outline=(0, 0, 0, 255))
783
- draw.text((legend_x + box_width + spacing, legend_y - 20),
784
  "Weak human-like signal", fill='black', font=font)
785
-
786
- # weak non-human-like
787
  draw.rectangle([
788
- (legend_x + 250, legend_y - 45),
789
- (legend_x + 250 + box_width, legend_y - 45 + box_height)
790
- ], fill=(200, 200, 255, 255), outline=(0, 0, 0, 255))
791
- draw.text((legend_x + 250 + box_width + spacing, legend_y - 45),
792
  "Weak non-human-like signal", fill='black', font=font)
793
-
794
- # strong non-human-like
795
  draw.rectangle([
796
- (legend_x + 250, legend_y - 20),
797
- (legend_x + 250 + box_width, legend_y - 20 + box_height)
798
- ], fill=(0, 0, 255, 255), outline=(0, 0, 0, 255))
799
- draw.text((legend_x + 250 + box_width + spacing, legend_y - 20),
800
  "Strong non-human-like signal", fill='black', font=font)
801
-
802
  return img
803
 
804
 
 
805
  def analyze_gene_features(sequence_file: str,
806
  features_file: str,
807
  fasta_text: str = "",
 
655
  'pos_fraction': float(np.mean(gene_shap > 0))
656
  }
657
 
658
+ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
659
+ """
660
+ Create a simple genome diagram using PIL, forcing a minimum color intensity
661
+ so that small SHAP values don't appear white.
662
+ """
663
  from PIL import Image, ImageDraw, ImageFont
664
+
665
  # Validate inputs
666
  if not gene_results or genome_length <= 0:
667
+ img = Image.new('RGB', (800, 100), color='white')
668
+ draw = ImageDraw.Draw(img)
669
  draw.text((10, 40), "Error: Invalid input data", fill='black')
670
  return img
671
+
672
+ # Ensure all gene coordinates are valid integers
673
  for gene in gene_results:
674
  gene['start'] = max(0, int(gene['start']))
675
+ gene['end'] = min(genome_length, int(gene['end']))
676
  if gene['start'] >= gene['end']:
677
+ print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}")
678
+
679
+ # Image dimensions
680
+ width = 1500
681
+ height = 600
682
  margin = 50
683
  track_height = 40
684
+
685
+ # Create image with white background
686
+ img = Image.new('RGB', (width, height), 'white')
687
+ draw = ImageDraw.Draw(img)
688
+
689
+ # Try to load font, fall back to default if unavailable
690
  try:
691
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
692
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
693
  except:
694
  font = ImageFont.load_default()
695
  title_font = ImageFont.load_default()
696
+
697
+ # Draw title
698
+ draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font)
699
+
700
+ # Draw genome line
701
  line_y = height // 2
702
+ draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
703
+
704
+ # Calculate scale factor
705
+ scale = float(width - 2 * margin) / float(genome_length)
706
+
707
+ # Determine a reasonable step for scale markers
708
  num_ticks = 10
709
+ if genome_length < num_ticks:
710
+ step = 1
711
+ else:
712
+ step = genome_length // num_ticks
713
+
714
+ # Draw scale markers
715
  for i in range(0, genome_length + 1, step):
716
  x_coord = margin + i * scale
717
+ draw.line([
718
+ (int(x_coord), int(line_y - 5)),
719
+ (int(x_coord), int(line_y + 5))
720
+ ], fill='black', width=1)
721
+ draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
722
+
723
+ # Sort genes by absolute SHAP value for drawing
724
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
725
+
726
+ # Draw genes
727
  for idx, gene in enumerate(sorted_genes):
728
+ # Calculate position and ensure integers
729
  start_x = margin + int(gene['start'] * scale)
730
  end_x = margin + int(gene['end'] * scale)
731
+
732
+ # Calculate color based on SHAP value
733
  avg_shap = gene['avg_shap']
734
+
735
+ # Convert shap -> color intensity (0 to 255)
736
+ # Then clamp to a minimum intensity so it never ends up plain white
737
+ intensity = int(abs(avg_shap) * 500)
738
+ intensity = max(50, min(255, intensity)) # clamp between 50 and 255
739
 
740
  if avg_shap > 0:
741
+ # Red-ish for positive
742
+ color = (255, 255 - intensity, 255 - intensity)
743
  else:
744
+ # Blue-ish for negative or zero
745
+ color = (255 - intensity, 255 - intensity, 255)
746
+
747
+ # Draw gene rectangle
748
  draw.rectangle([
749
+ (int(start_x), int(line_y - track_height // 2)),
750
+ (int(end_x), int(line_y + track_height // 2))
751
+ ], fill=color, outline='black')
752
+
753
+ # Prepare gene name label
754
+ label = str(gene.get('gene_name','?'))
755
+
756
+ # If getsize() or textsize() is missing, use getmask(...).size as fallback
757
+ # But if your Pillow version supports font.getsize, you can do:
758
+ # label_width, label_height = font.getsize(label)
759
  label_mask = font.getmask(label)
760
  label_width, label_height = label_mask.size
761
+
762
+ # Alternate label positions above/below line
763
  if idx % 2 == 0:
764
  text_y = line_y - track_height - 15
765
  else:
766
  text_y = line_y + track_height + 5
767
+
768
+ # Decide whether to rotate text based on space
769
  gene_width = end_x - start_x
770
  if gene_width > label_width:
771
  text_x = start_x + (gene_width - label_width) // 2
772
+ draw.text((int(text_x), int(text_y)), label, fill='black', font=font)
773
  elif gene_width > 20:
774
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
775
  txt_draw = ImageDraw.Draw(txt_img)
776
  txt_draw.text((0, 0), label, font=font, fill='black')
777
  rotated_img = txt_img.rotate(90, expand=True)
778
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
779
+
780
+ # Draw legend
781
  legend_x = margin
782
  legend_y = height - margin
783
+ draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font)
784
+
785
+ # Draw legend boxes
786
+ box_width = 20
787
+ box_height = 20
788
  spacing = 15
789
+
790
+ # Strong human-like
791
  draw.rectangle([
792
+ (int(legend_x), int(legend_y - 45)),
793
+ (int(legend_x + box_width), int(legend_y - 45 + box_height))
794
+ ], fill=(255, 0, 0), outline='black')
795
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)),
796
  "Strong human-like signal", fill='black', font=font)
797
+
798
+ # Weak human-like
799
  draw.rectangle([
800
+ (int(legend_x), int(legend_y - 20)),
801
+ (int(legend_x + box_width), int(legend_y - 20 + box_height))
802
+ ], fill=(255, 200, 200), outline='black')
803
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)),
804
  "Weak human-like signal", fill='black', font=font)
805
+
806
+ # Weak non-human-like
807
  draw.rectangle([
808
+ (int(legend_x + 250), int(legend_y - 45)),
809
+ (int(legend_x + 250 + box_width), int(legend_y - 45 + box_height))
810
+ ], fill=(200, 200, 255), outline='black')
811
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)),
812
  "Weak non-human-like signal", fill='black', font=font)
813
+
814
+ # Strong non-human-like
815
  draw.rectangle([
816
+ (int(legend_x + 250), int(legend_y - 20)),
817
+ (int(legend_x + 250 + box_width), int(legend_y - 20 + box_height))
818
+ ], fill=(0, 0, 255), outline='black')
819
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)),
820
  "Strong non-human-like signal", fill='black', font=font)
821
+
822
  return img
823
 
824
 
825
+
826
  def analyze_gene_features(sequence_file: str,
827
  features_file: str,
828
  fasta_text: str = "",