Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
import openai | |
import pandas as pd | |
from uuid import uuid4 | |
import time | |
# π Set the OpenAI API key from an environment variable | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# π Function to generate a unique session ID for caching | |
def get_session_id(): | |
if 'session_id' not in st.session_state: | |
st.session_state.session_id = str(uuid4()) | |
return st.session_state.session_id | |
# π Predefined examples loaded from Python dictionaries | |
EXAMPLES = [ | |
{ | |
'Problem': 'What is deductive reasoning?', | |
'Rationale': 'Deductive reasoning starts from general premises to arrive at a specific conclusion.', | |
'Answer': 'It involves deriving specific conclusions from general premises.' | |
}, | |
{ | |
'Problem': 'What is inductive reasoning?', | |
'Rationale': 'Inductive reasoning involves drawing generalizations based on specific observations.', | |
'Answer': 'It involves forming general rules from specific examples.' | |
}, | |
{ | |
'Problem': 'Explain abductive reasoning.', | |
'Rationale': 'Abductive reasoning finds the most likely explanation for incomplete observations.', | |
'Answer': 'It involves finding the best possible explanation.' | |
} | |
] | |
# π§ STaR Algorithm Implementation | |
class SelfTaughtReasoner: | |
def __init__(self, model_engine="text-davinci-003"): | |
self.model_engine = model_engine | |
self.prompt_examples = EXAMPLES # Initialize with predefined examples | |
self.iterations = 0 | |
self.generated_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) | |
self.rationalized_data = pd.DataFrame(columns=['Problem', 'Rationale', 'Answer', 'Is_Correct']) | |
self.fine_tuned_model = None # ποΈ Placeholder for fine-tuned model | |
def add_prompt_example(self, problem: str, rationale: str, answer: str): | |
""" | |
β Adds a prompt example to the few-shot examples. | |
""" | |
self.prompt_examples.append({ | |
'Problem': problem, | |
'Rationale': rationale, | |
'Answer': answer | |
}) | |
def construct_prompt(self, problem: str, include_answer: bool = False, answer: str = "") -> str: | |
""" | |
π Constructs the prompt for the OpenAI API call. | |
""" | |
prompt = "" | |
for example in self.prompt_examples: | |
prompt += f"Problem: {example['Problem']}\n" | |
prompt += f"Rationale: {example['Rationale']}\n" | |
prompt += f"Answer: {example['Answer']}\n\n" | |
prompt += f"Problem: {problem}\n" | |
if include_answer: | |
prompt += f"Answer (as hint): {answer}\n" | |
prompt += "Rationale:" | |
return prompt | |
def generate_rationale_and_answer(self, problem: str) -> Tuple[str, str]: | |
""" | |
π€ Generates a rationale and answer for a given problem. | |
""" | |
prompt = self.construct_prompt(problem) | |
try: | |
response = openai.Completion.create( | |
engine=self.model_engine, | |
prompt=prompt, | |
max_tokens=150, | |
temperature=0.7, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=["\n\n", "Problem:", "Answer:"] | |
) | |
rationale = response.choices[0].text.strip() | |
# π Now generate the answer using the rationale | |
prompt += f" {rationale}\nAnswer:" | |
answer_response = openai.Completion.create( | |
engine=self.model_engine, | |
prompt=prompt, | |
max_tokens=10, | |
temperature=0, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
stop=["\n", "\n\n", "Problem:"] | |
) | |
answer = answer_response.choices[0].text.strip() | |
return rationale, answer | |
except Exception as e: | |
st.error(f"β Error generating rationale and answer: {e}") | |
return "", "" | |
def fine_tune_model(self): | |
""" | |
π οΈ Fine-tunes the model on the generated rationales. | |
""" | |
time.sleep(1) # β³ Simulate time taken for fine-tuning | |
self.fine_tuned_model = f"{self.model_engine}-fine-tuned-{get_session_id()}" | |
st.success(f"β Model fine-tuned: {self.fine_tuned_model}") | |
def run_iteration(self, dataset: pd.DataFrame): | |
""" | |
π Runs one iteration of the STaR process. | |
""" | |
st.write(f"### Iteration {self.iterations + 1}") | |
progress_bar = st.progress(0) | |
total = len(dataset) | |
for idx, row in dataset.iterrows(): | |
problem = row['Problem'] | |
correct_answer = row['Answer'] | |
# π€ Generate rationale and answer | |
rationale, answer = self.generate_rationale_and_answer(problem) | |
is_correct = (answer.lower() == correct_answer.lower()) | |
# π Record the generated data | |
self.generated_data = self.generated_data.append({ | |
'Problem': problem, | |
'Rationale': rationale, | |
'Answer': answer, | |
'Is_Correct': is_correct | |
}, ignore_index=True) | |
# β If incorrect, perform rationalization | |
if not is_correct: | |
rationale, answer = self.rationalize(problem, correct_answer) | |
is_correct = (answer.lower() == correct_answer.lower()) | |
if is_correct: | |
self.rationalized_data = self.rationalized_data.append({ | |
'Problem': problem, | |
'Rationale': rationale, | |
'Answer': answer, | |
'Is_Correct': is_correct | |
}, ignore_index=True) | |
progress_bar.progress((idx + 1) / total) | |
# π§ Fine-tune the model on correct rationales | |
st.write("π Fine-tuning the model on correct rationales...") | |
self.fine_tune_model() | |
self.iterations += 1 | |
# π₯οΈ Streamlit App | |
def main(): | |
st.title("π€ Self-Taught Reasoner (STaR) Demonstration") | |
# π§© Initialize the Self-Taught Reasoner | |
if 'star' not in st.session_state: | |
st.session_state.star = SelfTaughtReasoner() | |
star = st.session_state.star | |
# π Wide format layout | |
col1, col2 = st.columns([1, 2]) # Column widths: col1 for input, col2 for display | |
# Step 1: Few-Shot Prompt Examples | |
with col1: | |
st.header("Step 1: Add Few-Shot Prompt Examples") | |
st.write("Choose an example from the dropdown or input your own.") | |
selected_example = st.selectbox( | |
"Select a predefined example", | |
[f"Example {i + 1}: {ex['Problem']}" for i, ex in enumerate(EXAMPLES)] | |
) | |
# Prefill with selected example | |
example_idx = int(selected_example.split(" ")[1]) - 1 | |
example_problem = EXAMPLES[example_idx]['Problem'] | |
example_rationale = EXAMPLES[example_idx]['Rationale'] | |
example_answer = EXAMPLES[example_idx]['Answer'] | |
st.text_area("Problem", value=example_problem, height=50, key="example_problem") | |
st.text_area("Rationale", value=example_rationale, height=100, key="example_rationale") | |
st.text_input("Answer", value=example_answer, key="example_answer") | |
if st.button("Add Example"): | |
star.add_prompt_example(st.session_state.example_problem, st.session_state.example_rationale, st.session_state.example_answer) | |
st.success("Example added successfully!") | |
with col2: | |
# Display current prompt examples | |
if star.prompt_examples: | |
st.subheader("Current Prompt Examples:") | |
for idx, example in enumerate(star.prompt_examples): | |
st.write(f"**Example {idx + 1}:**") | |
st.write(f"Problem: {example['Problem']}") | |
st.write(f"Rationale: {example['Rationale']}") | |
st.write(f"Answer: {example['Answer']}") | |
# Step 2: Input Dataset | |
st.header("Step 2: Input Dataset") | |
dataset_input_method = st.radio("How would you like to input the dataset?", ("Manual Entry", "Upload CSV")) | |
if dataset_input_method == "Manual Entry": | |
dataset_problems = st.text_area("Enter problems and answers in the format 'Problem | Answer', one per line.", height=200) | |
if st.button("Submit Dataset"): | |
dataset = [] | |
lines = dataset_problems.strip().split('\n') | |
for line in lines: | |
if '|' in line: | |
problem, answer = line.split('|', 1) | |
dataset.append({'Problem': problem.strip(), 'Answer': answer.strip()}) | |
st.session_state.dataset = pd.DataFrame(dataset) | |
st.success("Dataset loaded.") | |
else: | |
uploaded_file = st.file_uploader("Upload a CSV file with 'Problem' and 'Answer' columns.", type=['csv']) | |
if uploaded_file: | |
st.session_state.dataset = pd.read_csv(uploaded_file) | |
st.success("Dataset loaded.") | |
if 'dataset' in st.session_state: | |
st.subheader("Current Dataset:") | |
st.dataframe(st.session_state.dataset.head()) | |
# Step 3: Run STaR Process | |
st.header("Step 3: Run STaR Process") | |
num_iterations = st.number_input("Number of Iterations to Run:", min_value=1, max_value=10, value=1) | |
if st.button("Run STaR"): | |
for _ in range(num_iterations): | |
star.run_iteration(st.session_state.dataset) | |
st.header("Results") | |
st.subheader("Generated Data") | |
st.dataframe(star.generated_data) | |
st.subheader("Rationalized Data") | |
st.dataframe(star.rationalized_data) | |
st.write("The model has been fine-tuned iteratively.") | |
# Step 4: Test the Fine-Tuned Model | |
st.header("Step 4: Test the Fine-Tuned Model") | |
test_problem = st.text_area("Enter a new problem to solve:", height=100) | |
if st.button("Solve Problem"): | |
if not test_problem: | |
st.warning("Please enter a problem to solve.") | |
else: | |
rationale, answer = star.generate_rationale_and_answer(test_problem) | |
st.subheader("Rationale:") | |
st.write(rationale) | |
st.subheader("Answer:") | |
st.write(answer) | |
# Footer with custom HTML/JS component | |
st.markdown("---") | |
st.write("Developed as a demonstration of the STaR method with enhanced Streamlit capabilities.") | |
st.components.v1.html(""" | |
<div style="text-align: center; margin-top: 20px;"> | |
<h3>π Boost Your AI Reasoning with STaR! π</h3> | |
</div> | |
""") | |
if __name__ == "__main__": | |
main() | |