Feat: add meta filter to search app. (#9554)

### What problem does this PR solve?


### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-08-19 17:25:44 +08:00
committed by GitHub
parent a41a646909
commit f123587538
8 changed files with 94 additions and 154 deletions

View File

@ -479,7 +479,7 @@ class ComponentBase(ABC):
def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
res = {} res = {}
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE): for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE|re.DOTALL):
exp = r.group(1) exp = r.group(1)
cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp) cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
res[exp] = { res[exp] = {

View File

@ -54,6 +54,8 @@ class Message(ComponentBase):
if k in kwargs: if k in kwargs:
continue continue
v = v["value"] v = v["value"]
if not v:
v = ""
ans = "" ans = ""
if isinstance(v, partial): if isinstance(v, partial):
for t in v(): for t in v():
@ -94,6 +96,8 @@ class Message(ComponentBase):
continue continue
v = self._canvas.get_variable_value(exp) v = self._canvas.get_variable_value(exp)
if not v:
v = ""
if isinstance(v, partial): if isinstance(v, partial):
cnt = "" cnt = ""
for t in v(): for t in v():

View File

@ -115,6 +115,12 @@ def getsse(canvas_id):
if not objs: if not objs:
return get_data_error_result(message='Authentication error: API key is invalid!"') return get_data_error_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id tenant_id = objs[0].tenant_id
if not UserCanvasService.query(user_id=tenant_id, id=canvas_id):
return get_json_result(
data=False,
message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR
)
e, c = UserCanvasService.get_by_id(canvas_id) e, c = UserCanvasService.get_by_id(canvas_id)
if not e or c.user_id != tenant_id: if not e or c.user_id != tenant_id:
return get_data_error_result(message="canvas not found.") return get_data_error_result(message="canvas not found.")

View File

@ -17,24 +17,18 @@ import json
import re import re
import traceback import traceback
from copy import deepcopy from copy import deepcopy
import trio
from flask import Response, request from flask import Response, request
from flask_login import current_user, login_required from flask_login import current_user, login_required
from api import settings from api import settings
from api.db import LLMType from api.db import LLMType
from api.db.db_models import APIToken from api.db.db_models import APIToken
from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.conversation_service import ConversationService, structure_answer
from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question
from rag.prompts.prompt_template import load_prompt from rag.prompts.prompt_template import load_prompt
from rag.prompts.prompts import chunks_format from rag.prompts.prompts import chunks_format
@ -375,71 +369,12 @@ def ask_about():
@validate_request("question", "kb_ids") @validate_request("question", "kb_ids")
def mindmap(): def mindmap():
req = request.json req = request.json
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_app = None search_app = SearchService.get_detail(search_id) if search_id else {}
search_config = {} search_config = search_app.get("search_config", {}) if search_app else {}
if search_id: kb_ids = search_config.get("kb_ids", req["kb_ids"])
search_app = SearchService.get_detail(search_id)
if search_app:
search_config = search_app.get("search_config", {})
kb_ids = req["kb_ids"] mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config)
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(message="Knowledgebase not found!")
chat_id = ""
similarity_threshold = 0.3,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("chat_id", ""):
chat_id = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.2):
similarity_threshold = search_config.get("similarity_threshold", 0.2)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
tenant_id = kb.tenant_id
if search_app and search_app.get("tenant_id", ""):
tenant_id = search_app.get("tenant_id", "")
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
question = req["question"]
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_id,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, [kb]),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
if "error" in mind_map: if "error" in mind_map:
return server_error_response(Exception(mind_map["error"])) return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)

View File

@ -18,7 +18,6 @@ import re
import time import time
import tiktoken import tiktoken
from flask import Response, jsonify, request from flask import Response, jsonify, request
import trio
from agent.canvas import Canvas from agent.canvas import Canvas
from api import settings from api import settings
from api.db import LLMType, StatusEnum from api.db import LLMType, StatusEnum
@ -28,14 +27,13 @@ from api.db.services.canvas_service import UserCanvasService, completionOpenAI
from api.db.services.canvas_service import completion as agent_completion from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService from api.db.services.user_service import UserTenantService
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, get_result, server_error_response, token_required, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts import chunks_format from rag.prompts import chunks_format
from rag.prompts.prompt_template import load_prompt from rag.prompts.prompt_template import load_prompt
@ -1102,63 +1100,9 @@ def mindmap():
req = request.json req = request.json
search_id = req.get("search_id", "") search_id = req.get("search_id", "")
search_config = {} search_app = SearchService.get_detail(search_id) if search_id else {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
kb_ids = req["kb_ids"] mind_map = gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_error_data_result(message="Knowledgebase not found!")
chat_id = ""
similarity_threshold = 0.3,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("chat_id", ""):
chat_id = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.2):
similarity_threshold = search_config.get("similarity_threshold", 0.2)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
question = req["question"]
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_id,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, [kb]),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
if "error" in mind_map: if "error" in mind_map:
return server_error_response(Exception(mind_map["error"])) return server_error_response(Exception(mind_map["error"]))
return get_json_result(data=mind_map) return get_json_result(data=mind_map)

View File

@ -22,6 +22,7 @@ from datetime import datetime
from functools import partial from functools import partial
from timeit import default_timer as timer from timeit import default_timer as timer
import trio
from langfuse import Langfuse from langfuse import Langfuse
from peewee import fn from peewee import fn
@ -36,6 +37,7 @@ from api.db.services.langfuse_service import TenantLangfuseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
from api.utils import current_timestamp, datetime_format from api.utils import current_timestamp, datetime_format
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.resume import forbidden_select_fields4resume from rag.app.resume import forbidden_select_fields4resume
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.nlp.search import index_name from rag.nlp.search import index_name
@ -688,28 +690,12 @@ def tts(tts_mdl, text):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
similarity_threshold = 0.1,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
if search_config.get("chat_id", ""):
chat_llm_name = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.1):
similarity_threshold = search_config.get("similarity_threshold", 0.1)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", []) doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""): rerank_mdl = None
kb_ids = search_config.get("kb_ids", kb_ids)
chat_llm_name = search_config.get("chat_id", chat_llm_name)
rerank_id = search_config.get("rerank_id", "") rerank_id = search_config.get("rerank_id", "")
meta_data_filter = search_config.get("meta_data_filter")
kbs = KnowledgebaseService.get_by_ids(kb_ids) kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs])) embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -724,6 +710,18 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
max_tokens = chat_mdl.max_length max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
kbinfos = retriever.retrieval( kbinfos = retriever.retrieval(
question = question, question = question,
embd_mdl=embd_mdl, embd_mdl=embd_mdl,
@ -731,9 +729,9 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
kb_ids=kb_ids, kb_ids=kb_ids,
page=1, page=1,
page_size=12, page_size=12,
similarity_threshold=similarity_threshold, similarity_threshold=search_config.get("similarity_threshold", 0.1),
vector_similarity_weight=vector_similarity_weight, vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=top, top=search_config.get("top_k", 1024),
doc_ids=doc_ids, doc_ids=doc_ids,
aggs=False, aggs=False,
rerank_mdl=rerank_mdl, rerank_mdl=rerank_mdl,
@ -768,3 +766,50 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
answer = ans answer = ans
yield {"answer": answer, "reference": {}} yield {"answer": answer, "reference": {}}
yield decorate_answer(answer) yield decorate_answer(answer)
def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
meta_data_filter = search_config.get("meta_data_filter", {})
doc_ids = search_config.get("doc_ids", [])
kb_ids = search_config.get("doc_ids", kb_ids)
rerank_id = search_config.get("rerank_id", "")
rerank_mdl = None
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=search_config.get("similarity_threshold", 0.2),
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
top=search_config.get("top_k", 1024),
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
return mind_map.output

View File

@ -71,6 +71,8 @@ class SearchService(CommonService):
.first() .first()
.to_dict() .to_dict()
) )
if not search:
return {}
return search return search
@classmethod @classmethod

View File

@ -6,3 +6,7 @@ proxy_set_header Connection "";
proxy_buffering off; proxy_buffering off;
proxy_read_timeout 3600s; proxy_read_timeout 3600s;
proxy_send_timeout 3600s; proxy_send_timeout 3600s;
proxy_buffer_size 1024k;
proxy_buffers 16 1024k;
proxy_busy_buffers_size 2048k;
proxy_temp_file_write_size 2048k;