taybeyond commited on
Commit
2a4714f
·
verified ·
1 Parent(s): 912204b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -1,26 +1,39 @@
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForCausalLM
3
  import torch
4
- from PIL import Image
 
5
 
6
- MODEL_ID = "Qwen/Qwen1.5-VL-Chat"
 
 
7
 
8
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
9
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", trust_remote_code=True, torch_dtype=torch.bfloat16)
10
 
11
- def qwen_vl_chat(image, question):
12
- inputs = processor(text=question, images=image, return_tensors="pt").to(model.device)
13
- outputs = model.generate(**inputs, max_new_tokens=256)
14
- answer = processor.batch_decode(outputs, skip_special_tokens=True)[0]
15
- return answer
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  demo = gr.Interface(
18
- fn=qwen_vl_chat,
19
  inputs=[gr.Image(type="pil"), gr.Textbox(label="请输入问题")],
20
  outputs="text",
21
- title="🧠 Qwen1.5-VL 图文问答 Demo",
22
- description="上传一张图片,问一个问题,模型会给出答案。"
23
  )
24
 
25
  if __name__ == "__main__":
26
- demo.launch()
 
1
  import gradio as gr
2
  from transformers import AutoProcessor, AutoModelForCausalLM
3
  import torch
4
+ import os
5
+ from huggingface_hub import login
6
 
7
+ # 设置你的 Hugging Face Token
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")# ← 这里替换为你的 token
9
+ login(token=HF_TOKEN)
10
 
11
+ # 指定模型
12
+ MODEL_ID = "Qwen/Qwen-VL-Chat"
13
 
14
+ # 加载模型
15
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
16
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, device_map="auto", token=HF_TOKEN)
17
+ model.eval()
 
18
 
19
+ # 推理函数
20
+ def ask(image, prompt):
21
+ inputs = processor.from_list_format([
22
+ {"image": image},
23
+ {"text": prompt}
24
+ ])
25
+ inputs = processor(inputs, return_tensors="pt").to(model.device)
26
+ outputs = model.generate(**inputs, max_new_tokens=512)
27
+ response = processor.batch_decode(outputs, skip_special_tokens=True)[0]
28
+ return response
29
+
30
+ # Gradio 页面
31
  demo = gr.Interface(
32
+ fn=ask,
33
  inputs=[gr.Image(type="pil"), gr.Textbox(label="请输入问题")],
34
  outputs="text",
35
+ title="Qwen1.5-VL-Chat 在线体验"
 
36
  )
37
 
38
  if __name__ == "__main__":
39
+ demo.launch()