File size: 267 Bytes
ac3fa58
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
import webdataset as wds 
import io
import torch

ds = wds.WebDataset("output.tar")
for s in ds:
    print(s.keys())
    prompt_embeds_bytes = s["prompt_embeds.pt"]
    prompt_embeddings = torch.load(io.BytesIO(prompt_embeds_bytes))
    print(prompt_embeddings.shape)