mirror of
https://github.com/jlengrand/tldw.git
synced 2026-03-10 08:51:17 +00:00
diarize cli wip
This commit is contained in:
259
diarize.py
259
diarize.py
@@ -1,31 +1,10 @@
|
||||
# import whisper
|
||||
from faster_whisper import WhisperModel
|
||||
#!/usr/bin/env python3
|
||||
import datetime
|
||||
import subprocess
|
||||
import gradio as gr
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
import re
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
from pytube import YouTube
|
||||
import yt_dlp
|
||||
import json
|
||||
import torch
|
||||
import pyannote.audio
|
||||
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
||||
from pyannote.audio import Audio
|
||||
from pyannote.core import Segment
|
||||
|
||||
from gpuinfo import GPUInfo
|
||||
|
||||
import wave
|
||||
import contextlib
|
||||
from transformers import pipeline
|
||||
import psutil
|
||||
|
||||
whisper_models = ["small", "medium", "small.en","medium.en"]
|
||||
source_languages = {
|
||||
@@ -37,48 +16,44 @@ source_languages = {
|
||||
"ko": "Korean",
|
||||
"fr": "French"
|
||||
}
|
||||
|
||||
source_language_list = [key[0] for key in source_languages.items()]
|
||||
|
||||
embedding_model = PretrainedSpeakerEmbedding(
|
||||
#"speechbrain/spkrec-ecapa-voxceleb",
|
||||
"pyannote/embedding",
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
def convert_time(secs):
|
||||
return datetime.timedelta(seconds=round(secs))
|
||||
|
||||
# Download video .m4a and info.json
|
||||
def get_youtube(video_url):
|
||||
# yt = YouTube(video_url)
|
||||
# abs_video_path = yt.streams.filter(progressive=True, file_extension='mp4').order_by('resolution').desc().first().download()
|
||||
return "lex.m4a"
|
||||
|
||||
ydl_opts = {
|
||||
'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
|
||||
}
|
||||
import yt_dlp
|
||||
|
||||
ydl_opts = { 'format': 'bestaudio[ext=m4a]' }
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info = ydl.extract_info(video_url, download=False)
|
||||
abs_video_path = ydl.prepare_filename(info)
|
||||
with open(abs_video_path.replace('m4a','info.json'), 'w') as outfile:
|
||||
json.dump(info, outfile, indent=2)
|
||||
ydl.process_info(info)
|
||||
|
||||
print("Success download video")
|
||||
print(abs_video_path)
|
||||
print("Success download",video_url,"to", abs_video_path)
|
||||
return abs_video_path
|
||||
|
||||
def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_speakers):
|
||||
"""
|
||||
# Transcribe youtube link using OpenAI Whisper
|
||||
1. Using Open AI's Whisper model to seperate audio into segments and generate transcripts.
|
||||
2. Generating speaker embeddings for each segments.
|
||||
3. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||
# Convert video .m4a into .wav
|
||||
def convert_to_wav(video_file_path):
|
||||
out_path = video_file_path.replace("m4a","wav")
|
||||
if os.path.exists(out_path):
|
||||
print("wav file already exists:", out_path)
|
||||
return out_path
|
||||
|
||||
try:
|
||||
print("starting conversion to wav")
|
||||
os.system(f'ffmpeg -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{out_path}"')
|
||||
print("conversion to wav ready:", out_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error converting.")
|
||||
|
||||
Speech Recognition is based on models from OpenAI Whisper https://github.com/openai/whisper
|
||||
Speaker diarization model and pipeline from by https://github.com/pyannote/pyannote-audio
|
||||
"""
|
||||
|
||||
# model = whisper.load_model(whisper_model)
|
||||
print('loading whisper model..')
|
||||
return out_path
|
||||
|
||||
# Transcribe .wav into .segments.json
|
||||
def speech_to_text(video_file_path, selected_source_lang = 'en', whisper_model = 'small.en'):
|
||||
print('loading faster_whisper model:', whisper_model)
|
||||
from faster_whisper import WhisperModel
|
||||
model = WhisperModel(whisper_model, device="cuda", compute_type="float16")
|
||||
time_start = time.time()
|
||||
if(video_file_path == None):
|
||||
@@ -88,19 +63,16 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
|
||||
try:
|
||||
# Read and convert youtube video
|
||||
_,file_ending = os.path.splitext(f'{video_file_path}')
|
||||
print(f'file enging is {file_ending}')
|
||||
audio_file = video_file_path.replace(file_ending, ".wav")
|
||||
print("starting conversion to wav")
|
||||
os.system(f'ffmpeg -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{audio_file}"')
|
||||
out_file = video_file_path.replace(file_ending, ".segments.json")
|
||||
if os.path.exists(out_file):
|
||||
print("segments file already exists:", out_file)
|
||||
with open(out_file) as f:
|
||||
segments = json.load(f)
|
||||
return segments
|
||||
|
||||
# Get duration
|
||||
with contextlib.closing(wave.open(audio_file,'r')) as f:
|
||||
frames = f.getnframes()
|
||||
rate = f.getframerate()
|
||||
duration = frames / float(rate)
|
||||
print(f"conversion to wav ready, duration of audio file: {duration}")
|
||||
|
||||
# Transcribe audio
|
||||
print('starting transcription...')
|
||||
options = dict(language=selected_source_lang, beam_size=5, best_of=5)
|
||||
transcribe_options = dict(task="transcribe", **options)
|
||||
segments_raw, info = model.transcribe(audio_file, **transcribe_options)
|
||||
@@ -117,10 +89,51 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
|
||||
segments.append(chunk)
|
||||
i += 1
|
||||
print("transcribe audio done with fast whisper")
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error converting video to audio")
|
||||
|
||||
with open(out_file,'w') as f:
|
||||
out_file.write(json.dumps(segments, indent=2))
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error transcribing.")
|
||||
|
||||
return segments
|
||||
|
||||
# embedding_model: "pyannote/embedding", embedding_size: 512
|
||||
def speaker_diarize(video_file_path, segments, embedding_model = "speechbrain/spkrec-ecapa-voxceleb", embedding_size=192, num_speakers):
|
||||
"""
|
||||
1. Generating speaker embeddings for each segments.
|
||||
2. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||
"""
|
||||
try:
|
||||
# Load embedding model
|
||||
from pyannote.audio import Audio
|
||||
from pyannote.core import Segment
|
||||
|
||||
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
||||
embedding_model = PretrainedSpeakerEmbedding( embedding_model, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics import silhouette_score
|
||||
|
||||
_,file_ending = os.path.splitext(f'{video_file_path}')
|
||||
audio_file = video_file_path.replace(file_ending, ".wav")
|
||||
out_file = video_file_path.replace(file_ending, ".diarize.json")
|
||||
if os.path.exists(out_file):
|
||||
print("segments file already exists:", out_file)
|
||||
with open(out_file) as f:
|
||||
segments = json.load(f)
|
||||
return segments
|
||||
|
||||
# Get duration
|
||||
import wave
|
||||
with contextlib.closing(wave.open(audio_file,'r')) as f:
|
||||
frames = f.getnframes()
|
||||
rate = f.getframerate()
|
||||
duration = frames / float(rate)
|
||||
print(f"duration of audio file: {duration}")
|
||||
|
||||
# Create embedding
|
||||
def segment_embedding(segment):
|
||||
audio = Audio()
|
||||
@@ -131,7 +144,7 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
|
||||
waveform, sample_rate = audio.crop(audio_file, clip)
|
||||
return embedding_model(waveform[None])
|
||||
|
||||
embeddings = np.zeros(shape=(len(segments), 512))
|
||||
embeddings = np.zeros(shape=(len(segments), embedding_size))
|
||||
for i, segment in enumerate(segments):
|
||||
embeddings[i] = segment_embedding(segment)
|
||||
embeddings = np.nan_to_num(embeddings)
|
||||
@@ -156,7 +169,13 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
|
||||
for i in range(len(segments)):
|
||||
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
|
||||
|
||||
with open(out_file,'w') as f:
|
||||
out_file.write(json.dumps(segments, indent=2))
|
||||
|
||||
# Make output
|
||||
def convert_time(secs):
|
||||
return datetime.timedelta(seconds=round(secs))
|
||||
|
||||
objects = {
|
||||
'Start' : [],
|
||||
'End': [],
|
||||
@@ -176,107 +195,21 @@ def speech_to_text(video_file_path, selected_source_lang, whisper_model, num_spe
|
||||
objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||
objects['Text'].append(text)
|
||||
|
||||
time_end = time.time()
|
||||
time_diff = time_end - time_start
|
||||
memory = psutil.virtual_memory()
|
||||
gpu_utilization, gpu_memory = GPUInfo.gpu_usage()
|
||||
gpu_utilization = gpu_utilization[0] if len(gpu_utilization) > 0 else 0
|
||||
gpu_memory = gpu_memory[0] if len(gpu_memory) > 0 else 0
|
||||
system_info = f"""
|
||||
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB.*
|
||||
*Processing time: {time_diff:.5} seconds.*
|
||||
*GPU Utilization: {gpu_utilization}%, GPU Memory: {gpu_memory}MiB.*
|
||||
"""
|
||||
save_path = "output/transcript_result.csv"
|
||||
save_path = video_file_path.replace(file_ending, ".csv")
|
||||
df_results = pd.DataFrame(objects)
|
||||
df_results.to_csv(save_path)
|
||||
return df_results, system_info, save_path
|
||||
return df_results, save_path
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error Running inference with local model", e)
|
||||
|
||||
def main(youtube_url: str):
|
||||
video_path = youtube_url(youtube_url)
|
||||
convert_to_wav(video_path)
|
||||
segments = speech_to_text(video_path)
|
||||
df_results, save_path = speaker_diarize(video_path, segments)
|
||||
print("diarize complete:", save_path)
|
||||
|
||||
# ---- Gradio Layout -----
|
||||
# Inspiration from https://huggingface.co/spaces/RASMUS/Whisper-youtube-crosslingual-subtitles
|
||||
video_in = gr.Video(label="Video file", mirror_webcam=False)
|
||||
youtube_url_in = gr.Textbox(label="Youtube url", lines=1, interactive=True)
|
||||
df_init = pd.DataFrame(columns=['Start', 'End', 'Speaker', 'Text'])
|
||||
memory = psutil.virtual_memory()
|
||||
selected_source_lang = gr.Dropdown(choices=source_language_list, type="value", value="en", label="Spoken language in video", interactive=True)
|
||||
selected_whisper_model = gr.Dropdown(choices=whisper_models, type="value", value="small.en", label="Selected Whisper model", interactive=True)
|
||||
number_speakers = gr.Number(precision=0, value=0, label="Input number of speakers for better results. If value=0, model will automatic find the best number of speakers", interactive=True)
|
||||
system_info = gr.Markdown(f"*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*")
|
||||
download_transcript = gr.File(label="Download transcript")
|
||||
transcription_df = gr.DataFrame(value=df_init,label="Transcription dataframe", row_count=(0, "dynamic"), max_rows = 10, wrap=True, overflow_row_behaviour='paginate')
|
||||
title = "Whisper speaker diarization"
|
||||
demo = gr.Blocks(title=title)
|
||||
demo.encrypt = False
|
||||
|
||||
|
||||
with demo:
|
||||
gr.Markdown('''
|
||||
<div>
|
||||
<h1 style='text-align: center'>Whisper speaker diarization</h1>
|
||||
This space uses Whisper models from <a href='https://github.com/openai/whisper' target='_blank'><b>OpenAI</b></a> with <a href='https://github.com/guillaumekln/faster-whisper' target='_blank'><b>CTranslate2</b></a> which is a fast inference engine for Transformer models to recognize the speech (4 times faster than original openai model with same accuracy)
|
||||
and ECAPA-TDNN model from <a href='https://github.com/speechbrain/speechbrain' target='_blank'><b>SpeechBrain</b></a> to encode and clasify speakers
|
||||
</div>
|
||||
''')
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown('''
|
||||
### Transcribe youtube link using OpenAI Whisper
|
||||
##### 1. Using Open AI's Whisper model to seperate audio into segments and generate transcripts.
|
||||
##### 2. Generating speaker embeddings for each segments.
|
||||
##### 3. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||
''')
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown('''
|
||||
### You can test by following examples:
|
||||
''')
|
||||
examples = gr.Examples(examples=
|
||||
[ "https://www.youtube.com/watch?v=j7BfEzAFuYc&t=32s",
|
||||
"https://www.youtube.com/watch?v=-UX0X45sYe4",
|
||||
"https://www.youtube.com/watch?v=7minSgqi-Gw"],
|
||||
label="Examples", inputs=[youtube_url_in])
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
youtube_url_in.render()
|
||||
download_youtube_btn = gr.Button("Download Youtube video")
|
||||
download_youtube_btn.click(get_youtube, [youtube_url_in], [video_in])
|
||||
print(video_in)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
video_in.render()
|
||||
with gr.Column():
|
||||
gr.Markdown('''
|
||||
##### Here you can start the transcription process.
|
||||
##### Please select the source language for transcription.
|
||||
##### You can select a range of assumed numbers of speakers.
|
||||
''')
|
||||
selected_source_lang.render()
|
||||
selected_whisper_model.render()
|
||||
number_speakers.render()
|
||||
transcribe_btn = gr.Button("Transcribe audio and diarization")
|
||||
transcribe_btn.click(speech_to_text,
|
||||
[video_in, selected_source_lang, selected_whisper_model, number_speakers],
|
||||
[transcription_df, system_info, download_transcript]
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
gr.Markdown('''
|
||||
##### Here you will get transcription output
|
||||
##### ''')
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
download_transcript.render()
|
||||
transcription_df.render()
|
||||
system_info.render()
|
||||
|
||||
demo.launch(debug=True, server_port=8888, share=True)
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
fire.Fire(main)
|
||||
Reference in New Issue
Block a user