support sequence2txt and tts model in Xinference (#2696)

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
JobSmithManipulation
2024-10-08 10:43:18 +08:00
committed by GitHub
parent c552a02e7f
commit a3ab5ba9ac
5 changed files with 112 additions and 38 deletions

View File

@ -195,7 +195,7 @@ class LLMBundle(object):
self.llm_name = llm_name self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance( self.mdl = TenantLLMService.model_instance(
tenant_id, llm_type, llm_name, lang=lang) tenant_id, llm_type, llm_name, lang=lang)
assert self.mdl, "Can't find mole for {}/{}/{}".format( assert self.mdl, "Can't find model for {}/{}/{}".format(
tenant_id, llm_type, llm_name) tenant_id, llm_type, llm_name)
self.max_length = 8192 self.max_length = 8192
for lm in LLMService.query(llm_name=llm_name): for lm in LLMService.query(llm_name=llm_name):

View File

@ -47,10 +47,9 @@ EmbeddingModel = {
"Replicate": ReplicateEmbed, "Replicate": ReplicateEmbed,
"BaiduYiyan": BaiduYiyanEmbed, "BaiduYiyan": BaiduYiyanEmbed,
"Voyage AI": VoyageEmbed, "Voyage AI": VoyageEmbed,
"HuggingFace":HuggingFaceEmbed, "HuggingFace": HuggingFaceEmbed,
} }
CvModel = { CvModel = {
"OpenAI": GptV4, "OpenAI": GptV4,
"Azure-OpenAI": AzureGptV4, "Azure-OpenAI": AzureGptV4,
@ -64,14 +63,13 @@ CvModel = {
"LocalAI": LocalAICV, "LocalAI": LocalAICV,
"NVIDIA": NvidiaCV, "NVIDIA": NvidiaCV,
"LM-Studio": LmStudioCV, "LM-Studio": LmStudioCV,
"StepFun":StepFunCV, "StepFun": StepFunCV,
"OpenAI-API-Compatible": OpenAI_APICV, "OpenAI-API-Compatible": OpenAI_APICV,
"TogetherAI": TogetherAICV, "TogetherAI": TogetherAICV,
"01.AI": YiCV, "01.AI": YiCV,
"Tencent Hunyuan": HunyuanCV "Tencent Hunyuan": HunyuanCV
} }
ChatModel = { ChatModel = {
"OpenAI": GptTurbo, "OpenAI": GptTurbo,
"Azure-OpenAI": AzureChat, "Azure-OpenAI": AzureChat,
@ -99,7 +97,7 @@ ChatModel = {
"LeptonAI": LeptonAIChat, "LeptonAI": LeptonAIChat,
"TogetherAI": TogetherAIChat, "TogetherAI": TogetherAIChat,
"PerfXCloud": PerfXCloudChat, "PerfXCloud": PerfXCloudChat,
"Upstage":UpstageChat, "Upstage": UpstageChat,
"novita.ai": NovitaAIChat, "novita.ai": NovitaAIChat,
"SILICONFLOW": SILICONFLOWChat, "SILICONFLOW": SILICONFLOWChat,
"01.AI": YiChat, "01.AI": YiChat,
@ -111,7 +109,6 @@ ChatModel = {
"Google Cloud": GoogleChat, "Google Cloud": GoogleChat,
} }
RerankModel = { RerankModel = {
"BAAI": DefaultRerank, "BAAI": DefaultRerank,
"Jina": JinaRerank, "Jina": JinaRerank,
@ -127,11 +124,9 @@ RerankModel = {
"Voyage AI": VoyageRerank "Voyage AI": VoyageRerank
} }
Seq2txtModel = { Seq2txtModel = {
"OpenAI": GPTSeq2txt, "OpenAI": GPTSeq2txt,
"Tongyi-Qianwen": QWenSeq2txt, "Tongyi-Qianwen": QWenSeq2txt,
"Ollama": OllamaSeq2txt,
"Azure-OpenAI": AzureSeq2txt, "Azure-OpenAI": AzureSeq2txt,
"Xinference": XinferenceSeq2txt, "Xinference": XinferenceSeq2txt,
"Tencent Cloud": TencentCloudSeq2txt "Tencent Cloud": TencentCloudSeq2txt
@ -140,6 +135,7 @@ Seq2txtModel = {
TTSModel = { TTSModel = {
"Fish Audio": FishAudioTTS, "Fish Audio": FishAudioTTS,
"Tongyi-Qianwen": QwenTTS, "Tongyi-Qianwen": QwenTTS,
"OpenAI":OpenAITTS, "OpenAI": OpenAITTS,
"XunFei Spark":SparkTTS "XunFei Spark": SparkTTS,
} "Xinference": XinferenceTTS,
}

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import requests
from openai.lib.azure import AzureOpenAI from openai.lib.azure import AzureOpenAI
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
import io import io
@ -25,6 +26,7 @@ from rag.utils import num_tokens_from_string
import base64 import base64
import re import re
class Base(ABC): class Base(ABC):
def __init__(self, key, model_name): def __init__(self, key, model_name):
pass pass
@ -36,8 +38,8 @@ class Base(ABC):
response_format="text" response_format="text"
) )
return transcription.text.strip(), num_tokens_from_string(transcription.text.strip()) return transcription.text.strip(), num_tokens_from_string(transcription.text.strip())
def audio2base64(self,audio): def audio2base64(self, audio):
if isinstance(audio, bytes): if isinstance(audio, bytes):
return base64.b64encode(audio).decode("utf-8") return base64.b64encode(audio).decode("utf-8")
if isinstance(audio, io.BytesIO): if isinstance(audio, io.BytesIO):
@ -77,13 +79,6 @@ class QWenSeq2txt(Base):
return "**ERROR**: " + result.message, 0 return "**ERROR**: " + result.message, 0
class OllamaSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = Client(host=kwargs["base_url"])
self.model_name = model_name
self.lang = lang
class AzureSeq2txt(Base): class AzureSeq2txt(Base):
def __init__(self, key, model_name, lang="Chinese", **kwargs): def __init__(self, key, model_name, lang="Chinese", **kwargs):
self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01")
@ -92,16 +87,53 @@ class AzureSeq2txt(Base):
class XinferenceSeq2txt(Base): class XinferenceSeq2txt(Base):
def __init__(self, key, model_name="", base_url=""): def __init__(self,key,model_name="whisper-small",**kwargs):
if base_url.split("/")[-1] != "v1": self.base_url = kwargs.get('base_url', None)
base_url = os.path.join(base_url, "v1")
self.client = OpenAI(api_key="xxx", base_url=base_url)
self.model_name = model_name self.model_name = model_name
def transcription(self, audio, language="zh", prompt=None, response_format="json", temperature=0.7):
if isinstance(audio, str):
audio_file = open(audio, 'rb')
audio_data = audio_file.read()
audio_file_name = audio.split("/")[-1]
else:
audio_data = audio
audio_file_name = "audio.wav"
payload = {
"model": self.model_name,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature
}
files = {
"file": (audio_file_name, audio_data, 'audio/wav')
}
try:
response = requests.post(
f"{self.base_url}/v1/audio/transcriptions",
files=files,
data=payload
)
response.raise_for_status()
result = response.json()
if 'text' in result:
transcription_text = result['text'].strip()
return transcription_text, num_tokens_from_string(transcription_text)
else:
return "**ERROR**: Failed to retrieve transcription.", 0
except requests.exceptions.RequestException as e:
return f"**ERROR**: {str(e)}", 0
class TencentCloudSeq2txt(Base): class TencentCloudSeq2txt(Base):
def __init__( def __init__(
self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com" self, key, model_name="16k_zh", base_url="https://asr.tencentcloudapi.com"
): ):
from tencentcloud.common import credential from tencentcloud.common import credential
from tencentcloud.asr.v20190614 import asr_client from tencentcloud.asr.v20190614 import asr_client

View File

@ -297,3 +297,36 @@ class SparkTTS:
break break
status_code = 1 status_code = 1
yield audio_chunk yield audio_chunk
class XinferenceTTS:
def __init__(self, key, model_name, **kwargs):
self.base_url = kwargs.get("base_url", None)
self.model_name = model_name
self.headers = {
"accept": "application/json",
"Content-Type": "application/json"
}
def tts(self, text, voice="中文女", stream=True):
payload = {
"model": self.model_name,
"input": text,
"voice": voice
}
response = requests.post(
f"{self.base_url}/v1/audio/speech",
headers=self.headers,
json=payload,
stream=stream
)
if response.status_code != 200:
raise Exception(f"**Error**: {response.status_code}, {response.text}")
for chunk in response.iter_content(chunk_size=1024):
if chunk:
yield chunk

View File

@ -53,6 +53,26 @@ const OllamaModal = ({
const url = const url =
llmFactoryToUrlMap[llmFactory as LlmFactory] || llmFactoryToUrlMap[llmFactory as LlmFactory] ||
'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx'; 'https://github.com/infiniflow/ragflow/blob/main/docs/guides/deploy_local_llm.mdx';
const optionsMap = {
HuggingFace: [{ value: 'embedding', label: 'embedding' }],
Xinference: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'image2text', label: 'image2text' },
{ value: 'speech2text', label: 'sequence2text' },
{ value: 'tts', label: 'tts' },
],
Default: [
{ value: 'chat', label: 'chat' },
{ value: 'embedding', label: 'embedding' },
{ value: 'rerank', label: 'rerank' },
{ value: 'image2text', label: 'image2text' },
],
};
const getOptions = (factory: string) => {
return optionsMap[factory as keyof typeof optionsMap] || optionsMap.Default;
};
return ( return (
<Modal <Modal
title={t('addLlmTitle', { name: llmFactory })} title={t('addLlmTitle', { name: llmFactory })}
@ -85,18 +105,11 @@ const OllamaModal = ({
rules={[{ required: true, message: t('modelTypeMessage') }]} rules={[{ required: true, message: t('modelTypeMessage') }]}
> >
<Select placeholder={t('modelTypeMessage')}> <Select placeholder={t('modelTypeMessage')}>
{llmFactory === 'HuggingFace' ? ( {getOptions(llmFactory).map((option) => (
<Option value="embedding">embedding</Option> <Option key={option.value} value={option.value}>
) : ( {option.label}
<> </Option>
<Option value="chat">chat</Option> ))}
<Option value="embedding">embedding</Option>
<Option value="rerank">rerank</Option>
<Option value="image2text">image2text</Option>
<Option value="audio2text">audio2text</Option>
<Option value="text2andio">text2andio</Option>
</>
)}
</Select> </Select>
</Form.Item> </Form.Item>
<Form.Item<FieldType> <Form.Item<FieldType>