Luigi commited on
Commit
d5cb4e0
·
1 Parent(s): c5ee215

Delay Model Loading Until Inside a GPU Context

Browse files
Files changed (1) hide show
  1. app.py +5 -14
app.py CHANGED
@@ -5,29 +5,20 @@ from transformers import CLIPProcessor, CLIPModel
5
 
6
  # Load the CLIP model and processor on the CPU initially
7
  model_name = "openai/clip-vit-base-patch32"
8
- model = CLIPModel.from_pretrained(model_name)
9
- processor = CLIPProcessor.from_pretrained(model_name)
10
 
11
  @spaces.GPU
12
  def clip_similarity(image, text):
13
- """
14
- Computes a similarity score between an input image and text using the CLIP model.
15
- This function is decorated with @spaces.GPU so that the model is moved to GPU only when needed.
16
- """
17
- # Create a torch device for cuda
18
- device = torch.device("cuda")
19
 
20
- # Move the model to GPU within the function
21
  model.to(device)
22
 
23
- # Preprocess the inputs and move tensors to GPU
24
  inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
25
- inputs = {key: val.to(device) for key, val in inputs.items()}
26
 
27
- # Run inference
28
  outputs = model(**inputs)
29
-
30
- # Extract similarity score (logits_per_image): higher value indicates better matching
31
  similarity_score = outputs.logits_per_image.detach().cpu().numpy()[0]
32
  return float(similarity_score)
33
 
 
5
 
6
  # Load the CLIP model and processor on the CPU initially
7
  model_name = "openai/clip-vit-base-patch32"
 
 
8
 
9
  @spaces.GPU
10
  def clip_similarity(image, text):
11
+ # Load the model and processor inside GPU context
12
+ model = CLIPModel.from_pretrained(model_name)
13
+ processor = CLIPProcessor.from_pretrained(model_name)
 
 
 
14
 
15
+ device = torch.device("cuda")
16
  model.to(device)
17
 
 
18
  inputs = processor(text=[text], images=image, return_tensors="pt", padding=True)
19
+ inputs = {k: v.to(device) for k, v in inputs.items()}
20
 
 
21
  outputs = model(**inputs)
 
 
22
  similarity_score = outputs.logits_per_image.detach().cpu().numpy()[0]
23
  return float(similarity_score)
24