Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
from datetime import datetime | |
import io | |
import base64 | |
from typing import Dict, List, Set, Tuple | |
from rapidfuzz import fuzz, process | |
from collections import defaultdict | |
from tqdm import tqdm | |
import spacy | |
import torch.nn.functional as F | |
class NewsProcessor: | |
def __init__(self, similarity_threshold=0.75, time_threshold=24): | |
try: | |
self.nlp = spacy.load("ru_core_news_sm") | |
except: | |
self.nlp = spacy.load("en_core_web_sm") | |
import pymorphy2 | |
self.morph = pymorphy2.MorphAnalyzer() | |
self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
self.model = AutoModel.from_pretrained('sentence-transformers/paraphrase-multilingual-mpnet-base-v2') | |
self.similarity_threshold = similarity_threshold | |
self.time_threshold = time_threshold | |
def mean_pooling(self, model_output, attention_mask): | |
token_embeddings = model_output[0] | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
def encode_text(self, text): | |
# Convert text to string and handle NaN values | |
if pd.isna(text): | |
text = "" | |
else: | |
text = str(text) | |
encoded_input = self.tokenizer(text, padding=True, truncation=True, max_length=512, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input) | |
sentence_embeddings = self.mean_pooling(model_output, encoded_input['attention_mask']) | |
return F.normalize(sentence_embeddings[0], p=2, dim=0).numpy() | |
def get_company_variants(self, company_name: str) -> Set[str]: | |
"""Generate morphological variants of company name.""" | |
if pd.isna(company_name): | |
return set() | |
# Clean the company name | |
name = str(company_name).strip('"\'').strip() | |
name = name.split(',')[0].strip() # Take first part before comma | |
variants = set() | |
variants.add(name.lower()) | |
# Split into words and get significant parts | |
words = [w for w in name.split() if len(w) >= 3] | |
# Generate morphological variants for each significant word | |
for word in words: | |
parsed = self.morph.parse(word)[0] | |
lexeme = parsed.lexeme | |
variants.update(v.word.lower() for v in lexeme) | |
# Add combinations of consecutive words | |
if len(words) > 1: | |
for i in range(len(words)-1): | |
variants.add(f"{words[i]} {words[i+1]}".lower()) | |
return variants | |
def is_company_main_subject(self, title: str, text: str, company_name: str, threshold_score: float = 0.5) -> Tuple[bool, float]: | |
""" | |
Enhanced company subject detection using title and text with Russian language support. | |
Returns (is_main_subject, relevance_score). | |
""" | |
if pd.isna(text) or pd.isna(company_name): | |
return False, 0.0 | |
# Ensure we have strings | |
title = str(title) if not pd.isna(title) else "" | |
text = str(text) if not pd.isna(text) else "" | |
# Get company name variants | |
company_variants = self.get_company_variants(company_name) | |
if not company_variants: | |
return False, 0.0 | |
# Initialize scoring components | |
title_score = 0.0 | |
first_para_score = 0.0 | |
subject_score = 0.0 | |
frequency_score = 0.0 | |
# Process title (weight: 0.4) | |
title_doc = self.nlp(title.lower()) | |
title_text = title_doc.text | |
for variant in company_variants: | |
if variant in title_text: | |
title_score = 0.4 | |
# Check if company is subject in title | |
for token in title_doc: | |
if variant in token.text and token.dep_ in ['nsubj', 'nsubjpass']: | |
title_score = 0.4 | |
break | |
break | |
# Process main text | |
doc = self.nlp(text.lower()) | |
paragraphs = [p.strip() for p in text.split('\n') if p.strip()] | |
first_para = paragraphs[0] if paragraphs else "" | |
# Check first paragraph (weight: 0.2) | |
for variant in company_variants: | |
if variant in first_para.lower(): | |
first_para_score = 0.2 | |
break | |
# Analyze subject position and frequency | |
company_mentions = 0 | |
subject_mentions = 0 | |
other_company_indicators = { | |
'компания', 'корпорация', 'фирма', 'банк', 'группа', 'холдинг', | |
'организация', 'предприятие', 'производитель', 'ао', 'оао', 'пао', 'нк', 'гк', | |
'ооо', 'лк', 'фк', 'акб', 'ук', 'зао', 'ак' | |
} | |
other_companies = 0 | |
# Analyze each sentence | |
for sent in doc.sents: | |
sent_text = sent.text.lower() | |
# Count company mentions and subject positions | |
company_in_sent = False | |
for variant in company_variants: | |
if variant in sent_text: | |
company_mentions += 1 | |
company_in_sent = True | |
# Check subject position | |
for token in sent: | |
if variant in token.text and token.dep_ in ['nsubj', 'nsubjpass']: | |
subject_mentions += 1 | |
# Count other company mentions | |
if company_in_sent: | |
continue | |
for indicator in other_company_indicators: | |
if indicator in sent_text: | |
other_companies += 1 | |
break | |
# Calculate subject score (weight: 0.2) | |
subject_score = min(0.2, (subject_mentions / max(1, company_mentions)) * 0.2) | |
# Calculate frequency score (weight: 0.2) | |
if company_mentions > 0: | |
company_ratio = company_mentions / (company_mentions + other_companies + 1) | |
frequency_score = min(0.2, company_ratio * 0.2) | |
# Calculate final score | |
final_score = title_score + first_para_score + subject_score + frequency_score | |
# Apply penalties | |
if other_companies > 5: # Too many other companies mentioned | |
final_score *= 0.5 | |
# Check if the company is just part of a list | |
list_indicators = {'среди', 'включая', 'такие как', 'в том числе', 'и другие', 'а также'} | |
for indicator in list_indicators: | |
if indicator in text.lower(): | |
final_score *= 0.7 | |
return final_score >= threshold_score, final_score | |
def process_news(self, df: pd.DataFrame, progress_bar=None): | |
# Ensure the DataFrame is not empty | |
if df.empty: | |
return pd.DataFrame(columns=['cluster_id', 'datetime', 'company', 'relevance_score', 'text', 'cluster_size']) | |
df = df.copy() # Make a copy to preserve original indices | |
clusters = [] | |
processed = set() | |
for idx in df.index: # Iterate over original indices | |
if idx in processed: | |
continue | |
row1 = df.loc[idx] | |
cluster = [idx] # Store original index | |
processed.add(idx) | |
if not pd.isna(row1['text']): | |
text1_embedding = self.encode_text(row1['text']) | |
if progress_bar: | |
progress_bar.progress(len(processed) / len(df)) | |
for other_idx in df.index: # Iterate over original indices | |
if other_idx in processed: | |
continue | |
row2 = df.loc[other_idx] | |
if pd.isna(row2['text']): | |
continue | |
time_diff = pd.to_datetime(row1['datetime']) - pd.to_datetime(row2['datetime']) | |
if abs(time_diff.total_seconds() / 3600) > self.time_threshold: | |
continue | |
text2_embedding = self.encode_text(row2['text']) | |
similarity = np.dot(text1_embedding, text2_embedding) | |
if similarity >= self.similarity_threshold: | |
cluster.append(other_idx) | |
processed.add(other_idx) | |
clusters.append(cluster) | |
# Create result DataFrame preserving original indices | |
result_data = [] | |
for cluster_id, cluster_indices in enumerate(clusters, 1): | |
cluster_rows = df.loc[cluster_indices] | |
for idx in cluster_indices: | |
result_data.append({ | |
'cluster_id': cluster_id, | |
'datetime': df.loc[idx, 'datetime'], | |
'company': df.loc[idx, 'company'], | |
'text': df.loc[idx, 'text'], | |
'cluster_size': len(cluster_indices) | |
}) | |
result_df = pd.DataFrame(result_data, index=sum(clusters, [])) # Use original indices | |
return result_df | |
class NewsDeduplicator: | |
def __init__(self, fuzzy_threshold=85): | |
self.fuzzy_threshold = fuzzy_threshold | |
def deduplicate(self, df: pd.DataFrame, progress_bar=None) -> pd.DataFrame: | |
seen_texts: List[str] = [] | |
text_to_companies: Dict[str, Set[str]] = defaultdict(set) | |
indices_to_keep: Set[int] = set() | |
for idx, row in tqdm(df.iterrows(), total=len(df)): | |
text = str(row['text']) if not pd.isna(row['text']) else "" | |
company = str(row['company']) if not pd.isna(row['company']) else "" | |
if not text: | |
indices_to_keep.add(idx) | |
continue | |
if seen_texts: | |
result = process.extractOne( | |
text, | |
seen_texts, | |
scorer=fuzz.ratio, | |
score_cutoff=self.fuzzy_threshold | |
) | |
match = result[0] if result else None | |
else: | |
match = None | |
if match: | |
text_to_companies[match].add(company) | |
else: | |
seen_texts.append(text) | |
text_to_companies[text].add(company) | |
indices_to_keep.add(idx) | |
if progress_bar: | |
progress_bar.progress((idx + 1) / len(df)) | |
dedup_df = df.iloc[list(indices_to_keep)].copy() | |
for idx in indices_to_keep: | |
text = str(df.iloc[idx]['text']) | |
companies = sorted(text_to_companies[text]) | |
dedup_df.at[idx, 'company'] = ' | '.join(companies) | |
dedup_df.at[idx, 'company_count'] = len(companies) | |
dedup_df.at[idx, 'duplicate_count'] = len(text_to_companies[text]) | |
return dedup_df.sort_values('datetime') | |
def create_download_link(df: pd.DataFrame, filename: str) -> str: | |
excel_buffer = io.BytesIO() | |
with pd.ExcelWriter(excel_buffer, engine='openpyxl') as writer: | |
df.to_excel(writer, index=False) | |
excel_buffer.seek(0) | |
b64 = base64.b64encode(excel_buffer.read()).decode() | |
return f'<a href="data:application/vnd.openxmlformats-officedocument.spreadsheetml.sheet;base64,{b64}" download="{filename}">Download {filename}</a>' | |
def main(): | |
st.title("кластеризуем новости v.1.23 + print debug") | |
st.write("Upload Excel file with columns: company, datetime, text") | |
uploaded_file = st.file_uploader("Choose Excel file", type=['xlsx']) | |
if uploaded_file: | |
try: | |
# Read all columns from original sheet | |
df_original = pd.read_excel(uploaded_file, sheet_name='Публикации') | |
st.write("Available columns:", df_original.columns.tolist()) | |
# Create working copy with required columns | |
df = df_original.copy() | |
text_column = df_original.columns[6] | |
title_column = df_original.columns[5] | |
datetime_column = df_original.columns[3] | |
company_column = df_original.columns[0] | |
df = df_original[[company_column, datetime_column, title_column, text_column]].copy() | |
df.columns = ['company', 'datetime', 'title', 'text'] | |
st.success(f'Loaded {len(df)} records') | |
st.dataframe(df.head()) | |
col1, col2 = st.columns(2) | |
with col1: | |
fuzzy_threshold = st.slider("Fuzzy Match Threshold", 30, 100, 50) | |
with col2: | |
similarity_threshold = st.slider("Similarity Threshold", 0.5, 1.0, 0.75) | |
time_threshold = st.slider("Time Threshold (hours)", 1, 72, 24) | |
if st.button("Process News"): | |
try: | |
progress_bar = st.progress(0) | |
# Step 1: Deduplicate | |
deduplicator = NewsDeduplicator(fuzzy_threshold) | |
dedup_df = deduplicator.deduplicate(df, progress_bar) | |
# Preserve all columns from original DataFrame in dedup_df_full | |
dedup_df_full = df_original.loc[dedup_df.index].copy() | |
st.write("\nDeduplication Results:") | |
st.write(f"Original indices: {df.index.tolist()}") | |
st.write(f"Dedup indices: {dedup_df.index.tolist()}") | |
st.write(f"Sample from dedup_df:") | |
st.write(dedup_df[['company', 'text']].head()) | |
# Step 2: Cluster deduplicated news | |
processor = NewsProcessor(similarity_threshold, time_threshold) | |
result_df = processor.process_news(dedup_df, progress_bar) | |
st.write("\nClustering Results:") | |
st.write(f"Result df indices: {result_df.index.tolist()}") | |
# Display cluster information | |
if len(result_df) > 0: | |
st.write("\nCluster Details:") | |
for cluster_id in result_df['cluster_id'].unique(): | |
cluster_mask = result_df['cluster_id'] == cluster_id | |
if sum(cluster_mask) > 1: # Only show multi-item clusters | |
cluster_indices = result_df[cluster_mask].index.tolist() | |
st.write(f"\nCluster {cluster_id}:") | |
st.write(f"Indices: {cluster_indices}") | |
# Show texts for verification | |
for idx in cluster_indices: | |
text_length = len(str(dedup_df.loc[idx, 'text'])) | |
st.write(f"Index {idx} - Length {text_length}:") | |
st.write(str(dedup_df.loc[idx, 'text'])[:100] + '...') | |
# Process clusters for deletion | |
indices_to_delete = set() | |
if len(result_df) > 0: | |
for cluster_id in result_df['cluster_id'].unique(): | |
cluster_mask = result_df['cluster_id'] == cluster_id | |
if sum(cluster_mask) > 1: | |
cluster_indices = result_df[cluster_mask].index.tolist() | |
text_lengths = dedup_df.loc[cluster_indices, 'text'].fillna('').str.len() | |
longest_text_idx = text_lengths.idxmax() | |
indices_to_delete.update(set(cluster_indices) - {longest_text_idx}) | |
st.write("\nDeletion Summary:") | |
st.write(f"Indices to delete: {sorted(list(indices_to_delete))}") | |
# Create final DataFrame | |
declustered_df = dedup_df_full.copy() | |
if indices_to_delete: | |
declustered_df = declustered_df.drop(index=list(indices_to_delete)) | |
st.write(f"Final indices kept: {sorted(declustered_df.index.tolist())}") | |
# Print statistics | |
st.success(f""" | |
Processing results: | |
- Original rows: {len(df_original)} | |
- After deduplication: {len(dedup_df_full)} | |
- Multi-item clusters found: {len(result_df[result_df['cluster_size'] > 1]['cluster_id'].unique()) if len(result_df) > 0 else 0} | |
- Rows removed from clusters: {len(indices_to_delete)} | |
- Final rows kept: {len(declustered_df)} | |
""") | |
# Download buttons for all results | |
st.subheader("Download Results") | |
st.markdown(create_download_link(dedup_df_full, "deduplicated_news.xlsx"), unsafe_allow_html=True) | |
st.markdown(create_download_link(result_df, "clustered_news.xlsx"), unsafe_allow_html=True) | |
st.markdown(create_download_link(declustered_df, "declustered_news.xlsx"), unsafe_allow_html=True) | |
# Show clusters info | |
if len(result_df) > 0: | |
st.subheader("Largest Clusters") | |
largest_clusters = result_df[result_df['cluster_size'] > 1].sort_values( | |
['cluster_size', 'cluster_id', 'datetime'], | |
ascending=[False, True, True] | |
) | |
st.dataframe(largest_clusters) | |
except Exception as e: | |
st.error(f"Error: {str(e)}") | |
import traceback | |
st.error(traceback.format_exc()) | |
finally: | |
progress_bar.empty() | |
except Exception as e: | |
st.error(f"Error reading file: {str(e)}") | |
import traceback | |
st.error(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() |