zoya23 commited on
Commit
50a4735
Β·
verified Β·
1 Parent(s): 34cb346

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import streamlit as st
2
- from datasets import load_dataset
 
 
3
  from langchain.chains import LLMChain
4
- from langchain.llms import HuggingFaceHub
5
  from langchain.prompts import PromptTemplate
6
- from langchain.prompts.few_shot import FewShotChatMessagePromptTemplate
7
- from langchain.prompts.example_selector import LengthBasedExampleSelector
8
 
9
- # Load dataset (small subset)
10
  @st.cache_data
11
  def load_examples():
12
  dataset = load_dataset("knkarthick/dialogsum", split="train[:5]") # Take only 5 for speed
@@ -20,10 +21,12 @@ def load_examples():
20
 
21
  examples = load_examples()
22
 
23
- # Set up the HuggingFaceHub model (T5)
24
- llm = HuggingFaceHub(repo_id="google/pegasus-xsum", model_kwargs={"temperature": 0.7})
 
 
25
 
26
- # Few-shot prompt template
27
  example_prompt = FewShotChatMessagePromptTemplate.from_examples(
28
  examples=examples,
29
  example_selector=LengthBasedExampleSelector(examples=examples, max_length=1000),
@@ -34,24 +37,24 @@ example_prompt = FewShotChatMessagePromptTemplate.from_examples(
34
 
35
  # Streamlit UI
36
  st.title("πŸ’¬ Dialogue Summarizer using Few-Shot Prompt + T5 (via Langchain)")
 
37
  input_text = st.text_area("πŸ“ Paste your conversation:")
38
 
39
  if st.button("Generate Summary"):
40
  if input_text.strip():
41
- # Create prompt using FewShotChatMessagePromptTemplate
42
  messages = example_prompt.format_messages(input=input_text)
43
 
44
  with st.expander("πŸ“‹ Generated Prompt"):
45
  for msg in messages:
46
  st.markdown(f"**{msg.type.upper()}**:\n```\n{msg.content}\n```")
47
 
48
- # Create the prompt chain
49
- prompt_template = PromptTemplate(input_variables=["input"], template="{input}")
50
- chain = LLMChain(llm=llm, prompt=prompt_template)
51
 
52
- # Get the summary from the model
53
- summary = chain.run(input_text)
54
  st.success("βœ… Summary:")
55
- st.write(summary)
56
  else:
57
  st.warning("Please enter some text.")
 
1
  import streamlit as st
2
+ from langchain.prompts import FewShotChatMessagePromptTemplate
3
+ from langchain.prompts.example_selector import LengthBasedExampleSelector
4
+ from langchain_huggingface import HuggingFaceEndpoint, HuggingFacePipeline
5
  from langchain.chains import LLMChain
 
6
  from langchain.prompts import PromptTemplate
7
+ from datasets import load_dataset
8
+ from transformers import pipeline
9
 
10
+ # Load dataset (using knkarthick/dialogsum as an example)
11
  @st.cache_data
12
  def load_examples():
13
  dataset = load_dataset("knkarthick/dialogsum", split="train[:5]") # Take only 5 for speed
 
21
 
22
  examples = load_examples()
23
 
24
+ # Load the Hugging Face model
25
+ hf_endpoint = HuggingFaceEndpoint(
26
+ endpoint_url="https://api-inference.huggingface.co/models/t5-small" # or any model you like
27
+ )
28
 
29
+ # Create FewShotChatMessagePromptTemplate
30
  example_prompt = FewShotChatMessagePromptTemplate.from_examples(
31
  examples=examples,
32
  example_selector=LengthBasedExampleSelector(examples=examples, max_length=1000),
 
37
 
38
  # Streamlit UI
39
  st.title("πŸ’¬ Dialogue Summarizer using Few-Shot Prompt + T5 (via Langchain)")
40
+
41
  input_text = st.text_area("πŸ“ Paste your conversation:")
42
 
43
  if st.button("Generate Summary"):
44
  if input_text.strip():
45
+ # Create the prompt using FewShotChatMessagePromptTemplate
46
  messages = example_prompt.format_messages(input=input_text)
47
 
48
  with st.expander("πŸ“‹ Generated Prompt"):
49
  for msg in messages:
50
  st.markdown(f"**{msg.type.upper()}**:\n```\n{msg.content}\n```")
51
 
52
+ # Set up HuggingFacePipeline with the model endpoint
53
+ hf_pipeline = HuggingFacePipeline(pipeline="summarization", model=hf_endpoint)
 
54
 
55
+ # Generate summary
56
+ summary = hf_pipeline(messages[0].content)
57
  st.success("βœ… Summary:")
58
+ st.write(summary[0]['summary_text'])
59
  else:
60
  st.warning("Please enter some text.")