File size: 6,574 Bytes
76c69cb
a86b5c1
76c69cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7e206a
76c69cb
 
 
 
 
 
 
 
 
 
 
 
 
5a6c52b
76c69cb
 
 
 
 
1911931
76c69cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ceffff6
76c69cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
from theme import fast_rtc_theme
import torch
import json
import uuid
import os
import time
import pytz
from datetime import datetime
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
from threading import Thread
from huggingface_hub import CommitScheduler
from pathlib import Path
import spaces

os.system("apt-get update && apt-get install -y libstdc++6")

# Load HF token from the environment
token = os.environ["HF_TOKEN"]

# Load Model and Tokenizer
model_id = "large-traversaal/Mantra-14B"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    token=token,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
terminators = [tokenizer.eos_token_id]

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Setting up logging and scheduling periodic commits to Hugging Face dataset repository with the help of CommitScheduler. 
log_folder = Path("logs")
log_folder.mkdir(parents=True, exist_ok=True)
log_file = log_folder / f"chat_log_{uuid.uuid4()}.json"

scheduler = CommitScheduler(
    repo_id="DrishtiSharma/mantra-14b-user-interaction-log",  
    repo_type="dataset",
    folder_path=log_folder,
    path_in_repo="data",
    every=0.01,  
    token=token 
)

# Set timezone for logging timestamps
timezone = pytz.timezone("UTC")

@spaces.GPU(duration=60)
def chat(message, history, temperature, do_sample, max_tokens, top_p):
    start_time = time.time()
    timestamp = datetime.now(timezone).strftime("%Y-%m-%d %H:%M:%S %Z")
    
    conversation_history = []
    for item in history:
        conversation_history.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            conversation_history.append({"role": "assistant", "content": item[1]})
    conversation_history.append({"role": "user", "content": message})
    
    messages = tokenizer.apply_chat_template(conversation_history, tokenize=False, add_generation_prompt=True)
    model_inputs = tokenizer([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tokenizer, timeout=70.0, skip_prompt=True, skip_special_tokens=True
    )

    # Define generation parameters
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_p=top_p,  
        eos_token_id=terminators,
    )

    #Disable sampling if temperature is zero (deterministic generation)
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    generation_thread = Thread(target=model.generate, kwargs=generate_kwargs)
    generation_thread.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

    # Calculate total response time
    response_time = round(time.time() - start_time, 2)

    # Prepare log entry for the interaction
    log_data = {
        "timestamp": timestamp,
        "input": message,
        "output": partial_text,
        "response_time": response_time,
        "temperature": temperature,
        "do_sample": do_sample,
        "max_tokens": max_tokens,
        "top_p": top_p  
    }
    
    with scheduler.lock:
        with log_file.open("a", encoding="utf-8") as f:
            f.write(json.dumps(log_data, ensure_ascii=False) + "\n")

# Function to clear chat history
def clear_chat():
    return [], []

# Function to export chat history as a downloadable file
def export_chat(history):
    if not history:
        return None  # No chat history to export

    file_path = "chat_history.txt"
    with open(file_path, "w", encoding="utf-8") as f:
        for msg in history:
            f.write(f"User: {msg[0]}\nBot: {msg[1]}\n")
    return file_path


# Gradio UI
with gr.Blocks(theme=fast_rtc_theme) as demo:
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("#### ⚙️🛠 Configure Settings")
            temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.1, label="Temperature", interactive=True)
            do_sample = gr.Checkbox(label="Sampling", value=True, interactive=True)
            max_tokens = gr.Slider(minimum=128, maximum=4096, step=1, value=1024, label="max_new_tokens", interactive=True)
            top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.2, label="top_p", interactive=True)
            
        
        with gr.Column(scale=3):
            gr.Markdown("# **Chat With Phi-4-Hindi** 💬 ")
            
            chat_interface = gr.ChatInterface(
                fn=chat,
                examples=[
                ["What is the English translation of: 'इस मॉडल को हिंदी और अंग्रेजी डेटा पर प्रशिक्षित किया गया था'?"], 
                ["टिम अपने 3 बच्चों को ट्रिक या ट्रीटिंग के लिए ले जाता है। वे 4 घंटे बाहर रहते हैं। हर घंटे वे x घरों में जाते हैं। हर घर में हर बच्चे को 3 ट्रीट मिलते हैं। उसके बच्चों को कुल 180 ट्रीट मिलते हैं। अज्ञात चर x का मान क्या है?"], 
                ["How do you play fetch? A) Throw the object for the dog to bring back to you. B) Get the object and bring it back to the dog."]
                ],
                additional_inputs=[temperature, do_sample, max_tokens, top_p], 
                stop_btn="⏹ Stop",
                description="Phi-4-Hindi is a bilingual instruction-tuned LLM for Hindi and English, trained on a mixed datasets composed of 485K Hindi-English samples.",
                #theme="default"
            )

            with gr.Row():
                clear_btn = gr.Button("🧹 Clear Chat", variant="primary")
                export_btn = gr.Button("📥 Export Chat", variant="primary")

            # Connect buttons to their functions (Clear and Export Chat)
            clear_btn.click(
                fn=clear_chat,
                outputs=[chat_interface.chatbot, chat_interface.chatbot_value]  
            )

            export_btn.click(fn=export_chat, inputs=[chat_interface.chatbot], outputs=[gr.File()])  

demo.launch()