|
import os |
|
from openai import OpenAI |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from datasets import load_dataset, Dataset |
|
from sklearn.neighbors import NearestNeighbors |
|
import numpy as np |
|
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer |
|
import torch |
|
from typing import List |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_QUESTION = "Ask me anything in the context of persona-driven prompt generation..." |
|
|
|
|
|
os.environ['OPENAI_BASE'] = "https://api.openai.com/v1" |
|
os.environ['OPENAI_MODEL'] = "gpt-4" |
|
os.environ['MODEL_PROVIDER'] = "huggingface" |
|
model_provider = os.environ.get("MODEL_PROVIDER") |
|
|
|
|
|
if model_provider.lower() == "openai": |
|
MODEL_NAME = os.environ['OPENAI_MODEL'] |
|
client = OpenAI( |
|
base_url=os.environ.get("OPENAI_BASE"), |
|
api_key=api_key |
|
) |
|
else: |
|
MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct" |
|
|
|
hf_client = InferenceClient( |
|
model=MODEL_NAME, |
|
api_key=os.environ.get("HF_TOKEN") |
|
) |
|
|
|
|
|
dataset = load_dataset('tosin2013/persona-driven-prompt-generator', streaming=True) |
|
dataset = Dataset.from_list(list(dataset['train'])) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") |
|
|
|
|
|
texts = dataset['input'] |
|
|
|
|
|
text_embeddings = embeddings.embed_documents(texts) |
|
|
|
|
|
nn = NearestNeighbors(n_neighbors=5, metric='cosine') |
|
nn.fit(np.array(text_embeddings)) |
|
|
|
def get_relevant_documents(query, k=5): |
|
""" |
|
Retrieves the k most relevant documents to the query. |
|
""" |
|
query_embedding = embeddings.embed_query(query) |
|
distances, indices = nn.kneighbors([query_embedding], n_neighbors=k) |
|
relevant_docs = [texts[i] for i in indices[0]] |
|
return relevant_docs |
|
|
|
def generate_response(question, history): |
|
try: |
|
print(f"\n[LOG] Received question: {question}") |
|
|
|
|
|
relevant_docs = get_relevant_documents(question, k=3) |
|
print(f"[LOG] Retrieved {len(relevant_docs)} relevant documents") |
|
|
|
|
|
context = "\n".join(relevant_docs) |
|
prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:" |
|
print(f"[LOG] Generated prompt: {prompt[:200]}...") |
|
|
|
if model_provider.lower() == "huggingface": |
|
messages = [ |
|
{ |
|
"role": "system", |
|
"content": "You are a helpful AI assistant. Answer the question based on the provided context." |
|
}, |
|
{ |
|
"role": "user", |
|
"content": prompt |
|
} |
|
] |
|
|
|
completion = hf_client.chat.completions.create( |
|
model=MODEL_NAME, |
|
messages=messages, |
|
max_tokens=500 |
|
) |
|
response = completion.choices[0].message.content |
|
print(f"[LOG] Using Hugging Face model (serverless): {MODEL_NAME}") |
|
print(f"[LOG] Hugging Face response: {response[:200]}...") |
|
|
|
elif model_provider.lower() == "openai": |
|
response = client.chat.completions.create( |
|
model=os.environ.get("OPENAI_MODEL"), |
|
messages=[ |
|
{"role": "system", "content": "You are a helpful assistant. Answer the question based on the provided context."}, |
|
{"role": "user", "content": prompt}, |
|
] |
|
) |
|
response = response.choices[0].message.content |
|
print(f"[LOG] Using OpenAI model: {os.environ.get('OPENAI_MODEL')}") |
|
print(f"[LOG] OpenAI response: {response[:200]}...") |
|
|
|
|
|
history.append((question, response)) |
|
return history |
|
except Exception as e: |
|
error_msg = f"Error generating response: {str(e)}" |
|
print(f"[ERROR] {error_msg}") |
|
history.append((question, error_msg)) |
|
return history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(f""" |
|
## Persona-Driven Prompt Generator QA Agent |
|
**Current Model:** {MODEL_NAME} |
|
|
|
The Custom Prompt Generator is a Python application that leverages Large Language Models (LLMs) and the LiteLLM library to dynamically generate personas, fetch knowledge sources, resolve conflicts, and produce tailored prompts. This application is designed to assist in various software development tasks by providing context-aware prompts based on user input and predefined personas. |
|
|
|
Sample questions: |
|
1. What are the key components of an effective persona for prompt generation? |
|
2. How can I create a persona that generates creative writing prompts? |
|
3. What are the main features of the persona generator? |
|
|
|
Related repository: [persona-driven-prompt-generator](https://github.com/tosin2013/persona-driven-prompt-generator) |
|
""") |
|
|
|
with gr.Row(): |
|
chatbot = gr.Chatbot(label="Chat History") |
|
|
|
with gr.Row(): |
|
question = gr.Textbox( |
|
value=DEFAULT_QUESTION, |
|
label="Your Question", |
|
placeholder=DEFAULT_QUESTION |
|
) |
|
|
|
with gr.Row(): |
|
submit_btn = gr.Button("Submit") |
|
clear_btn = gr.Button("Clear") |
|
|
|
|
|
submit_btn.click( |
|
generate_response, |
|
inputs=[question, chatbot], |
|
outputs=[chatbot] |
|
) |
|
|
|
clear_btn.click( |
|
lambda: (None, ""), |
|
inputs=[], |
|
outputs=[chatbot, question] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|