|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Main model for using MAGNeT. This will combine all the required components |
|
and provide easy access to the generation API. |
|
""" |
|
import typing as tp |
|
import torch |
|
|
|
from .genmodel import BaseGenModel |
|
from .loaders import load_compression_model, load_lm_model_magnet |
|
|
|
|
|
class MAGNeT(BaseGenModel): |
|
"""MAGNeT main model with convenient generation API. |
|
Args: |
|
See MusicGen class. |
|
""" |
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
self.duration = self.lm.cfg.dataset.segment_duration |
|
self.set_generation_params() |
|
|
|
@staticmethod |
|
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None): |
|
"""Return pretrained model, we provide six models: |
|
- facebook/magnet-small-10secs (300M), text to music, 10-second audio samples. |
|
# see: https://huggingface.co./facebook/magnet-small-10secs |
|
- facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples. |
|
# see: https://huggingface.co./facebook/magnet-medium-10secs |
|
- facebook/magnet-small-30secs (300M), text to music, 30-second audio samples. |
|
# see: https://huggingface.co./facebook/magnet-small-30secs |
|
- facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples. |
|
# see: https://huggingface.co./facebook/magnet-medium-30secs |
|
- facebook/audio-magnet-small (300M), text to sound-effect (10-second samples). |
|
# see: https://huggingface.co./facebook/audio-magnet-small |
|
- facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples). |
|
# see: https://huggingface.co./facebook/audio-magnet-medium |
|
""" |
|
if device is None: |
|
if torch.cuda.device_count(): |
|
device = 'cuda' |
|
else: |
|
device = 'cpu' |
|
|
|
compression_model = load_compression_model(name, device=device) |
|
lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device) |
|
|
|
if 'self_wav' in lm.condition_provider.conditioners: |
|
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True |
|
|
|
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm} |
|
return MAGNeT(**kwargs) |
|
|
|
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0, |
|
top_p: float = 0.9, temperature: float = 3.0, |
|
max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0, |
|
decoding_steps: tp.List[int] = [20, 10, 10, 10], |
|
span_arrangement: str = 'nonoverlap'): |
|
"""Set the generation parameters for MAGNeT. |
|
|
|
Args: |
|
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. |
|
top_k (int, optional): top_k used for sampling. Defaults to 0. |
|
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9. |
|
temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0. |
|
max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0. |
|
min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0. |
|
decoding_steps (list of n_q ints, optional): The number of iterative decoding steps, |
|
for each of the n_q RVQ codebooks. |
|
span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap') |
|
or overlapping spans ('stride1') in the masking scheme. |
|
""" |
|
self.generation_params = { |
|
'use_sampling': use_sampling, |
|
'temp': temperature, |
|
'top_k': top_k, |
|
'top_p': top_p, |
|
'max_cfg_coef': max_cfg_coef, |
|
'min_cfg_coef': min_cfg_coef, |
|
'decoding_steps': [int(s) for s in decoding_steps], |
|
'span_arrangement': span_arrangement |
|
} |
|
|