Spaces:
Running
Running
"""from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import RedirectResponse, JSONResponse | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import tempfile | |
import torch | |
app = FastAPI() | |
# Load model | |
try: | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
USE_GIT = True | |
except Exception: | |
from transformers import pipeline | |
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
USE_GIT = False | |
def generate_caption(image_path): | |
try: | |
if USE_GIT: | |
image = Image.open(image_path) | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=50) | |
return processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
else: | |
result = captioner(image_path) | |
return result[0]['generated_text'] | |
except Exception as e: | |
return f"Error generating caption: {str(e)}" | |
@app.post("/imagecaption/") | |
async def caption_from_frontend(file: UploadFile = File(...)): | |
contents = await file.read() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
tmp.write(contents) | |
image_path = tmp.name | |
caption = generate_caption(image_path) | |
return JSONResponse({"caption": caption}) | |
@app.get("/") | |
def home(): | |
return RedirectResponse(url="/")""" | |
from fastapi import UploadFile | |
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline | |
from PIL import Image | |
import tempfile | |
import os | |
import torch | |
from gtts import gTTS | |
import uuid | |
# Load model | |
try: | |
processor = AutoProcessor.from_pretrained("microsoft/git-large-coco") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco") | |
USE_GIT = True | |
except Exception: | |
captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning") | |
USE_GIT = False | |
def generate_caption(image_path): | |
try: | |
if USE_GIT: | |
image = Image.open(image_path).convert("RGB") | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=50) | |
return processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
else: | |
result = captioner(image_path) | |
return result[0]['generated_text'] | |
except Exception as e: | |
return f"Error generating caption: {str(e)}" | |
async def caption_image(file: UploadFile): | |
try: | |
# Get file extension correctly | |
_, ext = os.path.splitext(file.filename) | |
if ext.lower() not in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]: | |
return {"error": "Unsupported file type"} | |
# Save the uploaded image with correct extension | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp: | |
contents = await file.read() | |
tmp.write(contents) | |
tmp_path = tmp.name | |
# Generate caption | |
caption = generate_caption(tmp_path) | |
os.remove(tmp_path) | |
# Handle errors inside generate_caption | |
if caption.startswith("Error"): | |
return {"error": caption} | |
# Now generate TTS audio for the caption | |
tts = gTTS(text=caption, lang="en") | |
audio_filename = f"{uuid.uuid4()}.mp3" | |
audio_path = os.path.join(tempfile.gettempdir(), audio_filename) | |
tts.save(audio_path) | |
# Return both caption and audio URL | |
return { | |
"caption": caption, | |
"audio": f"/files/{audio_filename}" | |
} | |
except Exception as e: | |
return {"error": f"Failed to generate caption: {str(e)}"} | |