mirror of
https://github.com/jlengrand/tldw.git
synced 2026-03-10 08:51:17 +00:00
API Selection working
This commit is contained in:
BIN
.gitignore
vendored
BIN
.gitignore
vendored
Binary file not shown.
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12 (tldw)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (tldw)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/tldw.iml" filepath="$PROJECT_DIR$/.idea/tldw.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
14
.idea/tldw.iml
generated
Normal file
14
.idea/tldw.iml
generated
Normal file
@@ -0,0 +1,14 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
@@ -1282,10 +1282,9 @@ def launch_ui(demo_mode=False):
|
||||
|
||||
iface = gr.Interface(
|
||||
# fn=lambda url, num_speakers, whisper_model, offset, api_name, api_key: process_url(url, num_speakers, whisper_model, offset, api_name=api_name, api_key=api_key, demo_mode=demo_mode),
|
||||
# fn=lambda url, num_speakers, whisper_model, offset, api_name, custom_prompt, api_key: process_url(url, num_speakers, whisper_model, offset, api_name, api_key, demo_mode=demo_mode),
|
||||
fn=lambda *args: process_url(*args, demo_mode=demo_mode),
|
||||
|
||||
fn=lambda url, num_speakers, whisper_model, offset, api_name, api_key: process_url(url, num_speakers, whisper_model, offset, api_name, api_key, demo_mode=demo_mode),
|
||||
|
||||
# fn=lambda *args: process_url(*args, demo_mode=demo_mode),
|
||||
inputs=inputs,
|
||||
outputs=[
|
||||
gr.components.Textbox(label="Transcription", value=lambda: "", max_lines=10),
|
||||
@@ -1532,7 +1531,8 @@ if __name__ == "__main__":
|
||||
launch_ui(demo_mode=args.demo_mode)
|
||||
|
||||
try:
|
||||
results = main(args.input_path, api_name=args.api_name, api_key=args.api_key, num_speakers=args.num_speakers, whisper_model=args.whisper_model, offset=args.offset, vad_filter=args.vad_filter, download_video_flag=args.video)
|
||||
results = main(input_path, api_name=api_name, api_key=api_key, num_speakers=num_speakers, whisper_model="small.en", offset=offset, vad_filter=vad_filter, download_video_flag=download_video_flag)
|
||||
results = main(args.input_path, api_name=args.api_name, api_key=args.api_key, num_speakers=args.num_speakers, whisper_model="small.en", offset=args.offset, vad_filter=args.vad_filter, download_video_flag=args.video)
|
||||
logging.info('Transcription process completed.')
|
||||
except Exception as e:
|
||||
logging.error('An error occurred during the transcription process.')
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[API]
|
||||
anthropic_api_key = <anthropic_api_key
|
||||
anthropic_api_key = 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
|
||||
anthropic_model = claude-3-sonnet-20240229
|
||||
cohere_api_key = <your_cohere_api_key>
|
||||
cohere_model = command-r-plus
|
||||
|
||||
@@ -26,8 +26,10 @@ fire==0.6.0
|
||||
flatbuffers==24.3.25
|
||||
fonttools==4.51.0
|
||||
fsspec==2024.3.1
|
||||
gradio==4.29.0
|
||||
gradio_client==0.16.1
|
||||
gradio
|
||||
#gradio==4.29.0
|
||||
gradio_client
|
||||
#gradio_client==0.16.1
|
||||
h11==0.14.0
|
||||
httpcore==1.0.5
|
||||
httptools==0.6.1
|
||||
@@ -89,8 +91,6 @@ timm==0.9.16
|
||||
tokenizers==0.15.2
|
||||
tomlkit==0.12.0
|
||||
toolz==0.12.1
|
||||
torch==2.2.2+cu121
|
||||
torchaudio==2.2.2+cu121
|
||||
torchvision==0.17.2
|
||||
tqdm==4.66.2
|
||||
transformers==4.39.3
|
||||
@@ -101,5 +101,12 @@ ujson==5.9.0
|
||||
urllib3==2.2.1
|
||||
uvicorn==0.29.0
|
||||
watchfiles==0.21.0
|
||||
websockets==11.0.3
|
||||
yt-dlp==2024.4.9
|
||||
websockets
|
||||
#websockets==11.0.3
|
||||
yt-dlp
|
||||
#yt-dlp==2024.4.9
|
||||
--extra-index-url https://download.pytorch.org/whl/cu113
|
||||
torch
|
||||
torchaudio
|
||||
#torch==2.2.2+cu121
|
||||
#torchaudio==2.2.2+cu121
|
||||
50
summarize.py
50
summarize.py
@@ -341,6 +341,7 @@ def process_local_file(file_path):
|
||||
#
|
||||
|
||||
def process_url(input_path, num_speakers=2, whisper_model="small.en", custom_prompt=None, offset=0, api_name=None, api_key=None, vad_filter=False, download_video_flag=False, demo_mode=False):
|
||||
custom_prompt = ""
|
||||
if demo_mode:
|
||||
# api_name = "<demo_mode_api>"
|
||||
# api_key = "<demo_mode_key>"
|
||||
@@ -766,7 +767,7 @@ def extract_text_from_segments(segments):
|
||||
|
||||
|
||||
|
||||
def summarize_with_openai(api_key, file_path, model):
|
||||
def summarize_with_openai(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("openai: Loading json data for summarization")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -779,7 +780,12 @@ def summarize_with_openai(api_key, file_path, model):
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
# headers = {
|
||||
# 'Authorization': f'Bearer {api_key}',
|
||||
# 'Content-Type': 'application/json'
|
||||
# }
|
||||
|
||||
logging.debug(f"openai: API Key is: {api_key}")
|
||||
logging.debug("openai: Preparing data + prompt for submittal")
|
||||
openai_prompt = f"{text} \n\n\n\n{custom_prompt}"
|
||||
data = {
|
||||
@@ -816,7 +822,7 @@ def summarize_with_openai(api_key, file_path, model):
|
||||
|
||||
|
||||
|
||||
def summarize_with_claude(api_key, file_path, model):
|
||||
def summarize_with_claude(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("anthropic: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -886,7 +892,7 @@ def summarize_with_claude(api_key, file_path, model):
|
||||
|
||||
|
||||
# Summarize with Cohere
|
||||
def summarize_with_cohere(api_key, file_path, model):
|
||||
def summarize_with_cohere(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logging.debug("cohere: Loading JSON data")
|
||||
@@ -930,7 +936,7 @@ def summarize_with_cohere(api_key, file_path, model):
|
||||
logging.error("Expected data not found in API response.")
|
||||
return "Expected data not found in API response."
|
||||
else:
|
||||
logging.error(f"cohere: API request failed with status code {response.status_code}: {resposne.text}")
|
||||
logging.error(f"cohere: API request failed with status code {response.status_code}: {response.text}")
|
||||
print(f"Failed to process summary, status code {response.status_code}: {response.text}")
|
||||
return f"cohere: API request failed: {response.text}"
|
||||
|
||||
@@ -941,7 +947,7 @@ def summarize_with_cohere(api_key, file_path, model):
|
||||
|
||||
|
||||
# https://console.groq.com/docs/quickstart
|
||||
def summarize_with_groq(api_key, file_path, model):
|
||||
def summarize_with_groq(api_key, file_path, model, custom_prompt):
|
||||
try:
|
||||
logging.debug("groq: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -997,7 +1003,7 @@ def summarize_with_groq(api_key, file_path, model):
|
||||
#
|
||||
# Local Summarization
|
||||
|
||||
def summarize_with_llama(api_url, file_path, token):
|
||||
def summarize_with_llama(api_url, file_path, token, custom_prompt):
|
||||
try:
|
||||
logging.debug("llama: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1045,7 +1051,7 @@ def summarize_with_llama(api_url, file_path, token):
|
||||
|
||||
|
||||
# https://lite.koboldai.net/koboldcpp_api#/api%2Fv1/post_api_v1_generate
|
||||
def summarize_with_kobold(api_url, file_path):
|
||||
def summarize_with_kobold(api_url, file_path,custom_prompt):
|
||||
try:
|
||||
logging.debug("kobold: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1096,7 +1102,7 @@ def summarize_with_kobold(api_url, file_path):
|
||||
|
||||
|
||||
# https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API
|
||||
def summarize_with_oobabooga(api_url, file_path):
|
||||
def summarize_with_oobabooga(api_url, file_path, custom_prompt):
|
||||
try:
|
||||
logging.debug("ooba: Loading JSON data")
|
||||
with open(file_path, 'r') as file:
|
||||
@@ -1336,61 +1342,61 @@ def main(input_path, api_name=None, api_key=None, num_speakers=2, whisper_model=
|
||||
api_key = openai_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: trying to summarize with openAI")
|
||||
summary = summarize_with_openai(api_key, json_file_path, openai_model)
|
||||
summary = summarize_with_openai(api_key, json_file_path, openai_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "anthropic":
|
||||
api_key = anthropic_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with anthropic")
|
||||
summary = summarize_with_claude(api_key, json_file_path, anthropic_model)
|
||||
summary = summarize_with_claude(api_key, json_file_path, anthropic_model, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "cohere":
|
||||
api_key = cohere_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with cohere")
|
||||
summary = summarize_with_cohere(api_key, json_file_path, cohere_model)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "groq":
|
||||
api_key = groq_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with Groq")
|
||||
summary = summarize_with_groq(api_key, json_file_path, groq_model)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "llama":
|
||||
token = llama_api_key
|
||||
llama_ip = llama_api_IP
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with Llama.cpp")
|
||||
summary = summarize_with_llama(llama_ip, json_file_path, token)
|
||||
summary = summarize_with_llama(llama_ip, json_file_path, token, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "kobold":
|
||||
token = kobold_api_key
|
||||
kobold_ip = kobold_api_IP
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with kobold.cpp")
|
||||
summary = summarize_with_kobold(kobold_ip, json_file_path)
|
||||
summary = summarize_with_kobold(kobold_ip, json_file_path, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "ooba":
|
||||
token = ooba_api_key
|
||||
ooba_ip = ooba_api_IP
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with oobabooga")
|
||||
summary = summarize_with_oobabooga(ooba_ip, json_file_path)
|
||||
summary = summarize_with_oobabooga(ooba_ip, json_file_path, custom_prompt)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
if api_name.lower() == "huggingface":
|
||||
requests.status_code = "Connection: "
|
||||
elif api_name.lower() == "huggingface":
|
||||
api_key = huggingface_api_key
|
||||
try:
|
||||
logging.debug(f"MAIN: Trying to summarize with huggingface")
|
||||
summarize_with_huggingface(api_key, json_file_path)
|
||||
except requests.exceptions.ConnectionError:
|
||||
r.status_code = "Connection: "
|
||||
requests.status_code = "Connection: "
|
||||
|
||||
else:
|
||||
logging.warning(f"Unsupported API: {api_name}")
|
||||
|
||||
Reference in New Issue
Block a user