|
import html |
|
import json |
|
from time import sleep |
|
|
|
import gradio as gr |
|
from omagent_core.clients.devices.app.callback import AppCallback |
|
from omagent_core.clients.devices.app.input import AppInput |
|
from omagent_core.clients.devices.app.schemas import ContentStatus, MessageType |
|
from omagent_core.engine.automator.task_handler import TaskHandler |
|
from omagent_core.engine.http.models.workflow_status import terminal_status |
|
from omagent_core.engine.workflow.conductor_workflow import ConductorWorkflow |
|
from omagent_core.services.connectors.redis import RedisConnector |
|
from omagent_core.utils.build import build_from_file |
|
from omagent_core.utils.container import container |
|
from omagent_core.utils.logger import logging |
|
from omagent_core.utils.registry import registry |
|
|
|
registry.import_module() |
|
|
|
container.register_connector(name="redis_stream_client", connector=RedisConnector) |
|
|
|
container.register_callback(callback=AppCallback) |
|
container.register_input(input=AppInput) |
|
|
|
|
|
class WebpageClient: |
|
def __init__( |
|
self, |
|
interactor: ConductorWorkflow = None, |
|
processor: ConductorWorkflow = None, |
|
config_path: str = "./config", |
|
workers: list = [], |
|
) -> None: |
|
self._interactor = interactor |
|
self._processor = processor |
|
self._config_path = config_path |
|
self._workers = workers |
|
self._workflow_instance_id = None |
|
self._worker_config = build_from_file(self._config_path) |
|
self._task_to_domain = {} |
|
self._incomplete_message = "" |
|
self._custom_css = """ |
|
#OmAgent { |
|
height: 100vh !important; |
|
max-height: calc(100vh - 190px) !important; |
|
overflow-y: auto; |
|
} |
|
|
|
.running-message { |
|
margin: 0; |
|
padding: 2px 4px; |
|
white-space: pre-wrap; |
|
word-wrap: break-word; |
|
font-family: inherit; |
|
} |
|
|
|
/* Remove the background and border of the message box */ |
|
.message-wrap { |
|
background: none !important; |
|
border: none !important; |
|
padding: 0 !important; |
|
margin: 0 !important; |
|
} |
|
|
|
/* Remove the bubble style of the running message */ |
|
.message:has(.running-message) { |
|
background: none !important; |
|
border: none !important; |
|
padding: 0 !important; |
|
box-shadow: none !important; |
|
} |
|
""" |
|
|
|
def start_interactor(self): |
|
self._task_handler_interactor = TaskHandler( |
|
worker_config=self._worker_config, workers=self._workers, task_to_domain=self._task_to_domain |
|
) |
|
self._task_handler_interactor.start_processes() |
|
try: |
|
with gr.Blocks(title="OmAgent", css=self._custom_css) as chat_interface: |
|
chatbot = gr.Chatbot( |
|
elem_id="OmAgent", |
|
bubble_full_width=False, |
|
type="messages", |
|
height="100%", |
|
) |
|
|
|
chat_input = gr.MultimodalTextbox( |
|
interactive=True, |
|
file_count="multiple", |
|
placeholder="Enter message or upload file...", |
|
show_label=False, |
|
) |
|
|
|
chat_msg = chat_input.submit( |
|
self.add_message, [chatbot, chat_input], [chatbot, chat_input] |
|
) |
|
bot_msg = chat_msg.then( |
|
self.bot, chatbot, chatbot, api_name="bot_response" |
|
) |
|
bot_msg.then( |
|
lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input] |
|
) |
|
chat_interface.launch() |
|
except KeyboardInterrupt: |
|
logging.info("\nDetected Ctrl+C, stopping workflow...") |
|
if self._workflow_instance_id is not None: |
|
self._interactor._executor.terminate( |
|
workflow_id=self._workflow_instance_id |
|
) |
|
raise |
|
|
|
def stop_interactor(self): |
|
self._task_handler_interactor.stop_processes() |
|
|
|
def start_processor(self): |
|
self._task_handler_processor = TaskHandler( |
|
worker_config=self._worker_config, workers=self._workers, task_to_domain=self._task_to_domain |
|
) |
|
self._task_handler_processor.start_processes() |
|
|
|
try: |
|
with gr.Blocks(title="OmAgent", css=self._custom_css) as chat_interface: |
|
chatbot = gr.Chatbot( |
|
elem_id="OmAgent", |
|
bubble_full_width=False, |
|
type="messages", |
|
height="100%", |
|
) |
|
|
|
chat_input = gr.MultimodalTextbox( |
|
interactive=True, |
|
file_count="multiple", |
|
placeholder="Enter message or upload file...", |
|
show_label=False, |
|
) |
|
|
|
chat_msg = chat_input.submit( |
|
self.add_processor_message, |
|
[chatbot, chat_input], |
|
[chatbot, chat_input], |
|
) |
|
bot_msg = chat_msg.then( |
|
self.processor_bot, chatbot, chatbot, api_name="bot_response" |
|
) |
|
bot_msg.then( |
|
lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input] |
|
) |
|
chat_interface.launch(server_port=7861) |
|
except KeyboardInterrupt: |
|
logging.info("\nDetected Ctrl+C, stopping workflow...") |
|
if self._workflow_instance_id is not None: |
|
self._processor._executor.terminate( |
|
workflow_id=self._workflow_instance_id |
|
) |
|
raise |
|
|
|
def stop_processor(self): |
|
self._task_handler_processor.stop_processes() |
|
|
|
def add_message(self, history, message): |
|
if self._workflow_instance_id is None: |
|
self._workflow_instance_id = self._interactor.start_workflow_with_input( |
|
workflow_input={}, task_to_domain=self._task_to_domain |
|
) |
|
contents = [] |
|
for x in message["files"]: |
|
history.append({"role": "user", "content": {"path": x}}) |
|
contents.append({"type": "image_url", "data": x}) |
|
if message["text"] is not None: |
|
history.append({"role": "user", "content": message["text"]}) |
|
contents.append({"type": "text", "data": message["text"]}) |
|
result = { |
|
"agent_id": self._workflow_instance_id, |
|
"messages": [{"role": "user", "content": contents}], |
|
"kwargs": {}, |
|
} |
|
container.get_connector("redis_stream_client")._client.xadd( |
|
f"{self._workflow_instance_id}_input", |
|
{"payload": json.dumps(result, ensure_ascii=False)}, |
|
) |
|
return history, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
|
def add_processor_message(self, history, message): |
|
if self._workflow_instance_id is None: |
|
self._workflow_instance_id = self._processor.start_workflow_with_input( |
|
workflow_input={}, task_to_domain=self._task_to_domain |
|
) |
|
image_items = [] |
|
for idx, x in enumerate(message["files"]): |
|
history.append({"role": "user", "content": {"path": x}}) |
|
image_items.append( |
|
{"type": "image_url", "resource_id": str(idx), "data": str(x)} |
|
) |
|
result = {"content": image_items} |
|
container.get_connector("redis_stream_client")._client.xadd( |
|
f"image_process", {"payload": json.dumps(result, ensure_ascii=False)} |
|
) |
|
return history, gr.MultimodalTextbox(value=None, interactive=False) |
|
|
|
def bot(self, history: list): |
|
stream_name = f"{self._workflow_instance_id}_output" |
|
consumer_name = f"{self._workflow_instance_id}_agent" |
|
group_name = "omappagent" |
|
running_stream_name = f"{self._workflow_instance_id}_running" |
|
self._check_redis_stream_exist(stream_name, group_name) |
|
self._check_redis_stream_exist(running_stream_name, group_name) |
|
while True: |
|
|
|
running_messages = self._get_redis_stream_message( |
|
group_name, consumer_name, running_stream_name |
|
) |
|
for stream, message_list in running_messages: |
|
for message_id, message in message_list: |
|
payload_data = self._get_message_payload(message) |
|
if payload_data is None: |
|
continue |
|
progress = html.escape(payload_data.get("progress", "")) |
|
message = html.escape(payload_data.get("message", "")) |
|
formatted_message = ( |
|
f'<pre class="running-message">{progress}: {message}</pre>' |
|
) |
|
history.append({"role": "assistant", "content": formatted_message}) |
|
yield history |
|
|
|
container.get_connector("redis_stream_client")._client.xack( |
|
running_stream_name, group_name, message_id |
|
) |
|
|
|
messages = self._get_redis_stream_message( |
|
group_name, consumer_name, stream_name |
|
) |
|
finish_flag = False |
|
|
|
for stream, message_list in messages: |
|
for message_id, message in message_list: |
|
incomplete_flag = False |
|
payload_data = self._get_message_payload(message) |
|
if payload_data is None: |
|
continue |
|
if payload_data["content_status"] == ContentStatus.INCOMPLETE.value: |
|
incomplete_flag = True |
|
message_item = payload_data["message"] |
|
if message_item["type"] == MessageType.IMAGE_URL.value: |
|
history.append( |
|
{ |
|
"role": "assistant", |
|
"content": {"path": message_item["content"]}, |
|
} |
|
) |
|
else: |
|
if incomplete_flag: |
|
self._incomplete_message = ( |
|
self._incomplete_message + message_item["content"] |
|
) |
|
if history and history[-1]["role"] == "assistant": |
|
history[-1]["content"] = self._incomplete_message |
|
else: |
|
history.append( |
|
{ |
|
"role": "assistant", |
|
"content": self._incomplete_message, |
|
} |
|
) |
|
else: |
|
if self._incomplete_message != "": |
|
self._incomplete_message = ( |
|
self._incomplete_message + message_item["content"] |
|
) |
|
if history and history[-1]["role"] == "assistant": |
|
history[-1]["content"] = self._incomplete_message |
|
else: |
|
history.append( |
|
{ |
|
"role": "assistant", |
|
"content": self._incomplete_message, |
|
} |
|
) |
|
self._incomplete_message = "" |
|
else: |
|
history.append( |
|
{ |
|
"role": "assistant", |
|
"content": message_item["content"], |
|
} |
|
) |
|
|
|
yield history |
|
|
|
container.get_connector("redis_stream_client")._client.xack( |
|
stream_name, group_name, message_id |
|
) |
|
|
|
|
|
if ( |
|
"interaction_type" in payload_data |
|
and payload_data["interaction_type"] == 1 |
|
): |
|
finish_flag = True |
|
if ( |
|
"content_status" in payload_data |
|
and payload_data["content_status"] |
|
== ContentStatus.END_ANSWER.value |
|
): |
|
self._workflow_instance_id = None |
|
finish_flag = True |
|
|
|
if finish_flag: |
|
break |
|
sleep(0.01) |
|
|
|
def processor_bot(self, history: list): |
|
history.append({"role": "assistant", "content": f"processing..."}) |
|
yield history |
|
while True: |
|
status = self._processor.get_workflow( |
|
workflow_id=self._workflow_instance_id |
|
).status |
|
if status in terminal_status: |
|
history.append({"role": "assistant", "content": f"completed"}) |
|
yield history |
|
self._workflow_instance_id = None |
|
break |
|
sleep(0.01) |
|
|
|
def _get_redis_stream_message( |
|
self, group_name: str, consumer_name: str, stream_name: str |
|
): |
|
messages = container.get_connector("redis_stream_client")._client.xreadgroup( |
|
group_name, consumer_name, {stream_name: ">"}, count=1 |
|
) |
|
messages = [ |
|
( |
|
stream, |
|
[ |
|
( |
|
message_id, |
|
{ |
|
k.decode("utf-8"): v.decode("utf-8") |
|
for k, v in message.items() |
|
}, |
|
) |
|
for message_id, message in message_list |
|
], |
|
) |
|
for stream, message_list in messages |
|
] |
|
return messages |
|
|
|
def _check_redis_stream_exist(self, stream_name: str, group_name: str): |
|
try: |
|
container.get_connector("redis_stream_client")._client.xgroup_create( |
|
stream_name, group_name, id="0", mkstream=True |
|
) |
|
except Exception as e: |
|
logging.debug(f"Consumer group may already exist: {e}") |
|
|
|
def _get_message_payload(self, message: dict): |
|
logging.info(f"Received running message: {message}") |
|
payload = message.get("payload") |
|
|
|
if not payload: |
|
logging.error("Payload is empty") |
|
return None |
|
try: |
|
payload_data = json.loads(payload) |
|
except json.JSONDecodeError as e: |
|
logging.error(f"Payload is not a valid JSON: {e}") |
|
return None |
|
return payload_data |
|
|