diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index f024c42fa..77a39b731 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -193,7 +193,7 @@ class Retrieval(ToolBase, ABC): if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) - cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], + cks = await settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) if self.check_if_canceled("Retrieval processing"): return diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index db8a97b68..d341cea55 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1575,7 +1575,7 @@ async def retrieval_test(tenant_id): ) if toc_enhance: chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - cks = settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) + cks = await settings.retriever.retrieval_by_toc(question, ranks["chunks"], tenant_ids, chat_mdl, size) if cks: ranks["chunks"] = cks if use_kg: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 33b50730f..83f1bb4fa 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -411,7 +411,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): rank_feature=label_question(" ".join(questions), kbs), ) if prompt_config.get("toc_enhance"): - cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) + cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 01f55c9ef..b10dc8572 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import json import logging import re @@ -589,7 +588,7 @@ class Dealer: key=lambda x: x[1] * -1)[:topn_tags] return {a.replace(".", "_"): max(1, c) for a, c in tag_fea} - def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6): + async def retrieval_by_toc(self, query: str, chunks: list[dict], tenant_ids: list[str], chat_mdl, topn: int = 6): if not chunks: return [] idx_nms = [index_name(tid) for tid in tenant_ids] @@ -614,7 +613,7 @@ class Dealer: if not toc: return chunks - ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2)) + ids = await relevant_chunks_with_toc(query, toc, chat_mdl, topn * 2) if not ids: return chunks