Update summarize.py

This commit is contained in:
Robert
2024-05-13 22:36:53 -07:00
parent d740d2b136
commit bb282cb6cc

View File

@@ -1383,62 +1383,62 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
logging.debug(f"MAIN: Summarization being performed by {api_name}")
json_file_path = audio_file.replace('.wav', '.segments.json')
if api_name.lower() == 'openai':
api_key = openai_api_key
openai_api_key = api_key if api_key else config.get('API', 'openai_api_key', fallback=None)
try:
logging.debug(f"MAIN: trying to summarize with openAI")
summary = summarize_with_openai(api_key, json_file_path, openai_model, custom_prompt)
summary = summarize_with_openai(openai_api_key, json_file_path, openai_model, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "anthropic":
api_key = anthropic_api_key
anthropic_api_key = api_key if api_key else config.get('API', 'anthropic_api_key', fallback=None)
try:
logging.debug(f"MAIN: Trying to summarize with anthropic")
summary = summarize_with_claude(api_key, json_file_path, anthropic_model, custom_prompt)
summary = summarize_with_claude(anthropic_api_key, json_file_path, anthropic_model, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "cohere":
api_key = cohere_api_key
cohere_api_key = api_key if api_key else config.get('API', 'cohere_api_key', fallback=None)
try:
logging.debug(f"MAIN: Trying to summarize with cohere")
summary = summarize_with_cohere(api_key, json_file_path, cohere_model, custom_prompt)
summary = summarize_with_cohere(cohere_api_key, json_file_path, cohere_model, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "groq":
api_key = groq_api_key
groq_api_key = api_key if api_key else config.get('API', 'groq_api_key', fallback=None)
try:
logging.debug(f"MAIN: Trying to summarize with Groq")
summary = summarize_with_groq(api_key, json_file_path, groq_model, custom_prompt)
summary = summarize_with_groq(groq_api_key, json_file_path, groq_model, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "llama":
token = llama_api_key
llama_token = api_key if api_key else config.get('API', 'llama_api_key', fallback=None)
llama_ip = llama_api_IP
try:
logging.debug(f"MAIN: Trying to summarize with Llama.cpp")
summary = summarize_with_llama(llama_ip, json_file_path, token, custom_prompt)
summary = summarize_with_llama(llama_ip, json_file_path, llama_token, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "kobold":
token = kobold_api_key
kobold_token = api_key if api_key else config.get('API', 'kobold_api_key', fallback=None)
kobold_ip = kobold_api_IP
try:
logging.debug(f"MAIN: Trying to summarize with kobold.cpp")
summary = summarize_with_kobold(kobold_ip, json_file_path, custom_prompt)
summary = summarize_with_kobold(kobold_ip, json_file_path, kobold_token, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "ooba":
token = ooba_api_key
ooba_token = api_key if api_key else config.get('API', 'ooba_api_key', fallback=None)
ooba_ip = ooba_api_IP
try:
logging.debug(f"MAIN: Trying to summarize with oobabooga")
summary = summarize_with_oobabooga(ooba_ip, json_file_path, custom_prompt)
summary = summarize_with_oobabooga(ooba_ip, json_file_path, ooba_token, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "
elif api_name.lower() == "huggingface":
api_key = huggingface_api_key
huggingface_api_key = api_key if api_key else config.get('API', 'huggingface_api_key', fallback=None)
try:
logging.debug(f"MAIN: Trying to summarize with huggingface")
summarize_with_huggingface(api_key, json_file_path, custom_prompt)
summarize_with_huggingface(huggingface_api_key, json_file_path, custom_prompt)
except requests.exceptions.ConnectionError:
requests.status_code = "Connection: "