TextClassifier / src /bert.py
ivanovot
init
b83e315
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
class Bert(nn.Module):
def __init__(self, model_name):
super().__init__()
self.model = AutoModel.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def forward(self, texts):
inputs = self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors='pt'
).to(self.model.device)
outputs = self.model(**inputs)
return outputs.last_hidden_state[:, 0, :]