mirror of
https://github.com/jlengrand/tldw.git
synced 2026-03-10 08:51:17 +00:00
Breaking app.py but fixing summarizer.py ....
This commit is contained in:
426
HF/app.py
426
HF/app.py
@@ -1,14 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
import gradio as gr
|
||||
import argparse, configparser, datetime, json, logging, os, platform, requests, shutil, subprocess, sys, time, unicodedata
|
||||
import argparse
|
||||
import configparser
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import requests
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
import contextlib
|
||||
import ffmpeg
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yt_dlp
|
||||
|
||||
|
||||
#######
|
||||
# Function Sections
|
||||
#
|
||||
@@ -349,7 +357,7 @@ def process_local_file(file_path):
|
||||
# Video Download/Handling
|
||||
#
|
||||
|
||||
def process_url(input_path, num_speakers=2, whisper_model="small.en", offset=0, api_name=None, api_key=None, vad_filter=False, download_video_flag=False,custom_prompt=None, demo_mode=True):
|
||||
def process_url(input_path, num_speakers=2, whisper_model="small.en", custom_prompt=None, offset=0, api_name=None, api_key=None, vad_filter=False, download_video_flag=False, demo_mode=False):
|
||||
if demo_mode:
|
||||
api_name = "huggingface"
|
||||
api_key = os.environ.get(HF_TOKEN)
|
||||
@@ -646,115 +654,115 @@ def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='sm
|
||||
# TODO: https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
# embedding_model = "pyannote/embedding", embedding_size=512
|
||||
# embedding_model = "speechbrain/spkrec-ecapa-voxceleb", embedding_size=192
|
||||
def speaker_diarize(video_file_path, segments, embedding_model = "pyannote/embedding", embedding_size=512, num_speakers=0):
|
||||
"""
|
||||
1. Generating speaker embeddings for each segments.
|
||||
2. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||
"""
|
||||
try:
|
||||
from pyannote.audio import Audio
|
||||
from pyannote.core import Segment
|
||||
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.metrics import silhouette_score
|
||||
import tqdm
|
||||
import wave
|
||||
|
||||
embedding_model = PretrainedSpeakerEmbedding( embedding_model, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
|
||||
|
||||
_,file_ending = os.path.splitext(f'{video_file_path}')
|
||||
audio_file = video_file_path.replace(file_ending, ".wav")
|
||||
out_file = video_file_path.replace(file_ending, ".diarize.json")
|
||||
|
||||
logging.debug("getting duration of audio file")
|
||||
with contextlib.closing(wave.open(audio_file,'r')) as f:
|
||||
frames = f.getnframes()
|
||||
rate = f.getframerate()
|
||||
duration = frames / float(rate)
|
||||
logging.debug("duration of audio file obtained")
|
||||
print(f"duration of audio file: {duration}")
|
||||
|
||||
def segment_embedding(segment):
|
||||
logging.debug("Creating embedding")
|
||||
audio = Audio()
|
||||
start = segment["start"]
|
||||
end = segment["end"]
|
||||
|
||||
# Enforcing a minimum segment length
|
||||
if end-start < 0.3:
|
||||
padding = 0.3-(end-start)
|
||||
start -= padding/2
|
||||
end += padding/2
|
||||
print('Padded segment because it was too short:',segment)
|
||||
|
||||
# Whisper overshoots the end timestamp in the last segment
|
||||
end = min(duration, end)
|
||||
# clip audio and embed
|
||||
clip = Segment(start, end)
|
||||
waveform, sample_rate = audio.crop(audio_file, clip)
|
||||
return embedding_model(waveform[None])
|
||||
|
||||
embeddings = np.zeros(shape=(len(segments), embedding_size))
|
||||
for i, segment in enumerate(tqdm.tqdm(segments)):
|
||||
embeddings[i] = segment_embedding(segment)
|
||||
embeddings = np.nan_to_num(embeddings)
|
||||
print(f'Embedding shape: {embeddings.shape}')
|
||||
|
||||
if num_speakers == 0:
|
||||
# Find the best number of speakers
|
||||
score_num_speakers = {}
|
||||
|
||||
for num_speakers in range(2, 10+1):
|
||||
clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
|
||||
score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
|
||||
score_num_speakers[num_speakers] = score
|
||||
best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
|
||||
print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
|
||||
else:
|
||||
best_num_speaker = num_speakers
|
||||
|
||||
# Assign speaker label
|
||||
clustering = AgglomerativeClustering(best_num_speaker).fit(embeddings)
|
||||
labels = clustering.labels_
|
||||
for i in range(len(segments)):
|
||||
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
|
||||
|
||||
with open(out_file,'w') as f:
|
||||
f.write(json.dumps(segments, indent=2))
|
||||
|
||||
# Make CSV output
|
||||
def convert_time(secs):
|
||||
return datetime.timedelta(seconds=round(secs))
|
||||
|
||||
objects = {
|
||||
'Start' : [],
|
||||
'End': [],
|
||||
'Speaker': [],
|
||||
'Text': []
|
||||
}
|
||||
text = ''
|
||||
for (i, segment) in enumerate(segments):
|
||||
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
|
||||
objects['Start'].append(str(convert_time(segment["start"])))
|
||||
objects['Speaker'].append(segment["speaker"])
|
||||
if i != 0:
|
||||
objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||
objects['Text'].append(text)
|
||||
text = ''
|
||||
text += segment["text"] + ' '
|
||||
objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||
objects['Text'].append(text)
|
||||
|
||||
save_path = video_file_path.replace(file_ending, ".csv")
|
||||
df_results = pd.DataFrame(objects)
|
||||
df_results.to_csv(save_path)
|
||||
return df_results, save_path
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError("Error Running inference with local model", e)
|
||||
# def speaker_diarize(video_file_path, segments, embedding_model = "pyannote/embedding", embedding_size=512, num_speakers=0):
|
||||
# """
|
||||
# 1. Generating speaker embeddings for each segments.
|
||||
# 2. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||
# """
|
||||
# try:
|
||||
# from pyannote.audio import Audio
|
||||
# from pyannote.core import Segment
|
||||
# from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
||||
# import numpy as np
|
||||
# import pandas as pd
|
||||
# from sklearn.cluster import AgglomerativeClustering
|
||||
# from sklearn.metrics import silhouette_score
|
||||
# import tqdm
|
||||
# import wave
|
||||
#
|
||||
# embedding_model = PretrainedSpeakerEmbedding( embedding_model, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||
#
|
||||
#
|
||||
# _,file_ending = os.path.splitext(f'{video_file_path}')
|
||||
# audio_file = video_file_path.replace(file_ending, ".wav")
|
||||
# out_file = video_file_path.replace(file_ending, ".diarize.json")
|
||||
#
|
||||
# logging.debug("getting duration of audio file")
|
||||
# with contextlib.closing(wave.open(audio_file,'r')) as f:
|
||||
# frames = f.getnframes()
|
||||
# rate = f.getframerate()
|
||||
# duration = frames / float(rate)
|
||||
# logging.debug("duration of audio file obtained")
|
||||
# print(f"duration of audio file: {duration}")
|
||||
#
|
||||
# def segment_embedding(segment):
|
||||
# logging.debug("Creating embedding")
|
||||
# audio = Audio()
|
||||
# start = segment["start"]
|
||||
# end = segment["end"]
|
||||
#
|
||||
# # Enforcing a minimum segment length
|
||||
# if end-start < 0.3:
|
||||
# padding = 0.3-(end-start)
|
||||
# start -= padding/2
|
||||
# end += padding/2
|
||||
# print('Padded segment because it was too short:',segment)
|
||||
#
|
||||
# # Whisper overshoots the end timestamp in the last segment
|
||||
# end = min(duration, end)
|
||||
# # clip audio and embed
|
||||
# clip = Segment(start, end)
|
||||
# waveform, sample_rate = audio.crop(audio_file, clip)
|
||||
# return embedding_model(waveform[None])
|
||||
#
|
||||
# embeddings = np.zeros(shape=(len(segments), embedding_size))
|
||||
# for i, segment in enumerate(tqdm.tqdm(segments)):
|
||||
# embeddings[i] = segment_embedding(segment)
|
||||
# embeddings = np.nan_to_num(embeddings)
|
||||
# print(f'Embedding shape: {embeddings.shape}')
|
||||
#
|
||||
# if num_speakers == 0:
|
||||
# # Find the best number of speakers
|
||||
# score_num_speakers = {}
|
||||
#
|
||||
# for num_speakers in range(2, 10+1):
|
||||
# clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
|
||||
# score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
|
||||
# score_num_speakers[num_speakers] = score
|
||||
# best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
|
||||
# print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
|
||||
# else:
|
||||
# best_num_speaker = num_speakers
|
||||
#
|
||||
# # Assign speaker label
|
||||
# clustering = AgglomerativeClustering(best_num_speaker).fit(embeddings)
|
||||
# labels = clustering.labels_
|
||||
# for i in range(len(segments)):
|
||||
# segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
|
||||
#
|
||||
# with open(out_file,'w') as f:
|
||||
# f.write(json.dumps(segments, indent=2))
|
||||
#
|
||||
# # Make CSV output
|
||||
# def convert_time(secs):
|
||||
# return datetime.timedelta(seconds=round(secs))
|
||||
#
|
||||
# objects = {
|
||||
# 'Start' : [],
|
||||
# 'End': [],
|
||||
# 'Speaker': [],
|
||||
# 'Text': []
|
||||
# }
|
||||
# text = ''
|
||||
# for (i, segment) in enumerate(segments):
|
||||
# if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
|
||||
# objects['Start'].append(str(convert_time(segment["start"])))
|
||||
# objects['Speaker'].append(segment["speaker"])
|
||||
# if i != 0:
|
||||
# objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||
# objects['Text'].append(text)
|
||||
# text = ''
|
||||
# text += segment["text"] + ' '
|
||||
# objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||
# objects['Text'].append(text)
|
||||
#
|
||||
# save_path = video_file_path.replace(file_ending, ".csv")
|
||||
# df_results = pd.DataFrame(objects)
|
||||
# df_results.to_csv(save_path)
|
||||
# return df_results, save_path
|
||||
#
|
||||
# except Exception as e:
|
||||
# raise RuntimeError("Error Running inference with local model", e)
|
||||
#
|
||||
#
|
||||
####################################################################################################################################
|
||||
@@ -777,12 +785,12 @@ def extract_text_from_segments(segments):
|
||||
|
||||
|
||||
|
||||
def summarize_with_openai(api_key, file_path, model):
|
||||
def summarize_with_openai(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("openai: Loading json data for summarization")
|
||||
with open(file_path, 'r') as file:
|
||||
segments = json.load(file)
|
||||
|
||||
|
||||
logging.debug("openai: Extracting text from the segments")
|
||||
text = extract_text_from_segments(segments)
|
||||
|
||||
@@ -790,9 +798,14 @@ def summarize_with_openai(api_key, file_path, model):
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
# headers = {
|
||||
# 'Authorization': f'Bearer {api_key}',
|
||||
# 'Content-Type': 'application/json'
|
||||
# }
|
||||
|
||||
logging.debug(f"openai: API Key is: {api_key}")
|
||||
logging.debug("openai: Preparing data + prompt for submittal")
|
||||
openai_prompt = f"{text} \n\n\n\n{prompt_text}"
|
||||
openai_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
@@ -810,7 +823,7 @@ def summarize_with_openai(api_key, file_path, model):
|
||||
}
|
||||
logging.debug("openai: Posting request")
|
||||
response = requests.post('https://api.openai.com/v1/chat/completions', headers=headers, json=data)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
summary = response.json()['choices'][0]['message']['content'].strip()
|
||||
logging.debug("openai: Summarization successful")
|
||||
@@ -826,13 +839,12 @@ def summarize_with_openai(api_key, file_path, model):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def summarize_with_claude(api_key, file_path, model):
|
||||
def summarize_with_claude(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("anthropic: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
segments = json.load(file)
|
||||
|
||||
|
||||
logging.debug("anthropic: Extracting text from the segments file")
|
||||
text = extract_text_from_segments(segments)
|
||||
|
||||
@@ -841,16 +853,17 @@ def summarize_with_claude(api_key, file_path, model):
|
||||
'anthropic-version': '2023-06-01',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
logging.debug("anthropic: Prepping data + prompt for submittal")
|
||||
|
||||
anthropic_prompt = custom_prompt
|
||||
logging.debug("anthropic: Prompt is {anthropic_prompt}")
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": f"{text} \n\n\n\n{prompt_text}"
|
||||
"content": f"{text} \n\n\n\n{anthropic_prompt}"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"max_tokens": 4096, # max _possible_ tokens to return
|
||||
"max_tokens": 4096, # max _possible_ tokens to return
|
||||
"messages": [user_message],
|
||||
"stop_sequences": ["\n\nHuman:"],
|
||||
"temperature": 0.7,
|
||||
@@ -862,17 +875,17 @@ def summarize_with_claude(api_key, file_path, model):
|
||||
"stream": False,
|
||||
"system": "You are a professional summarizer."
|
||||
}
|
||||
|
||||
|
||||
logging.debug("anthropic: Posting request to API")
|
||||
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, json=data)
|
||||
|
||||
|
||||
# Check if the status code indicates success
|
||||
if response.status_code == 200:
|
||||
logging.debug("anthropic: Post submittal successful")
|
||||
response_data = response.json()
|
||||
try:
|
||||
summary = response_data['content'][0]['text'].strip()
|
||||
logging.debug("anthropic: Summarization succesful")
|
||||
logging.debug("anthropic: Summarization successful")
|
||||
print("Summary processed successfully.")
|
||||
return summary
|
||||
except (IndexError, KeyError) as e:
|
||||
@@ -894,9 +907,8 @@ def summarize_with_claude(api_key, file_path, model):
|
||||
return None
|
||||
|
||||
|
||||
|
||||
# Summarize with Cohere
|
||||
def summarize_with_cohere(api_key, file_path, model):
|
||||
def summarize_with_cohere(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.debug("cohere: Loading JSON data")
|
||||
@@ -912,7 +924,9 @@ def summarize_with_cohere(api_key, file_path, model):
|
||||
'Authorization': f'Bearer {api_key}'
|
||||
}
|
||||
|
||||
cohere_prompt = f"{text} \n\n\n\n{prompt_text}"
|
||||
cohere_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
logging.debug("cohere: Prompt being sent is {cohere_prompt}")
|
||||
|
||||
data = {
|
||||
"chat_history": [
|
||||
{"role": "USER", "message": cohere_prompt}
|
||||
@@ -938,7 +952,7 @@ def summarize_with_cohere(api_key, file_path, model):
|
||||
logging.error("Expected data not found in API response.")
|
||||
return "Expected data not found in API response."
|
||||
else:
|
||||
logging.error(f"cohere: API request failed with status code {response.status_code}: {resposne.text}")
|
||||
logging.error(f"cohere: API request failed with status code {response.status_code}: {response.text}")
|
||||
print(f"Failed to process summary, status code {response.status_code}: {response.text}")
|
||||
return f"cohere: API request failed: {response.text}"
|
||||
|
||||
@@ -947,9 +961,8 @@ def summarize_with_cohere(api_key, file_path, model):
|
||||
return f"cohere: Error occurred while processing summary with Cohere: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# https://console.groq.com/docs/quickstart
|
||||
def summarize_with_groq(api_key, file_path, model):
|
||||
def summarize_with_groq(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("groq: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -963,7 +976,9 @@ def summarize_with_groq(api_key, file_path, model):
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
groq_prompt = f"{text} \n\n\n\n{prompt_text}"
|
||||
groq_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
logging.debug("groq: Prompt being sent is {groq_prompt}")
|
||||
|
||||
data = {
|
||||
"messages": [
|
||||
{
|
||||
@@ -1003,7 +1018,7 @@ def summarize_with_groq(api_key, file_path, model):
|
||||
#
|
||||
# Local Summarization
|
||||
|
||||
def summarize_with_llama(api_url, file_path, token):
|
||||
def summarize_with_llama(api_url, file_path, token, custom_prompt):
|
||||
try:
|
||||
logging.debug("llama: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1016,17 +1031,17 @@ def summarize_with_llama(api_url, file_path, token):
|
||||
'accept': 'application/json',
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
if len(token)>5:
|
||||
if len(token) > 5:
|
||||
headers['Authorization'] = f'Bearer {token}'
|
||||
|
||||
llama_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
logging.debug("llama: Prompt being sent is {llama_prompt}")
|
||||
|
||||
llama_prompt = f"{text} \n\n\n\n{prompt_text}"
|
||||
logging.debug(f"llama: Complete prompt is: {llama_prompt}")
|
||||
data = {
|
||||
"prompt": llama_prompt
|
||||
}
|
||||
|
||||
#logging.debug(f"llama: Submitting request to API endpoint {llama_prompt}")
|
||||
logging.debug("llama: Submitting request to API endpoint")
|
||||
print("llama: Submitting request to API endpoint")
|
||||
response = requests.post(api_url, headers=headers, json=data)
|
||||
response_data = response.json()
|
||||
@@ -1048,9 +1063,8 @@ def summarize_with_llama(api_url, file_path, token):
|
||||
return f"llama: Error occurred while processing summary with llama: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# https://lite.koboldai.net/koboldcpp_api#/api%2Fv1/post_api_v1_generate
|
||||
def summarize_with_kobold(api_url, file_path):
|
||||
def summarize_with_kobold(api_url, file_path, custom_prompt):
|
||||
try:
|
||||
logging.debug("kobold: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1063,9 +1077,11 @@ def summarize_with_kobold(api_url, file_path):
|
||||
'accept': 'application/json',
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
|
||||
kobold_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
logging.debug("kobold: Prompt being sent is {kobold_prompt}")
|
||||
|
||||
# FIXME
|
||||
kobold_prompt = f"{text} \n\n\n\n{prompt_text}"
|
||||
logging.debug(kobold_prompt)
|
||||
# Values literally c/p from the api docs....
|
||||
data = {
|
||||
"max_context_length": 8096,
|
||||
@@ -1097,9 +1113,8 @@ def summarize_with_kobold(api_url, file_path):
|
||||
return f"kobold: Error occurred while processing summary with kobold: {str(e)}"
|
||||
|
||||
|
||||
|
||||
# https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
||||
def summarize_with_oobabooga(api_url, file_path):
|
||||
def summarize_with_oobabooga(api_url, file_path, custom_prompt):
|
||||
try:
|
||||
logging.debug("ooba: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1114,14 +1129,15 @@ def summarize_with_oobabooga(api_url, file_path):
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
|
||||
#prompt_text = "I like to eat cake and bake cakes. I am a baker. I work in a french bakery baking cakes. It is a fun job. I have been baking cakes for ten years. I also bake lots of other baked goods, but cakes are my favorite."
|
||||
#prompt_text += f"\n\n{text}" # Uncomment this line if you want to include the text variable
|
||||
ooba_prompt = f"{text}\n\n\n\n{prompt_text}"
|
||||
# prompt_text = "I like to eat cake and bake cakes. I am a baker. I work in a French bakery baking cakes. It is a fun job. I have been baking cakes for ten years. I also bake lots of other baked goods, but cakes are my favorite."
|
||||
# prompt_text += f"\n\n{text}" # Uncomment this line if you want to include the text variable
|
||||
ooba_prompt = "{text}\n\n\n\n{custom_prompt}"
|
||||
logging.debug("ooba: Prompt being sent is {ooba_prompt}")
|
||||
|
||||
data = {
|
||||
data = {
|
||||
"mode": "chat",
|
||||
"character": "Example",
|
||||
"messages": [{"role": "user", "content": prompt_text}]
|
||||
"messages": [{"role": "user", "content": ooba_prompt}]
|
||||
}
|
||||
|
||||
logging.debug("ooba: Submitting request to API endpoint")
|
||||
@@ -1350,7 +1366,7 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
|
||||
logging.debug("MAIN: Video downloaded successfully")
|
||||
logging.debug("MAIN: Converting video file to WAV...")
|
||||
audio_file = convert_to_wav(video_path, offset)
|
||||
logging.debug("MAIN: Audio file converted succesfully")
|
||||
logging.debug("MAIN: Audio file converted successfully")
|
||||
else:
|
||||
if os.path.exists(path):
|
||||
logging.debug("MAIN: Local file path detected")
|
||||
@@ -1370,85 +1386,69 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
|
||||
results.append(transcription_result)
|
||||
logging.info(f"Transcription complete: {audio_file}")
|
||||
|
||||
if path.startswith('http'):
|
||||
# Delete the downloaded video file
|
||||
os.remove(video_path)
|
||||
logging.info(f"Deleted downloaded video file: {video_path}")
|
||||
|
||||
# Perform summarization based on the specified API
|
||||
if api_name and api_key:
|
||||
logging.debug(f"MAIN: Summarization being performed by {api_name}")
|
||||
json_file_path = audio_file.replace('.wav', '.segments.json')
|
||||
if api_name.lower() == 'openai':
|
||||
api_key = openai_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: trying to summarize with openAI")
|
||||
api_key = openai_api_key
|
||||
logging.debug(f"OpenAI: OpenAI API Key: {api_key}")
|
||||
summary = summarize_with_openai(api_key, json_file_path, openai_model)
|
||||
logging.debug(f"MAIN: trying to summarize with openAI")
|
||||
summary = summarize_with_openai(api_key, json_file_path, openai_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'anthropic':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "anthropic":
|
||||
api_key = anthropic_api_key
|
||||
try:
|
||||
logging.debug("MAIN: Trying to summarize with anthropic")
|
||||
api_key = anthropic_api_key
|
||||
logging.debug(f"Anthropic: Anthropic API Key: {api_key}")
|
||||
summary = summarize_with_claude(api_key, json_file_path, anthropic_model)
|
||||
logging.debug(f"MAIN: Trying to summarize with anthropic")
|
||||
summary = summarize_with_claude(api_key, json_file_path, anthropic_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'cohere':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "cohere":
|
||||
api_key = cohere_api_key
|
||||
try:
|
||||
logging.debug("Main: Trying to summarize with cohere")
|
||||
api_key = cohere_api_key
|
||||
logging.debug(f"Cohere: Cohere API Key: {api_key}")
|
||||
summary = summarize_with_cohere(api_key, json_file_path, cohere_model)
|
||||
logging.debug(f"MAIN: Trying to summarize with cohere")
|
||||
summary = summarize_with_cohere(api_key, json_file_path, cohere_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'groq':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "groq":
|
||||
api_key = groq_api_key
|
||||
try:
|
||||
logging.debug("Main: Trying to summarize with Groq")
|
||||
api_key = groq_api_key
|
||||
logging.debug(f"Groq: Groq API Key: {api_key}")
|
||||
summary = summarize_with_groq(api_key, json_file_path, groq_model)
|
||||
logging.debug(f"MAIN: Trying to summarize with Groq")
|
||||
summary = summarize_with_groq(api_key, json_file_path, groq_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'llama':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "llama":
|
||||
token = llama_api_key
|
||||
llama_ip = llama_api_IP
|
||||
try:
|
||||
logging.debug("Main: Trying to summarize with Llama.cpp")
|
||||
token = llama_api_key
|
||||
logging.debug(f"Llama.cpp: Llama.cpp API Key: {api_key}")
|
||||
llama_ip = llama_api_IP
|
||||
logging.debug(f"Llama.cpp: Llama.cpp API IP:Port : {llama_ip}")
|
||||
summary = summarize_with_llama(llama_ip, json_file_path, token)
|
||||
logging.debug(f"MAIN: Trying to summarize with Llama.cpp")
|
||||
summary = summarize_with_llama(llama_ip, json_file_path, token, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'kobold':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "kobold":
|
||||
token = kobold_api_key
|
||||
kobold_ip = kobold_api_IP
|
||||
try:
|
||||
logging.debug("Main: Trying to summarize with kobold.cpp")
|
||||
token = kobold_api_key
|
||||
logging.debug(f"kobold.cpp: Kobold.cpp API Key: {api_key}")
|
||||
kobold_ip = kobold_api_IP
|
||||
logging.debug(f"kobold.cpp: Kobold.cpp API IP:Port : {kobold_api_IP}")
|
||||
summary = summarize_with_kobold(kobold_ip, json_file_path)
|
||||
logging.debug(f"MAIN: Trying to summarize with kobold.cpp")
|
||||
summary = summarize_with_kobold(kobold_ip, json_file_path, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
elif api_name.lower() == 'ooba':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "ooba":
|
||||
token = ooba_api_key
|
||||
ooba_ip = ooba_api_IP
|
||||
try:
|
||||
logging.debug("Main: Trying to summarize with oobabooga")
|
||||
token = ooba_api_key
|
||||
logging.debug(f"oobabooga: ooba API Key: {api_key}")
|
||||
ooba_ip = ooba_api_IP
|
||||
logging.debug(f"oobabooga: ooba API IP:Port : {ooba_ip}")
|
||||
summary = summarize_with_oobabooga(ooba_ip, json_file_path)
|
||||
logging.debug(f"MAIN: Trying to summarize with oobabooga")
|
||||
summary = summarize_with_oobabooga(ooba_ip, json_file_path, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
if api_name.lower() == 'huggingface':
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "huggingface":
|
||||
api_key = huggingface_api_key
|
||||
try:
|
||||
logging.debug("MAIN: Trying to summarize with huggingface")
|
||||
api_key = huggingface_api_key
|
||||
logging.debug(f"huggingface: huggingface API Key: {api_key}")
|
||||
summarize_with_huggingface(api_key, json_file_path)
|
||||
logging.debug(f"MAIN: Trying to summarize with huggingface")
|
||||
summarize_with_huggingface(api_key, json_file_path, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
|
||||
else:
|
||||
logging.warning(f"Unsupported API: {api_name}")
|
||||
|
||||
517
summarize.py
517
summarize.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user