File size: 3,401 Bytes
246133f
 
5cdc45c
237a63b
5cdc45c
6050021
 
 
 
a4d215d
6050021
 
 
 
a4d215d
6050021
 
 
246133f
5cdc45c
ff871dc
237a63b
 
246133f
5cdc45c
237a63b
 
246133f
5cdc45c
237a63b
 
5cdc45c
237a63b
5cdc45c
237a63b
 
246133f
 
5cdc45c
 
 
 
 
237a63b
5cdc45c
237a63b
5cdc45c
237a63b
246133f
5cdc45c
237a63b
ff871dc
 
 
 
 
 
5cdc45c
237a63b
 
5cdc45c
 
 
237a63b
 
246133f
5cdc45c
 
 
237a63b
246133f
5cdc45c
237a63b
5cdc45c
237a63b
 
5cdc45c
237a63b
5cdc45c
 
6050021
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
import pandas as pd
import re
import io
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer

# Download NLTK resources
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

# Initialize lemmatizer and stopwords
lemmatizer = WordNetLemmatizer()
stop_words = set(stopwords.words('english'))

# Load fine-tuned model and tokenizer
model_name = "TAgroup5/news-classification-model"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Initialize pipelines
text_classification_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer)

# Streamlit App
st.title("News Classification and Q&A")

## ====================== Component 1: News Classification ====================== ##
st.header("Classify News Articles")
st.markdown("Upload a CSV file with a 'content' column to classify news into categories.")

uploaded_file = st.file_uploader("Choose a CSV file", type="csv")

if uploaded_file is not None:
    try:
        df = pd.read_csv(uploaded_file, encoding="utf-8")  # Handle encoding issues
    except UnicodeDecodeError:
        df = pd.read_csv(uploaded_file, encoding="ISO-8859-1")

    if 'content' not in df.columns:
        st.error("Error: The uploaded CSV must contain a 'content' column.")
    else:
        st.write("Preview of uploaded data:")
        st.dataframe(df.head())

        # Preprocessing function
        def preprocess_text(text):
            text = text.lower()  # Convert to lowercase
            text = re.sub(r'[^a-z\s]', '', text)  # Remove special characters & numbers
            tokens = word_tokenize(text)  # Tokenization
            tokens = [word for word in tokens if word not in stop_words]  # Remove stopwords
            tokens = [lemmatizer.lemmatize(word) for word in tokens]  # Lemmatization
            return " ".join(tokens)

        # Apply preprocessing and classification
        df['processed_content'] = df['content'].apply(preprocess_text)
        df['class'] = df['processed_content'].apply(lambda x: text_classification_pipeline(x)[0]['label'] if x.strip() else "Unknown")

        # Show results
        st.write("Classification Results:")
        st.dataframe(df[['content', 'class']])

        # Provide CSV download
        output = io.BytesIO()
        df.to_csv(output, index=False, encoding="utf-8-sig")
        st.download_button(label="Download classified news", data=output.getvalue(), file_name="output.csv", mime="text/csv")

## ====================== Component 2: Q&A ====================== ##
st.header("Ask a Question About the News")
st.markdown("Enter a question and provide a news article to get an answer.")

question = st.text_input("Ask a question:")
context = st.text_area("Provide the news article or content for the Q&A:", height=150)

if question and context.strip():
    result = qa_pipeline(question=question, context=context)
    
    # Check if the result contains an answer
    if 'answer' in result and result['answer']:
        st.write("Answer:", result['answer'])
    else:
        st.write("No answer found in the provided content.")