mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Refa: OpenAI whisper-1 (#9552)
### What problem does this PR solve? Refactor OpenAI to enable audio parsing. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring
This commit is contained in:
@ -41,6 +41,9 @@ def set_dialog():
|
|||||||
return get_data_error_result(message="Dialog name can't be empty.")
|
return get_data_error_result(message="Dialog name can't be empty.")
|
||||||
if len(name.encode("utf-8")) > 255:
|
if len(name.encode("utf-8")) > 255:
|
||||||
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
return get_data_error_result(message=f"Dialog name length is {len(name)} which is larger than 255")
|
||||||
|
|
||||||
|
if DialogService.get_or_none(tenant_id=current_user.id, name=name):
|
||||||
|
return get_data_error_result(message=f"Duplicated Dialog name {name}.")
|
||||||
description = req.get("description", "A helpful dialog")
|
description = req.get("description", "A helpful dialog")
|
||||||
icon = req.get("icon", "")
|
icon = req.get("icon", "")
|
||||||
top_n = req.get("top_n", 6)
|
top_n = req.get("top_n", 6)
|
||||||
|
|||||||
@ -505,6 +505,24 @@
|
|||||||
"tags": "RE-RANK,4k",
|
"tags": "RE-RANK,4k",
|
||||||
"max_tokens": 4000,
|
"max_tokens": 4000,
|
||||||
"model_type": "rerank"
|
"model_type": "rerank"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen-audio-asr",
|
||||||
|
"tags": "SPEECH2TEXT,8k",
|
||||||
|
"max_tokens": 8000,
|
||||||
|
"model_type": "speech2text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen-audio-asr-latest",
|
||||||
|
"tags": "SPEECH2TEXT,8k",
|
||||||
|
"max_tokens": 8000,
|
||||||
|
"model_type": "speech2text"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm_name": "qwen-audio-asr-1204",
|
||||||
|
"tags": "SPEECH2TEXT,8k",
|
||||||
|
"max_tokens": 8000,
|
||||||
|
"model_type": "speech2text"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@ -14,31 +14,48 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import tempfile
|
||||||
|
|
||||||
from api.db import LLMType
|
from api.db import LLMType
|
||||||
from rag.nlp import rag_tokenizer
|
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from rag.nlp import tokenize
|
from rag.nlp import rag_tokenizer, tokenize
|
||||||
|
|
||||||
|
|
||||||
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs):
|
||||||
doc = {
|
doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))}
|
||||||
"docnm_kwd": filename,
|
|
||||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
|
||||||
}
|
|
||||||
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
doc["title_sm_tks"] = rag_tokenizer.fine_grained_tokenize(doc["title_tks"])
|
||||||
|
|
||||||
# is it English
|
# is it English
|
||||||
eng = lang.lower() == "english" # is_english(sections)
|
eng = lang.lower() == "english" # is_english(sections)
|
||||||
try:
|
try:
|
||||||
|
_, ext = os.path.splitext(filename)
|
||||||
|
if not ext:
|
||||||
|
raise RuntimeError("No extension detected.")
|
||||||
|
|
||||||
|
if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]:
|
||||||
|
raise RuntimeError(f"Extension {ext} is not supported yet.")
|
||||||
|
|
||||||
|
tmp_path = ""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmpf:
|
||||||
|
tmpf.write(binary)
|
||||||
|
tmpf.flush()
|
||||||
|
tmp_path = os.path.abspath(tmpf.name)
|
||||||
|
|
||||||
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
|
callback(0.1, "USE Sequence2Txt LLM to transcription the audio")
|
||||||
seq2txt_mdl = LLMBundle(tenant_id, LLMType.SPEECH2TEXT, lang=lang)
|
seq2txt_mdl = LLMBundle(tenant_id, LLMType.SPEECH2TEXT, lang=lang)
|
||||||
ans = seq2txt_mdl.transcription(binary)
|
ans = seq2txt_mdl.transcription(tmp_path)
|
||||||
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
|
callback(0.8, "Sequence2Txt LLM respond: %s ..." % ans[:32])
|
||||||
|
|
||||||
tokenize(doc, ans, eng)
|
tokenize(doc, ans, eng)
|
||||||
return [doc]
|
return [doc]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
callback(prog=-1, msg=str(e))
|
callback(prog=-1, msg=str(e))
|
||||||
|
finally:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return []
|
return []
|
||||||
|
|||||||
@ -35,8 +35,9 @@ class Base(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def transcription(self, audio, **kwargs):
|
def transcription(self, audio_path, **kwargs):
|
||||||
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio, response_format="text")
|
audio_file = open(audio_path, "rb")
|
||||||
|
transcription = self.client.audio.transcriptions.create(model=self.model_name, file=audio_file)
|
||||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||||
|
|
||||||
def audio2base64(self, audio):
|
def audio2base64(self, audio):
|
||||||
@ -50,7 +51,7 @@ class Base(ABC):
|
|||||||
class GPTSeq2txt(Base):
|
class GPTSeq2txt(Base):
|
||||||
_FACTORY_NAME = "OpenAI"
|
_FACTORY_NAME = "OpenAI"
|
||||||
|
|
||||||
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1"):
|
def __init__(self, key, model_name="whisper-1", base_url="https://api.openai.com/v1", **kwargs):
|
||||||
if not base_url:
|
if not base_url:
|
||||||
base_url = "https://api.openai.com/v1"
|
base_url = "https://api.openai.com/v1"
|
||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
@ -60,27 +61,38 @@ class GPTSeq2txt(Base):
|
|||||||
class QWenSeq2txt(Base):
|
class QWenSeq2txt(Base):
|
||||||
_FACTORY_NAME = "Tongyi-Qianwen"
|
_FACTORY_NAME = "Tongyi-Qianwen"
|
||||||
|
|
||||||
def __init__(self, key, model_name="paraformer-realtime-8k-v1", **kwargs):
|
def __init__(self, key, model_name="qwen-audio-asr", **kwargs):
|
||||||
import dashscope
|
import dashscope
|
||||||
|
|
||||||
dashscope.api_key = key
|
dashscope.api_key = key
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
def transcription(self, audio, format):
|
def transcription(self, audio_path):
|
||||||
from http import HTTPStatus
|
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||||
|
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||||
|
|
||||||
from dashscope.audio.asr import Recognition
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
recognition = Recognition(model=self.model_name, format=format, sample_rate=16000, callback=None)
|
audio_path = f"file://{audio_path}"
|
||||||
result = recognition.call(audio)
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"audio": audio_path}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
ans = ""
|
response = None
|
||||||
if result.status_code == HTTPStatus.OK:
|
full_content = ""
|
||||||
for sentence in result.get_sentence():
|
try:
|
||||||
ans += sentence.text.decode("utf-8") + "\n"
|
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||||
return ans, num_tokens_from_string(ans)
|
for response in response:
|
||||||
|
try:
|
||||||
return "**ERROR**: " + result.message, 0
|
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return full_content, num_tokens_from_string(full_content)
|
||||||
|
except Exception as e:
|
||||||
|
return "**ERROR**: " + str(e), 0
|
||||||
|
|
||||||
|
|
||||||
class AzureSeq2txt(Base):
|
class AzureSeq2txt(Base):
|
||||||
@ -212,6 +224,7 @@ class GiteeSeq2txt(Base):
|
|||||||
self.client = OpenAI(api_key=key, base_url=base_url)
|
self.client = OpenAI(api_key=key, base_url=base_url)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
|
||||||
class DeepInfraSeq2txt(Base):
|
class DeepInfraSeq2txt(Base):
|
||||||
_FACTORY_NAME = "DeepInfra"
|
_FACTORY_NAME = "DeepInfra"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user