mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
support sequence2txt and tts model in Xinference (#2696)
### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
committed by
GitHub
parent
c552a02e7f
commit
a3ab5ba9ac
@ -195,7 +195,7 @@ class LLMBundle(object):
|
|||||||
self.llm_name = llm_name
|
self.llm_name = llm_name
|
||||||
self.mdl = TenantLLMService.model_instance(
|
self.mdl = TenantLLMService.model_instance(
|
||||||
tenant_id, llm_type, llm_name, lang=lang)
|
tenant_id, llm_type, llm_name, lang=lang)
|
||||||
assert self.mdl, "Can't find mole for {}/{}/{}".format(
|
assert self.mdl, "Can't find model for {}/{}/{}".format(
|
||||||
tenant_id, llm_type, llm_name)
|
tenant_id, llm_type, llm_name)
|
||||||
self.max_length = 8192
|
self.max_length = 8192
|
||||||
for lm in LLMService.query(llm_name=llm_name):
|
for lm in LLMService.query(llm_name=llm_name):
|
||||||
|
|||||||
@ -47,10 +47,9 @@ EmbeddingModel = {
|
|||||||
"Replicate": ReplicateEmbed,
|
"Replicate": ReplicateEmbed,
|
||||||
"BaiduYiyan": BaiduYiyanEmbed,
|
"BaiduYiyan": BaiduYiyanEmbed,
|
||||||
"Voyage AI": VoyageEmbed,
|
"Voyage AI": VoyageEmbed,
|
||||||
"HuggingFace":HuggingFaceEmbed,
|
"HuggingFace": HuggingFaceEmbed,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
CvModel = {
|
CvModel = {
|
||||||
"OpenAI": GptV4,
|
"OpenAI": GptV4,
|
||||||
"Azure-OpenAI": AzureGptV4,
|
"Azure-OpenAI": AzureGptV4,
|
||||||
@ -64,14 +63,13 @@ CvModel = {
|
|||||||
"LocalAI": LocalAICV,
|
"LocalAI": LocalAICV,
|
||||||
"NVIDIA": NvidiaCV,
|
"NVIDIA": NvidiaCV,
|
||||||
"LM-Studio": LmStudioCV,
|
"LM-Studio": LmStudioCV,
|
||||||
"StepFun":StepFunCV,
|
"StepFun": StepFunCV,
|
||||||
"OpenAI-API-Compatible": OpenAI_APICV,
|
"OpenAI-API-Compatible": OpenAI_APICV,
|
||||||
"TogetherAI": TogetherAICV,
|
"TogetherAI": TogetherAICV,
|
||||||
"01.AI": YiCV,
|
"01.AI": YiCV,
|
||||||
"Tencent Hunyuan": HunyuanCV
|
"Tencent Hunyuan": HunyuanCV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
ChatModel = {
|
ChatModel = {
|
||||||
"OpenAI": GptTurbo,
|
"OpenAI": GptTurbo,
|
||||||
"Azure-OpenAI": AzureChat,
|
"Azure-OpenAI": AzureChat,
|
||||||
@ -99,7 +97,7 @@ ChatModel = {
|
|||||||
"LeptonAI": LeptonAIChat,
|
"LeptonAI": LeptonAIChat,
|
||||||
"TogetherAI": TogetherAIChat,
|
"TogetherAI": TogetherAIChat,
|
||||||
"PerfXCloud": PerfXCloudChat,
|
"PerfXCloud": PerfXCloudChat,
|
||||||
"Upstage":UpstageChat,
|
"Upstage": UpstageChat,
|
||||||
"novita.ai": NovitaAIChat,
|
"novita.ai": NovitaAIChat,
|
||||||
"SILICONFLOW": SILICONFLOWChat,
|
"SILICONFLOW": SILICONFLOWChat,
|
||||||
"01.AI": YiChat,
|
"01.AI": YiChat,
|
||||||
@ -111,7 +109,6 @@ ChatModel = {
|
|||||||
"Google Cloud": GoogleChat,
|
"Google Cloud": GoogleChat,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
RerankModel = {
|
RerankModel = {
|
||||||
"BAAI": DefaultRerank,
|
"BAAI": DefaultRerank,
|
||||||
"Jina": JinaRerank,
|
"Jina": JinaRerank,
|
||||||
@ -127,11 +124,9 @@ RerankModel = {
|
|||||||
"Voyage AI": VoyageRerank
|
"Voyage AI": VoyageRerank
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Seq2txtModel = {
|
Seq2txtModel = {
|
||||||
"OpenAI": GPTSeq2txt,
|
"OpenAI": GPTSeq2txt,
|
||||||
"Tongyi-Qianwen": QWenSeq2txt,
|
"Tongyi-Qianwen": QWenSeq2txt,
|
||||||
"Ollama": OllamaSeq2txt,
|
|
||||||
"Azure-OpenAI": AzureSeq2txt,
|
"Azure-OpenAI": AzureSeq2txt,
|
||||||
"Xinference": XinferenceSeq2txt,
|
"Xinference": XinferenceSeq2txt,
|
||||||
"Tencent Cloud": TencentCloudSeq2txt
|
"Tencent Cloud": TencentCloudSeq2txt
|
||||||
@ -140,6 +135,7 @@ Seq2txtModel = {
|
|||||||
TTSModel = {
|
TTSModel = {
|
||||||
"Fish Audio": FishAudioTTS,
|
"Fish Audio": FishAudioTTS,
|
||||||
"Tongyi-Qianwen": QwenTTS,
|
"Tongyi-Qianwen": QwenTTS,
|
||||||
"OpenAI":OpenAITTS,
|
"OpenAI": OpenAITTS,
|
||||||
"XunFei Spark":SparkTTS
|
"XunFei Spark": SparkTTS,
|
||||||
}
|
"Xinference": XinferenceTTS,
|
||||||
|
}
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import requests
|
||||||
from openai.lib.azure import AzureOpenAI
|
from openai.lib.azure import AzureOpenAI
|
||||||
from zhipuai import ZhipuAI
|
from zhipuai import ZhipuAI
|
||||||
import io
|
import io
|
||||||
@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
|
|||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
class Base(ABC):
|
||||||
def __init__(self, key, model_name):
|
def __init__(self, key, model_name):
|
||||||
pass
|
pass
|
||||||
@ -36,8 +38,8 @@ class Base(ABC):
|
|||||||
response_format="text"
|
response_format="text"
|
||||||
)
|
)
|
||||||
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
|
||||||
|
|
||||||
def audio2base64(self,audio):
|
def audio2base64(self, audio):
|
||||||
if isinstance(audio, bytes):
|
if isinstance(audio, bytes):
|
||||||
return base64.b64encode(audio).decode("utf-8")
|
return base64.b64encode(audio).decode("utf-8")
|
||||||
if isinstance(audio, io.BytesIO):
|
if isinstance(audio, io.BytesIO):
|
||||||
@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
|
|||||||
return "**ERROR**: " + result.message, 0
|
return "**ERROR**: " + result.message, 0
|
||||||
|
|
||||||
|
|
||||||
class OllamaSeq2txt(Base):
|
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
|
||||||
self.client = Client(host=kwargs["base_url"])
|
|
||||||
self.model_name = model_name
|
|
||||||
self.lang = lang
|
|
||||||
|
|
||||||
|
|
||||||
class AzureSeq2txt(Base):
|
class AzureSeq2txt(Base):
|
||||||
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
def __init__(self, key, model_name, lang="Chinese", **kwargs):
|
||||||
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
|
||||||
@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
|
|||||||
|
|
||||||
|
|
||||||
class XinferenceSeq2txt(Base):
|
class XinferenceSeq2txt(Base):
|
||||||
def __init__(self, key, model_name="", base_url=""):
|
def __init__(self,key,model_name="whisper-small",**kwargs):
|
||||||
if base_url.split("/")[-1] != "v1":
|
self.base_url = kwargs.get('base_url', None)
|
||||||
base_url = os.path.join(base_url, "v1")
|
|
||||||
self.client = OpenAI(api_key="xxx", base_url=base_url)
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
|
||||||
|
if isinstance(audio, str):
|
||||||
|
audio_file = open(audio, 'rb')
|
||||||
|
audio_data = audio_file.read()
|
||||||
|
audio_file_name = audio.split("/")[-1]
|
||||||
|
else:
|
||||||
|
audio_data = audio
|
||||||
|
audio_file_name = "audio.wav"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"language": language,
|
||||||
|
"prompt": prompt,
|
||||||
|
"response_format": response_format,
|
||||||
|
"temperature": temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
files = {
|
||||||
|
"file": (audio_file_name, audio_data, 'audio/wav')
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/audio/transcriptions",
|
||||||
|
files=files,
|
||||||
|
data=payload
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if 'text' in result:
|
||||||
|
transcription_text = result['text'].strip()
|
||||||
|
return transcription_text, num_tokens_from_string(transcription_text)
|
||||||
|
else:
|
||||||
|
return "**ERROR**: Failed to retrieve transcription.", 0
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
return f"**ERROR**: {str(e)}", 0
|
||||||
|
|
||||||
|
|
||||||
class TencentCloudSeq2txt(Base):
|
class TencentCloudSeq2txt(Base):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
|
||||||
):
|
):
|
||||||
from tencentcloud.common import credential
|
from tencentcloud.common import credential
|
||||||
from tencentcloud.asr.v20190614 import asr_client
|
from tencentcloud.asr.v20190614 import asr_client
|
||||||
|
|||||||
@ -297,3 +297,36 @@ class SparkTTS:
|
|||||||
break
|
break
|
||||||
status_code = 1
|
status_code = 1
|
||||||
yield audio_chunk
|
yield audio_chunk
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class XinferenceTTS:
|
||||||
|
def __init__(self, key, model_name, **kwargs):
|
||||||
|
self.base_url = kwargs.get("base_url", None)
|
||||||
|
self.model_name = model_name
|
||||||
|
self.headers = {
|
||||||
|
"accept": "application/json",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def tts(self, text, voice="中文女", stream=True):
|
||||||
|
payload = {
|
||||||
|
"model": self.model_name,
|
||||||
|
"input": text,
|
||||||
|
"voice": voice
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/v1/audio/speech",
|
||||||
|
headers=self.headers,
|
||||||
|
json=payload,
|
||||||
|
stream=stream
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"**Error**: {response.status_code}, {response.text}")
|
||||||
|
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
|||||||
@ -53,6 +53,26 @@ const OllamaModal = ({
|
|||||||
const url =
|
const url =
|
||||||
llmFactoryToUrlMap[llmFactory as LlmFactory] ||
|
llmFactoryToUrlMap[llmFactory as LlmFactory] ||
|
||||||
'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
|
'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
|
||||||
|
const optionsMap = {
|
||||||
|
HuggingFace: [{ value: 'embedding', label: 'embedding' }],
|
||||||
|
Xinference: [
|
||||||
|
{ value: 'chat', label: 'chat' },
|
||||||
|
{ value: 'embedding', label: 'embedding' },
|
||||||
|
{ value: 'rerank', label: 'rerank' },
|
||||||
|
{ value: 'image2text', label: 'image2text' },
|
||||||
|
{ value: 'speech2text', label: 'sequence2text' },
|
||||||
|
{ value: 'tts', label: 'tts' },
|
||||||
|
],
|
||||||
|
Default: [
|
||||||
|
{ value: 'chat', label: 'chat' },
|
||||||
|
{ value: 'embedding', label: 'embedding' },
|
||||||
|
{ value: 'rerank', label: 'rerank' },
|
||||||
|
{ value: 'image2text', label: 'image2text' },
|
||||||
|
],
|
||||||
|
};
|
||||||
|
const getOptions = (factory: string) => {
|
||||||
|
return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default;
|
||||||
|
};
|
||||||
return (
|
return (
|
||||||
<Modal
|
<Modal
|
||||||
title={t('addLlmTitle', { name: llmFactory })}
|
title={t('addLlmTitle', { name: llmFactory })}
|
||||||
@ -85,18 +105,11 @@ const OllamaModal = ({
|
|||||||
rules={[{ required: true, message: t('modelTypeMessage') }]}
|
rules={[{ required: true, message: t('modelTypeMessage') }]}
|
||||||
>
|
>
|
||||||
<Select placeholder={t('modelTypeMessage')}>
|
<Select placeholder={t('modelTypeMessage')}>
|
||||||
{llmFactory === 'HuggingFace' ? (
|
{getOptions(llmFactory).map((option) => (
|
||||||
<Option value="embedding">embedding</Option>
|
<Option key={option.value} value={option.value}>
|
||||||
) : (
|
{option.label}
|
||||||
<>
|
</Option>
|
||||||
<Option value="chat">chat</Option>
|
))}
|
||||||
<Option value="embedding">embedding</Option>
|
|
||||||
<Option value="rerank">rerank</Option>
|
|
||||||
<Option value="image2text">image2text</Option>
|
|
||||||
<Option value="audio2text">audio2text</Option>
|
|
||||||
<Option value="text2andio">text2andio</Option>
|
|
||||||
</>
|
|
||||||
)}
|
|
||||||
</Select>
|
</Select>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
<Form.Item<FieldType>
|
<Form.Item<FieldType>
|
||||||
|
|||||||
Reference in New Issue
Block a user