LRPG / create_world /creator.py
pizb's picture
Upload folder using huggingface_hub
0b75c79 verified
from .generator import generate_chain
from .utils import load_yaml, load_txt, save_json
import ast
from langchain_core.output_parsers import StrOutputParser
def create_custom_world(prompt, language='ํ•œ๊ตญ์–ด', save=False):
'''
config: prompt.yaml
prompt = config['create_custom_world_prompt']
User๊ฐ€ ์ง์ ‘ topic์„ ์ •ํ•˜๊ณ , world_story๋ฅผ ์ž…๋ ฅํ•ฉ๋‹ˆ๋‹ค. ๊ฒŒ์ž„์˜ ์ „๋ฐ˜์ ์ธ ์„ธ๊ณ„๊ด€๊ณผ ๊ฒŒ์ž„์˜ ๋ฃฐ์„ ์ •ํ•ฉ๋‹ˆ๋‹ค.
ex)
topic: ํ•ด๋ฆฌ ํฌํ„ฐ
world_story: ํ•ด๋ฆฌ ํฌํ„ฐ ์„ธ๊ณ„๋Š” ๋งˆ๋ฒ•์‚ฌ๋“ค์˜ ์„ธ๊ณ„์ž…๋‹ˆ๋‹ค.
๋น—์ž๋ฃจ๋ฅผ ํƒ€๊ณ  ๋‚ ์•„๋‹ค๋‹ˆ๋ฉฐ, ์‹ ๊ธฐํ•œ ๋งˆ๋ฒ• ์ƒ๋ฌผ ๊ทธ๋ฆฌํ•€, ํ”ผ๋‹‰์Šค ๋“ฑ์ด ์žˆ์Šต๋‹ˆ๋‹ค.
์–ด๋‘ ์˜ ์„ธ๋ ฅ๊ณผ ๋งž์„œ ์‹ธ์šฐ์„ธ์š”.
'''
topic = input("์„ธ๊ณ„๊ด€ ์ฃผ์ œ๋ฅผ ์•Œ๋ ค ์ฃผ์„ธ์š” ex)๋งˆ๋ฒ•์‚ฌ ์„ธ๊ณ„, ์šฐ์ฃผ ์ „์Ÿ: ")
world_story = input("๊ตฌ์ฒด์ ์ธ ์„ธ๊ณ„๊ด€ ์„ค๋ช…๊ณผ ๋ฃฐ์„ ์†Œ๊ฐœํ•˜์„ธ์š”: ")
prompt_variable = {'topic':topic,
'context':world_story,
'language':language,}
world_summary = generate_chain(prompt, prompt_variable)
if save == True:
save_json(topic + '_world.json', {'topic':topic, 'world_summary':world_summary})
return topic, world_summary
def create_scenario(topic, context, output_count=5, language='ํ•œ๊ตญ์–ด', save=False):
'''
config: prompt.yaml
prompt = config['create_scenario_prompt']
์ฃผ์–ด์ง„ world_summary๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ {output_count} ๊ฐœ์˜ scenario๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'''
prompt_variable = {'topic':topic,
'output_count':output_count,
'context':context,
'language':language,}
scenario = generate_chain(prompt,
prompt_variable)
while(True):
try:
storyline = generate_chain(prompt,
prompt_variable,
parser=StrOutputParser())
scenario = ast.literal_eval(storyline)
if save == True:
prompt_variable['scenario'] = scenario
save_json(str(topic) + '_scenario.json', prompt_variable)
return scenario
except:
continue
def create_storyline(topic, context, output_count=5, language='ํ•œ๊ตญ์–ด', save=False):
'''
config: prompt.yaml
prompt = config['create_storyline_prompt']
์ฃผ์–ด์ง„ scenario๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ์„ธ๋ถ€์ ์ธ ๊ฒŒ์ž„ storyline์„ {output_count} ๊ฐœ์˜ ์›์†Œ๋กœ ๊ฐ€์ง€๋Š” ํŒŒ์ด์ฌ ๋ฆฌ์ŠคํŠธ๋กœ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
'''
prompt_variable = {'topic':topic,
'output_count': output_count,
'context':context,
'language':language,}
while(True):
try:
storyline = generate_chain(prompt,
prompt_variable,
parser=StrOutputParser())
storyline_list = ast.literal_eval(storyline)
if save== True:
prompt_variable['scenario'] = context
prompt_variable['story_line'] = storyline_list
save_json(str(topic) + str('story_line') +'.json', prompt_variable)
return storyline_list
except:
continue
if __name__ == '__main__':
config = load_yaml(path='prompt.yaml')
create_new_world = True
save = True
if create_new_world:
topic, world_summary = create_custom_world(config['create_custom_world_prompt'], save=save)
else:
topic = 'harry potter'
world_summary = load_txt('dummy/world_summary.txt')
print(world_summary)
scenario = create_scenario(topic, world_summary, config['create_scenario_prompt'], output_count=1, save=save)
print(scenario)
print(create_storyline(topic, scenario[0], config['create_storyline_prompt'], save=save))