diff --git a/agent/component/retrieval.py b/agent/component/retrieval.py index 10717250d..5df3335ec 100644 --- a/agent/component/retrieval.py +++ b/agent/component/retrieval.py @@ -65,19 +65,25 @@ class Retrieval(ComponentBase, ABC): embd_nms = list(set([kb.embd_id for kb in kbs])) assert len(embd_nms) == 1, "Knowledge bases use different embedding models." - embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0]) - self._canvas.set_embedding_model(embd_nms[0]) + embd_mdl = None + if embd_nms: + embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0]) + self._canvas.set_embedding_model(embd_nms[0]) rerank_mdl = None if self._param.rerank_id: rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id) - kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, + if kbs: + kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids, 1, self._param.top_n, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, aggs=False, rerank_mdl=rerank_mdl, rank_feature=label_question(query, kbs)) - if self._param.use_kg: + else: + kbinfos = {"chunks": [], "doc_aggs": []} + + if self._param.use_kg and kbs: ck = settings.kg_retrievaler.retrieval(query, [kbs[0].tenant_id], self._param.kb_ids,