mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-01 09:39:57 +08:00
Fix: KG search issue. (#12364)
### What problem does this PR solve? Close #12347 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -202,7 +202,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
||||||
[kb.tenant_id for kb in kbs])
|
[kb.tenant_id for kb in kbs])
|
||||||
if self._param.use_kg:
|
if self._param.use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(query,
|
ck = await settings.kg_retriever.retrieval(query,
|
||||||
[kb.tenant_id for kb in kbs],
|
[kb.tenant_id for kb in kbs],
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
@ -215,7 +215,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||||
|
|
||||||
if self._param.use_kg and kbs:
|
if self._param.use_kg and kbs:
|
||||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
||||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
|
|||||||
@ -381,7 +381,7 @@ async def retrieval_test():
|
|||||||
rank_feature=labels
|
rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(_question,
|
ck = await settings.kg_retriever.retrieval(_question,
|
||||||
tenant_ids,
|
tenant_ids,
|
||||||
kb_ids,
|
kb_ids,
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
|
|||||||
@ -150,7 +150,7 @@ async def retrieval(tenant_id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question,
|
ck = await settings.kg_retriever.retrieval(question,
|
||||||
[tenant_id],
|
[tenant_id],
|
||||||
[kb_id],
|
[kb_id],
|
||||||
embd_mdl,
|
embd_mdl,
|
||||||
|
|||||||
@ -1579,7 +1579,7 @@ async def retrieval_test(tenant_id):
|
|||||||
if cks:
|
if cks:
|
||||||
ranks["chunks"] = cks
|
ranks["chunks"] = cks
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
|
|||||||
@ -1116,7 +1116,7 @@ async def retrieval_test_embedded():
|
|||||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|||||||
@ -421,7 +421,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||||
if prompt_config.get("use_kg"):
|
if prompt_config.get("use_kg"):
|
||||||
ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl,
|
||||||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
kbinfos["chunks"].insert(0, ck)
|
kbinfos["chunks"].insert(0, ck)
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
@ -32,21 +33,21 @@ from common.doc_store.doc_store_base import OrderByExpr
|
|||||||
|
|
||||||
|
|
||||||
class KGSearch(Dealer):
|
class KGSearch(Dealer):
|
||||||
def _chat(self, llm_bdl, system, history, gen_conf):
|
async def _chat(self, llm_bdl, system, history, gen_conf):
|
||||||
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
|
response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf)
|
||||||
if response:
|
if response:
|
||||||
return response
|
return response
|
||||||
response = llm_bdl.chat(system, history, gen_conf)
|
response = await llm_bdl.async_chat(system, history, gen_conf)
|
||||||
if response.find("**ERROR**") >= 0:
|
if response.find("**ERROR**") >= 0:
|
||||||
raise Exception(response)
|
raise Exception(response)
|
||||||
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
|
set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def query_rewrite(self, llm, question, idxnms, kb_ids):
|
async def query_rewrite(self, llm, question, idxnms, kb_ids):
|
||||||
ty2ents = get_entity_type2samples(idxnms, kb_ids)
|
ty2ents = get_entity_type2samples(idxnms, kb_ids)
|
||||||
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question,
|
||||||
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2))
|
||||||
result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
result = await self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {})
|
||||||
try:
|
try:
|
||||||
keywords_data = json_repair.loads(result)
|
keywords_data = json_repair.loads(result)
|
||||||
type_keywords = keywords_data.get("answer_type_keywords", [])
|
type_keywords = keywords_data.get("answer_type_keywords", [])
|
||||||
@ -138,7 +139,7 @@ class KGSearch(Dealer):
|
|||||||
idxnms, kb_ids)
|
idxnms, kb_ids)
|
||||||
return self._ent_info_from_(es_res, 0)
|
return self._ent_info_from_(es_res, 0)
|
||||||
|
|
||||||
def retrieval(self, question: str,
|
async def retrieval(self, question: str,
|
||||||
tenant_ids: str | list[str],
|
tenant_ids: str | list[str],
|
||||||
kb_ids: list[str],
|
kb_ids: list[str],
|
||||||
emb_mdl,
|
emb_mdl,
|
||||||
@ -158,7 +159,7 @@ class KGSearch(Dealer):
|
|||||||
idxnms = [index_name(tid) for tid in tenant_ids]
|
idxnms = [index_name(tid) for tid in tenant_ids]
|
||||||
ty_kwds = []
|
ty_kwds = []
|
||||||
try:
|
try:
|
||||||
ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
|
ty_kwds, ents = await self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids)
|
||||||
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
|
logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(e)
|
logging.exception(e)
|
||||||
@ -334,5 +335,5 @@ if __name__ == "__main__":
|
|||||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||||
|
|
||||||
kg = KGSearch(settings.docStoreConn)
|
kg = KGSearch(settings.docStoreConn)
|
||||||
print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]},
|
||||||
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl))
|
search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)))
|
||||||
|
|||||||
Reference in New Issue
Block a user