diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 1da832e7b..4dcdba809 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -22,13 +22,13 @@ from abc import ABC from agent.tools.base import ToolParamBase, ToolBase, ToolMeta from common.constants import LLMType from api.db.services.document_service import DocumentService -from api.db.services.dialog_service import meta_filter +from common.metadata_utils import apply_meta_data_filter from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from common import settings from common.connection_utils import timeout from rag.app.tag import label_question -from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter +from rag.prompts.generator import cross_languages, kb_prompt class RetrievalParam(ToolParamBase): @@ -131,54 +131,48 @@ class Retrieval(ToolBase, ABC): doc_ids=[] if self._param.meta_data_filter!={}: metas = DocumentService.get_meta_by_kbs(kb_ids) - if self._param.meta_data_filter.get("method") == "auto": + + def _resolve_manual_filter(flt: dict) -> dict: + pat = re.compile(self.variable_ref_patt) + s = flt.get("value", "") + out_parts = [] + last = 0 + + for m in pat.finditer(s): + out_parts.append(s[last:m.start()]) + key = m.group(1) + v = self._canvas.get_variable_value(key) + if v is None: + rep = "" + elif isinstance(v, partial): + buf = [] + for chunk in v(): + buf.append(chunk) + rep = "".join(buf) + elif isinstance(v, str): + rep = v + else: + rep = json.dumps(v, ensure_ascii=False) + + out_parts.append(rep) + last = m.end() + + out_parts.append(s[last:]) + flt["value"] = "".join(out_parts) + return flt + + chat_mdl = None + if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, metas, query) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif self._param.meta_data_filter.get("method") == "semi_auto": - selected_keys = self._param.meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif self._param.meta_data_filter.get("method") == "manual": - filters = self._param.meta_data_filter["manual"] - for flt in filters: - pat = re.compile(self.variable_ref_patt) - s = flt["value"] - out_parts = [] - last = 0 - for m in pat.finditer(s): - out_parts.append(s[last:m.start()]) - key = m.group(1) - v = self._canvas.get_variable_value(key) - if v is None: - rep = "" - elif isinstance(v, partial): - buf = [] - for chunk in v(): - buf.append(chunk) - rep = "".join(buf) - elif isinstance(v, str): - rep = v - else: - rep = json.dumps(v, ensure_ascii=False) - - out_parts.append(rep) - last = m.end() - - out_parts.append(s[last:]) - flt["value"] = "".join(out_parts) - doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and"))) - if filters and not doc_ids: - doc_ids = ["-999"] + doc_ids = await apply_meta_data_filter( + self._param.meta_data_filter, + metas, + query, + chat_mdl, + doc_ids, + _resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None, + ) if self._param.cross_languages: query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index af6bb6617..09ff864b5 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -21,10 +21,10 @@ import re import xxhash from quart import request -from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \ @@ -32,7 +32,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search -from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction +from rag.prompts.generator import cross_languages, keyword_extraction from common.string_utils import remove_redundant_spaces from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD from common import settings @@ -317,54 +317,21 @@ async def retrieval_test(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] + meta_data_filter = {} + chat_mdl = None if req.get("search_id", ""): search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": + if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "manual": - local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not local_doc_ids: - local_doc_ids = ["-999"] else: - meta_data_filter = req.get("meta_data_filter") - if meta_data_filter: - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(user_id, LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - chat_mdl = LLMBundle(user_id, LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "manual": - local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not local_doc_ids: - local_doc_ids = ["-999"] + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(user_id, LLMType.CHAT) + + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids) tenants = UserTenantService.query(user_id=user_id) for kb_id in kb_ids: diff --git a/api/apps/document_app.py b/api/apps/document_app.py index 1a987cd39..ba1f6e7db 100644 --- a/api/apps/document_app.py +++ b/api/apps/document_app.py @@ -27,7 +27,7 @@ from api.db import VALID_FILE_TYPES, FileType from api.db.db_models import Task from api.db.services import duplicate_name from api.db.services.document_service import DocumentService, doc_upload_and_parse -from api.db.services.dialog_service import meta_filter, convert_conditions +from common.metadata_utils import meta_filter, convert_conditions from api.db.services.file2document_service import File2DocumentService from api.db.services.file_service import FileService from api.db.services.knowledgebase_service import KnowledgebaseService diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 9665754eb..7a11688dd 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -20,9 +20,9 @@ from quart import jsonify from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import meta_filter, convert_conditions from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request from rag.app.tag import label_question -from api.db.services.dialog_service import meta_filter, convert_conditions from common.constants import RetCode, LLMType from common import settings diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index 19fabdcff..a5c120d31 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -35,7 +35,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of -from api.db.services.dialog_service import meta_filter, convert_conditions +from common.metadata_utils import meta_filter, convert_conditions from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \ get_request_json from rag.app.qa import beAdoc, rmPrefix diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index d4db3cb56..34ae6b1ca 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -28,10 +28,11 @@ from api.db.services.canvas_service import completion as agent_completion from api.db.services.conversation_service import ConversationService from api.db.services.conversation_service import async_iframe_completion as iframe_completion from api.db.services.conversation_service import async_completion as rag_completion -from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter +from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from common.misc_utils import get_uuid @@ -39,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_ get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question from rag.prompts.template import load_prompt -from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format +from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format from common.constants import RetCode, LLMType, StatusEnum from common import settings @@ -974,54 +975,21 @@ async def retrieval_test_embedded(): tenant_ids = [] _question = question + meta_data_filter = {} + chat_mdl = None if req.get("search_id", ""): search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": + if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = await gen_meta_filter(chat_mdl, metas, _question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "manual": - local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not local_doc_ids: - local_doc_ids = ["-999"] else: - meta_data_filter = req.get("meta_data_filter") - if meta_data_filter: - metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) - local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not local_doc_ids: - local_doc_ids = None - elif meta_data_filter.get("method") == "manual": - local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not local_doc_ids: - local_doc_ids = ["-999"] + meta_data_filter = req.get("meta_data_filter") or {} + if meta_data_filter.get("method") in ["auto", "semi_auto"]: + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + + if meta_data_filter: + metas = DocumentService.get_meta_by_kbs(kb_ids) + local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids) tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 0fded53f6..e956b0a5b 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -32,6 +32,7 @@ from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle +from common.metadata_utils import apply_meta_data_filter from api.db.services.tenant_llm_service import TenantLLMService from common.time_utils import current_timestamp, datetime_format from graphrag.general.mind_map_extractor import MindMapExtractor @@ -39,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ - gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY + PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces @@ -277,77 +278,6 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): return answer, idx -def convert_conditions(metadata_condition): - if metadata_condition is None: - metadata_condition = {} - op_mapping = { - "is": "=", - "not is": "≠" - } - return [ - { - "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), - "key": cond["name"], - "value": cond["value"] - } - for cond in metadata_condition.get("conditions", []) - ] - - -def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): - doc_ids = set([]) - - def filter_out(v2docs, operator, value): - ids = [] - for input, docids in v2docs.items(): - if operator in ["=", "≠", ">", "<", "≥", "≤"]: - try: - input = float(input) - value = float(value) - except Exception: - input = str(input) - value = str(value) - - for conds in [ - (operator == "contains", str(value).lower() in str(input).lower()), - (operator == "not contains", str(value).lower() not in str(input).lower()), - (operator == "in", str(input).lower() in str(value).lower()), - (operator == "not in", str(input).lower() not in str(value).lower()), - (operator == "start with", str(input).lower().startswith(str(value).lower())), - (operator == "end with", str(input).lower().endswith(str(value).lower())), - (operator == "empty", not input), - (operator == "not empty", input), - (operator == "=", input == value), - (operator == "≠", input != value), - (operator == ">", input > value), - (operator == "<", input < value), - (operator == "≥", input >= value), - (operator == "≤", input <= value), - ]: - try: - if all(conds): - ids.extend(docids) - break - except Exception: - pass - return ids - - for k, v2docs in metas.items(): - for f in filters: - if k != f["key"]: - continue - ids = filter_out(v2docs, f["op"], f["value"]) - if not doc_ids: - doc_ids = set(ids) - else: - if logic == "and": - doc_ids = doc_ids & set(ids) - else: - doc_ids = doc_ids | set(ids) - if not doc_ids: - return [] - return list(doc_ids) - async def async_chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): @@ -420,25 +350,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) - if dialog.meta_data_filter.get("method") == "auto": - filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1]) - attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not attachments: - attachments = None - elif dialog.meta_data_filter.get("method") == "semi_auto": - selected_keys = dialog.meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1]) - attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not attachments: - attachments = None - elif dialog.meta_data_filter.get("method") == "manual": - conds = dialog.meta_data_filter["manual"] - attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and"))) - if conds and not attachments: - attachments = ["-999"] + attachments = await apply_meta_data_filter( + dialog.meta_data_filter, + metas, + questions[-1], + chat_mdl, + attachments, + ) if prompt_config.get("keyword", False): questions[-1] += await keyword_extraction(chat_mdl, questions[-1]) @@ -838,24 +756,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) kbinfos = retriever.retrieval( question=question, @@ -922,24 +823,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) - if meta_data_filter.get("method") == "auto": - filters: dict = await gen_meta_filter(chat_mdl, metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "semi_auto": - selected_keys = meta_data_filter.get("semi_auto", []) - if selected_keys: - filtered_metas = {key: metas[key] for key in selected_keys if key in metas} - if filtered_metas: - filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) - doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) - if not doc_ids: - doc_ids = None - elif meta_data_filter.get("method") == "manual": - doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and"))) - if meta_data_filter["manual"] and not doc_ids: - doc_ids = ["-999"] + doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) ranks = settings.retriever.retrieval( question=question, diff --git a/common/metadata_utils.py b/common/metadata_utils.py new file mode 100644 index 000000000..957ed3ece --- /dev/null +++ b/common/metadata_utils.py @@ -0,0 +1,142 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Any, Callable + +from rag.prompts.generator import gen_meta_filter + + +def convert_conditions(metadata_condition): + if metadata_condition is None: + metadata_condition = {} + op_mapping = { + "is": "=", + "not is": "≠" + } + return [ + { + "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), + "key": cond["name"], + "value": cond["value"] + } + for cond in metadata_condition.get("conditions", []) + ] + + +def meta_filter(metas: dict, filters: list[dict], logic: str = "and"): + doc_ids = set([]) + + def filter_out(v2docs, operator, value): + ids = [] + for input, docids in v2docs.items(): + if operator in ["=", "≠", ">", "<", "≥", "≤"]: + try: + input = float(input) + value = float(value) + except Exception: + input = str(input) + value = str(value) + + for conds in [ + (operator == "contains", str(value).lower() in str(input).lower()), + (operator == "not contains", str(value).lower() not in str(input).lower()), + (operator == "in", str(input).lower() in str(value).lower()), + (operator == "not in", str(input).lower() not in str(value).lower()), + (operator == "start with", str(input).lower().startswith(str(value).lower())), + (operator == "end with", str(input).lower().endswith(str(value).lower())), + (operator == "empty", not input), + (operator == "not empty", input), + (operator == "=", input == value), + (operator == "≠", input != value), + (operator == ">", input > value), + (operator == "<", input < value), + (operator == "≥", input >= value), + (operator == "≤", input <= value), + ]: + try: + if all(conds): + ids.extend(docids) + break + except Exception: + pass + return ids + + for k, v2docs in metas.items(): + for f in filters: + if k != f["key"]: + continue + ids = filter_out(v2docs, f["op"], f["value"]) + if not doc_ids: + doc_ids = set(ids) + else: + if logic == "and": + doc_ids = doc_ids & set(ids) + else: + doc_ids = doc_ids | set(ids) + if not doc_ids: + return [] + return list(doc_ids) + + +async def apply_meta_data_filter( + meta_data_filter: dict | None, + metas: dict, + question: str, + chat_mdl: Any = None, + base_doc_ids: list[str] | None = None, + manual_value_resolver: Callable[[dict], dict] | None = None, +) -> list[str] | None: + """ + Apply metadata filtering rules and return the filtered doc_ids. + + meta_data_filter supports three modes: + - auto: generate filter conditions via LLM (gen_meta_filter) + - semi_auto: generate conditions using selected metadata keys only + - manual: directly filter based on provided conditions + + Returns: + list of doc_ids, ["-999"] when manual filters yield no result, or None + when auto/semi_auto filters return empty. + """ + doc_ids = list(base_doc_ids) if base_doc_ids else [] + + if not meta_data_filter: + return doc_ids + + method = meta_data_filter.get("method") + + if method == "auto": + filters: dict = await gen_meta_filter(chat_mdl, metas, question) + doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not doc_ids: + return None + elif method == "semi_auto": + selected_keys = meta_data_filter.get("semi_auto", []) + if selected_keys: + filtered_metas = {key: metas[key] for key in selected_keys if key in metas} + if filtered_metas: + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) + doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) + if not doc_ids: + return None + elif method == "manual": + filters = meta_data_filter.get("manual", []) + if manual_value_resolver: + filters = [manual_value_resolver(flt) for flt in filters] + doc_ids.extend(meta_filter(metas, filters, meta_data_filter.get("logic", "and"))) + if filters and not doc_ids: + doc_ids = ["-999"] + + return doc_ids diff --git a/web/src/hooks/use-knowledge-request.ts b/web/src/hooks/use-knowledge-request.ts index daaa28715..7a89ef525 100644 --- a/web/src/hooks/use-knowledge-request.ts +++ b/web/src/hooks/use-knowledge-request.ts @@ -75,6 +75,7 @@ export const useTestRetrieval = () => { page, size: pageSize, doc_ids: filterValue.doc_ids, + highlight: true, }; }, [filterValue, knowledgeBaseId, page, pageSize, values]); diff --git a/web/src/pages/next-search/search-view.tsx b/web/src/pages/next-search/search-view.tsx index 5720896fd..890e28072 100644 --- a/web/src/pages/next-search/search-view.tsx +++ b/web/src/pages/next-search/search-view.tsx @@ -209,7 +209,11 @@ export default function SearchingView({