advanced / app.py
Alina Lozovskaya
First commit
d1ed69b
raw
history blame
5.28 kB
import os
import time
import pathlib
import threading
import shutil
import gradio as gr
import yaml
import io
from loguru import logger
from yourbench.pipeline import run_pipeline
UPLOAD_DIRECTORY = pathlib.Path("/app/uploaded_files")
UPLOAD_DIRECTORY.mkdir(parents=True, exist_ok=True)
CONFIG_PATH = pathlib.Path("/app/yourbench_config.yml")
yourbench_log_stream = io.StringIO()
def custom_log_handler(message):
yourbench_log_stream.write(message + "\n")
# yourbench_log_stream.flush()
def get_log_content():
yourbench_log_stream.seek(0)
content = yourbench_log_stream.read()
print(len(content))
return content
logger.add(custom_log_handler, filter="yourbench")
def start_task():
# Start the long-running task in a separate thread
task_thread = threading.Thread(target=run_pipeline, args=(CONFIG_PATH,), daemon=True)
task_thread.start()
task_thread.join()
def generate_config(
hf_token,
hf_org,
model_name,
provider,
base_url,
api_key,
max_concurrent_requests,
ingestion_source,
ingestion_output,
run_ingestion,
summarization_source,
summarization_output,
run_summarization
):
"""Generates a config.yaml based on user inputs"""
config = {
"hf_configuration": {
"token": hf_token,
"private": True,
"hf_organization": hf_org
},
"model_list": [{
"model_name": model_name,
"provider": provider,
"base_url": base_url,
"api_key": api_key,
"max_concurrent_requests": max_concurrent_requests
}],
"pipeline": {
"ingestion": {
"source_documents_dir": ingestion_source,
"output_dir": ingestion_output,
"run": run_ingestion
},
"summarization": {
"source_dataset_name": summarization_source,
"output_dataset_name": summarization_output,
"run": run_summarization
}
}
}
return yaml.dump(config, default_flow_style=False)
def save_config(yaml_text):
with open(CONFIG_PATH, "w") as file:
file.write(yaml_text)
return "✅ Config saved as config.yaml!"
def save_files(files: list[str]):
saved_paths = []
for file in files:
file_path = pathlib.Path(file)
save_path = UPLOAD_DIRECTORY / file_path.name
shutil.move(str(file_path), str(save_path))
saved_paths.append(str(save_path))
return f"Files have been successfully saved to: {', '.join(saved_paths)}"
def start_youbench():
run_pipeline(CONFIG_PATH, debug=False)
app = gr.Blocks()
with app:
gr.Markdown("## YourBench Configuration")
with gr.Tab("HF Configuration"):
hf_token = gr.Textbox(label="HF Token")
hf_org = gr.Textbox(label="HF Organization")
with gr.Tab("Model Settings"):
model_name = gr.Textbox(label="Model Name")
provider = gr.Dropdown(["openrouter", "openai", "huggingface"], value="huggingface", label="Provider")
base_url = gr.Textbox(label="Base URL")
api_key = gr.Textbox(label="API Key")
max_concurrent_requests = gr.Dropdown([8, 16, 32], value=16, label="Max Concurrent Requests")
with gr.Tab("Pipeline Stages"):
ingestion_source = gr.Textbox(label="Ingestion Source Directory")
ingestion_output = gr.Textbox(label="Ingestion Output Directory")
run_ingestion = gr.Checkbox(label="Run Ingestion", value=False)
summarization_source = gr.Textbox(label="Summarization Source Dataset")
summarization_output = gr.Textbox(label="Summarization Output Dataset")
run_summarization = gr.Checkbox(label="Run Summarization", value=False)
with gr.Tab("Config"):
config_output = gr.Code(label="Generated Config", language="yaml")
preview_button = gr.Button("Generate Config")
save_button = gr.Button("Save Config")
preview_button.click(generate_config,
inputs=[hf_token, hf_org, model_name, provider, base_url, api_key,
max_concurrent_requests, ingestion_source, ingestion_output,
run_ingestion, summarization_source, summarization_output, run_summarization],
outputs=config_output)
save_button.click(save_config, inputs=[config_output], outputs=[gr.Textbox(label="Save Status")])
with gr.Tab("Files"):
file_input = gr.File(label="Upload text files", file_count="multiple", file_types=[".txt", ".md", ".html"])
file_explorer = gr.FileExplorer(root_dir=UPLOAD_DIRECTORY, interactive=False, label="Current Files")
output = gr.Textbox(label="Log")
file_input.upload(save_files, file_input, output)
with gr.Tab("Run Generation"):
log_output = gr.Code(label="Log Output", language=None,lines=20, interactive=False)
start_button = gr.Button("Start Long-Running Task")
timer = gr.Timer(0.5, active=True)
timer.tick(get_log_content, outputs=log_output)
start_button.click(start_task)
app.launch()