File size: 3,904 Bytes
c8a32e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from collections import defaultdict
from itertools import chain
from typing import Optional

from marker.settings import settings
import torch
import torch.nn.functional as F
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize


def get_batch_size():
    if settings.EDITOR_BATCH_SIZE is not None:
        return settings.EDITOR_BATCH_SIZE
    elif settings.TORCH_DEVICE_MODEL == "cuda":
        return 12
    return 6


def load_editing_model():
    if not settings.ENABLE_EDITOR_MODEL:
        return None

    model = T5ForTokenClassification.from_pretrained(
            settings.EDITOR_MODEL_NAME,
            torch_dtype=settings.MODEL_DTYPE,
        ).to(settings.TORCH_DEVICE_MODEL)
    model.eval()

    model.config.label2id = {
        "equal": 0,
        "delete": 1,
        "newline-1": 2,
        "space-1": 3,
    }
    model.config.id2label = {v: k for k, v in model.config.label2id.items()}
    return model


def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_multiplier=1) -> (str, dict):
    if not model:
        return text, {}

    batch_size = get_batch_size() * batch_multiplier
    tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
    input_ids = tokenized["input_ids"]
    char_token_lengths = tokenized["char_token_lengths"]

    # Run model
    token_masks = []
    for i in range(0, len(input_ids), batch_size):
        batch_input_ids = tokenized["input_ids"][i: i + batch_size]
        batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
        batch_attention_mask = tokenized["attention_mask"][i: i + batch_size]
        batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
        with torch.inference_mode():
            predictions = model(batch_input_ids, attention_mask=batch_attention_mask)

        logits = predictions.logits.cpu()

        # If the max probability is less than a threshold, we assume it's a bad prediction
        # We want to be conservative to not edit the text too much
        probs = F.softmax(logits, dim=-1)
        max_prob = torch.max(probs, dim=-1)
        cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
        labels = logits.argmax(-1)
        labels[cutoff_prob] = model.config.label2id["equal"]
        labels = labels.squeeze().tolist()
        if len(labels) == settings.EDITOR_MAX_LENGTH:
            labels = [labels]
        labels = list(chain.from_iterable(labels))
        token_masks.extend(labels)

    # List of characters in the text
    flat_input_ids = list(chain.from_iterable(input_ids))

    # Strip special tokens 0,1.  Keep unknown token, although it should never be used
    assert len(token_masks) == len(flat_input_ids)
    token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]

    assert len(token_masks) == len(list(text.encode("utf-8")))

    edit_stats = defaultdict(int)
    out_text = []
    start = 0
    for i, char in enumerate(text):
        char_token_length = char_token_lengths[i]
        masks = token_masks[start: start + char_token_length]
        labels = [model.config.id2label[mask] for mask in masks]
        if all(l == "delete" for l in labels):
            # If we delete whitespace, roll with it, otherwise ignore
            if char.strip():
                out_text.append(char)
            else:
                edit_stats["delete"] += 1
        elif labels[0] == "newline-1":
            out_text.append("\n")
            out_text.append(char)
            edit_stats["newline-1"] += 1
        elif labels[0] == "space-1":
            out_text.append(" ")
            out_text.append(char)
            edit_stats["space-1"] += 1
        else:
            out_text.append(char)
            edit_stats["equal"] += 1

        start += char_token_length

    out_text = "".join(out_text)
    return out_text, edit_stats