mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa: refactor metadata filter (#11907)
### What problem does this PR solve? Refactor metadata filter. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -22,13 +22,13 @@ from abc import ABC
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.constants import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -131,27 +131,10 @@ class Retrieval(ToolBase, ABC):
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if self._param.meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = self._param.meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif self._param.meta_data_filter.get("method") == "manual":
|
||||
filters = self._param.meta_data_filter["manual"]
|
||||
for flt in filters:
|
||||
|
||||
def _resolve_manual_filter(flt: dict) -> dict:
|
||||
pat = re.compile(self.variable_ref_patt)
|
||||
s = flt["value"]
|
||||
s = flt.get("value", "")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
@ -176,9 +159,20 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
out_parts.append(s[last:])
|
||||
flt["value"] = "".join(out_parts)
|
||||
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
|
||||
if filters and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
return flt
|
||||
|
||||
chat_mdl = None
|
||||
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||
|
||||
doc_ids = await apply_meta_data_filter(
|
||||
self._param.meta_data_filter,
|
||||
metas,
|
||||
query,
|
||||
chat_mdl,
|
||||
doc_ids,
|
||||
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
|
||||
)
|
||||
|
||||
if self._param.cross_languages:
|
||||
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||
|
||||
@ -21,10 +21,10 @@ import re
|
||||
import xxhash
|
||||
from quart import request
|
||||
|
||||
from api.db.services.dialog_service import meta_filter
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||
@ -32,7 +32,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||
from common import settings
|
||||
@ -317,54 +317,21 @@ async def retrieval_test():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter")
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||
|
||||
tenants = UserTenantService.query(user_id=user_id)
|
||||
for kb_id in kb_ids:
|
||||
|
||||
@ -27,7 +27,7 @@ from api.db import VALID_FILE_TYPES, FileType
|
||||
from api.db.db_models import Task
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.db.services.file2document_service import File2DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
@ -20,9 +20,9 @@ from quart import jsonify
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.constants import RetCode, LLMType
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
||||
from common.metadata_utils import meta_filter, convert_conditions
|
||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||
get_request_json
|
||||
from rag.app.qa import beAdoc, rmPrefix
|
||||
|
||||
@ -28,10 +28,11 @@ from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||
from api.db.services.conversation_service import async_completion as rag_completion
|
||||
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
|
||||
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
@ -39,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
||||
from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format
|
||||
from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import settings
|
||||
|
||||
@ -974,54 +975,21 @@ async def retrieval_test_embedded():
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
|
||||
meta_data_filter = {}
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter")
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not local_doc_ids:
|
||||
local_doc_ids = ["-999"]
|
||||
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
|
||||
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
for kb_id in kb_ids:
|
||||
|
||||
@ -32,6 +32,7 @@ from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.langfuse_service import TenantLangfuseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from common.time_utils import current_timestamp, datetime_format
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
@ -39,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
from common.string_utils import remove_redundant_spaces
|
||||
@ -277,77 +278,6 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
||||
return answer, idx
|
||||
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return ids
|
||||
|
||||
for k, v2docs in metas.items():
|
||||
for f in filters:
|
||||
if k != f["key"]:
|
||||
continue
|
||||
ids = filter_out(v2docs, f["op"], f["value"])
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
if logic == "and":
|
||||
doc_ids = doc_ids & set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids | set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
|
||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||
@ -420,25 +350,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = dialog.meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
elif dialog.meta_data_filter.get("method") == "manual":
|
||||
conds = dialog.meta_data_filter["manual"]
|
||||
attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
|
||||
if conds and not attachments:
|
||||
attachments = ["-999"]
|
||||
attachments = await apply_meta_data_filter(
|
||||
dialog.meta_data_filter,
|
||||
metas,
|
||||
questions[-1],
|
||||
chat_mdl,
|
||||
attachments,
|
||||
)
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||
@ -838,24 +756,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
kbinfos = retriever.retrieval(
|
||||
question=question,
|
||||
@ -922,24 +823,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
elif meta_data_filter.get("method") == "manual":
|
||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
||||
if meta_data_filter["manual"] and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question=question,
|
||||
|
||||
142
common/metadata_utils.py
Normal file
142
common/metadata_utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Any, Callable
|
||||
|
||||
from rag.prompts.generator import gen_meta_filter
|
||||
|
||||
|
||||
def convert_conditions(metadata_condition):
|
||||
if metadata_condition is None:
|
||||
metadata_condition = {}
|
||||
op_mapping = {
|
||||
"is": "=",
|
||||
"not is": "≠"
|
||||
}
|
||||
return [
|
||||
{
|
||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||
"key": cond["name"],
|
||||
"value": cond["value"]
|
||||
}
|
||||
for cond in metadata_condition.get("conditions", [])
|
||||
]
|
||||
|
||||
|
||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
doc_ids = set([])
|
||||
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
(operator == "≠", input != value),
|
||||
(operator == ">", input > value),
|
||||
(operator == "<", input < value),
|
||||
(operator == "≥", input >= value),
|
||||
(operator == "≤", input <= value),
|
||||
]:
|
||||
try:
|
||||
if all(conds):
|
||||
ids.extend(docids)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return ids
|
||||
|
||||
for k, v2docs in metas.items():
|
||||
for f in filters:
|
||||
if k != f["key"]:
|
||||
continue
|
||||
ids = filter_out(v2docs, f["op"], f["value"])
|
||||
if not doc_ids:
|
||||
doc_ids = set(ids)
|
||||
else:
|
||||
if logic == "and":
|
||||
doc_ids = doc_ids & set(ids)
|
||||
else:
|
||||
doc_ids = doc_ids | set(ids)
|
||||
if not doc_ids:
|
||||
return []
|
||||
return list(doc_ids)
|
||||
|
||||
|
||||
async def apply_meta_data_filter(
|
||||
meta_data_filter: dict | None,
|
||||
metas: dict,
|
||||
question: str,
|
||||
chat_mdl: Any = None,
|
||||
base_doc_ids: list[str] | None = None,
|
||||
manual_value_resolver: Callable[[dict], dict] | None = None,
|
||||
) -> list[str] | None:
|
||||
"""
|
||||
Apply metadata filtering rules and return the filtered doc_ids.
|
||||
|
||||
meta_data_filter supports three modes:
|
||||
- auto: generate filter conditions via LLM (gen_meta_filter)
|
||||
- semi_auto: generate conditions using selected metadata keys only
|
||||
- manual: directly filter based on provided conditions
|
||||
|
||||
Returns:
|
||||
list of doc_ids, ["-999"] when manual filters yield no result, or None
|
||||
when auto/semi_auto filters return empty.
|
||||
"""
|
||||
doc_ids = list(base_doc_ids) if base_doc_ids else []
|
||||
|
||||
if not meta_data_filter:
|
||||
return doc_ids
|
||||
|
||||
method = meta_data_filter.get("method")
|
||||
|
||||
if method == "auto":
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
return None
|
||||
elif method == "semi_auto":
|
||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
return None
|
||||
elif method == "manual":
|
||||
filters = meta_data_filter.get("manual", [])
|
||||
if manual_value_resolver:
|
||||
filters = [manual_value_resolver(flt) for flt in filters]
|
||||
doc_ids.extend(meta_filter(metas, filters, meta_data_filter.get("logic", "and")))
|
||||
if filters and not doc_ids:
|
||||
doc_ids = ["-999"]
|
||||
|
||||
return doc_ids
|
||||
@ -75,6 +75,7 @@ export const useTestRetrieval = () => {
|
||||
page,
|
||||
size: pageSize,
|
||||
doc_ids: filterValue.doc_ids,
|
||||
highlight: true,
|
||||
};
|
||||
}, [filterValue, knowledgeBaseId, page, pageSize, values]);
|
||||
|
||||
|
||||
@ -209,7 +209,11 @@ export default function SearchingView({
|
||||
<div
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: DOMPurify.sanitize(
|
||||
`${chunk.highlight}...`,
|
||||
`${
|
||||
chunk.highlight ??
|
||||
chunk.content_with_weight ??
|
||||
''
|
||||
}...`,
|
||||
),
|
||||
}}
|
||||
className="text-sm text-text-primary mb-1"
|
||||
|
||||
Reference in New Issue
Block a user