diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..6bdba7be20bf3cf107c88e5269c6bf7f25b730ec
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,29 @@
+FROM misakiminato/cuda-python:cu12.0.0-py3.8.16-devel-ubuntu18.04
+WORKDIR /app
+
+COPY ./requirements.txt /app/requirements.txt
+COPY ./packages.txt /app/packages.txt
+RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+RUN apt-get update && xargs -r -a /app/packages.txt apt-get install -y && rm -rf /var/lib/apt/lists/*
+RUN pip3 install --no-cache-dir -r /app/requirements.txt
+
+
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+
+# Switch to the "user" user
+USER user
+
+# Set home to the user's home directory
+ENV HOME=/home/user \
+ PATH=/home/user/.local/bin:$PATH
+
+# Set the working directory to the user's home directory
+WORKDIR $HOME/app
+
+# Copy the current directory contents into the container at $HOME/app setting the owner to the user
+COPY --chown=user . $HOME/app
+
+
+EXPOSE 8501
+CMD streamlit run app.py --server.maxUploadSize 1024 --server.enableWebsocketCompression=false --server.enableXsrfProtection=false
diff --git a/Dockerfile.back b/Dockerfile.back
new file mode 100644
index 0000000000000000000000000000000000000000..66f1ddf79bc65bbab100a3aacd1b8b97eb6ee2d2
--- /dev/null
+++ b/Dockerfile.back
@@ -0,0 +1,35 @@
+FROM nvidia/cuda:12.0.0-base-ubuntu20.04
+ARG DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1
+
+WORKDIR /app
+# Install python, git & ffmpeg
+RUN apt-get update && apt-get install --no-install-recommends -y \
+ build-essential \
+ python3.8=3.8.10* \
+ python3-pip \
+ git \
+ ffmpeg \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+COPY ./requirements.txt /app/requirements.txt
+COPY ./packages.txt /app/packages.txt
+RUN pip install --upgrade pip
+RUN pip install pyproject-toml
+RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+RUN apt-get update && xargs -r -a /app/packages.txt apt-get install -y && rm -rf /var/lib/apt/lists/*
+RUN pip3 install --no-cache-dir -r /app/requirements.txt
+RUN pip3 install --no-cache-dir numba==0.56.3
+RUN pip install --no-binary :all: pyworld
+
+WORKDIR /app
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+# Set the working directory to the user's home directory
+WORKDIR $HOME/app
+
+# Copy the current directory contents into the container at $HOME/app setting the owner to the user
+COPY --chown=user . $HOME/app
+
+EXPOSE 8501
+CMD nvidia-smi -l
+CMD streamlit run app.py --server.maxUploadSize 1024 --server.enableWebsocketCompression=false --server.enableXsrfProtection=false
diff --git a/Dockerfile.latestbackup b/Dockerfile.latestbackup
new file mode 100644
index 0000000000000000000000000000000000000000..af25edbc8c0f0ab8d964b5856767c75495862b1a
--- /dev/null
+++ b/Dockerfile.latestbackup
@@ -0,0 +1,42 @@
+FROM nvidia/cuda:12.0.0-base-ubuntu20.04
+ARG DEBIAN_FRONTEND=noninteractive
+ENV PYTHONUNBUFFERED=1
+ENV PYTHON_INCLUDE /usr/include/python3.8
+ENV PYTHON_LIB /usr/lib/x86_64-linux-gnu/libpython3.8.so
+
+WORKDIR /app
+# Install python, git & ffmpeg
+RUN apt-get update && apt-get install --no-install-recommends -y \
+ build-essential \
+ python3.9 \
+ python3-pip \
+ git \
+ ffmpeg \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+COPY ./requirements.txt /app/requirements.txt
+COPY ./packages.txt /app/packages.txt
+RUN pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
+RUN apt-get update && xargs -r -a /app/packages.txt apt-get install -y && rm -rf /var/lib/apt/lists/*
+RUN apt-get update && apt-get install -y python3-dev
+RUN apt-get update && apt-get install -y build-essential
+RUN which python3
+RUN which python3-config
+RUN pip3 install --no-cache-dir -r /app/requirements.txt
+RUN apt-get update && apt-get install -y build-essential
+RUN pip3 install --no-cache-dir numba==0.56.3
+RUN pip install --upgrade pip setuptools wheel
+RUN pip3 install --no-binary :all: pyworld
+RUN pip install soundfile
+
+WORKDIR /app
+# Set up a new user named "user" with user ID 1000
+RUN useradd -m -u 1000 user
+# Set the working directory to the user's home directory
+WORKDIR $HOME/app
+
+# Copy the current directory contents into the container at $HOME/app setting the owner to the user
+COPY --chown=user . $HOME/app
+
+EXPOSE 8501
+CMD nvidia-smi -l
+CMD streamlit run app.py --server.maxUploadSize 1024 --server.enableWebsocketCompression=false --server.enableXsrfProtection=false
\ No newline at end of file
diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 0000000000000000000000000000000000000000..0ad25db4bd1d86c452db3f9602ccdbe172438f52
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,661 @@
+ GNU AFFERO GENERAL PUBLIC LICENSE
+ Version 3, 19 November 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU Affero General Public License is a free, copyleft license for
+software and other kinds of works, specifically designed to ensure
+cooperation with the community in the case of network server software.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+our General Public Licenses are intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ Developers that use our General Public Licenses protect your rights
+with two steps: (1) assert copyright on the software, and (2) offer
+you this License which gives you legal permission to copy, distribute
+and/or modify the software.
+
+ A secondary benefit of defending all users' freedom is that
+improvements made in alternate versions of the program, if they
+receive widespread use, become available for other developers to
+incorporate. Many developers of free software are heartened and
+encouraged by the resulting cooperation. However, in the case of
+software used on network servers, this result may fail to come about.
+The GNU General Public License permits making a modified version and
+letting the public access it on a server without ever releasing its
+source code to the public.
+
+ The GNU Affero General Public License is designed specifically to
+ensure that, in such cases, the modified source code becomes available
+to the community. It requires the operator of a network server to
+provide the source code of the modified version running there to the
+users of that server. Therefore, public use of a modified version, on
+a publicly accessible server, gives the public access to the source
+code of the modified version.
+
+ An older license, called the Affero General Public License and
+published by Affero, was designed to accomplish similar goals. This is
+a different license, not a version of the Affero GPL, but Affero has
+released a new version of the Affero GPL which permits relicensing under
+this license.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU Affero General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Remote Network Interaction; Use with the GNU General Public License.
+
+ Notwithstanding any other provision of this License, if you modify the
+Program, your modified version must prominently offer all users
+interacting with it remotely through a computer network (if your version
+supports such interaction) an opportunity to receive the Corresponding
+Source of your version by providing access to the Corresponding Source
+from a network server at no charge, through some standard or customary
+means of facilitating copying of software. This Corresponding Source
+shall include the Corresponding Source for any work covered by version 3
+of the GNU General Public License that is incorporated pursuant to the
+following paragraph.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the work with which it is combined will remain governed by version
+3 of the GNU General Public License.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU Affero General Public License from time to time. Such new versions
+will be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU Affero General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU Affero General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU Affero General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published
+ by the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If your software can interact with users remotely through a computer
+network, you should also make sure that it provides a way for users to
+get its source. For example, if your program is a web application, its
+interface could display a "Source" link that leads users to an archive
+of the code. There are many ways you could offer source, and different
+solutions will be better for different programs; see section 13 for the
+specific requirements.
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU AGPL, see
+.
diff --git a/README.md b/README.md
index 4d797e1dd44058d82277d7d011aa4ef59a5a01b7..65ca11661c2db5e302bf32c36e8a6fe9da7d034c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,10 @@
---
-title: Diffsvc Test
-emoji: 🔥
-colorFrom: yellow
-colorTo: red
-sdk: gradio
-sdk_version: 3.16.2
-app_file: app.py
+title: DiffSVC Inference
+emoji: 🎙
+colorFrom: red
+colorTo: orange
+sdk: docker
+app_port: 8501
pinned: false
+duplicated_from: DIFF-SVCModel/Inference
---
-
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..d898aab686c0c2271b8d9c8217aa30197eaa6a83
--- /dev/null
+++ b/app.py
@@ -0,0 +1,106 @@
+import gradio as gr
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import json
+import os
+import tempfile
+import shutil
+import requests
+from pathlib import Path
+global ckpt_temp_file
+global audio_temp_file
+global config_temp_file
+###################################################
+from utils.hparams import hparams
+from preprocessing.data_gen_utils import get_pitch_parselmouth,get_pitch_crepe
+import numpy as np
+import matplotlib.pyplot as plt
+import IPython.display as ipd
+import utils
+import librosa
+import torchcrepe
+from infer import *
+import logging
+from infer_tools.infer_tool import *
+import io
+
+spk_dict = {
+ "雷电将军": {"model_name": './models/genshin/raiden.ckpt', "config_name": './models/genshin/config.yaml'}
+}
+
+project_name = "Unnamed"
+model_path = spk_dict['雷电将军']['model_name']
+config_path= spk_dict['雷电将军']['config_name']
+hubert_gpu = False
+svc_model = Svc(project_name, config_path, hubert_gpu, model_path)
+
+def vc_fn(sid, audio_record, audio_upload, tran, pndm_speedup=20):
+ print(sid)
+ if audio_upload is not None:
+ audio_path = audio_upload
+ elif audio_record is not None:
+ audio_path = audio_record
+ else:
+ return "你需要上传wav文件或使用网页内置的录音!", None
+
+ tran = int(tran)
+ pndm_speedup = int(pndm_speedup)
+ print('model loaded')
+ # demoaudio, sr = librosa.load(audio_path)
+ key = tran # 音高调整,支持正负(半音)
+ # 加速倍数
+ pndm_speedup = 20
+ wav_gen='queeeeee.wav'
+
+ # Show the spinner and run the run_clip function inside the 'with' block
+ f0_tst, f0_pred, audio = run_clip(svc_model, file_path=audio_path, key=key, acc=pndm_speedup, use_crepe=True, use_pe=True, thre=0.05,
+ use_gt_mel=False, add_noise_step=500, project_name=project_name, out_path=wav_gen)
+
+ return "Success", (hparams['audio_sample_rate'], audio)
+
+
+app = gr.Blocks()
+with app:
+ with gr.Tabs():
+ with gr.TabItem("Basic"):
+ gr.Markdown(value="""
+ 本模型为sovits_f0(含AI猫雷2.0音色),支持**60s以内**的**无伴奏**wav、mp3(单声道)格式,或使用**网页内置**的录音(二选一)
+
+ 转换效果取决于源音频语气、节奏是否与目标音色相近,以及音域是否超出目标音色音域范围
+
+ 猫雷音色低音音域效果不佳,如转换男声歌声,建议变调升 **6-10key**
+
+ 该模型的 [github仓库链接](https://github.com/innnky/so-vits-svc),如果想自己制作并训练模型可以访问这个 [github仓库](https://github.com/IceKyrin/sovits_guide)
+ """)
+ speaker_id = gr.Dropdown(label="音色", choices=['雷电将军'], value="雷电将军")
+ record_input = gr.Audio(source="microphone", label="录制你的声音", type="filepath", elem_id="audio_inputs")
+ upload_input = gr.Audio(source="upload", label="上传音频(长度小于60秒)", type="filepath",
+ elem_id="audio_inputs")
+ vc_transform = gr.Number(label="变调(整数,可以正负,半音数量,升高八度就是12)", value=0)
+ vc_speedup = gr.Number(label="加速倍数", value=20)
+ vc_submit = gr.Button("转换", variant="primary")
+ out_audio = gr.Audio(label="Output Audio")
+ gr.Markdown(value="""
+ 输出信息为音高平均偏差半音数量,体现转换音频的跑调情况(一般平均小于0.5个半音)
+ """)
+ out_message = gr.Textbox(label="Output")
+ gr.Markdown(value="""f0曲线可以直观的显示跑调情况,蓝色为输入音高,橙色为合成音频的音高
+ 若**只看见橙色**,说明蓝色曲线被覆盖,转换效果较好
+ """)
+ # f0_image = gr.Image(label="f0曲线")
+ vc_submit.click(vc_fn, [speaker_id, record_input, upload_input, vc_transform, vc_speedup],
+ [out_message, out_audio])
+ with gr.TabItem("使用说明"):
+ gr.Markdown(value="""
+ 0、合集:https://github.com/IceKyrin/sovits_guide/blob/main/README.md
+ 1、仅支持sovit_f0(sovits2.0)模型
+ 2、自行下载hubert-soft-0d54a1f4.pt改名为hubert.pt放置于pth文件夹下(已经下好了)
+ https://github.com/bshall/hubert/releases/tag/v0.1
+ 3、pth文件夹下放置sovits2.0的模型
+ 4、与模型配套的xxx.json,需有speaker项——人物列表
+ 5、放无伴奏的音频、或网页内置录音,不要放奇奇怪怪的格式
+ 6、仅供交流使用,不对用户行为负责
+ """)
+
+ app.launch()
diff --git a/app.py.back2 b/app.py.back2
new file mode 100644
index 0000000000000000000000000000000000000000..2c960e3a027a2abdaddcc798c8e021f101f8d14f
--- /dev/null
+++ b/app.py.back2
@@ -0,0 +1,123 @@
+import streamlit as st
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import json
+import os
+import tempfile
+import shutil
+import requests
+from pathlib import Path
+temp_dir = os.path.expanduser("~/app")
+global ckpt_temp_file
+global audio_temp_file
+global config_temp_file
+###################################################
+from utils.hparams import hparams
+from preprocessing.data_gen_utils import get_pitch_parselmouth,get_pitch_crepe
+import numpy as np
+import matplotlib.pyplot as plt
+import IPython.display as ipd
+import utils
+import librosa
+import torchcrepe
+from infer import *
+import logging
+from infer_tools.infer_tool import *
+import io
+
+clip_completed = False
+def render_audio(ckpt_temp_file, config_temp_file, audio_temp_file, title):
+ logging.getLogger('numba').setLevel(logging.WARNING)
+ title = int(title)
+ project_name = "Unnamed"
+ model_path = ckpt_temp_file
+ config_path= config_temp_file
+ hubert_gpu=True
+ svc_model = Svc(project_name,config_path,hubert_gpu, model_path)
+ print('model loaded')
+ wav_fn = audio_temp_file
+ demoaudio, sr = librosa.load(wav_fn)
+ key = title # 音高调整,支持正负(半音)
+ # 加速倍数
+ pndm_speedup = 20
+ wav_gen='queeeeee.wav'#直接改后缀可以保存不同格式音频,如flac可无损压缩
+
+ # Show the spinner and run the run_clip function inside the 'with' block
+ with st.spinner("Rendering Audio..."):
+ f0_tst, f0_pred, audio = run_clip(svc_model,file_path=wav_fn, key=key, acc=pndm_speedup, use_crepe=True, use_pe=True, thre=0.05,
+ use_gt_mel=False, add_noise_step=500,project_name=project_name,out_path=wav_gen)
+ clip_completed = True
+ if clip_completed:
+ # If the 'run_clip' function has completed, use the st.audio function to show an audio player for the file stored in the 'wav_gen' variable
+ st.audio(wav_gen)
+
+#######################################################
+st.set_page_config(
+ page_title="DiffSVC Render",
+ page_icon="🧊",
+ initial_sidebar_state="expanded",
+)
+############
+st.title('DIFF-SVC Render')
+
+###CKPT LOADER
+with tempfile.TemporaryDirectory(dir=os.path.expanduser("~/app")) as temp_dir:
+ ckpt = st.file_uploader("Choose your CKPT", type= 'ckpt')
+ # Check if user uploaded a CKPT file
+ if ckpt is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="wb", suffix='.ckpt', delete=False) as temp:
+ # Get the file contents as bytes
+ bytes_data = ckpt.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ ckpt_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+ # Display the file path
+ if "ckpt_temp_file" in locals():
+ st.success("File saved to: {}".format(ckpt_temp_file))
+
+ # File uploader
+ config = st.file_uploader("Choose your config", type= 'yaml')
+
+ # Check if user uploaded a config file
+ if config is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="wb", suffix='.yaml', delete=False) as temp:
+ # Get the file contents as bytes
+ bytes_data = config.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ config_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+ # Display the file path
+ if "config_temp_file" in locals():
+ st.success("File saved to: {}".format(config_temp_file))
+
+ audio = st.file_uploader("Choose your audio", type=["wav", "mp3"])
+
+ # Check if user uploaded an audio file
+ if audio is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="wb", suffix='.wav', delete=False) as temp:
+ # Get the file contents as bytes
+ bytes_data = audio.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ audio_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+# Display the file path
+ if "audio_temp_file" in locals():
+ st.success("File saved to: {}".format(audio_temp_file))
+# Add a text input for the title with a default value of 0
+title = st.text_input("Key", value="0")
+# Add a button to start the rendering process
+if st.button("Render audio"):
+ render_audio(ckpt_temp_file, config_temp_file, audio_temp_file, title)
\ No newline at end of file
diff --git a/app.pyback b/app.pyback
new file mode 100644
index 0000000000000000000000000000000000000000..028481c47efd6e4ae06c1ae1054783176aedc6e0
--- /dev/null
+++ b/app.pyback
@@ -0,0 +1,123 @@
+import streamlit as st
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+import json
+import os
+import tempfile
+import shutil
+import requests
+from pathlib import Path
+temp_dir = os.path.expanduser("/~app")
+global ckpt_temp_file
+global audio_temp_file
+global config_temp_file
+###################################################
+from utils.hparams import hparams
+from preprocessing.data_gen_utils import get_pitch_parselmouth,get_pitch_crepe
+import numpy as np
+import matplotlib.pyplot as plt
+import IPython.display as ipd
+import utils
+import librosa
+import torchcrepe
+from infer import *
+import logging
+from infer_tools.infer_tool import *
+import io
+
+clip_completed = False
+def render_audio(ckpt_temp_file, config_temp_file, audio_temp_file, title):
+ logging.getLogger('numba').setLevel(logging.WARNING)
+ title = int(title)
+ project_name = "Unnamed"
+ model_path = ckpt_temp_file
+ config_path= config_temp_file
+ hubert_gpu=True
+ svc_model = Svc(project_name,config_path,hubert_gpu, model_path)
+ print('model loaded')
+ wav_fn = audio_temp_file
+ demoaudio, sr = librosa.load(wav_fn)
+ key = title # 音高调整,支持正负(半音)
+ # 加速倍数
+ pndm_speedup = 20
+ wav_gen='queeeeee.wav'#直接改后缀可以保存不同格式音频,如flac可无损压缩
+
+ # Show the spinner and run the run_clip function inside the 'with' block
+ with st.spinner("Rendering Audio..."):
+ f0_tst, f0_pred, audio = run_clip(svc_model,file_path=wav_fn, key=key, acc=pndm_speedup, use_crepe=True, use_pe=True, thre=0.05,
+ use_gt_mel=False, add_noise_step=500,project_name=project_name,out_path=wav_gen)
+ clip_completed = True
+ if clip_completed:
+ # If the 'run_clip' function has completed, use the st.audio function to show an audio player for the file stored in the 'wav_gen' variable
+ st.audio(wav_gen)
+
+#######################################################
+st.set_page_config(
+ page_title="DiffSVC Render",
+ page_icon="🧊",
+ initial_sidebar_state="expanded",
+)
+############
+st.title('DIFF-SVC Render')
+
+###CKPT LOADER
+with tempfile.TemporaryDirectory(dir=os.path.expanduser("/~app")) as temp_dir:
+ ckpt = st.file_uploader("Choose your CKPT", type= 'ckpt')
+ # Check if user uploaded a CKPT file
+ if ckpt is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="wb", suffix='.ckpt', delete=False, dir=temp_dir) as temp:
+ # Get the file contents as bytes
+ bytes_data = ckpt.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ ckpt_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+ # Display the file path
+ if "ckpt_temp_file" in locals():
+ st.success("File saved to: {}".format(ckpt_temp_file))
+
+ # File uploader
+ config = st.file_uploader("Choose your config", type= 'yaml')
+
+ # Check if user uploaded a config file
+ if config is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="w", suffix='.yaml', delete=False, dir=temp_dir) as temp:
+ # Get the file contents as bytes
+ bytes_data = config.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ config_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+ # Display the file path
+ if "config_temp_file" in locals():
+ st.success("File saved to: {}".format(config_temp_file))
+
+ audio = st.file_uploader("Choose your audio", type=["wav", "mp3"])
+
+ # Check if user uploaded an audio file
+ if audio is not None:
+ #TEMP FUNCTION
+ with tempfile.NamedTemporaryFile(mode="wb", suffix='.wav', delete=False, dir=temp_dir) as temp:
+ # Get the file contents as bytes
+ bytes_data = audio.getvalue()
+ # Write the bytes to the temporary file
+ temp.write(bytes_data)
+ audio_temp_file = temp.name
+ # Print the temporary file name
+ print(temp.name)
+
+# Display the file path
+ if "audio_temp_file" in locals():
+ st.success("File saved to: {}".format(audio_temp_file))
+# Add a text input for the title with a default value of 0
+title = st.text_input("Key", value="0")
+# Add a button to start the rendering process
+if st.button("Render audio"):
+ render_audio(ckpt_temp_file, config_temp_file, audio_temp_file, title)
\ No newline at end of file
diff --git a/batch.py b/batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..07b283e6de56d70c3dac2883830ab84f132ec4c5
--- /dev/null
+++ b/batch.py
@@ -0,0 +1,43 @@
+import soundfile
+
+from infer_tools import infer_tool
+from infer_tools.infer_tool import Svc
+
+
+def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
+ file_path=None, out_path=None):
+ raw_audio_path = f_name
+ infer_tool.format_wav(raw_audio_path)
+ _f0_tst, _f0_pred, _audio = svc_model.infer(raw_audio_path, key=key, acc=acc, singer=True, use_pe=use_pe,
+ use_crepe=use_crepe,
+ thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
+ out_path = f'./singer_data/{f_name.split("/")[-1]}'
+ soundfile.write(out_path, _audio, 44100, 'PCM_16')
+
+
+if __name__ == '__main__':
+ # 工程文件夹名,训练时用的那个
+ project_name = "firefox"
+ model_path = f'./checkpoints/{project_name}/clean_model_ckpt_steps_100000.ckpt'
+ config_path = f'./checkpoints/{project_name}/config.yaml'
+
+ # 支持多个wav/ogg文件,放在raw文件夹下,带扩展名
+ file_names = infer_tool.get_end_file("./batch", "wav")
+ trans = [-6] # 音高调整,支持正负(半音),数量与上一行对应,不足的自动按第一个移调参数补齐
+ # 加速倍数
+ accelerate = 50
+ hubert_gpu = True
+ cut_time = 30
+
+ # 下面不动
+ infer_tool.mkdir(["./batch", "./singer_data"])
+ infer_tool.fill_a_to_b(trans, file_names)
+
+ model = Svc(project_name, config_path, hubert_gpu, model_path)
+ count = 0
+ for f_name, tran in zip(file_names, trans):
+ print(f_name)
+ run_clip(model, key=tran, acc=accelerate, use_crepe=False, thre=0.05, use_pe=False, use_gt_mel=False,
+ add_noise_step=500, f_name=f_name, project_name=project_name)
+ count += 1
+ print(f"process:{round(count * 100 / len(file_names), 2)}%")
diff --git a/checkpoints/0102_xiaoma_pe/config.yaml b/checkpoints/0102_xiaoma_pe/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..69a88444205377d48573d53bb4fb500860976588
--- /dev/null
+++ b/checkpoints/0102_xiaoma_pe/config.yaml
@@ -0,0 +1,172 @@
+accumulate_grad_batches: 1
+audio_num_mel_bins: 80
+audio_sample_rate: 24000
+base_config:
+- configs/tts/lj/fs2.yaml
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_f0cwt: true
+ with_spk_embed: true
+ with_txt: true
+ with_wav: false
+binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer
+binary_data_dir: data/binary/xiaoma1022_24k_128hop
+check_val_every_n_epoch: 10
+clip_grad_norm: 1
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decoder_type: fft
+dict_dir: ''
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 2
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: true
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 512
+fmax: 12000
+fmin: 30
+gen_dir_name: ''
+hidden_size: 256
+hop_size: 128
+infer: false
+lambda_commit: 0.25
+lambda_energy: 0.1
+lambda_f0: 1.0
+lambda_ph_dur: 1.0
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 2.0
+max_epochs: 1000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 5000
+max_input_tokens: 1550
+max_sentences: 100000
+max_tokens: 20000
+max_updates: 60000
+mel_loss: l1
+mel_vmax: 1.5
+mel_vmin: -6
+min_level_db: -120
+norm_type: gn
+num_ckpt_keep: 3
+num_heads: 2
+num_sanity_val_steps: 5
+num_spk: 1
+num_test_samples: 20
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor_conv_layers: 2
+pitch_loss: l1
+pitch_norm: log
+pitch_type: frame
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: en
+ use_sox: false
+ use_tone: true
+pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 2
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: ''
+processed_data_dir: data/processed/ljspeech
+profile_infer: false
+raw_data_dir: data/raw/LJSpeech-1.1
+ref_norm_layer: bn
+reset_phone_dict: true
+save_best: false
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- tasks
+- utils
+- usr
+save_f0: false
+save_gt: false
+seed: 1234
+sort_by_len: true
+stop_token_weight: 5.0
+task_cls: tasks.tts.pe.PitchExtractionTask
+test_ids:
+- 68
+- 70
+- 74
+- 87
+- 110
+- 172
+- 190
+- 215
+- 231
+- 294
+- 316
+- 324
+- 402
+- 422
+- 485
+- 500
+- 505
+- 508
+- 509
+- 519
+test_input_dir: ''
+test_num: 523
+test_set_name: test
+train_set_name: train
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: true
+use_var_enc: false
+val_check_interval: 2000
+valid_num: 348
+valid_set_name: valid
+vocoder: pwg
+vocoder_ckpt: ''
+warmup_updates: 2000
+weight_decay: 0
+win_size: 512
+work_dir: checkpoints/0102_xiaoma_pe
diff --git a/checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt b/checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..fb8ffa57bddeef3b1def9ebf3e12311f0a1ea799
--- /dev/null
+++ b/checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1863f12324e43783089ab933edeeb969106b851e30d71019ebbaa9b82099d82a
+size 39141959
diff --git a/checkpoints/0109_hifigan_bigpopcs_hop128/config.yaml b/checkpoints/0109_hifigan_bigpopcs_hop128/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95fc5414ba1aff1bad8284ebfba52f5636b4d76d
--- /dev/null
+++ b/checkpoints/0109_hifigan_bigpopcs_hop128/config.yaml
@@ -0,0 +1,241 @@
+accumulate_grad_batches: 1
+adam_b1: 0.8
+adam_b2: 0.99
+amp: false
+audio_num_mel_bins: 80
+audio_sample_rate: 24000
+aux_context_window: 0
+#base_config:
+#- egs/egs_bases/singing/pwg.yaml
+#- egs/egs_bases/tts/vocoder/hifigan.yaml
+binarization_args:
+ reset_phone_dict: true
+ reset_word_dict: true
+ shuffle: false
+ trim_eos_bos: false
+ trim_sil: false
+ with_align: false
+ with_f0: true
+ with_f0cwt: false
+ with_linear: false
+ with_spk_embed: false
+ with_spk_id: true
+ with_txt: false
+ with_wav: true
+ with_word: false
+binarizer_cls: data_gen.tts.singing.binarize.SingingBinarizer
+binary_data_dir: data/binary/big_popcs_24k_hop128
+check_val_every_n_epoch: 10
+clip_grad_norm: 1
+clip_grad_value: 0
+datasets: []
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+dict_dir: ''
+disc_start_steps: 40000
+discriminator_grad_norm: 1
+discriminator_optimizer_params:
+ eps: 1.0e-06
+ lr: 0.0002
+ weight_decay: 0.0
+discriminator_params:
+ bias: true
+ conv_channels: 64
+ in_channels: 1
+ kernel_size: 3
+ layers: 10
+ nonlinear_activation: LeakyReLU
+ nonlinear_activation_params:
+ negative_slope: 0.2
+ out_channels: 1
+ use_weight_norm: true
+discriminator_scheduler_params:
+ gamma: 0.999
+ step_size: 600
+dropout: 0.1
+ds_workers: 1
+enc_ffn_kernel_size: 9
+enc_layers: 4
+endless_ds: true
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 512
+fmax: 12000
+fmin: 30
+frames_multiple: 1
+gen_dir_name: ''
+generator_grad_norm: 10
+generator_optimizer_params:
+ eps: 1.0e-06
+ lr: 0.0002
+ weight_decay: 0.0
+generator_params:
+ aux_channels: 80
+ dropout: 0.0
+ gate_channels: 128
+ in_channels: 1
+ kernel_size: 3
+ layers: 30
+ out_channels: 1
+ residual_channels: 64
+ skip_channels: 64
+ stacks: 3
+ upsample_net: ConvInUpsampleNetwork
+ upsample_params:
+ upsample_scales:
+ - 2
+ - 4
+ - 4
+ - 4
+ use_nsf: false
+ use_pitch_embed: true
+ use_weight_norm: true
+generator_scheduler_params:
+ gamma: 0.999
+ step_size: 600
+griffin_lim_iters: 60
+hidden_size: 256
+hop_size: 128
+infer: false
+lambda_adv: 1.0
+lambda_cdisc: 4.0
+lambda_energy: 0.0
+lambda_f0: 0.0
+lambda_mel: 5.0
+lambda_mel_adv: 1.0
+lambda_ph_dur: 0.0
+lambda_sent_dur: 0.0
+lambda_uv: 0.0
+lambda_word_dur: 0.0
+load_ckpt: ''
+loud_norm: false
+lr: 2.0
+max_epochs: 1000
+max_frames: 2400
+max_input_tokens: 1550
+max_samples: 8192
+max_sentences: 20
+max_tokens: 24000
+max_updates: 3000000
+max_valid_sentences: 1
+max_valid_tokens: 60000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6
+min_frames: 0
+min_level_db: -120
+num_ckpt_keep: 3
+num_heads: 2
+num_mels: 80
+num_sanity_val_steps: 5
+num_spk: 100
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pitch_extractor: parselmouth
+pitch_type: frame
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ sox_resample: true
+ sox_to_wav: false
+ trim_sil: false
+ txt_processor: zh
+ use_tone: false
+pre_align_cls: data_gen.tts.singing.pre_align.SingingPreAlign
+predictor_grad: 0.0
+print_nan_grads: false
+processed_data_dir: ''
+profile_infer: false
+raw_data_dir: ''
+ref_level_db: 20
+rename_tmux: true
+rerun_gen: true
+resblock: '1'
+resblock_dilation_sizes:
+- - 1
+ - 3
+ - 5
+- - 1
+ - 3
+ - 5
+- - 1
+ - 3
+ - 5
+resblock_kernel_sizes:
+- 3
+- 7
+- 11
+resume_from_checkpoint: 0
+save_best: true
+save_codes: []
+save_f0: true
+save_gt: true
+scheduler: rsqrt
+seed: 1234
+sort_by_len: true
+stft_loss_params:
+ fft_sizes:
+ - 1024
+ - 2048
+ - 512
+ hop_sizes:
+ - 120
+ - 240
+ - 50
+ win_lengths:
+ - 600
+ - 1200
+ - 240
+ window: hann_window
+task_cls: tasks.vocoder.hifigan.HifiGanTask
+tb_log_interval: 100
+test_ids: []
+test_input_dir: ''
+test_num: 50
+test_prefixes: []
+test_set_name: test
+train_set_name: train
+train_sets: ''
+upsample_initial_channel: 512
+upsample_kernel_sizes:
+- 16
+- 16
+- 4
+- 4
+upsample_rates:
+- 8
+- 4
+- 2
+- 2
+use_cdisc: false
+use_cond_disc: false
+use_fm_loss: false
+use_gt_dur: true
+use_gt_f0: true
+use_mel_loss: true
+use_ms_stft: false
+use_pitch_embed: true
+use_ref_enc: true
+use_spec_disc: false
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+val_check_interval: 2000
+valid_infer_interval: 10000
+valid_monitor_key: val_loss
+valid_monitor_mode: min
+valid_set_name: valid
+vocoder: pwg
+vocoder_ckpt: ''
+vocoder_denoise_c: 0.0
+warmup_updates: 8000
+weight_decay: 0
+win_length: null
+win_size: 512
+window: hann
+word_size: 3000
+work_dir: checkpoints/0109_hifigan_bigpopcs_hop128
diff --git a/checkpoints/0109_hifigan_bigpopcs_hop128/model_ckpt_steps_1512000.ckpt b/checkpoints/0109_hifigan_bigpopcs_hop128/model_ckpt_steps_1512000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..ed55eaa98f86e3e22f4eb4e8115f254745cea155
--- /dev/null
+++ b/checkpoints/0109_hifigan_bigpopcs_hop128/model_ckpt_steps_1512000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1cb68f3ce0c46ba0a8b6d49718f1fffdf5bd7bcab769a986fd2fd129835cc1d1
+size 55827436
diff --git a/checkpoints/Unnamed/config.yaml b/checkpoints/Unnamed/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e49935b8b7ff2a83ba1008f3617497e8c3c6da53
--- /dev/null
+++ b/checkpoints/Unnamed/config.yaml
@@ -0,0 +1,445 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 128
+audio_sample_rate: 44100
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/Unnamed
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config_nsf.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 60000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: false
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 40.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 2048
+fmax: 16000
+fmin: 40
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 512
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 128
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 0.0008
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 14
+max_tokens: 128000
+max_updates: 1000000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+no_fs2: true
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: ''
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/Unnamed
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 384
+residual_layers: 20
+save_best: true
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: Unnamed
+spec_max:
+- 0.47615352272987366
+- 0.6125704050064087
+- 0.7518845796585083
+- 0.900716245174408
+- 0.8935521841049194
+- 0.9057011604309082
+- 0.9648348689079285
+- 0.9044283032417297
+- 0.9109272360801697
+- 0.9744535088539124
+- 0.9476388692855835
+- 0.9883336424827576
+- 1.0821290016174316
+- 1.046391248703003
+- 0.9829667806625366
+- 1.0163493156433105
+- 0.9825412631034851
+- 1.0021960735321045
+- 1.052114725112915
+- 1.128888726234436
+- 1.186057209968567
+- 1.112004280090332
+- 1.1282787322998047
+- 1.051572322845459
+- 1.1104764938354492
+- 1.176831603050232
+- 1.13348388671875
+- 1.1075258255004883
+- 1.1696264743804932
+- 1.0231049060821533
+- 0.9303848743438721
+- 1.1257890462875366
+- 1.1610286235809326
+- 1.0335885286331177
+- 1.0645352602005005
+- 1.0619306564331055
+- 1.1310148239135742
+- 1.1191954612731934
+- 1.1307402849197388
+- 1.2094721794128418
+- 1.2683185338974
+- 1.1212272644042969
+- 1.1781182289123535
+- 1.1501952409744263
+- 0.9884514808654785
+- 0.9226155281066895
+- 0.9469702839851379
+- 1.023751139640808
+- 1.1348609924316406
+- 1.087107539176941
+- 0.9899962544441223
+- 1.061837077140808
+- 1.0341650247573853
+- 0.9019684195518494
+- 0.7986546158790588
+- 0.7983465194702148
+- 0.7755436301231384
+- 0.701917290687561
+- 0.7639197707176208
+- 0.7503461837768555
+- 0.6701087951660156
+- 0.5326520800590515
+- 0.6320568323135376
+- 0.4748716950416565
+- 0.41016310453414917
+- 0.4754445552825928
+- 0.4267503023147583
+- 0.391481876373291
+- 0.3118276298046112
+- 0.3193877339363098
+- 0.3111794888973236
+- 0.3342774212360382
+- 0.1353837102651596
+- 0.16596835851669312
+- 0.1730986088514328
+- 0.2325316220521927
+- 0.17107760906219482
+- 0.10877621918916702
+- 0.2612082064151764
+- 0.11200784891843796
+- 0.14075303077697754
+- 0.07312829792499542
+- -0.011712555773556232
+- 0.1741427332162857
+- 0.19782507419586182
+- 0.03305494412779808
+- 0.004054426681250334
+- 0.1011907309293747
+- 0.1317272037267685
+- 0.014256341382861137
+- 0.019952761009335518
+- -0.1253873109817505
+- -0.14854255318641663
+- -0.14063480496406555
+- -0.1331133395433426
+- -0.28339776396751404
+- -0.38559386134147644
+- -0.2798943519592285
+- -0.19351321458816528
+- -0.23238061368465424
+- -0.2850213944911957
+- -0.20320385694503784
+- -0.24087588489055634
+- -0.15823237597942352
+- -0.13949760794639587
+- -0.19627133011817932
+- -0.1920071393251419
+- -0.19384469091892242
+- -0.22403620183467865
+- -0.18197931349277496
+- -0.28423866629600525
+- -0.26859334111213684
+- -0.3213472068309784
+- -0.3303631842136383
+- -0.3835512697696686
+- -0.3256210386753082
+- -0.3938714265823364
+- -0.4373253881931305
+- -0.4146285951137543
+- -0.4861420691013336
+- -0.4018196761608124
+- -0.46770456433296204
+- -0.4100344479084015
+- -0.5364681482315063
+- -0.5802102088928223
+- -0.5856970548629761
+- -0.47378262877464294
+- -0.36258620023727417
+spec_min:
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: false
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 1000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.nsf_hifigan.NsfHifiGAN
+vocoder_ckpt: checkpoints/nsf_hifigan/model
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 2048
+work_dir: checkpoints/Unnamed
diff --git a/checkpoints/Unnamed/config_nsf.yaml b/checkpoints/Unnamed/config_nsf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0f07e15dbd2382be4a77521411e52b9214e5444f
--- /dev/null
+++ b/checkpoints/Unnamed/config_nsf.yaml
@@ -0,0 +1,445 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 128
+audio_sample_rate: 44100
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/Unnamed
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config_nsf.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 20000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: false
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 40.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 2048
+fmax: 16000
+fmin: 40
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 512
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 128
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: pretrain/nehito_ckpt_steps_1000000.ckpt
+log_interval: 100
+loud_norm: false
+lr: 5.0e-05
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 12
+max_tokens: 128000
+max_updates: 1000000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+no_fs2: true
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: ''
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/Unnamed
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 384
+residual_layers: 20
+save_best: false
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: Unnamed
+spec_max:
+- -0.4884430170059204
+- 0.004534448496997356
+- 0.5684943795204163
+- 0.6527385115623474
+- 0.659079372882843
+- 0.7416915893554688
+- 0.844637930393219
+- 0.806076169013977
+- 0.7238750457763672
+- 0.9744535088539124
+- 0.9476388692855835
+- 0.9883336424827576
+- 1.0821290016174316
+- 1.046391248703003
+- 0.9829667806625366
+- 1.0163493156433105
+- 0.9825412631034851
+- 0.9834834337234497
+- 0.9811502695083618
+- 1.128888726234436
+- 1.186057209968567
+- 1.112004280090332
+- 1.1282787322998047
+- 1.051572322845459
+- 1.0510444641113281
+- 1.0110565423965454
+- 0.9236567616462708
+- 0.8036720156669617
+- 0.8383486270904541
+- 0.7735869288444519
+- 0.9303848743438721
+- 1.1257890462875366
+- 1.1610286235809326
+- 1.0335885286331177
+- 1.0645352602005005
+- 1.0619306564331055
+- 1.1310148239135742
+- 1.1191954612731934
+- 1.1307402849197388
+- 0.8837698698043823
+- 1.1153966188430786
+- 1.1045044660568237
+- 1.0479614734649658
+- 0.9491603374481201
+- 0.9858523011207581
+- 0.9226155281066895
+- 0.9469702839851379
+- 0.8791896104812622
+- 0.997624933719635
+- 0.9068642854690552
+- 0.9575618505477905
+- 0.8551340699195862
+- 0.8397778272628784
+- 0.8908605575561523
+- 0.7986546158790588
+- 0.7983465194702148
+- 0.6965265274047852
+- 0.640673041343689
+- 0.6690735220909119
+- 0.5631484985351562
+- 0.48587048053741455
+- 0.5326520800590515
+- 0.4286036193370819
+- 0.35252484679222107
+- 0.3290073573589325
+- 0.4754445552825928
+- 0.3632410168647766
+- 0.391481876373291
+- 0.20288512110710144
+- 0.18305960297584534
+- 0.1539602279663086
+- 0.03451670706272125
+- -0.16881510615348816
+- -0.02030198462307453
+- 0.10024689882993698
+- -0.023952053859829903
+- 0.05635542422533035
+- 0.10877621918916702
+- 0.006155031267553568
+- 0.07318088412284851
+- 0.14075303077697754
+- 0.057870157063007355
+- -0.0520513579249382
+- 0.1741427332162857
+- -0.11464552581310272
+- 0.03305494412779808
+- -0.06897418200969696
+- -0.12598733603954315
+- -0.09894973039627075
+- -0.2817802429199219
+- -0.0825519785284996
+- -0.3040400445461273
+- -0.4998124837875366
+- -0.36957985162734985
+- -0.5409602522850037
+- -0.49879470467567444
+- -0.713716983795166
+- -0.6545754671096802
+- -0.6425778865814209
+- -0.6178902387619019
+- -0.47356730699539185
+- -0.6165243983268738
+- -0.5841533541679382
+- -0.5759448409080505
+- -0.5498068332672119
+- -0.4661938548088074
+- -0.5811225771903992
+- -0.614664614200592
+- -0.3902229070663452
+- -0.7037366032600403
+- -0.7260795831680298
+- -0.7540019750595093
+- -0.8360528945922852
+- -0.8374698758125305
+- -0.8328713178634644
+- -0.9081047177314758
+- -0.9679695963859558
+- -0.9587443470954895
+- -1.0706337690353394
+- -0.9818469285964966
+- -0.8360191583633423
+- -0.9938981533050537
+- -1.0823708772659302
+- -1.0617167949676514
+- -1.1093820333480835
+- -1.1300138235092163
+- -1.2141350507736206
+- -1.3147293329238892
+spec_min:
+- -4.473258972167969
+- -4.244492530822754
+- -4.390527725219727
+- -4.209497928619385
+- -4.446024417877197
+- -4.3960185050964355
+- -4.164802551269531
+- -4.5063300132751465
+- -4.608232021331787
+- -4.251623630523682
+- -4.4799604415893555
+- -4.733210563659668
+- -4.411860466003418
+- -4.609100818634033
+- -4.726972579956055
+- -4.428761959075928
+- -4.487612247467041
+- -4.525552749633789
+- -4.480506896972656
+- -4.589383125305176
+- -4.608384132385254
+- -4.385376453399658
+- -4.816161632537842
+- -4.8706955909729
+- -4.848956108093262
+- -4.431278705596924
+- -4.999994277954102
+- -4.818373203277588
+- -4.527368068695068
+- -4.872085094451904
+- -4.894851207733154
+- -4.511948585510254
+- -4.534575939178467
+- -4.57792854309082
+- -4.444681644439697
+- -4.628803253173828
+- -4.74341344833374
+- -4.85427713394165
+- -4.723776817321777
+- -4.7166008949279785
+- -4.749168395996094
+- -4.67240047454834
+- -4.590690612792969
+- -4.576009750366211
+- -4.542308330535889
+- -4.890907287597656
+- -4.613001823425293
+- -4.494126796722412
+- -4.474257946014404
+- -4.574635028839111
+- -4.4817585945129395
+- -4.651009559631348
+- -4.478254795074463
+- -4.523812770843506
+- -4.546536922454834
+- -4.535660266876221
+- -4.470296859741211
+- -4.577486991882324
+- -4.541748046875
+- -4.428532123565674
+- -4.461862564086914
+- -4.489077091217041
+- -4.515830039978027
+- -4.395663738250732
+- -4.439975738525391
+- -4.4290876388549805
+- -4.397741794586182
+- -4.478252410888672
+- -4.399686336517334
+- -4.45617151260376
+- -4.434477806091309
+- -4.442898750305176
+- -4.5840277671813965
+- -4.537542819976807
+- -4.492046356201172
+- -4.534677505493164
+- -4.477104187011719
+- -4.511618614196777
+- -4.387601375579834
+- -4.499236106872559
+- -4.3717169761657715
+- -4.4242024421691895
+- -4.4055657386779785
+- -4.429355144500732
+- -4.4636993408203125
+- -4.508528232574463
+- -4.515079498291016
+- -4.426190376281738
+- -4.433525085449219
+- -4.4200215339660645
+- -4.421280860900879
+- -4.400143623352051
+- -4.419166088104248
+- -4.429825305938721
+- -4.436781406402588
+- -4.51550817489624
+- -4.518474578857422
+- -4.495880603790283
+- -4.483924865722656
+- -4.409562587738037
+- -4.3811845779418945
+- -4.411908149719238
+- -4.427165985107422
+- -4.396549701690674
+- -4.340637683868408
+- -4.405435085296631
+- -4.367630481719971
+- -4.419083595275879
+- -4.389026165008545
+- -4.371067047119141
+- -4.370710372924805
+- -4.3755269050598145
+- -4.39500093460083
+- -4.451773166656494
+- -4.365351676940918
+- -4.348028182983398
+- -4.408270359039307
+- -4.390385627746582
+- -4.347931861877441
+- -4.378237247467041
+- -4.426717758178711
+- -4.364233493804932
+- -4.371546745300293
+- -4.402477264404297
+- -4.430750846862793
+- -4.404538154602051
+- -4.384459018707275
+- -4.401677131652832
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: false
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 1000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.nsf_hifigan.NsfHifiGAN
+vocoder_ckpt: checkpoints/nsf_hifigan/model
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 2048
+work_dir: checkpoints/HokoHifi
diff --git a/checkpoints/Unnamed/lightning_logs/lastest/hparams.yaml b/checkpoints/Unnamed/lightning_logs/lastest/hparams.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0967ef424bce6791893e9a57bb952f80fd536e93
--- /dev/null
+++ b/checkpoints/Unnamed/lightning_logs/lastest/hparams.yaml
@@ -0,0 +1 @@
+{}
diff --git a/checkpoints/Unnamed/model_ckpt_steps_192000.ckpt b/checkpoints/Unnamed/model_ckpt_steps_192000.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..c758264b63b667b86aa648c71426f07993c9d55e
--- /dev/null
+++ b/checkpoints/Unnamed/model_ckpt_steps_192000.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:c441462923580893a6170dd00126084be0a20b387b1c4fb1860755acd36c881b
+size 391390823
diff --git a/checkpoints/hubert/hubert_soft.pt b/checkpoints/hubert/hubert_soft.pt
new file mode 100644
index 0000000000000000000000000000000000000000..5ccd36b11dc124c97a0b73fa5f39eed8d1a6f27a
--- /dev/null
+++ b/checkpoints/hubert/hubert_soft.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e82e7d079df05fe3aa535f6f7d42d309bdae1d2a53324e2b2386c56721f4f649
+size 378435957
diff --git a/checkpoints/nsf_hifigan/NOTICE.txt b/checkpoints/nsf_hifigan/NOTICE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..228fc663c20c3166dc16dca0b3b94dca38a489b8
--- /dev/null
+++ b/checkpoints/nsf_hifigan/NOTICE.txt
@@ -0,0 +1,74 @@
+--- DiffSinger Community Vocoder ---
+
+ARCHITECTURE: NSF-HiFiGAN
+RELEASE DATE: 2022-12-11
+
+HYPER PARAMETERS:
+ - 44100 sample rate
+ - 128 mel bins
+ - 512 hop size
+ - 2048 window size
+ - fmin at 40Hz
+ - fmax at 16000Hz
+
+
+NOTICE:
+
+All model weights in the [DiffSinger Community Vocoder Project](https://openvpi.github.io/vocoders/), including
+model weights in this directory, are provided by the [OpenVPI Team](https://github.com/openvpi/), under the
+[Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/) license.
+
+
+ACKNOWLEDGEMENTS:
+
+Training data of this vocoder is provided and permitted by the following organizations, societies and individuals:
+
+孙飒 https://www.qfssr.cn
+赤松_Akamatsu https://www.zhibin.club
+乐威 https://www.zhibin.club
+伯添 https://space.bilibili.com/24087011
+雲宇光 https://space.bilibili.com/660675050
+橙子言 https://space.bilibili.com/318486464
+人衣大人 https://space.bilibili.com/2270344
+玖蝶 https://space.bilibili.com/676771003
+Yuuko
+白夜零BYL https://space.bilibili.com/1605040503
+嗷天 https://space.bilibili.com/5675252
+洛泠羽 https://space.bilibili.com/347373318
+灰条纹的灰猫君 https://space.bilibili.com/2083633
+幽寂 https://space.bilibili.com/478860
+恶魔王女 https://space.bilibili.com/2475098
+AlexYHX 芮晴
+绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
+诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
+汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
+1262917464
+炜阳
+叶卡yolka
+幸の夏 https://space.bilibili.com/1017297686
+暮色未量 https://space.bilibili.com/272904686
+晓寞sama https://space.bilibili.com/3463394
+没头绪的节操君
+串串BunC https://space.bilibili.com/95817834
+落雨 https://space.bilibili.com/1292427
+长尾巴的翎艾 https://space.bilibili.com/1638666
+声闻计划 https://space.bilibili.com/392812269
+唐家大小姐 http://5sing.kugou.com/palmusic/default.html
+不伊子
+
+Training machines are provided by:
+
+花儿不哭 https://space.bilibili.com/5760446
+
+
+TERMS OF REDISTRIBUTIONS:
+
+1. Do not sell this vocoder, or charge any fees from redistributing it, as prohibited by
+ the license.
+2. Include a copy of the CC BY-NC-SA 4.0 license, or a link referring to it.
+3. Include a copy of this notice, or any other notices informing that this vocoder is
+ provided by the OpenVPI Team, that this vocoder is licensed under CC BY-NC-SA 4.0, and
+ with a complete acknowledgement list as shown above.
+4. If you fine-tuned or modified the weights, leave a notice about what has been changed.
+5. (Optional) Leave a link to the official release page of the vocoder, and tell users
+ that other versions and future updates of this vocoder can be obtained from the website.
diff --git a/checkpoints/nsf_hifigan/NOTICE.zh-CN.txt b/checkpoints/nsf_hifigan/NOTICE.zh-CN.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b206a0bd1d3b80feb66c52a7452856021d06805a
--- /dev/null
+++ b/checkpoints/nsf_hifigan/NOTICE.zh-CN.txt
@@ -0,0 +1,72 @@
+--- DiffSinger 社区声码器 ---
+
+架构:NSF-HiFiGAN
+发布日期:2022-12-11
+
+超参数:
+ - 44100 sample rate
+ - 128 mel bins
+ - 512 hop size
+ - 2048 window size
+ - fmin at 40Hz
+ - fmax at 16000Hz
+
+
+注意事项:
+
+[DiffSinger 社区声码器企划](https://openvpi.github.io/vocoders/) 中的所有模型权重,
+包括此目录下的模型权重,均由 [OpenVPI Team](https://github.com/openvpi/) 提供,并基于
+[Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/)
+进行许可。
+
+
+致谢:
+
+此声码器的训练数据由以下组织、社团和个人提供并许可:
+
+孙飒 https://www.qfssr.cn
+赤松_Akamatsu https://www.zhibin.club
+乐威 https://www.zhibin.club
+伯添 https://space.bilibili.com/24087011
+雲宇光 https://space.bilibili.com/660675050
+橙子言 https://space.bilibili.com/318486464
+人衣大人 https://space.bilibili.com/2270344
+玖蝶 https://space.bilibili.com/676771003
+Yuuko
+白夜零BYL https://space.bilibili.com/1605040503
+嗷天 https://space.bilibili.com/5675252
+洛泠羽 https://space.bilibili.com/347373318
+灰条纹的灰猫君 https://space.bilibili.com/2083633
+幽寂 https://space.bilibili.com/478860
+恶魔王女 https://space.bilibili.com/2475098
+AlexYHX 芮晴
+绮萱 https://y.qq.com/n/ryqq/singer/003HjD6H4aZn1K
+诗芸 https://y.qq.com/n/ryqq/singer/0005NInj142zm0
+汐蕾 https://y.qq.com/n/ryqq/singer/0023cWMH1Bq1PJ
+1262917464
+炜阳
+叶卡yolka
+幸の夏 https://space.bilibili.com/1017297686
+暮色未量 https://space.bilibili.com/272904686
+晓寞sama https://space.bilibili.com/3463394
+没头绪的节操君
+串串BunC https://space.bilibili.com/95817834
+落雨 https://space.bilibili.com/1292427
+长尾巴的翎艾 https://space.bilibili.com/1638666
+声闻计划 https://space.bilibili.com/392812269
+唐家大小姐 http://5sing.kugou.com/palmusic/default.html
+不伊子
+
+训练算力的提供者如下:
+
+花儿不哭 https://space.bilibili.com/5760446
+
+
+二次分发条款:
+
+1. 请勿售卖此声码器或从其二次分发过程中收取任何费用,因为此类行为受到许可证的禁止。
+2. 请在二次分发文件中包含一份 CC BY-NC-SA 4.0 许可证的副本或指向该许可证的链接。
+3. 请在二次分发文件中包含这份声明,或以其他形式声明此声码器由 OpenVPI Team 提供并基于 CC BY-NC-SA 4.0 许可,
+ 并附带上述完整的致谢名单。
+4. 如果您微调或修改了权重,请留下一份关于其受到了何种修改的说明。
+5.(可选)留下一份指向此声码器的官方发布页面的链接,并告知使用者可从该网站获取此声码器的其他版本和未来的更新。
diff --git a/checkpoints/nsf_hifigan/config.json b/checkpoints/nsf_hifigan/config.json
new file mode 100644
index 0000000000000000000000000000000000000000..79821fb218253a51b8bcbaa2eaed539a79c78d32
--- /dev/null
+++ b/checkpoints/nsf_hifigan/config.json
@@ -0,0 +1,38 @@
+{
+ "resblock": "1",
+ "num_gpus": 4,
+ "batch_size": 10,
+ "learning_rate": 0.0002,
+ "adam_b1": 0.8,
+ "adam_b2": 0.99,
+ "lr_decay": 0.999,
+ "seed": 1234,
+
+ "upsample_rates": [ 8, 8, 2, 2, 2],
+ "upsample_kernel_sizes": [16,16, 4, 4, 4],
+ "upsample_initial_channel": 512,
+ "resblock_kernel_sizes": [3,7,11],
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
+ "discriminator_periods": [3, 5, 7, 11, 17, 23, 37],
+
+ "segment_size": 16384,
+ "num_mels": 128,
+ "num_freq": 1025,
+ "n_fft" : 2048,
+ "hop_size": 512,
+ "win_size": 2048,
+
+ "sampling_rate": 44100,
+
+ "fmin": 40,
+ "fmax": 16000,
+ "fmax_for_loss": null,
+
+ "num_workers": 16,
+
+ "dist_config": {
+ "dist_backend": "nccl",
+ "dist_url": "tcp://localhost:54321",
+ "world_size": 1
+ }
+}
diff --git a/ckpt.jpg b/ckpt.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..589c02413686da308443a1b03b5f19e1c13d6d47
Binary files /dev/null and b/ckpt.jpg differ
diff --git a/config.yaml b/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d49f0ecb2e118cc255d97e6077eb8e62046f4a05
--- /dev/null
+++ b/config.yaml
@@ -0,0 +1,349 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 80
+audio_sample_rate: 24000
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/atri
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 30000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: False
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 50.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 512
+fmax: 12000
+fmin: 30
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 128
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 80
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 5.0e-05
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 24
+max_tokens: 128000
+max_updates: 1000000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: pretrain/nyaru/model_ckpt_steps_60000.ckpt
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/atri
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 256
+residual_layers: 20
+save_best: false
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: atri
+spec_max:
+- 0.2987259328365326
+- 0.29721200466156006
+- 0.23978209495544434
+- 0.208412766456604
+- 0.25777050852775574
+- 0.2514476478099823
+- 0.1129382848739624
+- 0.03415697440505028
+- 0.09860049188137054
+- 0.10637332499027252
+- 0.13287633657455444
+- 0.19744250178337097
+- 0.10040587931871414
+- 0.13735432922840118
+- 0.15107455849647522
+- 0.17196381092071533
+- 0.08298977464437485
+- 0.0632769986987114
+- 0.02723858878016472
+- -0.001819317927584052
+- -0.029565516859292984
+- -0.023574354127049446
+- -0.01633293740451336
+- 0.07143621146678925
+- 0.021580500528216362
+- 0.07257916033267975
+- -0.024349519982933998
+- -0.06165708228945732
+- -0.10486568510532379
+- -0.1363687664270401
+- -0.13333871960639954
+- -0.13955898582935333
+- -0.16613495349884033
+- -0.17636367678642273
+- -0.2786925733089447
+- -0.22967253625392914
+- -0.31897130608558655
+- -0.18007366359233856
+- -0.29366692900657654
+- -0.2871025800704956
+- -0.36748355627059937
+- -0.46071451902389526
+- -0.5464922189712524
+- -0.5719417333602905
+- -0.6020897626876831
+- -0.6239874958992004
+- -0.5653440952301025
+- -0.6508013606071472
+- -0.628247857093811
+- -0.6809687614440918
+- -0.569259762763977
+- -0.5423558354377747
+- -0.5811785459518433
+- -0.5359002351760864
+- -0.6565515398979187
+- -0.7143737077713013
+- -0.8502675890922546
+- -0.7979224920272827
+- -0.7110578417778015
+- -0.763409435749054
+- -0.7984790802001953
+- -0.6927220821380615
+- -0.658117413520813
+- -0.7486468553543091
+- -0.5949879884719849
+- -0.7494576573371887
+- -0.7400822639465332
+- -0.6822793483734131
+- -0.7773582339286804
+- -0.661201536655426
+- -0.791329026222229
+- -0.8982341885566711
+- -0.8736728429794312
+- -0.7701027393341064
+- -0.8490535616874695
+- -0.7479292154312134
+- -0.9320166110992432
+- -1.2862414121627808
+- -2.8936190605163574
+- -2.924229860305786
+spec_min:
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -5.999454021453857
+- -5.8822431564331055
+- -5.892064571380615
+- -5.882402420043945
+- -5.786972522735596
+- -5.746835231781006
+- -5.8594512939453125
+- -5.7389445304870605
+- -5.718059539794922
+- -5.779720306396484
+- -5.801984786987305
+- -6.0
+- -6.0
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: true
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 2000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.hifigan.HifiGAN
+vocoder_ckpt: checkpoints/0109_hifigan_bigpopcs_hop128
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 512
+work_dir: checkpoints/atri
+no_fs2: false
\ No newline at end of file
diff --git a/doc/train_and_inference.markdown b/doc/train_and_inference.markdown
new file mode 100644
index 0000000000000000000000000000000000000000..eed0d6e9ad470db1d3fa8d8410d5c51fb026a19f
--- /dev/null
+++ b/doc/train_and_inference.markdown
@@ -0,0 +1,210 @@
+# Diff-SVC(train/inference by yourself)
+## 0.环境配置
+>注意:requirements文件已更新,目前分为3个版本,可自行选择使用。\
+1. requirements.txt 是此仓库测试的原始完整环境,Torch1.12.1+cu113,可选择直接pip 或删除其中与pytorch有关的项目(torch/torchvision)后再pip,并使用自己的torch环境
+```
+pip install -r requirements.txt
+```
+>2. (推荐)requirements_short.txt 是上述环境的手动整理版,不含torch本体,也可以直接
+```
+pip install -r requirements_short.txt
+```
+>3. 根目录下有一份@三千整理的依赖列表requirements.png,是在某品牌云服务器上跑通的,不过此torch版本已不兼容目前版本代码,但是其他部分版本可以参考,十分感谢
+
+## 1.推理
+>使用根目录下的inference.ipynb进行推理或使用经过作者适配的@小狼的infer.py\
+在第一个block中修改如下参数:
+```
+config_path='checkpoints压缩包中config.yaml的位置'
+如'./checkpoints/nyaru/config.yaml'
+config和checkpoints是一一对应的,请不要使用其他config
+
+project_name='这个项目的名称'
+如'nyaru'
+
+model_path='ckpt文件的全路径'
+如'./checkpoints/nyaru/model_ckpt_steps_112000.ckpt'
+
+hubert_gpu=True
+推理时是否使用gpu推理hubert(模型中的一个模块),不影响模型的其他部分
+目前版本已大幅减小hubert的gpu占用,在1060 6G显存下可完整推理,不需要关闭了。
+另外现已支持长音频自动切片功能(ipynb和infer.py均可),超过30s的音频将自动在静音处切片处理,感谢@小狼的代码
+
+```
+### 可调节参数:
+```
+wav_fn='xxx.wav'#传入音频的路径,默认在项目根目录中
+
+use_crepe=True
+#crepe是一个F0算法,效果好但速度慢,改成False会使用效果稍逊于crepe但较快的parselmouth算法
+
+thre=0.05
+#crepe的噪声过滤阈值,源音频干净可适当调大,噪音多就保持这个数值或者调小,前面改成False后这个参数不起作用
+
+pndm_speedup=20
+#推理加速算法倍数,默认是1000步,这里填成10就是只使用100步合成,是一个中规中矩的数值,这个数值可以高到50倍(20步合成)没有明显质量损失,再大可能会有可观的质量损失,注意如果下方开启了use_gt_mel, 应保证这个数值小于add_noise_step,并尽量让其能够整除
+
+key=0
+#变调参数,默认为0(不是1!!),将源音频的音高升高key个半音后合成,如男声转女生,可填入8或者12等(12就是升高一整个8度)
+
+use_pe=True
+#梅尔谱合成音频时使用的F0提取算法,如果改成False将使用源音频的F0\
+这里填True和False合成会略有差异,通常是True会好些,但也不尽然,对合成速度几乎无影响\
+(无论key填什么 这里都是可以自由选择的,不影响)\
+44.1kHz下不支持此功能,会自动关闭,开着也不报错就是了
+
+use_gt_mel=False
+#这个选项类似于AI画图的图生图功能,如果打开,产生的音频将是输入声音与目标说话人声音的混合,混合比例由下一个参数确定
+注意!!!:这个参数如果改成True,请确保key填成0,不支持变调
+
+add_noise_step=500
+#与上个参数有关,控制两种声音的比例,填入1是完全的源声线,填入1000是完全的目标声线,能听出来是两者均等混合的数值大约在300附近(并不是线性的,另外这个参数如果调的很小,可以把pndm加速倍率调低,增加合成质量)
+
+wav_gen='yyy.wav'#输出音频的路径,默认在项目根目录中,可通过改变扩展名更改保存文件类型
+```
+如果使用infer.py,修改方式类似,需要修改__name__=='__main__'中的部分,然后在根目录中执行\
+python infer.py\
+这种方式需要将原音频放入raw中并在results中查找结果
+## 2.数据预处理与训练
+### 2.1 准备数据
+>目前支持wav格式和ogg格式的音频数据,采样率最好高于24kHz,程序会自动处理采样率和声道问题。采样率不可低于16kHz(一般不会的)\
+音频需要切片为5-15s为宜的短音频,长度没有具体要求,但不宜过长过短。音频需要为纯目标人干声,不可以有背景音乐和其他人声音,最好也不要有过重的混响等。若经过去伴奏等处理,请尽量保证处理后的音频质量。\
+目前仅支持单人训练,总时长尽量保证在3h或以上,不需要额外任何标注,将音频文件放在下述raw_data_dir下即可,这个目录下的结构可以自由定义,程序会自主找到所需文件。
+
+### 2.2 修改超参数配置
+>首先请备份一份config.yaml(此文件对应24kHz声码器, 44.1kHz声码器请使用config_nsf.yaml),然后修改它\
+可能会用到的参数如下(以工程名为nyaru为例):
+```
+K_step: 1000
+#diffusion过程总的step,建议不要修改
+
+binary_data_dir: data/binary/nyaru
+预处理后数据的存放地址:需要将后缀改成工程名字
+
+config_path: training/config.yaml
+你要使用的这份yaml自身的地址,由于预处理过程中会写入数据,所以这个地址务必修改成将要存放这份yaml文件的完整路径
+
+choose_test_manually: false
+手动选择测试集,默认关闭,自动随机抽取5条音频作为测试集。
+如果改为ture,请在test_prefixes:中填入测试数据的文件名前缀,程序会将以对应前缀开头的文件作为测试集
+这是个列表,可以填多个前缀,如:
+test_prefixes:
+- test
+- aaaa
+- 5012
+- speaker1024
+重要:测试集*不可以*为空,为了不产生意外影响,建议尽量不要手动选择测试集
+
+endless_ds:False
+如果你的数据集过小,每个epoch时间很短,请将此项打开,将把正常的1000epoch作为一个epoch计算
+
+hubert_path: checkpoints/hubert/hubert.pt
+hubert模型的存放地址,确保这个路径是对的,一般解压checkpoints包之后就是这个路径不需要改,现已使用torch版本推理
+hubert_gpu:True
+是否在预处理时使用gpu运行hubert(模型的一个模块),关闭后使用cpu,但耗时会显著增加。另外模型训练完推理时hubert是否用gpu是在inference中单独控制的,不受此处影响。目前hubert改为torch版后已经可以做到在1060 6G显存gpu上进行预处理,与直接推理1分钟内的音频不超出显存限制,一般不需要关了。
+
+lr: 0.0008
+#初始的学习率:这个数字对应于88的batchsize,如果batchsize更小,可以调低这个数值一些
+
+decay_steps: 20000
+每20000步学习率衰减为原来的一半,如果batchsize比较小,请调大这个数值
+
+#对于30-40左右的batchsize,推荐lr=0.0004,decay_steps=40000
+
+max_frames: 42000
+max_input_tokens: 6000
+max_sentences: 88
+max_tokens: 128000
+#batchsize是由这几个参数动态算出来的,如果不太清楚具体含义,可以只改动max_sentences这个参数,填入batchsize的最大限制值,以免炸显存
+
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+#pe模型路径,确保这个文件存在,具体作用参考inference部分
+
+raw_data_dir: data/raw/nyaru
+#存放预处理前原始数据的位置,请将原始wav数据放在这个目录下,内部文件结构无所谓,会自动解构
+
+residual_channels: 384
+residual_layers: 20
+#控制核心网络规模的一组参数,越大参数越多炼的越慢,但效果不一定会变好,大一点的数据集可以把第一个改成512。这个可以自行实验效果,不过不了解的话尽量不动。
+
+speaker_id: nyaru
+#训练的说话人名字,目前只支持单说话人,请在这里填写(只是观赏作用,没有实际意义的参数)
+
+use_crepe: true
+#在数据预处理中使用crepe提取F0,追求效果请打开,追求速度可以关闭
+
+val_check_interval: 2000
+#每2000steps推理测试集并保存ckpt
+
+vocoder_ckpt:checkpoints/0109_hifigan_bigpopcs_hop128
+#24kHz下为对应声码器的目录, 44.1kHz下为对应声码器的文件名, 注意不要填错
+
+work_dir: checkpoints/nyaru
+#修改后缀为工程名(也可以删掉或完全留空自动生成,但别乱填)
+no_fs2: true
+#对网络encoder的精简,能缩减模型体积,加快训练,且并未发现有对网络表现损害的直接证据。默认打开
+
+```
+>其他的参数如果你不知道它是做什么的,请不要修改,即使你看着名称可能以为你知道它是做什么的。
+
+### 2.3 数据预处理
+在diff-svc的目录下执行以下命令:\
+#windows
+```
+set PYTHONPATH=.
+set CUDA_VISIBLE_DEVICES=0
+python preprocessing/binarize.py --config training/config.yaml
+```
+#linux
+```
+export PYTHONPATH=.
+CUDA_VISIBLE_DEVICES=0 python preprocessing/binarize.py --config training/config.yaml
+```
+对于预处理,@小狼准备了一份可以分段处理hubert和其他特征的代码,如果正常处理显存不足,可以先python ./network/hubert/hubert_model.py
+然后再运行正常的指令,能够识别提前处理好的hubert特征
+### 2.4 训练
+#windows
+```
+set CUDA_VISIBLE_DEVICES=0
+python run.py --config training/config.yaml --exp_name nyaru --reset
+```
+#linux
+```
+CUDA_VISIBLE_DEVICES=0 python run.py --config training/config.yaml --exp_name nyaru --reset
+```
+>需要将exp_name改为你的工程名,并修改config路径,请确保和预处理使用的是同一个config文件\
+*重要* :训练完成后,若之前不是在本地数据预处理,除了需要下载对应的ckpt文件,也需要将config文件下载下来,作为推理时使用的config,不可以使用本地之前上传上去那份。因为预处理时会向config文件中写入内容。推理时要保持使用的config和预处理使用的config是同一份。
+
+
+### 2.5 可能出现的问题:
+>2.5.1 'Upsample' object has no attribute 'recompute_scale_factor'\
+此问题发现于cuda11.3对应的torch中,若出现此问题,请通过合适的方法(如ide自动跳转等)找到你的python依赖包中的torch.nn.modules.upsampling.py文件(如conda环境中为conda目录\envs\环境目录\Lib\site-packages\torch\nn\modules\upsampling.py),修改其153-154行
+```
+return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners,recompute_scale_factor=self.recompute_scale_factor)
+```
+>改为
+```
+return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
+# recompute_scale_factor=self.recompute_scale_factor)
+```
+>2.5.2 no module named 'utils'\
+请在你的运行环境(如colab笔记本)中以如下方式设置:
+```
+import os
+os.environ['PYTHONPATH']='.'
+!CUDA_VISIBLE_DEVICES=0 python preprocessing/binarize.py --config training/config.yaml
+```
+注意一定要在项目文件夹的根目录中执行
+>2.5.3 cannot load library 'libsndfile.so'\
+可能会在linux环境中遇到的错误,请执行以下指令
+```
+apt-get install libsndfile1 -y
+```
+>2.5.4 cannot load import 'consume_prefix_in_state_dict_if_present'\
+torch版本过低,请更换高版本torch
+
+>2.5.5 预处理数据过慢\
+检查是否在配置中开启了use_crepe,将其关闭可显著提升速度。\
+检查配置中hubert_gpu是否开启。
+
+如有其他问题,请加入QQ频道或discord频道询问。
diff --git a/flask_api.py b/flask_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..eaecd0b7305218ad61d0a5a23d44f3312b538f0b
--- /dev/null
+++ b/flask_api.py
@@ -0,0 +1,54 @@
+import io
+import logging
+
+import librosa
+import soundfile
+from flask import Flask, request, send_file
+from flask_cors import CORS
+
+from infer_tools.infer_tool import Svc
+from utils.hparams import hparams
+
+app = Flask(__name__)
+
+CORS(app)
+
+logging.getLogger('numba').setLevel(logging.WARNING)
+
+
+@app.route("/voiceChangeModel", methods=["POST"])
+def voice_change_model():
+ request_form = request.form
+ wave_file = request.files.get("sample", None)
+ # 变调信息
+ f_pitch_change = float(request_form.get("fPitchChange", 0))
+ # DAW所需的采样率
+ daw_sample = int(float(request_form.get("sampleRate", 0)))
+ speaker_id = int(float(request_form.get("sSpeakId", 0)))
+ # http获得wav文件并转换
+ input_wav_path = io.BytesIO(wave_file.read())
+ # 模型推理
+ _f0_tst, _f0_pred, _audio = model.infer(input_wav_path, key=f_pitch_change, acc=accelerate, use_pe=False,
+ use_crepe=False)
+ tar_audio = librosa.resample(_audio, hparams["audio_sample_rate"], daw_sample)
+ # 返回音频
+ out_wav_path = io.BytesIO()
+ soundfile.write(out_wav_path, tar_audio, daw_sample, format="wav")
+ out_wav_path.seek(0)
+ return send_file(out_wav_path, download_name="temp.wav", as_attachment=True)
+
+
+if __name__ == '__main__':
+ # 工程文件夹名,训练时用的那个
+ project_name = "firefox"
+ model_path = f'./checkpoints/{project_name}/model_ckpt_steps_188000.ckpt'
+ config_path = f'./checkpoints/{project_name}/config.yaml'
+
+ # 加速倍数
+ accelerate = 50
+ hubert_gpu = True
+
+ model = Svc(project_name, config_path, hubert_gpu, model_path)
+
+ # 此处与vst插件对应,不建议更改
+ app.run(port=6842, host="0.0.0.0", debug=False, threaded=False)
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..a671ed05af4248a13fcf9225ce21133fc766ae01
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,98 @@
+import io
+import time
+from pathlib import Path
+
+import librosa
+import numpy as np
+import soundfile
+
+from infer_tools import infer_tool
+from infer_tools import slicer
+from infer_tools.infer_tool import Svc
+from utils.hparams import hparams
+
+chunks_dict = infer_tool.read_temp("./infer_tools/new_chunks_temp.json")
+
+
+def run_clip(svc_model, key, acc, use_pe, use_crepe, thre, use_gt_mel, add_noise_step, project_name='', f_name=None,
+ file_path=None, out_path=None, slice_db=-40,**kwargs):
+ print(f'code version:2022-12-04')
+ use_pe = use_pe if hparams['audio_sample_rate'] == 24000 else False
+ if file_path is None:
+ raw_audio_path = f"./raw/{f_name}"
+ clean_name = f_name[:-4]
+ else:
+ raw_audio_path = file_path
+ clean_name = str(Path(file_path).name)[:-4]
+ infer_tool.format_wav(raw_audio_path)
+ wav_path = Path(raw_audio_path).with_suffix('.wav')
+ global chunks_dict
+ audio, sr = librosa.load(wav_path, mono=True,sr=None)
+ wav_hash = infer_tool.get_md5(audio)
+ if wav_hash in chunks_dict.keys():
+ print("load chunks from temp")
+ chunks = chunks_dict[wav_hash]["chunks"]
+ else:
+ chunks = slicer.cut(wav_path, db_thresh=slice_db)
+ chunks_dict[wav_hash] = {"chunks": chunks, "time": int(time.time())}
+ infer_tool.write_temp("./infer_tools/new_chunks_temp.json", chunks_dict)
+ audio_data, audio_sr = slicer.chunks2audio(wav_path, chunks)
+
+ count = 0
+ f0_tst = []
+ f0_pred = []
+ audio = []
+ for (slice_tag, data) in audio_data:
+ print(f'#=====segment start, {round(len(data) / audio_sr, 3)}s======')
+ length = int(np.ceil(len(data) / audio_sr * hparams['audio_sample_rate']))
+ raw_path = io.BytesIO()
+ soundfile.write(raw_path, data, audio_sr, format="wav")
+ if hparams['debug']:
+ print(np.mean(data), np.var(data))
+ raw_path.seek(0)
+ if slice_tag:
+ print('jump empty segment')
+ _f0_tst, _f0_pred, _audio = (
+ np.zeros(int(np.ceil(length / hparams['hop_size']))), np.zeros(int(np.ceil(length / hparams['hop_size']))),
+ np.zeros(length))
+ else:
+ _f0_tst, _f0_pred, _audio = svc_model.infer(raw_path, key=key, acc=acc, use_pe=use_pe, use_crepe=use_crepe,
+ thre=thre, use_gt_mel=use_gt_mel, add_noise_step=add_noise_step)
+ fix_audio = np.zeros(length)
+ fix_audio[:] = np.mean(_audio)
+ fix_audio[:len(_audio)] = _audio[0 if len(_audio) 50 * 1024 * 1024:
+ f_name = file_name.split("/")[-1]
+ print(f"clean {f_name}")
+ for wav_hash in list(data_dict.keys()):
+ if int(time.time()) - int(data_dict[wav_hash]["time"]) > 14 * 24 * 3600:
+ del data_dict[wav_hash]
+ except Exception as e:
+ print(e)
+ print(f"{file_name} error,auto rebuild file")
+ data_dict = {"info": "temp_dict"}
+ return data_dict
+
+
+f0_dict = read_temp("./infer_tools/f0_temp.json")
+
+
+def write_temp(file_name, data):
+ with open(file_name, "w") as f:
+ f.write(json.dumps(data))
+
+
+def timeit(func):
+ def run(*args, **kwargs):
+ t = time.time()
+ res = func(*args, **kwargs)
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
+ return res
+
+ return run
+
+
+def format_wav(audio_path):
+ if Path(audio_path).suffix=='.wav':
+ return
+ raw_audio, raw_sample_rate = librosa.load(audio_path, mono=True,sr=None)
+ soundfile.write(Path(audio_path).with_suffix(".wav"), raw_audio, raw_sample_rate)
+
+
+def fill_a_to_b(a, b):
+ if len(a) < len(b):
+ for _ in range(0, len(b) - len(a)):
+ a.append(a[0])
+
+
+def get_end_file(dir_path, end):
+ file_lists = []
+ for root, dirs, files in os.walk(dir_path):
+ files = [f for f in files if f[0] != '.']
+ dirs[:] = [d for d in dirs if d[0] != '.']
+ for f_file in files:
+ if f_file.endswith(end):
+ file_lists.append(os.path.join(root, f_file).replace("\\", "/"))
+ return file_lists
+
+
+def mkdir(paths: list):
+ for path in paths:
+ if not os.path.exists(path):
+ os.mkdir(path)
+
+
+def get_md5(content):
+ return hashlib.new("md5", content).hexdigest()
+
+
+class Svc:
+ def __init__(self, project_name, config_name, hubert_gpu, model_path):
+ self.project_name = project_name
+ self.DIFF_DECODERS = {
+ 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+ 'fft': lambda hp: FFT(
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
+ }
+
+ self.model_path = model_path
+ self.dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ self._ = set_hparams(config=config_name, exp_name=self.project_name, infer=True,
+ reset=True,
+ hparams_str='',
+ print_hparams=False)
+
+ self.mel_bins = hparams['audio_num_mel_bins']
+ self.model = GaussianDiffusion(
+ phone_encoder=Hubertencoder(hparams['hubert_path']),
+ out_dims=self.mel_bins, denoise_fn=self.DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+ timesteps=hparams['timesteps'],
+ K_step=hparams['K_step'],
+ loss_type=hparams['diff_loss_type'],
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+ )
+ self.load_ckpt()
+ self.model.to(self.dev)
+ hparams['hubert_gpu'] = hubert_gpu
+ self.hubert = Hubertencoder(hparams['hubert_path'])
+ self.pe = PitchExtractor().to(self.dev)
+ utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
+ self.pe.eval()
+ self.vocoder = get_vocoder_cls(hparams)()
+
+ def load_ckpt(self, model_name='model', force=True, strict=True):
+ utils.load_ckpt(self.model, self.model_path, model_name, force, strict)
+
+ def infer(self, in_path, key, acc, use_pe=True, use_crepe=True, thre=0.05, singer=False, **kwargs):
+ batch = self.pre(in_path, acc, use_crepe, thre)
+ spk_embed = batch.get('spk_embed') if not hparams['use_spk_id'] else batch.get('spk_ids')
+ hubert = batch['hubert']
+ ref_mels = batch["mels"]
+ energy=batch['energy']
+ mel2ph = batch['mel2ph']
+ batch['f0'] = batch['f0'] + (key / 12)
+ batch['f0'][batch['f0']>np.log2(hparams['f0_max'])]=0
+ f0 = batch['f0']
+ uv = batch['uv']
+ @timeit
+ def diff_infer():
+ outputs = self.model(
+ hubert.to(self.dev), spk_embed=spk_embed, mel2ph=mel2ph.to(self.dev), f0=f0.to(self.dev), uv=uv.to(self.dev),energy=energy.to(self.dev),
+ ref_mels=ref_mels.to(self.dev),
+ infer=True, **kwargs)
+ return outputs
+ outputs=diff_infer()
+ batch['outputs'] = self.model.out2mel(outputs['mel_out'])
+ batch['mel2ph_pred'] = outputs['mel2ph']
+ batch['f0_gt'] = denorm_f0(batch['f0'], batch['uv'], hparams)
+ if use_pe:
+ batch['f0_pred'] = self.pe(outputs['mel_out'])['f0_denorm_pred'].detach()
+ else:
+ batch['f0_pred'] = outputs.get('f0_denorm')
+ return self.after_infer(batch, singer, in_path)
+
+ @timeit
+ def after_infer(self, prediction, singer, in_path):
+ for k, v in prediction.items():
+ if type(v) is torch.Tensor:
+ prediction[k] = v.cpu().numpy()
+
+ # remove paddings
+ mel_gt = prediction["mels"]
+ mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
+
+ mel_pred = prediction["outputs"]
+ mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
+ mel_pred = mel_pred[mel_pred_mask]
+ mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
+
+ f0_gt = prediction.get("f0_gt")
+ f0_pred = prediction.get("f0_pred")
+ if f0_pred is not None:
+ f0_gt = f0_gt[mel_gt_mask]
+ if len(f0_pred) > len(mel_pred_mask):
+ f0_pred = f0_pred[:len(mel_pred_mask)]
+ f0_pred = f0_pred[mel_pred_mask]
+ torch.cuda.is_available() and torch.cuda.empty_cache()
+
+ if singer:
+ data_path = in_path.replace("batch", "singer_data")
+ mel_path = data_path[:-4] + "_mel.npy"
+ f0_path = data_path[:-4] + "_f0.npy"
+ np.save(mel_path, mel_pred)
+ np.save(f0_path, f0_pred)
+ wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
+ return f0_gt, f0_pred, wav_pred
+
+ def temporary_dict2processed_input(self, item_name, temp_dict, use_crepe=True, thre=0.05):
+ '''
+ process data in temporary_dicts
+ '''
+
+ binarization_args = hparams['binarization_args']
+
+ @timeit
+ def get_pitch(wav, mel):
+ # get ground truth f0 by self.get_pitch_algorithm
+ global f0_dict
+ if use_crepe:
+ md5 = get_md5(wav)
+ if f"{md5}_gt" in f0_dict.keys():
+ print("load temp crepe f0")
+ gt_f0 = np.array(f0_dict[f"{md5}_gt"]["f0"])
+ coarse_f0 = np.array(f0_dict[f"{md5}_coarse"]["f0"])
+ else:
+ torch.cuda.is_available() and torch.cuda.empty_cache()
+ gt_f0, coarse_f0 = get_pitch_crepe(wav, mel, hparams, thre)
+ f0_dict[f"{md5}_gt"] = {"f0": gt_f0.tolist(), "time": int(time.time())}
+ f0_dict[f"{md5}_coarse"] = {"f0": coarse_f0.tolist(), "time": int(time.time())}
+ write_temp("./infer_tools/f0_temp.json", f0_dict)
+ else:
+ md5 = get_md5(wav)
+ if f"{md5}_gt_harvest" in f0_dict.keys():
+ print("load temp harvest f0")
+ gt_f0 = np.array(f0_dict[f"{md5}_gt_harvest"]["f0"])
+ coarse_f0 = np.array(f0_dict[f"{md5}_coarse_harvest"]["f0"])
+ else:
+ gt_f0, coarse_f0 = get_pitch_world(wav, mel, hparams)
+ f0_dict[f"{md5}_gt_harvest"] = {"f0": gt_f0.tolist(), "time": int(time.time())}
+ f0_dict[f"{md5}_coarse_harvest"] = {"f0": coarse_f0.tolist(), "time": int(time.time())}
+ write_temp("./infer_tools/f0_temp.json", f0_dict)
+ processed_input['f0'] = gt_f0
+ processed_input['pitch'] = coarse_f0
+
+ def get_align(mel, phone_encoded):
+ mel2ph = np.zeros([mel.shape[0]], int)
+ start_frame = 0
+ ph_durs = mel.shape[0] / phone_encoded.shape[0]
+ if hparams['debug']:
+ print(mel.shape, phone_encoded.shape, mel.shape[0] / phone_encoded.shape[0])
+ for i_ph in range(phone_encoded.shape[0]):
+ end_frame = int(i_ph * ph_durs + ph_durs + 0.5)
+ mel2ph[start_frame:end_frame + 1] = i_ph + 1
+ start_frame = end_frame + 1
+
+ processed_input['mel2ph'] = mel2ph
+
+ if hparams['vocoder'] in VOCODERS:
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(temp_dict['wav_fn'])
+ else:
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(temp_dict['wav_fn'])
+ processed_input = {
+ 'item_name': item_name, 'mel': mel,
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
+ }
+ processed_input = {**temp_dict, **processed_input} # merge two dicts
+
+ if binarization_args['with_f0']:
+ get_pitch(wav, mel)
+ if binarization_args['with_hubert']:
+ st = time.time()
+ hubert_encoded = processed_input['hubert'] = self.hubert.encode(temp_dict['wav_fn'])
+ et = time.time()
+ dev = 'cuda' if hparams['hubert_gpu'] and torch.cuda.is_available() else 'cpu'
+ print(f'hubert (on {dev}) time used {et - st}')
+
+ if binarization_args['with_align']:
+ get_align(mel, hubert_encoded)
+ return processed_input
+
+ def pre(self, wav_fn, accelerate, use_crepe=True, thre=0.05):
+ if isinstance(wav_fn, BytesIO):
+ item_name = self.project_name
+ else:
+ song_info = wav_fn.split('/')
+ item_name = song_info[-1].split('.')[-2]
+ temp_dict = {'wav_fn': wav_fn, 'spk_id': self.project_name}
+
+ temp_dict = self.temporary_dict2processed_input(item_name, temp_dict, use_crepe, thre)
+ hparams['pndm_speedup'] = accelerate
+ batch = processed_input2batch([getitem(temp_dict)])
+ return batch
+
+
+def getitem(item):
+ max_frames = hparams['max_frames']
+ spec = torch.Tensor(item['mel'])[:max_frames]
+ energy = (spec.exp() ** 2).sum(-1).sqrt()
+ mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
+ f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+ hubert = torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
+ pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
+ sample = {
+ "item_name": item['item_name'],
+ "hubert": hubert,
+ "mel": spec,
+ "pitch": pitch,
+ "energy": energy,
+ "f0": f0,
+ "uv": uv,
+ "mel2ph": mel2ph,
+ "mel_nonpadding": spec.abs().sum(-1) > 0,
+ }
+ return sample
+
+
+def processed_input2batch(samples):
+ '''
+ Args:
+ samples: one batch of processed_input
+ NOTE:
+ the batch size is controlled by hparams['max_sentences']
+ '''
+ if len(samples) == 0:
+ return {}
+ item_names = [s['item_name'] for s in samples]
+ hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
+ f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+ pitch = utils.collate_1d([s['pitch'] for s in samples])
+ uv = utils.collate_1d([s['uv'] for s in samples])
+ energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
+ mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+ if samples[0]['mel2ph'] is not None else None
+ mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+ mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+
+ batch = {
+ 'item_name': item_names,
+ 'nsamples': len(samples),
+ 'hubert': hubert,
+ 'mels': mels,
+ 'mel_lengths': mel_lengths,
+ 'mel2ph': mel2ph,
+ 'energy': energy,
+ 'pitch': pitch,
+ 'f0': f0,
+ 'uv': uv,
+ }
+ return batch
diff --git a/infer_tools/new_chunks_temp.json b/infer_tools/new_chunks_temp.json
new file mode 100644
index 0000000000000000000000000000000000000000..f71eeddf2c8c881b7fa6895e75a1a3e25ab1043d
--- /dev/null
+++ b/infer_tools/new_chunks_temp.json
@@ -0,0 +1 @@
+{"info": "temp_dict", "accd68639783a1819e41702c4c1bf2e7": {"chunks": {"0": {"slice": false, "split_time": "0,607727"}}, "time": 1672781849}, "28b718f4ef116ca8c4d2279dfc0bd161": {"chunks": {"0": {"slice": false, "split_time": "0,607727"}}, "time": 1672758446}, "3c68f6ef87cdbea1be9b66e78bcd1c62": {"chunks": {"0": {"slice": true, "split_time": "0,20115"}, "1": {"slice": false, "split_time": "20115,152363"}, "2": {"slice": true, "split_time": "152363,163441"}, "3": {"slice": false, "split_time": "163441,347184"}, "4": {"slice": true, "split_time": "347184,351976"}, "5": {"slice": false, "split_time": "351976,438356"}, "6": {"slice": true, "split_time": "438356,499095"}}, "time": 1673478071}, "dd17f428601bccf6dd3c82c6d6daaaba": {"chunks": {"0": {"slice": true, "split_time": "0,24155"}, "1": {"slice": false, "split_time": "24155,323983"}, "2": {"slice": true, "split_time": "323983,352800"}}, "time": 1672758814}, "7d915edda42b3c65471bb0f86ba2a57c": {"chunks": {"0": {"slice": false, "split_time": "0,1417197"}, "1": {"slice": true, "split_time": "1417197,1426665"}, "2": {"slice": false, "split_time": "1426665,1736746"}, "3": {"slice": true, "split_time": "1736746,1743374"}, "4": {"slice": false, "split_time": "1743374,2042438"}, "5": {"slice": true, "split_time": "2042438,2050710"}, "6": {"slice": false, "split_time": "2050710,2508864"}, "7": {"slice": true, "split_time": "2508864,2515696"}, "8": {"slice": false, "split_time": "2515696,2682383"}}, "time": 1672772091}, "3bad51a34c1a940a31de387147e10c0a": {"chunks": {"0": {"slice": false, "split_time": "0,318420"}, "1": {"slice": true, "split_time": "318420,325186"}, "2": {"slice": false, "split_time": "325186,1581611"}, "3": {"slice": true, "split_time": "1581611,1594730"}}, "time": 1672772297}, "3ffa2635bed65fb96d8476803560fd1a": {"chunks": {"0": {"slice": true, "split_time": "0,668952"}, "1": {"slice": false, "split_time": "668952,894598"}, "2": {"slice": true, "split_time": "894598,917516"}, "3": {"slice": false, "split_time": "917516,1140044"}, "4": {"slice": true, "split_time": "1140044,1159951"}, "5": {"slice": false, "split_time": "1159951,1725379"}, "6": {"slice": true, "split_time": "1725379,1756813"}, "7": {"slice": false, "split_time": "1756813,3302878"}, "8": {"slice": true, "split_time": "3302878,3361575"}, "9": {"slice": false, "split_time": "3361575,3582626"}, "10": {"slice": true, "split_time": "3582626,3609059"}, "11": {"slice": false, "split_time": "3609059,3844622"}, "12": {"slice": true, "split_time": "3844622,3861910"}, "13": {"slice": false, "split_time": "3861910,4440673"}, "14": {"slice": true, "split_time": "4440673,4468405"}, "15": {"slice": false, "split_time": "4468405,5108832"}, "16": {"slice": true, "split_time": "5108832,5129497"}, "17": {"slice": false, "split_time": "5129497,6675968"}, "18": {"slice": true, "split_time": "6675968,9217857"}}, "time": 1673261015}, "5fba9ddc6c7223c2907d9ff3169e23f1": {"chunks": {"0": {"slice": true, "split_time": "0,53579"}, "1": {"slice": false, "split_time": "53579,332754"}, "2": {"slice": true, "split_time": "332754,414277"}, "3": {"slice": false, "split_time": "414277,781044"}, "4": {"slice": true, "split_time": "781044,816837"}, "5": {"slice": false, "split_time": "816837,1549835"}, "6": {"slice": true, "split_time": "1549835,1557319"}, "7": {"slice": false, "split_time": "1557319,4476357"}, "8": {"slice": true, "split_time": "4476357,4503731"}, "9": {"slice": false, "split_time": "4503731,5209666"}, "10": {"slice": true, "split_time": "5209666,5213405"}, "11": {"slice": false, "split_time": "5213405,6021387"}, "12": {"slice": true, "split_time": "6021387,6063481"}, "13": {"slice": false, "split_time": "6063481,6491872"}, "14": {"slice": true, "split_time": "6491872,6602234"}, "15": {"slice": false, "split_time": "6602234,7311975"}, "16": {"slice": true, "split_time": "7311975,7328295"}, "17": {"slice": false, "split_time": "7328295,8137067"}, "18": {"slice": true, "split_time": "8137067,8142127"}, "19": {"slice": false, "split_time": "8142127,10821823"}, "20": {"slice": true, "split_time": "10821823,11066085"}, "21": {"slice": false, "split_time": "11066085,11148757"}, "22": {"slice": true, "split_time": "11148757,11459584"}}, "time": 1673446113}, "a8b37529b910527ef9b414d4fa485973": {"chunks": {"0": {"slice": true, "split_time": "0,74804"}, "1": {"slice": false, "split_time": "74804,545959"}, "2": {"slice": true, "split_time": "545959,550194"}, "3": {"slice": false, "split_time": "550194,1040392"}, "4": {"slice": true, "split_time": "1040392,1045223"}, "5": {"slice": false, "split_time": "1045223,2552370"}, "6": {"slice": true, "split_time": "2552370,2825551"}, "7": {"slice": false, "split_time": "2825551,3280352"}, "8": {"slice": true, "split_time": "3280352,3284153"}, "9": {"slice": false, "split_time": "3284153,4005079"}, "10": {"slice": true, "split_time": "4005079,4049271"}, "11": {"slice": false, "split_time": "4049271,4279046"}, "12": {"slice": true, "split_time": "4279046,4323261"}, "13": {"slice": false, "split_time": "4323261,5505905"}, "14": {"slice": true, "split_time": "5505905,5535670"}, "15": {"slice": false, "split_time": "5535670,5770626"}, "16": {"slice": true, "split_time": "5770626,5829727"}, "17": {"slice": false, "split_time": "5829727,7771861"}, "18": {"slice": true, "split_time": "7771861,8040098"}, "19": {"slice": false, "split_time": "8040098,8536352"}, "20": {"slice": true, "split_time": "8536352,9186304"}}, "time": 1673040810}, "7e44079101dd10bc262ee28c4b881f25": {"chunks": {"0": {"slice": true, "split_time": "0,135941"}, "1": {"slice": false, "split_time": "135941,787855"}, "2": {"slice": true, "split_time": "787855,999653"}, "3": {"slice": false, "split_time": "999653,2981684"}, "4": {"slice": true, "split_time": "2981684,3116588"}, "5": {"slice": false, "split_time": "3116588,3339320"}, "6": {"slice": true, "split_time": "3339320,3382142"}, "7": {"slice": false, "split_time": "3382142,5345853"}, "8": {"slice": true, "split_time": "5345853,5365617"}, "9": {"slice": false, "split_time": "5365617,6391431"}, "10": {"slice": true, "split_time": "6391431,6527569"}, "11": {"slice": false, "split_time": "6527569,7791684"}, "12": {"slice": true, "split_time": "7791684,7794073"}, "13": {"slice": false, "split_time": "7794073,8749857"}, "14": {"slice": true, "split_time": "8749857,9232506"}, "15": {"slice": false, "split_time": "9232506,11714376"}, "16": {"slice": true, "split_time": "11714376,11785553"}, "17": {"slice": false, "split_time": "11785553,12282894"}, "18": {"slice": true, "split_time": "12282894,12861440"}}, "time": 1673260417}, "89ccd0cc779c17deb9dfd54e700a0bcd": {"chunks": {"0": {"slice": true, "split_time": "0,15387"}, "1": {"slice": false, "split_time": "15387,348933"}, "2": {"slice": true, "split_time": "348933,352886"}, "3": {"slice": false, "split_time": "352886,675429"}, "4": {"slice": true, "split_time": "675429,678891"}, "5": {"slice": false, "split_time": "678891,1048502"}, "6": {"slice": true, "split_time": "1048502,1056469"}, "7": {"slice": false, "split_time": "1056469,1337989"}, "8": {"slice": true, "split_time": "1337989,1381376"}}, "time": 1673351834}, "46fa6dbb81bfe1252ebdefa0884d0e6d": {"chunks": {"0": {"slice": true, "split_time": "0,185799"}, "1": {"slice": false, "split_time": "185799,502227"}, "2": {"slice": true, "split_time": "502227,848473"}, "3": {"slice": false, "split_time": "848473,2347749"}, "4": {"slice": true, "split_time": "2347749,2348783"}, "5": {"slice": false, "split_time": "2348783,3429977"}, "6": {"slice": true, "split_time": "3429977,3890042"}, "7": {"slice": false, "split_time": "3890042,5375055"}, "8": {"slice": true, "split_time": "5375055,5402621"}, "9": {"slice": false, "split_time": "5402621,6462260"}, "10": {"slice": true, "split_time": "6462260,6923811"}, "11": {"slice": false, "split_time": "6923811,8507888"}, "12": {"slice": true, "split_time": "8507888,9252348"}, "13": {"slice": false, "split_time": "9252348,10599478"}, "14": {"slice": true, "split_time": "10599478,10628316"}, "15": {"slice": false, "split_time": "10628316,10692608"}}, "time": 1673446761}, "2b0ffc08fb6fb0f29df83ae8cc0c0eb8": {"chunks": {"0": {"slice": true, "split_time": "0,1208342"}, "1": {"slice": false, "split_time": "1208342,1917473"}, "2": {"slice": true, "split_time": "1917473,1984255"}, "3": {"slice": false, "split_time": "1984255,4413108"}, "4": {"slice": true, "split_time": "4413108,4882745"}, "5": {"slice": false, "split_time": "4882745,8072787"}, "6": {"slice": true, "split_time": "8072787,9419776"}}, "time": 1673388059}, "57440513a8cfffbc7c1ffb790b9723f2": {"chunks": {"0": {"slice": true, "split_time": "0,1199064"}, "1": {"slice": false, "split_time": "1199064,1431015"}, "2": {"slice": true, "split_time": "1431015,1452361"}, "3": {"slice": false, "split_time": "1452361,1898308"}, "4": {"slice": true, "split_time": "1898308,1904596"}, "5": {"slice": false, "split_time": "1904596,2468360"}, "6": {"slice": true, "split_time": "2468360,2479240"}, "7": {"slice": false, "split_time": "2479240,3593434"}, "8": {"slice": true, "split_time": "3593434,3629186"}, "9": {"slice": false, "split_time": "3629186,4074575"}, "10": {"slice": true, "split_time": "4074575,4078966"}, "11": {"slice": false, "split_time": "4078966,4307991"}, "12": {"slice": true, "split_time": "4307991,4316241"}, "13": {"slice": false, "split_time": "4316241,5986540"}, "14": {"slice": true, "split_time": "5986540,5990713"}, "15": {"slice": false, "split_time": "5990713,6243455"}, "16": {"slice": true, "split_time": "6243455,6491661"}, "17": {"slice": false, "split_time": "6491661,6719744"}, "18": {"slice": true, "split_time": "6719744,6739748"}, "19": {"slice": false, "split_time": "6739748,7021570"}, "20": {"slice": true, "split_time": "7021570,7069439"}, "21": {"slice": false, "split_time": "7069439,7363743"}, "22": {"slice": true, "split_time": "7363743,7376200"}, "23": {"slice": false, "split_time": "7376200,9402054"}, "24": {"slice": true, "split_time": "9402054,9403936"}, "25": {"slice": false, "split_time": "9403936,9697617"}, "26": {"slice": true, "split_time": "9697617,9878400"}}, "time": 1673468608}, "1d7d07c16652e4b7300bdaf5d38128b9": {"chunks": {"0": {"slice": true, "split_time": "0,261052"}, "1": {"slice": false, "split_time": "261052,993078"}, "2": {"slice": true, "split_time": "993078,1091294"}, "3": {"slice": false, "split_time": "1091294,2303042"}, "4": {"slice": true, "split_time": "2303042,2304533"}, "5": {"slice": false, "split_time": "2304533,3543442"}, "6": {"slice": true, "split_time": "3543442,3656782"}, "7": {"slice": false, "split_time": "3656782,4337660"}, "8": {"slice": true, "split_time": "4337660,4406921"}, "9": {"slice": false, "split_time": "4406921,6132779"}, "10": {"slice": true, "split_time": "6132779,6147466"}, "11": {"slice": false, "split_time": "6147466,6855506"}, "12": {"slice": true, "split_time": "6855506,7413094"}}, "time": 1673466029}, "dc6af7b6086c4a6f7f6ee206024988f3": {"chunks": {"0": {"slice": true, "split_time": "0,1492137"}, "1": {"slice": false, "split_time": "1492137,2132538"}, "2": {"slice": true, "split_time": "2132538,2140949"}, "3": {"slice": false, "split_time": "2140949,2362096"}, "4": {"slice": true, "split_time": "2362096,2377313"}, "5": {"slice": false, "split_time": "2377313,3190602"}, "6": {"slice": true, "split_time": "3190602,3236563"}, "7": {"slice": false, "split_time": "3236563,3795750"}, "8": {"slice": true, "split_time": "3795750,3798646"}, "9": {"slice": false, "split_time": "3798646,4037700"}, "10": {"slice": true, "split_time": "4037700,4057671"}, "11": {"slice": false, "split_time": "4057671,4828341"}, "12": {"slice": true, "split_time": "4828341,4839270"}, "13": {"slice": false, "split_time": "4839270,5069319"}, "14": {"slice": true, "split_time": "5069319,5070998"}, "15": {"slice": false, "split_time": "5070998,5722290"}, "16": {"slice": true, "split_time": "5722290,5778275"}, "17": {"slice": false, "split_time": "5778275,6346775"}, "18": {"slice": true, "split_time": "6346775,6354400"}, "19": {"slice": false, "split_time": "6354400,6965771"}, "20": {"slice": true, "split_time": "6965771,7361821"}, "21": {"slice": false, "split_time": "7361821,7649206"}, "22": {"slice": true, "split_time": "7649206,7681433"}, "23": {"slice": false, "split_time": "7681433,7921525"}, "24": {"slice": true, "split_time": "7921525,7934007"}, "25": {"slice": false, "split_time": "7934007,8183308"}, "26": {"slice": true, "split_time": "8183308,8238306"}, "27": {"slice": false, "split_time": "8238306,10906572"}, "28": {"slice": true, "split_time": "10906572,10942779"}}, "time": 1673469959}, "010900becfb5574ab3a333f941577c2c": {"chunks": {"0": {"slice": true, "split_time": "0,1083101"}, "1": {"slice": false, "split_time": "1083101,1421801"}, "2": {"slice": true, "split_time": "1421801,1453133"}, "3": {"slice": false, "split_time": "1453133,1794206"}, "4": {"slice": true, "split_time": "1794206,1800891"}, "5": {"slice": false, "split_time": "1800891,2115485"}, "6": {"slice": true, "split_time": "2115485,2152253"}, "7": {"slice": false, "split_time": "2152253,2957806"}, "8": {"slice": true, "split_time": "2957806,2967552"}, "9": {"slice": false, "split_time": "2967552,3967440"}, "10": {"slice": true, "split_time": "3967440,4012384"}, "11": {"slice": false, "split_time": "4012384,4335874"}, "12": {"slice": true, "split_time": "4335874,4363031"}, "13": {"slice": false, "split_time": "4363031,4694987"}, "14": {"slice": true, "split_time": "4694987,4715486"}, "15": {"slice": false, "split_time": "4715486,5030719"}, "16": {"slice": true, "split_time": "5030719,5065596"}, "17": {"slice": false, "split_time": "5065596,6874746"}, "18": {"slice": true, "split_time": "6874746,7885915"}, "19": {"slice": false, "split_time": "7885915,9628060"}, "20": {"slice": true, "split_time": "9628060,10032061"}}, "time": 1673470410}, "5bc32d122ced9dd6d8cf0186fff6466a": {"chunks": {"0": {"slice": true, "split_time": "0,814909"}, "1": {"slice": false, "split_time": "814909,3487605"}, "2": {"slice": true, "split_time": "3487605,3496613"}, "3": {"slice": false, "split_time": "3496613,6186104"}, "4": {"slice": true, "split_time": "6186104,6826505"}, "5": {"slice": false, "split_time": "6826505,9979380"}, "6": {"slice": true, "split_time": "9979380,13711090"}}, "time": 1673474828}}
\ No newline at end of file
diff --git a/infer_tools/slicer.py b/infer_tools/slicer.py
new file mode 100644
index 0000000000000000000000000000000000000000..35a888b906e7df8634cfdcec914f650c6cefd26a
--- /dev/null
+++ b/infer_tools/slicer.py
@@ -0,0 +1,158 @@
+import time
+
+import numpy as np
+import torch
+import torchaudio
+from scipy.ndimage import maximum_filter1d, uniform_filter1d
+
+
+def timeit(func):
+ def run(*args, **kwargs):
+ t = time.time()
+ res = func(*args, **kwargs)
+ print('executing \'%s\' costed %.3fs' % (func.__name__, time.time() - t))
+ return res
+
+ return run
+
+
+# @timeit
+def _window_maximum(arr, win_sz):
+ return maximum_filter1d(arr, size=win_sz)[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
+
+
+# @timeit
+def _window_rms(arr, win_sz):
+ filtered = np.sqrt(uniform_filter1d(np.power(arr, 2), win_sz) - np.power(uniform_filter1d(arr, win_sz), 2))
+ return filtered[win_sz // 2: win_sz // 2 + arr.shape[0] - win_sz + 1]
+
+
+def level2db(levels, eps=1e-12):
+ return 20 * np.log10(np.clip(levels, a_min=eps, a_max=1))
+
+
+def _apply_slice(audio, begin, end):
+ if len(audio.shape) > 1:
+ return audio[:, begin: end]
+ else:
+ return audio[begin: end]
+
+
+class Slicer:
+ def __init__(self,
+ sr: int,
+ db_threshold: float = -40,
+ min_length: int = 5000,
+ win_l: int = 300,
+ win_s: int = 20,
+ max_silence_kept: int = 500):
+ self.db_threshold = db_threshold
+ self.min_samples = round(sr * min_length / 1000)
+ self.win_ln = round(sr * win_l / 1000)
+ self.win_sn = round(sr * win_s / 1000)
+ self.max_silence = round(sr * max_silence_kept / 1000)
+ if not self.min_samples >= self.win_ln >= self.win_sn:
+ raise ValueError('The following condition must be satisfied: min_length >= win_l >= win_s')
+ if not self.max_silence >= self.win_sn:
+ raise ValueError('The following condition must be satisfied: max_silence_kept >= win_s')
+
+ @timeit
+ def slice(self, audio):
+ samples = audio
+ if samples.shape[0] <= self.min_samples:
+ return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
+ # get absolute amplitudes
+ abs_amp = np.abs(samples - np.mean(samples))
+ # calculate local maximum with large window
+ win_max_db = level2db(_window_maximum(abs_amp, win_sz=self.win_ln))
+ sil_tags = []
+ left = right = 0
+ while right < win_max_db.shape[0]:
+ if win_max_db[right] < self.db_threshold:
+ right += 1
+ elif left == right:
+ left += 1
+ right += 1
+ else:
+ if left == 0:
+ split_loc_l = left
+ else:
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
+ split_win_l = left + np.argmin(rms_db_left)
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
+ if len(sil_tags) != 0 and split_loc_l - sil_tags[-1][1] < self.min_samples and right < win_max_db.shape[
+ 0] - 1:
+ right += 1
+ left = right
+ continue
+ if right == win_max_db.shape[0] - 1:
+ split_loc_r = right + self.win_ln
+ else:
+ sil_right_n = min(self.max_silence, (right + self.win_ln - left) // 2)
+ rms_db_right = level2db(_window_rms(samples[right + self.win_ln - sil_right_n: right + self.win_ln],
+ win_sz=self.win_sn))
+ split_win_r = right + self.win_ln - sil_right_n + np.argmin(rms_db_right)
+ split_loc_r = split_win_r + np.argmin(abs_amp[split_win_r: split_win_r + self.win_sn])
+ sil_tags.append((split_loc_l, split_loc_r))
+ right += 1
+ left = right
+ if left != right:
+ sil_left_n = min(self.max_silence, (right + self.win_ln - left) // 2)
+ rms_db_left = level2db(_window_rms(samples[left: left + sil_left_n], win_sz=self.win_sn))
+ split_win_l = left + np.argmin(rms_db_left)
+ split_loc_l = split_win_l + np.argmin(abs_amp[split_win_l: split_win_l + self.win_sn])
+ sil_tags.append((split_loc_l, samples.shape[0]))
+ if len(sil_tags) == 0:
+ return {"0": {"slice": False, "split_time": f"0,{len(audio)}"}}
+ else:
+ chunks = []
+ # 第一段静音并非从头开始,补上有声片段
+ if sil_tags[0][0]:
+ chunks.append({"slice": False, "split_time": f"0,{sil_tags[0][0]}"})
+ for i in range(0, len(sil_tags)):
+ # 标识有声片段(跳过第一段)
+ if i:
+ chunks.append({"slice": False, "split_time": f"{sil_tags[i - 1][1]},{sil_tags[i][0]}"})
+ # 标识所有静音片段
+ chunks.append({"slice": True, "split_time": f"{sil_tags[i][0]},{sil_tags[i][1]}"})
+ # 最后一段静音并非结尾,补上结尾片段
+ if sil_tags[-1][1] != len(audio):
+ chunks.append({"slice": False, "split_time": f"{sil_tags[-1][1]},{len(audio)}"})
+ chunk_dict = {}
+ for i in range(len(chunks)):
+ chunk_dict[str(i)] = chunks[i]
+ return chunk_dict
+
+
+def cut(audio_path, db_thresh=-30, min_len=5000, win_l=300, win_s=20, max_sil_kept=500):
+ audio, sr = torchaudio.load(audio_path)
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
+ audio = audio.cpu().numpy()[0]
+
+ slicer = Slicer(
+ sr=sr,
+ db_threshold=db_thresh,
+ min_length=min_len,
+ win_l=win_l,
+ win_s=win_s,
+ max_silence_kept=max_sil_kept
+ )
+ chunks = slicer.slice(audio)
+ return chunks
+
+
+def chunks2audio(audio_path, chunks):
+ chunks = dict(chunks)
+ audio, sr = torchaudio.load(audio_path)
+ if len(audio.shape) == 2 and audio.shape[1] >= 2:
+ audio = torch.mean(audio, dim=0).unsqueeze(0)
+ audio = audio.cpu().numpy()[0]
+ result = []
+ for k, v in chunks.items():
+ tag = v["split_time"].split(",")
+ result.append((v["slice"], audio[int(tag[0]):int(tag[1])]))
+ return result, sr
+
+
diff --git a/inference.ipynb b/inference.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..129c5ee984fddc9b505b3a9bd3b065ef2ba988c5
--- /dev/null
+++ b/inference.ipynb
@@ -0,0 +1,245 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "e:\\Software\\Anaconda3\\envs\\diffsvc\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "| load 'model' from './checkpoints/nyaru/model_ckpt_steps_112000.ckpt'.\n",
+ "| load 'model' from 'checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt'.\n",
+ "| load HifiGAN: checkpoints/0109_hifigan_bigpopcs_hop128\\model_ckpt_steps_1512000.pth\n",
+ "| Loaded model parameters from checkpoints/0109_hifigan_bigpopcs_hop128\\model_ckpt_steps_1512000.pth.\n",
+ "| HifiGAN device: cuda.\n",
+ "model loaded\n"
+ ]
+ }
+ ],
+ "source": [
+ "from utils.hparams import hparams\n",
+ "from preprocessing.data_gen_utils import get_pitch_parselmouth,get_pitch_crepe\n",
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "import IPython.display as ipd\n",
+ "import utils\n",
+ "import librosa\n",
+ "import torchcrepe\n",
+ "from infer import *\n",
+ "import logging\n",
+ "from infer_tools.infer_tool import *\n",
+ "\n",
+ "logging.getLogger('numba').setLevel(logging.WARNING)\n",
+ "\n",
+ "# 工程文件夹名,训练时用的那个\n",
+ "project_name = \"nyaru\"\n",
+ "model_path = f'./checkpoints/{project_name}/model_ckpt_steps_112000.ckpt'\n",
+ "config_path=f'./checkpoints/{project_name}/config.yaml'\n",
+ "hubert_gpu=True\n",
+ "svc_model = Svc(project_name,config_path,hubert_gpu, model_path)\n",
+ "print('model loaded')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "load chunks from temp\n",
+ "#=====segment start, 0.46s======\n",
+ "jump empty segment\n",
+ "#=====segment start, 6.702s======\n",
+ "load temp crepe f0\n",
+ "executing 'get_pitch' costed 0.066s\n",
+ "hubert (on cpu) time used 0.6847963333129883\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "sample time step: 100%|██████████| 50/50 [00:02<00:00, 21.95it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "executing 'diff_infer' costed 2.310s\n",
+ "executing 'after_infer' costed 1.167s\n",
+ "#=====segment start, 8.831s======\n",
+ "load temp crepe f0\n",
+ "executing 'get_pitch' costed 0.063s\n",
+ "hubert (on cpu) time used 0.8832910060882568\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "sample time step: 100%|██████████| 50/50 [00:02<00:00, 18.36it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "executing 'diff_infer' costed 2.749s\n",
+ "executing 'after_infer' costed 1.894s\n",
+ "#=====segment start, 5.265s======\n",
+ "load temp crepe f0\n",
+ "executing 'get_pitch' costed 0.065s\n",
+ "hubert (on cpu) time used 0.5448079109191895\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "sample time step: 100%|██████████| 50/50 [00:01<00:00, 28.39it/s]\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "executing 'diff_infer' costed 1.780s\n",
+ "executing 'after_infer' costed 1.038s\n",
+ "#=====segment start, 1.377s======\n",
+ "jump empty segment\n"
+ ]
+ }
+ ],
+ "source": [
+ "wav_fn='raw/test_input.wav'#支持多数音频格式,无需手动转为wav\n",
+ "demoaudio, sr = librosa.load(wav_fn)\n",
+ "key = 0 # 音高调整,支持正负(半音)\n",
+ "# 加速倍数\n",
+ "pndm_speedup = 20\n",
+ "wav_gen='test_output.wav'#直接改后缀可以保存不同格式音频,如flac可无损压缩\n",
+ "f0_tst, f0_pred, audio = run_clip(svc_model,file_path=wav_fn, key=key, acc=pndm_speedup, use_crepe=True, use_pe=True, thre=0.05,\n",
+ " use_gt_mel=False, add_noise_step=500,project_name=project_name,out_path=wav_gen)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "ipd.display(ipd.Audio(demoaudio, rate=sr))\n",
+ "ipd.display(ipd.Audio(audio, rate=hparams['audio_sample_rate'], normalize=False))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "#f0_gen,_=get_pitch_crepe(*vocoder.wav2spec(wav_gen),hparams,threshold=0.05)\n",
+ "%matplotlib inline\n",
+ "f0_gen,_=get_pitch_parselmouth(*svc_model.vocoder.wav2spec(wav_gen),hparams)\n",
+ "f0_tst[f0_tst==0]=np.nan#ground truth f0\n",
+ "f0_pred[f0_pred==0]=np.nan#f0 pe predicted\n",
+ "f0_gen[f0_gen==0]=np.nan#f0 generated\n",
+ "fig=plt.figure(figsize=[15,5])\n",
+ "plt.plot(np.arange(0,len(f0_tst)),f0_tst,color='black')\n",
+ "plt.plot(np.arange(0,len(f0_pred)),f0_pred,color='orange')\n",
+ "plt.plot(np.arange(0,len(f0_gen)),f0_gen,color='red')\n",
+ "plt.axhline(librosa.note_to_hz('C4'),ls=\":\",c=\"blue\")\n",
+ "plt.axhline(librosa.note_to_hz('G4'),ls=\":\",c=\"green\")\n",
+ "plt.axhline(librosa.note_to_hz('C5'),ls=\":\",c=\"orange\")\n",
+ "plt.axhline(librosa.note_to_hz('F#5'),ls=\":\",c=\"red\")\n",
+ "#plt.axhline(librosa.note_to_hz('A#5'),ls=\":\",c=\"black\")\n",
+ "plt.show()"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3.8.13 ('diffsvc')",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "5cf89e54348a1bdadbb0ca2d227dcc30cc7e2d47cc75a8605923523671b5b7c7"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/models/genshin/__init__.py b/models/genshin/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/genshin/config.yaml b/models/genshin/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0fbb2e59828810a8fbc7f1c1119702cef1973931
--- /dev/null
+++ b/models/genshin/config.yaml
@@ -0,0 +1,445 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 128
+audio_sample_rate: 44100
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/raiden
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config_nsf.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 50000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: false
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 40.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 2048
+fmax: 16000
+fmin: 40
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 512
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 128
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 0.0012
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 48
+max_tokens: 128000
+max_updates: 260000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+no_fs2: true
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: ''
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/raiden
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 384
+residual_layers: 20
+save_best: false
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: raiden
+spec_max:
+- -0.4759584963321686
+- -0.04242899641394615
+- 0.2820039689540863
+- 0.6635098457336426
+- 0.7846556901931763
+- 0.9242268800735474
+- 1.0596446990966797
+- 0.9890199303627014
+- 0.8979427218437195
+- 0.8635445237159729
+- 0.8591453433036804
+- 0.6987467408180237
+- 0.717823326587677
+- 0.7350517511367798
+- 0.6464754939079285
+- 0.6164345145225525
+- 0.4986744523048401
+- 0.3543139100074768
+- 0.2876613438129425
+- 0.3467520773410797
+- 0.27083638310432434
+- 0.4002445936203003
+- 0.4222544729709625
+- 0.4776916205883026
+- 0.5299767255783081
+- 0.6194124817848206
+- 0.5802494287490845
+- 0.6222044229507446
+- 0.6124054193496704
+- 0.6688933968544006
+- 0.7368689179420471
+- 0.7275264859199524
+- 0.7448640465736389
+- 0.5364857912063599
+- 0.5867365002632141
+- 0.48127084970474243
+- 0.48270556330680847
+- 0.45863038301467896
+- 0.3647041916847229
+- 0.3207940459251404
+- 0.3352258801460266
+- 0.2846287190914154
+- 0.2674693465232849
+- 0.3205840587615967
+- 0.36074087023735046
+- 0.40593528747558594
+- 0.266417920589447
+- 0.22159256041049957
+- 0.19403372704982758
+- 0.29326388239860535
+- 0.29472747445106506
+- 0.3801038861274719
+- 0.3864395320415497
+- 0.285959392786026
+- 0.22213149070739746
+- 0.19549456238746643
+- 0.22238962352275848
+- 0.15776650607585907
+- 0.23960433900356293
+- 0.3050341308116913
+- 0.258531779050827
+- 0.19573383033275604
+- 0.2259710431098938
+- 0.2110864669084549
+- 0.24603094160556793
+- 0.05981471762061119
+- 0.17803697288036346
+- 0.17807669937610626
+- 0.18952223658561707
+- 0.053435735404491425
+- 0.1157914251089096
+- 0.026514273136854172
+- 0.16326436400413513
+- 0.22839383780956268
+- 0.08631942421197891
+- 0.23315998911857605
+- 0.162082701921463
+- 0.1533375382423401
+- 0.12668564915657043
+- 0.08644244074821472
+- 0.02113831788301468
+- 0.3463006615638733
+- 0.22695019841194153
+- 0.14443305134773254
+- 0.20211298763751984
+- 0.07295431941747665
+- -0.007622874341905117
+- -0.02703588828444481
+- -0.06394484639167786
+- -0.09371187537908554
+- -0.005024001933634281
+- -0.013857427053153515
+- -0.1372852921485901
+- -0.10361140221357346
+- -0.12665916979312897
+- -0.20724378526210785
+- -0.2142055779695511
+- 0.19261693954467773
+- 0.08877495676279068
+- -0.21178629994392395
+- -0.18947778642177582
+- -0.2520659863948822
+- -0.22880172729492188
+- -0.14105628430843353
+- -0.05572707951068878
+- -0.26297450065612793
+- -0.412873774766922
+- -0.35086822509765625
+- -0.36454305052757263
+- -0.4511951208114624
+- -0.3738810122013092
+- -0.4491288959980011
+- -0.5304299592971802
+- -0.4495029151439667
+- -0.4343602657318115
+- -0.34189438819885254
+- -0.5748119950294495
+- -0.6650391817092896
+- -0.6537092924118042
+- -0.8302493691444397
+- -0.7430973649024963
+- -0.9170511364936829
+- -1.0624077320098877
+- -1.242630958557129
+- -1.163043737411499
+- -1.0178996324539185
+- -1.4756207466125488
+- -1.5275251865386963
+spec_min:
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.9844536781311035
+- -4.999994277954102
+- -4.999994277954102
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: true
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 2000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.nsf_hifigan.NsfHifiGAN
+vocoder_ckpt: checkpoints/nsf_hifigan/model
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 2048
+work_dir: checkpoints/raiden
diff --git a/models/genshin/raiden.ckpt b/models/genshin/raiden.ckpt
new file mode 100644
index 0000000000000000000000000000000000000000..e8f78b9b6af0d026eda2e0b206b04ce3f16eaeeb
--- /dev/null
+++ b/models/genshin/raiden.ckpt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1da7f6c45f8651131386c946f1bd90010b1fd568792d8d70f7025ea3b9a9dda2
+size 134965595
diff --git a/modules/commons/__pycache__/common_layers.cpython-38.pyc b/modules/commons/__pycache__/common_layers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cf94aa9e8449eb39488f124f7743df1da2ded023
Binary files /dev/null and b/modules/commons/__pycache__/common_layers.cpython-38.pyc differ
diff --git a/modules/commons/__pycache__/espnet_positional_embedding.cpython-38.pyc b/modules/commons/__pycache__/espnet_positional_embedding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8b1b81e35f49db1b199633befefadfce70c38f68
Binary files /dev/null and b/modules/commons/__pycache__/espnet_positional_embedding.cpython-38.pyc differ
diff --git a/modules/commons/__pycache__/ssim.cpython-38.pyc b/modules/commons/__pycache__/ssim.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8fc7255a8b79296ddb4fcf41fad1710de2538a07
Binary files /dev/null and b/modules/commons/__pycache__/ssim.cpython-38.pyc differ
diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..192997cee5b265525b244e1f421b41dab9874b99
--- /dev/null
+++ b/modules/commons/common_layers.py
@@ -0,0 +1,671 @@
+import math
+import torch
+from torch import nn
+from torch.nn import Parameter
+import torch.onnx.operators
+import torch.nn.functional as F
+import utils
+
+
+class Reshape(nn.Module):
+ def __init__(self, *args):
+ super(Reshape, self).__init__()
+ self.shape = args
+
+ def forward(self, x):
+ return x.view(self.shape)
+
+
+class Permute(nn.Module):
+ def __init__(self, *args):
+ super(Permute, self).__init__()
+ self.args = args
+
+ def forward(self, x):
+ return x.permute(self.args)
+
+
+class LinearNorm(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
+ super(LinearNorm, self).__init__()
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.linear_layer.weight,
+ gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, x):
+ return self.linear_layer(x)
+
+
+class ConvNorm(torch.nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
+ padding=None, dilation=1, bias=True, w_init_gain='linear'):
+ super(ConvNorm, self).__init__()
+ if padding is None:
+ assert (kernel_size % 2 == 1)
+ padding = int(dilation * (kernel_size - 1) / 2)
+
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
+ kernel_size=kernel_size, stride=stride,
+ padding=padding, dilation=dilation,
+ bias=bias)
+
+ torch.nn.init.xavier_uniform_(
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
+
+ def forward(self, signal):
+ conv_signal = self.conv(signal)
+ return conv_signal
+
+
+def Embedding(num_embeddings, embedding_dim, padding_idx=None):
+ m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
+ nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
+ if padding_idx is not None:
+ nn.init.constant_(m.weight[padding_idx], 0)
+ return m
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if not export and torch.cuda.is_available():
+ try:
+ from apex.normalization import FusedLayerNorm
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ except ImportError:
+ pass
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+
+def Linear(in_features, out_features, bias=True):
+ m = nn.Linear(in_features, out_features, bias)
+ nn.init.xavier_uniform_(m.weight)
+ if bias:
+ nn.init.constant_(m.bias, 0.)
+ return m
+
+
+class SinusoidalPositionalEmbedding(nn.Module):
+ """This module produces sinusoidal positional embeddings of any length.
+
+ Padding symbols are ignored.
+ """
+
+ def __init__(self, embedding_dim, padding_idx, init_size=1024):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.padding_idx = padding_idx
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ init_size,
+ embedding_dim,
+ padding_idx,
+ )
+ self.register_buffer('_float_tensor', torch.FloatTensor(1))
+
+ @staticmethod
+ def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
+ """Build sinusoidal embeddings.
+
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
+ emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
+ if embedding_dim % 2 == 1:
+ # zero pad
+ emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
+ if padding_idx is not None:
+ emb[padding_idx, :] = 0
+ return emb
+
+ def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs):
+ """Input is expected to be of size [bsz x seqlen]."""
+ bsz, seq_len = input.shape[:2]
+ max_pos = self.padding_idx + 1 + seq_len
+ if self.weights is None or max_pos > self.weights.size(0):
+ # recompute/expand embeddings if needed
+ self.weights = SinusoidalPositionalEmbedding.get_embedding(
+ max_pos,
+ self.embedding_dim,
+ self.padding_idx,
+ )
+ self.weights = self.weights.to(self._float_tensor)
+
+ if incremental_state is not None:
+ # positions is the same for every token when decoding a single step
+ pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
+ return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
+
+ positions = utils.make_positions(input, self.padding_idx) if positions is None else positions
+ return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()
+
+ def max_positions(self):
+ """Maximum number of supported positions."""
+ return int(1e5) # an arbitrary large number
+
+
+class ConvTBC(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0):
+ super(ConvTBC, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = kernel_size
+ self.padding = padding
+
+ self.weight = torch.nn.Parameter(torch.Tensor(
+ self.kernel_size, in_channels, out_channels))
+ self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
+
+ def forward(self, input):
+ return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding)
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True,
+ add_bias_kv=False, add_zero_attn=False, self_attention=False,
+ encoder_decoder_attention=False):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.kdim = kdim if kdim is not None else embed_dim
+ self.vdim = vdim if vdim is not None else embed_dim
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
+
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim ** -0.5
+
+ self.self_attention = self_attention
+ self.encoder_decoder_attention = encoder_decoder_attention
+
+ assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \
+ 'value to be of the same size'
+
+ if self.qkv_same_dim:
+ self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
+ else:
+ self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
+ self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
+ self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
+
+ if bias:
+ self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
+ else:
+ self.register_parameter('in_proj_bias', None)
+
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ if add_bias_kv:
+ self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
+ self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
+ else:
+ self.bias_k = self.bias_v = None
+
+ self.add_zero_attn = add_zero_attn
+
+ self.reset_parameters()
+
+ self.enable_torch_version = False
+ if hasattr(F, "multi_head_attention_forward"):
+ self.enable_torch_version = True
+ else:
+ self.enable_torch_version = False
+ self.last_attn_probs = None
+
+ def reset_parameters(self):
+ if self.qkv_same_dim:
+ nn.init.xavier_uniform_(self.in_proj_weight)
+ else:
+ nn.init.xavier_uniform_(self.k_proj_weight)
+ nn.init.xavier_uniform_(self.v_proj_weight)
+ nn.init.xavier_uniform_(self.q_proj_weight)
+
+ nn.init.xavier_uniform_(self.out_proj.weight)
+ if self.in_proj_bias is not None:
+ nn.init.constant_(self.in_proj_bias, 0.)
+ nn.init.constant_(self.out_proj.bias, 0.)
+ if self.bias_k is not None:
+ nn.init.xavier_normal_(self.bias_k)
+ if self.bias_v is not None:
+ nn.init.xavier_normal_(self.bias_v)
+
+ def forward(
+ self,
+ query, key, value,
+ key_padding_mask=None,
+ incremental_state=None,
+ need_weights=True,
+ static_kv=False,
+ attn_mask=None,
+ before_softmax=False,
+ need_head_weights=False,
+ enc_dec_attn_constraint_mask=None,
+ reset_attn_weight=None
+ ):
+ """Input shape: Time x Batch x Channel
+
+ Args:
+ key_padding_mask (ByteTensor, optional): mask to exclude
+ keys that are pads, of shape `(batch, src_len)`, where
+ padding elements are indicated by 1s.
+ need_weights (bool, optional): return the attention weights,
+ averaged over heads (default: False).
+ attn_mask (ByteTensor, optional): typically used to
+ implement causal attention, where the mask prevents the
+ attention from looking forward in time (default: None).
+ before_softmax (bool, optional): return the raw attention
+ weights and values before the attention softmax.
+ need_head_weights (bool, optional): return the attention
+ weights for each head. Implies *need_weights*. Default:
+ return the average attention weights over all heads.
+ """
+ if need_head_weights:
+ need_weights = True
+
+ tgt_len, bsz, embed_dim = query.size()
+ assert embed_dim == self.embed_dim
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
+
+ if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None:
+ if self.qkv_same_dim:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ self.in_proj_weight,
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask)
+ else:
+ return F.multi_head_attention_forward(query, key, value,
+ self.embed_dim, self.num_heads,
+ torch.empty([0]),
+ self.in_proj_bias, self.bias_k, self.bias_v,
+ self.add_zero_attn, self.dropout,
+ self.out_proj.weight, self.out_proj.bias,
+ self.training, key_padding_mask, need_weights,
+ attn_mask, use_separate_proj_weight=True,
+ q_proj_weight=self.q_proj_weight,
+ k_proj_weight=self.k_proj_weight,
+ v_proj_weight=self.v_proj_weight)
+
+ if incremental_state is not None:
+ print('Not implemented error.')
+ exit()
+ else:
+ saved_state = None
+
+ if self.self_attention:
+ # self-attention
+ q, k, v = self.in_proj_qkv(query)
+ elif self.encoder_decoder_attention:
+ # encoder-decoder attention
+ q = self.in_proj_q(query)
+ if key is None:
+ assert value is None
+ k = v = None
+ else:
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(key)
+
+ else:
+ q = self.in_proj_q(query)
+ k = self.in_proj_k(key)
+ v = self.in_proj_v(value)
+ q *= self.scaling
+
+ if self.bias_k is not None:
+ assert self.bias_v is not None
+ k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
+ v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1)
+
+ q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if k is not None:
+ k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+ if v is not None:
+ v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
+
+ if saved_state is not None:
+ print('Not implemented error.')
+ exit()
+
+ src_len = k.size(1)
+
+ # This is part of a workaround to get around fork/join parallelism
+ # not supporting Optional types.
+ if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]):
+ key_padding_mask = None
+
+ if key_padding_mask is not None:
+ assert key_padding_mask.size(0) == bsz
+ assert key_padding_mask.size(1) == src_len
+
+ if self.add_zero_attn:
+ src_len += 1
+ k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
+ v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
+ if attn_mask is not None:
+ attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1)
+ if key_padding_mask is not None:
+ key_padding_mask = torch.cat(
+ [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1)
+
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
+
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
+
+ if attn_mask is not None:
+ if len(attn_mask.shape) == 2:
+ attn_mask = attn_mask.unsqueeze(0)
+ elif len(attn_mask.shape) == 3:
+ attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape(
+ bsz * self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights + attn_mask
+
+ if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ enc_dec_attn_constraint_mask.unsqueeze(2).bool(),
+ -1e9,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ if key_padding_mask is not None:
+ # don't attend to padding symbols
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attn_weights = attn_weights.masked_fill(
+ key_padding_mask.unsqueeze(1).unsqueeze(2),
+ -1e9,
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
+
+ attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+
+ if before_softmax:
+ return attn_weights, v
+
+ attn_weights_float = utils.softmax(attn_weights, dim=-1)
+ attn_weights = attn_weights_float.type_as(attn_weights)
+ attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training)
+
+ if reset_attn_weight is not None:
+ if reset_attn_weight:
+ self.last_attn_probs = attn_probs.detach()
+ else:
+ assert self.last_attn_probs is not None
+ attn_probs = self.last_attn_probs
+ attn = torch.bmm(attn_probs, v)
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
+ attn = self.out_proj(attn)
+
+ if need_weights:
+ attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0)
+ if not need_head_weights:
+ # average attention weights over heads
+ attn_weights = attn_weights.mean(dim=0)
+ else:
+ attn_weights = None
+
+ return attn, (attn_weights, attn_logits)
+
+ def in_proj_qkv(self, query):
+ return self._in_proj(query).chunk(3, dim=-1)
+
+ def in_proj_q(self, query):
+ if self.qkv_same_dim:
+ return self._in_proj(query, end=self.embed_dim)
+ else:
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[:self.embed_dim]
+ return F.linear(query, self.q_proj_weight, bias)
+
+ def in_proj_k(self, key):
+ if self.qkv_same_dim:
+ return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
+ else:
+ weight = self.k_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[self.embed_dim:2 * self.embed_dim]
+ return F.linear(key, weight, bias)
+
+ def in_proj_v(self, value):
+ if self.qkv_same_dim:
+ return self._in_proj(value, start=2 * self.embed_dim)
+ else:
+ weight = self.v_proj_weight
+ bias = self.in_proj_bias
+ if bias is not None:
+ bias = bias[2 * self.embed_dim:]
+ return F.linear(value, weight, bias)
+
+ def _in_proj(self, input, start=0, end=None):
+ weight = self.in_proj_weight
+ bias = self.in_proj_bias
+ weight = weight[start:end, :]
+ if bias is not None:
+ bias = bias[start:end]
+ return F.linear(input, weight, bias)
+
+
+ def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz):
+ return attn_weights
+
+
+class Swish(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, i):
+ result = i * torch.sigmoid(i)
+ ctx.save_for_backward(i)
+ return result
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ i = ctx.saved_variables[0]
+ sigmoid_i = torch.sigmoid(i)
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
+
+
+class CustomSwish(nn.Module):
+ def forward(self, input_tensor):
+ return Swish.apply(input_tensor)
+
+class Mish(nn.Module):
+ def forward(self, x):
+ return x * torch.tanh(F.softplus(x))
+
+class TransformerFFNLayer(nn.Module):
+ def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.dropout = dropout
+ self.act = act
+ if padding == 'SAME':
+ self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2)
+ elif padding == 'LEFT':
+ self.ffn_1 = nn.Sequential(
+ nn.ConstantPad1d((kernel_size - 1, 0), 0.0),
+ nn.Conv1d(hidden_size, filter_size, kernel_size)
+ )
+ self.ffn_2 = Linear(filter_size, hidden_size)
+ if self.act == 'swish':
+ self.swish_fn = CustomSwish()
+
+ def forward(self, x, incremental_state=None):
+ # x: T x B x C
+ if incremental_state is not None:
+ assert incremental_state is None, 'Nar-generation does not allow this.'
+ exit(1)
+
+ x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1)
+ x = x * self.kernel_size ** -0.5
+
+ if incremental_state is not None:
+ x = x[-1:]
+ if self.act == 'gelu':
+ x = F.gelu(x)
+ if self.act == 'relu':
+ x = F.relu(x)
+ if self.act == 'swish':
+ x = self.swish_fn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = self.ffn_2(x)
+ return x
+
+
+class BatchNorm1dTBC(nn.Module):
+ def __init__(self, c):
+ super(BatchNorm1dTBC, self).__init__()
+ self.bn = nn.BatchNorm1d(c)
+
+ def forward(self, x):
+ """
+
+ :param x: [T, B, C]
+ :return: [T, B, C]
+ """
+ x = x.permute(1, 2, 0) # [B, C, T]
+ x = self.bn(x) # [B, C, T]
+ x = x.permute(2, 0, 1) # [T, B, C]
+ return x
+
+
+class EncSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
+ relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.num_heads = num_heads
+ if num_heads > 0:
+ if norm == 'ln':
+ self.layer_norm1 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm1 = BatchNorm1dTBC(c)
+ self.self_attn = MultiheadAttention(
+ self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False,
+ )
+ if norm == 'ln':
+ self.layer_norm2 = LayerNorm(c)
+ elif norm == 'bn':
+ self.layer_norm2 = BatchNorm1dTBC(c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act)
+
+ def forward(self, x, encoder_padding_mask=None, **kwargs):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ if self.num_heads > 0:
+ residual = x
+ x = self.layer_norm1(x)
+ x, _, = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=encoder_padding_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+
+ residual = x
+ x = self.layer_norm2(x)
+ x = self.ffn(x)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None]
+ return x
+
+
+class DecSALayer(nn.Module):
+ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'):
+ super().__init__()
+ self.c = c
+ self.dropout = dropout
+ self.layer_norm1 = LayerNorm(c)
+ self.self_attn = MultiheadAttention(
+ c, num_heads, self_attention=True, dropout=attention_dropout, bias=False
+ )
+ self.layer_norm2 = LayerNorm(c)
+ self.encoder_attn = MultiheadAttention(
+ c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False,
+ )
+ self.layer_norm3 = LayerNorm(c)
+ self.ffn = TransformerFFNLayer(
+ c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act)
+
+ def forward(
+ self,
+ x,
+ encoder_out=None,
+ encoder_padding_mask=None,
+ incremental_state=None,
+ self_attn_mask=None,
+ self_attn_padding_mask=None,
+ attn_out=None,
+ reset_attn_weight=None,
+ **kwargs,
+ ):
+ layer_norm_training = kwargs.get('layer_norm_training', None)
+ if layer_norm_training is not None:
+ self.layer_norm1.training = layer_norm_training
+ self.layer_norm2.training = layer_norm_training
+ self.layer_norm3.training = layer_norm_training
+ residual = x
+ x = self.layer_norm1(x)
+ x, _ = self.self_attn(
+ query=x,
+ key=x,
+ value=x,
+ key_padding_mask=self_attn_padding_mask,
+ incremental_state=incremental_state,
+ attn_mask=self_attn_mask
+ )
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm2(x)
+ if encoder_out is not None:
+ x, attn = self.encoder_attn(
+ query=x,
+ key=encoder_out,
+ value=encoder_out,
+ key_padding_mask=encoder_padding_mask,
+ incremental_state=incremental_state,
+ static_kv=True,
+ enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'),
+ reset_attn_weight=reset_attn_weight
+ )
+ attn_logits = attn[1]
+ else:
+ assert attn_out is not None
+ x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1))
+ attn_logits = None
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+
+ residual = x
+ x = self.layer_norm3(x)
+ x = self.ffn(x, incremental_state=incremental_state)
+ x = F.dropout(x, self.dropout, training=self.training)
+ x = residual + x
+ # if len(attn_logits.size()) > 3:
+ # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1)
+ # attn_logits = attn_logits.gather(1,
+ # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1)
+ return x, attn_logits
diff --git a/modules/commons/espnet_positional_embedding.py b/modules/commons/espnet_positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..74decb6ab300951490ae08a4b93041a0542b5bb7
--- /dev/null
+++ b/modules/commons/espnet_positional_embedding.py
@@ -0,0 +1,113 @@
+import math
+import torch
+
+
+class PositionalEncoding(torch.nn.Module):
+ """Positional encoding.
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ reverse (bool): Whether to reverse the input position.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
+ """Construct an PositionalEncoding object."""
+ super(PositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.reverse = reverse
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ if self.pe.size(1) >= x.size(1):
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ pe = torch.zeros(x.size(1), self.d_model)
+ if self.reverse:
+ position = torch.arange(
+ x.size(1) - 1, -1, -1.0, dtype=torch.float32
+ ).unsqueeze(1)
+ else:
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.d_model, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.d_model)
+ )
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
+
+ def forward(self, x: torch.Tensor):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale + self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class ScaledPositionalEncoding(PositionalEncoding):
+ """Scaled positional encoding module.
+ See Sec. 3.2 https://arxiv.org/abs/1809.08895
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
+ self.alpha = torch.nn.Parameter(torch.tensor(1.0))
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.alpha.data = torch.tensor(1.0)
+
+ def forward(self, x):
+ """Add positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x + self.alpha * self.pe[:, : x.size(1)]
+ return self.dropout(x)
+
+
+class RelPositionalEncoding(PositionalEncoding):
+ """Relative positional encoding module.
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+ """
+
+ def __init__(self, d_model, dropout_rate, max_len=5000):
+ """Initialize class."""
+ super().__init__(d_model, dropout_rate, max_len, reverse=True)
+
+ def forward(self, x):
+ """Compute positional encoding.
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+ torch.Tensor: Positional embedding tensor (1, time, `*`).
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.pe[:, : x.size(1)]
+ return self.dropout(x) + self.dropout(pos_emb)
\ No newline at end of file
diff --git a/modules/commons/ssim.py b/modules/commons/ssim.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0241f267ef58b24979e022b05f2a9adf768826
--- /dev/null
+++ b/modules/commons/ssim.py
@@ -0,0 +1,391 @@
+# '''
+# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py
+# '''
+#
+# import torch
+# import torch.jit
+# import torch.nn.functional as F
+#
+#
+# @torch.jit.script
+# def create_window(window_size: int, sigma: float, channel: int):
+# '''
+# Create 1-D gauss kernel
+# :param window_size: the size of gauss kernel
+# :param sigma: sigma of normal distribution
+# :param channel: input channel
+# :return: 1D kernel
+# '''
+# coords = torch.arange(window_size, dtype=torch.float)
+# coords -= window_size // 2
+#
+# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
+# g /= g.sum()
+#
+# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1)
+# return g
+#
+#
+# @torch.jit.script
+# def _gaussian_filter(x, window_1d, use_padding: bool):
+# '''
+# Blur input with 1-D kernel
+# :param x: batch of tensors to be blured
+# :param window_1d: 1-D gauss kernel
+# :param use_padding: padding image before conv
+# :return: blured tensors
+# '''
+# C = x.shape[1]
+# padding = 0
+# if use_padding:
+# window_size = window_1d.shape[3]
+# padding = window_size // 2
+# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C)
+# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C)
+# return out
+#
+#
+# @torch.jit.script
+# def ssim(X, Y, window, data_range: float, use_padding: bool = False):
+# '''
+# Calculate ssim index for X and Y
+# :param X: images [B, C, H, N_bins]
+# :param Y: images [B, C, H, N_bins]
+# :param window: 1-D gauss kernel
+# :param data_range: value range of input images. (usually 1.0 or 255)
+# :param use_padding: padding image before conv
+# :return:
+# '''
+#
+# K1 = 0.01
+# K2 = 0.03
+# compensation = 1.0
+#
+# C1 = (K1 * data_range) ** 2
+# C2 = (K2 * data_range) ** 2
+#
+# mu1 = _gaussian_filter(X, window, use_padding)
+# mu2 = _gaussian_filter(Y, window, use_padding)
+# sigma1_sq = _gaussian_filter(X * X, window, use_padding)
+# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding)
+# sigma12 = _gaussian_filter(X * Y, window, use_padding)
+#
+# mu1_sq = mu1.pow(2)
+# mu2_sq = mu2.pow(2)
+# mu1_mu2 = mu1 * mu2
+#
+# sigma1_sq = compensation * (sigma1_sq - mu1_sq)
+# sigma2_sq = compensation * (sigma2_sq - mu2_sq)
+# sigma12 = compensation * (sigma12 - mu1_mu2)
+#
+# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
+# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan.
+# cs_map = cs_map.clamp_min(0.)
+# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
+#
+# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW
+# cs = cs_map.mean(dim=(1, 2, 3))
+#
+# return ssim_val, cs
+#
+#
+# @torch.jit.script
+# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8):
+# '''
+# interface of ms-ssim
+# :param X: a batch of images, (N,C,H,W)
+# :param Y: a batch of images, (N,C,H,W)
+# :param window: 1-D gauss kernel
+# :param data_range: value range of input images. (usually 1.0 or 255)
+# :param weights: weights for different levels
+# :param use_padding: padding image before conv
+# :param eps: use for avoid grad nan.
+# :return:
+# '''
+# levels = weights.shape[0]
+# cs_vals = []
+# ssim_vals = []
+# for _ in range(levels):
+# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding)
+# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
+# ssim_val = ssim_val.clamp_min(eps)
+# cs = cs.clamp_min(eps)
+# cs_vals.append(cs)
+#
+# ssim_vals.append(ssim_val)
+# padding = (X.shape[2] % 2, X.shape[3] % 2)
+# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding)
+# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding)
+#
+# cs_vals = torch.stack(cs_vals, dim=0)
+# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0)
+# return ms_ssim_val
+#
+#
+# class SSIM(torch.jit.ScriptModule):
+# __constants__ = ['data_range', 'use_padding']
+#
+# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False):
+# '''
+# :param window_size: the size of gauss kernel
+# :param window_sigma: sigma of normal distribution
+# :param data_range: value range of input images. (usually 1.0 or 255)
+# :param channel: input channels (default: 3)
+# :param use_padding: padding image before conv
+# '''
+# super().__init__()
+# assert window_size % 2 == 1, 'Window size must be odd.'
+# window = create_window(window_size, window_sigma, channel)
+# self.register_buffer('window', window)
+# self.data_range = data_range
+# self.use_padding = use_padding
+#
+# @torch.jit.script_method
+# def forward(self, X, Y):
+# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding)
+# return r[0]
+#
+#
+# class MS_SSIM(torch.jit.ScriptModule):
+# __constants__ = ['data_range', 'use_padding', 'eps']
+#
+# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None,
+# levels=None, eps=1e-8):
+# '''
+# class for ms-ssim
+# :param window_size: the size of gauss kernel
+# :param window_sigma: sigma of normal distribution
+# :param data_range: value range of input images. (usually 1.0 or 255)
+# :param channel: input channels
+# :param use_padding: padding image before conv
+# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
+# :param levels: number of downsampling
+# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf.
+# '''
+# super().__init__()
+# assert window_size % 2 == 1, 'Window size must be odd.'
+# self.data_range = data_range
+# self.use_padding = use_padding
+# self.eps = eps
+#
+# window = create_window(window_size, window_sigma, channel)
+# self.register_buffer('window', window)
+#
+# if weights is None:
+# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
+# weights = torch.tensor(weights, dtype=torch.float)
+#
+# if levels is not None:
+# weights = weights[:levels]
+# weights = weights / weights.sum()
+#
+# self.register_buffer('weights', weights)
+#
+# @torch.jit.script_method
+# def forward(self, X, Y):
+# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights,
+# use_padding=self.use_padding, eps=self.eps)
+#
+#
+# if __name__ == '__main__':
+# print('Simple Test')
+# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda')
+# img1 = im / 255
+# img2 = img1 * 0.5
+#
+# losser = SSIM(data_range=1.).cuda()
+# loss = losser(img1, img2).mean()
+#
+# losser2 = MS_SSIM(data_range=1.).cuda()
+# loss2 = losser2(img1, img2).mean()
+#
+# print(loss.item())
+# print(loss2.item())
+#
+# if __name__ == '__main__':
+# print('Training Test')
+# import cv2
+# import torch.optim
+# import numpy as np
+# import imageio
+# import time
+#
+# out_test_video = False
+# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF
+# video_use_gif = False
+#
+# im = cv2.imread('test_img1.jpg', 1)
+# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255.
+#
+# if out_test_video:
+# if video_use_gif:
+# fps = 0.5
+# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
+# suffix = '.gif'
+# else:
+# fps = 5
+# out_wh = (im.shape[1], im.shape[0])
+# suffix = '.mkv'
+# video_last_time = time.perf_counter()
+# video = imageio.get_writer('ssim_test' + suffix, fps=fps)
+#
+# # 测试ssim
+# print('Training SSIM')
+# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
+# rand_im.requires_grad = True
+# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
+# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda()
+# ssim_score = 0
+# while ssim_score < 0.999:
+# optim.zero_grad()
+# loss = losser(rand_im, t_im)
+# (-loss).sum().backward()
+# ssim_score = loss.item()
+# optim.step()
+# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
+# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
+#
+# if out_test_video:
+# if time.perf_counter() - video_last_time > 1. / fps:
+# video_last_time = time.perf_counter()
+# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
+# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
+# if isinstance(out_frame, cv2.UMat):
+# out_frame = out_frame.get()
+# video.append_data(out_frame)
+#
+# cv2.imshow('ssim', r_im)
+# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score)
+# cv2.waitKey(1)
+#
+# if out_test_video:
+# video.close()
+#
+# # 测试ms_ssim
+# if out_test_video:
+# if video_use_gif:
+# fps = 0.5
+# out_wh = (im.shape[1] // 2, im.shape[0] // 2)
+# suffix = '.gif'
+# else:
+# fps = 5
+# out_wh = (im.shape[1], im.shape[0])
+# suffix = '.mkv'
+# video_last_time = time.perf_counter()
+# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps)
+#
+# print('Training MS_SSIM')
+# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255.
+# rand_im.requires_grad = True
+# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8)
+# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda()
+# ssim_score = 0
+# while ssim_score < 0.999:
+# optim.zero_grad()
+# loss = losser(rand_im, t_im)
+# (-loss).sum().backward()
+# ssim_score = loss.item()
+# optim.step()
+# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0]
+# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2)
+#
+# if out_test_video:
+# if time.perf_counter() - video_last_time > 1. / fps:
+# video_last_time = time.perf_counter()
+# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB)
+# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA)
+# if isinstance(out_frame, cv2.UMat):
+# out_frame = out_frame.get()
+# video.append_data(out_frame)
+#
+# cv2.imshow('ms_ssim', r_im)
+# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score)
+# cv2.waitKey(1)
+#
+# if out_test_video:
+# video.close()
+
+"""
+Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim
+"""
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Variable
+import numpy as np
+from math import exp
+
+
+def gaussian(window_size, sigma):
+ gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
+ return gauss / gauss.sum()
+
+
+def create_window(window_size, channel):
+ _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
+ return window
+
+
+def _ssim(img1, img2, window, window_size, channel, size_average=True):
+ mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
+ mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
+ sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
+ sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
+
+ C1 = 0.01 ** 2
+ C2 = 0.03 ** 2
+
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
+
+ if size_average:
+ return ssim_map.mean()
+ else:
+ return ssim_map.mean(1)
+
+
+class SSIM(torch.nn.Module):
+ def __init__(self, window_size=11, size_average=True):
+ super(SSIM, self).__init__()
+ self.window_size = window_size
+ self.size_average = size_average
+ self.channel = 1
+ self.window = create_window(window_size, self.channel)
+
+ def forward(self, img1, img2):
+ (_, channel, _, _) = img1.size()
+
+ if channel == self.channel and self.window.data.type() == img1.data.type():
+ window = self.window
+ else:
+ window = create_window(self.window_size, channel)
+
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+
+ self.window = window
+ self.channel = channel
+
+ return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
+
+
+window = None
+
+
+def ssim(img1, img2, window_size=11, size_average=True):
+ (_, channel, _, _) = img1.size()
+ global window
+ if window is None:
+ window = create_window(window_size, channel)
+ if img1.is_cuda:
+ window = window.cuda(img1.get_device())
+ window = window.type_as(img1)
+ return _ssim(img1, img2, window, window_size, channel, size_average)
diff --git a/modules/fastspeech/__pycache__/fs2.cpython-38.pyc b/modules/fastspeech/__pycache__/fs2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..dee5fcdd0a76202567f83557c7acee622e5b6e6e
Binary files /dev/null and b/modules/fastspeech/__pycache__/fs2.cpython-38.pyc differ
diff --git a/modules/fastspeech/__pycache__/pe.cpython-38.pyc b/modules/fastspeech/__pycache__/pe.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6775c82edef051b7b6ddb3840c487526e026b4d9
Binary files /dev/null and b/modules/fastspeech/__pycache__/pe.cpython-38.pyc differ
diff --git a/modules/fastspeech/__pycache__/tts_modules.cpython-38.pyc b/modules/fastspeech/__pycache__/tts_modules.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce853baf7c08eabd8a9da96f476a2331baa75881
Binary files /dev/null and b/modules/fastspeech/__pycache__/tts_modules.cpython-38.pyc differ
diff --git a/modules/fastspeech/fs2.py b/modules/fastspeech/fs2.py
new file mode 100644
index 0000000000000000000000000000000000000000..085448a8dd44cdd8a8e106e9bf1b983fae29cb55
--- /dev/null
+++ b/modules/fastspeech/fs2.py
@@ -0,0 +1,255 @@
+from modules.commons.common_layers import *
+from modules.commons.common_layers import Embedding
+from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \
+ EnergyPredictor, FastspeechEncoder
+from utils.cwt import cwt2f0
+from utils.hparams import hparams
+from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
+
+FS_ENCODERS = {
+ 'fft': lambda hp: FastspeechEncoder(
+ hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'],
+ num_heads=hp['num_heads']),
+}
+
+FS_DECODERS = {
+ 'fft': lambda hp: FastspeechDecoder(
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
+}
+
+
+class FastSpeech2(nn.Module):
+ def __init__(self, dictionary, out_dims=None):
+ super().__init__()
+ # self.dictionary = dictionary
+ self.padding_idx = 0
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
+ self.enc_layers = hparams['enc_layers']
+ self.dec_layers = hparams['dec_layers']
+ self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams)
+ self.decoder = FS_DECODERS[hparams['decoder_type']](hparams)
+ self.hidden_size = hparams['hidden_size']
+ # self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size)
+ self.out_dims = out_dims
+ if out_dims is None:
+ self.out_dims = hparams['audio_num_mel_bins']
+ self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True)
+ #=========not used===========
+ # if hparams['use_spk_id']:
+ # self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+ # if hparams['use_split_spk_id']:
+ # self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+ # self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size)
+ # elif hparams['use_spk_embed']:
+ # self.spk_embed_proj = Linear(256, self.hidden_size, bias=True)
+ predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
+ # self.dur_predictor = DurationPredictor(
+ # self.hidden_size,
+ # n_chans=predictor_hidden,
+ # n_layers=hparams['dur_predictor_layers'],
+ # dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'],
+ # kernel_size=hparams['dur_predictor_kernel'])
+ # self.length_regulator = LengthRegulator()
+ if hparams['use_pitch_embed']:
+ self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx)
+ if hparams['pitch_type'] == 'cwt':
+ h = hparams['cwt_hidden_size']
+ cwt_out_dims = 10
+ if hparams['use_uv']:
+ cwt_out_dims = cwt_out_dims + 1
+ self.cwt_predictor = nn.Sequential(
+ nn.Linear(self.hidden_size, h),
+ PitchPredictor(
+ h,
+ n_chans=predictor_hidden,
+ n_layers=hparams['predictor_layers'],
+ dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims,
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']))
+ self.cwt_stats_layers = nn.Sequential(
+ nn.Linear(self.hidden_size, h), nn.ReLU(),
+ nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2)
+ )
+ else:
+ self.pitch_predictor = PitchPredictor(
+ self.hidden_size,
+ n_chans=predictor_hidden,
+ n_layers=hparams['predictor_layers'],
+ dropout_rate=hparams['predictor_dropout'],
+ odim=2 if hparams['pitch_type'] == 'frame' else 1,
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+ if hparams['use_energy_embed']:
+ self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx)
+ # self.energy_predictor = EnergyPredictor(
+ # self.hidden_size,
+ # n_chans=predictor_hidden,
+ # n_layers=hparams['predictor_layers'],
+ # dropout_rate=hparams['predictor_dropout'], odim=1,
+ # padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+
+ # def build_embedding(self, dictionary, embed_dim):
+ # num_embeddings = len(dictionary)
+ # emb = Embedding(num_embeddings, embed_dim, self.padding_idx)
+ # return emb
+
+ def forward(self, hubert, mel2ph=None, spk_embed=None,
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=True,
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
+ ret = {}
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
+ encoder_out =self.encoder(hubert) # [B, T, C]
+ else:
+ encoder_out =hubert
+ src_nonpadding = (hubert!=0).any(-1)[:,:,None]
+
+ # add ref style embed
+ # Not implemented
+ # variance encoder
+ var_embed = 0
+
+ # encoder_out_dur denotes encoder outputs for duration predictor
+ # in speech adaptation, duration predictor use old speaker embedding
+ if hparams['use_spk_embed']:
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
+ elif hparams['use_spk_id']:
+ spk_embed_id = spk_embed
+ if spk_embed_dur_id is None:
+ spk_embed_dur_id = spk_embed_id
+ if spk_embed_f0_id is None:
+ spk_embed_f0_id = spk_embed_id
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
+ spk_embed_dur = spk_embed_f0 = spk_embed
+ if hparams['use_split_spk_id']:
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
+ else:
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
+
+ # add dur
+ # dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
+
+ # mel2ph = self.add_dur(dur_inp, mel2ph, hubert, ret)
+ ret['mel2ph'] = mel2ph
+
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
+
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
+ decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
+
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
+
+ # add pitch and energy embed
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
+ if hparams['use_pitch_embed']:
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
+ if hparams['use_energy_embed']:
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
+
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
+ if not hparams['no_fs2'] if 'no_fs2' in hparams.keys() else True:
+ if skip_decoder:
+ return ret
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
+
+ return ret
+
+ def add_dur(self, dur_input, mel2ph, hubert, ret):
+ src_padding = (hubert==0).all(-1)
+ dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
+ if mel2ph is None:
+ dur, xs = self.dur_predictor.inference(dur_input, src_padding)
+ ret['dur'] = xs
+ ret['dur_choice'] = dur
+ mel2ph = self.length_regulator(dur, src_padding).detach()
+ else:
+ ret['dur'] = self.dur_predictor(dur_input, src_padding)
+ ret['mel2ph'] = mel2ph
+ return mel2ph
+
+ def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs):
+ x = decoder_inp # [B, T, H]
+ x = self.decoder(x)
+ x = self.mel_out(x)
+ return x * tgt_nonpadding
+
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+ f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales'])
+ f0 = torch.cat(
+ [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1)
+ f0_norm = norm_f0(f0, None, hparams)
+ return f0_norm
+
+ def out2mel(self, out):
+ return out
+
+ def add_pitch(self,decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
+ # if hparams['pitch_type'] == 'ph':
+ # pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
+ # pitch_padding = (encoder_out.sum().abs() == 0)
+ # ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp)
+ # if f0 is None:
+ # f0 = pitch_pred[:, :, 0]
+ # ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
+ # pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
+ # pitch = F.pad(pitch, [1, 0])
+ # pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
+ # pitch_embedding = pitch_embed(pitch)
+ # return pitch_embedding
+
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+
+ pitch_padding = (mel2ph == 0)
+
+ # if hparams['pitch_type'] == 'cwt':
+ # # NOTE: this part of script is *isolated* from other scripts, which means
+ # # it may not be compatible with the current version.
+ # pass
+ # # pitch_padding = None
+ # # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
+ # # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
+ # # mean = ret['f0_mean'] = stats_out[:, 0]
+ # # std = ret['f0_std'] = stats_out[:, 1]
+ # # cwt_spec = cwt_out[:, :, :10]
+ # # if f0 is None:
+ # # std = std * hparams['cwt_std_scale']
+ # # f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+ # # if hparams['use_uv']:
+ # # assert cwt_out.shape[-1] == 11
+ # # uv = cwt_out[:, :, -1] > 0
+ # elif hparams['pitch_ar']:
+ # ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if is_training else None)
+ # if f0 is None:
+ # f0 = pitch_pred[:, :, 0]
+ # else:
+ #ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
+ # if f0 is None:
+ # f0 = pitch_pred[:, :, 0]
+ # if hparams['use_uv'] and uv is None:
+ # uv = pitch_pred[:, :, 1] > 0
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+
+ pitch = f0_to_coarse(f0_denorm,hparams) # start from 0
+ ret['pitch_pred']=pitch.unsqueeze(-1)
+ # print(ret['pitch_pred'].shape)
+ # print(pitch.shape)
+ pitch_embedding = self.pitch_embed(pitch)
+ return pitch_embedding
+
+ def add_energy(self,decoder_inp, energy, ret):
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+ ret['energy_pred'] = energy#energy_pred = self.energy_predictor(decoder_inp)[:, :, 0]
+ # if energy is None:
+ # energy = energy_pred
+ energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
+ energy_embedding = self.energy_embed(energy)
+ return energy_embedding
+
+ @staticmethod
+ def mel_norm(x):
+ return (x + 5.5) / (6.3 / 2) - 1
+
+ @staticmethod
+ def mel_denorm(x):
+ return (x + 1) * (6.3 / 2) - 5.5
diff --git a/modules/fastspeech/pe.py b/modules/fastspeech/pe.py
new file mode 100644
index 0000000000000000000000000000000000000000..da0d46e3446bbf45d8ee3682edcaf0d8d64dcdfb
--- /dev/null
+++ b/modules/fastspeech/pe.py
@@ -0,0 +1,149 @@
+from modules.commons.common_layers import *
+from utils.hparams import hparams
+from modules.fastspeech.tts_modules import PitchPredictor
+from utils.pitch_utils import denorm_f0
+
+
+class Prenet(nn.Module):
+ def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None):
+ super(Prenet, self).__init__()
+ padding = kernel // 2
+ self.layers = []
+ self.strides = strides if strides is not None else [1] * n_layers
+ for l in range(n_layers):
+ self.layers.append(nn.Sequential(
+ nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]),
+ nn.ReLU(),
+ nn.BatchNorm1d(out_dim)
+ ))
+ in_dim = out_dim
+ self.layers = nn.ModuleList(self.layers)
+ self.out_proj = nn.Linear(out_dim, out_dim)
+
+ def forward(self, x):
+ """
+
+ :param x: [B, T, 80]
+ :return: [L, B, T, H], [B, T, H]
+ """
+ # padding_mask = x.abs().sum(-1).eq(0).data # [B, T]
+ padding_mask = x.abs().sum(-1).eq(0).detach()
+ nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T]
+ x = x.transpose(1, 2)
+ hiddens = []
+ for i, l in enumerate(self.layers):
+ nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]]
+ x = l(x) * nonpadding_mask_TB
+ hiddens.append(x)
+ hiddens = torch.stack(hiddens, 0) # [L, B, H, T]
+ hiddens = hiddens.transpose(2, 3) # [L, B, T, H]
+ x = self.out_proj(x.transpose(1, 2)) # [B, T, H]
+ x = x * nonpadding_mask_TB.transpose(1, 2)
+ return hiddens, x
+
+
+class ConvBlock(nn.Module):
+ def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0):
+ super().__init__()
+ self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride)
+ self.norm = norm
+ if self.norm == 'bn':
+ self.norm = nn.BatchNorm1d(n_chans)
+ elif self.norm == 'in':
+ self.norm = nn.InstanceNorm1d(n_chans, affine=True)
+ elif self.norm == 'gn':
+ self.norm = nn.GroupNorm(n_chans // 16, n_chans)
+ elif self.norm == 'ln':
+ self.norm = LayerNorm(n_chans // 16, n_chans)
+ elif self.norm == 'wn':
+ self.conv = torch.nn.utils.weight_norm(self.conv.conv)
+ self.dropout = nn.Dropout(dropout)
+ self.relu = nn.ReLU()
+
+ def forward(self, x):
+ """
+
+ :param x: [B, C, T]
+ :return: [B, C, T]
+ """
+ x = self.conv(x)
+ if not isinstance(self.norm, str):
+ if self.norm == 'none':
+ pass
+ elif self.norm == 'ln':
+ x = self.norm(x.transpose(1, 2)).transpose(1, 2)
+ else:
+ x = self.norm(x)
+ x = self.relu(x)
+ x = self.dropout(x)
+ return x
+
+
+class ConvStacks(nn.Module):
+ def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn',
+ dropout=0, strides=None, res=True):
+ super().__init__()
+ self.conv = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.res = res
+ self.in_proj = Linear(idim, n_chans)
+ if strides is None:
+ strides = [1] * n_layers
+ else:
+ assert len(strides) == n_layers
+ for idx in range(n_layers):
+ self.conv.append(ConvBlock(
+ n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout))
+ self.out_proj = Linear(n_chans, odim)
+
+ def forward(self, x, return_hiddens=False):
+ """
+
+ :param x: [B, T, H]
+ :return: [B, T, H]
+ """
+ x = self.in_proj(x)
+ x = x.transpose(1, -1) # (B, idim, Tmax)
+ hiddens = []
+ for f in self.conv:
+ x_ = f(x)
+ x = x + x_ if self.res else x_ # (B, C, Tmax)
+ hiddens.append(x)
+ x = x.transpose(1, -1)
+ x = self.out_proj(x) # (B, Tmax, H)
+ if return_hiddens:
+ hiddens = torch.stack(hiddens, 1) # [B, L, C, T]
+ return x, hiddens
+ return x
+
+
+class PitchExtractor(nn.Module):
+ def __init__(self, n_mel_bins=80, conv_layers=2):
+ super().__init__()
+ self.hidden_size = hparams['hidden_size']
+ self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size
+ self.conv_layers = conv_layers
+
+ self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1])
+ if self.conv_layers > 0:
+ self.mel_encoder = ConvStacks(
+ idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers)
+ self.pitch_predictor = PitchPredictor(
+ self.hidden_size, n_chans=self.predictor_hidden,
+ n_layers=5, dropout_rate=0.1, odim=2,
+ padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])
+
+ def forward(self, mel_input=None):
+ ret = {}
+ mel_hidden = self.mel_prenet(mel_input)[1]
+ if self.conv_layers > 0:
+ mel_hidden = self.mel_encoder(mel_hidden)
+
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden)
+
+ pitch_padding = mel_input.abs().sum(-1) == 0
+ use_uv = hparams['pitch_type'] == 'frame' #and hparams['use_uv']
+ ret['f0_denorm_pred'] = denorm_f0(
+ pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None,
+ hparams, pitch_padding=pitch_padding)
+ return ret
\ No newline at end of file
diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fdd417d3cc354a2d65f905326f8fe053dea6d97
--- /dev/null
+++ b/modules/fastspeech/tts_modules.py
@@ -0,0 +1,364 @@
+import logging
+import math
+
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from modules.commons.espnet_positional_embedding import RelPositionalEncoding
+from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC
+from utils.hparams import hparams
+
+DEFAULT_MAX_SOURCE_POSITIONS = 2000
+DEFAULT_MAX_TARGET_POSITIONS = 2000
+
+
+class TransformerEncoderLayer(nn.Module):
+ def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.op = EncSALayer(
+ hidden_size, num_heads, dropout=dropout,
+ attention_dropout=0.0, relu_dropout=dropout,
+ kernel_size=kernel_size
+ if kernel_size is not None else hparams['enc_ffn_kernel_size'],
+ padding=hparams['ffn_padding'],
+ norm=norm, act=hparams['ffn_act'])
+
+ def forward(self, x, **kwargs):
+ return self.op(x, **kwargs)
+
+
+######################
+# fastspeech modules
+######################
+class LayerNorm(torch.nn.LayerNorm):
+ """Layer normalization module.
+ :param int nout: output dim size
+ :param int dim: dimension to be normalized
+ """
+
+ def __init__(self, nout, dim=-1):
+ """Construct an LayerNorm object."""
+ super(LayerNorm, self).__init__(nout, eps=1e-12)
+ self.dim = dim
+
+ def forward(self, x):
+ """Apply layer normalization.
+ :param torch.Tensor x: input tensor
+ :return: layer normalized tensor
+ :rtype torch.Tensor
+ """
+ if self.dim == -1:
+ return super(LayerNorm, self).forward(x)
+ return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
+
+
+class DurationPredictor(torch.nn.Module):
+ """Duration predictor module.
+ This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
+ The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
+ .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
+ https://arxiv.org/pdf/1905.09263.pdf
+ Note:
+ The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
+ the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
+ """
+
+ def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'):
+ """Initilize duration predictor module.
+ Args:
+ idim (int): Input dimension.
+ n_layers (int, optional): Number of convolutional layers.
+ n_chans (int, optional): Number of channels of convolutional layers.
+ kernel_size (int, optional): Kernel size of convolutional layers.
+ dropout_rate (float, optional): Dropout rate.
+ offset (float, optional): Offset value to avoid nan in log domain.
+ """
+ super(DurationPredictor, self).__init__()
+ self.offset = offset
+ self.conv = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.padding = padding
+ for idx in range(n_layers):
+ in_chans = idim if idx == 0 else n_chans
+ self.conv += [torch.nn.Sequential(
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
+ if padding == 'SAME'
+ else (kernel_size - 1, 0), 0),
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
+ torch.nn.ReLU(),
+ LayerNorm(n_chans, dim=1),
+ torch.nn.Dropout(dropout_rate)
+ )]
+ if hparams['dur_loss'] in ['mse', 'huber']:
+ odims = 1
+ elif hparams['dur_loss'] == 'mog':
+ odims = 15
+ elif hparams['dur_loss'] == 'crf':
+ odims = 32
+ from torchcrf import CRF
+ self.crf = CRF(odims, batch_first=True)
+ self.linear = torch.nn.Linear(n_chans, odims)
+
+ def _forward(self, xs, x_masks=None, is_inference=False):
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
+ for f in self.conv:
+ xs = f(xs) # (B, C, Tmax)
+ if x_masks is not None:
+ xs = xs * (1 - x_masks.float())[:, None, :]
+
+ xs = self.linear(xs.transpose(1, -1)) # [B, T, C]
+ xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C)
+ if is_inference:
+ return self.out2dur(xs), xs
+ else:
+ if hparams['dur_loss'] in ['mse']:
+ xs = xs.squeeze(-1) # (B, Tmax)
+ return xs
+
+ def out2dur(self, xs):
+ if hparams['dur_loss'] in ['mse']:
+ # NOTE: calculate in log domain
+ xs = xs.squeeze(-1) # (B, Tmax)
+ dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value
+ elif hparams['dur_loss'] == 'mog':
+ return NotImplementedError
+ elif hparams['dur_loss'] == 'crf':
+ dur = torch.LongTensor(self.crf.decode(xs)).cuda()
+ return dur
+
+ def forward(self, xs, x_masks=None):
+ """Calculate forward propagation.
+ Args:
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
+ Returns:
+ Tensor: Batch of predicted durations in log domain (B, Tmax).
+ """
+ return self._forward(xs, x_masks, False)
+
+ def inference(self, xs, x_masks=None):
+ """Inference duration.
+ Args:
+ xs (Tensor): Batch of input sequences (B, Tmax, idim).
+ x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
+ Returns:
+ LongTensor: Batch of predicted durations in linear domain (B, Tmax).
+ """
+ return self._forward(xs, x_masks, True)
+
+
+class LengthRegulator(torch.nn.Module):
+ def __init__(self, pad_value=0.0):
+ super(LengthRegulator, self).__init__()
+ self.pad_value = pad_value
+
+ def forward(self, dur, dur_padding=None, alpha=1.0):
+ """
+ Example (no batch dim version):
+ 1. dur = [2,2,3]
+ 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4]
+ 3. token_mask = [[1,1,0,0,0,0,0],
+ [0,0,1,1,0,0,0],
+ [0,0,0,0,1,1,1]]
+ 4. token_idx * token_mask = [[1,1,0,0,0,0,0],
+ [0,0,2,2,0,0,0],
+ [0,0,0,0,3,3,3]]
+ 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3]
+
+ :param dur: Batch of durations of each frame (B, T_txt)
+ :param dur_padding: Batch of padding of each frame (B, T_txt)
+ :param alpha: duration rescale coefficient
+ :return:
+ mel2ph (B, T_speech)
+ """
+ assert alpha > 0
+ dur = torch.round(dur.float() * alpha).long()
+ if dur_padding is not None:
+ dur = dur * (1 - dur_padding.long())
+ token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device)
+ dur_cumsum = torch.cumsum(dur, 1)
+ dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0)
+
+ pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device)
+ token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None])
+ mel2ph = (token_idx * token_mask.long()).sum(1)
+ return mel2ph
+
+
+class PitchPredictor(torch.nn.Module):
+ def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5,
+ dropout_rate=0.1, padding='SAME'):
+ """Initilize pitch predictor module.
+ Args:
+ idim (int): Input dimension.
+ n_layers (int, optional): Number of convolutional layers.
+ n_chans (int, optional): Number of channels of convolutional layers.
+ kernel_size (int, optional): Kernel size of convolutional layers.
+ dropout_rate (float, optional): Dropout rate.
+ """
+ super(PitchPredictor, self).__init__()
+ self.conv = torch.nn.ModuleList()
+ self.kernel_size = kernel_size
+ self.padding = padding
+ for idx in range(n_layers):
+ in_chans = idim if idx == 0 else n_chans
+ self.conv += [torch.nn.Sequential(
+ torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2)
+ if padding == 'SAME'
+ else (kernel_size - 1, 0), 0),
+ torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0),
+ torch.nn.ReLU(),
+ LayerNorm(n_chans, dim=1),
+ torch.nn.Dropout(dropout_rate)
+ )]
+ self.linear = torch.nn.Linear(n_chans, odim)
+ self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096)
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1]))
+
+ def forward(self, xs):
+ """
+
+ :param xs: [B, T, H]
+ :return: [B, T, H]
+ """
+ positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0])
+ xs = xs + positions
+ xs = xs.transpose(1, -1) # (B, idim, Tmax)
+ for f in self.conv:
+ xs = f(xs) # (B, C, Tmax)
+ # NOTE: calculate in log domain
+ xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H)
+ return xs
+
+
+class EnergyPredictor(PitchPredictor):
+ pass
+
+
+def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
+ B, _ = mel2ph.shape
+ dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph))
+ dur = dur[:, 1:]
+ if max_dur is not None:
+ dur = dur.clamp(max=max_dur)
+ return dur
+
+
+class FFTBlocks(nn.Module):
+ def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2,
+ use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True):
+ super().__init__()
+ self.num_layers = num_layers
+ embed_dim = self.hidden_size = hidden_size
+ self.dropout = dropout if dropout is not None else hparams['dropout']
+ self.use_pos_embed = use_pos_embed
+ self.use_last_norm = use_last_norm
+ if use_pos_embed:
+ self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS
+ self.padding_idx = 0
+ self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ self.layers = nn.ModuleList([])
+ self.layers.extend([
+ TransformerEncoderLayer(self.hidden_size, self.dropout,
+ kernel_size=ffn_kernel_size, num_heads=num_heads)
+ for _ in range(self.num_layers)
+ ])
+ if self.use_last_norm:
+ if norm == 'ln':
+ self.layer_norm = nn.LayerNorm(embed_dim)
+ elif norm == 'bn':
+ self.layer_norm = BatchNorm1dTBC(embed_dim)
+ else:
+ self.layer_norm = None
+
+ def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param x: [B, T, C]
+ :param padding_mask: [B, T]
+ :return: [B, T, C] or [L, B, T, C]
+ """
+ # padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ padding_mask = x.abs().sum(-1).eq(0).detach() if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+ return x
+
+
+class FastspeechEncoder(FFTBlocks):
+ '''
+ compared to FFTBlocks:
+ - input is [B, T, H], not [B, T, C]
+ - supports "relative" positional encoding
+ '''
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2):
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
+ kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads,
+ use_pos_embed=False) # use_pos_embed_alpha for compatibility
+ #self.embed_tokens = embed_tokens
+ self.embed_scale = math.sqrt(hidden_size)
+ self.padding_idx = 0
+ if hparams.get('rel_pos') is not None and hparams['rel_pos']:
+ self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0)
+ else:
+ self.embed_positions = SinusoidalPositionalEmbedding(
+ hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS,
+ )
+
+ def forward(self, hubert):
+ """
+
+ :param hubert: [B, T, H ]
+ :return: {
+ 'encoder_out': [T x B x C]
+ }
+ """
+ # encoder_padding_mask = txt_tokens.eq(self.padding_idx).data
+ encoder_padding_mask = (hubert==0).all(-1)
+ x = self.forward_embedding(hubert) # [B, T, H]
+ x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask)
+ return x
+
+ def forward_embedding(self, hubert):
+ # embed tokens and positions
+ x = self.embed_scale * hubert
+ if hparams['use_pos_embed']:
+ positions = self.embed_positions(hubert)
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ return x
+
+
+class FastspeechDecoder(FFTBlocks):
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
+ num_heads = hparams['num_heads'] if num_heads is None else num_heads
+ hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size
+ kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size
+ num_layers = hparams['dec_layers'] if num_layers is None else num_layers
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
+
diff --git a/modules/hifigan/__pycache__/hifigan.cpython-38.pyc b/modules/hifigan/__pycache__/hifigan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2e8491c1afef1e7e079b27962e8cb06b6efea0b4
Binary files /dev/null and b/modules/hifigan/__pycache__/hifigan.cpython-38.pyc differ
diff --git a/modules/hifigan/hifigan.py b/modules/hifigan/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae7e61f56b00d60bcc49a18ece3edbe54746f7ea
--- /dev/null
+++ b/modules/hifigan/hifigan.py
@@ -0,0 +1,365 @@
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+
+from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
+from modules.parallel_wavegan.models.source import SourceModuleHnNSF
+import numpy as np
+
+LRELU_SLOPE = 0.1
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size * dilation - dilation) / 2)
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.h = h
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Conv1d1x1(Conv1d):
+ """1x1 Conv1d with customized initialization."""
+
+ def __init__(self, in_channels, out_channels, bias):
+ """Initialize 1x1 Conv1d module."""
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
+ kernel_size=1, padding=0,
+ dilation=1, bias=bias)
+
+
+class HifiGanGenerator(torch.nn.Module):
+ def __init__(self, h, c_out=1):
+ super(HifiGanGenerator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h['resblock_kernel_sizes'])
+ self.num_upsamples = len(h['upsample_rates'])
+
+ if h['use_pitch_embed']:
+ self.harmonic_num = 8
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=h['audio_sample_rate'],
+ harmonic_num=self.harmonic_num)
+ self.noise_convs = nn.ModuleList()
+ self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
+ resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
+ c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
+ self.ups.append(weight_norm(
+ ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
+ if h['use_pitch_embed']:
+ if i + 1 < len(h['upsample_rates']):
+ stride_f0 = np.prod(h['upsample_rates'][i + 1:])
+ self.noise_convs.append(Conv1d(
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h['upsample_initial_channel'] // (2 ** (i + 1))
+ for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x, f0=None):
+ if f0 is not None:
+ # harmonic-source signal, noise-source signal, uv flag
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
+ har_source, noi_source, uv = self.m_source(f0)
+ har_source = har_source.transpose(1, 2)
+
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ if f0 is not None:
+ x_source = self.noise_convs[i](har_source)
+ x = x + x_source
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
+ super(DiscriminatorP, self).__init__()
+ self.use_cond = use_cond
+ if use_cond:
+ from utils.hparams import hparams
+ t = hparams['hop_size']
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
+ c_in = 2
+
+ self.period = period
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x, mel):
+ fmap = []
+ if self.use_cond:
+ x_mel = self.cond_net(mel)
+ x = torch.cat([x_mel, x], 1)
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, use_cond=False, c_in=1):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
+ DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
+ DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
+ DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
+ DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
+ ])
+
+ def forward(self, y, y_hat, mel=None):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y, mel)
+ y_d_g, fmap_g = d(y_hat, mel)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
+ super(DiscriminatorS, self).__init__()
+ self.use_cond = use_cond
+ if use_cond:
+ t = np.prod(upsample_rates)
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
+ c_in = 2
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x, mel):
+ if self.use_cond:
+ x_mel = self.cond_net(mel)
+ x = torch.cat([x_mel, x], 1)
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+ def __init__(self, use_cond=False, c_in=1):
+ super(MultiScaleDiscriminator, self).__init__()
+ from utils.hparams import hparams
+ self.discriminators = nn.ModuleList([
+ DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
+ upsample_rates=[4, 4, hparams['hop_size'] // 16],
+ c_in=c_in),
+ DiscriminatorS(use_cond=use_cond,
+ upsample_rates=[4, 4, hparams['hop_size'] // 32],
+ c_in=c_in),
+ DiscriminatorS(use_cond=use_cond,
+ upsample_rates=[4, 4, hparams['hop_size'] // 64],
+ c_in=c_in),
+ ])
+ self.meanpools = nn.ModuleList([
+ AvgPool1d(4, 2, padding=1),
+ AvgPool1d(4, 2, padding=1)
+ ])
+
+ def forward(self, y, y_hat, mel=None):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i - 1](y)
+ y_hat = self.meanpools[i - 1](y_hat)
+ y_d_r, fmap_r = d(y, mel)
+ y_d_g, fmap_g = d(y_hat, mel)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss * 2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ r_losses = 0
+ g_losses = 0
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1 - dr) ** 2)
+ g_loss = torch.mean(dg ** 2)
+ r_losses += r_loss
+ g_losses += g_loss
+ r_losses = r_losses / len(disc_real_outputs)
+ g_losses = g_losses / len(disc_real_outputs)
+ return r_losses, g_losses
+
+
+def cond_discriminator_loss(outputs):
+ loss = 0
+ for dg in outputs:
+ g_loss = torch.mean(dg ** 2)
+ loss += g_loss
+ loss = loss / len(outputs)
+ return loss
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ for dg in disc_outputs:
+ l = torch.mean((1 - dg) ** 2)
+ loss += l
+ loss = loss / len(disc_outputs)
+ return loss
diff --git a/modules/hifigan/mel_utils.py b/modules/hifigan/mel_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..06e0f7d4d16fa3e4aefc8949347455f5a6e938da
--- /dev/null
+++ b/modules/hifigan/mel_utils.py
@@ -0,0 +1,80 @@
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+
+MAX_WAV_VALUE = 32768.0
+
+
+def load_wav(full_path):
+ sampling_rate, data = read(full_path)
+ return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ output = dynamic_range_compression_torch(magnitudes)
+ return output
+
+
+def spectral_de_normalize_torch(magnitudes):
+ output = dynamic_range_decompression_torch(magnitudes)
+ return output
+
+
+mel_basis = {}
+hann_window = {}
+
+
+def mel_spectrogram(y, hparams, center=False, complex=False):
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+ # fmax: 10000 # To be increased/reduced depending on data.
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
+ n_fft = hparams['fft_size']
+ num_mels = hparams['audio_num_mel_bins']
+ sampling_rate = hparams['audio_sample_rate']
+ hop_size = hparams['hop_size']
+ win_size = hparams['win_size']
+ fmin = hparams['fmin']
+ fmax = hparams['fmax']
+ y = y.clamp(min=-1., max=1.)
+ global mel_basis, hann_window
+ if fmax not in mel_basis:
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
+ mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+
+ if not complex:
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
+ spec = spectral_normalize_torch(spec)
+ else:
+ B, C, T, _ = spec.shape
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
+ return spec
diff --git a/modules/nsf_hifigan/__pycache__/env.cpython-38.pyc b/modules/nsf_hifigan/__pycache__/env.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..e221d3323729f3a79720201e7e1601fba9cdbf32
Binary files /dev/null and b/modules/nsf_hifigan/__pycache__/env.cpython-38.pyc differ
diff --git a/modules/nsf_hifigan/__pycache__/models.cpython-38.pyc b/modules/nsf_hifigan/__pycache__/models.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d79fb29e052aac1f1d9961e2bf745385947aa88
Binary files /dev/null and b/modules/nsf_hifigan/__pycache__/models.cpython-38.pyc differ
diff --git a/modules/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc b/modules/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0e19b55b614866765e2de2048952b21780cab49f
Binary files /dev/null and b/modules/nsf_hifigan/__pycache__/nvSTFT.cpython-38.pyc differ
diff --git a/modules/nsf_hifigan/__pycache__/utils.cpython-38.pyc b/modules/nsf_hifigan/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..39e546b658232b77a1250b0e54fee8a7a56e070f
Binary files /dev/null and b/modules/nsf_hifigan/__pycache__/utils.cpython-38.pyc differ
diff --git a/modules/nsf_hifigan/env.py b/modules/nsf_hifigan/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..02a2739a3eb25764527a14b46347ee386b0241b0
--- /dev/null
+++ b/modules/nsf_hifigan/env.py
@@ -0,0 +1,15 @@
+import os
+import shutil
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+
+def build_env(config, config_name, path):
+ t_path = os.path.join(path, config_name)
+ if config != t_path:
+ os.makedirs(path, exist_ok=True)
+ shutil.copyfile(config, os.path.join(path, config_name))
\ No newline at end of file
diff --git a/modules/nsf_hifigan/models.py b/modules/nsf_hifigan/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe065e2cab257e48a49ed22e53e52035c0d38b80
--- /dev/null
+++ b/modules/nsf_hifigan/models.py
@@ -0,0 +1,549 @@
+import os
+import json
+from .env import AttrDict
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.nn as nn
+from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
+from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
+from .utils import init_weights, get_padding
+
+LRELU_SLOPE = 0.1
+
+def load_model(model_path, device='cuda'):
+ config_file = os.path.join(os.path.split(model_path)[0], 'config.json')
+ with open(config_file) as f:
+ data = f.read()
+
+ global h
+ json_config = json.loads(data)
+ h = AttrDict(json_config)
+
+ generator = Generator(h).to(device)
+
+ cp_dict = torch.load(model_path, map_location='cpu')
+ generator.load_state_dict(cp_dict['generator'])
+ generator.eval()
+ generator.remove_weight_norm()
+ del cp_dict
+ return generator, h
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super(ResBlock1, self).__init__()
+ self.h = h
+ self.convs1 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2])))
+ ])
+ self.convs1.apply(init_weights)
+
+ self.convs2 = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
+ padding=get_padding(kernel_size, 1)))
+ ])
+ self.convs2.apply(init_weights)
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c1(xt)
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs1:
+ remove_weight_norm(l)
+ for l in self.convs2:
+ remove_weight_norm(l)
+
+
+class ResBlock2(torch.nn.Module):
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
+ super(ResBlock2, self).__init__()
+ self.h = h
+ self.convs = nn.ModuleList([
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]))),
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1])))
+ ])
+ self.convs.apply(init_weights)
+
+ def forward(self, x):
+ for c in self.convs:
+ xt = F.leaky_relu(x, LRELU_SLOPE)
+ xt = c(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for l in self.convs:
+ remove_weight_norm(l)
+
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ self.ups.append(weight_norm(
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+ k, u, padding=(k-u)//2)))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel//(2**(i+1))
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x):
+ x = self.conv_pre(x)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ x = self.ups[i](x)
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voiced_threshold)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+ device=f0_values.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+ # for normal case
+
+ # To prevent torch.cumsum numerical overflow,
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
+ tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
+ * 2 * np.pi)
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ with torch.no_grad():
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+ device=f0.device)
+ # fundamental component
+ f0_buf[:, :, 0] = f0[:, :, 0]
+ for idx in np.arange(self.harmonic_num):
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ sine_wavs, uv, _ = self.l_sin_gen(x)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+class Generator(torch.nn.Module):
+ def __init__(self, h):
+ super(Generator, self).__init__()
+ self.h = h
+ self.num_kernels = len(h.resblock_kernel_sizes)
+ self.num_upsamples = len(h.upsample_rates)
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates))
+ self.m_source = SourceModuleHnNSF(
+ sampling_rate=h.sampling_rate,
+ harmonic_num=8)
+ self.noise_convs = nn.ModuleList()
+ self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
+
+ self.ups = nn.ModuleList()
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
+ self.ups.append(weight_norm(
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
+ k, u, padding=(k-u)//2)))
+ if i + 1 < len(h.upsample_rates):#
+ stride_f0 = np.prod(h.upsample_rates[i + 1:])
+ self.noise_convs.append(Conv1d(
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
+ else:
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = h.upsample_initial_channel//(2**(i+1))
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
+ self.resblocks.append(resblock(h, ch, k, d))
+
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+ self.ups.apply(init_weights)
+ self.conv_post.apply(init_weights)
+
+ def forward(self, x,f0):
+ # print(1,x.shape,f0.shape,f0[:, None].shape)
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)#bs,n,t
+ # print(2,f0.shape)
+ har_source, noi_source, uv = self.m_source(f0)
+ har_source = har_source.transpose(1, 2)
+ x = self.conv_pre(x)
+ # print(124,x.shape,har_source.shape)
+ for i in range(self.num_upsamples):
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ # print(3,x.shape)
+ x = self.ups[i](x)
+ x_source = self.noise_convs[i](har_source)
+ # print(4,x_source.shape,har_source.shape,x.shape)
+ x = x + x_source
+ xs = None
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i*self.num_kernels+j](x)
+ else:
+ xs += self.resblocks[i*self.num_kernels+j](x)
+ x = xs / self.num_kernels
+ x = F.leaky_relu(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ print('Removing weight norm...')
+ for l in self.ups:
+ remove_weight_norm(l)
+ for l in self.resblocks:
+ l.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+class DiscriminatorP(torch.nn.Module):
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
+ super(DiscriminatorP, self).__init__()
+ self.period = period
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
+ ])
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
+
+ def forward(self, x):
+ fmap = []
+
+ # 1d to 2d
+ b, c, t = x.shape
+ if t % self.period != 0: # pad first
+ n_pad = self.period - (t % self.period)
+ x = F.pad(x, (0, n_pad), "reflect")
+ t = t + n_pad
+ x = x.view(b, c, t // self.period, self.period)
+
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiPeriodDiscriminator(torch.nn.Module):
+ def __init__(self, periods=None):
+ super(MultiPeriodDiscriminator, self).__init__()
+ self.periods = periods if periods is not None else [2, 3, 5, 7, 11]
+ self.discriminators = nn.ModuleList()
+ for period in self.periods:
+ self.discriminators.append(DiscriminatorP(period))
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+class DiscriminatorS(torch.nn.Module):
+ def __init__(self, use_spectral_norm=False):
+ super(DiscriminatorS, self).__init__()
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
+ self.convs = nn.ModuleList([
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
+ ])
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
+
+ def forward(self, x):
+ fmap = []
+ for l in self.convs:
+ x = l(x)
+ x = F.leaky_relu(x, LRELU_SLOPE)
+ fmap.append(x)
+ x = self.conv_post(x)
+ fmap.append(x)
+ x = torch.flatten(x, 1, -1)
+
+ return x, fmap
+
+
+class MultiScaleDiscriminator(torch.nn.Module):
+ def __init__(self):
+ super(MultiScaleDiscriminator, self).__init__()
+ self.discriminators = nn.ModuleList([
+ DiscriminatorS(use_spectral_norm=True),
+ DiscriminatorS(),
+ DiscriminatorS(),
+ ])
+ self.meanpools = nn.ModuleList([
+ AvgPool1d(4, 2, padding=2),
+ AvgPool1d(4, 2, padding=2)
+ ])
+
+ def forward(self, y, y_hat):
+ y_d_rs = []
+ y_d_gs = []
+ fmap_rs = []
+ fmap_gs = []
+ for i, d in enumerate(self.discriminators):
+ if i != 0:
+ y = self.meanpools[i-1](y)
+ y_hat = self.meanpools[i-1](y_hat)
+ y_d_r, fmap_r = d(y)
+ y_d_g, fmap_g = d(y_hat)
+ y_d_rs.append(y_d_r)
+ fmap_rs.append(fmap_r)
+ y_d_gs.append(y_d_g)
+ fmap_gs.append(fmap_g)
+
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
+
+
+def feature_loss(fmap_r, fmap_g):
+ loss = 0
+ for dr, dg in zip(fmap_r, fmap_g):
+ for rl, gl in zip(dr, dg):
+ loss += torch.mean(torch.abs(rl - gl))
+
+ return loss*2
+
+
+def discriminator_loss(disc_real_outputs, disc_generated_outputs):
+ loss = 0
+ r_losses = []
+ g_losses = []
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
+ r_loss = torch.mean((1-dr)**2)
+ g_loss = torch.mean(dg**2)
+ loss += (r_loss + g_loss)
+ r_losses.append(r_loss.item())
+ g_losses.append(g_loss.item())
+
+ return loss, r_losses, g_losses
+
+
+def generator_loss(disc_outputs):
+ loss = 0
+ gen_losses = []
+ for dg in disc_outputs:
+ l = torch.mean((1-dg)**2)
+ gen_losses.append(l)
+ loss += l
+
+ return loss, gen_losses
\ No newline at end of file
diff --git a/modules/nsf_hifigan/nvSTFT.py b/modules/nsf_hifigan/nvSTFT.py
new file mode 100644
index 0000000000000000000000000000000000000000..35635c844ea1ae6258112f0ba92e417e81a22642
--- /dev/null
+++ b/modules/nsf_hifigan/nvSTFT.py
@@ -0,0 +1,111 @@
+import math
+import os
+os.environ["LRU_CACHE_CAPACITY"] = "3"
+import random
+import torch
+import torch.utils.data
+import numpy as np
+import librosa
+from librosa.util import normalize
+from librosa.filters import mel as librosa_mel_fn
+from scipy.io.wavfile import read
+import soundfile as sf
+
+def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
+ sampling_rate = None
+ try:
+ data, sampling_rate = sf.read(full_path, always_2d=True)# than soundfile.
+ except Exception as ex:
+ print(f"'{full_path}' failed to load.\nException:")
+ print(ex)
+ if return_empty_on_exception:
+ return [], sampling_rate or target_sr or 48000
+ else:
+ raise Exception(ex)
+
+ if len(data.shape) > 1:
+ data = data[:, 0]
+ assert len(data) > 2# check duration of audio file is > 2 samples (because otherwise the slice operation was on the wrong dimension)
+
+ if np.issubdtype(data.dtype, np.integer): # if audio data is type int
+ max_mag = -np.iinfo(data.dtype).min # maximum magnitude = min possible value of intXX
+ else: # if audio data is type fp32
+ max_mag = max(np.amax(data), -np.amin(data))
+ max_mag = (2**31)+1 if max_mag > (2**15) else ((2**15)+1 if max_mag > 1.01 else 1.0) # data should be either 16-bit INT, 32-bit INT or [-1 to 1] float32
+
+ data = torch.FloatTensor(data.astype(np.float32))/max_mag
+
+ if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception:# resample will crash with inf/NaN inputs. return_empty_on_exception will return empty arr instead of except
+ return [], sampling_rate or target_sr or 48000
+ if target_sr is not None and sampling_rate != target_sr:
+ data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sampling_rate, target_sr=target_sr))
+ sampling_rate = target_sr
+
+ return data, sampling_rate
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+class STFT():
+ def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
+ self.target_sr = sr
+
+ self.n_mels = n_mels
+ self.n_fft = n_fft
+ self.win_size = win_size
+ self.hop_length = hop_length
+ self.fmin = fmin
+ self.fmax = fmax
+ self.clip_val = clip_val
+ self.mel_basis = {}
+ self.hann_window = {}
+
+ def get_mel(self, y, center=False):
+ sampling_rate = self.target_sr
+ n_mels = self.n_mels
+ n_fft = self.n_fft
+ win_size = self.win_size
+ hop_length = self.hop_length
+ fmin = self.fmin
+ fmax = self.fmax
+ clip_val = self.clip_val
+
+ if torch.min(y) < -1.:
+ print('min value is ', torch.min(y))
+ if torch.max(y) > 1.:
+ print('max value is ', torch.max(y))
+
+ if fmax not in self.mel_basis:
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
+ self.mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
+ self.hann_window[str(y.device)] = torch.hann_window(self.win_size).to(y.device)
+
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_length)/2), int((n_fft-hop_length)/2)), mode='reflect')
+ y = y.squeeze(1)
+
+ spec = torch.stft(y, n_fft, hop_length=hop_length, win_length=win_size, window=self.hann_window[str(y.device)],
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
+ # print(111,spec)
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
+ # print(222,spec)
+ spec = torch.matmul(self.mel_basis[str(fmax)+'_'+str(y.device)], spec)
+ # print(333,spec)
+ spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
+ # print(444,spec)
+ return spec
+
+ def __call__(self, audiopath):
+ audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
+ spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
+ return spect
+
+stft = STFT()
\ No newline at end of file
diff --git a/modules/nsf_hifigan/utils.py b/modules/nsf_hifigan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ad5a37507987bd6e1200cb9241edf97989e3ead
--- /dev/null
+++ b/modules/nsf_hifigan/utils.py
@@ -0,0 +1,67 @@
+import glob
+import os
+import matplotlib
+import torch
+from torch.nn.utils import weight_norm
+matplotlib.use("Agg")
+import matplotlib.pylab as plt
+
+
+def plot_spectrogram(spectrogram):
+ fig, ax = plt.subplots(figsize=(10, 2))
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
+ interpolation='none')
+ plt.colorbar(im, ax=ax)
+
+ fig.canvas.draw()
+ plt.close()
+
+ return fig
+
+
+def init_weights(m, mean=0.0, std=0.01):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ m.weight.data.normal_(mean, std)
+
+
+def apply_weight_norm(m):
+ classname = m.__class__.__name__
+ if classname.find("Conv") != -1:
+ weight_norm(m)
+
+
+def get_padding(kernel_size, dilation=1):
+ return int((kernel_size*dilation - dilation)/2)
+
+
+def load_checkpoint(filepath, device):
+ assert os.path.isfile(filepath)
+ print("Loading '{}'".format(filepath))
+ checkpoint_dict = torch.load(filepath, map_location=device)
+ print("Complete.")
+ return checkpoint_dict
+
+
+def save_checkpoint(filepath, obj):
+ print("Saving checkpoint to {}".format(filepath))
+ torch.save(obj, filepath)
+ print("Complete.")
+
+
+def del_old_checkpoints(cp_dir, prefix, n_models=2):
+ pattern = os.path.join(cp_dir, prefix + '????????')
+ cp_list = glob.glob(pattern) # get checkpoint paths
+ cp_list = sorted(cp_list)# sort by iter
+ if len(cp_list) > n_models: # if more than n_models models are found
+ for cp in cp_list[:-n_models]:# delete the oldest models other than lastest n_models
+ open(cp, 'w').close()# empty file contents
+ os.unlink(cp)# delete file (move to trash when using Colab)
+
+
+def scan_checkpoint(cp_dir, prefix):
+ pattern = os.path.join(cp_dir, prefix + '????????')
+ cp_list = glob.glob(pattern)
+ if len(cp_list) == 0:
+ return None
+ return sorted(cp_list)[-1]
\ No newline at end of file
diff --git a/modules/parallel_wavegan/__init__.py b/modules/parallel_wavegan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/parallel_wavegan/__pycache__/__init__.cpython-38.pyc b/modules/parallel_wavegan/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6a72dc078a6888fa6ecc164fddba4000a1021029
Binary files /dev/null and b/modules/parallel_wavegan/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__init__.py b/modules/parallel_wavegan/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e477f51116a3157781b1aefefbaf32fe4d4bd1f0
--- /dev/null
+++ b/modules/parallel_wavegan/layers/__init__.py
@@ -0,0 +1,5 @@
+from .causal_conv import * # NOQA
+from .pqmf import * # NOQA
+from .residual_block import * # NOQA
+from modules.parallel_wavegan.layers.residual_stack import * # NOQA
+from .upsample import * # NOQA
diff --git a/modules/parallel_wavegan/layers/__pycache__/__init__.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86f5d78d83d5daf5c56a45255dbaa381804cbcc2
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__pycache__/causal_conv.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/causal_conv.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6ab116b8671479cd5fa56a9553544a1924a5fb9
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/causal_conv.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__pycache__/pqmf.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/pqmf.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..291742627bc50217be5d191aa0b9838c4e961c6b
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/pqmf.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__pycache__/residual_block.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/residual_block.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7d3b43a4af9460e84063fecd2496273512d6428f
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/residual_block.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__pycache__/residual_stack.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/residual_stack.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f88f24a4aeb61a2ffdbb8c4c17832d5d8cfcdf45
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/residual_stack.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/__pycache__/upsample.cpython-38.pyc b/modules/parallel_wavegan/layers/__pycache__/upsample.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7942fd2ecd2b6a54b524cb71217440ebcf21581
Binary files /dev/null and b/modules/parallel_wavegan/layers/__pycache__/upsample.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/layers/causal_conv.py b/modules/parallel_wavegan/layers/causal_conv.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca77daf65f234e6fbe355ed148fc8f0ee85038a
--- /dev/null
+++ b/modules/parallel_wavegan/layers/causal_conv.py
@@ -0,0 +1,56 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Causal convolusion layer modules."""
+
+
+import torch
+
+
+class CausalConv1d(torch.nn.Module):
+ """CausalConv1d module with customized initialization."""
+
+ def __init__(self, in_channels, out_channels, kernel_size,
+ dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}):
+ """Initialize CausalConv1d module."""
+ super(CausalConv1d, self).__init__()
+ self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params)
+ self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size,
+ dilation=dilation, bias=bias)
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T).
+
+ """
+ return self.conv(self.pad(x))[:, :, :x.size(2)]
+
+
+class CausalConvTranspose1d(torch.nn.Module):
+ """CausalConvTranspose1d module with customized initialization."""
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True):
+ """Initialize CausalConvTranspose1d module."""
+ super(CausalConvTranspose1d, self).__init__()
+ self.deconv = torch.nn.ConvTranspose1d(
+ in_channels, out_channels, kernel_size, stride, bias=bias)
+ self.stride = stride
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, in_channels, T_in).
+
+ Returns:
+ Tensor: Output tensor (B, out_channels, T_out).
+
+ """
+ return self.deconv(x)[:, :, :-self.stride]
diff --git a/modules/parallel_wavegan/layers/pqmf.py b/modules/parallel_wavegan/layers/pqmf.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac21074fd32a370a099fa2facb62cfd3253d7579
--- /dev/null
+++ b/modules/parallel_wavegan/layers/pqmf.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Pseudo QMF modules."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from scipy.signal import kaiser
+
+
+def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0):
+ """Design prototype filter for PQMF.
+
+ This method is based on `A Kaiser window approach for the design of prototype
+ filters of cosine modulated filterbanks`_.
+
+ Args:
+ taps (int): The number of filter taps.
+ cutoff_ratio (float): Cut-off frequency ratio.
+ beta (float): Beta coefficient for kaiser window.
+
+ Returns:
+ ndarray: Impluse response of prototype filter (taps + 1,).
+
+ .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
+ https://ieeexplore.ieee.org/abstract/document/681427
+
+ """
+ # check the arguments are valid
+ assert taps % 2 == 0, "The number of taps mush be even number."
+ assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0."
+
+ # make initial filter
+ omega_c = np.pi * cutoff_ratio
+ with np.errstate(invalid='ignore'):
+ h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \
+ / (np.pi * (np.arange(taps + 1) - 0.5 * taps))
+ h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
+
+ # apply kaiser window
+ w = kaiser(taps + 1, beta)
+ h = h_i * w
+
+ return h
+
+
+class PQMF(torch.nn.Module):
+ """PQMF module.
+
+ This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
+
+ .. _`Near-perfect-reconstruction pseudo-QMF banks`:
+ https://ieeexplore.ieee.org/document/258122
+
+ """
+
+ def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0):
+ """Initilize PQMF module.
+
+ Args:
+ subbands (int): The number of subbands.
+ taps (int): The number of filter taps.
+ cutoff_ratio (float): Cut-off frequency ratio.
+ beta (float): Beta coefficient for kaiser window.
+
+ """
+ super(PQMF, self).__init__()
+
+ # define filter coefficient
+ h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
+ h_analysis = np.zeros((subbands, len(h_proto)))
+ h_synthesis = np.zeros((subbands, len(h_proto)))
+ for k in range(subbands):
+ h_analysis[k] = 2 * h_proto * np.cos(
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
+ (np.arange(taps + 1) - ((taps - 1) / 2)) +
+ (-1) ** k * np.pi / 4)
+ h_synthesis[k] = 2 * h_proto * np.cos(
+ (2 * k + 1) * (np.pi / (2 * subbands)) *
+ (np.arange(taps + 1) - ((taps - 1) / 2)) -
+ (-1) ** k * np.pi / 4)
+
+ # convert to tensor
+ analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
+ synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
+
+ # register coefficients as beffer
+ self.register_buffer("analysis_filter", analysis_filter)
+ self.register_buffer("synthesis_filter", synthesis_filter)
+
+ # filter for downsampling & upsampling
+ updown_filter = torch.zeros((subbands, subbands, subbands)).float()
+ for k in range(subbands):
+ updown_filter[k, k, 0] = 1.0
+ self.register_buffer("updown_filter", updown_filter)
+ self.subbands = subbands
+
+ # keep padding info
+ self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
+
+ def analysis(self, x):
+ """Analysis with PQMF.
+
+ Args:
+ x (Tensor): Input tensor (B, 1, T).
+
+ Returns:
+ Tensor: Output tensor (B, subbands, T // subbands).
+
+ """
+ x = F.conv1d(self.pad_fn(x), self.analysis_filter)
+ return F.conv1d(x, self.updown_filter, stride=self.subbands)
+
+ def synthesis(self, x):
+ """Synthesis with PQMF.
+
+ Args:
+ x (Tensor): Input tensor (B, subbands, T // subbands).
+
+ Returns:
+ Tensor: Output tensor (B, 1, T).
+
+ """
+ x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands)
+ return F.conv1d(self.pad_fn(x), self.synthesis_filter)
diff --git a/modules/parallel_wavegan/layers/residual_block.py b/modules/parallel_wavegan/layers/residual_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a267a86c1fa521c2824addf9dda304c43f1ff1f
--- /dev/null
+++ b/modules/parallel_wavegan/layers/residual_block.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+"""Residual block module in WaveNet.
+
+This code is modified from https://github.com/r9y9/wavenet_vocoder.
+
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+
+
+class Conv1d(torch.nn.Conv1d):
+ """Conv1d module with customized initialization."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize Conv1d module."""
+ super(Conv1d, self).__init__(*args, **kwargs)
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu")
+ if self.bias is not None:
+ torch.nn.init.constant_(self.bias, 0.0)
+
+
+class Conv1d1x1(Conv1d):
+ """1x1 Conv1d with customized initialization."""
+
+ def __init__(self, in_channels, out_channels, bias):
+ """Initialize 1x1 Conv1d module."""
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
+ kernel_size=1, padding=0,
+ dilation=1, bias=bias)
+
+
+class ResidualBlock(torch.nn.Module):
+ """Residual block module in WaveNet."""
+
+ def __init__(self,
+ kernel_size=3,
+ residual_channels=64,
+ gate_channels=128,
+ skip_channels=64,
+ aux_channels=80,
+ dropout=0.0,
+ dilation=1,
+ bias=True,
+ use_causal_conv=False
+ ):
+ """Initialize ResidualBlock module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ residual_channels (int): Number of channels for residual connection.
+ skip_channels (int): Number of channels for skip connection.
+ aux_channels (int): Local conditioning channels i.e. auxiliary input dimension.
+ dropout (float): Dropout probability.
+ dilation (int): Dilation factor.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution.
+
+ """
+ super(ResidualBlock, self).__init__()
+ self.dropout = dropout
+ # no future time stamps available
+ if use_causal_conv:
+ padding = (kernel_size - 1) * dilation
+ else:
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ padding = (kernel_size - 1) // 2 * dilation
+ self.use_causal_conv = use_causal_conv
+
+ # dilation conv
+ self.conv = Conv1d(residual_channels, gate_channels, kernel_size,
+ padding=padding, dilation=dilation, bias=bias)
+
+ # local conditioning
+ if aux_channels > 0:
+ self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False)
+ else:
+ self.conv1x1_aux = None
+
+ # conv output is split into two groups
+ gate_out_channels = gate_channels // 2
+ self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias)
+ self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias)
+
+ def forward(self, x, c):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, residual_channels, T).
+ c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T).
+
+ Returns:
+ Tensor: Output tensor for residual connection (B, residual_channels, T).
+ Tensor: Output tensor for skip connection (B, skip_channels, T).
+
+ """
+ residual = x
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ x = self.conv(x)
+
+ # remove future time steps if use_causal_conv conv
+ x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
+
+ # split into two part for gated activation
+ splitdim = 1
+ xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim)
+
+ # local conditioning
+ if c is not None:
+ assert self.conv1x1_aux is not None
+ c = self.conv1x1_aux(c)
+ ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim)
+ xa, xb = xa + ca, xb + cb
+
+ x = torch.tanh(xa) * torch.sigmoid(xb)
+
+ # for skip connection
+ s = self.conv1x1_skip(x)
+
+ # for residual connection
+ x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5)
+
+ return x, s
diff --git a/modules/parallel_wavegan/layers/residual_stack.py b/modules/parallel_wavegan/layers/residual_stack.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e07c8803ad348dd923f6b7c0f7aff14aab9cf78
--- /dev/null
+++ b/modules/parallel_wavegan/layers/residual_stack.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Residual stack module in MelGAN."""
+
+import torch
+
+from . import CausalConv1d
+
+
+class ResidualStack(torch.nn.Module):
+ """Residual stack module introduced in MelGAN."""
+
+ def __init__(self,
+ kernel_size=3,
+ channels=32,
+ dilation=1,
+ bias=True,
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ pad="ReflectionPad1d",
+ pad_params={},
+ use_causal_conv=False,
+ ):
+ """Initialize ResidualStack module.
+
+ Args:
+ kernel_size (int): Kernel size of dilation convolution layer.
+ channels (int): Number of channels of convolution layers.
+ dilation (int): Dilation factor.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
+ pad (str): Padding function module name before dilated convolution layer.
+ pad_params (dict): Hyperparameters for padding function.
+ use_causal_conv (bool): Whether to use causal convolution.
+
+ """
+ super(ResidualStack, self).__init__()
+
+ # defile residual stack part
+ if not use_causal_conv:
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ self.stack = torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params),
+ torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
+ )
+ else:
+ self.stack = torch.nn.Sequential(
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ CausalConv1d(channels, channels, kernel_size, dilation=dilation,
+ bias=bias, pad=pad, pad_params=pad_params),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ torch.nn.Conv1d(channels, channels, 1, bias=bias),
+ )
+
+ # defile extra layer for skip connection
+ self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias)
+
+ def forward(self, c):
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, chennels, T).
+
+ """
+ return self.stack(c) + self.skip_layer(c)
diff --git a/modules/parallel_wavegan/layers/tf_layers.py b/modules/parallel_wavegan/layers/tf_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0f46bd755c161cda2ac904fe37f3f3c6357a88d
--- /dev/null
+++ b/modules/parallel_wavegan/layers/tf_layers.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 MINH ANH (@dathudeptrai)
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Tensorflow Layer modules complatible with pytorch."""
+
+import tensorflow as tf
+
+
+class TFReflectionPad1d(tf.keras.layers.Layer):
+ """Tensorflow ReflectionPad1d module."""
+
+ def __init__(self, padding_size):
+ """Initialize TFReflectionPad1d module.
+
+ Args:
+ padding_size (int): Padding size.
+
+ """
+ super(TFReflectionPad1d, self).__init__()
+ self.padding_size = padding_size
+
+ @tf.function
+ def call(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, T, 1, C).
+
+ Returns:
+ Tensor: Padded tensor (B, T + 2 * padding_size, 1, C).
+
+ """
+ return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT")
+
+
+class TFConvTranspose1d(tf.keras.layers.Layer):
+ """Tensorflow ConvTranspose1d module."""
+
+ def __init__(self, channels, kernel_size, stride, padding):
+ """Initialize TFConvTranspose1d( module.
+
+ Args:
+ channels (int): Number of channels.
+ kernel_size (int): kernel size.
+ strides (int): Stride width.
+ padding (str): Padding type ("same" or "valid").
+
+ """
+ super(TFConvTranspose1d, self).__init__()
+ self.conv1d_transpose = tf.keras.layers.Conv2DTranspose(
+ filters=channels,
+ kernel_size=(kernel_size, 1),
+ strides=(stride, 1),
+ padding=padding,
+ )
+
+ @tf.function
+ def call(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, T, 1, C).
+
+ Returns:
+ Tensors: Output tensor (B, T', 1, C').
+
+ """
+ x = self.conv1d_transpose(x)
+ return x
+
+
+class TFResidualStack(tf.keras.layers.Layer):
+ """Tensorflow ResidualStack module."""
+
+ def __init__(self,
+ kernel_size,
+ channels,
+ dilation,
+ bias,
+ nonlinear_activation,
+ nonlinear_activation_params,
+ padding,
+ ):
+ """Initialize TFResidualStack module.
+
+ Args:
+ kernel_size (int): Kernel size.
+ channles (int): Number of channels.
+ dilation (int): Dilation ine.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
+ padding (str): Padding type ("same" or "valid").
+
+ """
+ super(TFResidualStack, self).__init__()
+ self.block = [
+ getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
+ TFReflectionPad1d(dilation),
+ tf.keras.layers.Conv2D(
+ filters=channels,
+ kernel_size=(kernel_size, 1),
+ dilation_rate=(dilation, 1),
+ use_bias=bias,
+ padding="valid",
+ ),
+ getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params),
+ tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
+ ]
+ self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias)
+
+ @tf.function
+ def call(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, T, 1, C).
+
+ Returns:
+ Tensor: Output tensor (B, T, 1, C).
+
+ """
+ _x = tf.identity(x)
+ for i, layer in enumerate(self.block):
+ _x = layer(_x)
+ shortcut = self.shortcut(x)
+ return shortcut + _x
diff --git a/modules/parallel_wavegan/layers/upsample.py b/modules/parallel_wavegan/layers/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c6397c420a81fadc5320e3a48f3249534decd8
--- /dev/null
+++ b/modules/parallel_wavegan/layers/upsample.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+
+"""Upsampling module.
+
+This code is modified from https://github.com/r9y9/wavenet_vocoder.
+
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from . import Conv1d
+
+
+class Stretch2d(torch.nn.Module):
+ """Stretch2d module."""
+
+ def __init__(self, x_scale, y_scale, mode="nearest"):
+ """Initialize Stretch2d module.
+
+ Args:
+ x_scale (int): X scaling factor (Time axis in spectrogram).
+ y_scale (int): Y scaling factor (Frequency axis in spectrogram).
+ mode (str): Interpolation mode.
+
+ """
+ super(Stretch2d, self).__init__()
+ self.x_scale = x_scale
+ self.y_scale = y_scale
+ self.mode = mode
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input tensor (B, C, F, T).
+
+ Returns:
+ Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale),
+
+ """
+ return F.interpolate(
+ x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode)
+
+
+class Conv2d(torch.nn.Conv2d):
+ """Conv2d module with customized initialization."""
+
+ def __init__(self, *args, **kwargs):
+ """Initialize Conv2d module."""
+ super(Conv2d, self).__init__(*args, **kwargs)
+
+ def reset_parameters(self):
+ """Reset parameters."""
+ self.weight.data.fill_(1. / np.prod(self.kernel_size))
+ if self.bias is not None:
+ torch.nn.init.constant_(self.bias, 0.0)
+
+
+class UpsampleNetwork(torch.nn.Module):
+ """Upsampling network module."""
+
+ def __init__(self,
+ upsample_scales,
+ nonlinear_activation=None,
+ nonlinear_activation_params={},
+ interpolate_mode="nearest",
+ freq_axis_kernel_size=1,
+ use_causal_conv=False,
+ ):
+ """Initialize upsampling network module.
+
+ Args:
+ upsample_scales (list): List of upsampling scales.
+ nonlinear_activation (str): Activation function name.
+ nonlinear_activation_params (dict): Arguments for specified activation function.
+ interpolate_mode (str): Interpolation mode.
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
+
+ """
+ super(UpsampleNetwork, self).__init__()
+ self.use_causal_conv = use_causal_conv
+ self.up_layers = torch.nn.ModuleList()
+ for scale in upsample_scales:
+ # interpolation layer
+ stretch = Stretch2d(scale, 1, interpolate_mode)
+ self.up_layers += [stretch]
+
+ # conv layer
+ assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size."
+ freq_axis_padding = (freq_axis_kernel_size - 1) // 2
+ kernel_size = (freq_axis_kernel_size, scale * 2 + 1)
+ if use_causal_conv:
+ padding = (freq_axis_padding, scale * 2)
+ else:
+ padding = (freq_axis_padding, scale)
+ conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
+ self.up_layers += [conv]
+
+ # nonlinear
+ if nonlinear_activation is not None:
+ nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)
+ self.up_layers += [nonlinear]
+
+ def forward(self, c):
+ """Calculate forward propagation.
+
+ Args:
+ c : Input tensor (B, C, T).
+
+ Returns:
+ Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales).
+
+ """
+ c = c.unsqueeze(1) # (B, 1, C, T)
+ for f in self.up_layers:
+ if self.use_causal_conv and isinstance(f, Conv2d):
+ c = f(c)[..., :c.size(-1)]
+ else:
+ c = f(c)
+ return c.squeeze(1) # (B, C, T')
+
+
+class ConvInUpsampleNetwork(torch.nn.Module):
+ """Convolution + upsampling network module."""
+
+ def __init__(self,
+ upsample_scales,
+ nonlinear_activation=None,
+ nonlinear_activation_params={},
+ interpolate_mode="nearest",
+ freq_axis_kernel_size=1,
+ aux_channels=80,
+ aux_context_window=0,
+ use_causal_conv=False
+ ):
+ """Initialize convolution + upsampling network module.
+
+ Args:
+ upsample_scales (list): List of upsampling scales.
+ nonlinear_activation (str): Activation function name.
+ nonlinear_activation_params (dict): Arguments for specified activation function.
+ mode (str): Interpolation mode.
+ freq_axis_kernel_size (int): Kernel size in the direction of frequency axis.
+ aux_channels (int): Number of channels of pre-convolutional layer.
+ aux_context_window (int): Context window size of the pre-convolutional layer.
+ use_causal_conv (bool): Whether to use causal structure.
+
+ """
+ super(ConvInUpsampleNetwork, self).__init__()
+ self.aux_context_window = aux_context_window
+ self.use_causal_conv = use_causal_conv and aux_context_window > 0
+ # To capture wide-context information in conditional features
+ kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1
+ # NOTE(kan-bayashi): Here do not use padding because the input is already padded
+ self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False)
+ self.upsample = UpsampleNetwork(
+ upsample_scales=upsample_scales,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ interpolate_mode=interpolate_mode,
+ freq_axis_kernel_size=freq_axis_kernel_size,
+ use_causal_conv=use_causal_conv,
+ )
+
+ def forward(self, c):
+ """Calculate forward propagation.
+
+ Args:
+ c : Input tensor (B, C, T').
+
+ Returns:
+ Tensor: Upsampled tensor (B, C, T),
+ where T = (T' - aux_context_window * 2) * prod(upsample_scales).
+
+ Note:
+ The length of inputs considers the context window size.
+
+ """
+ c_ = self.conv_in(c)
+ c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_
+ return self.upsample(c)
diff --git a/modules/parallel_wavegan/losses/__init__.py b/modules/parallel_wavegan/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b03080a907cb5cb4b316ceb74866ddbc406b33bf
--- /dev/null
+++ b/modules/parallel_wavegan/losses/__init__.py
@@ -0,0 +1 @@
+from .stft_loss import * # NOQA
diff --git a/modules/parallel_wavegan/losses/stft_loss.py b/modules/parallel_wavegan/losses/stft_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d2aa21ad30ba094c406366e652067462f49cd2
--- /dev/null
+++ b/modules/parallel_wavegan/losses/stft_loss.py
@@ -0,0 +1,153 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""STFT-based Loss modules."""
+
+import torch
+import torch.nn.functional as F
+
+
+def stft(x, fft_size, hop_size, win_length, window):
+ """Perform STFT and convert to magnitude spectrogram.
+
+ Args:
+ x (Tensor): Input signal tensor (B, T).
+ fft_size (int): FFT size.
+ hop_size (int): Hop size.
+ win_length (int): Window length.
+ window (str): Window function type.
+
+ Returns:
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
+
+ """
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
+ real = x_stft[..., 0]
+ imag = x_stft[..., 1]
+
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
+ return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1)
+
+
+class SpectralConvergengeLoss(torch.nn.Module):
+ """Spectral convergence loss module."""
+
+ def __init__(self):
+ """Initilize spectral convergence loss module."""
+ super(SpectralConvergengeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+ Returns:
+ Tensor: Spectral convergence loss value.
+
+ """
+ return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro")
+
+
+class LogSTFTMagnitudeLoss(torch.nn.Module):
+ """Log STFT magnitude loss module."""
+
+ def __init__(self):
+ """Initilize los STFT magnitude loss module."""
+ super(LogSTFTMagnitudeLoss, self).__init__()
+
+ def forward(self, x_mag, y_mag):
+ """Calculate forward propagation.
+
+ Args:
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
+
+ Returns:
+ Tensor: Log STFT magnitude loss value.
+
+ """
+ return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
+
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.window = getattr(torch, window)(win_length)
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+
+ """
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+ return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+ """Multi resolution STFT loss module."""
+
+ def __init__(self,
+ fft_sizes=[1024, 2048, 512],
+ hop_sizes=[120, 240, 50],
+ win_lengths=[600, 1200, 240],
+ window="hann_window"):
+ """Initialize Multi resolution STFT loss module.
+
+ Args:
+ fft_sizes (list): List of FFT sizes.
+ hop_sizes (list): List of hop sizes.
+ win_lengths (list): List of window lengths.
+ window (str): Window function type.
+
+ """
+ super(MultiResolutionSTFTLoss, self).__init__()
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+
+ """
+ sc_loss = 0.0
+ mag_loss = 0.0
+ for f in self.stft_losses:
+ sc_l, mag_l = f(x, y)
+ sc_loss += sc_l
+ mag_loss += mag_l
+ sc_loss /= len(self.stft_losses)
+ mag_loss /= len(self.stft_losses)
+
+ return sc_loss, mag_loss
diff --git a/modules/parallel_wavegan/models/__init__.py b/modules/parallel_wavegan/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4803ba6b2a0afc8022e756ae5b3f4c7403c3c1bd
--- /dev/null
+++ b/modules/parallel_wavegan/models/__init__.py
@@ -0,0 +1,2 @@
+from .melgan import * # NOQA
+from .parallel_wavegan import * # NOQA
diff --git a/modules/parallel_wavegan/models/__pycache__/__init__.cpython-38.pyc b/modules/parallel_wavegan/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..79b6711d154a174aaad515b78a793617933b0181
Binary files /dev/null and b/modules/parallel_wavegan/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/models/__pycache__/melgan.cpython-38.pyc b/modules/parallel_wavegan/models/__pycache__/melgan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b72851f34a348346b89619928b640c57d3a8fe1
Binary files /dev/null and b/modules/parallel_wavegan/models/__pycache__/melgan.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/models/__pycache__/parallel_wavegan.cpython-38.pyc b/modules/parallel_wavegan/models/__pycache__/parallel_wavegan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c979d6fcb27ebf7fd73becb75f311f55f27851c
Binary files /dev/null and b/modules/parallel_wavegan/models/__pycache__/parallel_wavegan.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/models/__pycache__/source.cpython-38.pyc b/modules/parallel_wavegan/models/__pycache__/source.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..684d62eb5adb24bae290bac8ba022da850a287f4
Binary files /dev/null and b/modules/parallel_wavegan/models/__pycache__/source.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/models/melgan.py b/modules/parallel_wavegan/models/melgan.py
new file mode 100644
index 0000000000000000000000000000000000000000..e021ae4817a8c1c97338e61b00b230c881836fd8
--- /dev/null
+++ b/modules/parallel_wavegan/models/melgan.py
@@ -0,0 +1,427 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""MelGAN Modules."""
+
+import logging
+
+import numpy as np
+import torch
+
+from modules.parallel_wavegan.layers import CausalConv1d
+from modules.parallel_wavegan.layers import CausalConvTranspose1d
+from modules.parallel_wavegan.layers import ResidualStack
+
+
+class MelGANGenerator(torch.nn.Module):
+ """MelGAN generator module."""
+
+ def __init__(self,
+ in_channels=80,
+ out_channels=1,
+ kernel_size=7,
+ channels=512,
+ bias=True,
+ upsample_scales=[8, 8, 2, 2],
+ stack_kernel_size=3,
+ stacks=3,
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ pad="ReflectionPad1d",
+ pad_params={},
+ use_final_nonlinear_activation=True,
+ use_weight_norm=True,
+ use_causal_conv=False,
+ ):
+ """Initialize MelGANGenerator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of initial and final conv layer.
+ channels (int): Initial number of channels for conv layer.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ upsample_scales (list): List of upsampling scales.
+ stack_kernel_size (int): Kernel size of dilated conv layers in residual stack.
+ stacks (int): Number of stacks in a single residual stack.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
+ pad (str): Padding function module name before dilated convolution layer.
+ pad_params (dict): Hyperparameters for padding function.
+ use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_causal_conv (bool): Whether to use causal convolution.
+
+ """
+ super(MelGANGenerator, self).__init__()
+
+ # check hyper parameters is valid
+ assert channels >= np.prod(upsample_scales)
+ assert channels % (2 ** len(upsample_scales)) == 0
+ if not use_causal_conv:
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+
+ # add initial layer
+ layers = []
+ if not use_causal_conv:
+ layers += [
+ getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
+ torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias),
+ ]
+ else:
+ layers += [
+ CausalConv1d(in_channels, channels, kernel_size,
+ bias=bias, pad=pad, pad_params=pad_params),
+ ]
+
+ for i, upsample_scale in enumerate(upsample_scales):
+ # add upsampling layer
+ layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
+ if not use_causal_conv:
+ layers += [
+ torch.nn.ConvTranspose1d(
+ channels // (2 ** i),
+ channels // (2 ** (i + 1)),
+ upsample_scale * 2,
+ stride=upsample_scale,
+ padding=upsample_scale // 2 + upsample_scale % 2,
+ output_padding=upsample_scale % 2,
+ bias=bias,
+ )
+ ]
+ else:
+ layers += [
+ CausalConvTranspose1d(
+ channels // (2 ** i),
+ channels // (2 ** (i + 1)),
+ upsample_scale * 2,
+ stride=upsample_scale,
+ bias=bias,
+ )
+ ]
+
+ # add residual stack
+ for j in range(stacks):
+ layers += [
+ ResidualStack(
+ kernel_size=stack_kernel_size,
+ channels=channels // (2 ** (i + 1)),
+ dilation=stack_kernel_size ** j,
+ bias=bias,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ pad=pad,
+ pad_params=pad_params,
+ use_causal_conv=use_causal_conv,
+ )
+ ]
+
+ # add final layer
+ layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)]
+ if not use_causal_conv:
+ layers += [
+ getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params),
+ torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias),
+ ]
+ else:
+ layers += [
+ CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size,
+ bias=bias, pad=pad, pad_params=pad_params),
+ ]
+ if use_final_nonlinear_activation:
+ layers += [torch.nn.Tanh()]
+
+ # define the model as a single function
+ self.melgan = torch.nn.Sequential(*layers)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # reset parameters
+ self.reset_parameters()
+
+ def forward(self, c):
+ """Calculate forward propagation.
+
+ Args:
+ c (Tensor): Input tensor (B, channels, T).
+
+ Returns:
+ Tensor: Output tensor (B, 1, T ** prod(upsample_scales)).
+
+ """
+ return self.melgan(c)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+ def _apply_weight_norm(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ """Reset parameters.
+
+ This initialization follows official implementation manner.
+ https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
+
+ """
+ def _reset_parameters(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+ m.weight.data.normal_(0.0, 0.02)
+ logging.debug(f"Reset parameters in {m}.")
+
+ self.apply(_reset_parameters)
+
+
+class MelGANDiscriminator(torch.nn.Module):
+ """MelGAN discriminator module."""
+
+ def __init__(self,
+ in_channels=1,
+ out_channels=1,
+ kernel_sizes=[5, 3],
+ channels=16,
+ max_downsample_channels=1024,
+ bias=True,
+ downsample_scales=[4, 4, 4, 4],
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ pad="ReflectionPad1d",
+ pad_params={},
+ ):
+ """Initilize MelGAN discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer,
+ and the first and the second kernel sizes will be used for the last two layers.
+ For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15,
+ the last two layers' kernel size will be 5 and 3, respectively.
+ channels (int): Initial number of channels for conv layer.
+ max_downsample_channels (int): Maximum number of channels for downsampling layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ downsample_scales (list): List of downsampling scales.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
+ pad (str): Padding function module name before dilated convolution layer.
+ pad_params (dict): Hyperparameters for padding function.
+
+ """
+ super(MelGANDiscriminator, self).__init__()
+ self.layers = torch.nn.ModuleList()
+
+ # check kernel size is valid
+ assert len(kernel_sizes) == 2
+ assert kernel_sizes[0] % 2 == 1
+ assert kernel_sizes[1] % 2 == 1
+
+ # add first layer
+ self.layers += [
+ torch.nn.Sequential(
+ getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params),
+ torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+
+ # add downsample layers
+ in_chs = channels
+ for downsample_scale in downsample_scales:
+ out_chs = min(in_chs * downsample_scale, max_downsample_channels)
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs, out_chs,
+ kernel_size=downsample_scale * 10 + 1,
+ stride=downsample_scale,
+ padding=downsample_scale * 5,
+ groups=in_chs // 4,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+ in_chs = out_chs
+
+ # add final layers
+ out_chs = min(in_chs * 2, max_downsample_channels)
+ self.layers += [
+ torch.nn.Sequential(
+ torch.nn.Conv1d(
+ in_chs, out_chs, kernel_sizes[0],
+ padding=(kernel_sizes[0] - 1) // 2,
+ bias=bias,
+ ),
+ getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
+ )
+ ]
+ self.layers += [
+ torch.nn.Conv1d(
+ out_chs, out_channels, kernel_sizes[1],
+ padding=(kernel_sizes[1] - 1) // 2,
+ bias=bias,
+ ),
+ ]
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of output tensors of each layer.
+
+ """
+ outs = []
+ for f in self.layers:
+ x = f(x)
+ outs += [x]
+
+ return outs
+
+
+class MelGANMultiScaleDiscriminator(torch.nn.Module):
+ """MelGAN multi-scale discriminator module."""
+
+ def __init__(self,
+ in_channels=1,
+ out_channels=1,
+ scales=3,
+ downsample_pooling="AvgPool1d",
+ # follow the official implementation setting
+ downsample_pooling_params={
+ "kernel_size": 4,
+ "stride": 2,
+ "padding": 1,
+ "count_include_pad": False,
+ },
+ kernel_sizes=[5, 3],
+ channels=16,
+ max_downsample_channels=1024,
+ bias=True,
+ downsample_scales=[4, 4, 4, 4],
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ pad="ReflectionPad1d",
+ pad_params={},
+ use_weight_norm=True,
+ ):
+ """Initilize MelGAN multi-scale discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ downsample_pooling (str): Pooling module name for downsampling of the inputs.
+ downsample_pooling_params (dict): Parameters for the above pooling module.
+ kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer,
+ and the first and the second kernel sizes will be used for the last two layers.
+ channels (int): Initial number of channels for conv layer.
+ max_downsample_channels (int): Maximum number of channels for downsampling layers.
+ bias (bool): Whether to add bias parameter in convolution layers.
+ downsample_scales (list): List of downsampling scales.
+ nonlinear_activation (str): Activation function module name.
+ nonlinear_activation_params (dict): Hyperparameters for activation function.
+ pad (str): Padding function module name before dilated convolution layer.
+ pad_params (dict): Hyperparameters for padding function.
+ use_causal_conv (bool): Whether to use causal convolution.
+
+ """
+ super(MelGANMultiScaleDiscriminator, self).__init__()
+ self.discriminators = torch.nn.ModuleList()
+
+ # add discriminators
+ for _ in range(scales):
+ self.discriminators += [
+ MelGANDiscriminator(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_sizes=kernel_sizes,
+ channels=channels,
+ max_downsample_channels=max_downsample_channels,
+ bias=bias,
+ downsample_scales=downsample_scales,
+ nonlinear_activation=nonlinear_activation,
+ nonlinear_activation_params=nonlinear_activation_params,
+ pad=pad,
+ pad_params=pad_params,
+ )
+ ]
+ self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ # reset parameters
+ self.reset_parameters()
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ List: List of list of each discriminator outputs, which consists of each layer output tensors.
+
+ """
+ outs = []
+ for f in self.discriminators:
+ outs += [f(x)]
+ x = self.pooling(x)
+
+ return outs
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+ def _apply_weight_norm(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def reset_parameters(self):
+ """Reset parameters.
+
+ This initialization follows official implementation manner.
+ https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py
+
+ """
+ def _reset_parameters(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d):
+ m.weight.data.normal_(0.0, 0.02)
+ logging.debug(f"Reset parameters in {m}.")
+
+ self.apply(_reset_parameters)
diff --git a/modules/parallel_wavegan/models/parallel_wavegan.py b/modules/parallel_wavegan/models/parallel_wavegan.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63b59f67aa48342179415c1d1beac68574a5498
--- /dev/null
+++ b/modules/parallel_wavegan/models/parallel_wavegan.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Parallel WaveGAN Modules."""
+
+import logging
+import math
+
+import torch
+from torch import nn
+
+from modules.parallel_wavegan.layers import Conv1d
+from modules.parallel_wavegan.layers import Conv1d1x1
+from modules.parallel_wavegan.layers import ResidualBlock
+from modules.parallel_wavegan.layers import upsample
+from modules.parallel_wavegan import models
+
+
+class ParallelWaveGANGenerator(torch.nn.Module):
+ """Parallel WaveGAN Generator module."""
+
+ def __init__(self,
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ layers=30,
+ stacks=3,
+ residual_channels=64,
+ gate_channels=128,
+ skip_channels=64,
+ aux_channels=80,
+ aux_context_window=2,
+ dropout=0.0,
+ bias=True,
+ use_weight_norm=True,
+ use_causal_conv=False,
+ upsample_conditional_features=True,
+ upsample_net="ConvInUpsampleNetwork",
+ upsample_params={"upsample_scales": [4, 4, 4, 4]},
+ use_pitch_embed=False,
+ ):
+ """Initialize Parallel WaveGAN Generator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of dilated convolution.
+ layers (int): Number of residual block layers.
+ stacks (int): Number of stacks i.e., dilation cycles.
+ residual_channels (int): Number of channels in residual conv.
+ gate_channels (int): Number of channels in gated conv.
+ skip_channels (int): Number of channels in skip conv.
+ aux_channels (int): Number of channels for auxiliary feature conv.
+ aux_context_window (int): Context window size for auxiliary feature.
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
+ bias (bool): Whether to use bias parameter in conv layer.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_causal_conv (bool): Whether to use causal structure.
+ upsample_conditional_features (bool): Whether to use upsampling network.
+ upsample_net (str): Upsampling network architecture.
+ upsample_params (dict): Upsampling network parameters.
+
+ """
+ super(ParallelWaveGANGenerator, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.aux_channels = aux_channels
+ self.layers = layers
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+
+ # check the number of layers and stacks
+ assert layers % stacks == 0
+ layers_per_stack = layers // stacks
+
+ # define first convolution
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
+
+ # define conv + upsampling network
+ if upsample_conditional_features:
+ upsample_params.update({
+ "use_causal_conv": use_causal_conv,
+ })
+ if upsample_net == "MelGANGenerator":
+ assert aux_context_window == 0
+ upsample_params.update({
+ "use_weight_norm": False, # not to apply twice
+ "use_final_nonlinear_activation": False,
+ })
+ self.upsample_net = getattr(models, upsample_net)(**upsample_params)
+ else:
+ if upsample_net == "ConvInUpsampleNetwork":
+ upsample_params.update({
+ "aux_channels": aux_channels,
+ "aux_context_window": aux_context_window,
+ })
+ self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
+ else:
+ self.upsample_net = None
+
+ # define residual blocks
+ self.conv_layers = torch.nn.ModuleList()
+ for layer in range(layers):
+ dilation = 2 ** (layer % layers_per_stack)
+ conv = ResidualBlock(
+ kernel_size=kernel_size,
+ residual_channels=residual_channels,
+ gate_channels=gate_channels,
+ skip_channels=skip_channels,
+ aux_channels=aux_channels,
+ dilation=dilation,
+ dropout=dropout,
+ bias=bias,
+ use_causal_conv=use_causal_conv,
+ )
+ self.conv_layers += [conv]
+
+ # define output layers
+ self.last_conv_layers = torch.nn.ModuleList([
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
+ torch.nn.ReLU(inplace=True),
+ Conv1d1x1(skip_channels, out_channels, bias=True),
+ ])
+
+ self.use_pitch_embed = use_pitch_embed
+ if use_pitch_embed:
+ self.pitch_embed = nn.Embedding(300, aux_channels, 0)
+ self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(self, x, c=None, pitch=None, **kwargs):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, C_in, T).
+ c (Tensor): Local conditioning auxiliary features (B, C ,T').
+ pitch (Tensor): Local conditioning pitch (B, T').
+
+ Returns:
+ Tensor: Output tensor (B, C_out, T)
+
+ """
+ # perform upsampling
+ if c is not None and self.upsample_net is not None:
+ if self.use_pitch_embed:
+ p = self.pitch_embed(pitch)
+ c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
+ c = self.upsample_net(c)
+ assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
+
+ # encode to hidden representation
+ x = self.first_conv(x)
+ skips = 0
+ for f in self.conv_layers:
+ x, h = f(x, c)
+ skips += h
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+ # apply final layers
+ x = skips
+ for f in self.last_conv_layers:
+ x = f(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+ def _apply_weight_norm(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ @staticmethod
+ def _get_receptive_field_size(layers, stacks, kernel_size,
+ dilation=lambda x: 2 ** x):
+ assert layers % stacks == 0
+ layers_per_cycle = layers // stacks
+ dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
+ return (kernel_size - 1) * sum(dilations) + 1
+
+ @property
+ def receptive_field_size(self):
+ """Return receptive field size."""
+ return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
+
+
+class ParallelWaveGANDiscriminator(torch.nn.Module):
+ """Parallel WaveGAN Discriminator module."""
+
+ def __init__(self,
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ layers=10,
+ conv_channels=64,
+ dilation_factor=1,
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ bias=True,
+ use_weight_norm=True,
+ ):
+ """Initialize Parallel WaveGAN Discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Number of output channels.
+ layers (int): Number of conv layers.
+ conv_channels (int): Number of chnn layers.
+ dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
+ the dilation will be 2, 4, 8, ..., and so on.
+ nonlinear_activation (str): Nonlinear function after each conv.
+ nonlinear_activation_params (dict): Nonlinear function parameters
+ bias (bool): Whether to use bias parameter in conv.
+ use_weight_norm (bool) Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+
+ """
+ super(ParallelWaveGANDiscriminator, self).__init__()
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+ assert dilation_factor > 0, "Dilation factor must be > 0."
+ self.conv_layers = torch.nn.ModuleList()
+ conv_in_channels = in_channels
+ for i in range(layers - 1):
+ if i == 0:
+ dilation = 1
+ else:
+ dilation = i if dilation_factor == 1 else dilation_factor ** i
+ conv_in_channels = conv_channels
+ padding = (kernel_size - 1) // 2 * dilation
+ conv_layer = [
+ Conv1d(conv_in_channels, conv_channels,
+ kernel_size=kernel_size, padding=padding,
+ dilation=dilation, bias=bias),
+ getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
+ ]
+ self.conv_layers += conv_layer
+ padding = (kernel_size - 1) // 2
+ last_conv_layer = Conv1d(
+ conv_in_channels, out_channels,
+ kernel_size=kernel_size, padding=padding, bias=bias)
+ self.conv_layers += [last_conv_layer]
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ Tensor: Output tensor (B, 1, T)
+
+ """
+ for f in self.conv_layers:
+ x = f(x)
+ return x
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+ def _apply_weight_norm(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
+
+
+class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
+ """Parallel WaveGAN Discriminator module."""
+
+ def __init__(self,
+ in_channels=1,
+ out_channels=1,
+ kernel_size=3,
+ layers=30,
+ stacks=3,
+ residual_channels=64,
+ gate_channels=128,
+ skip_channels=64,
+ dropout=0.0,
+ bias=True,
+ use_weight_norm=True,
+ use_causal_conv=False,
+ nonlinear_activation="LeakyReLU",
+ nonlinear_activation_params={"negative_slope": 0.2},
+ ):
+ """Initialize Parallel WaveGAN Discriminator module.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ kernel_size (int): Kernel size of dilated convolution.
+ layers (int): Number of residual block layers.
+ stacks (int): Number of stacks i.e., dilation cycles.
+ residual_channels (int): Number of channels in residual conv.
+ gate_channels (int): Number of channels in gated conv.
+ skip_channels (int): Number of channels in skip conv.
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
+ bias (bool): Whether to use bias parameter in conv.
+ use_weight_norm (bool): Whether to use weight norm.
+ If set to true, it will be applied to all of the conv layers.
+ use_causal_conv (bool): Whether to use causal structure.
+ nonlinear_activation_params (dict): Nonlinear function parameters
+
+ """
+ super(ResidualParallelWaveGANDiscriminator, self).__init__()
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.layers = layers
+ self.stacks = stacks
+ self.kernel_size = kernel_size
+
+ # check the number of layers and stacks
+ assert layers % stacks == 0
+ layers_per_stack = layers // stacks
+
+ # define first convolution
+ self.first_conv = torch.nn.Sequential(
+ Conv1d1x1(in_channels, residual_channels, bias=True),
+ getattr(torch.nn, nonlinear_activation)(
+ inplace=True, **nonlinear_activation_params),
+ )
+
+ # define residual blocks
+ self.conv_layers = torch.nn.ModuleList()
+ for layer in range(layers):
+ dilation = 2 ** (layer % layers_per_stack)
+ conv = ResidualBlock(
+ kernel_size=kernel_size,
+ residual_channels=residual_channels,
+ gate_channels=gate_channels,
+ skip_channels=skip_channels,
+ aux_channels=-1,
+ dilation=dilation,
+ dropout=dropout,
+ bias=bias,
+ use_causal_conv=use_causal_conv,
+ )
+ self.conv_layers += [conv]
+
+ # define output layers
+ self.last_conv_layers = torch.nn.ModuleList([
+ getattr(torch.nn, nonlinear_activation)(
+ inplace=True, **nonlinear_activation_params),
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
+ getattr(torch.nn, nonlinear_activation)(
+ inplace=True, **nonlinear_activation_params),
+ Conv1d1x1(skip_channels, out_channels, bias=True),
+ ])
+
+ # apply weight norm
+ if use_weight_norm:
+ self.apply_weight_norm()
+
+ def forward(self, x):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Input noise signal (B, 1, T).
+
+ Returns:
+ Tensor: Output tensor (B, 1, T)
+
+ """
+ x = self.first_conv(x)
+
+ skips = 0
+ for f in self.conv_layers:
+ x, h = f(x, None)
+ skips += h
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
+
+ # apply final layers
+ x = skips
+ for f in self.last_conv_layers:
+ x = f(x)
+ return x
+
+ def apply_weight_norm(self):
+ """Apply weight normalization module from all of the layers."""
+ def _apply_weight_norm(m):
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
+ torch.nn.utils.weight_norm(m)
+ logging.debug(f"Weight norm is applied to {m}.")
+
+ self.apply(_apply_weight_norm)
+
+ def remove_weight_norm(self):
+ """Remove weight normalization module from all of the layers."""
+ def _remove_weight_norm(m):
+ try:
+ logging.debug(f"Weight norm is removed from {m}.")
+ torch.nn.utils.remove_weight_norm(m)
+ except ValueError: # this module didn't have weight norm
+ return
+
+ self.apply(_remove_weight_norm)
diff --git a/modules/parallel_wavegan/models/source.py b/modules/parallel_wavegan/models/source.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2a006e53c0e2194036fd08ea9d6ed4d9a10d6cf
--- /dev/null
+++ b/modules/parallel_wavegan/models/source.py
@@ -0,0 +1,538 @@
+import torch
+import numpy as np
+import sys
+import torch.nn.functional as torch_nn_func
+
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voiced_threshold)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+ device=f0_values.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+ # for normal case
+
+ # To prevent torch.cumsum numerical overflow,
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
+ tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
+ * 2 * np.pi)
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ with torch.no_grad():
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
+ device=f0.device)
+ # fundamental component
+ f0_buf[:, :, 0] = f0[:, :, 0]
+ for idx in np.arange(self.harmonic_num):
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class PulseGen(torch.nn.Module):
+ """ Definition of Pulse train generator
+
+ There are many ways to implement pulse generator.
+ Here, PulseGen is based on SinGen. For a perfect
+ """
+ def __init__(self, samp_rate, pulse_amp = 0.1,
+ noise_std = 0.003, voiced_threshold = 0):
+ super(PulseGen, self).__init__()
+ self.pulse_amp = pulse_amp
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.noise_std = noise_std
+ self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \
+ sine_amp=self.pulse_amp, noise_std=0, \
+ voiced_threshold=self.voiced_threshold, \
+ flag_for_pulse=True)
+
+ def forward(self, f0):
+ """ Pulse train generator
+ pulse_train, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output pulse_train: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+
+ Note: self.l_sine doesn't make sure that the initial phase of
+ a voiced segment is np.pi, the first pulse in a voiced segment
+ may not be at the first time step within a voiced segment
+ """
+ with torch.no_grad():
+ sine_wav, uv, noise = self.l_sinegen(f0)
+
+ # sine without additive noise
+ pure_sine = sine_wav - noise
+
+ # step t corresponds to a pulse if
+ # sine[t] > sine[t+1] & sine[t] > sine[t-1]
+ # & sine[t-1], sine[t+1], and sine[t] are voiced
+ # or
+ # sine[t] is voiced, sine[t-1] is unvoiced
+ # we use torch.roll to simulate sine[t+1] and sine[t-1]
+ sine_1 = torch.roll(pure_sine, shifts=1, dims=1)
+ uv_1 = torch.roll(uv, shifts=1, dims=1)
+ uv_1[:, 0, :] = 0
+ sine_2 = torch.roll(pure_sine, shifts=-1, dims=1)
+ uv_2 = torch.roll(uv, shifts=-1, dims=1)
+ uv_2[:, -1, :] = 0
+
+ loc = (pure_sine > sine_1) * (pure_sine > sine_2) \
+ * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \
+ + (uv_1 < 1) * (uv > 0)
+
+ # pulse train without noise
+ pulse_train = pure_sine * loc
+
+ # additive noise to pulse train
+ # note that noise from sinegen is zero in voiced regions
+ pulse_noise = torch.randn_like(pure_sine) * self.noise_std
+
+ # with additive noise on pulse, and unvoiced regions
+ pulse_train += pulse_noise * loc + pulse_noise * (1 - uv)
+ return pulse_train, sine_wav, uv, pulse_noise
+
+
+class SignalsConv1d(torch.nn.Module):
+ """ Filtering input signal with time invariant filter
+ Note: FIRFilter conducted filtering given fixed FIR weight
+ SignalsConv1d convolves two signals
+ Note: this is based on torch.nn.functional.conv1d
+
+ """
+
+ def __init__(self):
+ super(SignalsConv1d, self).__init__()
+
+ def forward(self, signal, system_ir):
+ """ output = forward(signal, system_ir)
+
+ signal: (batchsize, length1, dim)
+ system_ir: (length2, dim)
+
+ output: (batchsize, length1, dim)
+ """
+ if signal.shape[-1] != system_ir.shape[-1]:
+ print("Error: SignalsConv1d expects shape:")
+ print("signal (batchsize, length1, dim)")
+ print("system_id (batchsize, length2, dim)")
+ print("But received signal: {:s}".format(str(signal.shape)))
+ print(" system_ir: {:s}".format(str(system_ir.shape)))
+ sys.exit(1)
+ padding_length = system_ir.shape[0] - 1
+ groups = signal.shape[-1]
+
+ # pad signal on the left
+ signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \
+ (padding_length, 0))
+ # prepare system impulse response as (dim, 1, length2)
+ # also flip the impulse response
+ ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \
+ dims=[2])
+ # convolute
+ output = torch_nn_func.conv1d(signal_pad, ir, groups=groups)
+ return output.permute(0, 2, 1)
+
+
+class CyclicNoiseGen_v1(torch.nn.Module):
+ """ CyclicnoiseGen_v1
+ Cyclic noise with a single parameter of beta.
+ Pytorch v1 implementation assumes f_t is also fixed
+ """
+
+ def __init__(self, samp_rate,
+ noise_std=0.003, voiced_threshold=0):
+ super(CyclicNoiseGen_v1, self).__init__()
+ self.samp_rate = samp_rate
+ self.noise_std = noise_std
+ self.voiced_threshold = voiced_threshold
+
+ self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0,
+ noise_std=noise_std,
+ voiced_threshold=voiced_threshold)
+ self.l_conv = SignalsConv1d()
+
+ def noise_decay(self, beta, f0mean):
+ """ decayed_noise = noise_decay(beta, f0mean)
+ decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate)
+
+ beta: (dim=1) or (batchsize=1, 1, dim=1)
+ f0mean (batchsize=1, 1, dim=1)
+
+ decayed_noise (batchsize=1, length, dim=1)
+ """
+ with torch.no_grad():
+ # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T
+ # truncate the noise when decayed by -40 dB
+ length = 4.6 * self.samp_rate / f0mean
+ length = length.int()
+ time_idx = torch.arange(0, length, device=beta.device)
+ time_idx = time_idx.unsqueeze(0).unsqueeze(2)
+ time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2])
+
+ noise = torch.randn(time_idx.shape, device=beta.device)
+
+ # due to Pytorch implementation, use f0_mean as the f0 factor
+ decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate)
+ return noise * self.noise_std * decay
+
+ def forward(self, f0s, beta):
+ """ Producde cyclic-noise
+ """
+ # pulse train
+ pulse_train, sine_wav, uv, noise = self.l_pulse(f0s)
+ pure_pulse = pulse_train - noise
+
+ # decayed_noise (length, dim=1)
+ if (uv < 1).all():
+ # all unvoiced
+ cyc_noise = torch.zeros_like(sine_wav)
+ else:
+ f0mean = f0s[uv > 0].mean()
+
+ decayed_noise = self.noise_decay(beta, f0mean)[0, :, :]
+ # convolute
+ cyc_noise = self.l_conv(pure_pulse, decayed_noise)
+
+ # add noise in invoiced segments
+ cyc_noise = cyc_noise + noise * (1.0 - uv)
+ return cyc_noise, pulse_train, sine_wav, uv, noise
+
+
+class SineGen(torch.nn.Module):
+ """ Definition of sine generator
+ SineGen(samp_rate, harmonic_num = 0,
+ sine_amp = 0.1, noise_std = 0.003,
+ voiced_threshold = 0,
+ flag_for_pulse=False)
+
+ samp_rate: sampling rate in Hz
+ harmonic_num: number of harmonic overtones (default 0)
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
+ noise_std: std of Gaussian noise (default 0.003)
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
+
+ Note: when flag_for_pulse is True, the first time step of a voiced
+ segment is always sin(np.pi) or cos(0)
+ """
+
+ def __init__(self, samp_rate, harmonic_num=0,
+ sine_amp=0.1, noise_std=0.003,
+ voiced_threshold=0,
+ flag_for_pulse=False):
+ super(SineGen, self).__init__()
+ self.sine_amp = sine_amp
+ self.noise_std = noise_std
+ self.harmonic_num = harmonic_num
+ self.dim = self.harmonic_num + 1
+ self.sampling_rate = samp_rate
+ self.voiced_threshold = voiced_threshold
+ self.flag_for_pulse = flag_for_pulse
+
+ def _f02uv(self, f0):
+ # generate uv signal
+ uv = torch.ones_like(f0)
+ uv = uv * (f0 > self.voiced_threshold)
+ return uv
+
+ def _f02sine(self, f0_values):
+ """ f0_values: (batchsize, length, dim)
+ where dim indicates fundamental tone and overtones
+ """
+ # convert to F0 in rad. The interger part n can be ignored
+ # because 2 * np.pi * n doesn't affect phase
+ rad_values = (f0_values / self.sampling_rate) % 1
+
+ # initial phase noise (no noise for fundamental component)
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
+ device=f0_values.device)
+ rand_ini[:, 0] = 0
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
+
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
+ if not self.flag_for_pulse:
+ # for normal case
+
+ # To prevent torch.cumsum numerical overflow,
+ # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
+ # Buffer tmp_over_one_idx indicates the time step to add -1.
+ # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] -
+ tmp_over_one[:, :-1, :]) < 0
+ cumsum_shift = torch.zeros_like(rad_values)
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
+
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1)
+ * 2 * np.pi)
+ else:
+ # If necessary, make sure that the first time step of every
+ # voiced segments is sin(pi) or cos(0)
+ # This is used for pulse-train generation
+
+ # identify the last time step in unvoiced segments
+ uv = self._f02uv(f0_values)
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
+ uv_1[:, -1, :] = 1
+ u_loc = (uv < 1) * (uv_1 > 0)
+
+ # get the instantanouse phase
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
+ # different batch needs to be processed differently
+ for idx in range(f0_values.shape[0]):
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
+ # stores the accumulation of i.phase within
+ # each voiced segments
+ tmp_cumsum[idx, :, :] = 0
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
+
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
+ # within the previous voiced segment.
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
+
+ # get the sines
+ sines = torch.cos(i_phase * 2 * np.pi)
+ return sines
+
+ def forward(self, f0):
+ """ sine_tensor, uv = forward(f0)
+ input F0: tensor(batchsize=1, length, dim=1)
+ f0 for unvoiced steps should be 0
+ output sine_tensor: tensor(batchsize=1, length, dim)
+ output uv: tensor(batchsize=1, length, 1)
+ """
+ with torch.no_grad():
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \
+ device=f0.device)
+ # fundamental component
+ f0_buf[:, :, 0] = f0[:, :, 0]
+ for idx in np.arange(self.harmonic_num):
+ # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
+
+ # generate sine waveforms
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
+
+ # generate uv signal
+ # uv = torch.ones(f0.shape)
+ # uv = uv * (f0 > self.voiced_threshold)
+ uv = self._f02uv(f0)
+
+ # noise: for unvoiced should be similar to sine_amp
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
+ # . for voiced regions is self.noise_std
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
+ noise = noise_amp * torch.randn_like(sine_waves)
+
+ # first: set the unvoiced part to 0 by uv
+ # then: additive noise
+ sine_waves = sine_waves * uv + noise
+ return sine_waves, uv, noise
+
+
+class SourceModuleCycNoise_v1(torch.nn.Module):
+ """ SourceModuleCycNoise_v1
+ SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+
+ noise_std: std of Gaussian noise (default: 0.003)
+ voiced_threshold: threshold to set U/V given F0 (default: 0)
+
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta)
+ F0_upsampled (batchsize, length, 1)
+ beta (1)
+ cyc (batchsize, length, 1)
+ noise (batchsize, length, 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleCycNoise_v1, self).__init__()
+ self.sampling_rate = sampling_rate
+ self.noise_std = noise_std
+ self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std,
+ voiced_threshod)
+
+ def forward(self, f0_upsamped, beta):
+ """
+ cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta)
+ F0_upsampled (batchsize, length, 1)
+ beta (1)
+ cyc (batchsize, length, 1)
+ noise (batchsize, length, 1)
+ uv (batchsize, length, 1)
+ """
+ # source for harmonic branch
+ cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta)
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.noise_std / 3
+ return cyc, noise, uv
+
+
+class SourceModuleHnNSF(torch.nn.Module):
+ """ SourceModule for hn-nsf
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0)
+ sampling_rate: sampling_rate in Hz
+ harmonic_num: number of harmonic above F0 (default: 0)
+ sine_amp: amplitude of sine source signal (default: 0.1)
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
+ note that amplitude of noise in unvoiced is decided
+ by sine_amp
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
+
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ uv (batchsize, length, 1)
+ """
+
+ def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
+ add_noise_std=0.003, voiced_threshod=0):
+ super(SourceModuleHnNSF, self).__init__()
+
+ self.sine_amp = sine_amp
+ self.noise_std = add_noise_std
+
+ # to produce sine waveforms
+ self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
+ sine_amp, add_noise_std, voiced_threshod)
+
+ # to merge source harmonics into a single excitation
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
+ self.l_tanh = torch.nn.Tanh()
+
+ def forward(self, x):
+ """
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
+ F0_sampled (batchsize, length, 1)
+ Sine_source (batchsize, length, 1)
+ noise_source (batchsize, length 1)
+ """
+ # source for harmonic branch
+ sine_wavs, uv, _ = self.l_sin_gen(x)
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
+
+ # source for noise branch, in the same shape as uv
+ noise = torch.randn_like(uv) * self.sine_amp / 3
+ return sine_merge, noise, uv
+
+
+if __name__ == '__main__':
+ source = SourceModuleCycNoise_v1(24000)
+ x = torch.randn(16, 25600, 1)
+
+
diff --git a/modules/parallel_wavegan/optimizers/__init__.py b/modules/parallel_wavegan/optimizers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0e0c5932838281e912079e5784d84d43444a61a
--- /dev/null
+++ b/modules/parallel_wavegan/optimizers/__init__.py
@@ -0,0 +1,2 @@
+from torch.optim import * # NOQA
+from .radam import * # NOQA
diff --git a/modules/parallel_wavegan/optimizers/radam.py b/modules/parallel_wavegan/optimizers/radam.py
new file mode 100644
index 0000000000000000000000000000000000000000..e805d7e34921bee436e1e7fd9e1f753c7609186b
--- /dev/null
+++ b/modules/parallel_wavegan/optimizers/radam.py
@@ -0,0 +1,91 @@
+# -*- coding: utf-8 -*-
+
+"""RAdam optimizer.
+
+This code is drived from https://github.com/LiyuanLucasLiu/RAdam.
+"""
+
+import math
+import torch
+
+from torch.optim.optimizer import Optimizer
+
+
+class RAdam(Optimizer):
+ """Rectified Adam optimizer."""
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
+ """Initilize RAdam optimizer."""
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ self.buffer = [[None, None, None] for ind in range(10)]
+ super(RAdam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ """Set state."""
+ super(RAdam, self).__setstate__(state)
+
+ def step(self, closure=None):
+ """Run one step."""
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+ if grad.is_sparse:
+ raise RuntimeError('RAdam does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p]
+
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+ state['step'] += 1
+ buffered = self.buffer[int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+ else:
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+ p.data.copy_(p_data_fp32)
+
+ return loss
diff --git a/modules/parallel_wavegan/stft_loss.py b/modules/parallel_wavegan/stft_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..229e6c777dc9ec7f710842d1e648dba1189ec8b4
--- /dev/null
+++ b/modules/parallel_wavegan/stft_loss.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""STFT-based Loss modules."""
+import librosa
+import torch
+
+from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft
+
+
+class STFTLoss(torch.nn.Module):
+ """STFT loss module."""
+
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window",
+ use_mel_loss=False):
+ """Initialize STFT loss module."""
+ super(STFTLoss, self).__init__()
+ self.fft_size = fft_size
+ self.shift_size = shift_size
+ self.win_length = win_length
+ self.window = getattr(torch, window)(win_length)
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
+ self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
+ self.use_mel_loss = use_mel_loss
+ self.mel_basis = None
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+
+ Returns:
+ Tensor: Spectral convergence loss value.
+ Tensor: Log STFT magnitude loss value.
+
+ """
+ x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
+ y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window)
+ if self.use_mel_loss:
+ if self.mel_basis is None:
+ self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T
+ x_mag = x_mag @ self.mel_basis
+ y_mag = y_mag @ self.mel_basis
+
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
+ mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
+
+ return sc_loss, mag_loss
+
+
+class MultiResolutionSTFTLoss(torch.nn.Module):
+ """Multi resolution STFT loss module."""
+
+ def __init__(self,
+ fft_sizes=[1024, 2048, 512],
+ hop_sizes=[120, 240, 50],
+ win_lengths=[600, 1200, 240],
+ window="hann_window",
+ use_mel_loss=False):
+ """Initialize Multi resolution STFT loss module.
+
+ Args:
+ fft_sizes (list): List of FFT sizes.
+ hop_sizes (list): List of hop sizes.
+ win_lengths (list): List of window lengths.
+ window (str): Window function type.
+
+ """
+ super(MultiResolutionSTFTLoss, self).__init__()
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
+ self.stft_losses = torch.nn.ModuleList()
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
+ self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)]
+
+ def forward(self, x, y):
+ """Calculate forward propagation.
+
+ Args:
+ x (Tensor): Predicted signal (B, T).
+ y (Tensor): Groundtruth signal (B, T).
+
+ Returns:
+ Tensor: Multi resolution spectral convergence loss value.
+ Tensor: Multi resolution log STFT magnitude loss value.
+
+ """
+ sc_loss = 0.0
+ mag_loss = 0.0
+ for f in self.stft_losses:
+ sc_l, mag_l = f(x, y)
+ sc_loss += sc_l
+ mag_loss += mag_l
+ sc_loss /= len(self.stft_losses)
+ mag_loss /= len(self.stft_losses)
+
+ return sc_loss, mag_loss
diff --git a/modules/parallel_wavegan/utils/__init__.py b/modules/parallel_wavegan/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8fa95a020706b5412c3959fbf6e5980019c0d5f
--- /dev/null
+++ b/modules/parallel_wavegan/utils/__init__.py
@@ -0,0 +1 @@
+from .utils import * # NOQA
diff --git a/modules/parallel_wavegan/utils/__pycache__/__init__.cpython-38.pyc b/modules/parallel_wavegan/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..856a1203a49a5f131e2e1f110f701ee988cc1b5f
Binary files /dev/null and b/modules/parallel_wavegan/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/utils/__pycache__/utils.cpython-38.pyc b/modules/parallel_wavegan/utils/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1f3736354ac834dab8860599354f3a0d2726b613
Binary files /dev/null and b/modules/parallel_wavegan/utils/__pycache__/utils.cpython-38.pyc differ
diff --git a/modules/parallel_wavegan/utils/utils.py b/modules/parallel_wavegan/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a30e803723d224d9825a753baf4cd91c94c1677
--- /dev/null
+++ b/modules/parallel_wavegan/utils/utils.py
@@ -0,0 +1,169 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Tomoki Hayashi
+# MIT License (https://opensource.org/licenses/MIT)
+
+"""Utility functions."""
+
+import fnmatch
+import logging
+import os
+import sys
+
+import h5py
+import numpy as np
+
+
+def find_files(root_dir, query="*.wav", include_root_dir=True):
+ """Find files recursively.
+
+ Args:
+ root_dir (str): Root root_dir to find.
+ query (str): Query to find.
+ include_root_dir (bool): If False, root_dir name is not included.
+
+ Returns:
+ list: List of found filenames.
+
+ """
+ files = []
+ for root, dirnames, filenames in os.walk(root_dir, followlinks=True):
+ for filename in fnmatch.filter(filenames, query):
+ files.append(os.path.join(root, filename))
+ if not include_root_dir:
+ files = [file_.replace(root_dir + "/", "") for file_ in files]
+
+ return files
+
+
+def read_hdf5(hdf5_name, hdf5_path):
+ """Read hdf5 dataset.
+
+ Args:
+ hdf5_name (str): Filename of hdf5 file.
+ hdf5_path (str): Dataset name in hdf5 file.
+
+ Return:
+ any: Dataset values.
+
+ """
+ if not os.path.exists(hdf5_name):
+ logging.error(f"There is no such a hdf5 file ({hdf5_name}).")
+ sys.exit(1)
+
+ hdf5_file = h5py.File(hdf5_name, "r")
+
+ if hdf5_path not in hdf5_file:
+ logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})")
+ sys.exit(1)
+
+ hdf5_data = hdf5_file[hdf5_path][()]
+ hdf5_file.close()
+
+ return hdf5_data
+
+
+def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True):
+ """Write dataset to hdf5.
+
+ Args:
+ hdf5_name (str): Hdf5 dataset filename.
+ hdf5_path (str): Dataset path in hdf5.
+ write_data (ndarray): Data to write.
+ is_overwrite (bool): Whether to overwrite dataset.
+
+ """
+ # convert to numpy array
+ write_data = np.array(write_data)
+
+ # check folder existence
+ folder_name, _ = os.path.split(hdf5_name)
+ if not os.path.exists(folder_name) and len(folder_name) != 0:
+ os.makedirs(folder_name)
+
+ # check hdf5 existence
+ if os.path.exists(hdf5_name):
+ # if already exists, open with r+ mode
+ hdf5_file = h5py.File(hdf5_name, "r+")
+ # check dataset existence
+ if hdf5_path in hdf5_file:
+ if is_overwrite:
+ logging.warning("Dataset in hdf5 file already exists. "
+ "recreate dataset in hdf5.")
+ hdf5_file.__delitem__(hdf5_path)
+ else:
+ logging.error("Dataset in hdf5 file already exists. "
+ "if you want to overwrite, please set is_overwrite = True.")
+ hdf5_file.close()
+ sys.exit(1)
+ else:
+ # if not exists, open with w mode
+ hdf5_file = h5py.File(hdf5_name, "w")
+
+ # write data to hdf5
+ hdf5_file.create_dataset(hdf5_path, data=write_data)
+ hdf5_file.flush()
+ hdf5_file.close()
+
+
+class HDF5ScpLoader(object):
+ """Loader class for a fests.scp file of hdf5 file.
+
+ Examples:
+ key1 /some/path/a.h5:feats
+ key2 /some/path/b.h5:feats
+ key3 /some/path/c.h5:feats
+ key4 /some/path/d.h5:feats
+ ...
+ >>> loader = HDF5ScpLoader("hdf5.scp")
+ >>> array = loader["key1"]
+
+ key1 /some/path/a.h5
+ key2 /some/path/b.h5
+ key3 /some/path/c.h5
+ key4 /some/path/d.h5
+ ...
+ >>> loader = HDF5ScpLoader("hdf5.scp", "feats")
+ >>> array = loader["key1"]
+
+ """
+
+ def __init__(self, feats_scp, default_hdf5_path="feats"):
+ """Initialize HDF5 scp loader.
+
+ Args:
+ feats_scp (str): Kaldi-style feats.scp file with hdf5 format.
+ default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used.
+
+ """
+ self.default_hdf5_path = default_hdf5_path
+ with open(feats_scp, encoding='utf-8') as f:
+ lines = [line.replace("\n", "") for line in f.readlines()]
+ self.data = {}
+ for line in lines:
+ key, value = line.split()
+ self.data[key] = value
+
+ def get_path(self, key):
+ """Get hdf5 file path for a given key."""
+ return self.data[key]
+
+ def __getitem__(self, key):
+ """Get ndarray for a given key."""
+ p = self.data[key]
+ if ":" in p:
+ return read_hdf5(*p.split(":"))
+ else:
+ return read_hdf5(p, self.default_hdf5_path)
+
+ def __len__(self):
+ """Return the length of the scp file."""
+ return len(self.data)
+
+ def __iter__(self):
+ """Return the iterator of the scp file."""
+ return iter(self.data)
+
+ def keys(self):
+ """Return the keys of the scp file."""
+ return self.data.keys()
diff --git a/network/diff/__pycache__/candidate_decoder.cpython-38.pyc b/network/diff/__pycache__/candidate_decoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..23ffbe728d4ac88fd1970a2f0d6d517ee2d86000
Binary files /dev/null and b/network/diff/__pycache__/candidate_decoder.cpython-38.pyc differ
diff --git a/network/diff/__pycache__/diffusion.cpython-38.pyc b/network/diff/__pycache__/diffusion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..cfaeb03a6cda242ad7e94ef2a5a757000f702308
Binary files /dev/null and b/network/diff/__pycache__/diffusion.cpython-38.pyc differ
diff --git a/network/diff/__pycache__/net.cpython-38.pyc b/network/diff/__pycache__/net.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9e44334bf72eae2fedc62291a535a5c5ac4f0be7
Binary files /dev/null and b/network/diff/__pycache__/net.cpython-38.pyc differ
diff --git a/network/diff/candidate_decoder.py b/network/diff/candidate_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..bccb47aad0285cab1e7aaca759294c8a1270849c
--- /dev/null
+++ b/network/diff/candidate_decoder.py
@@ -0,0 +1,98 @@
+from modules.fastspeech.tts_modules import FastspeechDecoder
+# from modules.fastspeech.fast_tacotron import DecoderRNN
+# from modules.fastspeech.speedy_speech.speedy_speech import ConvBlocks
+# from modules.fastspeech.conformer.conformer import ConformerDecoder
+import torch
+from torch.nn import functional as F
+import torch.nn as nn
+import math
+from utils.hparams import hparams
+from modules.commons.common_layers import Mish
+Linear = nn.Linear
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+def Conv1d(*args, **kwargs):
+ layer = nn.Conv1d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+class FFT(FastspeechDecoder): # unused, because DiffSinger only uses FastspeechEncoder
+ # NOTE: this part of script is *isolated* from other scripts, which means
+ # it may not be compatible with the current version.
+
+ def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None):
+ super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads)
+ dim = hparams['residual_channels']
+ self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1)
+ self.diffusion_embedding = SinusoidalPosEmb(dim)
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, dim * 4),
+ Mish(),
+ nn.Linear(dim * 4, dim)
+ )
+ self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True)
+ self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim,
+ hparams['hidden_size']) # hs + dim + 80 -> hs
+
+ def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False):
+ """
+ :param spec: [B, 1, 80, T]
+ :param diffusion_step: [B, 1]
+ :param cond: [B, M, T]
+ :return:
+ """
+ x = spec[:, 0]
+ x = self.input_projection(x).permute([0, 2, 1]) # [B, T, residual_channel]
+ diffusion_step = self.diffusion_embedding(diffusion_step)
+ diffusion_step = self.mlp(diffusion_step) # [B, dim]
+ cond = cond.permute([0, 2, 1]) # [B, T, M]
+
+ seq_len = cond.shape[1] # [T_mel]
+ time_embed = diffusion_step[:, None, :] # [B, 1, dim]
+ time_embed = time_embed.repeat([1, seq_len, 1]) # # [B, T, dim]
+
+ decoder_inp = torch.cat([x, cond, time_embed], dim=-1) # [B, T, dim + H + dim]
+ decoder_inp = self.get_decode_inp(decoder_inp) # [B, T, H]
+ x = decoder_inp
+
+ '''
+ Required x: [B, T, C]
+ :return: [B, T, C] or [L, B, T, C]
+ '''
+ padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask
+ nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1]
+ if self.use_pos_embed:
+ positions = self.pos_embed_alpha * self.embed_positions(x[..., 0])
+ x = x + positions
+ x = F.dropout(x, p=self.dropout, training=self.training)
+ # B x T x C -> T x B x C
+ x = x.transpose(0, 1) * nonpadding_mask_TB
+ hiddens = []
+ for layer in self.layers:
+ x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB
+ hiddens.append(x)
+ if self.use_last_norm:
+ x = self.layer_norm(x) * nonpadding_mask_TB
+ if return_hiddens:
+ x = torch.stack(hiddens, 0) # [L, T, B, C]
+ x = x.transpose(1, 2) # [L, B, T, C]
+ else:
+ x = x.transpose(0, 1) # [B, T, C]
+
+ x = self.get_mel_out(x).permute([0, 2, 1]) # [B, 80, T]
+ return x[:, None, :, :]
\ No newline at end of file
diff --git a/network/diff/diffusion.py b/network/diff/diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ff05212929a970974a59735735d704be83ecd3c
--- /dev/null
+++ b/network/diff/diffusion.py
@@ -0,0 +1,332 @@
+from collections import deque
+from functools import partial
+from inspect import isfunction
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+from tqdm import tqdm
+
+from modules.fastspeech.fs2 import FastSpeech2
+# from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
+from utils.hparams import hparams
+from training.train_pipeline import Batch2Loss
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+# gaussian diffusion trainer class
+
+def extract(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+
+def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
+ """
+ linear schedule
+ """
+ betas = np.linspace(1e-4, max_beta, timesteps)
+ return betas
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ steps = timesteps + 1
+ x = np.linspace(0, steps, steps)
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return np.clip(betas, a_min=0, a_max=0.999)
+
+
+beta_schedule = {
+ "cosine": cosine_beta_schedule,
+ "linear": linear_beta_schedule,
+}
+
+
+class GaussianDiffusion(nn.Module):
+ def __init__(self, phone_encoder, out_dims, denoise_fn,
+ timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None,
+ spec_max=None):
+ super().__init__()
+ self.denoise_fn = denoise_fn
+ # if hparams.get('use_midi') is not None and hparams['use_midi']:
+ # self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
+ # else:
+ self.fs2 = FastSpeech2(phone_encoder, out_dims)
+ self.mel_bins = out_dims
+
+ if exists(betas):
+ betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
+ else:
+ if 'schedule_type' in hparams.keys():
+ betas = beta_schedule[hparams['schedule_type']](timesteps)
+ else:
+ betas = cosine_beta_schedule(timesteps)
+
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.K_step = K_step
+ self.loss_type = loss_type
+
+ self.noise_list = deque(maxlen=4)
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer('posterior_mean_coef1', to_torch(
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
+ self.register_buffer('posterior_mean_coef2', to_torch(
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
+
+ self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
+ self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
+
+ def q_mean_variance(self, x_start, t):
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, t, cond, clip_denoised: bool):
+ noise_pred = self.denoise_fn(x, t, cond=cond)
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
+ """
+ Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
+ """
+
+ def get_x_pred(x, noise_t, t):
+ a_t = extract(self.alphas_cumprod, t, x.shape)
+ a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
+
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
+ x_pred = x + x_delta
+
+ return x_pred
+
+ noise_list = self.noise_list
+ noise_pred = self.denoise_fn(x, t, cond=cond)
+
+ if len(noise_list) == 0:
+ x_pred = get_x_pred(x, noise_pred, t)
+ noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
+ elif len(noise_list) == 1:
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
+ elif len(noise_list) == 2:
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
+ elif len(noise_list) >= 3:
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
+
+ x_prev = get_x_pred(x, noise_pred_prime, t)
+ noise_list.append(noise_pred)
+
+ return x_prev
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ x_recon = self.denoise_fn(x_noisy, t, cond)
+
+ if self.loss_type == 'l1':
+ if nonpadding is not None:
+ loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
+ else:
+ # print('are you sure w/o nonpadding?')
+ loss = (noise - x_recon).abs().mean()
+
+ elif self.loss_type == 'l2':
+ loss = F.mse_loss(noise, x_recon)
+ else:
+ raise NotImplementedError()
+
+ return loss
+
+ def forward(self, hubert, mel2ph=None, spk_embed=None,
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
+ '''
+ conditioning diffusion, use fastspeech2 encoder output as the condition
+ '''
+ ret = self.fs2(hubert, mel2ph, spk_embed, None, f0, uv, energy,
+ skip_decoder=True, infer=infer, **kwargs)
+ cond = ret['decoder_inp'].transpose(1, 2)
+ b, *_, device = *hubert.shape, hubert.device
+
+ if not infer:
+ Batch2Loss.module4(
+ self.p_losses,
+ self.norm_spec(ref_mels), cond, ret, self.K_step, b, device
+ )
+ else:
+ '''
+ ret['fs2_mel'] = ret['mel_out']
+ fs2_mels = ret['mel_out']
+ t = self.K_step
+ fs2_mels = self.norm_spec(fs2_mels)
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
+ print('===> gaussion start.')
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+ x = torch.randn(shape, device=device)
+ '''
+ if 'use_gt_mel' in kwargs.keys() and kwargs['use_gt_mel']:
+ t =kwargs['add_noise_step']
+ print('===>using ground truth mel as start, please make sure parameter "key==0" !')
+ fs2_mels = ref_mels
+ fs2_mels = self.norm_spec(fs2_mels)
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
+ # for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+ # x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+ else:
+ t = self.K_step
+ #print('===> gaussion start.')
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+ x = torch.randn(shape, device=device)
+ if hparams.get('pndm_speedup') and hparams['pndm_speedup'] > 1:
+ self.noise_list = deque(maxlen=4)
+ iteration_interval =hparams['pndm_speedup']
+ for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
+ total=t // iteration_interval):
+ x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
+ cond)
+ else:
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+ x = x[:, 0].transpose(1, 2)
+ if mel2ph is not None: # for singing
+ ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
+ else:
+ ret['mel_out'] = self.denorm_spec(x)
+ return ret
+
+ def norm_spec(self, x):
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
+
+ def denorm_spec(self, x):
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
+
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
+ return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+
+ def out2mel(self, x):
+ return x
+
+
+class OfflineGaussianDiffusion(GaussianDiffusion):
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
+
+ ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
+ skip_decoder=True, infer=True, **kwargs)
+ cond = ret['decoder_inp'].transpose(1, 2)
+ fs2_mels = ref_mels[1]
+ ref_mels = ref_mels[0]
+
+ if not infer:
+ t = torch.randint(0, self.K_step, (b,), device=device).long()
+ x = ref_mels
+ x = self.norm_spec(x)
+ x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
+ ret['diff_loss'] = self.p_losses(x, t, cond)
+ else:
+ t = self.K_step
+ fs2_mels = self.norm_spec(fs2_mels)
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
+
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
+
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
+ print('===> gaussion start.')
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
+ x = torch.randn(shape, device=device)
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
+ x = x[:, 0].transpose(1, 2)
+ ret['mel_out'] = self.denorm_spec(x)
+
+ return ret
\ No newline at end of file
diff --git a/network/diff/net.py b/network/diff/net.py
new file mode 100644
index 0000000000000000000000000000000000000000..df46b54cbd545b99e13b7f3802b7fc18f388e0c8
--- /dev/null
+++ b/network/diff/net.py
@@ -0,0 +1,135 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from math import sqrt
+
+from utils.hparams import hparams
+from modules.commons.common_layers import Mish
+
+Linear = nn.Linear
+ConvTranspose2d = nn.ConvTranspose2d
+
+
+class AttrDict(dict):
+ def __init__(self, *args, **kwargs):
+ super(AttrDict, self).__init__(*args, **kwargs)
+ self.__dict__ = self
+
+ def override(self, attrs):
+ if isinstance(attrs, dict):
+ self.__dict__.update(**attrs)
+ elif isinstance(attrs, (list, tuple, set)):
+ for attr in attrs:
+ self.override(attr)
+ elif attrs is not None:
+ raise NotImplementedError
+ return self
+
+
+class SinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
+
+
+def Conv1d(*args, **kwargs):
+ layer = nn.Conv1d(*args, **kwargs)
+ nn.init.kaiming_normal_(layer.weight)
+ return layer
+
+
+@torch.jit.script
+def silu(x):
+ return x * torch.sigmoid(x)
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, encoder_hidden, residual_channels, dilation):
+ super().__init__()
+ self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation)
+ self.diffusion_projection = Linear(residual_channels, residual_channels)
+ self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1)
+ self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)
+
+ def forward(self, x, conditioner, diffusion_step):
+ diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1)
+ conditioner = self.conditioner_projection(conditioner)
+ y = x + diffusion_step
+
+ y = self.dilated_conv(y) + conditioner
+
+ gate, filter = torch.chunk(y, 2, dim=1)
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
+ # gate, filter = torch.split(y, torch.div(y.shape[1], 2), dim=1)
+
+ y = torch.sigmoid(gate) * torch.tanh(filter)
+
+ y = self.output_projection(y)
+ residual, skip = torch.chunk(y, 2, dim=1)
+ # Using torch.split instead of torch.chunk to avoid using onnx::Slice
+ # residual, skip = torch.split(y, torch.div(y.shape[1], 2), dim=1)
+
+ return (x + residual) / sqrt(2.0), skip
+
+class DiffNet(nn.Module):
+ def __init__(self, in_dims=80):
+ super().__init__()
+ self.params = params = AttrDict(
+ # Model params
+ encoder_hidden=hparams['hidden_size'],
+ residual_layers=hparams['residual_layers'],
+ residual_channels=hparams['residual_channels'],
+ dilation_cycle_length=hparams['dilation_cycle_length'],
+ )
+ self.input_projection = Conv1d(in_dims, params.residual_channels, 1)
+ self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels)
+ dim = params.residual_channels
+ self.mlp = nn.Sequential(
+ nn.Linear(dim, dim * 4),
+ Mish(),
+ nn.Linear(dim * 4, dim)
+ )
+ self.residual_layers = nn.ModuleList([
+ ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length))
+ for i in range(params.residual_layers)
+ ])
+ self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1)
+ self.output_projection = Conv1d(params.residual_channels, in_dims, 1)
+ nn.init.zeros_(self.output_projection.weight)
+
+ def forward(self, spec, diffusion_step, cond):
+ """
+
+ :param spec: [B, 1, M, T]
+ :param diffusion_step: [B, 1]
+ :param cond: [B, M, T]
+ :return:
+ """
+ x = spec[:, 0]
+ x = self.input_projection(x) # x [B, residual_channel, T]
+
+ x = F.relu(x)
+ diffusion_step = self.diffusion_embedding(diffusion_step)
+ diffusion_step = self.mlp(diffusion_step)
+ skip = []
+ for layer_id, layer in enumerate(self.residual_layers):
+ x, skip_connection = layer(x, cond, diffusion_step)
+ skip.append(skip_connection)
+
+ x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers))
+ x = self.skip_projection(x)
+ x = F.relu(x)
+ x = self.output_projection(x) # [B, 80, T]
+ return x[:, None, :, :]
diff --git a/network/hubert/Hifi.txt b/network/hubert/Hifi.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4fc0b8e986f454fa2a5a9b878f1b89d56dd6a1b8
--- /dev/null
+++ b/network/hubert/Hifi.txt
@@ -0,0 +1 @@
+fuck2023arianitesshitcelebrities
\ No newline at end of file
diff --git a/network/hubert/__pycache__/hubert_model.cpython-38.pyc b/network/hubert/__pycache__/hubert_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..877626aac6c044efc0b362a07328a14903be4ae7
Binary files /dev/null and b/network/hubert/__pycache__/hubert_model.cpython-38.pyc differ
diff --git a/network/hubert/__pycache__/vec_model.cpython-38.pyc b/network/hubert/__pycache__/vec_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..13cb994a010702d266eb9925006e4971a7092732
Binary files /dev/null and b/network/hubert/__pycache__/vec_model.cpython-38.pyc differ
diff --git a/network/hubert/hubert_model.py b/network/hubert/hubert_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..54379f3d47383810c545c53e042d806734e9f8de
--- /dev/null
+++ b/network/hubert/hubert_model.py
@@ -0,0 +1,276 @@
+import copy
+import os
+import random
+from typing import Optional, Tuple
+
+import librosa
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as t_func
+from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
+
+from utils import hparams
+
+
+class Hubert(nn.Module):
+ def __init__(self, num_label_embeddings: int = 100, mask: bool = True):
+ super().__init__()
+ self._mask = mask
+ self.feature_extractor = FeatureExtractor()
+ self.feature_projection = FeatureProjection()
+ self.positional_embedding = PositionalConvEmbedding()
+ self.norm = nn.LayerNorm(768)
+ self.dropout = nn.Dropout(0.1)
+ self.encoder = TransformerEncoder(
+ nn.TransformerEncoderLayer(
+ 768, 12, 3072, activation="gelu", batch_first=True
+ ),
+ 12,
+ )
+ self.proj = nn.Linear(768, 256)
+
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(768).uniform_())
+ self.label_embedding = nn.Embedding(num_label_embeddings, 256)
+
+ def mask(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ mask = None
+ if self.training and self._mask:
+ mask = _compute_mask((x.size(0), x.size(1)), 0.8, 10, x.device, 2)
+ x[mask] = self.masked_spec_embed.to(x.dtype)
+ return x, mask
+
+ def encode(
+ self, x: torch.Tensor, layer: Optional[int] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = self.feature_extractor(x)
+ x = self.feature_projection(x.transpose(1, 2))
+ x, mask = self.mask(x)
+ x = x + self.positional_embedding(x)
+ x = self.dropout(self.norm(x))
+ x = self.encoder(x, output_layer=layer)
+ return x, mask
+
+ def logits(self, x: torch.Tensor) -> torch.Tensor:
+ logits = torch.cosine_similarity(
+ x.unsqueeze(2),
+ self.label_embedding.weight.unsqueeze(0).unsqueeze(0),
+ dim=-1,
+ )
+ return logits / 0.1
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, mask = self.encode(x)
+ x = self.proj(x)
+ logits = self.logits(x)
+ return logits, mask
+
+
+class HubertSoft(Hubert):
+ def __init__(self):
+ super().__init__()
+
+ # @torch.inference_mode()
+ def units(self, wav: torch.Tensor) -> torch.Tensor:
+ wav = torch.nn.functional.pad(wav, ((400 - 320) // 2, (400 - 320) // 2))
+ x, _ = self.encode(wav)
+ return self.proj(x)
+
+ def forward(self, wav: torch.Tensor):
+ return self.units(wav)
+
+
+class FeatureExtractor(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv0 = nn.Conv1d(1, 512, 10, 5, bias=False)
+ self.norm0 = nn.GroupNorm(512, 512)
+ self.conv1 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv2 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv3 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv4 = nn.Conv1d(512, 512, 3, 2, bias=False)
+ self.conv5 = nn.Conv1d(512, 512, 2, 2, bias=False)
+ self.conv6 = nn.Conv1d(512, 512, 2, 2, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = t_func.gelu(self.norm0(self.conv0(x)))
+ x = t_func.gelu(self.conv1(x))
+ x = t_func.gelu(self.conv2(x))
+ x = t_func.gelu(self.conv3(x))
+ x = t_func.gelu(self.conv4(x))
+ x = t_func.gelu(self.conv5(x))
+ x = t_func.gelu(self.conv6(x))
+ return x
+
+
+class FeatureProjection(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.norm = nn.LayerNorm(512)
+ self.projection = nn.Linear(512, 768)
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = self.projection(x)
+ x = self.dropout(x)
+ return x
+
+
+class PositionalConvEmbedding(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ 768,
+ 768,
+ kernel_size=128,
+ padding=128 // 2,
+ groups=16,
+ )
+ self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x.transpose(1, 2))
+ x = t_func.gelu(x[:, :, :-1])
+ return x.transpose(1, 2)
+
+
+class TransformerEncoder(nn.Module):
+ def __init__(
+ self, encoder_layer: nn.TransformerEncoderLayer, num_layers: int
+ ) -> None:
+ super(TransformerEncoder, self).__init__()
+ self.layers = nn.ModuleList(
+ [copy.deepcopy(encoder_layer) for _ in range(num_layers)]
+ )
+ self.num_layers = num_layers
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ mask: torch.Tensor = None,
+ src_key_padding_mask: torch.Tensor = None,
+ output_layer: Optional[int] = None,
+ ) -> torch.Tensor:
+ output = src
+ for layer in self.layers[:output_layer]:
+ output = layer(
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask
+ )
+ return output
+
+
+def _compute_mask(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ device: torch.device,
+ min_masks: int = 0,
+) -> torch.Tensor:
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
+ )
+
+ # compute number of masked spans in batch
+ num_masked_spans = int(mask_prob * sequence_length / mask_length + random.random())
+ num_masked_spans = max(num_masked_spans, min_masks)
+
+ # make sure num masked indices <= sequence_length
+ if num_masked_spans * mask_length > sequence_length:
+ num_masked_spans = sequence_length // mask_length
+
+ # SpecAugment mask to fill
+ mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
+
+ # uniform distribution to sample from, make sure that offset samples are < sequence_length
+ uniform_dist = torch.ones(
+ (batch_size, sequence_length - (mask_length - 1)), device=device
+ )
+
+ # get random indices to mask
+ mask_indices = torch.multinomial(uniform_dist, num_masked_spans)
+
+ # expand masked indices to masked spans
+ mask_indices = (
+ mask_indices.unsqueeze(dim=-1)
+ .expand((batch_size, num_masked_spans, mask_length))
+ .reshape(batch_size, num_masked_spans * mask_length)
+ )
+ offsets = (
+ torch.arange(mask_length, device=device)[None, None, :]
+ .expand((batch_size, num_masked_spans, mask_length))
+ .reshape(batch_size, num_masked_spans * mask_length)
+ )
+ mask_idxs = mask_indices + offsets
+
+ # scatter indices to mask
+ mask = mask.scatter(1, mask_idxs, True)
+
+ return mask
+
+
+def hubert_soft(
+ path: str
+) -> HubertSoft:
+ r"""HuBERT-Soft from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
+ Args:
+ path (str): path of a pretrained model
+ """
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ hubert = HubertSoft()
+ checkpoint = torch.load(path)
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
+ hubert.load_state_dict(checkpoint)
+ hubert.eval().to(dev)
+ return hubert
+
+
+def get_units(hbt_soft, raw_wav_path, dev=torch.device('cuda')):
+ wav, sr = librosa.load(raw_wav_path, sr=None)
+ assert (sr >= 16000)
+ if len(wav.shape) > 1:
+ wav = librosa.to_mono(wav)
+ if sr != 16000:
+ wav16 = librosa.resample(wav, sr, 16000)
+ else:
+ wav16 = wav
+ dev = torch.device("cuda" if (dev == torch.device('cuda') and torch.cuda.is_available()) else "cpu")
+ torch.cuda.is_available() and torch.cuda.empty_cache()
+ with torch.inference_mode():
+ units = hbt_soft.units(torch.FloatTensor(wav16.astype(float)).unsqueeze(0).unsqueeze(0).to(dev))
+ return units
+
+
+def get_end_file(dir_path, end):
+ file_list = []
+ for root, dirs, files in os.walk(dir_path):
+ files = [f for f in files if f[0] != '.']
+ dirs[:] = [d for d in dirs if d[0] != '.']
+ for f_file in files:
+ if f_file.endswith(end):
+ file_list.append(os.path.join(root, f_file).replace("\\", "/"))
+ return file_list
+
+
+if __name__ == '__main__':
+ from pathlib import Path
+
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ # hubert的模型路径
+ hbt_model = hubert_soft(str(list(Path(hparams['hubert_path']).home().rglob('*.pt'))[0]))
+ # 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
+ file_lists = list(Path(hparams['raw_data_dir']).rglob('*.wav'))
+ nums = len(file_lists)
+ count = 0
+ for wav_path in file_lists:
+ npy_path = wav_path.with_suffix(".npy")
+ npy_content = get_units(hbt_model, wav_path).cpu().numpy()[0]
+ np.save(str(npy_path), npy_content)
+ count += 1
+ print(f"hubert process:{round(count * 100 / nums, 2)}%")
diff --git a/network/hubert/vec_model.py b/network/hubert/vec_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee4b7a152ab9f299d43d6f519564caca25d836a5
--- /dev/null
+++ b/network/hubert/vec_model.py
@@ -0,0 +1,60 @@
+from pathlib import Path
+
+import librosa
+import numpy as np
+import torch
+
+
+
+def load_model(vec_path):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ print("load model(s) from {}".format(vec_path))
+ from fairseq import checkpoint_utils
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
+ [vec_path],
+ suffix="",
+ )
+ model = models[0]
+ model = model.to(device)
+ model.eval()
+ return model
+
+
+def get_vec_units(con_model, audio_path, dev):
+ audio, sampling_rate = librosa.load(audio_path)
+ if len(audio.shape) > 1:
+ audio = librosa.to_mono(audio.transpose(1, 0))
+ if sampling_rate != 16000:
+ audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
+
+ feats = torch.from_numpy(audio).float()
+ if feats.dim() == 2: # double channels
+ feats = feats.mean(-1)
+ assert feats.dim() == 1, feats.dim()
+ feats = feats.view(1, -1)
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
+ inputs = {
+ "source": feats.to(dev),
+ "padding_mask": padding_mask.to(dev),
+ "output_layer": 9, # layer 9
+ }
+ with torch.no_grad():
+ logits = con_model.extract_features(**inputs)
+ feats = con_model.final_proj(logits[0])
+ return feats
+
+
+if __name__ == '__main__':
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model_path = "../../checkpoints/checkpoint_best_legacy_500.pt" # checkpoint_best_legacy_500.pt
+ vec_model = load_model(model_path)
+ # 这个不用改,自动在根目录下所有wav的同文件夹生成其对应的npy
+ file_lists = list(Path("../../data/vecfox").rglob('*.wav'))
+ nums = len(file_lists)
+ count = 0
+ for wav_path in file_lists:
+ npy_path = wav_path.with_suffix(".npy")
+ npy_content = get_vec_units(vec_model, str(wav_path), device).cpu().numpy()[0]
+ np.save(str(npy_path), npy_content)
+ count += 1
+ print(f"hubert process:{round(count * 100 / nums, 2)}%")
diff --git a/network/vocoders/__init__.py b/network/vocoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6631bafa406a3e3add4903f3e7a11957d416a78f
--- /dev/null
+++ b/network/vocoders/__init__.py
@@ -0,0 +1,2 @@
+from network.vocoders import hifigan
+from network.vocoders import nsf_hifigan
diff --git a/network/vocoders/__pycache__/__init__.cpython-38.pyc b/network/vocoders/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34f3ccdc52637b18428ae7335d6e1063f5e12fda
Binary files /dev/null and b/network/vocoders/__pycache__/__init__.cpython-38.pyc differ
diff --git a/network/vocoders/__pycache__/base_vocoder.cpython-38.pyc b/network/vocoders/__pycache__/base_vocoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bf6210242ee7ca268e80f2c2e55f5c04119ce306
Binary files /dev/null and b/network/vocoders/__pycache__/base_vocoder.cpython-38.pyc differ
diff --git a/network/vocoders/__pycache__/hifigan.cpython-38.pyc b/network/vocoders/__pycache__/hifigan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..fdacbba6320cbc032beb7b70c742db8c0c98d3b0
Binary files /dev/null and b/network/vocoders/__pycache__/hifigan.cpython-38.pyc differ
diff --git a/network/vocoders/__pycache__/nsf_hifigan.cpython-38.pyc b/network/vocoders/__pycache__/nsf_hifigan.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2260b176a210d5fec3269ea0d3f1b618b5a77753
Binary files /dev/null and b/network/vocoders/__pycache__/nsf_hifigan.cpython-38.pyc differ
diff --git a/network/vocoders/__pycache__/pwg.cpython-38.pyc b/network/vocoders/__pycache__/pwg.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..def508d0e8f28d332075ffb6098bfacf79f612fa
Binary files /dev/null and b/network/vocoders/__pycache__/pwg.cpython-38.pyc differ
diff --git a/network/vocoders/__pycache__/vocoder_utils.cpython-38.pyc b/network/vocoders/__pycache__/vocoder_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5d1324a12266ff142ca0e9a727d5bec5be8ae1f
Binary files /dev/null and b/network/vocoders/__pycache__/vocoder_utils.cpython-38.pyc differ
diff --git a/network/vocoders/base_vocoder.py b/network/vocoders/base_vocoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe49a9e4f790ecdc5e76d60a23f96602b59fc48d
--- /dev/null
+++ b/network/vocoders/base_vocoder.py
@@ -0,0 +1,39 @@
+import importlib
+VOCODERS = {}
+
+
+def register_vocoder(cls):
+ VOCODERS[cls.__name__.lower()] = cls
+ VOCODERS[cls.__name__] = cls
+ return cls
+
+
+def get_vocoder_cls(hparams):
+ if hparams['vocoder'] in VOCODERS:
+ return VOCODERS[hparams['vocoder']]
+ else:
+ vocoder_cls = hparams['vocoder']
+ pkg = ".".join(vocoder_cls.split(".")[:-1])
+ cls_name = vocoder_cls.split(".")[-1]
+ vocoder_cls = getattr(importlib.import_module(pkg), cls_name)
+ return vocoder_cls
+
+
+class BaseVocoder:
+ def spec2wav(self, mel):
+ """
+
+ :param mel: [T, 80]
+ :return: wav: [T']
+ """
+
+ raise NotImplementedError
+
+ @staticmethod
+ def wav2spec(wav_fn):
+ """
+
+ :param wav_fn: str
+ :return: wav, mel: [T, 80]
+ """
+ raise NotImplementedError
diff --git a/network/vocoders/hifigan.py b/network/vocoders/hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..8838b9316e9c5608dba799381c52276710136d4a
--- /dev/null
+++ b/network/vocoders/hifigan.py
@@ -0,0 +1,83 @@
+import glob
+import json
+import os
+import re
+
+import librosa
+import torch
+
+import utils
+from modules.hifigan.hifigan import HifiGanGenerator
+from utils.hparams import hparams, set_hparams
+from network.vocoders.base_vocoder import register_vocoder
+from network.vocoders.pwg import PWG
+from network.vocoders.vocoder_utils import denoise
+
+
+def load_model(config_path, file_path):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ ext = os.path.splitext(file_path)[-1]
+ if ext == '.pth':
+ if '.yaml' in config_path:
+ config = set_hparams(config_path, global_hparams=False)
+ elif '.json' in config_path:
+ config = json.load(open(config_path, 'r', encoding='utf-8'))
+ model = torch.load(file_path, map_location="cpu")
+ elif ext == '.ckpt':
+ ckpt_dict = torch.load(file_path, map_location="cpu")
+ if '.yaml' in config_path:
+ config = set_hparams(config_path, global_hparams=False)
+ state = ckpt_dict["state_dict"]["model_gen"]
+ elif '.json' in config_path:
+ config = json.load(open(config_path, 'r', encoding='utf-8'))
+ state = ckpt_dict["generator"]
+ model = HifiGanGenerator(config)
+ model.load_state_dict(state, strict=True)
+ model.remove_weight_norm()
+ model = model.eval().to(device)
+ print(f"| Loaded model parameters from {file_path}.")
+ print(f"| HifiGAN device: {device}.")
+ return model, config, device
+
+
+total_time = 0
+
+
+@register_vocoder
+class HifiGAN(PWG):
+ def __init__(self):
+ base_dir = hparams['vocoder_ckpt']
+ config_path = f'{base_dir}/config.yaml'
+ if os.path.exists(config_path):
+ file_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.*'), key=
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).*', x.replace('\\','/'))[0]))[-1]
+ print('| load HifiGAN: ', file_path)
+ self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
+ else:
+ config_path = f'{base_dir}/config.json'
+ ckpt = f'{base_dir}/generator_v1'
+ if os.path.exists(config_path):
+ self.model, self.config, self.device = load_model(config_path=config_path, file_path=file_path)
+
+ def spec2wav(self, mel, **kwargs):
+ device = self.device
+ with torch.no_grad():
+ c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device)
+ with utils.Timer('hifigan', print_time=hparams['profile_infer']):
+ f0 = kwargs.get('f0')
+ if f0 is not None and hparams.get('use_nsf'):
+ f0 = torch.FloatTensor(f0[None, :]).to(device)
+ y = self.model(c, f0).view(-1)
+ else:
+ y = self.model(c).view(-1)
+ wav_out = y.cpu().numpy()
+ if hparams.get('vocoder_denoise_c', 0.0) > 0:
+ wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c'])
+ return wav_out
+
+ # @staticmethod
+ # def wav2spec(wav_fn, **kwargs):
+ # wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate'])
+ # wav_torch = torch.FloatTensor(wav)[None, :]
+ # mel = mel_spectrogram(wav_torch, hparams).numpy()[0]
+ # return wav, mel.T
diff --git a/network/vocoders/nsf_hifigan.py b/network/vocoders/nsf_hifigan.py
new file mode 100644
index 0000000000000000000000000000000000000000..93975546a7acff64279b3fc84b4edd0a7d292714
--- /dev/null
+++ b/network/vocoders/nsf_hifigan.py
@@ -0,0 +1,92 @@
+import os
+import torch
+from modules.nsf_hifigan.models import load_model, Generator
+from modules.nsf_hifigan.nvSTFT import load_wav_to_torch, STFT
+from utils.hparams import hparams
+from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
+
+@register_vocoder
+class NsfHifiGAN(BaseVocoder):
+ def __init__(self, device=None):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ self.device = device
+ model_path = hparams['vocoder_ckpt']
+ if os.path.exists(model_path):
+ print('| Load HifiGAN: ', model_path)
+ self.model, self.h = load_model(model_path, device=self.device)
+ else:
+ print('Error: HifiGAN model file is not found!')
+
+ def spec2wav_torch(self, mel, **kwargs): # mel: [B, T, bins]
+ if self.h.sampling_rate != hparams['audio_sample_rate']:
+ print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
+ if self.h.num_mels != hparams['audio_num_mel_bins']:
+ print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
+ if self.h.n_fft != hparams['fft_size']:
+ print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
+ if self.h.win_size != hparams['win_size']:
+ print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
+ if self.h.hop_size != hparams['hop_size']:
+ print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
+ if self.h.fmin != hparams['fmin']:
+ print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
+ if self.h.fmax != hparams['fmax']:
+ print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
+ with torch.no_grad():
+ c = mel.transpose(2, 1) #[B, T, bins]
+ #log10 to log mel
+ c = 2.30259 * c
+ f0 = kwargs.get('f0') #[B, T]
+ if f0 is not None and hparams.get('use_nsf'):
+ y = self.model(c, f0).view(-1)
+ else:
+ y = self.model(c).view(-1)
+ return y
+
+ def spec2wav(self, mel, **kwargs):
+ if self.h.sampling_rate != hparams['audio_sample_rate']:
+ print('Mismatch parameters: hparams[\'audio_sample_rate\']=',hparams['audio_sample_rate'],'!=',self.h.sampling_rate,'(vocoder)')
+ if self.h.num_mels != hparams['audio_num_mel_bins']:
+ print('Mismatch parameters: hparams[\'audio_num_mel_bins\']=',hparams['audio_num_mel_bins'],'!=',self.h.num_mels,'(vocoder)')
+ if self.h.n_fft != hparams['fft_size']:
+ print('Mismatch parameters: hparams[\'fft_size\']=',hparams['fft_size'],'!=',self.h.n_fft,'(vocoder)')
+ if self.h.win_size != hparams['win_size']:
+ print('Mismatch parameters: hparams[\'win_size\']=',hparams['win_size'],'!=',self.h.win_size,'(vocoder)')
+ if self.h.hop_size != hparams['hop_size']:
+ print('Mismatch parameters: hparams[\'hop_size\']=',hparams['hop_size'],'!=',self.h.hop_size,'(vocoder)')
+ if self.h.fmin != hparams['fmin']:
+ print('Mismatch parameters: hparams[\'fmin\']=',hparams['fmin'],'!=',self.h.fmin,'(vocoder)')
+ if self.h.fmax != hparams['fmax']:
+ print('Mismatch parameters: hparams[\'fmax\']=',hparams['fmax'],'!=',self.h.fmax,'(vocoder)')
+ with torch.no_grad():
+ c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(self.device)
+ #log10 to log mel
+ c = 2.30259 * c
+ f0 = kwargs.get('f0')
+ if f0 is not None and hparams.get('use_nsf'):
+ f0 = torch.FloatTensor(f0[None, :]).to(self.device)
+ y = self.model(c, f0).view(-1)
+ else:
+ y = self.model(c).view(-1)
+ wav_out = y.cpu().numpy()
+ return wav_out
+
+ @staticmethod
+ def wav2spec(inp_path, device=None):
+ if device is None:
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ sampling_rate = hparams['audio_sample_rate']
+ num_mels = hparams['audio_num_mel_bins']
+ n_fft = hparams['fft_size']
+ win_size =hparams['win_size']
+ hop_size = hparams['hop_size']
+ fmin = hparams['fmin']
+ fmax = hparams['fmax']
+ stft = STFT(sampling_rate, num_mels, n_fft, win_size, hop_size, fmin, fmax)
+ with torch.no_grad():
+ wav_torch, _ = load_wav_to_torch(inp_path, target_sr=stft.target_sr)
+ mel_torch = stft.get_mel(wav_torch.unsqueeze(0).to(device)).squeeze(0).T
+ #log mel to log10 mel
+ mel_torch = 0.434294 * mel_torch
+ return wav_torch.cpu().numpy(), mel_torch.cpu().numpy()
\ No newline at end of file
diff --git a/network/vocoders/pwg.py b/network/vocoders/pwg.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf2de16f271b66c308c604e52e9ab89242d5663e
--- /dev/null
+++ b/network/vocoders/pwg.py
@@ -0,0 +1,137 @@
+import glob
+import re
+import librosa
+import torch
+import yaml
+from sklearn.preprocessing import StandardScaler
+from torch import nn
+from modules.parallel_wavegan.models import ParallelWaveGANGenerator
+from modules.parallel_wavegan.utils import read_hdf5
+from utils.hparams import hparams
+from utils.pitch_utils import f0_to_coarse
+from network.vocoders.base_vocoder import BaseVocoder, register_vocoder
+import numpy as np
+
+
+def load_pwg_model(config_path, checkpoint_path, stats_path):
+ # load config
+ with open(config_path, encoding='utf-8') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+
+ # setup
+ if torch.cuda.is_available():
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+ model = ParallelWaveGANGenerator(**config["generator_params"])
+
+ ckpt_dict = torch.load(checkpoint_path, map_location="cpu")
+ if 'state_dict' not in ckpt_dict: # official vocoder
+ model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"])
+ scaler = StandardScaler()
+ if config["format"] == "hdf5":
+ scaler.mean_ = read_hdf5(stats_path, "mean")
+ scaler.scale_ = read_hdf5(stats_path, "scale")
+ elif config["format"] == "npy":
+ scaler.mean_ = np.load(stats_path)[0]
+ scaler.scale_ = np.load(stats_path)[1]
+ else:
+ raise ValueError("support only hdf5 or npy format.")
+ else: # custom PWG vocoder
+ fake_task = nn.Module()
+ fake_task.model_gen = model
+ fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False)
+ scaler = None
+
+ model.remove_weight_norm()
+ model = model.eval().to(device)
+ print(f"| Loaded model parameters from {checkpoint_path}.")
+ print(f"| PWG device: {device}.")
+ return model, scaler, config, device
+
+
+@register_vocoder
+class PWG(BaseVocoder):
+ def __init__(self):
+ if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model
+ base_dir = 'wavegan_pretrained'
+ ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl')
+ ckpt = sorted(ckpts, key=
+ lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1]
+ config_path = f'{base_dir}/config.yaml'
+ print('| load PWG: ', ckpt)
+ self.model, self.scaler, self.config, self.device = load_pwg_model(
+ config_path=config_path,
+ checkpoint_path=ckpt,
+ stats_path=f'{base_dir}/stats.h5',
+ )
+ else:
+ base_dir = hparams['vocoder_ckpt']
+ print(base_dir)
+ config_path = f'{base_dir}/config.yaml'
+ ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1]
+ print('| load PWG: ', ckpt)
+ self.scaler = None
+ self.model, _, self.config, self.device = load_pwg_model(
+ config_path=config_path,
+ checkpoint_path=ckpt,
+ stats_path=f'{base_dir}/stats.h5',
+ )
+
+ def spec2wav(self, mel, **kwargs):
+ # start generation
+ config = self.config
+ device = self.device
+ pad_size = (config["generator_params"]["aux_context_window"],
+ config["generator_params"]["aux_context_window"])
+ c = mel
+ if self.scaler is not None:
+ c = self.scaler.transform(c)
+
+ with torch.no_grad():
+ z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device)
+ c = np.pad(c, (pad_size, (0, 0)), "edge")
+ c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device)
+ p = kwargs.get('f0')
+ if p is not None:
+ p = f0_to_coarse(p)
+ p = np.pad(p, (pad_size,), "edge")
+ p = torch.LongTensor(p[None, :]).to(device)
+ y = self.model(z, c, p).view(-1)
+ wav_out = y.cpu().numpy()
+ return wav_out
+
+ @staticmethod
+ def wav2spec(wav_fn, return_linear=False):
+ from preprocessing.data_gen_utils import process_utterance
+ res = process_utterance(
+ wav_fn, fft_size=hparams['fft_size'],
+ hop_size=hparams['hop_size'],
+ win_length=hparams['win_size'],
+ num_mels=hparams['audio_num_mel_bins'],
+ fmin=hparams['fmin'],
+ fmax=hparams['fmax'],
+ sample_rate=hparams['audio_sample_rate'],
+ loud_norm=hparams['loud_norm'],
+ min_level_db=hparams['min_level_db'],
+ return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10)))
+ if return_linear:
+ return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft]
+ else:
+ return res[0], res[1].T
+
+ @staticmethod
+ def wav2mfcc(wav_fn):
+ fft_size = hparams['fft_size']
+ hop_size = hparams['hop_size']
+ win_length = hparams['win_size']
+ sample_rate = hparams['audio_sample_rate']
+ wav, _ = librosa.core.load(wav_fn, sr=sample_rate)
+ mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13,
+ n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, pad_mode="constant", power=1.0)
+ mfcc_delta = librosa.feature.delta(mfcc, order=1)
+ mfcc_delta_delta = librosa.feature.delta(mfcc, order=2)
+ mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T
+ return mfcc
diff --git a/network/vocoders/vocoder_utils.py b/network/vocoders/vocoder_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..db5d5ca1765928e4b047db04435a8a39b52592ca
--- /dev/null
+++ b/network/vocoders/vocoder_utils.py
@@ -0,0 +1,15 @@
+import librosa
+
+from utils.hparams import hparams
+import numpy as np
+
+
+def denoise(wav, v=0.1):
+ spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'],
+ win_length=hparams['win_size'], pad_mode='constant')
+ spec_m = np.abs(spec)
+ spec_m = np.clip(spec_m - v, a_min=0, a_max=None)
+ spec_a = np.angle(spec)
+
+ return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'],
+ win_length=hparams['win_size'])
diff --git a/packages.txt b/packages.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3f9e57bd15883069dcd71837fe76a4171b50ad77
--- /dev/null
+++ b/packages.txt
@@ -0,0 +1 @@
+libsndfile1-dev
\ No newline at end of file
diff --git a/preprocessing/SVCpre.py b/preprocessing/SVCpre.py
new file mode 100644
index 0000000000000000000000000000000000000000..2faa0737fb5d61f6bdb4ac1fb959711c50311d0e
--- /dev/null
+++ b/preprocessing/SVCpre.py
@@ -0,0 +1,63 @@
+'''
+
+ item: one piece of data
+ item_name: data id
+ wavfn: wave file path
+ txt: lyrics
+ ph: phoneme
+ tgfn: text grid file path (unused)
+ spk: dataset name
+ wdb: word boundary
+ ph_durs: phoneme durations
+ midi: pitch as midi notes
+ midi_dur: midi duration
+ is_slur: keep singing upon note changes
+'''
+
+
+from copy import deepcopy
+
+import logging
+
+from preprocessing.process_pipeline import File2Batch
+from utils.hparams import hparams
+from preprocessing.base_binarizer import BaseBinarizer
+
+SVCSINGING_ITEM_ATTRIBUTES = ['wav_fn', 'spk_id']
+class SVCBinarizer(BaseBinarizer):
+ def __init__(self, item_attributes=SVCSINGING_ITEM_ATTRIBUTES):
+ super().__init__(item_attributes)
+ print('spkers: ', set(item['spk_id'] for item in self.items.values()))
+ self.item_names = sorted(list(self.items.keys()))
+ self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names)
+ # self._valid_item_names=[]
+
+ def split_train_test_set(self, item_names):
+ item_names = deepcopy(item_names)
+ if hparams['choose_test_manually']:
+ test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])]
+ else:
+ test_item_names = item_names[-5:]
+ train_item_names = [x for x in item_names if x not in set(test_item_names)]
+ logging.info("train {}".format(len(train_item_names)))
+ logging.info("test {}".format(len(test_item_names)))
+ return train_item_names, test_item_names
+
+ @property
+ def train_item_names(self):
+ return self._train_item_names
+
+ @property
+ def valid_item_names(self):
+ return self._test_item_names
+
+ @property
+ def test_item_names(self):
+ return self._test_item_names
+
+ def load_meta_data(self):
+ self.items = File2Batch.file2temporary_dict()
+
+ def _phone_encoder(self):
+ from preprocessing.hubertinfer import Hubertencoder
+ return Hubertencoder(hparams['hubert_path'])
\ No newline at end of file
diff --git a/preprocessing/__pycache__/SVCpre.cpython-38.pyc b/preprocessing/__pycache__/SVCpre.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..19cea41fd198bdd01d2234f1c6ad698aaa29ec7f
Binary files /dev/null and b/preprocessing/__pycache__/SVCpre.cpython-38.pyc differ
diff --git a/preprocessing/__pycache__/base_binarizer.cpython-38.pyc b/preprocessing/__pycache__/base_binarizer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d6de402f6ee5607ab3388fb88f5ea13271ce4cbb
Binary files /dev/null and b/preprocessing/__pycache__/base_binarizer.cpython-38.pyc differ
diff --git a/preprocessing/__pycache__/data_gen_utils.cpython-38.pyc b/preprocessing/__pycache__/data_gen_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..43e5877963f70434da67019bfd6eb3be1783592d
Binary files /dev/null and b/preprocessing/__pycache__/data_gen_utils.cpython-38.pyc differ
diff --git a/preprocessing/__pycache__/hubertinfer.cpython-38.pyc b/preprocessing/__pycache__/hubertinfer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0af487a6e5a0b61b9a9ac2954201302891054d4e
Binary files /dev/null and b/preprocessing/__pycache__/hubertinfer.cpython-38.pyc differ
diff --git a/preprocessing/__pycache__/process_pipeline.cpython-38.pyc b/preprocessing/__pycache__/process_pipeline.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ac59fa750d10bcaeda1676bdda471a427a6242a8
Binary files /dev/null and b/preprocessing/__pycache__/process_pipeline.cpython-38.pyc differ
diff --git a/preprocessing/base_binarizer.py b/preprocessing/base_binarizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f070584c60c091c3fa4bc1da377733698b0165b6
--- /dev/null
+++ b/preprocessing/base_binarizer.py
@@ -0,0 +1,237 @@
+import os
+from webbrowser import get
+os.environ["OMP_NUM_THREADS"] = "1"
+import yaml
+from utils.multiprocess_utils import chunked_multiprocess_run
+import random
+import json
+# from resemblyzer import VoiceEncoder
+from tqdm import tqdm
+from preprocessing.data_gen_utils import get_mel2ph, get_pitch_parselmouth, build_phone_encoder,get_pitch_crepe
+from utils.hparams import set_hparams, hparams
+import numpy as np
+from utils.indexed_datasets import IndexedDatasetBuilder
+
+
+class BinarizationError(Exception):
+ pass
+
+BASE_ITEM_ATTRIBUTES = ['txt', 'ph', 'wav_fn', 'tg_fn', 'spk_id']
+
+class BaseBinarizer:
+ '''
+ Base class for data processing.
+ 1. *process* and *process_data_split*:
+ process entire data, generate the train-test split (support parallel processing);
+ 2. *process_item*:
+ process singe piece of data;
+ 3. *get_pitch*:
+ infer the pitch using some algorithm;
+ 4. *get_align*:
+ get the alignment using 'mel2ph' format (see https://arxiv.org/abs/1905.09263).
+ 5. phoneme encoder, voice encoder, etc.
+
+ Subclasses should define:
+ 1. *load_metadata*:
+ how to read multiple datasets from files;
+ 2. *train_item_names*, *valid_item_names*, *test_item_names*:
+ how to split the dataset;
+ 3. load_ph_set:
+ the phoneme set.
+ '''
+ def __init__(self, item_attributes=BASE_ITEM_ATTRIBUTES):
+ self.binarization_args = hparams['binarization_args']
+ #self.pre_align_args = hparams['pre_align_args']
+
+ self.items = {}
+ # every item in self.items has some attributes
+ self.item_attributes = item_attributes
+
+ self.load_meta_data()
+ # check program correctness 检查itemdict的key只能在给定的列表中取值
+ assert all([attr in self.item_attributes for attr in list(self.items.values())[0].keys()])
+ self.item_names = sorted(list(self.items.keys()))
+
+ if self.binarization_args['shuffle']:
+ random.seed(1234)
+ random.shuffle(self.item_names)
+
+ # set default get_pitch algorithm
+ if hparams['use_crepe']:
+ self.get_pitch_algorithm = get_pitch_crepe
+ else:
+ self.get_pitch_algorithm = get_pitch_parselmouth
+
+ def load_meta_data(self):
+ raise NotImplementedError
+
+ @property
+ def train_item_names(self):
+ raise NotImplementedError
+
+ @property
+ def valid_item_names(self):
+ raise NotImplementedError
+
+ @property
+ def test_item_names(self):
+ raise NotImplementedError
+
+ def build_spk_map(self):
+ spk_map = set()
+ for item_name in self.item_names:
+ spk_name = self.items[item_name]['spk_id']
+ spk_map.add(spk_name)
+ spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))}
+ assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map)
+ return spk_map
+
+ def item_name2spk_id(self, item_name):
+ return self.spk_map[self.items[item_name]['spk_id']]
+
+ def _phone_encoder(self):
+ '''
+ use hubert encoder
+ '''
+ raise NotImplementedError
+ '''
+ create 'phone_set.json' file if it doesn't exist
+ '''
+ ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json"
+ ph_set = []
+ if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn):
+ self.load_ph_set(ph_set)
+ ph_set = sorted(set(ph_set))
+ json.dump(ph_set, open(ph_set_fn, 'w', encoding='utf-8'))
+ print("| Build phone set: ", ph_set)
+ else:
+ ph_set = json.load(open(ph_set_fn, 'r', encoding='utf-8'))
+ print("| Load phone set: ", ph_set)
+ return build_phone_encoder(hparams['binary_data_dir'])
+
+
+ def load_ph_set(self, ph_set):
+ raise NotImplementedError
+
+ def meta_data_iterator(self, prefix):
+ if prefix == 'valid':
+ item_names = self.valid_item_names
+ elif prefix == 'test':
+ item_names = self.test_item_names
+ else:
+ item_names = self.train_item_names
+ for item_name in item_names:
+ meta_data = self.items[item_name]
+ yield item_name, meta_data
+
+ def process(self):
+ os.makedirs(hparams['binary_data_dir'], exist_ok=True)
+ self.spk_map = self.build_spk_map()
+ print("| spk_map: ", self.spk_map)
+ spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json"
+ json.dump(self.spk_map, open(spk_map_fn, 'w', encoding='utf-8'))
+
+ self.phone_encoder =self._phone_encoder()
+ self.process_data_split('valid')
+ self.process_data_split('test')
+ self.process_data_split('train')
+
+ def process_data_split(self, prefix):
+ data_dir = hparams['binary_data_dir']
+ args = []
+ builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}')
+ lengths = []
+ f0s = []
+ total_sec = 0
+ # if self.binarization_args['with_spk_embed']:
+ # voice_encoder = VoiceEncoder().cuda()
+
+ for item_name, meta_data in self.meta_data_iterator(prefix):
+ args.append([item_name, meta_data, self.binarization_args])
+ spec_min=[]
+ spec_max=[]
+ # code for single cpu processing
+ for i in tqdm(reversed(range(len(args))), total=len(args)):
+ a = args[i]
+ item = self.process_item(*a)
+ if item is None:
+ continue
+ spec_min.append(item['spec_min'])
+ spec_max.append(item['spec_max'])
+ # item['spk_embe'] = voice_encoder.embed_utterance(item['wav']) \
+ # if self.binardization_args['with_spk_embed'] else None
+ if not self.binarization_args['with_wav'] and 'wav' in item:
+ if hparams['debug']:
+ print("del wav")
+ del item['wav']
+ if(hparams['debug']):
+ print(item)
+ builder.add_item(item)
+ lengths.append(item['len'])
+ total_sec += item['sec']
+ # if item.get('f0') is not None:
+ # f0s.append(item['f0'])
+ if prefix=='train':
+ spec_max=np.max(spec_max,0)
+ spec_min=np.min(spec_min,0)
+ print(spec_max.shape)
+ with open(hparams['config_path'], encoding='utf-8') as f:
+ _hparams=yaml.safe_load(f)
+ _hparams['spec_max']=spec_max.tolist()
+ _hparams['spec_min']=spec_min.tolist()
+ with open(hparams['config_path'], 'w', encoding='utf-8') as f:
+ yaml.safe_dump(_hparams,f)
+ builder.finalize()
+ np.save(f'{data_dir}/{prefix}_lengths.npy', lengths)
+ if len(f0s) > 0:
+ f0s = np.concatenate(f0s, 0)
+ f0s = f0s[f0s != 0]
+ np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()])
+ print(f"| {prefix} total duration: {total_sec:.3f}s")
+
+ def process_item(self, item_name, meta_data, binarization_args):
+ from preprocessing.process_pipeline import File2Batch
+ return File2Batch.temporary_dict2processed_input(item_name, meta_data, self.phone_encoder, binarization_args)
+
+ def get_align(self, meta_data, mel, phone_encoded, res):
+ raise NotImplementedError
+
+ def get_align_from_textgrid(self, meta_data, mel, phone_encoded, res):
+ '''
+ NOTE: this part of script is *isolated* from other scripts, which means
+ it may not be compatible with the current version.
+ '''
+ return
+ tg_fn, ph = meta_data['tg_fn'], meta_data['ph']
+ if tg_fn is not None and os.path.exists(tg_fn):
+ mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams)
+ else:
+ raise BinarizationError(f"Align not found")
+ if mel2ph.max() - 1 >= len(phone_encoded):
+ raise BinarizationError(
+ f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}")
+ res['mel2ph'] = mel2ph
+ res['dur'] = dur
+
+ def get_f0cwt(self, f0, res):
+ '''
+ NOTE: this part of script is *isolated* from other scripts, which means
+ it may not be compatible with the current version.
+ '''
+ return
+ from utils.cwt import get_cont_lf0, get_lf0_cwt
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
+ logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf)
+ cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm)
+ if np.any(np.isnan(Wavelet_lf0)):
+ raise BinarizationError("NaN CWT")
+ res['cwt_spec'] = Wavelet_lf0
+ res['cwt_scales'] = scales
+ res['f0_mean'] = logf0s_mean_org
+ res['f0_std'] = logf0s_std_org
+
+
+if __name__ == "__main__":
+ set_hparams()
+ BaseBinarizer().process()
diff --git a/preprocessing/binarize.py b/preprocessing/binarize.py
new file mode 100644
index 0000000000000000000000000000000000000000..df3bff078132d5c1e031af449855fe9c2ba998a1
--- /dev/null
+++ b/preprocessing/binarize.py
@@ -0,0 +1,20 @@
+import os
+
+os.environ["OMP_NUM_THREADS"] = "1"
+
+import importlib
+from utils.hparams import set_hparams, hparams
+
+
+def binarize():
+ binarizer_cls = hparams.get("binarizer_cls", 'basics.base_binarizer.BaseBinarizer')
+ pkg = ".".join(binarizer_cls.split(".")[:-1])
+ cls_name = binarizer_cls.split(".")[-1]
+ binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+ print("| Binarizer: ", binarizer_cls)
+ binarizer_cls().process()
+
+
+if __name__ == '__main__':
+ set_hparams()
+ binarize()
diff --git a/preprocessing/data_gen_utils.py b/preprocessing/data_gen_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..68b9f981144329698da77a463222a73618e6007b
--- /dev/null
+++ b/preprocessing/data_gen_utils.py
@@ -0,0 +1,426 @@
+from io import BytesIO
+import json
+import os
+import re
+import struct
+import warnings
+from collections import OrderedDict
+
+import librosa
+import numpy as np
+import parselmouth
+import pyloudnorm as pyln
+import resampy
+import torch
+import torchcrepe
+import webrtcvad
+from scipy.ndimage.morphology import binary_dilation
+from skimage.transform import resize
+import pyworld as world
+
+from utils import audio
+from utils.pitch_utils import f0_to_coarse
+from utils.text_encoder import TokenTextEncoder
+
+warnings.filterwarnings("ignore")
+PUNCS = '!,.?;:'
+
+int16_max = (2 ** 15) - 1
+
+
+def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12):
+ """
+ Ensures that segments without voice in the waveform remain no longer than a
+ threshold determined by the VAD parameters in params.py.
+ :param wav: the raw waveform as a numpy array of floats
+ :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have.
+ :return: the same waveform with silences trimmed away (length <= original wav length)
+ """
+
+ ## Voice Activation Detection
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
+ # This sets the granularity of the VAD. Should not need to be changed.
+ sampling_rate = 16000
+ wav_raw, sr = librosa.core.load(path, sr=sr)
+
+ if norm:
+ meter = pyln.Meter(sr) # create BS.1770 meter
+ loudness = meter.integrated_loudness(wav_raw)
+ wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0)
+ if np.abs(wav_raw).max() > 1.0:
+ wav_raw = wav_raw / np.abs(wav_raw).max()
+
+ wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best')
+
+ vad_window_length = 30 # In milliseconds
+ # Number of frames to average together when performing the moving average smoothing.
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
+ vad_moving_average_width = 8
+
+ # Compute the voice detection window size
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
+
+ # Trim the end of the audio to have a multiple of the window size
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
+
+ # Convert the float waveform to 16-bit mono PCM
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
+
+ # Perform voice activation detection
+ voice_flags = []
+ vad = webrtcvad.Vad(mode=3)
+ for window_start in range(0, len(wav), samples_per_window):
+ window_end = window_start + samples_per_window
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
+ sample_rate=sampling_rate))
+ voice_flags = np.array(voice_flags)
+
+ # Smooth the voice detection with a moving average
+ def moving_average(array, width):
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
+ ret = np.cumsum(array_padded, dtype=float)
+ ret[width:] = ret[width:] - ret[:-width]
+ return ret[width - 1:] / width
+
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
+ audio_mask = np.round(audio_mask).astype(np.bool)
+
+ # Dilate the voiced regions
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
+ audio_mask = np.repeat(audio_mask, samples_per_window)
+ audio_mask = resize(audio_mask, (len(wav_raw),)) > 0
+ if return_raw_wav:
+ return wav_raw, audio_mask, sr
+ return wav_raw[audio_mask], audio_mask, sr
+
+
+def process_utterance(wav_path,
+ fft_size=1024,
+ hop_size=256,
+ win_length=1024,
+ window="hann",
+ num_mels=80,
+ fmin=80,
+ fmax=7600,
+ eps=1e-6,
+ sample_rate=22050,
+ loud_norm=False,
+ min_level_db=-100,
+ return_linear=False,
+ trim_long_sil=False, vocoder='pwg'):
+ if isinstance(wav_path, str) or isinstance(wav_path, BytesIO):
+ if trim_long_sil:
+ wav, _, _ = trim_long_silences(wav_path, sample_rate)
+ else:
+ wav, _ = librosa.core.load(wav_path, sr=sample_rate)
+ else:
+ wav = wav_path
+ if loud_norm:
+ meter = pyln.Meter(sample_rate) # create BS.1770 meter
+ loudness = meter.integrated_loudness(wav)
+ wav = pyln.normalize.loudness(wav, loudness, -22.0)
+ if np.abs(wav).max() > 1:
+ wav = wav / np.abs(wav).max()
+
+ # get amplitude spectrogram
+ x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size,
+ win_length=win_length, window=window, pad_mode="constant")
+ spc = np.abs(x_stft) # (n_bins, T)
+
+ # get mel basis
+ fmin = 0 if fmin == -1 else fmin
+ fmax = sample_rate / 2 if fmax == -1 else fmax
+ mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax)
+ mel = mel_basis @ spc
+
+ if vocoder == 'pwg':
+ mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T)
+ else:
+ assert False, f'"{vocoder}" is not in ["pwg"].'
+
+ l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1)
+ wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0)
+ wav = wav[:mel.shape[1] * hop_size]
+
+ if not return_linear:
+ return wav, mel
+ else:
+ spc = audio.amp_to_db(spc)
+ spc = audio.normalize(spc, {'min_level_db': min_level_db})
+ return wav, mel, spc
+
+
+def get_pitch_parselmouth(wav_data, mel, hparams):
+ """
+
+ :param wav_data: [T]
+ :param mel: [T, 80]
+ :param hparams:
+ :return:
+ """
+ # time_step = hparams['hop_size'] / hparams['audio_sample_rate']
+ # f0_min = hparams['f0_min']
+ # f0_max = hparams['f0_max']
+
+ # if hparams['hop_size'] == 128:
+ # pad_size = 4
+ # elif hparams['hop_size'] == 256:
+ # pad_size = 2
+ # else:
+ # assert False
+
+ # f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac(
+ # time_step=time_step, voicing_threshold=0.6,
+ # pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+ # lpad = pad_size * 2
+ # rpad = len(mel) - len(f0) - lpad
+ # f0 = np.pad(f0, [[lpad, rpad]], mode='constant')
+ # # mel and f0 are extracted by 2 different libraries. we should force them to have the same length.
+ # # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value...
+ # # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda)
+ # delta_l = len(mel) - len(f0)
+ # assert np.abs(delta_l) <= 8
+ # if delta_l > 0:
+ # f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0)
+ # f0 = f0[:len(mel)]
+ # pad_size=(int(len(wav_data) // hparams['hop_size']) - len(f0) + 1) // 2
+ # f0 = np.pad(f0,[[pad_size,len(mel) - len(f0) - pad_size]], mode='constant')
+ # pitch_coarse = f0_to_coarse(f0, hparams)
+ # return f0, pitch_coarse
+
+ # Bye bye Parselmouth !
+ return get_pitch_world(wav_data, mel, hparams)
+
+def get_pitch_world(wav_data, mel, hparams):
+ """
+
+ :param wav_data: [T]
+ :param mel: [T, 80]
+ :param hparams:
+ :return:
+ """
+ time_step = 1000 * hparams['hop_size'] / hparams['audio_sample_rate']
+ f0_min = hparams['f0_min']
+ f0_max = hparams['f0_max']
+
+ # Here's to hoping it uses numpy stuff !
+ f0, _ = world.harvest(wav_data.astype(np.double), hparams['audio_sample_rate'], f0_min, f0_max, time_step)
+
+ # Change padding
+ len_diff = len(mel) - len(f0)
+ if len_diff > 0:
+ pad_len = (len_diff + 1) // 2
+ f0 = np.pad(f0, [[pad_len, len_diff - pad_len]])
+ else:
+ pad_len = (1 - len_diff) // 2
+ rpad = pad_len + len_diff
+ if rpad != 0:
+ f0 = f0[pad_len:rpad]
+ f0 = f0[pad_len:]
+ pitch_coarse = f0_to_coarse(f0, hparams)
+ return f0, pitch_coarse
+
+
+def get_pitch_crepe(wav_data, mel, hparams, threshold=0.05):
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ device = torch.device("cuda")
+ # crepe只支持16khz采样率,需要重采样
+ wav16k = resampy.resample(wav_data, hparams['audio_sample_rate'], 16000)
+ wav16k_torch = torch.FloatTensor(wav16k).unsqueeze(0).to(device)
+
+ # 频率范围
+ f0_min = hparams['f0_min']
+ f0_max = hparams['f0_max']
+
+ # 重采样后按照hopsize=80,也就是5ms一帧分析f0
+ f0, pd = torchcrepe.predict(wav16k_torch, 16000, 80, f0_min, f0_max, pad=True, model='full', batch_size=1024,
+ device=device, return_periodicity=True)
+
+ # 滤波,去掉静音,设置uv阈值,参考原仓库readme
+ pd = torchcrepe.filter.median(pd, 3)
+ pd = torchcrepe.threshold.Silence(-60.)(pd, wav16k_torch, 16000, 80)
+ f0 = torchcrepe.threshold.At(threshold)(f0, pd)
+ f0 = torchcrepe.filter.mean(f0, 3)
+
+ # 将nan频率(uv部分)转换为0频率
+ f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0)
+
+ '''
+ np.savetxt('问棋-crepe.csv',np.array([0.005*np.arange(len(f0[0])),f0[0].cpu().numpy()]).transpose(),delimiter=',')
+ '''
+
+ # 去掉0频率,并线性插值
+ nzindex = torch.nonzero(f0[0]).squeeze()
+ f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy()
+ time_org = 0.005 * nzindex.cpu().numpy()
+ time_frame = np.arange(len(mel)) * hparams['hop_size'] / hparams['audio_sample_rate']
+ if f0.shape[0] == 0:
+ f0 = torch.FloatTensor(time_frame.shape[0]).fill_(0)
+ print('f0 all zero!')
+ else:
+ f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
+ pitch_coarse = f0_to_coarse(f0, hparams)
+ return f0, pitch_coarse
+
+
+def remove_empty_lines(text):
+ """remove empty lines"""
+ assert (len(text) > 0)
+ assert (isinstance(text, list))
+ text = [t.strip() for t in text]
+ if "" in text:
+ text.remove("")
+ return text
+
+
+class TextGrid(object):
+ def __init__(self, text):
+ text = remove_empty_lines(text)
+ self.text = text
+ self.line_count = 0
+ self._get_type()
+ self._get_time_intval()
+ self._get_size()
+ self.tier_list = []
+ self._get_item_list()
+
+ def _extract_pattern(self, pattern, inc):
+ """
+ Parameters
+ ----------
+ pattern : regex to extract pattern
+ inc : increment of line count after extraction
+ Returns
+ -------
+ group : extracted info
+ """
+ try:
+ group = re.match(pattern, self.text[self.line_count]).group(1)
+ self.line_count += inc
+ except AttributeError:
+ raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count]))
+ return group
+
+ def _get_type(self):
+ self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2)
+
+ def _get_time_intval(self):
+ self.xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ self.xmax = self._extract_pattern(r"xmax = (.*)", 2)
+
+ def _get_size(self):
+ self.size = int(self._extract_pattern(r"size = (.*)", 2))
+
+ def _get_item_list(self):
+ """Only supports IntervalTier currently"""
+ for itemIdx in range(1, self.size + 1):
+ tier = OrderedDict()
+ item_list = []
+ tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1)
+ tier_class = self._extract_pattern(r"class = \"(.*)\"", 1)
+ if tier_class != "IntervalTier":
+ raise NotImplementedError("Only IntervalTier class is supported currently")
+ tier_name = self._extract_pattern(r"name = \"(.*)\"", 1)
+ tier_xmin = self._extract_pattern(r"xmin = (.*)", 1)
+ tier_xmax = self._extract_pattern(r"xmax = (.*)", 1)
+ tier_size = self._extract_pattern(r"intervals: size = (.*)", 1)
+ for i in range(int(tier_size)):
+ item = OrderedDict()
+ item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1)
+ item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1)
+ item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1)
+ item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1)
+ item_list.append(item)
+ tier["idx"] = tier_idx
+ tier["class"] = tier_class
+ tier["name"] = tier_name
+ tier["xmin"] = tier_xmin
+ tier["xmax"] = tier_xmax
+ tier["size"] = tier_size
+ tier["items"] = item_list
+ self.tier_list.append(tier)
+
+ def toJson(self):
+ _json = OrderedDict()
+ _json["file_type"] = self.file_type
+ _json["xmin"] = self.xmin
+ _json["xmax"] = self.xmax
+ _json["size"] = self.size
+ _json["tiers"] = self.tier_list
+ return json.dumps(_json, ensure_ascii=False, indent=2)
+
+
+def get_mel2ph(tg_fn, ph, mel, hparams):
+ ph_list = ph.split(" ")
+ with open(tg_fn, "r", encoding='utf-8') as f:
+ tg = f.readlines()
+ tg = remove_empty_lines(tg)
+ tg = TextGrid(tg)
+ tg = json.loads(tg.toJson())
+ split = np.ones(len(ph_list) + 1, np.float) * -1
+ tg_idx = 0
+ ph_idx = 0
+ tg_align = [x for x in tg['tiers'][-1]['items']]
+ tg_align_ = []
+ for x in tg_align:
+ x['xmin'] = float(x['xmin'])
+ x['xmax'] = float(x['xmax'])
+ if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']:
+ x['text'] = ''
+ if len(tg_align_) > 0 and tg_align_[-1]['text'] == '':
+ tg_align_[-1]['xmax'] = x['xmax']
+ continue
+ tg_align_.append(x)
+ tg_align = tg_align_
+ tg_len = len([x for x in tg_align if x['text'] != ''])
+ ph_len = len([x for x in ph_list if not is_sil_phoneme(x)])
+ assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn)
+ while tg_idx < len(tg_align) or ph_idx < len(ph_list):
+ if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]):
+ split[ph_idx] = 1e8
+ ph_idx += 1
+ continue
+ x = tg_align[tg_idx]
+ if x['text'] == '' and ph_idx == len(ph_list):
+ tg_idx += 1
+ continue
+ assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn)
+ ph = ph_list[ph_idx]
+ if x['text'] == '' and not is_sil_phoneme(ph):
+ assert False, (ph_list, tg_align)
+ if x['text'] != '' and is_sil_phoneme(ph):
+ ph_idx += 1
+ else:
+ assert (x['text'] == '' and is_sil_phoneme(ph)) \
+ or x['text'].lower() == ph.lower() \
+ or x['text'].lower() == 'sil', (x['text'], ph)
+ split[ph_idx] = x['xmin']
+ if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]):
+ split[ph_idx - 1] = split[ph_idx]
+ ph_idx += 1
+ tg_idx += 1
+ assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align])
+ assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn)
+ mel2ph = np.zeros([mel.shape[0]], np.int)
+ split[0] = 0
+ split[-1] = 1e8
+ for i in range(len(split) - 1):
+ assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],)
+ split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split]
+ for ph_idx in range(len(ph_list)):
+ mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1
+ mel2ph_torch = torch.from_numpy(mel2ph)
+ T_t = len(ph_list)
+ dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch))
+ dur = dur[1:].numpy()
+ return mel2ph, dur
+
+
+def build_phone_encoder(data_dir):
+ phone_list_file = os.path.join(data_dir, 'phone_set.json')
+ phone_list = json.load(open(phone_list_file, encoding='utf-8'))
+ return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
+
+
+def is_sil_phoneme(p):
+ return not p[0].isalpha()
diff --git a/preprocessing/hubertinfer.py b/preprocessing/hubertinfer.py
new file mode 100644
index 0000000000000000000000000000000000000000..15077cb66c31cde45e19ea39f1e6eef2f51b60f9
--- /dev/null
+++ b/preprocessing/hubertinfer.py
@@ -0,0 +1,42 @@
+import os.path
+from io import BytesIO
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from network.hubert.hubert_model import hubert_soft, get_units
+from network.hubert.vec_model import load_model, get_vec_units
+from utils.hparams import hparams
+
+
+class Hubertencoder():
+ def __init__(self, pt_path=f'.checkpoints/hubert/hubert_soft.pt'):
+ if not 'use_vec' in hparams.keys():
+ hparams['use_vec'] = False
+ if hparams['use_vec']:
+ pt_path = f".checkpoints/vec/checkpoint_best_legacy_500.pt"
+ self.dev = torch.device("cuda")
+ self.hbt_model = load_model(pt_path)
+ else:
+ pt_path = list(Path(pt_path).parent.rglob('*.pt'))[0]
+ if 'hubert_gpu' in hparams.keys():
+ self.use_gpu = hparams['hubert_gpu']
+ else:
+ self.use_gpu = True
+ self.dev = torch.device("cuda" if self.use_gpu and torch.cuda.is_available() else "cpu")
+ self.hbt_model = hubert_soft(str(pt_path)).to(self.dev)
+
+ def encode(self, wav_path):
+ if isinstance(wav_path, BytesIO):
+ npy_path = ""
+ wav_path.seek(0)
+ else:
+ npy_path = Path(wav_path).with_suffix('.npy')
+ if os.path.exists(npy_path):
+ units = np.load(str(npy_path))
+ elif hparams['use_vec']:
+ units = get_vec_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
+ else:
+ units = get_units(self.hbt_model, wav_path, self.dev).cpu().numpy()[0]
+ return units # [T,256]
diff --git a/preprocessing/process_pipeline.py b/preprocessing/process_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..12dd11b1621bd321c46443cf4de277974db390bd
--- /dev/null
+++ b/preprocessing/process_pipeline.py
@@ -0,0 +1,199 @@
+'''
+ file -> temporary_dict -> processed_input -> batch
+'''
+from utils.hparams import hparams
+from network.vocoders.base_vocoder import VOCODERS
+import numpy as np
+import traceback
+from pathlib import Path
+from .data_gen_utils import get_pitch_parselmouth,get_pitch_crepe
+from .base_binarizer import BinarizationError
+import torch
+import utils
+
+class File2Batch:
+ '''
+ pipeline: file -> temporary_dict -> processed_input -> batch
+ '''
+
+ @staticmethod
+ def file2temporary_dict():
+ '''
+ read from file, store data in temporary dicts
+ '''
+ raw_data_dir = Path(hparams['raw_data_dir'])
+ # meta_midi = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict]
+
+ # if hparams['perform_enhance'] and not hparams['infer']:
+ # vocoder=get_vocoder_cls(hparams)()
+ # raw_files = list(raw_data_dir.rglob(f"*.wav"))
+ # dic=[]
+ # time_step = hparams['hop_size'] / hparams['audio_sample_rate']
+ # f0_min = hparams['f0_min']
+ # f0_max = hparams['f0_max']
+ # for file in raw_files:
+ # y, sr = librosa.load(file, sr=hparams['audio_sample_rate'])
+ # f0 = parselmouth.Sound(y, hparams['audio_sample_rate']).to_pitch_ac(
+ # time_step=time_step , voicing_threshold=0.6,
+ # pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency']
+ # f0_mean=np.mean(f0[f0>0])
+ # dic.append(f0_mean)
+ # for idx in np.where(dic>np.percentile(dic, 80))[0]:
+ # file=raw_files[idx]
+ # wav,mel=vocoder.wav2spec(str(file))
+ # f0,_=get_pitch_parselmouth(wav,mel,hparams)
+ # f0[f0>0]=f0[f0>0]*(2**(2/12))
+ # wav_pred=vocoder.spec2wav(torch.FloatTensor(mel),f0=torch.FloatTensor(f0))
+ # sf.write(file.with_name(file.name[:-4]+'_high.wav'), wav_pred, 24000, 'PCM_16')
+ utterance_labels =[]
+ utterance_labels.extend(list(raw_data_dir.rglob(f"*.wav")))
+ utterance_labels.extend(list(raw_data_dir.rglob(f"*.ogg")))
+ #open(os.path.join(raw_data_dir, 'transcriptions.txt'), encoding='utf-8').readlines()
+
+ all_temp_dict = {}
+ for utterance_label in utterance_labels:
+ #song_info = utterance_label.split('|')
+ item_name =str(utterance_label)#raw_item_name = song_info[0]
+ # print(item_name)
+ temp_dict = {}
+ temp_dict['wav_fn'] =str(utterance_label)#f'{raw_data_dir}/wavs/{item_name}.wav'
+ # temp_dict['txt'] = song_info[1]
+
+ # temp_dict['ph'] = song_info[2]
+ # # self.item2wdb[item_name] = list(np.nonzero([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])[0])
+ # temp_dict['word_boundary'] = np.array([1 if x in ALL_YUNMU + ['AP', 'SP'] else 0 for x in song_info[2].split()])
+ # temp_dict['ph_durs'] = [float(x) for x in song_info[5].split(" ")]
+
+ # temp_dict['pitch_midi'] = np.array([note_to_midi(x.split("/")[0]) if x != 'rest' else 0
+ # for x in song_info[3].split(" ")])
+ # temp_dict['midi_dur'] = np.array([float(x) for x in song_info[4].split(" ")])
+ # temp_dict['is_slur'] = np.array([int(x) for x in song_info[6].split(" ")])
+ temp_dict['spk_id'] = hparams['speaker_id']
+ # assert temp_dict['pitch_midi'].shape == temp_dict['midi_dur'].shape == temp_dict['is_slur'].shape, \
+ # (temp_dict['pitch_midi'].shape, temp_dict['midi_dur'].shape, temp_dict['is_slur'].shape)
+
+ all_temp_dict[item_name] = temp_dict
+
+ return all_temp_dict
+
+ @staticmethod
+ def temporary_dict2processed_input(item_name, temp_dict, encoder, binarization_args):
+ '''
+ process data in temporary_dicts
+ '''
+ def get_pitch(wav, mel):
+ # get ground truth f0 by self.get_pitch_algorithm
+ if hparams['use_crepe']:
+ gt_f0, gt_pitch_coarse = get_pitch_crepe(wav, mel, hparams)
+ else:
+ gt_f0, gt_pitch_coarse = get_pitch_parselmouth(wav, mel, hparams)
+ if sum(gt_f0) == 0:
+ raise BinarizationError("Empty **gt** f0")
+ processed_input['f0'] = gt_f0
+ processed_input['pitch'] = gt_pitch_coarse
+
+ def get_align(meta_data, mel, phone_encoded, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']):
+ mel2ph = np.zeros([mel.shape[0]], int)
+ start_frame=0
+ ph_durs = mel.shape[0]/phone_encoded.shape[0]
+ if hparams['debug']:
+ print(mel.shape,phone_encoded.shape,mel.shape[0]/phone_encoded.shape[0])
+ for i_ph in range(phone_encoded.shape[0]):
+
+ end_frame = int(i_ph*ph_durs +ph_durs+ 0.5)
+ mel2ph[start_frame:end_frame+1] = i_ph + 1
+ start_frame = end_frame+1
+
+ processed_input['mel2ph'] = mel2ph
+
+ if hparams['vocoder'] in VOCODERS:
+ wav, mel = VOCODERS[hparams['vocoder']].wav2spec(temp_dict['wav_fn'])
+ else:
+ wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(temp_dict['wav_fn'])
+ processed_input = {
+ 'item_name': item_name, 'mel': mel, 'wav': wav,
+ 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0]
+ }
+ processed_input = {**temp_dict, **processed_input} # merge two dicts
+ processed_input['spec_min']=np.min(mel,axis=0)
+ processed_input['spec_max']=np.max(mel,axis=0)
+ #(processed_input['spec_min'].shape)
+ try:
+ if binarization_args['with_f0']:
+ get_pitch(wav, mel)
+ if binarization_args['with_hubert']:
+ try:
+ hubert_encoded = processed_input['hubert'] = encoder.encode(temp_dict['wav_fn'])
+ except:
+ traceback.print_exc()
+ raise Exception(f"hubert encode error")
+ if binarization_args['with_align']:
+ get_align(temp_dict, mel, hubert_encoded)
+ except Exception as e:
+ print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {temp_dict['wav_fn']}")
+ return None
+ return processed_input
+
+ @staticmethod
+ def processed_input2batch(samples):
+ '''
+ Args:
+ samples: one batch of processed_input
+ NOTE:
+ the batch size is controlled by hparams['max_sentences']
+ '''
+ if len(samples) == 0:
+ return {}
+ id = torch.LongTensor([s['id'] for s in samples])
+ item_names = [s['item_name'] for s in samples]
+ #text = [s['text'] for s in samples]
+ #txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
+ hubert = utils.collate_2d([s['hubert'] for s in samples], 0.0)
+ f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+ pitch = utils.collate_1d([s['pitch'] for s in samples])
+ uv = utils.collate_1d([s['uv'] for s in samples])
+ energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
+ mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+ if samples[0]['mel2ph'] is not None else None
+ mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+ #txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
+ hubert_lengths = torch.LongTensor([s['hubert'].shape[0] for s in samples])
+ mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+
+ batch = {
+ 'id': id,
+ 'item_name': item_names,
+ 'nsamples': len(samples),
+ # 'text': text,
+ # 'txt_tokens': txt_tokens,
+ # 'txt_lengths': txt_lengths,
+ 'hubert':hubert,
+ 'mels': mels,
+ 'mel_lengths': mel_lengths,
+ 'mel2ph': mel2ph,
+ 'energy': energy,
+ 'pitch': pitch,
+ 'f0': f0,
+ 'uv': uv,
+ }
+ #========not used=================
+ # if hparams['use_spk_embed']:
+ # spk_embed = torch.stack([s['spk_embed'] for s in samples])
+ # batch['spk_embed'] = spk_embed
+ # if hparams['use_spk_id']:
+ # spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
+ # batch['spk_ids'] = spk_ids
+ # if hparams['pitch_type'] == 'cwt':
+ # cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
+ # f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
+ # f0_std = torch.Tensor([s['f0_std'] for s in samples])
+ # batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
+ # elif hparams['pitch_type'] == 'ph':
+ # batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples])
+
+ # batch['pitch_midi'] = utils.collate_1d([s['pitch_midi'] for s in samples], 0)
+ # batch['midi_dur'] = utils.collate_1d([s['midi_dur'] for s in samples], 0)
+ # batch['is_slur'] = utils.collate_1d([s['is_slur'] for s in samples], 0)
+ # batch['word_boundary'] = utils.collate_1d([s['word_boundary'] for s in samples], 0)
+
+ return batch
\ No newline at end of file
diff --git a/raw/test_input.wav b/raw/test_input.wav
new file mode 100644
index 0000000000000000000000000000000000000000..b04148a941b60c6b9ff8db693c83b1ff8d07ffa1
Binary files /dev/null and b/raw/test_input.wav differ
diff --git a/requirements.png b/requirements.png
new file mode 100644
index 0000000000000000000000000000000000000000..7c4e35bfb664f6a997c8cf69c0e536c01653d655
Binary files /dev/null and b/requirements.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..528770ad9209c839bbe213ac1eb092d417096ffc
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,27 @@
+gradio
+torchcrepe==0.0.17
+tornado==6.2
+tqdm==4.64.1
+traitlets==5.5.0
+praat-parselmouth==0.4.1
+scikit-image
+ipython
+ipykernel
+pyloudnorm
+webrtcvad-wheels
+h5py
+einops
+pycwt
+librosa
+scikit-image
+ipython
+ipykernel
+pyloudnorm
+webrtcvad-wheels
+h5py
+einops
+pycwt
+torchmetrics==0.5
+pytorch_lightning==1.3.3
+cython
+pyworld
\ No newline at end of file
diff --git a/requirements.txtback b/requirements.txtback
new file mode 100644
index 0000000000000000000000000000000000000000..3be46ffd2a92427062099a82cfd091b81688f3a4
--- /dev/null
+++ b/requirements.txtback
@@ -0,0 +1,25 @@
+torchcrepe==0.0.17
+tornado==6.2
+tqdm==4.64.1
+traitlets==5.5.0
+praat-parselmouth==0.4.1
+scikit-image
+ipython
+ipykernel
+pyloudnorm
+webrtcvad-wheels
+h5py
+einops
+pycwt
+librosa
+scikit-image
+ipython
+ipykernel
+pyloudnorm
+webrtcvad-wheels
+h5py
+einops
+pycwt
+torchmetrics==0.5
+pytorch_lightning==1.3.3
+streamlit
\ No newline at end of file
diff --git a/requirements.txtback2 b/requirements.txtback2
new file mode 100644
index 0000000000000000000000000000000000000000..365d82e5261124803985856114b94984eca540e1
--- /dev/null
+++ b/requirements.txtback2
@@ -0,0 +1,99 @@
+absl-py==1.3.0
+aiohttp==3.8.3
+aiosignal==1.2.0
+appdirs==1.4.4
+asttokens==2.0.8
+async-timeout==4.0.2
+attrs==22.1.0
+audioread==3.0.0
+backcall==0.2.0
+cachetools==5.2.0
+cffi==1.15.1
+charset-normalizer==2.1.1
+colorama==0.4.6
+contourpy==1.0.5
+cycler==0.11.0
+debugpy==1.6.3
+decorator==5.1.1
+einops==0.5.0
+entrypoints==0.4
+executing==1.1.1
+fonttools==4.38.0
+frozenlist==1.3.1
+fsspec==2022.10.0
+future==0.18.2
+google-auth==2.13.0
+google-auth-oauthlib==0.4.6
+grpcio==1.50.0
+h5py==3.7.0
+idna==3.4
+imageio==2.22.2
+importlib-metadata==5.0.0
+ipykernel==6.16.2
+ipython==8.5.0
+jedi==0.18.1
+joblib==1.2.0
+jupyter_client==7.4.4
+jupyter_core==4.11.2
+kiwisolver==1.4.4
+librosa==0.9.1
+llvmlite==0.39.1
+Markdown==3.4.1
+MarkupSafe==2.1.1
+matplotlib==3.6.1
+matplotlib-inline==0.1.6
+multidict==6.0.2
+nest-asyncio==1.5.6
+networkx==2.8.7
+numba==0.56.3
+numpy==1.23.4
+oauthlib==3.2.2
+packaging==21.3
+parso==0.8.3
+pickleshare==0.7.5
+Pillow==9.2.0
+pooch==1.6.0
+praat-parselmouth==0.4.1
+prompt-toolkit==3.0.31
+protobuf==3.19.6
+psutil==5.9.3
+pure-eval==0.2.2
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pycparser==2.21
+pycwt==0.3.0a22
+pyDeprecate==0.3.0
+Pygments==2.13.0
+pyloudnorm==0.1.0
+pyparsing==3.0.9
+python-dateutil==2.8.2
+pytorch-lightning==1.3.3
+PyWavelets==1.4.1
+PyYAML==5.4.1
+pyzmq==24.0.1
+requests==2.28.1
+requests-oauthlib==1.3.1
+resampy==0.4.2
+rsa==4.9
+scikit-image==0.19.3
+scikit-learn==1.1.3
+scipy==1.9.3
+six==1.16.0
+soundfile==0.11.0
+stack-data==0.5.1
+threadpoolctl==3.1.0
+tifffile==2022.10.10
+tornado==6.2
+tqdm==4.64.1
+traitlets==5.5.0
+typing_extensions==4.4.0
+urllib3==1.26.12
+wcwidth==0.2.5
+webrtcvad-wheels
+Werkzeug==2.2.2
+wincertstore==0.2
+yarl==1.8.1
+zipp==3.10.0
+streamlit
+torchcrepe==0.0.17
+torchmetrics==0.5.0
\ No newline at end of file
diff --git a/requirements_short.txt b/requirements_short.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fa3f3a016f4bd5b76f2c2c3971607ad6234f6da0
--- /dev/null
+++ b/requirements_short.txt
@@ -0,0 +1,14 @@
+torchcrepe
+praat-parselmouth==0.4.1
+scikit-image
+ipython
+ipykernel
+pyloudnorm
+webrtcvad
+h5py
+einops
+pycwt
+torchmetrics==0.5
+pytorch_lightning==1.3.3
+pyworld
+librosa
\ No newline at end of file
diff --git a/run.py b/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..40d119f4be35d5631e283dca414bb6f407ea71b5
--- /dev/null
+++ b/run.py
@@ -0,0 +1,16 @@
+import importlib
+import os
+from utils.hparams import set_hparams, hparams
+set_hparams(print_hparams=False)
+
+def run_task():
+ assert hparams['task_cls'] != ''
+ pkg = ".".join(hparams["task_cls"].split(".")[:-1])
+ cls_name = hparams["task_cls"].split(".")[-1]
+ task_cls = getattr(importlib.import_module(pkg), cls_name)
+ task_cls.start()
+
+
+if __name__ == '__main__':
+ run_task()
+
diff --git a/simplify.py b/simplify.py
new file mode 100644
index 0000000000000000000000000000000000000000..75c187cc77e41b26a4bfd6bbfe3dad85e434ea59
--- /dev/null
+++ b/simplify.py
@@ -0,0 +1,28 @@
+from argparse import ArgumentParser
+
+import torch
+
+
+def simplify_pth(pth_name, project_name):
+ model_path = f'./checkpoints/{project_name}'
+ checkpoint_dict = torch.load(f'{model_path}/{pth_name}')
+ torch.save({'epoch': checkpoint_dict['epoch'],
+ 'state_dict': checkpoint_dict['state_dict'],
+ 'global_step': None,
+ 'checkpoint_callback_best': None,
+ 'optimizer_states': None,
+ 'lr_schedulers': None
+ }, f'./clean_{pth_name}')
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('--proj', type=str)
+ parser.add_argument('--steps', type=str)
+ args = parser.parse_args()
+ model_name = f"model_ckpt_steps_{args.steps}.ckpt"
+ simplify_pth(model_name, args.proj)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/training/__pycache__/train_pipeline.cpython-38.pyc b/training/__pycache__/train_pipeline.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3d3bb1542d0d0d46272bf572feabe0ac6b796e59
Binary files /dev/null and b/training/__pycache__/train_pipeline.cpython-38.pyc differ
diff --git a/training/config.yaml b/training/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d49f0ecb2e118cc255d97e6077eb8e62046f4a05
--- /dev/null
+++ b/training/config.yaml
@@ -0,0 +1,349 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 80
+audio_sample_rate: 24000
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/atri
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 30000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: False
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 50.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 512
+fmax: 12000
+fmin: 30
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 128
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 80
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 5.0e-05
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 24
+max_tokens: 128000
+max_updates: 1000000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: pretrain/nyaru/model_ckpt_steps_60000.ckpt
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/atri
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 256
+residual_layers: 20
+save_best: false
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: atri
+spec_max:
+- 0.2987259328365326
+- 0.29721200466156006
+- 0.23978209495544434
+- 0.208412766456604
+- 0.25777050852775574
+- 0.2514476478099823
+- 0.1129382848739624
+- 0.03415697440505028
+- 0.09860049188137054
+- 0.10637332499027252
+- 0.13287633657455444
+- 0.19744250178337097
+- 0.10040587931871414
+- 0.13735432922840118
+- 0.15107455849647522
+- 0.17196381092071533
+- 0.08298977464437485
+- 0.0632769986987114
+- 0.02723858878016472
+- -0.001819317927584052
+- -0.029565516859292984
+- -0.023574354127049446
+- -0.01633293740451336
+- 0.07143621146678925
+- 0.021580500528216362
+- 0.07257916033267975
+- -0.024349519982933998
+- -0.06165708228945732
+- -0.10486568510532379
+- -0.1363687664270401
+- -0.13333871960639954
+- -0.13955898582935333
+- -0.16613495349884033
+- -0.17636367678642273
+- -0.2786925733089447
+- -0.22967253625392914
+- -0.31897130608558655
+- -0.18007366359233856
+- -0.29366692900657654
+- -0.2871025800704956
+- -0.36748355627059937
+- -0.46071451902389526
+- -0.5464922189712524
+- -0.5719417333602905
+- -0.6020897626876831
+- -0.6239874958992004
+- -0.5653440952301025
+- -0.6508013606071472
+- -0.628247857093811
+- -0.6809687614440918
+- -0.569259762763977
+- -0.5423558354377747
+- -0.5811785459518433
+- -0.5359002351760864
+- -0.6565515398979187
+- -0.7143737077713013
+- -0.8502675890922546
+- -0.7979224920272827
+- -0.7110578417778015
+- -0.763409435749054
+- -0.7984790802001953
+- -0.6927220821380615
+- -0.658117413520813
+- -0.7486468553543091
+- -0.5949879884719849
+- -0.7494576573371887
+- -0.7400822639465332
+- -0.6822793483734131
+- -0.7773582339286804
+- -0.661201536655426
+- -0.791329026222229
+- -0.8982341885566711
+- -0.8736728429794312
+- -0.7701027393341064
+- -0.8490535616874695
+- -0.7479292154312134
+- -0.9320166110992432
+- -1.2862414121627808
+- -2.8936190605163574
+- -2.924229860305786
+spec_min:
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -6.0
+- -5.999454021453857
+- -5.8822431564331055
+- -5.892064571380615
+- -5.882402420043945
+- -5.786972522735596
+- -5.746835231781006
+- -5.8594512939453125
+- -5.7389445304870605
+- -5.718059539794922
+- -5.779720306396484
+- -5.801984786987305
+- -6.0
+- -6.0
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: true
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 2000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.hifigan.HifiGAN
+vocoder_ckpt: checkpoints/0109_hifigan_bigpopcs_hop128
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 512
+work_dir: checkpoints/atri
+no_fs2: false
\ No newline at end of file
diff --git a/training/config_nsf.yaml b/training/config_nsf.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..37beccc14e0b8978e065a79bc7319c2c057bf1e5
--- /dev/null
+++ b/training/config_nsf.yaml
@@ -0,0 +1,445 @@
+K_step: 1000
+accumulate_grad_batches: 1
+audio_num_mel_bins: 128
+audio_sample_rate: 44100
+binarization_args:
+ shuffle: false
+ with_align: true
+ with_f0: true
+ with_hubert: true
+ with_spk_embed: false
+ with_wav: false
+binarizer_cls: preprocessing.SVCpre.SVCBinarizer
+binary_data_dir: data/binary/Meiko
+check_val_every_n_epoch: 10
+choose_test_manually: false
+clip_grad_norm: 1
+config_path: training/config_nsf.yaml
+content_cond_steps: []
+cwt_add_f0_loss: false
+cwt_hidden_size: 128
+cwt_layers: 2
+cwt_loss: l1
+cwt_std_scale: 0.8
+datasets:
+- opencpop
+debug: false
+dec_ffn_kernel_size: 9
+dec_layers: 4
+decay_steps: 50000
+decoder_type: fft
+dict_dir: ''
+diff_decoder_type: wavenet
+diff_loss_type: l2
+dilation_cycle_length: 4
+dropout: 0.1
+ds_workers: 4
+dur_enc_hidden_stride_kernel:
+- 0,2,3
+- 0,2,3
+- 0,1,3
+dur_loss: mse
+dur_predictor_kernel: 3
+dur_predictor_layers: 5
+enc_ffn_kernel_size: 9
+enc_layers: 4
+encoder_K: 8
+encoder_type: fft
+endless_ds: false
+f0_bin: 256
+f0_max: 1100.0
+f0_min: 40.0
+ffn_act: gelu
+ffn_padding: SAME
+fft_size: 2048
+fmax: 16000
+fmin: 40
+fs2_ckpt: ''
+gaussian_start: true
+gen_dir_name: ''
+gen_tgt_spk_id: -1
+hidden_size: 256
+hop_size: 512
+hubert_gpu: true
+hubert_path: checkpoints/hubert/hubert_soft.pt
+infer: false
+keep_bins: 128
+lambda_commit: 0.25
+lambda_energy: 0.0
+lambda_f0: 1.0
+lambda_ph_dur: 0.3
+lambda_sent_dur: 1.0
+lambda_uv: 1.0
+lambda_word_dur: 1.0
+load_ckpt: ''
+log_interval: 100
+loud_norm: false
+lr: 0.0006
+max_beta: 0.02
+max_epochs: 3000
+max_eval_sentences: 1
+max_eval_tokens: 60000
+max_frames: 42000
+max_input_tokens: 60000
+max_sentences: 11
+max_tokens: 128000
+max_updates: 1000000
+mel_loss: ssim:0.5|l1:0.5
+mel_vmax: 1.5
+mel_vmin: -6.0
+min_level_db: -120
+no_fs2: true
+norm_type: gn
+num_ckpt_keep: 10
+num_heads: 2
+num_sanity_val_steps: 1
+num_spk: 1
+num_test_samples: 0
+num_valid_plots: 10
+optimizer_adam_beta1: 0.9
+optimizer_adam_beta2: 0.98
+out_wav_norm: false
+pe_ckpt: checkpoints/0102_xiaoma_pe/model_ckpt_steps_60000.ckpt
+pe_enable: false
+perform_enhance: true
+pitch_ar: false
+pitch_enc_hidden_stride_kernel:
+- 0,2,5
+- 0,2,5
+- 0,2,5
+pitch_extractor: parselmouth
+pitch_loss: l2
+pitch_norm: log
+pitch_type: frame
+pndm_speedup: 10
+pre_align_args:
+ allow_no_txt: false
+ denoise: false
+ forced_align: mfa
+ txt_processor: zh_g2pM
+ use_sox: true
+ use_tone: false
+pre_align_cls: data_gen.singing.pre_align.SingingPreAlign
+predictor_dropout: 0.5
+predictor_grad: 0.1
+predictor_hidden: -1
+predictor_kernel: 5
+predictor_layers: 5
+prenet_dropout: 0.5
+prenet_hidden_size: 256
+pretrain_fs_ckpt: ''
+processed_data_dir: xxx
+profile_infer: false
+raw_data_dir: data/raw/Meiko
+ref_norm_layer: bn
+rel_pos: true
+reset_phone_dict: true
+residual_channels: 384
+residual_layers: 20
+save_best: true
+save_ckpt: true
+save_codes:
+- configs
+- modules
+- src
+- utils
+save_f0: true
+save_gt: false
+schedule_type: linear
+seed: 1234
+sort_by_len: true
+speaker_id: Meiko
+spec_max:
+- 0.11616316437721252
+- 0.009597139433026314
+- 0.28568679094314575
+- 0.5713539123535156
+- 0.6507775187492371
+- 0.6846900582313538
+- 0.7684511542320251
+- 0.7574314475059509
+- 0.7267094254493713
+- 0.8298212289810181
+- 0.6814215183258057
+- 0.7774385213851929
+- 0.7883802056312561
+- 0.7771736979484558
+- 0.7607403993606567
+- 0.8505979180335999
+- 0.7654092311859131
+- 0.7792922258377075
+- 0.814899206161499
+- 0.8058286905288696
+- 0.839918851852417
+- 0.8406909108161926
+- 0.8339935541152954
+- 0.9287465810775757
+- 0.8166532516479492
+- 0.8449192047119141
+- 0.7643511891365051
+- 0.8175668716430664
+- 1.0239852666854858
+- 0.920753002166748
+- 0.8153243660926819
+- 0.7587951421737671
+- 0.7698416113853455
+- 0.7247377634048462
+- 0.6954795122146606
+- 0.6807010173797607
+- 0.8715915679931641
+- 0.8993064761161804
+- 0.90997314453125
+- 0.7913641333580017
+- 0.7065826058387756
+- 0.6068118810653687
+- 0.6278789639472961
+- 0.6242763996124268
+- 0.5978773236274719
+- 0.651780366897583
+- 0.7780635952949524
+- 0.7565146684646606
+- 0.5729265213012695
+- 0.5707721710205078
+- 0.5281876921653748
+- 0.5579817891120911
+- 0.6407540440559387
+- 0.7233482003211975
+- 0.5677092671394348
+- 0.40926626324653625
+- 0.4460923373699188
+- 0.4058813750743866
+- 0.4390961229801178
+- 0.5553078055381775
+- 0.5349165201187134
+- 0.43830350041389465
+- 0.4032619595527649
+- 0.3253237009048462
+- 0.30613574385643005
+- 0.44174280762672424
+- 0.3622792959213257
+- 0.45337533950805664
+- 0.3313130736351013
+- 0.36956584453582764
+- 0.4998202919960022
+- 0.42133796215057373
+- 0.28050243854522705
+- 0.26571735739707947
+- 0.20871540904045105
+- 0.3416949510574341
+- 0.3328045904636383
+- 0.332925409078598
+- 0.3000032603740692
+- 0.08743463456630707
+- 0.20726755261421204
+- 0.1583203673362732
+- 0.13275942206382751
+- 0.066913902759552
+- 0.1054723709821701
+- -0.08983375877141953
+- -0.12505969405174255
+- -0.03509913384914398
+- -0.11556489020586014
+- -0.2324075847864151
+- -0.06187695264816284
+- 0.020108096301555634
+- -0.009129349142313004
+- -0.044059865176677704
+- 0.0343453511595726
+- 0.030609752982854843
+- 0.11592991650104523
+- 0.04611678794026375
+- 0.016514429822564125
+- -0.10608740150928497
+- -0.18119606375694275
+- -0.0764162689447403
+- -0.005786585621535778
+- -0.16699059307575226
+- -0.1254500299692154
+- -0.09370455145835876
+- 0.015143157914280891
+- 0.07289116084575653
+- -0.006812357809394598
+- -0.0280735082924366
+- -0.0021705669350922108
+- -0.1115487739443779
+- -0.2423458993434906
+- -0.116642065346241
+- -0.1487213373184204
+- -0.16707029938697815
+- -0.25437667965888977
+- -0.32499101758003235
+- -0.2704009413719177
+- -0.29621294140815735
+- -0.42674311995506287
+- -0.4650932848453522
+- -0.5842434763908386
+- -0.6859109401702881
+- -0.9532108902931213
+- -0.9863560199737549
+- -1.220953106880188
+- -1.3163429498672485
+spec_min:
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.999994277954102
+- -4.942144870758057
+- -4.772783279418945
+- -4.7206244468688965
+- -4.5759992599487305
+- -4.509932518005371
+spk_cond_steps: []
+stop_token_weight: 5.0
+task_cls: training.task.SVC_task.SVCTask
+test_ids: []
+test_input_dir: ''
+test_num: 0
+test_prefixes:
+- test
+test_set_name: test
+timesteps: 1000
+train_set_name: train
+use_crepe: false
+use_denoise: false
+use_energy_embed: false
+use_gt_dur: false
+use_gt_f0: false
+use_midi: false
+use_nsf: true
+use_pitch_embed: true
+use_pos_embed: true
+use_spk_embed: false
+use_spk_id: false
+use_split_spk_id: false
+use_uv: false
+use_var_enc: false
+use_vec: false
+val_check_interval: 1000
+valid_num: 0
+valid_set_name: valid
+vocoder: network.vocoders.nsf_hifigan.NsfHifiGAN
+vocoder_ckpt: checkpoints/nsf_hifigan/model
+warmup_updates: 2000
+wav2spec_eps: 1e-6
+weight_decay: 0
+win_size: 2048
+work_dir: checkpoints/Meiko
diff --git a/training/dataset/__pycache__/base_dataset.cpython-38.pyc b/training/dataset/__pycache__/base_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eb751fa101b07c244ec4327c8daa71c767c84889
Binary files /dev/null and b/training/dataset/__pycache__/base_dataset.cpython-38.pyc differ
diff --git a/training/dataset/__pycache__/fs2_utils.cpython-38.pyc b/training/dataset/__pycache__/fs2_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..95fde02cfa553720d97934e331342e0a8d2c8442
Binary files /dev/null and b/training/dataset/__pycache__/fs2_utils.cpython-38.pyc differ
diff --git a/training/dataset/base_dataset.py b/training/dataset/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7d96e0fab43788757bd67c1737b9eab77b98c4d
--- /dev/null
+++ b/training/dataset/base_dataset.py
@@ -0,0 +1,66 @@
+import torch
+from utils.hparams import hparams
+import numpy as np
+import os
+
+class BaseDataset(torch.utils.data.Dataset):
+ '''
+ Base class for datasets.
+ 1. *ordered_indices*:
+ if self.shuffle == True, shuffle the indices;
+ if self.sort_by_len == True, sort data by length;
+ 2. *sizes*:
+ clipped length if "max_frames" is set;
+ 3. *num_tokens*:
+ unclipped length.
+
+ Subclasses should define:
+ 1. *collate*:
+ take the longest data, pad other data to the same length;
+ 2. *__getitem__*:
+ the index function.
+ '''
+ def __init__(self, shuffle):
+ super().__init__()
+ self.hparams = hparams
+ self.shuffle = shuffle
+ self.sort_by_len = hparams['sort_by_len']
+ self.sizes = None
+
+ @property
+ def _sizes(self):
+ return self.sizes
+
+ def __getitem__(self, index):
+ raise NotImplementedError
+
+ def collater(self, samples):
+ raise NotImplementedError
+
+ def __len__(self):
+ return len(self._sizes)
+
+ def num_tokens(self, index):
+ return self.size(index)
+
+ def size(self, index):
+ """Return an example's size as a float or tuple. This value is used when
+ filtering a dataset with ``--max-positions``."""
+ size = min(self._sizes[index], hparams['max_frames'])
+ return size
+
+ def ordered_indices(self):
+ """Return an ordered list of indices. Batches will be constructed based
+ on this order."""
+ if self.shuffle:
+ indices = np.random.permutation(len(self))
+ if self.sort_by_len:
+ indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')]
+ # 先random, 然后稳定排序, 保证排序后同长度的数据顺序是依照random permutation的 (被其随机打乱).
+ else:
+ indices = np.arange(len(self))
+ return indices
+
+ @property
+ def num_workers(self):
+ return int(os.getenv('NUM_WORKERS', hparams['ds_workers']))
diff --git a/training/dataset/fs2_utils.py b/training/dataset/fs2_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..23833faa9b2858f86d31dadc77930cb78b456509
--- /dev/null
+++ b/training/dataset/fs2_utils.py
@@ -0,0 +1,178 @@
+import matplotlib
+
+matplotlib.use('Agg')
+
+import glob
+import importlib
+from utils.cwt import get_lf0_cwt
+import os
+import torch.optim
+import torch.utils.data
+from utils.indexed_datasets import IndexedDataset
+from utils.pitch_utils import norm_interp_f0
+import numpy as np
+from training.dataset.base_dataset import BaseDataset
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+import torch.distributions
+from utils.hparams import hparams
+
+
+class FastSpeechDataset(BaseDataset):
+ def __init__(self, prefix, shuffle=False):
+ super().__init__(shuffle)
+ self.data_dir = hparams['binary_data_dir']
+ self.prefix = prefix
+ self.hparams = hparams
+ self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+ self.indexed_ds = None
+ # self.name2spk_id={}
+
+ # pitch stats
+ f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
+ if os.path.exists(f0_stats_fn):
+ hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
+ hparams['f0_mean'] = float(hparams['f0_mean'])
+ hparams['f0_std'] = float(hparams['f0_std'])
+ else:
+ hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
+
+ if prefix == 'test':
+ if hparams['test_input_dir'] != '':
+ self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir'])
+ else:
+ if hparams['num_test_samples'] > 0:
+ self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
+ self.sizes = [self.sizes[i] for i in self.avail_idxs]
+
+ if hparams['pitch_type'] == 'cwt':
+ _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10))
+
+ def _get_item(self, index):
+ if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
+ index = self.avail_idxs[index]
+ if self.indexed_ds is None:
+ self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+ return self.indexed_ds[index]
+
+ def __getitem__(self, index):
+ hparams = self.hparams
+ item = self._get_item(index)
+ max_frames = hparams['max_frames']
+ spec = torch.Tensor(item['mel'])[:max_frames]
+ energy = (spec.exp() ** 2).sum(-1).sqrt()
+ mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
+ f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+ #phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']])
+ hubert=torch.Tensor(item['hubert'][:hparams['max_input_tokens']])
+ pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
+ # print(item.keys(), item['mel'].shape, spec.shape)
+ sample = {
+ "id": index,
+ "item_name": item['item_name'],
+ # "text": item['txt'],
+ # "txt_token": phone,
+ "hubert":hubert,
+ "mel": spec,
+ "pitch": pitch,
+ "energy": energy,
+ "f0": f0,
+ "uv": uv,
+ "mel2ph": mel2ph,
+ "mel_nonpadding": spec.abs().sum(-1) > 0,
+ }
+ if self.hparams['use_spk_embed']:
+ sample["spk_embed"] = torch.Tensor(item['spk_embed'])
+ if self.hparams['use_spk_id']:
+ sample["spk_id"] = item['spk_id']
+ # sample['spk_id'] = 0
+ # for key in self.name2spk_id.keys():
+ # if key in item['item_name']:
+ # sample['spk_id'] = self.name2spk_id[key]
+ # break
+ #======not used==========
+ # if self.hparams['pitch_type'] == 'cwt':
+ # cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames]
+ # f0_mean = item.get('f0_mean', item.get('cwt_mean'))
+ # f0_std = item.get('f0_std', item.get('cwt_std'))
+ # sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std})
+ # elif self.hparams['pitch_type'] == 'ph':
+ # f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0)
+ # f0_phlevel_num = torch.zeros_like(phone).float().scatter_add(
+ # 0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1)
+ # sample["f0_ph"] = f0_phlevel_sum / f0_phlevel_num
+ return sample
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+ id = torch.LongTensor([s['id'] for s in samples])
+ item_names = [s['item_name'] for s in samples]
+ text = [s['text'] for s in samples]
+ txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0)
+ f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+ pitch = utils.collate_1d([s['pitch'] for s in samples],1)
+ uv = utils.collate_1d([s['uv'] for s in samples])
+ energy = utils.collate_1d([s['energy'] for s in samples], 0.0)
+ mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+ if samples[0]['mel2ph'] is not None else None
+ mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+ txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples])
+ mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+
+ batch = {
+ 'id': id,
+ 'item_name': item_names,
+ 'nsamples': len(samples),
+ 'text': text,
+ 'txt_tokens': txt_tokens,
+ 'txt_lengths': txt_lengths,
+ 'mels': mels,
+ 'mel_lengths': mel_lengths,
+ 'mel2ph': mel2ph,
+ 'energy': energy,
+ 'pitch': pitch,
+ 'f0': f0,
+ 'uv': uv,
+ }
+
+ if self.hparams['use_spk_embed']:
+ spk_embed = torch.stack([s['spk_embed'] for s in samples])
+ batch['spk_embed'] = spk_embed
+ if self.hparams['use_spk_id']:
+ spk_ids = torch.LongTensor([s['spk_id'] for s in samples])
+ batch['spk_ids'] = spk_ids
+ if self.hparams['pitch_type'] == 'cwt':
+ cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples])
+ f0_mean = torch.Tensor([s['f0_mean'] for s in samples])
+ f0_std = torch.Tensor([s['f0_std'] for s in samples])
+ batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std})
+ elif self.hparams['pitch_type'] == 'ph':
+ batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples])
+
+ return batch
+
+ def load_test_inputs(self, test_input_dir, spk_id=0):
+ inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3')
+ sizes = []
+ items = []
+
+ binarizer_cls = hparams.get("binarizer_cls", 'basics.base_binarizer.BaseBinarizer')
+ pkg = ".".join(binarizer_cls.split(".")[:-1])
+ cls_name = binarizer_cls.split(".")[-1]
+ binarizer_cls = getattr(importlib.import_module(pkg), cls_name)
+ binarization_args = hparams['binarization_args']
+ from preprocessing.hubertinfer import Hubertencoder
+ for wav_fn in inp_wav_paths:
+ item_name = os.path.basename(wav_fn)
+ ph = txt = tg_fn = ''
+ wav_fn = wav_fn
+ encoder = Hubertencoder(hparams['hubert_path'])
+
+ item = binarizer_cls.process_item(item_name, {'wav_fn':wav_fn}, encoder, binarization_args)
+ print(item)
+ items.append(item)
+ sizes.append(item['len'])
+ return items, sizes
diff --git a/training/from huggingface_hub import Repository.py b/training/from huggingface_hub import Repository.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c3c8ded414d2af662eb49404e608a8a15462e9a
--- /dev/null
+++ b/training/from huggingface_hub import Repository.py
@@ -0,0 +1,2 @@
+from huggingface_hub import Repository
+repo = Repository(local_dir="w2v2", clone_from="facebook/wav2vec2-large-960h-lv60")
\ No newline at end of file
diff --git a/training/pe.py b/training/pe.py
new file mode 100644
index 0000000000000000000000000000000000000000..584a518649ba8465ad0c7690b51cce9762592da5
--- /dev/null
+++ b/training/pe.py
@@ -0,0 +1,155 @@
+import matplotlib
+matplotlib.use('Agg')
+
+import torch
+import numpy as np
+import os
+
+from training.dataset.base_dataset import BaseDataset
+from training.task.fs2 import FastSpeech2Task
+from modules.fastspeech.pe import PitchExtractor
+import utils
+from utils.indexed_datasets import IndexedDataset
+from utils.hparams import hparams
+from utils.plot import f0_to_figure
+from utils.pitch_utils import norm_interp_f0, denorm_f0
+
+
+class PeDataset(BaseDataset):
+ def __init__(self, prefix, shuffle=False):
+ super().__init__(shuffle)
+ self.data_dir = hparams['binary_data_dir']
+ self.prefix = prefix
+ self.hparams = hparams
+ self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
+ self.indexed_ds = None
+
+ # pitch stats
+ f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
+ if os.path.exists(f0_stats_fn):
+ hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
+ hparams['f0_mean'] = float(hparams['f0_mean'])
+ hparams['f0_std'] = float(hparams['f0_std'])
+ else:
+ hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
+
+ if prefix == 'test':
+ if hparams['num_test_samples'] > 0:
+ self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
+ self.sizes = [self.sizes[i] for i in self.avail_idxs]
+
+ def _get_item(self, index):
+ if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
+ index = self.avail_idxs[index]
+ if self.indexed_ds is None:
+ self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
+ return self.indexed_ds[index]
+
+ def __getitem__(self, index):
+ hparams = self.hparams
+ item = self._get_item(index)
+ max_frames = hparams['max_frames']
+ spec = torch.Tensor(item['mel'])[:max_frames]
+ # mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None
+ f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
+ pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
+ # print(item.keys(), item['mel'].shape, spec.shape)
+ sample = {
+ "id": index,
+ "item_name": item['item_name'],
+ "text": item['txt'],
+ "mel": spec,
+ "pitch": pitch,
+ "f0": f0,
+ "uv": uv,
+ # "mel2ph": mel2ph,
+ # "mel_nonpadding": spec.abs().sum(-1) > 0,
+ }
+ return sample
+
+ def collater(self, samples):
+ if len(samples) == 0:
+ return {}
+ id = torch.LongTensor([s['id'] for s in samples])
+ item_names = [s['item_name'] for s in samples]
+ text = [s['text'] for s in samples]
+ f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
+ pitch = utils.collate_1d([s['pitch'] for s in samples])
+ uv = utils.collate_1d([s['uv'] for s in samples])
+ mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
+ mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
+ # mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \
+ # if samples[0]['mel2ph'] is not None else None
+ # mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0)
+
+ batch = {
+ 'id': id,
+ 'item_name': item_names,
+ 'nsamples': len(samples),
+ 'text': text,
+ 'mels': mels,
+ 'mel_lengths': mel_lengths,
+ 'pitch': pitch,
+ # 'mel2ph': mel2ph,
+ # 'mel_nonpaddings': mel_nonpaddings,
+ 'f0': f0,
+ 'uv': uv,
+ }
+ return batch
+
+
+class PitchExtractionTask(FastSpeech2Task):
+ def __init__(self):
+ super().__init__()
+ self.dataset_cls = PeDataset
+
+ def build_tts_model(self):
+ self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
+
+ # def build_scheduler(self, optimizer):
+ # return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
+ def _training_step(self, sample, batch_idx, _):
+ loss_output = self.run_model(self.model, sample)
+ total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+ loss_output['batch_size'] = sample['mels'].size()[0]
+ return total_loss, loss_output
+
+ def validation_step(self, sample, batch_idx):
+ outputs = {}
+ outputs['losses'] = {}
+ outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
+ outputs['total_loss'] = sum(outputs['losses'].values())
+ outputs['nsamples'] = sample['nsamples']
+ outputs = utils.tensors_to_scalars(outputs)
+ if batch_idx < hparams['num_valid_plots']:
+ self.plot_pitch(batch_idx, model_out, sample)
+ return outputs
+
+ def run_model(self, model, sample, return_output=False, infer=False):
+ f0 = sample['f0']
+ uv = sample['uv']
+ output = model(sample['mels'])
+ losses = {}
+ self.add_pitch_loss(output, sample, losses)
+ if not return_output:
+ return losses
+ else:
+ return losses, output
+
+ def plot_pitch(self, batch_idx, model_out, sample):
+ gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+ self.logger.experiment.add_figure(
+ f'f0_{batch_idx}',
+ f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
+ self.global_step)
+
+ def add_pitch_loss(self, output, sample, losses):
+ # mel2ph = sample['mel2ph'] # [B, T_s]
+ mel = sample['mels']
+ f0 = sample['f0']
+ uv = sample['uv']
+ # nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \
+ # else (sample['txt_tokens'] != 0).float()
+ nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings']
+ # print(nonpadding[0][-8:], nonpadding.shape)
+ self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)
\ No newline at end of file
diff --git a/training/rgsdgsd.py b/training/rgsdgsd.py
new file mode 100644
index 0000000000000000000000000000000000000000..527fd3fc836518096a691c37a2d9a9a578441413
--- /dev/null
+++ b/training/rgsdgsd.py
@@ -0,0 +1,107 @@
+import pygame
+import mido
+from gtts import gTTS
+from pygame.locals import MOUSEBUTTONDOWN, KEYDOWN, K_RETURN
+import time
+
+# Initialize the click count and the time of the last click
+click_count = 0
+last_click_time = 0
+
+# Initialize Pygame and create a window for the MIDI editor
+pygame.init()
+screen = pygame.display.set_mode((640, 480))
+
+class MidiNote:
+ def __init__(self, note, velocity, time, x, y, width, height):
+ self.note = note
+ self.velocity = velocity
+ self.time = time
+ self.x = x
+ self.y = y
+ self.width = width
+ self.height = height
+
+
+# Create a new MIDI file
+mid = mido.MidiFile()
+
+# Create a new MIDI track
+track = mido.MidiTrack()
+
+# Add the MIDI track to the file
+mid.tracks.append(track)
+
+# Create a function to add lyrics to a specific MIDI note
+def add_lyrics(note, lyrics):
+ # Add the lyrics to the MIDI note
+ note.text = lyrics
+ # Update the MIDI file with the new lyrics
+ mid.save("song.mid")
+
+# Create a function to get the MIDI note that was clicked on
+def get_clicked_note(pos):
+ # Iterate through the MIDI notes in the file
+ for track in mid.tracks:
+ for note in track:
+ if isinstance(note, mido.Message):
+ # Check if the mouse position is within the bounds of the MIDI note
+ if pos[0] > note.x and pos[0] < note.x + note.width:
+ if pos[1] > note.y and pos[1] < note.y + note.height:
+ return note
+ return None
+
+# Create a function to convert the lyrics from Japanese to speech
+def speak_lyrics(lyrics):
+ tts = gTTS(lyrics, lang='ja')
+ tts.save('lyrics.mp3')
+ pygame.mixer.music.load('lyrics.mp3')
+ pygame.mixer.music.play()
+
+
+# Main loop
+while True:
+ for event in pygame.event.get():
+ if event.type == MOUSEBUTTONDOWN:
+ # Increment the click count
+ click_count += 1
+ # Check if the user double-clicked on a MIDI note
+ if time.time() - last_click_time < 0.5:
+ # Get the MIDI note that was clicked on
+ note = get_clicked_note(event.pos)
+ # Add the lyrics to the MIDI note
+ add_lyrics(note, lyrics)
+ # Reset the click count
+ click_count = 0
+ # Update the time of the last click
+ last_click_time = time.time()
+ if event.type == KEYDOWN:
+ if event.key == K_RETURN:
+ # Get the lyrics from the input field
+ lyrics = input_field.get_text()
+ # Convert the lyrics to speech and play them
+ speak_lyrics(lyrics)
+ # If the click count is not reset, it means that the user has single-clicked
+ if click_count == 1:
+ # Get the position of the single click
+ pos = pygame.mouse.get_pos()
+ # Create a new MIDI note with the specified position and length
+ note = MidiNote(60, 64, 0, event.pos[0], event.pos[1], 100, 100)
+ note.x = pos[0]
+ note.y = pos[1]
+ note.width = 100
+ note.height = 100
+ # Add the MIDI note to the track
+ track.append(note)
+ mid.save("song.mid")
+ # Reset the click count
+ click_count = 0
+ lyrics = ""
+
+ input_field = pygame.font.Font(None, 32).render(lyrics, True, (0, 0, 0))
+
+ # Display the input field on the window
+ screen.blit(input_field, (10, 10))
+ pygame.display.flip()
+
+
diff --git a/training/song.mid b/training/song.mid
new file mode 100644
index 0000000000000000000000000000000000000000..2e11450ce1dede01b45d22381102e424236fc679
Binary files /dev/null and b/training/song.mid differ
diff --git a/training/task/SVC_task.py b/training/task/SVC_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..56c66751af4a279adeb76089b45d0ab93b29e6ad
--- /dev/null
+++ b/training/task/SVC_task.py
@@ -0,0 +1,223 @@
+import torch
+
+import utils
+from utils.hparams import hparams
+from network.diff.net import DiffNet
+from network.diff.diffusion import GaussianDiffusion, OfflineGaussianDiffusion
+from training.task.fs2 import FastSpeech2Task
+from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from modules.fastspeech.tts_modules import mel2ph_to_dur
+
+from network.diff.candidate_decoder import FFT
+from utils.pitch_utils import denorm_f0
+from training.dataset.fs2_utils import FastSpeechDataset
+
+import numpy as np
+import os
+import torch.nn.functional as F
+
+DIFF_DECODERS = {
+ 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']),
+ 'fft': lambda hp: FFT(
+ hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']),
+}
+
+
+class SVCDataset(FastSpeechDataset):
+ def collater(self, samples):
+ from preprocessing.process_pipeline import File2Batch
+ return File2Batch.processed_input2batch(samples)
+
+
+class SVCTask(FastSpeech2Task):
+ def __init__(self):
+ super(SVCTask, self).__init__()
+ self.dataset_cls = SVCDataset
+ self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+
+ def build_tts_model(self):
+ # import torch
+ # from tqdm import tqdm
+ # v_min = torch.ones([80]) * 100
+ # v_max = torch.ones([80]) * -100
+ # for i, ds in enumerate(tqdm(self.dataset_cls('train'))):
+ # v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max)
+ # v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min)
+ # if i % 100 == 0:
+ # print(i, v_min, v_max)
+ # print('final', v_min, v_max)
+ mel_bins = hparams['audio_num_mel_bins']
+ self.model = GaussianDiffusion(
+ phone_encoder=self.phone_encoder,
+ out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams),
+ timesteps=hparams['timesteps'],
+ K_step=hparams['K_step'],
+ loss_type=hparams['diff_loss_type'],
+ spec_min=hparams['spec_min'], spec_max=hparams['spec_max'],
+ )
+
+
+ def build_optimizer(self, model):
+ self.optimizer = optimizer = torch.optim.AdamW(
+ filter(lambda p: p.requires_grad, model.parameters()),
+ lr=hparams['lr'],
+ betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']),
+ weight_decay=hparams['weight_decay'])
+ return optimizer
+
+ def run_model(self, model, sample, return_output=False, infer=False):
+ '''
+ steps:
+ 1. run the full model, calc the main loss
+ 2. calculate loss for dur_predictor, pitch_predictor, energy_predictor
+ '''
+ hubert = sample['hubert'] # [B, T_t,H]
+ target = sample['mels'] # [B, T_s, 80]
+ mel2ph = sample['mel2ph'] # [B, T_s]
+ f0 = sample['f0']
+ uv = sample['uv']
+ energy = sample['energy']
+
+ spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+ if hparams['pitch_type'] == 'cwt':
+ # NOTE: this part of script is *isolated* from other scripts, which means
+ # it may not be compatible with the current version.
+ pass
+ # cwt_spec = sample[f'cwt_spec']
+ # f0_mean = sample['f0_mean']
+ # f0_std = sample['f0_std']
+ # sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
+
+ # output == ret
+ # model == src.diff.diffusion.GaussianDiffusion
+ output = model(hubert, mel2ph=mel2ph, spk_embed=spk_embed,
+ ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer)
+
+ losses = {}
+ if 'diff_loss' in output:
+ losses['mel'] = output['diff_loss']
+ #self.add_dur_loss(output['dur'], mel2ph, txt_tokens, sample['word_boundary'], losses=losses)
+ # if hparams['use_pitch_embed']:
+ # self.add_pitch_loss(output, sample, losses)
+ # if hparams['use_energy_embed']:
+ # self.add_energy_loss(output['energy_pred'], energy, losses)
+ if not return_output:
+ return losses
+ else:
+ return losses, output
+
+ def _training_step(self, sample, batch_idx, _):
+ log_outputs = self.run_model(self.model, sample)
+ total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+ log_outputs['batch_size'] = sample['hubert'].size()[0]
+ log_outputs['lr'] = self.scheduler.get_lr()[0]
+ return total_loss, log_outputs
+
+ def build_scheduler(self, optimizer):
+ return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5)
+
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
+ if optimizer is None:
+ return
+ optimizer.step()
+ optimizer.zero_grad()
+ if self.scheduler is not None:
+ self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
+
+ def validation_step(self, sample, batch_idx):
+ outputs = {}
+ hubert = sample['hubert'] # [B, T_t]
+
+ target = sample['mels'] # [B, T_s, 80]
+ energy = sample['energy']
+ # fs2_mel = sample['fs2_mels']
+ spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+ mel2ph = sample['mel2ph']
+
+ outputs['losses'] = {}
+
+ outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False)
+
+ outputs['total_loss'] = sum(outputs['losses'].values())
+ outputs['nsamples'] = sample['nsamples']
+ outputs = utils.tensors_to_scalars(outputs)
+ if batch_idx < hparams['num_valid_plots']:
+ model_out = self.model(
+ hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=sample['f0'], uv=sample['uv'], energy=energy, ref_mels=None, infer=True
+ )
+
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+ gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel
+ pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel
+ else:
+ gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
+ pred_f0 = model_out.get('f0_denorm')
+ self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0)
+ self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}')
+ #self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'], name=f'fs2mel_{batch_idx}')
+ if hparams['use_pitch_embed']:
+ self.plot_pitch(batch_idx, sample, model_out)
+ return outputs
+
+ def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, wdb, losses=None):
+ """
+ the effect of each loss component:
+ hparams['dur_loss'] : align each phoneme
+ hparams['lambda_word_dur']: align each word
+ hparams['lambda_sent_dur']: align each sentence
+
+ :param dur_pred: [B, T], float, log scale
+ :param mel2ph: [B, T]
+ :param txt_tokens: [B, T]
+ :param losses:
+ :return:
+ """
+ B, T = txt_tokens.shape
+ nonpadding = (txt_tokens != 0).float()
+ dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding
+ is_sil = torch.zeros_like(txt_tokens).bool()
+ for p in self.sil_ph:
+ is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0])
+ is_sil = is_sil.float() # [B, T_txt]
+
+ # phone duration loss
+ if hparams['dur_loss'] == 'mse':
+ losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none')
+ losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
+ losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
+ dur_pred = (dur_pred.exp() - 1).clamp(min=0)
+ else:
+ raise NotImplementedError
+
+ # use linear scale for sent and word duration
+ if hparams['lambda_word_dur'] > 0:
+ #idx = F.pad(wdb.cumsum(axis=1), (1, 0))[:, :-1]
+ idx = wdb.cumsum(axis=1)
+ # word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_(1, idx, midi_dur) # midi_dur can be implied by add gt-ph_dur
+ word_dur_p = dur_pred.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_pred)
+ word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_gt)
+ wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none')
+ word_nonpadding = (word_dur_g > 0).float()
+ wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum()
+ losses['wdur'] = wdur_loss * hparams['lambda_word_dur']
+ if hparams['lambda_sent_dur'] > 0:
+ sent_dur_p = dur_pred.sum(-1)
+ sent_dur_g = dur_gt.sum(-1)
+ sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean')
+ losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur']
+
+ ############
+ # validation plots
+ ############
+ def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None):
+ gt_wav = gt_wav[0].cpu().numpy()
+ wav_out = wav_out[0].cpu().numpy()
+ gt_f0 = gt_f0[0].cpu().numpy()
+ f0 = f0[0].cpu().numpy()
+ if is_mel:
+ gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0)
+ wav_out = self.vocoder.spec2wav(wav_out, f0=f0)
+ self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+ self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step)
+
+
diff --git a/training/task/__pycache__/SVC_task.cpython-38.pyc b/training/task/__pycache__/SVC_task.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6765b70454177192d8f4d3da60bf3bedc96b2ce9
Binary files /dev/null and b/training/task/__pycache__/SVC_task.cpython-38.pyc differ
diff --git a/training/task/__pycache__/base_task.cpython-38.pyc b/training/task/__pycache__/base_task.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..080fb85834185c5f92f7c71f7c94ac576bd272c9
Binary files /dev/null and b/training/task/__pycache__/base_task.cpython-38.pyc differ
diff --git a/training/task/__pycache__/fs2.cpython-38.pyc b/training/task/__pycache__/fs2.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f430f42481655afd814db339415b6962a6da1650
Binary files /dev/null and b/training/task/__pycache__/fs2.cpython-38.pyc differ
diff --git a/training/task/__pycache__/tts.cpython-38.pyc b/training/task/__pycache__/tts.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4afe7b2dd3854a00ca49544b163578c5d4f49781
Binary files /dev/null and b/training/task/__pycache__/tts.cpython-38.pyc differ
diff --git a/training/task/base_task.py b/training/task/base_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..369bd18da4aeceaf45da0f913be6bc3a1948aa4b
--- /dev/null
+++ b/training/task/base_task.py
@@ -0,0 +1,337 @@
+from datetime import datetime
+import shutil
+
+import matplotlib
+
+matplotlib.use('Agg')
+
+from utils.hparams import hparams, set_hparams
+import random
+import sys
+import numpy as np
+import torch.distributed as dist
+from pytorch_lightning.loggers import TensorBoardLogger
+from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP
+from torch import nn
+import torch.utils.data
+import utils
+import logging
+import os
+
+torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system'))
+
+log_format = '%(asctime)s %(message)s'
+logging.basicConfig(stream=sys.stdout, level=logging.INFO,
+ format=log_format, datefmt='%m/%d %I:%M:%S %p')
+
+class BaseTask(nn.Module):
+ '''
+ Base class for training tasks.
+ 1. *load_ckpt*:
+ load checkpoint;
+ 2. *training_step*:
+ record and log the loss;
+ 3. *optimizer_step*:
+ run backwards step;
+ 4. *start*:
+ load training configs, backup code, log to tensorboard, start training;
+ 5. *configure_ddp* and *init_ddp_connection*:
+ start parallel training.
+
+ Subclasses should define:
+ 1. *build_model*, *build_optimizer*, *build_scheduler*:
+ how to build the model, the optimizer and the training scheduler;
+ 2. *_training_step*:
+ one training step of the model;
+ 3. *validation_end* and *_validation_end*:
+ postprocess the validation output.
+ '''
+ def __init__(self, *args, **kwargs):
+ # dataset configs
+ super(BaseTask, self).__init__(*args, **kwargs)
+ self.current_epoch = 0
+ self.global_step = 0
+ self.loaded_optimizer_states_dict = {}
+ self.trainer = None
+ self.logger = None
+ self.on_gpu = False
+ self.use_dp = False
+ self.use_ddp = False
+ self.example_input_array = None
+
+ self.max_tokens = hparams['max_tokens']
+ self.max_sentences = hparams['max_sentences']
+ self.max_eval_tokens = hparams['max_eval_tokens']
+ if self.max_eval_tokens == -1:
+ hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens
+ self.max_eval_sentences = hparams['max_eval_sentences']
+ if self.max_eval_sentences == -1:
+ hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences
+
+ self.model = None
+ self.training_losses_meter = None
+
+ ###########
+ # Training, validation and testing
+ ###########
+ def build_model(self):
+ raise NotImplementedError
+
+ def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True):
+ # This function is updated on 2021.12.13
+ if current_model_name is None:
+ current_model_name = model_name
+ utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict)
+
+ def on_epoch_start(self):
+ self.training_losses_meter = {'total_loss': utils.AvgrageMeter()}
+
+ def _training_step(self, sample, batch_idx, optimizer_idx):
+ """
+
+ :param sample:
+ :param batch_idx:
+ :return: total loss: torch.Tensor, loss_log: dict
+ """
+ raise NotImplementedError
+
+ def training_step(self, sample, batch_idx, optimizer_idx=-1):
+ loss_ret = self._training_step(sample, batch_idx, optimizer_idx)
+ self.opt_idx = optimizer_idx
+ if loss_ret is None:
+ return {'loss': None}
+ total_loss, log_outputs = loss_ret
+ log_outputs = utils.tensors_to_scalars(log_outputs)
+ for k, v in log_outputs.items():
+ if k not in self.training_losses_meter:
+ self.training_losses_meter[k] = utils.AvgrageMeter()
+ if not np.isnan(v):
+ self.training_losses_meter[k].update(v)
+ self.training_losses_meter['total_loss'].update(total_loss.item())
+
+ try:
+ log_outputs['lr'] = self.scheduler.get_lr()
+ if isinstance(log_outputs['lr'], list):
+ log_outputs['lr'] = log_outputs['lr'][0]
+ except:
+ pass
+
+ # log_outputs['all_loss'] = total_loss.item()
+ progress_bar_log = log_outputs
+ tb_log = {f'tr/{k}': v for k, v in log_outputs.items()}
+ return {
+ 'loss': total_loss,
+ 'progress_bar': progress_bar_log,
+ 'log': tb_log
+ }
+
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx):
+ optimizer.step()
+ optimizer.zero_grad()
+ if self.scheduler is not None:
+ self.scheduler.step(self.global_step // hparams['accumulate_grad_batches'])
+
+ def on_epoch_end(self):
+ loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()}
+ print(f"\n==============\n "
+ f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}"
+ f"\n==============\n")
+
+ def validation_step(self, sample, batch_idx):
+ """
+
+ :param sample:
+ :param batch_idx:
+ :return: output: dict
+ """
+ raise NotImplementedError
+
+ def _validation_end(self, outputs):
+ """
+
+ :param outputs:
+ :return: loss_output: dict
+ """
+ raise NotImplementedError
+
+ def validation_end(self, outputs):
+ loss_output = self._validation_end(outputs)
+ print(f"\n==============\n "
+ f"valid results: {loss_output}"
+ f"\n==============\n")
+ return {
+ 'log': {f'val/{k}': v for k, v in loss_output.items()},
+ 'val_loss': loss_output['total_loss']
+ }
+
+ def build_scheduler(self, optimizer):
+ raise NotImplementedError
+
+ def build_optimizer(self, model):
+ raise NotImplementedError
+
+ def configure_optimizers(self):
+ optm = self.build_optimizer(self.model)
+ self.scheduler = self.build_scheduler(optm)
+ return [optm]
+
+ def test_start(self):
+ pass
+
+ def test_step(self, sample, batch_idx):
+ return self.validation_step(sample, batch_idx)
+
+ def test_end(self, outputs):
+ return self.validation_end(outputs)
+
+ ###########
+ # Running configuration
+ ###########
+
+ @classmethod
+ def start(cls):
+ set_hparams()
+ os.environ['MASTER_PORT'] = str(random.randint(15000, 30000))
+ random.seed(hparams['seed'])
+ np.random.seed(hparams['seed'])
+ task = cls()
+ work_dir = hparams['work_dir']
+ trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint(
+ filepath=work_dir,
+ verbose=True,
+ monitor='val_loss',
+ mode='min',
+ num_ckpt_keep=hparams['num_ckpt_keep'],
+ save_best=hparams['save_best'],
+ period=1 if hparams['save_ckpt'] else 100000
+ ),
+ logger=TensorBoardLogger(
+ save_dir=work_dir,
+ name='lightning_logs',
+ version='lastest'
+ ),
+ gradient_clip_val=hparams['clip_grad_norm'],
+ val_check_interval=hparams['val_check_interval'],
+ row_log_interval=hparams['log_interval'],
+ max_updates=hparams['max_updates'],
+ num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[
+ 'validate'] else 10000,
+ accumulate_grad_batches=hparams['accumulate_grad_batches'])
+ if not hparams['infer']: # train
+ # copy_code = input(f'{hparams["save_codes"]} code backup? y/n: ') == 'y'
+ # copy_code = True # backup code every time
+ # if copy_code:
+ # t = datetime.now().strftime('%Y%m%d%H%M%S')
+ # code_dir = f'{work_dir}/codes/{t}'
+ # # TODO: test filesystem calls
+ # os.makedirs(code_dir, exist_ok=True)
+ # # subprocess.check_call(f'mkdir "{code_dir}"', shell=True)
+ # for c in hparams['save_codes']:
+ # shutil.copytree(c, code_dir, dirs_exist_ok=True)
+ # # subprocess.check_call(f'xcopy "{c}" "{code_dir}/" /s /e /y', shell=True)
+ # print(f"| Copied codes to {code_dir}.")
+ trainer.checkpoint_callback.task = task
+ trainer.fit(task)
+ else:
+ trainer.test(task)
+
+ def configure_ddp(self, model, device_ids):
+ model = DDP(
+ model,
+ device_ids=device_ids,
+ find_unused_parameters=True
+ )
+ if dist.get_rank() != 0 and not hparams['debug']:
+ sys.stdout = open(os.devnull, "w")
+ sys.stderr = open(os.devnull, "w")
+ random.seed(hparams['seed'])
+ np.random.seed(hparams['seed'])
+ return model
+
+ def training_end(self, *args, **kwargs):
+ return None
+
+ def init_ddp_connection(self, proc_rank, world_size):
+ set_hparams(print_hparams=False)
+ # guarantees unique ports across jobs from same grid search
+ default_port = 12910
+ # if user gave a port number, use that one instead
+ try:
+ default_port = os.environ['MASTER_PORT']
+ except Exception:
+ os.environ['MASTER_PORT'] = str(default_port)
+
+ # figure out the root node addr
+ root_node = '127.0.0.2'
+ root_node = self.trainer.resolve_root_node_address(root_node)
+ os.environ['MASTER_ADDR'] = root_node
+ dist.init_process_group('nccl', rank=proc_rank, world_size=world_size)
+
+ @data_loader
+ def train_dataloader(self):
+ return None
+
+ @data_loader
+ def test_dataloader(self):
+ return None
+
+ @data_loader
+ def val_dataloader(self):
+ return None
+
+ def on_load_checkpoint(self, checkpoint):
+ pass
+
+ def on_save_checkpoint(self, checkpoint):
+ pass
+
+ def on_sanity_check_start(self):
+ pass
+
+ def on_train_start(self):
+ pass
+
+ def on_train_end(self):
+ pass
+
+ def on_batch_start(self, batch):
+ pass
+
+ def on_batch_end(self):
+ pass
+
+ def on_pre_performance_check(self):
+ pass
+
+ def on_post_performance_check(self):
+ pass
+
+ def on_before_zero_grad(self, optimizer):
+ pass
+
+ def on_after_backward(self):
+ pass
+
+ def backward(self, loss, optimizer):
+ loss.backward()
+
+ def grad_norm(self, norm_type):
+ results = {}
+ total_norm = 0
+ for name, p in self.named_parameters():
+ if p.requires_grad:
+ try:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm ** norm_type
+ norm = param_norm ** (1 / norm_type)
+
+ grad = round(norm.data.cpu().numpy().flatten()[0], 3)
+ results['grad_{}_norm_{}'.format(norm_type, name)] = grad
+ except Exception:
+ # this param had no grad
+ pass
+
+ total_norm = total_norm ** (1. / norm_type)
+ grad = round(total_norm.data.cpu().numpy().flatten()[0], 3)
+ results['grad_{}_norm_total'.format(norm_type)] = grad
+ return results
diff --git a/training/task/fs2.py b/training/task/fs2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1618297cf869721c283baffa2fd7bda1c89020
--- /dev/null
+++ b/training/task/fs2.py
@@ -0,0 +1,539 @@
+import matplotlib
+
+matplotlib.use('Agg')
+
+from utils import audio
+import matplotlib.pyplot as plt
+from preprocessing.data_gen_utils import get_pitch_parselmouth
+from training.dataset.fs2_utils import FastSpeechDataset
+from utils.cwt import cwt2f0
+from utils.pl_utils import data_loader
+import os
+from multiprocessing.pool import Pool
+from tqdm import tqdm
+from modules.fastspeech.tts_modules import mel2ph_to_dur
+from utils.hparams import hparams
+from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure
+from utils.pitch_utils import denorm_f0
+from modules.fastspeech.fs2 import FastSpeech2
+from training.task.tts import TtsTask
+import torch
+import torch.optim
+import torch.utils.data
+import torch.nn.functional as F
+import utils
+import torch.distributions
+import numpy as np
+from modules.commons.ssim import ssim
+
+class FastSpeech2Task(TtsTask):
+ def __init__(self):
+ super(FastSpeech2Task, self).__init__()
+ self.dataset_cls = FastSpeechDataset
+ self.mse_loss_fn = torch.nn.MSELoss()
+ mel_losses = hparams['mel_loss'].split("|")
+ self.loss_and_lambda = {}
+ for i, l in enumerate(mel_losses):
+ if l == '':
+ continue
+ if ':' in l:
+ l, lbd = l.split(":")
+ lbd = float(lbd)
+ else:
+ lbd = 1.0
+ self.loss_and_lambda[l] = lbd
+ print("| Mel losses:", self.loss_and_lambda)
+ #self.sil_ph = self.phone_encoder.sil_phonemes()
+
+ @data_loader
+ def train_dataloader(self):
+ train_dataset = self.dataset_cls(hparams['train_set_name'], shuffle=True)
+ return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences,
+ endless=hparams['endless_ds'])
+
+ @data_loader
+ def val_dataloader(self):
+ valid_dataset = self.dataset_cls(hparams['valid_set_name'], shuffle=False)
+ return self.build_dataloader(valid_dataset, False, self.max_eval_tokens, self.max_eval_sentences)
+
+ @data_loader
+ def test_dataloader(self):
+ test_dataset = self.dataset_cls(hparams['test_set_name'], shuffle=False)
+ return self.build_dataloader(test_dataset, False, self.max_eval_tokens,
+ self.max_eval_sentences, batch_by_size=False)
+
+ def build_tts_model(self):
+ '''
+ rewrite
+ '''
+ return
+ # self.model = FastSpeech2(self.phone_encoder)
+
+ def build_model(self):
+ self.build_tts_model()
+ if hparams['load_ckpt'] != '':
+ self.load_ckpt(hparams['load_ckpt'], strict=True)
+ utils.print_arch(self.model)
+ return self.model
+
+ def _training_step(self, sample, batch_idx, _):
+ '''
+ rewrite
+ '''
+ return
+ # loss_output = self.run_model(self.model, sample)
+ # total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
+ # loss_output['batch_size'] = sample['txt_tokens'].size()[0]
+ # return total_loss, loss_output
+
+ def validation_step(self, sample, batch_idx):
+ '''
+ rewrite
+ '''
+ return
+ # outputs = {}
+ # outputs['losses'] = {}
+ # outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True)
+ # outputs['total_loss'] = sum(outputs['losses'].values())
+ # outputs['nsamples'] = sample['nsamples']
+ # mel_out = self.model.out2mel(model_out['mel_out'])
+ # outputs = utils.tensors_to_scalars(outputs)
+ # if batch_idx < hparams['num_valid_plots']:
+ # self.plot_mel(batch_idx, sample['mels'], mel_out)
+ # self.plot_dur(batch_idx, sample, model_out)
+ # if hparams['use_pitch_embed']:
+ # self.plot_pitch(batch_idx, sample, model_out)
+ # return outputs
+
+ def _validation_end(self, outputs):
+ all_losses_meter = {
+ 'total_loss': utils.AvgrageMeter(),
+ }
+ for output in outputs:
+ n = output['nsamples']
+ for k, v in output['losses'].items():
+ if k not in all_losses_meter:
+ all_losses_meter[k] = utils.AvgrageMeter()
+ all_losses_meter[k].update(v, n)
+ all_losses_meter['total_loss'].update(output['total_loss'], n)
+ return {k: round(v.avg, 4) for k, v in all_losses_meter.items()}
+
+ def run_model(self, model, sample, return_output=False):
+ '''
+ rewrite
+ '''
+ return
+ txt_tokens = sample['txt_tokens'] # [B, T_t]
+ target = sample['mels'] # [B, T_s, 80]
+ mel2ph = sample['mel2ph'] # [B, T_s]
+ f0 = sample['f0']
+ uv = sample['uv']
+ energy = sample['energy']
+ spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+ if hparams['pitch_type'] == 'cwt':
+ cwt_spec = sample[f'cwt_spec']
+ f0_mean = sample['f0_mean']
+ f0_std = sample['f0_std']
+ sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph)
+
+ output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed,
+ ref_mels=target, f0=f0, uv=uv, energy=energy, infer=False)
+
+ losses = {}
+ self.add_mel_loss(output['mel_out'], target, losses)
+ self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses)
+ if hparams['use_pitch_embed']:
+ self.add_pitch_loss(output, sample, losses)
+ if hparams['use_energy_embed']:
+ self.add_energy_loss(output['energy_pred'], energy, losses)
+ if not return_output:
+ return losses
+ else:
+ return losses, output
+
+ ############
+ # losses
+ ############
+ def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None):
+ if mel_mix_loss is None:
+ for loss_name, lbd in self.loss_and_lambda.items():
+ if 'l1' == loss_name:
+ l = self.l1_loss(mel_out, target)
+ elif 'mse' == loss_name:
+ raise NotImplementedError
+ elif 'ssim' == loss_name:
+ l = self.ssim_loss(mel_out, target)
+ elif 'gdl' == loss_name:
+ raise NotImplementedError
+ losses[f'{loss_name}{postfix}'] = l * lbd
+ else:
+ raise NotImplementedError
+
+ def l1_loss(self, decoder_output, target):
+ # decoder_output : B x T x n_mel
+ # target : B x T x n_mel
+ l1_loss = F.l1_loss(decoder_output, target, reduction='none')
+ weights = self.weights_nonzero_speech(target)
+ l1_loss = (l1_loss * weights).sum() / weights.sum()
+ return l1_loss
+
+ def ssim_loss(self, decoder_output, target, bias=6.0):
+ # decoder_output : B x T x n_mel
+ # target : B x T x n_mel
+ assert decoder_output.shape == target.shape
+ weights = self.weights_nonzero_speech(target)
+ decoder_output = decoder_output[:, None] + bias
+ target = target[:, None] + bias
+ ssim_loss = 1 - ssim(decoder_output, target, size_average=False)
+ ssim_loss = (ssim_loss * weights).sum() / weights.sum()
+ return ssim_loss
+
+ def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None):
+ """
+
+ :param dur_pred: [B, T], float, log scale
+ :param mel2ph: [B, T]
+ :param txt_tokens: [B, T]
+ :param losses:
+ :return:
+ """
+ B, T = txt_tokens.shape
+ nonpadding = (txt_tokens != 0).float()
+ dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding
+ is_sil = torch.zeros_like(txt_tokens).bool()
+ for p in self.sil_ph:
+ is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0])
+ is_sil = is_sil.float() # [B, T_txt]
+
+ # phone duration loss
+ if hparams['dur_loss'] == 'mse':
+ losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none')
+ losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum()
+ dur_pred = (dur_pred.exp() - 1).clamp(min=0)
+ elif hparams['dur_loss'] == 'mog':
+ return NotImplementedError
+ elif hparams['dur_loss'] == 'crf':
+ losses['pdur'] = -self.model.dur_predictor.crf(
+ dur_pred, dur_gt.long().clamp(min=0, max=31), mask=nonpadding > 0, reduction='mean')
+ losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur']
+
+ # use linear scale for sent and word duration
+ if hparams['lambda_word_dur'] > 0:
+ word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long()
+ word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:]
+ word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:]
+ wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none')
+ word_nonpadding = (word_dur_g > 0).float()
+ wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum()
+ losses['wdur'] = wdur_loss * hparams['lambda_word_dur']
+ if hparams['lambda_sent_dur'] > 0:
+ sent_dur_p = dur_pred.sum(-1)
+ sent_dur_g = dur_gt.sum(-1)
+ sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean')
+ losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur']
+
+ def add_pitch_loss(self, output, sample, losses):
+ if hparams['pitch_type'] == 'ph':
+ nonpadding = (sample['txt_tokens'] != 0).float()
+ pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
+ losses['f0'] = (pitch_loss_fn(output['pitch_pred'][:, :, 0], sample['f0'],
+ reduction='none') * nonpadding).sum() \
+ / nonpadding.sum() * hparams['lambda_f0']
+ return
+ mel2ph = sample['mel2ph'] # [B, T_s]
+ f0 = sample['f0']
+ uv = sample['uv']
+ nonpadding = (mel2ph != 0).float()
+ if hparams['pitch_type'] == 'cwt':
+ cwt_spec = sample[f'cwt_spec']
+ f0_mean = sample['f0_mean']
+ f0_std = sample['f0_std']
+ cwt_pred = output['cwt'][:, :, :10]
+ f0_mean_pred = output['f0_mean']
+ f0_std_pred = output['f0_std']
+ losses['C'] = self.cwt_loss(cwt_pred, cwt_spec) * hparams['lambda_f0']
+ if hparams['use_uv']:
+ assert output['cwt'].shape[-1] == 11
+ uv_pred = output['cwt'][:, :, -1]
+ losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') * nonpadding) \
+ .sum() / nonpadding.sum() * hparams['lambda_uv']
+ losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0']
+ losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0']
+ if hparams['cwt_add_f0_loss']:
+ f0_cwt_ = self.model.cwt2f0_norm(cwt_pred, f0_mean_pred, f0_std_pred, mel2ph)
+ self.add_f0_loss(f0_cwt_[:, :, None], f0, uv, losses, nonpadding=nonpadding)
+ elif hparams['pitch_type'] == 'frame':
+ self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding)
+
+ def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding):
+ assert p_pred[..., 0].shape == f0.shape
+ if hparams['use_uv']:
+ assert p_pred[..., 1].shape == uv.shape
+ losses['uv'] = (F.binary_cross_entropy_with_logits(
+ p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \
+ / nonpadding.sum() * hparams['lambda_uv']
+ nonpadding = nonpadding * (uv == 0).float()
+
+ f0_pred = p_pred[:, :, 0]
+ if hparams['pitch_loss'] in ['l1', 'l2']:
+ pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss
+ losses['f0'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \
+ / nonpadding.sum() * hparams['lambda_f0']
+ elif hparams['pitch_loss'] == 'ssim':
+ return NotImplementedError
+
+ def cwt_loss(self, cwt_p, cwt_g):
+ if hparams['cwt_loss'] == 'l1':
+ return F.l1_loss(cwt_p, cwt_g)
+ if hparams['cwt_loss'] == 'l2':
+ return F.mse_loss(cwt_p, cwt_g)
+ if hparams['cwt_loss'] == 'ssim':
+ return self.ssim_loss(cwt_p, cwt_g, 20)
+
+ def add_energy_loss(self, energy_pred, energy, losses):
+ nonpadding = (energy != 0).float()
+ loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum()
+ loss = loss * hparams['lambda_energy']
+ losses['e'] = loss
+
+
+ ############
+ # validation plots
+ ############
+ def plot_mel(self, batch_idx, spec, spec_out, name=None):
+ spec_cat = torch.cat([spec, spec_out], -1)
+ name = f'mel_{batch_idx}' if name is None else name
+ vmin = hparams['mel_vmin']
+ vmax = hparams['mel_vmax']
+ self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step)
+
+ def plot_dur(self, batch_idx, sample, model_out):
+ T_txt = sample['txt_tokens'].shape[1]
+ dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0]
+ dur_pred = self.model.dur_predictor.out2dur(model_out['dur']).float()
+ txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy())
+ txt = txt.split(" ")
+ self.logger.experiment.add_figure(
+ f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step)
+
+ def plot_pitch(self, batch_idx, sample, model_out):
+ f0 = sample['f0']
+ if hparams['pitch_type'] == 'ph':
+ mel2ph = sample['mel2ph']
+ f0 = self.expand_f0_ph(f0, mel2ph)
+ f0_pred = self.expand_f0_ph(model_out['pitch_pred'][:, :, 0], mel2ph)
+ self.logger.experiment.add_figure(
+ f'f0_{batch_idx}', f0_to_figure(f0[0], None, f0_pred[0]), self.global_step)
+ return
+ f0 = denorm_f0(f0, sample['uv'], hparams)
+ if hparams['pitch_type'] == 'cwt':
+ # cwt
+ cwt_out = model_out['cwt']
+ cwt_spec = cwt_out[:, :, :10]
+ cwt = torch.cat([cwt_spec, sample['cwt_spec']], -1)
+ self.logger.experiment.add_figure(f'cwt_{batch_idx}', spec_to_figure(cwt[0]), self.global_step)
+ # f0
+ f0_pred = cwt2f0(cwt_spec, model_out['f0_mean'], model_out['f0_std'], hparams['cwt_scales'])
+ if hparams['use_uv']:
+ assert cwt_out.shape[-1] == 11
+ uv_pred = cwt_out[:, :, -1] > 0
+ f0_pred[uv_pred > 0] = 0
+ f0_cwt = denorm_f0(sample['f0_cwt'], sample['uv'], hparams)
+ self.logger.experiment.add_figure(
+ f'f0_{batch_idx}', f0_to_figure(f0[0], f0_cwt[0], f0_pred[0]), self.global_step)
+ elif hparams['pitch_type'] == 'frame':
+ # f0
+ #uv_pred = model_out['pitch_pred'][:, :, 0] > 0
+ pitch_pred = denorm_f0(model_out['pitch_pred'][:, :, 0], sample['uv'], hparams)
+ self.logger.experiment.add_figure(
+ f'f0_{batch_idx}', f0_to_figure(f0[0], None, pitch_pred[0]), self.global_step)
+
+ ############
+ # infer
+ ############
+ def test_step(self, sample, batch_idx):
+ spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids')
+ hubert = sample['hubert']
+ mel2ph, uv, f0 = None, None, None
+ ref_mels = None
+ if hparams['profile_infer']:
+ pass
+ else:
+ # if hparams['use_gt_dur']:
+ mel2ph = sample['mel2ph']
+ #if hparams['use_gt_f0']:
+ f0 = sample['f0']
+ uv = sample['uv']
+ #print('Here using gt f0!!')
+ if hparams.get('use_midi') is not None and hparams['use_midi']:
+ outputs = self.model(
+ hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True)
+ else:
+ outputs = self.model(
+ hubert, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True)
+ sample['outputs'] = self.model.out2mel(outputs['mel_out'])
+ sample['mel2ph_pred'] = outputs['mel2ph']
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+ sample['f0'] = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel
+ sample['f0_pred'] = self.pe(sample['outputs'])['f0_denorm_pred'] # pe predict from Pred mel
+ else:
+ sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams)
+ sample['f0_pred'] = outputs.get('f0_denorm')
+ return self.after_infer(sample)
+
+ def after_infer(self, predictions):
+ if self.saving_result_pool is None and not hparams['profile_infer']:
+ self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16))
+ self.saving_results_futures = []
+ predictions = utils.unpack_dict_to_list(predictions)
+ t = tqdm(predictions)
+ for num_predictions, prediction in enumerate(t):
+ for k, v in prediction.items():
+ if type(v) is torch.Tensor:
+ prediction[k] = v.cpu().numpy()
+
+ item_name = prediction.get('item_name')
+ #text = prediction.get('text').replace(":", "%3A")[:80]
+
+ # remove paddings
+ mel_gt = prediction["mels"]
+ mel_gt_mask = np.abs(mel_gt).sum(-1) > 0
+ mel_gt = mel_gt[mel_gt_mask]
+ mel2ph_gt = prediction.get("mel2ph")
+ mel2ph_gt = mel2ph_gt[mel_gt_mask] if mel2ph_gt is not None else None
+ mel_pred = prediction["outputs"]
+ mel_pred_mask = np.abs(mel_pred).sum(-1) > 0
+ mel_pred = mel_pred[mel_pred_mask]
+ mel_gt = np.clip(mel_gt, hparams['mel_vmin'], hparams['mel_vmax'])
+ mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax'])
+
+ mel2ph_pred = prediction.get("mel2ph_pred")
+ if mel2ph_pred is not None:
+ if len(mel2ph_pred) > len(mel_pred_mask):
+ mel2ph_pred = mel2ph_pred[:len(mel_pred_mask)]
+ mel2ph_pred = mel2ph_pred[mel_pred_mask]
+
+ f0_gt = prediction.get("f0")
+ f0_pred = f0_gt#prediction.get("f0_pred")
+ if f0_pred is not None:
+ f0_gt = f0_gt[mel_gt_mask]
+ if len(f0_pred) > len(mel_pred_mask):
+ f0_pred = f0_pred[:len(mel_pred_mask)]
+ f0_pred = f0_pred[mel_pred_mask]
+ text=None
+ str_phs = None
+ # if self.phone_encoder is not None and 'txt_tokens' in prediction:
+ # str_phs = self.phone_encoder.decode(prediction['txt_tokens'], strip_padding=True)
+ # def resize2d(source, target_len):
+ # source[source<0.001] = np.nan
+ # target = np.interp(np.linspace(0, len(source)-1, num=target_len,endpoint=True), np.arange(0, len(source)), source)
+ # return np.nan_to_num(target)
+ # def resize3d(source, target_len):
+ # newsource=[]
+ # for i in range(source.shape[1]):
+ # newsource.append(resize2d(source[:,i],target_len))
+ # return np.array(newsource).transpose()
+ # print(mel_pred.shape)
+ # print(f0_pred.shape)
+ # mel_pred=resize3d(mel_pred,int(mel_pred.shape[0]/44100*24000))
+ # f0_pred=resize2d(f0_pred,int(f0_pred.shape[0]/44100*24000))
+ # print(mel_pred.shape)
+ # print(f0_pred.shape)
+ gen_dir = os.path.join(hparams['work_dir'],
+ f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
+ wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred)
+ if not hparams['profile_infer']:
+ os.makedirs(gen_dir, exist_ok=True)
+ os.makedirs(f'{gen_dir}/wavs', exist_ok=True)
+ os.makedirs(f'{gen_dir}/plot', exist_ok=True)
+ os.makedirs(os.path.join(hparams['work_dir'], 'P_mels_npy'), exist_ok=True)
+ os.makedirs(os.path.join(hparams['work_dir'], 'G_mels_npy'), exist_ok=True)
+ self.saving_results_futures.append(
+ self.saving_result_pool.apply_async(self.save_result, args=[
+ wav_pred, mel_pred, 'P', item_name, text, gen_dir, str_phs, mel2ph_pred, f0_gt, f0_pred]))
+
+ if mel_gt is not None and hparams['save_gt']:
+ wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt)
+ self.saving_results_futures.append(
+ self.saving_result_pool.apply_async(self.save_result, args=[
+ wav_gt, mel_gt, 'G', item_name, text, gen_dir, str_phs, mel2ph_gt, f0_gt, f0_pred]))
+ if hparams['save_f0']:
+ import matplotlib.pyplot as plt
+ # f0_pred_, _ = get_pitch(wav_pred, mel_pred, hparams)
+ f0_pred_ = f0_pred
+ f0_gt_, _ = get_pitch_parselmouth(wav_gt, mel_gt, hparams)
+ fig = plt.figure()
+ plt.plot(f0_pred_, label=r'$f0_P$')
+ plt.plot(f0_gt_, label=r'$f0_G$')
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+ # f0_midi = prediction.get("f0_midi")
+ # f0_midi = f0_midi[mel_gt_mask]
+ # plt.plot(f0_midi, label=r'$f0_M$')
+ pass
+ plt.legend()
+ plt.tight_layout()
+ plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png')
+ plt.close(fig)
+
+ t.set_description(
+ f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}")
+ else:
+ if 'gen_wav_time' not in self.stats:
+ self.stats['gen_wav_time'] = 0
+ self.stats['gen_wav_time'] += len(wav_pred) / hparams['audio_sample_rate']
+ print('gen_wav_time: ', self.stats['gen_wav_time'])
+
+ return {}
+
+ @staticmethod
+ def save_result(wav_out, mel, prefix, item_name, text, gen_dir, str_phs=None, mel2ph=None, gt_f0=None, pred_f0=None):
+ item_name = item_name.replace('/', '-')
+ base_fn = f'[{item_name}][{prefix}]'
+
+ if text is not None:
+ base_fn += text
+ base_fn += ('-' + hparams['exp_name'])
+ np.save(os.path.join(hparams['work_dir'], f'{prefix}_mels_npy', item_name), mel)
+ audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', 24000,#hparams['audio_sample_rate'],
+ norm=hparams['out_wav_norm'])
+ fig = plt.figure(figsize=(14, 10))
+ spec_vmin = hparams['mel_vmin']
+ spec_vmax = hparams['mel_vmax']
+ heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax)
+ fig.colorbar(heatmap)
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+ gt_f0 = (gt_f0 - 100) / (800 - 100) * 80 * (gt_f0 > 0)
+ pred_f0 = (pred_f0 - 100) / (800 - 100) * 80 * (pred_f0 > 0)
+ plt.plot(pred_f0, c='white', linewidth=1, alpha=0.6)
+ plt.plot(gt_f0, c='red', linewidth=1, alpha=0.6)
+ else:
+ f0, _ = get_pitch_parselmouth(wav_out, mel, hparams)
+ f0 = (f0 - 100) / (800 - 100) * 80 * (f0 > 0)
+ plt.plot(f0, c='white', linewidth=1, alpha=0.6)
+ if mel2ph is not None and str_phs is not None:
+ decoded_txt = str_phs.split(" ")
+ dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy()
+ dur = [0] + list(np.cumsum(dur))
+ for i in range(len(dur) - 1):
+ shift = (i % 20) + 1
+ plt.text(dur[i], shift, decoded_txt[i])
+ plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black')
+ plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black',
+ alpha=1, linewidth=1)
+ plt.tight_layout()
+ plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png', dpi=1000)
+ plt.close(fig)
+
+ ##############
+ # utils
+ ##############
+ @staticmethod
+ def expand_f0_ph(f0, mel2ph):
+ f0 = denorm_f0(f0, None, hparams)
+ f0 = F.pad(f0, [1, 0])
+ f0 = torch.gather(f0, 1, mel2ph) # [B, T_mel]
+ return f0
+
+
+if __name__ == '__main__':
+ FastSpeech2Task.start()
diff --git a/training/task/tts.py b/training/task/tts.py
new file mode 100644
index 0000000000000000000000000000000000000000..acf360a510a8cfc290da8f08269eda051cc77d0f
--- /dev/null
+++ b/training/task/tts.py
@@ -0,0 +1,131 @@
+from multiprocessing.pool import Pool
+
+import matplotlib
+
+from utils.pl_utils import data_loader
+from utils.training_utils import RSQRTSchedule
+from network.vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
+from modules.fastspeech.pe import PitchExtractor
+
+matplotlib.use('Agg')
+import os
+import numpy as np
+from tqdm import tqdm
+import torch.distributed as dist
+
+from training.task.base_task import BaseTask
+from utils.hparams import hparams
+from utils.text_encoder import TokenTextEncoder
+import json
+from preprocessing.hubertinfer import Hubertencoder
+import torch
+import torch.optim
+import torch.utils.data
+import utils
+
+
+
+class TtsTask(BaseTask):
+ def __init__(self, *args, **kwargs):
+ self.vocoder = None
+ self.phone_encoder = Hubertencoder(hparams['hubert_path'])
+ # self.padding_idx = self.phone_encoder.pad()
+ # self.eos_idx = self.phone_encoder.eos()
+ # self.seg_idx = self.phone_encoder.seg()
+ self.saving_result_pool = None
+ self.saving_results_futures = None
+ self.stats = {}
+ super().__init__(*args, **kwargs)
+
+ def build_scheduler(self, optimizer):
+ return RSQRTSchedule(optimizer)
+
+ def build_optimizer(self, model):
+ self.optimizer = optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=hparams['lr'])
+ return optimizer
+
+ def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
+ required_batch_size_multiple=-1, endless=False, batch_by_size=True):
+ devices_cnt = torch.cuda.device_count()
+ if devices_cnt == 0:
+ devices_cnt = 1
+ if required_batch_size_multiple == -1:
+ required_batch_size_multiple = devices_cnt
+
+ def shuffle_batches(batches):
+ np.random.shuffle(batches)
+ return batches
+
+ if max_tokens is not None:
+ max_tokens *= devices_cnt
+ if max_sentences is not None:
+ max_sentences *= devices_cnt
+ indices = dataset.ordered_indices()
+ if batch_by_size:
+ batch_sampler = utils.batch_by_size(
+ indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
+ required_batch_size_multiple=required_batch_size_multiple,
+ )
+ else:
+ batch_sampler = []
+ for i in range(0, len(indices), max_sentences):
+ batch_sampler.append(indices[i:i + max_sentences])
+
+ if shuffle:
+ batches = shuffle_batches(list(batch_sampler))
+ if endless:
+ batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
+ else:
+ batches = batch_sampler
+ if endless:
+ batches = [b for _ in range(1000) for b in batches]
+ num_workers = dataset.num_workers
+ if self.trainer.use_ddp:
+ num_replicas = dist.get_world_size()
+ rank = dist.get_rank()
+ batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
+ return torch.utils.data.DataLoader(dataset,
+ collate_fn=dataset.collater,
+ batch_sampler=batches,
+ num_workers=num_workers,
+ pin_memory=False)
+
+ # def build_phone_encoder(self, data_dir):
+ # phone_list_file = os.path.join(data_dir, 'phone_set.json')
+
+ # phone_list = json.load(open(phone_list_file, encoding='utf-8'))
+ # return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
+
+ def build_optimizer(self, model):
+ self.optimizer = optimizer = torch.optim.AdamW(
+ model.parameters(),
+ lr=hparams['lr'])
+ return optimizer
+
+ def test_start(self):
+ self.saving_result_pool = Pool(8)
+ self.saving_results_futures = []
+ self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
+ if hparams.get('pe_enable') is not None and hparams['pe_enable']:
+ self.pe = PitchExtractor().cuda()
+ utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
+ self.pe.eval()
+ def test_end(self, outputs):
+ self.saving_result_pool.close()
+ [f.get() for f in tqdm(self.saving_results_futures)]
+ self.saving_result_pool.join()
+ return {}
+
+ ##########
+ # utils
+ ##########
+ def weights_nonzero_speech(self, target):
+ # target : B x T x mel
+ # Assign weight 1.0 to all labels except for padding (id=0).
+ dim = target.size(-1)
+ return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
+
+if __name__ == '__main__':
+ TtsTask.start()
diff --git a/training/train_pipeline.py b/training/train_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f9f99f64ed0dcff0001034a6932a2c623ce706
--- /dev/null
+++ b/training/train_pipeline.py
@@ -0,0 +1,238 @@
+from utils.hparams import hparams
+import torch
+from torch.nn import functional as F
+from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0
+
+class Batch2Loss:
+ '''
+ pipeline: batch -> insert1 -> module1 -> insert2 -> module2 -> insert3 -> module3 -> insert4 -> module4 -> loss
+ '''
+
+ @staticmethod
+ def insert1(pitch_midi, midi_dur, is_slur, # variables
+ midi_embed, midi_dur_layer, is_slur_embed): # modules
+ '''
+ add embeddings for midi, midi_dur, slur
+ '''
+ midi_embedding = midi_embed(pitch_midi)
+ midi_dur_embedding, slur_embedding = 0, 0
+ if midi_dur is not None:
+ midi_dur_embedding = midi_dur_layer(midi_dur[:, :, None]) # [B, T, 1] -> [B, T, H]
+ if is_slur is not None:
+ slur_embedding = is_slur_embed(is_slur)
+ return midi_embedding, midi_dur_embedding, slur_embedding
+
+ @staticmethod
+ def module1(fs2_encoder, # modules
+ txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding): # variables
+ '''
+ get *encoder_out* == fs2_encoder(*txt_tokens*, some embeddings)
+ '''
+ encoder_out = fs2_encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding)
+ return encoder_out
+
+ @staticmethod
+ def insert2(encoder_out, spk_embed_id, spk_embed_dur_id, spk_embed_f0_id, src_nonpadding, # variables
+ spk_embed_proj): # modules
+ '''
+ 1. add embeddings for pspk, spk_dur, sk_f0
+ 2. get *dur_inp* ~= *encoder_out* + *spk_embed_dur*
+ '''
+ # add ref style embed
+ # Not implemented
+ # variance encoder
+ var_embed = 0
+
+ # encoder_out_dur denotes encoder outputs for duration predictor
+ # in speech adaptation, duration predictor use old speaker embedding
+ if hparams['use_spk_embed']:
+ spk_embed_dur = spk_embed_f0 = spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
+ elif hparams['use_spk_id']:
+ if spk_embed_dur_id is None:
+ spk_embed_dur_id = spk_embed_id
+ if spk_embed_f0_id is None:
+ spk_embed_f0_id = spk_embed_id
+ spk_embed = spk_embed_proj(spk_embed_id)[:, None, :]
+ spk_embed_dur = spk_embed_f0 = spk_embed
+ if hparams['use_split_spk_id']:
+ spk_embed_dur = spk_embed_dur(spk_embed_dur_id)[:, None, :]
+ spk_embed_f0 = spk_embed_f0(spk_embed_f0_id)[:, None, :]
+ else:
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
+
+ # add dur
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
+ return var_embed, spk_embed, spk_embed_dur, spk_embed_f0, dur_inp
+
+ @staticmethod
+ def module2(dur_predictor, length_regulator, # modules
+ dur_input, mel2ph, txt_tokens, all_vowel_tokens, ret, midi_dur=None): # variables
+ '''
+ 1. get *dur* ~= dur_predictor(*dur_inp*)
+ 2. (mel2ph is None): get *mel2ph* ~= length_regulater(*dur*)
+ '''
+ src_padding = (txt_tokens == 0)
+ dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach())
+
+ if mel2ph is None:
+ dur, xs = dur_predictor.inference(dur_input, src_padding)
+ ret['dur'] = xs
+ dur = xs.squeeze(-1).exp() - 1.0
+ for i in range(len(dur)):
+ for j in range(len(dur[i])):
+ if txt_tokens[i,j] in all_vowel_tokens:
+ if j < len(dur[i])-1 and txt_tokens[i,j+1] not in all_vowel_tokens:
+ dur[i,j] = midi_dur[i,j] - dur[i,j+1]
+ if dur[i,j] < 0:
+ dur[i,j] = 0
+ dur[i,j+1] = midi_dur[i,j]
+ else:
+ dur[i,j]=midi_dur[i,j]
+ dur[:,0] = dur[:,0] + 0.5
+ dur_acc = F.pad(torch.round(torch.cumsum(dur, axis=1)), (1,0))
+ dur = torch.clamp(dur_acc[:,1:]-dur_acc[:,:-1], min=0).long()
+ ret['dur_choice'] = dur
+ mel2ph = length_regulator(dur, src_padding).detach()
+ else:
+ ret['dur'] = dur_predictor(dur_input, src_padding)
+ ret['mel2ph'] = mel2ph
+
+ return mel2ph
+
+ @staticmethod
+ def insert3(encoder_out, mel2ph, var_embed, spk_embed_f0, src_nonpadding, tgt_nonpadding): # variables
+ '''
+ 1. get *decoder_inp* ~= gather *encoder_out* according to *mel2ph*
+ 2. get *pitch_inp* ~= *decoder_inp* + *spk_embed_f0*
+ 3. get *pitch_inp_ph* ~= *encoder_out* + *spk_embed_f0*
+ '''
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
+ decoder_inp = decoder_inp_origin = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
+
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
+ return decoder_inp, pitch_inp, pitch_inp_ph
+
+ @staticmethod
+ def module3(pitch_predictor, pitch_embed, energy_predictor, energy_embed, # modules
+ pitch_inp, pitch_inp_ph, f0, uv, energy, mel2ph, is_training, ret): # variables
+ '''
+ 1. get *ret['pitch_pred']*, *ret['energy_pred']* ~= pitch_predictor(*pitch_inp*), energy_predictor(*pitch_inp*)
+ 2. get *pitch_embedding* ~= pitch_embed(f0_to_coarse(denorm_f0(*f0* or *pitch_pred*))
+ 3. get *energy_embedding* ~= energy_embed(energy_to_coarse(*energy* or *energy_pred*))
+ '''
+ def add_pitch(decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
+ if hparams['pitch_type'] == 'ph':
+ pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach())
+ pitch_padding = (encoder_out.sum().abs() == 0)
+ ret['pitch_pred'] = pitch_pred = pitch_predictor(pitch_pred_inp)
+ if f0 is None:
+ f0 = pitch_pred[:, :, 0]
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding)
+ pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt]
+ pitch = F.pad(pitch, [1, 0])
+ pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel]
+ pitch_embedding = pitch_embed(pitch)
+ return pitch_embedding
+
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+
+ pitch_padding = (mel2ph == 0)
+
+ if hparams['pitch_type'] == 'cwt':
+ # NOTE: this part of script is *isolated* from other scripts, which means
+ # it may not be compatible with the current version.
+ pass
+ # pitch_padding = None
+ # ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp)
+ # stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2]
+ # mean = ret['f0_mean'] = stats_out[:, 0]
+ # std = ret['f0_std'] = stats_out[:, 1]
+ # cwt_spec = cwt_out[:, :, :10]
+ # if f0 is None:
+ # std = std * hparams['cwt_std_scale']
+ # f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
+ # if hparams['use_uv']:
+ # assert cwt_out.shape[-1] == 11
+ # uv = cwt_out[:, :, -1] > 0
+ elif hparams['pitch_ar']:
+ ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp, f0 if is_training else None)
+ if f0 is None:
+ f0 = pitch_pred[:, :, 0]
+ else:
+ ret['pitch_pred'] = pitch_pred = pitch_predictor(decoder_inp)
+ if f0 is None:
+ f0 = pitch_pred[:, :, 0]
+ if hparams['use_uv'] and uv is None:
+ uv = pitch_pred[:, :, 1] > 0
+ ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+
+ pitch = f0_to_coarse(f0_denorm) # start from 0
+ pitch_embedding = pitch_embed(pitch)
+ return pitch_embedding
+
+ def add_energy(decoder_inp, energy, ret):
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
+ ret['energy_pred'] = energy_pred = energy_predictor(decoder_inp)[:, :, 0]
+ if energy is None:
+ energy = energy_pred
+ energy = torch.clamp(energy * 256 // 4, max=255).long() # energy_to_coarse
+ energy_embedding = energy_embed(energy)
+ return energy_embedding
+
+ # add pitch and energy embed
+ nframes = mel2ph.size(1)
+
+ pitch_embedding = 0
+ if hparams['use_pitch_embed']:
+ if f0 is not None:
+ delta_l = nframes - f0.size(1)
+ if delta_l > 0:
+ f0 = torch.cat((f0,torch.FloatTensor([[x[-1]] * delta_l for x in f0]).to(f0.device)),1)
+ f0 = f0[:,:nframes]
+ if uv is not None:
+ delta_l = nframes - uv.size(1)
+ if delta_l > 0:
+ uv = torch.cat((uv,torch.FloatTensor([[x[-1]] * delta_l for x in uv]).to(uv.device)),1)
+ uv = uv[:,:nframes]
+ pitch_embedding = add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
+
+ energy_embedding = 0
+ if hparams['use_energy_embed']:
+ if energy is not None:
+ delta_l = nframes - energy.size(1)
+ if delta_l > 0:
+ energy = torch.cat((energy,torch.FloatTensor([[x[-1]] * delta_l for x in energy]).to(energy.device)),1)
+ energy = energy[:,:nframes]
+ energy_embedding = add_energy(pitch_inp, energy, ret)
+
+ return pitch_embedding, energy_embedding
+
+ @staticmethod
+ def insert4(decoder_inp, pitch_embedding, energy_embedding, spk_embed, ret, tgt_nonpadding):
+ '''
+ *decoder_inp* ~= *decoder_inp* + embeddings for spk, pitch, energy
+ '''
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + pitch_embedding + energy_embedding + spk_embed) * tgt_nonpadding
+ return decoder_inp
+
+ @staticmethod
+ def module4(diff_main_loss, # modules
+ norm_spec, decoder_inp_t, ret, K_step, batch_size, device): # variables
+ '''
+ training diffusion using spec as input and decoder_inp as condition.
+
+ Args:
+ norm_spec: (normalized) spec
+ decoder_inp_t: (transposed) decoder_inp
+ Returns:
+ ret['diff_loss']
+ '''
+ t = torch.randint(0, K_step, (batch_size,), device=device).long()
+ norm_spec = norm_spec.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
+ ret['diff_loss'] = diff_main_loss(norm_spec, t, cond=decoder_inp_t)
+ # nonpadding = (mel2ph != 0).float()
+ # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
diff --git a/trans_key.py b/trans_key.py
new file mode 100644
index 0000000000000000000000000000000000000000..c803a6acdbaa065cb75ce0a935b023780ab37026
--- /dev/null
+++ b/trans_key.py
@@ -0,0 +1,61 @@
+head_list = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]
+
+
+def trans_f0_seq(feature_pit, transform):
+ feature_pit = feature_pit * 2 ** (transform / 12)
+ return round(feature_pit, 1)
+
+
+def move_key(raw_data, mv_key):
+ head = raw_data[:-1]
+ body = int(raw_data[-1])
+ new_head_index = head_list.index(head) + mv_key
+ while new_head_index < 0:
+ body -= 1
+ new_head_index += 12
+ while new_head_index > 11:
+ body += 1
+ new_head_index -= 12
+ result_data = head_list[new_head_index] + str(body)
+ return result_data
+
+
+def trans_key(raw_data, key):
+ for i in raw_data:
+ note_seq_list = i["note_seq"].split(" ")
+ new_note_seq_list = []
+ for note_seq in note_seq_list:
+ if note_seq != "rest":
+ new_note_seq = move_key(note_seq, key)
+ new_note_seq_list.append(new_note_seq)
+ else:
+ new_note_seq_list.append(note_seq)
+ i["note_seq"] = " ".join(new_note_seq_list)
+
+ f0_seq_list = i["f0_seq"].split(" ")
+ f0_seq_list = [float(x) for x in f0_seq_list]
+ new_f0_seq_list = []
+ for f0_seq in f0_seq_list:
+ new_f0_seq = trans_f0_seq(f0_seq, key)
+ new_f0_seq_list.append(str(new_f0_seq))
+ i["f0_seq"] = " ".join(new_f0_seq_list)
+ return raw_data
+
+
+key = -6
+f_w = open("raw.txt", "w", encoding='utf-8')
+with open("result.txt", "r", encoding='utf-8') as f:
+ raw_data = f.readlines()
+ for raw in raw_data:
+ raw_list = raw.split("|")
+ new_note_seq_list = []
+ for note_seq in raw_list[3].split(" "):
+ if note_seq != "rest":
+ note_seq = note_seq.split("/")[0] if "/" in note_seq else note_seq
+ new_note_seq = move_key(note_seq, key)
+ new_note_seq_list.append(new_note_seq)
+ else:
+ new_note_seq_list.append(note_seq)
+ raw_list[3] = " ".join(new_note_seq_list)
+ f_w.write("|".join(raw_list))
+f_w.close()
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..edd05b1cbcf86d489ce395ab90e50587c7bef4c6
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,250 @@
+import glob
+import logging
+import re
+import time
+from collections import defaultdict
+import os
+import sys
+import shutil
+import types
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch import nn
+
+
+def tensors_to_scalars(metrics):
+ new_metrics = {}
+ for k, v in metrics.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ if type(v) is dict:
+ v = tensors_to_scalars(v)
+ new_metrics[k] = v
+ return new_metrics
+
+
+class AvgrageMeter(object):
+
+ def __init__(self):
+ self.reset()
+
+ def reset(self):
+ self.avg = 0
+ self.sum = 0
+ self.cnt = 0
+
+ def update(self, val, n=1):
+ self.sum += val * n
+ self.cnt += n
+ self.avg = self.sum / self.cnt
+
+
+def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
+ """Convert a list of 1d tensors into a padded 2d tensor."""
+ size = max(v.size(0) for v in values) if max_len is None else max_len
+ res = values[0].new(len(values), size).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if shift_right:
+ dst[1:] = src[:-1]
+ dst[0] = shift_id
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+ return res
+
+
+def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
+ """Convert a list of 2d tensors into a padded 3d tensor."""
+ size = max(v.size(0) for v in values) if max_len is None else max_len
+ res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
+
+ def copy_tensor(src, dst):
+ assert dst.numel() == src.numel()
+ if shift_right:
+ dst[1:] = src[:-1]
+ else:
+ dst.copy_(src)
+
+ for i, v in enumerate(values):
+ copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
+ return res
+
+
+def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ if len(batch) == 0:
+ return 0
+ if len(batch) == max_sentences:
+ return 1
+ if num_tokens > max_tokens:
+ return 1
+ return 0
+
+
+def batch_by_size(
+ indices, num_tokens_fn, max_tokens=None, max_sentences=None,
+ required_batch_size_multiple=1, distributed=False
+):
+ """
+ Yield mini-batches of indices bucketed by size. Batches may contain
+ sequences of different lengths.
+
+ Args:
+ indices (List[int]): ordered list of dataset indices
+ num_tokens_fn (callable): function that returns the number of tokens at
+ a given index
+ max_tokens (int, optional): max number of tokens in each batch
+ (default: None).
+ max_sentences (int, optional): max number of sentences in each
+ batch (default: None).
+ required_batch_size_multiple (int, optional): require batch size to
+ be a multiple of N (default: 1).
+ """
+ max_tokens = max_tokens if max_tokens is not None else sys.maxsize
+ max_sentences = max_sentences if max_sentences is not None else sys.maxsize
+ bsz_mult = required_batch_size_multiple
+
+ if isinstance(indices, types.GeneratorType):
+ indices = np.fromiter(indices, dtype=np.int64, count=-1)
+
+ sample_len = 0
+ sample_lens = []
+ batch = []
+ batches = []
+ for i in range(len(indices)):
+ idx = indices[i]
+ num_tokens = num_tokens_fn(idx)
+ sample_lens.append(num_tokens)
+ sample_len = max(sample_len, num_tokens)
+ assert sample_len <= max_tokens, (
+ "sentence at index {} of size {} exceeds max_tokens "
+ "limit of {}!".format(idx, sample_len, max_tokens)
+ )
+ num_tokens = (len(batch) + 1) * sample_len
+
+ if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
+ mod_len = max(
+ bsz_mult * (len(batch) // bsz_mult),
+ len(batch) % bsz_mult,
+ )
+ batches.append(batch[:mod_len])
+ batch = batch[mod_len:]
+ sample_lens = sample_lens[mod_len:]
+ sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
+ batch.append(idx)
+ if len(batch) > 0:
+ batches.append(batch)
+ return batches
+
+
+def make_positions(tensor, padding_idx):
+ """Replace non-padding symbols with their position numbers.
+
+ Position numbers begin at padding_idx+1. Padding symbols are ignored.
+ """
+ # The series of casts and type-conversions here are carefully
+ # balanced to both work with ONNX export and XLA. In particular XLA
+ # prefers ints, cumsum defaults to output longs, and ONNX doesn't know
+ # how to handle the dtype kwarg in cumsum.
+ mask = tensor.ne(padding_idx).int()
+ return (
+ torch.cumsum(mask, dim=1).type_as(mask) * mask
+ ).long() + padding_idx
+
+
+def softmax(x, dim):
+ return F.softmax(x, dim=dim, dtype=torch.float32)
+
+
+def unpack_dict_to_list(samples):
+ samples_ = []
+ bsz = samples.get('outputs').size(0)
+ for i in range(bsz):
+ res = {}
+ for k, v in samples.items():
+ try:
+ res[k] = v[i]
+ except:
+ pass
+ samples_.append(res)
+ return samples_
+
+
+def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True):
+ if os.path.isfile(ckpt_base_dir):
+ base_dir = os.path.dirname(ckpt_base_dir)
+ checkpoint_path = [ckpt_base_dir]
+ else:
+ base_dir = ckpt_base_dir
+ checkpoint_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
+ lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x.replace('\\','/'))[0]))
+ if len(checkpoint_path) > 0:
+ checkpoint_path = checkpoint_path[-1]
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
+ state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
+ if k.startswith(f'{prefix_in_ckpt}.')}
+ if not strict:
+ cur_model_state_dict = cur_model.state_dict()
+ unmatched_keys = []
+ for key, param in state_dict.items():
+ if key in cur_model_state_dict:
+ new_param = cur_model_state_dict[key]
+ if new_param.shape != param.shape:
+ unmatched_keys.append(key)
+ print("| Unmatched keys: ", key, new_param.shape, param.shape)
+ for key in unmatched_keys:
+ del state_dict[key]
+ cur_model.load_state_dict(state_dict, strict=strict)
+ print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.")
+ else:
+ e_msg = f"| ckpt not found in {base_dir}."
+ if force:
+ assert False, e_msg
+ else:
+ print(e_msg)
+
+
+def remove_padding(x, padding_idx=0):
+ if x is None:
+ return None
+ assert len(x.shape) in [1, 2]
+ if len(x.shape) == 2: # [T, H]
+ return x[np.abs(x).sum(-1) != padding_idx]
+ elif len(x.shape) == 1: # [T]
+ return x[x != padding_idx]
+
+
+class Timer:
+ timer_map = {}
+
+ def __init__(self, name, print_time=False):
+ if name not in Timer.timer_map:
+ Timer.timer_map[name] = 0
+ self.name = name
+ self.print_time = print_time
+
+ def __enter__(self):
+ self.t = time.time()
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ Timer.timer_map[self.name] += time.time() - self.t
+ if self.print_time:
+ print(self.name, Timer.timer_map[self.name])
+
+
+def print_arch(model, model_name='model'):
+ #print(f"| {model_name} Arch: ", model)
+ num_params(model, model_name=model_name)
+
+
+def num_params(model, print_out=True, model_name="model"):
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
+ if print_out:
+ print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
+ return parameters
diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c746ba9e45443d03816de9aa2b8a1059f9815a3c
Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/utils/__pycache__/audio.cpython-38.pyc b/utils/__pycache__/audio.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8afcff881dbb45eceea2a526c4a576105bcfc12c
Binary files /dev/null and b/utils/__pycache__/audio.cpython-38.pyc differ
diff --git a/utils/__pycache__/cwt.cpython-38.pyc b/utils/__pycache__/cwt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b3e274b26e8911ffb3b8f54a20ec8ff49f6c1495
Binary files /dev/null and b/utils/__pycache__/cwt.cpython-38.pyc differ
diff --git a/utils/__pycache__/hparams.cpython-38.pyc b/utils/__pycache__/hparams.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8641c0974b9c8464ca765b4208ba46da46ea7126
Binary files /dev/null and b/utils/__pycache__/hparams.cpython-38.pyc differ
diff --git a/utils/__pycache__/indexed_datasets.cpython-38.pyc b/utils/__pycache__/indexed_datasets.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b5af5e4925ba78cc44093198d69b905eb23b1cfd
Binary files /dev/null and b/utils/__pycache__/indexed_datasets.cpython-38.pyc differ
diff --git a/utils/__pycache__/multiprocess_utils.cpython-38.pyc b/utils/__pycache__/multiprocess_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..6dd62f9fd09bbe88a4c7eb52ef8a23c38a50cd4f
Binary files /dev/null and b/utils/__pycache__/multiprocess_utils.cpython-38.pyc differ
diff --git a/utils/__pycache__/pitch_utils.cpython-38.pyc b/utils/__pycache__/pitch_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3aff420fe73cfa093f2008de847293d166ac72a8
Binary files /dev/null and b/utils/__pycache__/pitch_utils.cpython-38.pyc differ
diff --git a/utils/__pycache__/pl_utils.cpython-38.pyc b/utils/__pycache__/pl_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f96abece2fced97b71c8ec51558130d0adee8678
Binary files /dev/null and b/utils/__pycache__/pl_utils.cpython-38.pyc differ
diff --git a/utils/__pycache__/plot.cpython-38.pyc b/utils/__pycache__/plot.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..be844867804ea2f6ddcf5d8e0c59aeb343d20aea
Binary files /dev/null and b/utils/__pycache__/plot.cpython-38.pyc differ
diff --git a/utils/__pycache__/text_encoder.cpython-38.pyc b/utils/__pycache__/text_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7dec50c4d7671deb03a4eba6969bcd909ef747b7
Binary files /dev/null and b/utils/__pycache__/text_encoder.cpython-38.pyc differ
diff --git a/utils/__pycache__/training_utils.cpython-38.pyc b/utils/__pycache__/training_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..290a0b8eb2bc967a3e9c87a564069db722a75810
Binary files /dev/null and b/utils/__pycache__/training_utils.cpython-38.pyc differ
diff --git a/utils/audio.py b/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..aba7ab926cf793d085bbdc70c97f376001183fe1
--- /dev/null
+++ b/utils/audio.py
@@ -0,0 +1,56 @@
+import subprocess
+import matplotlib
+
+matplotlib.use('Agg')
+import librosa
+import librosa.filters
+import numpy as np
+from scipy import signal
+from scipy.io import wavfile
+
+
+def save_wav(wav, path, sr, norm=False):
+ if norm:
+ wav = wav / np.abs(wav).max()
+ wav *= 32767
+ # proposed by @dsmiller
+ wavfile.write(path, sr, wav.astype(np.int16))
+
+
+def get_hop_size(hparams):
+ hop_size = hparams['hop_size']
+ if hop_size is None:
+ assert hparams['frame_shift_ms'] is not None
+ hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate'])
+ return hop_size
+
+
+###########################################################################################
+def _stft(y, hparams):
+ return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams),
+ win_length=hparams['win_size'], pad_mode='constant')
+
+
+def _istft(y, hparams):
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size'])
+
+
+def librosa_pad_lr(x, fsize, fshift, pad_sides=1):
+ '''compute right padding (final frame) or both sides padding (first and final frames)
+ '''
+ assert pad_sides in (1, 2)
+ # return int(fsize // 2)
+ pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+ if pad_sides == 1:
+ return 0, pad
+ else:
+ return pad // 2, pad // 2 + pad % 2
+
+
+# Conversions
+def amp_to_db(x):
+ return 20 * np.log10(np.maximum(1e-5, x))
+
+
+def normalize(S, hparams):
+ return (S - hparams['min_level_db']) / -hparams['min_level_db']
diff --git a/utils/cwt.py b/utils/cwt.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a08461b9e422aac614438e6240b7355b8e4bb2c
--- /dev/null
+++ b/utils/cwt.py
@@ -0,0 +1,146 @@
+import librosa
+import numpy as np
+from pycwt import wavelet
+from scipy.interpolate import interp1d
+
+
+def load_wav(wav_file, sr):
+ wav, _ = librosa.load(wav_file, sr=sr, mono=True)
+ return wav
+
+
+def convert_continuos_f0(f0):
+ '''CONVERT F0 TO CONTINUOUS F0
+ Args:
+ f0 (ndarray): original f0 sequence with the shape (T)
+ Return:
+ (ndarray): continuous f0 with the shape (T)
+ '''
+ # get uv information as binary
+ f0 = np.copy(f0)
+ uv = np.float32(f0 != 0)
+
+ # get start and end of f0
+ if (f0 == 0).all():
+ print("| all of the f0 values are 0.")
+ return uv, f0
+ start_f0 = f0[f0 != 0][0]
+ end_f0 = f0[f0 != 0][-1]
+
+ # padding start and end of f0 sequence
+ start_idx = np.where(f0 == start_f0)[0][0]
+ end_idx = np.where(f0 == end_f0)[0][-1]
+ f0[:start_idx] = start_f0
+ f0[end_idx:] = end_f0
+
+ # get non-zero frame index
+ nz_frames = np.where(f0 != 0)[0]
+
+ # perform linear interpolation
+ f = interp1d(nz_frames, f0[nz_frames])
+ cont_f0 = f(np.arange(0, f0.shape[0]))
+
+ return uv, cont_f0
+
+
+def get_cont_lf0(f0, frame_period=5.0):
+ uv, cont_f0_lpf = convert_continuos_f0(f0)
+ # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20)
+ cont_lf0_lpf = np.log(cont_f0_lpf)
+ return uv, cont_lf0_lpf
+
+
+def get_lf0_cwt(lf0):
+ '''
+ input:
+ signal of shape (N)
+ output:
+ Wavelet_lf0 of shape(10, N), scales of shape(10)
+ '''
+ mother = wavelet.MexicanHat()
+ dt = 0.005
+ dj = 1
+ s0 = dt * 2
+ J = 9
+
+ Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother)
+ # Wavelet.shape => (J + 1, len(lf0))
+ Wavelet_lf0 = np.real(Wavelet_lf0).T
+ return Wavelet_lf0, scales
+
+
+def norm_scale(Wavelet_lf0):
+ Wavelet_lf0_norm = np.zeros((Wavelet_lf0.shape[0], Wavelet_lf0.shape[1]))
+ mean = Wavelet_lf0.mean(0)[None, :]
+ std = Wavelet_lf0.std(0)[None, :]
+ Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std
+ return Wavelet_lf0_norm, mean, std
+
+
+def normalize_cwt_lf0(f0, mean, std):
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
+ cont_lf0_norm = (cont_lf0_lpf - mean) / std
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm)
+ Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0)
+
+ return Wavelet_lf0_norm
+
+
+def get_lf0_cwt_norm(f0s, mean, std):
+ uvs = list()
+ cont_lf0_lpfs = list()
+ cont_lf0_lpf_norms = list()
+ Wavelet_lf0s = list()
+ Wavelet_lf0s_norm = list()
+ scaless = list()
+
+ means = list()
+ stds = list()
+ for f0 in f0s:
+ uv, cont_lf0_lpf = get_cont_lf0(f0)
+ cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std
+
+ Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) # [560,10]
+ Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0) # [560,10],[1,10],[1,10]
+
+ Wavelet_lf0s_norm.append(Wavelet_lf0_norm)
+ uvs.append(uv)
+ cont_lf0_lpfs.append(cont_lf0_lpf)
+ cont_lf0_lpf_norms.append(cont_lf0_lpf_norm)
+ Wavelet_lf0s.append(Wavelet_lf0)
+ scaless.append(scales)
+ means.append(mean_scale)
+ stds.append(std_scale)
+
+ return Wavelet_lf0s_norm, scaless, means, stds
+
+
+def inverse_cwt_torch(Wavelet_lf0, scales):
+ import torch
+ b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5))
+ lf0_rec = Wavelet_lf0 * b
+ lf0_rec_sum = lf0_rec.sum(-1)
+ lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True)
+ return lf0_rec_sum
+
+
+def inverse_cwt(Wavelet_lf0, scales):
+ b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5))
+ lf0_rec = Wavelet_lf0 * b
+ lf0_rec_sum = lf0_rec.sum(-1)
+ lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True)
+ return lf0_rec_sum
+
+
+def cwt2f0(cwt_spec, mean, std, cwt_scales):
+ assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3
+ import torch
+ if isinstance(cwt_spec, torch.Tensor):
+ f0 = inverse_cwt_torch(cwt_spec, cwt_scales)
+ f0 = f0 * std[:, None] + mean[:, None]
+ f0 = f0.exp() # [B, T]
+ else:
+ f0 = inverse_cwt(cwt_spec, cwt_scales)
+ f0 = f0 * std[:, None] + mean[:, None]
+ f0 = np.exp(f0) # [B, T]
+ return f0
diff --git a/utils/hparams.py b/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d5e6552d88c4609343f968239bfce1a1c177c8b
--- /dev/null
+++ b/utils/hparams.py
@@ -0,0 +1,131 @@
+import argparse
+import os
+import yaml
+
+global_print_hparams = True
+hparams = {}
+
+
+class Args:
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ self.__setattr__(k, v)
+
+
+def override_config(old_config: dict, new_config: dict):
+ for k, v in new_config.items():
+ if isinstance(v, dict) and k in old_config:
+ override_config(old_config[k], new_config[k])
+ else:
+ old_config[k] = v
+
+
+def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True,reset=True,infer=True):
+ '''
+ Load hparams from multiple sources:
+ 1. config chain (i.e. first load base_config, then load config);
+ 2. if reset == True, load from the (auto-saved) complete config file ('config.yaml')
+ which contains all settings and do not rely on base_config;
+ 3. load from argument --hparams or hparams_str, as temporary modification.
+ '''
+ if config == '':
+ parser = argparse.ArgumentParser(description='neural music')
+ parser.add_argument('--config', type=str, default='',
+ help='location of the data corpus')
+ parser.add_argument('--exp_name', type=str, default='', help='exp_name')
+ parser.add_argument('--hparams', type=str, default='',
+ help='location of the data corpus')
+ parser.add_argument('--infer', action='store_true', help='infer')
+ parser.add_argument('--validate', action='store_true', help='validate')
+ parser.add_argument('--reset', action='store_true', help='reset hparams')
+ parser.add_argument('--debug', action='store_true', help='debug')
+ args, unknown = parser.parse_known_args()
+ else:
+ args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
+ infer=infer, validate=False, reset=reset, debug=False)
+ args_work_dir = ''
+ if args.exp_name != '':
+ args.work_dir = args.exp_name
+ args_work_dir = f'checkpoints/{args.work_dir}'
+
+ config_chains = []
+ loaded_config = set()
+
+ def load_config(config_fn): # deep first
+ with open(config_fn, encoding='utf-8') as f:
+ hparams_ = yaml.safe_load(f)
+ loaded_config.add(config_fn)
+ if 'base_config' in hparams_:
+ ret_hparams = {}
+ if not isinstance(hparams_['base_config'], list):
+ hparams_['base_config'] = [hparams_['base_config']]
+ for c in hparams_['base_config']:
+ if c not in loaded_config:
+ if c.startswith('.'):
+ c = f'{os.path.dirname(config_fn)}/{c}'
+ c = os.path.normpath(c)
+ override_config(ret_hparams, load_config(c))
+ override_config(ret_hparams, hparams_)
+ else:
+ ret_hparams = hparams_
+ config_chains.append(config_fn)
+ return ret_hparams
+
+ global hparams
+ assert args.config != '' or args_work_dir != ''
+ saved_hparams = {}
+ if args_work_dir != 'checkpoints/':
+ ckpt_config_path = f'{args_work_dir}/config.yaml'
+ if os.path.exists(ckpt_config_path):
+ try:
+ with open(ckpt_config_path, encoding='utf-8') as f:
+ saved_hparams.update(yaml.safe_load(f))
+ except:
+ pass
+ if args.config == '':
+ args.config = ckpt_config_path
+
+ hparams_ = {}
+
+ hparams_.update(load_config(args.config))
+
+ if not args.reset:
+ hparams_.update(saved_hparams)
+ hparams_['work_dir'] = args_work_dir
+
+ if args.hparams != "":
+ for new_hparam in args.hparams.split(","):
+ k, v = new_hparam.split("=")
+ if k not in hparams_:
+ hparams_[k] = eval(v)
+ if v in ['True', 'False'] or type(hparams_[k]) == bool:
+ hparams_[k] = eval(v)
+ else:
+ hparams_[k] = type(hparams_[k])(v)
+
+ if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
+ os.makedirs(hparams_['work_dir'], exist_ok=True)
+ with open(ckpt_config_path, 'w', encoding='utf-8') as f:
+ yaml.safe_dump(hparams_, f)
+
+ hparams_['infer'] = args.infer
+ hparams_['debug'] = args.debug
+ hparams_['validate'] = args.validate
+ global global_print_hparams
+ if global_hparams:
+ hparams.clear()
+ hparams.update(hparams_)
+
+ if print_hparams and global_print_hparams and global_hparams:
+ print('| Hparams chains: ', config_chains)
+ print('| Hparams: ')
+ for i, (k, v) in enumerate(sorted(hparams_.items())):
+ print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
+ print("")
+ global_print_hparams = False
+ # print(hparams_.keys())
+ if hparams.get('exp_name') is None:
+ hparams['exp_name'] = args.exp_name
+ if hparams_.get('exp_name') is None:
+ hparams_['exp_name'] = args.exp_name
+ return hparams_
diff --git a/utils/indexed_datasets.py b/utils/indexed_datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..e15632be30d6296a3c9aa80a1f351058003698b3
--- /dev/null
+++ b/utils/indexed_datasets.py
@@ -0,0 +1,71 @@
+import pickle
+from copy import deepcopy
+
+import numpy as np
+
+
+class IndexedDataset:
+ def __init__(self, path, num_cache=1):
+ super().__init__()
+ self.path = path
+ self.data_file = None
+ self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets']
+ self.data_file = open(f"{path}.data", 'rb', buffering=-1)
+ self.cache = []
+ self.num_cache = num_cache
+
+ def check_index(self, i):
+ if i < 0 or i >= len(self.data_offsets) - 1:
+ raise IndexError('index out of range')
+
+ def __del__(self):
+ if self.data_file:
+ self.data_file.close()
+
+ def __getitem__(self, i):
+ self.check_index(i)
+ if self.num_cache > 0:
+ for c in self.cache:
+ if c[0] == i:
+ return c[1]
+ self.data_file.seek(self.data_offsets[i])
+ b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i])
+ item = pickle.loads(b)
+ if self.num_cache > 0:
+ self.cache = [(i, deepcopy(item))] + self.cache[:-1]
+ return item
+
+ def __len__(self):
+ return len(self.data_offsets) - 1
+
+class IndexedDatasetBuilder:
+ def __init__(self, path):
+ self.path = path
+ self.out_file = open(f"{path}.data", 'wb')
+ self.byte_offsets = [0]
+
+ def add_item(self, item):
+ s = pickle.dumps(item)
+ bytes = self.out_file.write(s)
+ self.byte_offsets.append(self.byte_offsets[-1] + bytes)
+
+ def finalize(self):
+ self.out_file.close()
+ np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets})
+
+
+if __name__ == "__main__":
+ import random
+ from tqdm import tqdm
+ ds_path = '/tmp/indexed_ds_example'
+ size = 100
+ items = [{"a": np.random.normal(size=[10000, 10]),
+ "b": np.random.normal(size=[10000, 10])} for i in range(size)]
+ builder = IndexedDatasetBuilder(ds_path)
+ for i in tqdm(range(size)):
+ builder.add_item(items[i])
+ builder.finalize()
+ ds = IndexedDataset(ds_path)
+ for i in tqdm(range(10000)):
+ idx = random.randint(0, size - 1)
+ assert (ds[idx]['a'] == items[idx]['a']).all()
diff --git a/utils/multiprocess_utils.py b/utils/multiprocess_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..24876c4ca777f09d1c1e1b75674cd7aaf37a75a6
--- /dev/null
+++ b/utils/multiprocess_utils.py
@@ -0,0 +1,47 @@
+import os
+import traceback
+from multiprocessing import Queue, Process
+
+
+def chunked_worker(worker_id, map_func, args, results_queue=None, init_ctx_func=None):
+ ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None
+ for job_idx, arg in args:
+ try:
+ if ctx is not None:
+ res = map_func(*arg, ctx=ctx)
+ else:
+ res = map_func(*arg)
+ results_queue.put((job_idx, res))
+ except:
+ traceback.print_exc()
+ results_queue.put((job_idx, None))
+
+def chunked_multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, q_max_size=1000):
+ args = zip(range(len(args)), args)
+ args = list(args)
+ n_jobs = len(args)
+ if num_workers is None:
+ num_workers = int(os.getenv('N_PROC', os.cpu_count()))
+ results_queues = []
+ if ordered:
+ for i in range(num_workers):
+ results_queues.append(Queue(maxsize=q_max_size // num_workers))
+ else:
+ results_queue = Queue(maxsize=q_max_size)
+ for i in range(num_workers):
+ results_queues.append(results_queue)
+ workers = []
+ for i in range(num_workers):
+ args_worker = args[i::num_workers]
+ p = Process(target=chunked_worker, args=(
+ i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True)
+ workers.append(p)
+ p.start()
+ for n_finished in range(n_jobs):
+ results_queue = results_queues[n_finished % num_workers]
+ job_idx, res = results_queue.get()
+ assert job_idx == n_finished or not ordered, (job_idx, n_finished)
+ yield res
+ for w in workers:
+ w.join()
+ w.close()
diff --git a/utils/pitch_utils.py b/utils/pitch_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1767810e600c8f82e821ff4fc0a164daddaf7af4
--- /dev/null
+++ b/utils/pitch_utils.py
@@ -0,0 +1,76 @@
+#########
+# world
+##########
+import librosa
+import numpy as np
+import torch
+
+# gamma = 0
+# mcepInput = 3 # 0 for dB, 3 for magnitude
+# alpha = 0.45
+# en_floor = 10 ** (-80 / 20)
+# FFT_SIZE = 2048
+
+
+
+
+def f0_to_coarse(f0,hparams):
+ f0_bin = hparams['f0_bin']
+ f0_max = hparams['f0_max']
+ f0_min = hparams['f0_min']
+ is_torch = isinstance(f0, torch.Tensor)
+ f0_mel_min = 1127 * np.log(1 + f0_min / 700)
+ f0_mel_max = 1127 * np.log(1 + f0_max / 700)
+ f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700)
+ f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1
+
+ f0_mel[f0_mel <= 1] = 1
+ f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1
+ f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(int)
+ assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min())
+ return f0_coarse
+
+
+def norm_f0(f0, uv, hparams):
+ is_torch = isinstance(f0, torch.Tensor)
+ if hparams['pitch_norm'] == 'standard':
+ f0 = (f0 - hparams['f0_mean']) / hparams['f0_std']
+ if hparams['pitch_norm'] == 'log':
+ f0 = torch.log2(f0) if is_torch else np.log2(f0)
+ if uv is not None and hparams['use_uv']:
+ f0[uv > 0] = 0
+ return f0
+
+
+def norm_interp_f0(f0, hparams):
+ is_torch = isinstance(f0, torch.Tensor)
+ if is_torch:
+ device = f0.device
+ f0 = f0.data.cpu().numpy()
+ uv = f0 == 0
+ f0 = norm_f0(f0, uv, hparams)
+ if sum(uv) == len(f0):
+ f0[uv] = 0
+ elif sum(uv) > 0:
+ f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv])
+ uv = torch.FloatTensor(uv)
+ f0 = torch.FloatTensor(f0)
+ if is_torch:
+ f0 = f0.to(device)
+ return f0, uv
+
+
+def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None):
+ if hparams['pitch_norm'] == 'standard':
+ f0 = f0 * hparams['f0_std'] + hparams['f0_mean']
+ if hparams['pitch_norm'] == 'log':
+ f0 = 2 ** f0
+ if min is not None:
+ f0 = f0.clamp(min=min)
+ if max is not None:
+ f0 = f0.clamp(max=max)
+ if uv is not None and hparams['use_uv']:
+ f0[uv > 0] = 0
+ if pitch_padding is not None:
+ f0[pitch_padding] = 0
+ return f0
diff --git a/utils/pl_utils.py b/utils/pl_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f375637e83f07b3406e788026da2971d510540e7
--- /dev/null
+++ b/utils/pl_utils.py
@@ -0,0 +1,1625 @@
+import matplotlib
+from torch.nn import DataParallel
+from torch.nn.parallel import DistributedDataParallel
+
+matplotlib.use('Agg')
+import glob
+import itertools
+import subprocess
+import threading
+import traceback
+
+from pytorch_lightning.callbacks import GradientAccumulationScheduler
+from pytorch_lightning.callbacks import ModelCheckpoint
+
+from functools import wraps
+from torch.cuda._utils import _get_device_index
+import numpy as np
+import torch.optim
+import torch.utils.data
+import copy
+import logging
+import os
+import re
+import sys
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+import tqdm
+from torch.optim.optimizer import Optimizer
+
+
+def get_a_var(obj): # pragma: no cover
+ if isinstance(obj, torch.Tensor):
+ return obj
+
+ if isinstance(obj, list) or isinstance(obj, tuple):
+ for result in map(get_a_var, obj):
+ if isinstance(result, torch.Tensor):
+ return result
+ if isinstance(obj, dict):
+ for result in map(get_a_var, obj.items()):
+ if isinstance(result, torch.Tensor):
+ return result
+ return None
+
+
+def data_loader(fn):
+ """
+ Decorator to make any fx with this use the lazy property
+ :param fn:
+ :return:
+ """
+
+ wraps(fn)
+ attr_name = '_lazy_' + fn.__name__
+
+ def _get_data_loader(self):
+ try:
+ value = getattr(self, attr_name)
+ except AttributeError:
+ try:
+ value = fn(self) # Lazy evaluation, done only once.
+ if (
+ value is not None and
+ not isinstance(value, list) and
+ fn.__name__ in ['test_dataloader', 'val_dataloader']
+ ):
+ value = [value]
+ except AttributeError as e:
+ # Guard against AttributeError suppression. (Issue #142)
+ traceback.print_exc()
+ error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e)
+ raise RuntimeError(error) from e
+ setattr(self, attr_name, value) # Memoize evaluation.
+ return value
+
+ return _get_data_loader
+
+
+def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover
+ r"""Applies each `module` in :attr:`modules` in parallel on arguments
+ contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword)
+ on each of :attr:`devices`.
+
+ Args:
+ modules (Module): modules to be parallelized
+ inputs (tensor): inputs to the modules
+ devices (list of int or torch.device): CUDA devices
+
+ :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
+ :attr:`devices` (if given) should all have same length. Moreover, each
+ element of :attr:`inputs` can either be a single object as the only argument
+ to a module, or a collection of positional arguments.
+ """
+ assert len(modules) == len(inputs)
+ if kwargs_tup is not None:
+ assert len(modules) == len(kwargs_tup)
+ else:
+ kwargs_tup = ({},) * len(modules)
+ if devices is not None:
+ assert len(modules) == len(devices)
+ else:
+ devices = [None] * len(modules)
+ devices = list(map(lambda x: _get_device_index(x, True), devices))
+ lock = threading.Lock()
+ results = {}
+ grad_enabled = torch.is_grad_enabled()
+
+ def _worker(i, module, input, kwargs, device=None):
+ torch.set_grad_enabled(grad_enabled)
+ if device is None:
+ device = get_a_var(input).get_device()
+ try:
+ with torch.cuda.device(device):
+ # this also avoids accidental slicing of `input` if it is a Tensor
+ if not isinstance(input, (list, tuple)):
+ input = (input,)
+
+ # ---------------
+ # CHANGE
+ if module.training:
+ output = module.training_step(*input, **kwargs)
+
+ elif module.testing:
+ output = module.test_step(*input, **kwargs)
+
+ else:
+ output = module.validation_step(*input, **kwargs)
+ # ---------------
+
+ with lock:
+ results[i] = output
+ except Exception as e:
+ with lock:
+ results[i] = e
+
+ # make sure each module knows what training state it's in...
+ # fixes weird bug where copies are out of sync
+ root_m = modules[0]
+ for m in modules[1:]:
+ m.training = root_m.training
+ m.testing = root_m.testing
+
+ if len(modules) > 1:
+ threads = [threading.Thread(target=_worker,
+ args=(i, module, input, kwargs, device))
+ for i, (module, input, kwargs, device) in
+ enumerate(zip(modules, inputs, kwargs_tup, devices))]
+
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
+ else:
+ _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
+
+ outputs = []
+ for i in range(len(inputs)):
+ output = results[i]
+ if isinstance(output, Exception):
+ raise output
+ outputs.append(output)
+ return outputs
+
+
+def _find_tensors(obj): # pragma: no cover
+ r"""
+ Recursively find all tensors contained in the specified object.
+ """
+ if isinstance(obj, torch.Tensor):
+ return [obj]
+ if isinstance(obj, (list, tuple)):
+ return itertools.chain(*map(_find_tensors, obj))
+ if isinstance(obj, dict):
+ return itertools.chain(*map(_find_tensors, obj.values()))
+ return []
+
+
+class DDP(DistributedDataParallel):
+ """
+ Override the forward call in lightning so it goes to training and validation step respectively
+ """
+
+ def parallel_apply(self, replicas, inputs, kwargs):
+ return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+ def forward(self, *inputs, **kwargs): # pragma: no cover
+ self._sync_params()
+ if self.device_ids:
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ # --------------
+ # LIGHTNING MOD
+ # --------------
+ # normal
+ # output = self.module(*inputs[0], **kwargs[0])
+ # lightning
+ if self.module.training:
+ output = self.module.training_step(*inputs[0], **kwargs[0])
+ elif self.module.testing:
+ output = self.module.test_step(*inputs[0], **kwargs[0])
+ else:
+ output = self.module.validation_step(*inputs[0], **kwargs[0])
+ else:
+ outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs)
+ output = self.gather(outputs, self.output_device)
+ else:
+ # normal
+ output = self.module(*inputs, **kwargs)
+
+ if torch.is_grad_enabled():
+ # We'll return the output object verbatim since it is a freeform
+ # object. We need to find any tensors in this object, though,
+ # because we need to figure out which parameters were used during
+ # this forward pass, to ensure we short circuit reduction for any
+ # unused parameters. Only if `find_unused_parameters` is set.
+ if self.find_unused_parameters:
+ self.reducer.prepare_for_backward(list(_find_tensors(output)))
+ else:
+ self.reducer.prepare_for_backward([])
+ return output
+
+
+class DP(DataParallel):
+ """
+ Override the forward call in lightning so it goes to training and validation step respectively
+ """
+
+ def forward(self, *inputs, **kwargs):
+ if not self.device_ids:
+ return self.module(*inputs, **kwargs)
+
+ for t in itertools.chain(self.module.parameters(), self.module.buffers()):
+ if t.device != self.src_device_obj:
+ raise RuntimeError("module must have its parameters and buffers "
+ "on device {} (device_ids[0]) but found one of "
+ "them on device: {}".format(self.src_device_obj, t.device))
+
+ inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
+ if len(self.device_ids) == 1:
+ # lightning
+ if self.module.training:
+ return self.module.training_step(*inputs[0], **kwargs[0])
+ elif self.module.testing:
+ return self.module.test_step(*inputs[0], **kwargs[0])
+ else:
+ return self.module.validation_step(*inputs[0], **kwargs[0])
+
+ replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
+ outputs = self.parallel_apply(replicas, inputs, kwargs)
+ return self.gather(outputs, self.output_device)
+
+ def parallel_apply(self, replicas, inputs, kwargs):
+ return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
+
+
+class GradientAccumulationScheduler:
+ def __init__(self, scheduling: dict):
+ if scheduling == {}: # empty dict error
+ raise TypeError("Empty dict cannot be interpreted correct")
+
+ for key in scheduling.keys():
+ if not isinstance(key, int) or not isinstance(scheduling[key], int):
+ raise TypeError("All epoches and accumulation factor must be integers")
+
+ minimal_epoch = min(scheduling.keys())
+ if minimal_epoch < 1:
+ msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
+ raise IndexError(msg)
+ elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
+ scheduling.update({1: 1})
+
+ self.scheduling = scheduling
+ self.epochs = sorted(scheduling.keys())
+
+ def on_epoch_begin(self, epoch, trainer):
+ epoch += 1 # indexing epochs from 1
+ for i in reversed(range(len(self.epochs))):
+ if epoch >= self.epochs[i]:
+ trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
+ break
+
+
+class LatestModelCheckpoint(ModelCheckpoint):
+ def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5,
+ save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True):
+ super(ModelCheckpoint, self).__init__()
+ self.monitor = monitor
+ self.verbose = verbose
+ self.filepath = filepath
+ os.makedirs(filepath, exist_ok=True)
+ self.num_ckpt_keep = num_ckpt_keep
+ self.save_best = save_best
+ self.save_weights_only = save_weights_only
+ self.period = period
+ self.epochs_since_last_check = 0
+ self.prefix = prefix
+ self.best_k_models = {}
+ # {filename: monitor}
+ self.kth_best_model = ''
+ self.save_top_k = 1
+ self.task = None
+ if mode == 'min':
+ self.monitor_op = np.less
+ self.best = np.Inf
+ self.mode = 'min'
+ elif mode == 'max':
+ self.monitor_op = np.greater
+ self.best = -np.Inf
+ self.mode = 'max'
+ else:
+ if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
+ self.monitor_op = np.greater
+ self.best = -np.Inf
+ self.mode = 'max'
+ else:
+ self.monitor_op = np.less
+ self.best = np.Inf
+ self.mode = 'min'
+ if os.path.exists(f'{self.filepath}/best_valid.npy'):
+ self.best = np.load(f'{self.filepath}/best_valid.npy')[0]
+
+ def get_all_ckpts(self):
+ return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'),
+ key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0]))
+
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ self.epochs_since_last_check += 1
+ best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt'
+ if self.epochs_since_last_check >= self.period:
+ self.epochs_since_last_check = 0
+ filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt'
+ if self.verbose > 0:
+ logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}')
+ self._save_model(filepath)
+ for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]:
+ # TODO: test filesystem calls
+ os.remove(old_ckpt)
+ # subprocess.check_call(f'del "{old_ckpt}"', shell=True)
+ if self.verbose > 0:
+ logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}')
+ current = logs.get(self.monitor)
+ if current is not None and self.save_best:
+ if self.monitor_op(current, self.best):
+ self.best = current
+ if self.verbose > 0:
+ logging.info(
+ f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached'
+ f' {current:0.5f} (best {self.best:0.5f}), saving model to'
+ f' {best_filepath} as top 1')
+ self._save_model(best_filepath)
+ np.save(f'{self.filepath}/best_valid.npy', [self.best])
+
+ def _save_model(self,path):
+ return self.save_function(path)
+
+
+
+class BaseTrainer:
+ def __init__(
+ self,
+ logger=True,
+ checkpoint_callback=True,
+ default_save_path=None,
+ gradient_clip_val=0,
+ process_position=0,
+ gpus=-1,
+ log_gpu_memory=None,
+ show_progress_bar=True,
+ track_grad_norm=-1,
+ check_val_every_n_epoch=1,
+ accumulate_grad_batches=1,
+ max_updates=1000,
+ min_epochs=1,
+ val_check_interval=1.0,
+ log_save_interval=100,
+ row_log_interval=10,
+ print_nan_grads=False,
+ weights_summary='full',
+ num_sanity_val_steps=5,
+ resume_from_checkpoint=None,
+ ):
+ self.log_gpu_memory = log_gpu_memory
+ self.gradient_clip_val = gradient_clip_val
+ self.check_val_every_n_epoch = check_val_every_n_epoch
+ self.track_grad_norm = track_grad_norm
+ self.on_gpu = True if (gpus and torch.cuda.is_available()) else False
+ self.process_position = process_position
+ self.weights_summary = weights_summary
+ self.max_updates = max_updates
+ self.min_epochs = min_epochs
+ self.num_sanity_val_steps = num_sanity_val_steps
+ self.print_nan_grads = print_nan_grads
+ self.resume_from_checkpoint = resume_from_checkpoint
+ self.default_save_path = default_save_path
+
+ # training bookeeping
+ self.total_batch_idx = 0
+ self.running_loss = []
+ self.avg_loss = 0
+ self.batch_idx = 0
+ self.tqdm_metrics = {}
+ self.callback_metrics = {}
+ self.num_val_batches = 0
+ self.num_training_batches = 0
+ self.num_test_batches = 0
+ self.get_train_dataloader = None
+ self.get_test_dataloaders = None
+ self.get_val_dataloaders = None
+ self.is_iterable_train_dataloader = False
+
+ # training state
+ self.model = None
+ self.testing = False
+ self.disable_validation = False
+ self.lr_schedulers = []
+ self.optimizers = None
+ self.global_step = 0
+ self.current_epoch = 0
+ self.total_batches = 0
+
+ # configure checkpoint callback
+ self.checkpoint_callback = checkpoint_callback
+ self.checkpoint_callback.save_function = self.save_checkpoint
+ self.weights_save_path = self.checkpoint_callback.filepath
+
+ # accumulated grads
+ self.configure_accumulated_gradients(accumulate_grad_batches)
+
+ # allow int, string and gpu list
+ self.data_parallel_device_ids = [
+ int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != '']
+ if len(self.data_parallel_device_ids) == 0:
+ self.root_gpu = None
+ self.on_gpu = False
+ else:
+ self.root_gpu = self.data_parallel_device_ids[0]
+ self.on_gpu = True
+
+ # distributed backend choice
+ self.use_ddp = False
+ self.use_dp = False
+ self.single_gpu = False
+ self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp'
+ self.set_distributed_mode(self.distributed_backend)
+
+ self.proc_rank = 0
+ self.world_size = 1
+ self.node_rank = 0
+
+ # can't init progress bar here because starting a new process
+ # means the progress_bar won't survive pickling
+ self.show_progress_bar = show_progress_bar
+
+ # logging
+ self.log_save_interval = log_save_interval
+ self.val_check_interval = val_check_interval
+ self.logger = logger
+ self.logger.rank = 0
+ self.row_log_interval = row_log_interval
+
+ @property
+ def num_gpus(self):
+ gpus = self.data_parallel_device_ids
+ if gpus is None:
+ return 0
+ else:
+ return len(gpus)
+
+ @property
+ def data_parallel(self):
+ return self.use_dp or self.use_ddp
+
+ def get_model(self):
+ is_dp_module = isinstance(self.model, (DDP, DP))
+ model = self.model.module if is_dp_module else self.model
+ return model
+
+ # -----------------------------
+ # MODEL TRAINING
+ # -----------------------------
+ def fit(self, model):
+ if self.use_ddp:
+ mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,))
+ else:
+ model.model = model.build_model()
+ if not self.testing:
+ self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
+ if self.use_dp:
+ model.cuda(self.root_gpu)
+ model = DP(model, device_ids=self.data_parallel_device_ids)
+ elif self.single_gpu:
+ model.cuda(self.root_gpu)
+ self.run_pretrain_routine(model)
+ return 1
+
+ def init_optimizers(self, optimizers):
+
+ # single optimizer
+ if isinstance(optimizers, Optimizer):
+ return [optimizers], []
+
+ # two lists
+ elif len(optimizers) == 2 and isinstance(optimizers[0], list):
+ optimizers, lr_schedulers = optimizers
+ return optimizers, lr_schedulers
+
+ # single list or tuple
+ elif isinstance(optimizers, list) or isinstance(optimizers, tuple):
+ return optimizers, []
+
+ def run_pretrain_routine(self, model):
+ """Sanity check a few things before starting actual training.
+
+ :param model:
+ """
+ ref_model = model
+ if self.data_parallel:
+ ref_model = model.module
+
+ # give model convenience properties
+ ref_model.trainer = self
+
+ # set local properties on the model
+ self.copy_trainer_model_properties(ref_model)
+
+ # link up experiment object
+ if self.logger is not None:
+ ref_model.logger = self.logger
+ self.logger.save()
+
+ if self.use_ddp:
+ dist.barrier()
+
+ # set up checkpoint callback
+ # self.configure_checkpoint_callback()
+
+ # transfer data loaders from model
+ self.get_dataloaders(ref_model)
+
+ # track model now.
+ # if cluster resets state, the model will update with the saved weights
+ self.model = model
+
+ # restore training and model before hpc call
+ self.restore_weights(model)
+
+ # when testing requested only run test and return
+ if self.testing:
+ self.run_evaluation(test=True)
+ return
+
+ # check if we should run validation during training
+ self.disable_validation = self.num_val_batches == 0
+
+ # run tiny validation (if validation defined)
+ # to make sure program won't crash during val
+ ref_model.on_sanity_check_start()
+ ref_model.on_train_start()
+ if not self.disable_validation and self.num_sanity_val_steps > 0:
+ # init progress bars for validation sanity check
+ pbar = tqdm.tqdm(desc='Validation sanity check',
+ total=self.num_sanity_val_steps * len(self.get_val_dataloaders()),
+ leave=False, position=2 * self.process_position,
+ disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch')
+ self.main_progress_bar = pbar
+ # dummy validation progress bar
+ self.val_progress_bar = tqdm.tqdm(disable=True)
+
+ self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing)
+
+ # close progress bars
+ self.main_progress_bar.close()
+ self.val_progress_bar.close()
+
+ # init progress bar
+ pbar = tqdm.tqdm(leave=True, position=2 * self.process_position,
+ disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch',
+ file=sys.stdout)
+ self.main_progress_bar = pbar
+
+ # clear cache before training
+ if self.on_gpu:
+ torch.cuda.empty_cache()
+
+ # CORE TRAINING LOOP
+ self.train()
+
+ def test(self, model):
+ self.testing = True
+ self.fit(model)
+
+ @property
+ def training_tqdm_dict(self):
+ tqdm_dict = {
+ 'step': '{}'.format(self.global_step),
+ }
+ tqdm_dict.update(self.tqdm_metrics)
+ return tqdm_dict
+
+ # --------------------
+ # restore ckpt
+ # --------------------
+ def restore_weights(self, model):
+ """
+ To restore weights we have two cases.
+ First, attempt to restore hpc weights. If successful, don't restore
+ other weights.
+
+ Otherwise, try to restore actual weights
+ :param model:
+ :return:
+ """
+ # clear cache before restore
+ if self.on_gpu:
+ torch.cuda.empty_cache()
+
+ if self.resume_from_checkpoint is not None:
+ self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu)
+ else:
+ # restore weights if same exp version
+ self.restore_state_if_checkpoint_exists(model)
+
+ # wait for all models to restore weights
+ if self.use_ddp:
+ # wait for all processes to catch up
+ dist.barrier()
+
+ # clear cache after restore
+ if self.on_gpu:
+ torch.cuda.empty_cache()
+
+ def restore_state_if_checkpoint_exists(self, model):
+ did_restore = False
+
+ # do nothing if there's not dir or callback
+ no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback)
+ if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath):
+ return did_restore
+
+ # restore trainer state and model if there is a weight for this experiment
+ last_steps = -1
+ last_ckpt_name = None
+
+ # find last epoch
+ checkpoints = os.listdir(self.checkpoint_callback.filepath)
+ for name in checkpoints:
+ if '.ckpt' in name and not name.endswith('part'):
+ if 'steps_' in name:
+ steps = name.split('steps_')[1]
+ steps = int(re.sub('[^0-9]', '', steps))
+
+ if steps > last_steps:
+ last_steps = steps
+ last_ckpt_name = name
+
+ # restore last checkpoint
+ if last_ckpt_name is not None:
+ last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name)
+ self.restore(last_ckpt_path, self.on_gpu)
+ logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}')
+ did_restore = True
+
+ return did_restore
+
+ def restore(self, checkpoint_path, on_gpu):
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+
+ # load model state
+ model = self.get_model()
+
+ # load the state_dict on the model automatically
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
+ if on_gpu:
+ model.cuda(self.root_gpu)
+ # load training state (affects trainer only)
+ self.restore_training_state(checkpoint)
+ model.global_step = self.global_step
+ del checkpoint
+
+ try:
+ if dist.is_initialized() and dist.get_rank() > 0:
+ return
+ except Exception as e:
+ print(e)
+ return
+
+ def restore_training_state(self, checkpoint):
+ """
+ Restore trainer state.
+ Model will get its change to update
+ :param checkpoint:
+ :return:
+ """
+ if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
+ # return allowing checkpoints with meta information (global_step, etc)
+ self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
+
+ self.global_step = checkpoint['global_step']
+ self.current_epoch = checkpoint['epoch']
+
+ if self.testing:
+ return
+
+ # restore the optimizers
+ optimizer_states = checkpoint['optimizer_states']
+ for optimizer, opt_state in zip(self.optimizers, optimizer_states):
+ if optimizer is None:
+ return
+ optimizer.load_state_dict(opt_state)
+
+ # move optimizer to GPU 1 weight at a time
+ # avoids OOM
+ if self.root_gpu is not None:
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.cuda(self.root_gpu)
+
+ # restore the lr schedulers
+ lr_schedulers = checkpoint['lr_schedulers']
+ for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
+ scheduler.load_state_dict(lrs_state)
+
+ # --------------------
+ # MODEL SAVE CHECKPOINT
+ # --------------------
+ def _atomic_save(self, checkpoint, filepath):
+ """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.
+
+ This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once
+ saving is finished.
+
+ Args:
+ checkpoint (object): The object to save.
+ Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save``
+ accepts.
+ filepath (str|pathlib.Path): The path to which the checkpoint will be saved.
+ This points to the file that the checkpoint will be stored in.
+ """
+ tmp_path = str(filepath) + ".part"
+ torch.save(checkpoint, tmp_path)
+ os.replace(tmp_path, filepath)
+
+ def save_checkpoint(self, filepath):
+ checkpoint = self.dump_checkpoint()
+ self._atomic_save(checkpoint, filepath)
+
+ def dump_checkpoint(self):
+
+ checkpoint = {
+ 'epoch': self.current_epoch,
+ 'global_step': self.global_step
+ }
+
+ if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
+ checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
+
+ # save optimizers
+ optimizer_states = []
+ for i, optimizer in enumerate(self.optimizers):
+ if optimizer is not None:
+ optimizer_states.append(optimizer.state_dict())
+
+ checkpoint['optimizer_states'] = optimizer_states
+
+ # save lr schedulers
+ lr_schedulers = []
+ for i, scheduler in enumerate(self.lr_schedulers):
+ lr_schedulers.append(scheduler.state_dict())
+
+ checkpoint['lr_schedulers'] = lr_schedulers
+
+ # add the hparams and state_dict from the model
+ model = self.get_model()
+ checkpoint['state_dict'] = model.state_dict()
+ # give the model a chance to add a few things
+ model.on_save_checkpoint(checkpoint)
+
+ return checkpoint
+
+ def copy_trainer_model_properties(self, model):
+ if isinstance(model, DP):
+ ref_model = model.module
+ elif isinstance(model, DDP):
+ ref_model = model.module
+ else:
+ ref_model = model
+
+ for m in [model, ref_model]:
+ m.trainer = self
+ m.on_gpu = self.on_gpu
+ m.use_dp = self.use_dp
+ m.use_ddp = self.use_ddp
+ m.testing = self.testing
+ m.single_gpu = self.single_gpu
+
+ def transfer_batch_to_gpu(self, batch, gpu_id):
+ # base case: object can be directly moved using `cuda` or `to`
+ if callable(getattr(batch, 'cuda', None)):
+ return batch.cuda(gpu_id, non_blocking=True)
+
+ elif callable(getattr(batch, 'to', None)):
+ return batch.to(torch.device('cuda', gpu_id), non_blocking=True)
+
+ # when list
+ elif isinstance(batch, list):
+ for i, x in enumerate(batch):
+ batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
+ return batch
+
+ # when tuple
+ elif isinstance(batch, tuple):
+ batch = list(batch)
+ for i, x in enumerate(batch):
+ batch[i] = self.transfer_batch_to_gpu(x, gpu_id)
+ return tuple(batch)
+
+ # when dict
+ elif isinstance(batch, dict):
+ for k, v in batch.items():
+ batch[k] = self.transfer_batch_to_gpu(v, gpu_id)
+
+ return batch
+
+ # nothing matches, return the value as is without transform
+ return batch
+
+ def set_distributed_mode(self, distributed_backend):
+ # skip for CPU
+ if self.num_gpus == 0:
+ return
+
+ # single GPU case
+ # in single gpu case we allow ddp so we can train on multiple
+ # nodes, 1 gpu per node
+ elif self.num_gpus == 1:
+ self.single_gpu = True
+ self.use_dp = False
+ self.use_ddp = False
+ self.root_gpu = 0
+ self.data_parallel_device_ids = [0]
+ else:
+ if distributed_backend is not None:
+ self.use_dp = distributed_backend == 'dp'
+ self.use_ddp = distributed_backend == 'ddp'
+ elif distributed_backend is None:
+ self.use_dp = True
+ self.use_ddp = False
+
+ logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}')
+
+ def ddp_train(self, gpu_idx, model):
+ """
+ Entry point into a DP thread
+ :param gpu_idx:
+ :param model:
+ :param cluster_obj:
+ :return:
+ """
+ # otherwise default to node rank 0
+ self.node_rank = 0
+
+ # show progressbar only on progress_rank 0
+ self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0
+
+ # determine which process we are and world size
+ if self.use_ddp:
+ self.proc_rank = self.node_rank * self.num_gpus + gpu_idx
+ self.world_size = self.num_gpus
+
+ # let the exp know the rank to avoid overwriting logs
+ if self.logger is not None:
+ self.logger.rank = self.proc_rank
+
+ # set up server using proc 0's ip address
+ # try to init for 20 times at max in case ports are taken
+ # where to store ip_table
+ model.trainer = self
+ model.init_ddp_connection(self.proc_rank, self.world_size)
+
+ # CHOOSE OPTIMIZER
+ # allow for lr schedulers as well
+ model.model = model.build_model()
+ if not self.testing:
+ self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers())
+
+ # MODEL
+ # copy model to each gpu
+ if self.distributed_backend == 'ddp':
+ torch.cuda.set_device(gpu_idx)
+ model.cuda(gpu_idx)
+
+ # set model properties before going into wrapper
+ self.copy_trainer_model_properties(model)
+
+ # override root GPU
+ self.root_gpu = gpu_idx
+
+ if self.distributed_backend == 'ddp':
+ device_ids = [gpu_idx]
+ else:
+ device_ids = None
+
+ # allow user to configure ddp
+ model = model.configure_ddp(model, device_ids)
+
+ # continue training routine
+ self.run_pretrain_routine(model)
+
+ def resolve_root_node_address(self, root_node):
+ if '[' in root_node:
+ name = root_node.split('[')[0]
+ number = root_node.split(',')[0]
+ if '-' in number:
+ number = number.split('-')[0]
+
+ number = re.sub('[^0-9]', '', number)
+ root_node = name + number
+
+ return root_node
+
+ def log_metrics(self, metrics, grad_norm_dic, step=None):
+ """Logs the metric dict passed in.
+
+ :param metrics:
+ :param grad_norm_dic:
+ """
+ # added metrics by Lightning for convenience
+ metrics['epoch'] = self.current_epoch
+
+ # add norms
+ metrics.update(grad_norm_dic)
+
+ # turn all tensors to scalars
+ scalar_metrics = self.metrics_to_scalars(metrics)
+
+ step = step if step is not None else self.global_step
+ # log actual metrics
+ if self.proc_rank == 0 and self.logger is not None:
+ self.logger.log_metrics(scalar_metrics, step=step)
+ self.logger.save()
+
+ def add_tqdm_metrics(self, metrics):
+ for k, v in metrics.items():
+ if type(v) is torch.Tensor:
+ v = v.item()
+
+ self.tqdm_metrics[k] = v
+
+ def metrics_to_scalars(self, metrics):
+ new_metrics = {}
+ for k, v in metrics.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+
+ if type(v) is dict:
+ v = self.metrics_to_scalars(v)
+
+ new_metrics[k] = v
+
+ return new_metrics
+
+ def process_output(self, output, train=False):
+ """Reduces output according to the training mode.
+
+ Separates loss from logging and tqdm metrics
+ :param output:
+ :return:
+ """
+ # ---------------
+ # EXTRACT CALLBACK KEYS
+ # ---------------
+ # all keys not progress_bar or log are candidates for callbacks
+ callback_metrics = {}
+ for k, v in output.items():
+ if k not in ['progress_bar', 'log', 'hiddens']:
+ callback_metrics[k] = v
+
+ if train and self.use_dp:
+ num_gpus = self.num_gpus
+ callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
+
+ for k, v in callback_metrics.items():
+ if isinstance(v, torch.Tensor):
+ callback_metrics[k] = v.item()
+
+ # ---------------
+ # EXTRACT PROGRESS BAR KEYS
+ # ---------------
+ try:
+ progress_output = output['progress_bar']
+
+ # reduce progress metrics for tqdm when using dp
+ if train and self.use_dp:
+ num_gpus = self.num_gpus
+ progress_output = self.reduce_distributed_output(progress_output, num_gpus)
+
+ progress_bar_metrics = progress_output
+ except Exception:
+ progress_bar_metrics = {}
+
+ # ---------------
+ # EXTRACT LOGGING KEYS
+ # ---------------
+ # extract metrics to log to experiment
+ try:
+ log_output = output['log']
+
+ # reduce progress metrics for tqdm when using dp
+ if train and self.use_dp:
+ num_gpus = self.num_gpus
+ log_output = self.reduce_distributed_output(log_output, num_gpus)
+
+ log_metrics = log_output
+ except Exception:
+ log_metrics = {}
+
+ # ---------------
+ # EXTRACT LOSS
+ # ---------------
+ # if output dict doesn't have the keyword loss
+ # then assume the output=loss if scalar
+ loss = None
+ if train:
+ try:
+ loss = output['loss']
+ except Exception:
+ if type(output) is torch.Tensor:
+ loss = output
+ else:
+ raise RuntimeError(
+ 'No `loss` value in the dictionary returned from `model.training_step()`.'
+ )
+
+ # when using dp need to reduce the loss
+ if self.use_dp:
+ loss = self.reduce_distributed_output(loss, self.num_gpus)
+
+ # ---------------
+ # EXTRACT HIDDEN
+ # ---------------
+ hiddens = output.get('hiddens')
+
+ # use every metric passed in as a candidate for callback
+ callback_metrics.update(progress_bar_metrics)
+ callback_metrics.update(log_metrics)
+
+ # convert tensors to numpy
+ for k, v in callback_metrics.items():
+ if isinstance(v, torch.Tensor):
+ callback_metrics[k] = v.item()
+
+ return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
+
+ def reduce_distributed_output(self, output, num_gpus):
+ if num_gpus <= 1:
+ return output
+
+ # when using DP, we get one output per gpu
+ # average outputs and return
+ if type(output) is torch.Tensor:
+ return output.mean()
+
+ for k, v in output.items():
+ # recurse on nested dics
+ if isinstance(output[k], dict):
+ output[k] = self.reduce_distributed_output(output[k], num_gpus)
+
+ # do nothing when there's a scalar
+ elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0:
+ pass
+
+ # reduce only metrics that have the same number of gpus
+ elif output[k].size(0) == num_gpus:
+ reduced = torch.mean(output[k])
+ output[k] = reduced
+ return output
+
+ def clip_gradients(self):
+ if self.gradient_clip_val > 0:
+ model = self.get_model()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val)
+
+ def print_nan_gradients(self):
+ model = self.get_model()
+ for param in model.parameters():
+ if (param.grad is not None) and torch.isnan(param.grad.float()).any():
+ logging.info(param, param.grad)
+
+ def configure_accumulated_gradients(self, accumulate_grad_batches):
+ self.accumulate_grad_batches = None
+
+ if isinstance(accumulate_grad_batches, dict):
+ self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
+ elif isinstance(accumulate_grad_batches, int):
+ schedule = {1: accumulate_grad_batches}
+ self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
+ else:
+ raise TypeError("Gradient accumulation supports only int and dict types")
+
+ def get_dataloaders(self, model):
+ if not self.testing:
+ self.init_train_dataloader(model)
+ self.init_val_dataloader(model)
+ else:
+ self.init_test_dataloader(model)
+
+ if self.use_ddp:
+ dist.barrier()
+ if not self.testing:
+ self.get_train_dataloader()
+ self.get_val_dataloaders()
+ else:
+ self.get_test_dataloaders()
+
+ def init_train_dataloader(self, model):
+ self.fisrt_epoch = True
+ self.get_train_dataloader = model.train_dataloader
+ if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader):
+ self.num_training_batches = len(self.get_train_dataloader())
+ self.num_training_batches = int(self.num_training_batches)
+ else:
+ self.num_training_batches = float('inf')
+ self.is_iterable_train_dataloader = True
+ if isinstance(self.val_check_interval, int):
+ self.val_check_batch = self.val_check_interval
+ else:
+ self._percent_range_check('val_check_interval')
+ self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
+ self.val_check_batch = max(1, self.val_check_batch)
+
+ def init_val_dataloader(self, model):
+ self.get_val_dataloaders = model.val_dataloader
+ self.num_val_batches = 0
+ if self.get_val_dataloaders() is not None:
+ if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader):
+ self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders())
+ self.num_val_batches = int(self.num_val_batches)
+ else:
+ self.num_val_batches = float('inf')
+
+ def init_test_dataloader(self, model):
+ self.get_test_dataloaders = model.test_dataloader
+ if self.get_test_dataloaders() is not None:
+ if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader):
+ self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders())
+ self.num_test_batches = int(self.num_test_batches)
+ else:
+ self.num_test_batches = float('inf')
+
+ def evaluate(self, model, dataloaders, max_batches, test=False):
+ """Run evaluation code.
+
+ :param model: PT model
+ :param dataloaders: list of PT dataloaders
+ :param max_batches: Scalar
+ :param test: boolean
+ :return:
+ """
+ # enable eval mode
+ model.zero_grad()
+ model.eval()
+
+ # copy properties for forward overrides
+ self.copy_trainer_model_properties(model)
+
+ # disable gradients to save memory
+ torch.set_grad_enabled(False)
+
+ if test:
+ self.get_model().test_start()
+ # bookkeeping
+ outputs = []
+
+ # run training
+ for dataloader_idx, dataloader in enumerate(dataloaders):
+ dl_outputs = []
+ for batch_idx, batch in enumerate(dataloader):
+
+ if batch is None: # pragma: no cover
+ continue
+
+ # stop short when on fast_dev_run (sets max_batch=1)
+ if batch_idx >= max_batches:
+ break
+
+ # -----------------
+ # RUN EVALUATION STEP
+ # -----------------
+ output = self.evaluation_forward(model,
+ batch,
+ batch_idx,
+ dataloader_idx,
+ test)
+
+ # track outputs for collation
+ dl_outputs.append(output)
+
+ # batch done
+ if test:
+ self.test_progress_bar.update(1)
+ else:
+ self.val_progress_bar.update(1)
+ outputs.append(dl_outputs)
+
+ # with a single dataloader don't pass an array
+ if len(dataloaders) == 1:
+ outputs = outputs[0]
+
+ # give model a chance to do something with the outputs (and method defined)
+ model = self.get_model()
+ if test:
+ eval_results_ = model.test_end(outputs)
+ else:
+ eval_results_ = model.validation_end(outputs)
+ eval_results = eval_results_
+
+ # enable train mode again
+ model.train()
+
+ # enable gradients to save memory
+ torch.set_grad_enabled(True)
+
+ return eval_results
+
+ def run_evaluation(self, test=False):
+ # when testing make sure user defined a test step
+ model = self.get_model()
+ model.on_pre_performance_check()
+
+ # select dataloaders
+ if test:
+ dataloaders = self.get_test_dataloaders()
+ max_batches = self.num_test_batches
+ else:
+ # val
+ dataloaders = self.get_val_dataloaders()
+ max_batches = self.num_val_batches
+
+ # init validation or test progress bar
+ # main progress bar will already be closed when testing so initial position is free
+ position = 2 * self.process_position + (not test)
+ desc = 'Testing' if test else 'Validating'
+ pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position,
+ disable=not self.show_progress_bar, dynamic_ncols=True,
+ unit='batch', file=sys.stdout)
+ setattr(self, f'{"test" if test else "val"}_progress_bar', pbar)
+
+ # run evaluation
+ eval_results = self.evaluate(self.model,
+ dataloaders,
+ max_batches,
+ test)
+ if eval_results is not None:
+ _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(
+ eval_results)
+
+ # add metrics to prog bar
+ self.add_tqdm_metrics(prog_bar_metrics)
+
+ # log metrics
+ self.log_metrics(log_metrics, {})
+
+ # track metrics for callbacks
+ self.callback_metrics.update(callback_metrics)
+
+ # hook
+ model.on_post_performance_check()
+
+ # add model specific metrics
+ tqdm_metrics = self.training_tqdm_dict
+ if not test:
+ self.main_progress_bar.set_postfix(**tqdm_metrics)
+
+ # close progress bar
+ if test:
+ self.test_progress_bar.close()
+ else:
+ self.val_progress_bar.close()
+
+ # model checkpointing
+ if self.proc_rank == 0 and self.checkpoint_callback is not None and not test:
+ self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch,
+ logs=self.callback_metrics)
+
+ def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False):
+ # make dataloader_idx arg in validation_step optional
+ args = [batch, batch_idx]
+ # print(batch)
+ if test and len(self.get_test_dataloaders()) > 1:
+ args.append(dataloader_idx)
+
+ elif not test and len(self.get_val_dataloaders()) > 1:
+ args.append(dataloader_idx)
+
+ # handle DP, DDP forward
+ if self.use_ddp or self.use_dp:
+ output = model(*args)
+ return output
+
+ # single GPU
+ if self.single_gpu:
+ # for single GPU put inputs on gpu manually
+ root_gpu = 0
+ if isinstance(self.data_parallel_device_ids, list):
+ root_gpu = self.data_parallel_device_ids[0]
+ batch = self.transfer_batch_to_gpu(batch, root_gpu)
+ args[0] = batch
+
+ # CPU
+ if test:
+ output = model.test_step(*args)
+ else:
+ output = model.validation_step(*args)
+
+ return output
+
+ def train(self):
+ model = self.get_model()
+ # run all epochs
+ for epoch in range(self.current_epoch, 1000000):
+ # set seed for distributed sampler (enables shuffling for each epoch)
+ if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'):
+ self.get_train_dataloader().sampler.set_epoch(epoch)
+
+ # get model
+ model = self.get_model()
+
+ # update training progress in trainer and model
+ model.current_epoch = epoch
+ self.current_epoch = epoch
+
+ total_val_batches = 0
+ if not self.disable_validation:
+ # val can be checked multiple times in epoch
+ is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0
+ val_checks_per_epoch = self.num_training_batches // self.val_check_batch
+ val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0
+ total_val_batches = self.num_val_batches * val_checks_per_epoch
+
+ # total batches includes multiple val checks
+ self.total_batches = self.num_training_batches + total_val_batches
+ self.batch_loss_value = 0 # accumulated grads
+
+ if self.is_iterable_train_dataloader:
+ # for iterable train loader, the progress bar never ends
+ num_iterations = None
+ else:
+ num_iterations = self.total_batches
+
+ # reset progress bar
+ # .reset() doesn't work on disabled progress bar so we should check
+ desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else ''
+ self.main_progress_bar.set_description(desc)
+
+ # changing gradient according accumulation_scheduler
+ self.accumulation_scheduler.on_epoch_begin(epoch, self)
+
+ # -----------------
+ # RUN TNG EPOCH
+ # -----------------
+ self.run_training_epoch()
+
+ # update LR schedulers
+ if self.lr_schedulers is not None:
+ for lr_scheduler in self.lr_schedulers:
+ lr_scheduler.step(epoch=self.current_epoch)
+
+ self.main_progress_bar.close()
+
+ model.on_train_end()
+
+ if self.logger is not None:
+ self.logger.finalize("success")
+
+ def run_training_epoch(self):
+ # before epoch hook
+ if self.is_function_implemented('on_epoch_start'):
+ model = self.get_model()
+ model.on_epoch_start()
+
+ # run epoch
+ for batch_idx, batch in enumerate(self.get_train_dataloader()):
+ # stop epoch if we limited the number of training batches
+ if batch_idx >= self.num_training_batches:
+ break
+
+ self.batch_idx = batch_idx
+
+ model = self.get_model()
+ model.global_step = self.global_step
+
+ # ---------------
+ # RUN TRAIN STEP
+ # ---------------
+ output = self.run_training_batch(batch, batch_idx)
+ batch_result, grad_norm_dic, batch_step_metrics = output
+
+ # when returning -1 from train_step, we end epoch early
+ early_stop_epoch = batch_result == -1
+
+ # ---------------
+ # RUN VAL STEP
+ # ---------------
+ should_check_val = (
+ not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch)
+ self.fisrt_epoch = False
+
+ if should_check_val:
+ self.run_evaluation(test=self.testing)
+
+ # when logs should be saved
+ should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
+ if should_save_log:
+ if self.proc_rank == 0 and self.logger is not None:
+ self.logger.save()
+
+ # when metrics should be logged
+ should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
+ if should_log_metrics:
+ # logs user requested information to logger
+ self.log_metrics(batch_step_metrics, grad_norm_dic)
+
+ self.global_step += 1
+ self.total_batch_idx += 1
+
+ # end epoch early
+ # stop when the flag is changed or we've gone past the amount
+ # requested in the batches
+ if early_stop_epoch:
+ break
+ if self.global_step > self.max_updates:
+ print("| Training end..")
+ exit()
+
+ # epoch end hook
+ if self.is_function_implemented('on_epoch_end'):
+ model = self.get_model()
+ model.on_epoch_end()
+
+ def run_training_batch(self, batch, batch_idx):
+ # track grad norms
+ grad_norm_dic = {}
+
+ # track all metrics for callbacks
+ all_callback_metrics = []
+
+ # track metrics to log
+ all_log_metrics = []
+
+ if batch is None:
+ return 0, grad_norm_dic, {}
+
+ # hook
+ if self.is_function_implemented('on_batch_start'):
+ model_ref = self.get_model()
+ response = model_ref.on_batch_start(batch)
+
+ if response == -1:
+ return -1, grad_norm_dic, {}
+
+ splits = [batch]
+ self.hiddens = None
+ for split_idx, split_batch in enumerate(splits):
+ self.split_idx = split_idx
+
+ # call training_step once per optimizer
+ for opt_idx, optimizer in enumerate(self.optimizers):
+ if optimizer is None:
+ continue
+ # make sure only the gradients of the current optimizer's paramaters are calculated
+ # in the training step to prevent dangling gradients in multiple-optimizer setup.
+ if len(self.optimizers) > 1:
+ for param in self.get_model().parameters():
+ param.requires_grad = False
+ for group in optimizer.param_groups:
+ for param in group['params']:
+ param.requires_grad = True
+
+ # wrap the forward step in a closure so second order methods work
+ def optimizer_closure():
+ # forward pass
+ output = self.training_forward(
+ split_batch, batch_idx, opt_idx, self.hiddens)
+
+ closure_loss = output[0]
+ progress_bar_metrics = output[1]
+ log_metrics = output[2]
+ callback_metrics = output[3]
+ self.hiddens = output[4]
+ if closure_loss is None:
+ return None
+
+ # accumulate loss
+ # (if accumulate_grad_batches = 1 no effect)
+ closure_loss = closure_loss / self.accumulate_grad_batches
+
+ # backward pass
+ model_ref = self.get_model()
+ if closure_loss.requires_grad:
+ model_ref.backward(closure_loss, optimizer)
+
+ # track metrics for callbacks
+ all_callback_metrics.append(callback_metrics)
+
+ # track progress bar metrics
+ self.add_tqdm_metrics(progress_bar_metrics)
+ all_log_metrics.append(log_metrics)
+
+ # insert after step hook
+ if self.is_function_implemented('on_after_backward'):
+ model_ref = self.get_model()
+ model_ref.on_after_backward()
+
+ return closure_loss
+
+ # calculate loss
+ loss = optimizer_closure()
+ if loss is None:
+ continue
+
+ # nan grads
+ if self.print_nan_grads:
+ self.print_nan_gradients()
+
+ # track total loss for logging (avoid mem leaks)
+ self.batch_loss_value += loss.item()
+
+ # gradient update with accumulated gradients
+ if (self.batch_idx + 1) % self.accumulate_grad_batches == 0:
+
+ # track gradient norms when requested
+ if batch_idx % self.row_log_interval == 0:
+ if self.track_grad_norm > 0:
+ model = self.get_model()
+ grad_norm_dic = model.grad_norm(
+ self.track_grad_norm)
+
+ # clip gradients
+ self.clip_gradients()
+
+ # calls .step(), .zero_grad()
+ # override function to modify this behavior
+ model = self.get_model()
+ model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx)
+
+ # calculate running loss for display
+ self.running_loss.append(self.batch_loss_value)
+ self.batch_loss_value = 0
+ self.avg_loss = np.mean(self.running_loss[-100:])
+
+ # activate batch end hook
+ if self.is_function_implemented('on_batch_end'):
+ model = self.get_model()
+ model.on_batch_end()
+
+ # update progress bar
+ self.main_progress_bar.update(1)
+ self.main_progress_bar.set_postfix(**self.training_tqdm_dict)
+
+ # collapse all metrics into one dict
+ all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()}
+
+ # track all metrics for callbacks
+ self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()})
+
+ return 0, grad_norm_dic, all_log_metrics
+
+ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
+ """
+ Handle forward for each training case (distributed, single gpu, etc...)
+ :param batch:
+ :param batch_idx:
+ :return:
+ """
+ # ---------------
+ # FORWARD
+ # ---------------
+ # enable not needing to add opt_idx to training_step
+ args = [batch, batch_idx, opt_idx]
+
+ # distributed forward
+ if self.use_ddp or self.use_dp:
+ output = self.model(*args)
+ # single GPU forward
+ elif self.single_gpu:
+ gpu_id = 0
+ if isinstance(self.data_parallel_device_ids, list):
+ gpu_id = self.data_parallel_device_ids[0]
+ batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id)
+ args[0] = batch
+ output = self.model.training_step(*args)
+ # CPU forward
+ else:
+ output = self.model.training_step(*args)
+
+ # allow any mode to define training_end
+ model_ref = self.get_model()
+ output_ = model_ref.training_end(output)
+ if output_ is not None:
+ output = output_
+
+ # format and reduce outputs accordingly
+ output = self.process_output(output, train=True)
+
+ return output
+
+ # ---------------
+ # Utils
+ # ---------------
+ def is_function_implemented(self, f_name):
+ model = self.get_model()
+ f_op = getattr(model, f_name, None)
+ return callable(f_op)
+
+ def _percent_range_check(self, name):
+ value = getattr(self, name)
+ msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}."
+ if name == "val_check_interval":
+ msg += " If you want to disable validation set `val_percent_check` to 0.0 instead."
+
+ if not 0. <= value <= 1.:
+ raise ValueError(msg)
diff --git a/utils/plot.py b/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdca62a8cd80869c707890cd9febd39966cd3658
--- /dev/null
+++ b/utils/plot.py
@@ -0,0 +1,56 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime']
+
+
+def spec_to_figure(spec, vmin=None, vmax=None):
+ if isinstance(spec, torch.Tensor):
+ spec = spec.cpu().numpy()
+ fig = plt.figure(figsize=(12, 6))
+ plt.pcolor(spec.T, vmin=vmin, vmax=vmax)
+ return fig
+
+
+def spec_f0_to_figure(spec, f0s, figsize=None):
+ max_y = spec.shape[1]
+ if isinstance(spec, torch.Tensor):
+ spec = spec.detach().cpu().numpy()
+ f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()}
+ f0s = {k: f0 / 10 for k, f0 in f0s.items()}
+ fig = plt.figure(figsize=(12, 6) if figsize is None else figsize)
+ plt.pcolor(spec.T)
+ for i, (k, f0) in enumerate(f0s.items()):
+ plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8)
+ plt.legend()
+ return fig
+
+
+def dur_to_figure(dur_gt, dur_pred, txt):
+ dur_gt = dur_gt.long().cpu().numpy()
+ dur_pred = dur_pred.long().cpu().numpy()
+ dur_gt = np.cumsum(dur_gt)
+ dur_pred = np.cumsum(dur_pred)
+ fig = plt.figure(figsize=(12, 6))
+ for i in range(len(dur_gt)):
+ shift = (i % 8) + 1
+ plt.text(dur_gt[i], shift, txt[i])
+ plt.text(dur_pred[i], 10 + shift, txt[i])
+ plt.vlines(dur_gt[i], 0, 10, colors='b') # blue is gt
+ plt.vlines(dur_pred[i], 10, 20, colors='r') # red is pred
+ return fig
+
+
+def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None):
+ fig = plt.figure()
+ f0_gt = f0_gt.cpu().numpy()
+ plt.plot(f0_gt, color='r', label='gt')
+ if f0_cwt is not None:
+ f0_cwt = f0_cwt.cpu().numpy()
+ plt.plot(f0_cwt, color='b', label='cwt')
+ if f0_pred is not None:
+ f0_pred = f0_pred.cpu().numpy()
+ plt.plot(f0_pred, color='green', label='pred')
+ plt.legend()
+ return fig
diff --git a/utils/text_encoder.py b/utils/text_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9e0758abc7b4e1f452481cba9715df08ceab543
--- /dev/null
+++ b/utils/text_encoder.py
@@ -0,0 +1,304 @@
+import re
+import six
+from six.moves import range # pylint: disable=redefined-builtin
+
+PAD = ""
+EOS = ""
+UNK = ""
+SEG = "|"
+RESERVED_TOKENS = [PAD, EOS, UNK]
+NUM_RESERVED_TOKENS = len(RESERVED_TOKENS)
+PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0
+EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1
+UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2
+
+if six.PY2:
+ RESERVED_TOKENS_BYTES = RESERVED_TOKENS
+else:
+ RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")]
+
+# Regular expression for unescaping token strings.
+# '\u' is converted to '_'
+# '\\' is converted to '\'
+# '\213;' is converted to unichr(213)
+_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);")
+_ESCAPE_CHARS = set(u"\\_u;0123456789")
+
+
+def strip_ids(ids, ids_to_strip):
+ """Strip ids_to_strip from the end ids."""
+ ids = list(ids)
+ while ids and ids[-1] in ids_to_strip:
+ ids.pop()
+ return ids
+
+
+class TextEncoder(object):
+ """Base class for converting from ints to/from human readable strings."""
+
+ def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS):
+ self._num_reserved_ids = num_reserved_ids
+
+ @property
+ def num_reserved_ids(self):
+ return self._num_reserved_ids
+
+ def encode(self, s):
+ """Transform a human-readable string into a sequence of int ids.
+
+ The ids should be in the range [num_reserved_ids, vocab_size). Ids [0,
+ num_reserved_ids) are reserved.
+
+ EOS is not appended.
+
+ Args:
+ s: human-readable string to be converted.
+
+ Returns:
+ ids: list of integers
+ """
+ return [int(w) + self._num_reserved_ids for w in s.split()]
+
+ def decode(self, ids, strip_extraneous=False):
+ """Transform a sequence of int ids into a human-readable string.
+
+ EOS is not expected in ids.
+
+ Args:
+ ids: list of integers to be converted.
+ strip_extraneous: bool, whether to strip off extraneous tokens
+ (EOS and PAD).
+
+ Returns:
+ s: human-readable string.
+ """
+ if strip_extraneous:
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ """Transform a sequence of int ids into a their string versions.
+
+ This method supports transforming individual input/output ids to their
+ string versions so that sequence to/from text conversions can be visualized
+ in a human readable format.
+
+ Args:
+ ids: list of integers to be converted.
+
+ Returns:
+ strs: list of human-readable string.
+ """
+ decoded_ids = []
+ for id_ in ids:
+ if 0 <= id_ < self._num_reserved_ids:
+ decoded_ids.append(RESERVED_TOKENS[int(id_)])
+ else:
+ decoded_ids.append(id_ - self._num_reserved_ids)
+ return [str(d) for d in decoded_ids]
+
+ @property
+ def vocab_size(self):
+ raise NotImplementedError()
+
+
+class ByteTextEncoder(TextEncoder):
+ """Encodes each byte to an id. For 8-bit strings only."""
+
+ def encode(self, s):
+ numres = self._num_reserved_ids
+ if six.PY2:
+ if isinstance(s, unicode):
+ s = s.encode("utf-8")
+ return [ord(c) + numres for c in s]
+ # Python3: explicitly convert to UTF-8
+ return [c + numres for c in s.encode("utf-8")]
+
+ def decode(self, ids, strip_extraneous=False):
+ if strip_extraneous:
+ ids = strip_ids(ids, list(range(self._num_reserved_ids or 0)))
+ numres = self._num_reserved_ids
+ decoded_ids = []
+ int2byte = six.int2byte
+ for id_ in ids:
+ if 0 <= id_ < numres:
+ decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
+ else:
+ decoded_ids.append(int2byte(id_ - numres))
+ if six.PY2:
+ return "".join(decoded_ids)
+ # Python3: join byte arrays and then decode string
+ return b"".join(decoded_ids).decode("utf-8", "replace")
+
+ def decode_list(self, ids):
+ numres = self._num_reserved_ids
+ decoded_ids = []
+ int2byte = six.int2byte
+ for id_ in ids:
+ if 0 <= id_ < numres:
+ decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)])
+ else:
+ decoded_ids.append(int2byte(id_ - numres))
+ # Python3: join byte arrays and then decode string
+ return decoded_ids
+
+ @property
+ def vocab_size(self):
+ return 2**8 + self._num_reserved_ids
+
+
+class ByteTextEncoderWithEos(ByteTextEncoder):
+ """Encodes each byte to an id and appends the EOS token."""
+
+ def encode(self, s):
+ return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID]
+
+
+class TokenTextEncoder(TextEncoder):
+ """Encoder based on a user-supplied vocabulary (file or list)."""
+
+ def __init__(self,
+ vocab_filename,
+ reverse=False,
+ vocab_list=None,
+ replace_oov=None,
+ num_reserved_ids=NUM_RESERVED_TOKENS):
+ """Initialize from a file or list, one token per line.
+
+ Handling of reserved tokens works as follows:
+ - When initializing from a list, we add reserved tokens to the vocab.
+ - When initializing from a file, we do not add reserved tokens to the vocab.
+ - When saving vocab files, we save reserved tokens to the file.
+
+ Args:
+ vocab_filename: If not None, the full filename to read vocab from. If this
+ is not None, then vocab_list should be None.
+ reverse: Boolean indicating if tokens should be reversed during encoding
+ and decoding.
+ vocab_list: If not None, a list of elements of the vocabulary. If this is
+ not None, then vocab_filename should be None.
+ replace_oov: If not None, every out-of-vocabulary token seen when
+ encoding will be replaced by this string (which must be in vocab).
+ num_reserved_ids: Number of IDs to save for reserved tokens like .
+ """
+ super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids)
+ self._reverse = reverse
+ self._replace_oov = replace_oov
+ if vocab_filename:
+ self._init_vocab_from_file(vocab_filename)
+ else:
+ assert vocab_list is not None
+ self._init_vocab_from_list(vocab_list)
+ self.pad_index = self._token_to_id[PAD]
+ self.eos_index = self._token_to_id[EOS]
+ self.unk_index = self._token_to_id[UNK]
+ self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index
+
+ def encode(self, s):
+ """Converts a space-separated string of tokens to a list of ids."""
+ sentence = s
+ tokens = sentence.strip().split()
+ if self._replace_oov is not None:
+ tokens = [t if t in self._token_to_id else self._replace_oov
+ for t in tokens]
+ ret = [self._token_to_id[tok] for tok in tokens]
+ return ret[::-1] if self._reverse else ret
+
+ def decode(self, ids, strip_eos=False, strip_padding=False):
+ if strip_padding and self.pad() in list(ids):
+ pad_pos = list(ids).index(self.pad())
+ ids = ids[:pad_pos]
+ if strip_eos and self.eos() in list(ids):
+ eos_pos = list(ids).index(self.eos())
+ ids = ids[:eos_pos]
+ return " ".join(self.decode_list(ids))
+
+ def decode_list(self, ids):
+ seq = reversed(ids) if self._reverse else ids
+ return [self._safe_id_to_token(i) for i in seq]
+
+ @property
+ def vocab_size(self):
+ return len(self._id_to_token)
+
+ def __len__(self):
+ return self.vocab_size
+
+ def _safe_id_to_token(self, idx):
+ return self._id_to_token.get(idx, "ID_%d" % idx)
+
+ def _init_vocab_from_file(self, filename):
+ """Load vocab from a file.
+
+ Args:
+ filename: The file to load vocabulary from.
+ """
+ with open(filename) as f:
+ tokens = [token.strip() for token in f.readlines()]
+
+ def token_gen():
+ for token in tokens:
+ yield token
+
+ self._init_vocab(token_gen(), add_reserved_tokens=False)
+
+ def _init_vocab_from_list(self, vocab_list):
+ """Initialize tokens from a list of tokens.
+
+ It is ok if reserved tokens appear in the vocab list. They will be
+ removed. The set of tokens in vocab_list should be unique.
+
+ Args:
+ vocab_list: A list of tokens.
+ """
+ def token_gen():
+ for token in vocab_list:
+ if token not in RESERVED_TOKENS:
+ yield token
+
+ self._init_vocab(token_gen())
+
+ def _init_vocab(self, token_generator, add_reserved_tokens=True):
+ """Initialize vocabulary with tokens from token_generator."""
+
+ self._id_to_token = {}
+ non_reserved_start_index = 0
+
+ if add_reserved_tokens:
+ self._id_to_token.update(enumerate(RESERVED_TOKENS))
+ non_reserved_start_index = len(RESERVED_TOKENS)
+
+ self._id_to_token.update(
+ enumerate(token_generator, start=non_reserved_start_index))
+
+ # _token_to_id is the reverse of _id_to_token
+ self._token_to_id = dict((v, k)
+ for k, v in six.iteritems(self._id_to_token))
+
+ def pad(self):
+ return self.pad_index
+
+ def eos(self):
+ return self.eos_index
+
+ def unk(self):
+ return self.unk_index
+
+ def seg(self):
+ return self.seg_index
+
+ def store_to_file(self, filename):
+ """Write vocab file to disk.
+
+ Vocab files have one token per line. The file ends in a newline. Reserved
+ tokens are written to the vocab file as well.
+
+ Args:
+ filename: Full path of the file to store the vocab to.
+ """
+ with open(filename, "w") as f:
+ for i in range(len(self._id_to_token)):
+ f.write(self._id_to_token[i] + "\n")
+
+ def sil_phonemes(self):
+ return [p for p in self._id_to_token.values() if not p[0].isalpha()]
diff --git a/utils/text_norm.py b/utils/text_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0973cebc91e0525aeb6657e70012a1d37b5e6ff
--- /dev/null
+++ b/utils/text_norm.py
@@ -0,0 +1,790 @@
+# coding=utf-8
+# Authors:
+# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+# 2019.9 Jiayu DU
+#
+# requirements:
+# - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import sys, os, argparse, codecs, string, re
+
+# ================================================================================ #
+# basic constant
+# ================================================================================ #
+CHINESE_DIGIS = u'零一二三四五六七八九'
+BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖'
+BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖'
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万'
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬'
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载'
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載'
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万'
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬'
+
+ZERO_ALT = u'〇'
+ONE_ALT = u'幺'
+TWO_ALTS = [u'两', u'兩']
+
+POSITIVE = [u'正', u'正']
+NEGATIVE = [u'负', u'負']
+POINT = [u'点', u'點']
+# PLUS = [u'加', u'加']
+# SIL = [u'杠', u'槓']
+
+# 中文数字系统类型
+NUMBERING_TYPES = ['low', 'mid', 'high']
+
+CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \
+ '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)'
+CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)'
+COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \
+ '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \
+ '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \
+ '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \
+ '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \
+ '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)'
+
+# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CHINESE_PUNC_STOP = '!?。。'
+CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏'
+CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
+
+# ================================================================================ #
+# basic class
+# ================================================================================ #
+class ChineseChar(object):
+ """
+ 中文字符
+ 每个字符对应简体和繁体,
+ e.g. 简体 = '负', 繁体 = '負'
+ 转换时可转换为简体或繁体
+ """
+
+ def __init__(self, simplified, traditional):
+ self.simplified = simplified
+ self.traditional = traditional
+ # self.__repr__ = self.__str__
+
+ def __str__(self):
+ return self.simplified or self.traditional or None
+
+ def __repr__(self):
+ return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+ """
+ 中文数字/数位字符
+ 每个字符除繁简体外还有一个额外的大写字符
+ e.g. '陆' 和 '陸'
+ """
+
+ def __init__(self, power, simplified, traditional, big_s, big_t):
+ super(ChineseNumberUnit, self).__init__(simplified, traditional)
+ self.power = power
+ self.big_s = big_s
+ self.big_t = big_t
+
+ def __str__(self):
+ return '10^{}'.format(self.power)
+
+ @classmethod
+ def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+ if small_unit:
+ return ChineseNumberUnit(power=index + 1,
+ simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[0]:
+ return ChineseNumberUnit(power=index + 8,
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[1]:
+ return ChineseNumberUnit(power=(index + 2) * 4,
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ elif numbering_type == NUMBERING_TYPES[2]:
+ return ChineseNumberUnit(power=pow(2, index + 3),
+ simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1])
+ else:
+ raise ValueError(
+ 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type))
+
+
+class ChineseNumberDigit(ChineseChar):
+ """
+ 中文数字字符
+ """
+
+ def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+ super(ChineseNumberDigit, self).__init__(simplified, traditional)
+ self.value = value
+ self.big_s = big_s
+ self.big_t = big_t
+ self.alt_s = alt_s
+ self.alt_t = alt_t
+
+ def __str__(self):
+ return str(self.value)
+
+ @classmethod
+ def create(cls, i, v):
+ return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+ """
+ 中文数位字符
+ """
+
+ def __init__(self, simplified, traditional, symbol, expression=None):
+ super(ChineseMath, self).__init__(simplified, traditional)
+ self.symbol = symbol
+ self.expression = expression
+ self.big_s = simplified
+ self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+ """
+ 中文数字系统
+ """
+ pass
+
+
+class MathSymbol(object):
+ """
+ 用于中文数字系统的数学符号 (繁/简体), e.g.
+ positive = ['正', '正']
+ negative = ['负', '負']
+ point = ['点', '點']
+ """
+
+ def __init__(self, positive, negative, point):
+ self.positive = positive
+ self.negative = negative
+ self.point = point
+
+ def __iter__(self):
+ for v in self.__dict__.values():
+ yield v
+
+
+# class OtherSymbol(object):
+# """
+# 其他符号
+# """
+#
+# def __init__(self, sil):
+# self.sil = sil
+#
+# def __iter__(self):
+# for v in self.__dict__.values():
+# yield v
+
+
+# ================================================================================ #
+# basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+ """
+ 根据数字系统类型返回创建相应的数字系统,默认为 mid
+ NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
+ low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc.
+ mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
+ high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
+ 返回对应的数字系统
+ """
+
+ # chinese number units of '亿' and larger
+ all_larger_units = zip(
+ LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ larger_units = [CNU.create(i, v, numbering_type, False)
+ for i, v in enumerate(all_larger_units)]
+ # chinese number units of '十, 百, 千, 万'
+ all_smaller_units = zip(
+ SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL)
+ smaller_units = [CNU.create(i, v, small_unit=True)
+ for i, v in enumerate(all_smaller_units)]
+ # digis
+ chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS,
+ BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL)
+ digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+ digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+ digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+ digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+ # symbols
+ positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x)
+ negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x)
+ point_cn = CM(POINT[0], POINT[1], '.', lambda x,
+ y: float(str(x) + '.' + str(y)))
+ # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+ system = NumberSystem()
+ system.units = smaller_units + larger_units
+ system.digits = digits
+ system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+ # system.symbols = OtherSymbol(sil_cn)
+ return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+ def get_symbol(char, system):
+ for u in system.units:
+ if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+ return u
+ for d in system.digits:
+ if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+ return d
+ for m in system.math:
+ if char in [m.traditional, m.simplified]:
+ return m
+
+ def string2symbols(chinese_string, system):
+ int_string, dec_string = chinese_string, ''
+ for p in [system.math.point.simplified, system.math.point.traditional]:
+ if p in chinese_string:
+ int_string, dec_string = chinese_string.split(p)
+ break
+ return [get_symbol(c, system) for c in int_string], \
+ [get_symbol(c, system) for c in dec_string]
+
+ def correct_symbols(integer_symbols, system):
+ """
+ 一百八 to 一百八十
+ 一亿一千三百万 to 一亿 一千万 三百万
+ """
+
+ if integer_symbols and isinstance(integer_symbols[0], CNU):
+ if integer_symbols[0].power == 1:
+ integer_symbols = [system.digits[1]] + integer_symbols
+
+ if len(integer_symbols) > 1:
+ if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+ integer_symbols.append(
+ CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+ result = []
+ unit_count = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ result.append(s)
+ unit_count = 0
+ elif isinstance(s, CNU):
+ current_unit = CNU(s.power, None, None, None, None)
+ unit_count += 1
+
+ if unit_count == 1:
+ result.append(current_unit)
+ elif unit_count > 1:
+ for i in range(len(result)):
+ if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power:
+ result[-i - 1] = CNU(result[-i - 1].power +
+ current_unit.power, None, None, None, None)
+ return result
+
+ def compute_value(integer_symbols):
+ """
+ Compute the value.
+ When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+ e.g. '两千万' = 2000 * 10000 not 2000 + 10000
+ """
+ value = [0]
+ last_power = 0
+ for s in integer_symbols:
+ if isinstance(s, CND):
+ value[-1] = s.value
+ elif isinstance(s, CNU):
+ value[-1] *= pow(10, s.power)
+ if s.power > last_power:
+ value[:-1] = list(map(lambda v: v *
+ pow(10, s.power), value[:-1]))
+ last_power = s.power
+ value.append(0)
+ return sum(value)
+
+ system = create_system(numbering_type)
+ int_part, dec_part = string2symbols(chinese_string, system)
+ int_part = correct_symbols(int_part, system)
+ int_str = str(compute_value(int_part))
+ dec_str = ''.join([str(d.value) for d in dec_part])
+ if dec_part:
+ return '{0}.{1}'.format(int_str, dec_str)
+ else:
+ return int_str
+
+
+def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False,
+ traditional=False, alt_zero=False, alt_one=False, alt_two=True,
+ use_zeros=True, use_units=True):
+ def get_value(value_string, use_zeros=True):
+
+ striped_string = value_string.lstrip('0')
+
+ # record nothing if all zeros
+ if not striped_string:
+ return []
+
+ # record one digits
+ elif len(striped_string) == 1:
+ if use_zeros and len(value_string) != len(striped_string):
+ return [system.digits[0], system.digits[int(striped_string)]]
+ else:
+ return [system.digits[int(striped_string)]]
+
+ # recursively record multiple digits
+ else:
+ result_unit = next(u for u in reversed(
+ system.units) if u.power < len(striped_string))
+ result_string = value_string[:-result_unit.power]
+ return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:])
+
+ system = create_system(numbering_type)
+
+ int_dec = number_string.split('.')
+ if len(int_dec) == 1:
+ int_string = int_dec[0]
+ dec_string = ""
+ elif len(int_dec) == 2:
+ int_string = int_dec[0]
+ dec_string = int_dec[1]
+ else:
+ raise ValueError(
+ "invalid input num string with more than one dot: {}".format(number_string))
+
+ if use_units and len(int_string) > 1:
+ result_symbols = get_value(int_string)
+ else:
+ result_symbols = [system.digits[int(c)] for c in int_string]
+ dec_symbols = [system.digits[int(c)] for c in dec_string]
+ if dec_string:
+ result_symbols += [system.math.point] + dec_symbols
+
+ if alt_two:
+ liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t,
+ system.digits[2].big_s, system.digits[2].big_t)
+ for i, v in enumerate(result_symbols):
+ if isinstance(v, CND) and v.value == 2:
+ next_symbol = result_symbols[i +
+ 1] if i < len(result_symbols) - 1 else None
+ previous_symbol = result_symbols[i - 1] if i > 0 else None
+ if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+ if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)):
+ result_symbols[i] = liang
+
+ # if big is True, '两' will not be used and `alt_two` has no impact on output
+ if big:
+ attr_name = 'big_'
+ if traditional:
+ attr_name += 't'
+ else:
+ attr_name += 's'
+ else:
+ if traditional:
+ attr_name = 'traditional'
+ else:
+ attr_name = 'simplified'
+
+ result = ''.join([getattr(s, attr_name) for s in result_symbols])
+
+ # if not use_zeros:
+ # result = result.strip(getattr(system.digits[0], attr_name))
+
+ if alt_zero:
+ result = result.replace(
+ getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+ if alt_one:
+ result = result.replace(
+ getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+ for i, p in enumerate(POINT):
+ if result.startswith(p):
+ return CHINESE_DIGIS[0] + result
+
+ # ^10, 11, .., 19
+ if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+ SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \
+ result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]:
+ result = result[1:]
+
+ return result
+
+
+# ================================================================================ #
+# different types of rewriters
+# ================================================================================ #
+class Cardinal:
+ """
+ CARDINAL类
+ """
+
+ def __init__(self, cardinal=None, chntext=None):
+ self.cardinal = cardinal
+ self.chntext = chntext
+
+ def chntext2cardinal(self):
+ return chn2num(self.chntext)
+
+ def cardinal2chntext(self):
+ return num2chn(self.cardinal)
+
+
+class Digit:
+ """
+ DIGIT类
+ """
+
+ def __init__(self, digit=None, chntext=None):
+ self.digit = digit
+ self.chntext = chntext
+
+ # def chntext2digit(self):
+ # return chn2num(self.chntext)
+
+ def digit2chntext(self):
+ return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+ """
+ TELEPHONE类
+ """
+
+ def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+ self.telephone = telephone
+ self.raw_chntext = raw_chntext
+ self.chntext = chntext
+
+ # def chntext2telephone(self):
+ # sil_parts = self.raw_chntext.split('')
+ # self.telephone = '-'.join([
+ # str(chn2num(p)) for p in sil_parts
+ # ])
+ # return self.telephone
+
+ def telephone2chntext(self, fixed=False):
+
+ if fixed:
+ sil_parts = self.telephone.split('-')
+ self.raw_chntext = ''.join([
+ num2chn(part, alt_two=False, use_units=False) for part in sil_parts
+ ])
+ self.chntext = self.raw_chntext.replace('', '')
+ else:
+ sp_parts = self.telephone.strip('+').split()
+ self.raw_chntext = ''.join([
+ num2chn(part, alt_two=False, use_units=False) for part in sp_parts
+ ])
+ self.chntext = self.raw_chntext.replace('', '')
+ return self.chntext
+
+
+class Fraction:
+ """
+ FRACTION类
+ """
+
+ def __init__(self, fraction=None, chntext=None):
+ self.fraction = fraction
+ self.chntext = chntext
+
+ def chntext2fraction(self):
+ denominator, numerator = self.chntext.split('分之')
+ return chn2num(numerator) + '/' + chn2num(denominator)
+
+ def fraction2chntext(self):
+ numerator, denominator = self.fraction.split('/')
+ return num2chn(denominator) + '分之' + num2chn(numerator)
+
+
+class Date:
+ """
+ DATE类
+ """
+
+ def __init__(self, date=None, chntext=None):
+ self.date = date
+ self.chntext = chntext
+
+ # def chntext2date(self):
+ # chntext = self.chntext
+ # try:
+ # year, other = chntext.strip().split('年', maxsplit=1)
+ # year = Digit(chntext=year).digit2chntext() + '年'
+ # except ValueError:
+ # other = chntext
+ # year = ''
+ # if other:
+ # try:
+ # month, day = other.strip().split('月', maxsplit=1)
+ # month = Cardinal(chntext=month).chntext2cardinal() + '月'
+ # except ValueError:
+ # day = chntext
+ # month = ''
+ # if day:
+ # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+ # else:
+ # month = ''
+ # day = ''
+ # date = year + month + day
+ # self.date = date
+ # return self.date
+
+ def date2chntext(self):
+ date = self.date
+ try:
+ year, other = date.strip().split('年', 1)
+ year = Digit(digit=year).digit2chntext() + '年'
+ except ValueError:
+ other = date
+ year = ''
+ if other:
+ try:
+ month, day = other.strip().split('月', 1)
+ month = Cardinal(cardinal=month).cardinal2chntext() + '月'
+ except ValueError:
+ day = date
+ month = ''
+ if day:
+ day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+ else:
+ month = ''
+ day = ''
+ chntext = year + month + day
+ self.chntext = chntext
+ return self.chntext
+
+
+class Money:
+ """
+ MONEY类
+ """
+
+ def __init__(self, money=None, chntext=None):
+ self.money = money
+ self.chntext = chntext
+
+ # def chntext2money(self):
+ # return self.money
+
+ def money2chntext(self):
+ money = self.money
+ pattern = re.compile(r'(\d+(\.\d+)?)')
+ matchers = pattern.findall(money)
+ if matchers:
+ for matcher in matchers:
+ money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+ self.chntext = money
+ return self.chntext
+
+
+class Percentage:
+ """
+ PERCENTAGE类
+ """
+
+ def __init__(self, percentage=None, chntext=None):
+ self.percentage = percentage
+ self.chntext = chntext
+
+ def chntext2percentage(self):
+ return chn2num(self.chntext.strip().strip('百分之')) + '%'
+
+ def percentage2chntext(self):
+ return '百分之' + num2chn(self.percentage.strip().strip('%'))
+
+
+# ================================================================================ #
+# NSW Normalizer
+# ================================================================================ #
+class NSWNormalizer:
+ def __init__(self, raw_text):
+ self.raw_text = '^' + raw_text + '$'
+ self.norm_text = ''
+
+ def _particular(self):
+ text = self.norm_text
+ pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('particular')
+ for matcher in matchers:
+ text = text.replace(matcher[0], matcher[1] + '2' + matcher[2], 1)
+ self.norm_text = text
+ return self.norm_text
+
+ def normalize(self, remove_punc=True):
+ text = self.raw_text
+
+ # 规范化日期
+ pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('date')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+ # 规范化金钱
+ pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('money')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+ # 规范化固话/手机号码
+ # 手机
+ # http://www.jihaoba.com/news/show/13680
+ # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
+ # 联通:130、131、132、156、155、186、185、176
+ # 电信:133、153、189、180、181、177
+ pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1)
+ # 固话
+ pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fixed telephone')
+ for matcher in matchers:
+ text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1)
+
+ # 规范化分数
+ pattern = re.compile(r"(\d+/\d+)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('fraction')
+ for matcher in matchers:
+ text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+ # 规范化百分数
+ text = text.replace('%', '%')
+ pattern = re.compile(r"(\d+(\.\d+)?%)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('percentage')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1)
+
+ # 规范化纯数+量词
+ pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal+quantifier')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ # 规范化数字编号
+ pattern = re.compile(r"(\d{4,32})")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('digit')
+ for matcher in matchers:
+ text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+ # 规范化纯数
+ pattern = re.compile(r"(\d+(\.\d+)?)")
+ matchers = pattern.findall(text)
+ if matchers:
+ # print('cardinal')
+ for matcher in matchers:
+ text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+ self.norm_text = text
+ self._particular()
+
+ text = self.norm_text.lstrip('^').rstrip('$')
+ if remove_punc:
+ # Punctuations removal
+ old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations
+ new_chars = ' ' * len(old_chars)
+ del_chars = ''
+ text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
+ return text
+
+
+def nsw_test_case(raw_text):
+ print('I:' + raw_text)
+ print('O:' + NSWNormalizer(raw_text).normalize())
+ print('')
+
+
+def nsw_test():
+ nsw_test_case('固话:0595-23865596或23880880。')
+ nsw_test_case('固话:0595-23865596或23880880。')
+ nsw_test_case('手机:+86 19859213959或15659451527。')
+ nsw_test_case('分数:32477/76391。')
+ nsw_test_case('百分数:80.03%。')
+ nsw_test_case('编号:31520181154418。')
+ nsw_test_case('纯数:2983.07克或12345.60米。')
+ nsw_test_case('日期:1999年2月20日或09年3月15号。')
+ nsw_test_case('金钱:12块5,34.5元,20.1万')
+ nsw_test_case('特殊:O2O或B2C。')
+ nsw_test_case('3456万吨')
+ nsw_test_case('2938个')
+ nsw_test_case('938')
+ nsw_test_case('今天吃了115个小笼包231个馒头')
+ nsw_test_case('有62%的概率')
+
+
+if __name__ == '__main__':
+ # nsw_test()
+
+ p = argparse.ArgumentParser()
+ p.add_argument('ifile', help='input filename, assume utf-8 encoding')
+ p.add_argument('ofile', help='output filename')
+ p.add_argument('--to_upper', action='store_true', help='convert to upper case')
+ p.add_argument('--to_lower', action='store_true', help='convert to lower case')
+ p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.")
+ p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines')
+ args = p.parse_args()
+
+ ifile = codecs.open(args.ifile, 'r', 'utf8')
+ ofile = codecs.open(args.ofile, 'w+', 'utf8')
+
+ n = 0
+ for l in ifile:
+ key = ''
+ text = ''
+ if args.has_key:
+ cols = l.split(maxsplit=1)
+ key = cols[0]
+ if len(cols) == 2:
+ text = cols[1]
+ else:
+ text = ''
+ else:
+ text = l
+
+ # cases
+ if args.to_upper and args.to_lower:
+ sys.stderr.write('text norm: to_upper OR to_lower?')
+ exit(1)
+ if args.to_upper:
+ text = text.upper()
+ if args.to_lower:
+ text = text.lower()
+
+ # NSW(Non-Standard-Word) normalization
+ text = NSWNormalizer(text).normalize()
+
+ #
+ if args.has_key:
+ ofile.write(key + '\t' + text)
+ else:
+ ofile.write(text)
+
+ n += 1
+ if n % args.log_interval == 0:
+ sys.stderr.write("text norm: {} lines done.\n".format(n))
+
+ sys.stderr.write("text norm: {} lines done in total.\n".format(n))
+
+ ifile.close()
+ ofile.close()
diff --git a/utils/training_utils.py b/utils/training_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..409b15388790b1aadb24632313bdd1f41b4b06ac
--- /dev/null
+++ b/utils/training_utils.py
@@ -0,0 +1,27 @@
+from utils.hparams import hparams
+
+
+class RSQRTSchedule(object):
+ def __init__(self, optimizer):
+ super().__init__()
+ self.optimizer = optimizer
+ self.constant_lr = hparams['lr']
+ self.warmup_updates = hparams['warmup_updates']
+ self.hidden_size = hparams['hidden_size']
+ self.lr = hparams['lr']
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = self.lr
+ self.step(0)
+
+ def step(self, num_updates):
+ constant_lr = self.constant_lr
+ warmup = min(num_updates / self.warmup_updates, 1.0)
+ rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5
+ rsqrt_hidden = self.hidden_size ** -0.5
+ self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7)
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = self.lr
+ return self.lr
+
+ def get_lr(self):
+ return self.optimizer.param_groups[0]['lr']
diff --git "a/\351\242\204\345\244\204\347\220\206.bat" "b/\351\242\204\345\244\204\347\220\206.bat"
new file mode 100644
index 0000000000000000000000000000000000000000..5ba8f329852f75bab00abb07fa41d9d9662b75a1
--- /dev/null
+++ "b/\351\242\204\345\244\204\347\220\206.bat"
@@ -0,0 +1,4 @@
+set PYTHONPATH=.
+set CUDA_VISIBLE_DEVICES=0
+python preprocessing/binarize.py --config training/config.yaml
+pause
\ No newline at end of file