import os import time import tempfile import jinja2 import pdfkit import torch import logging import subprocess from threading import Thread from flask import Flask, request, send_file, jsonify from flask_cors import CORS from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # Configure cache directories os.environ['HF_HOME'] = '/app/.cache' os.environ['XDG_CACHE_HOME'] = '/app/.cache' # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s' ) # Initialize Flask app app = Flask(__name__) CORS(app) # Global state tracking model_loaded = False load_error = None generator = None # Find wkhtmltopdf path WKHTMLTOPDF_PATH = '/usr/bin/wkhtmltopdf' if not os.path.exists(WKHTMLTOPDF_PATH): # Try to find it using which try: WKHTMLTOPDF_PATH = subprocess.check_output(['which', 'wkhtmltopdf']).decode().strip() except: app.logger.warning("Could not find wkhtmltopdf path. Using default.") WKHTMLTOPDF_PATH = 'wkhtmltopdf' # Configure wkhtmltopdf pdf_config = pdfkit.configuration(wkhtmltopdf=WKHTMLTOPDF_PATH) def load_model(): global model_loaded, load_error, generator try: app.logger.info("Starting model loading process") # Detect device and dtype automatically dtype = torch.float16 if torch.cuda.is_available() else torch.float32 device = "cuda" if torch.cuda.is_available() else "cpu" app.logger.info(f"Device set to use {device}") model = AutoModelForCausalLM.from_pretrained( "gpt2", use_safetensors=True, device_map="auto", torch_dtype=dtype, low_cpu_mem_usage=True, offload_folder="offload" ) tokenizer = AutoTokenizer.from_pretrained("gpt2") # Initialize pipeline without explicit device assignment generator = pipeline( 'text-generation', model=model, tokenizer=tokenizer, torch_dtype=dtype ) model_loaded = True app.logger.info(f"Model loaded successfully on {model.device}") except Exception as e: load_error = str(e) app.logger.error(f"Model loading failed: {load_error}", exc_info=True) # Start model loading in background thread Thread(target=load_model).start() # -------------------------------------------------- # IEEE Format Template # -------------------------------------------------- IEEE_TEMPLATE = """ {{ title }}

{{ title }}

{% for author in authors %} {{ author.name }}
{% if author.institution %}{{ author.institution }}
{% endif %} {% if author.email %}Email: {{ author.email }}{% endif %} {% if not loop.last %}
{% endif %} {% endfor %}

Abstract

{{ abstract }}
Keywords— {{ keywords }}
{% for section in sections %}

{{ section.title }}

{{ section.content }} {% endfor %}

References

{% for ref in references %}
[{{ loop.index }}] {{ ref }}
{% endfor %}
""" # -------------------------------------------------- # API Endpoints # -------------------------------------------------- @app.route('/health', methods=['GET']) def health_check(): return jsonify({ "status": "ok", "model_loaded": model_loaded, "device": "cuda" if torch.cuda.is_available() else "cpu" }), 200 app.logger.info(f"Health check returning status: {'ready' if model_loaded else 'loading'}, device: {device_info}") return jsonify({ "status": "ready" if model_loaded else "loading", "model_loaded": model_loaded, "device": device_info }), status_code @app.route('/generate', methods=['POST']) def generate_pdf(): # Check model status if not model_loaded: app.logger.error("PDF generation requested but model not loaded") return jsonify({ "error": "Model not loaded yet", "status": "loading" }), 503 try: app.logger.info("Processing PDF generation request") # Validate input data = request.json if not data: app.logger.error("No data provided in request") return jsonify({"error": "No data provided"}), 400 required = ['title', 'authors', 'content'] if missing := [field for field in required if field not in data]: app.logger.error(f"Missing required fields: {missing}") return jsonify({ "error": f"Missing fields: {', '.join(missing)}" }), 400 app.logger.info(f"Received request with title: {data['title']}") # Format content with model app.logger.info("Formatting content using the model") formatted = format_content(data['content']) app.logger.info("Creating HTML from template") # Generate HTML html = jinja2.Template(IEEE_TEMPLATE).render( title=data['title'], authors=data['authors'], abstract=formatted.get('abstract', ''), keywords=', '.join(formatted.get('keywords', [])), sections=formatted.get('sections', []), references=formatted.get('references', []) ) # PDF options options = { 'page-size': 'Letter', 'margin-top': '0.75in', 'margin-right': '0.75in', 'margin-bottom': '0.75in', 'margin-left': '0.75in', 'encoding': 'UTF-8', 'quiet': '' } # Create temporary PDF app.logger.info("Generating PDF file") pdf_path = None try: with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as f: pdf_path = f.name # Generate PDF using xvfb-run as a separate process html_path = pdf_path + '.html' with open(html_path, 'w', encoding='utf-8') as f: f.write(html) command = ['xvfb-run', '-a', WKHTMLTOPDF_PATH] + \ [f'--{k}={v}' for k, v in options.items() if v] + \ [html_path, pdf_path] app.logger.info(f"Running command: {' '.join(command)}") result = subprocess.run(command, capture_output=True, text=True) if result.returncode != 0: app.logger.error(f"PDF generation command failed: {result.stderr}") # Fallback to direct pdfkit if available app.logger.info("Trying fallback PDF generation with pdfkit") pdfkit.from_string(html, pdf_path, options=options, configuration=pdf_config) # Clean up HTML file os.remove(html_path) app.logger.info(f"PDF generated successfully at {pdf_path}") return send_file(pdf_path, mimetype='application/pdf', as_attachment=True, download_name=f"{data['title'].replace(' ', '_')}.pdf") except Exception as e: app.logger.error(f"PDF generation failed: {str(e)}", exc_info=True) raise except Exception as e: app.logger.error(f"Request processing failed: {str(e)}", exc_info=True) return jsonify({"error": str(e)}), 500 finally: # Clean up temporary file if 'pdf_path' in locals() and pdf_path: try: app.logger.info(f"Cleaning up temporary file {pdf_path}") os.remove(pdf_path) except Exception as e: app.logger.warning(f"Failed to remove temporary file: {str(e)}") # -------------------------------------------------- # Content Formatting # -------------------------------------------------- def parse_formatted_content(text): """Parse the generated text into structured sections""" app.logger.info("Parsing formatted content") try: lines = text.split('\n') # Default structure result = { 'abstract': '', 'keywords': ['IEEE', 'format', 'research', 'paper'], 'sections': [], 'references': [] } # Extract abstract (simple approach - first paragraph after "Abstract") abstract_start = None for i, line in enumerate(lines): if line.strip().lower() == 'abstract': abstract_start = i + 1 break if abstract_start: abstract_text = [] i = abstract_start while i < len(lines) and not lines[i].strip().lower().startswith('keyword'): if lines[i].strip(): abstract_text.append(lines[i].strip()) i += 1 result['abstract'] = ' '.join(abstract_text) # Extract keywords for line in lines: if line.strip().lower().startswith('keyword'): # Extract keywords from the line keyword_parts = line.split('—') if len(keyword_parts) > 1: keywords = keyword_parts[1].strip().split(',') result['keywords'] = [k.strip() for k in keywords if k.strip()] break # Extract sections current_section = None section_content = [] # Skip lines until we find a section heading started = False for line in lines: # Very basic heuristic for Roman numerals section headings if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()): started = True if not started: continue if line.strip() and (line.strip()[0].isupper() or line.strip()[0].isdigit()) and len(line.strip().split()) <= 6: # This is likely a section heading if current_section: # Save the previous section result['sections'].append({ 'title': current_section, 'content': '\n'.join(section_content) }) section_content = [] current_section = line.strip() elif current_section and line.strip().lower() == 'references': # We've reached the references section if current_section: # Save the last section result['sections'].append({ 'title': current_section, 'content': '\n'.join(section_content) }) break elif current_section: # Add to current section content section_content.append(line) # Extract references in_references = False for line in lines: if line.strip().lower() == 'references': in_references = True continue if in_references and line.strip(): result['references'].append(line.strip()) app.logger.info(f"Content parsed into {len(result['sections'])} sections and {len(result['references'])} references") return result except Exception as e: app.logger.error(f"Error parsing formatted content: {str(e)}", exc_info=True) # Return a basic structure if parsing fails return { 'abstract': 'Error parsing content.', 'keywords': ['IEEE', 'format'], 'sections': [{'title': 'Content', 'content': text}], 'references': [] } def format_content(content): """Format the content using the ML model""" try: app.logger.info("Formatting content with ML model") prompt = f"Format this research content to IEEE standards with sections, abstract, and references:\n\n{str(content)}" response = generator( prompt, max_new_tokens=1024, # Increased for more complete generation temperature=0.5, # More deterministic output do_sample=True, truncation=True, num_return_sequences=1 ) generated_text = response[0]['generated_text'] # Remove the prompt from the generated text if prompt in generated_text: formatted_text = generated_text[len(prompt):].strip() else: formatted_text = generated_text app.logger.info("Content formatted successfully") # Parse the formatted text into structured sections return parse_formatted_content(formatted_text) except Exception as e: app.logger.error(f"Error formatting content: {str(e)}", exc_info=True) # Return the original content if formatting fails return { 'abstract': 'Content processing error.', 'keywords': ['IEEE', 'format'], 'sections': [{'title': 'Content', 'content': str(content)}], 'references': [] } if __name__ == '__main__': app.run(host='0.0.0.0', port=5000)