File size: 4,608 Bytes
bda5a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3757f34
bda5a7d
 
 
 
 
 
 
 
 
 
 
3757f34
bda5a7d
 
 
 
 
 
3757f34
 
bda5a7d
 
3757f34
 
 
bda5a7d
 
 
 
 
 
 
 
3757f34
 
bda5a7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3757f34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bda5a7d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import tempfile
from transformers import pipeline, RobertaForSequenceClassification, RobertaTokenizer
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Request, HTTPException
import os
import json
from typing import Optional, Dict, List
import torch

# Initialize models
model_name = "cardiffnlp/twitter-roberta-base-emotion"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForSequenceClassification.from_pretrained(model_name)
emotion_analysis = pipeline("text-classification",
                          model=model, 
                          tokenizer=tokenizer,
                          return_all_scores=True)

app = FastAPI()

def save_upload_file(upload_file: UploadFile) -> str:
    """Save uploaded file to temporary location"""
    try:
        suffix = os.path.splitext(upload_file.filename)[1]
        with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
            content = upload_file.file.read()
            if suffix == '.json':
                content = content.decode('utf-8')  # Decode JSON files
            tmp.write(content if isinstance(content, bytes) else content.encode())
            return tmp.name
    finally:
        upload_file.file.close()

@app.post("/api/predict")
async def predict_from_upload(file: UploadFile = File(...)):
    """API endpoint for file uploads"""
    try:
        temp_path = save_upload_file(file)
        
        if temp_path.endswith('.json'):
            with open(temp_path, 'r') as f:
                data = json.load(f)
            text = data.get('description', '')
        else:
            with open(temp_path, 'r') as f:
                text = f.read()
        
        if not text.strip():
            raise HTTPException(status_code=400, detail="No text content found")
        
        result = emotion_analysis(text)
        emotions = [{'label': e['label'], 'score': float(e['score'])} 
                   for e in sorted(result[0], key=lambda x: x['score'], reverse=True)]
        
        os.unlink(temp_path)
        return {"success": True, "results": emotions}
        
    except Exception as e:
        if 'temp_path' in locals() and os.path.exists(temp_path):
            os.unlink(temp_path)
        raise HTTPException(status_code=500, detail=str(e))

# Modified gradio_predict to handle both input types correctly
def gradio_predict(input_data, file_data=None):
    """Handle both direct text and file uploads"""
    try:
        # Determine input source
        if file_data is not None:  # File upload takes precedence
            temp_path = save_upload_file(file_data)
            if temp_path.endswith('.json'):
                with open(temp_path, 'r') as f:
                    data = json.load(f)
                text = data.get('description', '')
            else:
                with open(temp_path, 'r') as f:
                    text = f.read()
            os.unlink(temp_path)
        else:  # Use direct text input
            text = input_data
        
        if not text.strip():
            return {"error": "No text content found"}
        
        result = emotion_analysis(text)
        return {
            "emotions": [
                {e['label']: float(e['score'])} 
                for e in sorted(result[0], key=lambda x: x['score'], reverse=True)
            ]
        }
    
    except Exception as e:
        return {"error": str(e)}

# Updated Gradio interface with proper input handling
with gr.Blocks() as demo:
    gr.Markdown("# Text Emotion Analysis")
    
    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(label="Enter text directly", lines=5)
            file_input = gr.File(label="Or upload file", file_types=[".txt", ".json"])
            submit_btn = gr.Button("Analyze")
        
        with gr.Column():
            output = gr.JSON(label="Results")
    
    # Handle both input methods
    submit_btn.click(
        fn=gradio_predict,
        inputs=[text_input, file_input],
        outputs=output,
        api_name="predict"
    )
    
    # Examples with both input types
    gr.Examples(
        examples=[
            ["I'm feeling excited about this new project!"],
            ["This situation makes me anxious and worried"]
        ],
        inputs=text_input
    )
    gr.Examples(
        examples=[
            ["example1.json"],
            ["example2.txt"]
        ],
        inputs=file_input,
        label="File Examples"
    )

app = gr.mount_gradio_app(app, demo, path="/")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)