File size: 3,342 Bytes
daf0288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from typing import Any, Literal, Union
from pathlib import Path
import jsonlines
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import numpy as np
import os
import json

from src.utils import bbox_augmentation_resize


class PubTables(Dataset):
    """PubTables-1M-Structure"""

    def __init__(
        self,
        root_dir: Union[Path, str],
        label_type: Literal["image", "cell", "bbox"],
        split: Literal["train", "val", "test"],
        transform: transforms = None,
        cell_limit: int = 100,
    ) -> None:
        super().__init__()

        self.root_dir = Path(root_dir)
        self.split = split
        self.label_type = label_type
        self.transform = transform
        self.cell_limit = cell_limit

        tmp = os.listdir(self.root_dir / self.split)

        self.image_list = [i.split(".xml")[0] for i in tmp]

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, index: int) -> Any:
        name = self.image_list[index]
        img = Image.open(os.path.join(self.root_dir, "images", name + ".jpg"))

        if self.label_type == "image":
            if self.transform:
                img = self.transform(img)
            return img
        elif "bbox" in self.label_type:
            img_size = img.size
            if self.transform:
                img = self.transform(img)
            tgt_size = img.shape[-1]
            with open(
                os.path.join(self.root_dir, "words", name + "_words.json"), "r"
            ) as f:
                obj = json.load(f)

            obj[:] = [
                v
                for i in obj
                if "bbox" in i.keys()
                and all([i["bbox"][w + 2] > i["bbox"][w] for w in range(2)])
                for v in bbox_augmentation_resize(
                    [
                        min(max(i["bbox"][0], 0), img_size[0]),
                        min(max(i["bbox"][1], 0), img_size[1]),
                        min(max(i["bbox"][2], 0), img_size[0]),
                        min(max(i["bbox"][3], 0), img_size[1]),
                    ],
                    img_size,
                    tgt_size,
                )
            ]

            sample = {"filename": name, "image": img, "bbox": obj}
            return sample

        elif "cell" in self.label_type:
            img_size = img.size
            with open(
                os.path.join(self.root_dir, "words", name + "_words.json"), "r"
            ) as f:
                obj = json.load(f)

            bboxes_texts = [
                (i["bbox"], i["text"])
                for idx, i in enumerate(obj)
                if "bbox" in i
                and i["bbox"][0] < i["bbox"][2]
                and i["bbox"][1] < i["bbox"][3]
                and i["bbox"][0] >= 0
                and i["bbox"][1] >= 0
                and i["bbox"][2] < img_size[0]
                and i["bbox"][3] < img_size[1]
                and idx < self.cell_limit
            ]

            img_bboxes = [self.transform(img.crop(bbox[0])) for bbox in bboxes_texts]

            text_bboxes = [
                {"filename": name, "bbox_id": i, "cell": j[1]}
                for i, j in enumerate(bboxes_texts)
            ]
            return img_bboxes, text_bboxes