diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 21df960be..f024c42fa 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -202,7 +202,7 @@ class Retrieval(ToolBase, ABC): kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs]) if self._param.use_kg: - ck = settings.kg_retriever.retrieval(query, + ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], kb_ids, embd_mdl, @@ -215,7 +215,7 @@ class Retrieval(ToolBase, ABC): kbinfos = {"chunks": [], "doc_aggs": []} if self._param.use_kg and kbs: - ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, + ck = await settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) if self.check_if_canceled("Retrieval processing"): return diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index f5b248fd5..1a7bed0c6 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -381,7 +381,7 @@ async def retrieval_test(): rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(_question, + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index 7a11688dd..91f1c9a8f 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -150,7 +150,7 @@ async def retrieval(tenant_id): ) if use_kg: - ck = settings.kg_retriever.retrieval(question, + ck = await settings.kg_retriever.retrieval(question, [tenant_id], [kb_id], embd_mdl, diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index bef03d38e..db8a97b68 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1579,7 +1579,7 @@ async def retrieval_test(tenant_id): if cks: ranks["chunks"] = cks if use_kg: - ck = settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = await settings.kg_retriever.retrieval(question, [k.tenant_id for k in kbs], kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index f9615e36b..e76560ccf 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -1116,7 +1116,7 @@ async def retrieval_test_embedded(): local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - ck = settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, + ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 4bc24210b..33b50730f 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -421,7 +421,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): - ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, + ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) diff --git a/graphrag/search.py b/graphrag/search.py index 7bb46b6b9..728588b87 100644 --- a/graphrag/search.py +++ b/graphrag/search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging from collections import defaultdict @@ -32,21 +33,21 @@ from common.doc_store.doc_store_base import OrderByExpr class KGSearch(Dealer): - def _chat(self, llm_bdl, system, history, gen_conf): + async def _chat(self, llm_bdl, system, history, gen_conf): response = get_llm_cache(llm_bdl.llm_name, system, history, gen_conf) if response: return response - response = llm_bdl.chat(system, history, gen_conf) + response = await llm_bdl.async_chat(system, history, gen_conf) if response.find("**ERROR**") >= 0: raise Exception(response) set_llm_cache(llm_bdl.llm_name, system, response, history, gen_conf) return response - def query_rewrite(self, llm, question, idxnms, kb_ids): + async def query_rewrite(self, llm, question, idxnms, kb_ids): ty2ents = get_entity_type2samples(idxnms, kb_ids) hint_prompt = PROMPTS["minirag_query2kwd"].format(query=question, TYPE_POOL=json.dumps(ty2ents, ensure_ascii=False, indent=2)) - result = self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) + result = await self._chat(llm, hint_prompt, [{"role": "user", "content": "Output:"}], {}) try: keywords_data = json_repair.loads(result) type_keywords = keywords_data.get("answer_type_keywords", []) @@ -138,7 +139,7 @@ class KGSearch(Dealer): idxnms, kb_ids) return self._ent_info_from_(es_res, 0) - def retrieval(self, question: str, + async def retrieval(self, question: str, tenant_ids: str | list[str], kb_ids: list[str], emb_mdl, @@ -158,7 +159,7 @@ class KGSearch(Dealer): idxnms = [index_name(tid) for tid in tenant_ids] ty_kwds = [] try: - ty_kwds, ents = self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids) + ty_kwds, ents = await self.query_rewrite(llm, qst, [index_name(tid) for tid in tenant_ids], kb_ids) logging.info(f"Q: {qst}, Types: {ty_kwds}, Entities: {ents}") except Exception as e: logging.exception(e) @@ -334,5 +335,5 @@ if __name__ == "__main__": embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id) kg = KGSearch(settings.docStoreConn) - print(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, - search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)) + print(asyncio.run(kg.retrieval({"question": args.question, "kb_ids": [kb_id]}, + search.index_name(kb.tenant_id), [kb_id], embed_bdl, llm_bdl)))