|
import os |
|
import time |
|
import tempfile |
|
import jinja2 |
|
import pdfkit |
|
import torch |
|
import logging |
|
import subprocess |
|
from threading import Thread |
|
from flask import Flask, request, send_file, jsonify |
|
from flask_cors import CORS |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
|
|
|
os.environ['HF_HOME'] = '/app/.cache' |
|
os.environ['XDG_CACHE_HOME'] = '/app/.cache' |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s [%(levelname)s] %(message)s' |
|
) |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
model_loaded = False |
|
load_error = None |
|
generator = None |
|
|
|
|
|
WKHTMLTOPDF_PATH = '/usr/bin/wkhtmltopdf' |
|
if not os.path.exists(WKHTMLTOPDF_PATH): |
|
|
|
try: |
|
WKHTMLTOPDF_PATH = subprocess.check_output(['which', 'wkhtmltopdf']).decode().strip() |
|
except: |
|
app.logger.warning("Could not find wkhtmltopdf path. Using default.") |
|
WKHTMLTOPDF_PATH = 'wkhtmltopdf' |
|
|
|
|
|
pdf_config = pdfkit.configuration(wkhtmltopdf=WKHTMLTOPDF_PATH) |
|
|
|
def load_model(): |
|
global model_loaded, load_error, generator |
|
try: |
|
app.logger.info("Starting model loading process") |
|
|
|
|
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
app.logger.info(f"Device set to use {device}") |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"gpt2", |
|
use_safetensors=True, |
|
device_map="auto", |
|
torch_dtype=dtype, |
|
low_cpu_mem_usage=True, |
|
offload_folder="offload" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
|
|
|
|
generator = pipeline( |
|
'text-generation', |
|
model=model, |
|
tokenizer=tokenizer, |
|
torch_dtype=dtype |
|
) |
|
|
|
model_loaded = True |
|
app.logger.info(f"Model loaded successfully on {model.device}") |
|
|
|
except Exception as e: |
|
load_error = str(e) |
|
app.logger.error(f"Model loading failed: {load_error}", exc_info=True) |
|
|
|
|
|
Thread(target=load_model).start() |
|
|
|
|
|
|
|
|
|
IEEE_TEMPLATE = """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<title>{{ title }}</title> |
|
<style> |
|
@page { margin: 0.75in; } |
|
body { |
|
font-family: 'Times New Roman', Times, serif; |
|
font-size: 12pt; |
|
line-height: 1.5; |
|
} |
|
.header { text-align: center; margin-bottom: 24pt; } |
|
.two-column { column-count: 2; column-gap: 0.5in; } |
|
h1 { font-size: 14pt; margin: 12pt 0; } |
|
h2 { font-size: 12pt; margin: 12pt 0 6pt 0; } |
|
.abstract { margin-bottom: 24pt; } |
|
.keywords { font-weight: bold; margin: 12pt 0; } |
|
.references { margin-top: 24pt; } |
|
.reference-item { text-indent: -0.5in; padding-left: 0.5in; } |
|
</style> |
|
</head> |
|
<body> |
|
<div class="header"> |
|
<h1>{{ title }}</h1> |
|
<div class="author-info"> |
|
{% for author in authors %} |
|
{{ author.name }}<br> |
|
{% if author.institution %}{{ author.institution }}<br>{% endif %} |
|
{% if author.email %}Email: {{ author.email }}{% endif %} |
|
{% if not loop.last %}<br>{% endif %} |
|
{% endfor %} |
|
</div> |
|
</div> |
|
|
|
<div class="abstract"> |
|
<h2>Abstract</h2> |
|
{{ abstract }} |
|
<div class="keywords">Keywords— {{ keywords }}</div> |
|
</div> |
|
<div class="two-column"> |
|
{% for section in sections %} |
|
<h2>{{ section.title }}</h2> |
|
{{ section.content }} |
|
{% endfor %} |
|
</div> |
|
<div class="references"> |
|
<h2>References</h2> |
|
{% for ref in references %} |
|
<div class="reference-item">[{{ loop.index }}] {{ ref }}</div> |
|
{% endfor %} |
|
</div> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
|
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
return jsonify({ |
|
"status": "ok", |
|
"model_loaded": model_loaded, |
|
"device": "cuda" if torch.cuda.is_available() else "cpu" |
|
}), 200 |
|
|
|
app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}") |
|
return jsonify({ |
|
"status": "ready" if model_loaded else "loading", |
|
"model_loaded": model_loaded, |
|
"device": device_info |
|
}), status_code |
|
|
|
@app.route('/generate', methods=['POST']) |
|
def generate_pdf(): |
|
|
|
if not model_loaded: |
|
app.logger.error("PDF generation requested but model not loaded") |
|
return jsonify({ |
|
"error": "Model not loaded yet", |
|
"status": "loading" |
|
}), 503 |
|
|
|
try: |
|
app.logger.info("Processing PDF generation request") |
|
|
|
|
|
data = request.json |
|
if not data: |
|
app.logger.error("No data provided in request") |
|
return jsonify({"error": "No data provided"}), 400 |
|
|
|
required = ['title', 'authors', 'content'] |
|
if missing := [field for field in required if field not in data]: |
|
app.logger.error(f"Missing required fields: {missing}") |
|
return jsonify({ |
|
"error": f"Missing fields: {', '.join(missing)}" |
|
}), 400 |
|
|
|
app.logger.info(f"Received request with title: {data['title']}") |
|
|
|
|
|
app.logger.info("Formatting content using the model") |
|
formatted = format_content(data['content']) |
|
|
|
app.logger.info("Creating HTML from template") |
|
|
|
html = jinja2.Template(IEEE_TEMPLATE).render( |
|
title=data['title'], |
|
authors=data['authors'], |
|
abstract=formatted.get('abstract', ''), |
|
keywords=', '.join(formatted.get('keywords', [])), |
|
sections=formatted.get('sections', []), |
|
references=formatted.get('references', []) |
|
) |
|
|
|
|
|
options = { |
|
'page-size': 'Letter', |
|
'margin-top': '0.75in', |
|
'margin-right': '0.75in', |
|
'margin-bottom': '0.75in', |
|
'margin-left': '0.75in', |
|
'encoding': 'UTF-8', |
|
'quiet': '' |
|
} |
|
|
|
|
|
app.logger.info("Generating PDF file") |
|
pdf_path = None |
|
|
|
try: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f: |
|
pdf_path = f.name |
|
|
|
|
|
html_path = pdf_path + '.html' |
|
with open(html_path, 'w', encoding='utf-8') as f: |
|
f.write(html) |
|
|
|
command = ['xvfb-run', '-a', WKHTMLTOPDF_PATH] + \ |
|
[f'--{k}={v}' for k, v in options.items() if v] + \ |
|
[html_path, pdf_path] |
|
|
|
app.logger.info(f"Running command: {' '.join(command)}") |
|
result = subprocess.run(command, capture_output=True, text=True) |
|
|
|
if result.returncode != 0: |
|
app.logger.error(f"PDF generation command failed: {result.stderr}") |
|
|
|
app.logger.info("Trying fallback PDF generation with pdfkit") |
|
pdfkit.from_string(html, pdf_path, options=options, configuration=pdf_config) |
|
|
|
|
|
os.remove(html_path) |
|
|
|
app.logger.info(f"PDF generated successfully at {pdf_path}") |
|
return send_file(pdf_path, mimetype='application/pdf', as_attachment=True, |
|
download_name=f"{data['title'].replace(' ', '_')}.pdf") |
|
|
|
except Exception as e: |
|
app.logger.error(f"PDF generation failed: {str(e)}", exc_info=True) |
|
raise |
|
|
|
except Exception as e: |
|
app.logger.error(f"Request processing failed: {str(e)}", exc_info=True) |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
finally: |
|
|
|
if 'pdf_path' in locals() and pdf_path: |
|
try: |
|
app.logger.info(f"Cleaning up temporary file {pdf_path}") |
|
os.remove(pdf_path) |
|
except Exception as e: |
|
app.logger.warning(f"Failed to remove temporary file: {str(e)}") |
|
|
|
|
|
|
|
|
|
def parse_formatted_content(text): |
|
"""Parse the generated text into structured sections""" |
|
app.logger.info("Parsing formatted content") |
|
|
|
try: |
|
lines = text.split('\n') |
|
|
|
|
|
result = { |
|
'abstract': '', |
|
'keywords': ['IEEE', 'format', 'research', 'paper'], |
|
'sections': [], |
|
'references': [] |
|
} |
|
|
|
|
|
abstract_start = None |
|
for i, line in enumerate(lines): |
|
if line.strip().lower() == 'abstract': |
|
abstract_start = i + 1 |
|
break |
|
|
|
if abstract_start: |
|
abstract_text = [] |
|
i = abstract_start |
|
while i < len(lines) and not lines[i].strip().lower().startswith('keyword'): |
|
if lines[i].strip(): |
|
abstract_text.append(lines[i].strip()) |
|
i += 1 |
|
result['abstract'] = ' '.join(abstract_text) |
|
|
|
|
|
for line in lines: |
|
if line.strip().lower().startswith('keyword'): |
|
|
|
keyword_parts = line.split('—') |
|
if len(keyword_parts) > 1: |
|
keywords = keyword_parts[1].strip().split(',') |
|
result['keywords'] = [k.strip() for k in keywords if k.strip()] |
|
break |
|
|
|
|
|
current_section = None |
|
section_content = [] |
|
|
|
|
|
started = False |
|
for line in lines: |
|
|
|
if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()): |
|
started = True |
|
if not started: |
|
continue |
|
|
|
if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()) and len(line.strip().split()) <= 6: |
|
|
|
if current_section: |
|
|
|
result['sections'].append({ |
|
'title': current_section, |
|
'content': '\n'.join(section_content) |
|
}) |
|
section_content = [] |
|
|
|
current_section = line.strip() |
|
elif current_section and line.strip().lower() == 'references': |
|
|
|
if current_section: |
|
|
|
result['sections'].append({ |
|
'title': current_section, |
|
'content': '\n'.join(section_content) |
|
}) |
|
break |
|
elif current_section: |
|
|
|
section_content.append(line) |
|
|
|
|
|
in_references = False |
|
for line in lines: |
|
if line.strip().lower() == 'references': |
|
in_references = True |
|
continue |
|
|
|
if in_references and line.strip(): |
|
result['references'].append(line.strip()) |
|
|
|
app.logger.info(f"Content parsed into {len(result['sections'])} sections and {len(result['references'])} references") |
|
return result |
|
|
|
except Exception as e: |
|
app.logger.error(f"Error parsing formatted content: {str(e)}", exc_info=True) |
|
|
|
return { |
|
'abstract': 'Error parsing content.', |
|
'keywords': ['IEEE', 'format'], |
|
'sections': [{'title': 'Content', 'content': text}], |
|
'references': [] |
|
} |
|
|
|
def format_content(content): |
|
"""Format the content using the ML model""" |
|
try: |
|
app.logger.info("Formatting content with ML model") |
|
prompt = f"Format this research content to IEEE standards with sections, abstract, and references:\n\n{str(content)}" |
|
|
|
response = generator( |
|
prompt, |
|
max_new_tokens=1024, |
|
temperature=0.5, |
|
do_sample=True, |
|
truncation=True, |
|
num_return_sequences=1 |
|
) |
|
|
|
generated_text = response[0]['generated_text'] |
|
|
|
|
|
if prompt in generated_text: |
|
formatted_text = generated_text[len(prompt):].strip() |
|
else: |
|
formatted_text = generated_text |
|
|
|
app.logger.info("Content formatted successfully") |
|
|
|
|
|
return parse_formatted_content(formatted_text) |
|
|
|
except Exception as e: |
|
app.logger.error(f"Error formatting content: {str(e)}", exc_info=True) |
|
|
|
return { |
|
'abstract': 'Content processing error.', |
|
'keywords': ['IEEE', 'format'], |
|
'sections': [{'title': 'Content', 'content': str(content)}], |
|
'references': [] |
|
} |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=5000) |