import torch.nn as nn from transformers import AutoModel, PreTrainedModel, AutoConfig class QwenClassifier(PreTrainedModel): def __init__(self, config): super().__init__(config) self.qwen_model = AutoModel.from_pretrained(config.model_name) # Load Qwen model self.classifier = nn.Linear(self.qwen_model.config.hidden_size, config.num_labels) self.loss_fn = None def forward(self, input_ids, attention_mask, labels=None): outputs = self.qwen_model(input_ids=input_ids, attention_mask=attention_mask) pooled = outputs.last_hidden_state.mean(dim=1) logits = self.classifier(pooled) #logits = nn.functional.sigmoid(logits) if labels is not None: loss = self.loss_fn(logits, labels) return loss, logits return logits @classmethod def from_pretrained(cls, model_name): config = AutoConfig.from_pretrained(model_name) config.model_name = model_name # Store model name return super().from_pretrained(model_name, config=config)