sayakpaul's picture
sayakpaul HF Staff
Upload check_loading.py with huggingface_hub
ac3fa58 verified
raw
history blame contribute delete
267 Bytes
import webdataset as wds
import io
import torch
ds = wds.WebDataset("output.tar")
for s in ds:
print(s.keys())
prompt_embeds_bytes = s["prompt_embeds.pt"]
prompt_embeddings = torch.load(io.BytesIO(prompt_embeds_bytes))
print(prompt_embeddings.shape)