hehe / app.py
Reality123b's picture
Update app.py
b36c7d9 verified
# app.py (Slight modification)
from config import Flask, pipeline_dict, Response, convHandler, get_user_id
from application.chat_inference import ChatInference
from flask import render_template, request, make_response
from application.utils.image_captioning import ImageCaptioning
from application.utils.text_to_speech import generate_tts # Import
app = Flask(__name__, template_folder='application/templates', static_folder='application/static')
chat_inference = ChatInference()
image_captioning = ImageCaptioning()
@app.route('/')
def home():
user_id = get_user_id()
response = make_response(render_template('index.html'))
response.set_cookie('user_id', user_id) # Set the cookie
return response
@app.route('/completions', methods=['POST'])
def completeions():
user_id = get_user_id()
data = request.json
models = pipeline_dict['api']['models']
if data.get('model', None) not in models:
return "Model Not Found", 404
model_info = models[data['model']]
data.update({
"base_url": model_info['api_url'],
"type": model_info['type']
})
return chat_inference.chat(data=data, handle_stream=pipeline_dict['handle_stream'], user=user_id)
@app.route('/convs')
def get_conv():
user_id = get_user_id()
return convHandler.get_conv(user_id)
@app.route('/create', methods=['POST'])
def create_conv():
user_id = get_user_id()
sysPrompt = request.json.get('system_prompt', '')
return convHandler.create_conv(ip=user_id, sysPrompt=sysPrompt)
@app.route('/fetch', methods=['POST'])
def fetch():
user_id = get_user_id()
convId = request.json.get('convId')
return convHandler.fetch_conv(convId=convId, ip=user_id)
@app.route('/models')
def models():
return list(pipeline_dict['api']['models'].keys())
@app.route('/tts') # New route for TTS
def tts():
text = request.args.get('text')
if not text:
return "No text provided", 400
audio_stream = generate_tts(text) # Await the result
return Response(audio_stream, mimetype="audio/wav")
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=False)