mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support passing knowledge base id as variable in retrieval component (#7088)
### What problem does this PR solve? Fix #6600 Hello, I have the same business requirement as #6600. My use case is: We have many departments (> 20 now and increasing), and each department has its own knowledge base. Because the agent workflow is the same, so I want to change the knowledge base on the fly, instead of creating agents for every department. It now looks like this:  Knowledge bases can be selected from the dropdown, and passed through the variables in the table. All selected knowledge bases are used for retrieval. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe):
This commit is contained in:
@ -41,6 +41,7 @@ class RetrievalParam(ComponentParamBase):
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
self.tavily_api_key = ""
|
||||
@ -58,7 +59,22 @@ class Retrieval(ComponentBase, ABC):
|
||||
def _run(self, history, **kwargs):
|
||||
query = self.get_input()
|
||||
query = str(query["content"][0]) if "content" in query else ""
|
||||
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
|
||||
|
||||
kb_ids: list[str] = self._param.kb_ids or []
|
||||
|
||||
kb_vars = self._fetch_outputs_from(self._param.kb_vars)
|
||||
|
||||
if len(kb_vars) > 0:
|
||||
for kb_var in kb_vars:
|
||||
if len(kb_var) == 1:
|
||||
kb_ids.append(str(kb_var["content"][0]))
|
||||
else:
|
||||
for v in kb_var.to_dict("records"):
|
||||
kb_ids.append(v["content"])
|
||||
|
||||
filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
|
||||
if not kbs:
|
||||
return Retrieval.be_output("")
|
||||
|
||||
@ -75,7 +91,7 @@ class Retrieval(ComponentBase, ABC):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
if kbs:
|
||||
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
||||
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids,
|
||||
1, self._param.top_n,
|
||||
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
||||
aggs=False, rerank_mdl=rerank_mdl,
|
||||
@ -86,7 +102,7 @@ class Retrieval(ComponentBase, ABC):
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
[kbs[0].tenant_id],
|
||||
self._param.kb_ids,
|
||||
filtered_kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
|
||||
Reference in New Issue
Block a user