File size: 5,452 Bytes
b228643
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
2024-11-10 15:29:50
"""
import torchvision.transforms as transforms
import sys
import os
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from datetime import datetime
import multiprocessing
from transformers import ViTModel, ViTConfig
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold
import numpy as np
from collections import Counter
from torch.optim.lr_scheduler import StepLR
from PIL import Image
import torch.nn.functional as F


class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim, hidden_dim):
        super(PatchEmbedding, self).__init__()
        # self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.patch_embed = nn.Conv2d(in_channels, hidden_dim, kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.patch_embed(x).flatten(2).transpose(1, 2)  # (batch_size, num_patches, embed_dim)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim, hidden_dim):
        super(PositionalEncoding, self).__init__()
        self.positional_encoding = nn.Parameter(torch.randn(1, num_patches, hidden_dim))

    def forward(self, x):
        return x + self.positional_encoding


class TransformerLayer(nn.Module):
    def __init__(self, hidden_dim, num_heads, mlp_dim, dropout_rate):
        super(TransformerLayer, self).__init__()
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout_rate)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_dim, hidden_dim),
            nn.Dropout(dropout_rate)
        )
        self.norm1 = nn.LayerNorm(hidden_dim)
        self.norm2 = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        x = self.norm2(x + self.mlp(x))
        return x

# EvoViTModel class for building Vision Transformer model
class EvoViTModel(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim, num_classes, hidden_dim):
        super(EvoViTModel, self).__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim, hidden_dim)
        self.position_encoding = PositionalEncoding(self.patch_embed.num_patches, embed_dim, hidden_dim)
        self.sigmoid = nn.Sigmoid()
        # Placeholder for dynamically generated init:
# Transformer Layer Initialization
        self.transformer_layer_0 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.20362387412323335)
        self.transformer_layer_1 = TransformerLayer(num_heads=8, mlp_dim=3072, hidden_dim=512, dropout_rate=0.29859399476669696)
        self.transformer_layer_2 = TransformerLayer(num_heads=16, mlp_dim=4096, hidden_dim=512, dropout_rate=0.24029622136332746)
        self.transformer_layer_3 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.22640265738407994)
        self.transformer_layer_4 = TransformerLayer(num_heads=16, mlp_dim=3072, hidden_dim=512, dropout_rate=0.2969787366320388)
        self.transformer_layer_5 = TransformerLayer(num_heads=16, mlp_dim=2048, hidden_dim=512, dropout_rate=0.11264741089870321)
        self.transformer_layer_6 = TransformerLayer(num_heads=8, mlp_dim=4096, hidden_dim=512, dropout_rate=0.25324312813345734)
        self.transformer_layer_7 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17729069086242882)
        self.transformer_layer_8 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.2531553780827078)
        self.transformer_layer_9 = TransformerLayer(num_heads=16, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17372554665581236)
        self.transformer_layer_10 = TransformerLayer(num_heads=16, mlp_dim=3072, hidden_dim=512, dropout_rate=0.25217233180956183)
        self.transformer_layer_11 = TransformerLayer(num_heads=8, mlp_dim=4096, hidden_dim=512, dropout_rate=0.24459590331387862)
        self.transformer_layer_12 = TransformerLayer(num_heads=8, mlp_dim=2048, hidden_dim=512, dropout_rate=0.17589263405869232)
        self.classifier = nn.Linear(512, 48)

    def forward(self, x):
        expected_dtype = self.patch_embed.patch_embed   .weight.dtype
        if x.dtype != expected_dtype:
            x = x.to(expected_dtype)

        x = self.patch_embed(x)
        x = self.position_encoding(x)
        # Pass through additional transformer layers
        # Placeholder for dynamically generated forward pass:
        x = self.transformer_layer_0(x)
        x = self.transformer_layer_1(x)
        x = self.transformer_layer_2(x)
        x = self.transformer_layer_3(x)
        x = self.transformer_layer_4(x)
        x = self.transformer_layer_5(x)
        x = self.transformer_layer_6(x)
        x = self.transformer_layer_7(x)
        x = self.transformer_layer_8(x)
        x = self.transformer_layer_9(x)
        x = self.transformer_layer_10(x)
        x = self.transformer_layer_11(x)
        x = self.transformer_layer_12(x)
        x = self.classifier(x[:, 0])
        #probs = self.sigmoid(x)
        #return probs
        return x