AnseMin commited on
Commit
98482ce
·
1 Parent(s): b3c2847

New feature: Mistral OCR

Browse files

- changed the requirement txt and added new mistral ocr parser script

.env.example CHANGED
@@ -1,3 +1,5 @@
1
  # API keys for various services
2
- GOOGLE_API_KEY=your_google_api_key_here
3
- OPENAI_API_KEY=your_openai_api_key_here
 
 
 
1
  # API keys for various services
2
+ GOOGLE_API_KEY=your_google_api_key
3
+ OPENAI_API_KEY=your_openai_api_key
4
+ TESSDATAFIX_PREFIX=/path/to/tessdata
5
+ MISTRAL_API_KEY=your_mistral_api_key
requirements.txt CHANGED
@@ -17,6 +17,9 @@ pydantic==2.7.1
17
  # Gemini API client
18
  google-genai>=0.1.0
19
 
 
 
 
20
  # GOT-OCR dependencies - exactly as in original
21
  torch
22
  torchvision
 
17
  # Gemini API client
18
  google-genai>=0.1.0
19
 
20
+ # Mistral AI client
21
+ mistralai>=1.0.0
22
+
23
  # GOT-OCR dependencies - exactly as in original
24
  torch
25
  torchvision
src/parsers/__init__.py CHANGED
@@ -3,6 +3,7 @@
3
  # Import all parsers to ensure they're registered
4
  from src.parsers.gemini_flash_parser import GeminiFlashParser
5
  from src.parsers.got_ocr_parser import GotOcrParser
 
6
 
7
  # Import MarkItDown parser if available - needs to be imported last so it's default
8
  try:
 
3
  # Import all parsers to ensure they're registered
4
  from src.parsers.gemini_flash_parser import GeminiFlashParser
5
  from src.parsers.got_ocr_parser import GotOcrParser
6
+ from src.parsers.mistral_ocr_parser import MistralOcrParser
7
 
8
  # Import MarkItDown parser if available - needs to be imported last so it's default
9
  try:
src/parsers/mistral_ocr_parser.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Dict, List, Optional, Any, Union
3
+ import os
4
+ import base64
5
+ import tempfile
6
+ import json
7
+ from PIL import Image
8
+ import io
9
+
10
+ from src.parsers.parser_interface import DocumentParser
11
+ from src.parsers.parser_registry import ParserRegistry
12
+
13
+ # Import the Mistral AI client
14
+ try:
15
+ from mistralai import Mistral
16
+ MISTRAL_AVAILABLE = True
17
+ except ImportError:
18
+ MISTRAL_AVAILABLE = False
19
+
20
+ # Load API key from environment variable
21
+ api_key = os.getenv("MISTRAL_API_KEY")
22
+
23
+ # Check if API key is available and print a message if not
24
+ if not api_key:
25
+ print("Warning: MISTRAL_API_KEY environment variable not found. Mistral OCR parser may not work.")
26
+
27
+ class MistralOcrParser(DocumentParser):
28
+ """Parser that uses Mistral OCR to convert documents to markdown."""
29
+
30
+ @classmethod
31
+ def get_name(cls) -> str:
32
+ return "Mistral OCR"
33
+
34
+ @classmethod
35
+ def get_supported_ocr_methods(cls) -> List[Dict[str, Any]]:
36
+ return [
37
+ {
38
+ "id": "ocr",
39
+ "name": "OCR Only",
40
+ "default_params": {}
41
+ },
42
+ {
43
+ "id": "understand",
44
+ "name": "Document Understanding",
45
+ "default_params": {}
46
+ }
47
+ ]
48
+
49
+ @classmethod
50
+ def get_description(cls) -> str:
51
+ return "Mistral OCR parser for extracting text from documents and images with optional document understanding"
52
+
53
+ def encode_image(self, image_path):
54
+ """Encode the image to base64."""
55
+ try:
56
+ with open(image_path, "rb") as image_file:
57
+ return base64.b64encode(image_file.read()).decode('utf-8')
58
+ except FileNotFoundError:
59
+ print(f"Error: The file {image_path} was not found.")
60
+ return None
61
+ except Exception as e:
62
+ print(f"Error: {e}")
63
+ return None
64
+
65
+ def parse(self, file_path: Union[str, Path], ocr_method: Optional[str] = None, **kwargs) -> str:
66
+ """Parse a document using Mistral OCR."""
67
+ if not MISTRAL_AVAILABLE:
68
+ raise ImportError(
69
+ "The Mistral AI client is not installed. "
70
+ "Please install it with 'pip install mistralai'."
71
+ )
72
+
73
+ # Use the globally loaded API key
74
+ if not api_key:
75
+ raise ValueError(
76
+ "MISTRAL_API_KEY environment variable is not set. "
77
+ "Please set it to your Mistral API key."
78
+ )
79
+
80
+ # Check the OCR method
81
+ use_document_understanding = ocr_method == "understand"
82
+
83
+ try:
84
+ # Initialize the Mistral client
85
+ client = Mistral(api_key=api_key)
86
+
87
+ # Determine file type based on extension
88
+ file_path = Path(file_path)
89
+ file_extension = file_path.suffix.lower()
90
+
91
+ # Process the document with OCR
92
+ if use_document_understanding:
93
+ # Use document understanding via chat API for enhanced extraction
94
+ return self._extract_with_document_understanding(client, file_path, file_extension)
95
+ else:
96
+ # Use regular OCR for basic text extraction
97
+ return self._extract_with_ocr(client, file_path, file_extension)
98
+
99
+ except Exception as e:
100
+ error_message = f"Error parsing document with Mistral OCR: {str(e)}"
101
+ print(error_message)
102
+ return f"# Error\n\n{error_message}\n\nPlease check your API key and try again."
103
+
104
+ def _extract_with_ocr(self, client, file_path, file_extension):
105
+ """Extract document content using basic OCR."""
106
+ try:
107
+ # Process according to file type
108
+ if file_extension in ['.pdf']:
109
+ # For PDFs, we need to upload the file to the Mistral API first
110
+ try:
111
+ # Upload the file to Mistral API
112
+ uploaded_pdf = client.files.upload(
113
+ file={
114
+ "file_name": file_path.name,
115
+ "content": open(file_path, "rb"),
116
+ },
117
+ purpose="ocr"
118
+ )
119
+
120
+ # Get signed URL for the file
121
+ signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id)
122
+
123
+ # Use the signed URL for OCR processing
124
+ ocr_response = client.ocr.process(
125
+ model="mistral-ocr-latest",
126
+ document={
127
+ "type": "document_url",
128
+ "document_url": signed_url.url
129
+ }
130
+ )
131
+ except Exception as e:
132
+ # If file upload fails, try to use a direct URL method with base64
133
+ print(f"Failed to upload PDF, trying alternate method: {str(e)}")
134
+ base64_pdf = self.encode_image(file_path)
135
+
136
+ if base64_pdf:
137
+ ocr_response = client.ocr.process(
138
+ model="mistral-ocr-latest",
139
+ document={
140
+ "type": "image_url",
141
+ "image_url": f"data:application/pdf;base64,{base64_pdf}"
142
+ }
143
+ )
144
+ else:
145
+ return "# Error\n\nFailed to process PDF document."
146
+ else:
147
+ # For images (jpg, png, etc.), use image_url with base64
148
+ base64_image = self.encode_image(file_path)
149
+ if not base64_image:
150
+ return "# Error\n\nFailed to encode the image."
151
+
152
+ mime_type = self._get_mime_type(file_extension)
153
+
154
+ ocr_response = client.ocr.process(
155
+ model="mistral-ocr-latest",
156
+ document={
157
+ "type": "image_url",
158
+ "image_url": f"data:{mime_type};base64,{base64_image}"
159
+ }
160
+ )
161
+
162
+ # Process the OCR response
163
+ # The Mistral OCR response is structured with pages that contain text content
164
+ markdown_text = ""
165
+
166
+ # Check if the response contains pages
167
+ if hasattr(ocr_response, 'pages') and ocr_response.pages:
168
+ for page in ocr_response.pages:
169
+ # Add page number as heading
170
+ page_num = page.index if hasattr(page, 'index') else "Unknown"
171
+ markdown_text += f"## Page {page_num}\n\n"
172
+
173
+ # Add text content if available
174
+ if hasattr(page, 'text'):
175
+ markdown_text += page.text + "\n\n"
176
+
177
+ # Or markdown content if that's how it's structured
178
+ elif hasattr(page, 'markdown'):
179
+ markdown_text += page.markdown + "\n\n"
180
+
181
+ # Add any extracted tables with markdown formatting
182
+ if hasattr(page, 'tables') and page.tables:
183
+ for i, table in enumerate(page.tables):
184
+ markdown_text += f"### Table {i+1}\n\n"
185
+ if hasattr(table, 'markdown'):
186
+ markdown_text += table.markdown + "\n\n"
187
+ elif hasattr(table, 'data'):
188
+ # Convert table data to markdown format
189
+ markdown_text += self._convert_table_to_markdown(table.data) + "\n\n"
190
+
191
+ # If no markdown was generated, check for raw content
192
+ if not markdown_text and hasattr(ocr_response, 'content'):
193
+ markdown_text = ocr_response.content
194
+
195
+ # If still no content, try to access any available data
196
+ if not markdown_text:
197
+ # Try to get a JSON representation to extract data
198
+ try:
199
+ response_dict = ocr_response.to_dict() if hasattr(ocr_response, 'to_dict') else ocr_response.__dict__
200
+ markdown_text = "# Extracted Content\n\n"
201
+
202
+ # Look for content or text in the response dictionary
203
+ if 'content' in response_dict:
204
+ markdown_text += response_dict['content']
205
+ elif 'text' in response_dict:
206
+ markdown_text += response_dict['text']
207
+ elif 'pages' in response_dict:
208
+ for page in response_dict['pages']:
209
+ if 'text' in page:
210
+ markdown_text += page['text'] + "\n\n"
211
+ else:
212
+ # Just dump what we got as JSON
213
+ markdown_text += f"```json\n{json.dumps(response_dict, indent=2)}\n```"
214
+ except Exception as e:
215
+ markdown_text = f"# Error Processing Response\n\nCould not process the OCR response: {str(e)}"
216
+
217
+ # If we still have no content, return an error
218
+ if not markdown_text:
219
+ return "# Error\n\nNo text was extracted from the document."
220
+
221
+ return f"# Document Content\n\n{markdown_text}"
222
+
223
+ except Exception as e:
224
+ return f"# OCR Extraction Error\n\n{str(e)}"
225
+
226
+ def _extract_with_document_understanding(self, client, file_path, file_extension):
227
+ """Extract and understand document content using chat completion."""
228
+ try:
229
+ # For PDFs and images, we'll use Mistral's document understanding capability
230
+ if file_extension in ['.pdf']:
231
+ # Upload PDF first
232
+ try:
233
+ # Upload the file
234
+ uploaded_pdf = client.files.upload(
235
+ file={
236
+ "file_name": file_path.name,
237
+ "content": open(file_path, "rb"),
238
+ },
239
+ purpose="ocr"
240
+ )
241
+
242
+ # Get the signed URL
243
+ signed_url = client.files.get_signed_url(file_id=uploaded_pdf.id)
244
+
245
+ # Send to chat completion API with document understanding prompt
246
+ chat_response = client.chat.complete(
247
+ model="mistral-large-latest",
248
+ messages=[
249
+ {
250
+ "role": "user",
251
+ "content": [
252
+ {
253
+ "type": "text",
254
+ "text": "Convert this document to well-formatted markdown. Preserve all important content, structure, headings, lists, and tables. Include brief descriptions of any images."
255
+ },
256
+ {
257
+ "type": "document_url",
258
+ "document_url": signed_url.url
259
+ }
260
+ ]
261
+ }
262
+ ]
263
+ )
264
+
265
+ # Get the markdown result
266
+ return chat_response.choices[0].message.content
267
+
268
+ except Exception as e:
269
+ # Fall back to OCR if document understanding fails
270
+ print(f"Document understanding failed, falling back to OCR: {str(e)}")
271
+ return self._extract_with_ocr(client, file_path, file_extension)
272
+
273
+ else:
274
+ # For images, encode to base64
275
+ base64_image = self.encode_image(file_path)
276
+ if not base64_image:
277
+ return "# Error\n\nFailed to encode the image."
278
+
279
+ mime_type = self._get_mime_type(file_extension)
280
+
281
+ # Use the chat API with the image for document understanding
282
+ chat_response = client.chat.complete(
283
+ model="mistral-large-latest",
284
+ messages=[
285
+ {
286
+ "role": "user",
287
+ "content": [
288
+ {
289
+ "type": "text",
290
+ "text": "Extract all text from this image and convert it to well-formatted markdown. Preserve the structure and layout as much as possible."
291
+ },
292
+ {
293
+ "type": "image_url",
294
+ "image_url": {
295
+ "url": f"data:{mime_type};base64,{base64_image}"
296
+ }
297
+ }
298
+ ]
299
+ }
300
+ ]
301
+ )
302
+
303
+ # Get the markdown result
304
+ return chat_response.choices[0].message.content
305
+
306
+ except Exception as e:
307
+ return f"# Document Understanding Error\n\n{str(e)}\n\nFalling back to OCR method."
308
+
309
+ def _get_mime_type(self, file_extension: str) -> str:
310
+ """Get the MIME type for a file extension."""
311
+ mime_types = {
312
+ ".pdf": "application/pdf",
313
+ ".jpg": "image/jpeg",
314
+ ".jpeg": "image/jpeg",
315
+ ".png": "image/png",
316
+ ".gif": "image/gif",
317
+ ".bmp": "image/bmp",
318
+ ".tiff": "image/tiff",
319
+ ".tif": "image/tiff",
320
+ }
321
+
322
+ return mime_types.get(file_extension, "application/octet-stream")
323
+
324
+ def _convert_table_to_markdown(self, table_data) -> str:
325
+ """Convert a table data structure to markdown format."""
326
+ if not table_data or not isinstance(table_data, list):
327
+ return ""
328
+
329
+ # Create markdown table
330
+ markdown = ""
331
+
332
+ # Add header row
333
+ if table_data and isinstance(table_data[0], list):
334
+ header = table_data[0]
335
+ markdown += "| " + " | ".join(str(cell) for cell in header) + " |\n"
336
+
337
+ # Add separator row
338
+ markdown += "| " + " | ".join(["---"] * len(header)) + " |\n"
339
+
340
+ # Add data rows
341
+ for row in table_data[1:]:
342
+ markdown += "| " + " | ".join(str(cell) for cell in row) + " |\n"
343
+
344
+ return markdown
345
+
346
+
347
+ # Register the parser with the registry
348
+ if MISTRAL_AVAILABLE:
349
+ ParserRegistry.register(MistralOcrParser)
350
+ else:
351
+ print("Mistral OCR parser not registered: mistralai package not installed")