Spaces:
Sleeping
Sleeping
File size: 1,071 Bytes
6a1e686 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig
class QwenClassifier(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.qwen_model = AutoModel.from_pretrained(config.model_name) # Load Qwen model
self.classifier = nn.Linear(self.qwen_model.config.hidden_size, config.num_labels)
self.loss_fn = None
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.qwen_model(input_ids=input_ids, attention_mask=attention_mask)
pooled = outputs.last_hidden_state.mean(dim=1)
logits = self.classifier(pooled)
#logits = nn.functional.sigmoid(logits)
if labels is not None:
loss = self.loss_fn(logits, labels)
return loss, logits
return logits
@classmethod
def from_pretrained(cls, model_name):
config = AutoConfig.from_pretrained(model_name)
config.model_name = model_name # Store model name
return super().from_pretrained(model_name, config=config)
|