LRPG / create_world /creator_gradio.py
pizb's picture
Upload folder using huggingface_hub
0b75c79 verified
from .generator import generate_chain
from .utils import load_yaml, load_txt, save_json
from .prompt import create_custom_world_prompt, create_scenario_prompt, create_storyline_prompt
import ast
from langchain_core.output_parsers import StrOutputParser
def create_custom_world(topic, world_story, save=False):
prompt=create_custom_world_prompt
'''
config: prompt.yaml
prompt = config['create_custom_world_prompt']
User๊ฐ€ ์ง์ ‘ topic์„ ์ •ํ•˜๊ณ , world_story๋ฅผ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค. ๊ฒŒ์ž„์˜ ์ „๋ฐ˜์ ์ธ ์„ธ๊ณ„๊ด€๊ณผ ๊ฒŒ์ž„์˜ ๋ฃฐ์„ ์ •ํ•ฉ๋‹ˆ๋‹ค.
ex)
topic: ํ•ด๋ฆฌ ํฌํ„ฐ
world_story: ํ•ด๋ฆฌ ํฌํ„ฐ ์„ธ๊ณ„๋Š” ๋งˆ๋ฒ•์‚ฌ๋“ค์˜ ์„ธ๊ณ„์ž…๋‹ˆ๋‹ค.
๋น—์ž๋ฃจ๋ฅผ ํƒ€๊ณ  ๋‚ ์•„๋‹ค๋‹ˆ๋ฉฐ, ์‹ ๊ธฐํ•œ ๋งˆ๋ฒ• ์ƒ๋ฌผ ๊ทธ๋ฆฌํ•€, ํ”ผ๋‹‰์Šค ๋“ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์–ด๋‘ ์˜ ์„ธ๋ ฅ๊ณผ ๋งž์„œ ์‹ธ์šฐ์„ธ์š”.
'''
prompt_variable = {'topic':topic,
'context':world_story,
'language': "ํ•œ๊ตญ์–ด",}
world_summary = generate_chain(prompt, prompt_variable)
if save == True:
save_json(topic + '_world.json', {'topic':topic, 'world_summary':world_summary})
return topic, world_summary
def create_scenario(topic, context, output_count=5, save=False):
prompt=create_scenario_prompt
'''
config: prompt.yaml
prompt = config['create_scenario_prompt']
์ฃผ์–ด์ง„ world_summary๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ {output_count} ๊ฐœ์˜ scenario๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'''
prompt_variable = {'topic':topic,
'output_count':output_count,
'context':context,
'language': 'ํ•œ๊ตญ์–ด',}
scenario = generate_chain(prompt,
prompt_variable)
while(True):
try:
storyline = generate_chain(prompt,
prompt_variable,
parser=StrOutputParser())
scenario = ast.literal_eval(storyline)
if save == True:
prompt_variable['scenario'] = scenario
save_json(str(topic) + '_scenario.json', prompt_variable)
return scenario
except:
continue
def create_storyline(topic, context, output_count=5, save=False):
'''
config: prompt.yaml
prompt = config['create_storyline_prompt']
์ฃผ์–ด์ง„ scenario๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ธ๋ถ€์ ์ธ ๊ฒŒ์ž„ storyline์„ {output_count} ๊ฐœ์˜ ์›์†Œ๋กœ ๊ฐ€์ง€๋Š” ํŒŒ์ด์ฌ ๋ฆฌ์ŠคํŠธ๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'''
prompt=create_storyline_prompt
prompt_variable = {'topic':topic,
'output_count': output_count,
'context':context,
'language': 'ํ•œ๊ตญ์–ด',}
while(True):
try:
storyline = generate_chain(prompt,
prompt_variable,
parser=StrOutputParser())
storyline_list = ast.literal_eval(storyline)
if not storyline_list[0].get('title') or not storyline_list[0].get('story'):
raise Exception()
if save== True:
prompt_variable['scenario'] = context
prompt_variable['story_line'] = storyline_list
save_json(str(topic) + str('story_line') +'.json', prompt_variable)
return storyline_list
except:
continue