File size: 4,238 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import logging
from typing import Any, Dict

import numpy as np
import pandas as pd
import torch

from llm_studio.src.datasets.text_causal_language_modeling_ds import (
    CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException

logger = logging.getLogger(__name__)


class CustomDataset(TextCausalLanguageModelingCustomDataset):
    def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
        super().__init__(df=df, cfg=cfg, mode=mode)
        check_for_non_int_answers(cfg, df)
        self.answers_int = df[cfg.dataset.answer_column].astype(int).values
        max_value = np.max(self.answers_int)
        min_value = np.min(self.answers_int)

        if 1 < cfg.dataset.num_classes <= max_value:
            raise LLMDataException(
                "Number of classes is smaller than max label "
                f"{max_value}. Please increase the setting accordingly."
            )
        elif cfg.dataset.num_classes == 1 and max_value > 1:
            raise LLMDataException(
                "For binary classification, max label should be 1 but is "
                f"{max_value}."
            )
        if min_value < 0:
            raise LLMDataException(
                "Labels should be non-negative but min label is " f"{min_value}."
            )
        if min_value != 0 or max_value != np.unique(self.answers_int).size - 1:
            logger.warning(
                "Labels should start at 0 and be continuous but are "
                f"{sorted(np.unique(self.answers_int))}."
            )

        if cfg.dataset.parent_id_column != "None":
            raise LLMDataException(
                "Parent ID column is not supported for classification datasets."
            )

    def __getitem__(self, idx: int) -> Dict:
        sample = super().__getitem__(idx)
        sample["class_label"] = self.answers_int[idx]
        return sample

    def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
        output["logits"] = output["logits"].float()

        if cfg.training.loss_function == "CrossEntropyLoss":
            output["probabilities"] = torch.softmax(output["logits"], dim=-1)
        else:
            output["probabilities"] = torch.sigmoid(output["logits"])

        if len(cfg.dataset.answer_column) == 1:
            if cfg.dataset.num_classes == 1:
                output["predictions"] = (output["probabilities"] > 0.5).long()
            else:
                output["predictions"] = output["probabilities"].argmax(
                    dim=-1, keepdim=True
                )
        else:
            output["predictions"] = (output["probabilities"] > 0.5).long()

        preds = []
        for col in np.arange(output["probabilities"].shape[1]):
            preds.append(
                np.round(output["probabilities"][:, col].cpu().numpy(), 3).astype(str)
            )
        preds = [",".join(pred) for pred in zip(*preds)]
        output["predicted_text"] = preds
        return super().postprocess_output(cfg, df, output)

    def clean_output(self, output, cfg):
        return output

    @classmethod
    def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):

        for answer_col in cfg.dataset.answer_column:
            assert answer_col in df.columns, (
                f"Answer column {answer_col} not found in the " f"{mode} DataFrame."
            )
            assert df.shape[0] == df[answer_col].dropna().shape[0], (
                f"The {mode} DataFrame"
                f" column {answer_col}"
                " contains missing values."
            )

        check_for_non_int_answers(cfg, df)


def check_for_non_int_answers(cfg, df):
    answers_non_int: list = []
    for column in cfg.dataset.answer_column:
        answers_non_int.extend(
            x for x in df[column].values if not is_castable_to_int(x)
        )
    if len(answers_non_int) > 0:
        raise LLMDataException(
            f"Column {cfg.dataset.answer_column} contains non int items. "
            f"Sample values: {answers_non_int[:5]}."
        )


def is_castable_to_int(s):
    try:
        int(s)
        return True
    except ValueError:
        return False