import os
import time
import tempfile
import jinja2
import pdfkit
import torch
import logging
import subprocess
from threading import Thread
from flask import Flask, request, send_file, jsonify
from flask_cors import CORS
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
# Configure cache directories
os.environ['HF_HOME'] = '/app/.cache'
os.environ['XDG_CACHE_HOME'] = '/app/.cache'
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s'
)
# Initialize Flask app
app = Flask(__name__)
CORS(app)
# Global state tracking
model_loaded = False
load_error = None
generator = None
# Find wkhtmltopdf path
WKHTMLTOPDF_PATH = '/usr/bin/wkhtmltopdf'
if not os.path.exists(WKHTMLTOPDF_PATH):
# Try to find it using which
try:
WKHTMLTOPDF_PATH = subprocess.check_output(['which', 'wkhtmltopdf']).decode().strip()
except:
app.logger.warning("Could not find wkhtmltopdf path. Using default.")
WKHTMLTOPDF_PATH = 'wkhtmltopdf'
# Configure wkhtmltopdf
pdf_config = pdfkit.configuration(wkhtmltopdf=WKHTMLTOPDF_PATH)
def load_model():
global model_loaded, load_error, generator
try:
app.logger.info("Starting model loading process")
# Detect device and dtype automatically
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
device = "cuda" if torch.cuda.is_available() else "cpu"
app.logger.info(f"Device set to use {device}")
model = AutoModelForCausalLM.from_pretrained(
"gpt2",
use_safetensors=True,
device_map="auto",
torch_dtype=dtype,
low_cpu_mem_usage=True,
offload_folder="offload"
)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Initialize pipeline without explicit device assignment
generator = pipeline(
'text-generation',
model=model,
tokenizer=tokenizer,
torch_dtype=dtype
)
model_loaded = True
app.logger.info(f"Model loaded successfully on {model.device}")
except Exception as e:
load_error = str(e)
app.logger.error(f"Model loading failed: {load_error}", exc_info=True)
# Start model loading in background thread
Thread(target=load_model).start()
# --------------------------------------------------
# IEEE Format Template
# --------------------------------------------------
IEEE_TEMPLATE = """
{{ title }}
Abstract
{{ abstract }}
Keywords— {{ keywords }}
{% for section in sections %}
{{ section.title }}
{{ section.content }}
{% endfor %}
References
{% for ref in references %}
[{{ loop.index }}] {{ ref }}
{% endfor %}
"""
# --------------------------------------------------
# API Endpoints
# --------------------------------------------------
@app.route('/health', methods=['GET'])
def health_check():
return jsonify({
"status": "ok",
"model_loaded": model_loaded,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}), 200
app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}")
return jsonify({
"status": "ready" if model_loaded else "loading",
"model_loaded": model_loaded,
"device": device_info
}), status_code
@app.route('/generate', methods=['POST'])
def generate_pdf():
# Check model status
if not model_loaded:
app.logger.error("PDF generation requested but model not loaded")
return jsonify({
"error": "Model not loaded yet",
"status": "loading"
}), 503
try:
app.logger.info("Processing PDF generation request")
# Validate input
data = request.json
if not data:
app.logger.error("No data provided in request")
return jsonify({"error": "No data provided"}), 400
required = ['title', 'authors', 'content']
if missing := [field for field in required if field not in data]:
app.logger.error(f"Missing required fields: {missing}")
return jsonify({
"error": f"Missing fields: {', '.join(missing)}"
}), 400
app.logger.info(f"Received request with title: {data['title']}")
# Format content with model
app.logger.info("Formatting content using the model")
formatted = format_content(data['content'])
app.logger.info("Creating HTML from template")
# Generate HTML
html = jinja2.Template(IEEE_TEMPLATE).render(
title=data['title'],
authors=data['authors'],
abstract=formatted.get('abstract', ''),
keywords=', '.join(formatted.get('keywords', [])),
sections=formatted.get('sections', []),
references=formatted.get('references', [])
)
# PDF options
options = {
'page-size': 'Letter',
'margin-top': '0.75in',
'margin-right': '0.75in',
'margin-bottom': '0.75in',
'margin-left': '0.75in',
'encoding': 'UTF-8',
'quiet': ''
}
# Create temporary PDF
app.logger.info("Generating PDF file")
pdf_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f:
pdf_path = f.name
# Generate PDF using xvfb-run as a separate process
html_path = pdf_path + '.html'
with open(html_path, 'w', encoding='utf-8') as f:
f.write(html)
command = ['xvfb-run', '-a', WKHTMLTOPDF_PATH] + \
[f'--{k}={v}' for k, v in options.items() if v] + \
[html_path, pdf_path]
app.logger.info(f"Running command: {' '.join(command)}")
result = subprocess.run(command, capture_output=True, text=True)
if result.returncode != 0:
app.logger.error(f"PDF generation command failed: {result.stderr}")
# Fallback to direct pdfkit if available
app.logger.info("Trying fallback PDF generation with pdfkit")
pdfkit.from_string(html, pdf_path, options=options, configuration=pdf_config)
# Clean up HTML file
os.remove(html_path)
app.logger.info(f"PDF generated successfully at {pdf_path}")
return send_file(pdf_path, mimetype='application/pdf', as_attachment=True,
download_name=f"{data['title'].replace(' ', '_')}.pdf")
except Exception as e:
app.logger.error(f"PDF generation failed: {str(e)}", exc_info=True)
raise
except Exception as e:
app.logger.error(f"Request processing failed: {str(e)}", exc_info=True)
return jsonify({"error": str(e)}), 500
finally:
# Clean up temporary file
if 'pdf_path' in locals() and pdf_path:
try:
app.logger.info(f"Cleaning up temporary file {pdf_path}")
os.remove(pdf_path)
except Exception as e:
app.logger.warning(f"Failed to remove temporary file: {str(e)}")
# --------------------------------------------------
# Content Formatting
# --------------------------------------------------
def parse_formatted_content(text):
"""Parse the generated text into structured sections"""
app.logger.info("Parsing formatted content")
try:
lines = text.split('\n')
# Default structure
result = {
'abstract': '',
'keywords': ['IEEE', 'format', 'research', 'paper'],
'sections': [],
'references': []
}
# Extract abstract (simple approach - first paragraph after "Abstract")
abstract_start = None
for i, line in enumerate(lines):
if line.strip().lower() == 'abstract':
abstract_start = i + 1
break
if abstract_start:
abstract_text = []
i = abstract_start
while i < len(lines) and not lines[i].strip().lower().startswith('keyword'):
if lines[i].strip():
abstract_text.append(lines[i].strip())
i += 1
result['abstract'] = ' '.join(abstract_text)
# Extract keywords
for line in lines:
if line.strip().lower().startswith('keyword'):
# Extract keywords from the line
keyword_parts = line.split('—')
if len(keyword_parts) > 1:
keywords = keyword_parts[1].strip().split(',')
result['keywords'] = [k.strip() for k in keywords if k.strip()]
break
# Extract sections
current_section = None
section_content = []
# Skip lines until we find a section heading
started = False
for line in lines:
# Very basic heuristic for Roman numerals section headings
if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()):
started = True
if not started:
continue
if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()) and len(line.strip().split()) <= 6:
# This is likely a section heading
if current_section:
# Save the previous section
result['sections'].append({
'title': current_section,
'content': '\n'.join(section_content)
})
section_content = []
current_section = line.strip()
elif current_section and line.strip().lower() == 'references':
# We've reached the references section
if current_section:
# Save the last section
result['sections'].append({
'title': current_section,
'content': '\n'.join(section_content)
})
break
elif current_section:
# Add to current section content
section_content.append(line)
# Extract references
in_references = False
for line in lines:
if line.strip().lower() == 'references':
in_references = True
continue
if in_references and line.strip():
result['references'].append(line.strip())
app.logger.info(f"Content parsed into {len(result['sections'])} sections and {len(result['references'])} references")
return result
except Exception as e:
app.logger.error(f"Error parsing formatted content: {str(e)}", exc_info=True)
# Return a basic structure if parsing fails
return {
'abstract': 'Error parsing content.',
'keywords': ['IEEE', 'format'],
'sections': [{'title': 'Content', 'content': text}],
'references': []
}
def format_content(content):
"""Format the content using the ML model"""
try:
app.logger.info("Formatting content with ML model")
prompt = f"Format this research content to IEEE standards with sections, abstract, and references:\n\n{str(content)}"
response = generator(
prompt,
max_new_tokens=1024, # Increased for more complete generation
temperature=0.5, # More deterministic output
do_sample=True,
truncation=True,
num_return_sequences=1
)
generated_text = response[0]['generated_text']
# Remove the prompt from the generated text
if prompt in generated_text:
formatted_text = generated_text[len(prompt):].strip()
else:
formatted_text = generated_text
app.logger.info("Content formatted successfully")
# Parse the formatted text into structured sections
return parse_formatted_content(formatted_text)
except Exception as e:
app.logger.error(f"Error formatting content: {str(e)}", exc_info=True)
# Return the original content if formatting fails
return {
'abstract': 'Content processing error.',
'keywords': ['IEEE', 'format'],
'sections': [{'title': 'Content', 'content': str(content)}],
'references': []
}
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)