nuojohnchen commited on
Commit
001b7f2
·
verified ·
1 Parent(s): de7a426

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -14
app.py CHANGED
@@ -2,9 +2,11 @@ import gradio as gr
2
  import os
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
- import PyPDF2
6
- from io import BytesIO
7
  import torch
 
 
 
 
8
 
9
  # Set environment variables
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -65,30 +67,70 @@ AVAILABLE_MODELS = {
65
  current_model = None
66
  current_tokenizer = None
67
  current_model_name = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  def extract_text_from_pdf(pdf_bytes):
70
- """Extract text from uploaded PDF file"""
71
  if pdf_bytes is None:
72
  return default_paper_content
73
 
74
  try:
75
- # Ensure pdf_bytes is bytes type
76
- if isinstance(pdf_bytes, str):
77
- return pdf_bytes # If already a string, return directly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Use bytes object directly
80
- pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
81
 
82
- # Extract text from all pages
83
- text = ""
84
- for page_num in range(len(pdf_reader.pages)):
85
- page = pdf_reader.pages[page_num]
86
- text += page.extract_text() + "\n\n"
87
 
88
- return text
89
  except Exception as e:
90
  print(f"PDF extraction error: {str(e)}")
91
  return default_paper_content
 
 
 
92
 
93
  def load_model(model_name):
94
  """Load model and tokenizer on demand"""
 
2
  import os
3
  import spaces
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
5
  import torch
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ import fitz # PyMuPDF
9
+ from transformers import NougatProcessor, VisionEncoderDecoderModel
10
 
11
  # Set environment variables
12
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
67
  current_model = None
68
  current_tokenizer = None
69
  current_model_name = None
70
+ nougat_model = None
71
+ nougat_processor = None
72
+
73
+ @spaces.GPU(duration=200)
74
+ def load_nougat_model():
75
+ """Load Nougat model for PDF processing"""
76
+ global nougat_model, nougat_processor
77
+
78
+ if nougat_model is None or nougat_processor is None:
79
+ nougat_processor = NougatProcessor.from_pretrained("facebook/nougat-base")
80
+ nougat_model = VisionEncoderDecoderModel.from_pretrained("facebook/nougat-base")
81
+ nougat_model.to("cuda" if torch.cuda.is_available() else "cpu")
82
+
83
+ return nougat_processor, nougat_model
84
 
85
+ @spaces.GPU(duration=200)
86
  def extract_text_from_pdf(pdf_bytes):
87
+ """Extract text from uploaded PDF file using Nougat"""
88
  if pdf_bytes is None:
89
  return default_paper_content
90
 
91
  try:
92
+ # Load Nougat model
93
+ processor, model = load_nougat_model()
94
+
95
+ # Convert PDF to images
96
+ pdf_document = fitz.open(stream=pdf_bytes, filetype="pdf")
97
+ full_text = ""
98
+
99
+ for page_num in range(len(pdf_document)):
100
+ page = pdf_document.load_page(page_num)
101
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality
102
+
103
+ # Convert to PIL Image
104
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
105
+
106
+ # Process with Nougat
107
+ pixel_values = processor(img, return_tensors="pt").pixel_values.to(model.device)
108
+
109
+ # Generate text
110
+ outputs = model.generate(
111
+ pixel_values,
112
+ min_length=1,
113
+ max_new_tokens=1024, # Adjust based on expected page content length
114
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
115
+ )
116
+
117
+ # Decode and post-process
118
+ page_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
119
+ page_text = processor.post_process_generation(page_text, fix_markdown=True)
120
 
121
+ full_text += page_text + "\n\n"
 
122
 
123
+ # Clear GPU memory
124
+ del pixel_values, outputs
125
+ torch.cuda.empty_cache()
 
 
126
 
127
+ return full_text
128
  except Exception as e:
129
  print(f"PDF extraction error: {str(e)}")
130
  return default_paper_content
131
+ finally:
132
+ # Clear GPU memory
133
+ torch.cuda.empty_cache()
134
 
135
  def load_model(model_name):
136
  """Load model and tokenizer on demand"""