Spaces:
Build error
Build error
File size: 3,342 Bytes
daf0288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
from typing import Any, Literal, Union
from pathlib import Path
import jsonlines
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os
import json
from src.utils import bbox_augmentation_resize
class PubTables(Dataset):
"""PubTables-1M-Structure"""
def __init__(
self,
root_dir: Union[Path, str],
label_type: Literal["image", "cell", "bbox"],
split: Literal["train", "val", "test"],
transform: transforms = None,
cell_limit: int = 100,
) -> None:
super().__init__()
self.root_dir = Path(root_dir)
self.split = split
self.label_type = label_type
self.transform = transform
self.cell_limit = cell_limit
tmp = os.listdir(self.root_dir / self.split)
self.image_list = [i.split(".xml")[0] for i in tmp]
def __len__(self):
return len(self.image_list)
def __getitem__(self, index: int) -> Any:
name = self.image_list[index]
img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg"))
if self.label_type == "image":
if self.transform:
img = self.transform(img)
return img
elif "bbox" in self.label_type:
img_size = img.size
if self.transform:
img = self.transform(img)
tgt_size = img.shape[-1]
with open(
os.path.join(self.root_dir, "words", name + "_words.json"), "r"
) as f:
obj = json.load(f)
obj[:] = [
v
for i in obj
if "bbox" in i.keys()
and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)])
for v in bbox_augmentation_resize(
[
min(max(i["bbox"][0], 0), img_size[0]),
min(max(i["bbox"][1], 0), img_size[1]),
min(max(i["bbox"][2], 0), img_size[0]),
min(max(i["bbox"][3], 0), img_size[1]),
],
img_size,
tgt_size,
)
]
sample = {"filename": name, "image": img, "bbox": obj}
return sample
elif "cell" in self.label_type:
img_size = img.size
with open(
os.path.join(self.root_dir, "words", name + "_words.json"), "r"
) as f:
obj = json.load(f)
bboxes_texts = [
(i["bbox"], i["text"])
for idx, i in enumerate(obj)
if "bbox" in i
and i["bbox"][0] < i["bbox"][2]
and i["bbox"][1] < i["bbox"][3]
and i["bbox"][0] >= 0
and i["bbox"][1] >= 0
and i["bbox"][2] < img_size[0]
and i["bbox"][3] < img_size[1]
and idx < self.cell_limit
]
img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts]
text_bboxes = [
{"filename": name, "bbox_id": i, "cell": j[1]}
for i, j in enumerate(bboxes_texts)
]
return img_bboxes, text_bboxes
|