Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -13,14 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import base64
import json
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
import pandas as pd
from typing import Any, Union, Tuple
from agent.component import component_class
from agent.component.base import ComponentBase
from api.db.services.file_service import FileService
from api.utils import get_uuid, hash_str2int
from rag.prompts.prompts import chunks_format
from rag.utils.redis_conn import REDIS_CONN
class Canvas:
@ -35,14 +42,6 @@ class Canvas:
"downstream": ["answer_0"],
"upstream": [],
},
"answer_0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval_0"],
"upstream": ["begin", "generate_0"],
},
"retrieval_0": {
"obj": {
"component_name": "Retrieval",
@ -61,19 +60,28 @@ class Canvas:
}
},
"history": [],
"messages": [],
"reference": [],
"path": [["begin"]],
"answer": []
"path": ["begin"],
"retrieval": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
}
"""
def __init__(self, dsl: str, tenant_id=None):
def __init__(self, dsl: str, tenant_id=None, task_id=None):
self.path = []
self.history = []
self.messages = []
self.answer = []
self.components = {}
self.error = ""
self.globals = {
"sys.query": "",
"sys.user_id": tenant_id,
"sys.conversation_turns": 0,
"sys.files": []
}
self.dsl = json.loads(dsl) if dsl else {
"components": {
"begin": {
@ -89,13 +97,17 @@ class Canvas:
}
},
"history": [],
"messages": [],
"reference": [],
"path": [],
"answer": []
"retrieval": [],
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}
self._tenant_id = tenant_id
self._embed_id = ""
self.task_id = task_id if task_id else get_uuid()
self.load()
def load(self):
@ -105,33 +117,31 @@ class Canvas:
cpn_nms.add(cpn["obj"]["component_name"])
assert "Begin" in cpn_nms, "There have to be an 'Begin' component."
assert "Answer" in cpn_nms, "There have to be an 'Answer' component."
for k, cpn in self.components.items():
cpn_nms.add(cpn["obj"]["component_name"])
param = component_class(cpn["obj"]["component_name"] + "Param")()
param.update(cpn["obj"]["params"])
param.check()
try:
param.check()
except Exception as e:
raise ValueError(self.get_component_name(k) + f": {e}")
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param)
if cpn["obj"].component_name == "Categorize":
for _, desc in param.category_description.items():
if desc["to"] not in cpn["downstream"]:
cpn["downstream"].append(desc["to"])
self.path = self.dsl["path"]
self.history = self.dsl["history"]
self.messages = self.dsl["messages"]
self.answer = self.dsl["answer"]
self.reference = self.dsl["reference"]
self._embed_id = self.dsl.get("embed_id", "")
self.globals = self.dsl["globals"]
self.retrieval = self.dsl["retrieval"]
self.memory = self.dsl.get("memory", [])
def __str__(self):
self.dsl["path"] = self.path
self.dsl["history"] = self.history
self.dsl["messages"] = self.messages
self.dsl["answer"] = self.answer
self.dsl["reference"] = self.reference
self.dsl["embed_id"] = self._embed_id
self.dsl["globals"] = self.globals
self.dsl["task_id"] = self.task_id
self.dsl["retrieval"] = self.retrieval
self.dsl["memory"] = self.memory
dsl = {
"components": {}
}
@ -150,161 +160,245 @@ class Canvas:
dsl["components"][k][c] = deepcopy(cpn[c])
return json.dumps(dsl, ensure_ascii=False)
def reset(self):
def reset(self, mem=False):
self.path = []
self.history = []
self.messages = []
self.answer = []
self.reference = []
if not mem:
self.history = []
self.retrieval = []
self.memory = []
for k, cpn in self.components.items():
self.components[k]["obj"].reset()
self._embed_id = ""
for k in self.globals.keys():
if isinstance(self.globals[k], str):
self.globals[k] = ""
elif isinstance(self.globals[k], int):
self.globals[k] = 0
elif isinstance(self.globals[k], float):
self.globals[k] = 0
elif isinstance(self.globals[k], list):
self.globals[k] = []
elif isinstance(self.globals[k], dict):
self.globals[k] = {}
else:
self.globals[k] = None
try:
REDIS_CONN.delete(f"{self.task_id}-logs")
except Exception as e:
logging.exception(e)
def get_component_name(self, cid):
for n in self.dsl["graph"]["nodes"]:
for n in self.dsl.get("graph", {}).get("nodes", []):
if cid == n["id"]:
return n["data"]["name"]
return ""
def run(self, running_hint_text = "is running...🕞", **kwargs):
if not running_hint_text or not isinstance(running_hint_text, str):
running_hint_text = "is running...🕞"
bypass_begin = bool(kwargs.get("bypass_begin", False))
def run(self, **kwargs):
st = time.perf_counter()
self.message_id = get_uuid()
created_at = int(time.time())
self.add_user_input(kwargs.get("query"))
if self.answer:
cpn_id = self.answer[0]
self.answer.pop(0)
try:
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
except Exception as e:
ans = ComponentBase.be_output(str(e))
self.path[-1].append(cpn_id)
if kwargs.get("stream"):
for an in ans():
yield an
else:
yield ans
return
if not self.path:
self.components["begin"]["obj"].run(self.history, **kwargs)
self.path.append(["begin"])
if bypass_begin:
cpn = self.get_component("begin")
downstream = cpn["downstream"]
self.path.append(downstream)
self.path.append([])
ran = -1
waiting = []
without_dependent_checking = []
def prepare2run(cpns):
nonlocal ran, ans
for c in cpns:
if self.path[-1] and c == self.path[-1][-1]:
continue
cpn = self.components[c]["obj"]
if cpn.component_name == "Answer":
self.answer.append(c)
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k == "files":
self.globals[f"sys.{k}"] = self.get_files(kwargs[k])
else:
logging.debug(f"Canvas.prepare2run: {c}")
if c not in without_dependent_checking:
cpids = cpn.get_dependent_components()
if any([cc not in self.path[-1] for cc in cpids]):
if c not in waiting:
waiting.append(c)
continue
yield "*'{}'* {}".format(self.get_component_name(c), running_hint_text)
self.globals[f"sys.{k}"] = kwargs[k]
if not self.globals["sys.conversation_turns"] :
self.globals["sys.conversation_turns"] = 0
self.globals["sys.conversation_turns"] += 1
if cpn.component_name.lower() == "iteration":
st_cpn = cpn.get_start()
assert st_cpn, "Start component not found for Iteration."
if not st_cpn["obj"].end():
cpn = st_cpn["obj"]
c = cpn._id
def decorate(event, dt):
nonlocal created_at
return {
"event": event,
#"conversation_id": "f3cc152b-24b0-4258-a1a1-7d5e9fc8a115",
"message_id": self.message_id,
"created_at": created_at,
"task_id": self.task_id,
"data": dt
}
try:
ans = cpn.run(self.history, **kwargs)
except Exception as e:
logging.exception(f"Canvas.run got exception: {e}")
self.path[-1].append(c)
ran += 1
raise e
self.path[-1].append(c)
if not self.path or self.path[-1].lower().find("userfillup") < 0:
self.path.append("begin")
self.retrieval.append({"chunks": [], "doc_aggs": []})
ran += 1
yield decorate("workflow_started", {"inputs": kwargs.get("inputs")})
self.retrieval.append({"chunks": {}, "doc_aggs": {}})
downstream = self.components[self.path[-2][-1]]["downstream"]
if not downstream and self.components[self.path[-2][-1]].get("parent_id"):
cid = self.path[-2][-1]
pid = self.components[cid]["parent_id"]
o, _ = self.components[cid]["obj"].output(allow_partial=False)
oo, _ = self.components[pid]["obj"].output(allow_partial=False)
self.components[pid]["obj"].set_output(pd.concat([oo, o], ignore_index=True).dropna())
downstream = [pid]
def _run_batch(f, t):
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for i in range(f, t):
cpn = self.get_component_obj(self.path[i])
if cpn.component_name.lower() in ["begin", "userfillup"]:
thr.append(executor.submit(cpn.invoke, inputs=kwargs.get("inputs", {})))
else:
thr.append(executor.submit(cpn.invoke, **cpn.get_input()))
for t in thr:
t.result()
for m in prepare2run(downstream):
yield {"content": m, "running_status": True}
def _node_finished(cpn_obj):
return decorate("node_finished",{
"inputs": cpn_obj.get_input_values(),
"outputs": cpn_obj.output(),
"component_id": cpn_obj._id,
"component_name": self.get_component_name(cpn_obj._id),
"component_type": self.get_component_type(cpn_obj._id),
"error": cpn_obj.error(),
"elapsed_time": time.perf_counter() - cpn_obj.output("_created_time"),
"created_at": cpn_obj.output("_created_time"),
})
while 0 <= ran < len(self.path[-1]):
logging.debug(f"Canvas.run: {ran} {self.path}")
cpn_id = self.path[-1][ran]
cpn = self.get_component(cpn_id)
if not any([cpn["downstream"], cpn.get("parent_id"), waiting]):
def _append_path(cpn_id):
if self.path[-1] == cpn_id:
return
self.path.append(cpn_id)
def _extend_path(cpn_ids):
for cpn_id in cpn_ids:
_append_path(cpn_id)
self.error = ""
idx = len(self.path) - 1
partials = []
while idx < len(self.path):
to = len(self.path)
for i in range(idx, to):
yield decorate("node_started", {
"inputs": None, "created_at": int(time.time()),
"component_id": self.path[i],
"component_name": self.get_component_name(self.path[i]),
"component_type": self.get_component_type(self.path[i]),
})
_run_batch(idx, to)
# post processing of components invocation
for i in range(idx, to):
cpn = self.get_component(self.path[i])
if cpn["obj"].component_name.lower() == "message":
if isinstance(cpn["obj"].output("content"), partial):
_m = ""
for m in cpn["obj"].output("content")():
if not m:
continue
if m == "<think>":
yield decorate("message", {"content": "", "start_to_think": True})
elif m == "</think>":
yield decorate("message", {"content": "", "end_to_think": True})
else:
yield decorate("message", {"content": m})
_m += m
cpn["obj"].set_output("content", _m)
else:
yield decorate("message", {"content": cpn["obj"].output("content")})
yield decorate("message_end", {"reference": self.get_reference()})
while partials:
_cpn = self.get_component(partials[0])
if isinstance(_cpn["obj"].output("content"), partial):
break
yield _node_finished(_cpn["obj"])
partials.pop(0)
if cpn["obj"].error():
ex = cpn["obj"].exception_handler()
if ex and ex["comment"]:
yield decorate("message", {"content": ex["comment"]})
yield decorate("message_end", {})
if ex and ex["goto"]:
self.path.append(ex["goto"])
elif not ex or not ex["default_value"]:
self.error = cpn["obj"].error()
if cpn["obj"].component_name.lower() != "iteration":
if isinstance(cpn["obj"].output("content"), partial):
if self.error:
cpn["obj"].set_output("content", None)
yield _node_finished(cpn["obj"])
else:
partials.append(self.path[i])
else:
yield _node_finished(cpn["obj"])
if cpn["obj"].component_name.lower() == "iterationitem" and cpn["obj"].end():
iter = cpn["obj"].get_parent()
yield _node_finished(iter)
_extend_path(self.get_component(cpn["parent_id"])["downstream"])
elif cpn["obj"].component_name.lower() in ["categorize", "switch"]:
_extend_path(cpn["obj"].output("_next"))
elif cpn["obj"].component_name.lower() == "iteration":
_append_path(cpn["obj"].get_start())
elif not cpn["downstream"] and cpn["obj"].get_parent():
_append_path(cpn["obj"].get_parent().get_start())
else:
_extend_path(cpn["downstream"])
if self.error:
logging.error(f"Runtime Error: {self.error}")
break
idx = to
loop = self._find_loop()
if loop:
raise OverflowError(f"Too much loops: {loop}")
if any([self.get_component(c)["obj"].component_name.lower() == "userfillup" for c in self.path[idx:]]):
path = [c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() == "userfillup"]
path.extend([c for c in self.path[idx:] if self.get_component(c)["obj"].component_name.lower() != "userfillup"])
another_inputs = {}
tips = ""
for c in path:
o = self.get_component(c)["obj"]
if o.component_name.lower() == "userfillup":
another_inputs.update(o.get_input_elements())
if o.get_param("enable_tips"):
tips = o.get_param("tips")
self.path = path
yield decorate("user_inputs", {"inputs": another_inputs, "tips": tips})
return
downstream = []
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
switch_out = cpn["obj"].output()[1].iloc[0, 0]
assert switch_out in self.components, \
"{}'s output: {} not valid.".format(cpn_id, switch_out)
downstream = [switch_out]
else:
downstream = cpn["downstream"]
self.path = self.path[:idx]
if not self.error:
yield decorate("workflow_finished",
{
"inputs": kwargs.get("inputs"),
"outputs": self.get_component_obj(self.path[-1]).output(),
"elapsed_time": time.perf_counter() - st,
"created_at": st,
})
self.history.append(("assistant", self.get_component_obj(self.path[-1]).output()))
if not downstream and cpn.get("parent_id"):
pid = cpn["parent_id"]
_, o = cpn["obj"].output(allow_partial=False)
_, oo = self.components[pid]["obj"].output(allow_partial=False)
self.components[pid]["obj"].set_output(pd.concat([oo.dropna(axis=1), o.dropna(axis=1)], ignore_index=True).dropna())
downstream = [pid]
def get_component(self, cpn_id) -> Union[None, dict[str, Any]]:
return self.components.get(cpn_id)
for m in prepare2run(downstream):
yield {"content": m, "running_status": True}
def get_component_obj(self, cpn_id) -> ComponentBase:
return self.components.get(cpn_id)["obj"]
if ran >= len(self.path[-1]) and waiting:
without_dependent_checking = waiting
waiting = []
for m in prepare2run(without_dependent_checking):
yield {"content": m, "running_status": True}
without_dependent_checking = []
ran -= 1
def get_component_type(self, cpn_id) -> str:
return self.components.get(cpn_id)["obj"].component_name
if self.answer:
cpn_id = self.answer[0]
self.answer.pop(0)
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs)
self.path[-1].append(cpn_id)
if kwargs.get("stream"):
assert isinstance(ans, partial)
for an in ans():
yield an
else:
yield ans
def get_component_input_form(self, cpn_id) -> dict:
return self.components.get(cpn_id)["obj"].get_input_form()
else:
raise Exception("The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow.")
def is_reff(self, exp: str) -> bool:
exp = exp.strip("{").strip("}")
if exp.find("@") < 0:
return exp in self.globals
arr = exp.split("@")
if len(arr) != 2:
return False
if self.get_component(arr[0]) is None:
return False
return True
def get_component(self, cpn_id):
return self.components[cpn_id]
def get_variable_value(self, exp: str) -> Any:
exp = exp.strip("{").strip("}").strip(" ").strip("{").strip("}")
if exp.find("@") < 0:
return self.globals[exp]
cpn_id, var_nm = exp.split("@")
cpn = self.get_component(cpn_id)
if not cpn:
raise Exception(f"Can't find variable: '{cpn_id}@{var_nm}'")
return cpn["obj"].output(var_nm)
def get_tenant_id(self):
return self._tenant_id
@ -314,8 +408,8 @@ class Canvas:
if window_size <= 0:
return convs
for role, obj in self.history[window_size * -1:]:
if isinstance(obj, list) and obj and all([isinstance(o, dict) for o in obj]):
convs.append({"role": role, "content": '\n'.join([str(s.get("content", "")) for s in obj])})
if isinstance(obj, dict):
convs.append({"role": role, "content": obj.get("content", "")})
else:
convs.append({"role": role, "content": str(obj)})
return convs
@ -323,12 +417,6 @@ class Canvas:
def add_user_input(self, question):
self.history.append(("user", question))
def set_embedding_model(self, embed_id):
self._embed_id = embed_id
def get_embedding_model(self):
return self._embed_id
def _find_loop(self, max_loops=6):
path = self.path[-1][::-1]
if len(path) < 2:
@ -363,17 +451,75 @@ class Canvas:
return self.components["begin"]["obj"]._param.prologue
def set_global_param(self, **kwargs):
for k, v in kwargs.items():
for q in self.components["begin"]["obj"]._param.query:
if k != q["key"]:
continue
q["value"] = v
self.globals.update(kwargs)
def get_preset_param(self):
return self.components["begin"]["obj"]._param.query
return self.components["begin"]["obj"]._param.inputs
def get_component_input_elements(self, cpnnm):
return self.components[cpnnm]["obj"].get_input_elements()
def set_component_infor(self, cpn_id, infor):
self.components[cpn_id]["obj"].set_infor(infor)
def get_files(self, files: Union[None, list[dict]]) -> list[str]:
if not files:
return []
def image_to_base64(file):
return "data:{};base64,{}".format(file["mime_type"],
base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8"))
exe = ThreadPoolExecutor(max_workers=5)
threads = []
for file in files:
if file["mime_type"].find("image") >=0:
threads.append(exe.submit(image_to_base64, file))
continue
threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"]))
return [th.result() for th in threads]
def tool_use_callback(self, agent_id: str, func_name: str, params: dict, result: Any):
agent_ids = agent_id.split("-->")
agent_name = self.get_component_name(agent_ids[0])
path = agent_name if len(agent_ids) < 2 else agent_name+"-->"+"-->".join(agent_ids[1:])
try:
bin = REDIS_CONN.get(f"{self.task_id}-{self.message_id}-logs")
if bin:
obj = json.loads(bin.encode("utf-8"))
if obj[-1]["component_id"] == agent_ids[0]:
obj[-1]["trace"].append({"path": path, "tool_name": func_name, "arguments": params, "result": result})
else:
obj.append({
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
})
else:
obj = [{
"component_id": agent_ids[0],
"trace": [{"path": path, "tool_name": func_name, "arguments": params, "result": result}]
}]
REDIS_CONN.set_obj(f"{self.task_id}-{self.message_id}-logs", obj, 60*10)
except Exception as e:
logging.exception(e)
def add_refernce(self, chunks: list[object], doc_infos: list[object]):
if not self.retrieval:
self.retrieval = [{"chunks": {}, "doc_aggs": {}}]
r = self.retrieval[-1]
for ck in chunks_format({"chunks": chunks}):
cid = hash_str2int(ck["id"], 100)
if cid not in r:
r["chunks"][cid] = ck
for doc in doc_infos:
if doc["doc_name"] not in r:
r["doc_aggs"][doc["doc_name"]] = doc
def get_reference(self):
if not self.retrieval:
return {"chunks": {}, "doc_aggs": {}}
return self.retrieval[-1]
def add_memory(self, user:str, assist:str, summ: str):
self.memory.append((user, assist, summ))
def get_memory(self) -> list[Tuple]:
return self.memory

View File

@ -14,123 +14,44 @@
# limitations under the License.
#
import os
import importlib
from .begin import Begin, BeginParam
from .generate import Generate, GenerateParam
from .retrieval import Retrieval, RetrievalParam
from .answer import Answer, AnswerParam
from .categorize import Categorize, CategorizeParam
from .switch import Switch, SwitchParam
from .relevant import Relevant, RelevantParam
from .message import Message, MessageParam
from .rewrite import RewriteQuestion, RewriteQuestionParam
from .keyword import KeywordExtract, KeywordExtractParam
from .concentrator import Concentrator, ConcentratorParam
from .baidu import Baidu, BaiduParam
from .duckduckgo import DuckDuckGo, DuckDuckGoParam
from .wikipedia import Wikipedia, WikipediaParam
from .pubmed import PubMed, PubMedParam
from .arxiv import ArXiv, ArXivParam
from .google import Google, GoogleParam
from .bing import Bing, BingParam
from .googlescholar import GoogleScholar, GoogleScholarParam
from .deepl import DeepL, DeepLParam
from .github import GitHub, GitHubParam
from .baidufanyi import BaiduFanyi, BaiduFanyiParam
from .qweather import QWeather, QWeatherParam
from .exesql import ExeSQL, ExeSQLParam
from .yahoofinance import YahooFinance, YahooFinanceParam
from .wencai import WenCai, WenCaiParam
from .jin10 import Jin10, Jin10Param
from .tushare import TuShare, TuShareParam
from .akshare import AkShare, AkShareParam
from .crawler import Crawler, CrawlerParam
from .invoke import Invoke, InvokeParam
from .template import Template, TemplateParam
from .email import Email, EmailParam
from .iteration import Iteration, IterationParam
from .iterationitem import IterationItem, IterationItemParam
from .code import Code, CodeParam
import inspect
from types import ModuleType
from typing import Dict, Type
_package_path = os.path.dirname(__file__)
__all_classes: Dict[str, Type] = {}
def _import_submodules() -> None:
for filename in os.listdir(_package_path): # noqa: F821
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
continue
module_name = filename[:-3]
try:
module = importlib.import_module(f".{module_name}", package=__name__)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {module_name}: {str(e)}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
obj.__module__ == module.__name__ and not name.startswith("_")):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _package_path, _import_submodules, _extract_classes_from_module
def component_class(class_name):
m = importlib.import_module("agent.component")
c = getattr(m, class_name)
return c
try:
return getattr(m, class_name)
except Exception:
return getattr(importlib.import_module("agent.tools"), class_name)
__all__ = [
"Begin",
"BeginParam",
"Generate",
"GenerateParam",
"Retrieval",
"RetrievalParam",
"Answer",
"AnswerParam",
"Categorize",
"CategorizeParam",
"Switch",
"SwitchParam",
"Relevant",
"RelevantParam",
"Message",
"MessageParam",
"RewriteQuestion",
"RewriteQuestionParam",
"KeywordExtract",
"KeywordExtractParam",
"Concentrator",
"ConcentratorParam",
"Baidu",
"BaiduParam",
"DuckDuckGo",
"DuckDuckGoParam",
"Wikipedia",
"WikipediaParam",
"PubMed",
"PubMedParam",
"ArXiv",
"ArXivParam",
"Google",
"GoogleParam",
"Bing",
"BingParam",
"GoogleScholar",
"GoogleScholarParam",
"DeepL",
"DeepLParam",
"GitHub",
"GitHubParam",
"BaiduFanyi",
"BaiduFanyiParam",
"QWeather",
"QWeatherParam",
"ExeSQL",
"ExeSQLParam",
"YahooFinance",
"YahooFinanceParam",
"WenCai",
"WenCaiParam",
"Jin10",
"Jin10Param",
"TuShare",
"TuShareParam",
"AkShare",
"AkShareParam",
"Crawler",
"CrawlerParam",
"Invoke",
"InvokeParam",
"Iteration",
"IterationParam",
"IterationItem",
"IterationItemParam",
"Template",
"TemplateParam",
"Email",
"EmailParam",
"Code",
"CodeParam",
"component_class"
]

View File

@ -0,0 +1,332 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from functools import partial
from typing import Any
import json_repair
from agent.component.llm import LLMParam, LLM
from agent.tools.base import LLMToolPluginCallSession, ToolParamBase, ToolBase, ToolMeta
from api.db.services.llm_service import LLMBundle, TenantLLMService
from api.db.services.mcp_server_service import MCPServerService
from api.utils.api_utils import timeout
from rag.prompts import message_fit_in
from rag.prompts.prompts import next_step, COMPLETE_TASK, analyze_task, \
citation_prompt, reflect, rank_memories, kb_prompt, citation_plus, full_question
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
class AgentParam(LLMParam, ToolParamBase):
"""
Define the Agent component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "agent",
"description": "This is an agent for a specific task.",
"parameters": {
"user_prompt": {
"type": "string",
"description": "This is the order you need to send to the agent.",
"default": "",
"required": True
},
"reasoning": {
"type": "string",
"description": (
"Supervisor's reasoning for choosing the this agent. "
"Explain why this agent is being invoked and what is expected of it."
),
"required": True
},
"context": {
"type": "string",
"description": (
"All relevant background information, prior facts, decisions, "
"and state needed by the agent to solve the current query. "
"Should be as detailed and self-contained as possible."
),
"required": True
},
}
}
super().__init__()
self.function_name = "agent"
self.tools = []
self.mcp = []
self.max_rounds = 5
self.description = ""
class Agent(LLM, ToolBase):
component_name = "Agent"
def __init__(self, canvas, id, param: LLMParam):
LLM.__init__(self, canvas, id, param)
self.tools = {}
for cpn in self._param.tools:
cpn = self._load_tool_obj(cpn)
self.tools[cpn.get_meta()["function"]["name"]] = cpn
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error,
max_rounds=self._param.max_rounds,
verbose_tool_use=True
)
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
for mcp in self._param.mcp:
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
for tnm, meta in mcp["tools"].items():
self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta))
self.tools[tnm] = tool_call_session
self.callback = partial(self._canvas.tool_use_callback, id)
self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback)
#self.chat_mdl.bind_tools(self.toolcall_session, self.tool_metas)
def _load_tool_obj(self, cpn: dict) -> object:
from agent.component import component_class
param = component_class(cpn["component_name"] + "Param")()
param.update(cpn["params"])
try:
param.check()
except Exception as e:
self.set_output("_ERROR", cpn["component_name"] + f" configuration error: {e}")
raise
cpn_id = f"{self._id}-->" + cpn.get("name", "").replace(" ", "_")
return component_class(cpn["component_name"])(self._canvas, cpn_id, param)
def get_meta(self) -> dict[str, Any]:
self._param.function_name= self._id.split("-->")[-1]
m = super().get_meta()
if hasattr(self._param, "user_prompt") and self._param.user_prompt:
m["function"]["parameters"]["properties"]["user_prompt"] = self._param.user_prompt
return m
def get_input_form(self) -> dict[str, dict]:
res = {}
for k, v in self.get_input_elements().items():
res[k] = {
"type": "line",
"name": v["name"]
}
for cpn in self._param.tools:
if not isinstance(cpn, LLM):
continue
res.update(cpn.get_input_form())
return res
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))
def _invoke(self, **kwargs):
if kwargs.get("user_prompt"):
usr_pmt = ""
if kwargs.get("reasoning"):
usr_pmt += "\nREASONING:\n{}\n".format(kwargs["reasoning"])
if kwargs.get("context"):
usr_pmt += "\nCONTEXT:\n{}\n".format(kwargs["context"])
if usr_pmt:
usr_pmt += "\nQUERY:\n{}\n".format(str(kwargs["user_prompt"]))
else:
usr_pmt = str(kwargs["user_prompt"])
self._param.prompts = [{"role": "user", "content": usr_pmt}]
if not self.tools:
return LLM._invoke(self, **kwargs)
prompt, msg = self._prepare_prompt_variables()
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
self.set_output("content", partial(self.stream_output_with_tools, prompt, msg))
return
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
use_tools = []
ans = ""
for delta_ans, tk in self._react_with_tools_streamly(msg, use_tools):
ans += delta_ans
if ans.find("**ERROR**") >= 0:
logging.error(f"Agent._chat got error. response: {ans}")
self.set_output("_ERROR", ans)
return
self.set_output("content", ans)
if use_tools:
self.set_output("use_tools", use_tools)
return ans
def stream_output_with_tools(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer_without_toolcall = ""
use_tools = []
for delta_ans,_ in self._react_with_tools_streamly(msg, use_tools):
answer_without_toolcall += delta_ans
yield delta_ans
self.set_output("content", answer_without_toolcall)
if use_tools:
self.set_output("use_tools", use_tools)
def _gen_citations(self, text):
retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text}
]):
yield delta_ans
def _react_with_tools_streamly(self, history: list[dict], use_tools):
token_count = 0
tool_metas = self.tool_meta
hist = deepcopy(history)
last_calling = ""
if len(hist) > 3:
self.callback("Multi-turn conversation optimization", {}, " running ...")
user_request = full_question(messages=history, chat_mdl=self.chat_mdl)
else:
user_request = history[-1]["content"]
def use_tool(name, args):
nonlocal hist, use_tools, token_count,last_calling,user_request
print(f"{last_calling=} == {name=}", )
# Summarize of function calling
#if all([
# isinstance(self.toolcall_session.get_tool_obj(name), Agent),
# last_calling,
# last_calling != name
#]):
# self.toolcall_session.get_tool_obj(name).add2system_prompt(f"The chat history with other agents are as following: \n" + self.get_useful_memory(user_request, str(args["user_prompt"])))
last_calling = name
tool_response = self.toolcall_session.tool_call(name, args)
use_tools.append({
"name": name,
"arguments": args,
"results": tool_response
})
# self.callback("add_memory", {}, "...")
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response))
return name, tool_response
def complete():
nonlocal hist
need2cite = self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
cited = False
if hist[0]["role"] == "system" and need2cite:
if len(hist) < 7:
hist[0]["content"] += citation_prompt()
cited = True
yield "", token_count
if not cited and need2cite:
self.callback("gen_citations", {}, " running ...")
_hist = hist
if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = ""
for delta_ans in self._generate_streamly(_hist):
if not need2cite or cited:
yield delta_ans, 0
entire_txt += delta_ans
if not need2cite or cited:
return
for delta_ans in self._gen_citations(entire_txt):
yield delta_ans, 0
def append_user_content(hist, content):
if hist[-1]["role"] == "user":
hist[-1]["content"] += content
else:
hist.append({"role": "user", "content": content})
self.callback("analyze_task", {}, " running ...")
task_desc = analyze_task(self.chat_mdl, user_request, tool_metas)
for _ in range(self._param.max_rounds + 1):
response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc)
# self.callback("next_step", {}, str(response)[:256]+"...")
token_count += tk
hist.append({"role": "assistant", "content": response})
try:
functions = json_repair.loads(re.sub(r"```.*", "", response))
if not isinstance(functions, list):
raise TypeError(f"List should be returned, but `{functions}`")
for f in functions:
if not isinstance(f, dict):
raise TypeError(f"An object type should be returned, but `{f}`")
with ThreadPoolExecutor(max_workers=5) as executor:
thr = []
for func in functions:
name = func["name"]
args = func["arguments"]
if name == COMPLETE_TASK:
append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
for txt, tkcnt in complete():
yield txt, tkcnt
return
thr.append(executor.submit(use_tool, name, args))
reflection = reflect(self.chat_mdl, hist, [th.result() for th in thr])
append_user_content(hist, reflection)
self.callback("reflection", {}, str(reflection))
except Exception as e:
logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
append_user_content(hist, str(e))
logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
final_instruction = f"""
{user_request}
IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
Instructions:
1. SYNTHESIZE all information collected during this conversation
2. Provide a COMPLETE response using existing data - do not suggest additional research
3. Structure your response as a FINAL DELIVERABLE, not a plan
4. If information is incomplete, state what you found and provide the best analysis possible with available data
5. DO NOT mention conversation limits or suggest further steps
6. Focus on delivering VALUE with the information already gathered
Respond immediately with your final comprehensive answer.
"""
append_user_content(hist, final_instruction)
for txt, tkcnt in complete():
yield txt, tkcnt
def get_useful_memory(self, goal: str, sub_goal:str, topn=3) -> str:
# self.callback("get_useful_memory", {"topn": 3}, "...")
mems = self._canvas.get_memory()
rank = rank_memories(self.chat_mdl, goal, sub_goal, [summ for (user, assist, summ) in mems])
try:
rank = json_repair.loads(re.sub(r"```.*", "", rank))[:topn]
mems = [mems[r] for r in rank]
return "\n\n".join([f"User: {u}\nAgent: {a}" for u, a,_ in mems])
except Exception as e:
logging.exception(e)
return "Error occurred."

View File

@ -1,92 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from abc import ABC
from functools import partial
from typing import Tuple, Union
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
class AnswerParam(ComponentParamBase):
"""
Define the Answer component parameters.
"""
def __init__(self):
super().__init__()
self.post_answers = []
def check(self):
return True
class Answer(ComponentBase, ABC):
component_name = "Answer"
def _run(self, history, **kwargs):
if kwargs.get("stream"):
return partial(self.stream_output)
ans = self.get_input()
if self._param.post_answers:
ans = pd.concat([ans, pd.DataFrame([{"content": random.choice(self._param.post_answers)}])], ignore_index=False)
return ans
def stream_output(self):
res = None
if hasattr(self, "exception") and self.exception:
res = {"content": str(self.exception)}
self.exception = None
yield res
self.set_output(res)
return
stream = self.get_stream_input()
if isinstance(stream, pd.DataFrame):
res = stream
answer = ""
for ii, row in stream.iterrows():
answer += row.to_dict()["content"]
yield {"content": answer}
elif stream is not None:
for st in stream():
res = st
yield st
if self._param.post_answers and res:
res["content"] += random.choice(self._param.post_answers)
yield res
if res is None:
res = {"content": ""}
self.set_output(res)
def set_exception(self, e):
self.exception = e
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
if allow_partial:
return super.output()
for r, c in self._canvas.history[::-1]:
if r == "user":
return self._param.output_var_name, pd.DataFrame([{"content": c}])
self._param.output_var_name, pd.DataFrame([])

View File

@ -1,68 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import arxiv
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
class ArXivParam(ComponentParamBase):
"""
Define the ArXiv component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 6
self.sort_by = 'submittedDate'
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.sort_by, "ArXiv Search Sort_by",
['submittedDate', 'lastUpdatedDate', 'relevance'])
class ArXiv(ComponentBase, ABC):
component_name = "ArXiv"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return ArXiv.be_output("")
try:
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
'submittedDate': arxiv.SortCriterion.SubmittedDate}
arxiv_client = arxiv.Client()
search = arxiv.Search(
query=ans,
max_results=self._param.top_n,
sort_by=sort_choices[self._param.sort_by]
)
arxiv_res = [
{"content": 'Title: ' + i.title + '\nPdf_Url: <a href="' + i.pdf_url + '"></a> \nSummary: ' + i.summary} for
i in list(arxiv_client.results(search))]
except Exception as e:
return ArXiv.be_output("**ERROR**: " + str(e))
if not arxiv_res:
return ArXiv.be_output("")
df = pd.DataFrame(arxiv_res)
logging.debug(f"df: {str(df)}")
return df

View File

@ -1,79 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import pandas as pd
import requests
from bs4 import BeautifulSoup
import re
from agent.component.base import ComponentBase, ComponentParamBase
class BaiduParam(ComponentParamBase):
"""
Define the Baidu component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
def check(self):
self.check_positive_integer(self.top_n, "Top N")
class Baidu(ComponentBase, ABC):
component_name = "Baidu"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return Baidu.be_output("")
try:
url = 'https://www.baidu.com/s?wd=' + ans + '&rn=' + str(self._param.top_n)
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8',
'Connection': 'keep-alive',
}
response = requests.get(url=url, headers=headers)
# check if request success
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
url_res = []
title_res = []
body_res = []
for item in soup.select('.result.c-container'):
# extract title
title_res.append(item.select_one('h3 a').get_text(strip=True))
url_res.append(item.select_one('h3 a')['href'])
body_res.append(item.select_one('.c-abstract').get_text(strip=True) if item.select_one('.c-abstract') else '')
baidu_res = [{"content": re.sub('<em>|</em>', '', '<a href="' + url + '">' + title + '</a> ' + body)} for
url, title, body in zip(url_res, title_res, body_res)]
del body_res, url_res, title_res
except Exception as e:
return Baidu.be_output("**ERROR**: " + str(e))
if not baidu_res:
return Baidu.be_output("")
df = pd.DataFrame(baidu_res)
logging.debug(f"df: {str(df)}")
return df

View File

@ -1,96 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import random
from abc import ABC
import requests
from agent.component.base import ComponentBase, ComponentParamBase
from hashlib import md5
class BaiduFanyiParam(ComponentParamBase):
"""
Define the BaiduFanyi component parameters.
"""
def __init__(self):
super().__init__()
self.appid = "xxx"
self.secret_key = "xxx"
self.trans_type = 'translate'
self.parameters = []
self.source_lang = 'auto'
self.target_lang = 'auto'
self.domain = 'finance'
def check(self):
self.check_empty(self.appid, "BaiduFanyi APPID")
self.check_empty(self.secret_key, "BaiduFanyi Secret Key")
self.check_valid_value(self.trans_type, "Translate type", ['translate', 'fieldtranslate'])
self.check_valid_value(self.source_lang, "Source language",
['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
'hu', 'cht', 'vie'])
self.check_valid_value(self.target_lang, "Target language",
['auto', 'zh', 'en', 'yue', 'wyw', 'jp', 'kor', 'fra', 'spa', 'th', 'ara', 'ru', 'pt',
'de', 'it', 'el', 'nl', 'pl', 'bul', 'est', 'dan', 'fin', 'cs', 'rom', 'slo', 'swe',
'hu', 'cht', 'vie'])
self.check_valid_value(self.domain, "Translate field",
['it', 'finance', 'machinery', 'senimed', 'novel', 'academic', 'aerospace', 'wiki',
'news', 'law', 'contract'])
class BaiduFanyi(ComponentBase, ABC):
component_name = "BaiduFanyi"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return BaiduFanyi.be_output("")
try:
source_lang = self._param.source_lang
target_lang = self._param.target_lang
appid = self._param.appid
salt = random.randint(32768, 65536)
secret_key = self._param.secret_key
if self._param.trans_type == 'translate':
sign = md5((appid + ans + salt + secret_key).encode('utf-8')).hexdigest()
url = 'http://api.fanyi.baidu.com/api/trans/vip/translate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&sign=' + sign
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(url=url, headers=headers).json()
if response.get('error_code'):
BaiduFanyi.be_output("**Error**:" + response['error_msg'])
return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
elif self._param.trans_type == 'fieldtranslate':
domain = self._param.domain
sign = md5((appid + ans + salt + domain + secret_key).encode('utf-8')).hexdigest()
url = 'http://api.fanyi.baidu.com/api/trans/vip/fieldtranslate?' + 'q=' + ans + '&from=' + source_lang + '&to=' + target_lang + '&appid=' + appid + '&salt=' + salt + '&domain=' + domain + '&sign=' + sign
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = requests.post(url=url, headers=headers).json()
if response.get('error_code'):
BaiduFanyi.be_output("**Error**:" + response['error_msg'])
return BaiduFanyi.be_output(response['trans_result'][0]['dst'])
except Exception as e:
BaiduFanyi.be_output("**Error**:" + str(e))

View File

@ -13,17 +13,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import re
import time
from abc import ABC
import builtins
import json
import logging
import os
from abc import ABC
from functools import partial
from typing import Any, Tuple, Union
import logging
from typing import Any, List, Union
import pandas as pd
import trio
from agent import settings
from api.utils.api_utils import timeout
_FEEDED_DEPRECATED_PARAMS = "_feeded_deprecated_params"
_DEPRECATED_PARAMS = "_deprecated_params"
@ -33,12 +35,17 @@ _IS_RAW_CONF = "_is_raw_conf"
class ComponentParamBase(ABC):
def __init__(self):
self.output_var_name = "output"
self.infor_var_name = "infor"
self.message_history_window_size = 22
self.query = []
self.inputs = []
self.debug_inputs = []
self.inputs = {}
self.outputs = {}
self.description = ""
self.max_retries = 0
self.delay_after_error = 2.0
self.exception_method = None
self.exception_default_value = None
self.exception_comment = None
self.exception_goto = None
self.debug_inputs = {}
def set_name(self, name: str):
self._name = name
@ -110,11 +117,15 @@ class ComponentParamBase(ABC):
update_from_raw_conf = conf.get(_IS_RAW_CONF, True)
if update_from_raw_conf:
deprecated_params_set = self._get_or_init_deprecated_params_set()
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set()
feeded_deprecated_params_set = (
self._get_or_init_feeded_deprecated_params_set()
)
user_feeded_params_set = self._get_or_init_user_feeded_params_set()
setattr(self, _IS_RAW_CONF, False)
else:
feeded_deprecated_params_set = self._get_or_init_feeded_deprecated_params_set(conf)
feeded_deprecated_params_set = (
self._get_or_init_feeded_deprecated_params_set(conf)
)
user_feeded_params_set = self._get_or_init_user_feeded_params_set(conf)
def _recursive_update_param(param, config, depth, prefix):
@ -150,11 +161,15 @@ class ComponentParamBase(ABC):
else:
# recursive set obj attr
sub_params = _recursive_update_param(attr, config_value, depth + 1, prefix=f"{prefix}{config_key}.")
sub_params = _recursive_update_param(
attr, config_value, depth + 1, prefix=f"{prefix}{config_key}."
)
setattr(param, config_key, sub_params)
if not allow_redundant and redundant_attrs:
raise ValueError(f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`")
raise ValueError(
f"cpn `{getattr(self, '_name', type(self))}` has redundant parameters: `{[redundant_attrs]}`"
)
return param
@ -185,7 +200,9 @@ class ComponentParamBase(ABC):
param_validation_path_prefix = home_dir + "/param_validation/"
param_name = type(self).__name__
param_validation_path = "/".join([param_validation_path_prefix, param_name + ".json"])
param_validation_path = "/".join(
[param_validation_path_prefix, param_name + ".json"]
)
validation_json = None
@ -218,7 +235,11 @@ class ComponentParamBase(ABC):
break
if not value_legal:
raise ValueError("Plase check runtime conf, {} = {} does not match user-parameter restriction".format(variable, value))
raise ValueError(
"Plase check runtime conf, {} = {} does not match user-parameter restriction".format(
variable, value
)
)
elif variable in validation_json:
self._validate_param(attr, validation_json)
@ -226,63 +247,94 @@ class ComponentParamBase(ABC):
@staticmethod
def check_string(param, descr):
if type(param).__name__ not in ["str"]:
raise ValueError(descr + " {} not supported, should be string type".format(param))
raise ValueError(
descr + " {} not supported, should be string type".format(param)
)
@staticmethod
def check_empty(param, descr):
if not param:
raise ValueError(descr + " does not support empty value.")
raise ValueError(
descr + " does not support empty value."
)
@staticmethod
def check_positive_integer(param, descr):
if type(param).__name__ not in ["int", "long"] or param <= 0:
raise ValueError(descr + " {} not supported, should be positive integer".format(param))
raise ValueError(
descr + " {} not supported, should be positive integer".format(param)
)
@staticmethod
def check_positive_number(param, descr):
if type(param).__name__ not in ["float", "int", "long"] or param <= 0:
raise ValueError(descr + " {} not supported, should be positive numeric".format(param))
raise ValueError(
descr + " {} not supported, should be positive numeric".format(param)
)
@staticmethod
def check_nonnegative_number(param, descr):
if type(param).__name__ not in ["float", "int", "long"] or param < 0:
raise ValueError(descr + " {} not supported, should be non-negative numeric".format(param))
raise ValueError(
descr
+ " {} not supported, should be non-negative numeric".format(param)
)
@staticmethod
def check_decimal_float(param, descr):
if type(param).__name__ not in ["float", "int"] or param < 0 or param > 1:
raise ValueError(descr + " {} not supported, should be a float number in range [0, 1]".format(param))
raise ValueError(
descr
+ " {} not supported, should be a float number in range [0, 1]".format(
param
)
)
@staticmethod
def check_boolean(param, descr):
if type(param).__name__ != "bool":
raise ValueError(descr + " {} not supported, should be bool type".format(param))
raise ValueError(
descr + " {} not supported, should be bool type".format(param)
)
@staticmethod
def check_open_unit_interval(param, descr):
if type(param).__name__ not in ["float"] or param <= 0 or param >= 1:
raise ValueError(descr + " should be a numeric number between 0 and 1 exclusively")
raise ValueError(
descr + " should be a numeric number between 0 and 1 exclusively"
)
@staticmethod
def check_valid_value(param, descr, valid_values):
if param not in valid_values:
raise ValueError(descr + " {} is not supported, it should be in {}".format(param, valid_values))
raise ValueError(
descr
+ " {} is not supported, it should be in {}".format(param, valid_values)
)
@staticmethod
def check_defined_type(param, descr, types):
if type(param).__name__ not in types:
raise ValueError(descr + " {} not supported, should be one of {}".format(param, types))
raise ValueError(
descr + " {} not supported, should be one of {}".format(param, types)
)
@staticmethod
def check_and_change_lower(param, valid_list, descr=""):
if type(param).__name__ != "str":
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
raise ValueError(
descr
+ " {} not supported, should be one of {}".format(param, valid_list)
)
lower_param = param.lower()
if lower_param in valid_list:
return lower_param
else:
raise ValueError(descr + " {} not supported, should be one of {}".format(param, valid_list))
raise ValueError(
descr
+ " {} not supported, should be one of {}".format(param, valid_list)
)
@staticmethod
def _greater_equal_than(value, limit):
@ -296,7 +348,11 @@ class ComponentParamBase(ABC):
def _range(value, ranges):
in_range = False
for left_limit, right_limit in ranges:
if left_limit - settings.FLOAT_ZERO <= value <= right_limit + settings.FLOAT_ZERO:
if (
left_limit - settings.FLOAT_ZERO
<= value
<= right_limit + settings.FLOAT_ZERO
):
in_range = True
break
@ -312,17 +368,24 @@ class ComponentParamBase(ABC):
def _warn_deprecated_param(self, param_name, descr):
if self._deprecated_params_set.get(param_name):
logging.warning(f"{descr} {param_name} is deprecated and ignored in this version.")
logging.warning(
f"{descr} {param_name} is deprecated and ignored in this version."
)
def _warn_to_deprecate_param(self, param_name, descr, new_param):
if self._deprecated_params_set.get(param_name):
logging.warning(f"{descr} {param_name} will be deprecated in future release; please use {new_param} instead.")
logging.warning(
f"{descr} {param_name} will be deprecated in future release; "
f"please use {new_param} instead."
)
return True
return False
class ComponentBase(ABC):
component_name: str
thread_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z:0-9_.-]+|sys\.[a-z_]+)\} *\}*"
def __str__(self):
"""
@ -331,232 +394,144 @@ class ComponentBase(ABC):
"params": {}
}
"""
out = getattr(self._param, self._param.output_var_name)
if isinstance(out, pd.DataFrame) and "chunks" in out:
del out["chunks"]
setattr(self._param, self._param.output_var_name, out)
return """{{
"component_name": "{}",
"params": {},
"output": {},
"inputs": {}
}}""".format(
self.component_name,
self._param,
json.dumps(json.loads(str(self._param)).get("output", {}), ensure_ascii=False),
json.dumps(json.loads(str(self._param)).get("inputs", []), ensure_ascii=False),
"params": {}
}}""".format(self.component_name,
self._param
)
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
self._param = param
self._param.check()
def get_dependent_components(self):
cpnts = set(
[
para["component_id"].split("@")[0]
for para in self._param.query
if para.get("component_id") and para["component_id"].lower().find("answer") < 0 and para["component_id"].lower().find("begin") < 0
]
)
return list(cpnts)
def run(self, history, **kwargs):
logging.debug("{}, history: {}, kwargs: {}".format(self, json.dumps(history, ensure_ascii=False), json.dumps(kwargs, ensure_ascii=False)))
self._param.debug_inputs = []
def invoke(self, **kwargs) -> dict[str, Any]:
self.set_output("_created_time", time.perf_counter())
try:
res = self._run(history, **kwargs)
self.set_output(res)
self._invoke(**kwargs)
except Exception as e:
self.set_output(pd.DataFrame([{"content": str(e)}]))
raise e
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
self._param.debug_inputs = {}
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return self.output()
return res
def _run(self, history, **kwargs):
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
raise NotImplementedError()
def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
o = getattr(self._param, self._param.output_var_name)
if not isinstance(o, partial):
if not isinstance(o, pd.DataFrame):
if isinstance(o, list):
return self._param.output_var_name, pd.DataFrame(o).dropna()
if o is None:
return self._param.output_var_name, pd.DataFrame()
return self._param.output_var_name, pd.DataFrame([{"content": str(o)}])
return self._param.output_var_name, o
def output(self, var_nm: str=None) -> Union[dict[str, Any], Any]:
if var_nm:
return self._param.outputs.get(var_nm, {}).get("value")
return {k: o.get("value") for k,o in self._param.outputs.items()}
if allow_partial or not isinstance(o, partial):
if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
return pd.DataFrame(o if isinstance(o, list) else [o]).dropna()
return self._param.output_var_name, o
def set_output(self, key: str, value: Any):
if key not in self._param.outputs:
self._param.outputs[key] = {"value": None, "type": str(type(value))}
self._param.outputs[key]["value"] = value
outs = None
for oo in o():
if not isinstance(oo, pd.DataFrame):
outs = pd.DataFrame(oo if isinstance(oo, list) else [oo]).dropna()
else:
outs = oo.dropna()
return self._param.output_var_name, outs
def error(self):
return self._param.outputs.get("_ERROR", {}).get("value")
def reset(self):
setattr(self._param, self._param.output_var_name, None)
self._param.inputs = []
for k in self._param.outputs.keys():
self._param.outputs[k]["value"] = None
for k in self._param.inputs.keys():
self._param.inputs[k]["value"] = None
self._param.debug_inputs = {}
def set_output(self, v):
setattr(self._param, self._param.output_var_name, v)
def get_input(self, key: str=None) -> Union[Any, dict[str, Any]]:
if key:
return self._param.inputs.get(key, {}).get("value")
def set_infor(self, v):
setattr(self._param, self._param.infor_var_name, v)
def _fetch_outputs_from(self, sources: list[dict[str, Any]]) -> list[pd.DataFrame]:
outs = []
for q in sources:
if q.get("component_id"):
if "@" in q["component_id"] and q["component_id"].split("@")[0].lower().find("begin") >= 0:
cpn_id, key = q["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
outs.append(pd.DataFrame([{"content": p.get("value", "")}]))
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
if q["component_id"].lower().find("answer") == 0:
txt = []
for r, c in self._canvas.history[::-1][: self._param.message_history_window_size][::-1]:
txt.append(f"{r.upper()}:{c}")
txt = "\n".join(txt)
outs.append(pd.DataFrame([{"content": txt}]))
continue
outs.append(self._canvas.get_component(q["component_id"])["obj"].output(allow_partial=False)[1])
elif q.get("value"):
outs.append(pd.DataFrame([{"content": q["value"]}]))
return outs
def get_input(self):
if self._param.debug_inputs:
return pd.DataFrame([{"content": v["value"]} for v in self._param.debug_inputs if v.get("value")])
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
up_cpns = self.get_upstream()
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
if self._param.query:
self._param.inputs = []
outs = self._fetch_outputs_from(self._param.query)
for out in outs:
records = out.to_dict("records")
content: str
if len(records) > 1:
content = "\n".join([str(d["content"]) for d in records])
else:
content = records[0]["content"]
self._param.inputs.append({"component_id": records[0].get("component_id"), "content": content})
if outs:
df = pd.concat(outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
return df
upstream_outs = []
for u in reversed_up_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "concentrator"]:
res = {}
for var, o in self.get_input_elements().items():
v = self.get_param(var)
if v is None:
continue
if self.component_name.lower() == "generate" and self.get_component_name(u) == "retrieval":
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
o["component_id"] = u
upstream_outs.append(o)
continue
# if self.component_name.lower()!="answer" and u not in self._canvas.get_component(self._id)["upstream"]: continue
if self.component_name.lower().find("switch") < 0 and self.get_component_name(u) in ["relevant", "categorize"]:
continue
if u.lower().find("answer") >= 0:
for r, c in self._canvas.history[::-1]:
if r == "user":
upstream_outs.append(pd.DataFrame([{"content": c, "component_id": u}]))
break
break
if self.component_name.lower().find("answer") >= 0 and self.get_component_name(u) in ["relevant"]:
continue
o = self._canvas.get_component(u)["obj"].output(allow_partial=False)[1]
if o is not None:
o["component_id"] = u
upstream_outs.append(o)
break
assert upstream_outs, "Can't inference the where the component input is. Please identify whose output is this component's input."
df = pd.concat(upstream_outs, ignore_index=True)
if "content" in df:
df = df.drop_duplicates(subset=["content"]).reset_index(drop=True)
self._param.inputs = []
for _, r in df.iterrows():
self._param.inputs.append({"component_id": r["component_id"], "content": r["content"]})
return df
def get_input_elements(self):
assert self._param.query, "Please verify the input parameters first."
eles = []
for q in self._param.query:
if q.get("component_id"):
cpn_id = q["component_id"]
if cpn_id.split("@")[0].lower().find("begin") >= 0:
cpn_id, key = cpn_id.split("@")
eles.extend(self._canvas.get_component(cpn_id)["obj"]._param.query)
continue
eles.append({"name": self._canvas.get_component_name(cpn_id), "key": cpn_id})
if isinstance(v, str) and self._canvas.is_reff(v):
self.set_input_value(var, self._canvas.get_variable_value(v))
else:
eles.append({"key": q["value"], "name": q["value"], "value": q["value"]})
return eles
self.set_input_value(var, v)
res[var] = self.get_input_value(var)
return res
def get_stream_input(self):
reversed_cpnts = []
if len(self._canvas.path) > 1:
reversed_cpnts.extend(self._canvas.path[-2])
reversed_cpnts.extend(self._canvas.path[-1])
up_cpns = self.get_upstream()
reversed_up_cpnts = [cpn for cpn in reversed_cpnts if cpn in up_cpns]
def get_input_values(self) -> Union[Any, dict[str, Any]]:
if self._param.debug_inputs:
return self._param.debug_inputs
for u in reversed_up_cpnts[::-1]:
if self.get_component_name(u) in ["switch", "answer"]:
continue
return self._canvas.get_component(u)["obj"].output()[1]
return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()}
@staticmethod
def be_output(v):
return pd.DataFrame([{"content": v}])
def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]:
res = {}
for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE):
exp = r.group(1)
cpn_id, var_nm = exp.split("@") if exp.find("@")>0 else ("", exp)
res[exp] = {
"name": (self._canvas.get_component_name(cpn_id) +f"@{var_nm}") if cpn_id else exp,
"value": self._canvas.get_variable_value(exp),
"_retrival": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None,
"_cpn_id": cpn_id
}
return res
def get_component_name(self, cpn_id):
def get_input_elements(self) -> dict[str, Any]:
return self._param.inputs
def get_input_form(self) -> dict[str, dict]:
return self._param.get_input_form()
def set_input_value(self, key: str, value: Any) -> None:
if key not in self._param.inputs:
self._param.inputs[key] = {"value": None}
self._param.inputs[key]["value"] = value
def get_input_value(self, key: str) -> Any:
if key not in self._param.inputs:
return None
return self._param.inputs[key].get("value")
def get_component_name(self, cpn_id) -> str:
return self._canvas.get_component(cpn_id)["obj"].component_name.lower()
def debug(self, **kwargs):
return self._run([], **kwargs)
def get_param(self, name):
if hasattr(self._param, name):
return getattr(self._param, name)
def get_parent(self):
pid = self._canvas.get_component(self._id)["parent_id"]
def debug(self, **kwargs):
return self._invoke(**kwargs)
def get_parent(self) -> Union[object, None]:
pid = self._canvas.get_component(self._id).get("parent_id")
if not pid:
return
return self._canvas.get_component(pid)["obj"]
def get_upstream(self):
cpn_nms = self._canvas.get_component(self._id)["upstream"]
def get_upstream(self) -> List[str]:
cpn_nms = self._canvas.get_component(self._id)['upstream']
return cpn_nms
@staticmethod
def string_format(content: str, kv: dict[str, str]) -> str:
for n, v in kv.items():
content = re.sub(
r"\{%s\}" % re.escape(n), re.escape(v), content
)
return content
def exception_handler(self):
if not self._param.exception_method:
return
return {
"goto": self._param.exception_goto,
"comment": self._param.exception_comment,
"default_value": self._param.exception_default_value
}
def get_exception_default_value(self):
return self._param.exception_default_value

View File

@ -13,37 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from functools import partial
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
from agent.component.fillup import UserFillUpParam, UserFillUp
class BeginParam(ComponentParamBase):
class BeginParam(UserFillUpParam):
"""
Define the Begin component parameters.
"""
def __init__(self):
super().__init__()
self.mode = "conversational"
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
self.query = []
def check(self):
return True
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
def get_input_form(self) -> dict[str, dict]:
return getattr(self, "inputs")
class Begin(ComponentBase):
class Begin(UserFillUp):
component_name = "Begin"
def _run(self, history, **kwargs):
if kwargs.get("stream"):
return partial(self.stream_output)
return pd.DataFrame([{"content": self._param.prologue}])
def stream_output(self):
res = {"content": self._param.prologue}
yield res
self.set_output(self.be_output(res))
def _invoke(self, **kwargs):
for k, v in kwargs.get("inputs", {}).items():
if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0:
v = self._canvas.get_files([v["value"]])
else:
v = v.get("value")
self.set_output(k, v)
self.set_input_value(k, v)

View File

@ -1,84 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import requests
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
class BingParam(ComponentParamBase):
"""
Define the Bing component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
self.channel = "Webpages"
self.api_key = "YOUR_ACCESS_KEY"
self.country = "CN"
self.language = "en"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.channel, "Bing Web Search or Bing News", ["Webpages", "News"])
self.check_empty(self.api_key, "Bing subscription key")
self.check_valid_value(self.country, "Bing Country",
['AR', 'AU', 'AT', 'BE', 'BR', 'CA', 'CL', 'DK', 'FI', 'FR', 'DE', 'HK', 'IN', 'ID',
'IT', 'JP', 'KR', 'MY', 'MX', 'NL', 'NZ', 'NO', 'CN', 'PL', 'PT', 'PH', 'RU', 'SA',
'ZA', 'ES', 'SE', 'CH', 'TW', 'TR', 'GB', 'US'])
self.check_valid_value(self.language, "Bing Languages",
['ar', 'eu', 'bn', 'bg', 'ca', 'ns', 'nt', 'hr', 'cs', 'da', 'nl', 'en', 'gb', 'et',
'fi', 'fr', 'gl', 'de', 'gu', 'he', 'hi', 'hu', 'is', 'it', 'jp', 'kn', 'ko', 'lv',
'lt', 'ms', 'ml', 'mr', 'nb', 'pl', 'br', 'pt', 'pa', 'ro', 'ru', 'sr', 'sk', 'sl',
'es', 'sv', 'ta', 'te', 'th', 'tr', 'uk', 'vi'])
class Bing(ComponentBase, ABC):
component_name = "Bing"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return Bing.be_output("")
try:
headers = {"Ocp-Apim-Subscription-Key": self._param.api_key, 'Accept-Language': self._param.language}
params = {"q": ans, "textDecorations": True, "textFormat": "HTML", "cc": self._param.country,
"answerCount": 1, "promote": self._param.channel}
if self._param.channel == "Webpages":
response = requests.get("https://api.bing.microsoft.com/v7.0/search", headers=headers, params=params)
response.raise_for_status()
search_results = response.json()
bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["snippet"]} for i in
search_results["webPages"]["value"]]
elif self._param.channel == "News":
response = requests.get("https://api.bing.microsoft.com/v7.0/news/search", headers=headers,
params=params)
response.raise_for_status()
search_results = response.json()
bing_res = [{"content": '<a href="' + i["url"] + '">' + i["name"] + '</a> ' + i["description"]} for i
in search_results['news']['value']]
except Exception as e:
return Bing.be_output("**ERROR**: " + str(e))
if not bing_res:
return Bing.be_output("")
df = pd.DataFrame(bing_res)
logging.debug(f"df: {str(df)}")
return df

View File

@ -14,13 +14,18 @@
# limitations under the License.
#
import logging
import os
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate
from agent.component import LLMParam, LLM
from api.utils.api_utils import timeout
from rag.llm.chat_model import ERROR_PREFIX
class CategorizeParam(GenerateParam):
class CategorizeParam(LLMParam):
"""
Define the Categorize component parameters.
@ -28,10 +33,12 @@ class CategorizeParam(GenerateParam):
def __init__(self):
super().__init__()
self.category_description = {}
self.prompt = ""
self.query = "sys.query"
self.message_history_window_size = 1
self.update_prompt()
def check(self):
super().check()
self.check_positive_integer(self.message_history_window_size, "[Categorize] Message window size > 0")
self.check_empty(self.category_description, "[Categorize] Category examples")
for k, v in self.category_description.items():
if not k:
@ -39,76 +46,90 @@ class CategorizeParam(GenerateParam):
if not v.get("to"):
raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
def get_prompt(self, chat_hist):
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"type": "line",
"name": "Query"
}
}
def update_prompt(self):
cate_lines = []
for c, desc in self.category_description.items():
for line in desc.get("examples", "").split("\n"):
for line in desc.get("examples", []):
if not line:
continue
cate_lines.append("USER: {}\nCategory: {}".format(line, c))
cate_lines.append("USER: \"" + re.sub(r"\n", " ", line, flags=re.DOTALL) + "\""+c)
descriptions = []
for c, desc in self.category_description.items():
if desc.get("description"):
descriptions.append(
"\nCategory: {}\nDescription: {}".format(c, desc["description"]))
"\n------\nCategory: {}\nDescription: {}".format(c, desc["description"]))
self.prompt = """
Role: You're a text classifier.
Task: You need to categorize the users questions into {} categories, namely: {}
self.sys_prompt = """
You are an advanced classification system that categorizes user questions into specific types. Analyze the input question and classify it into ONE of the following categories:
{}
Here's description of each category:
{}
- {}
You could learn from the following examples:
{}
You could learn from the above examples.
Requirements:
- Just mention the category names, no need for any additional words.
---- Real Data ----
USER: {}\n
""".format(
len(self.category_description.keys()),
"/".join(list(self.category_description.keys())),
"\n".join(descriptions),
"\n\n- ".join(cate_lines),
chat_hist
---- Instructions ----
- Consider both explicit mentions and implied context
- Prioritize the most specific applicable category
- Return only the category name without explanations
- Use "Other" only when no other category fits
""".format(
"\n - ".join(list(self.category_description.keys())),
"\n".join(descriptions)
)
return self.prompt
if cate_lines:
self.sys_prompt += """
---- Examples ----
{}
""".format("\n".join(cate_lines))
class Categorize(Generate, ABC):
class Categorize(LLM, ABC):
component_name = "Categorize"
def _run(self, history, **kwargs):
input = self.get_input()
input = " - ".join(input["content"]) if "content" in input else ""
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
msg = self._canvas.get_history(self._param.message_history_window_size)
if not msg:
msg = [{"role": "user", "content": ""}]
if kwargs.get("sys.query"):
msg[-1]["content"] = kwargs["sys.query"]
self.set_input_value("sys.query", kwargs["sys.query"])
else:
msg[-1]["content"] = self._canvas.get_variable_value(self._param.query)
self.set_input_value(self._param.query, msg[-1]["content"])
self._param.update_prompt()
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
self._canvas.set_component_infor(self._id, {"prompt":self._param.get_prompt(input),"messages": [{"role": "user", "content": "\nCategory: "}],"conf": self._param.gen_conf()})
ans = chat_mdl.chat(self._param.get_prompt(input), [{"role": "user", "content": "\nCategory: "}],
self._param.gen_conf())
logging.debug(f"input: {input}, answer: {str(ans)}")
user_prompt = """
---- Real Data ----
{}
""".format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg]))
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans:
raise Exception(ans)
# Count the number of times each category appears in the answer.
category_counts = {}
for c in self._param.category_description.keys():
count = ans.lower().count(c.lower())
category_counts[c] = count
# If a category is found, return the category with the highest count.
cpn_ids = list(self._param.category_description.items())[-1][1]["to"]
max_category = list(self._param.category_description.keys())[0]
if any(category_counts.values()):
max_category = max(category_counts.items(), key=lambda x: x[1])
res = Categorize.be_output(self._param.category_description[max_category[0]]["to"])
self.set_output(res)
return res
max_category = max(category_counts.items(), key=lambda x: x[1])[0]
cpn_ids = self._param.category_description[max_category]["to"]
res = Categorize.be_output(list(self._param.category_description.items())[-1][1]["to"])
self.set_output(res)
return res
def debug(self, **kwargs):
df = self._run([], **kwargs)
cpn_id = df.iloc[0, 0]
return Categorize.be_output(self._canvas.get_component_name(cpn_id))
self.set_output("category_name", max_category)
self.set_output("_next", cpn_ids)

View File

@ -1,152 +0,0 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
from abc import ABC
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from agent.component.base import ComponentBase, ComponentParamBase
from api import settings
class Language(str, Enum):
PYTHON = "python"
NODEJS = "nodejs"
class CodeExecutionRequest(BaseModel):
code_b64: str = Field(..., description="Base64 encoded code string")
language: Language = Field(default=Language.PYTHON, description="Programming language")
arguments: Optional[dict] = Field(default={}, description="Arguments")
@field_validator("code_b64")
@classmethod
def validate_base64(cls, v: str) -> str:
try:
base64.b64decode(v, validate=True)
return v
except Exception as e:
raise ValueError(f"Invalid base64 encoding: {str(e)}")
@field_validator("language", mode="before")
@classmethod
def normalize_language(cls, v) -> str:
if isinstance(v, str):
low = v.lower()
if low in ("python", "python3"):
return "python"
elif low in ("javascript", "nodejs"):
return "nodejs"
raise ValueError(f"Unsupported language: {v}")
class CodeParam(ComponentParamBase):
"""
Define the code sandbox component parameters.
"""
def __init__(self):
super().__init__()
self.lang = "python"
self.script = ""
self.arguments = []
self.address = f"http://{settings.SANDBOX_HOST}:9385/run"
self.enable_network = True
def check(self):
self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
self.check_defined_type(self.enable_network, "Enable network", ["bool"])
class Code(ComponentBase, ABC):
component_name = "Code"
def _run(self, history, **kwargs):
arguments = {}
for input in self._param.arguments:
if "@" in input["component_id"]:
component_id = input["component_id"].split("@")[0]
referred_component_key = input["component_id"].split("@")[1]
referred_component = self._canvas.get_component(component_id)["obj"]
for param in referred_component._param.query:
if param["key"] == referred_component_key:
if "value" in param:
arguments[input["name"]] = param["value"]
else:
referred_component = self._canvas.get_component(input["component_id"])["obj"]
referred_component_name = referred_component.component_name
referred_component_id = referred_component._id
debug_inputs = self._param.debug_inputs
if debug_inputs:
for param in debug_inputs:
if param["key"] == referred_component_id:
if "value" in param and param["name"] == input["name"]:
arguments[input["name"]] = param["value"]
else:
if referred_component_name.lower() == "answer":
arguments[input["name"]] = self._canvas.get_history(1)[0]["content"]
continue
_, out = referred_component.output(allow_partial=False)
if not out.empty:
arguments[input["name"]] = "\n".join(out["content"])
return self._execute_code(
language=self._param.lang,
code=self._param.script,
arguments=arguments,
address=self._param.address,
enable_network=self._param.enable_network,
)
def _execute_code(self, language: str, code: str, arguments: dict, address: str, enable_network: bool):
import requests
try:
code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e:
return Code.be_output("**Error**: construct code request error: " + str(e))
try:
resp = requests.post(url=address, json=code_req, timeout=10)
body = resp.json()
if body:
stdout = body.get("stdout")
stderr = body.get("stderr")
return Code.be_output(stdout or stderr)
else:
return Code.be_output("**Error**: There is no response from sanbox")
except Exception as e:
return Code.be_output("**Error**: Internal error in sanbox: " + str(e))
def _encode_code(self, code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode("utf-8")
def get_input_elements(self):
elements = []
for input in self._param.arguments:
cpn_id = input["component_id"]
elements.append({"key": cpn_id, "name": input["name"]})
return elements
def debug(self, **kwargs):
return self._run([], **kwargs)

View File

@ -1,66 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
from duckduckgo_search import DDGS
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
class DuckDuckGoParam(ComponentParamBase):
"""
Define the DuckDuckGo component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
self.channel = "text"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.channel, "Web Search or News", ["text", "news"])
class DuckDuckGo(ComponentBase, ABC):
component_name = "DuckDuckGo"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return DuckDuckGo.be_output("")
try:
if self._param.channel == "text":
with DDGS() as ddgs:
# {'title': '', 'href': '', 'body': ''}
duck_res = [{"content": '<a href="' + i["href"] + '">' + i["title"] + '</a> ' + i["body"]} for i
in ddgs.text(ans, max_results=self._param.top_n)]
elif self._param.channel == "news":
with DDGS() as ddgs:
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
duck_res = [{"content": '<a href="' + i["url"] + '">' + i["title"] + '</a> ' + i["body"]} for i
in ddgs.news(ans, max_results=self._param.top_n)]
except Exception as e:
return DuckDuckGo.be_output("**ERROR**: " + str(e))
if not duck_res:
return DuckDuckGo.be_output("")
df = pd.DataFrame(duck_res)
logging.debug("df: {df}")
return df

View File

@ -1,141 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import json
import smtplib
import logging
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.header import Header
from email.utils import formataddr
from agent.component.base import ComponentBase, ComponentParamBase
class EmailParam(ComponentParamBase):
"""
Define the Email component parameters.
"""
def __init__(self):
super().__init__()
# Fixed configuration parameters
self.smtp_server = "" # SMTP server address
self.smtp_port = 465 # SMTP port
self.email = "" # Sender email
self.password = "" # Email authorization code
self.sender_name = "" # Sender name
def check(self):
# Check required parameters
self.check_empty(self.smtp_server, "SMTP Server")
self.check_empty(self.email, "Email")
self.check_empty(self.password, "Password")
self.check_empty(self.sender_name, "Sender Name")
class Email(ComponentBase, ABC):
component_name = "Email"
def _run(self, history, **kwargs):
# Get upstream component output and parse JSON
ans = self.get_input()
content = "".join(ans["content"]) if "content" in ans else ""
if not content:
return Email.be_output("No content to send")
success = False
try:
# Parse JSON string passed from upstream
email_data = json.loads(content)
# Validate required fields
if "to_email" not in email_data:
return Email.be_output("Missing required field: to_email")
# Create email object
msg = MIMEMultipart('alternative')
# Properly handle sender name encoding
msg['From'] = formataddr((str(Header(self._param.sender_name,'utf-8')), self._param.email))
msg['To'] = email_data["to_email"]
if "cc_email" in email_data and email_data["cc_email"]:
msg['Cc'] = email_data["cc_email"]
msg['Subject'] = Header(email_data.get("subject", "No Subject"), 'utf-8').encode()
# Use content from email_data or default content
email_content = email_data.get("content", "No content provided")
# msg.attach(MIMEText(email_content, 'plain', 'utf-8'))
msg.attach(MIMEText(email_content, 'html', 'utf-8'))
# Connect to SMTP server and send
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
context = smtplib.ssl.create_default_context()
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
server.ehlo()
server.starttls(context=context)
server.ehlo()
# Login
logging.info(f"Attempting to login with email: {self._param.email}")
server.login(self._param.email, self._param.password)
# Get all recipient list
recipients = [email_data["to_email"]]
if "cc_email" in email_data and email_data["cc_email"]:
recipients.extend(email_data["cc_email"].split(','))
# Send email
logging.info(f"Sending email to recipients: {recipients}")
try:
server.send_message(msg, self._param.email, recipients)
success = True
except Exception as e:
logging.error(f"Error during send_message: {str(e)}")
# Try alternative method
server.sendmail(self._param.email, recipients, msg.as_string())
success = True
try:
server.quit()
except Exception as e:
# Ignore errors when closing connection
logging.warning(f"Non-fatal error during connection close: {str(e)}")
if success:
return Email.be_output("Email sent successfully")
except json.JSONDecodeError:
error_msg = "Invalid JSON format in input"
logging.error(error_msg)
return Email.be_output(error_msg)
except smtplib.SMTPAuthenticationError:
error_msg = "SMTP Authentication failed. Please check your email and authorization code."
logging.error(error_msg)
return Email.be_output(f"Failed to send email: {error_msg}")
except smtplib.SMTPConnectError:
error_msg = f"Failed to connect to SMTP server {self._param.smtp_server}:{self._param.smtp_port}"
logging.error(error_msg)
return Email.be_output(f"Failed to send email: {error_msg}")
except smtplib.SMTPException as e:
error_msg = f"SMTP error occurred: {str(e)}"
logging.error(error_msg)
return Email.be_output(f"Failed to send email: {error_msg}")
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logging.error(error_msg)
return Email.be_output(f"Failed to send email: {error_msg}")

View File

@ -1,155 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import re
from copy import deepcopy
import pandas as pd
import pymysql
import psycopg2
from agent.component import GenerateParam, Generate
import pyodbc
import logging
class ExeSQLParam(GenerateParam):
"""
Define the ExeSQL component parameters.
"""
def __init__(self):
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.loop = 3
self.top_n = 30
def check(self):
super().check()
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.top_n, "Number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
raise ValueError("For the security reason, it dose not support database named rag_flow.")
if self.password == "infini_rag_flow":
raise ValueError("For the security reason, it dose not support database named rag_flow.")
class ExeSQL(Generate, ABC):
component_name = "ExeSQL"
def _refactor(self, ans):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
match = re.search(r"```sql\s*(.*?)\s*```", ans, re.DOTALL)
if match:
ans = match.group(1) # Query content
return ans
else:
print("no markdown")
ans = re.sub(r'^.*?SELECT ', 'SELECT ', (ans), flags=re.IGNORECASE)
ans = re.sub(r';.*?SELECT ', '; SELECT ', ans, flags=re.IGNORECASE)
ans = re.sub(r';[^;]*$', r';', ans)
if not ans:
raise Exception("SQL statement not found!")
return ans
def _run(self, history, **kwargs):
ans = self.get_input()
ans = "".join([str(a) for a in ans["content"]]) if "content" in ans else ""
ans = self._refactor(ans)
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'postgresql':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'mssql':
conn_str = (
r'DRIVER={ODBC Driver 17 for SQL Server};'
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
r'DATABASE=' + self._param.database + ';'
r'UID=' + self._param.username + ';'
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
try:
cursor = db.cursor()
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
if not hasattr(self, "_loop"):
setattr(self, "_loop", 0)
self._loop += 1
input_list = re.split(r';', ans.replace(r"\n", " "))
sql_res = []
for i in range(len(input_list)):
single_sql = input_list[i]
single_sql = single_sql.replace('```','')
while self._loop <= self._param.loop:
self._loop += 1
if not single_sql:
break
try:
cursor.execute(single_sql)
if cursor.rowcount == 0:
sql_res.append({"content": "No record in the database!"})
break
if self._param.db_type == 'mssql':
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.top_n),
columns=[desc[0] for desc in cursor.description])
else:
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.top_n)])
single_res.columns = [i[0] for i in cursor.description]
sql_res.append({"content": single_res.to_markdown(index=False, floatfmt=".6f")})
break
except Exception as e:
single_sql = self._regenerate_sql(single_sql, str(e), **kwargs)
single_sql = self._refactor(single_sql)
if self._loop > self._param.loop:
sql_res.append({"content": "Can't query the correct data via SQL statement."})
db.close()
if not sql_res:
return ExeSQL.be_output("")
return pd.DataFrame(sql_res)
def _regenerate_sql(self, failed_sql, error_message, **kwargs):
prompt = f'''
## You are the Repair SQL Statement Helper, please modify the original SQL statement based on the SQL query error report.
## The original SQL statement is as follows:{failed_sql}.
## The contents of the SQL query error report is as follows:{error_message}.
## Answer only the modified SQL statement. Please do not give any explanation, just answer the code.
'''
self._param.prompt = prompt
kwargs_ = deepcopy(kwargs)
kwargs_["stream"] = False
response = Generate._run(self, [], **kwargs_)
try:
regenerated_sql = response.loc[0, "content"]
return regenerated_sql
except Exception as e:
logging.error(f"Failed to regenerate SQL: {e}")
return None
def debug(self, **kwargs):
return self._run([], **kwargs)

View File

@ -13,24 +13,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from agent.component.base import ComponentBase, ComponentParamBase
class ConcentratorParam(ComponentParamBase):
"""
Define the Concentrator component parameters.
"""
class UserFillUpParam(ComponentParamBase):
def __init__(self):
super().__init__()
self.enable_tips = True
self.tips = "Please fill up the form"
def check(self):
def check(self) -> bool:
return True
class Concentrator(ComponentBase, ABC):
component_name = "Concentrator"
class UserFillUp(ComponentBase):
component_name = "UserFillUp"
def _invoke(self, **kwargs):
for k, v in kwargs.get("inputs", {}).items():
self.set_output(k, v)
def _run(self, history, **kwargs):
return Concentrator.be_output("")

View File

@ -1,276 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
from functools import partial
from typing import Any
import pandas as pd
from api.db import LLMType
from api.db.services.conversation_service import structure_answer
from api.db.services.llm_service import LLMBundle
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
from plugin import GlobalPluginManager
from plugin.llm_tool_plugin import llm_tool_metadata_to_openai_tool
from rag.llm.chat_model import ToolCallSession
from rag.prompts import message_fit_in
class LLMToolPluginCallSession(ToolCallSession):
def tool_call(self, name: str, arguments: dict[str, Any]) -> str:
tool = GlobalPluginManager.get_llm_tool_by_name(name)
if tool is None:
raise ValueError(f"LLM tool {name} does not exist")
return tool().invoke(**arguments)
class GenerateParam(ComponentParamBase):
"""
Define the Generate component parameters.
"""
def __init__(self):
super().__init__()
self.llm_id = ""
self.prompt = ""
self.max_tokens = 0
self.temperature = 0
self.top_p = 0
self.presence_penalty = 0
self.frequency_penalty = 0
self.cite = True
self.parameters = []
self.llm_enabled_tools = []
def check(self):
self.check_decimal_float(self.temperature, "[Generate] Temperature")
self.check_decimal_float(self.presence_penalty, "[Generate] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "[Generate] Frequency penalty")
self.check_nonnegative_number(self.max_tokens, "[Generate] Max tokens")
self.check_decimal_float(self.top_p, "[Generate] Top P")
self.check_empty(self.llm_id, "[Generate] LLM")
# self.check_defined_type(self.parameters, "Parameters", ["list"])
def gen_conf(self):
conf = {}
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
inputs = self.get_input_elements()
cpnts = set([i["key"] for i in inputs[1:] if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer):
if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True)
chunks = json.loads(retrieval_res["chunks"][0])
answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for ck in chunks],
[ck["vector"] for ck in chunks],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
vtweight=0.3)
doc_ids = set([])
recall_docs = []
for i in idx:
did = chunks[int(i)]["doc_id"]
if did in doc_ids:
continue
doc_ids.add(did)
recall_docs.append({"doc_id": did, "doc_name": chunks[int(i)]["docnm_kwd"]})
for c in chunks:
del c["vector"]
del c["content_ltks"]
reference = {
"chunks": chunks,
"doc_aggs": recall_docs
}
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
res = {"content": answer, "reference": reference}
res = structure_answer(None, res, "", "")
return res
def get_input_elements(self):
key_set = set([])
res = [{"key": "user", "name": "Input your question here:"}]
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.prompt, flags=re.IGNORECASE):
cpn_id = r.group(1)
if cpn_id in key_set:
continue
if cpn_id.lower().find("begin@") == 0:
cpn_id, key = cpn_id.split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] != key:
continue
res.append({"key": r.group(1), "name": p["name"]})
key_set.add(r.group(1))
continue
cpn_nm = self._canvas.get_component_name(cpn_id)
if not cpn_nm:
continue
res.append({"key": cpn_id, "name": cpn_nm})
key_set.add(cpn_id)
return res
def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
if len(self._param.llm_enabled_tools) > 0:
tools = GlobalPluginManager.get_llm_tools_by_names(self._param.llm_enabled_tools)
chat_mdl.bind_tools(
LLMToolPluginCallSession(),
[llm_tool_metadata_to_openai_tool(t.get_metadata()) for t in tools]
)
prompt = self._param.prompt
retrieval_res = []
self._param.inputs = []
for para in self.get_input_elements()[1:]:
if para["key"].lower().find("begin@") == 0:
cpn_id, key = para["key"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
kwargs[para["key"]] = p.get("value", "")
self._param.inputs.append(
{"component_id": para["key"], "content": kwargs[para["key"]]})
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
component_id = para["key"]
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1)
if hist:
hist = hist[0]["content"]
else:
hist = ""
kwargs[para["key"]] = hist
continue
_, out = cpn.output(allow_partial=False)
if "content" not in out.columns:
kwargs[para["key"]] = ""
else:
if cpn.component_name.lower() == "retrieval":
retrieval_res.append(out)
kwargs[para["key"]] = " - " + "\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["key"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
else:
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
if not self._param.inputs and prompt.find("{input}") >= 0:
retrieval_res = self.get_input()
input = (" - " + "\n - ".join(
[c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
prompt = re.sub(r"\{input\}", re.escape(input), prompt)
downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
"obj"].component_name.lower() == "answer":
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1:
msg.append({"role": "user", "content": "Output: "})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2:
msg.append({"role": "user", "content": "Output: "})
ans = chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf())
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
self._canvas.set_component_infor(self._id, {"prompt":msg[0]["content"],"messages": msg[1:],"conf": self._param.gen_conf()})
if self._param.cite and "chunks" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res])
return Generate.be_output(ans)
def stream_output(self, chat_mdl, prompt, retrieval_res):
res = None
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
empty_res = "\n- ".join([str(t) for t in retrieval_res["empty_response"] if str(t)])
res = {"content": empty_res if empty_res else "Nothing found in knowledgebase!", "reference": []}
yield res
self.set_output(res)
return
msg = self._canvas.get_history(self._param.message_history_window_size)
if msg and msg[0]['role'] == 'assistant':
msg.pop(0)
if len(msg) < 1:
msg.append({"role": "user", "content": "Output: "})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2:
msg.append({"role": "user", "content": "Output: "})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}
answer = ans
yield res
if self._param.cite and "chunks" in retrieval_res.columns:
res = self.set_cite(retrieval_res, answer)
yield res
self._canvas.set_component_infor(self._id, {"prompt":msg[0]["content"],"messages": msg[1:],"conf": self._param.gen_conf()})
self.set_output(Generate.be_output(res))
def debug(self, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
for para in self._param.debug_inputs:
kwargs[para["key"]] = para.get("value", "")
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
u = kwargs.get("user")
ans = chat_mdl.chat(prompt, [{"role": "user", "content": u if u else "Output: "}], self._param.gen_conf())
return pd.DataFrame([ans])

View File

@ -1,61 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import pandas as pd
import requests
from agent.component.base import ComponentBase, ComponentParamBase
class GitHubParam(ComponentParamBase):
"""
Define the GitHub component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
def check(self):
self.check_positive_integer(self.top_n, "Top N")
class GitHub(ComponentBase, ABC):
component_name = "GitHub"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return GitHub.be_output("")
try:
url = 'https://api.github.com/search/repositories?q=' + ans + '&sort=stars&order=desc&per_page=' + str(
self._param.top_n)
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
response = requests.get(url=url, headers=headers).json()
github_res = [{"content": '<a href="' + i["html_url"] + '">' + i["name"] + '</a>' + str(
i["description"]) + '\n stars:' + str(i['watchers'])} for i in response['items']]
except Exception as e:
return GitHub.be_output("**ERROR**: " + str(e))
if not github_res:
return GitHub.be_output("")
df = pd.DataFrame(github_res)
logging.debug(f"df: {df}")
return df

View File

@ -1,70 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
from scholarly import scholarly
class GoogleScholarParam(ComponentParamBase):
"""
Define the GoogleScholar component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 6
self.sort_by = 'relevance'
self.year_low = None
self.year_high = None
self.patents = True
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
class GoogleScholar(ComponentBase, ABC):
component_name = "GoogleScholar"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return GoogleScholar.be_output("")
scholar_client = scholarly.search_pubs(ans, patents=self._param.patents, year_low=self._param.year_low,
year_high=self._param.year_high, sort_by=self._param.sort_by)
scholar_res = []
for i in range(self._param.top_n):
try:
pub = next(scholar_client)
scholar_res.append({"content": 'Title: ' + pub['bib']['title'] + '\n_Url: <a href="' + pub[
'pub_url'] + '"></a> ' + "\n author: " + ",".join(pub['bib']['author']) + '\n Abstract: ' + pub[
'bib'].get('abstract', 'no abstract')})
except StopIteration or Exception:
logging.exception("GoogleScholar")
break
if not scholar_res:
return GoogleScholar.be_output("")
df = pd.DataFrame(scholar_res)
logging.debug(f"df: {df}")
return df

View File

@ -14,9 +14,14 @@
# limitations under the License.
#
import json
import logging
import os
import re
import time
from abc import ABC
import requests
from api.utils.api_utils import timeout
from deepdoc.parser import HtmlParser
from agent.component.base import ComponentBase, ComponentParamBase
@ -48,28 +53,14 @@ class InvokeParam(ComponentParamBase):
class Invoke(ComponentBase, ABC):
component_name = "Invoke"
def _run(self, history, **kwargs):
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
def _invoke(self, **kwargs):
args = {}
for para in self._param.variables:
if para.get("component_id"):
if '@' in para["component_id"]:
component = para["component_id"].split('@')[0]
field = para["component_id"].split('@')[1]
cpn = self._canvas.get_component(component)["obj"]
for param in cpn._param.query:
if param["key"] == field:
if "value" in param:
args[para["key"]] = param["value"]
else:
cpn = self._canvas.get_component(para["component_id"])["obj"]
if cpn.component_name.lower() == "answer":
args[para["key"]] = self._canvas.get_history(1)[0]["content"]
continue
_, out = cpn.output(allow_partial=False)
if not out.empty:
args[para["key"]] = "\n".join(out["content"])
else:
if para.get("value") is not None:
args[para["key"]] = para["value"]
else:
args[para["key"]] = self._canvas.get_variable_value(para["ref"])
url = self._param.url.strip()
if url.find("http") != 0:
@ -83,50 +74,66 @@ class Invoke(ComponentBase, ABC):
if re.sub(r"https?:?/?/?", "", self._param.proxy):
proxies = {"http": self._param.proxy, "https": self._param.proxy}
if method == 'get':
response = requests.get(url=url,
params=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
last_e = ""
for _ in range(self._param.max_retries+1):
try:
if method == 'get':
response = requests.get(url=url,
params=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
self.set_output("result", "\n".join(sections))
else:
self.set_output("result", response.text)
return Invoke.be_output(response.text)
if method == 'put':
if self._param.datatype.lower() == 'json':
response = requests.put(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
else:
response = requests.put(url=url,
data=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
self.set_output("result", "\n".join(sections))
else:
self.set_output("result", response.text)
if method == 'put':
if self._param.datatype.lower() == 'json':
response = requests.put(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
else:
response = requests.put(url=url,
data=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)
if method == 'post':
if self._param.datatype.lower() == 'json':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
else:
response = requests.post(url=url,
data=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
self.set_output("result", "\n".join(sections))
else:
self.set_output("result", response.text)
if method == 'post':
if self._param.datatype.lower() == 'json':
response = requests.post(url=url,
json=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
else:
response = requests.post(url=url,
data=args,
headers=headers,
proxies=proxies,
timeout=self._param.timeout)
if self._param.clean_html:
sections = HtmlParser()(None, response.content)
return Invoke.be_output("\n".join(sections))
return Invoke.be_output(response.text)
return self.output("result")
except Exception as e:
last_e = e
logging.exception(f"Http request error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"Http request error: {last_e}"
assert False, self.output()

View File

@ -24,10 +24,18 @@ class IterationParam(ComponentParamBase):
def __init__(self):
super().__init__()
self.delimiter = ","
self.items_ref = ""
def get_input_form(self) -> dict[str, dict]:
return {
"items": {
"type": "json",
"name": "Items"
}
}
def check(self):
self.check_empty(self.delimiter, "Delimiter")
return True
class Iteration(ComponentBase, ABC):
@ -38,8 +46,13 @@ class Iteration(ComponentBase, ABC):
if self._canvas.get_component(cid)["obj"].component_name.lower() != "iterationitem":
continue
if self._canvas.get_component(cid)["parent_id"] == self._id:
return self._canvas.get_component(cid)
return cid
def _invoke(self, **kwargs):
arr = self._canvas.get_variable_value(self._param.items_ref)
if not isinstance(arr, list):
self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr)))
def _run(self, history, **kwargs):
return self.output(allow_partial=False)[1]

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
from abc import ABC
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
@ -33,20 +32,49 @@ class IterationItem(ComponentBase, ABC):
super().__init__(canvas, id, param)
self._idx = 0
def _run(self, history, **kwargs):
def _invoke(self, **kwargs):
parent = self.get_parent()
ans = parent.get_input()
ans = parent._param.delimiter.join(ans["content"]) if "content" in ans else ""
ans = [a.strip() for a in ans.split(parent._param.delimiter)]
if not ans:
arr = self._canvas.get_variable_value(parent._param.items_ref)
if not isinstance(arr, list):
self._idx = -1
return pd.DataFrame()
raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr)))
df = pd.DataFrame([{"content": ans[self._idx]}])
self._idx += 1
if self._idx >= len(ans):
if self._idx > 0:
self.output_collation()
if self._idx >= len(arr):
self._idx = -1
return df
return
self.set_output("item", arr[self._idx])
self.set_output("index", self._idx)
self._idx += 1
def output_collation(self):
pid = self.get_parent()._id
for cid in self._canvas.components.keys():
obj = self._canvas.get_component_obj(cid)
p = obj.get_parent()
if not p:
continue
if p._id != pid:
continue
if p.component_name.lower() in ["categorize", "message", "switch", "userfillup", "interationitem"]:
continue
for k, o in p._param.outputs.items():
if "ref" not in o:
continue
_cid, var = o["ref"].split("@")
if _cid != cid:
continue
res = p.output(k)
if not res:
res = []
res.append(obj.output(var))
p.set_output(k, res)
def end(self):
return self._idx == -1

View File

@ -1,72 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate
class KeywordExtractParam(GenerateParam):
"""
Define the KeywordExtract component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 1
def check(self):
super().check()
self.check_positive_integer(self.top_n, "Top N")
def get_prompt(self):
self.prompt = """
- Role: You're a question analyzer.
- Requirements:
- Summarize user's question, and give top %s important keyword/phrase.
- Use comma as a delimiter to separate keywords/phrases.
- Answer format: (in language of user's question)
- keyword:
""" % self.top_n
return self.prompt
class KeywordExtract(Generate, ABC):
component_name = "KeywordExtract"
def _run(self, history, **kwargs):
query = self.get_input()
if hasattr(query, "to_dict") and "content" in query:
query = ", ".join(map(str, query["content"].dropna()))
else:
query = str(query)
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
self._canvas.set_component_infor(self._id, {"prompt":self._param.get_prompt(),"messages": [{"role": "user", "content": query}],"conf": self._param.gen_conf()})
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": query}],
self._param.gen_conf())
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r".*keyword:", "", ans).strip()
logging.debug(f"ans: {ans}")
return KeywordExtract.be_output(ans)
def debug(self, **kwargs):
return self._run([], **kwargs)

242
agent/component/llm.py Normal file
View File

@ -0,0 +1,242 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import os
import re
from typing import Any
import json_repair
from copy import deepcopy
from functools import partial
from api.db.services.llm_service import LLMBundle, TenantLLMService
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
from rag.prompts import message_fit_in, citation_prompt
from rag.prompts.prompts import tool_call_summary
class LLMParam(ComponentParamBase):
"""
Define the LLM component parameters.
"""
def __init__(self):
super().__init__()
self.llm_id = ""
self.sys_prompt = ""
self.prompts = [{"role": "user", "content": "{sys.query}"}]
self.max_tokens = 0
self.temperature = 0
self.top_p = 0
self.presence_penalty = 0
self.frequency_penalty = 0
self.output_structure = None
self.cite = True
self.visual_files_var = None
def check(self):
self.check_decimal_float(self.temperature, "[Agent] Temperature")
self.check_decimal_float(self.presence_penalty, "[Agent] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "[Agent] Frequency penalty")
self.check_nonnegative_number(self.max_tokens, "[Agent] Max tokens")
self.check_decimal_float(self.top_p, "[Agent] Top P")
self.check_empty(self.llm_id, "[Agent] LLM")
self.check_empty(self.sys_prompt, "[Agent] System prompt")
self.check_empty(self.prompts, "[Agent] User prompt")
def gen_conf(self):
conf = {}
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
class LLM(ComponentBase):
component_name = "LLM"
def __init__(self, canvas, id, param: ComponentParamBase):
super().__init__(canvas, id, param)
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id),
self._param.llm_id, max_retries=self._param.max_retries,
retry_interval=self._param.delay_after_error
)
self.imgs = []
def get_input_form(self) -> dict[str, dict]:
res = {}
for k, v in self.get_input_elements().items():
res[k] = {
"type": "line",
"name": v["name"]
}
return res
def get_input_elements(self) -> dict[str, Any]:
res = self.get_input_elements_from_text(self._param.sys_prompt)
for prompt in self._param.prompts:
d = self.get_input_elements_from_text(prompt["content"])
res.update(d)
return res
def set_debug_inputs(self, inputs: dict[str, dict]):
self._param.debug_inputs = inputs
def add2system_prompt(self, txt):
self._param.sys_prompt += txt
def _prepare_prompt_variables(self):
if self._param.visual_files_var:
self.imgs = self._canvas.get_variable_value(self._param.visual_files_var)
if not self.imgs:
self.imgs = []
self.imgs = [img for img in self.imgs if img[:len("data:image/")] == "data:image/"]
args = {}
vars = self.get_input_elements() if not self._param.debug_inputs else self._param.debug_inputs
prompt = self._param.sys_prompt
for k, o in vars.items():
args[k] = o["value"]
if not isinstance(args[k], str):
try:
args[k] = json.dumps(args[k], ensure_ascii=False)
except Exception:
args[k] = str(args[k])
self.set_input_value(k, args[k])
msg = self._canvas.get_history(self._param.message_history_window_size)[:-1]
msg.extend(deepcopy(self._param.prompts))
prompt = self.string_format(prompt, args)
for m in msg:
m["content"] = self.string_format(m["content"], args)
if self._canvas.get_reference()["chunks"]:
prompt += citation_prompt()
return prompt, msg
def _generate(self, msg:list[dict], **kwargs) -> str:
if not self.imgs:
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> str:
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx+delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
if not self.imgs:
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs):
yield delta(txt)
else:
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
def clean_formated_answer(ans: str) -> str:
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL)
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
prompt, msg = self._prepare_prompt_variables()
error = ""
if self._param.output_structure:
prompt += "\nThe output MUST follow this JSON format:\n"+json.dumps(self._param.output_structure, ensure_ascii=False, indent=2)
prompt += "\nRedundant information is FORBIDDEN."
for _ in range(self._param.max_retries+1):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = self._generate(msg)
msg.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
continue
try:
self.set_output("structured_content", json_repair.loads(clean_formated_answer(ans)))
return
except Exception:
msg.append({"role": "user", "content": "The answer can't not be parsed as JSON"})
error = "The answer can't not be parsed as JSON"
if error:
self.set_output("_ERROR", error)
return
downstreams = self._canvas.get_component(self._id)["downstream"] if self._canvas.get_component(self._id) else []
if any([self._canvas.get_component_obj(cid).component_name.lower()=="message" for cid in downstreams]) and not self._param.output_structure:
self.set_output("content", partial(self._stream_output, prompt, msg))
return
for _ in range(self._param.max_retries+1):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
error = ""
ans = self._generate(msg)
msg.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
error = ans
continue
self.set_output("content", ans)
break
if error:
self.set_output("_ERROR", error)
if self.get_exception_default_value():
self.set_output("content", self.get_exception_default_value())
def _stream_output(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
answer = ""
for ans in self._generate_streamly(msg):
yield ans
answer += ans
self.set_output("content", answer)
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str):
summ = tool_call_summary(self.chat_mdl, func_name, params, results)
logging.info(f"[MEMORY]: {summ}")
self._canvas.add_memory(user, assist, summ)

View File

@ -13,43 +13,132 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import os
import random
from abc import ABC
import re
from functools import partial
from typing import Any
from agent.component.base import ComponentBase, ComponentParamBase
from jinja2 import Template as Jinja2Template
from api.utils.api_utils import timeout
class MessageParam(ComponentParamBase):
"""
Define the Message component parameters.
"""
def __init__(self):
super().__init__()
self.messages = []
self.content = []
self.stream = True
self.outputs = {
"content": {
"type": "str"
}
}
def check(self):
self.check_empty(self.messages, "[Message]")
self.check_empty(self.content, "[Message] Content")
self.check_boolean(self.stream, "[Message] stream")
return True
class Message(ComponentBase, ABC):
class Message(ComponentBase):
component_name = "Message"
def _run(self, history, **kwargs):
if kwargs.get("stream"):
return partial(self.stream_output)
def get_kwargs(self, script:str, kwargs:dict = {}, delimeter:str=None) -> tuple[str, dict[str, str | list | Any]]:
for k,v in self.get_input_elements_from_text(script).items():
if k in kwargs:
continue
v = v["value"]
ans = ""
if isinstance(v, partial):
for t in v():
ans += t
elif isinstance(v, list) and delimeter:
ans = delimeter.join([str(vv) for vv in v])
elif not isinstance(v, str):
try:
ans = json.dumps(v, ensure_ascii=False)
except Exception:
pass
else:
ans = v
if not ans:
ans = ""
kwargs[k] = ans
self.set_input_value(k, ans)
res = Message.be_output(random.choice(self._param.messages))
self.set_output(res)
return res
_kwargs = {}
for n, v in kwargs.items():
_n = re.sub("[@:.]", "_", n)
script = re.sub(r"\{%s\}" % re.escape(n), _n, script)
_kwargs[_n] = v
return script, _kwargs
def stream_output(self):
res = None
if self._param.messages:
res = {"content": random.choice(self._param.messages)}
yield res
def _stream(self, rand_cnt:str):
s = 0
all_content = ""
cache = {}
for r in re.finditer(self.variable_ref_patt, rand_cnt, flags=re.DOTALL):
all_content += rand_cnt[s: r.start()]
yield rand_cnt[s: r.start()]
s = r.end()
exp = r.group(1)
if exp in cache:
yield cache[exp]
all_content += cache[exp]
continue
self.set_output(res)
v = self._canvas.get_variable_value(exp)
if isinstance(v, partial):
cnt = ""
for t in v():
all_content += t
cnt += t
yield t
continue
elif not isinstance(v, str):
try:
v = json.dumps(v, ensure_ascii=False, indent=2)
except Exception:
v = str(v)
yield v
all_content += v
cache[exp] = v
if s < len(rand_cnt):
all_content += rand_cnt[s: ]
yield rand_cnt[s: ]
self.set_output("content", all_content)
def _is_jinjia2(self, content:str) -> bool:
patt = [
r"\{%.*%\}", "{{", "}}"
]
return any([re.search(p, content) for p in patt])
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
rand_cnt = random.choice(self._param.content)
if self._param.stream and not self._is_jinjia2(rand_cnt):
self.set_output("content", partial(self._stream, rand_cnt))
return
rand_cnt, kwargs = self.get_kwargs(rand_cnt, kwargs)
template = Jinja2Template(rand_cnt)
try:
content = template.render(kwargs)
except Exception:
pass
for n, v in kwargs.items():
content = re.sub(n, v, content)
self.set_output("content", content)

View File

@ -1,69 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
from Bio import Entrez
import re
import pandas as pd
import xml.etree.ElementTree as ET
from agent.component.base import ComponentBase, ComponentParamBase
class PubMedParam(ComponentParamBase):
"""
Define the PubMed component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 5
self.email = "A.N.Other@example.com"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
class PubMed(ComponentBase, ABC):
component_name = "PubMed"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return PubMed.be_output("")
try:
Entrez.email = self._param.email
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=ans))['IdList']
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
retmode="xml").read().decode(
"utf-8")))
pubmed_res = [{"content": 'Title:' + child.find("MedlineCitation").find("Article").find(
"ArticleTitle").text + '\nUrl:<a href=" https://pubmed.ncbi.nlm.nih.gov/' + child.find(
"MedlineCitation").find("PMID").text + '">' + '</a>\n' + 'Abstract:' + (
child.find("MedlineCitation").find("Article").find("Abstract").find(
"AbstractText").text if child.find("MedlineCitation").find(
"Article").find("Abstract") else "No abstract available")} for child in
pubmedcnt.findall("PubmedArticle")]
except Exception as e:
return PubMed.be_output("**ERROR**: " + str(e))
if not pubmed_res:
return PubMed.be_output("")
df = pd.DataFrame(pubmed_res)
logging.debug(f"df: {df}")
return df

View File

@ -1,83 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from agent.component import GenerateParam, Generate
from rag.utils import num_tokens_from_string, encoder
class RelevantParam(GenerateParam):
"""
Define the Relevant component parameters.
"""
def __init__(self):
super().__init__()
self.prompt = ""
self.yes = ""
self.no = ""
def check(self):
super().check()
self.check_empty(self.yes, "[Relevant] 'Yes'")
self.check_empty(self.no, "[Relevant] 'No'")
def get_prompt(self):
self.prompt = """
You are a grader assessing relevance of a retrieved document to a user question.
It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.
No other words needed except 'yes' or 'no'.
"""
return self.prompt
class Relevant(Generate, ABC):
component_name = "Relevant"
def _run(self, history, **kwargs):
q = ""
for r, c in self._canvas.history[::-1]:
if r == "user":
q = c
break
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return Relevant.be_output(self._param.no)
ans = "Documents: \n" + ans
ans = f"Question: {q}\n" + ans
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
if num_tokens_from_string(ans) >= chat_mdl.max_length - 4:
ans = encoder.decode(encoder.encode(ans)[:chat_mdl.max_length - 4])
ans = chat_mdl.chat(self._param.get_prompt(), [{"role": "user", "content": ans}],
self._param.gen_conf())
logging.debug(ans)
if ans.lower().find("yes") >= 0:
return Relevant.be_output(self._param.yes)
if ans.lower().find("no") >= 0:
return Relevant.be_output(self._param.no)
assert False, f"Relevant component got: {ans}"
def debug(self, **kwargs):
return self._run([], **kwargs)

View File

@ -1,135 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import re
from abc import ABC
import pandas as pd
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
from rag.app.tag import label_question
from rag.prompts import kb_prompt
from rag.utils.tavily_conn import Tavily
class RetrievalParam(ComponentParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
super().__init__()
self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.5
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
self.tavily_api_key = ""
self.use_kg = False
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
self.check_positive_number(self.top_n, "[Retrieval] Top N")
class Retrieval(ComponentBase, ABC):
component_name = "Retrieval"
def _run(self, history, **kwargs):
query = self.get_input()
query = str(query["content"][0]) if "content" in query else ""
query = re.split(r"(USER:|ASSISTANT:)", query)[-1]
kb_ids: list[str] = self._param.kb_ids or []
kb_vars = self._fetch_outputs_from(self._param.kb_vars)
if len(kb_vars) > 0:
for kb_var in kb_vars:
if len(kb_var) == 1:
kb_var_value = str(kb_var["content"][0])
for v in kb_var_value.split(","):
kb_ids.append(v)
else:
for v in kb_var.to_dict("records"):
kb_ids.append(v["content"])
filtered_kb_ids: list[str] = [kb_id for kb_id in kb_ids if kb_id]
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs:
return Retrieval.be_output("")
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
embd_mdl = None
if embd_nms:
embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0])
self._canvas.set_embedding_model(embd_nms[0])
rerank_mdl = None
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retrievaler.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
filtered_kb_ids,
1,
self._param.top_n,
self._param.similarity_threshold,
1 - self._param.keywords_similarity_weight,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
)
else:
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
if self._param.tavily_api_key:
tav = Tavily(self._param.tavily_api_key)
tav_res = tav.retrieve_chunks(query)
kbinfos["chunks"].extend(tav_res["chunks"])
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
if not kbinfos["chunks"]:
df = Retrieval.be_output("")
if self._param.empty_response and self._param.empty_response.strip():
df["empty_response"] = self._param.empty_response
return df
df = pd.DataFrame({"content": kb_prompt(kbinfos, 200000), "chunks": json.dumps(kbinfos["chunks"])})
logging.debug("{} {}".format(query, df))
return df.dropna()

View File

@ -1,94 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
from agent.component import GenerateParam, Generate
from rag.prompts import full_question
class RewriteQuestionParam(GenerateParam):
"""
Define the QuestionRewrite component parameters.
"""
def __init__(self):
super().__init__()
self.temperature = 0.9
self.prompt = ""
self.language = ""
def check(self):
super().check()
class RewriteQuestion(Generate, ABC):
component_name = "RewriteQuestion"
def _run(self, history, **kwargs):
hist = self._canvas.get_history(self._param.message_history_window_size)
query = self.get_input()
query = str(query["content"][0]) if "content" in query else ""
messages = [h for h in hist if h["role"]!="system"]
if messages[-1]["role"] != "user":
messages.append({"role": "user", "content": query})
ans = full_question(self._canvas.get_tenant_id(), self._param.llm_id, messages, self.gen_lang(self._param.language))
self._canvas.history.pop()
self._canvas.history.append(("user", ans))
return RewriteQuestion.be_output(ans)
@staticmethod
def gen_lang(language):
# convert code lang to language word for the prompt
language_dict = {'af': 'Afrikaans', 'ak': 'Akan', 'sq': 'Albanian', 'ws': 'Samoan', 'am': 'Amharic',
'ar': 'Arabic', 'hy': 'Armenian', 'az': 'Azerbaijani', 'eu': 'Basque', 'be': 'Belarusian',
'bem': 'Bemba', 'bn': 'Bengali', 'bh': 'Bihari',
'xx-bork': 'Bork', 'bs': 'Bosnian', 'br': 'Breton', 'bg': 'Bulgarian', 'bt': 'Bhutani',
'km': 'Cambodian', 'ca': 'Catalan', 'chr': 'Cherokee', 'ny': 'Chichewa', 'zh-cn': 'Chinese',
'zh-tw': 'Chinese', 'co': 'Corsican',
'hr': 'Croatian', 'cs': 'Czech', 'da': 'Danish', 'nl': 'Dutch', 'xx-elmer': 'Elmer',
'en': 'English', 'eo': 'Esperanto', 'et': 'Estonian', 'ee': 'Ewe', 'fo': 'Faroese',
'tl': 'Filipino', 'fi': 'Finnish', 'fr': 'French',
'fy': 'Frisian', 'gaa': 'Ga', 'gl': 'Galician', 'ka': 'Georgian', 'de': 'German',
'el': 'Greek', 'kl': 'Greenlandic', 'gn': 'Guarani', 'gu': 'Gujarati', 'xx-hacker': 'Hacker',
'ht': 'Haitian Creole', 'ha': 'Hausa', 'haw': 'Hawaiian',
'iw': 'Hebrew', 'hi': 'Hindi', 'hu': 'Hungarian', 'is': 'Icelandic', 'ig': 'Igbo',
'id': 'Indonesian', 'ia': 'Interlingua', 'ga': 'Irish', 'it': 'Italian', 'ja': 'Japanese',
'jw': 'Javanese', 'kn': 'Kannada', 'kk': 'Kazakh', 'rw': 'Kinyarwanda',
'rn': 'Kirundi', 'xx-klingon': 'Klingon', 'kg': 'Kongo', 'ko': 'Korean', 'kri': 'Krio',
'ku': 'Kurdish', 'ckb': 'Kurdish (Sorani)', 'ky': 'Kyrgyz', 'lo': 'Laothian', 'la': 'Latin',
'lv': 'Latvian', 'ln': 'Lingala', 'lt': 'Lithuanian',
'loz': 'Lozi', 'lg': 'Luganda', 'ach': 'Luo', 'mk': 'Macedonian', 'mg': 'Malagasy',
'ms': 'Malay', 'ml': 'Malayalam', 'mt': 'Maltese', 'mv': 'Maldivian', 'mi': 'Maori',
'mr': 'Marathi', 'mfe': 'Mauritian Creole', 'mo': 'Moldavian', 'mn': 'Mongolian',
'sr-me': 'Montenegrin', 'my': 'Burmese', 'ne': 'Nepali', 'pcm': 'Nigerian Pidgin',
'nso': 'Northern Sotho', 'no': 'Norwegian', 'nn': 'Norwegian Nynorsk', 'oc': 'Occitan',
'or': 'Oriya', 'om': 'Oromo', 'ps': 'Pashto', 'fa': 'Persian',
'xx-pirate': 'Pirate', 'pl': 'Polish', 'pt': 'Portuguese', 'pt-br': 'Portuguese (Brazilian)',
'pt-pt': 'Portuguese (Portugal)', 'pa': 'Punjabi', 'qu': 'Quechua', 'ro': 'Romanian',
'rm': 'Romansh', 'nyn': 'Runyankole', 'ru': 'Russian', 'gd': 'Scots Gaelic',
'sr': 'Serbian', 'sh': 'Serbo-Croatian', 'st': 'Sesotho', 'tn': 'Setswana',
'crs': 'Seychellois Creole', 'sn': 'Shona', 'sd': 'Sindhi', 'si': 'Sinhalese', 'sk': 'Slovak',
'sl': 'Slovenian', 'so': 'Somali', 'es': 'Spanish', 'es-419': 'Spanish (Latin America)',
'su': 'Sundanese',
'sw': 'Swahili', 'sv': 'Swedish', 'tg': 'Tajik', 'ta': 'Tamil', 'tt': 'Tatar', 'te': 'Telugu',
'th': 'Thai', 'ti': 'Tigrinya', 'to': 'Tongan', 'lua': 'Tshiluba', 'tum': 'Tumbuka',
'tr': 'Turkish', 'tk': 'Turkmen', 'tw': 'Twi',
'ug': 'Uyghur', 'uk': 'Ukrainian', 'ur': 'Urdu', 'uz': 'Uzbek', 'vu': 'Vanuatu',
'vi': 'Vietnamese', 'cy': 'Welsh', 'wo': 'Wolof', 'xh': 'Xhosa', 'yi': 'Yiddish',
'yo': 'Yoruba', 'zu': 'Zulu'}
if language in language_dict:
return language_dict[language]
else:
return ""

View File

@ -0,0 +1,98 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
from abc import ABC
from jinja2 import Template as Jinja2Template
from agent.component.base import ComponentParamBase
from api.utils.api_utils import timeout
from .message import Message
class StringTransformParam(ComponentParamBase):
"""
Define the code sandbox component parameters.
"""
def __init__(self):
super().__init__()
self.method = "split"
self.script = ""
self.split_ref = ""
self.delimiters = [","]
self.outputs = {"result": {"value": "", "type": "string"}}
def check(self):
self.check_valid_value(self.method, "Support method", ["split", "merge"])
self.check_empty(self.delimiters, "delimiters")
class StringTransform(Message, ABC):
component_name = "StringTransform"
def get_input_form(self) -> dict[str, dict]:
if self._param.method == "split":
return {
"line": {
"name": "String",
"type": "line"
}
}
return {k: {
"name": o["name"],
"type": "line"
} for k, o in self.get_input_elements_from_text(self._param.script).items()}
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
if self._param.method == "split":
self._split(kwargs.get("line"))
else:
self._merge(kwargs)
def _split(self, line:str|None = None):
var = self._canvas.get_variable_value(self._param.split_ref) if not line else line
if not var:
var = ""
assert isinstance(var, str), "The input variable is not a string: {}".format(type(var))
self.set_input_value(self._param.split_ref, var)
res = []
for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)):
if i % 2 == 1:
continue
res.append(s)
self.set_output("result", res)
def _merge(self, kwargs:dict[str, str] = {}):
script = self._param.script
script, kwargs = self.get_kwargs(script, kwargs, self._param.delimiters[0])
if self._is_jinjia2(script):
template = Jinja2Template(script)
try:
script = template.render(kwargs)
except Exception:
pass
for k,v in kwargs.items():
if not v:
v = ""
script = re.sub(k, v, script)
self.set_output("result", script)

View File

@ -13,8 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import numbers
import os
from abc import ABC
from typing import Any
from agent.component.base import ComponentBase, ComponentParamBase
from api.utils.api_utils import timeout
class SwitchParam(ComponentParamBase):
@ -34,7 +39,7 @@ class SwitchParam(ComponentParamBase):
}
"""
self.conditions = []
self.end_cpn_id = "answer:0"
self.end_cpn_ids = []
self.operators = ['contains', 'not contains', 'start with', 'end with', 'empty', 'not empty', '=', '', '>',
'<', '', '']
@ -43,54 +48,46 @@ class SwitchParam(ComponentParamBase):
for cond in self.conditions:
if not cond["to"]:
raise ValueError("[Switch] 'To' can not be empty!")
self.check_empty(self.end_cpn_ids, "[Switch] the ELSE/Other destination can not be empty.")
def get_input_form(self) -> dict[str, dict]:
return {
"urls": {
"name": "URLs",
"type": "line"
}
}
class Switch(ComponentBase, ABC):
component_name = "Switch"
def get_dependent_components(self):
res = []
for cond in self._param.conditions:
for item in cond["items"]:
if not item["cpn_id"]:
continue
if item["cpn_id"].lower().find("begin") >= 0 or item["cpn_id"].lower().find("answer") >= 0:
continue
cid = item["cpn_id"].split("@")[0]
res.append(cid)
return list(set(res))
def _run(self, history, **kwargs):
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))
def _invoke(self, **kwargs):
for cond in self._param.conditions:
res = []
for item in cond["items"]:
if not item["cpn_id"]:
continue
cid = item["cpn_id"].split("@")[0]
if item["cpn_id"].find("@") > 0:
cpn_id, key = item["cpn_id"].split("@")
for p in self._canvas.get_component(cid)["obj"]._param.query:
if p["key"] == key:
res.append(self.process_operator(p.get("value",""), item["operator"], item.get("value", "")))
break
else:
out = self._canvas.get_component(cid)["obj"].output(allow_partial=False)[1]
cpn_input = "" if "content" not in out.columns else " ".join([str(s) for s in out["content"]])
res.append(self.process_operator(cpn_input, item["operator"], item.get("value", "")))
cpn_v = self._canvas.get_variable_value(item["cpn_id"])
self.set_input_value(item["cpn_id"], cpn_v)
operatee = item.get("value", "")
if isinstance(cpn_v, numbers.Number):
operatee = float(operatee)
res.append(self.process_operator(cpn_v, item["operator"], operatee))
if cond["logical_operator"] != "and" and any(res):
return Switch.be_output(cond["to"])
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in cond["to"]])
self.set_output("_next", cond["to"])
return
if all(res):
return Switch.be_output(cond["to"])
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in cond["to"]])
self.set_output("_next", cond["to"])
return
return Switch.be_output(self._param.end_cpn_id)
def process_operator(self, input: str, operator: str, value: str) -> bool:
if not isinstance(input, str) or not isinstance(value, str):
raise ValueError('Invalid input or value type: string')
self.set_output("next", [self._canvas.get_component_name(cpn_id) for cpn_id in self._param.end_cpn_ids])
self.set_output("_next", self._param.end_cpn_ids)
def process_operator(self, input: Any, operator: str, value: Any) -> bool:
if operator == "contains":
return True if value.lower() in input.lower() else False
elif operator == "not contains":

View File

@ -1,147 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
from jinja2 import StrictUndefined
from jinja2.sandbox import SandboxedEnvironment
from agent.component.base import ComponentBase, ComponentParamBase
class TemplateParam(ComponentParamBase):
"""
Define the Generate component parameters.
"""
def __init__(self):
super().__init__()
self.content = ""
self.parameters = []
def check(self):
self.check_empty(self.content, "[Template] Content")
return True
class Template(ComponentBase):
component_name = "Template"
def get_dependent_components(self):
inputs = self.get_input_elements()
cpnts = set([i["key"] for i in inputs if i["key"].lower().find("answer") < 0 and i["key"].lower().find("begin") < 0])
return list(cpnts)
def get_input_elements(self):
key_set = set([])
res = []
for r in re.finditer(r"\{([a-z]+[:@][a-z0-9_-]+)\}", self._param.content, flags=re.IGNORECASE):
cpn_id = r.group(1)
if cpn_id in key_set:
continue
if cpn_id.lower().find("begin@") == 0:
cpn_id, key = cpn_id.split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] != key:
continue
res.append({"key": r.group(1), "name": p["name"]})
key_set.add(r.group(1))
continue
cpn_nm = self._canvas.get_component_name(cpn_id)
if not cpn_nm:
continue
res.append({"key": cpn_id, "name": cpn_nm})
key_set.add(cpn_id)
return res
def _run(self, history, **kwargs):
content = self._param.content
self._param.inputs = []
for para in self.get_input_elements():
if para["key"].lower().find("begin@") == 0:
cpn_id, key = para["key"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
value = p.get("value", "")
self.make_kwargs(para, kwargs, value)
origin_pattern = "{begin@" + key + "}"
new_pattern = "begin_" + key
content = content.replace(origin_pattern, new_pattern)
kwargs[new_pattern] = kwargs.pop(origin_pattern, "")
break
else:
assert False, f"Can't find parameter '{key}' for {cpn_id}"
continue
component_id = para["key"]
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1)
if hist:
hist = hist[0]["content"]
else:
hist = ""
self.make_kwargs(para, kwargs, hist)
if ":" in component_id:
origin_pattern = "{" + component_id + "}"
new_pattern = component_id.replace(":", "_")
content = content.replace(origin_pattern, new_pattern)
kwargs[new_pattern] = kwargs.pop(component_id, "")
continue
_, out = cpn.output(allow_partial=False)
result = ""
if "content" in out.columns:
result = "\n".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self.make_kwargs(para, kwargs, result)
env = SandboxedEnvironment(
autoescape=True,
undefined=StrictUndefined,
)
template = env.from_string(content)
try:
content = template.render(kwargs)
except Exception:
pass
for n, v in kwargs.items():
if not isinstance(v, str):
try:
v = json.dumps(v, ensure_ascii=False)
except Exception:
pass
# Process backslashes in strings, Use Lambda function to avoid escape issues
if isinstance(v, str):
v = v.replace("\\", "\\\\")
content = re.sub(r"\{%s\}" % re.escape(n), lambda match: v, content)
content = re.sub(r"(#+)", r" \1 ", content)
return Template.be_output(content)
def make_kwargs(self, para, kwargs, value):
self._param.inputs.append({"component_id": para["key"], "content": value})
try:
value = json.loads(value)
except Exception:
pass
kwargs[para["key"]] = value

View File

@ -1,80 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from abc import ABC
import pandas as pd
import pywencai
from agent.component.base import ComponentBase, ComponentParamBase
class WenCaiParam(ComponentParamBase):
"""
Define the WenCai component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
self.query_type = "stock"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.query_type, "Query type",
['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance',
'futures', 'lccp',
'foreign_exchange'])
class WenCai(ComponentBase, ABC):
component_name = "WenCai"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = ",".join(ans["content"]) if "content" in ans else ""
if not ans:
return WenCai.be_output("")
try:
wencai_res = []
res = pywencai.get(query=ans, query_type=self._param.query_type, perpage=self._param.top_n)
if isinstance(res, pd.DataFrame):
wencai_res.append({"content": res.to_markdown()})
if isinstance(res, dict):
for item in res.items():
if isinstance(item[1], list):
wencai_res.append({"content": item[0] + "\n" + pd.DataFrame(item[1]).to_markdown()})
continue
if isinstance(item[1], str):
wencai_res.append({"content": item[0] + "\n" + item[1]})
continue
if isinstance(item[1], dict):
if "meta" in item[1].keys():
continue
wencai_res.append({"content": pd.DataFrame.from_dict(item[1], orient='index').to_markdown()})
continue
if isinstance(item[1], pd.DataFrame):
if "image_url" in item[1].columns:
continue
wencai_res.append({"content": item[1].to_markdown()})
continue
wencai_res.append({"content": item[0] + "\n" + str(item[1])})
except Exception as e:
return WenCai.be_output("**ERROR**: " + str(e))
if not wencai_res:
return WenCai.be_output("")
return pd.DataFrame(wencai_res)

View File

@ -1,67 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import wikipedia
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
class WikipediaParam(ComponentParamBase):
"""
Define the Wikipedia component parameters.
"""
def __init__(self):
super().__init__()
self.top_n = 10
self.language = "en"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.language, "Wikipedia languages",
['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de',
'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id',
'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl',
'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh',
'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue'])
class Wikipedia(ComponentBase, ABC):
component_name = "Wikipedia"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return Wikipedia.be_output("")
try:
wiki_res = []
wikipedia.set_lang(self._param.language)
wiki_engine = wikipedia
for wiki_key in wiki_engine.search(ans, results=self._param.top_n):
page = wiki_engine.page(title=wiki_key, auto_suggest=False)
wiki_res.append({"content": '<a href="' + page.url + '">' + page.title + '</a> ' + page.summary})
except Exception as e:
return Wikipedia.be_output("**ERROR**: " + str(e))
if not wiki_res:
return Wikipedia.be_output("")
df = pd.DataFrame(wiki_res)
logging.debug(f"df: {df}")
return df

View File

@ -1,84 +0,0 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
from abc import ABC
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
import yfinance as yf
class YahooFinanceParam(ComponentParamBase):
"""
Define the YahooFinance component parameters.
"""
def __init__(self):
super().__init__()
self.info = True
self.history = False
self.count = False
self.financials = False
self.income_stmt = False
self.balance_sheet = False
self.cash_flow_statement = False
self.news = True
def check(self):
self.check_boolean(self.info, "get all stock info")
self.check_boolean(self.history, "get historical market data")
self.check_boolean(self.count, "show share count")
self.check_boolean(self.financials, "show financials")
self.check_boolean(self.income_stmt, "income statement")
self.check_boolean(self.balance_sheet, "balance sheet")
self.check_boolean(self.cash_flow_statement, "cash flow statement")
self.check_boolean(self.news, "show news")
class YahooFinance(ComponentBase, ABC):
component_name = "YahooFinance"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = "".join(ans["content"]) if "content" in ans else ""
if not ans:
return YahooFinance.be_output("")
yohoo_res = []
try:
msft = yf.Ticker(ans)
if self._param.info:
yohoo_res.append({"content": "info:\n" + pd.Series(msft.info).to_markdown() + "\n"})
if self._param.history:
yohoo_res.append({"content": "history:\n" + msft.history().to_markdown() + "\n"})
if self._param.financials:
yohoo_res.append({"content": "calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n"})
if self._param.balance_sheet:
yohoo_res.append({"content": "balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n"})
yohoo_res.append(
{"content": "quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n"})
if self._param.cash_flow_statement:
yohoo_res.append({"content": "cash flow statement:\n" + msft.cashflow.to_markdown() + "\n"})
yohoo_res.append(
{"content": "quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n"})
if self._param.news:
yohoo_res.append({"content": "news:\n" + pd.DataFrame(msft.news).to_markdown() + "\n"})
except Exception:
logging.exception("YahooFinance got exception")
if not yohoo_res:
return YahooFinance.be_output("")
return pd.DataFrame(yohoo_res)

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -15,9 +15,8 @@
#
import argparse
import os
from functools import partial
from agent.canvas import Canvas
from agent.settings import DEBUG
from api import settings
if __name__ == '__main__':
parser = argparse.ArgumentParser()
@ -31,19 +30,17 @@ if __name__ == '__main__':
parser.add_argument('-m', '--stream', default=False, help="Stream output", action='store_true', required=False)
args = parser.parse_args()
settings.init_settings()
canvas = Canvas(open(args.dsl, "r").read(), args.tenant_id)
if canvas.get_prologue():
print(f"==================== Bot =====================\n> {canvas.get_prologue()}", end='')
query = ""
while True:
ans = canvas.run(stream=args.stream)
canvas.reset(True)
query = input("\n==================== User =====================\n> ")
ans = canvas.run(query=query)
print("==================== Bot =====================\n> ", end='')
if args.stream and isinstance(ans, partial):
cont = ""
for an in ans():
print(an["content"][len(cont):], end='', flush=True)
cont = an["content"]
else:
print(ans["content"])
for ans in canvas.run(query=query):
print(ans, end='\n', flush=True)
if DEBUG:
print(canvas.path)
question = input("\n==================== User =====================\n> ")
canvas.add_user_input(question)
print(canvas.path)

View File

@ -1,129 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["baidu:0"],
"upstream": ["begin", "message:0","message:1"]
},
"baidu:0": {
"obj": {
"component_name": "Baidu",
"params": {}
},
"downstream": ["generate:0"],
"upstream": ["answer:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the user's question based on what Baidu searched. First, please output the user's question and the content searched by Baidu, and then answer yes, no, or i don't know.Here is the user's question:{user_input}The above is the user's question.Here is what Baidu searched for:{baidu}The above is the content searched by Baidu.",
"temperature": 0.2
},
"parameters": [
{
"component_id": "answer:0",
"id": "69415446-49bf-4d4b-8ec9-ac86066f7709",
"key": "user_input"
},
{
"component_id": "baidu:0",
"id": "83363c2a-00a8-402f-a45c-ddc4097d7d8b",
"key": "baidu"
}
]
},
"downstream": ["switch:0"],
"upstream": ["baidu:0"]
},
"switch:0": {
"obj": {
"component_name": "Switch",
"params": {
"conditions": [
{
"logical_operator" : "or",
"items" : [
{"cpn_id": "generate:0", "operator": "contains", "value": "yes"},
{"cpn_id": "generate:0", "operator": "contains", "value": "yeah"}
],
"to": "message:0"
},
{
"logical_operator" : "and",
"items" : [
{"cpn_id": "generate:0", "operator": "contains", "value": "no"},
{"cpn_id": "generate:0", "operator": "not contains", "value": "yes"},
{"cpn_id": "generate:0", "operator": "not contains", "value": "know"}
],
"to": "message:1"
},
{
"logical_operator" : "",
"items" : [
{"cpn_id": "generate:0", "operator": "contains", "value": "know"}
],
"to": "message:2"
}
],
"end_cpn_id": "answer:0"
}
},
"downstream": ["message:0","message:1"],
"upstream": ["generate:0"]
},
"message:0": {
"obj": {
"component_name": "Message",
"params": {
"messages": ["YES YES YES YES YES YES YES YES YES YES YES YES"]
}
},
"upstream": ["switch:0"],
"downstream": ["answer:0"]
},
"message:1": {
"obj": {
"component_name": "Message",
"params": {
"messages": ["NO NO NO NO NO NO NO NO NO NO NO NO NO NO"]
}
},
"upstream": ["switch:0"],
"downstream": ["answer:0"]
},
"message:2": {
"obj": {
"component_name": "Message",
"params": {
"messages": ["I DON'T KNOW---------------------------"]
}
},
"upstream": ["switch:0"],
"downstream": ["answer:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
}

View File

@ -7,16 +7,8 @@
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["categorize:0"],
"upstream": ["begin"]
"upstream": []
},
"categorize:0": {
"obj": {
@ -26,48 +18,68 @@
"category_description": {
"product_related": {
"description": "The question is about the product usage, appearance and how it works.",
"examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?",
"to": "message:0"
"to": ["agent:0"]
},
"others": {
"description": "The question is not about the product usage, appearance and how it works.",
"examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?",
"to": "message:1"
"to": ["message:0"]
}
}
}
},
"downstream": ["message:0","message:1"],
"upstream": ["answer:0"]
"downstream": [],
"upstream": ["begin"]
},
"message:0": {
"obj": {
"obj":{
"component_name": "Message",
"params": {
"messages": [
"Message 0!!!!!!!"
"content": [
"Sorry, I don't know. I'm an AI bot."
]
}
},
"downstream": ["answer:0"],
"downstream": [],
"upstream": ["categorize:0"]
},
"agent:0": {
"obj": {
"component_name": "Agent",
"params": {
"llm_id": "deepseek-chat",
"sys_prompt": "You are a smart researcher. You could generate proper queries to search. According to the search results, you could deside next query if the result is not enough.",
"temperature": 0.2,
"llm_enabled_tools": [
{
"component_name": "TavilySearch",
"params": {
"api_key": "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"
}
}
]
}
},
"downstream": ["message:1"],
"upstream": ["categorize:0"]
},
"message:1": {
"obj": {
"component_name": "Message",
"params": {
"messages": [
"Message 1!!!!!!!"
]
"content": ["{agent:0@content}"]
}
},
"downstream": ["answer:0"],
"upstream": ["categorize:0"]
"downstream": [],
"upstream": ["agent:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}
"retrival": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}

View File

@ -1,113 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["categorize:0"],
"upstream": ["begin"]
},
"categorize:0": {
"obj": {
"component_name": "Categorize",
"params": {
"llm_id": "deepseek-chat",
"category_description": {
"product_related": {
"description": "The question is about the product usage, appearance and how it works.",
"examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?",
"to": "concentrator:0"
},
"others": {
"description": "The question is not about the product usage, appearance and how it works.",
"examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?",
"to": "concentrator:1"
}
}
}
},
"downstream": ["concentrator:0","concentrator:1"],
"upstream": ["answer:0"]
},
"concentrator:0": {
"obj": {
"component_name": "Concentrator",
"params": {}
},
"downstream": ["message:0"],
"upstream": ["categorize:0"]
},
"concentrator:1": {
"obj": {
"component_name": "Concentrator",
"params": {}
},
"downstream": ["message:1_0","message:1_1","message:1_2"],
"upstream": ["categorize:0"]
},
"message:0": {
"obj": {
"component_name": "Message",
"params": {
"messages": [
"Message 0_0!!!!!!!"
]
}
},
"downstream": ["answer:0"],
"upstream": ["concentrator:0"]
},
"message:1_0": {
"obj": {
"component_name": "Message",
"params": {
"messages": [
"Message 1_0!!!!!!!"
]
}
},
"downstream": ["answer:0"],
"upstream": ["concentrator:1"]
},
"message:1_1": {
"obj": {
"component_name": "Message",
"params": {
"messages": [
"Message 1_1!!!!!!!"
]
}
},
"downstream": ["answer:0"],
"upstream": ["concentrator:1"]
},
"message:1_2": {
"obj": {
"component_name": "Message",
"params": {
"messages": [
"Message 1_2!!!!!!!"
]
}
},
"downstream": ["answer:0"],
"upstream": ["concentrator:1"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}

View File

@ -1,157 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi! How can I help you?"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["categorize:0"],
"upstream": ["begin", "generate:0", "generate:casual", "generate:answer", "generate:complain", "generate:ask_contact", "message:get_contact"]
},
"categorize:0": {
"obj": {
"component_name": "Categorize",
"params": {
"llm_id": "deepseek-chat",
"category_description": {
"product_related": {
"description": "The question is about the product usage, appearance and how it works.",
"examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?\nException: Can't connect to ES cluster\nHow to build the RAGFlow image from scratch",
"to": "retrieval:0"
},
"casual": {
"description": "The question is not about the product usage, appearance and how it works. Just casual chat.",
"examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?",
"to": "generate:casual"
},
"complain": {
"description": "Complain even curse about the product or service you provide. But the comment is not specific enough.",
"examples": "How bad is it.\nIt's really sucks.\nDamn, for God's sake, can it be more steady?\nShit, I just can't use this shit.\nI can't stand it anymore.",
"to": "generate:complain"
},
"answer": {
"description": "This answer provide a specific contact information, like e-mail, phone number, wechat number, line number, twitter, discord, etc,.",
"examples": "My phone number is 203921\nkevinhu.hk@gmail.com\nThis is my discord number: johndowson_29384",
"to": "message:get_contact"
}
},
"message_history_window_size": 8
}
},
"downstream": ["retrieval:0", "generate:casual", "generate:complain", "message:get_contact"],
"upstream": ["answer:0"]
},
"generate:casual": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are a customer support. But the customer wants to have a casual chat with you instead of consulting about the product. Be nice, funny, enthusiasm and concern.",
"temperature": 0.9,
"message_history_window_size": 12,
"cite": false
}
},
"downstream": ["answer:0"],
"upstream": ["categorize:0"]
},
"generate:complain": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are a customer support. the Customers complain even curse about the products but not specific enough. You need to ask him/her what's the specific problem with the product. Be nice, patient and concern to soothe your customers emotions at first place.",
"temperature": 0.9,
"message_history_window_size": 12,
"cite": false
}
},
"downstream": ["answer:0"],
"upstream": ["categorize:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"]
}
},
"downstream": ["relevant:0"],
"upstream": ["categorize:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:answer",
"no": "generate:ask_contact"
}
},
"downstream": ["generate:answer", "generate:ask_contact"],
"upstream": ["retrieval:0"]
},
"generate:answer": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"generate:ask_contact": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are a customer support. But you can't answer to customers' question. You need to request their contact like E-mail, phone number, Wechat number, LINE number, twitter, discord, etc,. Product experts will contact them later. Please do not ask the same question twice.",
"temperature": 0.9,
"message_history_window_size": 12,
"cite": false
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"message:get_contact": {
"obj":{
"component_name": "Message",
"params": {
"messages": [
"Okay, I've already write this down. What else I can do for you?",
"Get it. What else I can do for you?",
"Thanks for your trust! Our expert will contact ASAP. So, anything else I can do for you?",
"Thanks! So, anything else I can do for you?"
]
}
},
"downstream": ["answer:0"],
"upstream": ["categorize:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}

View File

@ -1,39 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there! Please enter the text you want to translate in format like: 'text you want to translate' => target language. For an example: 您好! => English"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["generate:0"],
"upstream": ["begin", "generate:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an professional interpreter.\n- Role: an professional interpreter.\n- Input format: content need to be translated => target language. \n- Answer format: => translated content in target language. \n- Examples:\n - user: 您好! => English. assistant: => How are you doing!\n - user: You look good today. => Japanese. assistant: => 今日は調子がいいですね 。\n",
"temperature": 0.5
}
},
"downstream": ["answer:0"],
"upstream": ["answer:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
}

View File

@ -1,39 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there! Please enter the text you want to translate in format like: 'text you want to translate' => target language. For an example: 您好! => English"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["generate:0"],
"upstream": ["begin", "generate:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an professional interpreter.\n- Role: an professional interpreter.\n- Input format: content need to be translated => target language. \n- Answer format: => translated content in target language. \n- Examples:\n - user: 您好! => English. assistant: => How are you doing!\n - user: You look good today. => Japanese. assistant: => 今日は調子がいいですね 。\n",
"temperature": 0.5
}
},
"downstream": ["answer:0"],
"upstream": ["answer:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
}

View File

@ -0,0 +1,92 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["generate:0"],
"upstream": []
},
"generate:0": {
"obj": {
"component_name": "Agent",
"params": {
"llm_id": "deepseek-chat",
"sys_prompt": "You are an helpful research assistant. \nPlease decompose user's topic: '{sys.query}' into several meaningful sub-topics. \nThe output format MUST be an string array like: [\"sub-topic1\", \"sub-topic2\", ...]. Redundant information is forbidden.",
"temperature": 0.2,
"cite":false,
"output_structure": ["sub-topic1", "sub-topic2", "sub-topic3"]
}
},
"downstream": ["iteration:0"],
"upstream": ["begin"]
},
"iteration:0": {
"obj": {
"component_name": "Iteration",
"params": {
"items_ref": "generate:0@structured_content"
}
},
"downstream": ["message:0"],
"upstream": ["generate:0"]
},
"iterationitem:0": {
"obj": {
"component_name": "IterationItem",
"params": {}
},
"parent_id": "iteration:0",
"downstream": ["tavily:0"],
"upstream": []
},
"tavily:0": {
"obj": {
"component_name": "TavilySearch",
"params": {
"api_key": "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",
"query": "iterationitem:0@result"
}
},
"parent_id": "iteration:0",
"downstream": ["generate:1"],
"upstream": ["iterationitem:0"]
},
"generate:1": {
"obj": {
"component_name": "Agent",
"params": {
"llm_id": "deepseek-chat",
"sys_prompt": "Your goal is to provide answers based on information from the internet. \nYou must use the provided search results to find relevant online information. \nYou should never use your own knowledge to answer questions.\nPlease include relevant url sources in the end of your answers.\n\n \"{tavily:0@formalized_content}\" \nUsing the above information, answer the following question or topic: \"{iterationitem:0@result} \"\nin a detailed report — The report should focus on the answer to the question, should be well structured, informative, in depth, with facts and numbers if available, a minimum of 200 words and with markdown syntax and apa format. Write all source urls at the end of the report in apa format. You should write your report only based on the given information and nothing else.",
"temperature": 0.9,
"cite":false
}
},
"parent_id": "iteration:0",
"downstream": ["iterationitem:0"],
"upstream": ["tavily:0"]
},
"message:0": {
"obj": {
"component_name": "Message",
"params": {
"content": ["{iteration:0@generate:1}"]
}
},
"downstream": [],
"upstream": ["iteration:0"]
}
},
"history": [],
"path": [],
"retrival": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}

View File

@ -1,62 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["keyword:0"],
"upstream": ["begin"]
},
"keyword:0": {
"obj": {
"component_name": "KeywordExtract",
"params": {
"llm_id": "deepseek-chat",
"prompt": "- Role: You're a question analyzer.\n - Requirements:\n - Summarize user's question, and give top %s important keyword/phrase.\n - Use comma as a delimiter to separate keywords/phrases.\n - Answer format: (in language of user's question)\n - keyword: ",
"temperature": 0.2,
"top_n": 1
}
},
"downstream": ["wikipedia:0"],
"upstream": ["answer:0"]
},
"wikipedia:0": {
"obj":{
"component_name": "Wikipedia",
"params": {
"top_n": 10
}
},
"downstream": ["generate:0"],
"upstream": ["keyword:0"]
},
"generate:1": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content from Wikipedia. When the answer from Wikipedia is incomplete, you need to output the URL link of the corresponding content as well. When all the content searched from Wikipedia is irrelevant to the question, your answer must include the sentence, \"The answer you are looking for is not found in the Wikipedia!\". Answers need to consider chat history.\n The content of Wikipedia is as follows:\n {input}\n The above is the content of Wikipedia.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"upstream": ["wikipedia:0"]
}
},
"history": [],
"path": [],
"messages": [],
"reference": {},
"answer": []
}

View File

@ -7,16 +7,8 @@
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0"]
"upstream": []
},
"retrieval:0": {
"obj": {
@ -26,29 +18,44 @@
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"]
"rerank_id": "",
"empty_response": "Nothing found in dataset",
"kb_ids": ["1a3d1d7afb0611ef9866047c16ec874f"]
}
},
"downstream": ["generate:0"],
"upstream": ["answer:0"]
"upstream": ["begin"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"component_name": "LLM",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {input}\n The above is the knowledge base.",
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"downstream": ["message:0"],
"upstream": ["retrieval:0"]
},
"message:0": {
"obj": {
"component_name": "Message",
"params": {
"content": ["{generate:0@content}"]
}
},
"downstream": [],
"upstream": ["generate:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
"retrival": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}

View File

@ -7,16 +7,8 @@
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["categorize:0"],
"upstream": ["begin", "generate:0", "switch:0"]
"upstream": []
},
"categorize:0": {
"obj": {
@ -26,30 +18,30 @@
"category_description": {
"product_related": {
"description": "The question is about the product usage, appearance and how it works.",
"examples": "Why it always beaming?\nHow to install it onto the wall?\nIt leaks, what to do?",
"to": "retrieval:0"
"examples": [],
"to": ["retrieval:0"]
},
"others": {
"description": "The question is not about the product usage, appearance and how it works.",
"examples": "How are you doing?\nWhat is your name?\nAre you a robot?\nWhat's the weather?\nWill it rain?",
"to": "message:0"
"examples": [],
"to": ["message:0"]
}
}
}
},
"downstream": ["retrieval:0", "message:0"],
"upstream": ["answer:0"]
"downstream": [],
"upstream": ["begin"]
},
"message:0": {
"obj":{
"component_name": "Message",
"params": {
"messages": [
"content": [
"Sorry, I don't know. I'm an AI bot."
]
}
},
"downstream": ["answer:0"],
"downstream": [],
"upstream": ["categorize:0"]
},
"retrieval:0": {
@ -60,29 +52,44 @@
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"]
"rerank_id": "",
"empty_response": "Nothing found in dataset",
"kb_ids": ["1a3d1d7afb0611ef9866047c16ec874f"]
}
},
"downstream": ["generate:0"],
"upstream": ["switch:0"]
"upstream": ["categorize:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"component_name": "Agent",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {input}\n The above is the knowledge base.",
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {retrieval:0@formalized_content}\n The above is the knowledge base.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"downstream": ["message:1"],
"upstream": ["retrieval:0"]
},
"message:1": {
"obj": {
"component_name": "Message",
"params": {
"content": ["{generate:0@content}"]
}
},
"downstream": [],
"upstream": ["generate:0"]
}
},
"history": [],
"messages": [],
"reference": {},
"path": [],
"answer": []
"retrival": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}

View File

@ -1,82 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "message:0"
}
},
"downstream": ["message:0", "generate:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"message:0": {
"obj":{
"component_name": "Message",
"params": {
"messages": [
"Sorry, I don't know. Please leave your contact, our experts will contact you later. What's your e-mail/phone/wechat?",
"I'm an AI bot and not quite sure about this question. Please leave your contact, our experts will contact you later. What's your e-mail/phone/wechat?",
"Can't find answer in my knowledge base. Please leave your contact, our experts will contact you later. What's your e-mail/phone/wechat?"
]
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"path": [],
"messages": [],
"reference": {},
"answer": []
}

View File

@ -1,103 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["21ca4e6a2c8911ef8b1e0242ac120006"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "keyword:0"
}
},
"downstream": ["keyword:0", "generate:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"keyword:0": {
"obj": {
"component_name": "KeywordExtract",
"params": {
"llm_id": "deepseek-chat",
"prompt": "- Role: You're a question analyzer.\n - Requirements:\n - Summarize user's question, and give top %s important keyword/phrase.\n - Use comma as a delimiter to separate keywords/phrases.\n - Answer format: (in language of user's question)\n - keyword: ",
"temperature": 0.2,
"top_n": 1
}
},
"downstream": ["baidu:0"],
"upstream": ["relevant:0"]
},
"baidu:0": {
"obj":{
"component_name": "Baidu",
"params": {
"top_n": 10
}
},
"downstream": ["generate:1"],
"upstream": ["keyword:0"]
},
"generate:1": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content searched from Baidu. When the answer from a Baidu search is incomplete, you need to output the URL link of the corresponding content as well. When all the content searched from Baidu is irrelevant to the question, your answer must include the sentence, \"The answer you are looking for is not found in the Baidu search!\". Answers need to consider chat history.\n The content of Baidu search is as follows:\n {input}\n The above is the content of Baidu search.",
"temperature": 0.2
}
},
"downstream": ["answer:0"],
"upstream": ["baidu:0"]
}
},
"history": [],
"path": [],
"messages": [],
"reference": {},
"answer": []
}

View File

@ -1,79 +0,0 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["answer:0"],
"upstream": []
},
"answer:0": {
"obj": {
"component_name": "Answer",
"params": {}
},
"downstream": ["retrieval:0"],
"upstream": ["begin", "generate:0", "switch:0"]
},
"retrieval:0": {
"obj": {
"component_name": "Retrieval",
"params": {
"similarity_threshold": 0.2,
"keywords_similarity_weight": 0.3,
"top_n": 6,
"top_k": 1024,
"rerank_id": "BAAI/bge-reranker-v2-m3",
"kb_ids": ["869a236818b811ef91dffa163e197198"],
"empty_response": "Sorry, knowledge base has noting related information."
}
},
"downstream": ["relevant:0"],
"upstream": ["answer:0", "rewrite:0"]
},
"relevant:0": {
"obj": {
"component_name": "Relevant",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.02,
"yes": "generate:0",
"no": "rewrite:0"
}
},
"downstream": ["generate:0", "rewrite:0"],
"upstream": ["retrieval:0"]
},
"generate:0": {
"obj": {
"component_name": "Generate",
"params": {
"llm_id": "deepseek-chat",
"prompt": "You are an intelligent assistant. Please answer the question based on content of knowledge base. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\". Answers need to consider chat history.\n Knowledge base content is as following:\n {input}\n The above is the content of knowledge base.",
"temperature": 0.02
}
},
"downstream": ["answer:0"],
"upstream": ["relevant:0"]
},
"rewrite:0": {
"obj":{
"component_name": "RewriteQuestion",
"params": {
"llm_id": "deepseek-chat",
"temperature": 0.8
}
},
"downstream": ["retrieval:0"],
"upstream": ["relevant:0"]
}
},
"history": [],
"messages": [],
"path": [],
"reference": [],
"answer": []
}

View File

@ -0,0 +1,55 @@
{
"components": {
"begin": {
"obj":{
"component_name": "Begin",
"params": {
"prologue": "Hi there!"
}
},
"downstream": ["tavily:0"],
"upstream": []
},
"tavily:0": {
"obj": {
"component_name": "TavilySearch",
"params": {
"api_key": "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"
}
},
"downstream": ["generate:0"],
"upstream": ["begin"]
},
"generate:0": {
"obj": {
"component_name": "LLM",
"params": {
"llm_id": "deepseek-chat",
"sys_prompt": "You are an intelligent assistant. Please summarize the content of the knowledge base to answer the question. Please list the data in the knowledge base and answer in detail. When all knowledge base content is irrelevant to the question, your answer must include the sentence \"The answer you are looking for is not found in the knowledge base!\" Answers need to consider chat history.\n Here is the knowledge base:\n {tavily:0@formalized_content}\n The above is the knowledge base.",
"temperature": 0.2
}
},
"downstream": ["message:0"],
"upstream": ["tavily:0"]
},
"message:0": {
"obj": {
"component_name": "Message",
"params": {
"content": ["{generate:0@content}"]
}
},
"downstream": [],
"upstream": ["generate:0"]
}
},
"history": [],
"path": [],
"retrival": {"chunks": [], "doc_aggs": []},
"globals": {
"sys.query": "",
"sys.user_id": "",
"sys.conversation_turns": 0,
"sys.files": []
}
}

33
agent/tools/__init__.py Normal file
View File

@ -0,0 +1,33 @@
import os
import importlib
import inspect
from types import ModuleType
from typing import Dict, Type
_package_path = os.path.dirname(__file__)
__all_classes: Dict[str, Type] = {}
def _import_submodules() -> None:
for filename in os.listdir(_package_path): # noqa: F821
if filename.startswith("__") or not filename.endswith(".py") or filename.startswith("base"):
continue
module_name = filename[:-3]
try:
module = importlib.import_module(f".{module_name}", package=__name__)
_extract_classes_from_module(module) # noqa: F821
except ImportError as e:
print(f"Warning: Failed to import module {module_name}: {str(e)}")
def _extract_classes_from_module(module: ModuleType) -> None:
for name, obj in inspect.getmembers(module):
if (inspect.isclass(obj) and
obj.__module__ == module.__name__ and not name.startswith("_")):
__all_classes[name] = obj
globals()[name] = obj
_import_submodules()
__all__ = list(__all_classes.keys()) + ["__all_classes"]
del _package_path, _import_submodules, _extract_classes_from_module

96
agent/tools/arxiv.py Normal file
View File

@ -0,0 +1,96 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
import arxiv
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
class ArXivParam(ToolParamBase):
"""
Define the ArXiv component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "arxiv_search",
"description": """arXiv is a free distribution service and an open-access archive for nearly 2.4 million scholarly articles in the fields of physics, mathematics, computer science, quantitative biology, quantitative finance, statistics, electrical engineering and systems science, and economics. Materials on this site are not peer-reviewed by arXiv.""",
"parameters": {
"query": {
"type": "string",
"description": "The search keywords to execute with arXiv. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 12
self.sort_by = 'submittedDate'
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.sort_by, "ArXiv Search Sort_by",
['submittedDate', 'lastUpdatedDate', 'relevance'])
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class ArXiv(ToolBase, ABC):
component_name = "ArXiv"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
'submittedDate': arxiv.SortCriterion.SubmittedDate}
arxiv_client = arxiv.Client()
search = arxiv.Search(
query=kwargs["query"],
max_results=self._param.top_n,
sort_by=sort_choices[self._param.sort_by]
)
self._retrieve_chunks(list(arxiv_client.results(search)),
get_title=lambda r: r.title,
get_url=lambda r: r.pdf_url,
get_content=lambda r: r.summary)
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"ArXiv error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"ArXiv error: {last_e}"
assert False, self.output()

167
agent/tools/base.py Normal file
View File

@ -0,0 +1,167 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import re
import time
from copy import deepcopy
from functools import partial
from typing import TypedDict, List, Any
from agent.component.base import ComponentParamBase, ComponentBase
from api.utils import hash_str2int
from rag.llm.chat_model import ToolCallSession
from rag.prompts.prompts import kb_prompt
from rag.utils.mcp_tool_call_conn import MCPToolCallSession
class ToolParameter(TypedDict):
type: str
description: str
displayDescription: str
enum: List[str]
required: bool
class ToolMeta(TypedDict):
name: str
displayName: str
description: str
displayDescription: str
parameters: dict[str, ToolParameter]
class LLMToolPluginCallSession(ToolCallSession):
def __init__(self, tools_map: dict[str, object], callback: partial):
self.tools_map = tools_map
self.callback = callback
def tool_call(self, name: str, arguments: dict[str, Any]) -> Any:
assert name in self.tools_map, f"LLM tool {name} does not exist"
self.callback(name, arguments, " running ...")
if isinstance(self.tools_map[name], MCPToolCallSession):
resp = self.tools_map[name].tool_call(name, arguments, 60)
else:
resp = self.tools_map[name].invoke(**arguments)
return resp
def get_tool_obj(self, name):
return self.tools_map[name]
class ToolParamBase(ComponentParamBase):
def __init__(self):
#self.meta:ToolMeta = None
super().__init__()
self._init_inputs()
self._init_attr_by_meta()
def _init_inputs(self):
self.inputs = {}
for k,p in self.meta["parameters"].items():
self.inputs[k] = deepcopy(p)
def _init_attr_by_meta(self):
for k,p in self.meta["parameters"].items():
if not hasattr(self, k):
setattr(self, k, p.get("default"))
def get_meta(self):
params = {}
for k, p in self.meta["parameters"].items():
params[k] = {
"type": p["type"],
"description": p["description"]
}
if "enum" in p:
params[k]["enum"] = p["enum"]
desc = self.meta["description"]
if hasattr(self, "description"):
desc = self.description
function_name = self.meta["name"]
if hasattr(self, "function_name"):
function_name = self.function_name
return {
"type": "function",
"function": {
"name": function_name,
"description": desc,
"parameters": {
"type": "object",
"properties": params,
"required": [k for k, p in self.meta["parameters"].items() if p["required"]]
}
}
}
class ToolBase(ComponentBase):
def __init__(self, canvas, id, param: ComponentParamBase):
from agent.canvas import Canvas # Local import to avoid cyclic dependency
assert isinstance(canvas, Canvas), "canvas must be an instance of Canvas"
self._canvas = canvas
self._id = id
self._param = param
self._param.check()
def get_meta(self) -> dict[str, Any]:
return self._param.get_meta()
def invoke(self, **kwargs):
self.set_output("_created_time", time.perf_counter())
try:
res = self._invoke(**kwargs)
except Exception as e:
self._param.outputs["_ERROR"] = {"value": str(e)}
logging.exception(e)
res = str(e)
self._param.debug_inputs = []
self.set_output("_elapsed_time", time.perf_counter() - self.output("_created_time"))
return res
def _retrieve_chunks(self, res_list: list, get_title, get_url, get_content, get_score=None):
chunks = []
aggs = []
for r in res_list:
content = get_content(r)
if not content:
continue
content = re.sub(r"!?\[[a-z]+\]\(data:image/png;base64,[ 0-9A-Za-z/_=+-]+\)", "", content)
content = content[:10000]
if not content:
continue
id = str(hash_str2int(content))
title = get_title(r)
url = get_url(r)
score = get_score(r) if get_score else 1
chunks.append({
"chunk_id": id,
"content": content,
"doc_id": id,
"docnm_kwd": title,
"similarity": score,
"url": url
})
aggs.append({
"doc_name": title,
"doc_id": id,
"count": 1,
"url": url
})
self._canvas.add_refernce(chunks, aggs)
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))

192
agent/tools/code_exec.py Normal file
View File

@ -0,0 +1,192 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import base64
import logging
import os
from abc import ABC
from enum import StrEnum
from typing import Optional
from pydantic import BaseModel, Field, field_validator
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api import settings
from api.utils.api_utils import timeout
class Language(StrEnum):
PYTHON = "python"
NODEJS = "nodejs"
class CodeExecutionRequest(BaseModel):
code_b64: str = Field(..., description="Base64 encoded code string")
language: str = Field(default=Language.PYTHON.value, description="Programming language")
arguments: Optional[dict] = Field(default={}, description="Arguments")
@field_validator("code_b64")
@classmethod
def validate_base64(cls, v: str) -> str:
try:
base64.b64decode(v, validate=True)
return v
except Exception as e:
raise ValueError(f"Invalid base64 encoding: {str(e)}")
@field_validator("language", mode="before")
@classmethod
def normalize_language(cls, v) -> str:
if isinstance(v, str):
low = v.lower()
if low in ("python", "python3"):
return "python"
elif low in ("javascript", "nodejs"):
return "nodejs"
raise ValueError(f"Unsupported language: {v}")
class CodeExecParam(ToolParamBase):
"""
Define the code sandbox component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "execute_code",
"description": """
This tool has a sandbox that can execute code written in 'Python'/'Javascript'. It recieves a piece of code and return a Json string.
Here's a code example for Python(`main` function MUST be included):
def main(arg1: str, arg2: str) -> dict:
return {
"result": arg1 + arg2,
}
Here's a code example for Javascript(`main` function MUST be included and exported):
const axios = require('axios');
async function main(args) {
try {
const response = await axios.get('https://github.com/infiniflow/ragflow');
console.log('Body:', response.data);
} catch (error) {
console.error('Error:', error.message);
}
}
module.exports = { main };
""",
"parameters": {
"lang": {
"type": "string",
"description": "The programming language of this piece of code.",
"enum": ["python", "javascript"],
"required": True,
},
"script": {
"type": "string",
"description": "A piece of code in right format. There MUST be main function.",
"required": True
}
}
}
super().__init__()
self.lang = Language.PYTHON.value
self.script = "def main(arg1: str, arg2: str) -> dict: return {\"result\": arg1 + arg2}"
self.arguments = {}
self.outputs = {"result": {"value": "", "type": "string"}}
def check(self):
self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"])
self.check_empty(self.script, "Script")
def get_input_form(self) -> dict[str, dict]:
res = {}
for k, v in self.arguments.items():
res[k] = {
"type": "line",
"name": k
}
return res
class CodeExec(ToolBase, ABC):
component_name = "CodeExec"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
lang = kwargs.get("lang", self._param.lang)
script = kwargs.get("script", self._param.script)
arguments = {}
for k, v in self._param.arguments.items():
if kwargs.get(k):
arguments[k] = kwargs[k]
continue
arguments[k] = self._canvas.get_variable_value(v) if v else None
self._execute_code(
language=lang,
code=script,
arguments=arguments
)
def _execute_code(self, language: str, code: str, arguments: dict):
import requests
try:
code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e:
self.set_output("_ERROR", "construct code request error: " + str(e))
try:
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=10)
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run", code_req, resp.status_code)
if resp.status_code != 200:
resp.raise_for_status()
body = resp.json()
if body:
stderr = body.get("stderr")
if stderr:
self.set_output("_ERROR", stderr)
return
try:
rt = eval(body.get("stdout", ""))
except Exception:
rt = body.get("stdout", "")
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
if isinstance(rt, tuple):
for i, (k, o) in enumerate(self._param.outputs.items()):
if k.find("_") == 0:
continue
o["value"] = rt[i]
elif isinstance(rt, dict):
for i, (k, o) in enumerate(self._param.outputs.items()):
if k not in rt or k.find("_") == 0:
continue
o["value"] = rt[k]
else:
for i, (k, o) in enumerate(self._param.outputs.items()):
if k.find("_") == 0:
continue
o["value"] = rt
else:
self.set_output("_ERROR", "There is no response from sandbox")
except Exception as e:
self.set_output("_ERROR", "Exception executing code: " + str(e))
return self.output()
def _encode_code(self, code: str) -> str:
return base64.b64encode(code.encode("utf-8")).decode("utf-8")

View File

@ -16,11 +16,12 @@
from abc import ABC
import asyncio
from crawl4ai import AsyncWebCrawler
from agent.component.base import ComponentBase, ComponentParamBase
from agent.tools.base import ToolParamBase, ToolBase
from api.utils.web_utils import is_valid_url
class CrawlerParam(ComponentParamBase):
class CrawlerParam(ToolParamBase):
"""
Define the Crawler component parameters.
"""
@ -34,7 +35,7 @@ class CrawlerParam(ComponentParamBase):
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
class Crawler(ComponentBase, ABC):
class Crawler(ToolBase, ABC):
component_name = "Crawler"
def _run(self, history, **kwargs):

114
agent/tools/duckduckgo.py Normal file
View File

@ -0,0 +1,114 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
from duckduckgo_search import DDGS
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
class DuckDuckGoParam(ToolParamBase):
"""
Define the DuckDuckGo component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "duckduckgo_search",
"description": "DuckDuckGo is a search engine focused on privacy. It offers search capabilities for web pages, images, and provides translation services. DuckDuckGo also features a private AI chat interface, providing users with an AI assistant that prioritizes data protection.",
"parameters": {
"query": {
"type": "string",
"description": "The search keywords to execute with DuckDuckGo. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
},
"channel": {
"type": "string",
"description": "default:general. The category of the search. `news` is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. `general` is for broader, more general-purpose searches that may include a wide range of sources.",
"enum": ["general", "news"],
"default": "general",
"required": False,
},
}
}
super().__init__()
self.top_n = 10
self.channel = "text"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.channel, "Web Search or News", ["text", "news"])
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
},
"channel": {
"name": "Channel",
"type": "options",
"value": "general",
"options": ["general", "news"]
}
}
class DuckDuckGo(ToolBase, ABC):
component_name = "DuckDuckGo"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
if kwargs.get("topic", "general") == "general":
with DDGS() as ddgs:
# {'title': '', 'href': '', 'body': ''}
duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n)
self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")),
get_content=lambda r: r["body"])
self.set_output("json", duck_res)
return self.output("formalized_content")
else:
with DDGS() as ddgs:
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n)
self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")),
get_content=lambda r: r["body"])
self.set_output("json", duck_res)
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"DuckDuckGo error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"DuckDuckGo error: {last_e}"
assert False, self.output()

207
agent/tools/email.py Normal file
View File

@ -0,0 +1,207 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import time
from abc import ABC
import json
import smtplib
import logging
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
from email.header import Header
from email.utils import formataddr
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
class EmailParam(ToolParamBase):
"""
Define the Email component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "email",
"description": "The email is a method of electronic communication for sending and receiving information through the Internet. This tool helps users to send emails to one person or to multiple recipients with support for CC, BCC, file attachments, and markdown-to-HTML conversion.",
"parameters": {
"to_email": {
"type": "string",
"description": "The target email address.",
"default": "{sys.query}",
"required": True
},
"cc_email": {
"type": "string",
"description": "The other email addresses needs to be send to. Comma splited.",
"default": "",
"required": False
},
"content": {
"type": "string",
"description": "The content of the email.",
"default": "",
"required": False
},
"subject": {
"type": "string",
"description": "The subject/title of the email.",
"default": "",
"required": False
}
}
}
super().__init__()
# Fixed configuration parameters
self.smtp_server = "" # SMTP server address
self.smtp_port = 465 # SMTP port
self.email = "" # Sender email
self.password = "" # Email authorization code
self.sender_name = "" # Sender name
def check(self):
# Check required parameters
self.check_empty(self.smtp_server, "SMTP Server")
self.check_empty(self.email, "Email")
self.check_empty(self.password, "Password")
self.check_empty(self.sender_name, "Sender Name")
def get_input_form(self) -> dict[str, dict]:
return {
"to_email": {
"name": "To ",
"type": "line"
},
"subject": {
"name": "Subject",
"type": "line",
"optional": True
},
"cc_email": {
"name": "CC To",
"type": "line",
"optional": True
},
}
class Email(ToolBase, ABC):
component_name = "Email"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
def _invoke(self, **kwargs):
if not kwargs.get("to_email"):
self.set_output("success", False)
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
# Parse JSON string passed from upstream
email_data = kwargs
# Validate required fields
if "to_email" not in email_data:
return Email.be_output("Missing required field: to_email")
# Create email object
msg = MIMEMultipart('alternative')
# Properly handle sender name encoding
msg['From'] = formataddr((str(Header(self._param.sender_name,'utf-8')), self._param.email))
msg['To'] = email_data["to_email"]
if email_data.get("cc_email"):
msg['Cc'] = email_data["cc_email"]
msg['Subject'] = Header(email_data.get("subject", "No Subject"), 'utf-8').encode()
# Use content from email_data or default content
email_content = email_data.get("content", "No content provided")
# msg.attach(MIMEText(email_content, 'plain', 'utf-8'))
msg.attach(MIMEText(email_content, 'html', 'utf-8'))
# Connect to SMTP server and send
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
context = smtplib.ssl.create_default_context()
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
server.ehlo()
server.starttls(context=context)
server.ehlo()
# Login
logging.info(f"Attempting to login with email: {self._param.email}")
server.login(self._param.email, self._param.password)
# Get all recipient list
recipients = [email_data["to_email"]]
if email_data.get("cc_email"):
recipients.extend(email_data["cc_email"].split(','))
# Send email
logging.info(f"Sending email to recipients: {recipients}")
try:
server.send_message(msg, self._param.email, recipients)
success = True
except Exception as e:
logging.error(f"Error during send_message: {str(e)}")
# Try alternative method
server.sendmail(self._param.email, recipients, msg.as_string())
success = True
try:
server.quit()
except Exception as e:
# Ignore errors when closing connection
logging.warning(f"Non-fatal error during connection close: {str(e)}")
self.set_output("success", success)
return success
except json.JSONDecodeError:
error_msg = "Invalid JSON format in input"
logging.error(error_msg)
self.set_output("_ERROR", error_msg)
self.set_output("success", False)
return False
except smtplib.SMTPAuthenticationError:
error_msg = "SMTP Authentication failed. Please check your email and authorization code."
logging.error(error_msg)
self.set_output("_ERROR", error_msg)
self.set_output("success", False)
return False
except smtplib.SMTPConnectError:
error_msg = f"Failed to connect to SMTP server {self._param.smtp_server}:{self._param.smtp_port}"
logging.error(error_msg)
last_e = error_msg
time.sleep(self._param.delay_after_error)
except smtplib.SMTPException as e:
error_msg = f"SMTP error occurred: {str(e)}"
logging.error(error_msg)
last_e = error_msg
time.sleep(self._param.delay_after_error)
except Exception as e:
error_msg = f"Unexpected error: {str(e)}"
logging.error(error_msg)
self.set_output("_ERROR", error_msg)
self.set_output("success", False)
return False
if last_e:
self.set_output("_ERROR", str(last_e))
return False
assert False, self.output()

133
agent/tools/exesql.py Normal file
View File

@ -0,0 +1,133 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from abc import ABC
import pandas as pd
import pymysql
import psycopg2
import pyodbc
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
class ExeSQLParam(ToolParamBase):
"""
Define the ExeSQL component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "execute_sql",
"description": "This is a tool that can execute SQL.",
"parameters": {
"sql": {
"type": "string",
"description": "The SQL needs to be executed.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgresql', 'mariadb', 'mssql'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.max_records, "Maximum number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
raise ValueError("For the security reason, it dose not support database named rag_flow.")
if self.password == "infini_rag_flow":
raise ValueError("For the security reason, it dose not support database named rag_flow.")
def get_input_form(self) -> dict[str, dict]:
return {
"sql": {
"name": "SQL",
"type": "line"
}
}
class ExeSQL(ToolBase, ABC):
component_name = "ExeSQL"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
def _invoke(self, **kwargs):
sql = kwargs.get("sql")
if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.")
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'postgresql':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'mssql':
conn_str = (
r'DRIVER={ODBC Driver 17 for SQL Server};'
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
r'DATABASE=' + self._param.database + ';'
r'UID=' + self._param.username + ';'
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
try:
cursor = db.cursor()
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
sql_res = []
formalized_content = []
for single_sql in sqls:
single_sql = single_sql.replace('```','')
if not single_sql:
continue
cursor.execute(single_sql)
if cursor.rowcount == 0:
sql_res.append({"content": "No record in the database!"})
break
if self._param.db_type == 'mssql':
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records),
columns=[desc[0] for desc in cursor.description])
else:
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
single_res.columns = [i[0] for i in cursor.description]
sql_res.append(single_res.to_dict(orient='records'))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
def debug(self, **kwargs):
return self._run([], **kwargs)

88
agent/tools/github.py Normal file
View File

@ -0,0 +1,88 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
import requests
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
class GitHubParam(ToolParamBase):
"""
Define the GitHub component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "github_search",
"description": """GitHub repository search is a feature that enables users to find specific repositories on the GitHub platform. This search functionality allows users to locate projects, codebases, and other content hosted on GitHub based on various criteria.""",
"parameters": {
"query": {
"type": "string",
"description": "The search keywords to execute with GitHub. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 10
def check(self):
self.check_positive_integer(self.top_n, "Top N")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class GitHub(ToolBase, ABC):
component_name = "GitHub"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str(
self._param.top_n)
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
response = requests.get(url=url, headers=headers).json()
self._retrieve_chunks(response['items'],
get_title=lambda r: r["name"],
get_url=lambda r: r["html_url"],
get_content=lambda r: str(r["description"]) + '\n stars:' + str(r['watchers']))
self.set_output("json", response['items'])
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"GitHub error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"GitHub error: {last_e}"
assert False, self.output()

View File

@ -14,26 +14,52 @@
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
from serpapi import GoogleSearch
import pandas as pd
from agent.component.base import ComponentBase, ComponentParamBase
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
class GoogleParam(ComponentParamBase):
class GoogleParam(ToolParamBase):
"""
Define the Google component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "google_search",
"description": """Search the world's information, including webpages, images, videos and more. Google has many special features to help you find exactly what you're looking ...""",
"parameters": {
"q": {
"type": "string",
"description": "The search keywords to execute with Google. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
},
"start": {
"type": "integer",
"description": "Parameter defines the result offset. It skips the given number of results. It's used for pagination. (e.g., 0 (default) is the first page of results, 10 is the 2nd page of results, 20 is the 3rd page of results, etc.). Google Local Results only accepts multiples of 20(e.g. 20 for the second page results, 40 for the third page results, etc.) as the `start` value.",
"default": "0",
"required": False,
},
"num": {
"type": "integer",
"description": "Parameter defines the maximum number of results to return. (e.g., 10 (default) returns 10 results, 40 returns 40 results, and 100 returns 100 results). The use of num may introduce latency, and/or prevent the inclusion of specialized result types. It is better to omit this parameter unless it is strictly necessary to increase the number of results per page. Results are not guaranteed to have the number of results specified in num.",
"default": "6",
"required": False,
}
}
}
super().__init__()
self.top_n = 10
self.api_key = "xxx"
self.start = 0
self.num = 6
self.api_key = ""
self.country = "cn"
self.language = "en"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_empty(self.api_key, "SerpApi API key")
self.check_valid_value(self.country, "Google Country",
['af', 'al', 'dz', 'as', 'ad', 'ao', 'ai', 'aq', 'ag', 'ar', 'am', 'aw', 'au', 'at',
@ -69,28 +95,60 @@ class GoogleParam(ComponentParamBase):
'ug', 'uk', 'ur', 'uz', 'vu', 'vi', 'cy', 'wo', 'xh', 'yi', 'yo', 'zu']
)
def get_input_form(self) -> dict[str, dict]:
return {
"q": {
"name": "Query",
"type": "line"
},
"start": {
"name": "From",
"type": "integer",
"value": 0
},
"num": {
"name": "Limit",
"type": "integer",
"value": 12
}
}
class Google(ComponentBase, ABC):
class Google(ToolBase, ABC):
component_name = "Google"
def _run(self, history, **kwargs):
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return Google.be_output("")
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("q"):
self.set_output("formalized_content", "")
return ""
try:
client = GoogleSearch(
{"engine": "google", "q": ans, "api_key": self._param.api_key, "gl": self._param.country,
"hl": self._param.language, "num": self._param.top_n})
google_res = [{"content": '<a href="' + i["link"] + '">' + i["title"] + '</a> ' + i["snippet"]} for i in
client.get_dict()["organic_results"]]
except Exception:
return Google.be_output("**ERROR**: Existing Unavailable Parameters!")
params = {
"api_key": self._param.api_key,
"engine": "google",
"q": kwargs["q"],
"google_domain": "google.com",
"gl": self._param.country,
"hl": self._param.language
}
last_e = ""
for _ in range(self._param.max_retries+1):
try:
search = GoogleSearch(params).get_dict()
self._retrieve_chunks(search["organic_results"],
get_title=lambda r: r["title"],
get_url=lambda r: r["link"],
get_content=lambda r: r.get("about_this_result", {}).get("source", {}).get("description", r["snippet"])
)
self.set_output("json", search["organic_results"])
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"Google error: {e}")
time.sleep(self._param.delay_after_error)
if not google_res:
return Google.be_output("")
if last_e:
self.set_output("_ERROR", str(last_e))
return f"Google error: {last_e}"
assert False, self.output()
df = pd.DataFrame(google_res)
logging.debug(f"df: {df}")
return df

View File

@ -0,0 +1,93 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
from scholarly import scholarly
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
class GoogleScholarParam(ToolParamBase):
"""
Define the GoogleScholar component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "google_scholar_search",
"description": """Google Scholar provides a simple way to broadly search for scholarly literature. From one place, you can search across many disciplines and sources: articles, theses, books, abstracts and court opinions, from academic publishers, professional societies, online repositories, universities and other web sites. Google Scholar helps you find relevant work across the world of scholarly research.""",
"parameters": {
"query": {
"type": "string",
"description": "The search keyword to execute with Google Scholar. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 12
self.sort_by = 'relevance'
self.year_low = None
self.year_high = None
self.patents = True
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.sort_by, "GoogleScholar Sort_by", ['date', 'relevance'])
self.check_boolean(self.patents, "Whether or not to include patents, defaults to True")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class GoogleScholar(ToolBase, ABC):
component_name = "GoogleScholar"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low,
year_high=self._param.year_high, sort_by=self._param.sort_by)
self._retrieve_chunks(scholar_client,
get_title=lambda r: r['bib']['title'],
get_url=lambda r: r["pub_url"],
get_content=lambda r: "\n author: " + ",".join(r['bib']['author']) + '\n Abstract: ' + r['bib'].get('abstract', 'no abstract')
)
self.set_output("json", list(scholar_client))
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"GoogleScholar error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"GoogleScholar error: {last_e}"
assert False, self.output()

105
agent/tools/pubmed.py Normal file
View File

@ -0,0 +1,105 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
from Bio import Entrez
import re
import xml.etree.ElementTree as ET
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
class PubMedParam(ToolParamBase):
"""
Define the PubMed component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "pubmed_search",
"description": """
PubMed is an openly accessible, free database which includes primarily the MEDLINE database of references and abstracts on life sciences and biomedical topics.
In addition to MEDLINE, PubMed provides access to:
- older references from the print version of Index Medicus, back to 1951 and earlier
- references to some journals before they were indexed in Index Medicus and MEDLINE, for instance Science, BMJ, and Annals of Surgery
- very recent entries to records for an article before it is indexed with Medical Subject Headings (MeSH) and added to MEDLINE
- a collection of books available full-text and other subsets of NLM records[4]
- PMC citations
- NCBI Bookshelf
""",
"parameters": {
"query": {
"type": "string",
"description": "The search keywords to execute with PubMed. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 12
self.email = "A.N.Other@example.com"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class PubMed(ToolBase, ABC):
component_name = "PubMed"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
Entrez.email = self._param.email
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList']
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
retmode="xml").read().decode("utf-8")))
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
get_content=lambda child: child.find("MedlineCitation") \
.find("Article") \
.find("Abstract") \
.find("AbstractText").text \
if child.find("MedlineCitation")\
.find("Article").find("Abstract") \
else "No abstract available")
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"PubMed error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"PubMed error: {last_e}"
assert False, self.output()

161
agent/tools/retrieval.py Normal file
View File

@ -0,0 +1,161 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
from abc import ABC
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.db import LLMType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle
from api import settings
from api.utils.api_utils import timeout
from rag.app.tag import label_question
from rag.prompts import kb_prompt
from rag.prompts.prompts import cross_languages
class RetrievalParam(ToolParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "search_my_dateset",
"description": "This tool can be utilized for relevant content searching in the datasets.",
"parameters": {
"query": {
"type": "string",
"description": "The keywords to search the dataset. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "",
"required": True
}
}
}
super().__init__()
self.function_name = "search_my_dateset"
self.description = "This tool can be utilized for relevant content searching in the datasets."
self.similarity_threshold = 0.2
self.keywords_similarity_weight = 0.5
self.top_n = 8
self.top_k = 1024
self.kb_ids = []
self.kb_vars = []
self.rerank_id = ""
self.empty_response = ""
self.use_kg = False
self.cross_languages = []
def check(self):
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keyword similarity weight")
self.check_positive_number(self.top_n, "[Retrieval] Top N")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class Retrieval(ToolBase, ABC):
component_name = "Retrieval"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
kb_ids: list[str] = []
for id in self._param.kb_ids:
if id.find("@") < 0:
kb_ids.append(id)
continue
kb_nm = self._canvas.get_variable_value(id)
e, kb = KnowledgebaseService.get_by_name(kb_nm)
if not e:
raise Exception(f"Dataset({kb_nm}) does not exist.")
kb_ids.append(kb.id)
filtered_kb_ids: list[str] = list(set([kb_id for kb_id in kb_ids if kb_id]))
kbs = KnowledgebaseService.get_by_ids(filtered_kb_ids)
if not kbs:
raise Exception("No dataset is selected.")
embd_nms = list(set([kb.embd_id for kb in kbs]))
assert len(embd_nms) == 1, "Knowledge bases use different embedding models."
embd_mdl = None
if embd_nms:
embd_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, embd_nms[0])
rerank_mdl = None
if self._param.rerank_id:
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
query = kwargs["query"]
if self._param.cross_languages:
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
kbinfos = settings.retrievaler.retrieval(
query,
embd_mdl,
[kb.tenant_id for kb in kbs],
filtered_kb_ids,
1,
self._param.top_n,
self._param.similarity_threshold,
1 - self._param.keywords_similarity_weight,
aggs=False,
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
)
if self._param.use_kg:
ck = settings.kg_retrievaler.retrieval(query,
[kb.tenant_id for kb in kbs],
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
else:
kbinfos = {"chunks": [], "doc_aggs": []}
if self._param.use_kg and kbs:
ck = settings.kg_retrievaler.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if ck["content_with_weight"]:
ck["content"] = ck["content_with_weight"]
del ck["content_with_weight"]
kbinfos["chunks"].insert(0, ck)
for ck in kbinfos["chunks"]:
if "vector" in ck:
del ck["vector"]
if "content_ltks" in ck:
del ck["content_ltks"]
if not kbinfos["chunks"]:
self.set_output("formalized_content", self._param.empty_response)
return
self._canvas.add_refernce(kbinfos["chunks"], kbinfos["doc_aggs"])
form_cnt = "\n".join(kb_prompt(kbinfos, 200000, True))
self.set_output("formalized_content", form_cnt)
return form_cnt

218
agent/tools/tavily.py Normal file
View File

@ -0,0 +1,218 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
from tavily import TavilyClient
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from api.utils.api_utils import timeout
class TavilySearchParam(ToolParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "tavily_search",
"description": """
Tavily is a search engine optimized for LLMs, aimed at efficient, quick and persistent search results.
When searching:
- Start with specific query which should focus on just a single aspect.
- Number of keywords in query should be less than 5.
- Broaden search terms if needed
- Cross-reference information from multiple sources
""",
"parameters": {
"query": {
"type": "string",
"description": "The search keywords to execute with Tavily. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
},
"topic": {
"type": "string",
"description": "default:general. The category of the search.news is useful for retrieving real-time updates, particularly about politics, sports, and major current events covered by mainstream media sources. general is for broader, more general-purpose searches that may include a wide range of sources.",
"enum": ["general", "news"],
"default": "general",
"required": False,
},
"include_domains": {
"type": "array",
"description": "default:[]. A list of domains only from which the search results can be included.",
"default": [],
"items": {
"type": "string",
"description": "Domain name that must be included, e.g. www.yahoo.com"
},
"required": False
},
"exclude_domains": {
"type": "array",
"description": "default:[]. A list of domains from which the search results can not be included",
"default": [],
"items": {
"type": "string",
"description": "Domain name that must be excluded, e.g. www.yahoo.com"
},
"required": False
},
}
}
super().__init__()
self.api_key = ""
self.search_depth = "basic" # basic/advanced
self.max_results = 6
self.days = 14
self.include_answer = False
self.include_raw_content = False
self.include_images = False
self.include_image_descriptions = False
def check(self):
self.check_valid_value(self.topic, "Tavily topic: should be in 'general/news'", ["general", "news"])
self.check_valid_value(self.search_depth, "Tavily search depth should be in 'basic/advanced'", ["basic", "advanced"])
self.check_positive_integer(self.max_results, "Tavily max result number should be within [1 20]")
self.check_positive_integer(self.days, "Tavily days should be greater than 1")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class TavilySearch(ToolBase, ABC):
component_name = "TavilySearch"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None
for fld in ["search_depth", "topic", "max_results", "days", "include_answer", "include_raw_content", "include_images", "include_image_descriptions", "include_domains", "exclude_domains"]:
if fld not in kwargs:
kwargs[fld] = getattr(self._param, fld)
for _ in range(self._param.max_retries+1):
try:
kwargs["include_images"] = False
kwargs["include_raw_content"] = False
res = self.tavily_client.search(**kwargs)
self._retrieve_chunks(res["results"],
get_title=lambda r: r["title"],
get_url=lambda r: r["url"],
get_content=lambda r: r["raw_content"] if r["raw_content"] else r["content"],
get_score=lambda r: r["score"])
self.set_output("json", res["results"])
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"Tavily error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"Tavily error: {last_e}"
assert False, self.output()
class TavilyExtractParam(ToolParamBase):
"""
Define the Retrieval component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "tavily_extract",
"description": "Extract web page content from one or more specified URLs using Tavily Extract.",
"parameters": {
"urls": {
"type": "array",
"description": "The URLs to extract content from.",
"default": "",
"items": {
"type": "string",
"description": "The URL to extract content from, e.g. www.yahoo.com"
},
"required": True
},
"extract_depth": {
"type": "string",
"description": "The depth of the extraction process. advanced extraction retrieves more data, including tables and embedded content, with higher success but may increase latency.basic extraction costs 1 credit per 5 successful URL extractions, while advanced extraction costs 2 credits per 5 successful URL extractions.",
"enum": ["basic", "advanced"],
"default": "basic",
"required": False,
},
"format": {
"type": "string",
"description": "The format of the extracted web page content. markdown returns content in markdown format. text returns plain text and may increase latency.",
"enum": ["markdown", "text"],
"default": "markdown",
"required": False,
}
}
}
super().__init__()
self.api_key = ""
self.extract_depth = "basic" # basic/advanced
self.urls = []
self.format = "markdown"
self.include_images = False
def check(self):
self.check_valid_value(self.extract_depth, "Tavily extract depth should be in 'basic/advanced'", ["basic", "advanced"])
self.check_valid_value(self.format, "Tavily extract format should be in 'markdown/text'", ["markdown", "text"])
def get_input_form(self) -> dict[str, dict]:
return {
"urls": {
"name": "URLs",
"type": "line"
}
}
class TavilyExtract(ToolBase, ABC):
component_name = "TavilyExtract"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))
def _invoke(self, **kwargs):
self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None
for fld in ["urls", "extract_depth", "format"]:
if fld not in kwargs:
kwargs[fld] = getattr(self._param, fld)
if kwargs.get("urls") and isinstance(kwargs["urls"], str):
kwargs["urls"] = kwargs["urls"].split(",")
for _ in range(self._param.max_retries+1):
try:
kwargs["include_images"] = False
res = self.tavily_client.extract(**kwargs)
self.set_output("json", res["results"])
return self.output("json")
except Exception as e:
last_e = e
logging.exception(f"Tavily error: {e}")
if last_e:
self.set_output("_ERROR", str(last_e))
return f"Tavily error: {last_e}"
assert False, self.output()

111
agent/tools/wencai.py Normal file
View File

@ -0,0 +1,111 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
import pandas as pd
import pywencai
from agent.tools.base import ToolParamBase, ToolMeta, ToolBase
from api.utils.api_utils import timeout
class WenCaiParam(ToolParamBase):
"""
Define the WenCai component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "iwencai",
"description": """
iwencai search: search platform is committed to providing hundreds of millions of investors with the most timely, accurate and comprehensive information, covering news, announcements, research reports, blogs, forums, Weibo, characters, etc.
robo-advisor intelligent stock selection platform: through AI technology, is committed to providing investors with intelligent stock selection, quantitative investment, main force tracking, value investment, technical analysis and other types of stock selection technologies.
fund selection platform: through AI technology, is committed to providing excellent fund, value investment, quantitative analysis and other fund selection technologies for foundation citizens.
""",
"parameters": {
"query": {
"type": "string",
"description": "The question/conditions to select stocks.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 10
self.query_type = "stock"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.query_type, "Query type",
['stock', 'zhishu', 'fund', 'hkstock', 'usstock', 'threeboard', 'conbond', 'insurance',
'futures', 'lccp',
'foreign_exchange'])
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class WenCai(ToolBase, ABC):
component_name = "WenCai"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("report", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
wencai_res = []
res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n)
if isinstance(res, pd.DataFrame):
wencai_res.append(res.to_markdown())
elif isinstance(res, dict):
for item in res.items():
if isinstance(item[1], list):
wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown())
elif isinstance(item[1], str):
wencai_res.append(item[0] + "\n" + item[1])
elif isinstance(item[1], dict):
if "meta" in item[1].keys():
continue
wencai_res.append(pd.DataFrame.from_dict(item[1], orient='index').to_markdown())
elif isinstance(item[1], pd.DataFrame):
if "image_url" in item[1].columns:
continue
wencai_res.append(item[1].to_markdown())
else:
wencai_res.append(item[0] + "\n" + str(item[1]))
self.set_output("report", "\n\n".join(wencai_res))
return self.output("report")
except Exception as e:
last_e = e
logging.exception(f"WenCai error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"WenCai error: {last_e}"
assert False, self.output()

98
agent/tools/wikipedia.py Normal file
View File

@ -0,0 +1,98 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
import wikipedia
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
class WikipediaParam(ToolParamBase):
"""
Define the Wikipedia component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "wikipedia_search",
"description": """A wide range of how-to and information pages are made available in wikipedia. Since 2001, it has grown rapidly to become the world's largest reference website. From Wikipedia, the free encyclopedia.""",
"parameters": {
"query": {
"type": "string",
"description": "The search keyword to execute with wikipedia. The keyword MUST be a specific subject that can match the title.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.top_n = 10
self.language = "en"
def check(self):
self.check_positive_integer(self.top_n, "Top N")
self.check_valid_value(self.language, "Wikipedia languages",
['af', 'pl', 'ar', 'ast', 'az', 'bg', 'nan', 'bn', 'be', 'ca', 'cs', 'cy', 'da', 'de',
'et', 'el', 'en', 'es', 'eo', 'eu', 'fa', 'fr', 'gl', 'ko', 'hy', 'hi', 'hr', 'id',
'it', 'he', 'ka', 'lld', 'la', 'lv', 'lt', 'hu', 'mk', 'arz', 'ms', 'min', 'my', 'nl',
'ja', 'nb', 'nn', 'ce', 'uz', 'pt', 'kk', 'ro', 'ru', 'ceb', 'sk', 'sl', 'sr', 'sh',
'fi', 'sv', 'ta', 'tt', 'th', 'tg', 'azb', 'tr', 'uk', 'ur', 'vi', 'war', 'zh', 'yue'])
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
}
}
class Wikipedia(ToolBase, ABC):
component_name = "Wikipedia"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
def _invoke(self, **kwargs):
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
try:
wikipedia.set_lang(self._param.language)
wiki_engine = wikipedia
pages = []
for p in wiki_engine.search(kwargs["query"], results=self._param.top_n):
try:
pages.append(wikipedia.page(p))
except Exception:
pass
self._retrieve_chunks(pages,
get_title=lambda r: r.title,
get_url=lambda r: r.url,
get_content=lambda r: r.summary)
return self.output("formalized_content")
except Exception as e:
last_e = e
logging.exception(f"Wikipedia error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"Wikipedia error: {last_e}"
assert False, self.output()

111
agent/tools/yahoofinance.py Normal file
View File

@ -0,0 +1,111 @@
#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import os
import time
from abc import ABC
import pandas as pd
import yfinance as yf
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from api.utils.api_utils import timeout
class YahooFinanceParam(ToolParamBase):
"""
Define the YahooFinance component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "yahoo_finance",
"description": "The Yahoo Finance is a service that provides access to real-time and historical stock market data. It enables users to fetch various types of stock information, such as price quotes, historical prices, company profiles, and financial news. The API offers structured data, allowing developers to integrate market data into their applications and analysis tools.",
"parameters": {
"stock_code": {
"type": "string",
"description": "The stock code or company name.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.info = True
self.history = False
self.count = False
self.financials = False
self.income_stmt = False
self.balance_sheet = False
self.cash_flow_statement = False
self.news = True
def check(self):
self.check_boolean(self.info, "get all stock info")
self.check_boolean(self.history, "get historical market data")
self.check_boolean(self.count, "show share count")
self.check_boolean(self.financials, "show financials")
self.check_boolean(self.income_stmt, "income statement")
self.check_boolean(self.balance_sheet, "balance sheet")
self.check_boolean(self.cash_flow_statement, "cash flow statement")
self.check_boolean(self.news, "show news")
def get_input_form(self) -> dict[str, dict]:
return {
"stock_code": {
"name": "Stock code/Company name",
"type": "line"
}
}
class YahooFinance(ToolBase, ABC):
component_name = "YahooFinance"
@timeout(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))
def _invoke(self, **kwargs):
if not kwargs.get("stock_code"):
self.set_output("report", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
yohoo_res = []
try:
msft = yf.Ticker(kwargs["stock_code"])
if self._param.info:
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
if self._param.history:
yohoo_res.append("# History:\n" + msft.history().to_markdown() + "\n")
if self._param.financials:
yohoo_res.append("# Calendar:\n" + pd.DataFrame(msft.calendar).to_markdown() + "\n")
if self._param.balance_sheet:
yohoo_res.append("# Balance sheet:\n" + msft.balance_sheet.to_markdown() + "\n")
yohoo_res.append("# Quarterly balance sheet:\n" + msft.quarterly_balance_sheet.to_markdown() + "\n")
if self._param.cash_flow_statement:
yohoo_res.append("# Cash flow statement:\n" + msft.cashflow.to_markdown() + "\n")
yohoo_res.append("# Quarterly cash flow statement:\n" + msft.quarterly_cashflow.to_markdown() + "\n")
if self._param.news:
yohoo_res.append("# News:\n" + pd.DataFrame(msft.news).to_markdown() + "\n")
self.set_output("report", "\n\n".join(yohoo_res))
return self.output("report")
except Exception as e:
last_e = e
logging.exception(f"YahooFinance error: {e}")
time.sleep(self._param.delay_after_error)
if last_e:
self.set_output("_ERROR", str(last_e))
return f"YahooFinance error: {last_e}"
assert False, self.output()