text-process / app.py
Rivalcoder
Add files
bda5a7d
raw
history blame
4.23 kB
import tempfile
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Request, HTTPException
import os
import json
from typing import Optional, Dict, List
import torch
# Initialize models
model_name = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name)
emotion_analysis = pipeline("text-classification",
model=model,
tokenizer=tokenizer,
return_all_scores=True)
# Create FastAPI app
app = FastAPI()
def save_upload_file(upload_file: UploadFile) -> str:
"""Save uploaded file to temporary location"""
try:
suffix = os.path.splitext(upload_file.filename)[1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
content = upload_file.file.read()
if suffix == '.json':
content = content.decode('utf-8') # Decode JSON files
tmp.write(content if isinstance(content, bytes) else content.encode())
return tmp.name
finally:
upload_file.file.close()
@app.post("/api/predict")
async def predict_from_upload(file: UploadFile = File(...)):
"""API endpoint for file uploads"""
try:
# Save the uploaded file temporarily
temp_path = save_upload_file(file)
# Process based on file type
if temp_path.endswith('.json'):
with open(temp_path, 'r') as f:
data = json.load(f)
text = data.get('description', '')
else: # Assume text file
with open(temp_path, 'r') as f:
text = f.read()
if not text.strip():
raise HTTPException(status_code=400, detail="No text content found")
# Analyze text
result = emotion_analysis(text)
emotions = [{'label': e['label'], 'score': float(e['score'])}
for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
# Clean up
os.unlink(temp_path)
return {
"success": True,
"results": emotions
}
except Exception as e:
if 'temp_path' in locals() and os.path.exists(temp_path):
os.unlink(temp_path)
raise HTTPException(status_code=500, detail=str(e))
# Gradio interface
def gradio_predict(input_data):
"""Handle both direct text and file uploads"""
try:
if isinstance(input_data, str): # Direct text input
text = input_data
else: # File upload
temp_path = save_upload_file(input_data)
if temp_path.endswith('.json'):
with open(temp_path, 'r') as f:
data = json.load(f)
text = data.get('description', '')
else:
with open(temp_path, 'r') as f:
text = f.read()
os.unlink(temp_path)
if not text.strip():
return {"error": "No text content found"}
result = emotion_analysis(text)
return {
"emotions": [
{e['label']: float(e['score'])}
for e in sorted(result[0], key=lambda x: x['score'], reverse=True)
]
}
except Exception as e:
return {"error": str(e)}
# Create Gradio interface
demo = gr.Interface(
fn=gradio_predict,
inputs=[
gr.Textbox(label="Enter text directly", lines=5),
gr.File(label="Or upload text/JSON file", file_types=[".txt", ".json"])
],
outputs=gr.JSON(label="Emotion Analysis"),
title="Text Emotion Analysis",
description="Analyze emotion in text using RoBERTa model",
examples=[
["I'm feeling absolutely thrilled about this new project!"],
["This situation is making me extremely anxious and worried."]
]
)
# Mount Gradio app
app = gr.mount_gradio_app(app, demo, path="/")
# For running locally
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)