Ritvik19's picture
Add all files and directories
c8a32e7
from collections import defaultdict
from itertools import chain
from typing import Optional
from marker.settings import settings
import torch
import torch.nn.functional as F
from marker.postprocessors.t5 import T5ForTokenClassification, byt5_tokenize
def get_batch_size():
if settings.EDITOR_BATCH_SIZE is not None:
return settings.EDITOR_BATCH_SIZE
elif settings.TORCH_DEVICE_MODEL == "cuda":
return 12
return 6
def load_editing_model():
if not settings.ENABLE_EDITOR_MODEL:
return None
model = T5ForTokenClassification.from_pretrained(
settings.EDITOR_MODEL_NAME,
torch_dtype=settings.MODEL_DTYPE,
).to(settings.TORCH_DEVICE_MODEL)
model.eval()
model.config.label2id = {
"equal": 0,
"delete": 1,
"newline-1": 2,
"space-1": 3,
}
model.config.id2label = {v: k for k, v in model.config.label2id.items()}
return model
def edit_full_text(text: str, model: Optional[T5ForTokenClassification], batch_multiplier=1) -> (str, dict):
if not model:
return text, {}
batch_size = get_batch_size() * batch_multiplier
tokenized = byt5_tokenize(text, settings.EDITOR_MAX_LENGTH)
input_ids = tokenized["input_ids"]
char_token_lengths = tokenized["char_token_lengths"]
# Run model
token_masks = []
for i in range(0, len(input_ids), batch_size):
batch_input_ids = tokenized["input_ids"][i: i + batch_size]
batch_input_ids = torch.tensor(batch_input_ids, device=model.device)
batch_attention_mask = tokenized["attention_mask"][i: i + batch_size]
batch_attention_mask = torch.tensor(batch_attention_mask, device=model.device)
with torch.inference_mode():
predictions = model(batch_input_ids, attention_mask=batch_attention_mask)
logits = predictions.logits.cpu()
# If the max probability is less than a threshold, we assume it's a bad prediction
# We want to be conservative to not edit the text too much
probs = F.softmax(logits, dim=-1)
max_prob = torch.max(probs, dim=-1)
cutoff_prob = max_prob.values < settings.EDITOR_CUTOFF_THRESH
labels = logits.argmax(-1)
labels[cutoff_prob] = model.config.label2id["equal"]
labels = labels.squeeze().tolist()
if len(labels) == settings.EDITOR_MAX_LENGTH:
labels = [labels]
labels = list(chain.from_iterable(labels))
token_masks.extend(labels)
# List of characters in the text
flat_input_ids = list(chain.from_iterable(input_ids))
# Strip special tokens 0,1. Keep unknown token, although it should never be used
assert len(token_masks) == len(flat_input_ids)
token_masks = [mask for mask, token in zip(token_masks, flat_input_ids) if token >= 2]
assert len(token_masks) == len(list(text.encode("utf-8")))
edit_stats = defaultdict(int)
out_text = []
start = 0
for i, char in enumerate(text):
char_token_length = char_token_lengths[i]
masks = token_masks[start: start + char_token_length]
labels = [model.config.id2label[mask] for mask in masks]
if all(l == "delete" for l in labels):
# If we delete whitespace, roll with it, otherwise ignore
if char.strip():
out_text.append(char)
else:
edit_stats["delete"] += 1
elif labels[0] == "newline-1":
out_text.append("\n")
out_text.append(char)
edit_stats["newline-1"] += 1
elif labels[0] == "space-1":
out_text.append(" ")
out_text.append(char)
edit_stats["space-1"] += 1
else:
out_text.append(char)
edit_stats["equal"] += 1
start += char_token_length
out_text = "".join(out_text)
return out_text, edit_stats