import webdataset as wds | |
import io | |
import torch | |
ds = wds.WebDataset("output.tar") | |
for s in ds: | |
print(s.keys()) | |
prompt_embeds_bytes = s["prompt_embeds.pt"] | |
prompt_embeddings = torch.load(io.BytesIO(prompt_embeds_bytes)) | |
print(prompt_embeddings.shape) |