Skip to content

Commit

Permalink
cleanup and change notebook code
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Oct 25, 2024
1 parent f1aeb6c commit 56bd834
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 160 deletions.
170 changes: 130 additions & 40 deletions Whisper_Transcription_+_NeMo_Diarization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
},
"outputs": [],
"source": [
"!pip install git+https://github.com/m-bain/whisperX.git@78dcfaab51005aa703ee21375f81ed31bc248560\n",
"!pip install git+https://github.com/SYSTRAN/faster-whisper.git ctranslate2==4.4.0\n",
"!pip install \"nemo-toolkit[asr]>=2.dev\"\n",
"!pip install --no-deps git+https://github.com/facebookresearch/demucs#egg=demucs\n",
"!pip install git+https://github.com/MahmoudAshraf97/demucs.git\n",
"!pip install git+https://github.com/oliverguhr/deepmultilingualpunctuation.git\n",
"!pip install git+https://github.com/MahmoudAshraf97/ctc-forced-aligner.git"
]
Expand All @@ -56,7 +56,7 @@
"import re\n",
"import logging\n",
"import nltk\n",
"from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE\n",
"import faster_whisper\n",
"from ctc_forced_aligner import (\n",
" load_alignment_model,\n",
" generate_emissions,\n",
Expand Down Expand Up @@ -99,6 +99,127 @@
" \"sk\",\n",
" \"sl\",\n",
"]\n",
"\n",
"LANGUAGES = {\n",
" \"en\": \"english\",\n",
" \"zh\": \"chinese\",\n",
" \"de\": \"german\",\n",
" \"es\": \"spanish\",\n",
" \"ru\": \"russian\",\n",
" \"ko\": \"korean\",\n",
" \"fr\": \"french\",\n",
" \"ja\": \"japanese\",\n",
" \"pt\": \"portuguese\",\n",
" \"tr\": \"turkish\",\n",
" \"pl\": \"polish\",\n",
" \"ca\": \"catalan\",\n",
" \"nl\": \"dutch\",\n",
" \"ar\": \"arabic\",\n",
" \"sv\": \"swedish\",\n",
" \"it\": \"italian\",\n",
" \"id\": \"indonesian\",\n",
" \"hi\": \"hindi\",\n",
" \"fi\": \"finnish\",\n",
" \"vi\": \"vietnamese\",\n",
" \"he\": \"hebrew\",\n",
" \"uk\": \"ukrainian\",\n",
" \"el\": \"greek\",\n",
" \"ms\": \"malay\",\n",
" \"cs\": \"czech\",\n",
" \"ro\": \"romanian\",\n",
" \"da\": \"danish\",\n",
" \"hu\": \"hungarian\",\n",
" \"ta\": \"tamil\",\n",
" \"no\": \"norwegian\",\n",
" \"th\": \"thai\",\n",
" \"ur\": \"urdu\",\n",
" \"hr\": \"croatian\",\n",
" \"bg\": \"bulgarian\",\n",
" \"lt\": \"lithuanian\",\n",
" \"la\": \"latin\",\n",
" \"mi\": \"maori\",\n",
" \"ml\": \"malayalam\",\n",
" \"cy\": \"welsh\",\n",
" \"sk\": \"slovak\",\n",
" \"te\": \"telugu\",\n",
" \"fa\": \"persian\",\n",
" \"lv\": \"latvian\",\n",
" \"bn\": \"bengali\",\n",
" \"sr\": \"serbian\",\n",
" \"az\": \"azerbaijani\",\n",
" \"sl\": \"slovenian\",\n",
" \"kn\": \"kannada\",\n",
" \"et\": \"estonian\",\n",
" \"mk\": \"macedonian\",\n",
" \"br\": \"breton\",\n",
" \"eu\": \"basque\",\n",
" \"is\": \"icelandic\",\n",
" \"hy\": \"armenian\",\n",
" \"ne\": \"nepali\",\n",
" \"mn\": \"mongolian\",\n",
" \"bs\": \"bosnian\",\n",
" \"kk\": \"kazakh\",\n",
" \"sq\": \"albanian\",\n",
" \"sw\": \"swahili\",\n",
" \"gl\": \"galician\",\n",
" \"mr\": \"marathi\",\n",
" \"pa\": \"punjabi\",\n",
" \"si\": \"sinhala\",\n",
" \"km\": \"khmer\",\n",
" \"sn\": \"shona\",\n",
" \"yo\": \"yoruba\",\n",
" \"so\": \"somali\",\n",
" \"af\": \"afrikaans\",\n",
" \"oc\": \"occitan\",\n",
" \"ka\": \"georgian\",\n",
" \"be\": \"belarusian\",\n",
" \"tg\": \"tajik\",\n",
" \"sd\": \"sindhi\",\n",
" \"gu\": \"gujarati\",\n",
" \"am\": \"amharic\",\n",
" \"yi\": \"yiddish\",\n",
" \"lo\": \"lao\",\n",
" \"uz\": \"uzbek\",\n",
" \"fo\": \"faroese\",\n",
" \"ht\": \"haitian creole\",\n",
" \"ps\": \"pashto\",\n",
" \"tk\": \"turkmen\",\n",
" \"nn\": \"nynorsk\",\n",
" \"mt\": \"maltese\",\n",
" \"sa\": \"sanskrit\",\n",
" \"lb\": \"luxembourgish\",\n",
" \"my\": \"myanmar\",\n",
" \"bo\": \"tibetan\",\n",
" \"tl\": \"tagalog\",\n",
" \"mg\": \"malagasy\",\n",
" \"as\": \"assamese\",\n",
" \"tt\": \"tatar\",\n",
" \"haw\": \"hawaiian\",\n",
" \"ln\": \"lingala\",\n",
" \"ha\": \"hausa\",\n",
" \"ba\": \"bashkir\",\n",
" \"jw\": \"javanese\",\n",
" \"su\": \"sundanese\",\n",
" \"yue\": \"cantonese\",\n",
"}\n",
"\n",
"# language code lookup by name, with a few language aliases\n",
"TO_LANGUAGE_CODE = {\n",
" **{language: code for code, language in LANGUAGES.items()},\n",
" \"burmese\": \"my\",\n",
" \"valencian\": \"ca\",\n",
" \"flemish\": \"nl\",\n",
" \"haitian\": \"ht\",\n",
" \"letzeburgesch\": \"lb\",\n",
" \"pushto\": \"ps\",\n",
" \"panjabi\": \"pa\",\n",
" \"moldavian\": \"ro\",\n",
" \"moldovan\": \"ro\",\n",
" \"sinhalese\": \"si\",\n",
" \"castilian\": \"es\",\n",
"}\n",
"\n",
"\n",
"langs_to_iso = {\n",
" \"af\": \"afr\",\n",
" \"am\": \"amh\",\n",
Expand Down Expand Up @@ -564,32 +685,7 @@
" f\"{model_name} is an English-only model but received '{language}'; using English instead.\"\n",
" )\n",
" language = \"en\"\n",
" return language\n",
"\n",
"\n",
"def transcribe_batched(\n",
" audio_file: str,\n",
" language: str,\n",
" batch_size: int,\n",
" model_name: str,\n",
" compute_dtype: str,\n",
" suppress_numerals: bool,\n",
" device: str,\n",
"):\n",
" import whisperx\n",
"\n",
" # Faster Whisper batched\n",
" whisper_model = whisperx.load_model(\n",
" model_name,\n",
" device,\n",
" compute_type=compute_dtype,\n",
" asr_options={\"suppress_numerals\": suppress_numerals},\n",
" )\n",
" audio = whisperx.load_audio(audio_file)\n",
" result = whisper_model.transcribe(audio, language=language, batch_size=batch_size)\n",
" del whisper_model\n",
" torch.cuda.empty_cache()\n",
" return result[\"segments\"], result[\"language\"], audio"
" return language"
]
},
{
Expand Down Expand Up @@ -754,11 +850,7 @@
" dtype=torch.float16 if device == \"cuda\" else torch.float32,\n",
")\n",
"\n",
"audio_waveform = (\n",
" torch.from_numpy(audio_waveform)\n",
" .to(alignment_model.dtype)\n",
" .to(alignment_model.device)\n",
")\n",
"audio_waveform = audio_waveform.to(alignment_model.dtype).to(alignment_model.device)\n",
"\n",
"emissions, stride = generate_emissions(\n",
" alignment_model, audio_waveform, batch_size=batch_size\n",
Expand All @@ -767,12 +859,10 @@
"del alignment_model\n",
"torch.cuda.empty_cache()\n",
"\n",
"full_transcript = \"\".join(segment[\"text\"] for segment in whisper_results)\n",
"\n",
"tokens_starred, text_starred = preprocess_text(\n",
" full_transcript,\n",
" romanize=True,\n",
" language=langs_to_iso[language],\n",
" language=langs_to_iso[info.language],\n",
")\n",
"\n",
"segments, scores, blank_token = get_alignments(\n",
Expand Down Expand Up @@ -903,13 +993,13 @@
},
"outputs": [],
"source": [
"if language in punct_model_langs:\n",
"if info.language in punct_model_langs:\n",
" # restoring punctuation in the transcript to help realign the sentences\n",
" punct_model = PunctuationModel(model=\"kredor/punctuate-all\")\n",
"\n",
" words_list = list(map(lambda x: x[\"word\"], wsm))\n",
"\n",
" labled_words = punct_model.predict(words_list,chunk_size=230)\n",
" labled_words = punct_model.predict(words_list, chunk_size=230)\n",
"\n",
" ending_puncts = \".?!\"\n",
" model_puncts = \".,;:!?\"\n",
Expand All @@ -931,7 +1021,7 @@
"\n",
"else:\n",
" logging.warning(\n",
" f\"Punctuation restoration is not available for {language} language. Using the original punctuation.\"\n",
" f\"Punctuation restoration is not available for {info.language} language. Using the original punctuation.\"\n",
" )\n",
"\n",
"wsm = get_realigned_ws_mapping_with_punctuation(wsm)\n",
Expand Down
29 changes: 15 additions & 14 deletions diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import faster_whisper
import torch
import torchaudio

from ctc_forced_aligner import (
generate_emissions,
get_alignments,
Expand Down Expand Up @@ -69,7 +70,8 @@
type=int,
dest="batch_size",
default=8,
help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
help="Batch size for batched inference, reduce if you run out of memory, "
"set to 0 for original whisper longform inference",
)

parser.add_argument(
Expand All @@ -94,12 +96,13 @@
# Isolate vocals from the rest of the audio

return_code = os.system(
f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o temp_outputs'
)

if return_code != 0:
logging.warning(
"Source splitting failed, using original audio file. Use --no-stem argument to disable it."
"Source splitting failed, using original audio file. "
"Use --no-stem argument to disable it."
)
vocal_target = args.audio
else:
Expand All @@ -120,28 +123,25 @@
)
whisper_pipeline = faster_whisper.BatchedInferencePipeline(whisper_model)
audio_waveform = faster_whisper.decode_audio(vocal_target)
suppress_tokens = (
find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
if args.suppress_numerals
else [-1]
)

if args.batch_size > 0:
transcript_segments, info = whisper_pipeline.transcribe(
audio_waveform,
language,
suppress_tokens=(
find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
if args.suppress_numerals
else [-1]
),
suppress_tokens=suppress_tokens,
batch_size=args.batch_size,
without_timestamps=True,
)
else:
transcript_segments, info = whisper_model.transcribe(
audio_waveform,
language,
suppress_tokens=(
find_numeral_symbol_tokens(whisper_model.hf_tokenizer)
if args.suppress_numerals
else [-1]
),
suppress_tokens=suppress_tokens,
without_timestamps=True,
vad_filter=True,
)
Expand Down Expand Up @@ -245,7 +245,8 @@

else:
logging.warning(
f"Punctuation restoration is not available for {info.language} language. Using the original punctuation."
f"Punctuation restoration is not available for {info.language} language."
" Using the original punctuation."
)

wsm = get_realigned_ws_mapping_with_punctuation(wsm)
Expand Down
Loading

0 comments on commit 56bd834

Please sign in to comment.