Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from transformers import AutoModel, PreTrainedModel, AutoConfig | |
class QwenClassifier(PreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.qwen_model = AutoModel.from_pretrained(config.model_name) # Load Qwen model | |
self.classifier = nn.Linear(self.qwen_model.config.hidden_size, config.num_labels) | |
self.loss_fn = None | |
def forward(self, input_ids, attention_mask, labels=None): | |
outputs = self.qwen_model(input_ids=input_ids, attention_mask=attention_mask) | |
pooled = outputs.last_hidden_state.mean(dim=1) | |
logits = self.classifier(pooled) | |
#logits = nn.functional.sigmoid(logits) | |
if labels is not None: | |
loss = self.loss_fn(logits, labels) | |
return loss, logits | |
return logits | |
def from_pretrained(cls, model_name): | |
config = AutoConfig.from_pretrained(model_name) | |
config.model_name = model_name # Store model name | |
return super().from_pretrained(model_name, config=config) | |