KeivanR's picture
other files
6a1e686
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig
class QwenClassifier(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.qwen_model = AutoModel.from_pretrained(config.model_name) # Load Qwen model
self.classifier = nn.Linear(self.qwen_model.config.hidden_size, config.num_labels)
self.loss_fn = None
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.qwen_model(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state.mean(dim=1)
logits = self.classifier(pooled)
#logits = nn.functional.sigmoid(logits)
if labels is not None:
loss = self.loss_fn(logits, labels)
return loss, logits
return logits
@classmethod
def from_pretrained(cls, model_name):
config = AutoConfig.from_pretrained(model_name)
config.model_name = model_name # Store model name
return super().from_pretrained(model_name, config=config)