diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..38cd7b06f6129ad9ccf0e9dd75161cb8c5582370
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,10 @@
+# from .gitignore
+modules/yt_tmp.wav
+**/venv/
+**/__pycache__/
+**/outputs/
+**/models/
+
+**/.idea
+**/.git
+**/.github
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000000000000000000000000000000000000..68e0d8d6d61d3f17b03a9c6354d408ae57fb5eda
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1,13 @@
+# These are supported funding model platforms
+
+github: []
+patreon: # Replace with a single Patreon username
+open_collective: # Replace with a single Open Collective username
+ko_fi: jhj0517
+tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
+community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
+liberapay: # Replace with a single Liberapay username
+issuehunt: # Replace with a single IssueHunt username
+otechie: # Replace with a single Otechie username
+lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
+custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000000000000000000000000000000000000..bfb3b970454e3c6e501bf1594beed2da824b527e
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,11 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: bug
+assignees: jhj0517
+
+---
+
+**Which OS are you using?**
+ - OS: [e.g. iOS or Windows.. If you are using Google Colab, just Colab.]
diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md
new file mode 100644
index 0000000000000000000000000000000000000000..39b2a2ebf8084430fa7b7b7cc28aba5ee03b90d4
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/feature_request.md
@@ -0,0 +1,10 @@
+---
+name: Feature request
+about: Any feature you want
+title: ''
+labels: enhancement
+assignees: jhj0517
+
+---
+
+
diff --git a/.github/ISSUE_TEMPLATE/hallucination.md b/.github/ISSUE_TEMPLATE/hallucination.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d5ee0ff0291cab993c350d50def4f2ab648be91
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/hallucination.md
@@ -0,0 +1,12 @@
+---
+name: Hallucination
+about: Whisper hallucinations. ( Repeating certain words or subtitles starting too
+ early, etc. )
+title: ''
+labels: hallucination
+assignees: jhj0517
+
+---
+
+**Download URL for sample audio**
+- Please upload download URL for sample audio file so I can test with some settings for better result. You can use https://easyupload.io/ or any other service to share.
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
new file mode 100644
index 0000000000000000000000000000000000000000..1537a083b40e194162d9d5ed7327c2ee773b8def
--- /dev/null
+++ b/.github/pull_request_template.md
@@ -0,0 +1,5 @@
+## Related issues / PRs. Summarize issues.
+- #
+
+## Summarize Changes
+1.
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3d00327dd35e2dad36b1c71e13a7f987680303f2
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,101 @@
+name: CI
+
+on:
+ workflow_dispatch:
+
+ push:
+ branches:
+ - master
+ - intel-gpu
+ pull_request:
+ branches:
+ - master
+ - intel-gpu
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python: ["3.10", "3.11", "3.12"]
+
+ env:
+ DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
+
+ steps:
+ - name: Clean up space for action
+ run: rm -rf /opt/hostedtoolcache
+
+ - uses: actions/checkout@v4
+ - name: Setup Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Install git and ffmpeg
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
+
+ - name: Install dependencies
+ run: pip install -r requirements.txt pytest jiwer
+
+ - name: Run test
+ run: python -m pytest -rs tests
+
+ test-backend:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python: ["3.10", "3.11", "3.12"]
+
+ env:
+ DEEPL_API_KEY: ${{ secrets.DEEPL_API_KEY }}
+ TEST_ENV: true
+
+ steps:
+ - name: Clean up space for action
+ run: rm -rf /opt/hostedtoolcache
+
+ - uses: actions/checkout@v4
+ - name: Setup Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Install git and ffmpeg
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
+
+ - name: Install dependencies
+ run: pip install -r backend/requirements-backend.txt pytest pytest-asyncio jiwer
+
+ - name: Run test
+ run: python -m pytest -rs backend/tests
+
+ test-shell-script:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python: [ "3.10", "3.11", "3.12" ]
+
+ steps:
+ - name: Clean up space for action
+ run: rm -rf /opt/hostedtoolcache
+
+ - uses: actions/checkout@v4
+ - name: Setup Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python }}
+
+ - name: Install git and ffmpeg
+ run: sudo apt-get update && sudo apt-get install -y git ffmpeg
+
+ - name: Execute Install.sh
+ run: |
+ chmod +x ./Install.sh
+ ./Install.sh
+
+ - name: Execute start-webui.sh
+ run: |
+ chmod +x ./start-webui.sh
+ timeout 60s ./start-webui.sh || true
+
diff --git a/.github/workflows/publish-docker.yml b/.github/workflows/publish-docker.yml
new file mode 100644
index 0000000000000000000000000000000000000000..b8fb1725f0a70ed7c4d90202485268b09c719590
--- /dev/null
+++ b/.github/workflows/publish-docker.yml
@@ -0,0 +1,73 @@
+name: Publish to Docker Hub
+
+on:
+ push:
+ branches:
+ - master
+
+jobs:
+ build-and-push-webui:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Clean up space for action
+ run: rm -rf /opt/hostedtoolcache
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_PASSWORD }}
+
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ file: ./Dockerfile
+ push: true
+ tags: ${{ secrets.DOCKER_USERNAME }}/whisper-webui:latest
+
+ - name: Log out of Docker Hub
+ run: docker logout
+
+ build-and-push-backend:
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Clean up space for action
+ run: rm -rf /opt/hostedtoolcache
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKER_USERNAME }}
+ password: ${{ secrets.DOCKER_PASSWORD }}
+
+ - name: Checkout repository
+ uses: actions/checkout@v3
+
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v3
+
+ - name: Set up QEMU
+ uses: docker/setup-qemu-action@v3
+
+ - name: Build and push Docker image
+ uses: docker/build-push-action@v5
+ with:
+ context: .
+ file: ./backend/Dockerfile
+ push: true
+ tags: ${{ secrets.DOCKER_USERNAME }}/whisper-webui-backend:latest
+
+ - name: Log out of Docker Hub
+ run: docker logout
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8f2fbaa2e6ddc9842e4a884c1cb7dcbff312b558
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,13 @@
+*.wav
+*.png
+*.mp4
+*.mp3
+**/.env
+**/.idea/
+**/.pytest_cache/
+**/venv/
+**/__pycache__/
+outputs/
+models/
+modules/yt_tmp.wav
+configs/default_parameters.yaml
\ No newline at end of file
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..8d1c9841994f5bd7a43f309c0ca0cc9f7d1037bb
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,34 @@
+FROM debian:bookworm-slim AS builder
+
+RUN apt-get update && \
+ apt-get install -y curl git python3 python3-pip python3-venv && \
+ rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \
+ mkdir -p /Whisper-WebUI
+
+WORKDIR /Whisper-WebUI
+
+COPY requirements.txt .
+
+RUN python3 -m venv venv && \
+ . venv/bin/activate && \
+ pip install -U -r requirements.txt
+
+
+FROM debian:bookworm-slim AS runtime
+
+RUN apt-get update && \
+ apt-get install -y curl ffmpeg python3 && \
+ rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
+
+WORKDIR /Whisper-WebUI
+
+COPY . .
+COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
+
+VOLUME [ "/Whisper-WebUI/models" ]
+VOLUME [ "/Whisper-WebUI/outputs" ]
+
+ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
+ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
+
+ENTRYPOINT [ "python", "app.py" ]
diff --git a/Install.bat b/Install.bat
new file mode 100644
index 0000000000000000000000000000000000000000..26fa34c23349e616414e72c2a666c66df3acf84c
--- /dev/null
+++ b/Install.bat
@@ -0,0 +1,21 @@
+@echo off
+
+if not exist "%~dp0\venv\Scripts" (
+ echo Creating venv...
+ python -m venv venv
+)
+echo checked the venv folder. now installing requirements..
+
+call "%~dp0\venv\scripts\activate"
+
+python -m pip install -U pip
+pip install -r requirements.txt
+
+if errorlevel 1 (
+ echo.
+ echo Requirements installation failed. please remove venv folder and run install.bat again.
+) else (
+ echo.
+ echo Requirements installed successfully.
+)
+pause
\ No newline at end of file
diff --git a/Install.sh b/Install.sh
new file mode 100644
index 0000000000000000000000000000000000000000..e3430bfcf5933d54530f4cf9e89790368474407d
--- /dev/null
+++ b/Install.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+if [ ! -d "venv" ]; then
+ echo "Creating virtual environment..."
+ python -m venv venv
+fi
+
+source venv/bin/activate
+
+python -m pip install -U pip
+pip install -r requirements.txt && echo "Requirements installed successfully." || {
+ echo ""
+ echo "Requirements installation failed. Please remove the venv folder and run the script again."
+ deactivate
+ exit 1
+}
+
+deactivate
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..a35c2189609170339953d8e36edc5f1cd04c0f24
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2023 jhj0517
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 90214b07954c96f40a6429f7e91df12bed7f8583..5b9cd261d1c38977057f7596bdf5986764d0e58c 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,134 @@
----
-title: Whisper WebUI
-emoji: 📈
-colorFrom: green
-colorTo: green
-sdk: gradio
-sdk_version: 5.13.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co./docs/hub/spaces-config-reference
+# Whisper-WebUI
+A Gradio-based browser interface for [Whisper](https://github.com/openai/whisper). You can use it as an Easy Subtitle Generator!
+
+
+
+
+
+## Notebook
+If you wish to try this on Colab, you can do it in [here](https://colab.research.google.com/github/jhj0517/Whisper-WebUI/blob/master/notebook/whisper-webui.ipynb)!
+
+# Feature
+- Select the Whisper implementation you want to use between :
+ - [openai/whisper](https://github.com/openai/whisper)
+ - [SYSTRAN/faster-whisper](https://github.com/SYSTRAN/faster-whisper) (used by default)
+ - [Vaibhavs10/insanely-fast-whisper](https://github.com/Vaibhavs10/insanely-fast-whisper)
+- Generate subtitles from various sources, including :
+ - Files
+ - Youtube
+ - Microphone
+- Currently supported subtitle formats :
+ - SRT
+ - WebVTT
+ - txt ( only text file without timeline )
+- Speech to Text Translation
+ - From other languages to English. ( This is Whisper's end-to-end speech-to-text translation feature )
+- Text to Text Translation
+ - Translate subtitle files using Facebook NLLB models
+ - Translate subtitle files using DeepL API
+- Pre-processing audio input with [Silero VAD](https://github.com/snakers4/silero-vad).
+- Pre-processing audio input to separate BGM with [UVR](https://github.com/Anjok07/ultimatevocalremovergui).
+- Post-processing with speaker diarization using the [pyannote](https://huggingface.co./pyannote/speaker-diarization-3.1) model.
+ - To download the pyannote model, you need to have a Huggingface token and manually accept their terms in the pages below.
+ 1. https://huggingface.co./pyannote/speaker-diarization-3.1
+ 2. https://huggingface.co./pyannote/segmentation-3.0
+
+### Pipeline Diagram
+
+
+# Installation and Running
+
+- ## Running with Pinokio
+
+The app is able to run with [Pinokio](https://github.com/pinokiocomputer/pinokio).
+
+1. Install [Pinokio Software](https://program.pinokio.computer/#/?id=install).
+2. Open the software and search for Whisper-WebUI and install it.
+3. Start the Whisper-WebUI and connect to the `http://localhost:7860`.
+
+- ## Running with Docker
+
+1. Install and launch [Docker-Desktop](https://www.docker.com/products/docker-desktop/).
+
+2. Git clone the repository
+
+```sh
+git clone https://github.com/jhj0517/Whisper-WebUI.git
+```
+
+3. Build the image ( Image is about 7GB~ )
+
+```sh
+docker compose build
+```
+
+4. Run the container
+
+```sh
+docker compose up
+```
+
+5. Connect to the WebUI with your browser at `http://localhost:7860`
+
+If needed, update the [`docker-compose.yaml`](https://github.com/jhj0517/Whisper-WebUI/blob/master/docker-compose.yaml) to match your environment.
+
+- ## Run Locally
+
+### Prerequisite
+To run this WebUI, you need to have `git`, `3.10 <= python <= 3.12`, `FFmpeg`.
+And if you're not using an Nvida GPU, or using a different `CUDA` version than 12.4, edit the [`requirements.txt`](https://github.com/jhj0517/Whisper-WebUI/blob/master/requirements.txt) to match your environment.
+
+Please follow the links below to install the necessary software:
+- git : [https://git-scm.com/downloads](https://git-scm.com/downloads)
+- python : [https://www.python.org/downloads/](https://www.python.org/downloads/) **`3.10 ~ 3.12` is recommended.**
+- FFmpeg : [https://ffmpeg.org/download.html](https://ffmpeg.org/download.html)
+- CUDA : [https://developer.nvidia.com/cuda-downloads](https://developer.nvidia.com/cuda-downloads)
+
+After installing FFmpeg, **make sure to add the `FFmpeg/bin` folder to your system PATH!**
+
+### Installation Using the Script Files
+
+1. git clone this repository
+```shell
+git clone https://github.com/jhj0517/Whisper-WebUI.git
+```
+2. Run `install.bat` or `install.sh` to install dependencies. (It will create a `venv` directory and install dependencies there.)
+3. Start WebUI with `start-webui.bat` or `start-webui.sh` (It will run `python app.py` after activating the venv)
+
+And you can also run the project with command line arguments if you like to, see [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for a guide to arguments.
+
+# VRAM Usages
+This project is integrated with [faster-whisper](https://github.com/guillaumekln/faster-whisper) by default for better VRAM usage and transcription speed.
+
+According to faster-whisper, the efficiency of the optimized whisper model is as follows:
+| Implementation | Precision | Beam size | Time | Max. GPU memory | Max. CPU memory |
+|-------------------|-----------|-----------|-------|-----------------|-----------------|
+| openai/whisper | fp16 | 5 | 4m30s | 11325MB | 9439MB |
+| faster-whisper | fp16 | 5 | 54s | 4755MB | 3244MB |
+
+If you want to use an implementation other than faster-whisper, use `--whisper_type` arg and the repository name.
+Read [wiki](https://github.com/jhj0517/Whisper-WebUI/wiki/Command-Line-Arguments) for more info about CLI args.
+
+If you want to use a fine-tuned model, manually place the models in `models/Whisper/` corresponding to the implementation.
+
+Alternatively, if you enter the huggingface repo id (e.g, [deepdml/faster-whisper-large-v3-turbo-ct2](https://huggingface.co./deepdml/faster-whisper-large-v3-turbo-ct2)) in the "Model" dropdown, it will be automatically downloaded in the directory.
+
+
+
+# REST API
+If you're interested in deploying this app as a REST API, please check out [/backend](https://github.com/jhj0517/Whisper-WebUI/tree/master/backend).
+
+## TODO🗓
+
+- [x] Add DeepL API translation
+- [x] Add NLLB Model translation
+- [x] Integrate with faster-whisper
+- [x] Integrate with insanely-fast-whisper
+- [x] Integrate with whisperX ( Only speaker diarization part )
+- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
+- [x] Add fast api script
+- [ ] Add CLI usages
+- [ ] Support real-time transcription for microphone
+
+### Translation 🌐
+Any PRs that translate the language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..a81390d504a71cc5c437972b757ed1eee78d33a7
--- /dev/null
+++ b/app.py
@@ -0,0 +1,368 @@
+import os
+import argparse
+import gradio as gr
+from gradio_i18n import Translate, gettext as _
+import yaml
+
+from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
+ INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
+ UVR_MODELS_DIR, I18N_YAML_PATH)
+from modules.utils.files_manager import load_yaml, MEDIA_EXTENSION
+from modules.whisper.whisper_factory import WhisperFactory
+from modules.translation.nllb_inference import NLLBInference
+from modules.ui.htmls import *
+from modules.utils.cli_manager import str2bool
+from modules.utils.youtube_manager import get_ytmetas
+from modules.translation.deepl_api import DeepLAPI
+from modules.whisper.data_classes import *
+
+
+class App:
+ def __init__(self, args):
+ self.args = args
+ self.app = gr.Blocks(css=CSS, theme=self.args.theme, delete_cache=(60, 3600))
+ self.whisper_inf = WhisperFactory.create_whisper_inference(
+ whisper_type=self.args.whisper_type,
+ whisper_model_dir=self.args.whisper_model_dir,
+ faster_whisper_model_dir=self.args.faster_whisper_model_dir,
+ insanely_fast_whisper_model_dir=self.args.insanely_fast_whisper_model_dir,
+ uvr_model_dir=self.args.uvr_model_dir,
+ output_dir=self.args.output_dir,
+ )
+ self.nllb_inf = NLLBInference(
+ model_dir=self.args.nllb_model_dir,
+ output_dir=os.path.join(self.args.output_dir, "translations")
+ )
+ self.deepl_api = DeepLAPI(
+ output_dir=os.path.join(self.args.output_dir, "translations")
+ )
+ self.i18n = load_yaml(I18N_YAML_PATH)
+ self.default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
+ print(f"Use \"{self.args.whisper_type}\" implementation\n"
+ f"Device \"{self.whisper_inf.device}\" is detected")
+
+ def create_pipeline_inputs(self):
+ whisper_params = self.default_params["whisper"]
+ vad_params = self.default_params["vad"]
+ diarization_params = self.default_params["diarization"]
+ uvr_params = self.default_params["bgm_separation"]
+
+ with gr.Row():
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value=whisper_params["model_size"],
+ label=_("Model"), allow_custom_value=True)
+ dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
+ value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
+ else whisper_params["lang"], label=_("Language"))
+ dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
+ with gr.Row():
+ cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
+ interactive=True)
+ with gr.Row():
+ cb_timestamp = gr.Checkbox(value=whisper_params["add_timestamp"],
+ label=_("Add a timestamp to the end of the filename"),
+ interactive=True)
+
+ with gr.Accordion(_("Advanced Parameters"), open=False):
+ whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
+ whisper_type=self.args.whisper_type,
+ available_compute_types=self.whisper_inf.available_compute_types,
+ compute_type=self.whisper_inf.current_compute_type)
+
+ with gr.Accordion(_("Background Music Remover Filter"), open=False):
+ uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
+ available_models=self.whisper_inf.music_separator.available_models,
+ available_devices=self.whisper_inf.music_separator.available_devices,
+ device=self.whisper_inf.music_separator.device)
+
+ with gr.Accordion(_("Voice Detection Filter"), open=False):
+ vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
+
+ with gr.Accordion(_("Diarization"), open=False):
+ diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
+ available_devices=self.whisper_inf.diarizer.available_device,
+ device=self.whisper_inf.diarizer.device)
+
+ pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
+
+ return (
+ pipeline_inputs,
+ dd_file_format,
+ cb_timestamp
+ )
+
+ def launch(self):
+ translation_params = self.default_params["translation"]
+ deepl_params = translation_params["deepl"]
+ nllb_params = translation_params["nllb"]
+ uvr_params = self.default_params["bgm_separation"]
+
+ with self.app:
+ lang = gr.Radio(choices=list(self.i18n.keys()),
+ label=_("Language"), interactive=True,
+ visible=False, # Set it by development purpose.
+ )
+ with Translate(I18N_YAML_PATH):
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown(MARKDOWN, elem_id="md_project")
+ with gr.Tabs():
+ with gr.TabItem(_("File")): # tab1
+ with gr.Column():
+ input_file = gr.Files(type="filepath", label=_("Upload File here"), file_types=MEDIA_EXTENSION)
+ tb_input_folder = gr.Textbox(label="Input Folder Path (Optional)",
+ info="Optional: Specify the folder path where the input files are located, if you prefer to use local files instead of uploading them."
+ " Leave this field empty if you do not wish to use a local path.",
+ visible=self.args.colab,
+ value="")
+ cb_include_subdirectory = gr.Checkbox(label="Include Subdirectory Files",
+ info="When using Input Folder Path above, whether to include all files in the subdirectory or not.",
+ visible=self.args.colab,
+ value=False)
+ cb_save_same_dir = gr.Checkbox(label="Save outputs at same directory",
+ info="When using Input Folder Path above, whether to save output in the same directory as inputs or not, in addition to the original"
+ " output directory.",
+ visible=self.args.colab,
+ value=True)
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
+
+ with gr.Row():
+ btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
+ with gr.Row():
+ tb_indicator = gr.Textbox(label=_("Output"), scale=5)
+ files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3, interactive=False)
+ btn_openfolder = gr.Button('📂', scale=1)
+
+ params = [input_file, tb_input_folder, cb_include_subdirectory, cb_save_same_dir,
+ dd_file_format, cb_timestamp]
+ params = params + pipeline_params
+ btn_run.click(fn=self.whisper_inf.transcribe_file,
+ inputs=params,
+ outputs=[tb_indicator, files_subtitles])
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
+
+ with gr.TabItem(_("Youtube")): # tab2
+ with gr.Row():
+ tb_youtubelink = gr.Textbox(label=_("Youtube Link"))
+ with gr.Row(equal_height=True):
+ with gr.Column():
+ img_thumbnail = gr.Image(label=_("Youtube Thumbnail"))
+ with gr.Column():
+ tb_title = gr.Label(label=_("Youtube Title"))
+ tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
+
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
+
+ with gr.Row():
+ btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
+ with gr.Row():
+ tb_indicator = gr.Textbox(label=_("Output"), scale=5)
+ files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3)
+ btn_openfolder = gr.Button('📂', scale=1)
+
+ params = [tb_youtubelink, dd_file_format, cb_timestamp]
+
+ btn_run.click(fn=self.whisper_inf.transcribe_youtube,
+ inputs=params + pipeline_params,
+ outputs=[tb_indicator, files_subtitles])
+ tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
+ outputs=[img_thumbnail, tb_title, tb_description])
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
+
+ with gr.TabItem(_("Mic")): # tab3
+ with gr.Row():
+ mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True,
+ show_download_button=True)
+
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
+
+ with gr.Row():
+ btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
+ with gr.Row():
+ tb_indicator = gr.Textbox(label=_("Output"), scale=5)
+ files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3)
+ btn_openfolder = gr.Button('📂', scale=1)
+
+ params = [mic_input, dd_file_format, cb_timestamp]
+
+ btn_run.click(fn=self.whisper_inf.transcribe_mic,
+ inputs=params + pipeline_params,
+ outputs=[tb_indicator, files_subtitles])
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
+
+ with gr.TabItem(_("T2T Translation")): # tab 4
+ with gr.Row():
+ file_subs = gr.Files(type="filepath", label=_("Upload Subtitle Files to translate here"))
+
+ with gr.TabItem(_("DeepL API")): # sub tab1
+ with gr.Row():
+ tb_api_key = gr.Textbox(label=_("Your Auth Key (API KEY)"),
+ value=deepl_params["api_key"])
+ with gr.Row():
+ dd_source_lang = gr.Dropdown(label=_("Source Language"),
+ value=AUTOMATIC_DETECTION if deepl_params["source_lang"] == AUTOMATIC_DETECTION.unwrap()
+ else deepl_params["source_lang"],
+ choices=list(self.deepl_api.available_source_langs.keys()))
+ dd_target_lang = gr.Dropdown(label=_("Target Language"),
+ value=deepl_params["target_lang"],
+ choices=list(self.deepl_api.available_target_langs.keys()))
+ with gr.Row():
+ cb_is_pro = gr.Checkbox(label=_("Pro User?"), value=deepl_params["is_pro"])
+ with gr.Row():
+ cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"],
+ label=_("Add a timestamp to the end of the filename"),
+ interactive=True)
+ with gr.Row():
+ btn_run = gr.Button(_("TRANSLATE SUBTITLE FILE"), variant="primary")
+ with gr.Row():
+ tb_indicator = gr.Textbox(label=_("Output"), scale=5)
+ files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3)
+ btn_openfolder = gr.Button('📂', scale=1)
+
+ btn_run.click(fn=self.deepl_api.translate_deepl,
+ inputs=[tb_api_key, file_subs, dd_source_lang, dd_target_lang,
+ cb_is_pro, cb_timestamp],
+ outputs=[tb_indicator, files_subtitles])
+
+ btn_openfolder.click(
+ fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
+ inputs=None,
+ outputs=None)
+
+ with gr.TabItem(_("NLLB")): # sub tab2
+ with gr.Row():
+ dd_model_size = gr.Dropdown(label=_("Model"), value=nllb_params["model_size"],
+ choices=self.nllb_inf.available_models)
+ dd_source_lang = gr.Dropdown(label=_("Source Language"),
+ value=nllb_params["source_lang"],
+ choices=self.nllb_inf.available_source_langs)
+ dd_target_lang = gr.Dropdown(label=_("Target Language"),
+ value=nllb_params["target_lang"],
+ choices=self.nllb_inf.available_target_langs)
+ with gr.Row():
+ nb_max_length = gr.Number(label="Max Length Per Line", value=nllb_params["max_length"],
+ precision=0)
+ with gr.Row():
+ cb_timestamp = gr.Checkbox(value=translation_params["add_timestamp"],
+ label=_("Add a timestamp to the end of the filename"),
+ interactive=True)
+ with gr.Row():
+ btn_run = gr.Button(_("TRANSLATE SUBTITLE FILE"), variant="primary")
+ with gr.Row():
+ tb_indicator = gr.Textbox(label=_("Output"), scale=5)
+ files_subtitles = gr.Files(label=_("Downloadable output file"), scale=3)
+ btn_openfolder = gr.Button('📂', scale=1)
+ with gr.Column():
+ md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
+
+ btn_run.click(fn=self.nllb_inf.translate_file,
+ inputs=[file_subs, dd_model_size, dd_source_lang, dd_target_lang,
+ nb_max_length, cb_timestamp],
+ outputs=[tb_indicator, files_subtitles])
+
+ btn_openfolder.click(
+ fn=lambda: self.open_folder(os.path.join(self.args.output_dir, "translations")),
+ inputs=None,
+ outputs=None)
+
+ with gr.TabItem(_("BGM Separation")):
+ files_audio = gr.Files(type="filepath", label=_("Upload Audio Files to separate background music"))
+ dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
+ choices=self.whisper_inf.music_separator.available_devices)
+ dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["uvr_model_size"],
+ choices=self.whisper_inf.music_separator.available_models)
+ nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"],
+ precision=0)
+ cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"),
+ value=True, visible=False)
+ btn_run = gr.Button(_("SEPARATE BACKGROUND MUSIC"), variant="primary")
+ with gr.Column():
+ with gr.Row():
+ ad_instrumental = gr.Audio(label=_("Instrumental"), scale=8)
+ btn_open_instrumental_folder = gr.Button('📂', scale=1)
+ with gr.Row():
+ ad_vocals = gr.Audio(label=_("Vocals"), scale=8)
+ btn_open_vocals_folder = gr.Button('📂', scale=1)
+
+ btn_run.click(fn=self.whisper_inf.music_separator.separate_files,
+ inputs=[files_audio, dd_uvr_model_size, dd_uvr_device, nb_uvr_segment_size,
+ cb_uvr_save_file],
+ outputs=[ad_instrumental, ad_vocals])
+ btn_open_instrumental_folder.click(inputs=None,
+ outputs=None,
+ fn=lambda: self.open_folder(os.path.join(
+ self.args.output_dir, "UVR", "instrumental"
+ )))
+ btn_open_vocals_folder.click(inputs=None,
+ outputs=None,
+ fn=lambda: self.open_folder(os.path.join(
+ self.args.output_dir, "UVR", "vocals"
+ )))
+
+ # Launch the app with optional gradio settings
+ args = self.args
+ self.app.queue(
+ api_open=args.api_open
+ ).launch(
+ share=args.share,
+ server_name=args.server_name,
+ server_port=args.server_port,
+ auth=(args.username, args.password) if args.username and args.password else None,
+ root_path=args.root_path,
+ inbrowser=args.inbrowser,
+ ssl_verify=args.ssl_verify,
+ ssl_keyfile=args.ssl_keyfile,
+ ssl_keyfile_password=args.ssl_keyfile_password,
+ ssl_certfile=args.ssl_certfile,
+ allowed_paths=eval(args.allowed_paths) if args.allowed_paths else None
+ )
+
+ @staticmethod
+ def open_folder(folder_path: str):
+ if os.path.exists(folder_path):
+ os.system(f"start {folder_path}")
+ else:
+ os.makedirs(folder_path, exist_ok=True)
+ print(f"The directory path {folder_path} has newly created.")
+
+
+parser = argparse.ArgumentParser()
+parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
+ choices=[item.value for item in WhisperImpl],
+ help='A type of the whisper implementation (Github repo name)')
+parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
+parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
+parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
+parser.add_argument('--root_path', type=str, default=None, help='Gradio root path')
+parser.add_argument('--username', type=str, default=None, help='Gradio authentication username')
+parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
+parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
+parser.add_argument('--colab', type=str2bool, default=False, nargs='?', const=True, help='Is colab user or not')
+parser.add_argument('--api_open', type=str2bool, default=False, nargs='?', const=True,
+ help='Enable api or not in Gradio')
+parser.add_argument('--allowed_paths', type=str, default=None, help='Gradio allowed paths')
+parser.add_argument('--inbrowser', type=str2bool, default=True, nargs='?', const=True,
+ help='Whether to automatically start Gradio app or not')
+parser.add_argument('--ssl_verify', type=str2bool, default=True, nargs='?', const=True,
+ help='Whether to verify SSL or not')
+parser.add_argument('--ssl_keyfile', type=str, default=None, help='SSL Key file location')
+parser.add_argument('--ssl_keyfile_password', type=str, default=None, help='SSL Key file password')
+parser.add_argument('--ssl_certfile', type=str, default=None, help='SSL cert file location')
+parser.add_argument('--whisper_model_dir', type=str, default=WHISPER_MODELS_DIR,
+ help='Directory path of the whisper model')
+parser.add_argument('--faster_whisper_model_dir', type=str, default=FASTER_WHISPER_MODELS_DIR,
+ help='Directory path of the faster-whisper model')
+parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
+ default=INSANELY_FAST_WHISPER_MODELS_DIR,
+ help='Directory path of the insanely-fast-whisper model')
+parser.add_argument('--diarization_model_dir', type=str, default=DIARIZATION_MODELS_DIR,
+ help='Directory path of the diarization model')
+parser.add_argument('--nllb_model_dir', type=str, default=NLLB_MODELS_DIR,
+ help='Directory path of the Facebook NLLB model')
+parser.add_argument('--uvr_model_dir', type=str, default=UVR_MODELS_DIR,
+ help='Directory path of the UVR model')
+parser.add_argument('--output_dir', type=str, default=OUTPUT_DIR, help='Directory path of the outputs')
+_args = parser.parse_args()
+
+if __name__ == "__main__":
+ app = App(args=_args)
+ app.launch()
diff --git a/backend/Dockerfile b/backend/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..164a8083d32e1c2a2c2fa6af21145b62ce6010e3
--- /dev/null
+++ b/backend/Dockerfile
@@ -0,0 +1,36 @@
+FROM debian:bookworm-slim AS builder
+
+RUN apt-get update && \
+ apt-get install -y curl git python3 python3-pip python3-venv && \
+ rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* && \
+ mkdir -p /Whisper-WebUI
+
+WORKDIR /Whisper-WebUI
+
+COPY backend/ backend/
+COPY requirements.txt requirements.txt
+
+RUN python3 -m venv venv && \
+ . venv/bin/activate && \
+ pip install -U -r backend/requirements-backend.txt
+
+
+FROM debian:bookworm-slim AS runtime
+
+RUN apt-get update && \
+ apt-get install -y curl ffmpeg python3 && \
+ rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*
+
+WORKDIR /Whisper-WebUI
+
+COPY . .
+COPY --from=builder /Whisper-WebUI/venv /Whisper-WebUI/venv
+
+VOLUME [ "/Whisper-WebUI/models" ]
+VOLUME [ "/Whisper-WebUI/outputs" ]
+VOLUME [ "/Whisper-WebUI/backend" ]
+
+ENV PATH="/Whisper-WebUI/venv/bin:$PATH"
+ENV LD_LIBRARY_PATH=/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cublas/lib:/Whisper-WebUI/venv/lib64/python3.11/site-packages/nvidia/cudnn/lib
+
+ENTRYPOINT ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
diff --git a/backend/README.md b/backend/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5cc639583c0dbe2900710963e3bef68687f6a2aa
--- /dev/null
+++ b/backend/README.md
@@ -0,0 +1,110 @@
+# Whisper-WebUI REST API
+REST API for Whisper-WebUI. Documentation is auto-generated upon deploying the app.
+
[Swagger UI](https://github.com/swagger-api/swagger-ui) is available at `app/docs` or root URL with redirection. [Redoc](https://github.com/Redocly/redoc) is available at `app/redoc`.
+
+# Setup and Installation
+
+Installation assumes that you are in the root directory of Whisper-WebUI
+
+1. Create `.env` in `backend/configs/.env`
+```
+HF_TOKEN="YOUR_HF_TOKEN FOR DIARIZATION MODEL (READ PERMISSION)"
+DB_URL="sqlite:///backend/records.db"
+```
+`HF_TOKEN` is used to download diarization model, `DB_URL` indicates where your db file is located. It is stored in `backend/` by default.
+
+2. Install dependency
+```
+pip install -r backend/requirements-backend.txt
+```
+
+3. Deploy the server with `uvicorn` or whatever.
+```
+uvicorn backend.main:app --host 0.0.0.0 --port 8000
+```
+
+### Deploy with your domain name
+You can deploy the server with your domain name by setting up a reverse proxy with Nginx.
+
+1. Install Nginx if you don't already have it.
+- Linux : https://nginx.org/en/docs/install.html
+- Windows : https://nginx.org/en/docs/windows.html
+
+2. Edit [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf) for your domain name.
+https://github.com/jhj0517/Whisper-WebUI/blob/895cafe400944396ad8be5b1cc793b54fecc8bbe/backend/nginx/nginx.conf#L12
+
+3. Add an A type record of your public IPv4 address in your domain provider. (you can get it by searching "What is my IP" in Google)
+
+4. Open a terminal and go to the location of [`nginx.conf`](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/nginx/nginx.conf), then start the nginx server, so that you can manage nginx-related logs there.
+```shell
+cd backend/nginx
+nginx -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf"
+```
+
+5. Open another terminal in the root project location `/Whisper-WebUI`, and deploy the app with `uvicorn` or whatever. Now the app will be available at your domain.
+```shell
+uvicorn backend.main:app --host 0.0.0.0 --port 8000
+```
+
+6. When you turn off nginx, you can use `nginx -s stop`.
+```shell
+cd backend/nginx
+nginx -s stop -c "/path/to/Whisper-WebUI/backend/nginx/nginx.conf"
+```
+
+
+## Configuration
+You can set some server configurations in [config.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/backend/configs/config.yaml).
+
For example, initial model size for Whisper or the cleanup frequency and TTL for cached files.
+
If the endpoint generates and saves the file, all output files are stored in the `cache` directory, e.g. separated vocal/instrument files for `/bgm-separation` are saved in `cache` directory.
+
+## Docker
+The Dockerfile should be built when you're in the root directory of Whisper-WebUI.
+
+1. git clone this repository
+```
+git clone https://github.com/jhj0517/Whisper-WebUI.git
+```
+2. Mount volume paths with your local paths in `docker-compose.yaml`
+https://github.com/jhj0517/Whisper-WebUI/blob/1dd708ec3844dbf0c1f77de9ef5764e883dd4c78/backend/docker-compose.yaml#L12-L15
+3. Build the image
+```
+docker compose -f backend/docker-compose.yaml build
+```
+4. Run the container
+```
+docker compose -f backend/docker-compose.yaml up
+```
+
+5. Then you can read docs at `localhost:8000` (default port is set to `8000` in `docker-compose.yaml`) and run your own tests.
+
+
+# Architecture
+
+
+
+The response can be obtained through [the polling API](https://docs.oracle.com/en/cloud/saas/marketing/responsys-develop/API/REST/Async/asyncApi-v1.3-requests-requestId-get.htm).
+Each task is stored in the DB whenever the task is queued or updated by the process.
+
+When the client first sends the `POST` request, the server returns an `identifier` to the client that can be used to track the status of the task. The task status is updated by the processes, and once the task is completed, the client can finally obtain the result.
+
+The client needs to implement manual API polling to do this, this is the example for the python client:
+```python
+def wait_for_task_completion(identifier: str,
+ max_attempts: int = 20,
+ frequency: int = 3) -> httpx.Response:
+ """
+ Polls the task status every `frequency` until it is completed, failed, or the `max_attempts` are reached.
+ """
+ attempts = 0
+ while attempts < max_attempts:
+ task = fetch_task(identifier)
+ status = task.json()["status"]
+ if status == "COMPLETED":
+ return task["result"]
+ if status == "FAILED":
+ raise Exception("Task polling failed")
+ time.sleep(frequency)
+ attempts += 1
+ return None
+```
diff --git a/backend/__init__.py b/backend/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/cache/cached_files_are_generated_here b/backend/cache/cached_files_are_generated_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/common/audio.py b/backend/common/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..afaf91049dbc86d2e5f003b02f1e903ed83711ad
--- /dev/null
+++ b/backend/common/audio.py
@@ -0,0 +1,36 @@
+from io import BytesIO
+import numpy as np
+import httpx
+import faster_whisper
+from pydantic import BaseModel
+from fastapi import (
+ HTTPException,
+ UploadFile,
+)
+from typing import Annotated, Any, BinaryIO, Literal, Generator, Union, Optional, List, Tuple
+
+
+class AudioInfo(BaseModel):
+ duration: float
+
+
+async def read_audio(
+ file: Optional[UploadFile] = None,
+ file_url: Optional[str] = None
+):
+ """Read audio from "UploadFile". This resamples sampling rates to 16000."""
+ if (file and file_url) or (not file and not file_url):
+ raise HTTPException(status_code=400, detail="Provide only one of file or file_url")
+
+ if file:
+ file_content = await file.read()
+ elif file_url:
+ async with httpx.AsyncClient() as client:
+ file_response = await client.get(file_url)
+ if file_response.status_code != 200:
+ raise HTTPException(status_code=422, detail="Could not download the file")
+ file_content = file_response.content
+ file_bytes = BytesIO(file_content)
+ audio = faster_whisper.audio.decode_audio(file_bytes)
+ duration = len(audio) / 16000
+ return audio, AudioInfo(duration=duration)
diff --git a/backend/common/cache_manager.py b/backend/common/cache_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..dac4fa47fe21a9487406477ab12de6041d70b0ea
--- /dev/null
+++ b/backend/common/cache_manager.py
@@ -0,0 +1,21 @@
+import time
+import os
+from typing import Optional
+
+from modules.utils.paths import BACKEND_CACHE_DIR
+
+
+def cleanup_old_files(cache_dir: str = BACKEND_CACHE_DIR, ttl: int = 60):
+ now = time.time()
+ place_holder_name = "cached_files_are_generated_here"
+ for root, dirs, files in os.walk(cache_dir):
+ for filename in files:
+ if filename == place_holder_name:
+ continue
+ filepath = os.path.join(root, filename)
+ if now - os.path.getmtime(filepath) > ttl:
+ try:
+ os.remove(filepath)
+ except Exception as e:
+ print(f"Error removing {filepath}")
+ raise
diff --git a/backend/common/compresser.py b/backend/common/compresser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c63ac5f3b5f4c8c0079ccfb28674f043a577bc2b
--- /dev/null
+++ b/backend/common/compresser.py
@@ -0,0 +1,58 @@
+import os
+import zipfile
+from typing import List, Optional
+import hashlib
+
+
+def compress_files(file_paths: List[str], output_zip_path: str) -> str:
+ """
+ Compress multiple files into a single zip file.
+
+ Args:
+ file_paths (List[str]): List of paths to files to be compressed.
+ output_zip (str): Path and name of the output zip file.
+
+ Raises:
+ FileNotFoundError: If any of the input files doesn't exist.
+ """
+ os.makedirs(os.path.dirname(output_zip_path), exist_ok=True)
+ compression = zipfile.ZIP_DEFLATED
+
+ with zipfile.ZipFile(output_zip_path, 'w', compression=compression) as zipf:
+ for file_path in file_paths:
+ if not os.path.exists(file_path):
+ raise FileNotFoundError(f"File not found: {file_path}")
+
+ file_name = os.path.basename(file_path)
+ zipf.write(file_path, file_name)
+ return output_zip_path
+
+
+def get_file_hash(file_path: str) -> str:
+ """Generate the hash of a file using the specified hashing algorithm. It generates hash by content not path. """
+ hash_func = hashlib.new("sha256")
+ try:
+ with open(file_path, 'rb') as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_func.update(chunk)
+ return hash_func.hexdigest()
+ except FileNotFoundError:
+ return f"File not found: {file_path}"
+ except Exception as e:
+ return f"An error occurred: {str(e)}"
+
+
+def find_file_by_hash(dir_path: str, hash_str: str) -> Optional[str]:
+ """Get file path from the directory based on its hash"""
+ if not os.path.exists(dir_path) and os.path.isdir(dir_path):
+ raise ValueError(f"Directory {dir_path} does not exist")
+
+ files = [os.path.join(dir_path, f) for f in os.listdir(dir_path) if os.path.isfile(os.path.join(dir_path, f))]
+
+ for f in files:
+ f_hash = get_file_hash(f)
+ if hash_str == f_hash:
+ return f
+ return None
+
+
diff --git a/backend/common/config_loader.py b/backend/common/config_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0aebbaa67d811a608148391bf33ab3f9a9ecdc7
--- /dev/null
+++ b/backend/common/config_loader.py
@@ -0,0 +1,25 @@
+from dotenv import load_dotenv
+import os
+from modules.utils.paths import SERVER_CONFIG_PATH, SERVER_DOTENV_PATH
+from modules.utils.files_manager import load_yaml, save_yaml
+
+import functools
+
+
+@functools.lru_cache
+def load_server_config(config_path: str = SERVER_CONFIG_PATH) -> dict:
+ if os.getenv("TEST_ENV", "false").lower() == "true":
+ server_config = load_yaml(config_path)
+ server_config["whisper"]["model_size"] = "tiny"
+ server_config["whisper"]["compute_type"] = "float32"
+ save_yaml(server_config, config_path)
+
+ return load_yaml(config_path)
+
+
+@functools.lru_cache
+def read_env(key: str, default: str = None, dotenv_path: str = SERVER_DOTENV_PATH):
+ load_dotenv(dotenv_path)
+ value = os.getenv(key, default)
+ return value
+
diff --git a/backend/common/models.py b/backend/common/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f23f312844ce21516c535264f5208b4b4dbcbe4
--- /dev/null
+++ b/backend/common/models.py
@@ -0,0 +1,14 @@
+from pydantic import BaseModel, Field, validator
+from typing import List, Any, Optional
+from backend.db.task.models import TaskStatus, ResultType, TaskType
+
+
+class QueueResponse(BaseModel):
+ identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking")
+ status: TaskStatus = Field(..., description="Current status of the task")
+ message: str = Field(..., description="Message providing additional information about the task")
+
+
+class Response(BaseModel):
+ identifier: str
+ message: str
diff --git a/backend/configs/config.yaml b/backend/configs/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13b9a02d767c05919f840645d53d4ab03f6d9ab0
--- /dev/null
+++ b/backend/configs/config.yaml
@@ -0,0 +1,23 @@
+whisper:
+ # Default implementation is faster-whisper. This indicates model name within `models\Whisper\faster-whisper`
+ model_size: large-v2
+ # Compute type. 'float16' for CUDA, 'float32' for CPU.
+ compute_type: float16
+
+bgm_separation:
+ # UVR model sizes between ["UVR-MDX-NET-Inst_HQ_4", "UVR-MDX-NET-Inst_3"]
+ model_size: UVR-MDX-NET-Inst_HQ_4
+ # Whether to offload the model after the inference. Should be true if your setup has a VRAM less than <16GB
+ enable_offload: true
+ # Device to load BGM separation model
+ device: cuda
+
+# Settings that apply to the `cache' directory. The output files for `/bgm-separation` are stored in the `cache' directory,
+# (You can check out the actual generated files by testing `/bgm-separation`.)
+# You can adjust the TTL/cleanup frequency of the files in the `cache' directory here.
+cache:
+ # TTL (Time-To-Live) in seconds, defaults to 10 minutes
+ ttl: 600
+ # Clean up frequency in seconds, defaults to 1 minutes
+ frequency: 60
+
diff --git a/backend/db/__init__.py b/backend/db/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/db/db_instance.py b/backend/db/db_instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..182d990e61c781094bbfbad0d619b66752acad33
--- /dev/null
+++ b/backend/db/db_instance.py
@@ -0,0 +1,42 @@
+import functools
+import os
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+from functools import wraps
+from sqlalchemy.exc import SQLAlchemyError
+from fastapi import HTTPException
+from sqlmodel import SQLModel
+from dotenv import load_dotenv
+
+from backend.common.config_loader import read_env
+
+
+@functools.lru_cache
+def init_db():
+ db_url = read_env("DB_URL", "sqlite:///backend/records.db")
+ engine = create_engine(db_url, connect_args={"check_same_thread": False})
+ SQLModel.metadata.create_all(engine)
+ return sessionmaker(autocommit=False, autoflush=False, bind=engine)
+
+
+def get_db_session():
+ db_instance = init_db()
+ return db_instance()
+
+
+def handle_database_errors(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ session = None
+ try:
+ session = get_db_session()
+ kwargs['session'] = session
+
+ return func(*args, **kwargs)
+ except Exception as e:
+ print(f"Database error has occurred: {e}")
+ raise
+ finally:
+ if session:
+ session.close()
+ return wrapper
diff --git a/backend/db/task/__init__.py b/backend/db/task/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/db/task/dao.py b/backend/db/task/dao.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e78c9fbe6201feb4509b4ebaf06d5d97c0e5291
--- /dev/null
+++ b/backend/db/task/dao.py
@@ -0,0 +1,94 @@
+from typing import Dict, Any
+from sqlalchemy.orm import Session
+from fastapi import Depends
+
+from ..db_instance import handle_database_errors, get_db_session
+from .models import Task, TasksResult, TaskStatus
+
+
+@handle_database_errors
+def add_task_to_db(
+ session,
+ status=TaskStatus.QUEUED,
+ task_type=None,
+ language=None,
+ task_params=None,
+ file_name=None,
+ url=None,
+ audio_duration=None,
+):
+ """
+ Add task to the db
+ """
+ task = Task(
+ status=status,
+ language=language,
+ file_name=file_name,
+ url=url,
+ task_type=task_type,
+ task_params=task_params,
+ audio_duration=audio_duration,
+ )
+ session.add(task)
+ session.commit()
+ return task.uuid
+
+
+@handle_database_errors
+def update_task_status_in_db(
+ identifier: str,
+ update_data: Dict[str, Any],
+ session: Session,
+):
+ """
+ Update task status and attributes in the database.
+
+ Args:
+ identifier (str): Identifier of the task to be updated.
+ update_data (Dict[str, Any]): Dictionary containing the attributes to update along with their new values.
+ session (Session, optional): Database session. Defaults to Depends(get_db_session).
+
+ Returns:
+ None
+ """
+ task = session.query(Task).filter_by(uuid=identifier).first()
+ if task:
+ for key, value in update_data.items():
+ setattr(task, key, value)
+ session.commit()
+
+
+@handle_database_errors
+def get_task_status_from_db(
+ identifier: str, session: Session
+):
+ """Retrieve task status from db"""
+ task = session.query(Task).filter(Task.uuid == identifier).first()
+ if task:
+ return task
+ else:
+ return None
+
+
+@handle_database_errors
+def get_all_tasks_status_from_db(session: Session):
+ """Get all tasks from db"""
+ columns = [Task.uuid, Task.status, Task.task_type]
+ query = session.query(*columns)
+ tasks = [task for task in query]
+ return TasksResult(tasks=tasks)
+
+
+@handle_database_errors
+def delete_task_from_db(identifier: str, session: Session):
+ """Delete task from db"""
+ task = session.query(Task).filter(Task.uuid == identifier).first()
+
+ if task:
+ # If the task exists, delete it from the database
+ session.delete(task)
+ session.commit()
+ return True
+ else:
+ # If the task does not exist, return False
+ return False
diff --git a/backend/db/task/models.py b/backend/db/task/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..80825bcd1f784b8fc11832c90de400fe3a4df53a
--- /dev/null
+++ b/backend/db/task/models.py
@@ -0,0 +1,174 @@
+# Ported from https://github.com/pavelzbornik/whisperX-FastAPI/blob/main/app/models.py
+
+from enum import Enum
+from pydantic import BaseModel
+from typing import Optional, List
+from uuid import uuid4
+from datetime import datetime
+from sqlalchemy.types import Enum as SQLAlchemyEnum
+from typing import Any
+from sqlmodel import SQLModel, Field, JSON, Column
+
+
+class ResultType(str, Enum):
+ JSON = "json"
+ FILEPATH = "filepath"
+
+
+class TaskStatus(str, Enum):
+ PENDING = "pending"
+ IN_PROGRESS = "in_progress"
+ COMPLETED = "completed"
+ FAILED = "failed"
+ CANCELLED = "cancelled"
+ QUEUED = "queued"
+ PAUSED = "paused"
+ RETRYING = "retrying"
+
+ def __str__(self):
+ return self.value
+
+
+class TaskType(str, Enum):
+ TRANSCRIPTION = "transcription"
+ VAD = "vad"
+ BGM_SEPARATION = "bgm_separation"
+
+ def __str__(self):
+ return self.value
+
+
+class TaskStatusResponse(BaseModel):
+ """`TaskStatusResponse` is a wrapper class that hides sensitive information from `Task`"""
+ identifier: str = Field(..., description="Unique identifier for the queued task that can be used for tracking")
+ status: TaskStatus = Field(..., description="Current status of the task")
+ task_type: Optional[TaskType] = Field(
+ default=None,
+ description="Type/category of the task"
+ )
+ result_type: Optional[ResultType] = Field(
+ default=ResultType.JSON,
+ description="Result type whether it's a filepath or JSON"
+ )
+ result: Optional[Any] = Field(
+ default=None,
+ description="JSON data representing the result of the task"
+ )
+ task_params: Optional[dict] = Field(
+ default=None,
+ description="Parameters of the task"
+ )
+ error: Optional[str] = Field(
+ default=None,
+ description="Error message, if any, associated with the task"
+ )
+ duration: Optional[float] = Field(
+ default=None,
+ description="Duration of the task execution"
+ )
+
+
+class Task(SQLModel, table=True):
+ """
+ Table to store tasks information.
+
+ Attributes:
+ - id: Unique identifier for each task (Primary Key).
+ - uuid: Universally unique identifier for each task.
+ - status: Current status of the task.
+ - result: JSON data representing the result of the task.
+ - result_type: Type of the data whether it is normal JSON data or filepath.
+ - file_name: Name of the file associated with the task.
+ - task_type: Type/category of the task.
+ - duration: Duration of the task execution.
+ - error: Error message, if any, associated with the task.
+ - created_at: Date and time of creation.
+ - updated_at: Date and time of last update.
+ """
+
+ __tablename__ = "tasks"
+
+ id: Optional[int] = Field(
+ default=None,
+ primary_key=True,
+ description="Unique identifier for each task (Primary Key)"
+ )
+ uuid: str = Field(
+ default_factory=lambda: str(uuid4()),
+ description="Universally unique identifier for each task"
+ )
+ status: Optional[TaskStatus] = Field(
+ default=None,
+ sa_column=Field(sa_column=SQLAlchemyEnum(TaskStatus)),
+ description="Current status of the task",
+ )
+ result: Optional[dict] = Field(
+ default_factory=dict,
+ sa_column=Column(JSON),
+ description="JSON data representing the result of the task"
+ )
+ result_type: Optional[ResultType] = Field(
+ default=ResultType.JSON,
+ sa_column=Field(sa_column=SQLAlchemyEnum(ResultType)),
+ description="Result type whether it's a filepath or JSON"
+ )
+ file_name: Optional[str] = Field(
+ default=None,
+ description="Name of the file associated with the task"
+ )
+ url: Optional[str] = Field(
+ default=None,
+ description="URL of the file associated with the task"
+ )
+ audio_duration: Optional[float] = Field(
+ default=None,
+ description="Duration of the audio in seconds"
+ )
+ language: Optional[str] = Field(
+ default=None,
+ description="Language of the file associated with the task"
+ )
+ task_type: Optional[TaskType] = Field(
+ default=None,
+ sa_column=Field(sa_column=SQLAlchemyEnum(TaskType)),
+ description="Type/category of the task"
+ )
+ task_params: Optional[dict] = Field(
+ default_factory=dict,
+ sa_column=Column(JSON),
+ description="Parameters of the task"
+ )
+ duration: Optional[float] = Field(
+ default=None,
+ description="Duration of the task execution"
+ )
+ error: Optional[str] = Field(
+ default=None,
+ description="Error message, if any, associated with the task"
+ )
+ created_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ description="Date and time of creation"
+ )
+ updated_at: datetime = Field(
+ default_factory=datetime.utcnow,
+ sa_column_kwargs={"onupdate": datetime.utcnow},
+ description="Date and time of last update"
+ )
+
+ def to_response(self) -> "TaskStatusResponse":
+ return TaskStatusResponse(
+ identifier=self.uuid,
+ status=self.status,
+ task_type=self.task_type,
+ result_type=self.result_type,
+ result=self.result,
+ task_params=self.task_params,
+ error=self.error,
+ duration=self.duration
+ )
+
+
+class TasksResult(BaseModel):
+ tasks: List[Task]
+
diff --git a/backend/docker-compose.yaml b/backend/docker-compose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4ba05c2aefb2ccbad44a95470040f6a3e1c52d84
--- /dev/null
+++ b/backend/docker-compose.yaml
@@ -0,0 +1,33 @@
+services:
+ app:
+ build:
+ dockerfile: backend/Dockerfile
+ context: ..
+ image: jhj0517/whisper-webui-backend:latest
+
+ volumes:
+ # You can mount the container's volume paths to directory paths on your local machine.
+ # Models will be stored in the `./models' directory on your machine.
+ # Similarly, all output files will be stored in the `./outputs` directory.
+ # The DB file is saved in /Whisper-WebUI/backend/records.db unless you edit it in /Whisper-WebUI/backend/configs/.env
+ - ./models:/Whisper-WebUI/models
+ - ./outputs:/Whisper-WebUI/outputs
+ - ./backend:/Whisper-WebUI/backend
+
+ ports:
+ - "8000:8000"
+
+ stdin_open: true
+ tty: true
+
+ entrypoint: ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"]
+
+ # If you're not using Nvidia GPU, Update device to match yours.
+ # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [ gpu ]
\ No newline at end of file
diff --git a/backend/main.py b/backend/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7f3a2d6ab3f2876aa405542226b62fc7fe94627
--- /dev/null
+++ b/backend/main.py
@@ -0,0 +1,92 @@
+from contextlib import asynccontextmanager
+from fastapi import (
+ FastAPI,
+)
+from fastapi.responses import RedirectResponse
+from fastapi.middleware.cors import CORSMiddleware
+import os
+import time
+import threading
+
+from backend.db.db_instance import init_db
+from backend.routers.transcription.router import transcription_router, get_pipeline
+from backend.routers.vad.router import get_vad_model, vad_router
+from backend.routers.bgm_separation.router import get_bgm_separation_inferencer, bgm_separation_router
+from backend.routers.task.router import task_router
+from backend.common.config_loader import read_env, load_server_config
+from backend.common.cache_manager import cleanup_old_files
+from modules.utils.paths import SERVER_CONFIG_PATH, BACKEND_CACHE_DIR
+
+
+def clean_cache_thread(ttl: int, frequency: int) -> threading.Thread:
+ def clean_cache(_ttl: int, _frequency: int):
+ while True:
+ cleanup_old_files(cache_dir=BACKEND_CACHE_DIR, ttl=_ttl)
+ time.sleep(_frequency)
+
+ return threading.Thread(
+ target=clean_cache,
+ args=(ttl, frequency),
+ daemon=True
+ )
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI):
+ # Basic setup initialization
+ server_config = load_server_config()
+ read_env("DB_URL") # Place .env file into /configs/.env
+ init_db()
+
+ # Inferencer initialization
+ transcription_pipeline = get_pipeline()
+ vad_inferencer = get_vad_model()
+ bgm_separation_inferencer = get_bgm_separation_inferencer()
+
+ # Thread initialization
+ cache_thread = clean_cache_thread(server_config["cache"]["ttl"], server_config["cache"]["frequency"])
+ cache_thread.start()
+
+ yield
+
+ # Release VRAM when server shutdown
+ transcription_pipeline = None
+ vad_inferencer = None
+ bgm_separation_inferencer = None
+
+
+app = FastAPI(
+ title="Whisper-WebUI-Backend",
+ description=f"""
+ REST API for Whisper-WebUI. Swagger UI is available via /docs or root URL with redirection. Redoc is available via /redoc.
+ """,
+ version="0.0.1",
+ lifespan=lifespan,
+ openapi_tags=[
+ {
+ "name": "BGM Separation",
+ "description": "Cached files for /bgm-separation are generated in the `backend/cache` directory,"
+ " you can set TLL for these files in `backend/configs/config.yaml`."
+ }
+ ]
+)
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["GET", "POST", "PUT", "PATCH", "OPTIONS"], # Disable DELETE
+ allow_headers=["*"],
+)
+app.include_router(transcription_router)
+app.include_router(vad_router)
+app.include_router(bgm_separation_router)
+app.include_router(task_router)
+
+
+@app.get("/", response_class=RedirectResponse, include_in_schema=False)
+async def index():
+ """
+ Redirect to the documentation. Defaults to Swagger UI.
+ You can also check the /redoc with redoc style: https://github.com/Redocly/redoc
+ """
+ return "/docs"
diff --git a/backend/nginx/logs/logs_are_generated_here b/backend/nginx/logs/logs_are_generated_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/nginx/nginx.conf b/backend/nginx/nginx.conf
new file mode 100644
index 0000000000000000000000000000000000000000..fe5d61b535ced6fe8c82fde73c908ffa94a5d88b
--- /dev/null
+++ b/backend/nginx/nginx.conf
@@ -0,0 +1,23 @@
+worker_processes 1;
+
+events {
+ worker_connections 1024;
+}
+
+http {
+ server {
+ listen 80;
+ client_max_body_size 4G;
+
+ server_name your-own-domain-name.com;
+
+ location / {
+ proxy_pass http://127.0.0.1:8000;
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ }
+ }
+}
+
diff --git a/backend/nginx/temp/temps_are_generated_here b/backend/nginx/temp/temps_are_generated_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/requirements-backend.txt b/backend/requirements-backend.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1d8a87232b40cc3dc1b4410d9b0987c572cd3094
--- /dev/null
+++ b/backend/requirements-backend.txt
@@ -0,0 +1,13 @@
+# Whisper-WebUI dependencies
+-r ../requirements.txt
+
+# Backend dependencies
+python-dotenv
+uvicorn
+SQLAlchemy
+sqlmodel
+pydantic
+
+# Test dependencies
+# pytest
+# pytest-asyncio
\ No newline at end of file
diff --git a/backend/routers/__init__.py b/backend/routers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/routers/bgm_separation/__init__.py b/backend/routers/bgm_separation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/routers/bgm_separation/models.py b/backend/routers/bgm_separation/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..4353ae895b13f77276863dca2527ea1825b46b5e
--- /dev/null
+++ b/backend/routers/bgm_separation/models.py
@@ -0,0 +1,6 @@
+from pydantic import BaseModel, Field
+
+
+class BGMSeparationResult(BaseModel):
+ instrumental_hash: str = Field(..., description="Instrumental file hash")
+ vocal_hash: str = Field(..., description="Vocal file hash")
diff --git a/backend/routers/bgm_separation/router.py b/backend/routers/bgm_separation/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..0886b1c19b1845e9561a48d263a081c5a59a4703
--- /dev/null
+++ b/backend/routers/bgm_separation/router.py
@@ -0,0 +1,119 @@
+import functools
+import numpy as np
+from fastapi import (
+ File,
+ UploadFile,
+)
+import gradio as gr
+from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
+from fastapi.responses import FileResponse
+from typing import List, Dict, Tuple
+from datetime import datetime
+import os
+
+from modules.whisper.data_classes import *
+from modules.uvr.music_separator import MusicSeparator
+from modules.utils.paths import BACKEND_CACHE_DIR
+from backend.common.audio import read_audio
+from backend.common.models import QueueResponse
+from backend.common.config_loader import load_server_config
+from backend.common.compresser import get_file_hash, find_file_by_hash
+from backend.db.task.models import TaskStatus, TaskType, ResultType
+from backend.db.task.dao import add_task_to_db, update_task_status_in_db
+from .models import BGMSeparationResult
+
+
+bgm_separation_router = APIRouter(prefix="/bgm-separation", tags=["BGM Separation"])
+
+
+@functools.lru_cache
+def get_bgm_separation_inferencer() -> 'MusicSeparator':
+ config = load_server_config()["bgm_separation"]
+ inferencer = MusicSeparator(
+ output_dir=os.path.join(BACKEND_CACHE_DIR, "UVR")
+ )
+ inferencer.update_model(
+ model_name=config["model_size"],
+ device=config["device"]
+ )
+ return inferencer
+
+
+def run_bgm_separation(
+ audio: np.ndarray,
+ params: BGMSeparationParams,
+ identifier: str,
+) -> Tuple[np.ndarray, np.ndarray]:
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.IN_PROGRESS,
+ "updated_at": datetime.utcnow()
+ }
+ )
+
+ start_time = datetime.utcnow()
+ instrumental, vocal, filepaths = get_bgm_separation_inferencer().separate(
+ audio=audio,
+ model_name=params.uvr_model_size,
+ device=params.uvr_device,
+ segment_size=params.segment_size,
+ save_file=True,
+ progress=gr.Progress()
+ )
+ instrumental_path, vocal_path = filepaths
+ elapsed_time = (datetime.utcnow() - start_time).total_seconds()
+
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.COMPLETED,
+ "result": BGMSeparationResult(
+ instrumental_hash=get_file_hash(instrumental_path),
+ vocal_hash=get_file_hash(vocal_path)
+ ).model_dump(),
+ "result_type": ResultType.FILEPATH,
+ "updated_at": datetime.utcnow(),
+ "duration": elapsed_time
+ }
+ )
+ return instrumental, vocal
+
+
+@bgm_separation_router.post(
+ "/",
+ response_model=QueueResponse,
+ status_code=status.HTTP_201_CREATED,
+ summary="Separate Background BGM abd vocal",
+ description="Separate background music and vocal from an uploaded audio or video file.",
+)
+async def bgm_separation(
+ background_tasks: BackgroundTasks,
+ file: UploadFile = File(..., description="Audio or video file to separate background music."),
+ params: BGMSeparationParams = Depends()
+) -> QueueResponse:
+ if not isinstance(file, np.ndarray):
+ audio, info = await read_audio(file=file)
+ else:
+ audio, info = file, None
+
+ identifier = add_task_to_db(
+ status=TaskStatus.QUEUED,
+ file_name=file.filename,
+ audio_duration=info.duration if info else None,
+ task_type=TaskType.BGM_SEPARATION,
+ task_params=params.model_dump(),
+ )
+
+ background_tasks.add_task(
+ run_bgm_separation,
+ audio=audio,
+ params=params,
+ identifier=identifier
+ )
+
+ return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="BGM Separation task has queued")
+
+
diff --git a/backend/routers/task/__init__.py b/backend/routers/task/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/routers/task/router.py b/backend/routers/task/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3046725da651492d51c2a8ba3eeaec033e9320
--- /dev/null
+++ b/backend/routers/task/router.py
@@ -0,0 +1,130 @@
+from fastapi import APIRouter, Depends, HTTPException, status
+from fastapi.responses import FileResponse
+from sqlalchemy.orm import Session
+import os
+
+from backend.db.db_instance import get_db_session
+from backend.db.task.dao import (
+ get_task_status_from_db,
+ get_all_tasks_status_from_db,
+ delete_task_from_db,
+)
+from backend.db.task.models import (
+ TasksResult,
+ Task,
+ TaskStatusResponse,
+ TaskType
+)
+from backend.common.models import (
+ Response,
+)
+from backend.common.compresser import compress_files, find_file_by_hash
+from modules.utils.paths import BACKEND_CACHE_DIR
+
+task_router = APIRouter(prefix="/task", tags=["Tasks"])
+
+
+@task_router.get(
+ "/{identifier}",
+ response_model=TaskStatusResponse,
+ status_code=status.HTTP_200_OK,
+ summary="Retrieve Task by Identifier",
+ description="Retrieve the specific task by its identifier.",
+)
+async def get_task(
+ identifier: str,
+ session: Session = Depends(get_db_session),
+) -> TaskStatusResponse:
+ """
+ Retrieve the specific task by its identifier.
+ """
+ task = get_task_status_from_db(identifier=identifier, session=session)
+
+ if task is not None:
+ return task.to_response()
+ else:
+ raise HTTPException(status_code=404, detail="Identifier not found")
+
+
+@task_router.get(
+ "/file/{identifier}",
+ status_code=status.HTTP_200_OK,
+ summary="Retrieve FileResponse Task by Identifier",
+ description="Retrieve the file response task by its identifier. You can use this endpoint if you need to download"
+ " The file as a response",
+)
+async def get_file_task(
+ identifier: str,
+ session: Session = Depends(get_db_session),
+) -> FileResponse:
+ """
+ Retrieve the downloadable file response of a specific task by its identifier.
+ Compressed by ZIP basically.
+ """
+ task = get_task_status_from_db(identifier=identifier, session=session)
+
+ if task is not None:
+ if task.task_type == TaskType.BGM_SEPARATION:
+ output_zip_path = os.path.join(BACKEND_CACHE_DIR, f"{identifier}_bgm_separation.zip")
+ instrumental_path = find_file_by_hash(
+ os.path.join(BACKEND_CACHE_DIR, "UVR", "instrumental"),
+ task.result["instrumental_hash"]
+ )
+ vocal_path = find_file_by_hash(
+ os.path.join(BACKEND_CACHE_DIR, "UVR", "vocals"),
+ task.result["vocal_hash"]
+ )
+
+ output_zip_path = compress_files(
+ [instrumental_path, vocal_path],
+ output_zip_path
+ )
+ return FileResponse(
+ path=output_zip_path,
+ status_code=200,
+ filename=output_zip_path,
+ media_type="application/zip"
+ )
+ else:
+ raise HTTPException(status_code=404, detail=f"File download is only supported for bgm separation."
+ f" The given type is {task.task_type}")
+ else:
+ raise HTTPException(status_code=404, detail="Identifier not found")
+
+
+# Delete method, commented by default because this endpoint is likely to require special permissions
+# @task_router.delete(
+# "/{identifier}",
+# response_model=Response,
+# status_code=status.HTTP_200_OK,
+# summary="Delete Task by Identifier",
+# description="Delete a task from the system using its identifier.",
+# )
+async def delete_task(
+ identifier: str,
+ session: Session = Depends(get_db_session),
+) -> Response:
+ """
+ Delete a task by its identifier.
+ """
+ if delete_task_from_db(identifier, session):
+ return Response(identifier=identifier, message="Task deleted")
+ else:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+
+# Get All method, commented by default because this endpoint is likely to require special permissions
+# @task_router.get(
+# "/all",
+# response_model=TasksResult,
+# status_code=status.HTTP_200_OK,
+# summary="Retrieve All Task Statuses",
+# description="Retrieve the statuses of all tasks available in the system.",
+# )
+async def get_all_tasks_status(
+ session: Session = Depends(get_db_session),
+) -> TasksResult:
+ """
+ Retrieve all tasks.
+ """
+ return get_all_tasks_status_from_db(session=session)
\ No newline at end of file
diff --git a/backend/routers/transcription/__init__.py b/backend/routers/transcription/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/routers/transcription/router.py b/backend/routers/transcription/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..b03b2df0146928f4ac2ff28acf780f8fccf64927
--- /dev/null
+++ b/backend/routers/transcription/router.py
@@ -0,0 +1,123 @@
+import functools
+import uuid
+import numpy as np
+from fastapi import (
+ File,
+ UploadFile,
+)
+import gradio as gr
+from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
+from typing import List, Dict
+from sqlalchemy.orm import Session
+from datetime import datetime
+from modules.whisper.data_classes import *
+from modules.utils.paths import BACKEND_CACHE_DIR
+from modules.whisper.faster_whisper_inference import FasterWhisperInference
+from backend.common.audio import read_audio
+from backend.common.models import QueueResponse
+from backend.common.config_loader import load_server_config
+from backend.db.task.dao import (
+ add_task_to_db,
+ get_db_session,
+ update_task_status_in_db
+)
+from backend.db.task.models import TaskStatus, TaskType
+
+transcription_router = APIRouter(prefix="/transcription", tags=["Transcription"])
+
+
+@functools.lru_cache
+def get_pipeline() -> 'FasterWhisperInference':
+ config = load_server_config()["whisper"]
+ inferencer = FasterWhisperInference(
+ output_dir=BACKEND_CACHE_DIR
+ )
+ inferencer.update_model(
+ model_size=config["model_size"],
+ compute_type=config["compute_type"]
+ )
+ return inferencer
+
+
+def run_transcription(
+ audio: np.ndarray,
+ params: TranscriptionPipelineParams,
+ identifier: str,
+) -> List[Segment]:
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.IN_PROGRESS,
+ "updated_at": datetime.utcnow()
+ },
+ )
+
+ segments, elapsed_time = get_pipeline().run(
+ audio,
+ gr.Progress(),
+ "SRT",
+ False,
+ *params.to_list()
+ )
+ segments = [seg.model_dump() for seg in segments]
+
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.COMPLETED,
+ "result": segments,
+ "updated_at": datetime.utcnow(),
+ "duration": elapsed_time
+ },
+ )
+ return segments
+
+
+@transcription_router.post(
+ "/",
+ response_model=QueueResponse,
+ status_code=status.HTTP_201_CREATED,
+ summary="Transcribe Audio",
+ description="Process the provided audio or video file to generate a transcription.",
+)
+async def transcription(
+ background_tasks: BackgroundTasks,
+ file: UploadFile = File(..., description="Audio or video file to transcribe."),
+ whisper_params: WhisperParams = Depends(),
+ vad_params: VadParams = Depends(),
+ bgm_separation_params: BGMSeparationParams = Depends(),
+ diarization_params: DiarizationParams = Depends(),
+) -> QueueResponse:
+ if not isinstance(file, np.ndarray):
+ audio, info = await read_audio(file=file)
+ else:
+ audio, info = file, None
+
+ params = TranscriptionPipelineParams(
+ whisper=whisper_params,
+ vad=vad_params,
+ bgm_separation=bgm_separation_params,
+ diarization=diarization_params
+ )
+
+ identifier = add_task_to_db(
+ status=TaskStatus.QUEUED,
+ file_name=file.filename,
+ audio_duration=info.duration if info else None,
+ language=params.whisper.lang,
+ task_type=TaskType.TRANSCRIPTION,
+ task_params=params.to_dict(),
+ )
+
+ background_tasks.add_task(
+ run_transcription,
+ audio=audio,
+ params=params,
+ identifier=identifier,
+ )
+
+ return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="Transcription task has queued")
+
+
diff --git a/backend/routers/vad/__init__.py b/backend/routers/vad/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/routers/vad/router.py b/backend/routers/vad/router.py
new file mode 100644
index 0000000000000000000000000000000000000000..3beacf4ce0227dd863576193ac16df88365a10f5
--- /dev/null
+++ b/backend/routers/vad/router.py
@@ -0,0 +1,101 @@
+import functools
+import numpy as np
+from faster_whisper.vad import VadOptions
+from fastapi import (
+ File,
+ UploadFile,
+)
+from fastapi import APIRouter, BackgroundTasks, Depends, Response, status
+from typing import List, Dict
+from datetime import datetime
+
+from modules.vad.silero_vad import SileroVAD
+from modules.whisper.data_classes import VadParams
+from backend.common.audio import read_audio
+from backend.common.models import QueueResponse
+from backend.db.task.dao import add_task_to_db, update_task_status_in_db
+from backend.db.task.models import TaskStatus, TaskType
+
+vad_router = APIRouter(prefix="/vad", tags=["Voice Activity Detection"])
+
+
+@functools.lru_cache
+def get_vad_model() -> SileroVAD:
+ inferencer = SileroVAD()
+ inferencer.update_model()
+ return inferencer
+
+
+def run_vad(
+ audio: np.ndarray,
+ params: VadOptions,
+ identifier: str,
+) -> List[Dict]:
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.IN_PROGRESS,
+ "updated_at": datetime.utcnow()
+ }
+ )
+
+ start_time = datetime.utcnow()
+ audio, speech_chunks = get_vad_model().run(
+ audio=audio,
+ vad_parameters=params
+ )
+ elapsed_time = (datetime.utcnow() - start_time).total_seconds()
+
+ update_task_status_in_db(
+ identifier=identifier,
+ update_data={
+ "uuid": identifier,
+ "status": TaskStatus.COMPLETED,
+ "updated_at": datetime.utcnow(),
+ "result": speech_chunks,
+ "duration": elapsed_time
+ }
+ )
+
+ return speech_chunks
+
+
+@vad_router.post(
+ "/",
+ response_model=QueueResponse,
+ status_code=status.HTTP_201_CREATED,
+ summary="Voice Activity Detection",
+ description="Detect voice parts in the provided audio or video file to generate a timeline of speech segments.",
+)
+async def vad(
+ background_tasks: BackgroundTasks,
+ file: UploadFile = File(..., description="Audio or video file to detect voices."),
+ params: VadParams = Depends()
+) -> QueueResponse:
+ if not isinstance(file, np.ndarray):
+ audio, info = await read_audio(file=file)
+ else:
+ audio, info = file, None
+
+ vad_options = VadOptions(
+ threshold=params.threshold,
+ min_speech_duration_ms=params.min_speech_duration_ms,
+ max_speech_duration_s=params.max_speech_duration_s,
+ min_silence_duration_ms=params.min_silence_duration_ms,
+ speech_pad_ms=params.speech_pad_ms
+ )
+
+ identifier = add_task_to_db(
+ status=TaskStatus.QUEUED,
+ file_name=file.filename,
+ audio_duration=info.duration if info else None,
+ task_type=TaskType.VAD,
+ task_params=params.model_dump(),
+ )
+
+ background_tasks.add_task(run_vad, audio=audio, params=vad_options, identifier=identifier)
+
+ return QueueResponse(identifier=identifier, status=TaskStatus.QUEUED, message="VAD task has queued")
+
+
diff --git a/backend/tests/__init__.py b/backend/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/backend/tests/test_backend_bgm_separation.py b/backend/tests/test_backend_bgm_separation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0dc86bd1c2d33c4e5ab70a07616034d71740b95
--- /dev/null
+++ b/backend/tests/test_backend_bgm_separation.py
@@ -0,0 +1,59 @@
+import pytest
+from fastapi import UploadFile
+from io import BytesIO
+import os
+import torch
+
+from backend.db.task.models import TaskStatus
+from backend.tests.test_task_status import wait_for_task_completion, fetch_file_response
+from backend.tests.test_backend_config import (
+ get_client, setup_test_file, get_upload_file_instance, calculate_wer,
+ TEST_BGM_SEPARATION_PARAMS, TEST_ANSWER, TEST_BGM_SEPARATION_OUTPUT_PATH
+)
+
+
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip the test because CUDA is not available")
+@pytest.mark.parametrize(
+ "bgm_separation_params",
+ [
+ TEST_BGM_SEPARATION_PARAMS
+ ]
+)
+def test_transcription_endpoint(
+ get_upload_file_instance,
+ bgm_separation_params: dict
+):
+ client = get_client()
+ file_content = BytesIO(get_upload_file_instance.file.read())
+ get_upload_file_instance.file.seek(0)
+
+ response = client.post(
+ "/bgm-separation",
+ files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
+ params=bgm_separation_params
+ )
+
+ assert response.status_code == 201
+ assert response.json()["status"] == TaskStatus.QUEUED
+ task_identifier = response.json()["identifier"]
+ assert isinstance(task_identifier, str) and task_identifier
+
+ completed_task = wait_for_task_completion(
+ identifier=task_identifier
+ )
+
+ assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
+ f"expected time."
+
+ result = completed_task.json()["result"]
+ assert "instrumental_hash" in result and result["instrumental_hash"]
+ assert "vocal_hash" in result and result["vocal_hash"]
+
+ file_response = fetch_file_response(task_identifier)
+ assert file_response.status_code == 200, f"Fetching File Response has failed. Response is: {file_response}"
+
+ with open(TEST_BGM_SEPARATION_OUTPUT_PATH, "wb") as file:
+ file.write(file_response.content)
+
+ assert os.path.exists(TEST_BGM_SEPARATION_OUTPUT_PATH)
+
diff --git a/backend/tests/test_backend_config.py b/backend/tests/test_backend_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e19d505de1fa00647df045cc18b121ac6e3d635
--- /dev/null
+++ b/backend/tests/test_backend_config.py
@@ -0,0 +1,67 @@
+import functools
+from fastapi import FastAPI, UploadFile
+from fastapi.testclient import TestClient
+from starlette.datastructures import UploadFile as StarletteUploadFile
+from io import BytesIO
+import os
+import requests
+import pytest
+import yaml
+import jiwer
+
+from backend.main import app
+from modules.whisper.data_classes import *
+from modules.utils.paths import *
+from modules.utils.files_manager import load_yaml, save_yaml
+
+TEST_PIPELINE_PARAMS = {**WhisperParams(model_size="tiny", compute_type="float32").model_dump(exclude_none=True),
+ **VadParams().model_dump(exclude_none=True),
+ **BGMSeparationParams().model_dump(exclude_none=True),
+ **DiarizationParams().model_dump(exclude_none=True)}
+TEST_VAD_PARAMS = VadParams().model_dump()
+TEST_BGM_SEPARATION_PARAMS = BGMSeparationParams().model_dump()
+TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
+TEST_FILE_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "jfk.wav")
+TEST_BGM_SEPARATION_OUTPUT_PATH = os.path.join(WEBUI_DIR, "backend", "tests", "separated_audio.zip")
+TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
+TEST_WHISPER_MODEL = "tiny"
+TEST_COMPUTE_TYPE = "float32"
+
+
+@pytest.fixture(autouse=True)
+def setup_test_file():
+ @functools.lru_cache
+ def download_file(url=TEST_FILE_DOWNLOAD_URL, file_path=TEST_FILE_PATH):
+ if os.path.exists(file_path):
+ return
+
+ if not os.path.exists(os.path.dirname(file_path)):
+ os.makedirs(os.path.dirname(file_path))
+
+ response = requests.get(url)
+
+ with open(file_path, "wb") as file:
+ file.write(response.content)
+
+ print(f"File downloaded to: {file_path}")
+
+ download_file(TEST_FILE_DOWNLOAD_URL, TEST_FILE_PATH)
+
+
+@pytest.fixture
+@functools.lru_cache
+def get_upload_file_instance(filepath: str = TEST_FILE_PATH) -> UploadFile:
+ with open(filepath, "rb") as f:
+ file_contents = BytesIO(f.read())
+ filename = os.path.basename(filepath)
+ upload_file = StarletteUploadFile(file=file_contents, filename=filename)
+ return upload_file
+
+
+@functools.lru_cache
+def get_client(app: FastAPI = app):
+ return TestClient(app)
+
+
+def calculate_wer(answer, prediction):
+ return jiwer.wer(answer, prediction)
diff --git a/backend/tests/test_backend_transcription.py b/backend/tests/test_backend_transcription.py
new file mode 100644
index 0000000000000000000000000000000000000000..bda2e346ff8f905e7a6ec71e8386c39d93d72e23
--- /dev/null
+++ b/backend/tests/test_backend_transcription.py
@@ -0,0 +1,50 @@
+import pytest
+from fastapi import UploadFile
+from io import BytesIO
+
+from backend.db.task.models import TaskStatus
+from backend.tests.test_task_status import wait_for_task_completion
+from backend.tests.test_backend_config import (
+ get_client, setup_test_file, get_upload_file_instance, calculate_wer,
+ TEST_PIPELINE_PARAMS, TEST_ANSWER
+)
+
+
+@pytest.mark.parametrize(
+ "pipeline_params",
+ [
+ TEST_PIPELINE_PARAMS
+ ]
+)
+def test_transcription_endpoint(
+ get_upload_file_instance,
+ pipeline_params: dict
+):
+ client = get_client()
+ file_content = BytesIO(get_upload_file_instance.file.read())
+ get_upload_file_instance.file.seek(0)
+
+ response = client.post(
+ "/transcription",
+ files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
+ params=pipeline_params
+ )
+
+ assert response.status_code == 201
+ assert response.json()["status"] == TaskStatus.QUEUED
+ task_identifier = response.json()["identifier"]
+ assert isinstance(task_identifier, str) and task_identifier
+
+ completed_task = wait_for_task_completion(
+ identifier=task_identifier
+ )
+
+ assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
+ f"expected time."
+
+ result = completed_task.json()["result"]
+ assert result, "Transcription text is empty"
+
+ wer = calculate_wer(TEST_ANSWER, result[0]["text"].strip().replace(",", "").replace(".", ""))
+ assert wer < 0.1, f"WER is too high, it's {wer}"
+
diff --git a/backend/tests/test_backend_vad.py b/backend/tests/test_backend_vad.py
new file mode 100644
index 0000000000000000000000000000000000000000..92fe5c546b725a0ededd8b6edc58aee79dfd0496
--- /dev/null
+++ b/backend/tests/test_backend_vad.py
@@ -0,0 +1,47 @@
+import pytest
+from fastapi import UploadFile
+from io import BytesIO
+
+from backend.db.task.models import TaskStatus
+from backend.tests.test_task_status import wait_for_task_completion
+from backend.tests.test_backend_config import (
+ get_client, setup_test_file, get_upload_file_instance, calculate_wer,
+ TEST_VAD_PARAMS, TEST_ANSWER
+)
+
+
+@pytest.mark.parametrize(
+ "vad_params",
+ [
+ TEST_VAD_PARAMS
+ ]
+)
+def test_transcription_endpoint(
+ get_upload_file_instance,
+ vad_params: dict
+):
+ client = get_client()
+ file_content = BytesIO(get_upload_file_instance.file.read())
+ get_upload_file_instance.file.seek(0)
+
+ response = client.post(
+ "/vad",
+ files={"file": (get_upload_file_instance.filename, file_content, "audio/mpeg")},
+ params=vad_params
+ )
+
+ assert response.status_code == 201
+ assert response.json()["status"] == TaskStatus.QUEUED
+ task_identifier = response.json()["identifier"]
+ assert isinstance(task_identifier, str) and task_identifier
+
+ completed_task = wait_for_task_completion(
+ identifier=task_identifier
+ )
+
+ assert completed_task is not None, f"Task with identifier {task_identifier} did not complete within the " \
+ f"expected time."
+
+ result = completed_task.json()["result"]
+ assert result and "start" in result[0] and "end" in result[0]
+
diff --git a/backend/tests/test_task_status.py b/backend/tests/test_task_status.py
new file mode 100644
index 0000000000000000000000000000000000000000..95fa91daf31b1cdd5592b3988feabe9fbb88f38e
--- /dev/null
+++ b/backend/tests/test_task_status.py
@@ -0,0 +1,56 @@
+import time
+import pytest
+from typing import Optional, Union
+import httpx
+
+from backend.db.task.models import TaskStatus, Task
+from backend.tests.test_backend_config import get_client
+
+
+def fetch_task(identifier: str):
+ """Get task status"""
+ client = get_client()
+ response = client.get(
+ f"/task/{identifier}"
+ )
+ if response.status_code == 200:
+ return response
+ return None
+
+
+def fetch_file_response(identifier: str):
+ """Get task status"""
+ client = get_client()
+ response = client.get(
+ f"/task/file/{identifier}"
+ )
+ if response.status_code == 200:
+ return response
+ return None
+
+
+def wait_for_task_completion(identifier: str,
+ max_attempts: int = 20,
+ frequency: int = 3) -> httpx.Response:
+ """
+ Polls the task status until it is completed, failed, or the maximum attempts are reached.
+
+ Args:
+ identifier (str): The unique identifier of the task to monitor.
+ max_attempts (int): The maximum number of polling attempts..
+ frequency (int): The time (in seconds) to wait between polling attempts.
+
+ Returns:
+ bool: Returns json if the task completes successfully within the allowed attempts.
+ """
+ attempts = 0
+ while attempts < max_attempts:
+ task = fetch_task(identifier)
+ status = task.json()["status"]
+ if status == TaskStatus.COMPLETED:
+ return task
+ if status == TaskStatus.FAILED:
+ raise Exception("Task polling failed")
+ time.sleep(frequency)
+ attempts += 1
+ return None
diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c563383426ef5783177b9cd0886018ddf153b7cf
--- /dev/null
+++ b/configs/default_parameters.yaml
@@ -0,0 +1,65 @@
+whisper:
+ model_size: "large-v2"
+ file_format: "SRT"
+ lang: "Automatic Detection"
+ is_translate: false
+ beam_size: 5
+ log_prob_threshold: -1
+ no_speech_threshold: 0.6
+ best_of: 5
+ patience: 1
+ condition_on_previous_text: true
+ prompt_reset_on_temperature: 0.5
+ initial_prompt: null
+ temperature: 0
+ compression_ratio_threshold: 2.4
+ chunk_length: 30
+ batch_size: 24
+ length_penalty: 1
+ repetition_penalty: 1
+ no_repeat_ngram_size: 0
+ prefix: null
+ suppress_blank: true
+ suppress_tokens: "[-1]"
+ max_initial_timestamp: 1
+ word_timestamps: false
+ prepend_punctuations: "\"'“¿([{-"
+ append_punctuations: "\"'.。,,!!??::”)]}、"
+ max_new_tokens: null
+ hallucination_silence_threshold: null
+ hotwords: null
+ language_detection_threshold: 0.5
+ language_detection_segments: 1
+ add_timestamp: false
+
+vad:
+ vad_filter: false
+ threshold: 0.5
+ min_speech_duration_ms: 250
+ max_speech_duration_s: 9999
+ min_silence_duration_ms: 1000
+ speech_pad_ms: 2000
+
+diarization:
+ is_diarize: false
+ hf_token: ""
+
+bgm_separation:
+ is_separate_bgm: false
+ uvr_model_size: "UVR-MDX-NET-Inst_HQ_4"
+ segment_size: 256
+ save_file: false
+ enable_offload: true
+
+translation:
+ deepl:
+ api_key: ""
+ is_pro: false
+ source_lang: "Automatic Detection"
+ target_lang: "English"
+ nllb:
+ model_size: "facebook/nllb-200-1.3B"
+ source_lang: null
+ target_lang: null
+ max_length: 200
+ add_timestamp: false
diff --git a/configs/translation.yaml b/configs/translation.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..340edd316cfcdbe7d936936178a45f781001c7e2
--- /dev/null
+++ b/configs/translation.yaml
@@ -0,0 +1,506 @@
+en: # English
+ Language: Language
+ File: File
+ Youtube: Youtube
+ Mic: Mic
+ T2T Translation: T2T Translation
+ BGM Separation: BGM Separation
+ GENERATE SUBTITLE FILE: GENERATE SUBTITLE FILE
+ Output: Output
+ Downloadable output file: Downloadable output file
+ Upload File here: Upload File here
+ Model: Model
+ Automatic Detection: Automatic Detection
+ File Format: File Format
+ Translate to English?: Translate to English? (Whisper's End-To-End Speech-To-Text translation feature)
+ Add a timestamp to the end of the filename: Add a timestamp to the end of the filename
+ Advanced Parameters: Advanced Parameters
+ Background Music Remover Filter: Background Music Remover Filter
+ Enabling this will remove background music: Enabling this will remove background music by submodel before transcribing
+ Enable Background Music Remover Filter: Enable Background Music Remover Filter
+ Save separated files to output: Save separated files to output
+ Offload sub model after removing background music: Offload sub model after removing background music
+ Voice Detection Filter: Voice Detection Filter
+ Enable this to transcribe only detected voice: Enable this to transcribe only detected voice parts by submodel.
+ Enable Silero VAD Filter: Enable Silero VAD Filter
+ Diarization: Diarization
+ Enable Diarization: Enable Diarization
+ HuggingFace Token: HuggingFace Token
+ This is only needed the first time you download the model: This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to "https://huggingface.co./pyannote/speaker-diarization-3.1" and "https://huggingface.co./pyannote/segmentation-3.0" and agree to their requirement.
+ Device: Device
+ Youtube Link: Youtube Link
+ Youtube Thumbnail: Youtube Thumbnail
+ Youtube Title: Youtube Title
+ Youtube Description: Youtube Description
+ Record with Mic: Record with Mic
+ Upload Subtitle Files to translate here: Upload Subtitle Files to translate here
+ Your Auth Key (API KEY): Your Auth Key (API KEY)
+ Source Language: Source Language
+ Target Language: Target Language
+ Pro User?: Pro User?
+ TRANSLATE SUBTITLE FILE: TRANSLATE SUBTITLE FILE
+ Upload Audio Files to separate background music: Upload Audio Files to separate background music
+ Instrumental: Instrumental
+ Vocals: Vocals
+ SEPARATE BACKGROUND MUSIC: SEPARATE BACKGROUND MUSIC
+
+ko: # Korean
+ Language: 언어
+ File: 파일
+ Youtube: 유튜브
+ Mic: 마이크
+ T2T Translation: T2T 자막 번역
+ BGM Separation: 배경 음악 분리
+ GENERATE SUBTITLE FILE: 자막 파일 생성
+ Output: 결과물
+ Downloadable output file: 결과물 파일 다운로드
+ Upload File here: 파일을 업로드 하세요
+ Model: 모델
+ Automatic Detection: 자동 감지
+ File Format: 파일 형식
+ Translate to English?: 영어로 번역합니까? (위스퍼 모델 자체 번역 기능)
+ Add a timestamp to the end of the filename: 파일 이름 끝에 타임스태프 붙이기
+ Advanced Parameters: 고급 변수
+ Background Music Remover Filter: 배경 음악 제거 필터
+ Enabling this will remove background music: 받아쓰기 이전에 먼저 배경 음악 제거용 서브 모델을 활성화 합니다.
+ Enable Background Music Remover Filter: 배경 음악 제거 필터 활성화
+ Save separated files to output: 분리된 배경 음악 & 음성 파일 따로 출력 폴더에 저장
+ Offload sub model after removing background music: 배경 음악 제거 후 서브 모델을 비활성화 합니다. (VRAM 이 부족할 시 체크하세요.)
+ Voice Detection Filter: 목소리 감지 필터
+ Enable this to transcribe only detected voice: 서브 모델에 의해 목소리라고 판단된 부분만 받아쓰기를 진행합니다.
+ Enable Silero VAD Filter: Silero VAD 필터 활성화
+ Diarization: 화자 구분
+ Enable Diarization: 화자 구분 활성화
+ HuggingFace Token: 허깅페이스 토큰
+ This is only needed the first time you download the model: 모델을 처음 다운받을 때만 토큰이 필요합니다. 이미 다운로드 받으신 상태라면 입력하지 않아도 됩니다. 모델을 다운 받기 위해선 "https://huggingface.co./pyannote/speaker-diarization-3.1" 와 "https://huggingface.co./pyannote/segmentation-3.0" 에서 먼저 사용 지침에 동의하셔야 합니다.
+ Device: 디바이스
+ Youtube Link: 유튜브 링크
+ Youtube Thumbnail: 유튜브 썸네일
+ Youtube Title: 유튜브 제목
+ Youtube Description: 유튜브 설명
+ Record with Mic: 마이크로 녹음하세요
+ Upload Subtitle Files to translate here: 번역할 자막 파일을 업로드 하세요
+ Your Auth Key (API KEY): DeepL API 키
+ Source Language: 원본 언어
+ Target Language: 대상 언어
+ Pro User?: Pro 버전 사용자
+ TRANSLATE SUBTITLE FILE: 자막 파일 번역
+ Upload Audio Files to separate background music: 배경 음악을 분리할 오디오 파일을 업로드 하세요
+ Instrumental: 악기
+ Vocals: 보컬
+ SEPARATE BACKGROUND MUSIC: 배경 음악 분리
+
+ja: # Japanese
+ Language: 言語
+ File: File
+ Youtube: Youtube
+ Mic: Mic
+ T2T Translation: T2T Translation
+ BGM Separation: BGM Separation
+ GENERATE SUBTITLE FILE: GENERATE SUBTITLE FILE
+ Output: Output
+ Downloadable output file: Downloadable output file
+ Upload File here: Upload File here
+ Model: Model
+ Automatic Detection: Automatic Detection
+ File Format: File Format
+ Translate to English?: Translate to English?
+ Add a timestamp to the end of the filename: Add a timestamp to the end of the filename
+ Advanced Parameters: Advanced Parameters
+ Background Music Remover Filter: Background Music Remover Filter
+ Enabling this will remove background music: Enabling this will remove background music by submodel before transcribing
+ Enable Background Music Remover Filter: Enable Background Music Remover Filter
+ Save separated files to output: Save separated files to output
+ Offload sub model after removing background music: Offload sub model after removing background music
+ Voice Detection Filter: Voice Detection Filter
+ Enable this to transcribe only detected voice: Enable this to transcribe only detected voice parts by submodel.
+ Enable Silero VAD Filter: Enable Silero VAD Filter
+ Diarization: Diarization
+ Enable Diarization: Enable Diarization
+ HuggingFace Token: HuggingFace Token
+ This is only needed the first time you download the model: This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to "https://huggingface.co./pyannote/speaker-diarization-3.1" and "https://huggingface.co./pyannote/segmentation-3.0" and agree to their requirement.
+ Device: Device
+ Youtube Link: Youtube Link
+ Youtube Thumbnail: Youtube Thumbnail
+ Youtube Title: Youtube Title
+ Youtube Description: Youtube Description
+ Record with Mic: Record with Mic
+ Upload Subtitle Files to translate here: Upload Subtitle Files to translate here
+ Your Auth Key (API KEY): Your Auth Key (API KEY)
+ Source Language: Source Language
+ Target Language: Target Language
+ Pro User?: Pro User?
+ TRANSLATE SUBTITLE FILE: TRANSLATE SUBTITLE FILE
+ Upload Audio Files to separate background music: Upload Audio Files to separate background music
+ Instrumental: Instrumental
+ Vocals: Vocals
+ SEPARATE BACKGROUND MUSIC: SEPARATE BACKGROUND MUSIC
+
+es: # Spanish
+ Language: Idioma
+ File: File
+ Youtube: Youtube
+ Mic: Mic
+ T2T Translation: T2T Translation
+ BGM Separation: BGM Separation
+ GENERATE SUBTITLE FILE: GENERATE SUBTITLE FILE
+ Output: Output
+ Downloadable output file: Downloadable output file
+ Upload File here: Upload File here
+ Model: Model
+ Automatic Detection: Automatic Detection
+ File Format: File Format
+ Translate to English?: Translate to English?
+ Add a timestamp to the end of the filename: Add a timestamp to the end of the filename
+ Advanced Parameters: Advanced Parameters
+ Background Music Remover Filter: Background Music Remover Filter
+ Enabling this will remove background music: Enabling this will remove background music by submodel before transcribing
+ Enable Background Music Remover Filter: Enable Background Music Remover Filter
+ Save separated files to output: Save separated files to output
+ Offload sub model after removing background music: Offload sub model after removing background music
+ Voice Detection Filter: Voice Detection Filter
+ Enable this to transcribe only detected voice: Enable this to transcribe only detected voice parts by submodel.
+ Enable Silero VAD Filter: Enable Silero VAD Filter
+ Diarization: Diarization
+ Enable Diarization: Enable Diarization
+ HuggingFace Token: HuggingFace Token
+ This is only needed the first time you download the model: This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to "https://huggingface.co./pyannote/speaker-diarization-3.1" and "https://huggingface.co./pyannote/segmentation-3.0" and agree to their requirement.
+ Device: Device
+ Youtube Link: Youtube Link
+ Youtube Thumbnail: Youtube Thumbnail
+ Youtube Title: Youtube Title
+ Youtube Description: Youtube Description
+ Record with Mic: Record with Mic
+ Upload Subtitle Files to translate here: Upload Subtitle Files to translate here
+ Your Auth Key (API KEY): Your Auth Key (API KEY)
+ Source Language: Source Language
+ Target Language: Target Language
+ Pro User?: Pro User?
+ TRANSLATE SUBTITLE FILE: TRANSLATE SUBTITLE FILE
+ Upload Audio Files to separate background music: Upload Audio Files to separate background music
+ Instrumental: Instrumental
+ Vocals: Vocals
+ SEPARATE BACKGROUND MUSIC: SEPARATE BACKGROUND MUSIC
+
+fr: # French
+ Language: Langue
+ File: Fichier
+ Youtube: Youtube
+ Mic: Microphone
+ T2T Translation: Traduction T2T
+ BGM Separation: Séparation de la musique de fond
+ GENERATE SUBTITLE FILE: GÉNÉRER UN FICHIER DE SOUS-TITRES
+ Output: Sortie
+ Downloadable output file: Fichier de sortie téléchargeable
+ Upload File here: Téléverser un fichier ici
+ Model: Modèle
+ Automatic Detection: Détection automatique
+ File Format: Format de fichier
+ Translate to English?: Traduire en anglais ?
+ Add a timestamp to the end of the filename: Ajouter un timestamp à la fin du nom de fichier
+ Advanced Parameters: Paramètres avancés
+ Background Music Remover Filter: Filtre de suppression de musique de fond
+ Enabling this will remove background music: L'activation supprimera la musique de fond via un sous-modèle avant la transcription
+ Enable Background Music Remover Filter: Activer le filtre de suppression de musique de fond
+ Save separated files to output: Sauvegarder les fichiers séparés dans la sortie
+ Offload sub model after removing background music: Décharger le sous-modèle après avoir supprimé la musique de fond
+ Voice Detection Filter: Filtre de détection vocale
+ Enable this to transcribe only detected voice: Activer pour transcrire uniquement la voix détectée
+ Enable Silero VAD Filter: Activer le filtre Silero VAD
+ Diarization: Diarisation
+ Enable Diarization: Activer la diarisation
+ HuggingFace Token: Token HuggingFace
+ This is only needed the first time you download the model: Le token n'est nécessaire que lors du premier téléchargement du modèle. Si vous avez déjà les modèles, vous n'avez pas besoin de l'entrer. Pour télécharger le modèle, vous devez aller manuellement sur "https://huggingface.co./pyannote/speaker-diarization-3.1" et "https://huggingface.co./pyannote/segmentation-3.0" et accepter leurs conditions.
+ Device: Périphérique
+ Youtube Link: Lien Youtube
+ Youtube Thumbnail: Miniature Youtube
+ Youtube Title: Titre Youtube
+ Youtube Description: Description Youtube
+ Record with Mic: Enregistrer avec le microphone
+ Upload Subtitle Files to translate here: Téléverser ici les fichiers de sous-titres à traduire
+ Your Auth Key (API KEY): Votre clé d'authentification (API KEY)
+ Source Language: Langue source
+ Target Language: Langue cible
+ Pro User?: Utilisateur Pro ?
+ TRANSLATE SUBTITLE FILE: TRADUIRE LE FICHIER DE SOUS-TITRES
+ Upload Audio Files to separate background music: Téléverser les fichiers audio pour séparer la musique de fond
+ Instrumental: Instrumental
+ Vocals: Voix
+ SEPARATE BACKGROUND MUSIC: SÉPARER LA MUSIQUE DE FOND
+
+de: # Deutsch
+ Language: Sprache
+ File: Datei
+ Youtube: Youtube
+ Mic: Mikrofon
+ T2T Translation: Text zu Text Übersetzung
+ BGM Separation: Hintergrundmusik herausfiltern
+ GENERATE SUBTITLE FILE: UNTERTITELDATEI ERSTELLEN
+ Output: Ausgabe
+ Downloadable output file: Herunterladbare Ausgabedatei
+ Upload File here: Datei hier hochladen
+ Model: Modell
+ Automatic Detection: Automatische Erkennung
+ File Format: Dateiformat
+ Translate to English?: Ins Englische übersetzen?
+ Add a timestamp to the end of the filename: Zeitstempel an den Dateinamen anhängen
+ Advanced Parameters: Erweiterte Parameter
+ Background Music Remover Filter: Hintergrundmusik-Entfernungsfilter
+ Enabling this will remove background music: Aktivierung entfernt Hintergrundmusik mithilfe des Submodells vor der Transkription
+ Enable Background Music Remover Filter: Hintergrundmusik-Entfernungsfilter aktivieren
+ Save separated files to output: Getrennte Dateien in der Ausgabe speichern
+ Offload sub model after removing background music: Submodell nach Entfernung der Hintergrundmusik entladen
+ Voice Detection Filter: Sprachfilter
+ Enable this to transcribe only detected voice: Aktivieren, um nur erkannte Sprachsegmente mithilfe des Submodells zu transkribieren.
+ Enable Silero VAD Filter: Silero VAD-Filter aktivieren
+ Diarization: Diarisierung
+ Enable Diarization: Diarisierung aktivieren
+ HuggingFace Token: HuggingFace-Token
+ This is only needed the first time you download the model: Dies ist nur erforderlich, wenn Sie das Modell zum ersten Mal herunterladen. Falls Sie die Modelle bereits besitzen, ist dies nicht nötig. Um das Modell herunterzuladen, müssen Sie manuell die Seiten „https://huggingface.co./pyannote/speaker-diarization-3.1“ und „https://huggingface.co./pyannote/segmentation-3.0“ besuchen und deren Anforderungen zustimmen.
+ Device: Gerät
+ Youtube Link: Youtube-Link
+ Youtube Thumbnail: Youtube-Thumbnail
+ Youtube Title: Youtube-Titel
+ Youtube Description: Youtube-Beschreibung
+ Record with Mic: Mit Mikrofon aufnehmen
+ Upload Subtitle Files to translate here: Untertiteldateien zur Übersetzung hier hochladen
+ Your Auth Key (API KEY): Ihr Authentifizierungsschlüssel (API-Schlüssel)
+ Source Language: Ausgangssprache
+ Target Language: Zielsprache
+ Pro User?: Pro-Benutzer?
+ TRANSLATE SUBTITLE FILE: UNTERTITELDATEI ÜBERSETZEN
+ Upload Audio Files to separate background music: Audiodateien hochladen, um Hintergrundmusik zu trennen
+ Instrumental: Instrumental
+ Vocals: Gesang
+ SEPARATE BACKGROUND MUSIC: HINTERGRUNDMUSIK TRENNEN
+
+
+zh: # Chinese
+ Language: 语言
+ File: 文件
+ Youtube: Youtube
+ Mic: 麦克风
+ T2T Translation: 文字翻译
+ BGM Separation: 背景音乐分离
+ GENERATE SUBTITLE FILE: 生成字幕文件
+ Output: 输出
+ Downloadable output file: 生成的字幕文件
+ Upload File here: 上传文件
+ Model: 模型选择
+ Automatic Detection: 自动检测语言
+ File Format: 字幕文件格式
+ Translate to English?: 将字幕翻译成英语
+ Add a timestamp to the end of the filename: 在文件名末尾加入时间戳
+ Advanced Parameters: 高级参数
+ Background Music Remover Filter: 背景音乐消除设置
+ Enabling this will remove background music: 启用此功能将会在进行转录前去除背景音乐
+ Enable Background Music Remover Filter: 去除背景音乐
+ Save separated files to output: 导出分离出的音频文件
+ Offload sub model after removing background music: 移除背景音乐后卸载子模型
+ Voice Detection Filter: 话音检测设置
+ Enable this to transcribe only detected voice: 启用此功能将仅转录检测到的语音部分
+ Enable Silero VAD Filter: 启用 Silero 语音活动检测 (VAD)
+ Diarization: 说话人分离设置
+ Enable Diarization: 进行说话人分离处理
+ HuggingFace Token: HuggingFace令牌
+ This is only needed the first time you download the model: 当且仅当您首次下载模型时才需要在此输入令牌,如果您已有模型则无需输入。 要下载模型,您必须先在"https://huggingface.co./pyannote/speaker-diarization-3.1" 与 "https://huggingface.co./pyannote/segmentation-3.0" 上同意他们的使用条款。
+ Device: 设备
+ Youtube Link: Youtube视频链接
+ Youtube Thumbnail: Youtube视频封面
+ Youtube Title: Youtube视频标题
+ Youtube Description: Youtube视频简介
+ Record with Mic: 通过麦克风录制
+ Upload Subtitle Files to translate here: 上传待翻译的字幕文件
+ Your Auth Key (API KEY): 填写您的认证令牌(API KEY)
+ Source Language: 原语言
+ Target Language: 目标语言
+ Pro User?: 您是否为专业版用户?
+ TRANSLATE SUBTITLE FILE: 翻译字幕文件
+ Upload Audio Files to separate background music: 上传待分离背景音乐的音频文件
+ Instrumental: 器乐声
+ Vocals: 人声
+ SEPARATE BACKGROUND MUSIC: 分离背景音乐与人声
+
+uk: # Ukrainian
+ Language: Мова
+ File: Файл
+ Youtube: Youtube
+ Mic: Мікрофон
+ T2T Translation: T2T Переклад
+ BGM Separation: Розділення фонової музики
+ GENERATE SUBTITLE FILE: СТВОРИТИ ФАЙЛ СУБТИТРІВ
+ Output: Результат
+ Downloadable output file: Завантажуваний файл результату
+ Upload File here: Завантажте файл тут
+ Model: Модель
+ Automatic Detection: Автоматичне визначення
+ File Format: Формат файлу
+ Translate to English?: Перекласти на англійську?
+ Add a timestamp to the end of the filename: Додати мітку часу до кінця імені файлу
+ Advanced Parameters: Розширені параметри
+ Background Music Remover Filter: Фільтр видалення фонової музики
+ Enabling this will remove background music: Увімкнення цього видалить фонову музику за допомогою підмоделі перед транскрипцією
+ Enable Background Music Remover Filter: Увімкнути фільтр видалення фонової музики
+ Save separated files to output: Зберегти розділені файли до вихідної папки
+ Offload sub model after removing background music: Вивантажити підмодель після видалення фонової музики
+ Voice Detection Filter: Фільтр розпізнавання голосу
+ Enable this to transcribe only detected voice: Увімкніть це, щоб транскрибувати лише розпізнані голосові частини за допомогою підмоделі
+ Enable Silero VAD Filter: Увімкнути фільтр Silero VAD
+ Diarization: Діаризація
+ Enable Diarization: Увімкнути діаризацію
+ HuggingFace Token: Токен HuggingFace
+ This is only needed the first time you download the model: Це потрібно лише при першому завантаженні моделі. Якщо у вас вже є моделі, вводити не потрібно. Щоб завантажити модель, потрібно вручну перейти на "https://huggingface.co./pyannote/speaker-diarization-3.1" та "https://huggingface.co./pyannote/segmentation-3.0" і погодитися з їхніми вимогами.
+ Device: Пристрій
+ Youtube Link: Посилання на Youtube
+ Youtube Thumbnail: Ескіз Youtube
+ Youtube Title: Назва Youtube
+ Youtube Description: Опис Youtube
+ Record with Mic: Записати з мікрофона
+ Upload Subtitle Files to translate here: Завантажте файли субтитрів для перекладу тут
+ Your Auth Key (API KEY): Ваш ключ авторизації (API KEY)
+ Source Language: Мова джерела
+ Target Language: Мова перекладу
+ Pro User?: Професійний користувач?
+ TRANSLATE SUBTITLE FILE: ПЕРЕКЛАСТИ ФАЙЛ СУБТИТРІВ
+ Upload Audio Files to separate background music: Завантажте аудіофайли для розділення фонової музики
+ Instrumental: Інструментал
+ Vocals: Вокал
+ SEPARATE BACKGROUND MUSIC: РОЗДІЛИТИ ФОНОВУ МУЗИКУ
+
+ru: # Russian
+ Language: Язык
+ File: Файл
+ Youtube: Youtube
+ Mic: Микрофон
+ T2T Translation: Перевод T2T
+ BGM Separation: Разделение фоновой музыки
+ GENERATE SUBTITLE FILE: СГЕНЕРИРОВАТЬ ФАЙЛ СУБТИТРОВ
+ Output: Результат
+ Downloadable output file: Загружаемый файл результата
+ Upload File here: Загрузите файл здесь
+ Model: Модель
+ Automatic Detection: Автоматическое определение
+ File Format: Формат файла
+ Translate to English?: Перевести на английский?
+ Add a timestamp to the end of the filename: Добавить метку времени в конец имени файла
+ Advanced Parameters: Расширенные параметры
+ Background Music Remover Filter: Фильтр удаления фоновой музыки
+ Enabling this will remove background music: Включение этого удалит фоновую музыку с помощью подмодели перед транскрипцией
+ Enable Background Music Remover Filter: Включить фильтр удаления фоновой музыки
+ Save separated files to output: Сохранить разделенные файлы в выходную папку
+ Offload sub model after removing background music: Выгрузить подмодель после удаления фоновой музыки
+ Voice Detection Filter: Фильтр обнаружения голоса
+ Enable this to transcribe only detected voice: Включите это, чтобы транскрибировать только обнаруженные голосовые части с помощью подмодели
+ Enable Silero VAD Filter: Включить фильтр Silero VAD
+ Diarization: Диаризация
+ Enable Diarization: Включить диаризацию
+ HuggingFace Token: Токен HuggingFace
+ This is only needed the first time you download the model: Это нужно только при первом скачивании модели. Если у вас уже есть модели, вводить не нужно. Чтобы скачать модель, нужно вручную перейти на "https://huggingface.co./pyannote/speaker-diarization-3.1" и "https://huggingface.co./pyannote/segmentation-3.0" и согласиться с их требованиями.
+ Device: Устройство
+ Youtube Link: Ссылка на Youtube
+ Youtube Thumbnail: Миниатюра Youtube
+ Youtube Title: Название Youtube
+ Youtube Description: Описание Youtube
+ Record with Mic: Записать с микрофона
+ Upload Subtitle Files to translate here: Загрузите файлы субтитров для перевода здесь
+ Your Auth Key (API KEY): Ваш Auth Key (API KEY)
+ Source Language: Исходный язык
+ Target Language: Целевой язык
+ Pro User?: Профессиональный пользователь?
+ TRANSLATE SUBTITLE FILE: ПЕРЕВЕСТИ ФАЙЛ СУБТИТРОВ
+ Upload Audio Files to separate background music: Загрузите аудиофайлы для разделения фоновой музыки
+ Instrumental: Инструментал
+ Vocals: Вокал
+ SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
+
+tr: # Turkish
+ Language: Dil
+ File: Dosya
+ Youtube: Youtube
+ Mic: Mikrofon
+ T2T Translation: T2T Çeviri
+ BGM Separation: Arka Plan Müziği Ayırma
+ GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR
+ Output: Çıktı
+ Downloadable output file: İndirilebilir çıktı dosyası
+ Upload File here: Dosya Yükle
+ Model: Model
+ Automatic Detection: Otomatik Algılama
+ File Format: Dosya Formatı
+ Translate to English?: İngilizceye Çevir?
+ Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle
+ Advanced Parameters: Gelişmiş Parametreler
+ Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi
+ Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır
+ Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir
+ Save separated files to output: Ayrılmış dosyaları çıktıya kaydet
+ Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak
+ Voice Detection Filter: Ses Algılama Filtresi
+ Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et
+ Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir
+ Diarization: Konuşmacı Ayrımı
+ Enable Diarization: Konuşmacı Ayrımını Etkinleştir
+ HuggingFace Token: HuggingFace Anahtarı
+ This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co./pyannote/speaker-diarization-3.1" ve "https://huggingface.co./pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor
+ Device: Cihaz
+ Youtube Link: Youtube Bağlantısı
+ Youtube Thumbnail: Youtube Küçük Resmi
+ Youtube Title: Youtube Başlığı
+ Youtube Description: Youtube Açıklaması
+ Record with Mic: Mikrofonla Kaydet
+ Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle
+ Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI)
+ Source Language: Kaynak Dil
+ Target Language: Hedef Dil
+ Pro User?: Pro Kullanıcı?
+ TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR
+ Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle
+ Instrumental: Enstrümantal
+ Vocals: Vokal
+ SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR
+
+eu: # Basque
+ Language: Hizkuntza
+ File: Fitxategia
+ Youtube: Youtube
+ Mic: Mik
+ T2T Translation: T2T Itzulketak
+ BGM Separation: BGM Bereizpena
+ GENERATE SUBTITLE FILE: SORTU AZPITITULU-FITXATEGIA
+ Output: Irteera
+ Downloadable output file: Irteera fitxategia deskargagarria
+ Upload File here: Igo fitxategia hemen
+ Model: Eredua
+ Automatic Detection: Detekzio Automatikoa
+ File Format: Fitxategi Formatua
+ Translate to English?: Itzuli ingelesera?
+ Add a timestamp to the end of the filename: Gehitu denbora marka fitxategi amaieran
+ Advanced Parameters: Aukera Aurreratuak
+ Background Music Remover Filter: Atzeko Musika Ezabatzeko Filtroa
+ Enabling this will remove background music: Hau aktibatuz, transkribatu aurretik atzeko musika ezabatuko zaio azpieredu baten bidez
+ Enable Background Music Remover Filter: Aktibatu Atzeko Musika Ezabatzeko Filtroa
+ Save separated files to output: Gorde fitxategi bereiziak irteeran
+ Offload sub model after removing background music: Atzeko musika kendu ondoren azpieredua memoriatik kendu
+ Voice Detection Filter: Ahots Detekzio Filtroa
+ Enable this to transcribe only detected voice: Aktibatu hau azpieredu batekin soilik detektatutako ahots zatiak transkribatzeko.
+ Enable Silero VAD Filter: Aktibatu Silero VAD Filtroa
+ Diarization: Diarizazioa
+ Enable Diarization: Aktibatu Diarizazioa
+ HuggingFace Token: HuggingFace Tokena
+ This is only needed the first time you download the model: Soilik beharrezkoa eredua deskargatzen den lehen aldian. Jadanik eredu hauek deskargatuta, ez duzu berriz ere sartu behar. Eredua deskargatzeko, joan hurrengo esteketara "https://huggingface.co./pyannote/speaker-diarization-3.1" eta "https://huggingface.co./pyannote/segmentation-3.0" eta onartu haien erabilpen-baldintzak.
+ Device: Gailua
+ Youtube Link: Youtube Esteka
+ Youtube Thumbnail: Youtube Irudia
+ Youtube Title: Youtube Izenburua
+ Youtube Description: Youtube Deskribapena
+ Record with Mic: Mik. bidez grabatu
+ Upload Subtitle Files to translate here: Igo hemen azpititulu fitxategia itzulketarako
+ Your Auth Key (API KEY): Zure Auth Giltza (API KEY)
+ Source Language: Jatorrizko Hizkuntza
+ Target Language: Helburuko Hizkuntza
+ Pro User?: Pro Erabiltzailea?
+ TRANSLATE SUBTITLE FILE: ITZULI AZPITITULU FITXATEGIA
+ Upload Audio Files to separate background music: Igo Audio Fitxategiak atzeko musika bereizteko
+ Instrumental: Instrumentala
+ Vocals: Ahotsak
+ SEPARATE BACKGROUND MUSIC: BEREIZI ATZEKO MUSIKA
\ No newline at end of file
diff --git a/docker-compose.yaml b/docker-compose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..42746f69ba18b10522240164575e9796e1c896d0
--- /dev/null
+++ b/docker-compose.yaml
@@ -0,0 +1,29 @@
+services:
+ app:
+ build: .
+ image: jhj0517/whisper-webui:latest
+
+ volumes:
+ # You can mount the container's volume paths to directory paths on your local machine.
+ # Models will be stored in the `./models' directory on your machine.
+ # Similarly, all output files will be stored in the `./outputs` directory.
+ - ./models:/Whisper-WebUI/models
+ - ./outputs:/Whisper-WebUI/outputs
+
+ ports:
+ - "7860:7860"
+
+ stdin_open: true
+ tty: true
+
+ entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",]
+
+ # If you're not using nvidia GPU, Update device to match yours.
+ # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [ gpu ]
\ No newline at end of file
diff --git a/models/Diarization/diarization_models_will_be_saved_here b/models/Diarization/diarization_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/NLLB/nllb_models_will_be_saved_here b/models/NLLB/nllb_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/UVR/uvr_models_will_be_saved_here b/models/UVR/uvr_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here b/models/Whisper/faster-whisper/faster_whisper_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here b/models/Whisper/insanely-fast-whisper/insanely_fast_whisper_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/Whisper/whisper_models_will_be_saved_here b/models/Whisper/whisper_models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/models/models_will_be_saved_here b/models/models_will_be_saved_here
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/__init__.py b/modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/diarize/__init__.py b/modules/diarize/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/diarize/audio_loader.py b/modules/diarize/audio_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..368dc20029a322870cc96d1b4e2942e79fd5ac22
--- /dev/null
+++ b/modules/diarize/audio_loader.py
@@ -0,0 +1,180 @@
+# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
+
+import os
+import subprocess
+from functools import lru_cache
+from typing import Optional, Union
+from scipy.io.wavfile import write
+import tempfile
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+def exact_div(x, y):
+ assert x % y == 0
+ return x // y
+
+# hard-coded audio hyperparameters
+SAMPLE_RATE = 16000
+N_FFT = 400
+HOP_LENGTH = 160
+CHUNK_LENGTH = 30
+N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
+N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
+
+N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
+FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
+TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
+
+
+def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
+ """
+ Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
+
+ Parameters
+ ----------
+ file: Union[str, np.ndarray]
+ The audio file to open or a numpy array containing the audio data.
+
+ sr: int
+ The sample rate to resample the audio if necessary.
+
+ Returns
+ -------
+ A NumPy array containing the audio waveform, in float32 dtype.
+ """
+ if isinstance(file, np.ndarray):
+ if file.dtype != np.float32:
+ file = file.astype(np.float32)
+ if file.ndim > 1:
+ file = np.mean(file, axis=1)
+
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
+ write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
+ temp_file_path = temp_file.name
+ temp_file.close()
+ else:
+ temp_file_path = file
+
+ try:
+ cmd = [
+ "ffmpeg",
+ "-nostdin",
+ "-threads",
+ "0",
+ "-i",
+ temp_file_path,
+ "-f",
+ "s16le",
+ "-ac",
+ "1",
+ "-acodec",
+ "pcm_s16le",
+ "-ar",
+ str(sr),
+ "-",
+ ]
+ out = subprocess.run(cmd, capture_output=True, check=True).stdout
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+ finally:
+ if isinstance(file, np.ndarray):
+ os.remove(temp_file_path)
+
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
+
+
+
+def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
+ """
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
+ """
+ if torch.is_tensor(array):
+ if array.shape[axis] > length:
+ array = array.index_select(
+ dim=axis, index=torch.arange(length, device=array.device)
+ )
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
+ else:
+ if array.shape[axis] > length:
+ array = array.take(indices=range(length), axis=axis)
+
+ if array.shape[axis] < length:
+ pad_widths = [(0, 0)] * array.ndim
+ pad_widths[axis] = (0, length - array.shape[axis])
+ array = np.pad(array, pad_widths)
+
+ return array
+
+
+@lru_cache(maxsize=None)
+def mel_filters(device, n_mels: int) -> torch.Tensor:
+ """
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
+ Allows decoupling librosa dependency; saved using:
+
+ np.savez_compressed(
+ "mel_filters.npz",
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
+ )
+ """
+ assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
+ with np.load(
+ os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
+ ) as f:
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
+
+
+def log_mel_spectrogram(
+ audio: Union[str, np.ndarray, torch.Tensor],
+ n_mels: int,
+ padding: int = 0,
+ device: Optional[Union[str, torch.device]] = None,
+):
+ """
+ Compute the log-Mel spectrogram of
+
+ Parameters
+ ----------
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
+
+ n_mels: int
+ The number of Mel-frequency filters, only 80 is supported
+
+ padding: int
+ Number of zero samples to pad to the right
+
+ device: Optional[Union[str, torch.device]]
+ If given, the audio tensor is moved to this device before STFT
+
+ Returns
+ -------
+ torch.Tensor, shape = (80, n_frames)
+ A Tensor that contains the Mel spectrogram
+ """
+ if not torch.is_tensor(audio):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio = torch.from_numpy(audio)
+
+ if device is not None:
+ audio = audio.to(device)
+ if padding > 0:
+ audio = F.pad(audio, (0, padding))
+ window = torch.hann_window(N_FFT).to(audio.device)
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
+ magnitudes = stft[..., :-1].abs() ** 2
+
+ filters = mel_filters(audio.device, n_mels)
+ mel_spec = filters @ magnitudes
+
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
+ log_spec = (log_spec + 4.0) / 4.0
+ return log_spec
\ No newline at end of file
diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ab9947b7e8f5314ce784e7e285733bb69f526b6
--- /dev/null
+++ b/modules/diarize/diarize_pipeline.py
@@ -0,0 +1,98 @@
+# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
+
+import numpy as np
+import pandas as pd
+import os
+from pyannote.audio import Pipeline
+from typing import Optional, Union
+import torch
+
+from modules.whisper.data_classes import *
+from modules.utils.paths import DIARIZATION_MODELS_DIR
+from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
+
+
+class DiarizationPipeline:
+ def __init__(
+ self,
+ model_name="pyannote/speaker-diarization-3.1",
+ cache_dir: str = DIARIZATION_MODELS_DIR,
+ use_auth_token=None,
+ device: Optional[Union[str, torch.device]] = "cpu",
+ ):
+ if isinstance(device, str):
+ device = torch.device(device)
+ self.model = Pipeline.from_pretrained(
+ model_name,
+ use_auth_token=use_auth_token,
+ cache_dir=cache_dir
+ ).to(device)
+
+ def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio_data = {
+ 'waveform': torch.from_numpy(audio[None, :]),
+ 'sample_rate': SAMPLE_RATE
+ }
+ segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
+ return diarize_df
+
+
+def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
+ transcript_segments = transcript_result["segments"]
+ if transcript_segments and isinstance(transcript_segments[0], Segment):
+ transcript_segments = [seg.model_dump() for seg in transcript_segments]
+ for seg in transcript_segments:
+ # assign speaker to segment (if any)
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
+ seg['start'])
+ diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
+
+ intersected = diarize_df[diarize_df["intersection"] > 0]
+
+ speaker = None
+ if len(intersected) > 0:
+ # Choosing most strong intersection
+ speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
+ elif fill_nearest:
+ # Otherwise choosing closest
+ speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
+
+ if speaker is not None:
+ seg["speaker"] = speaker
+
+ # assign speaker to words
+ if 'words' in seg and seg['words'] is not None:
+ for word in seg['words']:
+ if 'start' in word:
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
+ diarize_df['start'], word['start'])
+ diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
+ word['start'])
+
+ intersected = diarize_df[diarize_df["intersection"] > 0]
+
+ word_speaker = None
+ if len(intersected) > 0:
+ # Choosing most strong intersection
+ word_speaker = \
+ intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
+ elif fill_nearest:
+ # Otherwise choosing closest
+ word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
+
+ if word_speaker is not None:
+ word["speaker"] = word_speaker
+
+ return {"segments": transcript_segments}
+
+
+class DiarizationSegment:
+ def __init__(self, start, end, speaker=None):
+ self.start = start
+ self.end = end
+ self.speaker = speaker
diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..45293315c6c6e8e82a3b2b145b8ebbd7b9a1b757
--- /dev/null
+++ b/modules/diarize/diarizer.py
@@ -0,0 +1,153 @@
+import os
+import torch
+from typing import List, Union, BinaryIO, Optional, Tuple
+import numpy as np
+import time
+import logging
+import gc
+
+from modules.utils.paths import DIARIZATION_MODELS_DIR
+from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
+from modules.diarize.audio_loader import load_audio
+from modules.whisper.data_classes import *
+
+
+class Diarizer:
+ def __init__(self,
+ model_dir: str = DIARIZATION_MODELS_DIR
+ ):
+ self.device = self.get_device()
+ self.available_device = self.get_available_device()
+ self.compute_type = "float16"
+ self.model_dir = model_dir
+ os.makedirs(self.model_dir, exist_ok=True)
+ self.pipe = None
+
+ def run(self,
+ audio: Union[str, BinaryIO, np.ndarray],
+ transcribed_result: List[Segment],
+ use_auth_token: str,
+ device: Optional[str] = None
+ ) -> Tuple[List[Segment], float]:
+ """
+ Diarize transcribed result as a post-processing
+
+ Parameters
+ ----------
+ audio: Union[str, BinaryIO, np.ndarray]
+ Audio input. This can be file path or binary type.
+ transcribed_result: List[Segment]
+ transcribed result through whisper.
+ use_auth_token: str
+ Huggingface token with READ permission. This is only needed the first time you download the model.
+ You must manually go to the website https://huggingface.co./pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
+ device: Optional[str]
+ Device for diarization.
+
+ Returns
+ ----------
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
+ elapsed_time: float
+ elapsed time for running
+ """
+ start_time = time.time()
+
+ if device is None:
+ device = self.device
+
+ if device != self.device or self.pipe is None:
+ self.update_pipe(
+ device=device,
+ use_auth_token=use_auth_token
+ )
+
+ audio = load_audio(audio)
+
+ diarization_segments = self.pipe(audio)
+ diarized_result = assign_word_speakers(
+ diarization_segments,
+ {"segments": transcribed_result}
+ )
+
+ segments_result = []
+ for segment in diarized_result["segments"]:
+ speaker = "None"
+ if "speaker" in segment:
+ speaker = segment["speaker"]
+ diarized_text = speaker + "|" + segment["text"].strip()
+ segments_result.append(Segment(
+ start=segment["start"],
+ end=segment["end"],
+ text=diarized_text
+ ))
+
+ elapsed_time = time.time() - start_time
+ return segments_result, elapsed_time
+
+ def update_pipe(self,
+ use_auth_token: Optional[str] = None,
+ device: Optional[str] = None,
+ ):
+ """
+ Set pipeline for diarization
+
+ Parameters
+ ----------
+ use_auth_token: str
+ Huggingface token with READ permission. This is only needed the first time you download the model.
+ You must manually go to the website https://huggingface.co./pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
+ device: str
+ Device for diarization.
+ """
+ if device is None:
+ device = self.get_device()
+ self.device = device
+
+ os.makedirs(self.model_dir, exist_ok=True)
+
+ if (not os.listdir(self.model_dir) and
+ not use_auth_token):
+ print(
+ "\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
+ "Go to \"https://huggingface.co./pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
+ )
+ return
+
+ logger = logging.getLogger("speechbrain.utils.train_logger")
+ # Disable redundant torchvision warning message
+ logger.disabled = True
+ self.pipe = DiarizationPipeline(
+ use_auth_token=use_auth_token,
+ device=device,
+ cache_dir=self.model_dir
+ )
+ logger.disabled = False
+
+ def offload(self):
+ """Offload the model and free up the memory"""
+ if self.pipe is not None:
+ del self.pipe
+ self.pipe = None
+ if self.device == "cuda":
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+ gc.collect()
+
+ @staticmethod
+ def get_device():
+ if torch.cuda.is_available():
+ return "cuda"
+ elif torch.backends.mps.is_available():
+ return "mps"
+ else:
+ return "cpu"
+
+ @staticmethod
+ def get_available_device():
+ devices = ["cpu"]
+ if torch.cuda.is_available():
+ devices.append("cuda")
+ elif torch.backends.mps.is_available():
+ devices.append("mps")
+ return devices
\ No newline at end of file
diff --git a/modules/translation/__init__.py b/modules/translation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/translation/deepl_api.py b/modules/translation/deepl_api.py
new file mode 100644
index 0000000000000000000000000000000000000000..08adc3f96c20b3abc358a9087259d4ecdc4b63a9
--- /dev/null
+++ b/modules/translation/deepl_api.py
@@ -0,0 +1,217 @@
+import requests
+import time
+import os
+from datetime import datetime
+import gradio as gr
+
+from modules.utils.paths import TRANSLATION_OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH
+from modules.utils.constants import AUTOMATIC_DETECTION
+from modules.utils.subtitle_manager import *
+from modules.utils.files_manager import load_yaml, save_yaml
+
+"""
+This is written with reference to the DeepL API documentation.
+If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
+"""
+
+DEEPL_AVAILABLE_TARGET_LANGS = {
+ 'Bulgarian': 'BG',
+ 'Czech': 'CS',
+ 'Danish': 'DA',
+ 'German': 'DE',
+ 'Greek': 'EL',
+ 'English': 'EN',
+ 'English (British)': 'EN-GB',
+ 'English (American)': 'EN-US',
+ 'Spanish': 'ES',
+ 'Estonian': 'ET',
+ 'Finnish': 'FI',
+ 'French': 'FR',
+ 'Hungarian': 'HU',
+ 'Indonesian': 'ID',
+ 'Italian': 'IT',
+ 'Japanese': 'JA',
+ 'Korean': 'KO',
+ 'Lithuanian': 'LT',
+ 'Latvian': 'LV',
+ 'Norwegian (Bokmål)': 'NB',
+ 'Dutch': 'NL',
+ 'Polish': 'PL',
+ 'Portuguese': 'PT',
+ 'Portuguese (Brazilian)': 'PT-BR',
+ 'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
+ 'Romanian': 'RO',
+ 'Russian': 'RU',
+ 'Slovak': 'SK',
+ 'Slovenian': 'SL',
+ 'Swedish': 'SV',
+ 'Turkish': 'TR',
+ 'Ukrainian': 'UK',
+ 'Chinese (simplified)': 'ZH'
+}
+
+DEEPL_AVAILABLE_SOURCE_LANGS = {
+ AUTOMATIC_DETECTION: None,
+ 'Bulgarian': 'BG',
+ 'Czech': 'CS',
+ 'Danish': 'DA',
+ 'German': 'DE',
+ 'Greek': 'EL',
+ 'English': 'EN',
+ 'Spanish': 'ES',
+ 'Estonian': 'ET',
+ 'Finnish': 'FI',
+ 'French': 'FR',
+ 'Hungarian': 'HU',
+ 'Indonesian': 'ID',
+ 'Italian': 'IT',
+ 'Japanese': 'JA',
+ 'Korean': 'KO',
+ 'Lithuanian': 'LT',
+ 'Latvian': 'LV',
+ 'Norwegian (Bokmål)': 'NB',
+ 'Dutch': 'NL',
+ 'Polish': 'PL',
+ 'Portuguese (all Portuguese varieties mixed)': 'PT',
+ 'Romanian': 'RO',
+ 'Russian': 'RU',
+ 'Slovak': 'SK',
+ 'Slovenian': 'SL',
+ 'Swedish': 'SV',
+ 'Turkish': 'TR',
+ 'Ukrainian': 'UK',
+ 'Chinese': 'ZH'
+}
+
+
+class DeepLAPI:
+ def __init__(self,
+ output_dir: str = TRANSLATION_OUTPUT_DIR
+ ):
+ self.api_interval = 1
+ self.max_text_batch_size = 50
+ self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
+ self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
+ self.output_dir = output_dir
+
+ def translate_deepl(self,
+ auth_key: str,
+ fileobjs: list,
+ source_lang: str,
+ target_lang: str,
+ is_pro: bool = False,
+ add_timestamp: bool = True,
+ progress=gr.Progress()) -> list:
+ """
+ Translate subtitle files using DeepL API
+ Parameters
+ ----------
+ auth_key: str
+ API Key for DeepL from gr.Textbox()
+ fileobjs: list
+ List of files to transcribe from gr.Files()
+ source_lang: str
+ Source language of the file to transcribe from gr.Dropdown()
+ target_lang: str
+ Target language of the file to transcribe from gr.Dropdown()
+ is_pro: str
+ Boolean value that is about pro user or not from gr.Checkbox().
+ add_timestamp: bool
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
+ progress: gr.Progress
+ Indicator to show progress directly in gradio.
+
+ Returns
+ ----------
+ A List of
+ String to return to gr.Textbox()
+ Files to return to gr.Files()
+ """
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
+ fileobjs = [fileobj.name for fileobj in fileobjs]
+
+ self.cache_parameters(
+ api_key=auth_key,
+ is_pro=is_pro,
+ source_lang=source_lang,
+ target_lang=target_lang,
+ add_timestamp=add_timestamp
+ )
+
+ files_info = {}
+ for file_path in fileobjs:
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
+ writer = get_writer(file_ext, self.output_dir)
+ segments = writer.to_segments(file_path)
+
+ batch_size = self.max_text_batch_size
+ for batch_start in range(0, len(segments), batch_size):
+ progress(batch_start / len(segments), desc="Translating..")
+ sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
+ translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
+ target_lang, is_pro)
+ for i, translated_text in enumerate(translated_texts):
+ segments[batch_start + i].text = translated_text["text"]
+
+ subtitle, output_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_ext,
+ result=segments,
+ add_timestamp=add_timestamp
+ )
+
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
+
+ total_result = ''
+ for file_name, info in files_info.items():
+ total_result += '------------------------------------\n'
+ total_result += f'{file_name}\n\n'
+ total_result += f'{info["subtitle"]}'
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
+
+ output_file_paths = [item["path"] for key, item in files_info.items()]
+ return [gr_str, output_file_paths]
+
+ def request_deepl_translate(self,
+ auth_key: str,
+ text: list,
+ source_lang: str,
+ target_lang: str,
+ is_pro: bool = False):
+ """Request API response to DeepL server"""
+ if source_lang not in list(DEEPL_AVAILABLE_SOURCE_LANGS.keys()):
+ raise ValueError(f"Source language {source_lang} is not supported."
+ f"Use one of {list(DEEPL_AVAILABLE_SOURCE_LANGS.keys())}")
+ if target_lang not in list(DEEPL_AVAILABLE_TARGET_LANGS.keys()):
+ raise ValueError(f"Target language {target_lang} is not supported."
+ f"Use one of {list(DEEPL_AVAILABLE_TARGET_LANGS.keys())}")
+
+ url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
+ headers = {
+ 'Authorization': f'DeepL-Auth-Key {auth_key}'
+ }
+ data = {
+ 'text': text,
+ 'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
+ 'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
+ }
+ response = requests.post(url, headers=headers, data=data).json()
+ time.sleep(self.api_interval)
+ return response["translations"]
+
+ @staticmethod
+ def cache_parameters(api_key: str,
+ is_pro: bool,
+ source_lang: str,
+ target_lang: str,
+ add_timestamp: bool):
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
+ cached_params["translation"]["deepl"] = {
+ "api_key": api_key,
+ "is_pro": is_pro,
+ "source_lang": source_lang,
+ "target_lang": target_lang
+ }
+ cached_params["translation"]["add_timestamp"] = add_timestamp
+ save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
diff --git a/modules/translation/nllb_inference.py b/modules/translation/nllb_inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..80e2528ff6b5aaec8cbef8487954ae53f58afa08
--- /dev/null
+++ b/modules/translation/nllb_inference.py
@@ -0,0 +1,290 @@
+from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
+import gradio as gr
+import os
+
+from modules.utils.paths import TRANSLATION_OUTPUT_DIR, NLLB_MODELS_DIR
+from modules.translation.translation_base import TranslationBase
+
+
+class NLLBInference(TranslationBase):
+ def __init__(self,
+ model_dir: str = NLLB_MODELS_DIR,
+ output_dir: str = TRANSLATION_OUTPUT_DIR
+ ):
+ super().__init__(
+ model_dir=model_dir,
+ output_dir=output_dir
+ )
+ self.tokenizer = None
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
+ self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
+ self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
+ self.pipeline = None
+
+ def translate(self,
+ text: str,
+ max_length: int
+ ):
+ result = self.pipeline(
+ text,
+ max_length=max_length
+ )
+ return result[0]["translation_text"]
+
+ def update_model(self,
+ model_size: str,
+ src_lang: str,
+ tgt_lang: str,
+ progress: gr.Progress = gr.Progress()
+ ):
+ def validate_language(lang: str) -> str:
+ if lang in NLLB_AVAILABLE_LANGS:
+ return NLLB_AVAILABLE_LANGS[lang]
+ elif lang not in NLLB_AVAILABLE_LANGS.values():
+ raise ValueError(f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
+ return lang
+
+ src_lang = validate_language(src_lang)
+ tgt_lang = validate_language(tgt_lang)
+
+ if model_size != self.current_model_size or self.model is None:
+ print("\nInitializing NLLB Model..\n")
+ progress(0, desc="Initializing NLLB Model..")
+ self.current_model_size = model_size
+ local_files_only = self.is_model_exists(self.current_model_size)
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
+ cache_dir=self.model_dir,
+ local_files_only=local_files_only)
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
+ cache_dir=os.path.join(self.model_dir, "tokenizers"),
+ local_files_only=local_files_only)
+
+ self.pipeline = pipeline("translation",
+ model=self.model,
+ tokenizer=self.tokenizer,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ device=self.device)
+
+ def is_model_exists(self,
+ model_size: str):
+ """Check if model exists or not (Only facebook model)"""
+ prefix = "models--facebook--"
+ _id, model_size_name = model_size.split("/")
+ model_dir_name = prefix + model_size_name
+ model_dir_path = os.path.join(self.model_dir, model_dir_name)
+ if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
+ return True
+ for model_dir_name in os.listdir(self.model_dir):
+ if (model_size in model_dir_name or model_size_name in model_dir_name) and \
+ os.listdir(os.path.join(self.model_dir, model_dir_name)):
+ return True
+ return False
+
+
+NLLB_AVAILABLE_LANGS = {
+ "Acehnese (Arabic script)": "ace_Arab",
+ "Acehnese (Latin script)": "ace_Latn",
+ "Mesopotamian Arabic": "acm_Arab",
+ "Ta’izzi-Adeni Arabic": "acq_Arab",
+ "Tunisian Arabic": "aeb_Arab",
+ "Afrikaans": "afr_Latn",
+ "South Levantine Arabic": "ajp_Arab",
+ "Akan": "aka_Latn",
+ "Amharic": "amh_Ethi",
+ "North Levantine Arabic": "apc_Arab",
+ "Modern Standard Arabic": "arb_Arab",
+ "Modern Standard Arabic (Romanized)": "arb_Latn",
+ "Najdi Arabic": "ars_Arab",
+ "Moroccan Arabic": "ary_Arab",
+ "Egyptian Arabic": "arz_Arab",
+ "Assamese": "asm_Beng",
+ "Asturian": "ast_Latn",
+ "Awadhi": "awa_Deva",
+ "Central Aymara": "ayr_Latn",
+ "South Azerbaijani": "azb_Arab",
+ "North Azerbaijani": "azj_Latn",
+ "Bashkir": "bak_Cyrl",
+ "Bambara": "bam_Latn",
+ "Balinese": "ban_Latn",
+ "Belarusian": "bel_Cyrl",
+ "Bemba": "bem_Latn",
+ "Bengali": "ben_Beng",
+ "Bhojpuri": "bho_Deva",
+ "Banjar (Arabic script)": "bjn_Arab",
+ "Banjar (Latin script)": "bjn_Latn",
+ "Standard Tibetan": "bod_Tibt",
+ "Bosnian": "bos_Latn",
+ "Buginese": "bug_Latn",
+ "Bulgarian": "bul_Cyrl",
+ "Catalan": "cat_Latn",
+ "Cebuano": "ceb_Latn",
+ "Czech": "ces_Latn",
+ "Chokwe": "cjk_Latn",
+ "Central Kurdish": "ckb_Arab",
+ "Crimean Tatar": "crh_Latn",
+ "Welsh": "cym_Latn",
+ "Danish": "dan_Latn",
+ "German": "deu_Latn",
+ "Southwestern Dinka": "dik_Latn",
+ "Dyula": "dyu_Latn",
+ "Dzongkha": "dzo_Tibt",
+ "Greek": "ell_Grek",
+ "English": "eng_Latn",
+ "Esperanto": "epo_Latn",
+ "Estonian": "est_Latn",
+ "Basque": "eus_Latn",
+ "Ewe": "ewe_Latn",
+ "Faroese": "fao_Latn",
+ "Fijian": "fij_Latn",
+ "Finnish": "fin_Latn",
+ "Fon": "fon_Latn",
+ "French": "fra_Latn",
+ "Friulian": "fur_Latn",
+ "Nigerian Fulfulde": "fuv_Latn",
+ "Scottish Gaelic": "gla_Latn",
+ "Irish": "gle_Latn",
+ "Galician": "glg_Latn",
+ "Guarani": "grn_Latn",
+ "Gujarati": "guj_Gujr",
+ "Haitian Creole": "hat_Latn",
+ "Hausa": "hau_Latn",
+ "Hebrew": "heb_Hebr",
+ "Hindi": "hin_Deva",
+ "Chhattisgarhi": "hne_Deva",
+ "Croatian": "hrv_Latn",
+ "Hungarian": "hun_Latn",
+ "Armenian": "hye_Armn",
+ "Igbo": "ibo_Latn",
+ "Ilocano": "ilo_Latn",
+ "Indonesian": "ind_Latn",
+ "Icelandic": "isl_Latn",
+ "Italian": "ita_Latn",
+ "Javanese": "jav_Latn",
+ "Japanese": "jpn_Jpan",
+ "Kabyle": "kab_Latn",
+ "Jingpho": "kac_Latn",
+ "Kamba": "kam_Latn",
+ "Kannada": "kan_Knda",
+ "Kashmiri (Arabic script)": "kas_Arab",
+ "Kashmiri (Devanagari script)": "kas_Deva",
+ "Georgian": "kat_Geor",
+ "Central Kanuri (Arabic script)": "knc_Arab",
+ "Central Kanuri (Latin script)": "knc_Latn",
+ "Kazakh": "kaz_Cyrl",
+ "Kabiyè": "kbp_Latn",
+ "Kabuverdianu": "kea_Latn",
+ "Khmer": "khm_Khmr",
+ "Kikuyu": "kik_Latn",
+ "Kinyarwanda": "kin_Latn",
+ "Kyrgyz": "kir_Cyrl",
+ "Kimbundu": "kmb_Latn",
+ "Northern Kurdish": "kmr_Latn",
+ "Kikongo": "kon_Latn",
+ "Korean": "kor_Hang",
+ "Lao": "lao_Laoo",
+ "Ligurian": "lij_Latn",
+ "Limburgish": "lim_Latn",
+ "Lingala": "lin_Latn",
+ "Lithuanian": "lit_Latn",
+ "Lombard": "lmo_Latn",
+ "Latgalian": "ltg_Latn",
+ "Luxembourgish": "ltz_Latn",
+ "Luba-Kasai": "lua_Latn",
+ "Ganda": "lug_Latn",
+ "Luo": "luo_Latn",
+ "Mizo": "lus_Latn",
+ "Standard Latvian": "lvs_Latn",
+ "Magahi": "mag_Deva",
+ "Maithili": "mai_Deva",
+ "Malayalam": "mal_Mlym",
+ "Marathi": "mar_Deva",
+ "Minangkabau (Arabic script)": "min_Arab",
+ "Minangkabau (Latin script)": "min_Latn",
+ "Macedonian": "mkd_Cyrl",
+ "Plateau Malagasy": "plt_Latn",
+ "Maltese": "mlt_Latn",
+ "Meitei (Bengali script)": "mni_Beng",
+ "Halh Mongolian": "khk_Cyrl",
+ "Mossi": "mos_Latn",
+ "Maori": "mri_Latn",
+ "Burmese": "mya_Mymr",
+ "Dutch": "nld_Latn",
+ "Norwegian Nynorsk": "nno_Latn",
+ "Norwegian Bokmål": "nob_Latn",
+ "Nepali": "npi_Deva",
+ "Northern Sotho": "nso_Latn",
+ "Nuer": "nus_Latn",
+ "Nyanja": "nya_Latn",
+ "Occitan": "oci_Latn",
+ "West Central Oromo": "gaz_Latn",
+ "Odia": "ory_Orya",
+ "Pangasinan": "pag_Latn",
+ "Eastern Panjabi": "pan_Guru",
+ "Papiamento": "pap_Latn",
+ "Western Persian": "pes_Arab",
+ "Polish": "pol_Latn",
+ "Portuguese": "por_Latn",
+ "Dari": "prs_Arab",
+ "Southern Pashto": "pbt_Arab",
+ "Ayacucho Quechua": "quy_Latn",
+ "Romanian": "ron_Latn",
+ "Rundi": "run_Latn",
+ "Russian": "rus_Cyrl",
+ "Sango": "sag_Latn",
+ "Sanskrit": "san_Deva",
+ "Santali": "sat_Olck",
+ "Sicilian": "scn_Latn",
+ "Shan": "shn_Mymr",
+ "Sinhala": "sin_Sinh",
+ "Slovak": "slk_Latn",
+ "Slovenian": "slv_Latn",
+ "Samoan": "smo_Latn",
+ "Shona": "sna_Latn",
+ "Sindhi": "snd_Arab",
+ "Somali": "som_Latn",
+ "Southern Sotho": "sot_Latn",
+ "Spanish": "spa_Latn",
+ "Tosk Albanian": "als_Latn",
+ "Sardinian": "srd_Latn",
+ "Serbian": "srp_Cyrl",
+ "Swati": "ssw_Latn",
+ "Sundanese": "sun_Latn",
+ "Swedish": "swe_Latn",
+ "Swahili": "swh_Latn",
+ "Silesian": "szl_Latn",
+ "Tamil": "tam_Taml",
+ "Tatar": "tat_Cyrl",
+ "Telugu": "tel_Telu",
+ "Tajik": "tgk_Cyrl",
+ "Tagalog": "tgl_Latn",
+ "Thai": "tha_Thai",
+ "Tigrinya": "tir_Ethi",
+ "Tamasheq (Latin script)": "taq_Latn",
+ "Tamasheq (Tifinagh script)": "taq_Tfng",
+ "Tok Pisin": "tpi_Latn",
+ "Tswana": "tsn_Latn",
+ "Tsonga": "tso_Latn",
+ "Turkmen": "tuk_Latn",
+ "Tumbuka": "tum_Latn",
+ "Turkish": "tur_Latn",
+ "Twi": "twi_Latn",
+ "Central Atlas Tamazight": "tzm_Tfng",
+ "Uyghur": "uig_Arab",
+ "Ukrainian": "ukr_Cyrl",
+ "Umbundu": "umb_Latn",
+ "Urdu": "urd_Arab",
+ "Northern Uzbek": "uzn_Latn",
+ "Venetian": "vec_Latn",
+ "Vietnamese": "vie_Latn",
+ "Waray": "war_Latn",
+ "Wolof": "wol_Latn",
+ "Xhosa": "xho_Latn",
+ "Eastern Yiddish": "ydd_Hebr",
+ "Yoruba": "yor_Latn",
+ "Yue Chinese": "yue_Hant",
+ "Chinese (Simplified)": "zho_Hans",
+ "Chinese (Traditional)": "zho_Hant",
+ "Standard Malay": "zsm_Latn",
+ "Zulu": "zul_Latn",
+}
diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b64afab64733330e7a84802a705e6e5fb490f99
--- /dev/null
+++ b/modules/translation/translation_base.py
@@ -0,0 +1,185 @@
+import os
+import torch
+import gradio as gr
+from abc import ABC, abstractmethod
+import gc
+from typing import List
+from datetime import datetime
+
+import modules.translation.nllb_inference as nllb
+from modules.whisper.data_classes import *
+from modules.utils.subtitle_manager import *
+from modules.utils.files_manager import load_yaml, save_yaml
+from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
+
+
+class TranslationBase(ABC):
+ def __init__(self,
+ model_dir: str = NLLB_MODELS_DIR,
+ output_dir: str = TRANSLATION_OUTPUT_DIR
+ ):
+ super().__init__()
+ self.model = None
+ self.model_dir = model_dir
+ self.output_dir = output_dir
+ os.makedirs(self.model_dir, exist_ok=True)
+ os.makedirs(self.output_dir, exist_ok=True)
+ self.current_model_size = None
+ self.device = self.get_device()
+
+ @abstractmethod
+ def translate(self,
+ text: str,
+ max_length: int
+ ):
+ pass
+
+ @abstractmethod
+ def update_model(self,
+ model_size: str,
+ src_lang: str,
+ tgt_lang: str,
+ progress: gr.Progress = gr.Progress()
+ ):
+ pass
+
+ def translate_file(self,
+ fileobjs: list,
+ model_size: str,
+ src_lang: str,
+ tgt_lang: str,
+ max_length: int = 200,
+ add_timestamp: bool = True,
+ progress=gr.Progress()) -> list:
+ """
+ Translate subtitle file from source language to target language
+
+ Parameters
+ ----------
+ fileobjs: list
+ List of files to transcribe from gr.Files()
+ model_size: str
+ Whisper model size from gr.Dropdown()
+ src_lang: str
+ Source language of the file to translate from gr.Dropdown()
+ tgt_lang: str
+ Target language of the file to translate from gr.Dropdown()
+ max_length: int
+ Max length per line to translate
+ add_timestamp: bool
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
+ progress: gr.Progress
+ Indicator to show progress directly in gradio.
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
+
+ Returns
+ ----------
+ A List of
+ String to return to gr.Textbox()
+ Files to return to gr.Files()
+ """
+ try:
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
+ fileobjs = [file.name for file in fileobjs]
+
+ self.cache_parameters(model_size=model_size,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ max_length=max_length,
+ add_timestamp=add_timestamp)
+
+ self.update_model(model_size=model_size,
+ src_lang=src_lang,
+ tgt_lang=tgt_lang,
+ progress=progress)
+
+ files_info = {}
+ for fileobj in fileobjs:
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
+ writer = get_writer(file_ext, self.output_dir)
+ segments = writer.to_segments(fileobj)
+ for i, segment in enumerate(segments):
+ progress(i / len(segments), desc="Translating..")
+ translated_text = self.translate(segment.text, max_length=max_length)
+ segment.text = translated_text
+
+ subtitle, file_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_ext,
+ result=segments,
+ add_timestamp=add_timestamp
+ )
+
+ files_info[file_name] = {"subtitle": subtitle, "path": file_path}
+
+ total_result = ''
+ for file_name, info in files_info.items():
+ total_result += '------------------------------------\n'
+ total_result += f'{file_name}\n\n'
+ total_result += f'{info["subtitle"]}'
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
+
+ output_file_paths = [item["path"] for key, item in files_info.items()]
+ return [gr_str, output_file_paths]
+
+ except Exception as e:
+ print(f"Error translating file: {e}")
+ raise
+ finally:
+ self.release_cuda_memory()
+
+ def offload(self):
+ """Offload the model and free up the memory"""
+ if self.model is not None:
+ del self.model
+ self.model = None
+ if self.device == "cuda":
+ self.release_cuda_memory()
+ gc.collect()
+
+ @staticmethod
+ def get_device():
+ if torch.cuda.is_available():
+ return "cuda"
+ elif torch.backends.mps.is_available():
+ return "mps"
+ else:
+ return "cpu"
+
+ @staticmethod
+ def release_cuda_memory():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.reset_max_memory_allocated()
+
+ @staticmethod
+ def remove_input_files(file_paths: List[str]):
+ if not file_paths:
+ return
+
+ for file_path in file_paths:
+ if file_path and os.path.exists(file_path):
+ os.remove(file_path)
+
+ @staticmethod
+ def cache_parameters(model_size: str,
+ src_lang: str,
+ tgt_lang: str,
+ max_length: int,
+ add_timestamp: bool):
+ def validate_lang(lang: str):
+ if lang in list(nllb.NLLB_AVAILABLE_LANGS.values()):
+ flipped = {value: key for key, value in nllb.NLLB_AVAILABLE_LANGS.items()}
+ return flipped[lang]
+ return lang
+
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
+ cached_params["translation"]["nllb"] = {
+ "model_size": model_size,
+ "source_lang": validate_lang(src_lang),
+ "target_lang": validate_lang(tgt_lang),
+ "max_length": max_length,
+ }
+ cached_params["translation"]["add_timestamp"] = add_timestamp
+ save_yaml(cached_params, DEFAULT_PARAMETERS_CONFIG_PATH)
diff --git a/modules/ui/__init__.py b/modules/ui/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/modules/ui/htmls.py b/modules/ui/htmls.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc8bce9241a08e9b19028f98067bd05d07fbf8f4
--- /dev/null
+++ b/modules/ui/htmls.py
@@ -0,0 +1,97 @@
+CSS = """
+.bmc-button {
+ padding: 2px 5px;
+ border-radius: 5px;
+ background-color: #FF813F;
+ color: white;
+ box-shadow: 0px 1px 2px rgba(0, 0, 0, 0.3);
+ text-decoration: none;
+ display: inline-block;
+ font-size: 20px;
+ margin: 2px;
+ cursor: pointer;
+ -webkit-transition: background-color 0.3s ease;
+ -ms-transition: background-color 0.3s ease;
+ transition: background-color 0.3s ease;
+}
+.bmc-button:hover,
+.bmc-button:active,
+.bmc-button:focus {
+ background-color: #FF5633;
+}
+.markdown {
+ margin-bottom: 0;
+ padding-bottom: 0;
+}
+.tabs {
+ margin-top: 0;
+ padding-top: 0;
+}
+
+#md_project a {
+ color: black;
+ text-decoration: none;
+}
+#md_project a:hover {
+ text-decoration: underline;
+}
+"""
+
+MARKDOWN = """
+### [Whisper-WebUI](https://github.com/jhj0517/Whsiper-WebUI)
+"""
+
+
+NLLB_VRAM_TABLE = """
+
+
+
Model name | +Required VRAM | +
---|---|
nllb-200-3.3B | +~16GB | +
nllb-200-1.3B | +~8GB | +
nllb-200-distilled-600M | +~4GB | +
Note: Be mindful of your VRAM! The table above provides an approximate VRAM usage for each model.
+