ApsidalSolid4 commited on
Commit
3fe982e
·
verified ·
1 Parent(s): a627ffe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -26
app.py CHANGED
@@ -69,33 +69,35 @@ class TextClassifier:
69
  self.initialize_model()
70
 
71
  def initialize_model(self):
72
- """Initialize the model and tokenizer."""
73
- logger.info("Initializing model and tokenizer...")
74
-
75
- # Download and save tokenizer files locally
76
- local_tokenizer_path = "tokenizer"
77
- if not os.path.exists(local_tokenizer_path):
78
- AutoTokenizer.from_pretrained(self.model_name).save_pretrained(local_tokenizer_path)
79
-
80
- # Load from local files
81
- self.tokenizer = AutoTokenizer.from_pretrained(local_tokenizer_path)
82
-
83
- # First initialize the base model
84
- self.model = AutoModelForSequenceClassification.from_pretrained(
85
- self.model_name,
86
- num_labels=2
87
- ).to(self.device)
 
 
 
 
 
 
 
 
 
 
 
88
 
89
- # Look for model file in the same directory as the code
90
- model_path = "model.pt" # Your model file should be uploaded as model.pt
91
- if os.path.exists(model_path):
92
- logger.info(f"Loading custom model from {model_path}")
93
- checkpoint = torch.load(model_path, map_location=self.device)
94
- self.model.load_state_dict(checkpoint['model_state_dict'])
95
- else:
96
- logger.warning("Custom model file not found. Using base model.")
97
-
98
- self.model.eval()
99
 
100
  def predict_with_sentence_scores(self, text: str) -> Dict:
101
  """Predict with sentence-level granularity using overlapping windows."""
 
69
  self.initialize_model()
70
 
71
  def initialize_model(self):
72
+ """Initialize the model and tokenizer."""
73
+ logger.info("Initializing model and tokenizer...")
74
+
75
+ from transformers import DebertaV2TokenizerFast
76
+
77
+ # Try to load tokenizer directly from the Hub
78
+ self.tokenizer = DebertaV2TokenizerFast.from_pretrained(
79
+ self.model_name,
80
+ model_max_length=MAX_LENGTH,
81
+ use_fast=False,
82
+ from_slow=True
83
+ )
84
+
85
+ # Initialize the model as before
86
+ self.model = AutoModelForSequenceClassification.from_pretrained(
87
+ self.model_name,
88
+ num_labels=2
89
+ ).to(self.device)
90
+
91
+ # Your existing model loading code
92
+ model_path = "model.pt"
93
+ if os.path.exists(model_path):
94
+ logger.info(f"Loading custom model from {model_path}")
95
+ checkpoint = torch.load(model_path, map_location=self.device)
96
+ self.model.load_state_dict(checkpoint['model_state_dict'])
97
+ else:
98
+ logger.warning("Custom model file not found. Using base model.")
99
 
100
+ self.model.eval()
 
 
 
 
 
 
 
 
 
101
 
102
  def predict_with_sentence_scores(self, text: str) -> Dict:
103
  """Predict with sentence-level granularity using overlapping windows."""