mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
24 Commits
4ec6a4e493
...
main_backu
| Author | SHA1 | Date | |
|---|---|---|---|
| 427e0540ca | |||
| 8f16fac898 | |||
| ea00478a21 | |||
| 906f19e863 | |||
| 667dc5467e | |||
| 977962fdfe | |||
| 1e9374a373 | |||
| 9b52ba8061 | |||
| 44671ea413 | |||
| c81421d340 | |||
| ee93a80e91 | |||
| 5c981978c1 | |||
| 7fef285af5 | |||
| b1efb905e5 | |||
| 6400bf87ba | |||
| f239bc02d3 | |||
| 5776fa73a7 | |||
| fc6af1998b | |||
| 0588fe79b9 | |||
| f545265f93 | |||
| c987d33649 | |||
| d72debf0db | |||
| c33134ea2c | |||
| 17b8bb62b6 |
29
.github/workflows/tests.yml
vendored
29
.github/workflows/tests.yml
vendored
@ -197,37 +197,38 @@ jobs:
|
||||
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
|
||||
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
|
||||
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
|
||||
sed -i '1i DOC_ENGINE=infinity' docker/.env
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
- name: Run sdk tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Elasticsearch
|
||||
- name: Run frontend api tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
|
||||
|
||||
- name: Run http api tests against Elasticsearch
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
|
||||
|
||||
- name: Run http api tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
|
||||
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
@ -237,35 +238,35 @@ jobs:
|
||||
|
||||
- name: Start ragflow:nightly
|
||||
run: |
|
||||
sed -i '1i DOC_ENGINE=infinity' docker/.env
|
||||
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
|
||||
- name: Run sdk tests against Infinity
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Infinity
|
||||
- name: Run frontend api tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
|
||||
|
||||
- name: Run http api tests against Infinity
|
||||
- name: Run http api tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
|
||||
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
|
||||
@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./
|
||||
COPY mcp mcp
|
||||
COPY plugin plugin
|
||||
COPY common common
|
||||
COPY memory memory
|
||||
|
||||
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
|
||||
COPY docker/entrypoint.sh ./
|
||||
|
||||
@ -29,6 +29,11 @@ from common.versions import get_ragflow_version
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
|
||||
|
||||
@admin_bp.route('/ping', methods=['GET'])
|
||||
def ping():
|
||||
return success_response('PONG')
|
||||
|
||||
|
||||
@admin_bp.route('/login', methods=['POST'])
|
||||
def login():
|
||||
if not request.json:
|
||||
|
||||
@ -86,8 +86,9 @@ class Agent(LLM, ToolBase):
|
||||
self.tools = {}
|
||||
for idx, cpn in enumerate(self._param.tools):
|
||||
cpn = self._load_tool_obj(cpn)
|
||||
name = cpn.get_meta()["function"]["name"]
|
||||
self.tools[f"{name}_{idx}"] = cpn
|
||||
original_name = cpn.get_meta()["function"]["name"]
|
||||
indexed_name = f"{original_name}_{idx}"
|
||||
self.tools[indexed_name] = cpn
|
||||
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
|
||||
max_retries=self._param.max_retries,
|
||||
@ -95,7 +96,12 @@ class Agent(LLM, ToolBase):
|
||||
max_rounds=self._param.max_rounds,
|
||||
verbose_tool_use=True
|
||||
)
|
||||
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
|
||||
self.tool_meta = []
|
||||
for indexed_name, tool_obj in self.tools.items():
|
||||
original_meta = tool_obj.get_meta()
|
||||
indexed_meta = deepcopy(original_meta)
|
||||
indexed_meta["function"]["name"] = indexed_name
|
||||
self.tool_meta.append(indexed_meta)
|
||||
|
||||
for mcp in self._param.mcp:
|
||||
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
||||
@ -109,7 +115,8 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
def _load_tool_obj(self, cpn: dict) -> object:
|
||||
from agent.component import component_class
|
||||
param = component_class(cpn["component_name"] + "Param")()
|
||||
tool_name = cpn["component_name"]
|
||||
param = component_class(tool_name + "Param")()
|
||||
param.update(cpn["params"])
|
||||
try:
|
||||
param.check()
|
||||
@ -277,19 +284,15 @@ class Agent(LLM, ToolBase):
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def build_task_desc(prompt: str, user_request: str, tool_metas: list[dict], user_defined_prompt: dict | None = None) -> str:
|
||||
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
|
||||
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
|
||||
user_defined_prompt = user_defined_prompt or {}
|
||||
|
||||
tools_json = json.dumps(tool_metas, ensure_ascii=False, indent=2)
|
||||
|
||||
task_desc = (
|
||||
"### Agent Prompt\n"
|
||||
f"{prompt}\n\n"
|
||||
"### User Request\n"
|
||||
f"{user_request}\n\n"
|
||||
"### Tools (schemas)\n"
|
||||
f"{tools_json}\n"
|
||||
)
|
||||
|
||||
if user_defined_prompt:
|
||||
@ -368,7 +371,7 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = build_task_desc(prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
|
||||
@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
|
||||
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
||||
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
||||
self.check_empty(self.llm_id, "[Agent] LLM")
|
||||
self.check_empty(self.sys_prompt, "[Agent] System prompt")
|
||||
self.check_empty(self.prompts, "[Agent] User prompt")
|
||||
|
||||
def gen_conf(self):
|
||||
|
||||
@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
elif var is None:
|
||||
if operator == "empty":
|
||||
return True
|
||||
return False
|
||||
|
||||
raise Exception(f"Invalid operator: {operator}")
|
||||
|
||||
|
||||
@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.joint_services.memory_message_service import query_message
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.memory_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
async def _retrieve_kb(self, query_text: str):
|
||||
kb_ids: list[str] = []
|
||||
for id in self._param.kb_ids:
|
||||
if id.find("@") < 0:
|
||||
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
|
||||
if self._param.rerank_id:
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
doc_ids = []
|
||||
if self._param.meta_data_filter != {}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
|
||||
def _resolve_manual_filter(flt: dict) -> dict:
|
||||
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||
chat_mdl, self._param.top_n)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
||||
[kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
return form_cnt
|
||||
|
||||
async def _retrieve_memory(self, query_text: str):
|
||||
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
raise Exception("No memory is selected.")
|
||||
|
||||
embd_names = list({memory.embd_id for memory in memory_list})
|
||||
assert len(embd_names) == 1, "Memory use different embedding models."
|
||||
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
# query message
|
||||
message_list = query_message({"memory_id": memory_ids}, {
|
||||
"query": query,
|
||||
"similarity_threshold": self._param.similarity_threshold,
|
||||
"keywords_similarity_weight": self._param.keywords_similarity_weight,
|
||||
"top_n": self._param.top_n
|
||||
})
|
||||
print(f"found {len(message_list)} messages.")
|
||||
|
||||
if not message_list:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
formated_content = "\n".join(memory_prompt(message_list, 200000))
|
||||
|
||||
# set formalized_content output
|
||||
self.set_output("formalized_content", formated_content)
|
||||
print(f"formated_content {formated_content}")
|
||||
return formated_content
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
if self._param.kb_ids:
|
||||
return await self._retrieve_kb(kwargs["query"])
|
||||
elif self._param.memory_ids:
|
||||
return await self._retrieve_memory(kwargs["query"])
|
||||
else:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@ -192,7 +192,7 @@ async def rerun():
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
|
||||
@ -447,6 +447,26 @@ async def metadata_update():
|
||||
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||
|
||||
|
||||
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "metadata")
|
||||
async def update_metadata_setting():
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
|
||||
e, doc = DocumentService.get_by_id(doc.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
return get_json_result(data=doc.to_dict())
|
||||
|
||||
|
||||
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def thumbnails():
|
||||
@ -564,7 +584,7 @@ async def run():
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
@ -615,7 +635,7 @@ async def rename():
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
@ -696,7 +716,7 @@ async def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
|
||||
@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@ -285,7 +285,7 @@ async def rm():
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
@ -386,7 +386,7 @@ def knowledge_graph(kb_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
@ -858,11 +858,11 @@ async def check_embedding():
|
||||
index_nm = search.index_name(tenant_id)
|
||||
|
||||
res0 = docStoreConn.search(
|
||||
selectFields=[], highlightFields=[],
|
||||
select_fields=[], highlight_fields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
match_expressions=[], order_by=OrderByExpr(),
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||
)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
@ -874,14 +874,14 @@ async def check_embedding():
|
||||
|
||||
for off in offsets:
|
||||
res1 = docStoreConn.search(
|
||||
selectFields=list(base_fields),
|
||||
highlightFields=[],
|
||||
select_fields=list(base_fields),
|
||||
highlight_fields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
match_expressions=[], order_by=OrderByExpr(),
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
ids = docStoreConn.get_doc_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
|
||||
@ -20,10 +20,12 @@ from api.apps import login_required, current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@ -57,7 +59,6 @@ async def create_memory():
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
else:
|
||||
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||
|
||||
@ -124,7 +125,7 @@ async def update_memory(memory_id):
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(memory_id, to_update)
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||
|
||||
@ -133,7 +134,7 @@ async def update_memory(memory_id):
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
|
||||
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||
try:
|
||||
MemoryService.delete_memory(memory_id)
|
||||
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||
return get_json_result(message=True)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
args = request.args
|
||||
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_detail(memory_id):
|
||||
args = request.args
|
||||
agent_ids = args.getlist("agent_id")
|
||||
keywords = args.get("keywords", "")
|
||||
keywords = keywords.strip()
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||
|
||||
169
api/apps/messages_app.py
Normal file
169
api/apps/messages_app.py
Normal file
@ -0,0 +1,169 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from quart import request
|
||||
from api.apps import login_required
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from memory.services.messages import MessageService
|
||||
from api.db.joint_services import memory_message_service
|
||||
from api.db.joint_services.memory_message_service import query_message
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
||||
async def add_message():
|
||||
|
||||
req = await get_request_json()
|
||||
memory_ids = req["memory_id"]
|
||||
agent_id = req["agent_id"]
|
||||
session_id = req["session_id"]
|
||||
user_id = req["user_id"] if req.get("user_id") else ""
|
||||
user_input = req["user_input"]
|
||||
agent_response = req["agent_response"]
|
||||
|
||||
res = []
|
||||
for memory_id in memory_ids:
|
||||
success, msg = await memory_message_service.save_to_memory(
|
||||
memory_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_input": user_input,
|
||||
"agent_response": agent_response
|
||||
}
|
||||
)
|
||||
res.append({
|
||||
"memory_id": memory_id,
|
||||
"success": success,
|
||||
"message": msg
|
||||
})
|
||||
|
||||
if all([r["success"] for r in res]):
|
||||
return get_json_result(message="Successfully added to memories.")
|
||||
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def forget_message(memory_id: str, message_id: int):
|
||||
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
forget_time = timestamp_to_date(current_timestamp())
|
||||
update_succeed = MessageService.update_message(
|
||||
{"memory_id": memory_id, "message_id": int(message_id)},
|
||||
{"forget_at": forget_time},
|
||||
memory.tenant_id, memory_id)
|
||||
if update_succeed:
|
||||
return get_json_result(message=update_succeed)
|
||||
else:
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("status")
|
||||
async def update_message(memory_id: str, message_id: int):
|
||||
req = await get_request_json()
|
||||
status = req["status"]
|
||||
if not isinstance(status, bool):
|
||||
return get_error_argument_result("Status must be a boolean.")
|
||||
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
|
||||
if update_succeed:
|
||||
return get_json_result(message=update_succeed)
|
||||
else:
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/search", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def search_message():
|
||||
args = request.args
|
||||
print(args, flush=True)
|
||||
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
|
||||
if empty_fields:
|
||||
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
|
||||
|
||||
memory_ids = args.getlist("memory_id")
|
||||
query = args.get("query")
|
||||
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
||||
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
||||
top_n = int(args.get("top_n", 5))
|
||||
agent_id = args.get("agent_id", "")
|
||||
session_id = args.get("session_id", "")
|
||||
|
||||
filter_dict = {
|
||||
"memory_id": memory_ids,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id
|
||||
}
|
||||
params = {
|
||||
"query": query,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"keywords_similarity_weight": keywords_similarity_weight,
|
||||
"top_n": top_n
|
||||
}
|
||||
res = query_message(filter_dict, params)
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_messages():
|
||||
args = request.args
|
||||
memory_ids = args.getlist("memory_id")
|
||||
agent_id = args.get("agent_id", "")
|
||||
session_id = args.get("session_id", "")
|
||||
limit = int(args.get("limit", 10))
|
||||
if not memory_ids:
|
||||
return get_error_argument_result("memory_ids is required.")
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
res = MessageService.get_recent_messages(
|
||||
uids,
|
||||
memory_ids,
|
||||
agent_id,
|
||||
session_id,
|
||||
limit
|
||||
)
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_message_content(memory_id:str, message_id: int):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
|
||||
if res:
|
||||
return get_json_result(message=True, data=res)
|
||||
else:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")
|
||||
@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
|
||||
@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
res["chunks"].append(final_chunk)
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
|
||||
@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
@ -169,6 +170,8 @@ def init_web_data():
|
||||
# init_superuser()
|
||||
|
||||
add_graph_templates()
|
||||
init_message_id_sequence()
|
||||
init_memory_size_cache()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
|
||||
233
api/db/joint_services/memory_message_service.py
Normal file
233
api/db/joint_services/memory_message_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
from common.doc_store.doc_store_base import FusionExpr
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.memory_utils import get_memory_type_human
|
||||
from memory.services.messages import MessageService
|
||||
from memory.services.query import MsgTextQuery, get_vector
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from memory.utils.msg_util import get_json_result_from_llm_response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
"""
|
||||
:param memory_id:
|
||||
:param message_dict: {
|
||||
"user_id": str,
|
||||
"agent_id": str,
|
||||
"session_id": str,
|
||||
"user_input": str,
|
||||
"agent_response": str
|
||||
}
|
||||
"""
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
|
||||
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||
message_list = [{
|
||||
"message_id": raw_message_id,
|
||||
"message_type": MemoryType.RAW.name.lower(),
|
||||
"source_id": 0,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||
"valid_at": timestamp_to_date(current_timestamp()),
|
||||
"invalid_at": None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
}, *[{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
"message_type": content["message_type"],
|
||||
"source_id": raw_message_id,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": content["content"],
|
||||
"valid_at": content["valid_at"],
|
||||
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
} for content in extracted_content]]
|
||||
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(tenant_id, memory_id):
|
||||
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
|
||||
if new_msg_size + current_memory_size > memory.memory_size:
|
||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||
if memory.forgetting_policy == "fifo":
|
||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
|
||||
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
|
||||
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
if not system_prompt:
|
||||
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
|
||||
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
|
||||
conversation_time = timestamp_to_date(current_timestamp())
|
||||
user_prompts = []
|
||||
if user_prompt:
|
||||
user_prompts.append({"role": "user", "content": user_prompt})
|
||||
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
res_json = get_json_result_from_llm_response(res)
|
||||
return [{
|
||||
"content": extracted_content["content"],
|
||||
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
|
||||
"message_type": message_type
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
def query_message(filter_dict: dict, params: dict):
|
||||
"""
|
||||
:param filter_dict: {
|
||||
"memory_id": List[str],
|
||||
"agent_id": optional
|
||||
"session_id": optional
|
||||
}
|
||||
:param params: {
|
||||
"query": question str,
|
||||
"similarity_threshold": float,
|
||||
"keywords_similarity_weight": float,
|
||||
"top_n": int
|
||||
}
|
||||
"""
|
||||
memory_ids = filter_dict["memory_id"]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
return []
|
||||
|
||||
condition_dict = {k: v for k, v in filter_dict.items() if v}
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
|
||||
question = params["query"]
|
||||
question = question.strip()
|
||||
memory = memory_list[0]
|
||||
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
|
||||
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
|
||||
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
|
||||
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
|
||||
|
||||
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
|
||||
|
||||
|
||||
def init_message_id_sequence():
|
||||
message_id_redis_key = "id_generator:memory"
|
||||
if REDIS_CONN.exist(message_id_redis_key):
|
||||
current_max_id = REDIS_CONN.get(message_id_redis_key)
|
||||
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
|
||||
else:
|
||||
max_id = 1
|
||||
exist_memory_list = MemoryService.get_all_memory()
|
||||
if not exist_memory_list:
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
else:
|
||||
max_id = MessageService.get_max_message_id(
|
||||
uid_list=[m.tenant_id for m in exist_memory_list],
|
||||
memory_ids=[m.id for m in exist_memory_list]
|
||||
)
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
|
||||
|
||||
|
||||
def get_memory_size_cache(memory_id: str, uid: str):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
if REDIS_CONN.exists(redis_key):
|
||||
return REDIS_CONN.get(redis_key)
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
[memory_id],
|
||||
[uid]
|
||||
)
|
||||
memory_size = memory_size_map.get(memory_id, 0)
|
||||
set_memory_size_cache(memory_id, memory_size)
|
||||
return memory_size
|
||||
|
||||
|
||||
def set_memory_size_cache(memory_id: str, size: int):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
return REDIS_CONN.set(redis_key, size)
|
||||
|
||||
|
||||
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, current_value + size)
|
||||
|
||||
|
||||
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, max(current_value - size, 0))
|
||||
|
||||
|
||||
def init_memory_size_cache():
|
||||
memory_list = MemoryService.get_all_memory()
|
||||
if not memory_list:
|
||||
logging.info("No memory found, no need to init memory size.")
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
memory_ids=[m.id for m in memory_list],
|
||||
uid_list=[m.tenant_id for m in memory_list],
|
||||
)
|
||||
for memory in memory_list:
|
||||
memory_size = memory_size_map.get(memory.id, 0)
|
||||
set_memory_size_cache(memory.id, memory_size)
|
||||
logging.info("Memory size cache init done.")
|
||||
@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from memory.services.messages import MessageService
|
||||
from rag.nlp import search
|
||||
from common.constants import ActiveEnum
|
||||
from common import settings
|
||||
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
# step1.3 delete memory and messages
|
||||
user_memory = MemoryService.get_by_tenant_id(tenant_id)
|
||||
if user_memory:
|
||||
for memory in user_memory:
|
||||
if MessageService.has_index(tenant_id, memory.id):
|
||||
MessageService.delete_index(tenant_id, memory.id)
|
||||
done_msg += " Deleted memory index."
|
||||
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
|
||||
done_msg += f"Deleted {memory_delete_res} memory datasets."
|
||||
# step1.4 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
|
||||
@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
|
||||
logging.exception(e)
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_basic_info_by_canvas_ids(cls, canvas_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.user_id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
|
||||
@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.metadata_utils import dedupe_list
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, get_format_time
|
||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -345,7 +346,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -696,10 +697,14 @@ class DocumentService(CommonService):
|
||||
for k,v in r.meta_fields.items():
|
||||
if k not in meta:
|
||||
meta[k] = {}
|
||||
v = str(v)
|
||||
if v not in meta[k]:
|
||||
meta[k][v] = []
|
||||
meta[k][v].append(doc_id)
|
||||
if not isinstance(v, list):
|
||||
v = [v]
|
||||
for vv in v:
|
||||
if vv not in meta[k]:
|
||||
if isinstance(vv, list) or isinstance(vv, dict):
|
||||
continue
|
||||
meta[k][vv] = []
|
||||
meta[k][vv].append(doc_id)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
@ -797,7 +802,10 @@ class DocumentService(CommonService):
|
||||
match_provided = "match" in upd
|
||||
if isinstance(meta[key], list):
|
||||
if not match_provided:
|
||||
meta[key] = new_value
|
||||
if isinstance(new_value, list):
|
||||
meta[key] = dedupe_list(new_value)
|
||||
else:
|
||||
meta[key] = new_value
|
||||
changed = True
|
||||
else:
|
||||
match_value = upd.get("match")
|
||||
@ -810,7 +818,7 @@ class DocumentService(CommonService):
|
||||
else:
|
||||
new_list.append(item)
|
||||
if replaced:
|
||||
meta[key] = new_list
|
||||
meta[key] = dedupe_list(new_list)
|
||||
changed = True
|
||||
else:
|
||||
if not match_provided:
|
||||
@ -1230,8 +1238,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -411,8 +411,6 @@ class KnowledgebaseService(CommonService):
|
||||
ok, _t = TenantService.get_by_id(tenant_id)
|
||||
if not ok:
|
||||
return False, get_data_error_result(message="Tenant not found.")
|
||||
if kwargs.get("parser_config") and isinstance(kwargs["parser_config"], dict) and not kwargs["parser_config"].get("llm_id"):
|
||||
kwargs["parser_config"]["llm_id"] = _t.llm_id
|
||||
|
||||
# Build payload
|
||||
kb_id = get_uuid()
|
||||
@ -427,6 +425,7 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
payload["parser_config"]["llm_id"] = _t.llm_id
|
||||
|
||||
return True, payload
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#
|
||||
from typing import List
|
||||
|
||||
from api.apps import current_user
|
||||
from api.db.db_models import DB, Memory, User
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
|
||||
from api.constants import MEMORY_NAME_LIMIT
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import get_format_time, current_timestamp
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
|
||||
|
||||
class MemoryService(CommonService):
|
||||
@ -34,6 +34,17 @@ class MemoryService(CommonService):
|
||||
def get_by_memory_id(cls, memory_id: str):
|
||||
return cls.model.select().where(cls.model.id == memory_id).first()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.select().where(cls.model.tenant_id == tenant_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_memory(cls):
|
||||
memory_list = cls.model.select()
|
||||
return list(memory_list)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_with_owner_name_by_id(cls, memory_id: str):
|
||||
@ -53,7 +64,9 @@ class MemoryService(CommonService):
|
||||
cls.model.forgetting_policy,
|
||||
cls.model.temperature,
|
||||
cls.model.system_prompt,
|
||||
cls.model.user_prompt
|
||||
cls.model.user_prompt,
|
||||
cls.model.create_date,
|
||||
cls.model.create_time
|
||||
]
|
||||
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||
cls.model.id == memory_id
|
||||
@ -72,7 +85,9 @@ class MemoryService(CommonService):
|
||||
cls.model.memory_type,
|
||||
cls.model.storage_type,
|
||||
cls.model.permissions,
|
||||
cls.model.description
|
||||
cls.model.description,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date
|
||||
]
|
||||
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||
if filter_dict.get("tenant_id"):
|
||||
@ -102,6 +117,8 @@ class MemoryService(CommonService):
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
||||
|
||||
timestamp = current_timestamp()
|
||||
format_time = get_format_time()
|
||||
# build create dict
|
||||
memory_info = {
|
||||
"id": get_uuid(),
|
||||
@ -110,10 +127,11 @@ class MemoryService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"embd_id": embd_id,
|
||||
"llm_id": llm_id,
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": get_format_time(),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": get_format_time(),
|
||||
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||
"create_time": timestamp,
|
||||
"create_date": format_time,
|
||||
"update_time": timestamp,
|
||||
"update_date": format_time,
|
||||
}
|
||||
obj = cls.model(**memory_info).save(force_insert=True)
|
||||
|
||||
@ -126,7 +144,7 @@ class MemoryService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_memory(cls, memory_id: str, update_dict: dict):
|
||||
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
|
||||
if not update_dict:
|
||||
return 0
|
||||
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||
@ -135,7 +153,7 @@ class MemoryService(CommonService):
|
||||
update_dict["name"] = duplicate_name(
|
||||
cls.query,
|
||||
name=update_dict["name"],
|
||||
tenant_id=current_user.id
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
update_dict.update({
|
||||
"update_time": current_timestamp(),
|
||||
|
||||
0
common/doc_store/__init__.py
Normal file
0
common/doc_store/__init__.py
Normal file
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = list | np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
@ -55,14 +53,13 @@ class SparseVector:
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
class MatchTextExpr:
|
||||
def __init__(
|
||||
self,
|
||||
fields: list[str],
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
extra_options: dict | None = None,
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
@ -70,7 +67,7 @@ class MatchTextExpr(ABC):
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
class MatchDenseExpr:
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
@ -78,7 +75,7 @@ class MatchDenseExpr(ABC):
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
extra_options: dict | None = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
@ -88,7 +85,7 @@ class MatchDenseExpr(ABC):
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
class MatchSparseExpr:
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
@ -104,7 +101,7 @@ class MatchSparseExpr(ABC):
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
class MatchTensorExpr:
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
@ -120,7 +117,7 @@ class MatchTensorExpr(ABC):
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
class FusionExpr:
|
||||
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
@ -129,7 +126,8 @@ class FusionExpr(ABC):
|
||||
|
||||
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
|
||||
class OrderByExpr:
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
@ -141,13 +139,14 @@ class OrderByExpr(ABC):
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
@ -165,21 +164,21 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def index_exist(self, index_name: str, dataset_id: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
@ -191,16 +190,16 @@ class DocStoreConnection(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str|list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
@ -209,28 +208,28 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
@ -245,7 +244,7 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_chunk_ids(self, res):
|
||||
def get_doc_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
@ -253,18 +252,18 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, field_name: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
326
common/doc_store/es_conn_base.py
Normal file
326
common/doc_store/es_conn_base.py
Normal file
@ -0,0 +1,326 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import convert_bytes
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common import settings
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
|
||||
class ESConnectionBase(DocStoreConnection):
|
||||
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
self.info = {}
|
||||
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
v = v["number"].split(".")[0]
|
||||
if int(v) < 8:
|
||||
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def db_type(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
health_dict = dict(self.es.cluster.health())
|
||||
health_dict["type"] = "elasticsearch"
|
||||
return health_dict
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
||||
"""
|
||||
raw_stats = self.es.cluster.stats()
|
||||
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
||||
try:
|
||||
res = {
|
||||
'cluster_name': raw_stats['cluster_name'],
|
||||
'status': raw_stats['status']
|
||||
}
|
||||
indices_status = raw_stats['indices']
|
||||
res.update({
|
||||
'indices': indices_status['count'],
|
||||
'indices_shards': indices_status['shards']['total']
|
||||
})
|
||||
doc_info = indices_status['docs']
|
||||
res.update({
|
||||
'docs': doc_info['count'],
|
||||
'docs_deleted': doc_info['deleted']
|
||||
})
|
||||
store_info = indices_status['store']
|
||||
res.update({
|
||||
'store_size': convert_bytes(store_info['size_in_bytes']),
|
||||
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
||||
})
|
||||
mappings_info = indices_status['mappings']
|
||||
res.update({
|
||||
'mappings_fields': mappings_info['total_field_count'],
|
||||
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
||||
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
||||
})
|
||||
node_info = raw_stats['nodes']
|
||||
res.update({
|
||||
'nodes': node_info['count']['total'],
|
||||
'nodes_version': node_info['versions'],
|
||||
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
||||
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
||||
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
||||
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
||||
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
||||
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
||||
})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
||||
return None
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
if self.index_exist(index_name, dataset_id):
|
||||
return True
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=index_name,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception:
|
||||
self.logger.exception("ESConnection.createIndex error %s" % index_name)
|
||||
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
if len(dataset_id) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
try:
|
||||
self.es.indices.delete(index=index_name, allow_no_indices=True)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
|
||||
|
||||
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
|
||||
s = Index(index_name, self.es)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
return s.exists()
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
break
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=index_name,
|
||||
id=doc_id, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
doc = res["_source"]
|
||||
doc["id"] = doc_id
|
||||
return doc
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
|
||||
raise e
|
||||
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_doc_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def _get_source(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
@abstractmethod
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
highlights = d.get("highlight")
|
||||
if not highlights:
|
||||
continue
|
||||
txt = "...".join([a for a in list(highlights.items())[0][1]])
|
||||
if not is_english(txt.split()):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][field_name]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txt_list = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txt_list.append(t)
|
||||
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res, field_name: str):
|
||||
agg_field = "aggs_" + field_name
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
buckets = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in buckets]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
self.logger.debug(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
self.logger.debug(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
||||
return None
|
||||
451
common/doc_store/infinity_conn_base.py
Normal file
451
common/doc_store/infinity_conn_base.py
Normal file
@ -0,0 +1,451 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
|
||||
import infinity
|
||||
from infinity.common import ConflictType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.errors import ErrorCode
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
|
||||
|
||||
|
||||
class InfinityConnectionBase(DocStoreConnection):
|
||||
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
self.mapping_file_name = mapping_file_name
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = None
|
||||
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
|
||||
for _ in range(24):
|
||||
try:
|
||||
conn_pool = ConnectionPool(infinity_uri, max_size=4)
|
||||
inf_conn = conn_pool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = conn_pool
|
||||
conn_pool.release_conn(inf_conn)
|
||||
break
|
||||
conn_pool.release_conn(inf_conn)
|
||||
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
if self.connPool is None:
|
||||
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.logger.info(f"Infinity {infinity_uri} is healthy.")
|
||||
|
||||
def _migrate_db(self, inf_conn):
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
table_names = inf_db.list_tables().table_names
|
||||
for table_name in table_names:
|
||||
inf_table = inf_db.get_table(table_name)
|
||||
index_names = inf_table.list_indexes().index_names
|
||||
if "q_vec_idx" not in index_names:
|
||||
# Skip tables not created by me
|
||||
continue
|
||||
column_names = inf_table.show_columns()["name"]
|
||||
column_names = set(column_names)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_name in column_names:
|
||||
continue
|
||||
res = inf_table.add_columns({field_name: field_info})
|
||||
assert res.error_code == infinity.ErrorCode.OK
|
||||
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
|
||||
"""
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def field_keyword(field_name: str):
|
||||
# judge keyword or not, such as "*_kwd" tag-like columns.
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
# convert matching field to
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@staticmethod
|
||||
def list2str(lst: str | list, sep: str = " ") -> str:
|
||||
if isinstance(lst, str):
|
||||
return lst
|
||||
return sep.join(lst)
|
||||
|
||||
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
|
||||
assert "_id" not in condition
|
||||
columns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
columns[n] = (ty, de)
|
||||
|
||||
def exists(cln):
|
||||
nonlocal columns
|
||||
assert cln in columns, f"'{cln}' should be in '{columns}'."
|
||||
ty, de = columns[cln]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
return f" {cln}!='{de}' "
|
||||
return f"{cln}!={de}"
|
||||
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
|
||||
if inCond:
|
||||
strInCond = " or ".join(inCond)
|
||||
strInCond = f"({strInCond})"
|
||||
cond.append(strInCond)
|
||||
else:
|
||||
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
|
||||
elif isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
cond.append("NOT (%s)" % exists(vv))
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
elif k == "exists":
|
||||
cond.append(exists(v))
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
@staticmethod
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in select_fields:
|
||||
if field_name == "score()": # Workaround: fix schema is changed to score()
|
||||
schema.append("SCORE")
|
||||
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
|
||||
schema.append("SIMILARITY")
|
||||
else:
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def db_type(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vector_size}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vector_size},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
|
||||
return True
|
||||
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def index_exist(self, index_name: str, dataset_id: str) -> bool:
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"INFINITY indexExist {str(e)}")
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
|
||||
return 0
|
||||
filter = self.equivalent_condition_to_str(condition, table_instance)
|
||||
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
@abstractmethod
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if field_name not in res:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[field_name][i]
|
||||
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
|
||||
ans[id] = txt
|
||||
continue
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txt_list = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
if is_english([t]):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
else:
|
||||
for w in sorted(keywords, key=len, reverse=True):
|
||||
t = re.sub(
|
||||
re.escape(w),
|
||||
f"<em>{w}</em>",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txt_list.append(t)
|
||||
if txt_list:
|
||||
ans[id] = "...".join(txt_list)
|
||||
else:
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
# Extract DataFrame from result
|
||||
if isinstance(res, tuple):
|
||||
df, _ = res
|
||||
else:
|
||||
df = res
|
||||
|
||||
if df.empty or field_name not in df.columns:
|
||||
return []
|
||||
|
||||
# Aggregate tag counts
|
||||
tag_counter = Counter()
|
||||
|
||||
for value in df[field_name]:
|
||||
if pd.isna(value) or not value:
|
||||
continue
|
||||
|
||||
# Handle different tag formats
|
||||
if isinstance(value, str):
|
||||
# Split by ### for tag_kwd field or comma for other formats
|
||||
if field_name == "tag_kwd" and "###" in value:
|
||||
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
|
||||
else:
|
||||
# Try comma separation as fallback
|
||||
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
|
||||
|
||||
for tag in tags:
|
||||
if tag: # Only count non-empty tags
|
||||
tag_counter[tag] += 1
|
||||
elif isinstance(value, list):
|
||||
# Handle list format
|
||||
for tag in value:
|
||||
if tag and isinstance(tag, str):
|
||||
tag_counter[tag.strip()] += 1
|
||||
|
||||
# Return as list of [tag, count] pairs, sorted by count descending
|
||||
return [[tag, count] for tag, count in tag_counter.most_common()]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
@ -44,21 +44,27 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
if isinstance(input, list):
|
||||
input = input[0]
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
pass
|
||||
if isinstance(input, str):
|
||||
input = input.lower()
|
||||
if isinstance(value, str):
|
||||
value = value.lower()
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])),
|
||||
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
|
||||
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])),
|
||||
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
@ -145,6 +151,18 @@ async def apply_meta_data_filter(
|
||||
return doc_ids
|
||||
|
||||
|
||||
def dedupe_list(values: list) -> list:
|
||||
seen = set()
|
||||
deduped = []
|
||||
for item in values:
|
||||
key = str(item)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
return deduped
|
||||
|
||||
|
||||
def update_metadata_to(metadata, meta):
|
||||
if not meta:
|
||||
return metadata
|
||||
@ -156,11 +174,13 @@ def update_metadata_to(metadata, meta):
|
||||
return metadata
|
||||
if not isinstance(meta, dict):
|
||||
return metadata
|
||||
|
||||
for k, v in meta.items():
|
||||
if isinstance(v, list):
|
||||
v = [vv for vv in v if isinstance(vv, str)]
|
||||
if not v:
|
||||
continue
|
||||
v = dedupe_list(v)
|
||||
if not isinstance(v, list) and not isinstance(v, str):
|
||||
continue
|
||||
if k not in metadata:
|
||||
@ -171,6 +191,7 @@ def update_metadata_to(metadata, meta):
|
||||
metadata[k].extend(v)
|
||||
else:
|
||||
metadata[k].append(v)
|
||||
metadata[k] = dedupe_list(metadata[k])
|
||||
else:
|
||||
metadata[k] = v
|
||||
|
||||
@ -202,4 +223,4 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
json_schema["additionalProperties"] = False
|
||||
return json_schema
|
||||
return json_schema
|
||||
|
||||
72
common/query_base.py
Normal file
72
common/query_base.py
Normal file
@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class QueryBase(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
@abstractmethod
|
||||
def question(self, text, tbl, min_match):
|
||||
"""
|
||||
Returns a query object based on the input text, table, and minimum match criteria.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
|
||||
|
||||
from rag.nlp import search
|
||||
|
||||
import memory.utils.es_conn as memory_es_conn
|
||||
import memory.utils.infinity_conn as memory_infinity_conn
|
||||
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
@ -76,9 +79,11 @@ FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
|
||||
MSG_ENGINE = DOC_ENGINE
|
||||
|
||||
|
||||
docStoreConn = None
|
||||
msgStoreConn = None
|
||||
|
||||
retriever = None
|
||||
kg_retriever = None
|
||||
@ -256,6 +261,15 @@ def init_settings():
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
global MSG_ENGINE, msgStoreConn
|
||||
MSG_ENGINE = DOC_ENGINE # use the same engine for message store
|
||||
if MSG_ENGINE == "elasticsearch":
|
||||
ES = get_base_config("es", {})
|
||||
msgStoreConn = memory_es_conn.ESConnection()
|
||||
elif MSG_ENGINE == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
||||
|
||||
global AZURE, S3, MINIO, OSS, GCS
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
def current_timestamp():
|
||||
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
|
||||
3600.0 # If current time is 2024-01-01 13:00:00
|
||||
"""
|
||||
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
|
||||
|
||||
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
|
||||
"""
|
||||
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
|
||||
|
||||
Args:
|
||||
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
|
||||
|
||||
Returns:
|
||||
str: Date string in "YYYY-MM-DD HH:MM:SS" format
|
||||
|
||||
Example:
|
||||
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
|
||||
'2024-01-01 12:00:00'
|
||||
"""
|
||||
from dateutil import parser
|
||||
|
||||
try:
|
||||
if parser.isoparse(time_str):
|
||||
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
return time_str
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return time_str
|
||||
|
||||
@ -1258,6 +1258,12 @@
|
||||
"status": "1",
|
||||
"rank": "810",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "MiniMax-M2.1",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "MiniMax-M2",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
|
||||
19
conf/message_infinity_mapping.json
Normal file
19
conf/message_infinity_mapping.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"id": {"type": "varchar", "default": ""},
|
||||
"message_id": {"type": "integer", "default": 0},
|
||||
"message_type_kwd": {"type": "varchar", "default": ""},
|
||||
"source_id": {"type": "integer", "default": 0},
|
||||
"memory_id": {"type": "varchar", "default": ""},
|
||||
"user_id": {"type": "varchar", "default": ""},
|
||||
"agent_id": {"type": "varchar", "default": ""},
|
||||
"session_id": {"type": "varchar", "default": ""},
|
||||
"valid_at": {"type": "varchar", "default": ""},
|
||||
"valid_at_flt": {"type": "float", "default": 0.0},
|
||||
"invalid_at": {"type": "varchar", "default": ""},
|
||||
"invalid_at_flt": {"type": "float", "default": 0.0},
|
||||
"forget_at": {"type": "varchar", "default": ""},
|
||||
"forget_at_flt": {"type": "float", "default": 0.0},
|
||||
"status_int": {"type": "integer", "default": 1},
|
||||
"zone_id": {"type": "integer", "default": 0},
|
||||
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"}
|
||||
}
|
||||
@ -1447,6 +1447,7 @@ class VisionParser(RAGFlowPdfParser):
|
||||
def __init__(self, vision_model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vision_model = vision_model
|
||||
self.outlines = []
|
||||
|
||||
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||
try:
|
||||
|
||||
@ -72,7 +72,7 @@ services:
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.11
|
||||
image: infiniflow/infinity:v0.6.13
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
[general]
|
||||
version = "0.6.11"
|
||||
version = "0.6.13"
|
||||
time_zone = "utc-8"
|
||||
|
||||
[network]
|
||||
|
||||
90
docs/guides/agent/agent_component_reference/http.md
Normal file
90
docs/guides/agent/agent_component_reference/http.md
Normal file
@ -0,0 +1,90 @@
|
||||
---
|
||||
sidebar_position: 30
|
||||
slug: /http_request_component
|
||||
---
|
||||
|
||||
# HTTP request component
|
||||
|
||||
A component that calls remote services.
|
||||
|
||||
---
|
||||
|
||||
An **HTTP request** component lets you access remote APIs or services by providing a URL and an HTTP method, and then receive the response. You can customize headers, parameters, proxies, and timeout settings, and use common methods like GET and POST. It’s useful for exchanging data with external systems in a workflow.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- An accessible remote API or service.
|
||||
- Add a Token or credentials to the request header, if the target service requires authentication.
|
||||
|
||||
## Configurations
|
||||
|
||||
### Url
|
||||
|
||||
*Required*. The complete request address, for example: http://api.example.com/data.
|
||||
|
||||
### Method
|
||||
|
||||
The HTTP request method to select. Available options:
|
||||
|
||||
- GET
|
||||
- POST
|
||||
- PUT
|
||||
|
||||
### Timeout
|
||||
|
||||
The maximum waiting time for the request, in seconds. Defaults to `60`.
|
||||
|
||||
### Headers
|
||||
|
||||
Custom HTTP headers can be set here, for example:
|
||||
|
||||
```http
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
```
|
||||
|
||||
### Proxy
|
||||
|
||||
Optional. The proxy server address to use for this request.
|
||||
|
||||
### Clean HTML
|
||||
|
||||
`Boolean`: Whether to remove HTML tags from the returned results and keep plain text only.
|
||||
|
||||
### Parameter
|
||||
|
||||
*Optional*. Parameters to send with the HTTP request. Supports key-value pairs:
|
||||
|
||||
- To assign a value using a dynamic system variable, set it as Variable.
|
||||
- To override these dynamic values under certain conditions and use a fixed static value instead, Value is the appropriate choice.
|
||||
|
||||
|
||||
:::tip NOTE
|
||||
- For GET requests, these parameters are appended to the end of the URL.
|
||||
- For POST/PUT requests, they are sent as the request body.
|
||||
:::
|
||||
|
||||
#### Example setting
|
||||
|
||||

|
||||
|
||||
#### Example response
|
||||
|
||||
```html
|
||||
{ "args": { "App": "RAGFlow", "Query": "How to do?", "Userid": "241ed25a8e1011f0b979424ebc5b108b" }, "headers": { "Accept": "/", "Accept-Encoding": "gzip, deflate, br, zstd", "Cache-Control": "no-cache", "Host": "httpbin.org", "User-Agent": "python-requests/2.32.2", "X-Amzn-Trace-Id": "Root=1-68c9210c-5aab9088580c130a2f065523" }, "origin": "185.36.193.38", "url": "https://httpbin.org/get?Userid=241ed25a8e1011f0b979424ebc5b108b&App=RAGFlow&Query=How+to+do%3F" }
|
||||
```
|
||||
|
||||
### Output
|
||||
|
||||
The global variable name for the output of the HTTP request component, which can be referenced by other components in the workflow.
|
||||
|
||||
- `Result`: `string` The response returned by the remote service.
|
||||
|
||||
## Example
|
||||
|
||||
This is a usage example: a workflow sends a GET request from the **Begin** component to `https://httpbin.org/get` via the **HTTP Request_0** component, passes parameters to the server, and finally outputs the result through the **Message_0** component.
|
||||
|
||||

|
||||
@ -24,11 +24,11 @@ from common.misc_utils import get_uuid
|
||||
from graphrag.query_analyze_prompt import PROMPTS
|
||||
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from rag.nlp.search import Dealer, index_name
|
||||
from common.float_utils import get_float
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
|
||||
class KGSearch(Dealer):
|
||||
|
||||
@ -26,9 +26,9 @@ from networkx.readwrite import json_graph
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ ragflow:
|
||||
infinity:
|
||||
image:
|
||||
repository: infiniflow/infinity
|
||||
tag: v0.6.11
|
||||
tag: v0.6.13
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: []
|
||||
storage:
|
||||
|
||||
0
memory/__init__.py
Normal file
0
memory/__init__.py
Normal file
0
memory/services/__init__.py
Normal file
0
memory/services/__init__.py
Normal file
240
memory/services/messages.py
Normal file
240
memory/services/messages.py
Normal file
@ -0,0 +1,240 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr, MatchExpr
|
||||
|
||||
|
||||
def index_name(uid: str): return f"memory_{uid}"
|
||||
|
||||
|
||||
class MessageService:
|
||||
|
||||
@classmethod
|
||||
def has_index(cls, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.index_exist(index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, uid: str, memory_id: str, vector_size: int):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.create_idx(index, memory_id, vector_size)
|
||||
|
||||
@classmethod
|
||||
def delete_index(cls, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.delete_idx(index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def insert_message(cls, messages: List[dict], uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
[m.update({
|
||||
"id": f'{memory_id}_{m["message_id"]}',
|
||||
"status": 1 if m["status"] else 0
|
||||
}) for m in messages]
|
||||
return settings.msgStoreConn.insert(messages, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
if "status" in update_dict:
|
||||
update_dict["status"] = 1 if update_dict["status"] else 0
|
||||
return settings.msgStoreConn.update(condition, update_dict, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def delete_message(cls, condition: dict, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.delete(condition, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50):
|
||||
index = index_name(uid)
|
||||
filter_dict = {}
|
||||
if agent_ids:
|
||||
filter_dict["agent_id"] = agent_ids
|
||||
if keywords:
|
||||
filter_dict["session_id"] = keywords
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=filter_dict,
|
||||
match_expressions=[], order_by=order_by,
|
||||
offset=(page-1)*page_size, limit=page_size,
|
||||
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
total_count = settings.msgStoreConn.get_total(res)
|
||||
doc_mapping = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status"
|
||||
])
|
||||
return {
|
||||
"message_list": list(doc_mapping.values()),
|
||||
"total_count": total_count
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
condition_dict = {
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id
|
||||
}
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=condition_dict,
|
||||
match_expressions=[], order_by=order_by,
|
||||
offset=0, limit=limit,
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
|
||||
)
|
||||
doc_mapping = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status", "content"
|
||||
])
|
||||
return list(doc_mapping.values())
|
||||
|
||||
@classmethod
|
||||
def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
# filter only valid messages by default
|
||||
if "status" not in condition_dict:
|
||||
condition_dict["status"] = 1
|
||||
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=condition_dict,
|
||||
match_expressions=match_expressions,
|
||||
order_by=order_by,
|
||||
offset=0, limit=top_n,
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
])
|
||||
return list(docs.values())
|
||||
|
||||
@staticmethod
|
||||
def calculate_message_size(message: dict):
|
||||
return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"])
|
||||
|
||||
@classmethod
|
||||
def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["memory_id", "content", "content_embed"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=2000*len(memory_ids),
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"])
|
||||
size_dict = {}
|
||||
for doc in docs.values():
|
||||
if size_dict.get(doc["memory_id"]):
|
||||
size_dict[doc["memory_id"]] += cls.calculate_message_size(doc)
|
||||
else:
|
||||
size_dict[doc["memory_id"]] = cls.calculate_message_size(doc)
|
||||
return size_dict
|
||||
|
||||
@classmethod
|
||||
def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int):
|
||||
select_fields = ["message_id", "content", "content_embed"]
|
||||
_index_name = index_name(uid)
|
||||
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
|
||||
message_list = settings.msgStoreConn.get_fields(res, select_fields)
|
||||
current_size = 0
|
||||
ids_to_remove = []
|
||||
for message in message_list:
|
||||
if current_size < size_to_delete:
|
||||
current_size += cls.calculate_message_size(message)
|
||||
ids_to_remove.append(message["message_id"])
|
||||
else:
|
||||
return ids_to_remove, current_size
|
||||
if current_size >= size_to_delete:
|
||||
return ids_to_remove, current_size
|
||||
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["memory_id", "content", "content_embed"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=2000,
|
||||
index_names=[_index_name], memory_ids=[memory_id], agg_fields=[]
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, select_fields)
|
||||
for doc in docs.values():
|
||||
if current_size < size_to_delete:
|
||||
current_size += cls.calculate_message_size(doc)
|
||||
ids_to_remove.append(doc["memory_id"])
|
||||
else:
|
||||
return ids_to_remove, current_size
|
||||
return ids_to_remove, current_size
|
||||
|
||||
@classmethod
|
||||
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
|
||||
index = index_name(uid)
|
||||
doc_id = f'{memory_id}_{message_id}'
|
||||
return settings.msgStoreConn.get(doc_id, index, [memory_id])
|
||||
|
||||
@classmethod
|
||||
def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]):
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("message_id")
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["message_id"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=1,
|
||||
index_names=index_names, memory_ids=memory_ids,
|
||||
agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, ["message_id"])
|
||||
if not docs:
|
||||
return 1
|
||||
else:
|
||||
latest_msg = list(docs.values())[0]
|
||||
return int(latest_msg["message_id"])
|
||||
185
memory/services/query.py
Normal file
185
memory/services/query.py
Normal file
@ -0,0 +1,185 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import numpy as np
|
||||
from common.query_base import QueryBase
|
||||
from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr
|
||||
from common.float_utils import get_float
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
def get_vector(txt, emb_mdl, topk=10, similarity=0.1):
|
||||
if isinstance(similarity, str) and len(similarity) > 0:
|
||||
try:
|
||||
similarity = float(similarity)
|
||||
except Exception as e:
|
||||
logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1")
|
||||
similarity = 0.1
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||
embedding_data = [get_float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
|
||||
|
||||
class MsgTextQuery(QueryBase):
|
||||
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
self.query_fields = [
|
||||
"content"
|
||||
]
|
||||
|
||||
def question(self, txt, tbl="messages", min_match: float=0.6):
|
||||
original_query = txt
|
||||
txt = MsgTextQuery.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = MsgTextQuery.rmWWW(txt)
|
||||
|
||||
if not self.is_chinese(txt):
|
||||
txt = self.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||
syns = []
|
||||
for tk, w in tks_w[:256]:
|
||||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
continue
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"original_query": original_query}
|
||||
), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = self.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
continue
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(self.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
if qs:
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
if not query:
|
||||
query = otxt
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}
|
||||
), keywords
|
||||
return None, keywords
|
||||
0
memory/utils/__init__.py
Normal file
0
memory/utils/__init__.py
Normal file
494
memory/utils/es_conn.py
Normal file
494
memory/utils/es_conn.py
Normal file
@ -0,0 +1,494 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
|
||||
import copy
|
||||
from elasticsearch import NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
@staticmethod
|
||||
def convert_field_name(field_name: str) -> str:
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "content":
|
||||
return "content_ltks"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
@staticmethod
|
||||
def map_message_to_es_fields(message: dict) -> dict:
|
||||
"""
|
||||
Map message dictionary fields to Elasticsearch document/Infinity fields.
|
||||
|
||||
:param message: A dictionary containing message details.
|
||||
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
|
||||
"""
|
||||
storage_doc = {
|
||||
"id": message.get("id"),
|
||||
"message_id": message["message_id"],
|
||||
"message_type_kwd": message["message_type"],
|
||||
"source_id": message["source_id"],
|
||||
"memory_id": message["memory_id"],
|
||||
"user_id": message["user_id"],
|
||||
"agent_id": message["agent_id"],
|
||||
"session_id": message["session_id"],
|
||||
"valid_at": message["valid_at"],
|
||||
"invalid_at": message["invalid_at"],
|
||||
"forget_at": message["forget_at"],
|
||||
"status_int": 1 if message["status"] else 0,
|
||||
"zone_id": message.get("zone_id", 0),
|
||||
"content_ltks": message["content"],
|
||||
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
|
||||
}
|
||||
return storage_doc
|
||||
|
||||
@staticmethod
|
||||
def get_message_from_es_doc(doc: dict) -> dict:
|
||||
"""
|
||||
Convert an Elasticsearch/Infinity document back to a message dictionary.
|
||||
|
||||
:param doc: A dictionary representing the Elasticsearch/Infinity document.
|
||||
:return: A dictionary formatted as a message.
|
||||
"""
|
||||
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
|
||||
message = {
|
||||
"message_id": doc["message_id"],
|
||||
"message_type": doc["message_type_kwd"],
|
||||
"source_id": doc["source_id"] if doc["source_id"] else None,
|
||||
"memory_id": doc["memory_id"],
|
||||
"user_id": doc.get("user_id", ""),
|
||||
"agent_id": doc["agent_id"],
|
||||
"session_id": doc["session_id"],
|
||||
"zone_id": doc.get("zone_id", 0),
|
||||
"valid_at": doc["valid_at"],
|
||||
"invalid_at": doc.get("invalid_at", "-"),
|
||||
"forget_at": doc.get("forget_at", "-"),
|
||||
"status": bool(int(doc["status_int"])),
|
||||
"content": doc.get("content_ltks", ""),
|
||||
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
|
||||
}
|
||||
if doc.get("id"):
|
||||
message["id"] = doc["id"]
|
||||
return message
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
memory_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
hide_forgotten: bool = True
|
||||
):
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
assert "_id" not in condition
|
||||
bool_query = Q("bool", must=[], must_not=[])
|
||||
if hide_forgotten:
|
||||
# filter not forget
|
||||
bool_query.must_not.append(Q("exists", field="forget_at"))
|
||||
|
||||
condition["memory_id"] = memory_ids
|
||||
for k, v in condition.items():
|
||||
if k == "session_id" and v:
|
||||
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
|
||||
continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
s = Search()
|
||||
vector_similarity_weight = 0.5
|
||||
for m in match_expressions:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
match_expressions[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
for m in match_expressions:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bool_query.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert (bool_query is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
s = s.knn(self.convert_field_name(m.vector_column_name),
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bool_query.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bool_query and rank_feature:
|
||||
for fld, sc in rank_feature.items():
|
||||
if fld != PAGERANK_FLD:
|
||||
fld = f"{TAG_FLD}.{fld}"
|
||||
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
|
||||
if bool_query:
|
||||
s = s.query(bool_query)
|
||||
for field in highlight_fields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if order_by:
|
||||
orders = list()
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field.endswith("_int") or field.endswith("_flt"):
|
||||
order_info = {"order": order, "unmapped_type": "float"}
|
||||
else:
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
|
||||
if agg_fields:
|
||||
for fld in agg_fields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
|
||||
bool_query = Q("bool", must_not=[])
|
||||
bool_query.must_not.append(Q("term", forget_at=None))
|
||||
bool_query.filter.append(Q("term", memory_id=memory_id))
|
||||
# from old to new
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("forget_at")
|
||||
# build search
|
||||
s = Search()
|
||||
s = s.query(bool_query)
|
||||
s = s.sort(order_by)
|
||||
s = s[:limit]
|
||||
q = s.to_dict()
|
||||
# search
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=index_name,
|
||||
id=doc_id, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
message = res["_source"]
|
||||
message["id"] = doc_id
|
||||
return self.get_message_from_es_doc(message)
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
|
||||
raise e
|
||||
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy_raw = copy.deepcopy(d)
|
||||
d_copy = self.map_message_to_es_fields(d_copy_raw)
|
||||
d_copy["memory_id"] = memory_id
|
||||
meta_id = d_copy.pop("id", "")
|
||||
operations.append(
|
||||
{"index": {"_index": index_name, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
res = []
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = []
|
||||
r = self.es.bulk(index=index_name, operations=operations,
|
||||
refresh=False, timeout="60s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for item in r["items"]:
|
||||
for action in ["create", "delete", "index", "update"]:
|
||||
if action in item and "error" in item[action]:
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
res.append(str(e))
|
||||
self.logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
|
||||
update_dict.pop("id", None)
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
if "id" in condition_dict and isinstance(condition_dict["id"], str):
|
||||
# update specific single document
|
||||
message_id = condition_dict["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
for k in update_dict.keys():
|
||||
if "feas" != k.split("_")[-1]:
|
||||
continue
|
||||
try:
|
||||
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=index_name, id=message_id, doc=update_dict)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
|
||||
break
|
||||
return False
|
||||
|
||||
# update unspecific maybe-multiple documents
|
||||
bool_query = Q("bool")
|
||||
for k, v in condition_dict.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if k == "exists":
|
||||
bool_query.filter.append(Q("exists", field=v))
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
params = {}
|
||||
for k, v in update_dict.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.remove('{v}');")
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
|
||||
params[f"p_{kk}"] = vv
|
||||
continue
|
||||
if k == "add":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
|
||||
params[f"pp_{kk}"] = vv.strip()
|
||||
continue
|
||||
if (not isinstance(k, str) or not v) and k != "status_int":
|
||||
continue
|
||||
if isinstance(v, str):
|
||||
v = re.sub(r"(['\n\r]|\\.)", " ", v)
|
||||
params[f"pp_{k}"] = v
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
elif isinstance(v, int) or isinstance(v, float):
|
||||
scripts.append(f"ctx._source.{k}={v};")
|
||||
elif isinstance(v, list):
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
|
||||
else:
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=index_name).using(
|
||||
self.es).query(bool_query)
|
||||
ubq = ubq.script(source="".join(scripts), params=params)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
break
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
|
||||
assert "_id" not in condition
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
if "id" in condition_dict:
|
||||
message_ids = condition_dict["id"]
|
||||
if not isinstance(message_ids, list):
|
||||
message_ids = [message_ids]
|
||||
if not message_ids: # when message_ids is empty, delete all
|
||||
qry = Q("match_all")
|
||||
else:
|
||||
qry = Q("ids", values=message_ids)
|
||||
else:
|
||||
qry = Q("bool")
|
||||
for k, v in condition_dict.items():
|
||||
if k == "exists":
|
||||
qry.filter.append(Q("exists", field=v))
|
||||
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
qry.must_not.append(Q("exists", field=vv))
|
||||
|
||||
elif isinstance(v, list):
|
||||
qry.must.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=index_name,
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for doc in self._get_source(res):
|
||||
message = self.get_message_from_es_doc(doc)
|
||||
m = {}
|
||||
for n, v in message.items():
|
||||
if n not in fields:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
m[n] = v
|
||||
continue
|
||||
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
|
||||
m[n] = v
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
m[n] = str(v)
|
||||
else:
|
||||
m[n] = v
|
||||
|
||||
if m:
|
||||
res_fields[doc["id"]] = m
|
||||
return res_fields
|
||||
467
memory/utils/infinity_conn.py
Normal file
467
memory/utils/infinity_conn.py
Normal file
@ -0,0 +1,467 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
from infinity.common import InfinityException, SortType
|
||||
from infinity.errors import ErrorCode
|
||||
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.infinity_conn_base import InfinityConnectionBase
|
||||
from common.time_utils import date_string_to_timestamp
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(InfinityConnectionBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mapping_file_name = "message_infinity_mapping.json"
|
||||
|
||||
"""
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def field_keyword(field_name: str):
|
||||
# no keywords right now
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def convert_message_field_to_infinity(field_name: str):
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
@staticmethod
|
||||
def convert_infinity_field_to_message(field_name: str):
|
||||
if field_name.startswith("message_type"):
|
||||
return "message_type"
|
||||
if field_name.startswith("status"):
|
||||
return "status"
|
||||
if re.match(r"q_\d+_vec", field_name):
|
||||
return "content_embed"
|
||||
return field_name
|
||||
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
|
||||
|
||||
@staticmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
tokens = field_weight_str.split("^")
|
||||
field = tokens[0]
|
||||
if field == "content":
|
||||
field = "content@ft_contentm_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
@staticmethod
|
||||
def convert_condition_and_order_field(field_name: str):
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "valid_at":
|
||||
return "valid_at_flt"
|
||||
case "invalid_at":
|
||||
return "invalid_at_flt"
|
||||
case "forget_at":
|
||||
return "forget_at_flt"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
memory_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
hide_forgotten: bool = True,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
"""
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
if hide_forgotten:
|
||||
condition.update({"must_not": {"exists": "forget_at_flt"}})
|
||||
output = select_fields.copy()
|
||||
output = self.convert_select_fields(output)
|
||||
if agg_fields is None:
|
||||
agg_fields = []
|
||||
for essential_field in ["id"] + agg_fields:
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
score_func = "score()"
|
||||
score_column = "SCORE"
|
||||
break
|
||||
if not score_func:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchDenseExpr):
|
||||
score_func = "similarity()"
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if match_expressions:
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
output.append(PAGERANK_FLD)
|
||||
output = [f for f in output if f != "_score"]
|
||||
if limit <= 0:
|
||||
# ElasticSearch default limit is 10000
|
||||
limit = 10000
|
||||
|
||||
# Prepare expressions common to all tables
|
||||
filter_cond = None
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
|
||||
table_found = False
|
||||
for indexName in index_names:
|
||||
for mem_id in memory_ids:
|
||||
table_name = f"{indexName}_{mem_id}"
|
||||
try:
|
||||
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
|
||||
table_found = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if filter_cond and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
if filter_cond:
|
||||
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
||||
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
|
||||
|
||||
# Add rank_feature support
|
||||
if rank_feature and "rank_features" not in matchExpr.extra_options:
|
||||
# Convert rank_feature dict to Infinity's rank_features string format
|
||||
# Format: "field^feature_name^weight,field^feature_name^weight"
|
||||
rank_features_list = []
|
||||
for feature_name, weight in rank_feature.items():
|
||||
# Use TAG_FLD as the field containing rank features
|
||||
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
|
||||
if rank_features_list:
|
||||
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
|
||||
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
similarity = matchExpr.extra_options.get("similarity")
|
||||
if similarity:
|
||||
matchExpr.extra_options["threshold"] = similarity
|
||||
del matchExpr.extra_options["similarity"]
|
||||
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
|
||||
total_hits_count = 0
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in index_names:
|
||||
for memory_id in memory_ids:
|
||||
table_name = f"{indexName}_{memory_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(output)
|
||||
if len(match_expressions) > 0:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
fields,
|
||||
matchExpr.matching_text,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
builder = builder.match_dense(
|
||||
matchExpr.vector_column_name,
|
||||
matchExpr.embedding_data,
|
||||
matchExpr.embedding_data_type,
|
||||
matchExpr.distance_type,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
|
||||
else:
|
||||
if filter_cond and len(filter_cond) > 0:
|
||||
builder.filter(filter_cond)
|
||||
if order_by.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
|
||||
df_list.append(mem_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = self.concat_dataframes(df_list, output)
|
||||
if match_expressions:
|
||||
res["_score"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
|
||||
res = res.head(limit)
|
||||
self.logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
|
||||
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("forget_at_flt")
|
||||
# query
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
|
||||
builder = table_instance.output(output_fields)
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
builder.filter(filter_cond)
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(0).limit(limit)
|
||||
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
|
||||
res = self.concat_dataframes(mem_res, output_fields)
|
||||
res.head(limit)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(memory_ids, list)
|
||||
table_list = list()
|
||||
for memoryId in memory_ids:
|
||||
table_name = f"{index_name}_{memoryId}"
|
||||
table_list.append(table_name)
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
|
||||
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
|
||||
df_list.append(mem_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = self.concat_dataframes(df_list, ["id"])
|
||||
fields = set(res.columns.tolist())
|
||||
res_fields = self.get_fields(res, list(fields))
|
||||
return res_fields.get(message_id, None)
|
||||
|
||||
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||
if not documents:
|
||||
return []
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
vector_size = int(len(documents[0]["content_embed"]))
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
# src/common/status.cppm, kTableNotExist = 3022
|
||||
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
|
||||
raise
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.create_idx(index_name, memory_id, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
embedding_columns = []
|
||||
table_columns = table_instance.show_columns().rows()
|
||||
for n, ty, _, _ in table_columns:
|
||||
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
|
||||
if not r:
|
||||
continue
|
||||
embedding_columns.append((n, int(r.group(1))))
|
||||
|
||||
docs = copy.deepcopy(documents)
|
||||
for d in docs:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
for k, v in list(d.items()):
|
||||
field_name = self.convert_message_field_to_infinity(k)
|
||||
if field_name in ["valid_at", "invalid_at", "forget_at"]:
|
||||
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
|
||||
if v is None:
|
||||
d[field_name] = ""
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
d[k] = v
|
||||
elif k == "memory_id":
|
||||
if isinstance(d[k], list):
|
||||
d[k] = d[k][0] # since d[k] is a list, but we need a str
|
||||
elif field_name == "content_embed":
|
||||
d[f"q_{vector_size}_vec"] = d["content_embed"]
|
||||
d.pop("content_embed")
|
||||
else:
|
||||
d[field_name] = v
|
||||
if k != field_name:
|
||||
d.pop(k)
|
||||
|
||||
for n, vs in embedding_columns:
|
||||
if n in d:
|
||||
continue
|
||||
d[n] = [0] * vs
|
||||
ids = ["'{}'".format(d["id"]) for d in docs]
|
||||
str_ids = ", ".join(ids)
|
||||
str_filter = f"id IN ({str_ids})"
|
||||
table_instance.delete(str_filter)
|
||||
table_instance.insert(docs)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
columns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
columns[n] = (ty, de)
|
||||
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
|
||||
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
|
||||
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
|
||||
date_floats = {}
|
||||
for k, v in update_dict.items():
|
||||
if k in ["valid_at", "invalid_at", "forget_at"]:
|
||||
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
update_dict[k] = "###".join(v)
|
||||
else:
|
||||
update_dict[k] = v
|
||||
elif k == "memory_id":
|
||||
if isinstance(update_dict[k], list):
|
||||
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
|
||||
else:
|
||||
update_dict[k] = v
|
||||
if date_floats:
|
||||
update_dict.update(date_floats)
|
||||
|
||||
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
|
||||
table_instance.update(filter, update_dict)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
return {}
|
||||
fields_all = fields.copy()
|
||||
fields_all.append("id")
|
||||
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
|
||||
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
|
||||
none_columns = [col for col in fields_all if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
res2.drop_duplicates(subset=["id"], inplace=True)
|
||||
|
||||
for column in list(res2.columns):
|
||||
k = column.lower()
|
||||
if self.field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
else:
|
||||
pass
|
||||
for column in ["content"]:
|
||||
if column in res2:
|
||||
del res2[column]
|
||||
for column in none_columns:
|
||||
res2[column] = None
|
||||
|
||||
res_dict = res2.set_index("id").to_dict(orient="index")
|
||||
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}
|
||||
37
memory/utils/msg_util.py
Normal file
37
memory/utils/msg_util.py
Normal file
@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
|
||||
|
||||
def get_json_result_from_llm_response(response_str: str) -> dict:
|
||||
"""
|
||||
Parse the LLM response string to extract JSON content.
|
||||
The function looks for the first and last curly braces to identify the JSON part.
|
||||
If parsing fails, it returns an empty dictionary.
|
||||
|
||||
:param response_str: The response string from the LLM.
|
||||
:return: A dictionary parsed from the JSON content in the response.
|
||||
"""
|
||||
try:
|
||||
clean_str = response_str.strip()
|
||||
if clean_str.startswith('```json'):
|
||||
clean_str = clean_str[7:] # Remove the starting ```json
|
||||
if clean_str.endswith('```'):
|
||||
clean_str = clean_str[:-3] # Remove the ending ```
|
||||
|
||||
return json.loads(clean_str.strip())
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return {}
|
||||
201
memory/utils/prompt_util.py
Normal file
201
memory/utils/prompt_util.py
Normal file
@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Optional, List
|
||||
|
||||
from common.constants import MemoryType
|
||||
from common.time_utils import current_timestamp
|
||||
|
||||
class PromptAssembler:
|
||||
|
||||
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
|
||||
You are an expert at analyzing conversations to extract structured memory.
|
||||
|
||||
{type_specific_instructions}
|
||||
|
||||
|
||||
**OUTPUT REQUIREMENTS:**
|
||||
1. Output MUST be valid JSON
|
||||
2. Follow the specified output format exactly
|
||||
3. Each extracted item MUST have: content, valid_at, invalid_at
|
||||
4. Timestamps in {timestamp_format} format
|
||||
5. Only extract memory types specified above
|
||||
6. Maximum {max_items} items per type
|
||||
"""
|
||||
|
||||
TYPE_INSTRUCTIONS = {
|
||||
MemoryType.SEMANTIC.name.lower(): """
|
||||
**EXTRACT SEMANTIC KNOWLEDGE:**
|
||||
- Universal facts, definitions, concepts, relationships
|
||||
- Time-invariant, generally true information
|
||||
- Examples: "The capital of France is Paris", "Water boils at 100°C"
|
||||
|
||||
**Timestamp Rules for Semantic Knowledge:**
|
||||
- valid_at: When the fact became true (e.g., law enactment, discovery)
|
||||
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
|
||||
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
|
||||
""",
|
||||
|
||||
MemoryType.EPISODIC.name.lower(): """
|
||||
**EXTRACT EPISODIC KNOWLEDGE:**
|
||||
- Specific experiences, events, personal stories
|
||||
- Time-bound, person-specific, contextual
|
||||
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
|
||||
|
||||
**Timestamp Rules for Episodic Knowledge:**
|
||||
- valid_at: Event start/occurrence time
|
||||
- invalid_at: Event end time or empty if instantaneous
|
||||
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
|
||||
""",
|
||||
|
||||
MemoryType.PROCEDURAL.name.lower(): """
|
||||
**EXTRACT PROCEDURAL KNOWLEDGE:**
|
||||
- Processes, methods, step-by-step instructions
|
||||
- Goal-oriented, actionable, often includes conditions
|
||||
- Examples: "To reset password, click...", "Debugging steps: 1)..."
|
||||
|
||||
**Timestamp Rules for Procedural Knowledge:**
|
||||
- valid_at: When procedure becomes valid/effective
|
||||
- invalid_at: When it expires/becomes obsolete or empty if current
|
||||
- For version-specific: use release dates
|
||||
- For best practices: invalid_at = ""
|
||||
"""
|
||||
}
|
||||
|
||||
OUTPUT_TEMPLATES = {
|
||||
MemoryType.SEMANTIC.name.lower(): """
|
||||
"semantic": [
|
||||
{
|
||||
"content": "Clear factual statement",
|
||||
"valid_at": "timestamp or empty",
|
||||
"invalid_at": "timestamp or empty"
|
||||
}
|
||||
]
|
||||
""",
|
||||
|
||||
MemoryType.EPISODIC.name.lower(): """
|
||||
"episodic": [
|
||||
{
|
||||
"content": "Narrative event description",
|
||||
"valid_at": "event start timestamp",
|
||||
"invalid_at": "event end timestamp or empty"
|
||||
}
|
||||
]
|
||||
""",
|
||||
|
||||
MemoryType.PROCEDURAL.name.lower(): """
|
||||
"procedural": [
|
||||
{
|
||||
"content": "Actionable instructions",
|
||||
"valid_at": "procedure effective timestamp",
|
||||
"invalid_at": "procedure expiration timestamp or empty"
|
||||
}
|
||||
]
|
||||
"""
|
||||
}
|
||||
|
||||
BASE_USER_PROMPT = """
|
||||
**CONVERSATION:**
|
||||
{conversation}
|
||||
|
||||
**CONVERSATION TIME:** {conversation_time}
|
||||
**CURRENT TIME:** {current_time}
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def assemble_system_prompt(cls, config: dict) -> str:
|
||||
types_to_extract = cls._get_types_to_extract(config["memory_type"])
|
||||
|
||||
type_instructions = cls._generate_type_instructions(types_to_extract)
|
||||
|
||||
output_format = cls._generate_output_format(types_to_extract)
|
||||
|
||||
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
|
||||
type_specific_instructions=type_instructions,
|
||||
timestamp_format=config.get("timestamp_format", "ISO 8601"),
|
||||
max_items=config.get("max_items_per_type", 5)
|
||||
)
|
||||
|
||||
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
|
||||
|
||||
examples = cls._generate_examples(types_to_extract)
|
||||
if examples:
|
||||
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
|
||||
|
||||
return full_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
|
||||
types = set()
|
||||
for rt in requested_types:
|
||||
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
|
||||
types.add(rt)
|
||||
return list(types)
|
||||
|
||||
@classmethod
|
||||
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
|
||||
target_types = set(types_to_extract)
|
||||
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
|
||||
return "\n".join(instructions)
|
||||
|
||||
@classmethod
|
||||
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
|
||||
target_types = set(types_to_extract)
|
||||
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
|
||||
return ",\n".join(output_parts)
|
||||
|
||||
@staticmethod
|
||||
def _generate_examples(types_to_extract: list[str]) -> str:
|
||||
examples = []
|
||||
|
||||
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Semantic Example:**
|
||||
Input: "Python lists are mutable and support various operations."
|
||||
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
|
||||
""")
|
||||
|
||||
if MemoryType.EPISODIC.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Episodic Example:**
|
||||
Input: "I deployed the new feature yesterday afternoon."
|
||||
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
|
||||
""")
|
||||
|
||||
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Procedural Example:**
|
||||
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
|
||||
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
|
||||
""")
|
||||
|
||||
return "\n".join(examples)
|
||||
|
||||
@classmethod
|
||||
def assemble_user_prompt(
|
||||
cls,
|
||||
conversation: str,
|
||||
conversation_time: Optional[str] = None,
|
||||
current_time: Optional[str] = None
|
||||
) -> str:
|
||||
return cls.BASE_USER_PROMPT.format(
|
||||
conversation=conversation,
|
||||
conversation_time=conversation_time or "Not specified",
|
||||
current_time=current_time or current_timestamp(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_raw_user_prompt(cls):
|
||||
return cls.BASE_USER_PROMPT
|
||||
@ -46,7 +46,7 @@ dependencies = [
|
||||
"groq==0.9.0",
|
||||
"grpcio-status==1.67.1",
|
||||
"html-text==0.6.2",
|
||||
"infinity-sdk==0.6.11",
|
||||
"infinity-sdk==0.6.13",
|
||||
"infinity-emb>=0.0.66,<0.0.67",
|
||||
"jira==3.10.5",
|
||||
"json-repair==0.35.0",
|
||||
|
||||
@ -77,9 +77,9 @@ class Benchmark:
|
||||
def init_index(self, vector_size: int):
|
||||
if self.initialized_index:
|
||||
return
|
||||
if settings.docStoreConn.indexExist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.deleteIdx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.createIdx(self.index_name, self.kb_id, vector_size)
|
||||
if settings.docStoreConn.index_exist(self.index_name, self.kb_id):
|
||||
settings.docStoreConn.delete_idx(self.index_name, self.kb_id)
|
||||
settings.docStoreConn.create_idx(self.index_name, self.kb_id, vector_size)
|
||||
self.initialized_index = True
|
||||
|
||||
def ms_marco_index(self, file_path, index_name):
|
||||
|
||||
@ -37,7 +37,6 @@ from rag.app.naive import Docx
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.parser.schema import ParserFromUpstream
|
||||
from rag.llm.cv_model import Base as VLM
|
||||
from rag.nlp import attach_media_context
|
||||
from rag.utils.base64_image import image2id
|
||||
|
||||
|
||||
@ -86,8 +85,6 @@ class ParserParam(ProcessParamBase):
|
||||
"pdf",
|
||||
],
|
||||
"output_format": "json",
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
},
|
||||
"spreadsheet": {
|
||||
"parse_method": "deepdoc", # deepdoc/tcadp_parser
|
||||
@ -97,8 +94,6 @@ class ParserParam(ProcessParamBase):
|
||||
"xlsx",
|
||||
"csv",
|
||||
],
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
},
|
||||
"word": {
|
||||
"suffix": [
|
||||
@ -106,14 +101,10 @@ class ParserParam(ProcessParamBase):
|
||||
"docx",
|
||||
],
|
||||
"output_format": "json",
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
},
|
||||
"text&markdown": {
|
||||
"suffix": ["md", "markdown", "mdx", "txt"],
|
||||
"output_format": "json",
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
},
|
||||
"slides": {
|
||||
"parse_method": "deepdoc", # deepdoc/tcadp_parser
|
||||
@ -122,8 +113,6 @@ class ParserParam(ProcessParamBase):
|
||||
"ppt",
|
||||
],
|
||||
"output_format": "json",
|
||||
"table_context_size": 0,
|
||||
"image_context_size": 0,
|
||||
},
|
||||
"image": {
|
||||
"parse_method": "ocr",
|
||||
@ -357,11 +346,6 @@ class Parser(ProcessBase):
|
||||
elif layout == "table":
|
||||
b["doc_type_kwd"] = "table"
|
||||
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
bboxes = attach_media_context(bboxes, table_ctx, image_ctx)
|
||||
|
||||
if conf.get("output_format") == "json":
|
||||
self.set_output("json", bboxes)
|
||||
if conf.get("output_format") == "markdown":
|
||||
@ -436,11 +420,6 @@ class Parser(ProcessBase):
|
||||
if table:
|
||||
result.append({"text": table, "doc_type_kwd": "table"})
|
||||
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
result = attach_media_context(result, table_ctx, image_ctx)
|
||||
|
||||
self.set_output("json", result)
|
||||
|
||||
elif output_format == "markdown":
|
||||
@ -476,11 +455,6 @@ class Parser(ProcessBase):
|
||||
sections = [{"text": section[0], "image": section[1]} for section in sections if section]
|
||||
sections.extend([{"text": tb, "image": None, "doc_type_kwd": "table"} for ((_, tb), _) in tbls])
|
||||
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
sections = attach_media_context(sections, table_ctx, image_ctx)
|
||||
|
||||
self.set_output("json", sections)
|
||||
elif conf.get("output_format") == "markdown":
|
||||
markdown_text = docx_parser.to_markdown(name, binary=blob)
|
||||
@ -536,11 +510,6 @@ class Parser(ProcessBase):
|
||||
if table:
|
||||
result.append({"text": table, "doc_type_kwd": "table"})
|
||||
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
result = attach_media_context(result, table_ctx, image_ctx)
|
||||
|
||||
self.set_output("json", result)
|
||||
else:
|
||||
# Default DeepDOC parser (supports .pptx format)
|
||||
@ -554,10 +523,6 @@ class Parser(ProcessBase):
|
||||
# json
|
||||
assert conf.get("output_format") == "json", "have to be json for ppt"
|
||||
if conf.get("output_format") == "json":
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
sections = attach_media_context(sections, table_ctx, image_ctx)
|
||||
self.set_output("json", sections)
|
||||
|
||||
def _markdown(self, name, blob):
|
||||
@ -597,11 +562,6 @@ class Parser(ProcessBase):
|
||||
|
||||
json_results.append(json_result)
|
||||
|
||||
table_ctx = conf.get("table_context_size", 0) or 0
|
||||
image_ctx = conf.get("image_context_size", 0) or 0
|
||||
if table_ctx or image_ctx:
|
||||
json_results = attach_media_context(json_results, table_ctx, image_ctx)
|
||||
|
||||
self.set_output("json", json_results)
|
||||
else:
|
||||
self.set_output("text", "\n".join([section_text for section_text, _ in sections]))
|
||||
|
||||
@ -23,7 +23,7 @@ from rag.utils.base64_image import id2image, image2id
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
from rag.flow.base import ProcessBase, ProcessParamBase
|
||||
from rag.flow.splitter.schema import SplitterFromUpstream
|
||||
from rag.nlp import naive_merge, naive_merge_with_images
|
||||
from rag.nlp import attach_media_context, naive_merge, naive_merge_with_images
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -34,11 +34,15 @@ class SplitterParam(ProcessParamBase):
|
||||
self.delimiters = ["\n"]
|
||||
self.overlapped_percent = 0
|
||||
self.children_delimiters = []
|
||||
self.table_context_size = 0
|
||||
self.image_context_size = 0
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.delimiters, "Delimiters.")
|
||||
self.check_positive_integer(self.chunk_token_size, "Chunk token size.")
|
||||
self.check_decimal_float(self.overlapped_percent, "Overlapped percentage: [0, 1)")
|
||||
self.check_nonnegative_number(self.table_context_size, "Table context size.")
|
||||
self.check_nonnegative_number(self.image_context_size, "Image context size.")
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
@ -103,8 +107,18 @@ class Splitter(ProcessBase):
|
||||
return
|
||||
|
||||
# json
|
||||
json_result = from_upstream.json_result or []
|
||||
if self._param.table_context_size or self._param.image_context_size:
|
||||
for ck in json_result:
|
||||
if "image" not in ck and ck.get("img_id") and not (isinstance(ck.get("text"), str) and ck.get("text").strip()):
|
||||
ck["image"] = True
|
||||
attach_media_context(json_result, self._param.table_context_size, self._param.image_context_size)
|
||||
for ck in json_result:
|
||||
if ck.get("image") is True:
|
||||
del ck["image"]
|
||||
|
||||
sections, section_images = [], []
|
||||
for o in from_upstream.json_result or []:
|
||||
for o in json_result:
|
||||
sections.append((o.get("text", ""), o.get("position_tag", "")))
|
||||
section_images.append(id2image(o.get("img_id"), partial(settings.STORAGE_IMPL.get, tenant_id=self._canvas._tenant_id)))
|
||||
|
||||
|
||||
@ -105,6 +105,9 @@ class Tokenizer(ProcessBase):
|
||||
|
||||
async def _invoke(self, **kwargs):
|
||||
try:
|
||||
chunks = kwargs.get("chunks")
|
||||
kwargs["chunks"] = [c for c in chunks if c is not None]
|
||||
|
||||
from_upstream = TokenizerFromUpstream.model_validate(kwargs)
|
||||
except Exception as e:
|
||||
self.set_output("_ERROR", f"Input error: {str(e)}")
|
||||
|
||||
@ -348,7 +348,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
d["doc_type_kwd"] = "table"
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
if d["content_with_weight"].find("<tr>") < 0:
|
||||
d["doc_type_kwd"] = "image"
|
||||
if poss:
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
@ -361,7 +362,8 @@ def tokenize_table(tbls, doc, eng, batch_size=10):
|
||||
d["doc_type_kwd"] = "table"
|
||||
if img:
|
||||
d["image"] = img
|
||||
d["doc_type_kwd"] = "image"
|
||||
if d["content_with_weight"].find("<tr>") < 0:
|
||||
d["doc_type_kwd"] = "image"
|
||||
add_positions(d, poss)
|
||||
res.append(d)
|
||||
return res
|
||||
|
||||
@ -19,11 +19,12 @@ import json
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from rag.utils.doc_store_conn import MatchTextExpr
|
||||
from common.query_base import QueryBase
|
||||
from common.doc_store.doc_store_base import MatchTextExpr
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
class FulltextQueryer:
|
||||
class FulltextQueryer(QueryBase):
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
@ -37,64 +38,19 @@ class FulltextQueryer:
|
||||
"content_sm_ltks",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
def question(self, txt, tbl="qa", min_match: float = 0.6):
|
||||
original_query = txt
|
||||
txt = FulltextQueryer.add_space_between_eng_zh(txt)
|
||||
txt = self.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
|
||||
if not self.is_chinese(txt):
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
@ -138,7 +94,7 @@ class FulltextQueryer:
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = FulltextQueryer.rmWWW(txt)
|
||||
txt = self.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
@ -164,7 +120,7 @@ class FulltextQueryer:
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [FulltextQueryer.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
@ -172,7 +128,7 @@ class FulltextQueryer:
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
@ -181,7 +137,7 @@ class FulltextQueryer:
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
@ -199,7 +155,7 @@ class FulltextQueryer:
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(FulltextQueryer.sub_special_char(s))
|
||||
% rag_tokenizer.tokenize(self.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
@ -264,10 +220,10 @@ class FulltextQueryer:
|
||||
keywords = [f'"{k.strip()}"' for k in keywords]
|
||||
for tk, w in sorted(tks_w, key=lambda x: x[1] * -1)[:keywords_topn]:
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [FulltextQueryer.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
tk = FulltextQueryer.sub_special_char(tk)
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
|
||||
@ -24,7 +24,7 @@ from dataclasses import dataclass
|
||||
from rag.prompts.generator import relevant_chunks_with_toc
|
||||
from rag.nlp import rag_tokenizer, query
|
||||
import numpy as np
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.doc_store_base import MatchDenseExpr, FusionExpr, OrderByExpr, DocStoreConnection
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
@ -155,7 +155,7 @@ class Dealer:
|
||||
kwds.add(kk)
|
||||
|
||||
logging.debug(f"TOTAL: {total}")
|
||||
ids = self.dataStore.get_chunk_ids(res)
|
||||
ids = self.dataStore.get_doc_ids(res)
|
||||
keywords = list(kwds)
|
||||
highlight = self.dataStore.get_highlight(res, keywords, "content_with_weight")
|
||||
aggs = self.dataStore.get_aggregation(res, "docnm_kwd")
|
||||
@ -545,7 +545,7 @@ class Dealer:
|
||||
return res
|
||||
|
||||
def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000):
|
||||
if not self.dataStore.indexExist(index_name(tenant_id), kb_ids[0]):
|
||||
if not self.dataStore.index_exist(index_name(tenant_id), kb_ids[0]):
|
||||
return []
|
||||
res = self.dataStore.search([], [], {}, [], OrderByExpr(), 0, 0, index_name(tenant_id), kb_ids, ["tag_kwd"])
|
||||
return self.dataStore.get_aggregation(res, "tag_kwd")
|
||||
|
||||
@ -136,6 +136,19 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
return knowledges
|
||||
|
||||
|
||||
def memory_prompt(message_list, max_tokens):
|
||||
used_token_count = 0
|
||||
content_list = []
|
||||
for message in message_list:
|
||||
current_content_tokens = num_tokens_from_string(message["content"])
|
||||
if used_token_count + current_content_tokens > max_tokens * 0.97:
|
||||
logging.warning(f"Not all the retrieval into prompt: {len(content_list)}/{len(message_list)}")
|
||||
break
|
||||
content_list.append(message["content"])
|
||||
used_token_count += current_content_tokens
|
||||
return content_list
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
@ -326,7 +339,7 @@ def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
}
|
||||
for idx, tool in enumerate(tools_description):
|
||||
name = tool["function"]["name"]
|
||||
desc[f"{name}_{idx}"] = tool
|
||||
desc[name] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
|
||||
|
||||
@ -94,7 +94,7 @@ This content will NOT be shown to the user.
|
||||
## Step 2: Structured Reflection (MANDATORY before `complete_task`)
|
||||
|
||||
### Context
|
||||
- Goal: {{ task_analysis }}
|
||||
- Goal: Reflect on the current task based on the full conversation context
|
||||
- Executed tool calls so far (if any): reflect from conversation history
|
||||
|
||||
### Task Complexity Assessment
|
||||
|
||||
@ -395,9 +395,9 @@ async def build_chunks(task, progress_callback):
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
metadata = {}
|
||||
for ck in cks:
|
||||
metadata = update_metadata_to(metadata, ck["metadata_obj"])
|
||||
del ck["metadata_obj"]
|
||||
for doc in docs:
|
||||
metadata = update_metadata_to(metadata, doc["metadata_obj"])
|
||||
del doc["metadata_obj"]
|
||||
if metadata:
|
||||
e, doc = DocumentService.get_by_id(task["doc_id"])
|
||||
if e:
|
||||
@ -506,7 +506,7 @@ def build_TOC(task, docs, progress_callback):
|
||||
|
||||
def init_kb(row, vector_size: int):
|
||||
idxnm = search.index_name(row["tenant_id"])
|
||||
return settings.docStoreConn.createIdx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
return settings.docStoreConn.create_idx(idxnm, row.get("kb_id", ""), vector_size)
|
||||
|
||||
|
||||
async def embedding(docs, mdl, parser_config=None, callback=None):
|
||||
|
||||
@ -82,7 +82,7 @@ def id2image(image_id:str|None, storage_get_func: partial):
|
||||
return
|
||||
bkt, nm = image_id.split("-")
|
||||
try:
|
||||
blob = storage_get_func(bucket=bkt, filename=nm)
|
||||
blob = storage_get_func(bucket=bkt, fnm=nm)
|
||||
if not blob:
|
||||
return
|
||||
return Image.open(BytesIO(blob))
|
||||
|
||||
@ -14,194 +14,92 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
|
||||
import copy
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search, Index
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import convert_bytes
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.doc_store.doc_store_base import MatchTextExpr, OrderByExpr, MatchExpr, MatchDenseExpr, FusionExpr
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
logger = logging.getLogger('ragflow.es_conn')
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.info = {}
|
||||
logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
v = v["number"].split(".")[0]
|
||||
if int(v) < 8:
|
||||
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
health_dict = dict(self.es.cluster.health())
|
||||
health_dict["type"] = "elasticsearch"
|
||||
return health_dict
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
return True
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=indexName,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception:
|
||||
logger.exception("ESConnection.createIndex error %s" % (indexName))
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
try:
|
||||
self.es.indices.delete(index=indexName, allow_no_indices=True)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.exception("ESConnection.deleteIdx error %s" % (indexName))
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
s = Index(indexName, self.es)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
return s.exists()
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
break
|
||||
return False
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
assert "_id" not in condition
|
||||
|
||||
bqry = Q("bool", must=[])
|
||||
condition["kb_id"] = knowledgebaseIds
|
||||
bool_query = Q("bool", must=[])
|
||||
condition["kb_id"] = knowledgebase_ids
|
||||
for k, v in condition.items():
|
||||
if k == "available_int":
|
||||
if v == 0:
|
||||
bqry.filter.append(Q("range", available_int={"lt": 1}))
|
||||
bool_query.filter.append(Q("range", available_int={"lt": 1}))
|
||||
else:
|
||||
bqry.filter.append(
|
||||
bool_query.filter.append(
|
||||
Q("bool", must_not=Q("range", available_int={"lt": 1})))
|
||||
continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
|
||||
s = Search()
|
||||
vector_similarity_weight = 0.5
|
||||
for m in matchExprs:
|
||||
for m in match_expressions:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(matchExprs[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
matchExprs[2], FusionExpr)
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
match_expressions[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
for m in matchExprs:
|
||||
for m in match_expressions:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bqry.must.append(Q("query_string", fields=m.fields,
|
||||
bool_query.must.append(Q("query_string", fields=m.fields,
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bqry.boost = 1.0 - vector_similarity_weight
|
||||
bool_query.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert (bqry is not None)
|
||||
assert (bool_query is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
@ -209,24 +107,24 @@ class ESConnection(DocStoreConnection):
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bqry.to_dict(),
|
||||
filter=bool_query.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bqry and rank_feature:
|
||||
if bool_query and rank_feature:
|
||||
for fld, sc in rank_feature.items():
|
||||
if fld != PAGERANK_FLD:
|
||||
fld = f"{TAG_FLD}.{fld}"
|
||||
bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
|
||||
if bqry:
|
||||
s = s.query(bqry)
|
||||
for field in highlightFields:
|
||||
if bool_query:
|
||||
s = s.query(bool_query)
|
||||
for field in highlight_fields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if orderBy:
|
||||
if order_by:
|
||||
orders = list()
|
||||
for field, order in orderBy.fields:
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field in ["page_num_int", "top_int"]:
|
||||
order_info = {"order": order, "unmapped_type": "float",
|
||||
@ -237,19 +135,19 @@ class ESConnection(DocStoreConnection):
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
|
||||
for fld in aggFields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
if agg_fields:
|
||||
for fld in agg_fields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} query: " + json.dumps(q))
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=indexNames,
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
@ -257,55 +155,37 @@ class ESConnection(DocStoreConnection):
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
logger.debug(f"ESConnection.search {str(indexNames)} res: " + str(res))
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.search {str(indexNames)} query: " + str(q) + str(e))
|
||||
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=(indexName),
|
||||
id=chunkId, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
chunk = res["_source"]
|
||||
chunk["id"] = chunkId
|
||||
return chunk
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get({chunkId}) got exception")
|
||||
raise e
|
||||
logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy = copy.deepcopy(d)
|
||||
d_copy["kb_id"] = knowledgebaseId
|
||||
d_copy["kb_id"] = knowledgebase_id
|
||||
meta_id = d_copy.pop("id", "")
|
||||
operations.append(
|
||||
{"index": {"_index": indexName, "_id": meta_id}})
|
||||
{"index": {"_index": index_name, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
|
||||
res = []
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = []
|
||||
r = self.es.bulk(index=(indexName), operations=operations,
|
||||
r = self.es.bulk(index=index_name, operations=operations,
|
||||
refresh=False, timeout="60s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
@ -316,58 +196,58 @@ class ESConnection(DocStoreConnection):
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
res.append(str(e))
|
||||
logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
self.logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
doc = copy.deepcopy(newValue)
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
doc.pop("id", None)
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
condition["kb_id"] = knowledgebase_id
|
||||
if "id" in condition and isinstance(condition["id"], str):
|
||||
# update specific single document
|
||||
chunkId = condition["id"]
|
||||
chunk_id = condition["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
for k in doc.keys():
|
||||
if "feas" != k.split("_")[-1]:
|
||||
continue
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, script=f"ctx._source.remove(\"{k}\");")
|
||||
self.es.update(index=index_name, id=chunk_id, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
logger.exception(f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
self.logger.exception(f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=indexName, id=chunkId, doc=doc)
|
||||
self.es.update(index=index_name, id=chunk_id, doc=doc)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"ESConnection.update(index={indexName}, id={chunkId}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: "+str(e))
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={chunk_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
|
||||
break
|
||||
return False
|
||||
|
||||
# update unspecific maybe-multiple documents
|
||||
bqry = Q("bool")
|
||||
bool_query = Q("bool")
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if k == "exists":
|
||||
bqry.filter.append(Q("exists", field=v))
|
||||
bool_query.filter.append(Q("exists", field=v))
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bqry.filter.append(Q("terms", **{k: v}))
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bqry.filter.append(Q("term", **{k: v}))
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
params = {}
|
||||
for k, v in newValue.items():
|
||||
for k, v in new_value.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.remove('{v}');")
|
||||
@ -397,8 +277,8 @@ class ESConnection(DocStoreConnection):
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=indexName).using(
|
||||
self.es).query(bqry)
|
||||
index=index_name).using(
|
||||
self.es).query(bool_query)
|
||||
ubq = ubq.script(source="".join(scripts), params=params)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
@ -409,19 +289,18 @@ class ESConnection(DocStoreConnection):
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
break
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
qry = None
|
||||
def delete(self, condition: dict, index_name: str, knowledgebase_id: str) -> int:
|
||||
assert "_id" not in condition
|
||||
condition["kb_id"] = knowledgebaseId
|
||||
condition["kb_id"] = knowledgebase_id
|
||||
if "id" in condition:
|
||||
chunk_ids = condition["id"]
|
||||
if not isinstance(chunk_ids, list):
|
||||
@ -448,21 +327,21 @@ class ESConnection(DocStoreConnection):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=indexName,
|
||||
index=index_name,
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
self.logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
@ -471,27 +350,11 @@ class ESConnection(DocStoreConnection):
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_chunk_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for d in self.__getSource(res):
|
||||
for d in self._get_source(res):
|
||||
m = {n: d.get(n) for n in fields if d.get(n) is not None}
|
||||
for n, v in m.items():
|
||||
if isinstance(v, list):
|
||||
@ -508,124 +371,3 @@ class ESConnection(DocStoreConnection):
|
||||
if m:
|
||||
res_fields[d["id"]] = m
|
||||
return res_fields
|
||||
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
hlts = d.get("highlight")
|
||||
if not hlts:
|
||||
continue
|
||||
txt = "...".join([a for a in list(hlts.items())[0][1]])
|
||||
if not is_english(txt.split()):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][fieldnm]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
ans[d["_id"]] = "...".join(txts) if txts else "...".join([a for a in list(hlts.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
agg_field = "aggs_" + fieldnm
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
bkts = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in bkts]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
logger.debug(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
logger.debug(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
||||
return None
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
||||
"""
|
||||
raw_stats = self.es.cluster.stats()
|
||||
logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
||||
try:
|
||||
res = {
|
||||
'cluster_name': raw_stats['cluster_name'],
|
||||
'status': raw_stats['status']
|
||||
}
|
||||
indices_status = raw_stats['indices']
|
||||
res.update({
|
||||
'indices': indices_status['count'],
|
||||
'indices_shards': indices_status['shards']['total']
|
||||
})
|
||||
doc_info = indices_status['docs']
|
||||
res.update({
|
||||
'docs': doc_info['count'],
|
||||
'docs_deleted': doc_info['deleted']
|
||||
})
|
||||
store_info = indices_status['store']
|
||||
res.update({
|
||||
'store_size': convert_bytes(store_info['size_in_bytes']),
|
||||
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
||||
})
|
||||
mappings_info = indices_status['mappings']
|
||||
res.update({
|
||||
'mappings_fields': mappings_info['total_field_count'],
|
||||
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
||||
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
||||
})
|
||||
node_info = raw_stats['nodes']
|
||||
res.update({
|
||||
'nodes': node_info['count']['total'],
|
||||
'nodes_version': node_info['versions'],
|
||||
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
||||
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
||||
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
||||
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
||||
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
||||
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
||||
})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
||||
return None
|
||||
|
||||
@ -14,365 +14,125 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import copy
|
||||
import infinity
|
||||
from infinity.common import ConflictType, InfinityException, SortType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.common import InfinityException, SortType
|
||||
from infinity.errors import ErrorCode
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
from rag.utils.doc_store_conn import (
|
||||
DocStoreConnection,
|
||||
MatchExpr,
|
||||
MatchTextExpr,
|
||||
MatchDenseExpr,
|
||||
FusionExpr,
|
||||
OrderByExpr,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("ragflow.infinity_conn")
|
||||
|
||||
|
||||
def field_keyword(field_name: str):
|
||||
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
|
||||
if field_name == "source_id" or (field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd", "question_kwd"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
def convert_select_fields(output_fields: list[str]) -> list[str]:
|
||||
for i, field in enumerate(output_fields):
|
||||
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
output_fields[i] = "docnm"
|
||||
elif field in ["important_kwd", "important_tks"]:
|
||||
output_fields[i] = "important_keywords"
|
||||
elif field in ["question_kwd", "question_tks"]:
|
||||
output_fields[i] = "questions"
|
||||
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
output_fields[i] = "content"
|
||||
elif field in ["authors_tks", "authors_sm_tks"]:
|
||||
output_fields[i] = "authors"
|
||||
return list(set(output_fields))
|
||||
|
||||
def convert_matching_field(field_weightstr: str) -> str:
|
||||
tokens = field_weightstr.split("^")
|
||||
field = tokens[0]
|
||||
if field == "docnm_kwd" or field == "title_tks":
|
||||
field = "docnm@ft_docnm_rag_coarse"
|
||||
elif field == "title_sm_tks":
|
||||
field = "docnm@ft_docnm_rag_fine"
|
||||
elif field == "important_kwd":
|
||||
field = "important_keywords@ft_important_keywords_rag_coarse"
|
||||
elif field == "important_tks":
|
||||
field = "important_keywords@ft_important_keywords_rag_fine"
|
||||
elif field == "question_kwd":
|
||||
field = "questions@ft_questions_rag_coarse"
|
||||
elif field == "question_tks":
|
||||
field = "questions@ft_questions_rag_fine"
|
||||
elif field == "content_with_weight" or field == "content_ltks":
|
||||
field = "content@ft_content_rag_coarse"
|
||||
elif field == "content_sm_ltks":
|
||||
field = "content@ft_content_rag_fine"
|
||||
elif field == "authors_tks":
|
||||
field = "authors@ft_authors_rag_coarse"
|
||||
elif field == "authors_sm_tks":
|
||||
field = "authors@ft_authors_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
def list2str(lst: str|list, sep: str = " ") -> str:
|
||||
if isinstance(lst, str):
|
||||
return lst
|
||||
return sep.join(lst)
|
||||
|
||||
|
||||
def equivalent_condition_to_str(condition: dict, table_instance=None) -> str | None:
|
||||
assert "_id" not in condition
|
||||
clmns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
|
||||
def exists(cln):
|
||||
nonlocal clmns
|
||||
assert cln in clmns, f"'{cln}' should be in '{clmns}'."
|
||||
ty, de = clmns[cln]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
return f" {cln}!='{de}' "
|
||||
return f"{cln}!={de}"
|
||||
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"filter_fulltext('{convert_matching_field(k)}', '{item}')")
|
||||
if inCond:
|
||||
strInCond = " or ".join(inCond)
|
||||
strInCond = f"({strInCond})"
|
||||
cond.append(strInCond)
|
||||
else:
|
||||
cond.append(f"filter_fulltext('{convert_matching_field(k)}', '{v}')")
|
||||
elif isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
cond.append("NOT (%s)" % exists(vv))
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
elif k == "exists":
|
||||
cond.append(exists(v))
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], selectFields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in selectFields:
|
||||
if field_name == "score()": # Workaround: fix schema is changed to score()
|
||||
schema.append("SCORE")
|
||||
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
|
||||
schema.append("SIMILARITY")
|
||||
else:
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.infinity_conn_base import InfinityConnectionBase
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(DocStoreConnection):
|
||||
def __init__(self):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = None
|
||||
logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
|
||||
for _ in range(24):
|
||||
try:
|
||||
connPool = ConnectionPool(infinity_uri, max_size=4)
|
||||
inf_conn = connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = connPool
|
||||
connPool.release_conn(inf_conn)
|
||||
break
|
||||
connPool.release_conn(inf_conn)
|
||||
logger.warn(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
if self.connPool is None:
|
||||
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
|
||||
logger.error(msg)
|
||||
raise Exception(msg)
|
||||
logger.info(f"Infinity {infinity_uri} is healthy.")
|
||||
|
||||
def _migrate_db(self, inf_conn):
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
table_names = inf_db.list_tables().table_names
|
||||
for table_name in table_names:
|
||||
inf_table = inf_db.get_table(table_name)
|
||||
index_names = inf_table.list_indexes().index_names
|
||||
if "q_vec_idx" not in index_names:
|
||||
# Skip tables not created by me
|
||||
continue
|
||||
column_names = inf_table.show_columns()["name"]
|
||||
column_names = set(column_names)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_name in column_names:
|
||||
continue
|
||||
res = inf_table.add_columns({field_name: field_info})
|
||||
assert res.error_code == infinity.ErrorCode.OK
|
||||
logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
class InfinityConnection(InfinityConnectionBase):
|
||||
|
||||
"""
|
||||
Database operations
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", "infinity_mapping.json")
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vectorSize}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vectorSize},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY created table {table_name}, vector size {vectorSize}")
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
@staticmethod
|
||||
def field_keyword(field_name: str):
|
||||
# Treat "*_kwd" tag-like columns as keyword lists except knowledge_graph_kwd; source_id is also keyword-like.
|
||||
if field_name == "source_id" or (
|
||||
field_name.endswith("_kwd") and field_name not in ["knowledge_graph_kwd", "docnm_kwd", "important_kwd",
|
||||
"question_kwd"]):
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"INFINITY indexExist {str(e)}")
|
||||
return False
|
||||
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
for i, field in enumerate(output_fields):
|
||||
if field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
output_fields[i] = "docnm"
|
||||
elif field in ["important_kwd", "important_tks"]:
|
||||
output_fields[i] = "important_keywords"
|
||||
elif field in ["question_kwd", "question_tks"]:
|
||||
output_fields[i] = "questions"
|
||||
elif field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
output_fields[i] = "content"
|
||||
elif field in ["authors_tks", "authors_sm_tks"]:
|
||||
output_fields[i] = "authors"
|
||||
return list(set(output_fields))
|
||||
|
||||
@staticmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
tokens = field_weight_str.split("^")
|
||||
field = tokens[0]
|
||||
if field == "docnm_kwd" or field == "title_tks":
|
||||
field = "docnm@ft_docnm_rag_coarse"
|
||||
elif field == "title_sm_tks":
|
||||
field = "docnm@ft_docnm_rag_fine"
|
||||
elif field == "important_kwd":
|
||||
field = "important_keywords@ft_important_keywords_rag_coarse"
|
||||
elif field == "important_tks":
|
||||
field = "important_keywords@ft_important_keywords_rag_fine"
|
||||
elif field == "question_kwd":
|
||||
field = "questions@ft_questions_rag_coarse"
|
||||
elif field == "question_tks":
|
||||
field = "questions@ft_questions_rag_fine"
|
||||
elif field == "content_with_weight" or field == "content_ltks":
|
||||
field = "content@ft_content_rag_coarse"
|
||||
elif field == "content_sm_ltks":
|
||||
field = "content@ft_content_rag_fine"
|
||||
elif field == "authors_tks":
|
||||
field = "authors@ft_authors_rag_coarse"
|
||||
elif field == "authors_sm_tks":
|
||||
field = "authors@ft_authors_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str | list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str | list[str],
|
||||
knowledgebase_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
"""
|
||||
if isinstance(indexNames, str):
|
||||
indexNames = indexNames.split(",")
|
||||
assert isinstance(indexNames, list) and len(indexNames) > 0
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
output = selectFields.copy()
|
||||
output = convert_select_fields(output)
|
||||
for essential_field in ["id"] + aggFields:
|
||||
output = select_fields.copy()
|
||||
output = self.convert_select_fields(output)
|
||||
if agg_fields is None:
|
||||
agg_fields = []
|
||||
for essential_field in ["id"] + agg_fields:
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
score_func = "score()"
|
||||
score_column = "SCORE"
|
||||
break
|
||||
if not score_func:
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchDenseExpr):
|
||||
score_func = "similarity()"
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if matchExprs:
|
||||
if match_expressions:
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
@ -387,11 +147,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
table_found = False
|
||||
for indexName in indexNames:
|
||||
for kb_id in knowledgebaseIds:
|
||||
for indexName in index_names:
|
||||
for kb_id in knowledgebase_ids:
|
||||
table_name = f"{indexName}_{kb_id}"
|
||||
try:
|
||||
filter_cond = equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
table_found = True
|
||||
break
|
||||
except Exception:
|
||||
@ -399,14 +159,14 @@ class InfinityConnection(DocStoreConnection):
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
logger.error(f"No valid tables found for indexNames {indexNames} and knowledgebaseIds {knowledgebaseIds}")
|
||||
self.logger.error(f"No valid tables found for indexNames {index_names} and knowledgebaseIds {knowledgebase_ids}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in matchExprs:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if filter_cond and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
matchExpr.fields = [convert_matching_field(field) for field in matchExpr.fields]
|
||||
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
if filter_cond:
|
||||
@ -430,7 +190,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
@ -441,16 +201,16 @@ class InfinityConnection(DocStoreConnection):
|
||||
if similarity:
|
||||
matchExpr.extra_options["threshold"] = similarity
|
||||
del matchExpr.extra_options["similarity"]
|
||||
logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
if matchExpr.method == "weighted_sum":
|
||||
# The default is "minmax" which gives a zero score for the last doc.
|
||||
matchExpr.fusion_params["normalize"] = "atan"
|
||||
logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
|
||||
order_by_expr_list = list()
|
||||
if orderBy.fields:
|
||||
for order_field in orderBy.fields:
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field[0], SortType.Asc))
|
||||
else:
|
||||
@ -458,8 +218,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
total_hits_count = 0
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in indexNames:
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
for indexName in index_names:
|
||||
for knowledgebaseId in knowledgebase_ids:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
@ -467,8 +227,8 @@ class InfinityConnection(DocStoreConnection):
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(output)
|
||||
if len(matchExprs) > 0:
|
||||
for matchExpr in matchExprs:
|
||||
if len(match_expressions) > 0:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
@ -491,53 +251,52 @@ class InfinityConnection(DocStoreConnection):
|
||||
else:
|
||||
if filter_cond and len(filter_cond) > 0:
|
||||
builder.filter(filter_cond)
|
||||
if orderBy.fields:
|
||||
if order_by.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
kb_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, output)
|
||||
if matchExprs:
|
||||
res = self.concat_dataframes(df_list, output)
|
||||
if match_expressions:
|
||||
res["_score"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
|
||||
res = res.head(limit)
|
||||
logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
self.logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(knowledgebaseIds, list)
|
||||
assert isinstance(knowledgebase_ids, list)
|
||||
table_list = list()
|
||||
for knowledgebaseId in knowledgebaseIds:
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
for knowledgebaseId in knowledgebase_ids:
|
||||
table_name = f"{index_name}_{knowledgebaseId}"
|
||||
table_list.append(table_name)
|
||||
table_instance = None
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
self.logger.warning(f"Table not found: {table_name}, this dataset isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunkId}'").to_df()
|
||||
logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
kb_res, _ = table_instance.output(["*"]).filter(f"id = '{chunk_id}'").to_df()
|
||||
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(kb_res)}")
|
||||
df_list.append(kb_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = concat_dataframes(df_list, ["id"])
|
||||
res = self.concat_dataframes(df_list, ["id"])
|
||||
fields = set(res.columns.tolist())
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "question_kwd", "question_tks","content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks"]:
|
||||
fields.add(field)
|
||||
res_fields = self.get_fields(res, list(fields))
|
||||
return res_fields.get(chunkId, None)
|
||||
return res_fields.get(chunk_id, None)
|
||||
|
||||
def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_name = f"{index_name}_{knowledgebase_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
@ -553,7 +312,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
break
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.createIdx(indexName, knowledgebaseId, vector_size)
|
||||
self.create_idx(index_name, knowledgebase_id, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
@ -574,12 +333,12 @@ class InfinityConnection(DocStoreConnection):
|
||||
d["docnm"] = v
|
||||
elif k == "title_kwd":
|
||||
if not d.get("docnm_kwd"):
|
||||
d["docnm"] = list2str(v)
|
||||
d["docnm"] = self.list2str(v)
|
||||
elif k == "title_sm_tks":
|
||||
if not d.get("docnm_kwd"):
|
||||
d["docnm"] = list2str(v)
|
||||
d["docnm"] = self.list2str(v)
|
||||
elif k == "important_kwd":
|
||||
d["important_keywords"] = list2str(v)
|
||||
d["important_keywords"] = self.list2str(v)
|
||||
elif k == "important_tks":
|
||||
if not d.get("important_kwd"):
|
||||
d["important_keywords"] = v
|
||||
@ -597,11 +356,11 @@ class InfinityConnection(DocStoreConnection):
|
||||
if not d.get("authors_tks"):
|
||||
d["authors"] = v
|
||||
elif k == "question_kwd":
|
||||
d["questions"] = list2str(v, "\n")
|
||||
d["questions"] = self.list2str(v, "\n")
|
||||
elif k == "question_tks":
|
||||
if not d.get("question_kwd"):
|
||||
d["questions"] = list2str(v)
|
||||
elif field_keyword(k):
|
||||
d["questions"] = self.list2str(v)
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
@ -637,15 +396,15 @@ class InfinityConnection(DocStoreConnection):
|
||||
# logger.info(f"InfinityConnection.insert {json.dumps(documents)}")
|
||||
table_instance.insert(docs)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, knowledgebase_id: str) -> bool:
|
||||
# if 'position_int' in newValue:
|
||||
# logger.info(f"update position_int: {newValue['position_int']}")
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
table_name = f"{index_name}_{knowledgebase_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
# if "exists" in condition:
|
||||
# del condition["exists"]
|
||||
@ -654,57 +413,57 @@ class InfinityConnection(DocStoreConnection):
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
clmns[n] = (ty, de)
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
filter = self.equivalent_condition_to_str(condition, table_instance)
|
||||
removeValue = {}
|
||||
for k, v in list(newValue.items()):
|
||||
for k, v in list(new_value.items()):
|
||||
if k == "docnm_kwd":
|
||||
newValue["docnm"] = list2str(v)
|
||||
new_value["docnm"] = self.list2str(v)
|
||||
elif k == "title_kwd":
|
||||
if not newValue.get("docnm_kwd"):
|
||||
newValue["docnm"] = list2str(v)
|
||||
if not new_value.get("docnm_kwd"):
|
||||
new_value["docnm"] = self.list2str(v)
|
||||
elif k == "title_sm_tks":
|
||||
if not newValue.get("docnm_kwd"):
|
||||
newValue["docnm"] = v
|
||||
if not new_value.get("docnm_kwd"):
|
||||
new_value["docnm"] = v
|
||||
elif k == "important_kwd":
|
||||
newValue["important_keywords"] = list2str(v)
|
||||
new_value["important_keywords"] = self.list2str(v)
|
||||
elif k == "important_tks":
|
||||
if not newValue.get("important_kwd"):
|
||||
newValue["important_keywords"] = v
|
||||
if not new_value.get("important_kwd"):
|
||||
new_value["important_keywords"] = v
|
||||
elif k == "content_with_weight":
|
||||
newValue["content"] = v
|
||||
new_value["content"] = v
|
||||
elif k == "content_ltks":
|
||||
if not newValue.get("content_with_weight"):
|
||||
newValue["content"] = v
|
||||
if not new_value.get("content_with_weight"):
|
||||
new_value["content"] = v
|
||||
elif k == "content_sm_ltks":
|
||||
if not newValue.get("content_with_weight"):
|
||||
newValue["content"] = v
|
||||
if not new_value.get("content_with_weight"):
|
||||
new_value["content"] = v
|
||||
elif k == "authors_tks":
|
||||
newValue["authors"] = v
|
||||
new_value["authors"] = v
|
||||
elif k == "authors_sm_tks":
|
||||
if not newValue.get("authors_tks"):
|
||||
newValue["authors"] = v
|
||||
if not new_value.get("authors_tks"):
|
||||
new_value["authors"] = v
|
||||
elif k == "question_kwd":
|
||||
newValue["questions"] = "\n".join(v)
|
||||
new_value["questions"] = "\n".join(v)
|
||||
elif k == "question_tks":
|
||||
if not newValue.get("question_kwd"):
|
||||
newValue["questions"] = list2str(v)
|
||||
elif field_keyword(k):
|
||||
if not new_value.get("question_kwd"):
|
||||
new_value["questions"] = self.list2str(v)
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
newValue[k] = "###".join(v)
|
||||
new_value[k] = "###".join(v)
|
||||
else:
|
||||
newValue[k] = v
|
||||
new_value[k] = v
|
||||
elif re.search(r"_feas$", k):
|
||||
newValue[k] = json.dumps(v)
|
||||
new_value[k] = json.dumps(v)
|
||||
elif k == "kb_id":
|
||||
if isinstance(newValue[k], list):
|
||||
newValue[k] = newValue[k][0] # since d[k] is a list, but we need a str
|
||||
if isinstance(new_value[k], list):
|
||||
new_value[k] = new_value[k][0] # since d[k] is a list, but we need a str
|
||||
elif k == "position_int":
|
||||
assert isinstance(v, list)
|
||||
arr = [num for row in v for num in row]
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
new_value[k] = "_".join(f"{num:08x}" for num in arr)
|
||||
elif k in ["page_num_int", "top_int"]:
|
||||
assert isinstance(v, list)
|
||||
newValue[k] = "_".join(f"{num:08x}" for num in v)
|
||||
new_value[k] = "_".join(f"{num:08x}" for num in v)
|
||||
elif k == "remove":
|
||||
if isinstance(v, str):
|
||||
assert v in clmns, f"'{v}' should be in '{clmns}'."
|
||||
@ -712,22 +471,22 @@ class InfinityConnection(DocStoreConnection):
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
newValue[v] = de
|
||||
new_value[v] = de
|
||||
else:
|
||||
for kk, vv in v.items():
|
||||
removeValue[kk] = vv
|
||||
del newValue[k]
|
||||
del new_value[k]
|
||||
else:
|
||||
newValue[k] = v
|
||||
new_value[k] = v
|
||||
for k in ["docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", "question_kwd", "question_tks"]:
|
||||
if k in newValue:
|
||||
del newValue[k]
|
||||
if k in new_value:
|
||||
del new_value[k]
|
||||
|
||||
remove_opt = {} # "[k,new_value]": [id_to_update, ...]
|
||||
if removeValue:
|
||||
col_to_remove = list(removeValue.keys())
|
||||
row_to_opt = table_instance.output(col_to_remove + ["id"]).filter(filter).to_df()
|
||||
logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
self.logger.debug(f"INFINITY search table {str(table_name)}, filter {filter}, result: {str(row_to_opt[0])}")
|
||||
row_to_opt = self.get_fields(row_to_opt, col_to_remove)
|
||||
for id, old_v in row_to_opt.items():
|
||||
for k, remove_v in removeValue.items():
|
||||
@ -740,78 +499,53 @@ class InfinityConnection(DocStoreConnection):
|
||||
else:
|
||||
remove_opt[kv_key].append(id)
|
||||
|
||||
logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {newValue}.")
|
||||
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
|
||||
for update_kv, ids in remove_opt.items():
|
||||
k, v = json.loads(update_kv)
|
||||
table_instance.update(filter + " AND id in ({0})".format(",".join([f"'{id}'" for id in ids])), {k: "###".join(v)})
|
||||
|
||||
table_instance.update(filter, newValue)
|
||||
table_instance.update(filter, new_value)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{indexName}_{knowledgebaseId}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
|
||||
return 0
|
||||
filter = equivalent_condition_to_str(condition, table_instance)
|
||||
logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def get_chunk_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
return {}
|
||||
fieldsAll = fields.copy()
|
||||
fieldsAll.append("id")
|
||||
fieldsAll = set(fieldsAll)
|
||||
fields_all = fields.copy()
|
||||
fields_all.append("id")
|
||||
fields_all = set(fields_all)
|
||||
if "docnm" in res.columns:
|
||||
for field in ["docnm_kwd", "title_tks", "title_sm_tks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["docnm"]
|
||||
if "important_keywords" in res.columns:
|
||||
if "important_kwd" in fieldsAll:
|
||||
if "important_kwd" in fields_all:
|
||||
res["important_kwd"] = res["important_keywords"].apply(lambda v: v.split())
|
||||
if "important_tks" in fieldsAll:
|
||||
if "important_tks" in fields_all:
|
||||
res["important_tks"] = res["important_keywords"]
|
||||
if "questions" in res.columns:
|
||||
if "question_kwd" in fieldsAll:
|
||||
if "question_kwd" in fields_all:
|
||||
res["question_kwd"] = res["questions"].apply(lambda v: v.splitlines())
|
||||
if "question_tks" in fieldsAll:
|
||||
if "question_tks" in fields_all:
|
||||
res["question_tks"] = res["questions"]
|
||||
if "content" in res.columns:
|
||||
for field in ["content_with_weight", "content_ltks", "content_sm_ltks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["content"]
|
||||
if "authors" in res.columns:
|
||||
for field in ["authors_tks", "authors_sm_tks"]:
|
||||
if field in fieldsAll:
|
||||
if field in fields_all:
|
||||
res[field] = res["authors"]
|
||||
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]: col for col in fieldsAll if col.lower() in column_map}
|
||||
none_columns = [col for col in fieldsAll if col.lower() not in column_map]
|
||||
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
|
||||
none_columns = [col for col in fields_all if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
@ -819,7 +553,7 @@ class InfinityConnection(DocStoreConnection):
|
||||
|
||||
for column in list(res2.columns):
|
||||
k = column.lower()
|
||||
if field_keyword(k):
|
||||
if self.field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
elif re.search(r"_feas$", k):
|
||||
res2[column] = res2[column].apply(lambda v: json.loads(v) if v else {})
|
||||
@ -844,95 +578,3 @@ class InfinityConnection(DocStoreConnection):
|
||||
res2[column] = None
|
||||
|
||||
return res2.set_index("id").to_dict(orient="index")
|
||||
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], fieldnm: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if fieldnm not in res:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[fieldnm][i]
|
||||
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
|
||||
ans[id] = txt
|
||||
continue
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txts = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
if is_english([t]):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
else:
|
||||
for w in sorted(keywords, key=len, reverse=True):
|
||||
t = re.sub(
|
||||
re.escape(w),
|
||||
f"<em>{w}</em>",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txts.append(t)
|
||||
if txts:
|
||||
ans[id] = "...".join(txts)
|
||||
else:
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fieldnm: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
# Extract DataFrame from result
|
||||
if isinstance(res, tuple):
|
||||
df, _ = res
|
||||
else:
|
||||
df = res
|
||||
|
||||
if df.empty or fieldnm not in df.columns:
|
||||
return []
|
||||
|
||||
# Aggregate tag counts
|
||||
tag_counter = Counter()
|
||||
|
||||
for value in df[fieldnm]:
|
||||
if pd.isna(value) or not value:
|
||||
continue
|
||||
|
||||
# Handle different tag formats
|
||||
if isinstance(value, str):
|
||||
# Split by ### for tag_kwd field or comma for other formats
|
||||
if fieldnm == "tag_kwd" and "###" in value:
|
||||
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
|
||||
else:
|
||||
# Try comma separation as fallback
|
||||
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
|
||||
|
||||
for tag in tags:
|
||||
if tag: # Only count non-empty tags
|
||||
tag_counter[tag] += 1
|
||||
elif isinstance(value, list):
|
||||
# Handle list format
|
||||
for tag in value:
|
||||
if tag and isinstance(tag, str):
|
||||
tag_counter[tag.strip()] += 1
|
||||
|
||||
# Return as list of [tag, count] pairs, sorted by count descending
|
||||
return [[tag, count] for tag, count in tag_counter.most_common()]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@ -37,9 +37,8 @@ from common import settings
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common.decorator import singleton
|
||||
from common.float_utils import get_float
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, MatchDenseExpr
|
||||
from rag.nlp import rag_tokenizer
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
|
||||
MatchDenseExpr
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))
|
||||
@ -497,7 +496,7 @@ class OBConnection(DocStoreConnection):
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
return "oceanbase"
|
||||
|
||||
def health(self) -> dict:
|
||||
@ -553,7 +552,7 @@ class OBConnection(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
vector_field_name = f"q_{vectorSize}_vec"
|
||||
vector_index_name = f"{vector_field_name}_idx"
|
||||
|
||||
@ -604,7 +603,7 @@ class OBConnection(DocStoreConnection):
|
||||
# always refresh metadata to make sure it contains the latest table structure
|
||||
self.client.refresh_metadata([indexName])
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
@ -615,7 +614,7 @@ class OBConnection(DocStoreConnection):
|
||||
except Exception as e:
|
||||
raise Exception(f"OBConnection.deleteIndex error: {str(e)}")
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
return self._check_table_exists_cached(indexName)
|
||||
|
||||
def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
|
||||
@ -1500,7 +1499,7 @@ class OBConnection(DocStoreConnection):
|
||||
def get_total(self, res) -> int:
|
||||
return res.total
|
||||
|
||||
def get_chunk_ids(self, res) -> list[str]:
|
||||
def get_doc_ids(self, res) -> list[str]:
|
||||
return [row["id"] for row in res.chunks]
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
|
||||
@ -26,8 +26,7 @@ from opensearchpy import UpdateByQuery, Q, Search, Index
|
||||
from opensearchpy import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \
|
||||
FusionExpr
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common import settings
|
||||
@ -79,7 +78,7 @@ class OSConnection(DocStoreConnection):
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
return "opensearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
@ -91,8 +90,8 @@ class OSConnection(DocStoreConnection):
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.indexExist(indexName, knowledgebaseId):
|
||||
def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
if self.index_exist(indexName, knowledgebaseId):
|
||||
return True
|
||||
try:
|
||||
from opensearchpy.client import IndicesClient
|
||||
@ -101,7 +100,7 @@ class OSConnection(DocStoreConnection):
|
||||
except Exception:
|
||||
logger.exception("OSConnection.createIndex error %s" % (indexName))
|
||||
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, indexName: str, knowledgebaseId: str):
|
||||
if len(knowledgebaseId) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
@ -112,7 +111,7 @@ class OSConnection(DocStoreConnection):
|
||||
except Exception:
|
||||
logger.exception("OSConnection.deleteIdx error %s" % (indexName))
|
||||
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
def index_exist(self, indexName: str, knowledgebaseId: str = None) -> bool:
|
||||
s = Index(indexName, self.os)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
@ -460,7 +459,7 @@ class OSConnection(DocStoreConnection):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_chunk_ids(self, res):
|
||||
def get_doc_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def __getSource(self, res):
|
||||
|
||||
@ -272,6 +272,49 @@ class RedisDB:
|
||||
self.__open__()
|
||||
return None
|
||||
|
||||
def generate_auto_increment_id(self, key_prefix: str = "id_generator", namespace: str = "default", increment: int = 1, ensure_minimum: int | None = None) -> int:
|
||||
redis_key = f"{key_prefix}:{namespace}"
|
||||
|
||||
try:
|
||||
# Use pipeline for atomicity
|
||||
pipe = self.REDIS.pipeline()
|
||||
|
||||
# Check if key exists
|
||||
pipe.exists(redis_key)
|
||||
|
||||
# Get/Increment
|
||||
if ensure_minimum is not None:
|
||||
# Ensure minimum value
|
||||
pipe.get(redis_key)
|
||||
results = pipe.execute()
|
||||
|
||||
if results[0] == 0: # Key doesn't exist
|
||||
start_id = max(1, ensure_minimum)
|
||||
pipe.set(redis_key, start_id)
|
||||
pipe.execute()
|
||||
return start_id
|
||||
else:
|
||||
current = int(results[1])
|
||||
if current < ensure_minimum:
|
||||
pipe.set(redis_key, ensure_minimum)
|
||||
pipe.execute()
|
||||
return ensure_minimum
|
||||
|
||||
# Increment operation
|
||||
next_id = self.REDIS.incrby(redis_key, increment)
|
||||
|
||||
# If it's the first time, set a reasonable initial value
|
||||
if next_id == increment:
|
||||
self.REDIS.set(redis_key, 1 + increment)
|
||||
return 1 + increment
|
||||
|
||||
return next_id
|
||||
|
||||
except Exception as e:
|
||||
logging.warning("RedisDB.generate_auto_increment_id got exception: " + str(e))
|
||||
self.__open__()
|
||||
return -1
|
||||
|
||||
def transaction(self, key, value, exp=3600):
|
||||
try:
|
||||
pipeline = self.REDIS.pipeline(transaction=True)
|
||||
|
||||
@ -32,8 +32,8 @@ def add_memory_func(request, WebApiAuth):
|
||||
payload = {
|
||||
"name": f"test_memory_{i}",
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
|
||||
"llm_id": "ZHIPU-AI@glm-4-flash"
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
memory_ids.append(res["data"]["id"])
|
||||
|
||||
@ -49,8 +49,8 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
|
||||
"llm_id": "ZHIPU-AI@glm-4-flash"
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
assert res["code"] == 0, res
|
||||
@ -72,8 +72,8 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
|
||||
"llm_id": "ZHIPU-AI@glm-4-flash"
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
assert res["message"] == expected_message, res
|
||||
@ -84,8 +84,8 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["something"],
|
||||
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
|
||||
"llm_id": "ZHIPU-AI@glm-4-flash"
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res = create_memory(WebApiAuth, payload)
|
||||
assert res["message"] == f"Memory type '{ {'something'} }' is not supported.", res
|
||||
@ -96,8 +96,8 @@ class TestMemoryCreate:
|
||||
payload = {
|
||||
"name": name,
|
||||
"memory_type": ["raw"] + random.choices(["semantic", "episodic", "procedural"], k=random.randint(0, 3)),
|
||||
"embd_id": "SILICONFLOW@BAAI/bge-large-zh-v1.5",
|
||||
"llm_id": "ZHIPU-AI@glm-4-flash"
|
||||
"embd_id": "BAAI/bge-large-zh-v1.5@SILICONFLOW",
|
||||
"llm_id": "glm-4-flash@ZHIPU-AI"
|
||||
}
|
||||
res1 = create_memory(WebApiAuth, payload)
|
||||
assert res1["code"] == 0, res1
|
||||
|
||||
@ -101,7 +101,7 @@ class TestMemoryUpdate:
|
||||
@pytest.mark.p1
|
||||
def test_llm(self, WebApiAuth, add_memory_func):
|
||||
memory_ids = add_memory_func
|
||||
llm_id = "ZHIPU-AI@glm-4"
|
||||
llm_id = "glm-4@ZHIPU-AI"
|
||||
payload = {"llm_id": llm_id}
|
||||
res = update_memory(WebApiAuth, memory_ids[0], payload)
|
||||
assert res["code"] == 0, res
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@ -3051,7 +3051,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "infinity-sdk"
|
||||
version = "0.6.11"
|
||||
version = "0.6.13"
|
||||
source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" }
|
||||
dependencies = [
|
||||
{ name = "datrie" },
|
||||
@ -3068,9 +3068,9 @@ dependencies = [
|
||||
{ name = "sqlglot", extra = ["rs"] },
|
||||
{ name = "thrift" },
|
||||
]
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/6e/6d/294b4c8fb36f874c92576107fab22da6d64f567fd3e24a312d7bcba5f17a/infinity_sdk-0.6.11.tar.gz", hash = "sha256:f78acd5439c3837715ab308c49be04b416bdaa42a5f4fb840682639ee39d435f", size = 29518792, upload-time = "2025-12-08T05:58:35.167Z" }
|
||||
sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/03/de/56fdc0fa962d5a8e0aa68d16f5321b2d88d79fceb7d0d6cfdde338b65d05/infinity_sdk-0.6.13.tar.gz", hash = "sha256:faf7bc23de7fa549a3842753eddad54ae551ada9df4fff25421658a7fa6fa8c2", size = 29518902, upload-time = "2025-12-24T10:00:01.483Z" }
|
||||
wheels = [
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/18/cf/d7d1bc584c8f7d1cd8f75c39067179b98ca4bcbe5a86c61e7dbc2b8e692d/infinity_sdk-0.6.11-py3-none-any.whl", hash = "sha256:ec5ac3e710f29db4b875d3e24e20e391ad64270a5e5d189295cc91c362af74d1", size = 29737403, upload-time = "2025-12-08T05:54:58.798Z" },
|
||||
{ url = "https://pypi.tuna.tsinghua.edu.cn/packages/f4/a0/8f1e134fdf4ca8bebac7b62caace1816953bb5ffc720d9f0004246c8c38d/infinity_sdk-0.6.13-py3-none-any.whl", hash = "sha256:c08a523d2c27e9a7e6e88be640970530b4661a67c3e9dc3e1aa89533a822fd78", size = 29737403, upload-time = "2025-12-24T09:56:16.93Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -6229,7 +6229,7 @@ requires-dist = [
|
||||
{ name = "grpcio-status", specifier = "==1.67.1" },
|
||||
{ name = "html-text", specifier = "==0.6.2" },
|
||||
{ name = "infinity-emb", specifier = ">=0.0.66,<0.0.67" },
|
||||
{ name = "infinity-sdk", specifier = "==0.6.11" },
|
||||
{ name = "infinity-sdk", specifier = "==0.6.13" },
|
||||
{ name = "jira", specifier = "==3.10.5" },
|
||||
{ name = "json-repair", specifier = "==0.35.0" },
|
||||
{ name = "langfuse", specifier = ">=2.60.0" },
|
||||
|
||||
@ -18,7 +18,9 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
|
||||
import { IModalProps } from '@/interfaces/common';
|
||||
import { IParserConfig } from '@/interfaces/database/document';
|
||||
import { IChangeParserConfigRequestBody } from '@/interfaces/request/document';
|
||||
import { MetadataType } from '@/pages/dataset/components/metedata/hooks/use-manage-modal';
|
||||
import {
|
||||
AutoMetadata,
|
||||
ChunkMethodItem,
|
||||
EnableTocToggle,
|
||||
ImageContextWindow,
|
||||
@ -86,6 +88,7 @@ export function ChunkMethodDialog({
|
||||
visible,
|
||||
parserConfig,
|
||||
loading,
|
||||
documentId,
|
||||
}: IProps) {
|
||||
const { t } = useTranslation();
|
||||
|
||||
@ -142,6 +145,18 @@ export function ChunkMethodDialog({
|
||||
pages: z
|
||||
.array(z.object({ from: z.coerce.number(), to: z.coerce.number() }))
|
||||
.optional(),
|
||||
metadata: z
|
||||
.array(
|
||||
z
|
||||
.object({
|
||||
key: z.string().optional(),
|
||||
description: z.string().optional(),
|
||||
enum: z.array(z.string().optional()).optional(),
|
||||
})
|
||||
.optional(),
|
||||
)
|
||||
.optional(),
|
||||
enable_metadata: z.boolean().optional(),
|
||||
}),
|
||||
})
|
||||
.superRefine((data, ctx) => {
|
||||
@ -373,6 +388,10 @@ export function ChunkMethodDialog({
|
||||
)}
|
||||
{showAutoKeywords(selectedTag) && (
|
||||
<>
|
||||
<AutoMetadata
|
||||
type={MetadataType.SingleFileSetting}
|
||||
otherData={{ documentId }}
|
||||
/>
|
||||
<AutoKeywordsFormField></AutoKeywordsFormField>
|
||||
<AutoQuestionsFormField></AutoQuestionsFormField>
|
||||
</>
|
||||
|
||||
@ -36,9 +36,11 @@ export function useDefaultParserValues() {
|
||||
// },
|
||||
entity_types: [],
|
||||
pages: [],
|
||||
metadata: [],
|
||||
enable_metadata: false,
|
||||
};
|
||||
|
||||
return defaultParserValues;
|
||||
return defaultParserValues as IParserConfig;
|
||||
}, [t]);
|
||||
|
||||
return defaultParserValues;
|
||||
|
||||
@ -35,6 +35,7 @@ import { cn } from '@/lib/utils';
|
||||
import { t } from 'i18next';
|
||||
import { Loader } from 'lucide-react';
|
||||
import { MultiSelect, MultiSelectOptionType } from './ui/multi-select';
|
||||
import { Switch } from './ui/switch';
|
||||
|
||||
// Field type enumeration
|
||||
export enum FormFieldType {
|
||||
@ -46,6 +47,7 @@ export enum FormFieldType {
|
||||
Select = 'select',
|
||||
MultiSelect = 'multi-select',
|
||||
Checkbox = 'checkbox',
|
||||
Switch = 'switch',
|
||||
Tag = 'tag',
|
||||
Custom = 'custom',
|
||||
}
|
||||
@ -154,6 +156,7 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
|
||||
}
|
||||
break;
|
||||
case FormFieldType.Checkbox:
|
||||
case FormFieldType.Switch:
|
||||
fieldSchema = z.boolean();
|
||||
break;
|
||||
case FormFieldType.Tag:
|
||||
@ -193,6 +196,8 @@ export const generateSchema = (fields: FormFieldConfig[]): ZodSchema<any> => {
|
||||
if (
|
||||
field.type !== FormFieldType.Number &&
|
||||
field.type !== FormFieldType.Checkbox &&
|
||||
field.type !== FormFieldType.Switch &&
|
||||
field.type !== FormFieldType.Custom &&
|
||||
field.type !== FormFieldType.Tag &&
|
||||
field.required
|
||||
) {
|
||||
@ -289,7 +294,10 @@ const generateDefaultValues = <T extends FieldValues>(
|
||||
const lastKey = keys[keys.length - 1];
|
||||
if (field.defaultValue !== undefined) {
|
||||
current[lastKey] = field.defaultValue;
|
||||
} else if (field.type === FormFieldType.Checkbox) {
|
||||
} else if (
|
||||
field.type === FormFieldType.Checkbox ||
|
||||
field.type === FormFieldType.Switch
|
||||
) {
|
||||
current[lastKey] = false;
|
||||
} else if (field.type === FormFieldType.Tag) {
|
||||
current[lastKey] = [];
|
||||
@ -299,7 +307,10 @@ const generateDefaultValues = <T extends FieldValues>(
|
||||
} else {
|
||||
if (field.defaultValue !== undefined) {
|
||||
defaultValues[field.name] = field.defaultValue;
|
||||
} else if (field.type === FormFieldType.Checkbox) {
|
||||
} else if (
|
||||
field.type === FormFieldType.Checkbox ||
|
||||
field.type === FormFieldType.Switch
|
||||
) {
|
||||
defaultValues[field.name] = false;
|
||||
} else if (
|
||||
field.type === FormFieldType.Tag ||
|
||||
@ -502,6 +513,32 @@ export const RenderField = ({
|
||||
)}
|
||||
/>
|
||||
);
|
||||
case FormFieldType.Switch:
|
||||
return (
|
||||
<RAGFlowFormItem
|
||||
{...field}
|
||||
labelClassName={labelClassName || field.labelClassName}
|
||||
>
|
||||
{(fieldProps) => {
|
||||
const finalFieldProps = field.onChange
|
||||
? {
|
||||
...fieldProps,
|
||||
onChange: (checked: boolean) => {
|
||||
fieldProps.onChange(checked);
|
||||
field.onChange?.(checked);
|
||||
},
|
||||
}
|
||||
: fieldProps;
|
||||
return (
|
||||
<Switch
|
||||
checked={finalFieldProps.value as boolean}
|
||||
onCheckedChange={(checked) => finalFieldProps.onChange(checked)}
|
||||
disabled={field.disabled}
|
||||
/>
|
||||
);
|
||||
}}
|
||||
</RAGFlowFormItem>
|
||||
);
|
||||
|
||||
case FormFieldType.Tag:
|
||||
return (
|
||||
|
||||
@ -15,6 +15,7 @@ import { Progress } from '@/components/ui/progress';
|
||||
import { useControllableState } from '@/hooks/use-controllable-state';
|
||||
import { cn, formatBytes } from '@/lib/utils';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from './ui/tooltip';
|
||||
|
||||
function isFileWithPreview(file: File): file is File & { preview: string } {
|
||||
return 'preview' in file && typeof file.preview === 'string';
|
||||
@ -58,10 +59,17 @@ function FileCard({ file, progress, onRemove }: FileCardProps) {
|
||||
</div>
|
||||
<div className="flex flex-col flex-1 gap-2 overflow-hidden">
|
||||
<div className="flex flex-col gap-px">
|
||||
<p className="line-clamp-1 text-sm font-medium text-foreground/80 text-ellipsis">
|
||||
{file.name}
|
||||
</p>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<p className=" w-fit line-clamp-1 text-sm font-medium text-foreground/80 text-ellipsis truncate max-w-[370px]">
|
||||
{file.name}
|
||||
</p>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className="border border-border-button">
|
||||
{file.name}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
<p className="text-xs text-text-secondary">
|
||||
{formatBytes(file.size)}
|
||||
</p>
|
||||
</div>
|
||||
@ -311,7 +319,7 @@ export function FileUploader(props: FileUploaderProps) {
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-px">
|
||||
<p className="font-medium text-text-secondary">
|
||||
<p className="font-medium text-text-secondary ">
|
||||
{title || t('knowledgeDetails.uploadTitle')}
|
||||
</p>
|
||||
<p className="text-sm text-text-disabled">
|
||||
|
||||
@ -31,14 +31,16 @@ const handleCheckChange = ({
|
||||
(value: string) => value !== item.id.toString(),
|
||||
);
|
||||
|
||||
const newValue = {
|
||||
...currentValue,
|
||||
[parentId]: newParentValues,
|
||||
};
|
||||
const newValue = newParentValues?.length
|
||||
? {
|
||||
...currentValue,
|
||||
[parentId]: newParentValues,
|
||||
}
|
||||
: { ...currentValue };
|
||||
|
||||
if (newValue[parentId].length === 0) {
|
||||
delete newValue[parentId];
|
||||
}
|
||||
// if (newValue[parentId].length === 0) {
|
||||
// delete newValue[parentId];
|
||||
// }
|
||||
|
||||
return field.onChange(newValue);
|
||||
} else {
|
||||
@ -66,20 +68,31 @@ const FilterItem = memo(
|
||||
}) => {
|
||||
return (
|
||||
<div
|
||||
className={`flex items-center justify-between text-text-primary text-xs ${level > 0 ? 'ml-4' : ''}`}
|
||||
className={`flex items-center justify-between text-text-primary text-xs ${level > 0 ? 'ml-1' : ''}`}
|
||||
>
|
||||
<FormItem className="flex flex-row space-x-3 space-y-0 items-center">
|
||||
<FormItem className="flex flex-row space-x-3 space-y-0 items-center ">
|
||||
<FormControl>
|
||||
<Checkbox
|
||||
checked={field.value?.includes(item.id.toString())}
|
||||
onCheckedChange={(checked: boolean) =>
|
||||
handleCheckChange({ checked, field, item })
|
||||
}
|
||||
/>
|
||||
<div className="flex space-x-3">
|
||||
<Checkbox
|
||||
checked={field.value?.includes(item.id.toString())}
|
||||
onCheckedChange={(checked: boolean) =>
|
||||
handleCheckChange({ checked, field, item })
|
||||
}
|
||||
// className="hidden group-hover:block"
|
||||
/>
|
||||
<FormLabel
|
||||
onClick={() =>
|
||||
handleCheckChange({
|
||||
checked: !field.value?.includes(item.id.toString()),
|
||||
field,
|
||||
item,
|
||||
})
|
||||
}
|
||||
>
|
||||
{item.label}
|
||||
</FormLabel>
|
||||
</div>
|
||||
</FormControl>
|
||||
<FormLabel onClick={(e) => e.stopPropagation()}>
|
||||
{item.label}
|
||||
</FormLabel>
|
||||
</FormItem>
|
||||
{item.count !== undefined && (
|
||||
<span className="text-sm">{item.count}</span>
|
||||
@ -107,11 +120,11 @@ export const FilterField = memo(
|
||||
<FormField
|
||||
key={item.id}
|
||||
control={form.control}
|
||||
name={parent.field as string}
|
||||
name={parent.field?.toString() as string}
|
||||
render={({ field }) => {
|
||||
if (hasNestedList) {
|
||||
return (
|
||||
<div className={`flex flex-col gap-2 ${level > 0 ? 'ml-4' : ''}`}>
|
||||
<div className={`flex flex-col gap-2 ${level > 0 ? 'ml-1' : ''}`}>
|
||||
<div
|
||||
className="flex items-center justify-between cursor-pointer"
|
||||
onClick={() => {
|
||||
@ -138,23 +151,6 @@ export const FilterField = memo(
|
||||
}}
|
||||
level={level + 1}
|
||||
/>
|
||||
// <FilterItem key={child.id} item={child} field={child.field} level={level+1} />
|
||||
// <div
|
||||
// className="flex flex-row space-x-3 space-y-0 items-center"
|
||||
// key={child.id}
|
||||
// >
|
||||
// <FormControl>
|
||||
// <Checkbox
|
||||
// checked={field.value?.includes(child.id.toString())}
|
||||
// onCheckedChange={(checked) =>
|
||||
// handleCheckChange({ checked, field, item: child })
|
||||
// }
|
||||
// />
|
||||
// </FormControl>
|
||||
// <FormLabel onClick={(e) => e.stopPropagation()}>
|
||||
// {child.label}
|
||||
// </FormLabel>
|
||||
// </div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
||||
@ -11,8 +11,8 @@ import {
|
||||
useMemo,
|
||||
useState,
|
||||
} from 'react';
|
||||
import { useForm } from 'react-hook-form';
|
||||
import { ZodArray, ZodString, z } from 'zod';
|
||||
import { FieldPath, useForm } from 'react-hook-form';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { Button } from '@/components/ui/button';
|
||||
|
||||
@ -71,34 +71,37 @@ function CheckboxFormMultiple({
|
||||
}, {});
|
||||
}, [resolvedFilters]);
|
||||
|
||||
// const FormSchema = useMemo(() => {
|
||||
// if (resolvedFilters.length === 0) {
|
||||
// return z.object({});
|
||||
// }
|
||||
|
||||
// return z.object(
|
||||
// resolvedFilters.reduce<
|
||||
// Record<
|
||||
// string,
|
||||
// ZodArray<ZodString, 'many'> | z.ZodObject<any> | z.ZodOptional<any>
|
||||
// >
|
||||
// >((pre, cur) => {
|
||||
// const hasNested = cur.list?.some(
|
||||
// (item) => item.list && item.list.length > 0,
|
||||
// );
|
||||
|
||||
// if (hasNested) {
|
||||
// pre[cur.field] = z
|
||||
// .record(z.string(), z.array(z.string().optional()).optional())
|
||||
// .optional();
|
||||
// } else {
|
||||
// pre[cur.field] = z.array(z.string().optional()).optional();
|
||||
// }
|
||||
|
||||
// return pre;
|
||||
// }, {}),
|
||||
// );
|
||||
// }, [resolvedFilters]);
|
||||
const FormSchema = useMemo(() => {
|
||||
if (resolvedFilters.length === 0) {
|
||||
return z.object({});
|
||||
}
|
||||
|
||||
return z.object(
|
||||
resolvedFilters.reduce<
|
||||
Record<
|
||||
string,
|
||||
ZodArray<ZodString, 'many'> | z.ZodObject<any> | z.ZodOptional<any>
|
||||
>
|
||||
>((pre, cur) => {
|
||||
const hasNested = cur.list?.some(
|
||||
(item) => item.list && item.list.length > 0,
|
||||
);
|
||||
|
||||
if (hasNested) {
|
||||
pre[cur.field] = z
|
||||
.record(z.string(), z.array(z.string().optional()).optional())
|
||||
.optional();
|
||||
} else {
|
||||
pre[cur.field] = z.array(z.string().optional()).optional();
|
||||
}
|
||||
|
||||
return pre;
|
||||
}, {}),
|
||||
);
|
||||
}, [resolvedFilters]);
|
||||
return z.object({});
|
||||
}, []);
|
||||
|
||||
const form = useForm<z.infer<typeof FormSchema>>({
|
||||
resolver: resolvedFilters.length > 0 ? zodResolver(FormSchema) : undefined,
|
||||
@ -178,7 +181,9 @@ function CheckboxFormMultiple({
|
||||
<FormField
|
||||
key={x.field}
|
||||
control={form.control}
|
||||
name={x.field}
|
||||
name={
|
||||
x.field.toString() as FieldPath<z.infer<typeof FormSchema>>
|
||||
}
|
||||
render={() => (
|
||||
<FormItem className="space-y-4">
|
||||
<div>
|
||||
@ -186,19 +191,20 @@ function CheckboxFormMultiple({
|
||||
{x.label}
|
||||
</FormLabel>
|
||||
</div>
|
||||
{x.list.map((item) => {
|
||||
return (
|
||||
<FilterField
|
||||
key={item.id}
|
||||
item={{ ...item }}
|
||||
parent={{
|
||||
...x,
|
||||
id: x.field,
|
||||
// field: `${x.field}${item.field ? '.' + item.field : ''}`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{x.list?.length &&
|
||||
x.list.map((item) => {
|
||||
return (
|
||||
<FilterField
|
||||
key={item.id}
|
||||
item={{ ...item }}
|
||||
parent={{
|
||||
...x,
|
||||
id: x.field,
|
||||
// field: `${x.field}${item.field ? '.' + item.field : ''}`,
|
||||
}}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
|
||||
@ -14,6 +14,7 @@ import {
|
||||
IDocumentMetaRequestBody,
|
||||
} from '@/interfaces/request/document';
|
||||
import i18n from '@/locales/config';
|
||||
import { EMPTY_METADATA_FIELD } from '@/pages/dataset/dataset/use-select-filters';
|
||||
import kbService, { listDocument } from '@/services/knowledge-service';
|
||||
import api, { api_host } from '@/utils/api';
|
||||
import { buildChunkHighlights } from '@/utils/document-util';
|
||||
@ -114,6 +115,20 @@ export const useFetchDocumentList = () => {
|
||||
refetchInterval: isLoop ? 5000 : false,
|
||||
enabled: !!knowledgeId || !!id,
|
||||
queryFn: async () => {
|
||||
let run = [] as any;
|
||||
let returnEmptyMetadata = false;
|
||||
if (filterValue.run && Array.isArray(filterValue.run)) {
|
||||
run = [...(filterValue.run as string[])];
|
||||
const returnEmptyMetadataIndex = run.findIndex(
|
||||
(r: string) => r === EMPTY_METADATA_FIELD,
|
||||
);
|
||||
if (returnEmptyMetadataIndex > -1) {
|
||||
returnEmptyMetadata = true;
|
||||
run.splice(returnEmptyMetadataIndex, 1);
|
||||
}
|
||||
} else {
|
||||
run = filterValue.run;
|
||||
}
|
||||
const ret = await listDocument(
|
||||
{
|
||||
kb_id: knowledgeId || id,
|
||||
@ -123,7 +138,8 @@ export const useFetchDocumentList = () => {
|
||||
},
|
||||
{
|
||||
suffix: filterValue.type as string[],
|
||||
run_status: filterValue.run as string[],
|
||||
run_status: run as string[],
|
||||
return_empty_metadata: returnEmptyMetadata,
|
||||
metadata: filterValue.metadata as Record<string, string[]>,
|
||||
},
|
||||
);
|
||||
|
||||
@ -43,6 +43,18 @@ export interface IParserConfig {
|
||||
task_page_size?: number;
|
||||
raptor?: Raptor;
|
||||
graphrag?: GraphRag;
|
||||
image_context_window?: number;
|
||||
mineru_parse_method?: 'auto' | 'txt' | 'ocr';
|
||||
mineru_formula_enable?: boolean;
|
||||
mineru_table_enable?: boolean;
|
||||
mineru_lang?: string;
|
||||
entity_types?: string[];
|
||||
metadata?: Array<{
|
||||
key?: string;
|
||||
description?: string;
|
||||
enum?: string[];
|
||||
}>;
|
||||
enable_metadata?: boolean;
|
||||
}
|
||||
|
||||
interface Raptor {
|
||||
|
||||
@ -33,5 +33,6 @@ export interface IFetchKnowledgeListRequestParams {
|
||||
export interface IFetchDocumentListRequestBody {
|
||||
suffix?: string[];
|
||||
run_status?: string[];
|
||||
return_empty_metadata?: boolean;
|
||||
metadata?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
@ -176,6 +176,15 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
|
||||
},
|
||||
knowledgeDetails: {
|
||||
metadata: {
|
||||
descriptionTip:
|
||||
'Provide descriptions or examples to guide LLM extract values for this field. If left empty, it will rely on the field name.',
|
||||
restrictTDefinedValuesTip:
|
||||
'Enum Mode: Restricts LLM extraction to match preset values only. Define values below.',
|
||||
valueExists:
|
||||
'Value already exists. Confirm to merge duplicates and combine all associated files.',
|
||||
fieldNameExists:
|
||||
'Field name already exists. Confirm to merge duplicates and combine all associated files.',
|
||||
fieldExists: 'Field already exists.',
|
||||
fieldSetting: 'Field settings',
|
||||
changesAffectNewParses: 'Changes affect new parses only.',
|
||||
editMetadataForDataset: 'View and edit metadata for ',
|
||||
@ -185,12 +194,25 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
|
||||
manageMetadata: 'Manage metadata',
|
||||
metadata: 'Metadata',
|
||||
values: 'Values',
|
||||
value: 'Value',
|
||||
action: 'Action',
|
||||
field: 'Field',
|
||||
description: 'Description',
|
||||
fieldName: 'Field name',
|
||||
editMetadata: 'Edit metadata',
|
||||
deleteWarn: 'This {{field}} will be removed from all associated files',
|
||||
deleteManageFieldAllWarn:
|
||||
'This field and all its corresponding values will be deleted from all associated files.',
|
||||
deleteManageValueAllWarn:
|
||||
'This value will be deleted from from all associated files.',
|
||||
deleteManageFieldSingleWarn:
|
||||
'This field and all its corresponding values will be deleted from this files.',
|
||||
deleteManageValueSingleWarn:
|
||||
'This value will be deleted from this files.',
|
||||
deleteSettingFieldWarn: `This field will be deleted; existing metadata won't be affected.`,
|
||||
deleteSettingValueWarn: `This value will be deleted; existing metadata won't be affected.`,
|
||||
},
|
||||
emptyMetadata: 'No metadata',
|
||||
metadataField: 'Metadata field',
|
||||
systemAttribute: 'System attribute',
|
||||
localUpload: 'Local upload',
|
||||
@ -334,9 +356,9 @@ Procedural Memory: Learned skills, habits, and automated procedures.`,
|
||||
html4excel: 'Excel to HTML',
|
||||
html4excelTip: `Use with the General chunking method. When disabled, spreadsheets (XLSX or XLS(Excel 97-2003)) in the knowledge base will be parsed into key-value pairs. When enabled, they will be parsed into HTML tables, splitting every 12 rows if the original table has more than 12 rows. See https://ragflow.io/docs/dev/enable_excel2html for details.`,
|
||||
autoKeywords: 'Auto-keyword',
|
||||
autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the chat model specified in 'System model settings'. You can check or update the added keywords for a chunk from the chunk list. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
|
||||
autoKeywordsTip: `Automatically extract N keywords for each chunk to increase their ranking for queries containing those keywords. Be aware that extra tokens will be consumed by the indexing model specified in 'Configuration'. You can check or update the added keywords for a chunk from the chunk list. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
|
||||
autoQuestions: 'Auto-question',
|
||||
autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the LLM specified in 'System model settings'. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
|
||||
autoQuestionsTip: `Automatically extract N questions for each chunk to increase their ranking for queries containing those questions. You can check or update the added questions for a chunk from the chunk list. This feature will not disrupt the chunking process if an error occurs, except that it may add an empty result to the original chunk. Be aware that extra tokens will be consumed by the indexing model specified in 'Configuration'. For details, see https://ragflow.io/docs/dev/autokeyword_autoquestion.`,
|
||||
redo: 'Do you want to clear the existing {{chunkNum}} chunks?',
|
||||
setMetaData: 'Set meta data',
|
||||
pleaseInputJson: 'Please enter JSON',
|
||||
@ -1607,6 +1629,8 @@ Example: Virtual Hosted Style`,
|
||||
notEmpty: 'Not empty',
|
||||
in: 'In',
|
||||
notIn: 'Not in',
|
||||
is: 'Is',
|
||||
isNot: 'Is not',
|
||||
},
|
||||
switchLogicOperatorOptions: {
|
||||
and: 'AND',
|
||||
@ -1868,6 +1892,8 @@ This process aggregates variables from multiple branches into a single variable
|
||||
beginInputTip:
|
||||
'By defining input parameters, this content can be accessed by other components in subsequent processes.',
|
||||
query: 'Query variables',
|
||||
switchPromptMessage:
|
||||
'The prompt words will change. Please confirm whether you want to discard the existing prompt words?',
|
||||
queryRequired: 'Query is required',
|
||||
queryTip: 'Select the variable you want to use',
|
||||
agent: 'Agent',
|
||||
|
||||
@ -168,6 +168,10 @@ export default {
|
||||
},
|
||||
knowledgeDetails: {
|
||||
metadata: {
|
||||
descriptionTip:
|
||||
'提供描述或示例来指导大语言模型为此字段提取值。如果留空,将依赖字段名称。',
|
||||
restrictTDefinedValuesTip:
|
||||
'枚举模式:限制大语言模型仅提取预设值。在下方定义值。',
|
||||
fieldSetting: '字段设置',
|
||||
changesAffectNewParses: '更改仅影响新解析。',
|
||||
editMetadataForDataset: '查看和编辑元数据于 ',
|
||||
@ -177,12 +181,25 @@ export default {
|
||||
manageMetadata: '管理元数据',
|
||||
metadata: '元数据',
|
||||
values: '值',
|
||||
value: '值',
|
||||
action: '操作',
|
||||
field: '字段',
|
||||
description: '描述',
|
||||
fieldName: '字段名',
|
||||
editMetadata: '编辑元数据',
|
||||
valueExists: '值已存在。确认合并重复项并组合所有关联文件。',
|
||||
fieldNameExists: '字段名已存在。确认合并重复项并组合所有关联文件。',
|
||||
fieldExists: '字段名已存在。',
|
||||
deleteWarn: '此 {{field}} 将从所有关联文件中移除',
|
||||
deleteManageFieldAllWarn:
|
||||
'此字段及其所有对应值将从所有关联的文件中删除。',
|
||||
deleteManageValueAllWarn: '此值将从所有关联的文件中删除。',
|
||||
deleteManageFieldSingleWarn: '此字段及其所有对应值将从此文件中删除。',
|
||||
deleteManageValueSingleWarn: '此值将从此文件中删除。',
|
||||
deleteSettingFieldWarn: `此字段将被删除;现有元数据不会受到影响。`,
|
||||
deleteSettingValueWarn: `此值将被删除;现有元数据不会受到影响。`,
|
||||
},
|
||||
emptyMetadata: '无元数据',
|
||||
localUpload: '本地上传',
|
||||
fileSize: '文件大小',
|
||||
fileType: '文件类型',
|
||||
@ -311,9 +328,9 @@ export default {
|
||||
html4excel: '表格转HTML',
|
||||
html4excelTip: `与 General 切片方法配合使用。未开启状态下,表格文件(XLSX、XLS(Excel 97-2003))会按行解析为键值对。开启后,表格文件会被解析为 HTML 表格。若原始表格超过 12 行,系统会自动按每 12 行拆分为多个 HTML 表格。欲了解更多详情,请参阅 https://ragflow.io/docs/dev/enable_excel2html。`,
|
||||
autoKeywords: '自动关键词提取',
|
||||
autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用“系统模型设置”中设置的默认聊天模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
|
||||
autoKeywordsTip: `自动为每个文本块中提取 N 个关键词,用以提升查询精度。请注意:该功能采用在“配置”中指定的索引模型提取关键词,因此也会产生更多 Token 消耗。另外,你也可以手动更新生成的关键词。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
|
||||
autoQuestions: '自动问题提取',
|
||||
autoQuestionsTip: `利用“系统模型设置”中设置的 chat model 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误,不会妨碍整个分块过程,只会将空结果添加到原始文本块。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
|
||||
autoQuestionsTip: `利用在“配置”中指定的索引模型 对知识库的每个文本块提取 N 个问题以提高其排名得分。请注意,开启后将消耗额外的 token。您可以在块列表中查看、编辑结果。如果自动问题提取发生错误,不会妨碍整个分块过程,只会将空结果添加到原始文本块。详情请见 https://ragflow.io/docs/dev/autokeyword_autoquestion。`,
|
||||
redo: '是否清空已有 {{chunkNum}}个 chunk?',
|
||||
setMetaData: '设置元数据',
|
||||
pleaseInputJson: '请输入JSON',
|
||||
@ -1505,6 +1522,8 @@ General:实体和关系提取提示来自 GitHub - microsoft/graphrag:基于
|
||||
endWith: '结束是',
|
||||
empty: '为空',
|
||||
notEmpty: '不为空',
|
||||
is: '是',
|
||||
isNot: '不是',
|
||||
},
|
||||
switchLogicOperatorOptions: {
|
||||
and: '与',
|
||||
|
||||
@ -4,6 +4,7 @@ import { useSendAgentMessage } from './use-send-agent-message';
|
||||
|
||||
import { FileUploadProps } from '@/components/file-upload';
|
||||
import { NextMessageInput } from '@/components/message-input/next';
|
||||
import MarkdownContent from '@/components/next-markdown-content';
|
||||
import MessageItem from '@/components/next-message-item';
|
||||
import PdfSheet from '@/components/pdf-drawer';
|
||||
import { useClickDrawer } from '@/components/pdf-drawer/hooks';
|
||||
@ -102,8 +103,10 @@ function AgentChatBox() {
|
||||
{message.role === MessageType.Assistant &&
|
||||
derivedMessages.length - 1 !== i && (
|
||||
<div>
|
||||
<div>{message?.data?.tips}</div>
|
||||
|
||||
<MarkdownContent
|
||||
content={message?.data?.tips}
|
||||
loading={false}
|
||||
></MarkdownContent>
|
||||
<div>
|
||||
{buildInputList(message)?.map((item) => item.value)}
|
||||
</div>
|
||||
|
||||
@ -206,6 +206,7 @@ export const initialSplitterValues = {
|
||||
chunk_token_size: 512,
|
||||
overlapped_percent: 0,
|
||||
delimiters: [{ value: '\n' }],
|
||||
image_table_context_window: 0,
|
||||
};
|
||||
|
||||
export enum Hierarchy {
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import MarkdownContent from '@/components/next-markdown-content';
|
||||
import { ButtonLoading } from '@/components/ui/button';
|
||||
import {
|
||||
Form,
|
||||
@ -234,7 +235,14 @@ const DebugContent = ({
|
||||
return (
|
||||
<>
|
||||
<section>
|
||||
{message?.data?.tips && <div className="mb-2">{message.data.tips}</div>}
|
||||
{message?.data?.tips && (
|
||||
<div className="mb-2">
|
||||
<MarkdownContent
|
||||
content={message?.data?.tips}
|
||||
loading={false}
|
||||
></MarkdownContent>
|
||||
</div>
|
||||
)}
|
||||
<Form {...form}>
|
||||
<form onSubmit={form.handleSubmit(onSubmit)} className="space-y-4">
|
||||
{parameters.map((x, idx) => {
|
||||
|
||||
@ -22,6 +22,7 @@ const outputList = buildOutputList(initialSplitterValues.outputs);
|
||||
|
||||
export const FormSchema = z.object({
|
||||
chunk_token_size: z.number(),
|
||||
image_table_context_window: z.number(),
|
||||
delimiters: z.array(
|
||||
z.object({
|
||||
value: z.string().optional(),
|
||||
@ -74,6 +75,13 @@ const SplitterForm = ({ node }: INextOperatorForm) => {
|
||||
min={0}
|
||||
label={t('flow.overlappedPercent')}
|
||||
></SliderInputFormField>
|
||||
<SliderInputFormField
|
||||
name="image_table_context_window"
|
||||
max={256}
|
||||
min={0}
|
||||
label={t('knowledgeConfiguration.imageTableContextWindow')}
|
||||
tooltip={t('knowledgeConfiguration.imageTableContextWindowTip')}
|
||||
></SliderInputFormField>
|
||||
<section>
|
||||
<span className="mb-2 inline-block">{t('flow.delimiters')}</span>
|
||||
<div className="space-y-4">
|
||||
|
||||
@ -466,7 +466,7 @@ const useGraphStore = create<RFState>()(
|
||||
}
|
||||
},
|
||||
updateSwitchFormData: (source, sourceHandle, target, isConnecting) => {
|
||||
const { updateNodeForm, edges } = get();
|
||||
const { updateNodeForm, edges, getOperatorTypeFromId } = get();
|
||||
if (sourceHandle) {
|
||||
// A handle will connect to multiple downstream nodes
|
||||
let currentHandleTargets = edges
|
||||
@ -474,7 +474,8 @@ const useGraphStore = create<RFState>()(
|
||||
(x) =>
|
||||
x.source === source &&
|
||||
x.sourceHandle === sourceHandle &&
|
||||
typeof x.target === 'string',
|
||||
typeof x.target === 'string' &&
|
||||
getOperatorTypeFromId(x.target) !== Operator.Placeholder,
|
||||
)
|
||||
.map((x) => x.target);
|
||||
|
||||
|
||||
@ -289,10 +289,14 @@ function transformParserParams(params: ParserFormSchemaType) {
|
||||
}
|
||||
|
||||
function transformSplitterParams(params: SplitterFormSchemaType) {
|
||||
const { image_table_context_window, ...rest } = params;
|
||||
const imageTableContextWindow = Number(image_table_context_window || 0);
|
||||
return {
|
||||
...params,
|
||||
...rest,
|
||||
overlapped_percent: Number(params.overlapped_percent) / 100,
|
||||
delimiters: transformObjectArrayToPureArray(params.delimiters, 'value'),
|
||||
table_context_size: imageTableContextWindow,
|
||||
image_context_size: imageTableContextWindow,
|
||||
|
||||
// Unset children delimiters if this option is not enabled
|
||||
children_delimiters: params.enable_children
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
import message from '@/components/ui/message';
|
||||
import { useSetModalState } from '@/hooks/common-hooks';
|
||||
import { useSetDocumentMeta } from '@/hooks/use-document-request';
|
||||
import {
|
||||
DocumentApiAction,
|
||||
useSetDocumentMeta,
|
||||
} from '@/hooks/use-document-request';
|
||||
import kbService, {
|
||||
getMetaDataService,
|
||||
updateMetaData,
|
||||
} from '@/services/knowledge-service';
|
||||
import { useQuery } from '@tanstack/react-query';
|
||||
import { useQuery, useQueryClient } from '@tanstack/react-query';
|
||||
import { TFunction } from 'i18next';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useParams } from 'umi';
|
||||
@ -16,12 +20,43 @@ import {
|
||||
IMetaDataTableData,
|
||||
MetadataOperations,
|
||||
ShowManageMetadataModalProps,
|
||||
} from './interface';
|
||||
} from '../interface';
|
||||
export enum MetadataType {
|
||||
Manage = 1,
|
||||
UpdateSingle = 2,
|
||||
Setting = 3,
|
||||
SingleFileSetting = 4,
|
||||
}
|
||||
|
||||
export const MetadataDeleteMap = (
|
||||
t: TFunction<'translation', undefined>,
|
||||
): Record<
|
||||
MetadataType,
|
||||
{ title: string; warnFieldText: string; warnValueText: string }
|
||||
> => {
|
||||
return {
|
||||
[MetadataType.Manage]: {
|
||||
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
|
||||
warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldAllWarn'),
|
||||
warnValueText: t('knowledgeDetails.metadata.deleteManageValueAllWarn'),
|
||||
},
|
||||
[MetadataType.Setting]: {
|
||||
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
|
||||
warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'),
|
||||
warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'),
|
||||
},
|
||||
[MetadataType.UpdateSingle]: {
|
||||
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
|
||||
warnFieldText: t('knowledgeDetails.metadata.deleteManageFieldSingleWarn'),
|
||||
warnValueText: t('knowledgeDetails.metadata.deleteManageValueSingleWarn'),
|
||||
},
|
||||
[MetadataType.SingleFileSetting]: {
|
||||
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.metadata'),
|
||||
warnFieldText: t('knowledgeDetails.metadata.deleteSettingFieldWarn'),
|
||||
warnValueText: t('knowledgeDetails.metadata.deleteSettingValueWarn'),
|
||||
},
|
||||
};
|
||||
};
|
||||
export const util = {
|
||||
changeToMetaDataTableData(data: IMetaDataReturnType): IMetaDataTableData[] {
|
||||
return Object.entries(data).map(([key, value]) => {
|
||||
@ -39,10 +74,21 @@ export const util = {
|
||||
data: Record<string, string | string[]>,
|
||||
): IMetaDataTableData[] {
|
||||
return Object.entries(data).map(([key, value]) => {
|
||||
let thisValue = [] as string[];
|
||||
if (value && Array.isArray(value)) {
|
||||
thisValue = value;
|
||||
} else if (value && typeof value === 'string') {
|
||||
thisValue = [value];
|
||||
} else if (value && typeof value === 'object') {
|
||||
thisValue = [JSON.stringify(value)];
|
||||
} else if (value) {
|
||||
thisValue = [value.toString()];
|
||||
}
|
||||
|
||||
return {
|
||||
field: key,
|
||||
description: '',
|
||||
values: value,
|
||||
values: thisValue,
|
||||
} as IMetaDataTableData;
|
||||
});
|
||||
},
|
||||
@ -100,12 +146,42 @@ export const useMetadataOperations = () => {
|
||||
}));
|
||||
}, []);
|
||||
|
||||
// const addUpdateValue = useCallback(
|
||||
// (key: string, value: string | string[]) => {
|
||||
// setOperations((prev) => ({
|
||||
// ...prev,
|
||||
// updates: [...prev.updates, { key, value }],
|
||||
// }));
|
||||
// },
|
||||
// [],
|
||||
// );
|
||||
const addUpdateValue = useCallback(
|
||||
(key: string, value: string | string[]) => {
|
||||
setOperations((prev) => ({
|
||||
...prev,
|
||||
updates: [...prev.updates, { key, value }],
|
||||
}));
|
||||
(key: string, originalValue: string, newValue: string) => {
|
||||
setOperations((prev) => {
|
||||
const existsIndex = prev.updates.findIndex(
|
||||
(update) => update.key === key && update.match === originalValue,
|
||||
);
|
||||
|
||||
if (existsIndex > -1) {
|
||||
const updatedUpdates = [...prev.updates];
|
||||
updatedUpdates[existsIndex] = {
|
||||
key,
|
||||
match: originalValue,
|
||||
value: newValue,
|
||||
};
|
||||
return {
|
||||
...prev,
|
||||
updates: updatedUpdates,
|
||||
};
|
||||
}
|
||||
return {
|
||||
...prev,
|
||||
updates: [
|
||||
...prev.updates,
|
||||
{ key, match: originalValue, value: newValue },
|
||||
],
|
||||
};
|
||||
});
|
||||
},
|
||||
[],
|
||||
);
|
||||
@ -191,9 +267,14 @@ export const useManageMetaDataModal = (
|
||||
const { data, loading } = useFetchMetaDataManageData(type);
|
||||
|
||||
const [tableData, setTableData] = useState<IMetaDataTableData[]>(metaData);
|
||||
|
||||
const { operations, addDeleteRow, addDeleteValue, addUpdateValue } =
|
||||
useMetadataOperations();
|
||||
const queryClient = useQueryClient();
|
||||
const {
|
||||
operations,
|
||||
addDeleteRow,
|
||||
addDeleteValue,
|
||||
addUpdateValue,
|
||||
resetOperations,
|
||||
} = useMetadataOperations();
|
||||
|
||||
const { setDocumentMeta } = useSetDocumentMeta();
|
||||
|
||||
@ -259,11 +340,15 @@ export const useManageMetaDataModal = (
|
||||
data: operations,
|
||||
});
|
||||
if (res.code === 0) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: [DocumentApiAction.FetchDocumentList],
|
||||
});
|
||||
resetOperations();
|
||||
message.success(t('message.operated'));
|
||||
callback();
|
||||
}
|
||||
},
|
||||
[operations, id, t],
|
||||
[operations, id, t, queryClient, resetOperations],
|
||||
);
|
||||
|
||||
const handleSaveUpdateSingle = useCallback(
|
||||
@ -297,7 +382,26 @@ export const useManageMetaDataModal = (
|
||||
|
||||
return data;
|
||||
},
|
||||
[tableData, id],
|
||||
[tableData, id, t],
|
||||
);
|
||||
|
||||
const handleSaveSingleFileSettings = useCallback(
|
||||
async (callback: () => void) => {
|
||||
const data = util.tableDataToMetaDataSettingJSON(tableData);
|
||||
if (otherData?.documentId) {
|
||||
const { data: res } = await kbService.documentUpdateMetaData({
|
||||
doc_id: otherData.documentId,
|
||||
metadata: data,
|
||||
});
|
||||
if (res.code === 0) {
|
||||
message.success(t('message.operated'));
|
||||
callback?.();
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
},
|
||||
[tableData, t, otherData],
|
||||
);
|
||||
|
||||
const handleSave = useCallback(
|
||||
@ -311,12 +415,20 @@ export const useManageMetaDataModal = (
|
||||
break;
|
||||
case MetadataType.Setting:
|
||||
return handleSaveSettings(callback);
|
||||
case MetadataType.SingleFileSetting:
|
||||
return handleSaveSingleFileSettings(callback);
|
||||
default:
|
||||
handleSaveManage(callback);
|
||||
break;
|
||||
}
|
||||
},
|
||||
[handleSaveManage, type, handleSaveUpdateSingle, handleSaveSettings],
|
||||
[
|
||||
handleSaveManage,
|
||||
type,
|
||||
handleSaveUpdateSingle,
|
||||
handleSaveSettings,
|
||||
handleSaveSingleFileSettings,
|
||||
],
|
||||
);
|
||||
|
||||
return {
|
||||
@ -371,11 +483,3 @@ export const useManageMetadata = () => {
|
||||
config,
|
||||
};
|
||||
};
|
||||
|
||||
export const useManageValues = () => {
|
||||
const [updateValues, setUpdateValues] = useState<{
|
||||
field: string;
|
||||
values: string[];
|
||||
} | null>(null);
|
||||
return { updateValues, setUpdateValues };
|
||||
};
|
||||
@ -0,0 +1,208 @@
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MetadataDeleteMap, MetadataType } from '../hooks/use-manage-modal';
|
||||
import { IManageValuesProps, IMetaDataTableData } from '../interface';
|
||||
|
||||
export const useManageValues = (props: IManageValuesProps) => {
|
||||
const {
|
||||
data,
|
||||
|
||||
isShowValueSwitch,
|
||||
hideModal,
|
||||
onSave,
|
||||
addUpdateValue,
|
||||
addDeleteValue,
|
||||
existsKeys,
|
||||
type,
|
||||
} = props;
|
||||
const { t } = useTranslation();
|
||||
const [metaData, setMetaData] = useState(data);
|
||||
const [valueError, setValueError] = useState<Record<string, string>>({
|
||||
field: '',
|
||||
values: '',
|
||||
});
|
||||
const [deleteDialogContent, setDeleteDialogContent] = useState({
|
||||
visible: false,
|
||||
title: '',
|
||||
name: '',
|
||||
warnText: '',
|
||||
onOk: () => {},
|
||||
onCancel: () => {},
|
||||
});
|
||||
const hideDeleteModal = () => {
|
||||
setDeleteDialogContent({
|
||||
visible: false,
|
||||
title: '',
|
||||
name: '',
|
||||
warnText: '',
|
||||
onOk: () => {},
|
||||
onCancel: () => {},
|
||||
});
|
||||
};
|
||||
|
||||
// Use functional update to avoid closure issues
|
||||
const handleChange = useCallback(
|
||||
(field: string, value: any) => {
|
||||
if (field === 'field' && existsKeys.includes(value)) {
|
||||
setValueError((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
field:
|
||||
type === MetadataType.Setting
|
||||
? t('knowledgeDetails.metadata.fieldExists')
|
||||
: t('knowledgeDetails.metadata.fieldNameExists'),
|
||||
};
|
||||
});
|
||||
} else if (field === 'field' && !existsKeys.includes(value)) {
|
||||
setValueError((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
field: '',
|
||||
};
|
||||
});
|
||||
}
|
||||
setMetaData((prev) => ({
|
||||
...prev,
|
||||
[field]: value,
|
||||
}));
|
||||
},
|
||||
[existsKeys, type, t],
|
||||
);
|
||||
|
||||
// Maintain separate state for each input box
|
||||
const [tempValues, setTempValues] = useState<string[]>([...data.values]);
|
||||
|
||||
useEffect(() => {
|
||||
setTempValues([...data.values]);
|
||||
setMetaData(data);
|
||||
}, [data]);
|
||||
|
||||
const handleHideModal = useCallback(() => {
|
||||
hideModal();
|
||||
setMetaData({} as IMetaDataTableData);
|
||||
}, [hideModal]);
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
if (type === MetadataType.Setting && valueError.field) {
|
||||
return;
|
||||
}
|
||||
if (!metaData.restrictDefinedValues && isShowValueSwitch) {
|
||||
const newMetaData = { ...metaData, values: [] };
|
||||
onSave(newMetaData);
|
||||
} else {
|
||||
onSave(metaData);
|
||||
}
|
||||
handleHideModal();
|
||||
}, [metaData, onSave, handleHideModal, isShowValueSwitch, type, valueError]);
|
||||
|
||||
// Handle value changes, only update temporary state
|
||||
const handleValueChange = useCallback(
|
||||
(index: number, value: string) => {
|
||||
setTempValues((prev) => {
|
||||
if (prev.includes(value)) {
|
||||
setValueError((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
values: t('knowledgeDetails.metadata.valueExists'),
|
||||
};
|
||||
});
|
||||
} else {
|
||||
setValueError((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
values: '',
|
||||
};
|
||||
});
|
||||
}
|
||||
const newValues = [...prev];
|
||||
newValues[index] = value;
|
||||
|
||||
return newValues;
|
||||
});
|
||||
},
|
||||
[t],
|
||||
);
|
||||
|
||||
// Handle blur event, synchronize to main state
|
||||
const handleValueBlur = useCallback(() => {
|
||||
// addUpdateValue(metaData.field, [...new Set([...tempValues])]);
|
||||
tempValues.forEach((newValue, index) => {
|
||||
if (index < data.values.length) {
|
||||
const originalValue = data.values[index];
|
||||
if (originalValue !== newValue) {
|
||||
addUpdateValue(metaData.field, originalValue, newValue);
|
||||
}
|
||||
} else {
|
||||
if (newValue) {
|
||||
addUpdateValue(metaData.field, '', newValue);
|
||||
}
|
||||
}
|
||||
});
|
||||
handleChange('values', [...new Set([...tempValues])]);
|
||||
}, [handleChange, tempValues, metaData, data, addUpdateValue]);
|
||||
|
||||
// Handle delete operation
|
||||
const handleDelete = useCallback(
|
||||
(index: number) => {
|
||||
setTempValues((prev) => {
|
||||
const newTempValues = [...prev];
|
||||
addDeleteValue(metaData.field, newTempValues[index]);
|
||||
newTempValues.splice(index, 1);
|
||||
return newTempValues;
|
||||
});
|
||||
|
||||
// Synchronize to main state
|
||||
setMetaData((prev) => {
|
||||
const newMetaDataValues = [...prev.values];
|
||||
newMetaDataValues.splice(index, 1);
|
||||
return {
|
||||
...prev,
|
||||
values: newMetaDataValues,
|
||||
};
|
||||
});
|
||||
},
|
||||
[addDeleteValue, metaData],
|
||||
);
|
||||
|
||||
const showDeleteModal = (item: string, callback: () => void) => {
|
||||
setDeleteDialogContent({
|
||||
visible: true,
|
||||
title: t('common.delete') + ' ' + t('knowledgeDetails.metadata.value'),
|
||||
name: item,
|
||||
warnText: MetadataDeleteMap(t)[type as MetadataType].warnValueText,
|
||||
onOk: () => {
|
||||
hideDeleteModal();
|
||||
callback();
|
||||
},
|
||||
onCancel: () => {
|
||||
hideDeleteModal();
|
||||
},
|
||||
});
|
||||
};
|
||||
|
||||
// Handle adding new value
|
||||
const handleAddValue = useCallback(() => {
|
||||
setTempValues((prev) => [...new Set([...prev, ''])]);
|
||||
|
||||
// Synchronize to main state
|
||||
setMetaData((prev) => ({
|
||||
...prev,
|
||||
values: [...new Set([...prev.values, ''])],
|
||||
}));
|
||||
}, []);
|
||||
|
||||
return {
|
||||
metaData,
|
||||
tempValues,
|
||||
valueError,
|
||||
deleteDialogContent,
|
||||
handleChange,
|
||||
handleValueChange,
|
||||
handleValueBlur,
|
||||
handleDelete,
|
||||
handleAddValue,
|
||||
showDeleteModal,
|
||||
handleSave,
|
||||
handleHideModal,
|
||||
};
|
||||
};
|
||||
@ -39,6 +39,7 @@ export type IManageModalProps = {
|
||||
|
||||
export interface IManageValuesProps {
|
||||
title: ReactNode;
|
||||
existsKeys: string[];
|
||||
visible: boolean;
|
||||
isEditField?: boolean;
|
||||
isAddValue?: boolean;
|
||||
@ -46,9 +47,14 @@ export interface IManageValuesProps {
|
||||
isShowValueSwitch?: boolean;
|
||||
isVerticalShowValue?: boolean;
|
||||
data: IMetaDataTableData;
|
||||
type: MetadataType;
|
||||
hideModal: () => void;
|
||||
onSave: (data: IMetaDataTableData) => void;
|
||||
addUpdateValue: (key: string, value: string | string[]) => void;
|
||||
addUpdateValue: (
|
||||
key: string,
|
||||
originalValue: string,
|
||||
newValue: string,
|
||||
) => void;
|
||||
addDeleteValue: (key: string, value: string) => void;
|
||||
}
|
||||
|
||||
@ -59,7 +65,8 @@ interface DeleteOperation {
|
||||
|
||||
interface UpdateOperation {
|
||||
key: string;
|
||||
value: string | string[];
|
||||
match: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
export interface MetadataOperations {
|
||||
|
||||
@ -25,11 +25,16 @@ import {
|
||||
useReactTable,
|
||||
} from '@tanstack/react-table';
|
||||
import { Plus, Settings, Trash2 } from 'lucide-react';
|
||||
import { useCallback, useMemo, useState } from 'react';
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MetadataType, useManageMetaDataModal } from './hook';
|
||||
import {
|
||||
MetadataDeleteMap,
|
||||
MetadataType,
|
||||
useManageMetaDataModal,
|
||||
} from './hooks/use-manage-modal';
|
||||
import { IManageModalProps, IMetaDataTableData } from './interface';
|
||||
import { ManageValuesModal } from './manage-values-modal';
|
||||
|
||||
export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
const {
|
||||
title,
|
||||
@ -54,6 +59,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
values: [],
|
||||
});
|
||||
|
||||
const [currentValueIndex, setCurrentValueIndex] = useState<number>(0);
|
||||
const [deleteDialogContent, setDeleteDialogContent] = useState({
|
||||
visible: false,
|
||||
title: '',
|
||||
@ -94,12 +100,12 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
description: '',
|
||||
values: [],
|
||||
});
|
||||
// setCurrentValueIndex(tableData.length || 0);
|
||||
setCurrentValueIndex(tableData.length || 0);
|
||||
showManageValuesModal();
|
||||
};
|
||||
const handleEditValueRow = useCallback(
|
||||
(data: IMetaDataTableData) => {
|
||||
// setCurrentValueIndex(index);
|
||||
(data: IMetaDataTableData, index: number) => {
|
||||
setCurrentValueIndex(index);
|
||||
setValueData(data);
|
||||
showManageValuesModal();
|
||||
},
|
||||
@ -133,7 +139,8 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
const values = row.getValue('values') as Array<string>;
|
||||
return (
|
||||
<div className="flex items-center gap-1">
|
||||
{values.length > 0 &&
|
||||
{Array.isArray(values) &&
|
||||
values.length > 0 &&
|
||||
values
|
||||
.filter((value: string, index: number) => index < 2)
|
||||
?.map((value: string) => {
|
||||
@ -153,10 +160,28 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
variant={'delete'}
|
||||
className="p-0 bg-transparent"
|
||||
onClick={() => {
|
||||
handleDeleteSingleValue(
|
||||
row.getValue('field'),
|
||||
value,
|
||||
);
|
||||
setDeleteDialogContent({
|
||||
visible: true,
|
||||
title:
|
||||
t('common.delete') +
|
||||
' ' +
|
||||
t('knowledgeDetails.metadata.value'),
|
||||
name: value,
|
||||
warnText:
|
||||
MetadataDeleteMap(t)[
|
||||
metadataType as MetadataType
|
||||
].warnValueText,
|
||||
onOk: () => {
|
||||
hideDeleteModal();
|
||||
handleDeleteSingleValue(
|
||||
row.getValue('field'),
|
||||
value,
|
||||
);
|
||||
},
|
||||
onCancel: () => {
|
||||
hideDeleteModal();
|
||||
},
|
||||
});
|
||||
}}
|
||||
>
|
||||
<Trash2 />
|
||||
@ -166,7 +191,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
</Button>
|
||||
);
|
||||
})}
|
||||
{values.length > 2 && (
|
||||
{Array.isArray(values) && values.length > 2 && (
|
||||
<div className="text-text-secondary self-end">...</div>
|
||||
)}
|
||||
</div>
|
||||
@ -185,7 +210,7 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
variant={'ghost'}
|
||||
className="bg-transparent px-1 py-0"
|
||||
onClick={() => {
|
||||
handleEditValueRow(row.original);
|
||||
handleEditValueRow(row.original, row.index);
|
||||
}}
|
||||
>
|
||||
<Settings />
|
||||
@ -197,11 +222,14 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
setDeleteDialogContent({
|
||||
visible: true,
|
||||
title:
|
||||
t('common.delete') +
|
||||
' ' +
|
||||
t('knowledgeDetails.metadata.metadata'),
|
||||
// t('common.delete') +
|
||||
// ' ' +
|
||||
// t('knowledgeDetails.metadata.metadata')
|
||||
MetadataDeleteMap(t)[metadataType as MetadataType].title,
|
||||
name: row.getValue('field'),
|
||||
warnText: t('knowledgeDetails.metadata.deleteWarn'),
|
||||
warnText:
|
||||
MetadataDeleteMap(t)[metadataType as MetadataType]
|
||||
.warnFieldText,
|
||||
onOk: () => {
|
||||
hideDeleteModal();
|
||||
handleDeleteSingleRow(row.getValue('field'));
|
||||
@ -240,15 +268,29 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
getFilteredRowModel: getFilteredRowModel(),
|
||||
manualPagination: true,
|
||||
});
|
||||
|
||||
const [shouldSave, setShouldSave] = useState(false);
|
||||
const handleSaveValues = (data: IMetaDataTableData) => {
|
||||
setTableData((prev) => {
|
||||
//If the keys are the same, they need to be merged.
|
||||
const fieldMap = new Map<string, any>();
|
||||
let newData;
|
||||
if (currentValueIndex >= prev.length) {
|
||||
// Add operation
|
||||
newData = [...prev, data];
|
||||
} else {
|
||||
// Edit operation
|
||||
newData = prev.map((item, index) => {
|
||||
if (index === currentValueIndex) {
|
||||
return data;
|
||||
}
|
||||
return item;
|
||||
});
|
||||
}
|
||||
|
||||
prev.forEach((item) => {
|
||||
// Deduplicate by field and merge values
|
||||
const fieldMap = new Map<string, IMetaDataTableData>();
|
||||
newData.forEach((item) => {
|
||||
if (fieldMap.has(item.field)) {
|
||||
const existingItem = fieldMap.get(item.field);
|
||||
// Merge values if field exists
|
||||
const existingItem = fieldMap.get(item.field)!;
|
||||
const mergedValues = [
|
||||
...new Set([...existingItem.values, ...item.values]),
|
||||
];
|
||||
@ -258,20 +300,26 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
}
|
||||
});
|
||||
|
||||
if (fieldMap.has(data.field)) {
|
||||
const existingItem = fieldMap.get(data.field);
|
||||
const mergedValues = [
|
||||
...new Set([...existingItem.values, ...data.values]),
|
||||
];
|
||||
fieldMap.set(data.field, { ...existingItem, values: mergedValues });
|
||||
} else {
|
||||
fieldMap.set(data.field, data);
|
||||
}
|
||||
|
||||
return Array.from(fieldMap.values());
|
||||
});
|
||||
setShouldSave(true);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
if (shouldSave) {
|
||||
const timer = setTimeout(() => {
|
||||
handleSave({ callback: () => {} });
|
||||
setShouldSave(false);
|
||||
}, 0);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [tableData, shouldSave, handleSave]);
|
||||
|
||||
const existsKeys = useMemo(() => {
|
||||
return tableData.map((item) => item.field);
|
||||
}, [tableData]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal
|
||||
@ -352,11 +400,14 @@ export const ManageMetadataModal = (props: IManageModalProps) => {
|
||||
<ManageValuesModal
|
||||
title={
|
||||
<div>
|
||||
{metadataType === MetadataType.Setting
|
||||
{metadataType === MetadataType.Setting ||
|
||||
metadataType === MetadataType.SingleFileSetting
|
||||
? t('knowledgeDetails.metadata.fieldSetting')
|
||||
: t('knowledgeDetails.metadata.editMetadata')}
|
||||
</div>
|
||||
}
|
||||
type={metadataType}
|
||||
existsKeys={existsKeys}
|
||||
visible={manageValuesVisible}
|
||||
hideModal={hideManageValuesModal}
|
||||
data={valueData}
|
||||
|
||||
@ -1,13 +1,19 @@
|
||||
import {
|
||||
ConfirmDeleteDialog,
|
||||
ConfirmDeleteDialogNode,
|
||||
} from '@/components/confirm-delete-dialog';
|
||||
import EditTag from '@/components/edit-tag';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { FormLabel } from '@/components/ui/form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { Modal } from '@/components/ui/modal/modal';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import { Textarea } from '@/components/ui/textarea';
|
||||
import { Plus, Trash2 } from 'lucide-react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IManageValuesProps, IMetaDataTableData } from './interface';
|
||||
import { useManageValues } from './hooks/use-manage-values-modal';
|
||||
import { IManageValuesProps } from './interface';
|
||||
|
||||
// Create a separate input component, wrapped with memo to avoid unnecessary re-renders
|
||||
const ValueInputItem = memo(
|
||||
@ -52,102 +58,29 @@ const ValueInputItem = memo(
|
||||
export const ManageValuesModal = (props: IManageValuesProps) => {
|
||||
const {
|
||||
title,
|
||||
data,
|
||||
isEditField,
|
||||
visible,
|
||||
isAddValue,
|
||||
isShowDescription,
|
||||
isShowValueSwitch,
|
||||
isVerticalShowValue,
|
||||
hideModal,
|
||||
onSave,
|
||||
addUpdateValue,
|
||||
addDeleteValue,
|
||||
} = props;
|
||||
const [metaData, setMetaData] = useState(data);
|
||||
const {
|
||||
metaData,
|
||||
tempValues,
|
||||
valueError,
|
||||
deleteDialogContent,
|
||||
handleChange,
|
||||
handleValueChange,
|
||||
handleValueBlur,
|
||||
handleDelete,
|
||||
handleAddValue,
|
||||
showDeleteModal,
|
||||
handleSave,
|
||||
handleHideModal,
|
||||
} = useManageValues(props);
|
||||
const { t } = useTranslation();
|
||||
|
||||
// Use functional update to avoid closure issues
|
||||
const handleChange = useCallback((field: string, value: any) => {
|
||||
setMetaData((prev) => ({
|
||||
...prev,
|
||||
[field]: value,
|
||||
}));
|
||||
}, []);
|
||||
|
||||
// Maintain separate state for each input box
|
||||
const [tempValues, setTempValues] = useState<string[]>([...data.values]);
|
||||
|
||||
useEffect(() => {
|
||||
setTempValues([...data.values]);
|
||||
setMetaData(data);
|
||||
}, [data]);
|
||||
|
||||
const handleHideModal = useCallback(() => {
|
||||
hideModal();
|
||||
setMetaData({} as IMetaDataTableData);
|
||||
}, [hideModal]);
|
||||
|
||||
const handleSave = useCallback(() => {
|
||||
if (!metaData.restrictDefinedValues && isShowValueSwitch) {
|
||||
const newMetaData = { ...metaData, values: [] };
|
||||
onSave(newMetaData);
|
||||
} else {
|
||||
onSave(metaData);
|
||||
}
|
||||
handleHideModal();
|
||||
}, [metaData, onSave, handleHideModal, isShowValueSwitch]);
|
||||
|
||||
// Handle value changes, only update temporary state
|
||||
const handleValueChange = useCallback((index: number, value: string) => {
|
||||
setTempValues((prev) => {
|
||||
const newValues = [...prev];
|
||||
newValues[index] = value;
|
||||
|
||||
return newValues;
|
||||
});
|
||||
}, []);
|
||||
|
||||
// Handle blur event, synchronize to main state
|
||||
const handleValueBlur = useCallback(() => {
|
||||
addUpdateValue(metaData.field, [...new Set([...tempValues])]);
|
||||
handleChange('values', [...new Set([...tempValues])]);
|
||||
}, [handleChange, tempValues, metaData, addUpdateValue]);
|
||||
|
||||
// Handle delete operation
|
||||
const handleDelete = useCallback(
|
||||
(index: number) => {
|
||||
setTempValues((prev) => {
|
||||
const newTempValues = [...prev];
|
||||
addDeleteValue(metaData.field, newTempValues[index]);
|
||||
newTempValues.splice(index, 1);
|
||||
return newTempValues;
|
||||
});
|
||||
|
||||
// Synchronize to main state
|
||||
setMetaData((prev) => {
|
||||
const newMetaDataValues = [...prev.values];
|
||||
newMetaDataValues.splice(index, 1);
|
||||
return {
|
||||
...prev,
|
||||
values: newMetaDataValues,
|
||||
};
|
||||
});
|
||||
},
|
||||
[addDeleteValue, metaData],
|
||||
);
|
||||
|
||||
// Handle adding new value
|
||||
const handleAddValue = useCallback(() => {
|
||||
setTempValues((prev) => [...new Set([...prev, ''])]);
|
||||
|
||||
// Synchronize to main state
|
||||
setMetaData((prev) => ({
|
||||
...prev,
|
||||
values: [...new Set([...prev.values, ''])],
|
||||
}));
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={title}
|
||||
@ -172,15 +105,24 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
|
||||
<Input
|
||||
value={metaData.field}
|
||||
onChange={(e) => {
|
||||
handleChange('field', e.target?.value || '');
|
||||
const value = e.target?.value || '';
|
||||
if (/^[a-zA-Z_]*$/.test(value)) {
|
||||
handleChange('field', value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<div className="text-state-error text-sm">{valueError.field}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{isShowDescription && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div>{t('knowledgeDetails.metadata.description')}</div>
|
||||
<FormLabel
|
||||
className="text-text-primary text-base"
|
||||
tooltip={t('knowledgeDetails.metadata.descriptionTip')}
|
||||
>
|
||||
{t('knowledgeDetails.metadata.description')}
|
||||
</FormLabel>
|
||||
<div>
|
||||
<Textarea
|
||||
value={metaData.description}
|
||||
@ -193,7 +135,12 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
|
||||
)}
|
||||
{isShowValueSwitch && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<div>{t('knowledgeDetails.metadata.restrictDefinedValues')}</div>
|
||||
<FormLabel
|
||||
className="text-text-primary text-base"
|
||||
tooltip={t('knowledgeDetails.metadata.restrictTDefinedValuesTip')}
|
||||
>
|
||||
{t('knowledgeDetails.metadata.restrictDefinedValues')}
|
||||
</FormLabel>
|
||||
<div>
|
||||
<Switch
|
||||
checked={metaData.restrictDefinedValues || false}
|
||||
@ -230,7 +177,11 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
|
||||
item={item}
|
||||
index={index}
|
||||
onValueChange={handleValueChange}
|
||||
onDelete={handleDelete}
|
||||
onDelete={(idx: number) => {
|
||||
showDeleteModal(item, () => {
|
||||
handleDelete(idx);
|
||||
});
|
||||
}}
|
||||
onBlur={handleValueBlur}
|
||||
/>
|
||||
);
|
||||
@ -240,11 +191,41 @@ export const ManageValuesModal = (props: IManageValuesProps) => {
|
||||
{!isVerticalShowValue && (
|
||||
<EditTag
|
||||
value={metaData.values}
|
||||
onChange={(value) => handleChange('values', value)}
|
||||
onChange={(value) => {
|
||||
// find deleted value
|
||||
const item = metaData.values.find(
|
||||
(item) => !value.includes(item),
|
||||
);
|
||||
if (item) {
|
||||
showDeleteModal(item, () => {
|
||||
// handleDelete(idx);
|
||||
handleChange('values', value);
|
||||
});
|
||||
} else {
|
||||
handleChange('values', value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
<div className="text-state-error text-sm">{valueError.values}</div>
|
||||
</div>
|
||||
)}
|
||||
{deleteDialogContent.visible && (
|
||||
<ConfirmDeleteDialog
|
||||
open={deleteDialogContent.visible}
|
||||
onCancel={deleteDialogContent.onCancel}
|
||||
onOk={deleteDialogContent.onOk}
|
||||
title={deleteDialogContent.title}
|
||||
content={{
|
||||
node: (
|
||||
<ConfirmDeleteDialogNode
|
||||
name={deleteDialogContent.name}
|
||||
warnText={deleteDialogContent.warnText}
|
||||
/>
|
||||
),
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
@ -35,7 +35,8 @@ import {
|
||||
MetadataType,
|
||||
useManageMetadata,
|
||||
util,
|
||||
} from '../../components/metedata/hook';
|
||||
} from '../../components/metedata/hooks/use-manage-modal';
|
||||
import { IMetaDataReturnJSONSettings } from '../../components/metedata/interface';
|
||||
import { ManageMetadataModal } from '../../components/metedata/manage-modal';
|
||||
import {
|
||||
useHandleKbEmbedding,
|
||||
@ -359,7 +360,13 @@ export function OverlappedPercent() {
|
||||
);
|
||||
}
|
||||
|
||||
export function AutoMetadata() {
|
||||
export function AutoMetadata({
|
||||
type = MetadataType.Setting,
|
||||
otherData,
|
||||
}: {
|
||||
type?: MetadataType;
|
||||
otherData?: Record<string, any>;
|
||||
}) {
|
||||
// get metadata field
|
||||
const form = useFormContext();
|
||||
const {
|
||||
@ -369,6 +376,7 @@ export function AutoMetadata() {
|
||||
tableData,
|
||||
config: metadataConfig,
|
||||
} = useManageMetadata();
|
||||
|
||||
const autoMetadataField: FormFieldConfig = {
|
||||
name: 'parser_config.enable_metadata',
|
||||
label: t('knowledgeConfiguration.autoMetadata'),
|
||||
@ -379,6 +387,7 @@ export function AutoMetadata() {
|
||||
render: (fieldProps: ControllerRenderProps) => (
|
||||
<div className="flex items-center justify-between">
|
||||
<Button
|
||||
type="button"
|
||||
variant="ghost"
|
||||
onClick={() => {
|
||||
const metadata = form.getValues('parser_config.metadata');
|
||||
@ -387,7 +396,8 @@ export function AutoMetadata() {
|
||||
showManageMetadataModal({
|
||||
metadata: tableMetaData,
|
||||
isCanAdd: true,
|
||||
type: MetadataType.Setting,
|
||||
type: type,
|
||||
record: otherData,
|
||||
});
|
||||
}}
|
||||
>
|
||||
@ -403,6 +413,10 @@ export function AutoMetadata() {
|
||||
</div>
|
||||
),
|
||||
};
|
||||
|
||||
const handleSaveMetadata = (data?: IMetaDataReturnJSONSettings) => {
|
||||
form.setValue('parser_config.metadata', data || []);
|
||||
};
|
||||
return (
|
||||
<>
|
||||
<RenderField field={autoMetadataField} />
|
||||
@ -431,8 +445,8 @@ export function AutoMetadata() {
|
||||
isShowDescription={true}
|
||||
isShowValueSwitch={true}
|
||||
isVerticalShowValue={false}
|
||||
success={(data) => {
|
||||
form.setValue('parser_config.metadata', data || []);
|
||||
success={(data?: IMetaDataReturnJSONSettings) => {
|
||||
handleSaveMetadata(data);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
@ -96,7 +96,7 @@ export const formSchema = z
|
||||
)
|
||||
.optional(),
|
||||
enable_metadata: z.boolean().optional(),
|
||||
llm_id: z.string().optional(),
|
||||
llm_id: z.string().min(1, { message: 'Indexing model is required' }),
|
||||
})
|
||||
.optional(),
|
||||
pagerank: z.number(),
|
||||
|
||||
@ -16,7 +16,10 @@ import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-reques
|
||||
import { Pen, Upload } from 'lucide-react';
|
||||
import { useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MetadataType, useManageMetadata } from '../components/metedata/hook';
|
||||
import {
|
||||
MetadataType,
|
||||
useManageMetadata,
|
||||
} from '../components/metedata/hooks/use-manage-modal';
|
||||
import { ManageMetadataModal } from '../components/metedata/manage-modal';
|
||||
import { DatasetTable } from './dataset-table';
|
||||
import Generate from './generate-button/generate';
|
||||
|
||||
@ -16,7 +16,10 @@ import { formatDate } from '@/utils/date';
|
||||
import { ColumnDef } from '@tanstack/table-core';
|
||||
import { ArrowUpDown, MonitorUp } from 'lucide-react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { MetadataType, util } from '../components/metedata/hook';
|
||||
import {
|
||||
MetadataType,
|
||||
util,
|
||||
} from '../components/metedata/hooks/use-manage-modal';
|
||||
import { ShowManageMetadataModalProps } from '../components/metedata/interface';
|
||||
import { DatasetActionCell } from './dataset-action-cell';
|
||||
import { ParsingStatusCell } from './parsing-status-cell';
|
||||
|
||||
@ -1,8 +1,13 @@
|
||||
import { FilterCollection } from '@/components/list-filter-bar/interface';
|
||||
import {
|
||||
FilterCollection,
|
||||
FilterType,
|
||||
} from '@/components/list-filter-bar/interface';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import { useGetDocumentFilter } from '@/hooks/use-document-request';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const EMPTY_METADATA_FIELD = 'empty_metadata';
|
||||
|
||||
export function useSelectDatasetFilters() {
|
||||
const { t } = useTranslate('knowledgeDetails');
|
||||
const { filter, onOpenChange } = useGetDocumentFilter();
|
||||
@ -17,34 +22,52 @@ export function useSelectDatasetFilters() {
|
||||
}
|
||||
}, [filter.suffix]);
|
||||
const fileStatus = useMemo(() => {
|
||||
let list = [] as FilterType[];
|
||||
if (filter.run_status) {
|
||||
return Object.keys(filter.run_status).map((x) => ({
|
||||
list = Object.keys(filter.run_status).map((x) => ({
|
||||
id: x,
|
||||
label: t(`runningStatus${x}`),
|
||||
count: filter.run_status[x as unknown as number],
|
||||
}));
|
||||
}
|
||||
}, [filter.run_status, t]);
|
||||
if (filter.metadata) {
|
||||
const emptyMetadata = filter.metadata?.empty_metadata;
|
||||
if (emptyMetadata) {
|
||||
list.push({
|
||||
id: EMPTY_METADATA_FIELD,
|
||||
label: t('emptyMetadata'),
|
||||
count: emptyMetadata.true,
|
||||
});
|
||||
}
|
||||
}
|
||||
return list;
|
||||
}, [filter.run_status, filter.metadata, t]);
|
||||
const metaDataList = useMemo(() => {
|
||||
if (filter.metadata) {
|
||||
return Object.keys(filter.metadata).map((x) => ({
|
||||
id: x.toString(),
|
||||
field: x.toString(),
|
||||
label: x.toString(),
|
||||
list: Object.keys(filter.metadata[x]).map((y) => ({
|
||||
id: y.toString(),
|
||||
field: y.toString(),
|
||||
label: y.toString(),
|
||||
value: [y],
|
||||
count: filter.metadata[x][y],
|
||||
})),
|
||||
count: Object.keys(filter.metadata[x]).reduce(
|
||||
(acc, cur) => acc + filter.metadata[x][cur],
|
||||
0,
|
||||
),
|
||||
}));
|
||||
const list = Object.keys(filter.metadata)
|
||||
?.filter((m) => m !== EMPTY_METADATA_FIELD)
|
||||
?.map((x) => {
|
||||
return {
|
||||
id: x.toString(),
|
||||
field: x.toString(),
|
||||
label: x.toString(),
|
||||
list: Object.keys(filter.metadata[x]).map((y) => ({
|
||||
id: y.toString(),
|
||||
field: y.toString(),
|
||||
label: y.toString(),
|
||||
value: [y],
|
||||
count: filter.metadata[x][y],
|
||||
})),
|
||||
count: Object.keys(filter.metadata[x]).reduce(
|
||||
(acc, cur) => acc + filter.metadata[x][cur],
|
||||
0,
|
||||
),
|
||||
};
|
||||
});
|
||||
return list;
|
||||
}
|
||||
}, [filter.metadata]);
|
||||
|
||||
const filters: FilterCollection[] = useMemo(() => {
|
||||
return [
|
||||
{ field: 'type', label: 'File Type', list: fileTypes },
|
||||
|
||||
@ -1,6 +1,11 @@
|
||||
import FileStatusBadge from '@/components/file-status-badge';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Modal } from '@/components/ui/modal/modal';
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipTrigger,
|
||||
} from '@/components/ui/tooltip';
|
||||
import { RunningStatusMap } from '@/constants/knowledge';
|
||||
import { useTranslate } from '@/hooks/common-hooks';
|
||||
import React, { useMemo } from 'react';
|
||||
@ -40,7 +45,14 @@ const InfoItem: React.FC<{
|
||||
return (
|
||||
<div className={`flex flex-col mb-4 ${className}`}>
|
||||
<span className="text-text-secondary text-sm">{label}</span>
|
||||
<span className="text-text-primary mt-1">{value}</span>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="text-text-primary mt-1 truncate max-w-[200px]">
|
||||
{value}
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>{value}</TooltipContent>
|
||||
</Tooltip>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@ -70,9 +82,7 @@ const ProcessLogModal: React.FC<ProcessLogModalProps> = ({
|
||||
}) => {
|
||||
const { t } = useTranslate('knowledgeDetails');
|
||||
const blackKeyList = [''];
|
||||
console.log('logInfo', initData);
|
||||
const logInfo = useMemo(() => {
|
||||
console.log('logInfo', initData);
|
||||
return initData;
|
||||
}, [initData]);
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ import {
|
||||
} from '@/components/ui/form';
|
||||
import { Input } from '@/components/ui/input';
|
||||
import { FormLayout } from '@/constants/form';
|
||||
import { useFetchTenantInfo } from '@/hooks/use-user-setting-request';
|
||||
import { IModalProps } from '@/interfaces/common';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
import { useEffect } from 'react';
|
||||
@ -33,6 +34,7 @@ const FormId = 'dataset-creating-form';
|
||||
|
||||
export function InputForm({ onOk }: IModalProps<any>) {
|
||||
const { t } = useTranslation();
|
||||
const { data: tenantInfo } = useFetchTenantInfo();
|
||||
|
||||
const FormSchema = z
|
||||
.object({
|
||||
@ -80,7 +82,7 @@ export function InputForm({ onOk }: IModalProps<any>) {
|
||||
name: '',
|
||||
parseType: 1,
|
||||
parser_id: '',
|
||||
embd_id: '',
|
||||
embd_id: tenantInfo?.embd_id,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@ -48,10 +48,10 @@ const {
|
||||
traceRaptor,
|
||||
check_embedding,
|
||||
kbUpdateMetaData,
|
||||
documentUpdateMetaData,
|
||||
} = api;
|
||||
|
||||
const methods = {
|
||||
// 知识库管理
|
||||
createKb: {
|
||||
url: create_kb,
|
||||
method: 'post',
|
||||
@ -220,6 +220,10 @@ const methods = {
|
||||
url: kbUpdateMetaData,
|
||||
method: 'post',
|
||||
},
|
||||
documentUpdateMetaData: {
|
||||
url: documentUpdateMetaData,
|
||||
method: 'post',
|
||||
},
|
||||
// getMetaData: {
|
||||
// url: getMetaData,
|
||||
// method: 'get',
|
||||
@ -263,7 +267,7 @@ export const documentFilter = (kb_id: string) =>
|
||||
export const getMetaDataService = ({ kb_id }: { kb_id: string }) =>
|
||||
request.post(api.getMetaData, { data: { kb_id } });
|
||||
export const updateMetaData = ({ kb_id, data }: { kb_id: string; data: any }) =>
|
||||
request.post(api.updateMetaData, { data: { kb_id, data } });
|
||||
request.post(api.updateMetaData, { data: { kb_id, ...data } });
|
||||
|
||||
export const listDataPipelineLogDocument = (
|
||||
params?: IFetchKnowledgeListRequestParams,
|
||||
|
||||
@ -80,6 +80,7 @@ export default {
|
||||
getMetaData: `${api_host}/document/metadata/summary`,
|
||||
updateMetaData: `${api_host}/document/metadata/update`,
|
||||
kbUpdateMetaData: `${api_host}/kb/update_metadata_setting`,
|
||||
documentUpdateMetaData: `${api_host}/document/update_metadata_setting`,
|
||||
|
||||
// tags
|
||||
listTag: (knowledgeId: string) => `${api_host}/kb/${knowledgeId}/tags`,
|
||||
|
||||
Reference in New Issue
Block a user