DrishtiSharma commited on
Commit
fc64c4c
Β·
verified Β·
1 Parent(s): e90d440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -73
app.py CHANGED
@@ -1,11 +1,27 @@
1
  import sys
2
  import os
3
  import re
 
4
  import time
5
- import tempfile
6
  import streamlit as st
7
  import nltk
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain.llms import OpenAI
@@ -15,52 +31,62 @@ from langchain.embeddings import HuggingFaceEmbeddings
15
  from langchain.text_splitter import NLTKTextSplitter
16
  from patent_downloader import PatentDownloader
17
 
18
- # Download NLTK resources
19
- nltk.download("punkt", quiet=True)
20
 
21
- #fetch API key
22
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
23
  if not OPENAI_API_KEY:
24
- st.error("Critical Error: OpenAI API key not found in environment variables. Please configure it.")
25
  st.stop()
26
 
 
 
 
 
 
27
 
28
- def extract_patent_number(url):
29
- """Extracts patent number from a Google patent link."""
30
- pattern = r"/patent/([A-Z]{2}\d+)"
31
- match = re.search(pattern, url)
32
- return match.group(1) if match else None
33
-
34
-
35
- def download_pdf(patent_number):
36
- """Downloads patent PDF using a temporary directory."""
37
- try:
38
- with tempfile.TemporaryDirectory() as temp_dir:
39
- patent_downloader = PatentDownloader(verbose=True)
40
- output_path = patent_downloader.download(patents=patent_number, output_path=temp_dir)
41
- return output_path[0]
42
- except Exception as e:
43
- st.error(f"Failed to download patent PDF: {e}")
44
- return None
45
-
46
 
47
  def load_docs(document_path):
48
- """Loads and splits PDF documents into chunks."""
49
  try:
50
- loader = UnstructuredPDFLoader(document_path)
 
 
 
 
 
51
  documents = loader.load()
52
  text_splitter = NLTKTextSplitter(chunk_size=1000)
53
  return text_splitter.split_documents(documents)
54
  except Exception as e:
55
- st.error(f"Failed to process PDF: {e}")
56
- return []
 
 
 
 
 
 
57
 
 
 
58
 
59
- def load_chain(docs):
60
- """Creates a conversational retrieval chain using in-memory ChromaDB."""
61
- vectordb = Chroma.from_documents(
62
- docs, HuggingFaceEmbeddings(), persist_directory=None
63
  )
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  memory = ConversationBufferMemory(
66
  memory_key="chat_history",
@@ -68,7 +94,6 @@ def load_chain(docs):
68
  input_key="question",
69
  output_key="answer",
70
  )
71
-
72
  return ConversationalRetrievalChain.from_llm(
73
  OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
74
  vectordb.as_retriever(search_kwargs={"k": 3}),
@@ -76,8 +101,20 @@ def load_chain(docs):
76
  memory=memory,
77
  )
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- # Streamlit UI
81
  if __name__ == "__main__":
82
  st.set_page_config(
83
  page_title="Patent Chat: Google Patents Chat Demo",
@@ -85,10 +122,8 @@ if __name__ == "__main__":
85
  layout="wide",
86
  initial_sidebar_state="expanded",
87
  )
88
-
89
  st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
90
 
91
- # Input for Google Patent Link
92
  patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
93
 
94
  if not patent_link:
@@ -100,64 +135,48 @@ if __name__ == "__main__":
100
  st.error("Invalid patent link format. Please provide a valid Google patent link.")
101
  st.stop()
102
 
103
- st.write(f"πŸ” Patent Number: **{patent_number}**")
104
 
105
- # Download or Upload PDF
106
- st.write("πŸ“₯ Downloading patent PDF...")
107
- pdf_path = None
108
-
109
- try:
110
  pdf_path = download_pdf(patent_number)
111
- except Exception:
112
- st.error("Automatic download failed. Please upload the PDF manually below.")
113
-
114
- if not pdf_path:
115
- uploaded_file = st.file_uploader("Upload the patent PDF file:", type="pdf")
116
- if uploaded_file:
117
- with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp_file:
118
- tmp_file.write(uploaded_file.read())
119
- pdf_path = tmp_file.name
120
- st.success("βœ… PDF successfully uploaded.")
121
- else:
122
- st.stop()
123
-
124
- # Load and Process PDF
125
- st.write("πŸ”„ Processing document...")
126
- docs = load_docs(pdf_path)
127
-
128
- if not docs:
129
- st.error("No content found in the PDF. Exiting...")
130
- st.stop()
131
 
132
- chain = load_chain(docs)
 
133
  st.success("πŸš€ Document successfully loaded! You can now start asking questions.")
134
 
135
- # Initialize chat history
136
  if "messages" not in st.session_state:
137
  st.session_state["messages"] = [
138
  {"role": "assistant", "content": "Hello! How can I assist you with this patent?"}
139
  ]
140
 
141
- # Display chat history
142
  for message in st.session_state.messages:
143
  with st.chat_message(message["role"]):
144
  st.markdown(message["content"])
145
 
146
- # Handle User Input
147
  if user_input := st.chat_input("What is your question?"):
148
  st.session_state.messages.append({"role": "user", "content": user_input})
149
-
150
  with st.chat_message("user"):
151
  st.markdown(user_input)
152
 
153
  with st.chat_message("assistant"):
154
  message_placeholder = st.empty()
155
- with st.spinner("Generating response..."):
156
- try:
157
- assistant_response = chain({"question": user_input})
158
- full_response = assistant_response.get("answer", "I'm sorry, I couldn't generate a response.")
159
- except Exception as e:
160
- full_response = f"An error occurred: {e}"
161
- message_placeholder.markdown(full_response)
 
 
 
 
 
 
162
 
163
  st.session_state.messages.append({"role": "assistant", "content": full_response})
 
1
  import sys
2
  import os
3
  import re
4
+ import shutil
5
  import time
 
6
  import streamlit as st
7
  import nltk
8
+ import tempfile
9
 
10
+ # Set up temporary directory for NLTK resources
11
+ nltk_data_path = os.path.join(tempfile.gettempdir(), "nltk_data")
12
+ os.makedirs(nltk_data_path, exist_ok=True)
13
+ nltk.data.path = [nltk_data_path] # Force NLTK to use only the temp directory
14
+
15
+ # Force clean download of 'punkt'
16
+ try:
17
+ print("Ensuring NLTK 'punkt' resource is downloaded...")
18
+ if not os.path.exists(os.path.join(nltk_data_path, "tokenizers/punkt")):
19
+ nltk.download("punkt", download_dir=nltk_data_path)
20
+ except Exception as e:
21
+ print(f"Error downloading NLTK 'punkt': {e}")
22
+ raise e
23
+
24
+ sys.path.append(os.path.abspath("."))
25
  from langchain.chains import ConversationalRetrievalChain
26
  from langchain.memory import ConversationBufferMemory
27
  from langchain.llms import OpenAI
 
31
  from langchain.text_splitter import NLTKTextSplitter
32
  from patent_downloader import PatentDownloader
33
 
34
+ PERSISTED_DIRECTORY = tempfile.mkdtemp()
 
35
 
36
+ # Fetch API key securely from the environment
37
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
38
  if not OPENAI_API_KEY:
39
+ st.error("Critical Error: OpenAI API key not found in the environment variables. Please configure it.")
40
  st.stop()
41
 
42
+ def check_poppler_installed():
43
+ if not shutil.which("pdfinfo"):
44
+ raise EnvironmentError(
45
+ "Poppler is not installed or not in PATH. Install 'poppler-utils' for PDF processing."
46
+ )
47
 
48
+ check_poppler_installed()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def load_docs(document_path):
 
51
  try:
52
+ loader = UnstructuredPDFLoader(
53
+ document_path,
54
+ mode="elements",
55
+ strategy="fast",
56
+ ocr_languages=None
57
+ )
58
  documents = loader.load()
59
  text_splitter = NLTKTextSplitter(chunk_size=1000)
60
  return text_splitter.split_documents(documents)
61
  except Exception as e:
62
+ st.error(f"Failed to load and process PDF: {e}")
63
+ st.stop()
64
+
65
+ def already_indexed(vectordb, file_name):
66
+ indexed_sources = set(
67
+ x["source"] for x in vectordb.get(include=["metadatas"])["metadatas"]
68
+ )
69
+ return file_name in indexed_sources
70
 
71
+ def load_chain(file_name=None):
72
+ loaded_patent = st.session_state.get("LOADED_PATENT")
73
 
74
+ vectordb = Chroma(
75
+ persist_directory=PERSISTED_DIRECTORY,
76
+ embedding_function=HuggingFaceEmbeddings(),
 
77
  )
78
+ if loaded_patent == file_name or already_indexed(vectordb, file_name):
79
+ st.write("βœ… Already indexed.")
80
+ else:
81
+ vectordb.delete_collection()
82
+ docs = load_docs(file_name)
83
+ st.write("πŸ” Number of Documents: ", len(docs))
84
+
85
+ vectordb = Chroma.from_documents(
86
+ docs, HuggingFaceEmbeddings(), persist_directory=PERSISTED_DIRECTORY
87
+ )
88
+ vectordb.persist()
89
+ st.session_state["LOADED_PATENT"] = file_name
90
 
91
  memory = ConversationBufferMemory(
92
  memory_key="chat_history",
 
94
  input_key="question",
95
  output_key="answer",
96
  )
 
97
  return ConversationalRetrievalChain.from_llm(
98
  OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY),
99
  vectordb.as_retriever(search_kwargs={"k": 3}),
 
101
  memory=memory,
102
  )
103
 
104
+ def extract_patent_number(url):
105
+ pattern = r"/patent/([A-Z]{2}\d+)"
106
+ match = re.search(pattern, url)
107
+ return match.group(1) if match else None
108
+
109
+ def download_pdf(patent_number):
110
+ try:
111
+ patent_downloader = PatentDownloader(verbose=True)
112
+ output_path = patent_downloader.download(patents=patent_number, output_path=tempfile.gettempdir())
113
+ return output_path[0]
114
+ except Exception as e:
115
+ st.error(f"Failed to download patent PDF: {e}")
116
+ st.stop()
117
 
 
118
  if __name__ == "__main__":
119
  st.set_page_config(
120
  page_title="Patent Chat: Google Patents Chat Demo",
 
122
  layout="wide",
123
  initial_sidebar_state="expanded",
124
  )
 
125
  st.header("πŸ“– Patent Chat: Google Patents Chat Demo")
126
 
 
127
  patent_link = st.text_input("Enter Google Patent Link:", key="PATENT_LINK")
128
 
129
  if not patent_link:
 
135
  st.error("Invalid patent link format. Please provide a valid Google patent link.")
136
  st.stop()
137
 
138
+ st.write(f"Patent number: **{patent_number}**")
139
 
140
+ pdf_path = os.path.join(tempfile.gettempdir(), f"{patent_number}.pdf")
141
+ if os.path.isfile(pdf_path):
142
+ st.write("βœ… File already downloaded.")
143
+ else:
144
+ st.write("πŸ“₯ Downloading patent file...")
145
  pdf_path = download_pdf(patent_number)
146
+ st.write(f"βœ… File downloaded: {pdf_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
+ st.write("πŸ”„ Loading document into the system...")
149
+ chain = load_chain(pdf_path)
150
  st.success("πŸš€ Document successfully loaded! You can now start asking questions.")
151
 
 
152
  if "messages" not in st.session_state:
153
  st.session_state["messages"] = [
154
  {"role": "assistant", "content": "Hello! How can I assist you with this patent?"}
155
  ]
156
 
 
157
  for message in st.session_state.messages:
158
  with st.chat_message(message["role"]):
159
  st.markdown(message["content"])
160
 
 
161
  if user_input := st.chat_input("What is your question?"):
162
  st.session_state.messages.append({"role": "user", "content": user_input})
 
163
  with st.chat_message("user"):
164
  st.markdown(user_input)
165
 
166
  with st.chat_message("assistant"):
167
  message_placeholder = st.empty()
168
+ full_response = ""
169
+
170
+ with st.spinner("Generating response..."):
171
+ try:
172
+ assistant_response = chain({"question": user_input})
173
+ for chunk in assistant_response["answer"].split():
174
+ full_response += chunk + " "
175
+ time.sleep(0.05)
176
+ message_placeholder.markdown(full_response + "β–Œ")
177
+ except Exception as e:
178
+ full_response = f"An error occurred: {e}"
179
+ finally:
180
+ message_placeholder.markdown(full_response)
181
 
182
  st.session_state.messages.append({"role": "assistant", "content": full_response})