Spaces:
Runtime error
Runtime error
File size: 14,311 Bytes
1bdaecc 9544646 4b9c9b6 9544646 5d16b15 9544646 4b9c9b6 c6b2f05 4b9c9b6 c6b2f05 4b9c9b6 7fb90e2 4b9c9b6 c6b2f05 4b9c9b6 c6b2f05 9544646 4b9c9b6 730ca01 4b9c9b6 7fb90e2 4b9c9b6 69177fb 7fb90e2 4b9c9b6 9544646 69177fb 9544646 69177fb 5d16b15 69177fb 4b9c9b6 69177fb 4b9c9b6 69177fb 4b9c9b6 9544646 69177fb 9544646 69177fb 5d16b15 69177fb 4b9c9b6 69177fb 4b9c9b6 69177fb 4b9c9b6 730ca01 4b9c9b6 69177fb 4b9c9b6 69177fb 7fb90e2 11d1f29 7fb90e2 c6b2f05 631c491 730ca01 7fb90e2 11d1f29 c6b2f05 40ba287 7fb90e2 76bab0f 7fb90e2 730ca01 11d1f29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 |
import gradio as gr
import json, time, torch
from transformers import BartTokenizer, BartForConditionalGeneration, AutoModel, AutoTokenizer
from webshop_lite import dict_to_fake_html
from predict_help import (
Page, convert_dict_to_actions, convert_html_to_text,
parse_results_amz, parse_item_page_amz,
parse_results_ws, parse_item_page_ws,
parse_results_ebay, parse_item_page_ebay,
WEBSHOP_URL, WEBSHOP_SESSION
)
ENVIRONMENTS = ['amazon', 'webshop', 'ebay']
# IL+RL: 'webshop/il-rl-choice-bert-image_1'
# IL: 'webshop/il-choice-bert-image_0'
BERT_MODEL_PATH = 'webshop/il-choice-bert-image_0'
# load IL models
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
bart_model = BartForConditionalGeneration.from_pretrained('webshop/il_search_bart')
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation_side='left')
bert_tokenizer.add_tokens(['[button]', '[button_]', '[clicked button]', '[clicked button_]'], special_tokens=True)
bert_model = AutoModel.from_pretrained(BERT_MODEL_PATH, trust_remote_code=True)
def process_str(s):
s = s.lower().replace('"', '').replace("'", "").strip()
s = s.replace('[sep]', '[SEP]')
return s
def process_goal(state):
state = state.lower().replace('"', '').replace("'", "")
state = state.replace('amazon shopping game\ninstruction:', '').replace('\n[button] search [button_]', '').strip()
if ', and price lower than' in state:
state = state.split(', and price lower than')[0]
return state
def data_collator(batch):
state_input_ids, state_attention_mask, action_input_ids, action_attention_mask, sizes, labels, images = [], [], [], [], [], [], []
for sample in batch:
state_input_ids.append(sample['state_input_ids'])
state_attention_mask.append(sample['state_attention_mask'])
action_input_ids.extend(sample['action_input_ids'])
action_attention_mask.extend(sample['action_attention_mask'])
sizes.append(sample['sizes'])
labels.append(sample['labels'])
images.append(sample['images'])
max_state_len = max(sum(x) for x in state_attention_mask)
max_action_len = max(sum(x) for x in action_attention_mask)
return {
'state_input_ids': torch.tensor(state_input_ids)[:, :max_state_len],
'state_attention_mask': torch.tensor(state_attention_mask)[:, :max_state_len],
'action_input_ids': torch.tensor(action_input_ids)[:, :max_action_len],
'action_attention_mask': torch.tensor(action_attention_mask)[:, :max_action_len],
'sizes': torch.tensor(sizes),
'images': torch.tensor(images),
'labels': torch.tensor(labels),
}
def bart_predict(input):
input_ids = bart_tokenizer(input)['input_ids']
input_ids = torch.tensor(input_ids).unsqueeze(0)
output = bart_model.generate(input_ids, max_length=512, num_return_sequences=5, num_beams=5)
return bart_tokenizer.batch_decode(output.tolist(), skip_special_tokens=True)[0]
def bert_predict(obs, info, softmax=True):
valid_acts = info['valid']
assert valid_acts[0].startswith('click[')
state_encodings = bert_tokenizer(process_str(obs), max_length=512, truncation=True, padding='max_length')
action_encodings = bert_tokenizer(list(map(process_str, valid_acts)), max_length=512, truncation=True, padding='max_length')
batch = {
'state_input_ids': state_encodings['input_ids'],
'state_attention_mask': state_encodings['attention_mask'],
'action_input_ids': action_encodings['input_ids'],
'action_attention_mask': action_encodings['attention_mask'],
'sizes': len(valid_acts),
'images': info['image_feat'].tolist(),
'labels': 0
}
batch = data_collator([batch])
outputs = bert_model(**batch)
if softmax:
idx = torch.multinomial(torch.nn.functional.softmax(outputs.logits[0], dim=0), 1)[0].item()
else:
idx = outputs.logits[0].argmax(0).item()
return valid_acts[idx]
def get_return_value(env, asin, options, search_terms, page_num, product):
asin_url = None
# Determine product URL + options based on environment
if env == 'webshop':
query_str = "+".join(search_terms.split())
options_str = json.dumps(options)
asin_url = (
f'{WEBSHOP_URL}/item_page/{WEBSHOP_SESSION}/'
f'{asin}/{query_str}/{page_num}/{options_str}'
)
else:
asin_url = f"https://www.ebay.com/itm/{asin}" if env == 'ebay' else \
f"https://www.amazon.com/dp/{asin}"
# Extract relevant fields for product
product_reduced = {k: v for k, v in product.items() if k in ["asin", "Title", "Description", "BulletPoints"]}
product_reduced["Description"] = product_reduced["Description"][:100] + "..."
product_reduced["Features"] = product_reduced.pop("BulletPoints")
product_reduced["Features"] = product_reduced["Features"][:100] + "..."
# Create HTML to show link to product
html = """<!DOCTYPE html><html><head><title>Chosen Product</title></head><body>"""
html += f"""Product Image:<img src="{product["MainImage"]}" height="50px" /><br>""" if len(product["MainImage"]) > 0 else ""
html += f"""Link to Product:
<a href="{asin_url}" style="color:blue;text-decoration:underline;" target="_blank">{asin_url}</a>
</body></html>"""
return product_reduced, options if len(options) > 0 else "None Selected", html
def predict(obs, info):
"""
Given WebShop environment observation and info, predict an action.
"""
valid_acts = info['valid']
if valid_acts[0].startswith('click['):
return bert_predict(obs, info)
else:
return "search[" + bart_predict(process_goal(obs)) + "]"
def run_episode(goal, env, verbose=True):
"""
Interact with amazon to find a product given input goal.
Input: text goal
Output: a url of found item on amazon.
"""
env = env.lower()
if env not in ENVIRONMENTS:
print(f"[ERROR] Environment {env} not recognized")
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
product_map = {}
title_to_asin_map = {}
search_results_cache = {}
visited_asins, clicked_options = set(), set()
sub_page_type, page_type, page_num = None, None, None
search_terms, prod_title, asin = None, None, None
options = {}
for i in range(100):
# Run prediction
action = predict(obs, info)
if verbose:
print("====")
print(action)
# Previous Page Type, Action -> Next Page Type
action_content = action[action.find("[")+1:action.find("]")]
prev_page_type = page_type
if action.startswith('search['):
page_type = Page.RESULTS
search_terms = action_content
page_num = 1
elif action.startswith('click['):
if action.startswith('click[item -'):
prod_title = action_content[len("item -"):].strip()
found = False
for key in title_to_asin_map:
if prod_title == key:
asin = title_to_asin_map[key]
page_type = Page.ITEM_PAGE
visited_asins.add(asin)
found = True
break
if not found:
raise Exception("Product to click not found")
elif any(x.value in action for x in [Page.DESC, Page.FEATURES, Page.REVIEWS]):
page_type = Page.SUB_PAGE
sub_page_type = Page(action_content.lower())
elif action == 'click[< prev]':
if sub_page_type is not None:
page_type, sub_page_type = Page.ITEM_PAGE, None
elif prev_page_type == Page.ITEM_PAGE:
page_type = Page.RESULTS
options, clicked_options = {}, set()
elif prev_page_type == Page.RESULTS and page_num > 1:
page_type = Page.RESULTS
page_num -= 1
elif action == 'click[next >]':
page_type = Page.RESULTS
page_num += 1
elif action.lower() == 'click[back to search]':
page_type = Page.SEARCH
elif action == 'click[buy now]':
return get_return_value(env, asin, options, search_terms, page_num, product_map[asin])
elif prev_page_type == Page.ITEM_PAGE:
found = False
for opt_name, opt_values in product_map[asin]["options"].items():
if action_content in opt_values:
options[opt_name] = action_content
page_type = Page.ITEM_PAGE
clicked_options.add(action_content)
found = True
break
if not found:
raise Exception("Unrecognized action: " + action)
else:
raise Exception("Unrecognized action:" + action)
if verbose:
print(f"Parsing {page_type.value} page...")
# URL -> Real HTML -> Dict of Info
if page_type == Page.RESULTS:
if search_terms in search_results_cache:
data = search_results_cache[search_terms]
if verbose:
print(f"Loading cached results page for \"{search_terms}\"")
else:
begin = time.time()
if env == 'amazon':
data = parse_results_amz(search_terms, page_num, verbose)
if env == 'webshop':
data = parse_results_ws(search_terms, page_num, verbose)
if env == 'ebay':
data = parse_results_ebay(search_terms, page_num, verbose)
end = time.time()
if verbose:
print(f"Parsing search results took {end-begin} seconds")
search_results_cache[search_terms] = data
for d in data:
title_to_asin_map[d['Title']] = d['asin']
elif page_type == Page.ITEM_PAGE or page_type == Page.SUB_PAGE:
if asin in product_map:
if verbose:
print("Loading cached item page for", asin)
data = product_map[asin]
else:
begin = time.time()
if env == 'amazon':
data = parse_item_page_amz(asin, verbose)
if env == 'webshop':
data = parse_item_page_ws(asin, search_terms, page_num, options, verbose)
if env == 'ebay':
data = parse_item_page_ebay(asin, verbose)
end = time.time()
if verbose:
print("Parsing item page took", end-begin, "seconds")
product_map[asin] = data
elif page_type == Page.SEARCH:
if verbose:
print("Executing search")
obs = "Amazon Shopping Game\nInstruction:" + goal + "\n[button] search [button]"
info = {'valid': ['search[stuff]'], 'image_feat': torch.zeros(512)}
continue
else:
raise Exception("Page of type `", page_type, "` not found")
# Dict of Info -> Fake HTML -> Text Observation
begin = time.time()
html_str = dict_to_fake_html(data, page_type, asin, sub_page_type, options, product_map, goal)
obs = convert_html_to_text(html_str, simple=False, clicked_options=clicked_options, visited_asins=visited_asins)
end = time.time()
if verbose:
print("[Page Info -> WebShop HTML -> Observation] took", end-begin, "seconds")
# Dict of Info -> Valid Action State (Info)
begin = time.time()
prod_arg = product_map if page_type == Page.ITEM_PAGE else data
info = convert_dict_to_actions(page_type, prod_arg, asin, page_num)
end = time.time()
if verbose:
print("Extracting available actions took", end-begin, "seconds")
if i == 50:
return get_return_value(env, asin, options, search_terms, page_num, product_map[asin])
gr.Interface(
fn=run_episode,
inputs=[
gr.inputs.Textbox(lines=7, label="Input Text"),
gr.inputs.Radio(['Amazon', 'eBay'], type="value", default="Amazon", label='Environment')
],
outputs=[
gr.outputs.JSON(label="Selected Product"),
gr.outputs.JSON(label="Selected Options"),
gr.outputs.HTML()
],
examples=[
["I want to find a gold floor lamp with a glass shade and a nickel finish that i can use for my living room, and price lower than 270.00 dollars", "Amazon"],
["I need some cute heart-shaped glittery cupcake picks as a gift to bring to a baby shower", "Amazon"],
["I want to buy ballet shoes which have rubber sole in grey suede color and a size of 6", "Amazon"],
["I would like a 7 piece king comforter set decorated with flowers and is machine washable", "Amazon"],
["I'm trying to find white bluetooth speakers that are not only water resistant but also come with stereo sound", "eBay"],
["find me the soy free 3.5 ounce 4-pack of dang thai rice chips, and make sure they are the aged cheddar flavor. i also need the ones in the resealable bags", "eBay"],
["I am looking for a milk chocolate of 1 pound size in a single pack for valentine day", "eBay"],
["I'm looking for a mini pc intel core desktop computer which supports with windows 11", "eBay"]
],
title="WebShop",
article="<p style='padding-top:15px;text-align:center;'>To learn more about this project, check out the <a href='https://webshop-pnlp.github.io/' target='_blank'>project page</a>!</p>",
description="<p style='text-align:center;'>Sim-to-real transfer of agent trained on WebShop to search a desired product on Amazon from any natural language query!</p>",
).launch(inline=False)
|