mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 20:16:49 +08:00
Refa: refactor metadata filter (#11907)
### What problem does this PR solve? Refactor metadata filter. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Refactoring --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@ -22,13 +22,13 @@ from abc import ABC
|
|||||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.dialog_service import meta_filter
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from common import settings
|
from common import settings
|
||||||
from common.connection_utils import timeout
|
from common.connection_utils import timeout
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.generator import cross_languages, kb_prompt, gen_meta_filter
|
from rag.prompts.generator import cross_languages, kb_prompt
|
||||||
|
|
||||||
|
|
||||||
class RetrievalParam(ToolParamBase):
|
class RetrievalParam(ToolParamBase):
|
||||||
@ -131,27 +131,10 @@ class Retrieval(ToolBase, ABC):
|
|||||||
doc_ids=[]
|
doc_ids=[]
|
||||||
if self._param.meta_data_filter!={}:
|
if self._param.meta_data_filter!={}:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if self._param.meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
def _resolve_manual_filter(flt: dict) -> dict:
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, query)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif self._param.meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = self._param.meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif self._param.meta_data_filter.get("method") == "manual":
|
|
||||||
filters = self._param.meta_data_filter["manual"]
|
|
||||||
for flt in filters:
|
|
||||||
pat = re.compile(self.variable_ref_patt)
|
pat = re.compile(self.variable_ref_patt)
|
||||||
s = flt["value"]
|
s = flt.get("value", "")
|
||||||
out_parts = []
|
out_parts = []
|
||||||
last = 0
|
last = 0
|
||||||
|
|
||||||
@ -176,9 +159,20 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
out_parts.append(s[last:])
|
out_parts.append(s[last:])
|
||||||
flt["value"] = "".join(out_parts)
|
flt["value"] = "".join(out_parts)
|
||||||
doc_ids.extend(meta_filter(metas, filters, self._param.meta_data_filter.get("logic", "and")))
|
return flt
|
||||||
if filters and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
chat_mdl = None
|
||||||
|
if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||||
|
|
||||||
|
doc_ids = await apply_meta_data_filter(
|
||||||
|
self._param.meta_data_filter,
|
||||||
|
metas,
|
||||||
|
query,
|
||||||
|
chat_mdl,
|
||||||
|
doc_ids,
|
||||||
|
_resolve_manual_filter if self._param.meta_data_filter.get("method") == "manual" else None,
|
||||||
|
)
|
||||||
|
|
||||||
if self._param.cross_languages:
|
if self._param.cross_languages:
|
||||||
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||||
|
|||||||
@ -21,10 +21,10 @@ import re
|
|||||||
import xxhash
|
import xxhash
|
||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
from api.db.services.dialog_service import meta_filter
|
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request, \
|
||||||
@ -32,7 +32,7 @@ from api.utils.api_utils import get_data_error_result, get_json_result, server_e
|
|||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp import rag_tokenizer, search
|
from rag.nlp import rag_tokenizer, search
|
||||||
from rag.prompts.generator import gen_meta_filter, cross_languages, keyword_extraction
|
from rag.prompts.generator import cross_languages, keyword_extraction
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
from common.constants import RetCode, LLMType, ParserType, PAGERANK_FLD
|
||||||
from common import settings
|
from common import settings
|
||||||
@ -317,54 +317,21 @@ async def retrieval_test():
|
|||||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
|
meta_data_filter = {}
|
||||||
|
chat_mdl = None
|
||||||
if req.get("search_id", ""):
|
if req.get("search_id", ""):
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not local_doc_ids:
|
|
||||||
local_doc_ids = ["-999"]
|
|
||||||
else:
|
else:
|
||||||
meta_data_filter = req.get("meta_data_filter")
|
meta_data_filter = req.get("meta_data_filter") or {}
|
||||||
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, local_doc_ids)
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not local_doc_ids:
|
|
||||||
local_doc_ids = ["-999"]
|
|
||||||
|
|
||||||
tenants = UserTenantService.query(user_id=user_id)
|
tenants = UserTenantService.query(user_id=user_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
|
|||||||
@ -27,7 +27,7 @@ from api.db import VALID_FILE_TYPES, FileType
|
|||||||
from api.db.db_models import Task
|
from api.db.db_models import Task
|
||||||
from api.db.services import duplicate_name
|
from api.db.services import duplicate_name
|
||||||
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
from api.db.services.document_service import DocumentService, doc_upload_and_parse
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
from api.db.services.file2document_service import File2DocumentService
|
from api.db.services.file2document_service import File2DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
|
|||||||
@ -20,9 +20,9 @@ from quart import jsonify
|
|||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
|
||||||
from common.constants import RetCode, LLMType
|
from common.constants import RetCode, LLMType
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
|||||||
@ -35,7 +35,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
|||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
from api.db.services.task_service import TaskService, queue_tasks, cancel_all_task_of
|
||||||
from api.db.services.dialog_service import meta_filter, convert_conditions
|
from common.metadata_utils import meta_filter, convert_conditions
|
||||||
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
from api.utils.api_utils import check_duplicate_ids, construct_json_result, get_error_data_result, get_parser_config, get_result, server_error_response, token_required, \
|
||||||
get_request_json
|
get_request_json
|
||||||
from rag.app.qa import beAdoc, rmPrefix
|
from rag.app.qa import beAdoc, rmPrefix
|
||||||
|
|||||||
@ -28,10 +28,11 @@ from api.db.services.canvas_service import completion as agent_completion
|
|||||||
from api.db.services.conversation_service import ConversationService
|
from api.db.services.conversation_service import ConversationService
|
||||||
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
from api.db.services.conversation_service import async_iframe_completion as iframe_completion
|
||||||
from api.db.services.conversation_service import async_completion as rag_completion
|
from api.db.services.conversation_service import async_completion as rag_completion
|
||||||
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap, meta_filter
|
from api.db.services.dialog_service import DialogService, async_ask, async_chat, gen_mindmap
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
from api.db.services.user_service import UserTenantService
|
from api.db.services.user_service import UserTenantService
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
@ -39,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_
|
|||||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts.template import load_prompt
|
from rag.prompts.template import load_prompt
|
||||||
from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extraction, chunks_format
|
from rag.prompts.generator import cross_languages, keyword_extraction, chunks_format
|
||||||
from common.constants import RetCode, LLMType, StatusEnum
|
from common.constants import RetCode, LLMType, StatusEnum
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
@ -974,54 +975,21 @@ async def retrieval_test_embedded():
|
|||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
_question = question
|
_question = question
|
||||||
|
|
||||||
|
meta_data_filter = {}
|
||||||
|
chat_mdl = None
|
||||||
if req.get("search_id", ""):
|
if req.get("search_id", ""):
|
||||||
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
if meta_data_filter.get("method") == "auto":
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, _question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not local_doc_ids:
|
|
||||||
local_doc_ids = ["-999"]
|
|
||||||
else:
|
else:
|
||||||
meta_data_filter = req.get("meta_data_filter")
|
meta_data_filter = req.get("meta_data_filter") or {}
|
||||||
|
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
local_doc_ids = await apply_meta_data_filter(meta_data_filter, metas, _question, chat_mdl, local_doc_ids)
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not local_doc_ids:
|
|
||||||
local_doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
local_doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not local_doc_ids:
|
|
||||||
local_doc_ids = ["-999"]
|
|
||||||
|
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from api.db.services.document_service import DocumentService
|
|||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.langfuse_service import TenantLangfuseService
|
from api.db.services.langfuse_service import TenantLangfuseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
|
from common.metadata_utils import apply_meta_data_filter
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
from common.time_utils import current_timestamp, datetime_format
|
from common.time_utils import current_timestamp, datetime_format
|
||||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||||
@ -39,7 +40,7 @@ from rag.app.resume import forbidden_select_fields4resume
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.nlp.search import index_name
|
from rag.nlp.search import index_name
|
||||||
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \
|
||||||
gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY
|
PROMPT_JINJA_ENV, ASK_SUMMARY
|
||||||
from common.token_utils import num_tokens_from_string
|
from common.token_utils import num_tokens_from_string
|
||||||
from rag.utils.tavily_conn import Tavily
|
from rag.utils.tavily_conn import Tavily
|
||||||
from common.string_utils import remove_redundant_spaces
|
from common.string_utils import remove_redundant_spaces
|
||||||
@ -277,77 +278,6 @@ def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set):
|
|||||||
return answer, idx
|
return answer, idx
|
||||||
|
|
||||||
|
|
||||||
def convert_conditions(metadata_condition):
|
|
||||||
if metadata_condition is None:
|
|
||||||
metadata_condition = {}
|
|
||||||
op_mapping = {
|
|
||||||
"is": "=",
|
|
||||||
"not is": "≠"
|
|
||||||
}
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
|
||||||
"key": cond["name"],
|
|
||||||
"value": cond["value"]
|
|
||||||
}
|
|
||||||
for cond in metadata_condition.get("conditions", [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
|
||||||
doc_ids = set([])
|
|
||||||
|
|
||||||
def filter_out(v2docs, operator, value):
|
|
||||||
ids = []
|
|
||||||
for input, docids in v2docs.items():
|
|
||||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
|
||||||
try:
|
|
||||||
input = float(input)
|
|
||||||
value = float(value)
|
|
||||||
except Exception:
|
|
||||||
input = str(input)
|
|
||||||
value = str(value)
|
|
||||||
|
|
||||||
for conds in [
|
|
||||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
|
||||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
|
||||||
(operator == "in", str(input).lower() in str(value).lower()),
|
|
||||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
|
||||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
|
||||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
|
||||||
(operator == "empty", not input),
|
|
||||||
(operator == "not empty", input),
|
|
||||||
(operator == "=", input == value),
|
|
||||||
(operator == "≠", input != value),
|
|
||||||
(operator == ">", input > value),
|
|
||||||
(operator == "<", input < value),
|
|
||||||
(operator == "≥", input >= value),
|
|
||||||
(operator == "≤", input <= value),
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
if all(conds):
|
|
||||||
ids.extend(docids)
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return ids
|
|
||||||
|
|
||||||
for k, v2docs in metas.items():
|
|
||||||
for f in filters:
|
|
||||||
if k != f["key"]:
|
|
||||||
continue
|
|
||||||
ids = filter_out(v2docs, f["op"], f["value"])
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = set(ids)
|
|
||||||
else:
|
|
||||||
if logic == "and":
|
|
||||||
doc_ids = doc_ids & set(ids)
|
|
||||||
else:
|
|
||||||
doc_ids = doc_ids | set(ids)
|
|
||||||
if not doc_ids:
|
|
||||||
return []
|
|
||||||
return list(doc_ids)
|
|
||||||
|
|
||||||
async def async_chat(dialog, messages, stream=True, **kwargs):
|
async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||||
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"):
|
||||||
@ -420,25 +350,13 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
|
|
||||||
if dialog.meta_data_filter:
|
if dialog.meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||||
if dialog.meta_data_filter.get("method") == "auto":
|
attachments = await apply_meta_data_filter(
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1])
|
dialog.meta_data_filter,
|
||||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
metas,
|
||||||
if not attachments:
|
questions[-1],
|
||||||
attachments = None
|
chat_mdl,
|
||||||
elif dialog.meta_data_filter.get("method") == "semi_auto":
|
attachments,
|
||||||
selected_keys = dialog.meta_data_filter.get("semi_auto", [])
|
)
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
|
||||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not attachments:
|
|
||||||
attachments = None
|
|
||||||
elif dialog.meta_data_filter.get("method") == "manual":
|
|
||||||
conds = dialog.meta_data_filter["manual"]
|
|
||||||
attachments.extend(meta_filter(metas, conds, dialog.meta_data_filter.get("logic", "and")))
|
|
||||||
if conds and not attachments:
|
|
||||||
attachments = ["-999"]
|
|
||||||
|
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||||
@ -838,24 +756,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
kbinfos = retriever.retrieval(
|
kbinfos = retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
@ -922,24 +823,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
|
|
||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids)
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "semi_auto":
|
|
||||||
selected_keys = meta_data_filter.get("semi_auto", [])
|
|
||||||
if selected_keys:
|
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
|
||||||
if filtered_metas:
|
|
||||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
|
||||||
if not doc_ids:
|
|
||||||
doc_ids = None
|
|
||||||
elif meta_data_filter.get("method") == "manual":
|
|
||||||
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"], meta_data_filter.get("logic", "and")))
|
|
||||||
if meta_data_filter["manual"] and not doc_ids:
|
|
||||||
doc_ids = ["-999"]
|
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question=question,
|
question=question,
|
||||||
|
|||||||
142
common/metadata_utils.py
Normal file
142
common/metadata_utils.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
from rag.prompts.generator import gen_meta_filter
|
||||||
|
|
||||||
|
|
||||||
|
def convert_conditions(metadata_condition):
|
||||||
|
if metadata_condition is None:
|
||||||
|
metadata_condition = {}
|
||||||
|
op_mapping = {
|
||||||
|
"is": "=",
|
||||||
|
"not is": "≠"
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]),
|
||||||
|
"key": cond["name"],
|
||||||
|
"value": cond["value"]
|
||||||
|
}
|
||||||
|
for cond in metadata_condition.get("conditions", [])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||||
|
doc_ids = set([])
|
||||||
|
|
||||||
|
def filter_out(v2docs, operator, value):
|
||||||
|
ids = []
|
||||||
|
for input, docids in v2docs.items():
|
||||||
|
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||||
|
try:
|
||||||
|
input = float(input)
|
||||||
|
value = float(value)
|
||||||
|
except Exception:
|
||||||
|
input = str(input)
|
||||||
|
value = str(value)
|
||||||
|
|
||||||
|
for conds in [
|
||||||
|
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||||
|
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||||
|
(operator == "in", str(input).lower() in str(value).lower()),
|
||||||
|
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||||
|
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||||
|
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||||
|
(operator == "empty", not input),
|
||||||
|
(operator == "not empty", input),
|
||||||
|
(operator == "=", input == value),
|
||||||
|
(operator == "≠", input != value),
|
||||||
|
(operator == ">", input > value),
|
||||||
|
(operator == "<", input < value),
|
||||||
|
(operator == "≥", input >= value),
|
||||||
|
(operator == "≤", input <= value),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
if all(conds):
|
||||||
|
ids.extend(docids)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return ids
|
||||||
|
|
||||||
|
for k, v2docs in metas.items():
|
||||||
|
for f in filters:
|
||||||
|
if k != f["key"]:
|
||||||
|
continue
|
||||||
|
ids = filter_out(v2docs, f["op"], f["value"])
|
||||||
|
if not doc_ids:
|
||||||
|
doc_ids = set(ids)
|
||||||
|
else:
|
||||||
|
if logic == "and":
|
||||||
|
doc_ids = doc_ids & set(ids)
|
||||||
|
else:
|
||||||
|
doc_ids = doc_ids | set(ids)
|
||||||
|
if not doc_ids:
|
||||||
|
return []
|
||||||
|
return list(doc_ids)
|
||||||
|
|
||||||
|
|
||||||
|
async def apply_meta_data_filter(
|
||||||
|
meta_data_filter: dict | None,
|
||||||
|
metas: dict,
|
||||||
|
question: str,
|
||||||
|
chat_mdl: Any = None,
|
||||||
|
base_doc_ids: list[str] | None = None,
|
||||||
|
manual_value_resolver: Callable[[dict], dict] | None = None,
|
||||||
|
) -> list[str] | None:
|
||||||
|
"""
|
||||||
|
Apply metadata filtering rules and return the filtered doc_ids.
|
||||||
|
|
||||||
|
meta_data_filter supports three modes:
|
||||||
|
- auto: generate filter conditions via LLM (gen_meta_filter)
|
||||||
|
- semi_auto: generate conditions using selected metadata keys only
|
||||||
|
- manual: directly filter based on provided conditions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of doc_ids, ["-999"] when manual filters yield no result, or None
|
||||||
|
when auto/semi_auto filters return empty.
|
||||||
|
"""
|
||||||
|
doc_ids = list(base_doc_ids) if base_doc_ids else []
|
||||||
|
|
||||||
|
if not meta_data_filter:
|
||||||
|
return doc_ids
|
||||||
|
|
||||||
|
method = meta_data_filter.get("method")
|
||||||
|
|
||||||
|
if method == "auto":
|
||||||
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
|
if not doc_ids:
|
||||||
|
return None
|
||||||
|
elif method == "semi_auto":
|
||||||
|
selected_keys = meta_data_filter.get("semi_auto", [])
|
||||||
|
if selected_keys:
|
||||||
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
|
if filtered_metas:
|
||||||
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
|
if not doc_ids:
|
||||||
|
return None
|
||||||
|
elif method == "manual":
|
||||||
|
filters = meta_data_filter.get("manual", [])
|
||||||
|
if manual_value_resolver:
|
||||||
|
filters = [manual_value_resolver(flt) for flt in filters]
|
||||||
|
doc_ids.extend(meta_filter(metas, filters, meta_data_filter.get("logic", "and")))
|
||||||
|
if filters and not doc_ids:
|
||||||
|
doc_ids = ["-999"]
|
||||||
|
|
||||||
|
return doc_ids
|
||||||
@ -75,6 +75,7 @@ export const useTestRetrieval = () => {
|
|||||||
page,
|
page,
|
||||||
size: pageSize,
|
size: pageSize,
|
||||||
doc_ids: filterValue.doc_ids,
|
doc_ids: filterValue.doc_ids,
|
||||||
|
highlight: true,
|
||||||
};
|
};
|
||||||
}, [filterValue, knowledgeBaseId, page, pageSize, values]);
|
}, [filterValue, knowledgeBaseId, page, pageSize, values]);
|
||||||
|
|
||||||
|
|||||||
@ -209,7 +209,11 @@ export default function SearchingView({
|
|||||||
<div
|
<div
|
||||||
dangerouslySetInnerHTML={{
|
dangerouslySetInnerHTML={{
|
||||||
__html: DOMPurify.sanitize(
|
__html: DOMPurify.sanitize(
|
||||||
`${chunk.highlight}...`,
|
`${
|
||||||
|
chunk.highlight ??
|
||||||
|
chunk.content_with_weight ??
|
||||||
|
''
|
||||||
|
}...`,
|
||||||
),
|
),
|
||||||
}}
|
}}
|
||||||
className="text-sm text-text-primary mb-1"
|
className="text-sm text-text-primary mb-1"
|
||||||
|
|||||||
Reference in New Issue
Block a user