LayBraid commited on
Commit
ae92333
·
1 Parent(s): b2c2198

:construction: update app

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -1
  2. text_to_image.py +48 -1
requirements.txt CHANGED
@@ -1,3 +1,5 @@
1
  streamlit==1.2.0
2
  transformers~=4.19.4
3
- numpy~=1.22.2
 
 
 
1
  streamlit==1.2.0
2
  transformers~=4.19.4
3
+ numpy~=1.22.2
4
+ nmslib~=2.1.1
5
+ Pillow~=9.0.1
text_to_image.py CHANGED
@@ -1,16 +1,63 @@
 
 
1
  import numpy as np
2
  import streamlit as st
 
3
  from transformers import CLIPProcessor, FlaxCLIPModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def get_image(text):
7
  model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
8
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
 
 
9
 
10
  inputs = processor(text=[text], image=None, return_tensors="jax", padding=True)
11
 
12
  vector = model.get_text_features(**inputs)
13
  vector = np.asarray(vector)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  def app():
@@ -20,5 +67,5 @@ def app():
20
  text = st.text_input("Enter text: ")
21
 
22
  if st.button("Search"):
23
- st.image(get_image(text))
24
 
 
1
+ import json
2
+ import os
3
  import numpy as np
4
  import streamlit as st
5
+ from PIL import Image
6
  from transformers import CLIPProcessor, FlaxCLIPModel
7
+ import nmslib
8
+
9
+
10
+ def load_index(image_vector_file):
11
+ filenames, image_vecs = [], []
12
+ fvec = open(image_vector_file, "r")
13
+ for line in fvec:
14
+ cols = line.strip().split(' ')
15
+ filename = cols[0]
16
+ image_vec = np.array([float(x) for x in cols[1].split(',')])
17
+ filenames.append(filename)
18
+ image_vecs.append(image_vec)
19
+ V = np.array(image_vecs)
20
+ index = nmslib.init(method='hnsw', space='cosinesimil')
21
+ index.addDataPointBatch(V)
22
+ index.createIndex({'post': 2}, print_progress=True)
23
+ return filenames, index
24
+
25
+
26
+ def load_captions(caption_file):
27
+ image2caption = {}
28
+ with open(caption_file, "r") as fcap:
29
+ for line in fcap:
30
+ data = json.loads(line.strip())
31
+ filename = data["filename"]
32
+ captions = data["captions"]
33
+ image2caption[filename] = captions
34
+ return image2caption
35
 
36
 
37
  def get_image(text):
38
  model = FlaxCLIPModel.from_pretrained("flax-community/clip-rsicd-v2")
39
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
40
+ filename, index = load_index("./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv")
41
+ image2caption = load_captions("./images/test-captions.json")
42
 
43
  inputs = processor(text=[text], image=None, return_tensors="jax", padding=True)
44
 
45
  vector = model.get_text_features(**inputs)
46
  vector = np.asarray(vector)
47
+ ids, distances = index.knnQuery(vector, k=10)
48
+ result_filenames = [filename[id] for id in ids]
49
+ for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
50
+ caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
51
+ col1, col2, col3 = st.columns([2, 10, 10])
52
+ col1.markdown("{:d}.".format(rank + 1))
53
+ col2.image(Image.open(os.path.join("./images", result_filename)),
54
+ caption=caption)
55
+ caption_text = []
56
+ for caption in image2caption[result_filename]:
57
+ caption_text.append("* {:s}".format(caption))
58
+ col3.markdown("".join(caption_text))
59
+ st.markdown("---")
60
+ suggest_idx = -1
61
 
62
 
63
  def app():
 
67
  text = st.text_input("Enter text: ")
68
 
69
  if st.button("Search"):
70
+ get_image(text)
71