mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
6 Commits
d1e172171f
...
2ffe6f7439
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ffe6f7439 | |||
| e3987e21b9 | |||
| a713f54732 | |||
| 519f03097e | |||
| 299c655e39 | |||
| b8c0fb4572 |
@ -41,6 +41,7 @@ class MessageParam(ComponentParamBase):
|
||||
self.content = []
|
||||
self.stream = True
|
||||
self.output_format = None # default output format
|
||||
self.auto_play = False
|
||||
self.outputs = {
|
||||
"content": {
|
||||
"type": "str"
|
||||
|
||||
@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
import tempfile
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.db_models import APIToken
|
||||
@ -248,6 +250,64 @@ async def completion():
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
@manager.route("/sequence2txt", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def sequence2txt():
|
||||
req = await request.form
|
||||
stream_mode = req.get("stream", "false").lower() == "true"
|
||||
files = await request.files
|
||||
if "file" not in files:
|
||||
return get_data_error_result(message="Missing 'file' in multipart form-data")
|
||||
|
||||
uploaded = files["file"]
|
||||
|
||||
ALLOWED_EXTS = {
|
||||
".wav", ".mp3", ".m4a", ".aac",
|
||||
".flac", ".ogg", ".webm",
|
||||
".opus", ".wma"
|
||||
}
|
||||
|
||||
filename = uploaded.filename or ""
|
||||
suffix = os.path.splitext(filename)[-1].lower()
|
||||
if suffix not in ALLOWED_EXTS:
|
||||
return get_data_error_result(message=
|
||||
f"Unsupported audio format: {suffix}. "
|
||||
f"Allowed: {', '.join(sorted(ALLOWED_EXTS))}"
|
||||
)
|
||||
fd, temp_audio_path = tempfile.mkstemp(suffix=suffix)
|
||||
os.close(fd)
|
||||
await uploaded.save(temp_audio_path)
|
||||
|
||||
tenants = TenantService.get_info_by(current_user.id)
|
||||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
asr_id = tenants[0]["asr_id"]
|
||||
if not asr_id:
|
||||
return get_data_error_result(message="No default ASR model is set")
|
||||
|
||||
asr_mdl=LLMBundle(tenants[0]["tenant_id"], LLMType.SPEECH2TEXT, asr_id)
|
||||
if not stream_mode:
|
||||
text = asr_mdl.transcription(temp_audio_path)
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
return get_json_result(data={"text": text})
|
||||
async def event_stream():
|
||||
try:
|
||||
for evt in asr_mdl.stream_transcription(temp_audio_path):
|
||||
yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n"
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_audio_path)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to remove temp audio file: {str(e)}")
|
||||
|
||||
return Response(event_stream(), content_type="text/event-stream")
|
||||
|
||||
@manager.route("/tts", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
|
||||
@ -77,7 +77,8 @@ async def convert():
|
||||
doc = DocumentService.insert({
|
||||
"id": get_uuid(),
|
||||
"kb_id": kb.id,
|
||||
"parser_id": FileService.get_parser(file.type, file.name, kb.parser_id),
|
||||
"parser_id": kb.parser_id,
|
||||
"pipeline_id": kb.pipeline_id,
|
||||
"parser_config": kb.parser_config,
|
||||
"created_by": current_user.id,
|
||||
"type": file.type,
|
||||
|
||||
@ -185,6 +185,66 @@ class LLMBundle(LLM4Tenant):
|
||||
|
||||
return txt
|
||||
|
||||
def stream_transcription(self, audio):
|
||||
mdl = self.mdl
|
||||
supports_stream = hasattr(mdl, "stream_transcription") and callable(getattr(mdl, "stream_transcription"))
|
||||
if supports_stream:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(
|
||||
trace_context=self.trace_context,
|
||||
name="stream_transcription",
|
||||
metadata={"model": self.llm_name}
|
||||
)
|
||||
final_text = ""
|
||||
used_tokens = 0
|
||||
|
||||
try:
|
||||
for evt in mdl.stream_transcription(audio):
|
||||
if evt.get("event") == "final":
|
||||
final_text = evt.get("text", "")
|
||||
|
||||
yield evt
|
||||
|
||||
except Exception as e:
|
||||
err = {"event": "error", "text": str(e)}
|
||||
yield err
|
||||
final_text = final_text or ""
|
||||
finally:
|
||||
if final_text:
|
||||
used_tokens = num_tokens_from_string(final_text)
|
||||
TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens)
|
||||
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": final_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
return
|
||||
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
||||
full_text, used_tokens = mdl.transcription(audio)
|
||||
if not TenantLLMService.increase_usage(
|
||||
self.tenant_id, self.llm_type, used_tokens
|
||||
):
|
||||
logging.error(
|
||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
||||
)
|
||||
if self.langfuse:
|
||||
generation.update(
|
||||
output={"output": full_text},
|
||||
usage_details={"total_tokens": used_tokens}
|
||||
)
|
||||
generation.end()
|
||||
|
||||
yield {
|
||||
"event": "final",
|
||||
"text": full_text,
|
||||
"streaming": False
|
||||
}
|
||||
|
||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||
if self.langfuse:
|
||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="tts", input={"text": text})
|
||||
|
||||
@ -714,19 +714,13 @@
|
||||
"model_type": "rerank"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr",
|
||||
"llm_name": "qwen3-asr-flash",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr-latest",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-audio-asr-1204",
|
||||
"llm_name": "qwen3-asr-flash-2025-09-08",
|
||||
"tags": "SPEECH2TEXT,8k",
|
||||
"max_tokens": 8000,
|
||||
"model_type": "speech2text"
|
||||
@ -1232,39 +1226,14 @@
|
||||
{
|
||||
"name": "MiniMax",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"tags": "LLM",
|
||||
"status": "1",
|
||||
"rank": "810",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "abab6.5-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5s-chat",
|
||||
"tags": "LLM,CHAT,245k",
|
||||
"max_tokens": 245760,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5t-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab6.5g-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "abab5.5s-chat",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"llm_name": "MiniMax-M2",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat"
|
||||
}
|
||||
]
|
||||
|
||||
@ -72,7 +72,7 @@ services:
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.7
|
||||
image: infiniflow/infinity:v0.6.8
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
[general]
|
||||
version = "0.6.7"
|
||||
version = "0.6.8"
|
||||
time_zone = "utc-8"
|
||||
|
||||
[network]
|
||||
|
||||
@ -19,48 +19,60 @@ Upgrading RAGFlow in itself will *not* remove your uploaded/historical data. How
|
||||
|
||||
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
|
||||
|
||||
1. Clone the repo
|
||||
1. Stop the server
|
||||
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
docker compose -f docker/docker-compose.yml down
|
||||
```
|
||||
|
||||
2. Update **ragflow/docker/.env**:
|
||||
2. Update the local code
|
||||
|
||||
```bash
|
||||
git pull
|
||||
```
|
||||
|
||||
3. Update **ragflow/docker/.env**:
|
||||
|
||||
```bash
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:nightly
|
||||
```
|
||||
|
||||
3. Update RAGFlow image and restart RAGFlow:
|
||||
4. Update RAGFlow image and restart RAGFlow:
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml pull
|
||||
docker compose -f docker/docker-compose.yml up -d
|
||||
```
|
||||
|
||||
## Upgrade RAGFlow to the most recent, officially published release
|
||||
## Upgrade RAGFlow to given release
|
||||
|
||||
To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker image:
|
||||
|
||||
1. Clone the repo
|
||||
1. Stop the server
|
||||
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
docker compose -f docker/docker-compose.yml down
|
||||
```
|
||||
|
||||
2. Switch to the latest, officially published release, e.g., `v0.22.1`:
|
||||
2. Update the local code
|
||||
|
||||
```bash
|
||||
git pull
|
||||
```
|
||||
|
||||
3. Switch to the latest, officially published release, e.g., `v0.22.1`:
|
||||
|
||||
```bash
|
||||
git checkout -f v0.22.1
|
||||
```
|
||||
|
||||
3. Update **ragflow/docker/.env**:
|
||||
4. Update **ragflow/docker/.env**:
|
||||
|
||||
```bash
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.22.1
|
||||
```
|
||||
|
||||
4. Update the RAGFlow image and restart RAGFlow:
|
||||
5. Update the RAGFlow image and restart RAGFlow:
|
||||
|
||||
```bash
|
||||
docker compose -f docker/docker-compose.yml pull
|
||||
|
||||
@ -96,7 +96,7 @@ ragflow:
|
||||
infinity:
|
||||
image:
|
||||
repository: infiniflow/infinity
|
||||
tag: v0.6.7
|
||||
tag: v0.6.8
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: []
|
||||
storage:
|
||||
|
||||
@ -49,7 +49,7 @@ dependencies = [
|
||||
"html-text==0.6.2",
|
||||
"httpx[socks]>=0.28.1,<0.29.0",
|
||||
"huggingface-hub>=0.25.0,<0.26.0",
|
||||
"infinity-sdk==0.6.7",
|
||||
"infinity-sdk==0.6.8",
|
||||
"infinity-emb>=0.0.66,<0.0.67",
|
||||
"itsdangerous==2.1.2",
|
||||
"json-repair==0.35.0",
|
||||
@ -152,7 +152,9 @@ dependencies = [
|
||||
"moodlepy>=0.23.0",
|
||||
"pypandoc>=1.16",
|
||||
"pyobvector==0.2.18",
|
||||
"exceptiongroup>=1.3.0,<2.0.0"
|
||||
"exceptiongroup>=1.3.0,<2.0.0",
|
||||
"ffmpeg-python>=0.2.0",
|
||||
"imageio-ffmpeg>=0.6.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
@ -51,6 +51,7 @@ class SupportedLiteLLMProvider(StrEnum):
|
||||
AI_302 = "302.AI"
|
||||
JiekouAI = "Jiekou.AI"
|
||||
ZHIPU_AI = "ZHIPU-AI"
|
||||
MiniMax = "MiniMax"
|
||||
|
||||
|
||||
FACTORY_DEFAULT_BASE_URL = {
|
||||
@ -73,6 +74,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
||||
SupportedLiteLLMProvider.Anthropic: "https://api.anthropic.com/",
|
||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
||||
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
||||
}
|
||||
|
||||
|
||||
@ -105,6 +107,7 @@ LITELLM_PROVIDER_PREFIX = {
|
||||
SupportedLiteLLMProvider.AI_302: "openai/",
|
||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
||||
SupportedLiteLLMProvider.MiniMax: "openai/",
|
||||
}
|
||||
|
||||
ChatModel = globals().get("ChatModel", {})
|
||||
|
||||
@ -28,7 +28,6 @@ from urllib.parse import urljoin
|
||||
import json_repair
|
||||
import litellm
|
||||
import openai
|
||||
import requests
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
@ -1015,86 +1014,6 @@ class VolcEngineChat(Base):
|
||||
super().__init__(ark_api_key, model_name, base_url, **kwargs)
|
||||
|
||||
|
||||
class MiniMaxChat(Base):
|
||||
_FACTORY_NAME = "MiniMax"
|
||||
|
||||
def __init__(self, key, model_name, base_url="https://api.minimax.chat/v1/text/chatcompletion_v2", **kwargs):
|
||||
super().__init__(key, model_name, base_url=base_url, **kwargs)
|
||||
|
||||
if not base_url:
|
||||
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
self.api_key = key
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = json.dumps({"model": self.model_name, "messages": history, **gen_conf})
|
||||
response = requests.request("POST", url=self.base_url, headers=headers, data=payload)
|
||||
response = response.json()
|
||||
ans = response["choices"][0]["message"]["content"].strip()
|
||||
if response["choices"][0]["finish_reason"] == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = json.dumps(
|
||||
{
|
||||
"model": self.model_name,
|
||||
"messages": history,
|
||||
"stream": True,
|
||||
**gen_conf,
|
||||
}
|
||||
)
|
||||
response = requests.request(
|
||||
"POST",
|
||||
url=self.base_url,
|
||||
headers=headers,
|
||||
data=payload,
|
||||
)
|
||||
for resp in response.text.split("\n\n")[:-1]:
|
||||
resp = json.loads(resp[6:])
|
||||
text = ""
|
||||
if "choices" in resp and "delta" in resp["choices"][0]:
|
||||
text = resp["choices"][0]["delta"]["content"]
|
||||
ans = text
|
||||
tol = total_token_count_from_response(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(text)
|
||||
else:
|
||||
total_tokens = tol
|
||||
yield ans
|
||||
|
||||
except Exception as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
yield total_tokens
|
||||
|
||||
|
||||
class MistralChat(Base):
|
||||
_FACTORY_NAME = "Mistral"
|
||||
|
||||
@ -1642,6 +1561,7 @@ class LiteLLMBase(ABC):
|
||||
"302.AI",
|
||||
"Jiekou.AI",
|
||||
"ZHIPU-AI",
|
||||
"MiniMax",
|
||||
]
|
||||
|
||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||
|
||||
@ -19,6 +19,7 @@ import json
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
import tempfile
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
@ -68,32 +69,80 @@ class QWenSeq2txt(Base):
|
||||
self.model_name = model_name
|
||||
|
||||
def transcription(self, audio_path):
|
||||
if "paraformer" in self.model_name or "sensevoice" in self.model_name:
|
||||
return f"**ERROR**: model {self.model_name} is not suppported yet.", 0
|
||||
import dashscope
|
||||
|
||||
from dashscope import MultiModalConversation
|
||||
if audio_path.startswith("http"):
|
||||
audio_input = audio_path
|
||||
else:
|
||||
audio_input = f"file://{audio_path}"
|
||||
|
||||
audio_path = f"file://{audio_path}"
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"text": ""}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_path}],
|
||||
"content": [{"audio": audio_input}]
|
||||
}
|
||||
]
|
||||
|
||||
response = None
|
||||
full_content = ""
|
||||
try:
|
||||
response = MultiModalConversation.call(model="qwen-audio-asr", messages=messages, result_format="message", stream=True)
|
||||
for response in response:
|
||||
try:
|
||||
full_content += response["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception:
|
||||
pass
|
||||
return full_content, num_tokens_from_string(full_content)
|
||||
except Exception as e:
|
||||
return "**ERROR**: " + str(e), 0
|
||||
resp = dashscope.MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options={
|
||||
"enable_lid": True,
|
||||
"enable_itn": False
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
text = resp["output"]["choices"][0]["message"].content[0]["text"]
|
||||
except Exception as e:
|
||||
text = "**ERROR**: " + str(e)
|
||||
return text, num_tokens_from_string(text)
|
||||
|
||||
def stream_transcription(self, audio_path):
|
||||
import dashscope
|
||||
|
||||
if audio_path.startswith("http"):
|
||||
audio_input = audio_path
|
||||
else:
|
||||
audio_input = f"file://{audio_path}"
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"text": ""}]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"audio": audio_input}]
|
||||
}
|
||||
]
|
||||
|
||||
stream = dashscope.MultiModalConversation.call(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
stream=True,
|
||||
asr_options={
|
||||
"enable_lid": True,
|
||||
"enable_itn": False
|
||||
}
|
||||
)
|
||||
|
||||
full = ""
|
||||
for chunk in stream:
|
||||
try:
|
||||
piece = chunk["output"]["choices"][0]["message"].content[0]["text"]
|
||||
full = piece
|
||||
yield {"event": "delta", "text": piece}
|
||||
except Exception as e:
|
||||
yield {"event": "error", "text": str(e)}
|
||||
|
||||
yield {"event": "final", "text": full}
|
||||
|
||||
class AzureSeq2txt(Base):
|
||||
_FACTORY_NAME = "Azure-OpenAI"
|
||||
@ -268,6 +317,27 @@ class ZhipuSeq2txt(Base):
|
||||
self.gen_conf = kwargs.get("gen_conf", {})
|
||||
self.stream = kwargs.get("stream", False)
|
||||
|
||||
def _convert_to_wav(self, input_path):
|
||||
ext = os.path.splitext(input_path)[1].lower()
|
||||
if ext in [".wav", ".mp3"]:
|
||||
return input_path
|
||||
fd, out_path = tempfile.mkstemp(suffix=".wav")
|
||||
os.close(fd)
|
||||
try:
|
||||
import ffmpeg
|
||||
import imageio_ffmpeg as ffmpeg_exe
|
||||
ffmpeg_path = ffmpeg_exe.get_ffmpeg_exe()
|
||||
(
|
||||
ffmpeg
|
||||
.input(input_path)
|
||||
.output(out_path, ar=16000, ac=1)
|
||||
.overwrite_output()
|
||||
.run(cmd=ffmpeg_path,quiet=True)
|
||||
)
|
||||
return out_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"audio convert failed: {e}")
|
||||
|
||||
def transcription(self, audio_path):
|
||||
payload = {
|
||||
"model": self.model_name,
|
||||
@ -276,7 +346,9 @@ class ZhipuSeq2txt(Base):
|
||||
}
|
||||
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
with open(audio_path, "rb") as audio_file:
|
||||
converted = self._convert_to_wav(audio_path)
|
||||
|
||||
with open(converted, "rb") as audio_file:
|
||||
files = {"file": audio_file}
|
||||
|
||||
try:
|
||||
|
||||
@ -19,7 +19,6 @@ import random
|
||||
from collections import Counter
|
||||
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from . import rag_tokenizer
|
||||
import re
|
||||
import copy
|
||||
import roman_numbers as r
|
||||
@ -29,6 +28,8 @@ from PIL import Image
|
||||
|
||||
import chardet
|
||||
|
||||
__all__ = ['rag_tokenizer']
|
||||
|
||||
all_codecs = [
|
||||
'utf-8', 'gb2312', 'gbk', 'utf_16', 'ascii', 'big5', 'big5hkscs',
|
||||
'cp037', 'cp273', 'cp424', 'cp437',
|
||||
@ -265,6 +266,7 @@ def is_chinese(text):
|
||||
|
||||
|
||||
def tokenize(d, txt, eng):
|
||||
from . import rag_tokenizer
|
||||
d["content_with_weight"] = txt
|
||||
t = re.sub(r"</?(table|td|caption|tr|th)( [^<>]{0,12})?>", " ", txt)
|
||||
d["content_ltks"] = rag_tokenizer.tokenize(t)
|
||||
@ -362,6 +364,7 @@ def attach_media_context(chunks, table_context_size=0, image_context_size=0):
|
||||
Best-effort ordering: if positional info exists on any chunk, use it to
|
||||
order chunks before collecting context; otherwise keep original order.
|
||||
"""
|
||||
from . import rag_tokenizer
|
||||
if not chunks or (table_context_size <= 0 and image_context_size <= 0):
|
||||
return chunks
|
||||
|
||||
|
||||
@ -14,455 +14,23 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import copy
|
||||
import datrie
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
from hanziconv import HanziConv
|
||||
from nltk import word_tokenize
|
||||
from nltk.stem import PorterStemmer, WordNetLemmatizer
|
||||
from common.file_utils import get_project_base_directory
|
||||
import infinity.rag_tokenizer
|
||||
from common import settings
|
||||
|
||||
|
||||
class RagTokenizer:
|
||||
def key_(self, line):
|
||||
return str(line.lower().encode("utf-8"))[2:-1]
|
||||
|
||||
def rkey_(self, line):
|
||||
return str(("DD" + (line[::-1].lower())).encode("utf-8"))[2:-1]
|
||||
|
||||
def _load_dict(self, fnm):
|
||||
logging.info(f"[HUQIE]:Build trie from {fnm}")
|
||||
try:
|
||||
of = open(fnm, "r", encoding="utf-8")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
line = re.sub(r"[\r\n]+", "", line)
|
||||
line = re.split(r"[ \t]", line)
|
||||
k = self.key_(line[0])
|
||||
F = int(math.log(float(line[1]) / self.DENOMINATOR) + 0.5)
|
||||
if k not in self.trie_ or self.trie_[k][0] < F:
|
||||
self.trie_[self.key_(line[0])] = (F, line[2])
|
||||
self.trie_[self.rkey_(line[0])] = 1
|
||||
|
||||
dict_file_cache = fnm + ".trie"
|
||||
logging.info(f"[HUQIE]:Build trie cache to {dict_file_cache}")
|
||||
self.trie_.save(dict_file_cache)
|
||||
of.close()
|
||||
except Exception:
|
||||
logging.exception(f"[HUQIE]:Build trie {fnm} failed")
|
||||
|
||||
def __init__(self, debug=False):
|
||||
self.DEBUG = debug
|
||||
self.DENOMINATOR = 1000000
|
||||
self.DIR_ = os.path.join(get_project_base_directory(), "rag/res", "huqie")
|
||||
|
||||
self.stemmer = PorterStemmer()
|
||||
self.lemmatizer = WordNetLemmatizer()
|
||||
|
||||
self.SPLIT_CHAR = r"([ ,\.<>/?;:'\[\]\\`!@#$%^&*\(\)\{\}\|_+=《》,。?、;‘’:“”【】~!¥%……()——-]+|[a-zA-Z0-9,\.-]+)"
|
||||
|
||||
trie_file_name = self.DIR_ + ".txt.trie"
|
||||
# check if trie file existence
|
||||
if os.path.exists(trie_file_name):
|
||||
try:
|
||||
# load trie from file
|
||||
self.trie_ = datrie.Trie.load(trie_file_name)
|
||||
return
|
||||
except Exception:
|
||||
# fail to load trie from file, build default trie
|
||||
logging.exception(f"[HUQIE]:Fail to load trie file {trie_file_name}, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
else:
|
||||
# file not exist, build default trie
|
||||
logging.info(f"[HUQIE]:Trie file {trie_file_name} not found, build the default trie file")
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
|
||||
# load data from dict file and save to trie file
|
||||
self._load_dict(self.DIR_ + ".txt")
|
||||
|
||||
def load_user_dict(self, fnm):
|
||||
try:
|
||||
self.trie_ = datrie.Trie.load(fnm + ".trie")
|
||||
return
|
||||
except Exception:
|
||||
self.trie_ = datrie.Trie(string.printable)
|
||||
self._load_dict(fnm)
|
||||
|
||||
def add_user_dict(self, fnm):
|
||||
self._load_dict(fnm)
|
||||
|
||||
def _strQ2B(self, ustring):
|
||||
"""Convert full-width characters to half-width characters"""
|
||||
rstring = ""
|
||||
for uchar in ustring:
|
||||
inside_code = ord(uchar)
|
||||
if inside_code == 0x3000:
|
||||
inside_code = 0x0020
|
||||
else:
|
||||
inside_code -= 0xFEE0
|
||||
if inside_code < 0x0020 or inside_code > 0x7E: # After the conversion, if it's not a half-width character, return the original character.
|
||||
rstring += uchar
|
||||
else:
|
||||
rstring += chr(inside_code)
|
||||
return rstring
|
||||
|
||||
def _tradi2simp(self, line):
|
||||
return HanziConv.toSimplified(line)
|
||||
|
||||
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None):
|
||||
if _memo is None:
|
||||
_memo = {}
|
||||
MAX_DEPTH = 10
|
||||
if _depth > MAX_DEPTH:
|
||||
if s < len(chars):
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
remaining = "".join(chars[s:])
|
||||
copy_pretks.append((remaining, (-12, "")))
|
||||
tkslist.append(copy_pretks)
|
||||
return s
|
||||
|
||||
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None)
|
||||
if state_key in _memo:
|
||||
return _memo[state_key]
|
||||
|
||||
res = s
|
||||
if s >= len(chars):
|
||||
tkslist.append(preTks)
|
||||
_memo[state_key] = s
|
||||
return s
|
||||
if s < len(chars) - 4:
|
||||
is_repetitive = True
|
||||
char_to_check = chars[s]
|
||||
for i in range(1, 5):
|
||||
if s + i >= len(chars) or chars[s + i] != char_to_check:
|
||||
is_repetitive = False
|
||||
break
|
||||
if is_repetitive:
|
||||
end = s
|
||||
while end < len(chars) and chars[end] == char_to_check:
|
||||
end += 1
|
||||
mid = s + min(10, end - s)
|
||||
t = "".join(chars[s:mid])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, "")))
|
||||
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
res = max(res, next_res)
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
S = s + 1
|
||||
if s + 2 <= len(chars):
|
||||
t1 = "".join(chars[s : s + 1])
|
||||
t2 = "".join(chars[s : s + 2])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)):
|
||||
S = s + 2
|
||||
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1:
|
||||
t1 = preTks[-1][0] + "".join(chars[s : s + 1])
|
||||
if self.trie_.has_keys_with_prefix(self.key_(t1)):
|
||||
S = s + 2
|
||||
|
||||
for e in range(S, len(chars) + 1):
|
||||
t = "".join(chars[s:e])
|
||||
k = self.key_(t)
|
||||
if e > s + 1 and not self.trie_.has_keys_with_prefix(k):
|
||||
break
|
||||
if k in self.trie_:
|
||||
pretks = copy.deepcopy(preTks)
|
||||
pretks.append((t, self.trie_[k]))
|
||||
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo))
|
||||
|
||||
if res > s:
|
||||
_memo[state_key] = res
|
||||
return res
|
||||
|
||||
t = "".join(chars[s : s + 1])
|
||||
k = self.key_(t)
|
||||
copy_pretks = copy.deepcopy(preTks)
|
||||
if k in self.trie_:
|
||||
copy_pretks.append((t, self.trie_[k]))
|
||||
else:
|
||||
copy_pretks.append((t, (-12, "")))
|
||||
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo)
|
||||
_memo[state_key] = result
|
||||
return result
|
||||
|
||||
def freq(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return 0
|
||||
return int(math.exp(self.trie_[k][0]) * self.DENOMINATOR + 0.5)
|
||||
|
||||
def tag(self, tk):
|
||||
k = self.key_(tk)
|
||||
if k not in self.trie_:
|
||||
return ""
|
||||
return self.trie_[k][1]
|
||||
|
||||
def score_(self, tfts):
|
||||
B = 30
|
||||
F, L, tks = 0, 0, []
|
||||
for tk, (freq, tag) in tfts:
|
||||
F += freq
|
||||
L += 0 if len(tk) < 2 else 1
|
||||
tks.append(tk)
|
||||
# F /= len(tks)
|
||||
L /= len(tks)
|
||||
logging.debug("[SC] {} {} {} {} {}".format(tks, len(tks), L, F, B / len(tks) + L + F))
|
||||
return tks, B / len(tks) + L + F
|
||||
|
||||
def _sort_tokens(self, tkslist):
|
||||
res = []
|
||||
for tfts in tkslist:
|
||||
tks, s = self.score_(tfts)
|
||||
res.append((tks, s))
|
||||
return sorted(res, key=lambda x: x[1], reverse=True)
|
||||
|
||||
def merge_(self, tks):
|
||||
# if split chars is part of token
|
||||
res = []
|
||||
tks = re.sub(r"[ ]+", " ", tks).split()
|
||||
s = 0
|
||||
while True:
|
||||
if s >= len(tks):
|
||||
break
|
||||
E = s + 1
|
||||
for e in range(s + 2, min(len(tks) + 2, s + 6)):
|
||||
tk = "".join(tks[s:e])
|
||||
if re.search(self.SPLIT_CHAR, tk) and self.freq(tk):
|
||||
E = e
|
||||
res.append("".join(tks[s:E]))
|
||||
s = E
|
||||
|
||||
return " ".join(res)
|
||||
|
||||
def _max_forward(self, line):
|
||||
res = []
|
||||
s = 0
|
||||
while s < len(line):
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while e < len(line) and self.trie_.has_keys_with_prefix(self.key_(t)):
|
||||
e += 1
|
||||
t = line[s:e]
|
||||
|
||||
while e - 1 > s and self.key_(t) not in self.trie_:
|
||||
e -= 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s = e
|
||||
|
||||
return self.score_(res)
|
||||
|
||||
def _max_backward(self, line):
|
||||
res = []
|
||||
s = len(line) - 1
|
||||
while s >= 0:
|
||||
e = s + 1
|
||||
t = line[s:e]
|
||||
while s > 0 and self.trie_.has_keys_with_prefix(self.rkey_(t)):
|
||||
s -= 1
|
||||
t = line[s:e]
|
||||
|
||||
while s + 1 < e and self.key_(t) not in self.trie_:
|
||||
s += 1
|
||||
t = line[s:e]
|
||||
|
||||
if self.key_(t) in self.trie_:
|
||||
res.append((t, self.trie_[self.key_(t)]))
|
||||
else:
|
||||
res.append((t, (0, "")))
|
||||
|
||||
s -= 1
|
||||
|
||||
return self.score_(res[::-1])
|
||||
|
||||
def english_normalize_(self, tks):
|
||||
return [self.stemmer.stem(self.lemmatizer.lemmatize(t)) if re.match(r"[a-zA-Z_-]+$", t) else t for t in tks]
|
||||
|
||||
def _split_by_lang(self, line):
|
||||
txt_lang_pairs = []
|
||||
arr = re.split(self.SPLIT_CHAR, line)
|
||||
for a in arr:
|
||||
if not a:
|
||||
continue
|
||||
s = 0
|
||||
e = s + 1
|
||||
zh = is_chinese(a[s])
|
||||
while e < len(a):
|
||||
_zh = is_chinese(a[e])
|
||||
if _zh == zh:
|
||||
e += 1
|
||||
continue
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
s = e
|
||||
e = s + 1
|
||||
zh = _zh
|
||||
if s >= len(a):
|
||||
continue
|
||||
txt_lang_pairs.append((a[s:e], zh))
|
||||
return txt_lang_pairs
|
||||
class RagTokenizer(infinity.rag_tokenizer.RagTokenizer):
|
||||
|
||||
def tokenize(self, line: str) -> str:
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
return line
|
||||
line = re.sub(r"\W+", " ", line)
|
||||
line = self._strQ2B(line).lower()
|
||||
line = self._tradi2simp(line)
|
||||
|
||||
arr = self._split_by_lang(line)
|
||||
res = []
|
||||
for L, lang in arr:
|
||||
if not lang:
|
||||
res.extend([self.stemmer.stem(self.lemmatizer.lemmatize(t)) for t in word_tokenize(L)])
|
||||
continue
|
||||
if len(L) < 2 or re.match(r"[a-z\.-]+$", L) or re.match(r"[0-9\.-]+$", L):
|
||||
res.append(L)
|
||||
continue
|
||||
|
||||
# use maxforward for the first time
|
||||
tks, s = self._max_forward(L)
|
||||
tks1, s1 = self._max_backward(L)
|
||||
if self.DEBUG:
|
||||
logging.debug("[FW] {} {}".format(tks, s))
|
||||
logging.debug("[BW] {} {}".format(tks1, s1))
|
||||
|
||||
i, j, _i, _j = 0, 0, 0, 0
|
||||
same = 0
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
if same > 0:
|
||||
res.append(" ".join(tks[j : j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
while i < len(tks1) and j < len(tks):
|
||||
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
|
||||
if tk1 != tk:
|
||||
if len(tk1) > len(tk):
|
||||
j += 1
|
||||
else:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if tks1[i] != tks[j]:
|
||||
i += 1
|
||||
j += 1
|
||||
continue
|
||||
# backward tokens from_i to i are different from forward tokens from _j to j.
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
same = 1
|
||||
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
|
||||
same += 1
|
||||
res.append(" ".join(tks[j : j + same]))
|
||||
_i = i + same
|
||||
_j = j + same
|
||||
j = _j + 1
|
||||
i = _i + 1
|
||||
|
||||
if _i < len(tks1):
|
||||
assert _j < len(tks)
|
||||
assert "".join(tks1[_i:]) == "".join(tks[_j:])
|
||||
tkslist = []
|
||||
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
|
||||
res.append(" ".join(self._sort_tokens(tkslist)[0][0]))
|
||||
|
||||
res = " ".join(res)
|
||||
logging.debug("[TKS] {}".format(self.merge_(res)))
|
||||
return self.merge_(res)
|
||||
else:
|
||||
return super().tokenize(line)
|
||||
|
||||
def fine_grained_tokenize(self, tks: str) -> str:
|
||||
if settings.DOC_ENGINE_INFINITY:
|
||||
return tks
|
||||
tks = tks.split()
|
||||
zh_num = len([1 for c in tks if c and is_chinese(c[0])])
|
||||
if zh_num < len(tks) * 0.2:
|
||||
res = []
|
||||
for tk in tks:
|
||||
res.extend(tk.split("/"))
|
||||
return " ".join(res)
|
||||
|
||||
res = []
|
||||
for tk in tks:
|
||||
if len(tk) < 3 or re.match(r"[0-9,\.-]+$", tk):
|
||||
res.append(tk)
|
||||
continue
|
||||
tkslist = []
|
||||
if len(tk) > 10:
|
||||
tkslist.append(tk)
|
||||
else:
|
||||
self.dfs_(tk, 0, [], tkslist)
|
||||
if len(tkslist) < 2:
|
||||
res.append(tk)
|
||||
continue
|
||||
stk = self._sort_tokens(tkslist)[1][0]
|
||||
if len(stk) == len(tk):
|
||||
stk = tk
|
||||
else:
|
||||
if re.match(r"[a-z\.-]+$", tk):
|
||||
for t in stk:
|
||||
if len(t) < 3:
|
||||
stk = tk
|
||||
break
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
else:
|
||||
stk = " ".join(stk)
|
||||
|
||||
res.append(stk)
|
||||
|
||||
return " ".join(self.english_normalize_(res))
|
||||
|
||||
|
||||
def is_chinese(s):
|
||||
if s >= "\u4e00" and s <= "\u9fa5":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_number(s):
|
||||
if s >= "\u0030" and s <= "\u0039":
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_alphabet(s):
|
||||
if ("\u0041" <= s <= "\u005a") or ("\u0061" <= s <= "\u007a"):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def naive_qie(txt):
|
||||
tks = []
|
||||
for t in txt.split():
|
||||
if tks and re.match(r".*[a-zA-Z]$", tks[-1]) and re.match(r".*[a-zA-Z]$", t):
|
||||
tks.append(" ")
|
||||
tks.append(t)
|
||||
return tks
|
||||
else:
|
||||
return super().fine_grained_tokenize(tks)
|
||||
|
||||
|
||||
tokenizer = RagTokenizer()
|
||||
@ -470,40 +38,5 @@ tokenize = tokenizer.tokenize
|
||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||
tag = tokenizer.tag
|
||||
freq = tokenizer.freq
|
||||
load_user_dict = tokenizer.load_user_dict
|
||||
add_user_dict = tokenizer.add_user_dict
|
||||
tradi2simp = tokenizer._tradi2simp
|
||||
strQ2B = tokenizer._strQ2B
|
||||
|
||||
if __name__ == "__main__":
|
||||
tknzr = RagTokenizer(debug=True)
|
||||
# huqie.add_user_dict("/tmp/tmp.new.tks.dict")
|
||||
texts = [
|
||||
"over_the_past.pdf",
|
||||
"哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈哈",
|
||||
"公开征求意见稿提出,境外投资者可使用自有人民币或外汇投资。使用外汇投资的,可通过债券持有人在香港人民币业务清算行及香港地区经批准可进入境内银行间外汇市场进行交易的境外人民币业务参加行(以下统称香港结算行)办理外汇资金兑换。香港结算行由此所产生的头寸可到境内银行间外汇市场平盘。使用外汇投资的,在其投资的债券到期或卖出后,原则上应兑换回外汇。",
|
||||
"多校划片就是一个小区对应多个小学初中,让买了学区房的家庭也不确定到底能上哪个学校。目的是通过这种方式为学区房降温,把就近入学落到实处。南京市长江大桥",
|
||||
"实际上当时他们已经将业务中心偏移到安全部门和针对政府企业的部门 Scripts are compiled and cached aaaaaaaaa",
|
||||
"虽然我不怎么玩",
|
||||
"蓝月亮如何在外资夹击中生存,那是全宇宙最有意思的",
|
||||
"涡轮增压发动机num最大功率,不像别的共享买车锁电子化的手段,我们接过来是否有意义,黄黄爱美食,不过,今天阿奇要讲到的这家农贸市场,说实话,还真蛮有特色的!不仅环境好,还打出了",
|
||||
"这周日你去吗?这周日你有空吗?",
|
||||
"Unity3D开发经验 测试开发工程师 c++双11双11 985 211 ",
|
||||
"数据分析项目经理|数据分析挖掘|数据分析方向|商品数据分析|搜索数据分析 sql python hive tableau Cocos2d-",
|
||||
]
|
||||
for text in texts:
|
||||
print(text)
|
||||
tks1 = tknzr.tokenize(text)
|
||||
tks2 = tknzr.fine_grained_tokenize(tks1)
|
||||
print(tks1)
|
||||
print(tks2)
|
||||
if len(sys.argv) < 2:
|
||||
sys.exit()
|
||||
tknzr.load_user_dict(sys.argv[1])
|
||||
of = open(sys.argv[2], "r")
|
||||
while True:
|
||||
line = of.readline()
|
||||
if not line:
|
||||
break
|
||||
print(tknzr.tokenize(line))
|
||||
of.close()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { useDeleteMessage, useFeedback } from '@/hooks/chat-hooks';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { IRemoveMessageById, useSpeechWithSse } from '@/hooks/logic-hooks';
|
||||
import { useDeleteMessage, useFeedback } from '@/hooks/use-chat-request';
|
||||
import { IFeedbackRequestBody } from '@/interfaces/request/chat';
|
||||
import { hexStringToUint8Array } from '@/utils/common-util';
|
||||
import { SpeechPlayer } from 'openai-speech-stream-player';
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { useDeleteMessage, useFeedback } from '@/hooks/chat-hooks';
|
||||
// import { useDeleteMessage, useFeedback } from '@/hooks/chat-hooks';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { IRemoveMessageById, useSpeechWithSse } from '@/hooks/logic-hooks';
|
||||
import { useDeleteMessage, useFeedback } from '@/hooks/use-chat-request';
|
||||
import { IFeedbackRequestBody } from '@/interfaces/request/chat';
|
||||
import { hexStringToUint8Array } from '@/utils/common-util';
|
||||
import { SpeechPlayer } from 'openai-speech-stream-player';
|
||||
|
||||
@ -1,642 +0,0 @@
|
||||
import { ChatSearchParams } from '@/constants/chat';
|
||||
import {
|
||||
IClientConversation,
|
||||
IConversation,
|
||||
IDialog,
|
||||
IStats,
|
||||
IToken,
|
||||
} from '@/interfaces/database/chat';
|
||||
import {
|
||||
IAskRequestBody,
|
||||
IFeedbackRequestBody,
|
||||
} from '@/interfaces/request/chat';
|
||||
import i18n from '@/locales/config';
|
||||
import { useGetSharedChatSearchParams } from '@/pages/next-chats/hooks/use-send-shared-message';
|
||||
import chatService from '@/services/chat-service';
|
||||
import {
|
||||
buildMessageListWithUuid,
|
||||
getConversationId,
|
||||
isConversationIdExist,
|
||||
} from '@/utils/chat';
|
||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||
import { message } from 'antd';
|
||||
import dayjs, { Dayjs } from 'dayjs';
|
||||
import { has, set } from 'lodash';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { history, useSearchParams } from 'umi';
|
||||
|
||||
//#region logic
|
||||
|
||||
export const useClickDialogCard = () => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const [_, setSearchParams] = useSearchParams();
|
||||
|
||||
const newQueryParameters: URLSearchParams = useMemo(() => {
|
||||
return new URLSearchParams();
|
||||
}, []);
|
||||
|
||||
const handleClickDialog = useCallback(
|
||||
(dialogId: string) => {
|
||||
newQueryParameters.set(ChatSearchParams.DialogId, dialogId);
|
||||
// newQueryParameters.set(
|
||||
// ChatSearchParams.ConversationId,
|
||||
// EmptyConversationId,
|
||||
// );
|
||||
setSearchParams(newQueryParameters);
|
||||
},
|
||||
[newQueryParameters, setSearchParams],
|
||||
);
|
||||
|
||||
return { handleClickDialog };
|
||||
};
|
||||
|
||||
export const useClickConversationCard = () => {
|
||||
const [currentQueryParameters, setSearchParams] = useSearchParams();
|
||||
const newQueryParameters: URLSearchParams = useMemo(
|
||||
() => new URLSearchParams(currentQueryParameters.toString()),
|
||||
[currentQueryParameters],
|
||||
);
|
||||
|
||||
const handleClickConversation = useCallback(
|
||||
(conversationId: string, isNew: string) => {
|
||||
newQueryParameters.set(ChatSearchParams.ConversationId, conversationId);
|
||||
newQueryParameters.set(ChatSearchParams.isNew, isNew);
|
||||
setSearchParams(newQueryParameters);
|
||||
},
|
||||
[setSearchParams, newQueryParameters],
|
||||
);
|
||||
|
||||
return { handleClickConversation };
|
||||
};
|
||||
|
||||
export const useGetChatSearchParams = () => {
|
||||
const [currentQueryParameters] = useSearchParams();
|
||||
|
||||
return {
|
||||
dialogId: currentQueryParameters.get(ChatSearchParams.DialogId) || '',
|
||||
conversationId:
|
||||
currentQueryParameters.get(ChatSearchParams.ConversationId) || '',
|
||||
isNew: currentQueryParameters.get(ChatSearchParams.isNew) || '',
|
||||
};
|
||||
};
|
||||
|
||||
//#endregion
|
||||
|
||||
//#region dialog
|
||||
|
||||
export const useFetchNextDialogList = (pureFetch = false) => {
|
||||
const { handleClickDialog } = useClickDialogCard();
|
||||
const { dialogId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IDialog[]>({
|
||||
queryKey: ['fetchDialogList'],
|
||||
initialData: [],
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async (...params) => {
|
||||
console.log('🚀 ~ queryFn: ~ params:', params);
|
||||
const { data } = await chatService.listDialog();
|
||||
|
||||
if (data.code === 0) {
|
||||
const list: IDialog[] = data.data;
|
||||
if (!pureFetch) {
|
||||
if (list.length > 0) {
|
||||
if (list.every((x) => x.id !== dialogId)) {
|
||||
handleClickDialog(data.data[0].id);
|
||||
}
|
||||
} else {
|
||||
history.push('/chat');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useFetchChatAppList = () => {
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IDialog[]>({
|
||||
queryKey: ['fetchChatAppList'],
|
||||
initialData: [],
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async () => {
|
||||
const { data } = await chatService.listDialog();
|
||||
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useSetNextDialog = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['setDialog'],
|
||||
mutationFn: async (params: IDialog) => {
|
||||
const { data } = await chatService.setDialog(params);
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({
|
||||
exact: false,
|
||||
queryKey: ['fetchDialogList'],
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: ['fetchDialog'],
|
||||
});
|
||||
message.success(
|
||||
i18n.t(`message.${params.dialog_id ? 'modified' : 'created'}`),
|
||||
);
|
||||
}
|
||||
return data?.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, setDialog: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFetchNextDialog = () => {
|
||||
const { dialogId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IDialog>({
|
||||
queryKey: ['fetchDialog', dialogId],
|
||||
gcTime: 0,
|
||||
initialData: {} as IDialog,
|
||||
enabled: !!dialogId,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async () => {
|
||||
const { data } = await chatService.getDialog({ dialogId });
|
||||
|
||||
return data?.data ?? ({} as IDialog);
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useFetchManualDialog = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['fetchManualDialog'],
|
||||
gcTime: 0,
|
||||
mutationFn: async (dialogId: string) => {
|
||||
const { data } = await chatService.getDialog({ dialogId });
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, fetchDialog: mutateAsync };
|
||||
};
|
||||
|
||||
export const useRemoveNextDialog = () => {
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['removeDialog'],
|
||||
mutationFn: async (dialogIds: string[]) => {
|
||||
const { data } = await chatService.removeDialog({ dialogIds });
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({ queryKey: ['fetchDialogList'] });
|
||||
|
||||
message.success(i18n.t('message.deleted'));
|
||||
}
|
||||
return data.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, removeDialog: mutateAsync };
|
||||
};
|
||||
|
||||
//#endregion
|
||||
|
||||
//#region conversation
|
||||
|
||||
export const useFetchNextConversationList = () => {
|
||||
const { dialogId } = useGetChatSearchParams();
|
||||
const { handleClickConversation } = useClickConversationCard();
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IConversation[]>({
|
||||
queryKey: ['fetchConversationList', dialogId],
|
||||
initialData: [],
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
enabled: !!dialogId,
|
||||
queryFn: async () => {
|
||||
const { data } = await chatService.listConversation({ dialogId });
|
||||
if (data.code === 0) {
|
||||
if (data.data.length > 0) {
|
||||
handleClickConversation(data.data[0].id, '');
|
||||
} else {
|
||||
handleClickConversation('', '');
|
||||
}
|
||||
}
|
||||
return data?.data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useFetchNextConversation = () => {
|
||||
const { isNew, conversationId } = useGetChatSearchParams();
|
||||
const { sharedId } = useGetSharedChatSearchParams();
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IClientConversation>({
|
||||
queryKey: ['fetchConversation', conversationId],
|
||||
initialData: {} as IClientConversation,
|
||||
// enabled: isConversationIdExist(conversationId),
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async () => {
|
||||
if (
|
||||
isNew !== 'true' &&
|
||||
isConversationIdExist(sharedId || conversationId)
|
||||
) {
|
||||
const { data } = await chatService.getConversation({
|
||||
conversationId: conversationId || sharedId,
|
||||
});
|
||||
|
||||
const conversation = data?.data ?? {};
|
||||
|
||||
const messageList = buildMessageListWithUuid(conversation?.message);
|
||||
|
||||
return { ...conversation, message: messageList };
|
||||
}
|
||||
return { message: [] };
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useFetchNextConversationSSE = () => {
|
||||
const { isNew } = useGetChatSearchParams();
|
||||
const { sharedId } = useGetSharedChatSearchParams();
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IClientConversation>({
|
||||
queryKey: ['fetchConversationSSE', sharedId],
|
||||
initialData: {} as IClientConversation,
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async () => {
|
||||
if (isNew !== 'true' && isConversationIdExist(sharedId || '')) {
|
||||
if (!sharedId) return {};
|
||||
const { data } = await chatService.getConversationSSE({}, sharedId);
|
||||
const conversation = data?.data ?? {};
|
||||
const messageList = buildMessageListWithUuid(conversation?.message);
|
||||
return { ...conversation, message: messageList };
|
||||
}
|
||||
return { message: [] };
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useFetchManualConversation = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['fetchManualConversation'],
|
||||
gcTime: 0,
|
||||
mutationFn: async (conversationId: string) => {
|
||||
const { data } = await chatService.getConversation({ conversationId });
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, fetchConversation: mutateAsync };
|
||||
};
|
||||
|
||||
export const useUpdateNextConversation = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['updateConversation'],
|
||||
mutationFn: async (params: Record<string, any>) => {
|
||||
const { data } = await chatService.setConversation({
|
||||
...params,
|
||||
conversation_id: params.conversation_id
|
||||
? params.conversation_id
|
||||
: getConversationId(),
|
||||
});
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({ queryKey: ['fetchConversationList'] });
|
||||
message.success(i18n.t(`message.modified`));
|
||||
}
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, updateConversation: mutateAsync };
|
||||
};
|
||||
|
||||
export const useRemoveNextConversation = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const { dialogId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['removeConversation'],
|
||||
mutationFn: async (conversationIds: string[]) => {
|
||||
const { data } = await chatService.removeConversation({
|
||||
conversationIds,
|
||||
dialogId,
|
||||
});
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({ queryKey: ['fetchConversationList'] });
|
||||
}
|
||||
return data.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, removeConversation: mutateAsync };
|
||||
};
|
||||
|
||||
export const useDeleteMessage = () => {
|
||||
const { conversationId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['deleteMessage'],
|
||||
mutationFn: async (messageId: string) => {
|
||||
const { data } = await chatService.deleteMessage({
|
||||
messageId,
|
||||
conversationId,
|
||||
});
|
||||
|
||||
if (data.code === 0) {
|
||||
message.success(i18n.t(`message.deleted`));
|
||||
}
|
||||
|
||||
return data.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, deleteMessage: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFeedback = () => {
|
||||
const { conversationId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['feedback'],
|
||||
mutationFn: async (params: IFeedbackRequestBody) => {
|
||||
const { data } = await chatService.thumbup({
|
||||
...params,
|
||||
conversationId,
|
||||
});
|
||||
if (data.code === 0) {
|
||||
message.success(i18n.t(`message.operated`));
|
||||
}
|
||||
return data.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, feedback: mutateAsync };
|
||||
};
|
||||
|
||||
//#endregion
|
||||
|
||||
// #region API provided for external calls
|
||||
|
||||
export const useCreateNextToken = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['createToken'],
|
||||
mutationFn: async (params: Record<string, any>) => {
|
||||
const { data } = await chatService.createToken(params);
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({ queryKey: ['fetchTokenList'] });
|
||||
}
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, createToken: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFetchTokenList = (params: Record<string, any>) => {
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IToken[]>({
|
||||
queryKey: ['fetchTokenList', params],
|
||||
initialData: [],
|
||||
gcTime: 0,
|
||||
queryFn: async () => {
|
||||
const { data } = await chatService.listToken(params);
|
||||
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
export const useRemoveNextToken = () => {
|
||||
const queryClient = useQueryClient();
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['removeToken'],
|
||||
mutationFn: async (params: {
|
||||
tenantId: string;
|
||||
dialogId?: string;
|
||||
tokens: string[];
|
||||
}) => {
|
||||
const { data } = await chatService.removeToken(params);
|
||||
if (data.code === 0) {
|
||||
queryClient.invalidateQueries({ queryKey: ['fetchTokenList'] });
|
||||
}
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, removeToken: mutateAsync };
|
||||
};
|
||||
|
||||
type RangeValue = [Dayjs | null, Dayjs | null] | null;
|
||||
|
||||
const getDay = (date?: Dayjs) => date?.format('YYYY-MM-DD');
|
||||
|
||||
export const useFetchNextStats = () => {
|
||||
const [pickerValue, setPickerValue] = useState<RangeValue>([
|
||||
dayjs().subtract(7, 'day'),
|
||||
dayjs(),
|
||||
]);
|
||||
const { data, isFetching: loading } = useQuery<IStats>({
|
||||
queryKey: ['fetchStats', pickerValue],
|
||||
initialData: {} as IStats,
|
||||
gcTime: 0,
|
||||
queryFn: async () => {
|
||||
if (Array.isArray(pickerValue) && pickerValue[0]) {
|
||||
const { data } = await chatService.getStats({
|
||||
fromDate: getDay(pickerValue[0]),
|
||||
toDate: getDay(pickerValue[1] ?? dayjs()),
|
||||
});
|
||||
return data?.data ?? {};
|
||||
}
|
||||
return {};
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, pickerValue, setPickerValue };
|
||||
};
|
||||
|
||||
//#endregion
|
||||
|
||||
//#region shared chat
|
||||
|
||||
export const useCreateNextSharedConversation = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['createSharedConversation'],
|
||||
mutationFn: async (userId?: string) => {
|
||||
const { data } = await chatService.createExternalConversation({ userId });
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, createSharedConversation: mutateAsync };
|
||||
};
|
||||
|
||||
// deprecated
|
||||
export const useFetchNextSharedConversation = (
|
||||
conversationId?: string | null,
|
||||
) => {
|
||||
const { data, isPending: loading } = useQuery({
|
||||
queryKey: ['fetchSharedConversation'],
|
||||
enabled: !!conversationId,
|
||||
queryFn: async () => {
|
||||
if (!conversationId) {
|
||||
return {};
|
||||
}
|
||||
const { data } = await chatService.getExternalConversation(
|
||||
null,
|
||||
conversationId,
|
||||
);
|
||||
|
||||
const messageList = buildMessageListWithUuid(data?.data?.message);
|
||||
|
||||
set(data, 'data.message', messageList);
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading };
|
||||
};
|
||||
|
||||
//#endregion
|
||||
|
||||
//#region search page
|
||||
|
||||
export const useFetchMindMap = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['fetchMindMap'],
|
||||
gcTime: 0,
|
||||
mutationFn: async (params: IAskRequestBody) => {
|
||||
try {
|
||||
const ret = await chatService.getMindMap(params);
|
||||
return ret?.data?.data ?? {};
|
||||
} catch (error: any) {
|
||||
if (has(error, 'message')) {
|
||||
message.error(error.message);
|
||||
}
|
||||
|
||||
return [];
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, fetchMindMap: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFetchRelatedQuestions = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: ['fetchRelatedQuestions'],
|
||||
gcTime: 0,
|
||||
mutationFn: async (question: string): Promise<string[]> => {
|
||||
const { data } = await chatService.getRelatedQuestions({ question });
|
||||
|
||||
return data?.data ?? [];
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, fetchRelatedQuestions: mutateAsync };
|
||||
};
|
||||
//#endregion
|
||||
@ -7,7 +7,11 @@ import {
|
||||
IDialog,
|
||||
IExternalChatInfo,
|
||||
} from '@/interfaces/database/chat';
|
||||
import { IAskRequestBody } from '@/interfaces/request/chat';
|
||||
import {
|
||||
IAskRequestBody,
|
||||
IFeedbackRequestBody,
|
||||
} from '@/interfaces/request/chat';
|
||||
import i18n from '@/locales/config';
|
||||
import { useGetSharedChatSearchParams } from '@/pages/next-chats/hooks/use-send-shared-message';
|
||||
import { isConversationIdExist } from '@/pages/next-chats/utils';
|
||||
import chatService from '@/services/next-chat-service';
|
||||
@ -39,6 +43,9 @@ export const enum ChatApiAction {
|
||||
FetchRelatedQuestions = 'fetchRelatedQuestions',
|
||||
UploadAndParse = 'upload_and_parse',
|
||||
FetchExternalChatInfo = 'fetchExternalChatInfo',
|
||||
Feedback = 'feedback',
|
||||
CreateSharedConversation = 'createSharedConversation',
|
||||
FetchConversationSse = 'fetchConversationSSE',
|
||||
}
|
||||
|
||||
export const useGetChatSearchParams = () => {
|
||||
@ -397,6 +404,30 @@ export const useDeleteMessage = () => {
|
||||
return { data, loading, deleteMessage: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFeedback = () => {
|
||||
const { conversationId } = useGetChatSearchParams();
|
||||
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: [ChatApiAction.Feedback],
|
||||
mutationFn: async (params: IFeedbackRequestBody) => {
|
||||
const { data } = await chatService.thumbup({
|
||||
...params,
|
||||
conversationId,
|
||||
});
|
||||
if (data.code === 0) {
|
||||
message.success(i18n.t(`message.operated`));
|
||||
}
|
||||
return data.code;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, feedback: mutateAsync };
|
||||
};
|
||||
|
||||
type UploadParameters = Parameters<NonNullable<FileUploadProps['onUpload']>>;
|
||||
|
||||
type X = {
|
||||
@ -532,3 +563,47 @@ export const useFetchRelatedQuestions = () => {
|
||||
return { data, loading, fetchRelatedQuestions: mutateAsync };
|
||||
};
|
||||
//#endregion
|
||||
|
||||
export const useCreateNextSharedConversation = () => {
|
||||
const {
|
||||
data,
|
||||
isPending: loading,
|
||||
mutateAsync,
|
||||
} = useMutation({
|
||||
mutationKey: [ChatApiAction.CreateSharedConversation],
|
||||
mutationFn: async (userId?: string) => {
|
||||
const { data } = await chatService.createExternalConversation({ userId });
|
||||
|
||||
return data;
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, createSharedConversation: mutateAsync };
|
||||
};
|
||||
|
||||
export const useFetchNextConversationSSE = () => {
|
||||
const { isNew } = useGetChatSearchParams();
|
||||
const { sharedId } = useGetSharedChatSearchParams();
|
||||
const {
|
||||
data,
|
||||
isFetching: loading,
|
||||
refetch,
|
||||
} = useQuery<IClientConversation>({
|
||||
queryKey: [ChatApiAction.FetchConversationSse, sharedId],
|
||||
initialData: {} as IClientConversation,
|
||||
gcTime: 0,
|
||||
refetchOnWindowFocus: false,
|
||||
queryFn: async () => {
|
||||
if (isNew !== 'true' && isConversationIdExist(sharedId || '')) {
|
||||
if (!sharedId) return {};
|
||||
const { data } = await chatService.getConversationSSE(sharedId);
|
||||
const conversation = data?.data ?? {};
|
||||
const messageList = buildMessageListWithUuid(conversation?.message);
|
||||
return { ...conversation, message: messageList };
|
||||
}
|
||||
return { message: [] };
|
||||
},
|
||||
});
|
||||
|
||||
return { data, loading, refetch };
|
||||
};
|
||||
|
||||
@ -628,6 +628,10 @@ export default {
|
||||
'Für chinesische Benutzer ist keine Eingabe erforderlich oder verwenden Sie https://dashscope.aliyuncs.com/compatible-mode/v1. Für internationale Benutzer verwenden Sie https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Nur für internationale Benutzer, bitte Hinweis beachten)',
|
||||
minimaxBaseUrlTip:
|
||||
'Nur für internationale Nutzer: https://api.minimax.io/v1 verwenden.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Nur für internationale Benutzer, https://api.minimax.io/v1 eintragen)',
|
||||
modify: 'Ändern',
|
||||
systemModelSettings: 'Standardmodelle festlegen',
|
||||
chatModel: 'Chat-Modell',
|
||||
|
||||
@ -858,6 +858,10 @@ Example: Virtual Hosted Style`,
|
||||
tongyiBaseUrlTip:
|
||||
'For Chinese users, no need to fill in or use https://dashscope.aliyuncs.com/compatible-mode/v1. For international users, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder: '(International users only, please see tip)',
|
||||
minimaxBaseUrlTip:
|
||||
'International users only: use https://api.minimax.io/v1',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(International users only, fill in https://api.minimax.io/v1)',
|
||||
modify: 'Modify',
|
||||
systemModelSettings: 'Set default models',
|
||||
chatModel: 'LLM',
|
||||
|
||||
@ -344,6 +344,10 @@ export default {
|
||||
'Para usuarios chinos, no es necesario rellenar o usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuarios internacionales, usar https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Solo para usuarios internacionales, por favor ver consejo)',
|
||||
minimaxBaseUrlTip:
|
||||
'Solo usuarios internacionales: utilice https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Solo usuarios internacionales, ingrese https://api.minimax.io/v1)',
|
||||
modify: 'Modificar',
|
||||
systemModelSettings: 'Establecer modelos predeterminados',
|
||||
chatModel: 'Modelo de chat',
|
||||
|
||||
@ -526,6 +526,10 @@ export default {
|
||||
'Pour les utilisateurs chinois, pas besoin de remplir ou utiliser https://dashscope.aliyuncs.com/compatible-mode/v1. Pour les utilisateurs internationaux, utilisez https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
"(Utilisateurs internationaux uniquement, veuillez consulter l'astuce)",
|
||||
minimaxBaseUrlTip:
|
||||
'Utilisateurs internationaux uniquement : utilisez https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Utilisateurs internationaux uniquement, renseignez https://api.minimax.io/v1)',
|
||||
modify: 'Modifier',
|
||||
systemModelSettings: 'Définir les modèles par défaut',
|
||||
chatModel: 'Modèle de chat',
|
||||
|
||||
@ -516,6 +516,10 @@ export default {
|
||||
'Untuk pengguna Tiongkok, tidak perlu diisi atau gunakan https://dashscope.aliyuncs.com/compatible-mode/v1. Untuk pengguna internasional, gunakan https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Hanya untuk pengguna internasional, silakan lihat tip)',
|
||||
minimaxBaseUrlTip:
|
||||
'Hanya untuk pengguna internasional: gunakan https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Hanya untuk pengguna internasional, isi https://api.minimax.io/v1)',
|
||||
modify: 'Ubah',
|
||||
systemModelSettings: 'Tetapkan model default',
|
||||
chatModel: 'Model Obrolan',
|
||||
|
||||
@ -557,6 +557,10 @@ export default {
|
||||
tongyiBaseUrlTip:
|
||||
'中国ユーザーの場合、記入不要または https://dashscope.aliyuncs.com/compatible-mode/v1 を使用してください。国際ユーザーは https://dashscope-intl.aliyuncs.com/compatible-mode/v1 を使用してください',
|
||||
tongyiBaseUrlPlaceholder: '(国際ユーザーのみ、ヒントをご覧ください)',
|
||||
minimaxBaseUrlTip:
|
||||
'国際ユーザーのみ:https://api.minimax.io/v1 を使用してください。',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(国際ユーザーのみ、https://api.minimax.io/v1 を入力してください)',
|
||||
modify: '変更',
|
||||
systemModelSettings: 'デフォルトモデルを設定する',
|
||||
chatModel: 'チャットモデル',
|
||||
|
||||
@ -508,6 +508,10 @@ export default {
|
||||
'Para usuários chineses, não é necessário preencher ou usar https://dashscope.aliyuncs.com/compatible-mode/v1. Para usuários internacionais, use https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Apenas para usuários internacionais, consulte a dica)',
|
||||
minimaxBaseUrlTip:
|
||||
'Somente usuários internacionais: use https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Somente para usuários internacionais, preencha https://api.minimax.io/v1)',
|
||||
modify: 'Modificar',
|
||||
systemModelSettings: 'Definir modelos padrão',
|
||||
chatModel: 'Modelo de chat',
|
||||
|
||||
@ -846,6 +846,10 @@ export default {
|
||||
'Для китайских пользователей не нужно заполнять или используйте https://dashscope.aliyuncs.com/compatible-mode/v1. Для международных пользователей используйте https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder:
|
||||
'(Только для международных пользователей, см. подсказку)',
|
||||
minimaxBaseUrlTip:
|
||||
'Только для международных пользователей: используйте https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Только для международных пользователей, введите https://api.minimax.io/v1)',
|
||||
modify: 'Изменить',
|
||||
systemModelSettings: 'Установить модели по умолчанию',
|
||||
chatModel: 'LLM',
|
||||
|
||||
@ -558,6 +558,10 @@ export default {
|
||||
baseUrl: 'Base-Url',
|
||||
baseUrlTip:
|
||||
'Nếu khóa API của bạn từ OpenAI, chỉ cần bỏ qua nó. Bất kỳ nhà cung cấp trung gian nào khác sẽ cung cấp URL cơ sở này với khóa API.',
|
||||
minimaxBaseUrlTip:
|
||||
'Chỉ người dùng quốc tế: dùng https://api.minimax.io/v1.',
|
||||
minimaxBaseUrlPlaceholder:
|
||||
'(Chỉ dành cho người dùng quốc tế, điền https://api.minimax.io/v1)',
|
||||
modify: 'Sửa đổi',
|
||||
systemModelSettings: 'Đặt mô hình mặc định',
|
||||
chatModel: 'Mô hình trò chuyện',
|
||||
|
||||
@ -596,6 +596,8 @@ export default {
|
||||
tongyiBaseUrlTip:
|
||||
'中國用戶無需填寫或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。國際用戶請使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1',
|
||||
tongyiBaseUrlPlaceholder: '(僅國際用戶,請參閱提示)',
|
||||
minimaxBaseUrlTip: '僅國際用戶:使用 https://api.minimax.io/v1。',
|
||||
minimaxBaseUrlPlaceholder: '(僅國際用戶填寫 https://api.minimax.io/v1)',
|
||||
modify: '修改',
|
||||
systemModelSettings: '設定預設模型',
|
||||
chatModel: '聊天模型',
|
||||
|
||||
@ -813,6 +813,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
tongyiBaseUrlTip:
|
||||
'对于中国用户,不需要填写或使用 https://dashscope.aliyuncs.com/compatible-mode/v1。对于国际用户,使用 https://dashscope-intl.aliyuncs.com/compatible-mode/v1。',
|
||||
tongyiBaseUrlPlaceholder: '(仅国际用户需要)',
|
||||
minimaxBaseUrlTip: '仅国际用户:使用 https://api.minimax.io/v1。',
|
||||
minimaxBaseUrlPlaceholder: '(仅国际用户填写 https://api.minimax.io/v1)',
|
||||
modify: '修改',
|
||||
systemModelSettings: '设置默认模型',
|
||||
chatModel: 'LLM',
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import { MessageType, SharedFrom } from '@/constants/chat';
|
||||
import { useCreateNextSharedConversation } from '@/hooks/chat-hooks';
|
||||
import {
|
||||
useHandleMessageInputChange,
|
||||
useSelectDerivedMessages,
|
||||
useSendMessageWithSse,
|
||||
} from '@/hooks/logic-hooks';
|
||||
import { useCreateNextSharedConversation } from '@/hooks/use-chat-request';
|
||||
import { Message } from '@/interfaces/database/chat';
|
||||
import { message } from 'antd';
|
||||
import { get } from 'lodash';
|
||||
|
||||
@ -5,9 +5,11 @@ import PdfSheet from '@/components/pdf-drawer';
|
||||
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
|
||||
import { useSyncThemeFromParams } from '@/components/theme-provider';
|
||||
import { MessageType, SharedFrom } from '@/constants/chat';
|
||||
import { useFetchNextConversationSSE } from '@/hooks/chat-hooks';
|
||||
import { useFetchFlowSSE } from '@/hooks/flow-hooks';
|
||||
import { useFetchExternalChatInfo } from '@/hooks/use-chat-request';
|
||||
import {
|
||||
useFetchExternalChatInfo,
|
||||
useFetchNextConversationSSE,
|
||||
} from '@/hooks/use-chat-request';
|
||||
import i18n from '@/locales/config';
|
||||
import { buildMessageUuidWithRole } from '@/utils/chat';
|
||||
import React, { forwardRef, useMemo } from 'react';
|
||||
|
||||
@ -12,8 +12,8 @@ import { ResponsePostType } from '@/interfaces/database/base';
|
||||
import { IAnswer } from '@/interfaces/database/chat';
|
||||
import { ITestingResult } from '@/interfaces/database/knowledge';
|
||||
import { IAskRequestBody } from '@/interfaces/request/chat';
|
||||
import chatService from '@/services/chat-service';
|
||||
import kbService from '@/services/knowledge-service';
|
||||
import chatService from '@/services/next-chat-service';
|
||||
import searchService from '@/services/search-service';
|
||||
import api from '@/utils/api';
|
||||
import { useMutation } from '@tanstack/react-query';
|
||||
|
||||
@ -34,6 +34,7 @@ const modelsWithBaseUrl = [
|
||||
LLMFactory.OpenAI,
|
||||
LLMFactory.AzureOpenAI,
|
||||
LLMFactory.TongYiQianWen,
|
||||
LLMFactory.MiniMax,
|
||||
];
|
||||
|
||||
const ApiKeyModal = ({
|
||||
@ -109,7 +110,16 @@ const ApiKeyModal = ({
|
||||
name="base_url"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel className="text-sm font-medium text-text-primary">
|
||||
<FormLabel
|
||||
className="text-sm font-medium text-text-primary"
|
||||
tooltip={
|
||||
llmFactory === LLMFactory.MiniMax
|
||||
? t('minimaxBaseUrlTip')
|
||||
: llmFactory === LLMFactory.TongYiQianWen
|
||||
? t('tongyiBaseUrlTip')
|
||||
: t('baseUrlTip')
|
||||
}
|
||||
>
|
||||
{t('baseUrl')}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
@ -118,7 +128,9 @@ const ApiKeyModal = ({
|
||||
placeholder={
|
||||
llmFactory === LLMFactory.TongYiQianWen
|
||||
? t('tongyiBaseUrlPlaceholder')
|
||||
: 'https://api.openai.com/v1'
|
||||
: llmFactory === LLMFactory.MiniMax
|
||||
? t('minimaxBaseUrlPlaceholder')
|
||||
: 'https://api.openai.com/v1'
|
||||
}
|
||||
onKeyDown={handleKeyDown}
|
||||
className="w-full"
|
||||
|
||||
@ -1,133 +0,0 @@
|
||||
import api from '@/utils/api';
|
||||
import registerServer from '@/utils/register-server';
|
||||
import request from '@/utils/request';
|
||||
|
||||
const {
|
||||
getDialog,
|
||||
setDialog,
|
||||
listDialog,
|
||||
removeDialog,
|
||||
getConversation,
|
||||
getConversationSSE,
|
||||
setConversation,
|
||||
completeConversation,
|
||||
listConversation,
|
||||
removeConversation,
|
||||
createToken,
|
||||
listToken,
|
||||
removeToken,
|
||||
getStats,
|
||||
createExternalConversation,
|
||||
getExternalConversation,
|
||||
completeExternalConversation,
|
||||
uploadAndParseExternal,
|
||||
deleteMessage,
|
||||
thumbup,
|
||||
tts,
|
||||
ask,
|
||||
mindmap,
|
||||
getRelatedQuestions,
|
||||
} = api;
|
||||
|
||||
const methods = {
|
||||
getDialog: {
|
||||
url: getDialog,
|
||||
method: 'get',
|
||||
},
|
||||
setDialog: {
|
||||
url: setDialog,
|
||||
method: 'post',
|
||||
},
|
||||
removeDialog: {
|
||||
url: removeDialog,
|
||||
method: 'post',
|
||||
},
|
||||
listDialog: {
|
||||
url: listDialog,
|
||||
method: 'get',
|
||||
},
|
||||
listConversation: {
|
||||
url: listConversation,
|
||||
method: 'get',
|
||||
},
|
||||
getConversation: {
|
||||
url: getConversation,
|
||||
method: 'get',
|
||||
},
|
||||
getConversationSSE: {
|
||||
url: getConversationSSE,
|
||||
method: 'get',
|
||||
},
|
||||
setConversation: {
|
||||
url: setConversation,
|
||||
method: 'post',
|
||||
},
|
||||
completeConversation: {
|
||||
url: completeConversation,
|
||||
method: 'post',
|
||||
},
|
||||
removeConversation: {
|
||||
url: removeConversation,
|
||||
method: 'post',
|
||||
},
|
||||
createToken: {
|
||||
url: createToken,
|
||||
method: 'post',
|
||||
},
|
||||
listToken: {
|
||||
url: listToken,
|
||||
method: 'get',
|
||||
},
|
||||
removeToken: {
|
||||
url: removeToken,
|
||||
method: 'post',
|
||||
},
|
||||
getStats: {
|
||||
url: getStats,
|
||||
method: 'get',
|
||||
},
|
||||
createExternalConversation: {
|
||||
url: createExternalConversation,
|
||||
method: 'get',
|
||||
},
|
||||
getExternalConversation: {
|
||||
url: getExternalConversation,
|
||||
method: 'get',
|
||||
},
|
||||
completeExternalConversation: {
|
||||
url: completeExternalConversation,
|
||||
method: 'post',
|
||||
},
|
||||
uploadAndParseExternal: {
|
||||
url: uploadAndParseExternal,
|
||||
method: 'post',
|
||||
},
|
||||
deleteMessage: {
|
||||
url: deleteMessage,
|
||||
method: 'post',
|
||||
},
|
||||
thumbup: {
|
||||
url: thumbup,
|
||||
method: 'post',
|
||||
},
|
||||
tts: {
|
||||
url: tts,
|
||||
method: 'post',
|
||||
},
|
||||
ask: {
|
||||
url: ask,
|
||||
method: 'post',
|
||||
},
|
||||
getMindMap: {
|
||||
url: mindmap,
|
||||
method: 'post',
|
||||
},
|
||||
getRelatedQuestions: {
|
||||
url: getRelatedQuestions,
|
||||
method: 'post',
|
||||
},
|
||||
} as const;
|
||||
|
||||
const chatService = registerServer<keyof typeof methods>(methods, request);
|
||||
|
||||
export default chatService;
|
||||
@ -119,7 +119,8 @@ export default {
|
||||
listDialog: `${api_host}/dialog/list`,
|
||||
setConversation: `${api_host}/conversation/set`,
|
||||
getConversation: `${api_host}/conversation/get`,
|
||||
getConversationSSE: `${api_host}/conversation/getsse`,
|
||||
getConversationSSE: (dialogId: string) =>
|
||||
`${api_host}/conversation/getsse/${dialogId}`,
|
||||
listConversation: `${api_host}/conversation/list`,
|
||||
removeConversation: `${api_host}/conversation/rm`,
|
||||
completeConversation: `${api_host}/conversation/completion`,
|
||||
|
||||
Reference in New Issue
Block a user