mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Support passing knowledge base id as variable in retrieval component (#7088)
### What problem does this PR solve? Fix #6600 Hello, I have the same business requirement as #6600. My use case is: We have many departments (> 20 now and increasing), and each department has its own knowledge base. Because the agent workflow is the same, so I want to change the knowledge base on the fly, instead of creating agents for every department. It now looks like this:  Knowledge bases can be selected from the dropdown, and passed through the variables in the table. All selected knowledge bases are used for retrieval. ### Type of change - [ ] Bug Fix (non-breaking change which fixes an issue) - [x] New Feature (non-breaking change which adds functionality) - [ ] Documentation Update - [ ] Refactoring - [ ] Performance Improvement - [ ] Other (please describe):
This commit is contained in:
@ -19,7 +19,7 @@ import json
|
||||
import os
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Tuple, Union
|
||||
from typing import Any, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
@ -462,6 +462,33 @@ class ComponentBase(ABC):
|
||||
def set_output(self, v):
|
||||
setattr(self._param, self._param.output_var_name, v)
|
||||
|
||||
def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
|
||||
outs = []
|
||||
for q in sources:
|
||||
if q.get("component_id"):
|
||||
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
||||
cpn_id, key = q["component_id"].split("@")
|
||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||
if p["key"] == key:
|
||||
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||
continue
|
||||
|
||||
if q["component_id"].lower().find("answer") == 0:
|
||||
txt = []
|
||||
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
|
||||
txt.append(f"{r.upper()}:{c}")
|
||||
txt = "\n".join(txt)
|
||||
outs.append(pd.DataFrame([{"content": txt}]))
|
||||
continue
|
||||
|
||||
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
|
||||
elif q.get("value"):
|
||||
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
||||
return outs
|
||||
|
||||
def get_input(self):
|
||||
if self._param.debug_inputs:
|
||||
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
|
||||
@ -475,37 +502,24 @@ class ComponentBase(ABC):
|
||||
|
||||
if self._param.query:
|
||||
self._param.inputs = []
|
||||
outs = []
|
||||
for q in self._param.query:
|
||||
if q.get("component_id"):
|
||||
if q["component_id"].split("@")[0].lower().find("begin") >= 0:
|
||||
cpn_id, key = q["component_id"].split("@")
|
||||
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
|
||||
if p["key"] == key:
|
||||
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
|
||||
self._param.inputs.append({"component_id": q["component_id"],
|
||||
"content": p.get("value", "")})
|
||||
break
|
||||
else:
|
||||
assert False, f"Can't find parameter '{key}' for {cpn_id}"
|
||||
continue
|
||||
outs = self._fetch_outputs_from(self._param.query)
|
||||
|
||||
if q["component_id"].lower().find("answer") == 0:
|
||||
txt = []
|
||||
for r, c in self._canvas.history[::-1][:self._param.message_history_window_size][::-1]:
|
||||
txt.append(f"{r.upper()}:{c}")
|
||||
txt = "\n".join(txt)
|
||||
self._param.inputs.append({"content": txt, "component_id": q["component_id"]})
|
||||
outs.append(pd.DataFrame([{"content": txt}]))
|
||||
continue
|
||||
for out in outs:
|
||||
records = out.to_dict("records")
|
||||
content: str
|
||||
|
||||
if len(records) > 1:
|
||||
content = "\n".join(
|
||||
[str(d["content"]) for d in records]
|
||||
)
|
||||
else:
|
||||
content = records[0]["content"]
|
||||
|
||||
self._param.inputs.append({
|
||||
"component_id": records[0].get("component_id"),
|
||||
"content": content
|
||||
})
|
||||
|
||||
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
|
||||
self._param.inputs.append({"component_id": q["component_id"],
|
||||
"content": "\n".join(
|
||||
[str(d["content"]) for d in outs[-1].to_dict('records')])})
|
||||
elif q.get("value"):
|
||||
self._param.inputs.append({"component_id": None, "content": q["value"]})
|
||||
outs.append(pd.DataFrame([{"content": q["value"]}]))
|
||||
if outs:
|
||||
df = pd.concat(outs, ignore_index=True)
|
||||
if "content" in df:
|
||||
|
||||
@ -41,6 +41,7 @@ class RetrievalParam(ComponentParamBase):
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
self.tavily_api_key = ""
|
||||
@ -58,7 +59,22 @@ class Retrieval(ComponentBase, ABC):
|
||||
def _run(self, history, **kwargs):
|
||||
query = self.get_input()
|
||||
query = str(query["content"][0]) if "content" in query else ""
|
||||
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
|
||||
|
||||
kb_ids: list[str] = self._param.kb_ids or []
|
||||
|
||||
kb_vars = self._fetch_outputs_from(self._param.kb_vars)
|
||||
|
||||
if len(kb_vars) > 0:
|
||||
for kb_var in kb_vars:
|
||||
if len(kb_var) == 1:
|
||||
kb_ids.append(str(kb_var["content"][0]))
|
||||
else:
|
||||
for v in kb_var.to_dict("records"):
|
||||
kb_ids.append(v["content"])
|
||||
|
||||
filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
|
||||
if not kbs:
|
||||
return Retrieval.be_output("")
|
||||
|
||||
@ -75,7 +91,7 @@ class Retrieval(ComponentBase, ABC):
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
if kbs:
|
||||
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, self._param.kb_ids,
|
||||
kbinfos = settings.retrievaler.retrieval(query, embd_mdl, kbs[0].tenant_id, filtered_kb_ids,
|
||||
1, self._param.top_n,
|
||||
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
||||
aggs=False, rerank_mdl=rerank_mdl,
|
||||
@ -86,7 +102,7 @@ class Retrieval(ComponentBase, ABC):
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
[kbs[0].tenant_id],
|
||||
self._param.kb_ids,
|
||||
filtered_kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
|
||||
Reference in New Issue
Block a user