import os from typing import Optional import torch from peft import LoraConfig, TaskType, get_peft_model from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers.models.mistral import MistralConfig, MistralForCausalLM from .ar_warp.ar_warper import GARDiffusionLM from .cdcd.ar_warper import CDCDGARRobertaForDiffusionLM from .cdcd.positionwise_warper_model import ( PositionwiseCDCDRobertaConfig, PositionwiseCDCDRobertaForDiffusionLM, ) from .cdcd.tokenwise_warper_model import TokenwiseCDCDRobertaForDiffusionLM from .cdcd.warper_model import CDCDRobertaConfig, CDCDRobertaForDiffusionLM from .confidence_tracker.confidence_tracker_model import ( ConfidenceTrackerRobertaDiffusionLM, ) from .llama.configuration_llama import LlamaDiffusionConfig from .llama.modeling_llama import LlamaForDiffusionLM, LlamaForSeq2SeqLM from .mistral.configuration_mistral import ( CDCDMistralDiffusionConfig, MistralDiffusionConfig, ) from .mistral.modeling_mistral import ( CDCDMistralForDiffusionLM, MistralForDiffusionLM, MistralForSeq2SeqLM, ) from .mixins.modeling_mixin import CDCDDiffusionModelMixin from .roberta.configuration_roberta import RobertaDiffusionConfig from .roberta.modeling_roberta import RobertaForDiffusionLM def model_config_helper( model_name_or_path: str, use_model: str = "cdcd", is_diffusion: bool = True, conditional_generation: Optional[str] = None, ): if "llama" in model_name_or_path.lower(): if conditional_generation == "seq2seq" and not is_diffusion: return LlamaDiffusionConfig, LlamaForSeq2SeqLM return LlamaDiffusionConfig, LlamaForDiffusionLM if "mistral" in model_name_or_path.lower(): if conditional_generation == "seq2seq" and not is_diffusion: return MistralDiffusionConfig, MistralForSeq2SeqLM if conditional_generation is None and not is_diffusion: return MistralConfig, MistralForCausalLM if use_model == "cdcd": return CDCDMistralDiffusionConfig, CDCDMistralForDiffusionLM return MistralDiffusionConfig, MistralForDiffusionLM if "roberta" in model_name_or_path and use_model == "cdcd": return CDCDRobertaConfig, CDCDRobertaForDiffusionLM elif "roberta" in model_name_or_path and use_model == "tokenwise_cdcd": return CDCDRobertaConfig, TokenwiseCDCDRobertaForDiffusionLM elif "roberta" in model_name_or_path and use_model == "positionwise_cdcd": return PositionwiseCDCDRobertaConfig, PositionwiseCDCDRobertaForDiffusionLM elif "roberta" in model_name_or_path and use_model == "confidence": return RobertaDiffusionConfig, ConfidenceTrackerRobertaDiffusionLM elif "roberta" in model_name_or_path: print( f"Using RobertaDiffusionConfig and RobertaForDiffusionLM for {model_name_or_path}" ) return RobertaDiffusionConfig, RobertaForDiffusionLM elif "roberta" in model_name_or_path and use_model == "cdcdgar": return CDCDRobertaConfig, CDCDGARRobertaForDiffusionLM # default to mistral if use_model == "cdcd": print( f"Using CDCDMistralDiffusionConfig and CDCDMistralForDiffusionLM for {model_name_or_path}" ) return CDCDMistralDiffusionConfig, CDCDMistralForDiffusionLM print( f"Using MistralDiffusionConfig and MistralForDiffusionLM for {model_name_or_path}" ) return MistralDiffusionConfig, MistralForDiffusionLM def is_cdcd_check(model): return ( isinstance(model, CDCDDiffusionModelMixin) or isinstance(model, CDCDMistralForDiffusionLM) or isinstance(model, CDCDRobertaForDiffusionLM) or isinstance(model, TokenwiseCDCDRobertaForDiffusionLM) or isinstance(model, PositionwiseCDCDRobertaForDiffusionLM) or isinstance(model, GARDiffusionLM) or isinstance(model, CDCDGARRobertaForDiffusionLM) ) def is_tokenwise_cdcd_check(model): return isinstance(model, TokenwiseCDCDRobertaForDiffusionLM) or isinstance( model, PositionwiseCDCDRobertaForDiffusionLM ) def freeze(module): for param in module.parameters(): param.requires_grad = False def get_torch_dtype(training_args): torch_dtype = torch.float32 if training_args.bf16: torch_dtype = torch.bfloat16 elif training_args.fp16: torch_dtype = torch.float16 return torch_dtype def load_model(model_args, data_args, training_args, diffusion_args, logger): config_kwargs = { "cache_dir": model_args.cache_dir, "revision": model_args.model_revision, "use_auth_token": True if model_args.use_auth_token else None, } cfg_cls, model_cls = model_config_helper( model_args.model_name_or_path, use_model=model_args.use_model, is_diffusion=diffusion_args.num_diffusion_steps > 0, conditional_generation=data_args.conditional_generation, ) config = cfg_cls.from_pretrained( model_args.model_name_or_path, self_condition=diffusion_args.self_condition, self_condition_zeros_after_softmax=diffusion_args.self_condition_zeros_after_softmax, deepmind_conditional=diffusion_args.deepmind_conditional, classifier_free_simplex_inputs=diffusion_args.classifier_free_simplex_inputs, classifier_free_uncond_input=diffusion_args.classifier_free_uncond_input, self_condition_mlp_projection=diffusion_args.self_condition_mlp_projection, self_condition_mix_before_weights=diffusion_args.self_condition_mix_before_weights, self_condition_mix_logits_before_weights=diffusion_args.self_condition_mix_logits_before_weights, empty_token_be_mask=diffusion_args.empty_token_be_mask, is_causal=model_args.is_causal, mask_padding_in_loss=training_args.mask_padding_in_loss, padding_side=model_args.tokenizer_padding_side, token=os.environ.get("HF_TOKEN", None), **config_kwargs, ) tokenizer_kwargs = { "cache_dir": model_args.cache_dir, "use_fast": model_args.use_fast_tokenizer, "revision": model_args.model_revision, "padding_side": model_args.tokenizer_padding_side, "use_auth_token": True if model_args.use_auth_token else None, } if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, token=os.environ.get("HF_TOKEN", None), **tokenizer_kwargs, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, token=os.environ.get("HF_TOKEN", None), **tokenizer_kwargs, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) try: tokenizer.add_eos_token = True except AttributeError: # roberta does not have this pass if model_args.model_name_or_path and not model_args.from_scratch: model = model_cls.from_pretrained( model_args.model_name_or_path, from_tf=bool(".ckpt" in model_args.model_name_or_path), config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, torch_dtype=get_torch_dtype(training_args), token=os.environ.get("HF_TOKEN", None), attn_implementation="flash_attention_2" if model_args.use_flash_attention2 else "eager", ).to("cuda") if model_args.freeze_embedding: model.get_input_embeddings().requires_grad = False if model_args.freeze_model: freeze(model) else: logger.warning("Training new model from scratch") model = model_cls._from_config(config) model.init_weights() if not tokenizer.pad_token_id: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # on a small vocab and want a smaller embedding size, remove this test. vocab_size = model.get_input_embeddings().weight.shape[0] if len(tokenizer) > vocab_size: model.resize_token_embeddings(len(tokenizer)) model.config.pad_token_id = tokenizer.pad_token_id # if peft, apply it here if model_args.use_lora: peft_config = LoraConfig( task_type=TaskType.CAUSAL_LM, inference_mode=False, r=model_args.lora_rank, lora_alpha=model_args.lora_alpha, lora_dropout=model_args.lora_dropout, ) # we just peft the internal model. # a little hacky, remove the task type wrapper class # TODO: does this cook anything? model.model = get_peft_model(model.model, peft_config).base_model # apply liger monkey patching if model_args.use_liger_kernel: from liger_kernel.transformers import apply_liger_kernel_to_mistral apply_liger_kernel_to_mistral() return tokenizer, model def load_classifier(classifier_model_name_or_path: str): tokenizer = AutoTokenizer.from_pretrained(classifier_model_name_or_path) model = AutoModelForSequenceClassification.from_pretrained( classifier_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ).eval() model.gradient_checkpointing_enable() # NOTE: for quick testing (reduce vram req) # model.model.layers = torch.nn.ModuleList([model.model.layers[0]]) freeze(model) # from liger_kernel.transformers import apply_liger_kernel_to_mistral # apply_liger_kernel_to_mistral() return tokenizer, model def check_tokenizer_equal(tokenizer1, tokenizer2): # check class assert tokenizer1.__class__ is tokenizer2.__class__ # check vocab size assert tokenizer1.vocab_size == tokenizer2.vocab_size # check special tokens size assert len(tokenizer1.special_tokens_map) == len(tokenizer2.special_tokens_map) # check special tokens for special_token in ("bos", "eos", "unk", "pad"): attr = f"{special_token}_token_id" assert getattr(tokenizer1, attr) == getattr(tokenizer2, attr) # full decoding check for i in range(tokenizer1.vocab_size + len(tokenizer1.special_tokens_map)): decoded1 = tokenizer1.decode([i]) decoded2 = tokenizer2.decode([i]) assert decoded1 == decoded2