File size: 2,427 Bytes
cb3a670
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
import json
from typing import Dict, List


def load_eurorad_dataset(
    dataset_path: str,
    section: str = "any",
    as_dict: bool = False,
    filter_by_caption: List[str] = [
        "xray",
        "x-ray",
        "x ray",
        "ray",
        "xr",
        "radiograph",
        "radiogram",
        "plain film",
    ],
) -> List[Dict] | Dict[str, Dict]:
    """
    Load a dataset from a JSON file.

    Args:
        dataset_path (str): Path to the JSON dataset file.
        section (str, optional): Section of the dataset to load. Defaults to "any".
        as_dict (bool, optional): Whether to return data as dict. Defaults to False.
        filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].

    Returns:
        List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.

    Raises:
        FileNotFoundError: If dataset_path does not exist
        json.JSONDecodeError: If file is not valid JSON
    """

    with open(dataset_path, "r", encoding="utf-8") as file:
        data = json.load(file)

    if filter_by_caption:
        filtered_data = {}
        for case_id, case in data.items():
            if any(
                any(x in subfig["caption"].lower() for x in filter_by_caption)
                for figure in case["figures"]
                for subfig in figure["subfigures"]
            ) or any(x in case["image_finding"].lower() for x in filter_by_caption):
                filtered_data[case_id] = case
        data = filtered_data

    if section != "any":
        section = section.strip().lower()
        if not as_dict:
            data = [
                item for item in data.values() if item.get("section", "").strip().lower() == section
            ]
        else:
            data = {
                k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
            }

    elif not as_dict:
        data = list(data.values())

    return data


def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
    """
    Save a dataset to a JSON file.

    Args:
        dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
        dataset_path (str): Path where the JSON dataset file will be saved.
    """
    with open(dataset_path, "w", encoding="utf-8") as file:
        json.dump(dataset, file)