ikraamkb commited on
Commit
f018781
·
verified ·
1 Parent(s): 4b810e0

Update appImage.py

Browse files
Files changed (1) hide show
  1. appImage.py +20 -9
appImage.py CHANGED
@@ -43,12 +43,11 @@ async def caption_from_frontend(file: UploadFile = File(...)):
43
  @app.get("/")
44
  def home():
45
  return RedirectResponse(url="/")"""
46
- # appImage.py
47
- from fastapi import UploadFile, File
48
- from transformers import AutoProcessor, AutoModelForCausalLM
49
- from transformers import pipeline
50
  from PIL import Image
51
  import tempfile
 
52
  import torch
53
 
54
  # Load model
@@ -63,7 +62,7 @@ except Exception:
63
  def generate_caption(image_path):
64
  try:
65
  if USE_GIT:
66
- image = Image.open(image_path)
67
  inputs = processor(images=image, return_tensors="pt")
68
  outputs = model.generate(**inputs, max_length=50)
69
  return processor.batch_decode(outputs, skip_special_tokens=True)[0]
@@ -73,14 +72,26 @@ def generate_caption(image_path):
73
  except Exception as e:
74
  return f"Error generating caption: {str(e)}"
75
 
76
- async def caption_image(file: UploadFile = File(...)):
77
  try:
78
- contents = await file.read()
79
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
 
 
 
 
 
 
80
  tmp.write(contents)
81
  tmp_path = tmp.name
82
 
 
83
  caption = generate_caption(tmp_path)
84
- return caption
 
 
 
 
 
85
  except Exception as e:
86
  return {"error": f"Failed to generate caption: {str(e)}"}
 
43
  @app.get("/")
44
  def home():
45
  return RedirectResponse(url="/")"""
46
+ from fastapi import UploadFile
47
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
 
 
48
  from PIL import Image
49
  import tempfile
50
+ import os
51
  import torch
52
 
53
  # Load model
 
62
  def generate_caption(image_path):
63
  try:
64
  if USE_GIT:
65
+ image = Image.open(image_path).convert("RGB")
66
  inputs = processor(images=image, return_tensors="pt")
67
  outputs = model.generate(**inputs, max_length=50)
68
  return processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
72
  except Exception as e:
73
  return f"Error generating caption: {str(e)}"
74
 
75
+ async def caption_image(file: UploadFile):
76
  try:
77
+ # Get file extension correctly
78
+ _, ext = os.path.splitext(file.filename)
79
+ if ext.lower() not in [".jpg", ".jpeg", ".png", ".bmp", ".gif"]:
80
+ return {"error": "Unsupported file type"}
81
+
82
+ # Save the uploaded image with correct extension
83
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp:
84
+ contents = await file.read()
85
  tmp.write(contents)
86
  tmp_path = tmp.name
87
 
88
+ # Generate caption
89
  caption = generate_caption(tmp_path)
90
+
91
+ # Handle errors inside generate_caption
92
+ if caption.startswith("Error"):
93
+ return {"error": caption}
94
+ return {"caption": caption}
95
+
96
  except Exception as e:
97
  return {"error": f"Failed to generate caption: {str(e)}"}