K00B404 commited on
Commit
6e527a5
·
verified ·
1 Parent(s): 2bf5338

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py CHANGED
@@ -1,3 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  '''import os
2
  import re
3
  from typing import List, Optional, Union
 
1
+ from auto_round import AutoRoundConfig ## must import for auto-round format
2
+ import requests
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
6
+
7
+
8
+ quantized_model_path="OPEA/llama-joycaption-alpha-two-hf-llava-int4-sym-inc"
9
+
10
+ # Load JoyCaption INT4 Model
11
+ processor = AutoProcessor.from_pretrained(quantized_model_path)
12
+ model = LlavaForConditionalGeneration.from_pretrained(
13
+ quantized_model_path,
14
+ device_map="auto",
15
+ revision="bc917a8" ## ##AutoGPTQ format
16
+ )
17
+ model.eval()
18
+
19
+ image_url = "http://images.cocodataset.org/train2017/000000116003.jpg"
20
+ content = "Write a descriptive caption for this image in a formal tone."
21
+
22
+ # Preparation for inference
23
+ with torch.no_grad():
24
+ image = Image.open(requests.get(image_url, stream=True).raw)
25
+ messages = [
26
+ {
27
+ "role": "system",
28
+ "content": "You are a helpful image captioner.",
29
+ },
30
+ {
31
+ "role": "user",
32
+ "content": content,
33
+ },
34
+ ]
35
+ prompt = processor.apply_chat_template(messages, tokenize = False, add_generation_prompt = True)
36
+ assert isinstance(prompt, str)
37
+ inputs = processor(text=[prompt], images=[image], return_tensors="pt").to(model.device)
38
+ inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype)
39
+
40
+ # Generate the captions
41
+ generate_ids = model.generate(
42
+ **inputs,
43
+ max_new_tokens=50,
44
+ do_sample=False,
45
+ suppress_tokens=None,
46
+ use_cache=True,
47
+ temperature=0.6,
48
+ top_k=None,
49
+ top_p=0.9,
50
+ )[0]
51
+
52
+ # Trim off the prompt
53
+ generate_ids = generate_ids[inputs['input_ids'].shape[1]:]
54
+
55
+ # Decode the caption
56
+ caption = processor.tokenizer.decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
57
+ caption = caption.strip()
58
+ print(caption)
59
+
60
+
61
  '''import os
62
  import re
63
  from typing import List, Optional, Union