File size: 1,499 Bytes
32dd4d2
 
7cab805
d5d3aa6
32dd4d2
d5d3aa6
 
 
 
32dd4d2
d5d3aa6
 
32dd4d2
d5d3aa6
32dd4d2
 
d5d3aa6
 
 
7cab805
d5d3aa6
 
7cab805
d5d3aa6
32dd4d2
7cab805
d5d3aa6
 
7cab805
d5d3aa6
32dd4d2
d5d3aa6
32dd4d2
 
 
 
 
 
d5d3aa6
32dd4d2
 
d5d3aa6
7cab805
32dd4d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import RedirectResponse, JSONResponse
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import tempfile
import torch

app = FastAPI()

# Load model
try:
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
    USE_GIT = True
except Exception:
    from transformers import pipeline
    captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
    USE_GIT = False

def generate_caption(image_path):
    try:
        if USE_GIT:
            image = Image.open(image_path)
            inputs = processor(images=image, return_tensors="pt")
            outputs = model.generate(**inputs, max_length=50)
            return processor.batch_decode(outputs, skip_special_tokens=True)[0]
        else:
            result = captioner(image_path)
            return result[0]['generated_text']
    except Exception as e:
        return f"Error generating caption: {str(e)}"

@app.post("/imagecaption/")
async def caption_from_frontend(file: UploadFile = File(...)):
    contents = await file.read()
    with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
        tmp.write(contents)
        image_path = tmp.name

    caption = generate_caption(image_path)
    return JSONResponse({"caption": caption})

@app.get("/")
def home():
    return RedirectResponse(url="/")