terapyon commited on
Commit
e6cff8e
·
1 Parent(s): 87e40dc

added base script

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. app.py +31 -0
  3. constraints.txt +102 -0
  4. requirements.txt +5 -0
  5. store.py +42 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ venv/*
2
+ __pychache__
3
+ podcast-*.txt
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.chains import RetrievalQA
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.llms import OpenAI
5
+ from langchain.vectorstores import Chroma
6
+
7
+
8
+ PERSIST_DIR_NAME = "podcast-75"
9
+
10
+
11
+ def get_retrieval_qa() -> RetrievalQA:
12
+ embeddings = OpenAIEmbeddings()
13
+ db = Chroma(persist_directory=PERSIST_DIR_NAME, embedding_function=embeddings)
14
+ retriever = db.as_retriever()
15
+ return RetrievalQA.from_chain_type(
16
+ llm=OpenAI(), chain_type="stuff", retriever=retriever
17
+ )
18
+
19
+
20
+ def main(query: str):
21
+ qa = get_retrieval_qa()
22
+ answer = qa(query)
23
+ return answer["result"]
24
+
25
+
26
+ pyhack_qa = gr.Interface(
27
+ fn=main,
28
+ inputs=[gr.Textbox(label="query")],
29
+ outputs="text",
30
+ )
31
+ pyhack_qa.launch()
constraints.txt ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.1.0
2
+ aiohttp==3.8.4
3
+ aiosignal==1.3.1
4
+ altair==5.0.1
5
+ anyio==3.7.0
6
+ async-timeout==4.0.2
7
+ attrs==23.1.0
8
+ backoff==2.2.1
9
+ certifi==2023.5.7
10
+ charset-normalizer==3.1.0
11
+ chromadb==0.3.26
12
+ click==8.1.3
13
+ clickhouse-connect==0.6.3
14
+ coloredlogs==15.0.1
15
+ contourpy==1.1.0
16
+ cycler==0.11.0
17
+ dataclasses-json==0.5.8
18
+ duckdb==0.8.1
19
+ exceptiongroup==1.1.1
20
+ fastapi==0.98.0
21
+ ffmpy==0.3.0
22
+ filelock==3.12.2
23
+ flatbuffers==23.5.26
24
+ fonttools==4.40.0
25
+ frozenlist==1.3.3
26
+ fsspec==2023.6.0
27
+ gradio==3.35.2
28
+ gradio_client==0.2.7
29
+ greenlet==2.0.2
30
+ h11==0.14.0
31
+ hnswlib==0.7.0
32
+ httpcore==0.17.2
33
+ httptools==0.5.0
34
+ httpx==0.24.1
35
+ huggingface-hub==0.15.1
36
+ humanfriendly==10.0
37
+ idna==3.4
38
+ Jinja2==3.1.2
39
+ jsonschema==4.17.3
40
+ kiwisolver==1.4.4
41
+ langchain==0.0.209
42
+ langchainplus-sdk==0.0.16
43
+ linkify-it-py==2.0.2
44
+ lz4==4.3.2
45
+ markdown-it-py==2.2.0
46
+ MarkupSafe==2.1.3
47
+ marshmallow==3.19.0
48
+ marshmallow-enum==1.5.1
49
+ matplotlib==3.7.1
50
+ mdit-py-plugins==0.3.3
51
+ mdurl==0.1.2
52
+ monotonic==1.6
53
+ mpmath==1.3.0
54
+ multidict==6.0.4
55
+ mypy-extensions==1.0.0
56
+ numexpr==2.8.4
57
+ numpy==1.25.0
58
+ onnxruntime==1.15.1
59
+ openai==0.27.8
60
+ openapi-schema-pydantic==1.2.4
61
+ orjson==3.9.1
62
+ overrides==7.3.1
63
+ packaging==23.1
64
+ pandas==2.0.2
65
+ Pillow==9.5.0
66
+ posthog==3.0.1
67
+ protobuf==4.23.3
68
+ pulsar-client==3.2.0
69
+ pydantic==1.10.9
70
+ pydub==0.25.1
71
+ Pygments==2.15.1
72
+ pyparsing==3.1.0
73
+ pyrsistent==0.19.3
74
+ python-dateutil==2.8.2
75
+ python-dotenv==1.0.0
76
+ python-multipart==0.0.6
77
+ pytz==2023.3
78
+ PyYAML==6.0
79
+ regex==2023.6.3
80
+ requests==2.31.0
81
+ semantic-version==2.10.0
82
+ six==1.16.0
83
+ sniffio==1.3.0
84
+ SQLAlchemy==2.0.16
85
+ starlette==0.27.0
86
+ sympy==1.12
87
+ tenacity==8.2.2
88
+ tiktoken==0.4.0
89
+ tokenizers==0.13.3
90
+ toolz==0.12.0
91
+ tqdm==4.65.0
92
+ typing-inspect==0.9.0
93
+ typing_extensions==4.6.3
94
+ tzdata==2023.3
95
+ uc-micro-py==1.0.2
96
+ urllib3==2.0.3
97
+ uvicorn==0.22.0
98
+ uvloop==0.17.0
99
+ watchfiles==0.19.0
100
+ websockets==11.0.3
101
+ yarl==1.9.2
102
+ zstandard==0.21.0
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ langchain
2
+ openai
3
+ chromadb
4
+ tiktoken
5
+ gradio
store.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.document_loaders import TextLoader
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.vectorstores import Chroma
5
+
6
+
7
+ CHUNK_SIZE = 500
8
+
9
+
10
+ def get_documents(filename: str):
11
+ loader = TextLoader(filename)
12
+ docs = loader.load()
13
+ return docs
14
+
15
+
16
+ def get_text_chunk(docs):
17
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=0)
18
+ texts = text_splitter.split_documents(docs)
19
+ return texts
20
+
21
+
22
+ def store(texts, dir_name):
23
+ embeddings = OpenAIEmbeddings()
24
+ db = Chroma.from_documents(texts, embeddings, persist_directory=dir_name)
25
+ db.persist()
26
+
27
+
28
+ def main(filename: str, dir_name: str):
29
+ docs = get_documents(filename)
30
+ texts = get_text_chunk(docs)
31
+ store(texts, dir_name)
32
+
33
+
34
+ if __name__ == "__main__":
35
+ import sys
36
+ args = sys.argv
37
+ if len(args) != 3:
38
+ print("No args, you need one args for text filename")
39
+ else:
40
+ filename = args[1]
41
+ dir_name = args[2]
42
+ main(filename, dir_name)