Fix: KG search issue. (#12364)

### What problem does this PR solve?

Close #12347

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-12-31 14:40:27 +08:00
committed by GitHub
parent 675d18d359
commit 461c81e14a
7 changed files with 16 additions and 15 deletions

View File

@ -202,7 +202,7 @@ class Retrieval(ToolBase, ABC):
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
[kb.tenant_id for kb in kbs])
if self._param.use_kg:
ck = settings.kg_retriever.retrieval(query,
ck = await settings.kg_retriever.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
@ -215,7 +215,7 @@ class Retrieval(ToolBase, ABC):
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return

View File

@ -381,7 +381,7 @@ async def retrieval_test():
rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(_question,
ck = await settings.kg_retriever.retrieval(_question,
tenant_ids,
kb_ids,
embd_mdl,

View File

@ -150,7 +150,7 @@ async def retrieval(tenant_id):
)
if use_kg:
ck = settings.kg_retriever.retrieval(question,
ck = await settings.kg_retriever.retrieval(question,
[tenant_id],
[kb_id],
embd_mdl,

View File

@ -1579,7 +1579,7 @@ async def retrieval_test(tenant_id):
if cks:
ranks["chunks"] = cks
if use_kg:
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

View File

@ -1116,7 +1116,7 @@ async def retrieval_test_embedded():
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck)

View File

@ -421,7 +421,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if prompt_config.get("use_kg"):
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
LLMBundle(dialog.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
from collections import defaultdict
@ -32,21 +33,21 @@ from common.doc_store.doc_store_base import OrderByExpr
class KGSearch(Dealer):
def _chat(self, llm_bdl, system, history, gen_conf):
async def _chat(self, llm_bdl, system, history, gen_conf):
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
if response:
return response
response = llm_bdl.chat(system, history, gen_conf)
response = await llm_bdl.async_chat(system, history, gen_conf)
if response.find("**ERROR**") >= 0:
raise Exception(response)
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
return response
def query_rewrite(self, llm, question, idxnms, kb_ids):
async def query_rewrite(self, llm, question, idxnms, kb_ids):
ty2ents = get_entity_type2samples(idxnms, kb_ids)
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
result = await self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
try:
keywords_data = json_repair.loads(result)
type_keywords = keywords_data.get("answer_type_keywords", [])
@ -138,7 +139,7 @@ class KGSearch(Dealer):
idxnms, kb_ids)
return self._ent_info_from_(es_res, 0)
def retrieval(self, question: str,
async def retrieval(self, question: str,
tenant_ids: str | list[str],
kb_ids: list[str],
emb_mdl,
@ -158,7 +159,7 @@ class KGSearch(Dealer):
idxnms = [index_name(tid) for tid in tenant_ids]
ty_kwds = []
try:
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
ty_kwds, ents = await self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
except Exception as e:
logging.exception(e)
@ -334,5 +335,5 @@ if __name__ == "__main__":
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
kg = KGSearch(settings.docStoreConn)
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)))