File size: 5,331 Bytes
2af06c8
038b068
 
 
 
 
 
 
 
 
 
 
2af06c8
038b068
2af06c8
038b068
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2af06c8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
from smolagents import HfApiModel, CodeAgent, Tool
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
from huggingface_hub import login
from llama_index.retrievers.bm25 import BM25Retriever
import spaces
import torch
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
from string import punctuation
import re
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed

device = "cuda:0" if torch.cuda.is_available() else "cpu"

repo_id =  "parler-tts/parler-tts-mini-v1"
# repo_id_large = "parler-tts/parler-tts-large-v1"

tts_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)

SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42

number_normalizer = EnglishNumberNormalizer()

def preprocess(text):
    text = number_normalizer(text).strip()
    text = text.replace("-", " ")
    if text[-1] not in punctuation:
        text = f"{text}."
    
    abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
    
    def separate_abb(chunk):
        chunk = chunk.replace(".","")
        print(chunk)
        return " ".join(chunk)
    
    abbreviations = re.findall(abbreviations_pattern, text)
    for abv in abbreviations:
        if abv in text:
            text = text.replace(abv, separate_abb(abv))
    return text

@spaces.GPU
def gen_tts(text, description):
    inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
    prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)


    set_seed(SEED)
    generation = tts_model.generate(
            input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
        )
    audio_arr = generation.cpu().numpy().squeeze()

    return SAMPLE_RATE, audio_arr

class RetrieverTool(Tool):
    name = "retriever"
    description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
    inputs = {
        "query": {
            "type": "string",
            "description": "The query to perform. Ask the question as an human would, with simple explanation. The underlying index is BM25.",
        }
    }
    output_type = "string"

    def __init__(self, path, **kwargs):
        super().__init__(**kwargs)
        self.retriever = BM25Retriever.from_persist_dir(path)

    def forward(self, query: str) -> str:
        assert isinstance(query, str), "Your search query must be a string"

        docs = self.retriever.retrieve(
            query,
        )

        return "\nRetrieved documents:\n" + "".join(
            [
                f"\n\n===== Document {str(i)} =====\n" + doc.text
                for i, doc in enumerate(docs)
            ]
        )

path  = "./ml_notes_index"

model = HfApiModel(
    max_tokens=4086,
    temperature=0.5,
    model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
    custom_role_conversions=None
)

retriever_tool = RetrieverTool(path)

agent = CodeAgent(
    tools=[retriever_tool], 
    model=model, 
    max_steps=4, 
    verbosity_level=2
)

summarization_agent = CodeAgent(
    tools=[], 
    model=model, 
    max_steps=1, 
    verbosity_level=2
)


def greet(question):
    agent_output = agent.run(question)
    result = summarization_agent.run(f"Rephrase the following out since it will be passed to an Text-To-Speach Model: {agent_output}")
    
    # Generate audio from the text
    description = "Laura's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
    sample_rate, audio = gen_tts(result, description)
    
    return result, (sample_rate, audio)

# login()

css = """
        #share-btn-container {
            display: flex;
            padding-left: 0.5rem !important;
            padding-right: 0.5rem !important;
            background-color: #000000;
            justify-content: center;
            align-items: center;
            border-radius: 9999px !important; 
            width: 13rem;
            margin-top: 10px;
            margin-left: auto;
            flex: unset !important;
        }
"""

with gr.Blocks(css=css) as block:
    gr.HTML(
        """
            <div style="text-align: center; max-width: 700px; margin: 0 auto;">
              <div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
                <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
                  ML Professor with Voice 🗣️
                </h1>
              </div>
            </div>
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_text = gr.Textbox(label="Your Question", lines=2)
            run_button = gr.Button("Ask Question", variant="primary")
        with gr.Column():
            text_output = gr.Textbox(label="Answer", lines=4)
            audio_out = gr.Audio(label="Voice Answer", type="numpy")

    run_button.click(fn=greet, inputs=[input_text], outputs=[text_output, audio_out])

block.queue()
block.launch(share=True)