Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use faster whisper directly instead of whisperX #243

Merged
merged 5 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@

#
Speaker Diarization pipeline based on OpenAI Whisper
I'd like to thank [@m-bain](https://github.com/m-bain) for Batched Whisper Inference, [@mu4farooqi](https://github.com/mu4farooqi) for punctuation realignment algorithm

<img src="https://github.blog/wp-content/uploads/2020/09/github-stars-logo_Color.png" alt="drawing" width="25"/> **Please, star the project on github (see top-right corner) if you appreciate my contribution to the community!**

## What is it
This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using WhisperX to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by WhisperX to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts.
This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using `ctc-forced-aligner` to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by `ctc-forced-aligner` to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts.


WhisperX and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later
Whisper and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later
## Installation
Python >= `3.10` is needed, `3.9` will work but you'll need to manually install the requirements one by one.

Expand Down
182 changes: 134 additions & 48 deletions Whisper_Transcription_+_NeMo_Diarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
},
"outputs": [],
"source": [
"!pip install git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n",
"!pip install git+https://github.com/SYSTRAN/faster-whisper.git ctranslate2==4.4.0\n",
"!pip install \"nemo-toolkit[asr]>=2.dev\"\n",
"!pip install --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n",
"!pip install git+https://github.com/MahmoudAshraf97/demucs.git\n",
"!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n",
"!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git"
]
Expand All @@ -56,7 +56,7 @@
"import re\n",
"import logging\n",
"import nltk\n",
"from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE\n",
"import faster_whisper\n",
"from ctc_forced_aligner import (\n",
" load_alignment_model,\n",
" generate_emissions,\n",
Expand Down Expand Up @@ -99,6 +99,127 @@
" \"sk\",\n",
" \"sl\",\n",
"]\n",
"\n",
"LANGUAGES = {\n",
" \"en\": \"english\",\n",
" \"zh\": \"chinese\",\n",
" \"de\": \"german\",\n",
" \"es\": \"spanish\",\n",
" \"ru\": \"russian\",\n",
" \"ko\": \"korean\",\n",
" \"fr\": \"french\",\n",
" \"ja\": \"japanese\",\n",
" \"pt\": \"portuguese\",\n",
" \"tr\": \"turkish\",\n",
" \"pl\": \"polish\",\n",
" \"ca\": \"catalan\",\n",
" \"nl\": \"dutch\",\n",
" \"ar\": \"arabic\",\n",
" \"sv\": \"swedish\",\n",
" \"it\": \"italian\",\n",
" \"id\": \"indonesian\",\n",
" \"hi\": \"hindi\",\n",
" \"fi\": \"finnish\",\n",
" \"vi\": \"vietnamese\",\n",
" \"he\": \"hebrew\",\n",
" \"uk\": \"ukrainian\",\n",
" \"el\": \"greek\",\n",
" \"ms\": \"malay\",\n",
" \"cs\": \"czech\",\n",
" \"ro\": \"romanian\",\n",
" \"da\": \"danish\",\n",
" \"hu\": \"hungarian\",\n",
" \"ta\": \"tamil\",\n",
" \"no\": \"norwegian\",\n",
" \"th\": \"thai\",\n",
" \"ur\": \"urdu\",\n",
" \"hr\": \"croatian\",\n",
" \"bg\": \"bulgarian\",\n",
" \"lt\": \"lithuanian\",\n",
" \"la\": \"latin\",\n",
" \"mi\": \"maori\",\n",
" \"ml\": \"malayalam\",\n",
" \"cy\": \"welsh\",\n",
" \"sk\": \"slovak\",\n",
" \"te\": \"telugu\",\n",
" \"fa\": \"persian\",\n",
" \"lv\": \"latvian\",\n",
" \"bn\": \"bengali\",\n",
" \"sr\": \"serbian\",\n",
" \"az\": \"azerbaijani\",\n",
" \"sl\": \"slovenian\",\n",
" \"kn\": \"kannada\",\n",
" \"et\": \"estonian\",\n",
" \"mk\": \"macedonian\",\n",
" \"br\": \"breton\",\n",
" \"eu\": \"basque\",\n",
" \"is\": \"icelandic\",\n",
" \"hy\": \"armenian\",\n",
" \"ne\": \"nepali\",\n",
" \"mn\": \"mongolian\",\n",
" \"bs\": \"bosnian\",\n",
" \"kk\": \"kazakh\",\n",
" \"sq\": \"albanian\",\n",
" \"sw\": \"swahili\",\n",
" \"gl\": \"galician\",\n",
" \"mr\": \"marathi\",\n",
" \"pa\": \"punjabi\",\n",
" \"si\": \"sinhala\",\n",
" \"km\": \"khmer\",\n",
" \"sn\": \"shona\",\n",
" \"yo\": \"yoruba\",\n",
" \"so\": \"somali\",\n",
" \"af\": \"afrikaans\",\n",
" \"oc\": \"occitan\",\n",
" \"ka\": \"georgian\",\n",
" \"be\": \"belarusian\",\n",
" \"tg\": \"tajik\",\n",
" \"sd\": \"sindhi\",\n",
" \"gu\": \"gujarati\",\n",
" \"am\": \"amharic\",\n",
" \"yi\": \"yiddish\",\n",
" \"lo\": \"lao\",\n",
" \"uz\": \"uzbek\",\n",
" \"fo\": \"faroese\",\n",
" \"ht\": \"haitian creole\",\n",
" \"ps\": \"pashto\",\n",
" \"tk\": \"turkmen\",\n",
" \"nn\": \"nynorsk\",\n",
" \"mt\": \"maltese\",\n",
" \"sa\": \"sanskrit\",\n",
" \"lb\": \"luxembourgish\",\n",
" \"my\": \"myanmar\",\n",
" \"bo\": \"tibetan\",\n",
" \"tl\": \"tagalog\",\n",
" \"mg\": \"malagasy\",\n",
" \"as\": \"assamese\",\n",
" \"tt\": \"tatar\",\n",
" \"haw\": \"hawaiian\",\n",
" \"ln\": \"lingala\",\n",
" \"ha\": \"hausa\",\n",
" \"ba\": \"bashkir\",\n",
" \"jw\": \"javanese\",\n",
" \"su\": \"sundanese\",\n",
" \"yue\": \"cantonese\",\n",
"}\n",
"\n",
"# language code lookup by name, with a few language aliases\n",
"TO_LANGUAGE_CODE = {\n",
" **{language: code for code, language in LANGUAGES.items()},\n",
" \"burmese\": \"my\",\n",
" \"valencian\": \"ca\",\n",
" \"flemish\": \"nl\",\n",
" \"haitian\": \"ht\",\n",
" \"letzeburgesch\": \"lb\",\n",
" \"pushto\": \"ps\",\n",
" \"panjabi\": \"pa\",\n",
" \"moldavian\": \"ro\",\n",
" \"moldovan\": \"ro\",\n",
" \"sinhalese\": \"si\",\n",
" \"castilian\": \"es\",\n",
"}\n",
"\n",
"\n",
"langs_to_iso = {\n",
" \"af\": \"afr\",\n",
" \"am\": \"amh\",\n",
Expand Down Expand Up @@ -564,32 +685,7 @@
" f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
" )\n",
" language = \"en\"\n",
" return language\n",
"\n",
"\n",
"def transcribe_batched(\n",
" audio_file: str,\n",
" language: str,\n",
" batch_size: int,\n",
" model_name: str,\n",
" compute_dtype: str,\n",
" suppress_numerals: bool,\n",
" device: str,\n",
"):\n",
" import whisperx\n",
"\n",
" # Faster Whisper batched\n",
" whisper_model = whisperx.load_model(\n",
" model_name,\n",
" device,\n",
" compute_type=compute_dtype,\n",
" asr_options={\"suppress_numerals\": suppress_numerals},\n",
" )\n",
" audio = whisperx.load_audio(audio_file)\n",
" result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
" del whisper_model\n",
" torch.cuda.empty_cache()\n",
" return result[\"segments\"], result[\"language\"], audio"
" return language"
]
},
{
Expand Down Expand Up @@ -693,7 +789,7 @@
"id": "UYg9VWb22Tz8"
},
"source": [
"## Transcriping audio using Whisper and realligning timestamps using Wav2Vec2\n",
"## Transcriping audio using Whisper and realligning timestamps using Forced Alignment\n",
"---\n",
"This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.\n",
"\n",
Expand Down Expand Up @@ -732,15 +828,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Aligning the transcription with the original audio using Wav2Vec2\n",
"## Aligning the transcription with the original audio using Forced Alignment\n",
"---\n",
"The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.\n",
"\n",
"The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
"Forced alignment aims to to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n",
"\n",
"By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.\n",
"\n",
"If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead."
"By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification."
]
},
{
Expand All @@ -754,11 +846,7 @@
" dtype=torch.float16 if device == \"cuda\" else torch.float32,\n",
")\n",
"\n",
"audio_waveform = (\n",
" torch.from_numpy(audio_waveform)\n",
" .to(alignment_model.dtype)\n",
" .to(alignment_model.device)\n",
")\n",
"audio_waveform = audio_waveform.to(alignment_model.dtype).to(alignment_model.device)\n",
"\n",
"emissions, stride = generate_emissions(\n",
" alignment_model, audio_waveform, batch_size=batch_size\n",
Expand All @@ -767,12 +855,10 @@
"del alignment_model\n",
"torch.cuda.empty_cache()\n",
"\n",
"full_transcript = \"\".join(segment[\"text\"] for segment in whisper_results)\n",
"\n",
"tokens_starred, text_starred = preprocess_text(\n",
" full_transcript,\n",
" romanize=True,\n",
" language=langs_to_iso[language],\n",
" language=langs_to_iso[info.language],\n",
")\n",
"\n",
"segments, scores, blank_token = get_alignments(\n",
Expand Down Expand Up @@ -903,13 +989,13 @@
},
"outputs": [],
"source": [
"if language in punct_model_langs:\n",
"if info.language in punct_model_langs:\n",
" # restoring punctuation in the transcript to help realign the sentences\n",
" punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
"\n",
" words_list = list(map(lambda x: x[\"word\"], wsm))\n",
"\n",
" labled_words = punct_model.predict(words_list,chunk_size=230)\n",
" labled_words = punct_model.predict(words_list, chunk_size=230)\n",
"\n",
" ending_puncts = \".?!\"\n",
" model_puncts = \".,;:!?\"\n",
Expand All @@ -931,7 +1017,7 @@
"\n",
"else:\n",
" logging.warning(\n",
" f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
" f\"Punctuation restoration is not available for {info.language} language. Using the original punctuation.\"\n",
" )\n",
"\n",
"wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
Expand Down
Loading