From 35e39fd848488d59aac94f7b03a183e952b5198d Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Thu, 10 Oct 2024 20:01:43 +0300 Subject: [PATCH 1/5] initial commit --- README.md | 4 +- diarize.py | 63 +++++++++++++------- diarize_parallel.py | 67 ++++++++++++++------- helpers.py | 125 +++++++++++++++++++++++++++++++++++++-- requirements.txt | 3 +- transcription_helpers.py | 24 ++++---- 6 files changed, 222 insertions(+), 64 deletions(-) diff --git a/README.md b/README.md index c813eeb..ea3d56a 100644 --- a/README.md +++ b/README.md @@ -34,10 +34,10 @@ I'd like to thank [@m-bain](https://github.com/m-bain) for Batched Whisper Infer drawing **Please, star the project on github (see top-right corner) if you appreciate my contribution to the community!** ## What is it -This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using WhisperX to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by WhisperX to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts. +This repository combines Whisper ASR capabilities with Voice Activity Detection (VAD) and Speaker Embedding to identify the speaker for each sentence in the transcription generated by Whisper. First, the vocals are extracted from the audio to increase the speaker embedding accuracy, then the transcription is generated using Whisper, then the timestamps are corrected and aligned using `ctc-forced-aligner` to help minimize diarization error due to time shift. The audio is then passed into MarbleNet for VAD and segmentation to exclude silences, TitaNet is then used to extract speaker embeddings to identify the speaker for each segment, the result is then associated with the timestamps generated by `ctc-forced-aligner` to detect the speaker for each word based on timestamps and then realigned using punctuation models to compensate for minor time shifts. -WhisperX and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later +Whisper and NeMo parameters are coded into diarize.py and helpers.py, I will add the CLI arguments to change them later ## Installation Python >= `3.10` is needed, `3.9` will work but you'll need to manually install the requirements one by one. diff --git a/diarize.py b/diarize.py index 2ba24af..0863d03 100644 --- a/diarize.py +++ b/diarize.py @@ -3,6 +3,7 @@ import os import re +import faster_whisper import torch import torchaudio from ctc_forced_aligner import ( @@ -19,6 +20,7 @@ from helpers import ( cleanup, create_config, + find_numeral_symbol_tokens, get_realigned_ws_mapping_with_punctuation, get_sentences_speaker_mapping, get_speaker_aware_transcript, @@ -29,7 +31,6 @@ whisper_langs, write_srt, ) -from transcription_helpers import transcribe_batched mtypes = {"cpu": "int8", "cuda": "float16"} @@ -114,15 +115,42 @@ # Transcribe the audio file -whisper_results, language, audio_waveform = transcribe_batched( - vocal_target, - language, - args.batch_size, - args.model_name, - mtypes[args.device], - args.suppress_numerals, - args.device, +whisper_model = faster_whisper.WhisperModel( + args.model_name, device=args.device, compute_type=mtypes[args.device] ) +whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) +audio_waveform = faster_whisper.decode_audio(vocal_target) + +if args.batch_size > 0: + transcript_segments, info = whisper_pipeline.transcribe( + audio_waveform, + language, + suppress_tokens=( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] + ), + batch_size=args.batch_size, + without_timestamps=True, + ) +else: + transcript_segments, info = whisper_model.transcribe( + audio_waveform, + language, + suppress_tokens=( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] + ), + without_timestamps=True, + vad_filter=True, + ) + +full_transcript = "".join(segment.text for segment in transcript_segments) + +# clear gpu vram +del whisper_model, whisper_pipeline +torch.cuda.empty_cache() # Forced Alignment alignment_model, alignment_tokenizer = load_alignment_model( @@ -130,24 +158,19 @@ dtype=torch.float16 if args.device == "cuda" else torch.float32, ) -audio_waveform = ( - torch.from_numpy(audio_waveform) - .to(alignment_model.dtype) - .to(alignment_model.device) -) emissions, stride = generate_emissions( - alignment_model, audio_waveform, batch_size=args.batch_size + alignment_model, + audio_waveform.to(alignment_model.dtype).to(alignment_model.device), + batch_size=args.batch_size, ) del alignment_model torch.cuda.empty_cache() -full_transcript = "".join(segment["text"] for segment in whisper_results) - tokens_starred, text_starred = preprocess_text( full_transcript, romanize=True, - language=langs_to_iso[language], + language=langs_to_iso[info.language], ) segments, scores, blank_token = get_alignments( @@ -194,7 +217,7 @@ wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start") -if language in punct_model_langs: +if info.language in punct_model_langs: # restoring punctuation in the transcript to help realign the sentences punct_model = PunctuationModel(model="kredor/punctuate-all") @@ -222,7 +245,7 @@ else: logging.warning( - f"Punctuation restoration is not available for {language} language. Using the original punctuation." + f"Punctuation restoration is not available for {info.language} language. Using the original punctuation." ) wsm = get_realigned_ws_mapping_with_punctuation(wsm) diff --git a/diarize_parallel.py b/diarize_parallel.py index 47d1579..0277889 100644 --- a/diarize_parallel.py +++ b/diarize_parallel.py @@ -4,6 +4,7 @@ import re import subprocess +import faster_whisper import torch from ctc_forced_aligner import ( generate_emissions, @@ -17,6 +18,7 @@ from helpers import ( cleanup, + find_numeral_symbol_tokens, get_realigned_ws_mapping_with_punctuation, get_sentences_speaker_mapping, get_speaker_aware_transcript, @@ -27,7 +29,6 @@ whisper_langs, write_srt, ) -from transcription_helpers import transcribe_batched mtypes = {"cpu": "int8", "cuda": "float16"} @@ -115,15 +116,44 @@ stderr=subprocess.PIPE, ) # Transcribe the audio file -whisper_results, language, audio_waveform = transcribe_batched( - vocal_target, - language, - args.batch_size, - args.model_name, - mtypes[args.device], - args.suppress_numerals, - args.device, -) + +whisper_model = faster_whisper.WhisperModel( + args.model_name, device=args.device, compute_type=mtypes[args.device] +) +whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) +audio_waveform = faster_whisper.decode_audio(vocal_target) + +if args.batch_size > 0: + transcript_segments, info = whisper_pipeline.transcribe( + audio_waveform, + language, + suppress_tokens=( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] + ), + batch_size=args.batch_size, + without_timestamps=True, + ) +else: + transcript_segments, info = whisper_model.transcribe( + audio_waveform, + language, + suppress_tokens=( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] + ), + without_timestamps=True, + vad_filter=True, + ) + +full_transcript = "".join(segment.text for segment in transcript_segments) + +# clear gpu vram +del whisper_model, whisper_pipeline +torch.cuda.empty_cache() + # Forced Alignment alignment_model, alignment_tokenizer = load_alignment_model( @@ -131,24 +161,19 @@ dtype=torch.float16 if args.device == "cuda" else torch.float32, ) -audio_waveform = ( - torch.from_numpy(audio_waveform) - .to(alignment_model.dtype) - .to(alignment_model.device) -) emissions, stride = generate_emissions( - alignment_model, audio_waveform, batch_size=args.batch_size + alignment_model, + audio_waveform.to(alignment_model.dtype).to(alignment_model.device), + batch_size=args.batch_size, ) del alignment_model torch.cuda.empty_cache() -full_transcript = "".join(segment["text"] for segment in whisper_results) - tokens_starred, text_starred = preprocess_text( full_transcript, romanize=True, - language=langs_to_iso[language], + language=langs_to_iso[info.language], ) segments, scores, blank_token = get_alignments( @@ -184,7 +209,7 @@ wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start") -if language in punct_model_langs: +if info.language in punct_model_langs: # restoring punctuation in the transcript to help realign the sentences punct_model = PunctuationModel(model="kredor/punctuate-all") @@ -212,7 +237,7 @@ else: logging.warning( - f"Punctuation restoration is not available for {language} language. Using the original punctuation." + f"Punctuation restoration is not available for {info.language} language. Using the original punctuation." ) wsm = get_realigned_ws_mapping_with_punctuation(wsm) diff --git a/helpers.py b/helpers.py index 7e34d14..01c58fe 100644 --- a/helpers.py +++ b/helpers.py @@ -1,13 +1,10 @@ import json -import logging import os import shutil import nltk import wget from omegaconf import OmegaConf -from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH -from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE punct_model_langs = [ "en", @@ -23,9 +20,125 @@ "sk", "sl", ] -wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list( - DEFAULT_ALIGN_MODELS_HF.keys() -) + +LANGUAGES = { + "en": "english", + "zh": "chinese", + "de": "german", + "es": "spanish", + "ru": "russian", + "ko": "korean", + "fr": "french", + "ja": "japanese", + "pt": "portuguese", + "tr": "turkish", + "pl": "polish", + "ca": "catalan", + "nl": "dutch", + "ar": "arabic", + "sv": "swedish", + "it": "italian", + "id": "indonesian", + "hi": "hindi", + "fi": "finnish", + "vi": "vietnamese", + "he": "hebrew", + "uk": "ukrainian", + "el": "greek", + "ms": "malay", + "cs": "czech", + "ro": "romanian", + "da": "danish", + "hu": "hungarian", + "ta": "tamil", + "no": "norwegian", + "th": "thai", + "ur": "urdu", + "hr": "croatian", + "bg": "bulgarian", + "lt": "lithuanian", + "la": "latin", + "mi": "maori", + "ml": "malayalam", + "cy": "welsh", + "sk": "slovak", + "te": "telugu", + "fa": "persian", + "lv": "latvian", + "bn": "bengali", + "sr": "serbian", + "az": "azerbaijani", + "sl": "slovenian", + "kn": "kannada", + "et": "estonian", + "mk": "macedonian", + "br": "breton", + "eu": "basque", + "is": "icelandic", + "hy": "armenian", + "ne": "nepali", + "mn": "mongolian", + "bs": "bosnian", + "kk": "kazakh", + "sq": "albanian", + "sw": "swahili", + "gl": "galician", + "mr": "marathi", + "pa": "punjabi", + "si": "sinhala", + "km": "khmer", + "sn": "shona", + "yo": "yoruba", + "so": "somali", + "af": "afrikaans", + "oc": "occitan", + "ka": "georgian", + "be": "belarusian", + "tg": "tajik", + "sd": "sindhi", + "gu": "gujarati", + "am": "amharic", + "yi": "yiddish", + "lo": "lao", + "uz": "uzbek", + "fo": "faroese", + "ht": "haitian creole", + "ps": "pashto", + "tk": "turkmen", + "nn": "nynorsk", + "mt": "maltese", + "sa": "sanskrit", + "lb": "luxembourgish", + "my": "myanmar", + "bo": "tibetan", + "tl": "tagalog", + "mg": "malagasy", + "as": "assamese", + "tt": "tatar", + "haw": "hawaiian", + "ln": "lingala", + "ha": "hausa", + "ba": "bashkir", + "jw": "javanese", + "su": "sundanese", + "yue": "cantonese", +} + +# language code lookup by name, with a few language aliases +TO_LANGUAGE_CODE = { + **{language: code for code, language in LANGUAGES.items()}, + "burmese": "my", + "valencian": "ca", + "flemish": "nl", + "haitian": "ht", + "letzeburgesch": "lb", + "pushto": "ps", + "panjabi": "pa", + "moldavian": "ro", + "moldovan": "ro", + "sinhalese": "si", + "castilian": "es", +} whisper_langs = sorted(LANGUAGES.keys()) + sorted( [k.title() for k in TO_LANGUAGE_CODE.keys()] diff --git a/requirements.txt b/requirements.txt index 551be3f..ccf47ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ wget nemo_toolkit[asr]==2.0.0rc0 -git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560 +nltk +git+https://github.com/MahmoudAshraf97/faster-whisper.git@same_vad git+https://github.com/MahmoudAshraf97/demucs.git git+https://github.com/oliverguhr/deepmultilingualpunctuation.git git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git diff --git a/transcription_helpers.py b/transcription_helpers.py index 2e6e1b6..fb05c9a 100644 --- a/transcription_helpers.py +++ b/transcription_helpers.py @@ -8,15 +8,18 @@ def transcribe( compute_dtype: str, suppress_numerals: bool, device: str, + batch_size: int, ): - from faster_whisper import WhisperModel + import faster_whisper - from helpers import find_numeral_symbol_tokens, wav2vec2_langs + from helpers import find_numeral_symbol_tokens # Faster Whisper non-batched # Run on GPU with FP16 - whisper_model = WhisperModel(model_name, device=device, compute_type=compute_dtype) - + whisper_model = faster_whisper.WhisperModel( + model_name, device=device, compute_type=compute_dtype + ) + whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) # or run on GPU with INT8 # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") # or run on CPU with INT8 @@ -25,20 +28,13 @@ def transcribe( if suppress_numerals: numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer) else: - numeral_symbol_tokens = None - - if language is not None and language in wav2vec2_langs: - word_timestamps = False - else: - word_timestamps = True + numeral_symbol_tokens = [-1] - segments, info = whisper_model.transcribe( + segments, info = whisper_pipeline.transcribe( audio_file, language=language, - beam_size=5, - word_timestamps=word_timestamps, # TODO: disable this if the language is supported by wav2vec2 suppress_tokens=numeral_symbol_tokens, - vad_filter=True, + batch_size=batch_size, ) whisper_results = [] for segment in segments: From f1aeb6ca3ec8d4c70c5ecbfa754c97eb8fb23789 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 25 Oct 2024 11:17:07 +0300 Subject: [PATCH 2/5] switch to main faster whisper repo --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ccf47ea..a93fa06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ wget nemo_toolkit[asr]==2.0.0rc0 nltk -git+https://github.com/MahmoudAshraf97/faster-whisper.git@same_vad +git+https://github.com/SYSTRAN/faster-whisper.git git+https://github.com/MahmoudAshraf97/demucs.git git+https://github.com/oliverguhr/deepmultilingualpunctuation.git git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git From 56bd834ac9f4fe10dc48ba41c45b814f02203101 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 25 Oct 2024 11:17:43 +0300 Subject: [PATCH 3/5] cleanup and change notebook code --- ...per_Transcription_+_NeMo_Diarization.ipynb | 170 +++++++++++++----- diarize.py | 29 +-- diarize_parallel.py | 33 ++-- helpers.py | 98 +--------- nemo_process.py | 1 + 5 files changed, 171 insertions(+), 160 deletions(-) diff --git a/Whisper_Transcription_+_NeMo_Diarization.ipynb b/Whisper_Transcription_+_NeMo_Diarization.ipynb index cffbef2..f2ac775 100644 --- a/Whisper_Transcription_+_NeMo_Diarization.ipynb +++ b/Whisper_Transcription_+_NeMo_Diarization.ipynb @@ -29,9 +29,9 @@ }, "outputs": [], "source": [ - "!pip install git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n", + "!pip install git+https://github.com/SYSTRAN/faster-whisper.git ctranslate2==4.4.0\n", "!pip install \"nemo-toolkit[asr]>=2.dev\"\n", - "!pip install --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n", + "!pip install git+https://github.com/MahmoudAshraf97/demucs.git\n", "!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n", "!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git" ] @@ -56,7 +56,7 @@ "import re\n", "import logging\n", "import nltk\n", - "from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE\n", + "import faster_whisper\n", "from ctc_forced_aligner import (\n", " load_alignment_model,\n", " generate_emissions,\n", @@ -99,6 +99,127 @@ " \"sk\",\n", " \"sl\",\n", "]\n", + "\n", + "LANGUAGES = {\n", + " \"en\": \"english\",\n", + " \"zh\": \"chinese\",\n", + " \"de\": \"german\",\n", + " \"es\": \"spanish\",\n", + " \"ru\": \"russian\",\n", + " \"ko\": \"korean\",\n", + " \"fr\": \"french\",\n", + " \"ja\": \"japanese\",\n", + " \"pt\": \"portuguese\",\n", + " \"tr\": \"turkish\",\n", + " \"pl\": \"polish\",\n", + " \"ca\": \"catalan\",\n", + " \"nl\": \"dutch\",\n", + " \"ar\": \"arabic\",\n", + " \"sv\": \"swedish\",\n", + " \"it\": \"italian\",\n", + " \"id\": \"indonesian\",\n", + " \"hi\": \"hindi\",\n", + " \"fi\": \"finnish\",\n", + " \"vi\": \"vietnamese\",\n", + " \"he\": \"hebrew\",\n", + " \"uk\": \"ukrainian\",\n", + " \"el\": \"greek\",\n", + " \"ms\": \"malay\",\n", + " \"cs\": \"czech\",\n", + " \"ro\": \"romanian\",\n", + " \"da\": \"danish\",\n", + " \"hu\": \"hungarian\",\n", + " \"ta\": \"tamil\",\n", + " \"no\": \"norwegian\",\n", + " \"th\": \"thai\",\n", + " \"ur\": \"urdu\",\n", + " \"hr\": \"croatian\",\n", + " \"bg\": \"bulgarian\",\n", + " \"lt\": \"lithuanian\",\n", + " \"la\": \"latin\",\n", + " \"mi\": \"maori\",\n", + " \"ml\": \"malayalam\",\n", + " \"cy\": \"welsh\",\n", + " \"sk\": \"slovak\",\n", + " \"te\": \"telugu\",\n", + " \"fa\": \"persian\",\n", + " \"lv\": \"latvian\",\n", + " \"bn\": \"bengali\",\n", + " \"sr\": \"serbian\",\n", + " \"az\": \"azerbaijani\",\n", + " \"sl\": \"slovenian\",\n", + " \"kn\": \"kannada\",\n", + " \"et\": \"estonian\",\n", + " \"mk\": \"macedonian\",\n", + " \"br\": \"breton\",\n", + " \"eu\": \"basque\",\n", + " \"is\": \"icelandic\",\n", + " \"hy\": \"armenian\",\n", + " \"ne\": \"nepali\",\n", + " \"mn\": \"mongolian\",\n", + " \"bs\": \"bosnian\",\n", + " \"kk\": \"kazakh\",\n", + " \"sq\": \"albanian\",\n", + " \"sw\": \"swahili\",\n", + " \"gl\": \"galician\",\n", + " \"mr\": \"marathi\",\n", + " \"pa\": \"punjabi\",\n", + " \"si\": \"sinhala\",\n", + " \"km\": \"khmer\",\n", + " \"sn\": \"shona\",\n", + " \"yo\": \"yoruba\",\n", + " \"so\": \"somali\",\n", + " \"af\": \"afrikaans\",\n", + " \"oc\": \"occitan\",\n", + " \"ka\": \"georgian\",\n", + " \"be\": \"belarusian\",\n", + " \"tg\": \"tajik\",\n", + " \"sd\": \"sindhi\",\n", + " \"gu\": \"gujarati\",\n", + " \"am\": \"amharic\",\n", + " \"yi\": \"yiddish\",\n", + " \"lo\": \"lao\",\n", + " \"uz\": \"uzbek\",\n", + " \"fo\": \"faroese\",\n", + " \"ht\": \"haitian creole\",\n", + " \"ps\": \"pashto\",\n", + " \"tk\": \"turkmen\",\n", + " \"nn\": \"nynorsk\",\n", + " \"mt\": \"maltese\",\n", + " \"sa\": \"sanskrit\",\n", + " \"lb\": \"luxembourgish\",\n", + " \"my\": \"myanmar\",\n", + " \"bo\": \"tibetan\",\n", + " \"tl\": \"tagalog\",\n", + " \"mg\": \"malagasy\",\n", + " \"as\": \"assamese\",\n", + " \"tt\": \"tatar\",\n", + " \"haw\": \"hawaiian\",\n", + " \"ln\": \"lingala\",\n", + " \"ha\": \"hausa\",\n", + " \"ba\": \"bashkir\",\n", + " \"jw\": \"javanese\",\n", + " \"su\": \"sundanese\",\n", + " \"yue\": \"cantonese\",\n", + "}\n", + "\n", + "# language code lookup by name, with a few language aliases\n", + "TO_LANGUAGE_CODE = {\n", + " **{language: code for code, language in LANGUAGES.items()},\n", + " \"burmese\": \"my\",\n", + " \"valencian\": \"ca\",\n", + " \"flemish\": \"nl\",\n", + " \"haitian\": \"ht\",\n", + " \"letzeburgesch\": \"lb\",\n", + " \"pushto\": \"ps\",\n", + " \"panjabi\": \"pa\",\n", + " \"moldavian\": \"ro\",\n", + " \"moldovan\": \"ro\",\n", + " \"sinhalese\": \"si\",\n", + " \"castilian\": \"es\",\n", + "}\n", + "\n", + "\n", "langs_to_iso = {\n", " \"af\": \"afr\",\n", " \"am\": \"amh\",\n", @@ -564,32 +685,7 @@ " f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n", " )\n", " language = \"en\"\n", - " return language\n", - "\n", - "\n", - "def transcribe_batched(\n", - " audio_file: str,\n", - " language: str,\n", - " batch_size: int,\n", - " model_name: str,\n", - " compute_dtype: str,\n", - " suppress_numerals: bool,\n", - " device: str,\n", - "):\n", - " import whisperx\n", - "\n", - " # Faster Whisper batched\n", - " whisper_model = whisperx.load_model(\n", - " model_name,\n", - " device,\n", - " compute_type=compute_dtype,\n", - " asr_options={\"suppress_numerals\": suppress_numerals},\n", - " )\n", - " audio = whisperx.load_audio(audio_file)\n", - " result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n", - " del whisper_model\n", - " torch.cuda.empty_cache()\n", - " return result[\"segments\"], result[\"language\"], audio" + " return language" ] }, { @@ -754,11 +850,7 @@ " dtype=torch.float16 if device == \"cuda\" else torch.float32,\n", ")\n", "\n", - "audio_waveform = (\n", - " torch.from_numpy(audio_waveform)\n", - " .to(alignment_model.dtype)\n", - " .to(alignment_model.device)\n", - ")\n", + "audio_waveform = audio_waveform.to(alignment_model.dtype).to(alignment_model.device)\n", "\n", "emissions, stride = generate_emissions(\n", " alignment_model, audio_waveform, batch_size=batch_size\n", @@ -767,12 +859,10 @@ "del alignment_model\n", "torch.cuda.empty_cache()\n", "\n", - "full_transcript = \"\".join(segment[\"text\"] for segment in whisper_results)\n", - "\n", "tokens_starred, text_starred = preprocess_text(\n", " full_transcript,\n", " romanize=True,\n", - " language=langs_to_iso[language],\n", + " language=langs_to_iso[info.language],\n", ")\n", "\n", "segments, scores, blank_token = get_alignments(\n", @@ -903,13 +993,13 @@ }, "outputs": [], "source": [ - "if language in punct_model_langs:\n", + "if info.language in punct_model_langs:\n", " # restoring punctuation in the transcript to help realign the sentences\n", " punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n", "\n", " words_list = list(map(lambda x: x[\"word\"], wsm))\n", "\n", - " labled_words = punct_model.predict(words_list,chunk_size=230)\n", + " labled_words = punct_model.predict(words_list, chunk_size=230)\n", "\n", " ending_puncts = \".?!\"\n", " model_puncts = \".,;:!?\"\n", @@ -931,7 +1021,7 @@ "\n", "else:\n", " logging.warning(\n", - " f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n", + " f\"Punctuation restoration is not available for {info.language} language. Using the original punctuation.\"\n", " )\n", "\n", "wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n", diff --git a/diarize.py b/diarize.py index 0863d03..9be939f 100644 --- a/diarize.py +++ b/diarize.py @@ -6,6 +6,7 @@ import faster_whisper import torch import torchaudio + from ctc_forced_aligner import ( generate_emissions, get_alignments, @@ -69,7 +70,8 @@ type=int, dest="batch_size", default=8, - help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference", + help="Batch size for batched inference, reduce if you run out of memory, " + "set to 0 for original whisper longform inference", ) parser.add_argument( @@ -94,12 +96,13 @@ # Isolate vocals from the rest of the audio return_code = os.system( - f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"' + f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o temp_outputs' ) if return_code != 0: logging.warning( - "Source splitting failed, using original audio file. Use --no-stem argument to disable it." + "Source splitting failed, using original audio file. " + "Use --no-stem argument to disable it." ) vocal_target = args.audio else: @@ -120,16 +123,17 @@ ) whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) audio_waveform = faster_whisper.decode_audio(vocal_target) +suppress_tokens = ( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] +) if args.batch_size > 0: transcript_segments, info = whisper_pipeline.transcribe( audio_waveform, language, - suppress_tokens=( - find_numeral_symbol_tokens(whisper_model.hf_tokenizer) - if args.suppress_numerals - else [-1] - ), + suppress_tokens=suppress_tokens, batch_size=args.batch_size, without_timestamps=True, ) @@ -137,11 +141,7 @@ transcript_segments, info = whisper_model.transcribe( audio_waveform, language, - suppress_tokens=( - find_numeral_symbol_tokens(whisper_model.hf_tokenizer) - if args.suppress_numerals - else [-1] - ), + suppress_tokens=suppress_tokens, without_timestamps=True, vad_filter=True, ) @@ -245,7 +245,8 @@ else: logging.warning( - f"Punctuation restoration is not available for {info.language} language. Using the original punctuation." + f"Punctuation restoration is not available for {info.language} language." + " Using the original punctuation." ) wsm = get_realigned_ws_mapping_with_punctuation(wsm) diff --git a/diarize_parallel.py b/diarize_parallel.py index 0277889..1fe589b 100644 --- a/diarize_parallel.py +++ b/diarize_parallel.py @@ -6,6 +6,7 @@ import faster_whisper import torch + from ctc_forced_aligner import ( generate_emissions, get_alignments, @@ -58,7 +59,7 @@ parser.add_argument( "--whisper-model", dest="model_name", - default="medium.en", + default="large-v2", help="name of the Whisper model to use", ) @@ -66,8 +67,9 @@ "--batch-size", type=int, dest="batch_size", - default=8, - help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference", + default=4, + help="Batch size for batched inference, reduce if you run out of memory, " + "set to 0 for original whisper longform inference", ) parser.add_argument( @@ -92,12 +94,13 @@ # Isolate vocals from the rest of the audio return_code = os.system( - f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"' + f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o temp_outputs' ) if return_code != 0: logging.warning( - "Source splitting failed, using original audio file. Use --no-stem argument to disable it." + "Source splitting failed, using original audio file. " + "Use --no-stem argument to disable it." ) vocal_target = args.audio else: @@ -122,16 +125,17 @@ ) whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) audio_waveform = faster_whisper.decode_audio(vocal_target) +suppress_tokens = ( + find_numeral_symbol_tokens(whisper_model.hf_tokenizer) + if args.suppress_numerals + else [-1] +) if args.batch_size > 0: transcript_segments, info = whisper_pipeline.transcribe( audio_waveform, language, - suppress_tokens=( - find_numeral_symbol_tokens(whisper_model.hf_tokenizer) - if args.suppress_numerals - else [-1] - ), + suppress_tokens=suppress_tokens, batch_size=args.batch_size, without_timestamps=True, ) @@ -139,11 +143,7 @@ transcript_segments, info = whisper_model.transcribe( audio_waveform, language, - suppress_tokens=( - find_numeral_symbol_tokens(whisper_model.hf_tokenizer) - if args.suppress_numerals - else [-1] - ), + suppress_tokens=suppress_tokens, without_timestamps=True, vad_filter=True, ) @@ -237,7 +237,8 @@ else: logging.warning( - f"Punctuation restoration is not available for {info.language} language. Using the original punctuation." + f"Punctuation restoration is not available for {info.language} language." + " Using the original punctuation." ) wsm = get_realigned_ws_mapping_with_punctuation(wsm) diff --git a/helpers.py b/helpers.py index 01c58fe..df36fb2 100644 --- a/helpers.py +++ b/helpers.py @@ -4,6 +4,7 @@ import nltk import wget + from omegaconf import OmegaConf punct_model_langs = [ @@ -145,108 +146,59 @@ ) langs_to_iso = { - "aa": "aar", - "ab": "abk", - "ae": "ave", "af": "afr", - "ak": "aka", "am": "amh", - "an": "arg", "ar": "ara", "as": "asm", - "av": "ava", - "ay": "aym", "az": "aze", "ba": "bak", "be": "bel", "bg": "bul", - "bh": "bih", - "bi": "bis", - "bm": "bam", "bn": "ben", "bo": "tib", "br": "bre", "bs": "bos", "ca": "cat", - "ce": "che", - "ch": "cha", - "co": "cos", - "cr": "cre", "cs": "cze", - "cu": "chu", - "cv": "chv", "cy": "wel", "da": "dan", "de": "ger", - "dv": "div", - "dz": "dzo", - "ee": "ewe", "el": "gre", "en": "eng", - "eo": "epo", "es": "spa", "et": "est", "eu": "baq", "fa": "per", - "ff": "ful", "fi": "fin", - "fj": "fij", "fo": "fao", "fr": "fre", - "fy": "fry", - "ga": "gle", - "gd": "gla", "gl": "glg", - "gn": "grn", "gu": "guj", - "gv": "glv", "ha": "hau", + "haw": "haw", "he": "heb", "hi": "hin", - "ho": "hmo", "hr": "hrv", "ht": "hat", "hu": "hun", "hy": "arm", - "hz": "her", - "ia": "ina", "id": "ind", - "ie": "ile", - "ig": "ibo", - "ii": "iii", - "ik": "ipk", - "io": "ido", "is": "ice", "it": "ita", - "iu": "iku", "ja": "jpn", - "jv": "jav", + "jw": "jav", "ka": "geo", - "kg": "kon", - "ki": "kik", - "kj": "kua", "kk": "kaz", - "kl": "kal", "km": "khm", "kn": "kan", "ko": "kor", - "kr": "kau", - "ks": "kas", - "ku": "kur", - "kv": "kom", - "kw": "cor", - "ky": "kir", "la": "lat", "lb": "ltz", - "lg": "lug", - "li": "lim", "ln": "lin", "lo": "lao", "lt": "lit", - "lu": "lub", "lv": "lav", "mg": "mlg", - "mh": "mah", "mi": "mao", "mk": "mac", "ml": "mal", @@ -255,48 +207,26 @@ "ms": "may", "mt": "mlt", "my": "bur", - "na": "nau", - "nb": "nob", - "nd": "nde", "ne": "nep", - "ng": "ndo", "nl": "dut", "nn": "nno", "no": "nor", - "nr": "nbl", - "nv": "nav", - "ny": "nya", "oc": "oci", - "oj": "oji", - "om": "orm", - "or": "ori", - "os": "oss", "pa": "pan", - "pi": "pli", "pl": "pol", "ps": "pus", "pt": "por", - "qu": "que", - "rm": "roh", - "rn": "run", "ro": "rum", "ru": "rus", - "rw": "kin", "sa": "san", - "sc": "srd", "sd": "snd", - "se": "sme", - "sg": "sag", "si": "sin", "sk": "slo", "sl": "slv", - "sm": "smo", "sn": "sna", "so": "som", "sq": "alb", "sr": "srp", - "ss": "ssw", - "st": "sot", "su": "sun", "sv": "swe", "sw": "swa", @@ -304,36 +234,23 @@ "te": "tel", "tg": "tgk", "th": "tha", - "ti": "tir", "tk": "tuk", "tl": "tgl", - "tn": "tsn", - "to": "ton", "tr": "tur", - "ts": "tso", "tt": "tat", - "tw": "twi", - "ty": "tah", - "ug": "uig", "uk": "ukr", "ur": "urd", "uz": "uzb", - "ve": "ven", "vi": "vie", - "vo": "vol", - "wa": "wln", - "wo": "wol", - "xh": "xho", "yi": "yid", "yo": "yor", - "za": "zha", + "yue": "yue", "zh": "chi", - "zu": "zul", } def create_config(output_dir): - DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file + DOMAIN_TYPE = "telephonic" CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs" CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml" MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME) @@ -669,12 +586,13 @@ def cleanup(path: str): # remove directory and all its content shutil.rmtree(path) else: - raise ValueError("Path {} is not a file or dir.".format(path)) + raise ValueError(f"Path {path} is not a file or dir.") def process_language_arg(language: str, model_name: str): """ - Process the language argument to make sure it's valid and convert language names to language codes. + Process the language argument to make sure it's valid + and convert language names to language codes. """ if language is not None: language = language.lower() diff --git a/nemo_process.py b/nemo_process.py index 7d02c57..8b3e87a 100644 --- a/nemo_process.py +++ b/nemo_process.py @@ -2,6 +2,7 @@ import os import torch + from nemo.collections.asr.models.msdd_models import NeuralDiarizer from pydub import AudioSegment From dfa8d1a35b891ed44043edca0a79af4f397766fe Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 25 Oct 2024 11:24:50 +0300 Subject: [PATCH 4/5] delete unnecessary files --- transcription_helpers.py | 70 ---------------------------------------- 1 file changed, 70 deletions(-) delete mode 100644 transcription_helpers.py diff --git a/transcription_helpers.py b/transcription_helpers.py deleted file mode 100644 index fb05c9a..0000000 --- a/transcription_helpers.py +++ /dev/null @@ -1,70 +0,0 @@ -import torch - - -def transcribe( - audio_file: str, - language: str, - model_name: str, - compute_dtype: str, - suppress_numerals: bool, - device: str, - batch_size: int, -): - import faster_whisper - - from helpers import find_numeral_symbol_tokens - - # Faster Whisper non-batched - # Run on GPU with FP16 - whisper_model = faster_whisper.WhisperModel( - model_name, device=device, compute_type=compute_dtype - ) - whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model) - # or run on GPU with INT8 - # model = WhisperModel(model_size, device="cuda", compute_type="int8_float16") - # or run on CPU with INT8 - # model = WhisperModel(model_size, device="cpu", compute_type="int8") - - if suppress_numerals: - numeral_symbol_tokens = find_numeral_symbol_tokens(whisper_model.hf_tokenizer) - else: - numeral_symbol_tokens = [-1] - - segments, info = whisper_pipeline.transcribe( - audio_file, - language=language, - suppress_tokens=numeral_symbol_tokens, - batch_size=batch_size, - ) - whisper_results = [] - for segment in segments: - whisper_results.append(segment._asdict()) - # clear gpu vram - del whisper_model - torch.cuda.empty_cache() - return whisper_results, info.language - - -def transcribe_batched( - audio_file: str, - language: str, - batch_size: int, - model_name: str, - compute_dtype: str, - suppress_numerals: bool, - device: str, -): - import whisperx - - # Faster Whisper batched - whisper_model = whisperx.load_model( - model_name, - device, - compute_type=compute_dtype, - asr_options={"suppress_numerals": suppress_numerals}, - ) - audio = whisperx.load_audio(audio_file) - result = whisper_model.transcribe(audio, language=language, batch_size=batch_size) - del whisper_model - torch.cuda.empty_cache() - return result["segments"], result["language"], audio From d10fba5edec60cdbcec2f7ad5cce8d3f958f1ed6 Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Fri, 25 Oct 2024 11:28:38 +0300 Subject: [PATCH 5/5] Update readme and notebook [skip ci] --- README.md | 1 - Whisper_Transcription_+_NeMo_Diarization.ipynb | 12 ++++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ea3d56a..8dd5b20 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,6 @@ # Speaker Diarization pipeline based on OpenAI Whisper -I'd like to thank [@m-bain](https://github.com/m-bain) for Batched Whisper Inference, [@mu4farooqi](https://github.com/mu4farooqi) for punctuation realignment algorithm drawing **Please, star the project on github (see top-right corner) if you appreciate my contribution to the community!** diff --git a/Whisper_Transcription_+_NeMo_Diarization.ipynb b/Whisper_Transcription_+_NeMo_Diarization.ipynb index f2ac775..552d8c8 100644 --- a/Whisper_Transcription_+_NeMo_Diarization.ipynb +++ b/Whisper_Transcription_+_NeMo_Diarization.ipynb @@ -789,7 +789,7 @@ "id": "UYg9VWb22Tz8" }, "source": [ - "## Transcriping audio using Whisper and realligning timestamps using Wav2Vec2\n", + "## Transcriping audio using Whisper and realligning timestamps using Forced Alignment\n", "---\n", "This code uses two different open-source models to transcribe speech and perform forced alignment on the resulting transcription.\n", "\n", @@ -828,15 +828,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Aligning the transcription with the original audio using Wav2Vec2\n", + "## Aligning the transcription with the original audio using Forced Alignment\n", "---\n", - "The second model used is called wav2vec2, which is a large-scale neural network that is designed to learn representations of speech that are useful for a variety of speech processing tasks, including speech recognition and alignment.\n", + "Forced alignment aims to to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n", "\n", - "The code loads the wav2vec2 alignment model and uses it to align the transcription segments with the original audio signal contained in the vocal_target file. This process involves finding the exact timestamps in the audio signal where each segment was spoken and aligning the text accordingly.\n", - "\n", - "By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification.\n", - "\n", - "If there's no Wav2Vec2 model available for your language, word timestamps generated by whisper will be used instead." + "By combining the outputs of the two models, the code produces a fully aligned transcription of the speech contained in the vocal_target file. This aligned transcription can be useful for a variety of speech processing tasks, such as speaker diarization, sentiment analysis, and language identification." ] }, {