pentarosarium commited on
Commit
412ee33
·
1 Parent(s): 2a0d401
Files changed (1) hide show
  1. app.py +100 -5
app.py CHANGED
@@ -44,13 +44,108 @@ class ProcessControl:
44
  class EventDetector:
45
  def __init__(self):
46
  self.model_name = "google/mt5-small"
47
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
 
 
 
48
  self.model = None
49
  self.finbert = None
50
  self.roberta = None
51
  self.finbert_tone = None
52
  self.control = ProcessControl()
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def get_sentiment_label(self, result):
55
  """Helper method for sentiment classification"""
56
  label = result['label'].lower()
@@ -72,9 +167,9 @@ class EventDetector:
72
 
73
  try:
74
  inputs = [truncated_text]
75
- finbert_result = self.finbert(inputs, truncation=True, max_length=512)[0]
76
- roberta_result = self.roberta(inputs, truncation=True, max_length=512)[0]
77
- finbert_tone_result = self.finbert_tone(inputs, truncation=True, max_length=512)[0]
78
 
79
  results = [
80
  self.get_sentiment_label(finbert_result),
@@ -182,7 +277,7 @@ def create_interface():
182
  control = ProcessControl()
183
 
184
  with gr.Blocks(theme=gr.themes.Soft()) as app:
185
- gr.Markdown("# AI-анализ мониторинга новостей v.1.13")
186
 
187
  with gr.Row():
188
  file_input = gr.File(
 
44
  class EventDetector:
45
  def __init__(self):
46
  self.model_name = "google/mt5-small"
47
+ # Initialize tokenizer with legacy=True to suppress warning
48
+ self.tokenizer = AutoTokenizer.from_pretrained(
49
+ self.model_name,
50
+ legacy=True
51
+ )
52
  self.model = None
53
  self.finbert = None
54
  self.roberta = None
55
  self.finbert_tone = None
56
  self.control = ProcessControl()
57
 
58
+ @spaces.GPU
59
+ def initialize_models(self):
60
+ """Initialize all models with GPU support"""
61
+ try:
62
+ device = "cuda" if torch.cuda.is_available() else "cpu"
63
+ logger.info(f"Initializing models on device: {device}")
64
+
65
+ # Initialize MT5 model
66
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name).to(device)
67
+
68
+ # Initialize sentiment analysis pipelines
69
+ self.finbert = pipeline(
70
+ "sentiment-analysis",
71
+ model="ProsusAI/finbert",
72
+ device=device,
73
+ truncation=True,
74
+ max_length=512
75
+ )
76
+
77
+ self.roberta = pipeline(
78
+ "sentiment-analysis",
79
+ model="cardiffnlp/twitter-roberta-base-sentiment",
80
+ device=device,
81
+ truncation=True,
82
+ max_length=512
83
+ )
84
+
85
+ self.finbert_tone = pipeline(
86
+ "sentiment-analysis",
87
+ model="yiyanghkust/finbert-tone",
88
+ device=device,
89
+ truncation=True,
90
+ max_length=512
91
+ )
92
+
93
+ logger.info("All models initialized successfully")
94
+ return True
95
+
96
+ except Exception as e:
97
+ logger.error(f"Model initialization error: {str(e)}")
98
+ return False
99
+
100
+ @spaces.GPU
101
+ def detect_events(self, text, entity):
102
+ if not text or not entity:
103
+ return "Нет", "Invalid input"
104
+
105
+ try:
106
+ # Check if models are initialized
107
+ if self.model is None:
108
+ if not self.initialize_models():
109
+ return "Нет", "Model initialization failed"
110
+
111
+ # Truncate input text
112
+ text = text[:500]
113
+
114
+ prompt = f"""<s>Analyze the following news about {entity}:
115
+ Text: {text}
116
+ Task: Identify the main event type and provide a brief summary.</s>"""
117
+
118
+ inputs = self.tokenizer(
119
+ prompt,
120
+ return_tensors="pt",
121
+ padding=True,
122
+ truncation=True,
123
+ max_length=512
124
+ ).to(self.model.device)
125
+
126
+ outputs = self.model.generate(
127
+ **inputs,
128
+ max_length=300,
129
+ num_return_sequences=1,
130
+ pad_token_id=self.tokenizer.pad_token_id,
131
+ eos_token_id=self.tokenizer.eos_token_id
132
+ )
133
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
134
+
135
+ event_type = "Нет"
136
+ if any(term in text.lower() for term in ['отчет', 'выручка', 'прибыль', 'ebitda']):
137
+ event_type = "Отчетность"
138
+ elif any(term in text.lower() for term in ['облигаци', 'купон', 'дефолт']):
139
+ event_type = "РЦБ"
140
+ elif any(term in text.lower() for term in ['суд', 'иск', 'арбитраж']):
141
+ event_type = "Суд"
142
+
143
+ return event_type, response
144
+
145
+ except Exception as e:
146
+ logger.error(f"Event detection error: {str(e)}")
147
+ return "Нет", f"Error: {str(e)}"
148
+
149
  def get_sentiment_label(self, result):
150
  """Helper method for sentiment classification"""
151
  label = result['label'].lower()
 
167
 
168
  try:
169
  inputs = [truncated_text]
170
+ finbert_result = self.finbert(inputs)[0]
171
+ roberta_result = self.roberta(inputs)[0]
172
+ finbert_tone_result = self.finbert_tone(inputs)[0]
173
 
174
  results = [
175
  self.get_sentiment_label(finbert_result),
 
277
  control = ProcessControl()
278
 
279
  with gr.Blocks(theme=gr.themes.Soft()) as app:
280
+ gr.Markdown("# AI-а��ализ мониторинга новостей v.1.14")
281
 
282
  with gr.Row():
283
  file_input = gr.File(