mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: init dataflow. (#9791)
### What problem does this PR solve? #9790 Close #9782 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
239
agent/canvas.py
239
agent/canvas.py
@ -29,83 +29,52 @@ from api.utils import get_uuid, hash_str2int
|
||||
from rag.prompts.prompts import chunks_format
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Canvas:
|
||||
class Graph:
|
||||
"""
|
||||
dsl = {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj":{
|
||||
"component_name": "Begin",
|
||||
"params": {},
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": [],
|
||||
},
|
||||
"retrieval_0": {
|
||||
"obj": {
|
||||
"component_name": "Retrieval",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["generate_0"],
|
||||
"upstream": ["answer_0"],
|
||||
},
|
||||
"generate_0": {
|
||||
"obj": {
|
||||
"component_name": "Generate",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": ["retrieval_0"],
|
||||
}
|
||||
},
|
||||
"history": [],
|
||||
"path": ["begin"],
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
self.path = []
|
||||
self.history = []
|
||||
self.components = {}
|
||||
self.error = ""
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
self.dsl = json.loads(dsl) if dsl else {
|
||||
dsl = {
|
||||
"components": {
|
||||
"begin": {
|
||||
"obj": {
|
||||
"obj":{
|
||||
"component_name": "Begin",
|
||||
"params": {
|
||||
"prologue": "Hi there!"
|
||||
}
|
||||
"params": {},
|
||||
},
|
||||
"downstream": [],
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": [],
|
||||
"parent_id": ""
|
||||
},
|
||||
"retrieval_0": {
|
||||
"obj": {
|
||||
"component_name": "Retrieval",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["generate_0"],
|
||||
"upstream": ["answer_0"],
|
||||
},
|
||||
"generate_0": {
|
||||
"obj": {
|
||||
"component_name": "Generate",
|
||||
"params": {}
|
||||
},
|
||||
"downstream": ["answer_0"],
|
||||
"upstream": ["retrieval_0"],
|
||||
}
|
||||
},
|
||||
"history": [],
|
||||
"path": [],
|
||||
"retrieval": [],
|
||||
"path": ["begin"],
|
||||
"retrieval": {"chunks": [], "doc_aggs": []},
|
||||
"globals": {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
self.path = []
|
||||
self.components = {}
|
||||
self.error = ""
|
||||
self.dsl = json.loads(dsl)
|
||||
self._tenant_id = tenant_id
|
||||
self.task_id = task_id if task_id else get_uuid()
|
||||
self.load()
|
||||
@ -116,8 +85,6 @@ class Canvas:
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
|
||||
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
|
||||
|
||||
for k, cpn in self.components.items():
|
||||
cpn_nms.add(cpn["obj"]["component_name"])
|
||||
param = component_class(cpn["obj"]["component_name"] + "Param")()
|
||||
@ -130,27 +97,10 @@ class Canvas:
|
||||
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
|
||||
|
||||
self.path = self.dsl["path"]
|
||||
self.history = self.dsl["history"]
|
||||
if "globals" in self.dsl:
|
||||
self.globals = self.dsl["globals"]
|
||||
else:
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
|
||||
def __str__(self):
|
||||
self.dsl["path"] = self.path
|
||||
self.dsl["history"] = self.history
|
||||
self.dsl["globals"] = self.globals
|
||||
self.dsl["task_id"] = self.task_id
|
||||
self.dsl["retrieval"] = self.retrieval
|
||||
self.dsl["memory"] = self.memory
|
||||
dsl = {
|
||||
"components": {}
|
||||
}
|
||||
@ -169,14 +119,79 @@ class Canvas:
|
||||
dsl["components"][k][c] = deepcopy(cpn[c])
|
||||
return json.dumps(dsl, ensure_ascii=False)
|
||||
|
||||
def reset(self, mem=False):
|
||||
def reset(self):
|
||||
self.path = []
|
||||
for k, cpn in self.components.items():
|
||||
self.components[k]["obj"].reset()
|
||||
try:
|
||||
REDIS_CONN.delete(f"{self.task_id}-logs")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def get_component_name(self, cid):
|
||||
for n in self.dsl.get("graph", {}).get("nodes", []):
|
||||
if cid == n["id"]:
|
||||
return n["data"]["name"]
|
||||
return ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
|
||||
return self.components.get(cpn_id)
|
||||
|
||||
def get_component_obj(self, cpn_id) -> ComponentBase:
|
||||
return self.components.get(cpn_id)["obj"]
|
||||
|
||||
def get_component_type(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].component_name
|
||||
|
||||
def get_component_input_form(self, cpn_id) -> dict:
|
||||
return self.components.get(cpn_id)["obj"].get_input_form()
|
||||
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
|
||||
class Canvas(Graph):
|
||||
|
||||
def __init__(self, dsl: str, tenant_id=None, task_id=None):
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": tenant_id,
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
super().__init__(dsl, tenant_id, task_id)
|
||||
|
||||
def load(self):
|
||||
super().load()
|
||||
self.history = self.dsl["history"]
|
||||
if "globals" in self.dsl:
|
||||
self.globals = self.dsl["globals"]
|
||||
else:
|
||||
self.globals = {
|
||||
"sys.query": "",
|
||||
"sys.user_id": "",
|
||||
"sys.conversation_turns": 0,
|
||||
"sys.files": []
|
||||
}
|
||||
|
||||
self.retrieval = self.dsl["retrieval"]
|
||||
self.memory = self.dsl.get("memory", [])
|
||||
|
||||
def __str__(self):
|
||||
self.dsl["history"] = self.history
|
||||
self.dsl["retrieval"] = self.retrieval
|
||||
self.dsl["memory"] = self.memory
|
||||
return super().__str__()
|
||||
|
||||
def reset(self, mem=False):
|
||||
super().reset()
|
||||
if not mem:
|
||||
self.history = []
|
||||
self.retrieval = []
|
||||
self.memory = []
|
||||
for k, cpn in self.components.items():
|
||||
self.components[k]["obj"].reset()
|
||||
|
||||
for k in self.globals.keys():
|
||||
if isinstance(self.globals[k], str):
|
||||
@ -192,17 +207,6 @@ class Canvas:
|
||||
else:
|
||||
self.globals[k] = None
|
||||
|
||||
try:
|
||||
REDIS_CONN.delete(f"{self.task_id}-logs")
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
def get_component_name(self, cid):
|
||||
for n in self.dsl.get("graph", {}).get("nodes", []):
|
||||
if cid == n["id"]:
|
||||
return n["data"]["name"]
|
||||
return ""
|
||||
|
||||
def run(self, **kwargs):
|
||||
st = time.perf_counter()
|
||||
self.message_id = get_uuid()
|
||||
@ -388,18 +392,6 @@ class Canvas:
|
||||
})
|
||||
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
|
||||
|
||||
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
|
||||
return self.components.get(cpn_id)
|
||||
|
||||
def get_component_obj(self, cpn_id) -> ComponentBase:
|
||||
return self.components.get(cpn_id)["obj"]
|
||||
|
||||
def get_component_type(self, cpn_id) -> str:
|
||||
return self.components.get(cpn_id)["obj"].component_name
|
||||
|
||||
def get_component_input_form(self, cpn_id) -> dict:
|
||||
return self.components.get(cpn_id)["obj"].get_input_form()
|
||||
|
||||
def is_reff(self, exp: str) -> bool:
|
||||
exp = exp.strip("{").strip("}")
|
||||
if exp.find("@") < 0:
|
||||
@ -421,9 +413,6 @@ class Canvas:
|
||||
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
|
||||
return cpn["obj"].output(var_nm)
|
||||
|
||||
def get_tenant_id(self):
|
||||
return self._tenant_id
|
||||
|
||||
def get_history(self, window_size):
|
||||
convs = []
|
||||
if window_size <= 0:
|
||||
@ -438,36 +427,6 @@ class Canvas:
|
||||
def add_user_input(self, question):
|
||||
self.history.append(("user", question))
|
||||
|
||||
def _find_loop(self, max_loops=6):
|
||||
path = self.path[-1][::-1]
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for i in range(len(path)):
|
||||
if path[i].lower().find("answer") == 0 or path[i].lower().find("iterationitem") == 0:
|
||||
path = path[:i]
|
||||
break
|
||||
|
||||
if len(path) < 2:
|
||||
return False
|
||||
|
||||
for loc in range(2, len(path) // 2):
|
||||
pat = ",".join(path[0:loc])
|
||||
path_str = ",".join(path)
|
||||
if len(pat) >= len(path_str):
|
||||
return False
|
||||
loop = max_loops
|
||||
while path_str.find(pat) == 0 and loop >= 0:
|
||||
loop -= 1
|
||||
if len(pat)+1 >= len(path_str):
|
||||
return False
|
||||
path_str = path_str[len(pat)+1:]
|
||||
if loop < 0:
|
||||
pat = " => ".join([p.split(":")[0] for p in path[0:loc]])
|
||||
return pat + " => " + pat
|
||||
|
||||
return False
|
||||
|
||||
def get_prologue(self):
|
||||
return self.components["begin"]["obj"]._param.prologue
|
||||
|
||||
|
||||
Reference in New Issue
Block a user