mirror of
https://github.com/jlengrand/tldw.git
synced 2026-03-10 08:51:17 +00:00
Breaking app.py but fixing summarizer.py ....
This commit is contained in:
426
HF/app.py
426
HF/app.py
@@ -1,14 +1,22 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import gradio as gr
|
import argparse
|
||||||
import argparse, configparser, datetime, json, logging, os, platform, requests, shutil, subprocess, sys, time, unicodedata
|
import configparser
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import requests
|
||||||
|
import shutil
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import unicodedata
|
||||||
import zipfile
|
import zipfile
|
||||||
from datetime import datetime
|
|
||||||
import contextlib
|
import gradio as gr
|
||||||
import ffmpeg
|
|
||||||
import torch
|
import torch
|
||||||
import yt_dlp
|
import yt_dlp
|
||||||
|
|
||||||
|
|
||||||
#######
|
#######
|
||||||
# Function Sections
|
# Function Sections
|
||||||
#
|
#
|
||||||
@@ -349,7 +357,7 @@ def process_local_file(file_path):
|
|||||||
# Video Download/Handling
|
# Video Download/Handling
|
||||||
#
|
#
|
||||||
|
|
||||||
def process_url(input_path, num_speakers=2, whisper_model="small.en", offset=0, api_name=None, api_key=None, vad_filter=False, download_video_flag=False,custom_prompt=None, demo_mode=True):
|
def process_url(input_path, num_speakers=2, whisper_model="small.en", custom_prompt=None, offset=0, api_name=None, api_key=None, vad_filter=False, download_video_flag=False, demo_mode=False):
|
||||||
if demo_mode:
|
if demo_mode:
|
||||||
api_name = "huggingface"
|
api_name = "huggingface"
|
||||||
api_key = os.environ.get(HF_TOKEN)
|
api_key = os.environ.get(HF_TOKEN)
|
||||||
@@ -646,115 +654,115 @@ def speech_to_text(audio_file_path, selected_source_lang='en', whisper_model='sm
|
|||||||
# TODO: https://huggingface.co/pyannote/speaker-diarization-3.1
|
# TODO: https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||||
# embedding_model = "pyannote/embedding", embedding_size=512
|
# embedding_model = "pyannote/embedding", embedding_size=512
|
||||||
# embedding_model = "speechbrain/spkrec-ecapa-voxceleb", embedding_size=192
|
# embedding_model = "speechbrain/spkrec-ecapa-voxceleb", embedding_size=192
|
||||||
def speaker_diarize(video_file_path, segments, embedding_model = "pyannote/embedding", embedding_size=512, num_speakers=0):
|
# def speaker_diarize(video_file_path, segments, embedding_model = "pyannote/embedding", embedding_size=512, num_speakers=0):
|
||||||
"""
|
# """
|
||||||
1. Generating speaker embeddings for each segments.
|
# 1. Generating speaker embeddings for each segments.
|
||||||
2. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
# 2. Applying agglomerative clustering on the embeddings to identify the speaker for each segment.
|
||||||
"""
|
# """
|
||||||
try:
|
# try:
|
||||||
from pyannote.audio import Audio
|
# from pyannote.audio import Audio
|
||||||
from pyannote.core import Segment
|
# from pyannote.core import Segment
|
||||||
from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
# from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
import pandas as pd
|
# import pandas as pd
|
||||||
from sklearn.cluster import AgglomerativeClustering
|
# from sklearn.cluster import AgglomerativeClustering
|
||||||
from sklearn.metrics import silhouette_score
|
# from sklearn.metrics import silhouette_score
|
||||||
import tqdm
|
# import tqdm
|
||||||
import wave
|
# import wave
|
||||||
|
#
|
||||||
embedding_model = PretrainedSpeakerEmbedding( embedding_model, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
# embedding_model = PretrainedSpeakerEmbedding( embedding_model, device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
||||||
|
#
|
||||||
|
#
|
||||||
_,file_ending = os.path.splitext(f'{video_file_path}')
|
# _,file_ending = os.path.splitext(f'{video_file_path}')
|
||||||
audio_file = video_file_path.replace(file_ending, ".wav")
|
# audio_file = video_file_path.replace(file_ending, ".wav")
|
||||||
out_file = video_file_path.replace(file_ending, ".diarize.json")
|
# out_file = video_file_path.replace(file_ending, ".diarize.json")
|
||||||
|
#
|
||||||
logging.debug("getting duration of audio file")
|
# logging.debug("getting duration of audio file")
|
||||||
with contextlib.closing(wave.open(audio_file,'r')) as f:
|
# with contextlib.closing(wave.open(audio_file,'r')) as f:
|
||||||
frames = f.getnframes()
|
# frames = f.getnframes()
|
||||||
rate = f.getframerate()
|
# rate = f.getframerate()
|
||||||
duration = frames / float(rate)
|
# duration = frames / float(rate)
|
||||||
logging.debug("duration of audio file obtained")
|
# logging.debug("duration of audio file obtained")
|
||||||
print(f"duration of audio file: {duration}")
|
# print(f"duration of audio file: {duration}")
|
||||||
|
#
|
||||||
def segment_embedding(segment):
|
# def segment_embedding(segment):
|
||||||
logging.debug("Creating embedding")
|
# logging.debug("Creating embedding")
|
||||||
audio = Audio()
|
# audio = Audio()
|
||||||
start = segment["start"]
|
# start = segment["start"]
|
||||||
end = segment["end"]
|
# end = segment["end"]
|
||||||
|
#
|
||||||
# Enforcing a minimum segment length
|
# # Enforcing a minimum segment length
|
||||||
if end-start < 0.3:
|
# if end-start < 0.3:
|
||||||
padding = 0.3-(end-start)
|
# padding = 0.3-(end-start)
|
||||||
start -= padding/2
|
# start -= padding/2
|
||||||
end += padding/2
|
# end += padding/2
|
||||||
print('Padded segment because it was too short:',segment)
|
# print('Padded segment because it was too short:',segment)
|
||||||
|
#
|
||||||
# Whisper overshoots the end timestamp in the last segment
|
# # Whisper overshoots the end timestamp in the last segment
|
||||||
end = min(duration, end)
|
# end = min(duration, end)
|
||||||
# clip audio and embed
|
# # clip audio and embed
|
||||||
clip = Segment(start, end)
|
# clip = Segment(start, end)
|
||||||
waveform, sample_rate = audio.crop(audio_file, clip)
|
# waveform, sample_rate = audio.crop(audio_file, clip)
|
||||||
return embedding_model(waveform[None])
|
# return embedding_model(waveform[None])
|
||||||
|
#
|
||||||
embeddings = np.zeros(shape=(len(segments), embedding_size))
|
# embeddings = np.zeros(shape=(len(segments), embedding_size))
|
||||||
for i, segment in enumerate(tqdm.tqdm(segments)):
|
# for i, segment in enumerate(tqdm.tqdm(segments)):
|
||||||
embeddings[i] = segment_embedding(segment)
|
# embeddings[i] = segment_embedding(segment)
|
||||||
embeddings = np.nan_to_num(embeddings)
|
# embeddings = np.nan_to_num(embeddings)
|
||||||
print(f'Embedding shape: {embeddings.shape}')
|
# print(f'Embedding shape: {embeddings.shape}')
|
||||||
|
#
|
||||||
if num_speakers == 0:
|
# if num_speakers == 0:
|
||||||
# Find the best number of speakers
|
# # Find the best number of speakers
|
||||||
score_num_speakers = {}
|
# score_num_speakers = {}
|
||||||
|
#
|
||||||
for num_speakers in range(2, 10+1):
|
# for num_speakers in range(2, 10+1):
|
||||||
clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
|
# clustering = AgglomerativeClustering(num_speakers).fit(embeddings)
|
||||||
score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
|
# score = silhouette_score(embeddings, clustering.labels_, metric='euclidean')
|
||||||
score_num_speakers[num_speakers] = score
|
# score_num_speakers[num_speakers] = score
|
||||||
best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
|
# best_num_speaker = max(score_num_speakers, key=lambda x:score_num_speakers[x])
|
||||||
print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
|
# print(f"The best number of speakers: {best_num_speaker} with {score_num_speakers[best_num_speaker]} score")
|
||||||
else:
|
# else:
|
||||||
best_num_speaker = num_speakers
|
# best_num_speaker = num_speakers
|
||||||
|
#
|
||||||
# Assign speaker label
|
# # Assign speaker label
|
||||||
clustering = AgglomerativeClustering(best_num_speaker).fit(embeddings)
|
# clustering = AgglomerativeClustering(best_num_speaker).fit(embeddings)
|
||||||
labels = clustering.labels_
|
# labels = clustering.labels_
|
||||||
for i in range(len(segments)):
|
# for i in range(len(segments)):
|
||||||
segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
|
# segments[i]["speaker"] = 'SPEAKER ' + str(labels[i] + 1)
|
||||||
|
#
|
||||||
with open(out_file,'w') as f:
|
# with open(out_file,'w') as f:
|
||||||
f.write(json.dumps(segments, indent=2))
|
# f.write(json.dumps(segments, indent=2))
|
||||||
|
#
|
||||||
# Make CSV output
|
# # Make CSV output
|
||||||
def convert_time(secs):
|
# def convert_time(secs):
|
||||||
return datetime.timedelta(seconds=round(secs))
|
# return datetime.timedelta(seconds=round(secs))
|
||||||
|
#
|
||||||
objects = {
|
# objects = {
|
||||||
'Start' : [],
|
# 'Start' : [],
|
||||||
'End': [],
|
# 'End': [],
|
||||||
'Speaker': [],
|
# 'Speaker': [],
|
||||||
'Text': []
|
# 'Text': []
|
||||||
}
|
# }
|
||||||
text = ''
|
# text = ''
|
||||||
for (i, segment) in enumerate(segments):
|
# for (i, segment) in enumerate(segments):
|
||||||
if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
|
# if i == 0 or segments[i - 1]["speaker"] != segment["speaker"]:
|
||||||
objects['Start'].append(str(convert_time(segment["start"])))
|
# objects['Start'].append(str(convert_time(segment["start"])))
|
||||||
objects['Speaker'].append(segment["speaker"])
|
# objects['Speaker'].append(segment["speaker"])
|
||||||
if i != 0:
|
# if i != 0:
|
||||||
objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
# objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||||
objects['Text'].append(text)
|
# objects['Text'].append(text)
|
||||||
text = ''
|
# text = ''
|
||||||
text += segment["text"] + ' '
|
# text += segment["text"] + ' '
|
||||||
objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
# objects['End'].append(str(convert_time(segments[i - 1]["end"])))
|
||||||
objects['Text'].append(text)
|
# objects['Text'].append(text)
|
||||||
|
#
|
||||||
save_path = video_file_path.replace(file_ending, ".csv")
|
# save_path = video_file_path.replace(file_ending, ".csv")
|
||||||
df_results = pd.DataFrame(objects)
|
# df_results = pd.DataFrame(objects)
|
||||||
df_results.to_csv(save_path)
|
# df_results.to_csv(save_path)
|
||||||
return df_results, save_path
|
# return df_results, save_path
|
||||||
|
#
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
raise RuntimeError("Error Running inference with local model", e)
|
# raise RuntimeError("Error Running inference with local model", e)
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
####################################################################################################################################
|
####################################################################################################################################
|
||||||
@@ -777,12 +785,12 @@ def extract_text_from_segments(segments):
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
def summarize_with_openai(api_key, file_path, model):
|
def summarize_with_openai(api_key, file_path, model, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.debug("openai: Loading json data for summarization")
|
logging.debug("openai: Loading json data for summarization")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
segments = json.load(file)
|
segments = json.load(file)
|
||||||
|
|
||||||
logging.debug("openai: Extracting text from the segments")
|
logging.debug("openai: Extracting text from the segments")
|
||||||
text = extract_text_from_segments(segments)
|
text = extract_text_from_segments(segments)
|
||||||
|
|
||||||
@@ -790,9 +798,14 @@ def summarize_with_openai(api_key, file_path, model):
|
|||||||
'Authorization': f'Bearer {api_key}',
|
'Authorization': f'Bearer {api_key}',
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
# headers = {
|
||||||
|
# 'Authorization': f'Bearer {api_key}',
|
||||||
|
# 'Content-Type': 'application/json'
|
||||||
|
# }
|
||||||
|
|
||||||
|
logging.debug(f"openai: API Key is: {api_key}")
|
||||||
logging.debug("openai: Preparing data + prompt for submittal")
|
logging.debug("openai: Preparing data + prompt for submittal")
|
||||||
openai_prompt = f"{text} \n\n\n\n{prompt_text}"
|
openai_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": [
|
"messages": [
|
||||||
@@ -810,7 +823,7 @@ def summarize_with_openai(api_key, file_path, model):
|
|||||||
}
|
}
|
||||||
logging.debug("openai: Posting request")
|
logging.debug("openai: Posting request")
|
||||||
response = requests.post('https://api.openai.com/v1/chat/completions', headers=headers, json=data)
|
response = requests.post('https://api.openai.com/v1/chat/completions', headers=headers, json=data)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
summary = response.json()['choices'][0]['message']['content'].strip()
|
summary = response.json()['choices'][0]['message']['content'].strip()
|
||||||
logging.debug("openai: Summarization successful")
|
logging.debug("openai: Summarization successful")
|
||||||
@@ -826,13 +839,12 @@ def summarize_with_openai(api_key, file_path, model):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_with_claude(api_key, file_path, model, custom_prompt):
|
||||||
def summarize_with_claude(api_key, file_path, model):
|
|
||||||
try:
|
try:
|
||||||
logging.debug("anthropic: Loading JSON data")
|
logging.debug("anthropic: Loading JSON data")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
segments = json.load(file)
|
segments = json.load(file)
|
||||||
|
|
||||||
logging.debug("anthropic: Extracting text from the segments file")
|
logging.debug("anthropic: Extracting text from the segments file")
|
||||||
text = extract_text_from_segments(segments)
|
text = extract_text_from_segments(segments)
|
||||||
|
|
||||||
@@ -841,16 +853,17 @@ def summarize_with_claude(api_key, file_path, model):
|
|||||||
'anthropic-version': '2023-06-01',
|
'anthropic-version': '2023-06-01',
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.debug("anthropic: Prepping data + prompt for submittal")
|
anthropic_prompt = custom_prompt
|
||||||
|
logging.debug("anthropic: Prompt is {anthropic_prompt}")
|
||||||
user_message = {
|
user_message = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"{text} \n\n\n\n{prompt_text}"
|
"content": f"{text} \n\n\n\n{anthropic_prompt}"
|
||||||
}
|
}
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_tokens": 4096, # max _possible_ tokens to return
|
"max_tokens": 4096, # max _possible_ tokens to return
|
||||||
"messages": [user_message],
|
"messages": [user_message],
|
||||||
"stop_sequences": ["\n\nHuman:"],
|
"stop_sequences": ["\n\nHuman:"],
|
||||||
"temperature": 0.7,
|
"temperature": 0.7,
|
||||||
@@ -862,17 +875,17 @@ def summarize_with_claude(api_key, file_path, model):
|
|||||||
"stream": False,
|
"stream": False,
|
||||||
"system": "You are a professional summarizer."
|
"system": "You are a professional summarizer."
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.debug("anthropic: Posting request to API")
|
logging.debug("anthropic: Posting request to API")
|
||||||
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, json=data)
|
response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, json=data)
|
||||||
|
|
||||||
# Check if the status code indicates success
|
# Check if the status code indicates success
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
logging.debug("anthropic: Post submittal successful")
|
logging.debug("anthropic: Post submittal successful")
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
try:
|
try:
|
||||||
summary = response_data['content'][0]['text'].strip()
|
summary = response_data['content'][0]['text'].strip()
|
||||||
logging.debug("anthropic: Summarization succesful")
|
logging.debug("anthropic: Summarization successful")
|
||||||
print("Summary processed successfully.")
|
print("Summary processed successfully.")
|
||||||
return summary
|
return summary
|
||||||
except (IndexError, KeyError) as e:
|
except (IndexError, KeyError) as e:
|
||||||
@@ -894,9 +907,8 @@ def summarize_with_claude(api_key, file_path, model):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Summarize with Cohere
|
# Summarize with Cohere
|
||||||
def summarize_with_cohere(api_key, file_path, model):
|
def summarize_with_cohere(api_key, file_path, model, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
logging.debug("cohere: Loading JSON data")
|
logging.debug("cohere: Loading JSON data")
|
||||||
@@ -912,7 +924,9 @@ def summarize_with_cohere(api_key, file_path, model):
|
|||||||
'Authorization': f'Bearer {api_key}'
|
'Authorization': f'Bearer {api_key}'
|
||||||
}
|
}
|
||||||
|
|
||||||
cohere_prompt = f"{text} \n\n\n\n{prompt_text}"
|
cohere_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||||
|
logging.debug("cohere: Prompt being sent is {cohere_prompt}")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"chat_history": [
|
"chat_history": [
|
||||||
{"role": "USER", "message": cohere_prompt}
|
{"role": "USER", "message": cohere_prompt}
|
||||||
@@ -938,7 +952,7 @@ def summarize_with_cohere(api_key, file_path, model):
|
|||||||
logging.error("Expected data not found in API response.")
|
logging.error("Expected data not found in API response.")
|
||||||
return "Expected data not found in API response."
|
return "Expected data not found in API response."
|
||||||
else:
|
else:
|
||||||
logging.error(f"cohere: API request failed with status code {response.status_code}: {resposne.text}")
|
logging.error(f"cohere: API request failed with status code {response.status_code}: {response.text}")
|
||||||
print(f"Failed to process summary, status code {response.status_code}: {response.text}")
|
print(f"Failed to process summary, status code {response.status_code}: {response.text}")
|
||||||
return f"cohere: API request failed: {response.text}"
|
return f"cohere: API request failed: {response.text}"
|
||||||
|
|
||||||
@@ -947,9 +961,8 @@ def summarize_with_cohere(api_key, file_path, model):
|
|||||||
return f"cohere: Error occurred while processing summary with Cohere: {str(e)}"
|
return f"cohere: Error occurred while processing summary with Cohere: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# https://console.groq.com/docs/quickstart
|
# https://console.groq.com/docs/quickstart
|
||||||
def summarize_with_groq(api_key, file_path, model):
|
def summarize_with_groq(api_key, file_path, model, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.debug("groq: Loading JSON data")
|
logging.debug("groq: Loading JSON data")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
@@ -963,7 +976,9 @@ def summarize_with_groq(api_key, file_path, model):
|
|||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
}
|
}
|
||||||
|
|
||||||
groq_prompt = f"{text} \n\n\n\n{prompt_text}"
|
groq_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||||
|
logging.debug("groq: Prompt being sent is {groq_prompt}")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
@@ -1003,7 +1018,7 @@ def summarize_with_groq(api_key, file_path, model):
|
|||||||
#
|
#
|
||||||
# Local Summarization
|
# Local Summarization
|
||||||
|
|
||||||
def summarize_with_llama(api_url, file_path, token):
|
def summarize_with_llama(api_url, file_path, token, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.debug("llama: Loading JSON data")
|
logging.debug("llama: Loading JSON data")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
@@ -1016,17 +1031,17 @@ def summarize_with_llama(api_url, file_path, token):
|
|||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'content-type': 'application/json',
|
'content-type': 'application/json',
|
||||||
}
|
}
|
||||||
if len(token)>5:
|
if len(token) > 5:
|
||||||
headers['Authorization'] = f'Bearer {token}'
|
headers['Authorization'] = f'Bearer {token}'
|
||||||
|
|
||||||
|
llama_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||||
|
logging.debug("llama: Prompt being sent is {llama_prompt}")
|
||||||
|
|
||||||
llama_prompt = f"{text} \n\n\n\n{prompt_text}"
|
|
||||||
logging.debug(f"llama: Complete prompt is: {llama_prompt}")
|
|
||||||
data = {
|
data = {
|
||||||
"prompt": llama_prompt
|
"prompt": llama_prompt
|
||||||
}
|
}
|
||||||
|
|
||||||
#logging.debug(f"llama: Submitting request to API endpoint {llama_prompt}")
|
logging.debug("llama: Submitting request to API endpoint")
|
||||||
print("llama: Submitting request to API endpoint")
|
print("llama: Submitting request to API endpoint")
|
||||||
response = requests.post(api_url, headers=headers, json=data)
|
response = requests.post(api_url, headers=headers, json=data)
|
||||||
response_data = response.json()
|
response_data = response.json()
|
||||||
@@ -1048,9 +1063,8 @@ def summarize_with_llama(api_url, file_path, token):
|
|||||||
return f"llama: Error occurred while processing summary with llama: {str(e)}"
|
return f"llama: Error occurred while processing summary with llama: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# https://lite.koboldai.net/koboldcpp_api#/api%2Fv1/post_api_v1_generate
|
# https://lite.koboldai.net/koboldcpp_api#/api%2Fv1/post_api_v1_generate
|
||||||
def summarize_with_kobold(api_url, file_path):
|
def summarize_with_kobold(api_url, file_path, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.debug("kobold: Loading JSON data")
|
logging.debug("kobold: Loading JSON data")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
@@ -1063,9 +1077,11 @@ def summarize_with_kobold(api_url, file_path):
|
|||||||
'accept': 'application/json',
|
'accept': 'application/json',
|
||||||
'content-type': 'application/json',
|
'content-type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kobold_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||||
|
logging.debug("kobold: Prompt being sent is {kobold_prompt}")
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
kobold_prompt = f"{text} \n\n\n\n{prompt_text}"
|
|
||||||
logging.debug(kobold_prompt)
|
|
||||||
# Values literally c/p from the api docs....
|
# Values literally c/p from the api docs....
|
||||||
data = {
|
data = {
|
||||||
"max_context_length": 8096,
|
"max_context_length": 8096,
|
||||||
@@ -1097,9 +1113,8 @@ def summarize_with_kobold(api_url, file_path):
|
|||||||
return f"kobold: Error occurred while processing summary with kobold: {str(e)}"
|
return f"kobold: Error occurred while processing summary with kobold: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
# https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
||||||
def summarize_with_oobabooga(api_url, file_path):
|
def summarize_with_oobabooga(api_url, file_path, custom_prompt):
|
||||||
try:
|
try:
|
||||||
logging.debug("ooba: Loading JSON data")
|
logging.debug("ooba: Loading JSON data")
|
||||||
with open(file_path, 'r') as file:
|
with open(file_path, 'r') as file:
|
||||||
@@ -1114,14 +1129,15 @@ def summarize_with_oobabooga(api_url, file_path):
|
|||||||
'content-type': 'application/json',
|
'content-type': 'application/json',
|
||||||
}
|
}
|
||||||
|
|
||||||
#prompt_text = "I like to eat cake and bake cakes. I am a baker. I work in a french bakery baking cakes. It is a fun job. I have been baking cakes for ten years. I also bake lots of other baked goods, but cakes are my favorite."
|
# prompt_text = "I like to eat cake and bake cakes. I am a baker. I work in a French bakery baking cakes. It is a fun job. I have been baking cakes for ten years. I also bake lots of other baked goods, but cakes are my favorite."
|
||||||
#prompt_text += f"\n\n{text}" # Uncomment this line if you want to include the text variable
|
# prompt_text += f"\n\n{text}" # Uncomment this line if you want to include the text variable
|
||||||
ooba_prompt = f"{text}\n\n\n\n{prompt_text}"
|
ooba_prompt = "{text}\n\n\n\n{custom_prompt}"
|
||||||
|
logging.debug("ooba: Prompt being sent is {ooba_prompt}")
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"mode": "chat",
|
"mode": "chat",
|
||||||
"character": "Example",
|
"character": "Example",
|
||||||
"messages": [{"role": "user", "content": prompt_text}]
|
"messages": [{"role": "user", "content": ooba_prompt}]
|
||||||
}
|
}
|
||||||
|
|
||||||
logging.debug("ooba: Submitting request to API endpoint")
|
logging.debug("ooba: Submitting request to API endpoint")
|
||||||
@@ -1350,7 +1366,7 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
|
|||||||
logging.debug("MAIN: Video downloaded successfully")
|
logging.debug("MAIN: Video downloaded successfully")
|
||||||
logging.debug("MAIN: Converting video file to WAV...")
|
logging.debug("MAIN: Converting video file to WAV...")
|
||||||
audio_file = convert_to_wav(video_path, offset)
|
audio_file = convert_to_wav(video_path, offset)
|
||||||
logging.debug("MAIN: Audio file converted succesfully")
|
logging.debug("MAIN: Audio file converted successfully")
|
||||||
else:
|
else:
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
logging.debug("MAIN: Local file path detected")
|
logging.debug("MAIN: Local file path detected")
|
||||||
@@ -1370,85 +1386,69 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
|
|||||||
results.append(transcription_result)
|
results.append(transcription_result)
|
||||||
logging.info(f"Transcription complete: {audio_file}")
|
logging.info(f"Transcription complete: {audio_file}")
|
||||||
|
|
||||||
if path.startswith('http'):
|
|
||||||
# Delete the downloaded video file
|
|
||||||
os.remove(video_path)
|
|
||||||
logging.info(f"Deleted downloaded video file: {video_path}")
|
|
||||||
|
|
||||||
# Perform summarization based on the specified API
|
# Perform summarization based on the specified API
|
||||||
if api_name and api_key:
|
if api_name and api_key:
|
||||||
logging.debug(f"MAIN: Summarization being performed by {api_name}")
|
logging.debug(f"MAIN: Summarization being performed by {api_name}")
|
||||||
json_file_path = audio_file.replace('.wav', '.segments.json')
|
json_file_path = audio_file.replace('.wav', '.segments.json')
|
||||||
if api_name.lower() == 'openai':
|
if api_name.lower() == 'openai':
|
||||||
|
api_key = openai_api_key
|
||||||
try:
|
try:
|
||||||
logging.debug(f"MAIN: trying to summarize with openAI")
|
logging.debug(f"MAIN: trying to summarize with openAI")
|
||||||
api_key = openai_api_key
|
summary = summarize_with_openai(api_key, json_file_path, openai_model, custom_prompt)
|
||||||
logging.debug(f"OpenAI: OpenAI API Key: {api_key}")
|
|
||||||
summary = summarize_with_openai(api_key, json_file_path, openai_model)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'anthropic':
|
elif api_name.lower() == "anthropic":
|
||||||
|
api_key = anthropic_api_key
|
||||||
try:
|
try:
|
||||||
logging.debug("MAIN: Trying to summarize with anthropic")
|
logging.debug(f"MAIN: Trying to summarize with anthropic")
|
||||||
api_key = anthropic_api_key
|
summary = summarize_with_claude(api_key, json_file_path, anthropic_model, custom_prompt)
|
||||||
logging.debug(f"Anthropic: Anthropic API Key: {api_key}")
|
|
||||||
summary = summarize_with_claude(api_key, json_file_path, anthropic_model)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'cohere':
|
elif api_name.lower() == "cohere":
|
||||||
|
api_key = cohere_api_key
|
||||||
try:
|
try:
|
||||||
logging.debug("Main: Trying to summarize with cohere")
|
logging.debug(f"MAIN: Trying to summarize with cohere")
|
||||||
api_key = cohere_api_key
|
summary = summarize_with_cohere(api_key, json_file_path, cohere_model, custom_prompt)
|
||||||
logging.debug(f"Cohere: Cohere API Key: {api_key}")
|
|
||||||
summary = summarize_with_cohere(api_key, json_file_path, cohere_model)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'groq':
|
elif api_name.lower() == "groq":
|
||||||
|
api_key = groq_api_key
|
||||||
try:
|
try:
|
||||||
logging.debug("Main: Trying to summarize with Groq")
|
logging.debug(f"MAIN: Trying to summarize with Groq")
|
||||||
api_key = groq_api_key
|
summary = summarize_with_groq(api_key, json_file_path, groq_model, custom_prompt)
|
||||||
logging.debug(f"Groq: Groq API Key: {api_key}")
|
|
||||||
summary = summarize_with_groq(api_key, json_file_path, groq_model)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'llama':
|
elif api_name.lower() == "llama":
|
||||||
|
token = llama_api_key
|
||||||
|
llama_ip = llama_api_IP
|
||||||
try:
|
try:
|
||||||
logging.debug("Main: Trying to summarize with Llama.cpp")
|
logging.debug(f"MAIN: Trying to summarize with Llama.cpp")
|
||||||
token = llama_api_key
|
summary = summarize_with_llama(llama_ip, json_file_path, token, custom_prompt)
|
||||||
logging.debug(f"Llama.cpp: Llama.cpp API Key: {api_key}")
|
|
||||||
llama_ip = llama_api_IP
|
|
||||||
logging.debug(f"Llama.cpp: Llama.cpp API IP:Port : {llama_ip}")
|
|
||||||
summary = summarize_with_llama(llama_ip, json_file_path, token)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'kobold':
|
elif api_name.lower() == "kobold":
|
||||||
|
token = kobold_api_key
|
||||||
|
kobold_ip = kobold_api_IP
|
||||||
try:
|
try:
|
||||||
logging.debug("Main: Trying to summarize with kobold.cpp")
|
logging.debug(f"MAIN: Trying to summarize with kobold.cpp")
|
||||||
token = kobold_api_key
|
summary = summarize_with_kobold(kobold_ip, json_file_path, custom_prompt)
|
||||||
logging.debug(f"kobold.cpp: Kobold.cpp API Key: {api_key}")
|
|
||||||
kobold_ip = kobold_api_IP
|
|
||||||
logging.debug(f"kobold.cpp: Kobold.cpp API IP:Port : {kobold_api_IP}")
|
|
||||||
summary = summarize_with_kobold(kobold_ip, json_file_path)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
elif api_name.lower() == 'ooba':
|
elif api_name.lower() == "ooba":
|
||||||
|
token = ooba_api_key
|
||||||
|
ooba_ip = ooba_api_IP
|
||||||
try:
|
try:
|
||||||
logging.debug("Main: Trying to summarize with oobabooga")
|
logging.debug(f"MAIN: Trying to summarize with oobabooga")
|
||||||
token = ooba_api_key
|
summary = summarize_with_oobabooga(ooba_ip, json_file_path, custom_prompt)
|
||||||
logging.debug(f"oobabooga: ooba API Key: {api_key}")
|
|
||||||
ooba_ip = ooba_api_IP
|
|
||||||
logging.debug(f"oobabooga: ooba API IP:Port : {ooba_ip}")
|
|
||||||
summary = summarize_with_oobabooga(ooba_ip, json_file_path)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
if api_name.lower() == 'huggingface':
|
elif api_name.lower() == "huggingface":
|
||||||
|
api_key = huggingface_api_key
|
||||||
try:
|
try:
|
||||||
logging.debug("MAIN: Trying to summarize with huggingface")
|
logging.debug(f"MAIN: Trying to summarize with huggingface")
|
||||||
api_key = huggingface_api_key
|
summarize_with_huggingface(api_key, json_file_path, custom_prompt)
|
||||||
logging.debug(f"huggingface: huggingface API Key: {api_key}")
|
|
||||||
summarize_with_huggingface(api_key, json_file_path)
|
|
||||||
except requests.exceptions.ConnectionError:
|
except requests.exceptions.ConnectionError:
|
||||||
r.status_code = "Connection: "
|
requests.status_code = "Connection: "
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logging.warning(f"Unsupported API: {api_name}")
|
logging.warning(f"Unsupported API: {api_name}")
|
||||||
|
|||||||
517
summarize.py
517
summarize.py
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user