diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 73f83a5da..beb24f065 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -803,6 +803,12 @@ "tags": "TEXT EMBEDDING", "max_tokens": 512, "model_type": "embedding" + }, + { + "llm_name": "glm-asr", + "tags": "SPEECH2TEXT", + "max_tokens": 4096, + "model_type": "speech2text" } ] }, @@ -5140,4 +5146,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index c43a0141a..c66adada4 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -234,8 +234,8 @@ class DeepInfraSeq2txt(Base): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name - - + + class CometAPISeq2txt(Base): _FACTORY_NAME = "CometAPI" @@ -244,7 +244,8 @@ class CometAPISeq2txt(Base): base_url = "https://api.cometapi.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name - + + class DeerAPISeq2txt(Base): _FACTORY_NAME = "DeerAPI" @@ -253,3 +254,44 @@ class DeerAPISeq2txt(Base): base_url = "https://api.deerapi.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name + + +class ZhipuSeq2txt(Base): + _FACTORY_NAME = "ZHIPU-AI" + + def __init__(self, key, model_name="glm-asr", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs): + if not base_url: + base_url = "https://open.bigmodel.cn/api/paas/v4" + self.base_url = base_url + self.api_key = key + self.model_name = model_name + self.gen_conf = kwargs.get("gen_conf", {}) + self.stream = kwargs.get("stream", False) + + def transcription(self, audio_path): + payload = { + "model": self.model_name, + "temperature": str(self.gen_conf.get("temperature", 0.75)) or "0.75", + "stream": self.stream, + } + + headers = {"Authorization": f"Bearer {self.api_key}"} + with open(audio_path, "rb") as audio_file: + files = {"file": audio_file} + + try: + response = requests.post( + url=f"{self.base_url}/audio/transcriptions", + data=payload, + files=files, + headers=headers, + ) + body = response.json() + if response.status_code == 200: + full_content = body["text"] + return full_content, num_tokens_from_string(full_content) + else: + error = body["error"] + return f"**ERROR**: code: {error['code']}, message: {error['message']}", 0 + except Exception as e: + return "**ERROR**: " + str(e), 0