mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 09a3854ed8 | |||
| 43f51baa96 | |||
| 5a2011e687 | |||
| 7dd9ce0b5f | |||
| b66881a371 | |||
| 4d7934061e | |||
| 660fa8888b | |||
| 3285f09c92 | |||
| 51ec708c58 | |||
| 9b8971a9de | |||
| 6546f86b4e |
@ -78,12 +78,12 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
# A modern version of cargo is needed for the latest version of the Rust compiler.
|
||||||
RUN apt update && apt install -y curl build-essential \
|
RUN apt update && apt install -y curl build-essential \
|
||||||
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
&& if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
# Use TUNA mirrors for rustup/rust dist files
|
# Use TUNA mirrors for rustup/rust dist files \
|
||||||
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
export RUSTUP_DIST_SERVER="https://mirrors.tuna.tsinghua.edu.cn/rustup"; \
|
||||||
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
export RUSTUP_UPDATE_ROOT="https://mirrors.tuna.tsinghua.edu.cn/rustup/rustup"; \
|
||||||
echo "Using TUNA mirrors for Rustup."; \
|
echo "Using TUNA mirrors for Rustup."; \
|
||||||
fi; \
|
fi; \
|
||||||
# Force curl to use HTTP/1.1
|
# Force curl to use HTTP/1.1 \
|
||||||
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
curl --proto '=https' --tlsv1.2 --http1.1 -sSf https://sh.rustup.rs | bash -s -- -y --profile minimal \
|
||||||
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
&& echo 'export PATH="/root/.cargo/bin:${PATH}"' >> /root/.bashrc
|
||||||
|
|
||||||
|
|||||||
@ -478,7 +478,7 @@ class Canvas(Graph):
|
|||||||
})
|
})
|
||||||
await _run_batch(idx, to)
|
await _run_batch(idx, to)
|
||||||
to = len(self.path)
|
to = len(self.path)
|
||||||
# post processing of components invocation
|
# post-processing of components invocation
|
||||||
for i in range(idx, to):
|
for i in range(idx, to):
|
||||||
cpn = self.get_component(self.path[i])
|
cpn = self.get_component(self.path[i])
|
||||||
cpn_obj = self.get_component_obj(self.path[i])
|
cpn_obj = self.get_component_obj(self.path[i])
|
||||||
|
|||||||
@ -75,7 +75,7 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
|
||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if not kwargs.get("stock_code"):
|
if not kwargs.get("stock_code"):
|
||||||
self.set_output("report", "")
|
self.set_output("report", "")
|
||||||
@ -84,33 +84,33 @@ class YahooFinance(ToolBase, ABC):
|
|||||||
last_e = ""
|
last_e = ""
|
||||||
for _ in range(self._param.max_retries+1):
|
for _ in range(self._param.max_retries+1):
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
yohoo_res = []
|
yahoo_res = []
|
||||||
try:
|
try:
|
||||||
msft = yf.Ticker(kwargs["stock_code"])
|
msft = yf.Ticker(kwargs["stock_code"])
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
if self._param.info:
|
if self._param.info:
|
||||||
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
yahoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
|
||||||
if self._param.history:
|
if self._param.history:
|
||||||
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
yahoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
|
||||||
if self._param.financials:
|
if self._param.financials:
|
||||||
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
yahoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
|
||||||
if self._param.balance_sheet:
|
if self._param.balance_sheet:
|
||||||
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
|
||||||
if self._param.cash_flow_statement:
|
if self._param.cash_flow_statement:
|
||||||
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
|
||||||
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
yahoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
|
||||||
if self._param.news:
|
if self._param.news:
|
||||||
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
yahoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
|
||||||
self.set_output("report", "\n\n".join(yohoo_res))
|
self.set_output("report", "\n\n".join(yahoo_res))
|
||||||
return self.output("report")
|
return self.output("report")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.check_if_canceled("YahooFinance processing"):
|
if self.check_if_canceled("YahooFinance processing"):
|
||||||
return
|
return None
|
||||||
|
|
||||||
last_e = e
|
last_e = e
|
||||||
logging.exception(f"YahooFinance error: {e}")
|
logging.exception(f"YahooFinance error: {e}")
|
||||||
|
|||||||
@ -180,7 +180,7 @@ def login_user(user, remember=False, duration=None, force=False, fresh=True):
|
|||||||
user's `is_active` property is ``False``, they will not be logged in
|
user's `is_active` property is ``False``, they will not be logged in
|
||||||
unless `force` is ``True``.
|
unless `force` is ``True``.
|
||||||
|
|
||||||
This will return ``True`` if the log in attempt succeeds, and ``False`` if
|
This will return ``True`` if the login attempt succeeds, and ``False`` if
|
||||||
it fails (i.e. because the user is inactive).
|
it fails (i.e. because the user is inactive).
|
||||||
|
|
||||||
:param user: The user object to log in.
|
:param user: The user object to log in.
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from quart import Response, request
|
|||||||
from api.apps import current_user, login_required
|
from api.apps import current_user, login_required
|
||||||
from api.db.db_models import APIToken
|
from api.db.db_models import APIToken
|
||||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
@ -218,10 +218,10 @@ async def completion():
|
|||||||
dia.llm_setting = chat_model_config
|
dia.llm_setting = chat_model_config
|
||||||
|
|
||||||
is_embedded = bool(chat_model_id)
|
is_embedded = bool(chat_model_id)
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal dia, msg, req, conv
|
nonlocal dia, msg, req, conv
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **req):
|
async for ans in async_chat(dia, msg, True, **req):
|
||||||
ans = structure_answer(conv, ans, message_id, conv.id)
|
ans = structure_answer(conv, ans, message_id, conv.id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
@ -241,7 +241,7 @@ async def completion():
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, **req):
|
async for ans in async_chat(dia, msg, **req):
|
||||||
answer = structure_answer(conv, ans, message_id, conv.id)
|
answer = structure_answer(conv, ans, message_id, conv.id)
|
||||||
if not is_embedded:
|
if not is_embedded:
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -406,10 +406,10 @@ async def ask_about():
|
|||||||
if search_app:
|
if search_app:
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||||
|
|||||||
@ -34,8 +34,9 @@ async def set_api_key():
|
|||||||
if not all([secret_key, public_key, host]):
|
if not all([secret_key, public_key, host]):
|
||||||
return get_error_data_result(message="Missing required fields")
|
return get_error_data_result(message="Missing required fields")
|
||||||
|
|
||||||
|
current_user_id = current_user.id
|
||||||
langfuse_keys = dict(
|
langfuse_keys = dict(
|
||||||
tenant_id=current_user.id,
|
tenant_id=current_user_id,
|
||||||
secret_key=secret_key,
|
secret_key=secret_key,
|
||||||
public_key=public_key,
|
public_key=public_key,
|
||||||
host=host,
|
host=host,
|
||||||
@ -45,23 +46,24 @@ async def set_api_key():
|
|||||||
if not langfuse.auth_check():
|
if not langfuse.auth_check():
|
||||||
return get_error_data_result(message="Invalid Langfuse keys")
|
return get_error_data_result(message="Invalid Langfuse keys")
|
||||||
|
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
with DB.atomic():
|
with DB.atomic():
|
||||||
try:
|
try:
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
TenantLangfuseService.save(**langfuse_keys)
|
TenantLangfuseService.save(**langfuse_keys)
|
||||||
else:
|
else:
|
||||||
TenantLangfuseService.update_by_tenant(tenant_id=current_user.id, langfuse_keys=langfuse_keys)
|
TenantLangfuseService.update_by_tenant(tenant_id=current_user_id, langfuse_keys=langfuse_keys)
|
||||||
return get_json_result(data=langfuse_keys)
|
return get_json_result(data=langfuse_keys)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
@manager.route("/api_key", methods=["GET"]) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def get_api_key():
|
def get_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant_with_info(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -72,7 +74,7 @@ def get_api_key():
|
|||||||
except langfuse.api.core.api_error.ApiError as api_err:
|
except langfuse.api.core.api_error.ApiError as api_err:
|
||||||
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
return get_json_result(message=f"Error from Langfuse: {api_err}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
langfuse_entry["project_id"] = langfuse.api.projects.get().dict()["data"][0]["id"]
|
||||||
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
langfuse_entry["project_name"] = langfuse.api.projects.get().dict()["data"][0]["name"]
|
||||||
@ -84,7 +86,8 @@ def get_api_key():
|
|||||||
@login_required
|
@login_required
|
||||||
@validate_request()
|
@validate_request()
|
||||||
def delete_api_key():
|
def delete_api_key():
|
||||||
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user.id)
|
current_user_id = current_user.id
|
||||||
|
langfuse_entry = TenantLangfuseService.filter_by_tenant(tenant_id=current_user_id)
|
||||||
if not langfuse_entry:
|
if not langfuse_entry:
|
||||||
return get_json_result(message="Have not record any Langfuse keys.")
|
return get_json_result(message="Have not record any Langfuse keys.")
|
||||||
|
|
||||||
@ -93,4 +96,4 @@ def delete_api_key():
|
|||||||
TenantLangfuseService.delete_model(langfuse_entry)
|
TenantLangfuseService.delete_model(langfuse_entry)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -74,7 +74,7 @@ async def set_api_key():
|
|||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
||||||
if m.find("**ERROR**") >= 0:
|
if m.find("**ERROR**") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
chat_passed = True
|
chat_passed = True
|
||||||
@ -217,7 +217,7 @@ async def add_llm():
|
|||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
m, tc = mdl.chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -552,7 +552,7 @@ def list_docs(dataset_id, tenant_id):
|
|||||||
create_time_from = int(q.get("create_time_from", 0))
|
create_time_from = int(q.get("create_time_from", 0))
|
||||||
create_time_to = int(q.get("create_time_to", 0))
|
create_time_to = int(q.get("create_time_to", 0))
|
||||||
|
|
||||||
# map run status (accept text or numeric) - align with API parameter
|
# map run status (text or numeric) - align with API parameter
|
||||||
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
run_status_text_to_numeric = {"UNSTART": "0", "RUNNING": "1", "CANCEL": "2", "DONE": "3", "FAIL": "4"}
|
||||||
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
run_status_converted = [run_status_text_to_numeric.get(v, v) for v in run_status]
|
||||||
|
|
||||||
@ -890,7 +890,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
|||||||
type: string
|
type: string
|
||||||
required: false
|
required: false
|
||||||
default: ""
|
default: ""
|
||||||
description: Chunk Id.
|
description: Chunk id.
|
||||||
- in: header
|
- in: header
|
||||||
name: Authorization
|
name: Authorization
|
||||||
type: string
|
type: string
|
||||||
|
|||||||
@ -26,9 +26,10 @@ from api.db.db_models import APIToken
|
|||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
from api.db.services.conversation_service import async_completion as rag_completion
|
||||||
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
@ -141,7 +142,7 @@ async def chat_completion(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in rag_completion(tenant_id, chat_id, **req):
|
async for ans in rag_completion(tenant_id, chat_id, **req):
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
@ -245,7 +246,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
# The value for the usage field on all chunks except for the last one will be null.
|
# The value for the usage field on all chunks except for the last one will be null.
|
||||||
# The usage field on the last chunk contains token usage statistics for the entire request.
|
# The usage field on the last chunk contains token usage statistics for the entire request.
|
||||||
# The choices field on the last chunk will always be an empty array [].
|
# The choices field on the last chunk will always be an empty array [].
|
||||||
def streamed_response_generator(chat_id, dia, msg):
|
async def streamed_response_generator(chat_id, dia, msg):
|
||||||
token_used = 0
|
token_used = 0
|
||||||
answer_cache = ""
|
answer_cache = ""
|
||||||
reasoning_cache = ""
|
reasoning_cache = ""
|
||||||
@ -274,7 +275,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
last_ans = ans
|
last_ans = ans
|
||||||
answer = ans["answer"]
|
answer = ans["answer"]
|
||||||
|
|
||||||
@ -342,7 +343,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
|||||||
return resp
|
return resp
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||||
# focus answer content only
|
# focus answer content only
|
||||||
answer = ans
|
answer = ans
|
||||||
break
|
break
|
||||||
@ -733,10 +734,10 @@ async def ask_about(tenant_id):
|
|||||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||||
uid = tenant_id
|
uid = tenant_id
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
@ -827,7 +828,7 @@ async def chatbot_completions(dialog_id):
|
|||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
for answer in iframe_completion(dialog_id, **req):
|
async for answer in iframe_completion(dialog_id, **req):
|
||||||
return get_result(data=answer)
|
return get_result(data=answer)
|
||||||
|
|
||||||
|
|
||||||
@ -918,10 +919,10 @@ async def ask_about_embedded():
|
|||||||
if search_app := SearchService.get_detail(search_id):
|
if search_app := SearchService.get_detail(search_id):
|
||||||
search_config = search_app.get("search_config", {})
|
search_config = search_app.get("search_config", {})
|
||||||
|
|
||||||
def stream():
|
async def stream():
|
||||||
nonlocal req, uid
|
nonlocal req, uid
|
||||||
try:
|
try:
|
||||||
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
async for ans in async_ask(req["question"], req["kb_ids"], uid, search_config=search_config):
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
yield "data:" + json.dumps(
|
yield "data:" + json.dumps(
|
||||||
|
|||||||
@ -19,7 +19,7 @@ from common.constants import StatusEnum
|
|||||||
from api.db.db_models import Conversation, DB
|
from api.db.db_models import Conversation, DB
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService, async_chat
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@ -89,8 +89,7 @@ def structure_answer(conv, ans, message_id, session_id):
|
|||||||
conv.reference[-1] = reference
|
conv.reference[-1] = reference
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
async def async_completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
||||||
def completion(tenant_id, chat_id, question, name="New session", session_id=None, stream=True, **kwargs):
|
|
||||||
assert name, "`name` can not be empty."
|
assert name, "`name` can not be empty."
|
||||||
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
dia = DialogService.query(id=chat_id, tenant_id=tenant_id, status=StatusEnum.VALID.value)
|
||||||
assert dia, "You do not own the chat."
|
assert dia, "You do not own the chat."
|
||||||
@ -112,7 +111,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
"reference": {},
|
"reference": {},
|
||||||
"audio_binary": None,
|
"audio_binary": None,
|
||||||
"id": None,
|
"id": None,
|
||||||
"session_id": session_id
|
"session_id": session_id
|
||||||
}},
|
}},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||||
@ -148,7 +147,7 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
yield "data:" + json.dumps({"code": 0, "data": ans}, ensure_ascii=False) + "\n\n"
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
@ -160,14 +159,13 @@ def completion(tenant_id, chat_id, question, name="New session", session_id=None
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
yield answer
|
yield answer
|
||||||
|
|
||||||
|
async def async_iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
||||||
def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwargs):
|
|
||||||
e, dia = DialogService.get_by_id(dialog_id)
|
e, dia = DialogService.get_by_id(dialog_id)
|
||||||
assert e, "Dialog not found"
|
assert e, "Dialog not found"
|
||||||
if not session_id:
|
if not session_id:
|
||||||
@ -222,7 +220,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
try:
|
try:
|
||||||
for ans in chat(dia, msg, True, **kwargs):
|
async for ans in async_chat(dia, msg, True, **kwargs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans = structure_answer(conv, ans, message_id, session_id)
|
||||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||||
ensure_ascii=False) + "\n\n"
|
ensure_ascii=False) + "\n\n"
|
||||||
@ -235,7 +233,7 @@ def iframe_completion(dialog_id, question, session_id=None, stream=True, **kwarg
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
answer = None
|
answer = None
|
||||||
for ans in chat(dia, msg, False, **kwargs):
|
async for ans in async_chat(dia, msg, False, **kwargs):
|
||||||
answer = structure_answer(conv, ans, message_id, session_id)
|
answer = structure_answer(conv, ans, message_id, session_id)
|
||||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||||
break
|
break
|
||||||
|
|||||||
@ -178,7 +178,8 @@ class DialogService(CommonService):
|
|||||||
offset += limit
|
offset += limit
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def chat_solo(dialog, messages, stream=True):
|
|
||||||
|
async def async_chat_solo(dialog, messages, stream=True):
|
||||||
attachments = ""
|
attachments = ""
|
||||||
if "files" in messages[-1]:
|
if "files" in messages[-1]:
|
||||||
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
attachments = "\n\n".join(FileService.get_files(messages[-1]["files"]))
|
||||||
@ -197,7 +198,8 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
delta_ans = ""
|
delta_ans = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
answer = ""
|
||||||
|
async for ans in chat_mdl.async_chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||||
answer = ans
|
answer = ans
|
||||||
delta_ans = ans[len(last_ans):]
|
delta_ans = ans[len(last_ans):]
|
||||||
if num_tokens_from_string(delta_ans) < 16:
|
if num_tokens_from_string(delta_ans) < 16:
|
||||||
@ -208,7 +210,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||||||
if delta_ans:
|
if delta_ans:
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()}
|
||||||
@ -347,13 +349,12 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
|||||||
return []
|
return []
|
||||||
return list(doc_ids)
|
return list(doc_ids)
|
||||||
|
|
||||||
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
def chat(dialog, messages, stream=True, **kwargs):
|
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
for ans in chat_solo(dialog, messages, stream):
|
async for ans in async_chat_solo(dialog, messages, stream):
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
chat_start_ts = timer()
|
chat_start_ts = timer()
|
||||||
|
|
||||||
@ -400,7 +401,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return None
|
return
|
||||||
|
|
||||||
for p in prompt_config["parameters"]:
|
for p in prompt_config["parameters"]:
|
||||||
if p["key"] == "knowledge":
|
if p["key"] == "knowledge":
|
||||||
@ -508,7 +509,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
empty_res = prompt_config["empty_response"]
|
empty_res = prompt_config["empty_response"]
|
||||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions),
|
||||||
"audio_binary": tts(tts_mdl, empty_res)}
|
"audio_binary": tts(tts_mdl, empty_res)}
|
||||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||||
|
return
|
||||||
|
|
||||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||||
gen_conf = dialog.llm_setting
|
gen_conf = dialog.llm_setting
|
||||||
@ -612,7 +614,7 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if stream:
|
if stream:
|
||||||
last_ans = ""
|
last_ans = ""
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
async for ans in chat_mdl.async_chat_streamly(prompt + prompt4citation, msg[1:], gen_conf):
|
||||||
if thought:
|
if thought:
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
answer = ans
|
answer = ans
|
||||||
@ -626,19 +628,19 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||||||
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||||
yield decorate_answer(thought + answer)
|
yield decorate_answer(thought + answer)
|
||||||
else:
|
else:
|
||||||
answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf)
|
answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf)
|
||||||
user_content = msg[-1].get("content", "[content not available]")
|
user_content = msg[-1].get("content", "[content not available]")
|
||||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||||
res = decorate_answer(answer)
|
res = decorate_answer(answer)
|
||||||
res["audio_binary"] = tts(tts_mdl, answer)
|
res["audio_binary"] = tts(tts_mdl, answer)
|
||||||
yield res
|
yield res
|
||||||
|
|
||||||
return None
|
return
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = """
|
sys_prompt = """
|
||||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||||
Ensure that:
|
Ensure that:
|
||||||
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
1. Field names should not start with a digit. If any field name starts with a digit, use double quotes around it.
|
||||||
2. Write only the SQL, no explanations or additional text.
|
2. Write only the SQL, no explanations or additional text.
|
||||||
@ -805,8 +807,7 @@ def tts(tts_mdl, text):
|
|||||||
return None
|
return None
|
||||||
return binascii.hexlify(bin).decode("utf-8")
|
return binascii.hexlify(bin).decode("utf-8")
|
||||||
|
|
||||||
|
async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
||||||
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|
||||||
doc_ids = search_config.get("doc_ids", [])
|
doc_ids = search_config.get("doc_ids", [])
|
||||||
rerank_mdl = None
|
rerank_mdl = None
|
||||||
kb_ids = search_config.get("kb_ids", kb_ids)
|
kb_ids = search_config.get("kb_ids", kb_ids)
|
||||||
@ -880,7 +881,7 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
|
|||||||
return {"answer": answer, "reference": refs}
|
return {"answer": answer, "reference": refs}
|
||||||
|
|
||||||
answer = ""
|
answer = ""
|
||||||
for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
async for ans in chat_mdl.async_chat_streamly(sys_prompt, msg, {"temperature": 0.1}):
|
||||||
answer = ans
|
answer = ans
|
||||||
yield {"answer": answer, "reference": {}}
|
yield {"answer": answer, "reference": {}}
|
||||||
yield decorate_answer(answer)
|
yield decorate_answer(answer)
|
||||||
|
|||||||
@ -25,14 +25,17 @@ Provides functionality for evaluating RAG system performance including:
|
|||||||
- Configuration recommendations
|
- Configuration recommendations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from typing import List, Dict, Any, Optional, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
|
|
||||||
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
from api.db.db_models import EvaluationDataset, EvaluationCase, EvaluationRun, EvaluationResult
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.dialog_service import DialogService, chat
|
from api.db.services.dialog_service import DialogService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from common.time_utils import current_timestamp
|
from common.time_utils import current_timestamp
|
||||||
from common.constants import StatusEnum
|
from common.constants import StatusEnum
|
||||||
@ -40,24 +43,24 @@ from common.constants import StatusEnum
|
|||||||
|
|
||||||
class EvaluationService(CommonService):
|
class EvaluationService(CommonService):
|
||||||
"""Service for managing RAG evaluations"""
|
"""Service for managing RAG evaluations"""
|
||||||
|
|
||||||
model = EvaluationDataset
|
model = EvaluationDataset
|
||||||
|
|
||||||
# ==================== Dataset Management ====================
|
# ==================== Dataset Management ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
def create_dataset(cls, name: str, description: str, kb_ids: List[str],
|
||||||
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
tenant_id: str, user_id: str) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Create a new evaluation dataset.
|
Create a new evaluation dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Dataset name
|
name: Dataset name
|
||||||
description: Dataset description
|
description: Dataset description
|
||||||
kb_ids: List of knowledge base IDs to evaluate against
|
kb_ids: List of knowledge base IDs to evaluate against
|
||||||
tenant_id: Tenant ID
|
tenant_id: Tenant ID
|
||||||
user_id: User ID who creates the dataset
|
user_id: User ID who creates the dataset
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(success, dataset_id or error_message)
|
(success, dataset_id or error_message)
|
||||||
"""
|
"""
|
||||||
@ -74,15 +77,15 @@ class EvaluationService(CommonService):
|
|||||||
"update_time": current_timestamp(),
|
"update_time": current_timestamp(),
|
||||||
"status": StatusEnum.VALID.value
|
"status": StatusEnum.VALID.value
|
||||||
}
|
}
|
||||||
|
|
||||||
if not EvaluationDataset.create(**dataset):
|
if not EvaluationDataset.create(**dataset):
|
||||||
return False, "Failed to create dataset"
|
return False, "Failed to create dataset"
|
||||||
|
|
||||||
return True, dataset_id
|
return True, dataset_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error creating evaluation dataset: {e}")
|
logging.error(f"Error creating evaluation dataset: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
def get_dataset(cls, dataset_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get dataset by ID"""
|
"""Get dataset by ID"""
|
||||||
@ -94,9 +97,9 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
logging.error(f"Error getting dataset {dataset_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_datasets(cls, tenant_id: str, user_id: str,
|
def list_datasets(cls, tenant_id: str, user_id: str,
|
||||||
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
page: int = 1, page_size: int = 20) -> Dict[str, Any]:
|
||||||
"""List datasets for a tenant"""
|
"""List datasets for a tenant"""
|
||||||
try:
|
try:
|
||||||
@ -104,10 +107,10 @@ class EvaluationService(CommonService):
|
|||||||
(EvaluationDataset.tenant_id == tenant_id) &
|
(EvaluationDataset.tenant_id == tenant_id) &
|
||||||
(EvaluationDataset.status == StatusEnum.VALID.value)
|
(EvaluationDataset.status == StatusEnum.VALID.value)
|
||||||
).order_by(EvaluationDataset.create_time.desc())
|
).order_by(EvaluationDataset.create_time.desc())
|
||||||
|
|
||||||
total = query.count()
|
total = query.count()
|
||||||
datasets = query.paginate(page, page_size)
|
datasets = query.paginate(page, page_size)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total": total,
|
"total": total,
|
||||||
"datasets": [d.to_dict() for d in datasets]
|
"datasets": [d.to_dict() for d in datasets]
|
||||||
@ -115,7 +118,7 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error listing datasets: {e}")
|
logging.error(f"Error listing datasets: {e}")
|
||||||
return {"total": 0, "datasets": []}
|
return {"total": 0, "datasets": []}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
def update_dataset(cls, dataset_id: str, **kwargs) -> bool:
|
||||||
"""Update dataset"""
|
"""Update dataset"""
|
||||||
@ -127,7 +130,7 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
logging.error(f"Error updating dataset {dataset_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_dataset(cls, dataset_id: str) -> bool:
|
def delete_dataset(cls, dataset_id: str) -> bool:
|
||||||
"""Soft delete dataset"""
|
"""Soft delete dataset"""
|
||||||
@ -139,18 +142,18 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
logging.error(f"Error deleting dataset {dataset_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# ==================== Test Case Management ====================
|
# ==================== Test Case Management ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_test_case(cls, dataset_id: str, question: str,
|
def add_test_case(cls, dataset_id: str, question: str,
|
||||||
reference_answer: Optional[str] = None,
|
reference_answer: Optional[str] = None,
|
||||||
relevant_doc_ids: Optional[List[str]] = None,
|
relevant_doc_ids: Optional[List[str]] = None,
|
||||||
relevant_chunk_ids: Optional[List[str]] = None,
|
relevant_chunk_ids: Optional[List[str]] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
metadata: Optional[Dict[str, Any]] = None) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Add a test case to a dataset.
|
Add a test case to a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_id: Dataset ID
|
dataset_id: Dataset ID
|
||||||
question: Test question
|
question: Test question
|
||||||
@ -158,7 +161,7 @@ class EvaluationService(CommonService):
|
|||||||
relevant_doc_ids: Optional list of relevant document IDs
|
relevant_doc_ids: Optional list of relevant document IDs
|
||||||
relevant_chunk_ids: Optional list of relevant chunk IDs
|
relevant_chunk_ids: Optional list of relevant chunk IDs
|
||||||
metadata: Optional additional metadata
|
metadata: Optional additional metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(success, case_id or error_message)
|
(success, case_id or error_message)
|
||||||
"""
|
"""
|
||||||
@ -174,15 +177,15 @@ class EvaluationService(CommonService):
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"create_time": current_timestamp()
|
"create_time": current_timestamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
if not EvaluationCase.create(**case):
|
if not EvaluationCase.create(**case):
|
||||||
return False, "Failed to create test case"
|
return False, "Failed to create test case"
|
||||||
|
|
||||||
return True, case_id
|
return True, case_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error adding test case: {e}")
|
logging.error(f"Error adding test case: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
def get_test_cases(cls, dataset_id: str) -> List[Dict[str, Any]]:
|
||||||
"""Get all test cases for a dataset"""
|
"""Get all test cases for a dataset"""
|
||||||
@ -190,12 +193,12 @@ class EvaluationService(CommonService):
|
|||||||
cases = EvaluationCase.select().where(
|
cases = EvaluationCase.select().where(
|
||||||
EvaluationCase.dataset_id == dataset_id
|
EvaluationCase.dataset_id == dataset_id
|
||||||
).order_by(EvaluationCase.create_time)
|
).order_by(EvaluationCase.create_time)
|
||||||
|
|
||||||
return [c.to_dict() for c in cases]
|
return [c.to_dict() for c in cases]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
logging.error(f"Error getting test cases for dataset {dataset_id}: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_test_case(cls, case_id: str) -> bool:
|
def delete_test_case(cls, case_id: str) -> bool:
|
||||||
"""Delete a test case"""
|
"""Delete a test case"""
|
||||||
@ -206,22 +209,22 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error deleting test case {case_id}: {e}")
|
logging.error(f"Error deleting test case {case_id}: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
def import_test_cases(cls, dataset_id: str, cases: List[Dict[str, Any]]) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
Bulk import test cases from a list.
|
Bulk import test cases from a list.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_id: Dataset ID
|
dataset_id: Dataset ID
|
||||||
cases: List of test case dictionaries
|
cases: List of test case dictionaries
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(success_count, failure_count)
|
(success_count, failure_count)
|
||||||
"""
|
"""
|
||||||
success_count = 0
|
success_count = 0
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
|
|
||||||
for case_data in cases:
|
for case_data in cases:
|
||||||
success, _ = cls.add_test_case(
|
success, _ = cls.add_test_case(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
@ -231,28 +234,28 @@ class EvaluationService(CommonService):
|
|||||||
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
relevant_chunk_ids=case_data.get("relevant_chunk_ids"),
|
||||||
metadata=case_data.get("metadata")
|
metadata=case_data.get("metadata")
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
success_count += 1
|
success_count += 1
|
||||||
else:
|
else:
|
||||||
failure_count += 1
|
failure_count += 1
|
||||||
|
|
||||||
return success_count, failure_count
|
return success_count, failure_count
|
||||||
|
|
||||||
# ==================== Evaluation Execution ====================
|
# ==================== Evaluation Execution ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
def start_evaluation(cls, dataset_id: str, dialog_id: str,
|
||||||
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
user_id: str, name: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
"""
|
"""
|
||||||
Start an evaluation run.
|
Start an evaluation run.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_id: Dataset ID
|
dataset_id: Dataset ID
|
||||||
dialog_id: Dialog configuration to evaluate
|
dialog_id: Dialog configuration to evaluate
|
||||||
user_id: User ID who starts the run
|
user_id: User ID who starts the run
|
||||||
name: Optional run name
|
name: Optional run name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(success, run_id or error_message)
|
(success, run_id or error_message)
|
||||||
"""
|
"""
|
||||||
@ -261,12 +264,12 @@ class EvaluationService(CommonService):
|
|||||||
success, dialog = DialogService.get_by_id(dialog_id)
|
success, dialog = DialogService.get_by_id(dialog_id)
|
||||||
if not success:
|
if not success:
|
||||||
return False, "Dialog not found"
|
return False, "Dialog not found"
|
||||||
|
|
||||||
# Create evaluation run
|
# Create evaluation run
|
||||||
run_id = get_uuid()
|
run_id = get_uuid()
|
||||||
if not name:
|
if not name:
|
||||||
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
name = f"Evaluation Run {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
run = {
|
run = {
|
||||||
"id": run_id,
|
"id": run_id,
|
||||||
"dataset_id": dataset_id,
|
"dataset_id": dataset_id,
|
||||||
@ -279,92 +282,128 @@ class EvaluationService(CommonService):
|
|||||||
"create_time": current_timestamp(),
|
"create_time": current_timestamp(),
|
||||||
"complete_time": None
|
"complete_time": None
|
||||||
}
|
}
|
||||||
|
|
||||||
if not EvaluationRun.create(**run):
|
if not EvaluationRun.create(**run):
|
||||||
return False, "Failed to create evaluation run"
|
return False, "Failed to create evaluation run"
|
||||||
|
|
||||||
# Execute evaluation asynchronously (in production, use task queue)
|
# Execute evaluation asynchronously (in production, use task queue)
|
||||||
# For now, we'll execute synchronously
|
# For now, we'll execute synchronously
|
||||||
cls._execute_evaluation(run_id, dataset_id, dialog)
|
cls._execute_evaluation(run_id, dataset_id, dialog)
|
||||||
|
|
||||||
return True, run_id
|
return True, run_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error starting evaluation: {e}")
|
logging.error(f"Error starting evaluation: {e}")
|
||||||
return False, str(e)
|
return False, str(e)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
def _execute_evaluation(cls, run_id: str, dataset_id: str, dialog: Any):
|
||||||
"""
|
"""
|
||||||
Execute evaluation for all test cases.
|
Execute evaluation for all test cases.
|
||||||
|
|
||||||
This method runs the RAG pipeline for each test case and computes metrics.
|
This method runs the RAG pipeline for each test case and computes metrics.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Get all test cases
|
# Get all test cases
|
||||||
test_cases = cls.get_test_cases(dataset_id)
|
test_cases = cls.get_test_cases(dataset_id)
|
||||||
|
|
||||||
if not test_cases:
|
if not test_cases:
|
||||||
EvaluationRun.update(
|
EvaluationRun.update(
|
||||||
status="FAILED",
|
status="FAILED",
|
||||||
complete_time=current_timestamp()
|
complete_time=current_timestamp()
|
||||||
).where(EvaluationRun.id == run_id).execute()
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Execute each test case
|
# Execute each test case
|
||||||
results = []
|
results = []
|
||||||
for case in test_cases:
|
for case in test_cases:
|
||||||
result = cls._evaluate_single_case(run_id, case, dialog)
|
result = cls._evaluate_single_case(run_id, case, dialog)
|
||||||
if result:
|
if result:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# Compute summary metrics
|
# Compute summary metrics
|
||||||
metrics_summary = cls._compute_summary_metrics(results)
|
metrics_summary = cls._compute_summary_metrics(results)
|
||||||
|
|
||||||
# Update run status
|
# Update run status
|
||||||
EvaluationRun.update(
|
EvaluationRun.update(
|
||||||
status="COMPLETED",
|
status="COMPLETED",
|
||||||
metrics_summary=metrics_summary,
|
metrics_summary=metrics_summary,
|
||||||
complete_time=current_timestamp()
|
complete_time=current_timestamp()
|
||||||
).where(EvaluationRun.id == run_id).execute()
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error executing evaluation {run_id}: {e}")
|
logging.error(f"Error executing evaluation {run_id}: {e}")
|
||||||
EvaluationRun.update(
|
EvaluationRun.update(
|
||||||
status="FAILED",
|
status="FAILED",
|
||||||
complete_time=current_timestamp()
|
complete_time=current_timestamp()
|
||||||
).where(EvaluationRun.id == run_id).execute()
|
).where(EvaluationRun.id == run_id).execute()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
def _evaluate_single_case(cls, run_id: str, case: Dict[str, Any],
|
||||||
dialog: Any) -> Optional[Dict[str, Any]]:
|
dialog: Any) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Evaluate a single test case.
|
Evaluate a single test case.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: Evaluation run ID
|
run_id: Evaluation run ID
|
||||||
case: Test case dictionary
|
case: Test case dictionary
|
||||||
dialog: Dialog configuration
|
dialog: Dialog configuration
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Result dictionary or None if failed
|
Result dictionary or None if failed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Prepare messages
|
# Prepare messages
|
||||||
messages = [{"role": "user", "content": case["question"]}]
|
messages = [{"role": "user", "content": case["question"]}]
|
||||||
|
|
||||||
# Execute RAG pipeline
|
# Execute RAG pipeline
|
||||||
start_time = timer()
|
start_time = timer()
|
||||||
answer = ""
|
answer = ""
|
||||||
retrieved_chunks = []
|
retrieved_chunks = []
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_from_async_gen(async_gen):
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
async def consume():
|
||||||
|
try:
|
||||||
|
async for item in async_gen:
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
|
loop.run_until_complete(consume())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
|
||||||
|
def chat(dialog, messages, stream=True, **kwargs):
|
||||||
|
from api.db.services.dialog_service import async_chat
|
||||||
|
|
||||||
|
return _sync_from_async_gen(async_chat(dialog, messages, stream=stream, **kwargs))
|
||||||
|
|
||||||
for ans in chat(dialog, messages, stream=False):
|
for ans in chat(dialog, messages, stream=False):
|
||||||
if isinstance(ans, dict):
|
if isinstance(ans, dict):
|
||||||
answer = ans.get("answer", "")
|
answer = ans.get("answer", "")
|
||||||
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
retrieved_chunks = ans.get("reference", {}).get("chunks", [])
|
||||||
break
|
break
|
||||||
|
|
||||||
execution_time = timer() - start_time
|
execution_time = timer() - start_time
|
||||||
|
|
||||||
# Compute metrics
|
# Compute metrics
|
||||||
metrics = cls._compute_metrics(
|
metrics = cls._compute_metrics(
|
||||||
question=case["question"],
|
question=case["question"],
|
||||||
@ -374,7 +413,7 @@ class EvaluationService(CommonService):
|
|||||||
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
relevant_chunk_ids=case.get("relevant_chunk_ids"),
|
||||||
dialog=dialog
|
dialog=dialog
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save result
|
# Save result
|
||||||
result_id = get_uuid()
|
result_id = get_uuid()
|
||||||
result = {
|
result = {
|
||||||
@ -388,14 +427,14 @@ class EvaluationService(CommonService):
|
|||||||
"token_usage": None, # TODO: Track token usage
|
"token_usage": None, # TODO: Track token usage
|
||||||
"create_time": current_timestamp()
|
"create_time": current_timestamp()
|
||||||
}
|
}
|
||||||
|
|
||||||
EvaluationResult.create(**result)
|
EvaluationResult.create(**result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
logging.error(f"Error evaluating case {case.get('id')}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _compute_metrics(cls, question: str, generated_answer: str,
|
def _compute_metrics(cls, question: str, generated_answer: str,
|
||||||
reference_answer: Optional[str],
|
reference_answer: Optional[str],
|
||||||
@ -404,69 +443,69 @@ class EvaluationService(CommonService):
|
|||||||
dialog: Any) -> Dict[str, float]:
|
dialog: Any) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Compute evaluation metrics for a single test case.
|
Compute evaluation metrics for a single test case.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of metric names to values
|
Dictionary of metric names to values
|
||||||
"""
|
"""
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
# Retrieval metrics (if ground truth chunks provided)
|
# Retrieval metrics (if ground truth chunks provided)
|
||||||
if relevant_chunk_ids:
|
if relevant_chunk_ids:
|
||||||
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
retrieved_ids = [c.get("chunk_id") for c in retrieved_chunks]
|
||||||
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
metrics.update(cls._compute_retrieval_metrics(retrieved_ids, relevant_chunk_ids))
|
||||||
|
|
||||||
# Generation metrics
|
# Generation metrics
|
||||||
if generated_answer:
|
if generated_answer:
|
||||||
# Basic metrics
|
# Basic metrics
|
||||||
metrics["answer_length"] = len(generated_answer)
|
metrics["answer_length"] = len(generated_answer)
|
||||||
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
metrics["has_answer"] = 1.0 if generated_answer.strip() else 0.0
|
||||||
|
|
||||||
# TODO: Implement advanced metrics using LLM-as-judge
|
# TODO: Implement advanced metrics using LLM-as-judge
|
||||||
# - Faithfulness (hallucination detection)
|
# - Faithfulness (hallucination detection)
|
||||||
# - Answer relevance
|
# - Answer relevance
|
||||||
# - Context relevance
|
# - Context relevance
|
||||||
# - Semantic similarity (if reference answer provided)
|
# - Semantic similarity (if reference answer provided)
|
||||||
|
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
def _compute_retrieval_metrics(cls, retrieved_ids: List[str],
|
||||||
relevant_ids: List[str]) -> Dict[str, float]:
|
relevant_ids: List[str]) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Compute retrieval metrics.
|
Compute retrieval metrics.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
retrieved_ids: List of retrieved chunk IDs
|
retrieved_ids: List of retrieved chunk IDs
|
||||||
relevant_ids: List of relevant chunk IDs (ground truth)
|
relevant_ids: List of relevant chunk IDs (ground truth)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of retrieval metrics
|
Dictionary of retrieval metrics
|
||||||
"""
|
"""
|
||||||
if not relevant_ids:
|
if not relevant_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
retrieved_set = set(retrieved_ids)
|
retrieved_set = set(retrieved_ids)
|
||||||
relevant_set = set(relevant_ids)
|
relevant_set = set(relevant_ids)
|
||||||
|
|
||||||
# Precision: proportion of retrieved that are relevant
|
# Precision: proportion of retrieved that are relevant
|
||||||
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
precision = len(retrieved_set & relevant_set) / len(retrieved_set) if retrieved_set else 0.0
|
||||||
|
|
||||||
# Recall: proportion of relevant that were retrieved
|
# Recall: proportion of relevant that were retrieved
|
||||||
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
recall = len(retrieved_set & relevant_set) / len(relevant_set) if relevant_set else 0.0
|
||||||
|
|
||||||
# F1 score
|
# F1 score
|
||||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
|
||||||
|
|
||||||
# Hit rate: whether any relevant chunk was retrieved
|
# Hit rate: whether any relevant chunk was retrieved
|
||||||
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
hit_rate = 1.0 if (retrieved_set & relevant_set) else 0.0
|
||||||
|
|
||||||
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
# MRR (Mean Reciprocal Rank): position of first relevant chunk
|
||||||
mrr = 0.0
|
mrr = 0.0
|
||||||
for i, chunk_id in enumerate(retrieved_ids, 1):
|
for i, chunk_id in enumerate(retrieved_ids, 1):
|
||||||
if chunk_id in relevant_set:
|
if chunk_id in relevant_set:
|
||||||
mrr = 1.0 / i
|
mrr = 1.0 / i
|
||||||
break
|
break
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"precision": precision,
|
"precision": precision,
|
||||||
"recall": recall,
|
"recall": recall,
|
||||||
@ -474,45 +513,45 @@ class EvaluationService(CommonService):
|
|||||||
"hit_rate": hit_rate,
|
"hit_rate": hit_rate,
|
||||||
"mrr": mrr
|
"mrr": mrr
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
def _compute_summary_metrics(cls, results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Compute summary metrics across all test cases.
|
Compute summary metrics across all test cases.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
results: List of result dictionaries
|
results: List of result dictionaries
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Summary metrics dictionary
|
Summary metrics dictionary
|
||||||
"""
|
"""
|
||||||
if not results:
|
if not results:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# Aggregate metrics
|
# Aggregate metrics
|
||||||
metric_sums = {}
|
metric_sums = {}
|
||||||
metric_counts = {}
|
metric_counts = {}
|
||||||
|
|
||||||
for result in results:
|
for result in results:
|
||||||
metrics = result.get("metrics", {})
|
metrics = result.get("metrics", {})
|
||||||
for key, value in metrics.items():
|
for key, value in metrics.items():
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
metric_sums[key] = metric_sums.get(key, 0) + value
|
metric_sums[key] = metric_sums.get(key, 0) + value
|
||||||
metric_counts[key] = metric_counts.get(key, 0) + 1
|
metric_counts[key] = metric_counts.get(key, 0) + 1
|
||||||
|
|
||||||
# Compute averages
|
# Compute averages
|
||||||
summary = {
|
summary = {
|
||||||
"total_cases": len(results),
|
"total_cases": len(results),
|
||||||
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
"avg_execution_time": sum(r.get("execution_time", 0) for r in results) / len(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
for key in metric_sums:
|
for key in metric_sums:
|
||||||
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
summary[f"avg_{key}"] = metric_sums[key] / metric_counts[key]
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
# ==================== Results & Analysis ====================
|
# ==================== Results & Analysis ====================
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
def get_run_results(cls, run_id: str) -> Dict[str, Any]:
|
||||||
"""Get results for an evaluation run"""
|
"""Get results for an evaluation run"""
|
||||||
@ -520,11 +559,11 @@ class EvaluationService(CommonService):
|
|||||||
run = EvaluationRun.get_by_id(run_id)
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
if not run:
|
if not run:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
results = EvaluationResult.select().where(
|
results = EvaluationResult.select().where(
|
||||||
EvaluationResult.run_id == run_id
|
EvaluationResult.run_id == run_id
|
||||||
).order_by(EvaluationResult.create_time)
|
).order_by(EvaluationResult.create_time)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"run": run.to_dict(),
|
"run": run.to_dict(),
|
||||||
"results": [r.to_dict() for r in results]
|
"results": [r.to_dict() for r in results]
|
||||||
@ -532,15 +571,15 @@ class EvaluationService(CommonService):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error getting run results {run_id}: {e}")
|
logging.error(f"Error getting run results {run_id}: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
def get_recommendations(cls, run_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Analyze evaluation results and provide configuration recommendations.
|
Analyze evaluation results and provide configuration recommendations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
run_id: Evaluation run ID
|
run_id: Evaluation run ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of recommendation dictionaries
|
List of recommendation dictionaries
|
||||||
"""
|
"""
|
||||||
@ -548,10 +587,10 @@ class EvaluationService(CommonService):
|
|||||||
run = EvaluationRun.get_by_id(run_id)
|
run = EvaluationRun.get_by_id(run_id)
|
||||||
if not run or not run.metrics_summary:
|
if not run or not run.metrics_summary:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
metrics = run.metrics_summary
|
metrics = run.metrics_summary
|
||||||
recommendations = []
|
recommendations = []
|
||||||
|
|
||||||
# Low precision: retrieving irrelevant chunks
|
# Low precision: retrieving irrelevant chunks
|
||||||
if metrics.get("avg_precision", 1.0) < 0.7:
|
if metrics.get("avg_precision", 1.0) < 0.7:
|
||||||
recommendations.append({
|
recommendations.append({
|
||||||
@ -564,7 +603,7 @@ class EvaluationService(CommonService):
|
|||||||
"Reduce top_k to return fewer chunks"
|
"Reduce top_k to return fewer chunks"
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Low recall: missing relevant chunks
|
# Low recall: missing relevant chunks
|
||||||
if metrics.get("avg_recall", 1.0) < 0.7:
|
if metrics.get("avg_recall", 1.0) < 0.7:
|
||||||
recommendations.append({
|
recommendations.append({
|
||||||
@ -578,7 +617,7 @@ class EvaluationService(CommonService):
|
|||||||
"Check chunk size - may be too large or too small"
|
"Check chunk size - may be too large or too small"
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
# Slow response time
|
# Slow response time
|
||||||
if metrics.get("avg_execution_time", 0) > 5.0:
|
if metrics.get("avg_execution_time", 0) > 5.0:
|
||||||
recommendations.append({
|
recommendations.append({
|
||||||
@ -591,7 +630,7 @@ class EvaluationService(CommonService):
|
|||||||
"Consider caching frequently asked questions"
|
"Consider caching frequently asked questions"
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
return recommendations
|
return recommendations
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
logging.error(f"Error generating recommendations for run {run_id}: {e}")
|
||||||
|
|||||||
@ -16,15 +16,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
|
import queue
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
from common.token_utils import num_tokens_from_string
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
from common.constants import LLMType
|
|
||||||
from api.db.db_models import LLM
|
from api.db.db_models import LLM
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
from api.db.services.tenant_llm_service import LLM4Tenant, TenantLLMService
|
||||||
|
from common.constants import LLMType
|
||||||
|
from common.token_utils import num_tokens_from_string
|
||||||
|
|
||||||
|
|
||||||
class LLMService(CommonService):
|
class LLMService(CommonService):
|
||||||
@ -33,6 +35,7 @@ class LLMService(CommonService):
|
|||||||
|
|
||||||
def get_init_tenant_llm(user_id):
|
def get_init_tenant_llm(user_id):
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
tenant_llm = []
|
tenant_llm = []
|
||||||
|
|
||||||
model_configs = {
|
model_configs = {
|
||||||
@ -193,7 +196,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
generation = self.langfuse.start_generation(
|
generation = self.langfuse.start_generation(
|
||||||
trace_context=self.trace_context,
|
trace_context=self.trace_context,
|
||||||
name="stream_transcription",
|
name="stream_transcription",
|
||||||
metadata={"model": self.llm_name}
|
metadata={"model": self.llm_name},
|
||||||
)
|
)
|
||||||
final_text = ""
|
final_text = ""
|
||||||
used_tokens = 0
|
used_tokens = 0
|
||||||
@ -217,32 +220,34 @@ class LLMBundle(LLM4Tenant):
|
|||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": final_text},
|
output={"output": final_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="stream_transcription", metadata={"model": self.llm_name})
|
generation = self.langfuse.start_generation(
|
||||||
full_text, used_tokens = mdl.transcription(audio)
|
trace_context=self.trace_context,
|
||||||
if not TenantLLMService.increase_usage(
|
name="stream_transcription",
|
||||||
self.tenant_id, self.llm_type, used_tokens
|
metadata={"model": self.llm_name},
|
||||||
):
|
|
||||||
logging.error(
|
|
||||||
f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
full_text, used_tokens = mdl.transcription(audio)
|
||||||
|
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens):
|
||||||
|
logging.error(f"LLMBundle.stream_transcription can't update token usage for {self.tenant_id}/SEQUENCE2TXT used_tokens: {used_tokens}")
|
||||||
|
|
||||||
if self.langfuse:
|
if self.langfuse:
|
||||||
generation.update(
|
generation.update(
|
||||||
output={"output": full_text},
|
output={"output": full_text},
|
||||||
usage_details={"total_tokens": used_tokens}
|
usage_details={"total_tokens": used_tokens},
|
||||||
)
|
)
|
||||||
generation.end()
|
generation.end()
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "final",
|
"event": "final",
|
||||||
"text": full_text,
|
"text": full_text,
|
||||||
"streaming": False
|
"streaming": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
def tts(self, text: str) -> Generator[bytes, None, None]:
|
def tts(self, text: str) -> Generator[bytes, None, None]:
|
||||||
@ -289,61 +294,79 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return kwargs
|
return kwargs
|
||||||
else:
|
else:
|
||||||
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
return {k: v for k, v in kwargs.items() if k in allowed_params}
|
||||||
|
|
||||||
|
def _run_coroutine_sync(self, coro):
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
return asyncio.run(coro)
|
||||||
|
|
||||||
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
|
def runner():
|
||||||
|
try:
|
||||||
|
result_queue.put((True, asyncio.run(coro)))
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put((False, e))
|
||||||
|
|
||||||
|
thread = threading.Thread(target=runner, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
success, value = result_queue.get_nowait()
|
||||||
|
if success:
|
||||||
|
return value
|
||||||
|
raise value
|
||||||
|
|
||||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||||
if self.langfuse:
|
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||||
if self.is_tools and self.mdl.is_tools:
|
result_queue: queue.Queue = queue.Queue()
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
|
||||||
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
def runner():
|
||||||
txt, used_tokens = chat_partial(**use_kwargs)
|
loop = asyncio.new_event_loop()
|
||||||
txt = self._remove_reasoning_content(txt)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
async def consume():
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
try:
|
||||||
|
async for item in async_gen_fn(*args, **kwargs):
|
||||||
|
result_queue.put(item)
|
||||||
|
except Exception as e:
|
||||||
|
result_queue.put(e)
|
||||||
|
finally:
|
||||||
|
result_queue.put(StopIteration)
|
||||||
|
|
||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
loop.run_until_complete(consume())
|
||||||
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
loop.close()
|
||||||
|
|
||||||
if self.langfuse:
|
threading.Thread(target=runner, daemon=True).start()
|
||||||
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
|
||||||
generation.end()
|
|
||||||
|
|
||||||
return txt
|
while True:
|
||||||
|
item = result_queue.get()
|
||||||
|
if item is StopIteration:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
if self.langfuse:
|
|
||||||
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
|
||||||
|
|
||||||
ans = ""
|
ans = ""
|
||||||
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
|
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
||||||
total_tokens = 0
|
|
||||||
if self.is_tools and self.mdl.is_tools:
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
for txt in chat_partial(**use_kwargs):
|
|
||||||
if isinstance(txt, int):
|
if isinstance(txt, int):
|
||||||
total_tokens = txt
|
|
||||||
if self.langfuse:
|
|
||||||
generation.update(output={"output": ans})
|
|
||||||
generation.end()
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
if txt.endswith("</think>"):
|
||||||
ans = ans[: -len("</think>")]
|
ans = txt[: -len("</think>")]
|
||||||
|
continue
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
ans += txt
|
# cancatination has beend done in async_chat_streamly
|
||||||
|
ans = txt
|
||||||
yield ans
|
yield ans
|
||||||
|
|
||||||
if total_tokens > 0:
|
|
||||||
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
def _bridge_sync_stream(self, gen):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
@ -352,7 +375,7 @@ class LLMBundle(LLM4Tenant):
|
|||||||
try:
|
try:
|
||||||
for item in gen:
|
for item in gen:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
loop.call_soon_threadsafe(queue.put_nowait, item)
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
finally:
|
finally:
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
||||||
@ -361,18 +384,27 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return queue
|
return queue
|
||||||
|
|
||||||
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
chat_partial = partial(self.mdl.chat, system, history, gen_conf, **kwargs)
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_with_tools"):
|
||||||
if self.is_tools and self.mdl.is_tools and hasattr(self.mdl, "chat_with_tools"):
|
base_fn = self.mdl.async_chat_with_tools
|
||||||
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf, **kwargs)
|
elif hasattr(self.mdl, "async_chat"):
|
||||||
|
base_fn = self.mdl.async_chat
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
|
chat_partial = partial(base_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
|
|
||||||
if hasattr(self.mdl, "async_chat_with_tools") and self.is_tools and self.mdl.is_tools:
|
try:
|
||||||
txt, used_tokens = await self.mdl.async_chat_with_tools(system, history, gen_conf, **use_kwargs)
|
txt, used_tokens = await chat_partial(**use_kwargs)
|
||||||
elif hasattr(self.mdl, "async_chat"):
|
except Exception as e:
|
||||||
txt, used_tokens = await self.mdl.async_chat(system, history, gen_conf, **use_kwargs)
|
if generation:
|
||||||
else:
|
generation.update(output={"error": str(e)})
|
||||||
txt, used_tokens = await asyncio.to_thread(chat_partial, **use_kwargs)
|
generation.end()
|
||||||
|
raise
|
||||||
|
|
||||||
txt = self._remove_reasoning_content(txt)
|
txt = self._remove_reasoning_content(txt)
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
@ -381,49 +413,51 @@ class LLMBundle(LLM4Tenant):
|
|||||||
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
if used_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
logging.error("LLMBundle.async_chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
|
||||||
|
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": txt}, usage_details={"total_tokens": used_tokens})
|
||||||
|
generation.end()
|
||||||
|
|
||||||
return txt
|
return txt
|
||||||
|
|
||||||
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
async def async_chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
ans = ""
|
ans = ""
|
||||||
if self.is_tools and self.mdl.is_tools:
|
if self.is_tools and getattr(self.mdl, "is_tools", False) and hasattr(self.mdl, "async_chat_streamly_with_tools"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly_with_tools", None)
|
||||||
else:
|
elif hasattr(self.mdl, "async_chat_streamly"):
|
||||||
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
stream_fn = getattr(self.mdl, "async_chat_streamly", None)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Model {self.mdl} does not implement async_chat or async_chat_with_tools")
|
||||||
|
|
||||||
|
generation = None
|
||||||
|
if self.langfuse:
|
||||||
|
generation = self.langfuse.start_generation(trace_context=self.trace_context, name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
|
||||||
|
|
||||||
if stream_fn:
|
if stream_fn:
|
||||||
chat_partial = partial(stream_fn, system, history, gen_conf)
|
chat_partial = partial(stream_fn, system, history, gen_conf)
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
||||||
async for txt in chat_partial(**use_kwargs):
|
try:
|
||||||
if isinstance(txt, int):
|
async for txt in chat_partial(**use_kwargs):
|
||||||
total_tokens = txt
|
if isinstance(txt, int):
|
||||||
break
|
total_tokens = txt
|
||||||
|
break
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
if txt.endswith("</think>"):
|
||||||
ans = ans[: -len("</think>")]
|
ans = ans[: -len("</think>")]
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
if not self.verbose_tool_use:
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||||
|
|
||||||
ans += txt
|
ans += txt
|
||||||
yield ans
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"error": str(e)})
|
||||||
|
generation.end()
|
||||||
|
raise
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
||||||
|
if generation:
|
||||||
|
generation.update(output={"output": ans}, usage_details={"total_tokens": total_tokens})
|
||||||
|
generation.end()
|
||||||
return
|
return
|
||||||
|
|
||||||
chat_partial = partial(self.mdl.chat_streamly_with_tools if (self.is_tools and self.mdl.is_tools) else self.mdl.chat_streamly, system, history, gen_conf)
|
|
||||||
use_kwargs = self._clean_param(chat_partial, **kwargs)
|
|
||||||
queue = self._bridge_sync_stream(chat_partial(**use_kwargs))
|
|
||||||
while True:
|
|
||||||
item = await queue.get()
|
|
||||||
if item is StopAsyncIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
if isinstance(item, int):
|
|
||||||
total_tokens = item
|
|
||||||
break
|
|
||||||
yield item
|
|
||||||
|
|
||||||
if total_tokens and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, total_tokens, self.llm_name):
|
|
||||||
logging.error("LLMBundle.async_chat_streamly can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, total_tokens))
|
|
||||||
|
|||||||
@ -126,7 +126,7 @@ class OnyxConfluence:
|
|||||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||||
"""credential_json - the current json credentials
|
"""credential_json - the current json credentials
|
||||||
Returns a tuple
|
Returns a tuple
|
||||||
1. The up to date credentials
|
1. The up-to-date credentials
|
||||||
2. True if the credentials were updated
|
2. True if the credentials were updated
|
||||||
|
|
||||||
This method is intended to be used within a distributed lock.
|
This method is intended to be used within a distributed lock.
|
||||||
@ -179,8 +179,8 @@ class OnyxConfluence:
|
|||||||
credential_json["confluence_refresh_token"],
|
credential_json["confluence_refresh_token"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# store the new credentials to redis and to the db thru the provider
|
# store the new credentials to redis and to the db through the provider
|
||||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
# redis: we use a 5 min TTL because we are given a 10 minutes grace period
|
||||||
# when keys are rotated. it's easier to expire the cached credentials
|
# when keys are rotated. it's easier to expire the cached credentials
|
||||||
# reasonably frequently rather than trying to handle strong synchronization
|
# reasonably frequently rather than trying to handle strong synchronization
|
||||||
# between the db and redis everywhere the credentials might be updated
|
# between the db and redis everywhere the credentials might be updated
|
||||||
@ -690,7 +690,7 @@ class OnyxConfluence:
|
|||||||
) -> Iterator[dict[str, Any]]:
|
) -> Iterator[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
This function will paginate through the top level query first, then
|
This function will paginate through the top level query first, then
|
||||||
paginate through all of the expansions.
|
paginate through all the expansions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _traverse_and_update(data: dict | list) -> None:
|
def _traverse_and_update(data: dict | list) -> None:
|
||||||
@ -863,7 +863,7 @@ def get_user_email_from_username__server(
|
|||||||
# For now, we'll just return None and log a warning. This means
|
# For now, we'll just return None and log a warning. This means
|
||||||
# we will keep retrying to get the email every group sync.
|
# we will keep retrying to get the email every group sync.
|
||||||
email = None
|
email = None
|
||||||
# We may want to just return a string that indicates failure so we dont
|
# We may want to just return a string that indicates failure so we don't
|
||||||
# keep retrying
|
# keep retrying
|
||||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||||
_USER_EMAIL_CACHE[user_name] = email
|
_USER_EMAIL_CACHE[user_name] = email
|
||||||
@ -912,7 +912,7 @@ def extract_text_from_confluence_html(
|
|||||||
confluence_object: dict[str, Any],
|
confluence_object: dict[str, Any],
|
||||||
fetched_titles: set[str],
|
fetched_titles: set[str],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
"""Parse a Confluence html page and replace the 'user id' by the real
|
||||||
User Display Name
|
User Display Name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@ -33,7 +33,7 @@ def _convert_message_to_document(
|
|||||||
metadata: dict[str, str | list[str]] = {}
|
metadata: dict[str, str | list[str]] = {}
|
||||||
semantic_substring = ""
|
semantic_substring = ""
|
||||||
|
|
||||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
# Only messages from TextChannels will make it here, but we have to check for it anyway
|
||||||
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
if isinstance(message.channel, TextChannel) and (channel_name := message.channel.name):
|
||||||
metadata["Channel"] = channel_name
|
metadata["Channel"] = channel_name
|
||||||
semantic_substring += f" in Channel: #{channel_name}"
|
semantic_substring += f" in Channel: #{channel_name}"
|
||||||
@ -176,7 +176,7 @@ def _manage_async_retrieval(
|
|||||||
# parse requested_start_date_string to datetime
|
# parse requested_start_date_string to datetime
|
||||||
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
pull_date: datetime | None = datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(tzinfo=timezone.utc) if requested_start_date_string else None
|
||||||
|
|
||||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
# Set start_time to the most recent of start and pull_date, or whichever is provided
|
||||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||||
|
|
||||||
end_time: datetime | None = end
|
end_time: datetime | None = end
|
||||||
|
|||||||
@ -76,7 +76,7 @@ ALL_ACCEPTED_FILE_EXTENSIONS = ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS + ACCEPTED_DO
|
|||||||
|
|
||||||
MAX_RETRIEVER_EMAILS = 20
|
MAX_RETRIEVER_EMAILS = 20
|
||||||
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read
|
||||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
# This is not a standard valid Unicode char, it is used by the docs advanced API to
|
||||||
# represent smart chips (elements like dates and doc links).
|
# represent smart chips (elements like dates and doc links).
|
||||||
SMART_CHIP_CHAR = "\ue907"
|
SMART_CHIP_CHAR = "\ue907"
|
||||||
WEB_VIEW_LINK_KEY = "webViewLink"
|
WEB_VIEW_LINK_KEY = "webViewLink"
|
||||||
|
|||||||
@ -141,7 +141,7 @@ def crawl_folders_for_files(
|
|||||||
# Only mark a folder as done if it was fully traversed without errors
|
# Only mark a folder as done if it was fully traversed without errors
|
||||||
# This usually indicates that the owner of the folder was impersonated.
|
# This usually indicates that the owner of the folder was impersonated.
|
||||||
# In cases where this never happens, most likely the folder owner is
|
# In cases where this never happens, most likely the folder owner is
|
||||||
# not part of the google workspace in question (or for oauth, the authenticated
|
# not part of the Google Workspace in question (or for oauth, the authenticated
|
||||||
# user doesn't own the folder)
|
# user doesn't own the folder)
|
||||||
if found_files:
|
if found_files:
|
||||||
update_traversed_ids_func(parent_id)
|
update_traversed_ids_func(parent_id)
|
||||||
@ -232,7 +232,7 @@ def get_files_in_shared_drive(
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
# If we found any files, mark this drive as traversed. When a user has access to a drive,
|
||||||
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
|
# they have access to all the files in the drive. Also, not a huge deal if we re-traverse
|
||||||
# empty drives.
|
# empty drives.
|
||||||
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
# NOTE: ^^ the above is not actually true due to folder restrictions:
|
||||||
# https://support.google.com/a/users/answer/12380484?hl=en
|
# https://support.google.com/a/users/answer/12380484?hl=en
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class GDriveMimeType(str, Enum):
|
|||||||
MARKDOWN = "text/markdown"
|
MARKDOWN = "text/markdown"
|
||||||
|
|
||||||
|
|
||||||
# These correspond to The major stages of retrieval for google drive.
|
# These correspond to The major stages of retrieval for Google Drive.
|
||||||
# The stages for the oauth flow are:
|
# The stages for the oauth flow are:
|
||||||
# get_all_files_for_oauth(),
|
# get_all_files_for_oauth(),
|
||||||
# get_all_drive_ids(),
|
# get_all_drive_ids(),
|
||||||
@ -117,7 +117,7 @@ class GoogleDriveCheckpoint(ConnectorCheckpoint):
|
|||||||
|
|
||||||
class RetrievedDriveFile(BaseModel):
|
class RetrievedDriveFile(BaseModel):
|
||||||
"""
|
"""
|
||||||
Describes a file that has been retrieved from google drive.
|
Describes a file that has been retrieved from Google Drive.
|
||||||
user_email is the email of the user that the file was retrieved
|
user_email is the email of the user that the file was retrieved
|
||||||
by impersonating. If an error worthy of being reported is encountered,
|
by impersonating. If an error worthy of being reported is encountered,
|
||||||
error should be set and later propagated as a ConnectorFailure.
|
error should be set and later propagated as a ConnectorFailure.
|
||||||
|
|||||||
@ -29,8 +29,8 @@ class GmailService(Resource):
|
|||||||
|
|
||||||
class RefreshableDriveObject:
|
class RefreshableDriveObject:
|
||||||
"""
|
"""
|
||||||
Running Google drive service retrieval functions
|
Running Google Drive service retrieval functions
|
||||||
involves accessing methods of the service object (ie. files().list())
|
involves accessing methods of the service object (i.e. files().list())
|
||||||
which can raise a RefreshError if the access token is expired.
|
which can raise a RefreshError if the access token is expired.
|
||||||
This class is a wrapper that propagates the ability to refresh the access token
|
This class is a wrapper that propagates the ability to refresh the access token
|
||||||
and retry the final retrieval function until execute() is called.
|
and retry the final retrieval function until execute() is called.
|
||||||
|
|||||||
@ -120,7 +120,7 @@ def format_document_soup(
|
|||||||
# table is standard HTML element
|
# table is standard HTML element
|
||||||
if e.name == "table":
|
if e.name == "table":
|
||||||
in_table = True
|
in_table = True
|
||||||
# tr is for rows
|
# TR is for rows
|
||||||
elif e.name == "tr" and in_table:
|
elif e.name == "tr" and in_table:
|
||||||
text += "\n"
|
text += "\n"
|
||||||
# td for data cell, th for header
|
# td for data cell, th for header
|
||||||
|
|||||||
@ -395,8 +395,7 @@ class AttachmentProcessingResult(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class IndexingHeartbeatInterface(ABC):
|
class IndexingHeartbeatInterface(ABC):
|
||||||
"""Defines a callback interface to be passed to
|
"""Defines a callback interface to be passed to run_indexing_entrypoint."""
|
||||||
to run_indexing_entrypoint."""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def should_stop(self) -> bool:
|
def should_stop(self) -> bool:
|
||||||
|
|||||||
@ -80,7 +80,7 @@ _TZ_OFFSET_PATTERN = re.compile(r"([+-])(\d{2})(:?)(\d{2})$")
|
|||||||
|
|
||||||
|
|
||||||
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
class JiraConnector(CheckpointedConnectorWithPermSync, SlimConnectorWithPermSync):
|
||||||
"""Retrieve Jira issues and emit them as markdown documents."""
|
"""Retrieve Jira issues and emit them as Markdown documents."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -54,8 +54,8 @@ class ExternalAccess:
|
|||||||
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`.
|
||||||
This effectively makes the document in question "private" or inaccessible to anyone else.
|
This effectively makes the document in question "private" or inaccessible to anyone else.
|
||||||
|
|
||||||
This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able
|
This is especially helpful to use when you are performing permission-syncing, and some document's permissions can't
|
||||||
to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -61,7 +61,7 @@ def clean_markdown_block(text):
|
|||||||
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
str: Cleaned text with Markdown code block syntax removed, and stripped of surrounding whitespace
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# Remove opening ```markdown tag with optional whitespace and newlines
|
# Remove opening ```Markdown tag with optional whitespace and newlines
|
||||||
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
# Matches: optional whitespace + ```markdown + optional whitespace + optional newline
|
||||||
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
text = re.sub(r'^\s*```markdown\s*\n?', '', text)
|
||||||
|
|
||||||
|
|||||||
@ -51,7 +51,7 @@ We use vision information to resolve problems as human being.
|
|||||||
```bash
|
```bash
|
||||||
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
python deepdoc/vision/t_ocr.py --inputs=path_to_images_or_pdfs --output_dir=path_to_store_result
|
||||||
```
|
```
|
||||||
The inputs could be directory to images or PDF, or a image or PDF.
|
The inputs could be directory to images or PDF, or an image or PDF.
|
||||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
You can look into the folder 'path_to_store_result' where has images which demonstrate the positions of results,
|
||||||
txt files which contain the OCR text.
|
txt files which contain the OCR text.
|
||||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
@ -78,7 +78,7 @@ We use vision information to resolve problems as human being.
|
|||||||
```bash
|
```bash
|
||||||
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
python deepdoc/vision/t_recognizer.py --inputs=path_to_images_or_pdfs --threshold=0.2 --mode=layout --output_dir=path_to_store_result
|
||||||
```
|
```
|
||||||
The inputs could be directory to images or PDF, or a image or PDF.
|
The inputs could be directory to images or PDF, or an image or PDF.
|
||||||
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
You can look into the folder 'path_to_store_result' where has images which demonstrate the detection results as following:
|
||||||
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
<div align="center" style="margin-top:20px;margin-bottom:20px;">
|
||||||
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
<img src="https://github.com/infiniflow/ragflow/assets/12318111/07e0f625-9b28-43d0-9fbb-5bf586cd286f" width="1000"/>
|
||||||
|
|||||||
@ -41,7 +41,7 @@ class RAGFlowExcelParser:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
df = pd.read_csv(file_like_object)
|
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||||
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
||||||
|
|
||||||
except Exception as e_csv:
|
except Exception as e_csv:
|
||||||
@ -164,7 +164,7 @@ class RAGFlowExcelParser:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
logging.warning(f"Parse spreadsheet error: {e}, trying to interpret as CSV file")
|
||||||
file_like_object.seek(0)
|
file_like_object.seek(0)
|
||||||
df = pd.read_csv(file_like_object)
|
df = pd.read_csv(file_like_object, on_bad_lines='skip')
|
||||||
df = df.replace(r"^\s*$", "", regex=True)
|
df = df.replace(r"^\s*$", "", regex=True)
|
||||||
return df.to_markdown(index=False)
|
return df.to_markdown(index=False)
|
||||||
|
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class RAGFlowHtmlParser:
|
|||||||
block_content = []
|
block_content = []
|
||||||
current_content = ""
|
current_content = ""
|
||||||
table_info_list = []
|
table_info_list = []
|
||||||
lask_block_id = None
|
last_block_id = None
|
||||||
for item in parser_result:
|
for item in parser_result:
|
||||||
content = item.get("content")
|
content = item.get("content")
|
||||||
tag_name = item.get("tag_name")
|
tag_name = item.get("tag_name")
|
||||||
@ -160,11 +160,11 @@ class RAGFlowHtmlParser:
|
|||||||
if block_id:
|
if block_id:
|
||||||
if title_flag:
|
if title_flag:
|
||||||
content = f"{TITLE_TAGS[tag_name]} {content}"
|
content = f"{TITLE_TAGS[tag_name]} {content}"
|
||||||
if lask_block_id != block_id:
|
if last_block_id != block_id:
|
||||||
if lask_block_id is not None:
|
if last_block_id is not None:
|
||||||
block_content.append(current_content)
|
block_content.append(current_content)
|
||||||
current_content = content
|
current_content = content
|
||||||
lask_block_id = block_id
|
last_block_id = block_id
|
||||||
else:
|
else:
|
||||||
current_content += (" " if current_content else "") + content
|
current_content += (" " if current_content else "") + content
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -582,7 +582,7 @@ class OCR:
|
|||||||
self.crop_image_res_index = 0
|
self.crop_image_res_index = 0
|
||||||
|
|
||||||
def get_rotate_crop_image(self, img, points):
|
def get_rotate_crop_image(self, img, points):
|
||||||
'''
|
"""
|
||||||
img_height, img_width = img.shape[0:2]
|
img_height, img_width = img.shape[0:2]
|
||||||
left = int(np.min(points[:, 0]))
|
left = int(np.min(points[:, 0]))
|
||||||
right = int(np.max(points[:, 0]))
|
right = int(np.max(points[:, 0]))
|
||||||
@ -591,7 +591,7 @@ class OCR:
|
|||||||
img_crop = img[top:bottom, left:right, :].copy()
|
img_crop = img[top:bottom, left:right, :].copy()
|
||||||
points[:, 0] = points[:, 0] - left
|
points[:, 0] = points[:, 0] - left
|
||||||
points[:, 1] = points[:, 1] - top
|
points[:, 1] = points[:, 1] - top
|
||||||
'''
|
"""
|
||||||
assert len(points) == 4, "shape of points must be 4*2"
|
assert len(points) == 4, "shape of points must be 4*2"
|
||||||
img_crop_width = int(
|
img_crop_width = int(
|
||||||
max(
|
max(
|
||||||
|
|||||||
@ -67,10 +67,10 @@ class DBPostProcess:
|
|||||||
[[1, 1], [1, 1]])
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
@ -114,10 +114,10 @@ class DBPostProcess:
|
|||||||
return boxes, scores
|
return boxes, scores
|
||||||
|
|
||||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
"""
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
"""
|
||||||
|
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
@ -192,9 +192,9 @@ class DBPostProcess:
|
|||||||
return box, min(bounding_box[1])
|
return box, min(bounding_box[1])
|
||||||
|
|
||||||
def box_score_fast(self, bitmap, _box):
|
def box_score_fast(self, bitmap, _box):
|
||||||
'''
|
"""
|
||||||
box_score_fast: use bbox mean score as the mean score
|
box_score_fast: use bbox mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
box = _box.copy()
|
box = _box.copy()
|
||||||
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
|
||||||
@ -209,9 +209,9 @@ class DBPostProcess:
|
|||||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
def box_score_slow(self, bitmap, contour):
|
def box_score_slow(self, bitmap, contour):
|
||||||
'''
|
"""
|
||||||
box_score_slow: use polyon mean score as the mean score
|
box_score_slow: use polygon mean score as the mean score
|
||||||
'''
|
"""
|
||||||
h, w = bitmap.shape[:2]
|
h, w = bitmap.shape[:2]
|
||||||
contour = contour.copy()
|
contour = contour.copy()
|
||||||
contour = np.reshape(contour, (-1, 2))
|
contour = np.reshape(contour, (-1, 2))
|
||||||
|
|||||||
@ -155,7 +155,7 @@ class TableStructureRecognizer(Recognizer):
|
|||||||
while i < len(boxes):
|
while i < len(boxes):
|
||||||
if TableStructureRecognizer.is_caption(boxes[i]):
|
if TableStructureRecognizer.is_caption(boxes[i]):
|
||||||
if is_english:
|
if is_english:
|
||||||
cap + " "
|
cap += " "
|
||||||
cap += boxes[i]["text"]
|
cap += boxes[i]["text"]
|
||||||
boxes.pop(i)
|
boxes.pop(i)
|
||||||
i -= 1
|
i -= 1
|
||||||
|
|||||||
@ -25,9 +25,9 @@ services:
|
|||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
command:
|
||||||
# - --enable-adminserver
|
- --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
- ${SVR_WEB_HTTP_PORT}:80
|
- ${SVR_WEB_HTTP_PORT}:80
|
||||||
- ${SVR_WEB_HTTPS_PORT}:443
|
- ${SVR_WEB_HTTPS_PORT}:443
|
||||||
@ -74,9 +74,9 @@ services:
|
|||||||
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
# - --no-transport-streamable-http-enabled # Disable Streamable HTTP transport (/mcp endpoint)
|
||||||
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
# - --no-json-response # Disable JSON response mode in Streamable HTTP transport (instead of SSE over HTTP)
|
||||||
|
|
||||||
# Example configration to start Admin server:
|
# Example configuration to start Admin server:
|
||||||
# command:
|
command:
|
||||||
# - --enable-adminserver
|
- --enable-adminserver
|
||||||
ports:
|
ports:
|
||||||
- ${SVR_WEB_HTTP_PORT}:80
|
- ${SVR_WEB_HTTP_PORT}:80
|
||||||
- ${SVR_WEB_HTTPS_PORT}:443
|
- ${SVR_WEB_HTTPS_PORT}:443
|
||||||
|
|||||||
@ -151,7 +151,7 @@ See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
|
|||||||
|
|
||||||
### Cannot access https://huggingface.co
|
### Cannot access https://huggingface.co
|
||||||
|
|
||||||
A locally deployed RAGflow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
A locally deployed RAGFlow downloads OCR models from [Huggingface website](https://huggingface.co) by default. If your machine is unable to access this site, the following error occurs and PDF parsing fails:
|
||||||
|
|
||||||
```
|
```
|
||||||
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/hub/models--InfiniFlow--deepdoc/snapshots/be0c1e50eef6047b412d1800aa89aba4d275f997/ocr.res'
|
||||||
|
|||||||
@ -45,13 +45,13 @@ Click the light bulb icon above the *current* dialogue and scroll down the popup
|
|||||||
|
|
||||||
|
|
||||||
| Item name | Description |
|
| Item name | Description |
|
||||||
| ----------------- | --------------------------------------------------------------------------------------------- |
|
| ----------------- |-----------------------------------------------------------------------------------------------|
|
||||||
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||||
| Check LLM | Time to validate the specified LLM. |
|
| Check LLM | Time to validate the specified LLM. |
|
||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|||||||
@ -37,7 +37,7 @@ Please note that rerank models are essential in certain scenarios. There is alwa
|
|||||||
| Create retriever | Time to create a chunk retriever. |
|
| Create retriever | Time to create a chunk retriever. |
|
||||||
| Bind embedding | Time to initialize an embedding model instance. |
|
| Bind embedding | Time to initialize an embedding model instance. |
|
||||||
| Bind LLM | Time to initialize an LLM instance. |
|
| Bind LLM | Time to initialize an LLM instance. |
|
||||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
| Tune question | Time to optimize the user query using the context of the multi-turn conversation. |
|
||||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||||
| Generate keywords | Time to extract keywords from the user query. |
|
| Generate keywords | Time to extract keywords from the user query. |
|
||||||
| Retrieval | Time to retrieve the chunks. |
|
| Retrieval | Time to retrieve the chunks. |
|
||||||
|
|||||||
@ -8,7 +8,7 @@ slug: /manage_users_and_services
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
The Admin CLI and Admin Service form a client-server architectural suite for RAGflow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGflow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
The Admin CLI and Admin Service form a client-server architectural suite for RAGFlow system administration. The Admin CLI serves as an interactive command-line interface that receives instructions and displays execution results from the Admin Service in real-time. This duo enables real-time monitoring of system operational status, supporting visibility into RAGFlow Server services and dependent components including MySQL, Elasticsearch, Redis, and MinIO. In administrator mode, they provide user management capabilities that allow viewing users and performing critical operations—such as user creation, password updates, activation status changes, and comprehensive user data deletion—even when corresponding web interface functionalities are disabled.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -305,7 +305,7 @@ With the Ollama service running, open a new terminal and run `./ollama pull <mod
|
|||||||
</TabItem>
|
</TabItem>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
### 4. Configure RAGflow
|
### 4. Configure RAGFlow
|
||||||
|
|
||||||
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
To enable IPEX-LLM accelerated Ollama in RAGFlow, you must also complete the configurations in RAGFlow. The steps are identical to those outlined in the *Deploy a local model using Ollama* section:
|
||||||
|
|
||||||
|
|||||||
@ -4013,7 +4013,7 @@ Failure:
|
|||||||
|
|
||||||
**DELETE** `/api/v1/agents/{agent_id}/sessions`
|
**DELETE** `/api/v1/agents/{agent_id}/sessions`
|
||||||
|
|
||||||
Deletes sessions of a agent by ID.
|
Deletes sessions of an agent by ID.
|
||||||
|
|
||||||
#### Request
|
#### Request
|
||||||
|
|
||||||
@ -4072,7 +4072,7 @@ Failure:
|
|||||||
|
|
||||||
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
|
Generates five to ten alternative question strings from the user's original query to retrieve more relevant search results.
|
||||||
|
|
||||||
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find the it in the Request Headers in your browser easily as shown below:
|
This operation requires a `Bearer Login Token`, which typically expires with in 24 hours. You can find it in the Request Headers in your browser easily as shown below:
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
@ -1740,7 +1740,7 @@ for session in sessions:
|
|||||||
Agent.delete_sessions(ids: list[str] = None)
|
Agent.delete_sessions(ids: list[str] = None)
|
||||||
```
|
```
|
||||||
|
|
||||||
Deletes sessions of a agent by ID.
|
Deletes sessions of an agent by ID.
|
||||||
|
|
||||||
#### Parameters
|
#### Parameters
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
# requires-python = ">=3.10"
|
# requires-python = ">=3.10"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "nltk",
|
# "nltk",
|
||||||
|
# "huggingface-hub"
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
|
|||||||
@ -14,9 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
'''
|
"""
|
||||||
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
The example is about CRUD operations (Create, Read, Update, Delete) on a dataset.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
from ragflow_sdk import RAGFlow
|
from ragflow_sdk import RAGFlow
|
||||||
import sys
|
import sys
|
||||||
|
|||||||
@ -92,6 +92,6 @@ def get_metadata(cls) -> LLMToolMetadata:
|
|||||||
|
|
||||||
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
|
The `get_metadata` method is a `classmethod`. It will provide the description of this tool to LLM.
|
||||||
|
|
||||||
The fields starts with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
|
The fields start with `display` can use a special notation: `$t:xxx`, which will use the i18n mechanism in the RAGFlow frontend, getting text from the `llmTools` category. The frontend will display what you put here if you don't use this notation.
|
||||||
|
|
||||||
Now our tool is ready. You can select it in the `Generate` component and try it out.
|
Now our tool is ready. You can select it in the `Generate` component and try it out.
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from plugin.llm_tool_plugin import LLMToolMetadata, LLMToolPlugin
|
|||||||
class BadCalculatorPlugin(LLMToolPlugin):
|
class BadCalculatorPlugin(LLMToolPlugin):
|
||||||
"""
|
"""
|
||||||
A sample LLM tool plugin, will add two numbers with 100.
|
A sample LLM tool plugin, will add two numbers with 100.
|
||||||
It only present for demo purpose. Do not use it in production.
|
It only presents for demo purpose. Do not use it in production.
|
||||||
"""
|
"""
|
||||||
_version_ = "1.0.0"
|
_version_ = "1.0.0"
|
||||||
|
|
||||||
|
|||||||
@ -70,7 +70,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
"""
|
"""
|
||||||
Supported file formats are docx, pdf, txt.
|
Supported file formats are docx, pdf, txt.
|
||||||
Since a book is long and not all the parts are useful, if it's a PDF,
|
Since a book is long and not all the parts are useful, if it's a PDF,
|
||||||
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
please set up the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
||||||
"""
|
"""
|
||||||
parser_config = kwargs.get(
|
parser_config = kwargs.get(
|
||||||
"parser_config", {
|
"parser_config", {
|
||||||
@ -143,13 +143,14 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
binary = BytesIO(binary)
|
with BytesIO(binary) as binary:
|
||||||
doc_parsed = parser.from_buffer(binary)
|
binary = BytesIO(binary)
|
||||||
sections = doc_parsed['content'].split('\n')
|
doc_parsed = parser.from_buffer(binary)
|
||||||
sections = [(line, "") for line in sections if line]
|
sections = doc_parsed['content'].split('\n')
|
||||||
remove_contents_table(sections, eng=is_english(
|
sections = [(line, "") for line in sections if line]
|
||||||
random_choices([t for t, _ in sections], k=200)))
|
remove_contents_table(sections, eng=is_english(
|
||||||
callback(0.8, "Finish parsing.")
|
random_choices([t for t, _ in sections], k=200)))
|
||||||
|
callback(0.8, "Finish parsing.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|||||||
@ -201,12 +201,23 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||||||
|
|
||||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||||
callback(0.1, "Start to parse.")
|
callback(0.1, "Start to parse.")
|
||||||
binary = BytesIO(binary)
|
try:
|
||||||
doc_parsed = parser.from_buffer(binary)
|
from tika import parser as tika_parser
|
||||||
sections = doc_parsed['content'].split('\n')
|
except Exception as e:
|
||||||
sections = [s for s in sections if s]
|
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
|
||||||
callback(0.8, "Finish parsing.")
|
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
|
||||||
|
return []
|
||||||
|
|
||||||
|
binary = BytesIO(binary)
|
||||||
|
doc_parsed = tika_parser.from_buffer(binary)
|
||||||
|
if doc_parsed.get('content', None) is not None:
|
||||||
|
sections = doc_parsed['content'].split('\n')
|
||||||
|
sections = [s for s in sections if s]
|
||||||
|
callback(0.8, "Finish parsing.")
|
||||||
|
else:
|
||||||
|
callback(0.8, f"tika.parser got empty content from {filename}.")
|
||||||
|
logging.warning(f"tika.parser got empty content from {filename}.")
|
||||||
|
return []
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"file type not supported yet(doc, docx, pdf, txt supported)")
|
"file type not supported yet(doc, docx, pdf, txt supported)")
|
||||||
|
|||||||
@ -313,7 +313,7 @@ def mdQuestionLevel(s):
|
|||||||
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Excel and csv(txt) format files are supported.
|
Excel and csv(txt) format files are supported.
|
||||||
If the file is in excel format, there should be 2 column question and answer without header.
|
If the file is in Excel format, there should be 2 column question and answer without header.
|
||||||
And question column is ahead of answer column.
|
And question column is ahead of answer column.
|
||||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||||
|
|
||||||
|
|||||||
@ -37,7 +37,7 @@ def beAdoc(d, q, a, eng, row_num=-1):
|
|||||||
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
def chunk(filename, binary=None, lang="Chinese", callback=None, **kwargs):
|
||||||
"""
|
"""
|
||||||
Excel and csv(txt) format files are supported.
|
Excel and csv(txt) format files are supported.
|
||||||
If the file is in excel format, there should be 2 column content and tags without header.
|
If the file is in Excel format, there should be 2 column content and tags without header.
|
||||||
And content column is ahead of tags column.
|
And content column is ahead of tags column.
|
||||||
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
And it's O.K if it has multiple sheets as long as the columns are rightly composed.
|
||||||
|
|
||||||
|
|||||||
@ -15,9 +15,8 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from copy import deepcopy, copy
|
from copy import deepcopy
|
||||||
|
|
||||||
import trio
|
|
||||||
import xxhash
|
import xxhash
|
||||||
|
|
||||||
from agent.component.llm import LLMParam, LLM
|
from agent.component.llm import LLMParam, LLM
|
||||||
@ -38,13 +37,13 @@ class ExtractorParam(ProcessParamBase, LLMParam):
|
|||||||
class Extractor(ProcessBase, LLM):
|
class Extractor(ProcessBase, LLM):
|
||||||
component_name = "Extractor"
|
component_name = "Extractor"
|
||||||
|
|
||||||
def _build_TOC(self, docs):
|
async def _build_TOC(self, docs):
|
||||||
self.callback(message="Start to generate table of content ...")
|
self.callback(0.2,message="Start to generate table of content ...")
|
||||||
docs = sorted(docs, key=lambda d:(
|
docs = sorted(docs, key=lambda d:(
|
||||||
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
d.get("page_num_int", 0)[0] if isinstance(d.get("page_num_int", 0), list) else d.get("page_num_int", 0),
|
||||||
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
d.get("top_int", 0)[0] if isinstance(d.get("top_int", 0), list) else d.get("top_int", 0)
|
||||||
))
|
))
|
||||||
toc: list[dict] = trio.run(run_toc_from_text, [d["text"] for d in docs], self.chat_mdl)
|
toc = await run_toc_from_text([d["text"] for d in docs], self.chat_mdl)
|
||||||
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
logging.info("------------ T O C -------------\n"+json.dumps(toc, ensure_ascii=False, indent=' '))
|
||||||
ii = 0
|
ii = 0
|
||||||
while ii < len(toc):
|
while ii < len(toc):
|
||||||
@ -61,7 +60,8 @@ class Extractor(ProcessBase, LLM):
|
|||||||
ii += 1
|
ii += 1
|
||||||
|
|
||||||
if toc:
|
if toc:
|
||||||
d = copy.deepcopy(docs[-1])
|
d = deepcopy(docs[-1])
|
||||||
|
d["doc_id"] = self._canvas._doc_id
|
||||||
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
d["content_with_weight"] = json.dumps(toc, ensure_ascii=False)
|
||||||
d["toc_kwd"] = "toc"
|
d["toc_kwd"] = "toc"
|
||||||
d["available_int"] = 0
|
d["available_int"] = 0
|
||||||
@ -85,11 +85,14 @@ class Extractor(ProcessBase, LLM):
|
|||||||
|
|
||||||
if chunks:
|
if chunks:
|
||||||
if self._param.field_name == "toc":
|
if self._param.field_name == "toc":
|
||||||
toc = self._build_TOC(chunks)
|
for ck in chunks:
|
||||||
|
ck["doc_id"] = self._canvas._doc_id
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
|
toc =await self._build_TOC(chunks)
|
||||||
chunks.append(toc)
|
chunks.append(toc)
|
||||||
self.set_output("chunks", chunks)
|
self.set_output("chunks", chunks)
|
||||||
return
|
return
|
||||||
|
|
||||||
prog = 0
|
prog = 0
|
||||||
for i, ck in enumerate(chunks):
|
for i, ck in enumerate(chunks):
|
||||||
args[chunks_key] = ck["text"]
|
args[chunks_key] = ck["text"]
|
||||||
|
|||||||
@ -125,7 +125,7 @@ class Splitter(ProcessBase):
|
|||||||
{
|
{
|
||||||
"text": RAGFlowPdfParser.remove_tag(c),
|
"text": RAGFlowPdfParser.remove_tag(c),
|
||||||
"image": img,
|
"image": img,
|
||||||
"positions": [[pos[0][-1]+1, *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
"positions": [[pos[0][-1], *pos[1:]] for pos in RAGFlowPdfParser.extract_positions(c)]
|
||||||
}
|
}
|
||||||
for c, img in zip(chunks, images) if c.strip()
|
for c, img in zip(chunks, images) if c.strip()
|
||||||
]
|
]
|
||||||
|
|||||||
@ -52,6 +52,8 @@ class SupportedLiteLLMProvider(StrEnum):
|
|||||||
JiekouAI = "Jiekou.AI"
|
JiekouAI = "Jiekou.AI"
|
||||||
ZHIPU_AI = "ZHIPU-AI"
|
ZHIPU_AI = "ZHIPU-AI"
|
||||||
MiniMax = "MiniMax"
|
MiniMax = "MiniMax"
|
||||||
|
DeerAPI = "DeerAPI"
|
||||||
|
GPUStack = "GPUStack"
|
||||||
|
|
||||||
|
|
||||||
FACTORY_DEFAULT_BASE_URL = {
|
FACTORY_DEFAULT_BASE_URL = {
|
||||||
@ -75,6 +77,7 @@ FACTORY_DEFAULT_BASE_URL = {
|
|||||||
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
SupportedLiteLLMProvider.JiekouAI: "https://api.jiekou.ai/openai",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
SupportedLiteLLMProvider.ZHIPU_AI: "https://open.bigmodel.cn/api/paas/v4",
|
||||||
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
SupportedLiteLLMProvider.MiniMax: "https://api.minimaxi.com/v1",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "https://api.deerapi.com/v1",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -108,6 +111,8 @@ LITELLM_PROVIDER_PREFIX = {
|
|||||||
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
SupportedLiteLLMProvider.JiekouAI: "openai/",
|
||||||
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
SupportedLiteLLMProvider.ZHIPU_AI: "openai/",
|
||||||
SupportedLiteLLMProvider.MiniMax: "openai/",
|
SupportedLiteLLMProvider.MiniMax: "openai/",
|
||||||
|
SupportedLiteLLMProvider.DeerAPI: "openai/",
|
||||||
|
SupportedLiteLLMProvider.GPUStack: "openai/",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatModel = globals().get("ChatModel", {})
|
ChatModel = globals().get("ChatModel", {})
|
||||||
|
|||||||
@ -19,7 +19,6 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -78,11 +77,9 @@ class Base(ABC):
|
|||||||
self.toolcall_sessions = {}
|
self.toolcall_sessions = {}
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
@ -139,89 +136,7 @@ class Base(ABC):
|
|||||||
|
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
async def _async_chat_streamly(self, history, gen_conf, **kwargs):
|
||||||
"""Run a sync generator in a thread and yield asynchronously."""
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
|
|
||||||
def worker():
|
|
||||||
try:
|
|
||||||
for item in gen:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
|
||||||
except Exception as exc: # pragma: no cover - defensive
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, exc)
|
|
||||||
finally:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
|
||||||
|
|
||||||
threading.Thread(target=worker, daemon=True).start()
|
|
||||||
return queue
|
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using _chat_streamly")
|
|
||||||
|
|
||||||
final_ans = ""
|
|
||||||
tol_token = 0
|
|
||||||
for delta, tol in self._chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
|
||||||
continue
|
|
||||||
final_ans += delta
|
|
||||||
tol_token = tol
|
|
||||||
|
|
||||||
if len(final_ans.strip()) == 0:
|
|
||||||
final_ans = "**ERROR**: Empty response from reasoning model"
|
|
||||||
|
|
||||||
return final_ans.strip(), tol_token
|
|
||||||
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
|
|
||||||
if not response.choices or not response.choices[0].message or not response.choices[0].message.content:
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
if kwargs.get("stop") or "stop" in gen_conf:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
|
||||||
else:
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not resp.choices:
|
|
||||||
continue
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
|
|
||||||
if resp.choices[0].finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def _async_chat_stream(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
|
|
||||||
@ -265,13 +180,19 @@ class Base(ABC):
|
|||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
ans = ""
|
ans = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
try:
|
|
||||||
async for delta_ans, tol in self._async_chat_stream(history, gen_conf, **kwargs):
|
for attempt in range(self.max_retries + 1):
|
||||||
ans = delta_ans
|
try:
|
||||||
total_tokens += tol
|
async for delta_ans, tol in self._async_chat_streamly(history, gen_conf, **kwargs):
|
||||||
yield delta_ans
|
ans = delta_ans
|
||||||
except openai.APIError as e:
|
total_tokens += tol
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
yield ans
|
||||||
|
except Exception as e:
|
||||||
|
e = await self._exceptions_async(e, attempt)
|
||||||
|
if e:
|
||||||
|
yield e
|
||||||
|
yield total_tokens
|
||||||
|
return
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
@ -307,7 +228,7 @@ class Base(ABC):
|
|||||||
logging.error(f"sync base giving up: {msg}")
|
logging.error(f"sync base giving up: {msg}")
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI async completion")
|
logging.exception("OpenAI async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
@ -357,61 +278,6 @@ class Base(ABC):
|
|||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
ans = ""
|
|
||||||
tk_count = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
logging.info(f"{self.tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
|
||||||
tk_count += total_token_count_from_response(response)
|
|
||||||
if any([not response.choices, not response.choices[0].message]):
|
|
||||||
raise Exception(f"500 response structure error. Response: {response}")
|
|
||||||
|
|
||||||
if not hasattr(response.choices[0].message, "tool_calls") or not response.choices[0].message.tool_calls:
|
|
||||||
if hasattr(response.choices[0].message, "reasoning_content") and response.choices[0].message.reasoning_content:
|
|
||||||
ans += "<think>" + response.choices[0].message.reasoning_content + "</think>"
|
|
||||||
|
|
||||||
ans += response.choices[0].message.content
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, tk_count
|
|
||||||
|
|
||||||
for tool_call in response.choices[0].message.tool_calls:
|
|
||||||
logging.info(f"Response {tool_call=}")
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
ans += self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
|
||||||
ans += response
|
|
||||||
tk_count += token_count
|
|
||||||
return ans, tk_count
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, tk_count
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -466,140 +332,6 @@ class Base(ABC):
|
|||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
return self._chat(history, gen_conf, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
tools = self.tools
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
|
|
||||||
total_tokens = 0
|
|
||||||
hist = deepcopy(history)
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
history = hist
|
|
||||||
try:
|
|
||||||
for _ in range(self.max_rounds + 1):
|
|
||||||
reasoning_start = False
|
|
||||||
logging.info(f"{tools=}")
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
|
||||||
final_tool_calls = {}
|
|
||||||
answer = ""
|
|
||||||
for resp in response:
|
|
||||||
if resp.choices[0].delta.tool_calls:
|
|
||||||
for tool_call in resp.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
if not tool_call.function.arguments:
|
|
||||||
tool_call.function.arguments = ""
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
else:
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments if tool_call.function.arguments else ""
|
|
||||||
continue
|
|
||||||
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
|
|
||||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
|
||||||
yield ans
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
yield self._length_stop("")
|
|
||||||
|
|
||||||
if answer:
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
for tool_call in final_tool_calls.values():
|
|
||||||
name = tool_call.function.name
|
|
||||||
try:
|
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
|
||||||
except Exception as e:
|
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
|
||||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
|
||||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
|
||||||
for resp in response:
|
|
||||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
|
||||||
raise Exception("500 response structure error.")
|
|
||||||
if not resp.choices[0].delta.content:
|
|
||||||
resp.choices[0].delta.content = ""
|
|
||||||
continue
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
|
||||||
else:
|
|
||||||
total_tokens = tol
|
|
||||||
answer += resp.choices[0].delta.content
|
|
||||||
yield resp.choices[0].delta.content
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
yield e
|
|
||||||
yield total_tokens
|
|
||||||
return
|
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
@ -715,9 +447,10 @@ class Base(ABC):
|
|||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||||
if self.model_name.lower().find("qwq") >= 0:
|
if self.model_name.lower().find("qwq") >= 0:
|
||||||
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
logging.info(f"[INFO] {self.model_name} detected as reasoning model, using async_chat_streamly")
|
||||||
|
|
||||||
final_ans = ""
|
final_ans = ""
|
||||||
tol_token = 0
|
tol_token = 0
|
||||||
async for delta, tol in self._async_chat_stream(history, gen_conf, with_reasoning=False, **kwargs):
|
async for delta, tol in self._async_chat_streamly(history, gen_conf, with_reasoning=False, **kwargs):
|
||||||
if delta.startswith("<think>") or delta.endswith("</think>"):
|
if delta.startswith("<think>") or delta.endswith("</think>"):
|
||||||
continue
|
continue
|
||||||
final_ans += delta
|
final_ans += delta
|
||||||
@ -754,57 +487,6 @@ class Base(ABC):
|
|||||||
return e, 0
|
return e, 0
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
ans = ""
|
|
||||||
total_tokens = 0
|
|
||||||
try:
|
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
|
||||||
yield delta_ans
|
|
||||||
total_tokens += tol
|
|
||||||
except openai.APIError as e:
|
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
|
||||||
|
|
||||||
yield total_tokens
|
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
|
||||||
"""Calculate dynamic context window size"""
|
|
||||||
|
|
||||||
def count_tokens(text):
|
|
||||||
"""Calculate token count for text"""
|
|
||||||
# Simple calculation: 1 token per ASCII character
|
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total = 0
|
|
||||||
for char in text:
|
|
||||||
if ord(char) < 128: # ASCII characters
|
|
||||||
total += 1
|
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|
||||||
|
|
||||||
class GptTurbo(Base):
|
class GptTurbo(Base):
|
||||||
_FACTORY_NAME = "OpenAI"
|
_FACTORY_NAME = "OpenAI"
|
||||||
@ -1504,16 +1186,6 @@ class GoogleChat(Base):
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
|
|
||||||
|
|
||||||
class GPUStackChat(Base):
|
|
||||||
_FACTORY_NAME = "GPUStack"
|
|
||||||
|
|
||||||
def __init__(self, key=None, model_name="", base_url="", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
raise ValueError("Local llm url cannot be None")
|
|
||||||
base_url = urljoin(base_url, "v1")
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class TokenPonyChat(Base):
|
class TokenPonyChat(Base):
|
||||||
_FACTORY_NAME = "TokenPony"
|
_FACTORY_NAME = "TokenPony"
|
||||||
|
|
||||||
@ -1523,15 +1195,6 @@ class TokenPonyChat(Base):
|
|||||||
super().__init__(key, model_name, base_url, **kwargs)
|
super().__init__(key, model_name, base_url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class DeerAPIChat(Base):
|
|
||||||
_FACTORY_NAME = "DeerAPI"
|
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url="https://api.deerapi.com/v1", **kwargs):
|
|
||||||
if not base_url:
|
|
||||||
base_url = "https://api.deerapi.com/v1"
|
|
||||||
super().__init__(key, model_name, base_url, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMBase(ABC):
|
class LiteLLMBase(ABC):
|
||||||
_FACTORY_NAME = [
|
_FACTORY_NAME = [
|
||||||
"Tongyi-Qianwen",
|
"Tongyi-Qianwen",
|
||||||
@ -1562,6 +1225,8 @@ class LiteLLMBase(ABC):
|
|||||||
"Jiekou.AI",
|
"Jiekou.AI",
|
||||||
"ZHIPU-AI",
|
"ZHIPU-AI",
|
||||||
"MiniMax",
|
"MiniMax",
|
||||||
|
"DeerAPI",
|
||||||
|
"GPUStack",
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, key, model_name, base_url=None, **kwargs):
|
def __init__(self, key, model_name, base_url=None, **kwargs):
|
||||||
@ -1589,11 +1254,9 @@ class LiteLLMBase(ABC):
|
|||||||
self.provider_order = json.loads(key).get("provider_order", "")
|
self.provider_order = json.loads(key).get("provider_order", "")
|
||||||
|
|
||||||
def _get_delay(self):
|
def _get_delay(self):
|
||||||
"""Calculate retry delay time"""
|
|
||||||
return self.base_delay * random.uniform(10, 150)
|
return self.base_delay * random.uniform(10, 150)
|
||||||
|
|
||||||
def _classify_error(self, error):
|
def _classify_error(self, error):
|
||||||
"""Classify error based on error message content"""
|
|
||||||
error_str = str(error).lower()
|
error_str = str(error).lower()
|
||||||
|
|
||||||
keywords_mapping = [
|
keywords_mapping = [
|
||||||
@ -1619,72 +1282,6 @@ class LiteLLMBase(ABC):
|
|||||||
del gen_conf["max_tokens"]
|
del gen_conf["max_tokens"]
|
||||||
return gen_conf
|
return gen_conf
|
||||||
|
|
||||||
def _chat(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
|
||||||
if self.model_name.lower().find("qwen3") >= 0:
|
|
||||||
kwargs["extra_body"] = {"enable_thinking": False}
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=False, **gen_conf)
|
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
# response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
|
||||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
|
||||||
return "", 0
|
|
||||||
ans = response.choices[0].message.content.strip()
|
|
||||||
if response.choices[0].finish_reason == "length":
|
|
||||||
ans = self._length_stop(ans)
|
|
||||||
|
|
||||||
return ans, total_token_count_from_response(response)
|
|
||||||
|
|
||||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
|
||||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
reasoning_start = False
|
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=False, **gen_conf)
|
|
||||||
stop = kwargs.get("stop")
|
|
||||||
if stop:
|
|
||||||
completion_args["stop"] = stop
|
|
||||||
response = litellm.completion(
|
|
||||||
**completion_args,
|
|
||||||
drop_params=True,
|
|
||||||
timeout=self.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
for resp in response:
|
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
|
||||||
continue
|
|
||||||
|
|
||||||
delta = resp.choices[0].delta
|
|
||||||
if not hasattr(delta, "content") or delta.content is None:
|
|
||||||
delta.content = ""
|
|
||||||
|
|
||||||
if kwargs.get("with_reasoning", True) and hasattr(delta, "reasoning_content") and delta.reasoning_content:
|
|
||||||
ans = ""
|
|
||||||
if not reasoning_start:
|
|
||||||
reasoning_start = True
|
|
||||||
ans = "<think>"
|
|
||||||
ans += delta.reasoning_content + "</think>"
|
|
||||||
else:
|
|
||||||
reasoning_start = False
|
|
||||||
ans = delta.content
|
|
||||||
|
|
||||||
tol = total_token_count_from_response(resp)
|
|
||||||
if not tol:
|
|
||||||
tol = num_tokens_from_string(delta.content)
|
|
||||||
|
|
||||||
finish_reason = resp.choices[0].finish_reason if hasattr(resp.choices[0], "finish_reason") else ""
|
|
||||||
if finish_reason == "length":
|
|
||||||
if is_chinese(ans):
|
|
||||||
ans += LENGTH_NOTIFICATION_CN
|
|
||||||
else:
|
|
||||||
ans += LENGTH_NOTIFICATION_EN
|
|
||||||
|
|
||||||
yield ans, tol
|
|
||||||
|
|
||||||
async def async_chat(self, system, history, gen_conf, **kwargs):
|
async def async_chat(self, system, history, gen_conf, **kwargs):
|
||||||
hist = list(history) if history else []
|
hist = list(history) if history else []
|
||||||
if system:
|
if system:
|
||||||
@ -1795,22 +1392,7 @@ class LiteLLMBase(ABC):
|
|||||||
def _should_retry(self, error_code: str) -> bool:
|
def _should_retry(self, error_code: str) -> bool:
|
||||||
return error_code in self._retryable_errors
|
return error_code in self._retryable_errors
|
||||||
|
|
||||||
def _exceptions(self, e, attempt) -> str | None:
|
async def _exceptions_async(self, e, attempt):
|
||||||
logging.exception("OpenAI chat_with_tools")
|
|
||||||
# Classify the error
|
|
||||||
error_code = self._classify_error(e)
|
|
||||||
if attempt == self.max_retries:
|
|
||||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
|
||||||
|
|
||||||
if self._should_retry(error_code):
|
|
||||||
delay = self._get_delay()
|
|
||||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
|
||||||
time.sleep(delay)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
|
||||||
|
|
||||||
async def _exceptions_async(self, e, attempt) -> str | None:
|
|
||||||
logging.exception("LiteLLMBase async completion")
|
logging.exception("LiteLLMBase async completion")
|
||||||
error_code = self._classify_error(e)
|
error_code = self._classify_error(e)
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
@ -1859,71 +1441,7 @@ class LiteLLMBase(ABC):
|
|||||||
self.toolcall_session = toolcall_session
|
self.toolcall_session = toolcall_session
|
||||||
self.tools = tools
|
self.tools = tools
|
||||||
|
|
||||||
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
completion_args = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": history,
|
|
||||||
"api_key": self.api_key,
|
|
||||||
"num_retries": self.max_retries,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
if stream:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"stream": stream,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if tools and self.tools:
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"tools": self.tools,
|
|
||||||
"tool_choice": "auto",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
|
||||||
completion_args.update({"api_base": self.base_url})
|
|
||||||
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
|
||||||
completion_args.pop("api_key", None)
|
|
||||||
completion_args.pop("api_base", None)
|
|
||||||
completion_args.update(
|
|
||||||
{
|
|
||||||
"aws_access_key_id": self.bedrock_ak,
|
|
||||||
"aws_secret_access_key": self.bedrock_sk,
|
|
||||||
"aws_region_name": self.bedrock_region,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.provider == SupportedLiteLLMProvider.OpenRouter:
|
|
||||||
if self.provider_order:
|
|
||||||
|
|
||||||
def _to_order_list(x):
|
|
||||||
if x is None:
|
|
||||||
return []
|
|
||||||
if isinstance(x, str):
|
|
||||||
return [s.strip() for s in x.split(",") if s.strip()]
|
|
||||||
if isinstance(x, (list, tuple)):
|
|
||||||
return [str(s).strip() for s in x if str(s).strip()]
|
|
||||||
return []
|
|
||||||
|
|
||||||
extra_body = {}
|
|
||||||
provider_cfg = {}
|
|
||||||
provider_order = _to_order_list(self.provider_order)
|
|
||||||
provider_cfg["order"] = provider_order
|
|
||||||
provider_cfg["allow_fallbacks"] = False
|
|
||||||
extra_body["provider"] = provider_cfg
|
|
||||||
completion_args.update({"extra_body": extra_body})
|
|
||||||
|
|
||||||
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
|
||||||
# Bearer auth. Ensure the Authorization header is set when an API key
|
|
||||||
# is provided, while respecting any user-supplied headers. #11350
|
|
||||||
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
|
||||||
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
|
||||||
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
|
||||||
if extra_headers:
|
|
||||||
completion_args["extra_headers"] = extra_headers
|
|
||||||
return completion_args
|
|
||||||
|
|
||||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
history.insert(0, {"role": "system", "content": system})
|
history.insert(0, {"role": "system", "content": system})
|
||||||
@ -1931,16 +1449,14 @@ class LiteLLMBase(ABC):
|
|||||||
ans = ""
|
ans = ""
|
||||||
tk_count = 0
|
tk_count = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
logging.info(f"{self.tools=}")
|
logging.info(f"{self.tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=False, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -1966,7 +1482,7 @@ class LiteLLMBase(ABC):
|
|||||||
name = tool_call.function.name
|
name = tool_call.function.name
|
||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
ans += self._verbose_tool_use(name, args, tool_response)
|
ans += self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1977,49 +1493,19 @@ class LiteLLMBase(ABC):
|
|||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
response, token_count = self._chat(history, gen_conf)
|
response, token_count = await self.async_chat("", history, gen_conf)
|
||||||
ans += response
|
ans += response
|
||||||
tk_count += token_count
|
tk_count += token_count
|
||||||
return ans, tk_count
|
return ans, tk_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
return e, tk_count
|
return e, tk_count
|
||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||||
if system and history and history[0].get("role") != "system":
|
|
||||||
history.insert(0, {"role": "system", "content": system})
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
|
||||||
try:
|
|
||||||
response = self._chat(history, gen_conf, **kwargs)
|
|
||||||
return response
|
|
||||||
except Exception as e:
|
|
||||||
e = self._exceptions(e, attempt)
|
|
||||||
if e:
|
|
||||||
return e, 0
|
|
||||||
assert False, "Shouldn't be here."
|
|
||||||
|
|
||||||
def _wrap_toolcall_message(self, stream):
|
|
||||||
final_tool_calls = {}
|
|
||||||
|
|
||||||
for chunk in stream:
|
|
||||||
for tool_call in chunk.choices[0].delta.tool_calls or []:
|
|
||||||
index = tool_call.index
|
|
||||||
|
|
||||||
if index not in final_tool_calls:
|
|
||||||
final_tool_calls[index] = tool_call
|
|
||||||
|
|
||||||
final_tool_calls[index].function.arguments += tool_call.function.arguments
|
|
||||||
|
|
||||||
return final_tool_calls
|
|
||||||
|
|
||||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
gen_conf = self._clean_conf(gen_conf)
|
||||||
tools = self.tools
|
tools = self.tools
|
||||||
if system and history and history[0].get("role") != "system":
|
if system and history and history[0].get("role") != "system":
|
||||||
@ -2028,16 +1514,15 @@ class LiteLLMBase(ABC):
|
|||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
hist = deepcopy(history)
|
hist = deepcopy(history)
|
||||||
|
|
||||||
# Implement exponential backoff retry strategy
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
history = deepcopy(hist) # deepcopy is required here
|
history = deepcopy(hist)
|
||||||
try:
|
try:
|
||||||
for _ in range(self.max_rounds + 1):
|
for _ in range(self.max_rounds + 1):
|
||||||
reasoning_start = False
|
reasoning_start = False
|
||||||
logging.info(f"{tools=}")
|
logging.info(f"{tools=}")
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
@ -2046,7 +1531,7 @@ class LiteLLMBase(ABC):
|
|||||||
final_tool_calls = {}
|
final_tool_calls = {}
|
||||||
answer = ""
|
answer = ""
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -2082,7 +1567,7 @@ class LiteLLMBase(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
|
|
||||||
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
finish_reason = getattr(resp.choices[0], "finish_reason", "")
|
||||||
if finish_reason == "length":
|
if finish_reason == "length":
|
||||||
@ -2097,31 +1582,25 @@ class LiteLLMBase(ABC):
|
|||||||
try:
|
try:
|
||||||
args = json_repair.loads(tool_call.function.arguments)
|
args = json_repair.loads(tool_call.function.arguments)
|
||||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||||
tool_response = self.toolcall_session.tool_call(name, args)
|
tool_response = await asyncio.to_thread(self.toolcall_session.tool_call, name, args)
|
||||||
history = self._append_history(history, tool_call, tool_response)
|
history = self._append_history(history, tool_call, tool_response)
|
||||||
yield self._verbose_tool_use(name, args, tool_response)
|
yield self._verbose_tool_use(name, args, tool_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
logging.exception(msg=f"Wrong JSON argument format in LLM tool call response: {tool_call}")
|
||||||
history.append(
|
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tool_call.id,
|
|
||||||
"content": f"Tool call error: \n{tool_call}\nException:\n{str(e)}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
yield self._verbose_tool_use(name, {}, str(e))
|
yield self._verbose_tool_use(name, {}, str(e))
|
||||||
|
|
||||||
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
logging.warning(f"Exceed max rounds: {self.max_rounds}")
|
||||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||||
|
|
||||||
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
completion_args = self._construct_completion_args(history=history, stream=True, tools=True, **gen_conf)
|
||||||
response = litellm.completion(
|
response = await litellm.acompletion(
|
||||||
**completion_args,
|
**completion_args,
|
||||||
drop_params=True,
|
drop_params=True,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
for resp in response:
|
async for resp in response:
|
||||||
if not hasattr(resp, "choices") or not resp.choices:
|
if not hasattr(resp, "choices") or not resp.choices:
|
||||||
continue
|
continue
|
||||||
delta = resp.choices[0].delta
|
delta = resp.choices[0].delta
|
||||||
@ -2131,14 +1610,14 @@ class LiteLLMBase(ABC):
|
|||||||
if not tol:
|
if not tol:
|
||||||
total_tokens += num_tokens_from_string(delta.content)
|
total_tokens += num_tokens_from_string(delta.content)
|
||||||
else:
|
else:
|
||||||
total_tokens += tol
|
total_tokens = tol
|
||||||
yield delta.content
|
yield delta.content
|
||||||
|
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
e = self._exceptions(e, attempt)
|
e = await self._exceptions_async(e, attempt)
|
||||||
if e:
|
if e:
|
||||||
yield e
|
yield e
|
||||||
yield total_tokens
|
yield total_tokens
|
||||||
@ -2146,53 +1625,71 @@ class LiteLLMBase(ABC):
|
|||||||
|
|
||||||
assert False, "Shouldn't be here."
|
assert False, "Shouldn't be here."
|
||||||
|
|
||||||
def chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
def _construct_completion_args(self, history, stream: bool, tools: bool, **kwargs):
|
||||||
if system and history and history[0].get("role") != "system":
|
completion_args = {
|
||||||
history.insert(0, {"role": "system", "content": system})
|
"model": self.model_name,
|
||||||
gen_conf = self._clean_conf(gen_conf)
|
"messages": history,
|
||||||
ans = ""
|
"api_key": self.api_key,
|
||||||
total_tokens = 0
|
"num_retries": self.max_retries,
|
||||||
try:
|
**kwargs,
|
||||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
}
|
||||||
yield delta_ans
|
if stream:
|
||||||
total_tokens += tol
|
completion_args.update(
|
||||||
except openai.APIError as e:
|
{
|
||||||
yield ans + "\n**ERROR**: " + str(e)
|
"stream": stream,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if tools and self.tools:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"tools": self.tools,
|
||||||
|
"tool_choice": "auto",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
if self.provider in FACTORY_DEFAULT_BASE_URL:
|
||||||
|
completion_args.update({"api_base": self.base_url})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.Bedrock:
|
||||||
|
completion_args.pop("api_key", None)
|
||||||
|
completion_args.pop("api_base", None)
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"aws_access_key_id": self.bedrock_ak,
|
||||||
|
"aws_secret_access_key": self.bedrock_sk,
|
||||||
|
"aws_region_name": self.bedrock_region,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.OpenRouter:
|
||||||
|
if self.provider_order:
|
||||||
|
|
||||||
yield total_tokens
|
def _to_order_list(x):
|
||||||
|
if x is None:
|
||||||
|
return []
|
||||||
|
if isinstance(x, str):
|
||||||
|
return [s.strip() for s in x.split(",") if s.strip()]
|
||||||
|
if isinstance(x, (list, tuple)):
|
||||||
|
return [str(s).strip() for s in x if str(s).strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
def _calculate_dynamic_ctx(self, history):
|
extra_body = {}
|
||||||
"""Calculate dynamic context window size"""
|
provider_cfg = {}
|
||||||
|
provider_order = _to_order_list(self.provider_order)
|
||||||
|
provider_cfg["order"] = provider_order
|
||||||
|
provider_cfg["allow_fallbacks"] = False
|
||||||
|
extra_body["provider"] = provider_cfg
|
||||||
|
completion_args.update({"extra_body": extra_body})
|
||||||
|
elif self.provider == SupportedLiteLLMProvider.GPUStack:
|
||||||
|
completion_args.update(
|
||||||
|
{
|
||||||
|
"api_base": self.base_url,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
def count_tokens(text):
|
# Ollama deployments commonly sit behind a reverse proxy that enforces
|
||||||
"""Calculate token count for text"""
|
# Bearer auth. Ensure the Authorization header is set when an API key
|
||||||
# Simple calculation: 1 token per ASCII character
|
# is provided, while respecting any user-supplied headers. #11350
|
||||||
# 2 tokens for non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
extra_headers = deepcopy(completion_args.get("extra_headers") or {})
|
||||||
total = 0
|
if self.provider == SupportedLiteLLMProvider.Ollama and self.api_key and "Authorization" not in extra_headers:
|
||||||
for char in text:
|
extra_headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
if ord(char) < 128: # ASCII characters
|
if extra_headers:
|
||||||
total += 1
|
completion_args["extra_headers"] = extra_headers
|
||||||
else: # Non-ASCII characters (Chinese, Japanese, Korean, etc.)
|
return completion_args
|
||||||
total += 2
|
|
||||||
return total
|
|
||||||
|
|
||||||
# Calculate total tokens for all messages
|
|
||||||
total_tokens = 0
|
|
||||||
for message in history:
|
|
||||||
content = message.get("content", "")
|
|
||||||
# Calculate content tokens
|
|
||||||
content_tokens = count_tokens(content)
|
|
||||||
# Add role marker token overhead
|
|
||||||
role_tokens = 4
|
|
||||||
total_tokens += content_tokens + role_tokens
|
|
||||||
|
|
||||||
# Apply 1.2x buffer ratio
|
|
||||||
total_tokens_with_buffer = int(total_tokens * 1.2)
|
|
||||||
|
|
||||||
if total_tokens_with_buffer <= 8192:
|
|
||||||
ctx_size = 8192
|
|
||||||
else:
|
|
||||||
ctx_multiplier = (total_tokens_with_buffer // 8192) + 1
|
|
||||||
ctx_size = ctx_multiplier * 8192
|
|
||||||
|
|
||||||
return ctx_size
|
|
||||||
|
|||||||
@ -592,7 +592,8 @@ async def run_dataflow(task: dict):
|
|||||||
ck["docnm_kwd"] = task["name"]
|
ck["docnm_kwd"] = task["name"]
|
||||||
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
ck["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||||
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
ck["create_timestamp_flt"] = datetime.now().timestamp()
|
||||||
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
if not ck.get("id"):
|
||||||
|
ck["id"] = xxhash.xxh64((ck["text"] + str(ck["doc_id"])).encode("utf-8")).hexdigest()
|
||||||
if "questions" in ck:
|
if "questions" in ck:
|
||||||
if "question_tks" not in ck:
|
if "question_tks" not in ck:
|
||||||
ck["question_kwd"] = ck["questions"].split("\n")
|
ck["question_kwd"] = ck["questions"].split("\n")
|
||||||
|
|||||||
@ -122,15 +122,15 @@ async def create_container(name: str, language: SupportLanguage) -> bool:
|
|||||||
logger.info(f"Sandbox config:\n\t {create_args}")
|
logger.info(f"Sandbox config:\n\t {create_args}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
returncode, _, stderr = await async_run_command(*create_args, timeout=10)
|
return_code, _, stderr = await async_run_command(*create_args, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
logger.error(f"❌ Container creation failed {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if language == SupportLanguage.NODEJS:
|
if language == SupportLanguage.NODEJS:
|
||||||
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
copy_cmd = ["docker", "exec", name, "bash", "-c", "cp -a /app/node_modules /workspace/"]
|
||||||
returncode, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
return_code, _, stderr = await async_run_command(*copy_cmd, timeout=10)
|
||||||
if returncode != 0:
|
if return_code != 0:
|
||||||
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
logger.error(f"❌ Failed to prepare dependencies for {name}: {stderr}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -185,7 +185,7 @@ async def allocate_container_blocking(language: SupportLanguage, timeout=10) ->
|
|||||||
async def container_is_running(name: str) -> bool:
|
async def container_is_running(name: str) -> bool:
|
||||||
"""Asynchronously check the container status"""
|
"""Asynchronously check the container status"""
|
||||||
try:
|
try:
|
||||||
returncode, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
return_code, stdout, _ = await async_run_command("docker", "inspect", "-f", "{{.State.Running}}", name, timeout=2)
|
||||||
return returncode == 0 and stdout.strip() == "true"
|
return return_code == 0 and stdout.strip() == "true"
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|||||||
16
web/src/assets/svg/home-icon/memory.svg
Normal file
16
web/src/assets/svg/home-icon/memory.svg
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<g id="memory dark">
|
||||||
|
<g id="lucide/hard-drive">
|
||||||
|
<path id="Vector" d="M22 12H2M22 12V18C22 18.5304 21.7893 19.0391 21.4142 19.4142C21.0391 19.7893 20.5304 20 20 20H4C3.46957 20 2.96086 19.7893 2.58579 19.4142C2.21071 19.0391 2 18.5304 2 18V12M22 12L18.55 5.11C18.3844 4.77679 18.1292 4.49637 17.813 4.30028C17.4967 4.10419 17.1321 4.0002 16.76 4H7.24C6.86792 4.0002 6.50326 4.10419 6.18704 4.30028C5.87083 4.49637 5.61558 4.77679 5.45 5.11L2 12" stroke="url(#paint0_linear_1100_4836)" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
</g>
|
||||||
|
<g id="lucide/hard-drive_2">
|
||||||
|
<path id="Vector_2" d="M6 16H6.01M10 16H10.01" stroke="#00BEB4" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
<defs>
|
||||||
|
<linearGradient id="paint0_linear_1100_4836" x1="12.5556" y1="4" x2="12.5556" y2="20" gradientUnits="userSpaceOnUse">
|
||||||
|
<stop stop-color="white"/>
|
||||||
|
<stop offset="1" stop-color="#666666"/>
|
||||||
|
</linearGradient>
|
||||||
|
</defs>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.0 KiB |
@ -1,5 +1,5 @@
|
|||||||
import CopyToClipboard from '@/components/copy-to-clipboard';
|
import CopyToClipboard from '@/components/copy-to-clipboard';
|
||||||
import HightLightMarkdown from '@/components/highlight-markdown';
|
import HighLightMarkdown from '@/components/highlight-markdown';
|
||||||
import { SharedFrom } from '@/constants/chat';
|
import { SharedFrom } from '@/constants/chat';
|
||||||
import { useTranslate } from '@/hooks/common-hooks';
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
import { IModalProps } from '@/interfaces/common';
|
import { IModalProps } from '@/interfaces/common';
|
||||||
@ -111,7 +111,7 @@ const EmbedModal = ({
|
|||||||
/>
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
</div>
|
</div>
|
||||||
<HightLightMarkdown>{text}</HightLightMarkdown>
|
<HighLightMarkdown>{text}</HighLightMarkdown>
|
||||||
</Card>
|
</Card>
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,6 +1,13 @@
|
|||||||
import { zodResolver } from '@hookform/resolvers/zod';
|
import { zodResolver } from '@hookform/resolvers/zod';
|
||||||
import { forwardRef, useEffect, useImperativeHandle, useMemo } from 'react';
|
|
||||||
import {
|
import {
|
||||||
|
forwardRef,
|
||||||
|
useEffect,
|
||||||
|
useImperativeHandle,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
} from 'react';
|
||||||
|
import {
|
||||||
|
ControllerRenderProps,
|
||||||
DefaultValues,
|
DefaultValues,
|
||||||
FieldValues,
|
FieldValues,
|
||||||
SubmitHandler,
|
SubmitHandler,
|
||||||
@ -26,6 +33,7 @@ import { Textarea } from '@/components/ui/textarea';
|
|||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
import { Loader } from 'lucide-react';
|
import { Loader } from 'lucide-react';
|
||||||
|
import { MultiSelect, MultiSelectOptionType } from './ui/multi-select';
|
||||||
|
|
||||||
// Field type enumeration
|
// Field type enumeration
|
||||||
export enum FormFieldType {
|
export enum FormFieldType {
|
||||||
@ -35,14 +43,17 @@ export enum FormFieldType {
|
|||||||
Number = 'number',
|
Number = 'number',
|
||||||
Textarea = 'textarea',
|
Textarea = 'textarea',
|
||||||
Select = 'select',
|
Select = 'select',
|
||||||
|
MultiSelect = 'multi-select',
|
||||||
Checkbox = 'checkbox',
|
Checkbox = 'checkbox',
|
||||||
Tag = 'tag',
|
Tag = 'tag',
|
||||||
|
Custom = 'custom',
|
||||||
}
|
}
|
||||||
|
|
||||||
// Field configuration interface
|
// Field configuration interface
|
||||||
export interface FormFieldConfig {
|
export interface FormFieldConfig {
|
||||||
name: string;
|
name: string;
|
||||||
label: string;
|
label: string;
|
||||||
|
hideLabel?: boolean;
|
||||||
type: FormFieldType;
|
type: FormFieldType;
|
||||||
hidden?: boolean;
|
hidden?: boolean;
|
||||||
required?: boolean;
|
required?: boolean;
|
||||||
@ -57,7 +68,7 @@ export interface FormFieldConfig {
|
|||||||
max?: number;
|
max?: number;
|
||||||
message?: string;
|
message?: string;
|
||||||
};
|
};
|
||||||
render?: (fieldProps: any) => React.ReactNode;
|
render?: (fieldProps: ControllerRenderProps) => React.ReactNode;
|
||||||
horizontal?: boolean;
|
horizontal?: boolean;
|
||||||
onChange?: (value: any) => void;
|
onChange?: (value: any) => void;
|
||||||
tooltip?: React.ReactNode;
|
tooltip?: React.ReactNode;
|
||||||
@ -78,10 +89,10 @@ interface DynamicFormProps<T extends FieldValues> {
|
|||||||
className?: string;
|
className?: string;
|
||||||
children?: React.ReactNode;
|
children?: React.ReactNode;
|
||||||
defaultValues?: DefaultValues<T>;
|
defaultValues?: DefaultValues<T>;
|
||||||
onFieldUpdate?: (
|
// onFieldUpdate?: (
|
||||||
fieldName: string,
|
// fieldName: string,
|
||||||
updatedField: Partial<FormFieldConfig>,
|
// updatedField: Partial<FormFieldConfig>,
|
||||||
) => void;
|
// ) => void;
|
||||||
labelClassName?: string;
|
labelClassName?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,6 +103,10 @@ export interface DynamicFormRef {
|
|||||||
reset: (values?: any) => void;
|
reset: (values?: any) => void;
|
||||||
watch: (field: string, callback: (value: any) => void) => () => void;
|
watch: (field: string, callback: (value: any) => void) => () => void;
|
||||||
updateFieldType: (fieldName: string, newType: FormFieldType) => void;
|
updateFieldType: (fieldName: string, newType: FormFieldType) => void;
|
||||||
|
onFieldUpdate: (
|
||||||
|
fieldName: string,
|
||||||
|
newFieldProperties: Partial<FormFieldConfig>,
|
||||||
|
) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate Zod validation schema based on field configurations
|
// Generate Zod validation schema based on field configurations
|
||||||
@ -110,6 +125,14 @@ const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
|
|||||||
case FormFieldType.Email:
|
case FormFieldType.Email:
|
||||||
fieldSchema = z.string().email('Please enter a valid email address');
|
fieldSchema = z.string().email('Please enter a valid email address');
|
||||||
break;
|
break;
|
||||||
|
case FormFieldType.MultiSelect:
|
||||||
|
fieldSchema = z.array(z.string()).optional();
|
||||||
|
if (field.required) {
|
||||||
|
fieldSchema = z.array(z.string()).min(1, {
|
||||||
|
message: `${field.label} is required`,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
break;
|
||||||
case FormFieldType.Number:
|
case FormFieldType.Number:
|
||||||
fieldSchema = z.coerce.number();
|
fieldSchema = z.coerce.number();
|
||||||
if (field.validation?.min !== undefined) {
|
if (field.validation?.min !== undefined) {
|
||||||
@ -275,7 +298,10 @@ const generateDefaultValues = <T extends FieldValues>(
|
|||||||
defaultValues[field.name] = field.defaultValue;
|
defaultValues[field.name] = field.defaultValue;
|
||||||
} else if (field.type === FormFieldType.Checkbox) {
|
} else if (field.type === FormFieldType.Checkbox) {
|
||||||
defaultValues[field.name] = false;
|
defaultValues[field.name] = false;
|
||||||
} else if (field.type === FormFieldType.Tag) {
|
} else if (
|
||||||
|
field.type === FormFieldType.Tag ||
|
||||||
|
field.type === FormFieldType.MultiSelect
|
||||||
|
) {
|
||||||
defaultValues[field.name] = [];
|
defaultValues[field.name] = [];
|
||||||
} else {
|
} else {
|
||||||
defaultValues[field.name] = '';
|
defaultValues[field.name] = '';
|
||||||
@ -291,17 +317,21 @@ const DynamicForm = {
|
|||||||
Root: forwardRef(
|
Root: forwardRef(
|
||||||
<T extends FieldValues>(
|
<T extends FieldValues>(
|
||||||
{
|
{
|
||||||
fields,
|
fields: originFields,
|
||||||
onSubmit,
|
onSubmit,
|
||||||
className = '',
|
className = '',
|
||||||
children,
|
children,
|
||||||
defaultValues: formDefaultValues = {} as DefaultValues<T>,
|
defaultValues: formDefaultValues = {} as DefaultValues<T>,
|
||||||
onFieldUpdate,
|
// onFieldUpdate,
|
||||||
labelClassName,
|
labelClassName,
|
||||||
}: DynamicFormProps<T>,
|
}: DynamicFormProps<T>,
|
||||||
ref: React.Ref<any>,
|
ref: React.Ref<any>,
|
||||||
) => {
|
) => {
|
||||||
// Generate validation schema and default values
|
// Generate validation schema and default values
|
||||||
|
const [fields, setFields] = useState(originFields);
|
||||||
|
useMemo(() => {
|
||||||
|
setFields(originFields);
|
||||||
|
}, [originFields]);
|
||||||
const schema = useMemo(() => generateSchema(fields), [fields]);
|
const schema = useMemo(() => generateSchema(fields), [fields]);
|
||||||
|
|
||||||
const defaultValues = useMemo(() => {
|
const defaultValues = useMemo(() => {
|
||||||
@ -406,43 +436,54 @@ const DynamicForm = {
|
|||||||
}, [fields, form]);
|
}, [fields, form]);
|
||||||
|
|
||||||
// Expose form methods via ref
|
// Expose form methods via ref
|
||||||
useImperativeHandle(ref, () => ({
|
useImperativeHandle(
|
||||||
submit: () => form.handleSubmit(onSubmit)(),
|
ref,
|
||||||
getValues: () => form.getValues(),
|
() => ({
|
||||||
reset: (values?: T) => {
|
submit: () => form.handleSubmit(onSubmit)(),
|
||||||
if (values) {
|
getValues: () => form.getValues(),
|
||||||
form.reset(values);
|
reset: (values?: T) => {
|
||||||
} else {
|
if (values) {
|
||||||
form.reset();
|
form.reset(values);
|
||||||
}
|
|
||||||
},
|
|
||||||
setError: form.setError,
|
|
||||||
clearErrors: form.clearErrors,
|
|
||||||
trigger: form.trigger,
|
|
||||||
watch: (field: string, callback: (value: any) => void) => {
|
|
||||||
const { unsubscribe } = form.watch((values: any) => {
|
|
||||||
if (values && values[field] !== undefined) {
|
|
||||||
callback(values[field]);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return unsubscribe;
|
|
||||||
},
|
|
||||||
|
|
||||||
onFieldUpdate: (
|
|
||||||
fieldName: string,
|
|
||||||
updatedField: Partial<FormFieldConfig>,
|
|
||||||
) => {
|
|
||||||
setTimeout(() => {
|
|
||||||
if (onFieldUpdate) {
|
|
||||||
onFieldUpdate(fieldName, updatedField);
|
|
||||||
} else {
|
} else {
|
||||||
console.warn(
|
form.reset();
|
||||||
'onFieldUpdate prop is not provided. Cannot update field type.',
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}, 0);
|
},
|
||||||
},
|
setError: form.setError,
|
||||||
}));
|
clearErrors: form.clearErrors,
|
||||||
|
trigger: form.trigger,
|
||||||
|
watch: (field: string, callback: (value: any) => void) => {
|
||||||
|
const { unsubscribe } = form.watch((values: any) => {
|
||||||
|
if (values && values[field] !== undefined) {
|
||||||
|
callback(values[field]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return unsubscribe;
|
||||||
|
},
|
||||||
|
|
||||||
|
onFieldUpdate: (
|
||||||
|
fieldName: string,
|
||||||
|
updatedField: Partial<FormFieldConfig>,
|
||||||
|
) => {
|
||||||
|
setFields((prevFields: any) =>
|
||||||
|
prevFields.map((field: any) =>
|
||||||
|
field.name === fieldName
|
||||||
|
? { ...field, ...updatedField }
|
||||||
|
: field,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
// setTimeout(() => {
|
||||||
|
// if (onFieldUpdate) {
|
||||||
|
// onFieldUpdate(fieldName, updatedField);
|
||||||
|
// } else {
|
||||||
|
// console.warn(
|
||||||
|
// 'onFieldUpdate prop is not provided. Cannot update field type.',
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
// }, 0);
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
[form],
|
||||||
|
);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (formDefaultValues && Object.keys(formDefaultValues).length > 0) {
|
if (formDefaultValues && Object.keys(formDefaultValues).length > 0) {
|
||||||
@ -459,6 +500,9 @@ const DynamicForm = {
|
|||||||
// Render form fields
|
// Render form fields
|
||||||
const renderField = (field: FormFieldConfig) => {
|
const renderField = (field: FormFieldConfig) => {
|
||||||
if (field.render) {
|
if (field.render) {
|
||||||
|
if (field.type === FormFieldType.Custom && field.hideLabel) {
|
||||||
|
return <div className="w-full">{field.render({})}</div>;
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<RAGFlowFormItem
|
<RAGFlowFormItem
|
||||||
name={field.name}
|
name={field.name}
|
||||||
@ -549,6 +593,43 @@ const DynamicForm = {
|
|||||||
</RAGFlowFormItem>
|
</RAGFlowFormItem>
|
||||||
);
|
);
|
||||||
|
|
||||||
|
case FormFieldType.MultiSelect:
|
||||||
|
return (
|
||||||
|
<RAGFlowFormItem
|
||||||
|
name={field.name}
|
||||||
|
label={field.label}
|
||||||
|
required={field.required}
|
||||||
|
horizontal={field.horizontal}
|
||||||
|
tooltip={field.tooltip}
|
||||||
|
labelClassName={labelClassName || field.labelClassName}
|
||||||
|
>
|
||||||
|
{(fieldProps) => {
|
||||||
|
console.log('multi select value', fieldProps);
|
||||||
|
const finalFieldProps = {
|
||||||
|
...fieldProps,
|
||||||
|
onValueChange: (value: string[]) => {
|
||||||
|
if (fieldProps.onChange) {
|
||||||
|
fieldProps.onChange(value);
|
||||||
|
}
|
||||||
|
field.onChange?.(value);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
return (
|
||||||
|
<MultiSelect
|
||||||
|
variant="inverted"
|
||||||
|
maxCount={100}
|
||||||
|
{...finalFieldProps}
|
||||||
|
// onValueChange={(data) => {
|
||||||
|
// console.log(data);
|
||||||
|
// field.onChange?.(data);
|
||||||
|
// }}
|
||||||
|
options={field.options as MultiSelectOptionType[]}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}}
|
||||||
|
</RAGFlowFormItem>
|
||||||
|
);
|
||||||
|
|
||||||
case FormFieldType.Checkbox:
|
case FormFieldType.Checkbox:
|
||||||
return (
|
return (
|
||||||
<FormField
|
<FormField
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import CopyToClipboard from '@/components/copy-to-clipboard';
|
import CopyToClipboard from '@/components/copy-to-clipboard';
|
||||||
import HightLightMarkdown from '@/components/highlight-markdown';
|
import HighLightMarkdown from '@/components/highlight-markdown';
|
||||||
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
import { SelectWithSearch } from '@/components/originui/select-with-search';
|
||||||
import {
|
import {
|
||||||
Dialog,
|
Dialog,
|
||||||
@ -277,7 +277,7 @@ function EmbedDialog({
|
|||||||
<div className="max-h-[350px] overflow-auto">
|
<div className="max-h-[350px] overflow-auto">
|
||||||
<span>{t('embedCode', { keyPrefix: 'search' })}</span>
|
<span>{t('embedCode', { keyPrefix: 'search' })}</span>
|
||||||
<div className="max-h-full overflow-y-auto">
|
<div className="max-h-full overflow-y-auto">
|
||||||
<HightLightMarkdown>{text}</HightLightMarkdown>
|
<HighLightMarkdown>{text}</HighLightMarkdown>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className=" font-medium mt-4 mb-1">
|
<div className=" font-medium mt-4 mb-1">
|
||||||
|
|||||||
@ -11,23 +11,33 @@ export enum EmptyCardType {
|
|||||||
Dataset = 'dataset',
|
Dataset = 'dataset',
|
||||||
Chat = 'chat',
|
Chat = 'chat',
|
||||||
Search = 'search',
|
Search = 'search',
|
||||||
|
Memory = 'memory',
|
||||||
}
|
}
|
||||||
|
|
||||||
export const EmptyCardData = {
|
export const EmptyCardData = {
|
||||||
[EmptyCardType.Agent]: {
|
[EmptyCardType.Agent]: {
|
||||||
icon: <HomeIcon name="agents" width={'24'} />,
|
icon: <HomeIcon name="agents" width={'24'} />,
|
||||||
title: t('empty.agentTitle'),
|
title: t('empty.agentTitle'),
|
||||||
|
notFound: t('empty.notFoundAgent'),
|
||||||
},
|
},
|
||||||
[EmptyCardType.Dataset]: {
|
[EmptyCardType.Dataset]: {
|
||||||
icon: <HomeIcon name="datasets" width={'24'} />,
|
icon: <HomeIcon name="datasets" width={'24'} />,
|
||||||
title: t('empty.datasetTitle'),
|
title: t('empty.datasetTitle'),
|
||||||
|
notFound: t('empty.notFoundDataset'),
|
||||||
},
|
},
|
||||||
[EmptyCardType.Chat]: {
|
[EmptyCardType.Chat]: {
|
||||||
icon: <HomeIcon name="chats" width={'24'} />,
|
icon: <HomeIcon name="chats" width={'24'} />,
|
||||||
title: t('empty.chatTitle'),
|
title: t('empty.chatTitle'),
|
||||||
|
notFound: t('empty.notFoundChat'),
|
||||||
},
|
},
|
||||||
[EmptyCardType.Search]: {
|
[EmptyCardType.Search]: {
|
||||||
icon: <HomeIcon name="searches" width={'24'} />,
|
icon: <HomeIcon name="searches" width={'24'} />,
|
||||||
title: t('empty.searchTitle'),
|
title: t('empty.searchTitle'),
|
||||||
|
notFound: t('empty.notFoundSearch'),
|
||||||
|
},
|
||||||
|
[EmptyCardType.Memory]: {
|
||||||
|
icon: <HomeIcon name="memory" width={'24'} />,
|
||||||
|
title: t('empty.memoryTitle'),
|
||||||
|
notFound: t('empty.notFoundMemory'),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@ -76,9 +76,10 @@ export const EmptyAppCard = (props: {
|
|||||||
onClick?: () => void;
|
onClick?: () => void;
|
||||||
showIcon?: boolean;
|
showIcon?: boolean;
|
||||||
className?: string;
|
className?: string;
|
||||||
|
isSearch?: boolean;
|
||||||
size?: 'small' | 'large';
|
size?: 'small' | 'large';
|
||||||
}) => {
|
}) => {
|
||||||
const { type, showIcon, className } = props;
|
const { type, showIcon, className, isSearch } = props;
|
||||||
let defaultClass = '';
|
let defaultClass = '';
|
||||||
let style = {};
|
let style = {};
|
||||||
switch (props.size) {
|
switch (props.size) {
|
||||||
@ -95,19 +96,29 @@ export const EmptyAppCard = (props: {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<div className=" cursor-pointer " onClick={props.onClick}>
|
<div
|
||||||
|
className=" cursor-pointer "
|
||||||
|
onClick={isSearch ? undefined : props.onClick}
|
||||||
|
>
|
||||||
<EmptyCard
|
<EmptyCard
|
||||||
icon={showIcon ? EmptyCardData[type].icon : undefined}
|
icon={showIcon ? EmptyCardData[type].icon : undefined}
|
||||||
title={EmptyCardData[type].title}
|
title={
|
||||||
|
isSearch ? EmptyCardData[type].notFound : EmptyCardData[type].title
|
||||||
|
}
|
||||||
className={className}
|
className={className}
|
||||||
style={style}
|
style={style}
|
||||||
// description={EmptyCardData[type].description}
|
// description={EmptyCardData[type].description}
|
||||||
>
|
>
|
||||||
<div
|
{!isSearch && (
|
||||||
className={cn(defaultClass, 'flex items-center justify-start w-full')}
|
<div
|
||||||
>
|
className={cn(
|
||||||
<Plus size={24} />
|
defaultClass,
|
||||||
</div>
|
'flex items-center justify-start w-full',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<Plus size={24} />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</EmptyCard>
|
</EmptyCard>
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
|
|||||||
@ -16,7 +16,7 @@ import { preprocessLaTeX } from '@/utils/chat';
|
|||||||
import { useIsDarkTheme } from '../theme-provider';
|
import { useIsDarkTheme } from '../theme-provider';
|
||||||
import styles from './index.less';
|
import styles from './index.less';
|
||||||
|
|
||||||
const HightLightMarkdown = ({
|
const HighLightMarkdown = ({
|
||||||
children,
|
children,
|
||||||
}: {
|
}: {
|
||||||
children: string | null | undefined;
|
children: string | null | undefined;
|
||||||
@ -56,4 +56,4 @@ const HightLightMarkdown = ({
|
|||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
export default HightLightMarkdown;
|
export default HighLightMarkdown;
|
||||||
|
|||||||
@ -3,10 +3,7 @@
|
|||||||
// Inspired by react-hot-toast library
|
// Inspired by react-hot-toast library
|
||||||
import * as React from 'react';
|
import * as React from 'react';
|
||||||
|
|
||||||
import type {
|
import type { ToastActionElement, ToastProps } from '@/components/ui/toast';
|
||||||
ToastActionElement,
|
|
||||||
ToastProps,
|
|
||||||
} from '@/registry/default/ui/toast';
|
|
||||||
|
|
||||||
const TOAST_LIMIT = 1;
|
const TOAST_LIMIT = 1;
|
||||||
const TOAST_REMOVE_DELAY = 1000000;
|
const TOAST_REMOVE_DELAY = 1000000;
|
||||||
|
|||||||
@ -244,7 +244,7 @@ export interface JsonEditorOptions {
|
|||||||
timestampFormat?: string;
|
timestampFormat?: string;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If true, unicode characters are escaped. false by default.
|
* If true, Unicode characters are escaped. false by default.
|
||||||
*/
|
*/
|
||||||
escapeUnicode?: boolean;
|
escapeUnicode?: boolean;
|
||||||
|
|
||||||
|
|||||||
@ -9,13 +9,19 @@ export type LLMFormFieldProps = {
|
|||||||
name?: string;
|
name?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export function LLMFormField({ options, name }: LLMFormFieldProps) {
|
export const useModelOptions = () => {
|
||||||
const { t } = useTranslation();
|
|
||||||
|
|
||||||
const modelOptions = useComposeLlmOptionsByModelTypes([
|
const modelOptions = useComposeLlmOptionsByModelTypes([
|
||||||
LlmModelType.Chat,
|
LlmModelType.Chat,
|
||||||
LlmModelType.Image2text,
|
LlmModelType.Image2text,
|
||||||
]);
|
]);
|
||||||
|
return {
|
||||||
|
modelOptions,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export function LLMFormField({ options, name }: LLMFormFieldProps) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { modelOptions } = useModelOptions();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RAGFlowFormItem name={name || 'llm_id'} label={t('chat.model')}>
|
<RAGFlowFormItem name={name || 'llm_id'} label={t('chat.model')}>
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import { IModalProps } from '@/interfaces/common';
|
import { IModalProps } from '@/interfaces/common';
|
||||||
import { IFeedbackRequestBody } from '@/interfaces/request/chat';
|
import { IFeedbackRequestBody } from '@/interfaces/request/chat';
|
||||||
import HightLightMarkdown from './highlight-markdown';
|
import HighLightMarkdown from './highlight-markdown';
|
||||||
import SvgIcon from './svg-icon';
|
import SvgIcon from './svg-icon';
|
||||||
import { Dialog, DialogContent, DialogHeader, DialogTitle } from './ui/dialog';
|
import { Dialog, DialogContent, DialogHeader, DialogTitle } from './ui/dialog';
|
||||||
|
|
||||||
@ -25,7 +25,7 @@ export function PromptDialog({
|
|||||||
</DialogTitle>
|
</DialogTitle>
|
||||||
</DialogHeader>
|
</DialogHeader>
|
||||||
<section className="max-h-[80vh] overflow-auto">
|
<section className="max-h-[80vh] overflow-auto">
|
||||||
<HightLightMarkdown>{prompt}</HightLightMarkdown>
|
<HighLightMarkdown>{prompt}</HighLightMarkdown>
|
||||||
</section>
|
</section>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|||||||
@ -53,14 +53,16 @@ export function RAGFlowFormItem({
|
|||||||
{label}
|
{label}
|
||||||
</FormLabel>
|
</FormLabel>
|
||||||
)}
|
)}
|
||||||
<FormControl>
|
<div className="w-full flex flex-col">
|
||||||
{typeof children === 'function'
|
<FormControl>
|
||||||
? children(field)
|
{typeof children === 'function'
|
||||||
: isValidElement(children)
|
? children(field)
|
||||||
? cloneElement(children, { ...field })
|
: isValidElement(children)
|
||||||
: children}
|
? cloneElement(children, { ...field })
|
||||||
</FormControl>
|
: children}
|
||||||
<FormMessage />
|
</FormControl>
|
||||||
|
<FormMessage />
|
||||||
|
</div>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
)}
|
)}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@ -126,3 +126,53 @@ export const IconMap = {
|
|||||||
[LLMFactory.JiekouAI]: 'jiekouai',
|
[LLMFactory.JiekouAI]: 'jiekouai',
|
||||||
[LLMFactory.Builtin]: 'builtin',
|
[LLMFactory.Builtin]: 'builtin',
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const APIMapUrl = {
|
||||||
|
[LLMFactory.OpenAI]: 'https://platform.openai.com/api-keys',
|
||||||
|
[LLMFactory.Anthropic]: 'https://console.anthropic.com/settings/keys',
|
||||||
|
[LLMFactory.Gemini]: 'https://aistudio.google.com/app/apikey',
|
||||||
|
[LLMFactory.DeepSeek]: 'https://platform.deepseek.com/api_keys',
|
||||||
|
[LLMFactory.Moonshot]: 'https://platform.moonshot.cn/console/api-keys',
|
||||||
|
[LLMFactory.TongYiQianWen]: 'https://dashscope.console.aliyun.com/apiKey',
|
||||||
|
[LLMFactory.ZhipuAI]: 'https://open.bigmodel.cn/usercenter/apikeys',
|
||||||
|
[LLMFactory.XAI]: 'https://x.ai/api/',
|
||||||
|
[LLMFactory.HuggingFace]: 'https://huggingface.co/settings/tokens',
|
||||||
|
[LLMFactory.Mistral]: 'https://console.mistral.ai/api-keys/',
|
||||||
|
[LLMFactory.Cohere]: 'https://dashboard.cohere.com/api-keys',
|
||||||
|
[LLMFactory.BaiduYiYan]: 'https://wenxin.baidu.com/user/key',
|
||||||
|
[LLMFactory.Meituan]: 'https://longcat.chat/platform/api_keys',
|
||||||
|
[LLMFactory.Bedrock]:
|
||||||
|
'https://us-east-2.console.aws.amazon.com/bedrock/home#/api-keys',
|
||||||
|
[LLMFactory.AzureOpenAI]:
|
||||||
|
'https://portal.azure.com/#create/Microsoft.CognitiveServicesOpenAI',
|
||||||
|
[LLMFactory.OpenRouter]: 'https://openrouter.ai/keys',
|
||||||
|
[LLMFactory.XunFeiSpark]: 'https://console.xfyun.cn/services/cbm',
|
||||||
|
[LLMFactory.MiniMax]:
|
||||||
|
'https://platform.minimaxi.com/user-center/basic-information',
|
||||||
|
[LLMFactory.Groq]: 'https://console.groq.com/keys',
|
||||||
|
[LLMFactory.NVIDIA]: 'https://build.nvidia.com/settings/api-keys',
|
||||||
|
[LLMFactory.SILICONFLOW]: 'https://cloud.siliconflow.cn/account/ak',
|
||||||
|
[LLMFactory.Replicate]: 'https://replicate.com/account/api-tokens',
|
||||||
|
[LLMFactory.VolcEngine]: 'https://console.volcengine.com/ark',
|
||||||
|
[LLMFactory.Jina]: 'https://jina.ai/embeddings/',
|
||||||
|
[LLMFactory.TencentHunYuan]:
|
||||||
|
'https://console.cloud.tencent.com/hunyuan/api-key',
|
||||||
|
[LLMFactory.TencentCloud]: 'https://console.cloud.tencent.com/cam/capi',
|
||||||
|
[LLMFactory.ModelScope]: 'https://modelscope.cn/my/myaccesstoken',
|
||||||
|
[LLMFactory.GoogleCloud]: 'https://console.cloud.google.com/apis/credentials',
|
||||||
|
[LLMFactory.FishAudio]: 'https://fish.audio/app/api-keys/',
|
||||||
|
[LLMFactory.GiteeAI]:
|
||||||
|
'https://ai.gitee.com/hhxzgrjn/dashboard/settings/tokens',
|
||||||
|
[LLMFactory.StepFun]: 'https://platform.stepfun.com/interface-key',
|
||||||
|
[LLMFactory.BaiChuan]: 'https://platform.baichuan-ai.com/console/apikey',
|
||||||
|
[LLMFactory.PPIO]: 'https://ppio.com/settings/key-management',
|
||||||
|
[LLMFactory.VoyageAI]: 'https://dash.voyageai.com/api-keys',
|
||||||
|
[LLMFactory.TogetherAI]: 'https://api.together.xyz/settings/api-keys',
|
||||||
|
[LLMFactory.NovitaAI]: 'https://novita.ai/dashboard/key',
|
||||||
|
[LLMFactory.Upstage]: 'https://console.upstage.ai/api-keys',
|
||||||
|
[LLMFactory.CometAPI]: 'https://api.cometapi.com/console/token',
|
||||||
|
[LLMFactory.Ai302]: 'https://302.ai/apis/list',
|
||||||
|
[LLMFactory.DeerAPI]: 'https://api.deerapi.com/token',
|
||||||
|
[LLMFactory.TokenPony]: 'https://www.tokenpony.cn/#/user/keys',
|
||||||
|
[LLMFactory.DeepInfra]: 'https://deepinfra.com/dash/api_keys',
|
||||||
|
};
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import { Authorization } from '@/constants/authorization';
|
import { Authorization } from '@/constants/authorization';
|
||||||
import { MessageType } from '@/constants/chat';
|
import { MessageType } from '@/constants/chat';
|
||||||
import { LanguageTranslationMap } from '@/constants/common';
|
import { LanguageTranslationMap } from '@/constants/common';
|
||||||
|
import { Pagination } from '@/interfaces/common';
|
||||||
import { ResponseType } from '@/interfaces/database/base';
|
import { ResponseType } from '@/interfaces/database/base';
|
||||||
import {
|
import {
|
||||||
IAnswer,
|
IAnswer,
|
||||||
@ -12,7 +13,7 @@ import { IKnowledgeFile } from '@/interfaces/database/knowledge';
|
|||||||
import api from '@/utils/api';
|
import api from '@/utils/api';
|
||||||
import { getAuthorization } from '@/utils/authorization-util';
|
import { getAuthorization } from '@/utils/authorization-util';
|
||||||
import { buildMessageUuid } from '@/utils/chat';
|
import { buildMessageUuid } from '@/utils/chat';
|
||||||
import { PaginationProps, message } from 'antd';
|
import { message } from 'antd';
|
||||||
import { FormInstance } from 'antd/lib';
|
import { FormInstance } from 'antd/lib';
|
||||||
import axios from 'axios';
|
import axios from 'axios';
|
||||||
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
import { EventSourceParserStream } from 'eventsource-parser/stream';
|
||||||
@ -71,8 +72,8 @@ export const useGetPaginationWithRouter = () => {
|
|||||||
size: pageSize,
|
size: pageSize,
|
||||||
} = useSetPaginationParams();
|
} = useSetPaginationParams();
|
||||||
|
|
||||||
const onPageChange: PaginationProps['onChange'] = useCallback(
|
const onPageChange: Pagination['onChange'] = useCallback(
|
||||||
(pageNumber: number, pageSize: number) => {
|
(pageNumber: number, pageSize?: number) => {
|
||||||
setPaginationParams(pageNumber, pageSize);
|
setPaginationParams(pageNumber, pageSize);
|
||||||
},
|
},
|
||||||
[setPaginationParams],
|
[setPaginationParams],
|
||||||
@ -88,7 +89,7 @@ export const useGetPaginationWithRouter = () => {
|
|||||||
[setPaginationParams, pageSize],
|
[setPaginationParams, pageSize],
|
||||||
);
|
);
|
||||||
|
|
||||||
const pagination: PaginationProps = useMemo(() => {
|
const pagination: Pagination = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
showQuickJumper: true,
|
showQuickJumper: true,
|
||||||
total: 0,
|
total: 0,
|
||||||
@ -97,7 +98,7 @@ export const useGetPaginationWithRouter = () => {
|
|||||||
pageSize: pageSize,
|
pageSize: pageSize,
|
||||||
pageSizeOptions: [1, 2, 10, 20, 50, 100],
|
pageSizeOptions: [1, 2, 10, 20, 50, 100],
|
||||||
onChange: onPageChange,
|
onChange: onPageChange,
|
||||||
showTotal: (total) => `${t('total')} ${total}`,
|
showTotal: (total: number) => `${t('total')} ${total}`,
|
||||||
};
|
};
|
||||||
}, [t, onPageChange, page, pageSize]);
|
}, [t, onPageChange, page, pageSize]);
|
||||||
|
|
||||||
@ -109,7 +110,7 @@ export const useGetPaginationWithRouter = () => {
|
|||||||
|
|
||||||
export const useHandleSearchChange = () => {
|
export const useHandleSearchChange = () => {
|
||||||
const [searchString, setSearchString] = useState('');
|
const [searchString, setSearchString] = useState('');
|
||||||
const { setPagination } = useGetPaginationWithRouter();
|
const { pagination, setPagination } = useGetPaginationWithRouter();
|
||||||
const handleInputChange = useCallback(
|
const handleInputChange = useCallback(
|
||||||
(e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
(e: React.ChangeEvent<HTMLInputElement | HTMLTextAreaElement>) => {
|
||||||
const value = e.target.value;
|
const value = e.target.value;
|
||||||
@ -119,21 +120,21 @@ export const useHandleSearchChange = () => {
|
|||||||
[setPagination],
|
[setPagination],
|
||||||
);
|
);
|
||||||
|
|
||||||
return { handleInputChange, searchString };
|
return { handleInputChange, searchString, pagination, setPagination };
|
||||||
};
|
};
|
||||||
|
|
||||||
export const useGetPagination = () => {
|
export const useGetPagination = () => {
|
||||||
const [pagination, setPagination] = useState({ page: 1, pageSize: 10 });
|
const [pagination, setPagination] = useState({ page: 1, pageSize: 10 });
|
||||||
const { t } = useTranslate('common');
|
const { t } = useTranslate('common');
|
||||||
|
|
||||||
const onPageChange: PaginationProps['onChange'] = useCallback(
|
const onPageChange: Pagination['onChange'] = useCallback(
|
||||||
(pageNumber: number, pageSize: number) => {
|
(pageNumber: number, pageSize: number) => {
|
||||||
setPagination({ page: pageNumber, pageSize });
|
setPagination({ page: pageNumber, pageSize });
|
||||||
},
|
},
|
||||||
[],
|
[],
|
||||||
);
|
);
|
||||||
|
|
||||||
const currentPagination: PaginationProps = useMemo(() => {
|
const currentPagination: Pagination = useMemo(() => {
|
||||||
return {
|
return {
|
||||||
showQuickJumper: true,
|
showQuickJumper: true,
|
||||||
total: 0,
|
total: 0,
|
||||||
@ -142,7 +143,7 @@ export const useGetPagination = () => {
|
|||||||
pageSize: pagination.pageSize,
|
pageSize: pagination.pageSize,
|
||||||
pageSizeOptions: [1, 2, 10, 20, 50, 100],
|
pageSizeOptions: [1, 2, 10, 20, 50, 100],
|
||||||
onChange: onPageChange,
|
onChange: onPageChange,
|
||||||
showTotal: (total) => `${t('total')} ${total}`,
|
showTotal: (total: number) => `${t('total')} ${total}`,
|
||||||
};
|
};
|
||||||
}, [t, onPageChange, pagination]);
|
}, [t, onPageChange, pagination]);
|
||||||
|
|
||||||
|
|||||||
@ -25,6 +25,17 @@ export const useNavigatePage = () => {
|
|||||||
[navigate],
|
[navigate],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const navigateToMemoryList = useCallback(
|
||||||
|
({ isCreate = false }: { isCreate?: boolean }) => {
|
||||||
|
if (isCreate) {
|
||||||
|
navigate(Routes.Memories + '?isCreate=true');
|
||||||
|
} else {
|
||||||
|
navigate(Routes.Memories);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[navigate],
|
||||||
|
);
|
||||||
|
|
||||||
const navigateToDataset = useCallback(
|
const navigateToDataset = useCallback(
|
||||||
(id: string) => () => {
|
(id: string) => () => {
|
||||||
// navigate(`${Routes.DatasetBase}${Routes.DataSetOverview}/${id}`);
|
// navigate(`${Routes.DatasetBase}${Routes.DataSetOverview}/${id}`);
|
||||||
@ -105,6 +116,12 @@ export const useNavigatePage = () => {
|
|||||||
},
|
},
|
||||||
[navigate],
|
[navigate],
|
||||||
);
|
);
|
||||||
|
const navigateToMemory = useCallback(
|
||||||
|
(id: string) => () => {
|
||||||
|
navigate(`${Routes.Memory}${Routes.MemoryMessage}/${id}`);
|
||||||
|
},
|
||||||
|
[navigate],
|
||||||
|
);
|
||||||
|
|
||||||
const navigateToChunkParsedResult = useCallback(
|
const navigateToChunkParsedResult = useCallback(
|
||||||
(id: string, knowledgeId?: string) => () => {
|
(id: string, knowledgeId?: string) => () => {
|
||||||
@ -196,5 +213,7 @@ export const useNavigatePage = () => {
|
|||||||
navigateToDataflowResult,
|
navigateToDataflowResult,
|
||||||
navigateToDataFile,
|
navigateToDataFile,
|
||||||
navigateToDataSourceDetail,
|
navigateToDataSourceDetail,
|
||||||
|
navigateToMemory,
|
||||||
|
navigateToMemoryList,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -2,6 +2,7 @@ export interface Pagination {
|
|||||||
current: number;
|
current: number;
|
||||||
pageSize: number;
|
pageSize: number;
|
||||||
total: number;
|
total: number;
|
||||||
|
onChange?: (page: number, pageSize: number) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface BaseState {
|
export interface BaseState {
|
||||||
|
|||||||
@ -99,6 +99,29 @@ export default {
|
|||||||
search: 'Search',
|
search: 'Search',
|
||||||
welcome: 'Welcome to',
|
welcome: 'Welcome to',
|
||||||
dataset: 'Dataset',
|
dataset: 'Dataset',
|
||||||
|
Memories: 'Memory',
|
||||||
|
},
|
||||||
|
memory: {
|
||||||
|
memory: 'Memory',
|
||||||
|
createMemory: 'Create Memory',
|
||||||
|
name: 'Name',
|
||||||
|
memoryNamePlaceholder: 'memory name',
|
||||||
|
memoryType: 'Memory type',
|
||||||
|
embeddingModel: 'Embedding model',
|
||||||
|
selectModel: 'Select model',
|
||||||
|
llm: 'LLM',
|
||||||
|
},
|
||||||
|
memoryDetail: {
|
||||||
|
messages: {
|
||||||
|
sessionId: 'Session ID',
|
||||||
|
agent: 'Agent',
|
||||||
|
type: 'Type',
|
||||||
|
validDate: 'Valid date',
|
||||||
|
forgetAt: 'Forget at',
|
||||||
|
source: 'Source',
|
||||||
|
enable: 'Enable',
|
||||||
|
action: 'Action',
|
||||||
|
},
|
||||||
},
|
},
|
||||||
knowledgeList: {
|
knowledgeList: {
|
||||||
welcome: 'Welcome back',
|
welcome: 'Welcome back',
|
||||||
@ -2044,14 +2067,21 @@ Important structured information may include: names, dates, locations, events, k
|
|||||||
delFilesContent: 'Selected {{count}} files',
|
delFilesContent: 'Selected {{count}} files',
|
||||||
delChat: 'Delete chat',
|
delChat: 'Delete chat',
|
||||||
delMember: 'Delete member',
|
delMember: 'Delete member',
|
||||||
|
delMemory: 'Delete memory',
|
||||||
},
|
},
|
||||||
|
|
||||||
empty: {
|
empty: {
|
||||||
noMCP: 'No MCP servers available',
|
noMCP: 'No MCP servers available',
|
||||||
agentTitle: 'No agent app created yet',
|
agentTitle: 'No agent app created yet',
|
||||||
|
notFoundAgent: 'Agent app not found',
|
||||||
datasetTitle: 'No dataset created yet',
|
datasetTitle: 'No dataset created yet',
|
||||||
|
notFoundDataset: 'Dataset not found',
|
||||||
chatTitle: 'No chat app created yet',
|
chatTitle: 'No chat app created yet',
|
||||||
|
notFoundChat: 'Chat app not found',
|
||||||
searchTitle: 'No search app created yet',
|
searchTitle: 'No search app created yet',
|
||||||
|
notFoundSearch: 'Search app not found',
|
||||||
|
memoryTitle: 'No memory created yet',
|
||||||
|
notFoundMemory: 'Memory not found',
|
||||||
addNow: 'Add Now',
|
addNow: 'Add Now',
|
||||||
},
|
},
|
||||||
|
|
||||||
|
|||||||
@ -1900,9 +1900,15 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
|
|||||||
empty: {
|
empty: {
|
||||||
noMCP: '暂无 MCP 服务器可用',
|
noMCP: '暂无 MCP 服务器可用',
|
||||||
agentTitle: '尚未创建智能体',
|
agentTitle: '尚未创建智能体',
|
||||||
|
notFoundAgent: '未查询到智能体',
|
||||||
datasetTitle: '尚未创建数据集',
|
datasetTitle: '尚未创建数据集',
|
||||||
|
notFoundDataset: '未查询到数据集',
|
||||||
chatTitle: '尚未创建聊天应用',
|
chatTitle: '尚未创建聊天应用',
|
||||||
|
notFoundChat: '未查询到聊天应用',
|
||||||
searchTitle: '尚未创建搜索应用',
|
searchTitle: '尚未创建搜索应用',
|
||||||
|
notFoundSearch: '未查询到搜索应用',
|
||||||
|
memoryTitle: '尚未创建记忆',
|
||||||
|
notFoundMemory: '未查询到记忆',
|
||||||
addNow: '立即添加',
|
addNow: '立即添加',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import HightLightMarkdown from '@/components/highlight-markdown';
|
import HighLightMarkdown from '@/components/highlight-markdown';
|
||||||
import {
|
import {
|
||||||
Timeline,
|
Timeline,
|
||||||
TimelineContent,
|
TimelineContent,
|
||||||
@ -327,9 +327,9 @@ export const WorkFlowTimeline = ({
|
|||||||
<AccordionContent>
|
<AccordionContent>
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<div className="w-full h-[200px] break-words overflow-auto scrollbar-auto p-2 bg-muted">
|
<div className="w-full h-[200px] break-words overflow-auto scrollbar-auto p-2 bg-muted">
|
||||||
<HightLightMarkdown>
|
<HighLightMarkdown>
|
||||||
{x.data.thoughts || ''}
|
{x.data.thoughts || ''}
|
||||||
</HightLightMarkdown>
|
</HighLightMarkdown>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</AccordionContent>
|
</AccordionContent>
|
||||||
|
|||||||
@ -81,19 +81,20 @@ export default function Agents() {
|
|||||||
}, [isCreate, showCreatingModal, searchUrl, setSearchUrl]);
|
}, [isCreate, showCreatingModal, searchUrl, setSearchUrl]);
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
{(!data?.length || data?.length <= 0) && (
|
{(!data?.length || data?.length <= 0) && !searchString && (
|
||||||
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
<EmptyAppCard
|
<EmptyAppCard
|
||||||
showIcon
|
showIcon
|
||||||
size="large"
|
size="large"
|
||||||
className="w-[480px] p-14"
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
type={EmptyCardType.Agent}
|
type={EmptyCardType.Agent}
|
||||||
onClick={() => showCreatingModal()}
|
onClick={() => showCreatingModal()}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<section className="flex flex-col w-full flex-1">
|
<section className="flex flex-col w-full flex-1">
|
||||||
{!!data?.length && (
|
{(!!data?.length || searchString) && (
|
||||||
<>
|
<>
|
||||||
<div className="px-8 pt-8 ">
|
<div className="px-8 pt-8 ">
|
||||||
<ListFilterBar
|
<ListFilterBar
|
||||||
@ -138,6 +139,18 @@ export default function Agents() {
|
|||||||
</DropdownMenu>
|
</DropdownMenu>
|
||||||
</ListFilterBar>
|
</ListFilterBar>
|
||||||
</div>
|
</div>
|
||||||
|
{(!data?.length || data?.length <= 0) && searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
|
type={EmptyCardType.Agent}
|
||||||
|
onClick={() => showCreatingModal()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="flex-1 overflow-auto">
|
<div className="flex-1 overflow-auto">
|
||||||
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
||||||
{data.map((x) => {
|
{data.map((x) => {
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import { Switch } from '@/components/ui/switch';
|
|||||||
import { useTranslate } from '@/hooks/common-hooks';
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
import { useMemo, useState } from 'react';
|
import { useMemo, useState } from 'react';
|
||||||
import { useFormContext } from 'react-hook-form';
|
import { FieldValues, useFormContext } from 'react-hook-form';
|
||||||
import {
|
import {
|
||||||
useHandleKbEmbedding,
|
useHandleKbEmbedding,
|
||||||
useHasParsedDocument,
|
useHasParsedDocument,
|
||||||
@ -65,17 +65,59 @@ export function ChunkMethodItem(props: IProps) {
|
|||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
export function EmbeddingModelItem({ line = 1, isEdit }: IProps) {
|
|
||||||
|
export const EmbeddingSelect = ({
|
||||||
|
isEdit,
|
||||||
|
field,
|
||||||
|
name,
|
||||||
|
}: {
|
||||||
|
isEdit: boolean;
|
||||||
|
field: FieldValues;
|
||||||
|
name?: string;
|
||||||
|
}) => {
|
||||||
const { t } = useTranslate('knowledgeConfiguration');
|
const { t } = useTranslate('knowledgeConfiguration');
|
||||||
const form = useFormContext();
|
const form = useFormContext();
|
||||||
const embeddingModelOptions = useSelectEmbeddingModelOptions();
|
const embeddingModelOptions = useSelectEmbeddingModelOptions();
|
||||||
const { handleChange } = useHandleKbEmbedding();
|
const { handleChange } = useHandleKbEmbedding();
|
||||||
const disabled = useHasParsedDocument(isEdit);
|
const disabled = useHasParsedDocument(isEdit);
|
||||||
const oldValue = useMemo(() => {
|
const oldValue = useMemo(() => {
|
||||||
const embdStr = form.getValues('embd_id');
|
const embdStr = form.getValues(name || 'embd_id');
|
||||||
return embdStr || '';
|
return embdStr || '';
|
||||||
}, [form]);
|
}, [form]);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
|
return (
|
||||||
|
<Spin
|
||||||
|
spinning={loading}
|
||||||
|
className={cn(' rounded-lg after:bg-bg-base', {
|
||||||
|
'opacity-20': loading,
|
||||||
|
})}
|
||||||
|
>
|
||||||
|
<SelectWithSearch
|
||||||
|
onChange={async (value) => {
|
||||||
|
field.onChange(value);
|
||||||
|
if (isEdit && disabled) {
|
||||||
|
setLoading(true);
|
||||||
|
const res = await handleChange({
|
||||||
|
embed_id: value,
|
||||||
|
callback: field.onChange,
|
||||||
|
});
|
||||||
|
if (res.code !== 0) {
|
||||||
|
field.onChange(oldValue);
|
||||||
|
}
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
value={field.value}
|
||||||
|
options={embeddingModelOptions}
|
||||||
|
placeholder={t('embeddingModelPlaceholder')}
|
||||||
|
/>
|
||||||
|
</Spin>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export function EmbeddingModelItem({ line = 1, isEdit }: IProps) {
|
||||||
|
const { t } = useTranslate('knowledgeConfiguration');
|
||||||
|
const form = useFormContext();
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<FormField
|
<FormField
|
||||||
@ -102,33 +144,10 @@ export function EmbeddingModelItem({ line = 1, isEdit }: IProps) {
|
|||||||
className={cn('text-muted-foreground', { 'w-3/4': line === 1 })}
|
className={cn('text-muted-foreground', { 'w-3/4': line === 1 })}
|
||||||
>
|
>
|
||||||
<FormControl>
|
<FormControl>
|
||||||
<Spin
|
<EmbeddingSelect
|
||||||
spinning={loading}
|
isEdit={!!isEdit}
|
||||||
className={cn(' rounded-lg after:bg-bg-base', {
|
field={field}
|
||||||
'opacity-20': loading,
|
></EmbeddingSelect>
|
||||||
})}
|
|
||||||
>
|
|
||||||
<SelectWithSearch
|
|
||||||
onChange={async (value) => {
|
|
||||||
field.onChange(value);
|
|
||||||
if (isEdit && disabled) {
|
|
||||||
setLoading(true);
|
|
||||||
const res = await handleChange({
|
|
||||||
embed_id: value,
|
|
||||||
callback: field.onChange,
|
|
||||||
});
|
|
||||||
if (res.code !== 0) {
|
|
||||||
field.onChange(oldValue);
|
|
||||||
}
|
|
||||||
setLoading(false);
|
|
||||||
}
|
|
||||||
}}
|
|
||||||
value={field.value}
|
|
||||||
options={embeddingModelOptions}
|
|
||||||
placeholder={t('embeddingModelPlaceholder')}
|
|
||||||
triggerClassName="!bg-bg-base"
|
|
||||||
/>
|
|
||||||
</Spin>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|||||||
@ -70,18 +70,19 @@ export default function Datasets() {
|
|||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<section className="py-4 flex-1 flex flex-col">
|
<section className="py-4 flex-1 flex flex-col">
|
||||||
{(!kbs?.length || kbs?.length <= 0) && (
|
{(!kbs?.length || kbs?.length <= 0) && !searchString && (
|
||||||
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
<EmptyAppCard
|
<EmptyAppCard
|
||||||
showIcon
|
showIcon
|
||||||
size="large"
|
size="large"
|
||||||
className="w-[480px] p-14"
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
type={EmptyCardType.Dataset}
|
type={EmptyCardType.Dataset}
|
||||||
onClick={() => showModal()}
|
onClick={() => showModal()}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{!!kbs?.length && (
|
{(!!kbs?.length || searchString) && (
|
||||||
<>
|
<>
|
||||||
<ListFilterBar
|
<ListFilterBar
|
||||||
title={t('header.dataset')}
|
title={t('header.dataset')}
|
||||||
@ -98,6 +99,18 @@ export default function Datasets() {
|
|||||||
{t('knowledgeList.createKnowledgeBase')}
|
{t('knowledgeList.createKnowledgeBase')}
|
||||||
</Button>
|
</Button>
|
||||||
</ListFilterBar>
|
</ListFilterBar>
|
||||||
|
{(!kbs?.length || kbs?.length <= 0) && searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
|
type={EmptyCardType.Dataset}
|
||||||
|
onClick={() => showModal()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
||||||
{kbs.map((dataset) => {
|
{kbs.map((dataset) => {
|
||||||
|
|||||||
75
web/src/pages/memories/add-or-edit-modal.tsx
Normal file
75
web/src/pages/memories/add-or-edit-modal.tsx
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import { DynamicForm, DynamicFormRef } from '@/components/dynamic-form';
|
||||||
|
import { useModelOptions } from '@/components/llm-setting-items/llm-form-field';
|
||||||
|
import { HomeIcon } from '@/components/svg-icon';
|
||||||
|
import { Modal } from '@/components/ui/modal/modal';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { useCallback, useEffect, useState } from 'react';
|
||||||
|
import { createMemoryFields } from './constants';
|
||||||
|
import { IMemory } from './interface';
|
||||||
|
|
||||||
|
type IProps = {
|
||||||
|
open: boolean;
|
||||||
|
onClose: () => void;
|
||||||
|
onSubmit?: (data: any) => void;
|
||||||
|
initialMemory: IMemory;
|
||||||
|
loading?: boolean;
|
||||||
|
};
|
||||||
|
export const AddOrEditModal = (props: IProps) => {
|
||||||
|
const { open, onClose, onSubmit, initialMemory } = props;
|
||||||
|
// const [fields, setFields] = useState<FormFieldConfig[]>(createMemoryFields);
|
||||||
|
// const formRef = useRef<DynamicFormRef>(null);
|
||||||
|
const [formInstance, setFormInstance] = useState<DynamicFormRef | null>(null);
|
||||||
|
|
||||||
|
const formCallbackRef = useCallback((node: DynamicFormRef | null) => {
|
||||||
|
if (node) {
|
||||||
|
// formRef.current = node;
|
||||||
|
setFormInstance(node);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
const { modelOptions } = useModelOptions();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (initialMemory && initialMemory.id) {
|
||||||
|
formInstance?.onFieldUpdate('memory_type', { hidden: true });
|
||||||
|
formInstance?.onFieldUpdate('embedding', { hidden: true });
|
||||||
|
formInstance?.onFieldUpdate('llm', { hidden: true });
|
||||||
|
} else {
|
||||||
|
formInstance?.onFieldUpdate('llm', { options: modelOptions as any });
|
||||||
|
}
|
||||||
|
}, [modelOptions, formInstance, initialMemory]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Modal
|
||||||
|
open={open}
|
||||||
|
onOpenChange={onClose}
|
||||||
|
className="!w-[480px]"
|
||||||
|
title={
|
||||||
|
<div className="flex flex-col">
|
||||||
|
<div>
|
||||||
|
<HomeIcon name="memory" width={'24'} />
|
||||||
|
</div>
|
||||||
|
{t('memory.createMemory')}
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
showfooter={false}
|
||||||
|
confirmLoading={props.loading}
|
||||||
|
>
|
||||||
|
<DynamicForm.Root
|
||||||
|
ref={formCallbackRef}
|
||||||
|
fields={createMemoryFields}
|
||||||
|
onSubmit={() => {}}
|
||||||
|
defaultValues={initialMemory}
|
||||||
|
>
|
||||||
|
<div className="flex justify-end gap-2 pb-5">
|
||||||
|
<DynamicForm.CancelButton handleCancel={onClose} />
|
||||||
|
<DynamicForm.SavingButton
|
||||||
|
submitLoading={false}
|
||||||
|
submitFunc={(data) => {
|
||||||
|
onSubmit?.(data);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</DynamicForm.Root>
|
||||||
|
</Modal>
|
||||||
|
);
|
||||||
|
};
|
||||||
41
web/src/pages/memories/constants/index.tsx
Normal file
41
web/src/pages/memories/constants/index.tsx
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import { FormFieldConfig, FormFieldType } from '@/components/dynamic-form';
|
||||||
|
import { EmbeddingSelect } from '@/pages/dataset/dataset-setting/configuration/common-item';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
|
||||||
|
export const createMemoryFields = [
|
||||||
|
{
|
||||||
|
name: 'memory_name',
|
||||||
|
label: t('memory.name'),
|
||||||
|
placeholder: t('memory.memoryNamePlaceholder'),
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'memory_type',
|
||||||
|
label: t('memory.memoryType'),
|
||||||
|
type: FormFieldType.MultiSelect,
|
||||||
|
placeholder: t('memory.descriptionPlaceholder'),
|
||||||
|
options: [
|
||||||
|
{ label: 'Raw', value: 'raw' },
|
||||||
|
{ label: 'Semantic', value: 'semantic' },
|
||||||
|
{ label: 'Episodic', value: 'episodic' },
|
||||||
|
{ label: 'Procedural', value: 'procedural' },
|
||||||
|
],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'embedding',
|
||||||
|
label: t('memory.embeddingModel'),
|
||||||
|
placeholder: t('memory.selectModel'),
|
||||||
|
required: true,
|
||||||
|
// hideLabel: true,
|
||||||
|
// type: 'custom',
|
||||||
|
render: (field) => <EmbeddingSelect field={field} isEdit={false} />,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'llm',
|
||||||
|
label: t('memory.llm'),
|
||||||
|
placeholder: t('memory.selectModel'),
|
||||||
|
required: true,
|
||||||
|
type: FormFieldType.Select,
|
||||||
|
},
|
||||||
|
] as FormFieldConfig[];
|
||||||
288
web/src/pages/memories/hooks.ts
Normal file
288
web/src/pages/memories/hooks.ts
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
// src/pages/next-memoryes/hooks.ts
|
||||||
|
|
||||||
|
import message from '@/components/ui/message';
|
||||||
|
import { useSetModalState } from '@/hooks/common-hooks';
|
||||||
|
import { useHandleSearchChange } from '@/hooks/logic-hooks';
|
||||||
|
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||||
|
import memoryService, { updateMemoryById } from '@/services/memory-service';
|
||||||
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
|
import { useDebounce } from 'ahooks';
|
||||||
|
import { useCallback, useState } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useParams, useSearchParams } from 'umi';
|
||||||
|
import {
|
||||||
|
CreateMemoryResponse,
|
||||||
|
DeleteMemoryProps,
|
||||||
|
DeleteMemoryResponse,
|
||||||
|
ICreateMemoryProps,
|
||||||
|
IMemory,
|
||||||
|
IMemoryAppDetailProps,
|
||||||
|
MemoryDetailResponse,
|
||||||
|
MemoryListResponse,
|
||||||
|
} from './interface';
|
||||||
|
|
||||||
|
export const useCreateMemory = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const {
|
||||||
|
data,
|
||||||
|
isError,
|
||||||
|
mutateAsync: createMemoryMutation,
|
||||||
|
} = useMutation<CreateMemoryResponse, Error, ICreateMemoryProps>({
|
||||||
|
mutationKey: ['createMemory'],
|
||||||
|
mutationFn: async (props) => {
|
||||||
|
const { data: response } = await memoryService.createMemory(props);
|
||||||
|
if (response.code !== 0) {
|
||||||
|
throw new Error(response.message || 'Failed to create memory');
|
||||||
|
}
|
||||||
|
return response.data;
|
||||||
|
},
|
||||||
|
onSuccess: () => {
|
||||||
|
message.success(t('message.created'));
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
message.error(t('message.error', { error: error.message }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const createMemory = useCallback(
|
||||||
|
(props: ICreateMemoryProps) => {
|
||||||
|
return createMemoryMutation(props);
|
||||||
|
},
|
||||||
|
[createMemoryMutation],
|
||||||
|
);
|
||||||
|
|
||||||
|
return { data, isError, createMemory };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useFetchMemoryList = () => {
|
||||||
|
const { handleInputChange, searchString, pagination, setPagination } =
|
||||||
|
useHandleSearchChange();
|
||||||
|
const debouncedSearchString = useDebounce(searchString, { wait: 500 });
|
||||||
|
const { data, isLoading, isError, refetch } = useQuery<
|
||||||
|
MemoryListResponse,
|
||||||
|
Error
|
||||||
|
>({
|
||||||
|
queryKey: [
|
||||||
|
'memoryList',
|
||||||
|
{
|
||||||
|
debouncedSearchString,
|
||||||
|
...pagination,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
queryFn: async () => {
|
||||||
|
const { data: response } = await memoryService.getMemoryList(
|
||||||
|
{
|
||||||
|
params: {
|
||||||
|
keywords: debouncedSearchString,
|
||||||
|
page_size: pagination.pageSize,
|
||||||
|
page: pagination.current,
|
||||||
|
},
|
||||||
|
data: {},
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
);
|
||||||
|
if (response.code !== 0) {
|
||||||
|
throw new Error(response.message || 'Failed to fetch memory list');
|
||||||
|
}
|
||||||
|
console.log(response);
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
// const setMemoryListParams = (newParams: MemoryListParams) => {
|
||||||
|
// setMemoryParams((prevParams) => ({
|
||||||
|
// ...prevParams,
|
||||||
|
// ...newParams,
|
||||||
|
// }));
|
||||||
|
// };
|
||||||
|
|
||||||
|
return {
|
||||||
|
data,
|
||||||
|
isLoading,
|
||||||
|
isError,
|
||||||
|
pagination,
|
||||||
|
searchString,
|
||||||
|
handleInputChange,
|
||||||
|
setPagination,
|
||||||
|
refetch,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useFetchMemoryDetail = (tenantId?: string) => {
|
||||||
|
const { id } = useParams();
|
||||||
|
|
||||||
|
const [memoryParams] = useSearchParams();
|
||||||
|
const shared_id = memoryParams.get('shared_id');
|
||||||
|
const memoryId = id || shared_id;
|
||||||
|
let param: { id: string | null; tenant_id?: string } = {
|
||||||
|
id: memoryId,
|
||||||
|
};
|
||||||
|
if (shared_id) {
|
||||||
|
param = {
|
||||||
|
id: memoryId,
|
||||||
|
tenant_id: tenantId,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
const fetchMemoryDetailFunc = shared_id
|
||||||
|
? memoryService.getMemoryDetailShare
|
||||||
|
: memoryService.getMemoryDetail;
|
||||||
|
|
||||||
|
const { data, isLoading, isError } = useQuery<MemoryDetailResponse, Error>({
|
||||||
|
queryKey: ['memoryDetail', memoryId],
|
||||||
|
enabled: !shared_id || !!tenantId,
|
||||||
|
queryFn: async () => {
|
||||||
|
const { data: response } = await fetchMemoryDetailFunc(param);
|
||||||
|
if (response.code !== 0) {
|
||||||
|
throw new Error(response.message || 'Failed to fetch memory detail');
|
||||||
|
}
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return { data: data?.data, isLoading, isError };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useDeleteMemory = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const {
|
||||||
|
data,
|
||||||
|
isError,
|
||||||
|
mutateAsync: deleteMemoryMutation,
|
||||||
|
} = useMutation<DeleteMemoryResponse, Error, DeleteMemoryProps>({
|
||||||
|
mutationKey: ['deleteMemory'],
|
||||||
|
mutationFn: async (props) => {
|
||||||
|
const { data: response } = await memoryService.deleteMemory(props);
|
||||||
|
if (response.code !== 0) {
|
||||||
|
throw new Error(response.message || 'Failed to delete memory');
|
||||||
|
}
|
||||||
|
|
||||||
|
queryClient.invalidateQueries({ queryKey: ['memoryList'] });
|
||||||
|
return response;
|
||||||
|
},
|
||||||
|
onSuccess: () => {
|
||||||
|
message.success(t('message.deleted'));
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
message.error(t('message.error', { error: error.message }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const deleteMemory = useCallback(
|
||||||
|
(props: DeleteMemoryProps) => {
|
||||||
|
return deleteMemoryMutation(props);
|
||||||
|
},
|
||||||
|
[deleteMemoryMutation],
|
||||||
|
);
|
||||||
|
|
||||||
|
return { data, isError, deleteMemory };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useUpdateMemory = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
const {
|
||||||
|
data,
|
||||||
|
isError,
|
||||||
|
mutateAsync: updateMemoryMutation,
|
||||||
|
} = useMutation<any, Error, IMemoryAppDetailProps>({
|
||||||
|
mutationKey: ['updateMemory'],
|
||||||
|
mutationFn: async (formData) => {
|
||||||
|
const { data: response } = await updateMemoryById(formData.id, formData);
|
||||||
|
if (response.code !== 0) {
|
||||||
|
throw new Error(response.message || 'Failed to update memory');
|
||||||
|
}
|
||||||
|
return response.data;
|
||||||
|
},
|
||||||
|
onSuccess: (data, variables) => {
|
||||||
|
message.success(t('message.updated'));
|
||||||
|
queryClient.invalidateQueries({
|
||||||
|
queryKey: ['memoryDetail', variables.id],
|
||||||
|
});
|
||||||
|
},
|
||||||
|
onError: (error) => {
|
||||||
|
message.error(t('message.error', { error: error.message }));
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
const updateMemory = useCallback(
|
||||||
|
(formData: IMemoryAppDetailProps) => {
|
||||||
|
return updateMemoryMutation(formData);
|
||||||
|
},
|
||||||
|
[updateMemoryMutation],
|
||||||
|
);
|
||||||
|
|
||||||
|
return { data, isError, updateMemory };
|
||||||
|
};
|
||||||
|
|
||||||
|
export const useRenameMemory = () => {
|
||||||
|
const [memory, setMemory] = useState<IMemory>({} as IMemory);
|
||||||
|
const { navigateToMemory } = useNavigatePage();
|
||||||
|
const {
|
||||||
|
visible: openCreateModal,
|
||||||
|
hideModal: hideChatRenameModal,
|
||||||
|
showModal: showChatRenameModal,
|
||||||
|
} = useSetModalState();
|
||||||
|
const { updateMemory } = useUpdateMemory();
|
||||||
|
const { createMemory } = useCreateMemory();
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
const handleShowChatRenameModal = useCallback(
|
||||||
|
(record?: IMemory) => {
|
||||||
|
if (record) {
|
||||||
|
setMemory(record);
|
||||||
|
}
|
||||||
|
showChatRenameModal();
|
||||||
|
},
|
||||||
|
[showChatRenameModal],
|
||||||
|
);
|
||||||
|
|
||||||
|
const handleHideModal = useCallback(() => {
|
||||||
|
hideChatRenameModal();
|
||||||
|
setMemory({} as IMemory);
|
||||||
|
}, [hideChatRenameModal]);
|
||||||
|
|
||||||
|
const onMemoryRenameOk = useCallback(
|
||||||
|
async (data: ICreateMemoryProps, callBack?: () => void) => {
|
||||||
|
let res;
|
||||||
|
setLoading(true);
|
||||||
|
if (memory?.id) {
|
||||||
|
try {
|
||||||
|
// const reponse = await memoryService.getMemoryDetail({
|
||||||
|
// id: memory?.id,
|
||||||
|
// });
|
||||||
|
// const detail = reponse.data?.data;
|
||||||
|
// console.log('detail-->', detail);
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
// const { id, created_by, update_time, ...memoryDataTemp } = detail;
|
||||||
|
res = await updateMemory({
|
||||||
|
// ...memoryDataTemp,
|
||||||
|
name: data.memory_name,
|
||||||
|
id: memory?.id,
|
||||||
|
} as unknown as IMemoryAppDetailProps);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('error', e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res = await createMemory(data);
|
||||||
|
}
|
||||||
|
if (res && !memory?.id) {
|
||||||
|
navigateToMemory(res?.id)();
|
||||||
|
}
|
||||||
|
callBack?.();
|
||||||
|
setLoading(false);
|
||||||
|
handleHideModal();
|
||||||
|
},
|
||||||
|
[memory, createMemory, handleHideModal, navigateToMemory, updateMemory],
|
||||||
|
);
|
||||||
|
return {
|
||||||
|
memoryRenameLoading: loading,
|
||||||
|
initialMemory: memory,
|
||||||
|
onMemoryRenameOk,
|
||||||
|
openCreateModal,
|
||||||
|
hideMemoryModal: handleHideModal,
|
||||||
|
showMemoryRenameModal: handleShowChatRenameModal,
|
||||||
|
};
|
||||||
|
};
|
||||||
163
web/src/pages/memories/index.tsx
Normal file
163
web/src/pages/memories/index.tsx
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import { CardContainer } from '@/components/card-container';
|
||||||
|
import { EmptyCardType } from '@/components/empty/constant';
|
||||||
|
import { EmptyAppCard } from '@/components/empty/empty';
|
||||||
|
import ListFilterBar from '@/components/list-filter-bar';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { RAGFlowPagination } from '@/components/ui/ragflow-pagination';
|
||||||
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
|
import { pick } from 'lodash';
|
||||||
|
import { Plus } from 'lucide-react';
|
||||||
|
import { useCallback, useEffect } from 'react';
|
||||||
|
import { useSearchParams } from 'umi';
|
||||||
|
import { AddOrEditModal } from './add-or-edit-modal';
|
||||||
|
import { useFetchMemoryList, useRenameMemory } from './hooks';
|
||||||
|
import { ICreateMemoryProps } from './interface';
|
||||||
|
import { MemoryCard } from './memory-card';
|
||||||
|
|
||||||
|
export default function MemoryList() {
|
||||||
|
// const { data } = useFetchFlowList();
|
||||||
|
const { t } = useTranslate('memory');
|
||||||
|
// const [isEdit, setIsEdit] = useState(false);
|
||||||
|
const {
|
||||||
|
data: list,
|
||||||
|
pagination,
|
||||||
|
searchString,
|
||||||
|
handleInputChange,
|
||||||
|
setPagination,
|
||||||
|
refetch: refetchList,
|
||||||
|
} = useFetchMemoryList();
|
||||||
|
|
||||||
|
const {
|
||||||
|
openCreateModal,
|
||||||
|
showMemoryRenameModal,
|
||||||
|
hideMemoryModal,
|
||||||
|
searchRenameLoading,
|
||||||
|
onMemoryRenameOk,
|
||||||
|
initialMemory,
|
||||||
|
} = useRenameMemory();
|
||||||
|
|
||||||
|
const onMemoryConfirm = (data: ICreateMemoryProps) => {
|
||||||
|
onMemoryRenameOk(data, () => {
|
||||||
|
refetchList();
|
||||||
|
});
|
||||||
|
};
|
||||||
|
const openCreateModalFun = useCallback(() => {
|
||||||
|
// setIsEdit(false);
|
||||||
|
showMemoryRenameModal();
|
||||||
|
}, [showMemoryRenameModal]);
|
||||||
|
const handlePageChange = useCallback(
|
||||||
|
(page: number, pageSize?: number) => {
|
||||||
|
setPagination({ page, pageSize });
|
||||||
|
},
|
||||||
|
[setPagination],
|
||||||
|
);
|
||||||
|
|
||||||
|
const [searchUrl, setMemoryUrl] = useSearchParams();
|
||||||
|
const isCreate = searchUrl.get('isCreate') === 'true';
|
||||||
|
useEffect(() => {
|
||||||
|
if (isCreate) {
|
||||||
|
openCreateModalFun();
|
||||||
|
searchUrl.delete('isCreate');
|
||||||
|
setMemoryUrl(searchUrl);
|
||||||
|
}
|
||||||
|
}, [isCreate, openCreateModalFun, searchUrl, setMemoryUrl]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<section className="w-full h-full flex flex-col">
|
||||||
|
{(!list?.data?.memory_list?.length ||
|
||||||
|
list?.data?.memory_list?.length <= 0) &&
|
||||||
|
!searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
|
type={EmptyCardType.Memory}
|
||||||
|
onClick={() => openCreateModalFun()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
{(!!list?.data?.memory_list?.length || searchString) && (
|
||||||
|
<>
|
||||||
|
<div className="px-8 pt-8">
|
||||||
|
<ListFilterBar
|
||||||
|
icon="memory"
|
||||||
|
title={t('memory')}
|
||||||
|
showFilter={false}
|
||||||
|
onSearchChange={handleInputChange}
|
||||||
|
searchString={searchString}
|
||||||
|
>
|
||||||
|
<Button
|
||||||
|
variant={'default'}
|
||||||
|
onClick={() => {
|
||||||
|
openCreateModalFun();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<Plus className="mr-2 h-4 w-4" />
|
||||||
|
{t('createMemory')}
|
||||||
|
</Button>
|
||||||
|
</ListFilterBar>
|
||||||
|
</div>
|
||||||
|
{(!list?.data?.memory_list?.length ||
|
||||||
|
list?.data?.memory_list?.length <= 0) &&
|
||||||
|
searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
|
type={EmptyCardType.Memory}
|
||||||
|
onClick={() => openCreateModalFun()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="flex-1">
|
||||||
|
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
||||||
|
{list?.data.memory_list.map((x) => {
|
||||||
|
return (
|
||||||
|
<MemoryCard
|
||||||
|
key={x.id}
|
||||||
|
data={x}
|
||||||
|
showMemoryRenameModal={() => {
|
||||||
|
showMemoryRenameModal(x);
|
||||||
|
}}
|
||||||
|
></MemoryCard>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</CardContainer>
|
||||||
|
</div>
|
||||||
|
{list?.data.total && list?.data.total > 0 && (
|
||||||
|
<div className="px-8 mb-4">
|
||||||
|
<RAGFlowPagination
|
||||||
|
{...pick(pagination, 'current', 'pageSize')}
|
||||||
|
// total={pagination.total}
|
||||||
|
total={list?.data.total}
|
||||||
|
onChange={handlePageChange}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
{/* {openCreateModal && (
|
||||||
|
<RenameDialog
|
||||||
|
hideModal={hideMemoryRenameModal}
|
||||||
|
onOk={onMemoryRenameConfirm}
|
||||||
|
initialName={initialMemoryName}
|
||||||
|
loading={searchRenameLoading}
|
||||||
|
title={<HomeIcon name="memory" width={'24'} />}
|
||||||
|
></RenameDialog>
|
||||||
|
)} */}
|
||||||
|
{openCreateModal && (
|
||||||
|
<AddOrEditModal
|
||||||
|
initialMemory={initialMemory}
|
||||||
|
open={openCreateModal}
|
||||||
|
loading={searchRenameLoading}
|
||||||
|
onClose={hideMemoryModal}
|
||||||
|
onSubmit={onMemoryConfirm}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</section>
|
||||||
|
);
|
||||||
|
}
|
||||||
121
web/src/pages/memories/interface.ts
Normal file
121
web/src/pages/memories/interface.ts
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
export interface ICreateMemoryProps {
|
||||||
|
memory_name: string;
|
||||||
|
memory_type: Array<string>;
|
||||||
|
embedding: string;
|
||||||
|
llm: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CreateMemoryResponse {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
description: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MemoryListParams {
|
||||||
|
keywords?: string;
|
||||||
|
parser_id?: string;
|
||||||
|
page?: number;
|
||||||
|
page_size?: number;
|
||||||
|
orderby?: string;
|
||||||
|
desc?: boolean;
|
||||||
|
owner_ids?: string;
|
||||||
|
}
|
||||||
|
export type MemoryType = 'raw' | 'semantic' | 'episodic' | 'procedural';
|
||||||
|
export type StorageType = 'table' | 'graph';
|
||||||
|
export type Permissions = 'me' | 'team';
|
||||||
|
export type ForgettingPolicy = 'fifo' | 'lru';
|
||||||
|
|
||||||
|
export interface IMemory {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
avatar: string;
|
||||||
|
tenant_id: string;
|
||||||
|
owner_name: string;
|
||||||
|
memory_type: MemoryType[];
|
||||||
|
storage_type: StorageType;
|
||||||
|
embedding: string;
|
||||||
|
llm: string;
|
||||||
|
permissions: Permissions;
|
||||||
|
description: string;
|
||||||
|
memory_size: number;
|
||||||
|
forgetting_policy: ForgettingPolicy;
|
||||||
|
temperature: string;
|
||||||
|
system_prompt: string;
|
||||||
|
user_prompt: string;
|
||||||
|
}
|
||||||
|
export interface MemoryListResponse {
|
||||||
|
code: number;
|
||||||
|
data: {
|
||||||
|
memory_list: Array<IMemory>;
|
||||||
|
total: number;
|
||||||
|
};
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DeleteMemoryProps {
|
||||||
|
memory_id: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DeleteMemoryResponse {
|
||||||
|
code: number;
|
||||||
|
data: boolean;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IllmSettingProps {
|
||||||
|
llm_id: string;
|
||||||
|
parameter: string;
|
||||||
|
temperature?: number;
|
||||||
|
top_p?: number;
|
||||||
|
frequency_penalty?: number;
|
||||||
|
presence_penalty?: number;
|
||||||
|
}
|
||||||
|
interface IllmSettingEnableProps {
|
||||||
|
temperatureEnabled?: boolean;
|
||||||
|
topPEnabled?: boolean;
|
||||||
|
presencePenaltyEnabled?: boolean;
|
||||||
|
frequencyPenaltyEnabled?: boolean;
|
||||||
|
}
|
||||||
|
export interface IMemoryAppDetailProps {
|
||||||
|
avatar: any;
|
||||||
|
created_by: string;
|
||||||
|
description: string;
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
memory_config: {
|
||||||
|
cross_languages: string[];
|
||||||
|
doc_ids: string[];
|
||||||
|
chat_id: string;
|
||||||
|
highlight: boolean;
|
||||||
|
kb_ids: string[];
|
||||||
|
keyword: boolean;
|
||||||
|
query_mindmap: boolean;
|
||||||
|
related_memory: boolean;
|
||||||
|
rerank_id: string;
|
||||||
|
use_rerank?: boolean;
|
||||||
|
similarity_threshold: number;
|
||||||
|
summary: boolean;
|
||||||
|
llm_setting: IllmSettingProps & IllmSettingEnableProps;
|
||||||
|
top_k: number;
|
||||||
|
use_kg: boolean;
|
||||||
|
vector_similarity_weight: number;
|
||||||
|
web_memory: boolean;
|
||||||
|
chat_settingcross_languages: string[];
|
||||||
|
meta_data_filter?: {
|
||||||
|
method: string;
|
||||||
|
manual: { key: string; op: string; value: string }[];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
tenant_id: string;
|
||||||
|
update_time: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MemoryDetailResponse {
|
||||||
|
code: number;
|
||||||
|
data: IMemoryAppDetailProps;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
// export type IUpdateMemoryProps = Omit<IMemoryAppDetailProps, 'id'> & {
|
||||||
|
// id: string;
|
||||||
|
// };
|
||||||
32
web/src/pages/memories/memory-card.tsx
Normal file
32
web/src/pages/memories/memory-card.tsx
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
import { HomeCard } from '@/components/home-card';
|
||||||
|
import { MoreButton } from '@/components/more-button';
|
||||||
|
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||||
|
import { IMemory } from './interface';
|
||||||
|
import { MemoryDropdown } from './memory-dropdown';
|
||||||
|
|
||||||
|
interface IProps {
|
||||||
|
data: IMemory;
|
||||||
|
showMemoryRenameModal: (data: IMemory) => void;
|
||||||
|
}
|
||||||
|
export function MemoryCard({ data, showMemoryRenameModal }: IProps) {
|
||||||
|
const { navigateToMemory } = useNavigatePage();
|
||||||
|
|
||||||
|
return (
|
||||||
|
<HomeCard
|
||||||
|
data={{
|
||||||
|
name: data?.name,
|
||||||
|
avatar: data?.avatar,
|
||||||
|
description: data?.description,
|
||||||
|
}}
|
||||||
|
moreDropdown={
|
||||||
|
<MemoryDropdown
|
||||||
|
dataset={data}
|
||||||
|
showMemoryRenameModal={showMemoryRenameModal}
|
||||||
|
>
|
||||||
|
<MoreButton></MoreButton>
|
||||||
|
</MemoryDropdown>
|
||||||
|
}
|
||||||
|
onClick={navigateToMemory(data?.id)}
|
||||||
|
/>
|
||||||
|
);
|
||||||
|
}
|
||||||
74
web/src/pages/memories/memory-dropdown.tsx
Normal file
74
web/src/pages/memories/memory-dropdown.tsx
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
import {
|
||||||
|
ConfirmDeleteDialog,
|
||||||
|
ConfirmDeleteDialogNode,
|
||||||
|
} from '@/components/confirm-delete-dialog';
|
||||||
|
import {
|
||||||
|
DropdownMenu,
|
||||||
|
DropdownMenuContent,
|
||||||
|
DropdownMenuItem,
|
||||||
|
DropdownMenuSeparator,
|
||||||
|
DropdownMenuTrigger,
|
||||||
|
} from '@/components/ui/dropdown-menu';
|
||||||
|
import { PenLine, Trash2 } from 'lucide-react';
|
||||||
|
import { MouseEventHandler, PropsWithChildren, useCallback } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { IMemoryAppProps, useDeleteMemory } from './hooks';
|
||||||
|
|
||||||
|
export function MemoryDropdown({
|
||||||
|
children,
|
||||||
|
dataset,
|
||||||
|
showMemoryRenameModal,
|
||||||
|
}: PropsWithChildren & {
|
||||||
|
dataset: IMemoryAppProps;
|
||||||
|
showMemoryRenameModal: (dataset: IMemoryAppProps) => void;
|
||||||
|
}) {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { deleteMemory } = useDeleteMemory();
|
||||||
|
const handleShowChatRenameModal: MouseEventHandler<HTMLDivElement> =
|
||||||
|
useCallback(
|
||||||
|
(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
showMemoryRenameModal(dataset);
|
||||||
|
},
|
||||||
|
[dataset, showMemoryRenameModal],
|
||||||
|
);
|
||||||
|
const handleDelete: MouseEventHandler<HTMLDivElement> = useCallback(() => {
|
||||||
|
deleteMemory({ search_id: dataset.id });
|
||||||
|
}, [dataset.id, deleteMemory]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<DropdownMenu>
|
||||||
|
<DropdownMenuTrigger asChild>{children}</DropdownMenuTrigger>
|
||||||
|
<DropdownMenuContent>
|
||||||
|
<DropdownMenuItem onClick={handleShowChatRenameModal}>
|
||||||
|
{t('common.rename')} <PenLine />
|
||||||
|
</DropdownMenuItem>
|
||||||
|
<DropdownMenuSeparator />
|
||||||
|
<ConfirmDeleteDialog
|
||||||
|
onOk={handleDelete}
|
||||||
|
title={t('deleteModal.delMemory')}
|
||||||
|
content={{
|
||||||
|
node: (
|
||||||
|
<ConfirmDeleteDialogNode
|
||||||
|
avatar={{ avatar: dataset.avatar, name: dataset.name }}
|
||||||
|
name={dataset.name}
|
||||||
|
/>
|
||||||
|
),
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
<DropdownMenuItem
|
||||||
|
className="text-state-error"
|
||||||
|
onSelect={(e) => {
|
||||||
|
e.preventDefault();
|
||||||
|
}}
|
||||||
|
onClick={(e) => {
|
||||||
|
e.stopPropagation();
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{t('common.delete')} <Trash2 />
|
||||||
|
</DropdownMenuItem>
|
||||||
|
</ConfirmDeleteDialog>
|
||||||
|
</DropdownMenuContent>
|
||||||
|
</DropdownMenu>
|
||||||
|
);
|
||||||
|
}
|
||||||
3
web/src/pages/memory/constant.tsx
Normal file
3
web/src/pages/memory/constant.tsx
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
export enum MemoryApiAction {
|
||||||
|
FetchMemoryDetail = 'fetchMemoryDetail',
|
||||||
|
}
|
||||||
59
web/src/pages/memory/hooks/use-memory-messages.ts
Normal file
59
web/src/pages/memory/hooks/use-memory-messages.ts
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import { useHandleSearchChange } from '@/hooks/logic-hooks';
|
||||||
|
import { getMemoryDetailById } from '@/services/memory-service';
|
||||||
|
import { useQuery } from '@tanstack/react-query';
|
||||||
|
import { useParams, useSearchParams } from 'umi';
|
||||||
|
import { MemoryApiAction } from '../constant';
|
||||||
|
import { IMessageTableProps } from '../memory-message/interface';
|
||||||
|
|
||||||
|
export const useFetchMemoryMessageList = (props?: {
|
||||||
|
refreshCount?: number;
|
||||||
|
}) => {
|
||||||
|
const { refreshCount } = props || {};
|
||||||
|
const { id } = useParams();
|
||||||
|
const [searchParams] = useSearchParams();
|
||||||
|
const memoryBaseId = searchParams.get('id') || id;
|
||||||
|
const { handleInputChange, searchString, pagination, setPagination } =
|
||||||
|
useHandleSearchChange();
|
||||||
|
|
||||||
|
let queryKey: (MemoryApiAction | number)[] = [
|
||||||
|
MemoryApiAction.FetchMemoryDetail,
|
||||||
|
];
|
||||||
|
if (typeof refreshCount === 'number') {
|
||||||
|
queryKey = [MemoryApiAction.FetchMemoryDetail, refreshCount];
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data, isFetching: loading } = useQuery<IMessageTableProps>({
|
||||||
|
queryKey: [...queryKey, searchString, pagination],
|
||||||
|
initialData: {} as IMessageTableProps,
|
||||||
|
gcTime: 0,
|
||||||
|
queryFn: async () => {
|
||||||
|
if (memoryBaseId) {
|
||||||
|
const { data } = await getMemoryDetailById(memoryBaseId as string, {
|
||||||
|
// filter: {
|
||||||
|
// agent_id: '',
|
||||||
|
// },
|
||||||
|
keyword: searchString,
|
||||||
|
page: pagination.current,
|
||||||
|
page_size: pagination.pageSize,
|
||||||
|
});
|
||||||
|
// setPagination({
|
||||||
|
// page: data?.page ?? 1,
|
||||||
|
// pageSize: data?.page_size ?? 10,
|
||||||
|
// total: data?.total ?? 0,
|
||||||
|
// });
|
||||||
|
return data?.data ?? {};
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
data,
|
||||||
|
loading,
|
||||||
|
handleInputChange,
|
||||||
|
searchString,
|
||||||
|
pagination,
|
||||||
|
setPagination,
|
||||||
|
};
|
||||||
|
};
|
||||||
59
web/src/pages/memory/hooks/use-memory-setting.ts
Normal file
59
web/src/pages/memory/hooks/use-memory-setting.ts
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
import { useHandleSearchChange } from '@/hooks/logic-hooks';
|
||||||
|
import { IMemory } from '@/pages/memories/interface';
|
||||||
|
import { getMemoryDetailById } from '@/services/memory-service';
|
||||||
|
import { useQuery } from '@tanstack/react-query';
|
||||||
|
import { useParams, useSearchParams } from 'umi';
|
||||||
|
import { MemoryApiAction } from '../constant';
|
||||||
|
|
||||||
|
export const useFetchMemoryBaseConfiguration = (props?: {
|
||||||
|
refreshCount?: number;
|
||||||
|
}) => {
|
||||||
|
const { refreshCount } = props || {};
|
||||||
|
const { id } = useParams();
|
||||||
|
const [searchParams] = useSearchParams();
|
||||||
|
const memoryBaseId = searchParams.get('id') || id;
|
||||||
|
const { handleInputChange, searchString, pagination, setPagination } =
|
||||||
|
useHandleSearchChange();
|
||||||
|
|
||||||
|
let queryKey: (MemoryApiAction | number)[] = [
|
||||||
|
MemoryApiAction.FetchMemoryDetail,
|
||||||
|
];
|
||||||
|
if (typeof refreshCount === 'number') {
|
||||||
|
queryKey = [MemoryApiAction.FetchMemoryDetail, refreshCount];
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data, isFetching: loading } = useQuery<IMemory>({
|
||||||
|
queryKey: [...queryKey, searchString, pagination],
|
||||||
|
initialData: {} as IMemory,
|
||||||
|
gcTime: 0,
|
||||||
|
queryFn: async () => {
|
||||||
|
if (memoryBaseId) {
|
||||||
|
const { data } = await getMemoryDetailById(memoryBaseId as string, {
|
||||||
|
// filter: {
|
||||||
|
// agent_id: '',
|
||||||
|
// },
|
||||||
|
keyword: searchString,
|
||||||
|
page: pagination.current,
|
||||||
|
page_size: pagination.size,
|
||||||
|
});
|
||||||
|
// setPagination({
|
||||||
|
// page: data?.page ?? 1,
|
||||||
|
// pageSize: data?.page_size ?? 10,
|
||||||
|
// total: data?.total ?? 0,
|
||||||
|
// });
|
||||||
|
return data?.data ?? {};
|
||||||
|
} else {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
data,
|
||||||
|
loading,
|
||||||
|
handleInputChange,
|
||||||
|
searchString,
|
||||||
|
pagination,
|
||||||
|
setPagination,
|
||||||
|
};
|
||||||
|
};
|
||||||
17
web/src/pages/memory/index.tsx
Normal file
17
web/src/pages/memory/index.tsx
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import Spotlight from '@/components/spotlight';
|
||||||
|
import { Outlet } from 'umi';
|
||||||
|
import { SideBar } from './sidebar';
|
||||||
|
|
||||||
|
export default function DatasetWrapper() {
|
||||||
|
return (
|
||||||
|
<section className="flex h-full flex-col w-full">
|
||||||
|
<div className="flex flex-1 min-h-0">
|
||||||
|
<SideBar></SideBar>
|
||||||
|
<div className=" relative flex-1 overflow-auto border-[0.5px] border-border-button p-5 rounded-md mr-5 mb-5">
|
||||||
|
<Spotlight />
|
||||||
|
<Outlet />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</section>
|
||||||
|
);
|
||||||
|
}
|
||||||
48
web/src/pages/memory/memory-message/index.tsx
Normal file
48
web/src/pages/memory/memory-message/index.tsx
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import ListFilterBar from '@/components/list-filter-bar';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { useFetchMemoryMessageList } from '../hooks/use-memory-messages';
|
||||||
|
import { MemoryTable } from './message-table';
|
||||||
|
|
||||||
|
export default function MemoryMessage() {
|
||||||
|
const {
|
||||||
|
searchString,
|
||||||
|
// documents,
|
||||||
|
data,
|
||||||
|
pagination,
|
||||||
|
handleInputChange,
|
||||||
|
setPagination,
|
||||||
|
// filterValue,
|
||||||
|
// handleFilterSubmit,
|
||||||
|
loading,
|
||||||
|
} = useFetchMemoryMessageList();
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<ListFilterBar
|
||||||
|
title="Dataset"
|
||||||
|
onSearchChange={handleInputChange}
|
||||||
|
searchString={searchString}
|
||||||
|
// value={filterValue}
|
||||||
|
// onChange={handleFilterSubmit}
|
||||||
|
// onOpenChange={onOpenChange}
|
||||||
|
// filters={filters}
|
||||||
|
leftPanel={
|
||||||
|
<div className="items-start">
|
||||||
|
<div className="pb-1">{t('knowledgeDetails.subbarFiles')}</div>
|
||||||
|
<div className="text-text-secondary text-sm">
|
||||||
|
{t('knowledgeDetails.datasetDescription')}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
}
|
||||||
|
></ListFilterBar>
|
||||||
|
<MemoryTable
|
||||||
|
messages={data?.messages?.message_list ?? []}
|
||||||
|
pagination={pagination}
|
||||||
|
setPagination={setPagination}
|
||||||
|
total={data?.messages?.total ?? 0}
|
||||||
|
// rowSelection={rowSelection}
|
||||||
|
// setRowSelection={setRowSelection}
|
||||||
|
// loading={loading}
|
||||||
|
></MemoryTable>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
19
web/src/pages/memory/memory-message/interface.ts
Normal file
19
web/src/pages/memory/memory-message/interface.ts
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
export interface IMessageInfo {
|
||||||
|
message_id: number;
|
||||||
|
message_type: 'semantic' | 'raw' | 'procedural';
|
||||||
|
source_id: string | '-';
|
||||||
|
id: string;
|
||||||
|
user_id: string;
|
||||||
|
agent_id: string;
|
||||||
|
agent_name: string;
|
||||||
|
session_id: string;
|
||||||
|
valid_at: string;
|
||||||
|
invalid_at: string;
|
||||||
|
forget_at: string;
|
||||||
|
status: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface IMessageTableProps {
|
||||||
|
messages: { message_list: Array<IMessageInfo>; total: number };
|
||||||
|
storage_type: string;
|
||||||
|
}
|
||||||
225
web/src/pages/memory/memory-message/message-table.tsx
Normal file
225
web/src/pages/memory/memory-message/message-table.tsx
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import {
|
||||||
|
ColumnDef,
|
||||||
|
ColumnFiltersState,
|
||||||
|
SortingState,
|
||||||
|
VisibilityState,
|
||||||
|
flexRender,
|
||||||
|
getCoreRowModel,
|
||||||
|
getFilteredRowModel,
|
||||||
|
getPaginationRowModel,
|
||||||
|
getSortedRowModel,
|
||||||
|
useReactTable,
|
||||||
|
} from '@tanstack/react-table';
|
||||||
|
import * as React from 'react';
|
||||||
|
|
||||||
|
import { EmptyType } from '@/components/empty/constant';
|
||||||
|
import Empty from '@/components/empty/empty';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { RAGFlowPagination } from '@/components/ui/ragflow-pagination';
|
||||||
|
import { Switch } from '@/components/ui/switch';
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from '@/components/ui/table';
|
||||||
|
import { Pagination } from '@/interfaces/common';
|
||||||
|
import { t } from 'i18next';
|
||||||
|
import { pick } from 'lodash';
|
||||||
|
import { Eraser, TextSelect } from 'lucide-react';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { IMessageInfo } from './interface';
|
||||||
|
|
||||||
|
export type MemoryTableProps = {
|
||||||
|
messages: Array<IMessageInfo>;
|
||||||
|
total: number;
|
||||||
|
pagination: Pagination;
|
||||||
|
setPagination: (params: { page: number; pageSize: number }) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function MemoryTable({
|
||||||
|
messages,
|
||||||
|
total,
|
||||||
|
pagination,
|
||||||
|
setPagination,
|
||||||
|
}: MemoryTableProps) {
|
||||||
|
const [sorting, setSorting] = React.useState<SortingState>([]);
|
||||||
|
const [columnFilters, setColumnFilters] = React.useState<ColumnFiltersState>(
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
const [columnVisibility, setColumnVisibility] =
|
||||||
|
React.useState<VisibilityState>({});
|
||||||
|
|
||||||
|
// Define columns for the memory table
|
||||||
|
const columns: ColumnDef<IMessageInfo>[] = useMemo(
|
||||||
|
() => [
|
||||||
|
{
|
||||||
|
accessorKey: 'session_id',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.sessionId')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm font-medium ">
|
||||||
|
{row.getValue('session_id')}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'agent_name',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.agent')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm font-medium ">
|
||||||
|
{row.getValue('agent_name')}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'message_type',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.type')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm font-medium capitalize">
|
||||||
|
{row.getValue('message_type')}
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'valid_at',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.validDate')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm ">{row.getValue('valid_at')}</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'forget_at',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.forgetAt')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm ">{row.getValue('forget_at')}</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'source_id',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.source')}</span>,
|
||||||
|
cell: ({ row }) => (
|
||||||
|
<div className="text-sm ">{row.getValue('source_id')}</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'status',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.enable')}</span>,
|
||||||
|
cell: ({ row }) => {
|
||||||
|
const isEnabled = row.getValue('status') as boolean;
|
||||||
|
return (
|
||||||
|
<div className="flex items-center">
|
||||||
|
<Switch defaultChecked={isEnabled} onChange={() => {}} />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
accessorKey: 'action',
|
||||||
|
header: () => <span>{t('memoryDetail.messages.action')}</span>,
|
||||||
|
meta: {
|
||||||
|
cellClassName: 'w-12',
|
||||||
|
},
|
||||||
|
cell: () => (
|
||||||
|
<div className=" flex opacity-0 group-hover:opacity-100">
|
||||||
|
<Button variant={'ghost'} className="bg-transparent">
|
||||||
|
<TextSelect />
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant={'delete'}
|
||||||
|
className="bg-transparent"
|
||||||
|
aria-label="Edit"
|
||||||
|
>
|
||||||
|
<Eraser />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
[],
|
||||||
|
);
|
||||||
|
|
||||||
|
const currentPagination = useMemo(() => {
|
||||||
|
return {
|
||||||
|
pageIndex: (pagination.current || 1) - 1,
|
||||||
|
pageSize: pagination.pageSize || 10,
|
||||||
|
};
|
||||||
|
}, [pagination]);
|
||||||
|
|
||||||
|
const table = useReactTable({
|
||||||
|
data: messages,
|
||||||
|
columns,
|
||||||
|
onSortingChange: setSorting,
|
||||||
|
onColumnFiltersChange: setColumnFilters,
|
||||||
|
getCoreRowModel: getCoreRowModel(),
|
||||||
|
getPaginationRowModel: getPaginationRowModel(),
|
||||||
|
getSortedRowModel: getSortedRowModel(),
|
||||||
|
getFilteredRowModel: getFilteredRowModel(),
|
||||||
|
onColumnVisibilityChange: setColumnVisibility,
|
||||||
|
manualPagination: true,
|
||||||
|
state: {
|
||||||
|
sorting,
|
||||||
|
columnFilters,
|
||||||
|
columnVisibility,
|
||||||
|
pagination: currentPagination,
|
||||||
|
},
|
||||||
|
rowCount: total,
|
||||||
|
});
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full">
|
||||||
|
<Table rootClassName="max-h-[calc(100vh-222px)]">
|
||||||
|
<TableHeader>
|
||||||
|
{table.getHeaderGroups().map((headerGroup) => (
|
||||||
|
<TableRow key={headerGroup.id}>
|
||||||
|
{headerGroup.headers.map((header) => (
|
||||||
|
<TableHead key={header.id}>
|
||||||
|
{header.isPlaceholder
|
||||||
|
? null
|
||||||
|
: flexRender(
|
||||||
|
header.column.columnDef.header,
|
||||||
|
header.getContext(),
|
||||||
|
)}
|
||||||
|
</TableHead>
|
||||||
|
))}
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody className="relative">
|
||||||
|
{table.getRowModel().rows?.length ? (
|
||||||
|
table.getRowModel().rows.map((row) => (
|
||||||
|
<TableRow
|
||||||
|
key={row.id}
|
||||||
|
data-state={row.getIsSelected() && 'selected'}
|
||||||
|
className="group"
|
||||||
|
>
|
||||||
|
{row.getVisibleCells().map((cell) => (
|
||||||
|
<TableCell key={cell.id}>
|
||||||
|
{flexRender(cell.column.columnDef.cell, cell.getContext())}
|
||||||
|
</TableCell>
|
||||||
|
))}
|
||||||
|
</TableRow>
|
||||||
|
))
|
||||||
|
) : (
|
||||||
|
<TableRow>
|
||||||
|
<TableCell colSpan={columns.length} className="h-24 text-center">
|
||||||
|
<Empty type={EmptyType.Data} />
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
)}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
|
||||||
|
<div className="flex items-center justify-end py-4 absolute bottom-3 right-3">
|
||||||
|
<RAGFlowPagination
|
||||||
|
{...pick(pagination, 'current', 'pageSize')}
|
||||||
|
total={total}
|
||||||
|
onChange={(page, pageSize) => {
|
||||||
|
setPagination({ page, pageSize });
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
13
web/src/pages/memory/memory-setting/index.tsx
Normal file
13
web/src/pages/memory/memory-setting/index.tsx
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
export default function MemoryMessage() {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-2">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="h-4 w-4 rounded-full bg-text-secondary">11</div>
|
||||||
|
<div className="h-4 w-4 rounded-full bg-text-secondary">11</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="h-4 w-4 rounded-full bg-text ">setting</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
17
web/src/pages/memory/sidebar/hooks.tsx
Normal file
17
web/src/pages/memory/sidebar/hooks.tsx
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
import { Routes } from '@/routes';
|
||||||
|
import { useCallback } from 'react';
|
||||||
|
import { useNavigate, useParams } from 'umi';
|
||||||
|
|
||||||
|
export const useHandleMenuClick = () => {
|
||||||
|
const navigate = useNavigate();
|
||||||
|
const { id } = useParams();
|
||||||
|
|
||||||
|
const handleMenuClick = useCallback(
|
||||||
|
(key: Routes) => () => {
|
||||||
|
navigate(`${Routes.Memory}${key}/${id}`);
|
||||||
|
},
|
||||||
|
[id, navigate],
|
||||||
|
);
|
||||||
|
|
||||||
|
return { handleMenuClick };
|
||||||
|
};
|
||||||
88
web/src/pages/memory/sidebar/index.tsx
Normal file
88
web/src/pages/memory/sidebar/index.tsx
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import { RAGFlowAvatar } from '@/components/ragflow-avatar';
|
||||||
|
import { Button } from '@/components/ui/button';
|
||||||
|
import { useSecondPathName } from '@/hooks/route-hook';
|
||||||
|
import { cn, formatBytes } from '@/lib/utils';
|
||||||
|
import { Routes } from '@/routes';
|
||||||
|
import { formatPureDate } from '@/utils/date';
|
||||||
|
import { Banknote, Logs } from 'lucide-react';
|
||||||
|
import { useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useFetchMemoryBaseConfiguration } from '../hooks/use-memory-setting';
|
||||||
|
import { useHandleMenuClick } from './hooks';
|
||||||
|
|
||||||
|
type PropType = {
|
||||||
|
refreshCount?: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function SideBar({ refreshCount }: PropType) {
|
||||||
|
const pathName = useSecondPathName();
|
||||||
|
const { handleMenuClick } = useHandleMenuClick();
|
||||||
|
// refreshCount: be for avatar img sync update on top left
|
||||||
|
const { data } = useFetchMemoryBaseConfiguration({ refreshCount });
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
const items = useMemo(() => {
|
||||||
|
const list = [
|
||||||
|
{
|
||||||
|
icon: <Logs className="size-4" />,
|
||||||
|
label: t(`knowledgeDetails.overview`),
|
||||||
|
key: Routes.MemoryMessage,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
icon: <Banknote className="size-4" />,
|
||||||
|
label: t(`knowledgeDetails.configuration`),
|
||||||
|
key: Routes.MemorySetting,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
return list;
|
||||||
|
}, [t]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<aside className="relative p-5 space-y-8">
|
||||||
|
<div className="flex gap-2.5 max-w-[200px] items-center">
|
||||||
|
<RAGFlowAvatar
|
||||||
|
avatar={data.avatar}
|
||||||
|
name={data.name}
|
||||||
|
className="size-16"
|
||||||
|
></RAGFlowAvatar>
|
||||||
|
<div className=" text-text-secondary text-xs space-y-1 overflow-hidden">
|
||||||
|
<h3 className="text-lg font-semibold line-clamp-1 text-text-primary text-ellipsis overflow-hidden">
|
||||||
|
{data.name}
|
||||||
|
</h3>
|
||||||
|
<div className="flex justify-between">
|
||||||
|
<span>
|
||||||
|
{data.doc_num} {t('knowledgeDetails.files')}
|
||||||
|
</span>
|
||||||
|
<span>{formatBytes(data.size)}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
{t('knowledgeDetails.created')} {formatPureDate(data.create_time)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="w-[200px] flex flex-col gap-5">
|
||||||
|
{items.map((item, itemIdx) => {
|
||||||
|
const active = '/' + pathName === item.key;
|
||||||
|
return (
|
||||||
|
<Button
|
||||||
|
key={itemIdx}
|
||||||
|
variant={active ? 'secondary' : 'ghost'}
|
||||||
|
className={cn(
|
||||||
|
'w-full justify-start gap-2.5 px-3 relative h-10 text-text-secondary',
|
||||||
|
{
|
||||||
|
'bg-bg-card': active,
|
||||||
|
'text-text-primary': active,
|
||||||
|
},
|
||||||
|
)}
|
||||||
|
onClick={handleMenuClick(item.key)}
|
||||||
|
>
|
||||||
|
{item.icon}
|
||||||
|
<span>{item.label}</span>
|
||||||
|
</Button>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
</aside>
|
||||||
|
);
|
||||||
|
}
|
||||||
@ -50,18 +50,19 @@ export default function ChatList() {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<section className="flex flex-col w-full flex-1">
|
<section className="flex flex-col w-full flex-1">
|
||||||
{data.dialogs?.length <= 0 && (
|
{data.dialogs?.length <= 0 && !searchString && (
|
||||||
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
<EmptyAppCard
|
<EmptyAppCard
|
||||||
showIcon
|
showIcon
|
||||||
size="large"
|
size="large"
|
||||||
className="w-[480px] p-14"
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
type={EmptyCardType.Chat}
|
type={EmptyCardType.Chat}
|
||||||
onClick={() => handleShowCreateModal()}
|
onClick={() => handleShowCreateModal()}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
{data.dialogs?.length > 0 && (
|
{(data.dialogs?.length > 0 || searchString) && (
|
||||||
<>
|
<>
|
||||||
<div className="px-8 pt-8">
|
<div className="px-8 pt-8">
|
||||||
<ListFilterBar
|
<ListFilterBar
|
||||||
@ -76,6 +77,18 @@ export default function ChatList() {
|
|||||||
</Button>
|
</Button>
|
||||||
</ListFilterBar>
|
</ListFilterBar>
|
||||||
</div>
|
</div>
|
||||||
|
{data.dialogs?.length <= 0 && searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
isSearch={!!searchString}
|
||||||
|
type={EmptyCardType.Chat}
|
||||||
|
onClick={() => handleShowCreateModal()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="flex-1 overflow-auto">
|
<div className="flex-1 overflow-auto">
|
||||||
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
||||||
{data.dialogs.map((x) => {
|
{data.dialogs.map((x) => {
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
import HightLightMarkdown from '@/components/highlight-markdown';
|
import HighLightMarkdown from '@/components/highlight-markdown';
|
||||||
import message from '@/components/ui/message';
|
import message from '@/components/ui/message';
|
||||||
import { Modal } from '@/components/ui/modal/modal';
|
import { Modal } from '@/components/ui/modal/modal';
|
||||||
import { RAGFlowSelect } from '@/components/ui/select';
|
import { RAGFlowSelect } from '@/components/ui/select';
|
||||||
@ -102,7 +102,7 @@ const EmbedAppModal = (props: IEmbedAppModalProps) => {
|
|||||||
</label>
|
</label>
|
||||||
{/* <div className=" border rounded-lg"> */}
|
{/* <div className=" border rounded-lg"> */}
|
||||||
{/* <pre className="text-sm whitespace-pre-wrap">{text}</pre> */}
|
{/* <pre className="text-sm whitespace-pre-wrap">{text}</pre> */}
|
||||||
<HightLightMarkdown>{text}</HightLightMarkdown>
|
<HighLightMarkdown>{text}</HighLightMarkdown>
|
||||||
{/* </div> */}
|
{/* </div> */}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
|||||||
@ -1,48 +0,0 @@
|
|||||||
import Markdown from 'react-markdown';
|
|
||||||
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
|
|
||||||
import rehypeKatex from 'rehype-katex';
|
|
||||||
import rehypeRaw from 'rehype-raw';
|
|
||||||
import remarkGfm from 'remark-gfm';
|
|
||||||
import remarkMath from 'remark-math';
|
|
||||||
|
|
||||||
import 'katex/dist/katex.min.css'; // `rehype-katex` does not import the CSS for you
|
|
||||||
|
|
||||||
import { preprocessLaTeX } from '@/utils/chat';
|
|
||||||
|
|
||||||
const HightLightMarkdown = ({
|
|
||||||
children,
|
|
||||||
}: {
|
|
||||||
children: string | null | undefined;
|
|
||||||
}) => {
|
|
||||||
return (
|
|
||||||
<Markdown
|
|
||||||
remarkPlugins={[remarkGfm, remarkMath]}
|
|
||||||
rehypePlugins={[rehypeRaw, rehypeKatex]}
|
|
||||||
className="text-text-primary text-sm"
|
|
||||||
components={
|
|
||||||
{
|
|
||||||
code(props: any) {
|
|
||||||
const { children, className, ...rest } = props;
|
|
||||||
const match = /language-(\w+)/.exec(className || '');
|
|
||||||
return match ? (
|
|
||||||
<SyntaxHighlighter {...rest} PreTag="div" language={match[1]}>
|
|
||||||
{String(children).replace(/\n$/, '')}
|
|
||||||
</SyntaxHighlighter>
|
|
||||||
) : (
|
|
||||||
<code
|
|
||||||
{...rest}
|
|
||||||
className={`${className} pt-1 px-2 pb-2 m-0 whitespace-break-spaces rounded text-text-primary text-sm`}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</code>
|
|
||||||
);
|
|
||||||
},
|
|
||||||
} as any
|
|
||||||
}
|
|
||||||
>
|
|
||||||
{children ? preprocessLaTeX(children) : children}
|
|
||||||
</Markdown>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default HightLightMarkdown;
|
|
||||||
@ -1,5 +1,6 @@
|
|||||||
import { EmptyType } from '@/components/empty/constant';
|
import { EmptyType } from '@/components/empty/constant';
|
||||||
import Empty from '@/components/empty/empty';
|
import Empty from '@/components/empty/empty';
|
||||||
|
import HighLightMarkdown from '@/components/highlight-markdown';
|
||||||
import { FileIcon } from '@/components/icon-font';
|
import { FileIcon } from '@/components/icon-font';
|
||||||
import { ImageWithPopover } from '@/components/image';
|
import { ImageWithPopover } from '@/components/image';
|
||||||
import { Input } from '@/components/originui/input';
|
import { Input } from '@/components/originui/input';
|
||||||
@ -20,7 +21,6 @@ import { Dispatch, SetStateAction, useEffect, useState } from 'react';
|
|||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { ISearchAppDetailProps } from '../next-searches/hooks';
|
import { ISearchAppDetailProps } from '../next-searches/hooks';
|
||||||
import PdfDrawer from './document-preview-modal';
|
import PdfDrawer from './document-preview-modal';
|
||||||
import HightLightMarkdown from './highlight-markdown';
|
|
||||||
import { ISearchReturnProps } from './hooks';
|
import { ISearchReturnProps } from './hooks';
|
||||||
import './index.less';
|
import './index.less';
|
||||||
import MarkdownContent from './markdown-content';
|
import MarkdownContent from './markdown-content';
|
||||||
@ -217,9 +217,9 @@ export default function SearchingView({
|
|||||||
</PopoverTrigger>
|
</PopoverTrigger>
|
||||||
<PopoverContent className="text-text-primary !w-full max-w-lg ">
|
<PopoverContent className="text-text-primary !w-full max-w-lg ">
|
||||||
<div className="max-h-96 overflow-auto scrollbar-thin">
|
<div className="max-h-96 overflow-auto scrollbar-thin">
|
||||||
<HightLightMarkdown>
|
<HighLightMarkdown>
|
||||||
{chunk.content_with_weight}
|
{chunk.content_with_weight}
|
||||||
</HightLightMarkdown>
|
</HighLightMarkdown>
|
||||||
</div>
|
</div>
|
||||||
</PopoverContent>
|
</PopoverContent>
|
||||||
</Popover>
|
</Popover>
|
||||||
|
|||||||
@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import message from '@/components/ui/message';
|
import message from '@/components/ui/message';
|
||||||
import { useSetModalState } from '@/hooks/common-hooks';
|
import { useSetModalState } from '@/hooks/common-hooks';
|
||||||
|
import { useHandleSearchChange } from '@/hooks/logic-hooks';
|
||||||
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
import { useNavigatePage } from '@/hooks/logic-hooks/navigate-hooks';
|
||||||
import searchService from '@/services/search-service';
|
import searchService, { searchServiceNext } from '@/services/search-service';
|
||||||
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
import { useMutation, useQuery, useQueryClient } from '@tanstack/react-query';
|
||||||
|
import { useDebounce } from 'ahooks';
|
||||||
import { useCallback, useState } from 'react';
|
import { useCallback, useState } from 'react';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useParams, useSearchParams } from 'umi';
|
import { useParams, useSearchParams } from 'umi';
|
||||||
@ -84,21 +86,34 @@ interface SearchListResponse {
|
|||||||
message: string;
|
message: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useFetchSearchList = (params?: SearchListParams) => {
|
export const useFetchSearchList = () => {
|
||||||
const [searchParams, setSearchParams] = useState<SearchListParams>({
|
const { handleInputChange, searchString, pagination, setPagination } =
|
||||||
page: 1,
|
useHandleSearchChange();
|
||||||
page_size: 50,
|
const debouncedSearchString = useDebounce(searchString, { wait: 500 });
|
||||||
...params,
|
|
||||||
});
|
|
||||||
|
|
||||||
const { data, isLoading, isError, refetch } = useQuery<
|
const { data, isLoading, isError, refetch } = useQuery<
|
||||||
SearchListResponse,
|
SearchListResponse,
|
||||||
Error
|
Error
|
||||||
>({
|
>({
|
||||||
queryKey: ['searchList', searchParams],
|
queryKey: [
|
||||||
|
'searchList',
|
||||||
|
{
|
||||||
|
debouncedSearchString,
|
||||||
|
...pagination,
|
||||||
|
},
|
||||||
|
],
|
||||||
queryFn: async () => {
|
queryFn: async () => {
|
||||||
const { data: response } =
|
const { data: response } = await searchServiceNext.getSearchList(
|
||||||
await searchService.getSearchList(searchParams);
|
{
|
||||||
|
params: {
|
||||||
|
keywords: debouncedSearchString,
|
||||||
|
page_size: pagination.pageSize,
|
||||||
|
page: pagination.current,
|
||||||
|
},
|
||||||
|
data: {},
|
||||||
|
},
|
||||||
|
true,
|
||||||
|
);
|
||||||
if (response.code !== 0) {
|
if (response.code !== 0) {
|
||||||
throw new Error(response.message || 'Failed to fetch search list');
|
throw new Error(response.message || 'Failed to fetch search list');
|
||||||
}
|
}
|
||||||
@ -106,19 +121,14 @@ export const useFetchSearchList = (params?: SearchListParams) => {
|
|||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
const setSearchListParams = (newParams: SearchListParams) => {
|
|
||||||
setSearchParams((prevParams) => ({
|
|
||||||
...prevParams,
|
|
||||||
...newParams,
|
|
||||||
}));
|
|
||||||
};
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
data,
|
data,
|
||||||
isLoading,
|
isLoading,
|
||||||
isError,
|
isError,
|
||||||
searchParams,
|
pagination,
|
||||||
setSearchListParams,
|
searchString,
|
||||||
|
handleInputChange,
|
||||||
|
setPagination,
|
||||||
refetch,
|
refetch,
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import { RenameDialog } from '@/components/rename-dialog';
|
|||||||
import { Button } from '@/components/ui/button';
|
import { Button } from '@/components/ui/button';
|
||||||
import { RAGFlowPagination } from '@/components/ui/ragflow-pagination';
|
import { RAGFlowPagination } from '@/components/ui/ragflow-pagination';
|
||||||
import { useTranslate } from '@/hooks/common-hooks';
|
import { useTranslate } from '@/hooks/common-hooks';
|
||||||
|
import { pick } from 'lodash';
|
||||||
import { Plus } from 'lucide-react';
|
import { Plus } from 'lucide-react';
|
||||||
import { useCallback, useEffect } from 'react';
|
import { useCallback, useEffect } from 'react';
|
||||||
import { useSearchParams } from 'umi';
|
import { useSearchParams } from 'umi';
|
||||||
@ -19,10 +20,13 @@ export default function SearchList() {
|
|||||||
// const [isEdit, setIsEdit] = useState(false);
|
// const [isEdit, setIsEdit] = useState(false);
|
||||||
const {
|
const {
|
||||||
data: list,
|
data: list,
|
||||||
searchParams,
|
pagination,
|
||||||
setSearchListParams,
|
searchString,
|
||||||
|
handleInputChange,
|
||||||
|
setPagination,
|
||||||
refetch: refetchList,
|
refetch: refetchList,
|
||||||
} = useFetchSearchList();
|
} = useFetchSearchList();
|
||||||
|
|
||||||
const {
|
const {
|
||||||
openCreateModal,
|
openCreateModal,
|
||||||
showSearchRenameModal,
|
showSearchRenameModal,
|
||||||
@ -32,9 +36,9 @@ export default function SearchList() {
|
|||||||
initialSearchName,
|
initialSearchName,
|
||||||
} = useRenameSearch();
|
} = useRenameSearch();
|
||||||
|
|
||||||
const handleSearchChange = (value: string) => {
|
// const handleSearchChange = (value: string) => {
|
||||||
console.log(value);
|
// console.log(value);
|
||||||
};
|
// };
|
||||||
const onSearchRenameConfirm = (name: string) => {
|
const onSearchRenameConfirm = (name: string) => {
|
||||||
onSearchRenameOk(name, () => {
|
onSearchRenameOk(name, () => {
|
||||||
refetchList();
|
refetchList();
|
||||||
@ -44,10 +48,12 @@ export default function SearchList() {
|
|||||||
// setIsEdit(false);
|
// setIsEdit(false);
|
||||||
showSearchRenameModal();
|
showSearchRenameModal();
|
||||||
}, [showSearchRenameModal]);
|
}, [showSearchRenameModal]);
|
||||||
const handlePageChange = (page: number, pageSize: number) => {
|
const handlePageChange = useCallback(
|
||||||
// setIsEdit(false);
|
(page: number, pageSize?: number) => {
|
||||||
setSearchListParams({ ...searchParams, page, page_size: pageSize });
|
setPagination({ page, pageSize });
|
||||||
};
|
},
|
||||||
|
[setPagination],
|
||||||
|
);
|
||||||
|
|
||||||
const [searchUrl, setSearchUrl] = useSearchParams();
|
const [searchUrl, setSearchUrl] = useSearchParams();
|
||||||
const isCreate = searchUrl.get('isCreate') === 'true';
|
const isCreate = searchUrl.get('isCreate') === 'true';
|
||||||
@ -62,25 +68,28 @@ export default function SearchList() {
|
|||||||
return (
|
return (
|
||||||
<section className="w-full h-full flex flex-col">
|
<section className="w-full h-full flex flex-col">
|
||||||
{(!list?.data?.search_apps?.length ||
|
{(!list?.data?.search_apps?.length ||
|
||||||
list?.data?.search_apps?.length <= 0) && (
|
list?.data?.search_apps?.length <= 0) &&
|
||||||
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
!searchString && (
|
||||||
<EmptyAppCard
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
showIcon
|
<EmptyAppCard
|
||||||
size="large"
|
showIcon
|
||||||
className="w-[480px] p-14"
|
size="large"
|
||||||
type={EmptyCardType.Search}
|
className="w-[480px] p-14"
|
||||||
onClick={() => openCreateModalFun()}
|
type={EmptyCardType.Search}
|
||||||
/>
|
isSearch={!!searchString}
|
||||||
</div>
|
onClick={() => openCreateModalFun()}
|
||||||
)}
|
/>
|
||||||
{!!list?.data?.search_apps?.length && (
|
</div>
|
||||||
|
)}
|
||||||
|
{(!!list?.data?.search_apps?.length || searchString) && (
|
||||||
<>
|
<>
|
||||||
<div className="px-8 pt-8">
|
<div className="px-8 pt-8">
|
||||||
<ListFilterBar
|
<ListFilterBar
|
||||||
icon="searches"
|
icon="searches"
|
||||||
title={t('searchApps')}
|
title={t('searchApps')}
|
||||||
showFilter={false}
|
showFilter={false}
|
||||||
onSearchChange={(e) => handleSearchChange(e.target.value)}
|
searchString={searchString}
|
||||||
|
onSearchChange={handleInputChange}
|
||||||
>
|
>
|
||||||
<Button
|
<Button
|
||||||
variant={'default'}
|
variant={'default'}
|
||||||
@ -93,6 +102,20 @@ export default function SearchList() {
|
|||||||
</Button>
|
</Button>
|
||||||
</ListFilterBar>
|
</ListFilterBar>
|
||||||
</div>
|
</div>
|
||||||
|
{(!list?.data?.search_apps?.length ||
|
||||||
|
list?.data?.search_apps?.length <= 0) &&
|
||||||
|
searchString && (
|
||||||
|
<div className="flex w-full items-center justify-center h-[calc(100vh-164px)]">
|
||||||
|
<EmptyAppCard
|
||||||
|
showIcon
|
||||||
|
size="large"
|
||||||
|
className="w-[480px] p-14"
|
||||||
|
type={EmptyCardType.Search}
|
||||||
|
isSearch={!!searchString}
|
||||||
|
onClick={() => openCreateModalFun()}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
<div className="flex-1">
|
<div className="flex-1">
|
||||||
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
<CardContainer className="max-h-[calc(100dvh-280px)] overflow-auto px-8">
|
||||||
{list?.data.search_apps.map((x) => {
|
{list?.data.search_apps.map((x) => {
|
||||||
@ -111,8 +134,8 @@ export default function SearchList() {
|
|||||||
{list?.data.total && list?.data.total > 0 && (
|
{list?.data.total && list?.data.total > 0 && (
|
||||||
<div className="px-8 mb-4">
|
<div className="px-8 mb-4">
|
||||||
<RAGFlowPagination
|
<RAGFlowPagination
|
||||||
current={searchParams.page}
|
{...pick(pagination, 'current', 'pageSize')}
|
||||||
pageSize={searchParams.page_size}
|
// total={pagination.total}
|
||||||
total={list?.data.total}
|
total={list?.data.total}
|
||||||
onChange={handlePageChange}
|
onChange={handlePageChange}
|
||||||
/>
|
/>
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import { useEffect, useMemo } from 'react';
|
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||||
import { ControllerRenderProps, useFormContext } from 'react-hook-form';
|
import { ControllerRenderProps, useFormContext } from 'react-hook-form';
|
||||||
|
|
||||||
import { Checkbox } from '@/components/ui/checkbox';
|
import { Checkbox } from '@/components/ui/checkbox';
|
||||||
import { Input } from '@/components/ui/input';
|
import { Input } from '@/components/ui/input';
|
||||||
import { cn } from '@/lib/utils';
|
import { cn } from '@/lib/utils';
|
||||||
|
import { debounce } from 'lodash';
|
||||||
|
|
||||||
/* ---------------- Token Field ---------------- */
|
/* ---------------- Token Field ---------------- */
|
||||||
|
|
||||||
@ -48,15 +49,15 @@ type ConfluenceIndexingMode = 'everything' | 'space' | 'page';
|
|||||||
export type ConfluenceIndexingModeFieldProps = ControllerRenderProps;
|
export type ConfluenceIndexingModeFieldProps = ControllerRenderProps;
|
||||||
|
|
||||||
export const ConfluenceIndexingModeField = (
|
export const ConfluenceIndexingModeField = (
|
||||||
fieldProps: ConfluenceIndexingModeFieldProps,
|
fieldProps: ControllerRenderProps,
|
||||||
) => {
|
) => {
|
||||||
const { value, onChange, disabled } = fieldProps;
|
const { value, onChange, disabled } = fieldProps;
|
||||||
|
const [mode, setMode] = useState<ConfluenceIndexingMode>(
|
||||||
|
value || 'everything',
|
||||||
|
);
|
||||||
const { watch, setValue } = useFormContext();
|
const { watch, setValue } = useFormContext();
|
||||||
|
|
||||||
const mode = useMemo<ConfluenceIndexingMode>(
|
useEffect(() => setMode(value), [value]);
|
||||||
() => (value as ConfluenceIndexingMode) || 'everything',
|
|
||||||
[value],
|
|
||||||
);
|
|
||||||
|
|
||||||
const spaceValue = watch('config.space');
|
const spaceValue = watch('config.space');
|
||||||
const pageIdValue = watch('config.page_id');
|
const pageIdValue = watch('config.page_id');
|
||||||
@ -66,27 +67,40 @@ export const ConfluenceIndexingModeField = (
|
|||||||
if (!value) onChange('everything');
|
if (!value) onChange('everything');
|
||||||
}, [value, onChange]);
|
}, [value, onChange]);
|
||||||
|
|
||||||
const handleModeChange = (nextMode?: string) => {
|
const handleModeChange = useCallback(
|
||||||
const normalized = (nextMode || 'everything') as ConfluenceIndexingMode;
|
(nextMode?: string) => {
|
||||||
onChange(normalized);
|
let normalized: ConfluenceIndexingMode = 'everything';
|
||||||
|
if (nextMode) {
|
||||||
|
normalized = nextMode as ConfluenceIndexingMode;
|
||||||
|
setMode(normalized);
|
||||||
|
onChange(normalized);
|
||||||
|
} else {
|
||||||
|
setMode(mode);
|
||||||
|
normalized = mode;
|
||||||
|
onChange(mode);
|
||||||
|
// onChange(mode);
|
||||||
|
}
|
||||||
|
if (normalized === 'everything') {
|
||||||
|
setValue('config.space', '');
|
||||||
|
setValue('config.page_id', '');
|
||||||
|
setValue('config.index_recursively', false);
|
||||||
|
} else if (normalized === 'space') {
|
||||||
|
setValue('config.page_id', '');
|
||||||
|
setValue('config.index_recursively', false);
|
||||||
|
} else if (normalized === 'page') {
|
||||||
|
setValue('config.space', '');
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[mode, onChange, setValue],
|
||||||
|
);
|
||||||
|
|
||||||
if (normalized === 'everything') {
|
const debouncedHandleChange = useMemo(
|
||||||
setValue('config.space', '', { shouldDirty: true, shouldTouch: true });
|
() =>
|
||||||
setValue('config.page_id', '', { shouldDirty: true, shouldTouch: true });
|
debounce(() => {
|
||||||
setValue('config.index_recursively', false, {
|
handleModeChange();
|
||||||
shouldDirty: true,
|
}, 300),
|
||||||
shouldTouch: true,
|
[handleModeChange],
|
||||||
});
|
);
|
||||||
} else if (normalized === 'space') {
|
|
||||||
setValue('config.page_id', '', { shouldDirty: true, shouldTouch: true });
|
|
||||||
setValue('config.index_recursively', false, {
|
|
||||||
shouldDirty: true,
|
|
||||||
shouldTouch: true,
|
|
||||||
});
|
|
||||||
} else if (normalized === 'page') {
|
|
||||||
setValue('config.space', '', { shouldDirty: true, shouldTouch: true });
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="w-full rounded-lg border border-border-button bg-bg-card p-4 space-y-4">
|
<div className="w-full rounded-lg border border-border-button bg-bg-card p-4 space-y-4">
|
||||||
@ -127,12 +141,11 @@ export const ConfluenceIndexingModeField = (
|
|||||||
<Input
|
<Input
|
||||||
className="w-full"
|
className="w-full"
|
||||||
value={spaceValue ?? ''}
|
value={spaceValue ?? ''}
|
||||||
onChange={(e) =>
|
onChange={(e) => {
|
||||||
setValue('config.space', e.target.value, {
|
const value = e.target.value;
|
||||||
shouldDirty: true,
|
setValue('config.space', value);
|
||||||
shouldTouch: true,
|
debouncedHandleChange();
|
||||||
})
|
}}
|
||||||
}
|
|
||||||
placeholder="e.g. KB"
|
placeholder="e.g. KB"
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
/>
|
/>
|
||||||
@ -148,12 +161,10 @@ export const ConfluenceIndexingModeField = (
|
|||||||
<Input
|
<Input
|
||||||
className="w-full"
|
className="w-full"
|
||||||
value={pageIdValue ?? ''}
|
value={pageIdValue ?? ''}
|
||||||
onChange={(e) =>
|
onChange={(e) => {
|
||||||
setValue('config.page_id', e.target.value, {
|
setValue('config.page_id', e.target.value);
|
||||||
shouldDirty: true,
|
debouncedHandleChange();
|
||||||
shouldTouch: true,
|
}}
|
||||||
})
|
|
||||||
}
|
|
||||||
placeholder="e.g. 123456"
|
placeholder="e.g. 123456"
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
/>
|
/>
|
||||||
@ -164,12 +175,10 @@ export const ConfluenceIndexingModeField = (
|
|||||||
<div className="flex items-center gap-2 pt-2">
|
<div className="flex items-center gap-2 pt-2">
|
||||||
<Checkbox
|
<Checkbox
|
||||||
checked={Boolean(indexRecursively)}
|
checked={Boolean(indexRecursively)}
|
||||||
onCheckedChange={(checked) =>
|
onCheckedChange={(checked) => {
|
||||||
setValue('config.index_recursively', Boolean(checked), {
|
setValue('config.index_recursively', Boolean(checked));
|
||||||
shouldDirty: true,
|
debouncedHandleChange();
|
||||||
shouldTouch: true,
|
}}
|
||||||
})
|
|
||||||
}
|
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
/>
|
/>
|
||||||
<span className="text-sm text-text-secondary">
|
<span className="text-sm text-text-secondary">
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import { FormFieldType } from '@/components/dynamic-form';
|
import { FormFieldType } from '@/components/dynamic-form';
|
||||||
import SvgIcon from '@/components/svg-icon';
|
import SvgIcon from '@/components/svg-icon';
|
||||||
import { t } from 'i18next';
|
import { t } from 'i18next';
|
||||||
|
import { ControllerRenderProps } from 'react-hook-form';
|
||||||
import { ConfluenceIndexingModeField } from './component/confluence-token-field';
|
import { ConfluenceIndexingModeField } from './component/confluence-token-field';
|
||||||
import GmailTokenField from './component/gmail-token-field';
|
import GmailTokenField from './component/gmail-token-field';
|
||||||
import GoogleDriveTokenField from './component/google-drive-token-field';
|
import GoogleDriveTokenField from './component/google-drive-token-field';
|
||||||
@ -237,7 +238,9 @@ export const DataSourceFormFields = {
|
|||||||
required: false,
|
required: false,
|
||||||
horizontal: true,
|
horizontal: true,
|
||||||
labelClassName: 'self-start pt-4',
|
labelClassName: 'self-start pt-4',
|
||||||
render: (fieldProps) => <ConfluenceIndexingModeField {...fieldProps} />,
|
render: (fieldProps: ControllerRenderProps) => (
|
||||||
|
<ConfluenceIndexingModeField {...fieldProps} />
|
||||||
|
),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
label: 'Space Key',
|
label: 'Space Key',
|
||||||
@ -598,6 +601,7 @@ export const DataSourceFormDefaultValues = {
|
|||||||
confluence_username: '',
|
confluence_username: '',
|
||||||
confluence_access_token: '',
|
confluence_access_token: '',
|
||||||
},
|
},
|
||||||
|
index_mode: 'everything',
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
[DataSourceKey.GOOGLE_DRIVE]: {
|
[DataSourceKey.GOOGLE_DRIVE]: {
|
||||||
|
|||||||
@ -136,7 +136,7 @@ const SourceDetailPage = () => {
|
|||||||
...customFields,
|
...customFields,
|
||||||
] as FormFieldConfig[];
|
] as FormFieldConfig[];
|
||||||
|
|
||||||
const neweFields = fields.map((field) => {
|
const newFields = fields.map((field) => {
|
||||||
return {
|
return {
|
||||||
...field,
|
...field,
|
||||||
horizontal: true,
|
horizontal: true,
|
||||||
@ -145,7 +145,7 @@ const SourceDetailPage = () => {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
setFields(neweFields);
|
setFields(newFields);
|
||||||
|
|
||||||
const defultValueTemp = {
|
const defultValueTemp = {
|
||||||
...(DataSourceFormDefaultValues[
|
...(DataSourceFormDefaultValues[
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user