Refa: refine search app (#9536)

### What problem does this PR solve?

Refine search app.

### Type of change

- [x] Refactoring
This commit is contained in:
Yongteng Lei
2025-08-19 09:33:33 +08:00
committed by GitHub
parent dad97869b6
commit 188c0f614b
4 changed files with 192 additions and 15 deletions

View File

@ -29,6 +29,7 @@ from api.db.services.conversation_service import ConversationService, structure_
from api.db.services.dialog_service import DialogService, ask, chat
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService
from api.db.services.tenant_llm_service import TenantLLMService
from api.db.services.user_service import TenantService, UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
@ -344,10 +345,18 @@ def ask_about():
req = request.json
uid = current_user.id
search_id = req.get("search_id", "")
search_app = None
search_config = {}
if search_id:
search_app = SearchService.get_detail(search_id)
if search_app:
search_config = search_app.get("search_config", {})
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@ -366,15 +375,68 @@ def ask_about():
@validate_request("question", "kb_ids")
def mindmap():
req = request.json
search_id = req.get("search_id", "")
search_app = None
search_config = {}
if search_id:
search_app = SearchService.get_detail(search_id)
if search_app:
search_config = search_app.get("search_config", {})
kb_ids = req["kb_ids"]
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_data_error_result(message="Knowledgebase not found!")
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
chat_id = ""
similarity_threshold = 0.3,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("chat_id", ""):
chat_id = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.2):
similarity_threshold = search_config.get("similarity_threshold", 0.2)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
tenant_id = kb.tenant_id
if search_app and search_app.get("tenant_id", ""):
tenant_id = search_app.get("tenant_id", "")
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_id,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, [kb]),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output
@ -388,8 +450,19 @@ def mindmap():
@validate_request("question")
def related_questions():
req = request.json
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
question = req["question"]
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT)
chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id)
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question")
ans = chat_mdl.chat(
prompt,
@ -402,6 +475,6 @@ Related search terms:
""",
}
],
{"temperature": 0.9},
gen_conf,
)
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])

View File

@ -902,10 +902,16 @@ def ask_about_embedded():
req = request.json
uid = objs[0].tenant_id
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
def stream():
nonlocal req, uid
try:
for ans in ask(req["question"], req["kb_ids"], uid):
for ans in ask(req["question"], req["kb_ids"], uid, search_config):
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
except Exception as e:
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n"
@ -1021,8 +1027,19 @@ def related_questions_embedded():
tenant_id = objs[0].tenant_id
if not tenant_id:
return get_error_data_result(message="permission denined.")
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
question = req["question"]
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
chat_id = search_config.get("chat_id", "")
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_id)
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
prompt = load_prompt("related_question")
ans = chat_mdl.chat(
prompt,
@ -1035,7 +1052,7 @@ Related search terms:
""",
}
],
{"temperature": 0.9},
gen_conf,
)
return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)])
@ -1083,15 +1100,62 @@ def mindmap():
tenant_id = objs[0].tenant_id
req = request.json
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
search_config = search_app.get("search_config", {})
kb_ids = req["kb_ids"]
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e:
return get_error_data_result(message="Knowledgebase not found!")
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
chat_id = ""
similarity_threshold = 0.3,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("chat_id", ""):
chat_id = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.2):
similarity_threshold = search_config.get("similarity_threshold", 0.2)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id)
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
question = req["question"]
ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb]))
ranks = settings.retrievaler.retrieval(
question=question,
embd_mdl=embd_mdl,
tenant_ids=tenant_id,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, [kb]),
)
mindmap = MindMapExtractor(chat_mdl)
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
mind_map = mind_map.output

View File

@ -872,7 +872,7 @@ class Search(DataBaseModel):
default={
"kb_ids": [],
"doc_ids": [],
"similarity_threshold": 0.0,
"similarity_threshold": 0.2,
"vector_similarity_weight": 0.3,
"use_kg": False,
# rerank settings

View File

@ -687,7 +687,30 @@ def tts(tts_mdl, text):
return binascii.hexlify(bin).decode("utf-8")
def ask(question, kb_ids, tenant_id, chat_llm_name=None):
def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}):
similarity_threshold = 0.1,
vector_similarity_weight = 0.3,
top = 1024,
doc_ids = []
rerank_id = ""
rerank_mdl = None
if search_config:
if search_config.get("kb_ids", []):
kb_ids = search_config.get("kb_ids", [])
if search_config.get("chat_id", ""):
chat_llm_name = search_config.get("chat_id", "")
if search_config.get("similarity_threshold", 0.1):
similarity_threshold = search_config.get("similarity_threshold", 0.1)
if search_config.get("vector_similarity_weight", 0.3):
vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3)
if search_config.get("top_k", 1024):
top = search_config.get("top_k", 1024)
if search_config.get("doc_ids", []):
doc_ids = search_config.get("doc_ids", [])
if search_config.get("rerank_id", ""):
rerank_id = search_config.get("rerank_id", "")
kbs = KnowledgebaseService.get_by_ids(kb_ids)
embedding_list = list(set([kb.embd_id for kb in kbs]))
@ -696,9 +719,26 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None):
embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0])
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name)
if rerank_id:
rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id)
max_tokens = chat_mdl.max_length
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs))
kbinfos = retriever.retrieval(
question = question,
embd_mdl=embd_mdl,
tenant_ids=tenant_ids,
kb_ids=kb_ids,
page=1,
page_size=12,
similarity_threshold=similarity_threshold,
vector_similarity_weight=vector_similarity_weight,
top=top,
doc_ids=doc_ids,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(question, kbs)
)
knowledges = kb_prompt(kbinfos, max_tokens)
prompt = """
Role: You're a smart assistant. Your name is Miss R.