File size: 1,071 Bytes
6a1e686
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel, AutoConfig

class QwenClassifier(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.qwen_model = AutoModel.from_pretrained(config.model_name)  # Load Qwen model

        self.classifier = nn.Linear(self.qwen_model.config.hidden_size, config.num_labels)
        self.loss_fn = None 

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.qwen_model(input_ids=input_ids, attention_mask=attention_mask)
        pooled = outputs.last_hidden_state.mean(dim=1)
        logits = self.classifier(pooled)
        #logits = nn.functional.sigmoid(logits)

        if labels is not None:
            loss = self.loss_fn(logits, labels)
            return loss, logits
        return logits

    @classmethod
    def from_pretrained(cls, model_name):
        config = AutoConfig.from_pretrained(model_name)
        config.model_name = model_name  # Store model name
        return super().from_pretrained(model_name, config=config)