# -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import torch class PromptExtendAnnotator: def __init__(self, cfg, device=None): from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander self.mode = cfg.get('MODE', "local_qwen") self.model_name = cfg.get('MODEL_NAME', "Qwen2.5_3B") self.is_vl = cfg.get('IS_VL', False) self.system_prompt = cfg.get('SYSTEM_PROMPT', None) self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.device_id = self.device.index if self.device.type == 'cuda' else None rank = self.device_id if self.device_id is not None else 0 if self.mode == "dashscope": self.prompt_expander = DashScopePromptExpander( model_name=self.model_name, is_vl=self.is_vl) elif self.mode == "local_qwen": self.prompt_expander = QwenPromptExpander( model_name=self.model_name, is_vl=self.is_vl, device=rank) else: raise NotImplementedError(f"Unsupport prompt_extend_method: {self.mode}") def forward(self, prompt, system_prompt=None, seed=-1): system_prompt = system_prompt if system_prompt is not None else self.system_prompt output = self.prompt_expander(prompt, system_prompt=system_prompt, seed=seed) if output.status == False: print(f"Extending prompt failed: {output.message}") output_prompt = prompt else: output_prompt = output.prompt return output_prompt