Feat: add meta filter to search app. (#9554)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-08-19 17:25:44 +08:00
committed by GitHub
parent a41a646909
commit f123587538
8 changed files with 94 additions and 154 deletions

View File

@ -22,6 +22,7 @@ from datetime import datetime
from functools import partial
from timeit import default_timer as timer
import trio
from langfuse import Langfuse
from peewee import fn
@ -36,6 +37,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question
from rag.nlp.search import index_name
@ -688,28 +690,12 @@ def tts(tts_mdl, text):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
similarity_threshold = 0.1,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
doc_ids = search_config.get("doc_ids", [])
rerank_mdl = None
if search_config:
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
if search_config.get("chat_id", ""):
chat_llm_name = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.1):
similarity_threshold = search_config.get("similarity_threshold", 0.1)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
kb_ids = search_config.get("kb_ids", kb_ids)
chat_llm_name = search_config.get("chat_id", chat_llm_name)
rerank_id = search_config.get("rerank_id", "")
meta_data_filter = search_config.get("meta_data_filter")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -724,6 +710,18 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
kbinfos = retriever.retrieval(
question = question,
embd_mdl=embd_mdl,
@ -731,9 +729,9 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
similarity_threshold=search_config.get("similarity_threshold", 0.1),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
@ -768,3 +766,50 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
answer = ans
yield {"answer": answer, "reference": {}}
yield decorate_answer(answer)
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
meta_data_filter = search_config.get("meta_data_filter", {})
doc_ids = search_config.get("doc_ids", [])
kb_ids = search_config.get("doc_ids", kb_ids)
rerank_id = search_config.get("rerank_id", "")
rerank_mdl = None
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=search_config.get("similarity_threshold", 0.2),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
return mind_map.output

View File

@ -71,6 +71,8 @@ class SearchService(CommonService):
.first()
.to_dict()
)
if not search:
return {}
return search
@classmethod