File size: 267 Bytes
ac3fa58 |
1 2 3 4 5 6 7 8 9 10 |
import webdataset as wds
import io
import torch
ds = wds.WebDataset("output.tar")
for s in ds:
print(s.keys())
prompt_embeds_bytes = s["prompt_embeds.pt"]
prompt_embeddings = torch.load(io.BytesIO(prompt_embeds_bytes))
print(prompt_embeddings.shape) |