From 92135686929a336b141763f5e1985cc73c7eadcd Mon Sep 17 00:00:00 2001 From: Yongteng Lei Date: Tue, 11 Nov 2025 17:36:48 +0800 Subject: [PATCH] Feat: add mechanism to check cancellation in Agent (#10766) ### What problem does this PR solve? Add mechanism to check cancellation in Agent. ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- agent/canvas.py | 36 +++++++++++++++++++++++++++-- agent/component/agent_with_tools.py | 17 ++++++++++++++ agent/component/base.py | 14 +++++++++++ agent/component/begin.py | 6 +++++ agent/component/categorize.py | 11 +++++++++ agent/component/fillup.py | 7 ++++-- agent/component/invoke.py | 9 ++++++++ agent/component/iteration.py | 3 +++ agent/component/iterationitem.py | 10 +++++++- agent/component/llm.py | 14 ++++++++++- agent/component/message.py | 17 +++++++++++++- agent/component/string_transform.py | 10 ++++++++ agent/component/switch.py | 11 ++++++++- agent/tools/arxiv.py | 16 ++++++++++++- agent/tools/base.py | 5 +++- agent/tools/code_exec.py | 29 +++++++++++++++++++++++ agent/tools/crawler.py | 12 +++++++--- agent/tools/deepl.py | 7 ++++++ agent/tools/duckduckgo.py | 23 ++++++++++++++++++ agent/tools/email.py | 17 +++++++++++++- agent/tools/exesql.py | 23 ++++++++++++++++++ agent/tools/github.py | 13 +++++++++++ agent/tools/google.py | 13 +++++++++++ agent/tools/googlescholar.py | 13 +++++++++++ agent/tools/jin10.py | 20 ++++++++++++++++ agent/tools/pubmed.py | 17 ++++++++++++++ agent/tools/qweather.py | 19 +++++++++++++++ agent/tools/retrieval.py | 15 +++++++++++- agent/tools/searxng.py | 18 +++++++++++++++ agent/tools/tavily.py | 24 +++++++++++++++++++ agent/tools/tushare.py | 12 ++++++++++ agent/tools/wencai.py | 15 ++++++++++++ agent/tools/wikipedia.py | 12 ++++++++++ agent/tools/yahoofinance.py | 12 ++++++++++ api/apps/canvas_app.py | 9 +++++--- rag/svr/task_executor.py | 6 +++-- 36 files changed, 495 insertions(+), 20 deletions(-) diff --git a/agent/canvas.py b/agent/canvas.py index cdb9233c4..72dce5d33 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -26,7 +26,9 @@ from typing import Any, Union, Tuple from agent.component import component_class from agent.component.base import ComponentBase from api.db.services.file_service import FileService +from api.db.services.task_service import has_canceled from common.misc_utils import get_uuid, hash_str2int +from common.exceptions import TaskCanceledException from rag.prompts.generator import chunks_format from rag.utils.redis_conn import REDIS_CONN @@ -126,6 +128,7 @@ class Graph: self.components[k]["obj"].reset() try: REDIS_CONN.delete(f"{self.task_id}-logs") + REDIS_CONN.delete(f"{self.task_id}-cancel") except Exception as e: logging.exception(e) @@ -196,7 +199,7 @@ class Graph: if not rest: return root_val return self.get_variable_param_value(root_val,rest) - + def get_variable_param_value(self, obj: Any, path: str) -> Any: cur = obj if not path: @@ -215,6 +218,17 @@ class Graph: cur = getattr(cur, key, None) return cur + def is_canceled(self) -> bool: + return has_canceled(self.task_id) + + def cancel_task(self) -> bool: + try: + REDIS_CONN.set(f"{self.task_id}-cancel", "x") + except Exception as e: + logging.exception(e) + return False + return True + class Canvas(Graph): @@ -239,7 +253,7 @@ class Canvas(Graph): "sys.conversation_turns": 0, "sys.files": [] } - + self.retrieval = self.dsl["retrieval"] self.memory = self.dsl.get("memory", []) @@ -311,10 +325,20 @@ class Canvas(Graph): self.path.append("begin") self.retrieval.append({"chunks": [], "doc_aggs": []}) + if self.is_canceled(): + msg = f"Task {self.task_id} has been canceled before starting." + logging.info(msg) + raise TaskCanceledException(msg) + yield decorate("workflow_started", {"inputs": kwargs.get("inputs")}) self.retrieval.append({"chunks": {}, "doc_aggs": {}}) def _run_batch(f, t): + if self.is_canceled(): + msg = f"Task {self.task_id} has been canceled during batch execution." + logging.info(msg) + raise TaskCanceledException(msg) + with ThreadPoolExecutor(max_workers=5) as executor: thr = [] i = f @@ -473,6 +497,14 @@ class Canvas(Graph): "created_at": st, }) self.history.append(("assistant", self.get_component_obj(self.path[-1]).output())) + elif "Task has been canceled" in self.error: + yield decorate("workflow_finished", + { + "inputs": kwargs.get("inputs"), + "outputs": "Task has been canceled", + "elapsed_time": time.perf_counter() - st, + "created_at": st, + }) def is_reff(self, exp: str) -> bool: exp = exp.strip("{").strip("}") diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 53c2ed9ad..98dfbc92f 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -139,6 +139,9 @@ class Agent(LLM, ToolBase): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 20*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("Agent processing"): + return + if kwargs.get("user_prompt"): usr_pmt = "" if kwargs.get("reasoning"): @@ -152,6 +155,8 @@ class Agent(LLM, ToolBase): self._param.prompts = [{"role": "user", "content": usr_pmt}] if not self.tools: + if self.check_if_canceled("Agent processing"): + return return LLM._invoke(self, **kwargs) prompt, msg, user_defined_prompt = self._prepare_prompt_variables() @@ -171,6 +176,8 @@ class Agent(LLM, ToolBase): use_tools = [] ans = "" for delta_ans, tk in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + if self.check_if_canceled("Agent processing"): + return ans += delta_ans if ans.find("**ERROR**") >= 0: @@ -191,12 +198,16 @@ class Agent(LLM, ToolBase): answer_without_toolcall = "" use_tools = [] for delta_ans,_ in self._react_with_tools_streamly(prompt, msg, use_tools, user_defined_prompt): + if self.check_if_canceled("Agent streaming"): + return + if delta_ans.find("**ERROR**") >= 0: if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) yield self.get_exception_default_value() else: self.set_output("_ERROR", delta_ans) + return answer_without_toolcall += delta_ans yield delta_ans @@ -271,6 +282,8 @@ class Agent(LLM, ToolBase): st = timer() txt = "" for delta_ans in self._gen_citations(entire_txt): + if self.check_if_canceled("Agent streaming"): + return yield delta_ans, 0 txt += delta_ans @@ -286,6 +299,8 @@ class Agent(LLM, ToolBase): task_desc = analyze_task(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt) self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st) for _ in range(self._param.max_rounds + 1): + if self.check_if_canceled("Agent streaming"): + return response, tk = next_step(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt) # self.callback("next_step", {}, str(response)[:256]+"...") token_count += tk @@ -333,6 +348,8 @@ Instructions: 6. Focus on delivering VALUE with the information already gathered Respond immediately with your final comprehensive answer. """ + if self.check_if_canceled("Agent final instruction"): + return append_user_content(hist, final_instruction) for txt, tkcnt in complete(): diff --git a/agent/component/base.py b/agent/component/base.py index cfec4ac1f..31ad46820 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -417,6 +417,20 @@ class ComponentBase(ABC): self._param = param self._param.check() + def is_canceled(self) -> bool: + return self._canvas.is_canceled() + + def check_if_canceled(self, message: str = "") -> bool: + if self.is_canceled(): + task_id = getattr(self._canvas, 'task_id', 'unknown') + log_message = f"Task {task_id} has been canceled" + if message: + log_message += f" during {message}" + logging.info(log_message) + self.set_output("_ERROR", "Task has been canceled") + return True + return False + def invoke(self, **kwargs) -> dict[str, Any]: self.set_output("_created_time", time.perf_counter()) try: diff --git a/agent/component/begin.py b/agent/component/begin.py index 159f0f5d7..b5985bb7a 100644 --- a/agent/component/begin.py +++ b/agent/component/begin.py @@ -37,7 +37,13 @@ class Begin(UserFillUp): component_name = "Begin" def _invoke(self, **kwargs): + if self.check_if_canceled("Begin processing"): + return + for k, v in kwargs.get("inputs", {}).items(): + if self.check_if_canceled("Begin processing"): + return + if isinstance(v, dict) and v.get("type", "").lower().find("file") >=0: if v.get("optional") and v.get("value", None) is None: v = None diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 3534225d9..1333889bb 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -98,6 +98,9 @@ class Categorize(LLM, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("Categorize processing"): + return + msg = self._canvas.get_history(self._param.message_history_window_size) if not msg: msg = [{"role": "user", "content": ""}] @@ -114,10 +117,18 @@ class Categorize(LLM, ABC): ---- Real Data ---- {} → """.format(" | ".join(["{}: \"{}\"".format(c["role"].upper(), re.sub(r"\n", "", c["content"], flags=re.DOTALL)) for c in msg])) + + if self.check_if_canceled("Categorize processing"): + return + ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) logging.info(f"input: {user_prompt}, answer: {str(ans)}") if ERROR_PREFIX in ans: raise Exception(ans) + + if self.check_if_canceled("Categorize processing"): + return + # Count the number of times each category appears in the answer. category_counts = {} for c in self._param.category_description.keys(): diff --git a/agent/component/fillup.py b/agent/component/fillup.py index 7d27280c5..7428912d4 100644 --- a/agent/component/fillup.py +++ b/agent/component/fillup.py @@ -35,6 +35,9 @@ class UserFillUp(ComponentBase): component_name = "UserFillUp" def _invoke(self, **kwargs): + if self.check_if_canceled("UserFillUp processing"): + return + if self._param.enable_tips: content = self._param.tips for k, v in self.get_input_elements_from_text(self._param.tips).items(): @@ -58,9 +61,9 @@ class UserFillUp(ComponentBase): self.set_output("tips", content) for k, v in kwargs.get("inputs", {}).items(): + if self.check_if_canceled("UserFillUp processing"): + return self.set_output(k, v) def thoughts(self) -> str: return "Waiting for your input..." - - diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 00a39b905..61ebe2b39 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -56,6 +56,9 @@ class Invoke(ComponentBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) def _invoke(self, **kwargs): + if self.check_if_canceled("Invoke processing"): + return + args = {} for para in self._param.variables: if para.get("value"): @@ -89,6 +92,9 @@ class Invoke(ComponentBase, ABC): last_e = "" for _ in range(self._param.max_retries + 1): + if self.check_if_canceled("Invoke processing"): + return + try: if method == "get": response = requests.get(url=url, params=args, headers=headers, proxies=proxies, timeout=self._param.timeout) @@ -121,6 +127,9 @@ class Invoke(ComponentBase, ABC): return self.output("result") except Exception as e: + if self.check_if_canceled("Invoke processing"): + return + last_e = e logging.exception(f"Http request error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/component/iteration.py b/agent/component/iteration.py index a6065a281..a39147d8f 100644 --- a/agent/component/iteration.py +++ b/agent/component/iteration.py @@ -56,6 +56,9 @@ class Iteration(ComponentBase, ABC): return cid def _invoke(self, **kwargs): + if self.check_if_canceled("Iteration processing"): + return + arr = self._canvas.get_variable_value(self._param.items_ref) if not isinstance(arr, list): self.set_output("_ERROR", self._param.items_ref + " must be an array, but its type is "+str(type(arr))) diff --git a/agent/component/iterationitem.py b/agent/component/iterationitem.py index 6c4d0bae7..83713aedb 100644 --- a/agent/component/iterationitem.py +++ b/agent/component/iterationitem.py @@ -33,6 +33,9 @@ class IterationItem(ComponentBase, ABC): self._idx = 0 def _invoke(self, **kwargs): + if self.check_if_canceled("IterationItem processing"): + return + parent = self.get_parent() arr = self._canvas.get_variable_value(parent._param.items_ref) if not isinstance(arr, list): @@ -40,12 +43,17 @@ class IterationItem(ComponentBase, ABC): raise Exception(parent._param.items_ref + " must be an array, but its type is "+str(type(arr))) if self._idx > 0: + if self.check_if_canceled("IterationItem processing"): + return self.output_collation() if self._idx >= len(arr): self._idx = -1 return + if self.check_if_canceled("IterationItem processing"): + return + self.set_output("item", arr[self._idx]) self.set_output("index", self._idx) @@ -80,4 +88,4 @@ class IterationItem(ComponentBase, ABC): return self._idx == -1 def thoughts(self) -> str: - return "Next turn..." \ No newline at end of file + return "Next turn..." diff --git a/agent/component/llm.py b/agent/component/llm.py index c8383835b..6ce0f65a5 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -207,6 +207,9 @@ class LLM(ComponentBase): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("LLM processing"): + return + def clean_formated_answer(ans: str) -> str: ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*```json", "", ans, flags=re.DOTALL) @@ -223,6 +226,9 @@ class LLM(ComponentBase): schema=json.dumps(output_structure, ensure_ascii=False, indent=2) prompt += structured_output_prompt(schema) for _ in range(self._param.max_retries+1): + if self.check_if_canceled("LLM processing"): + return + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) error = "" ans = self._generate(msg) @@ -248,6 +254,9 @@ class LLM(ComponentBase): return for _ in range(self._param.max_retries+1): + if self.check_if_canceled("LLM processing"): + return + _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) error = "" ans = self._generate(msg) @@ -269,6 +278,9 @@ class LLM(ComponentBase): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) answer = "" for ans in self._generate_streamly(msg): + if self.check_if_canceled("LLM streaming"): + return + if ans.find("**ERROR**") >= 0: if self.get_exception_default_value(): self.set_output("content", self.get_exception_default_value()) @@ -287,4 +299,4 @@ class LLM(ComponentBase): def thoughts(self) -> str: _, msg,_ = self._prepare_prompt_variables() - return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move." \ No newline at end of file + return "⌛Give me a moment—starting from: \n\n" + re.sub(r"(User's query:|[\\]+)", '', msg[-1]['content'], flags=re.DOTALL) + "\n\nI’ll figure out our best next move." diff --git a/agent/component/message.py b/agent/component/message.py index 1a506325e..641198083 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -89,6 +89,9 @@ class Message(ComponentBase): all_content = "" cache = {} for r in re.finditer(self.variable_ref_patt, rand_cnt, flags=re.DOTALL): + if self.check_if_canceled("Message streaming"): + return + all_content += rand_cnt[s: r.start()] yield rand_cnt[s: r.start()] s = r.end() @@ -104,6 +107,9 @@ class Message(ComponentBase): if isinstance(v, partial): cnt = "" for t in v(): + if self.check_if_canceled("Message streaming"): + return + all_content += t cnt += t yield t @@ -111,7 +117,7 @@ class Message(ComponentBase): continue elif not isinstance(v, str): try: - v = json.dumps(v, ensure_ascii=False, indent=2) + v = json.dumps(v, ensure_ascii=False) except Exception: v = str(v) yield v @@ -120,6 +126,9 @@ class Message(ComponentBase): cache[exp] = v if s < len(rand_cnt): + if self.check_if_canceled("Message streaming"): + return + all_content += rand_cnt[s: ] yield rand_cnt[s: ] @@ -133,6 +142,9 @@ class Message(ComponentBase): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("Message processing"): + return + rand_cnt = random.choice(self._param.content) if self._param.stream and not self._is_jinjia2(rand_cnt): self.set_output("content", partial(self._stream, rand_cnt)) @@ -145,6 +157,9 @@ class Message(ComponentBase): except Exception: pass + if self.check_if_canceled("Message processing"): + return + for n, v in kwargs.items(): content = re.sub(n, v, content) diff --git a/agent/component/string_transform.py b/agent/component/string_transform.py index 08e44d2e0..444161f72 100644 --- a/agent/component/string_transform.py +++ b/agent/component/string_transform.py @@ -63,17 +63,24 @@ class StringTransform(Message, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("StringTransform processing"): + return + if self._param.method == "split": self._split(kwargs.get("line")) else: self._merge(kwargs) def _split(self, line:str|None = None): + if self.check_if_canceled("StringTransform split processing"): + return + var = self._canvas.get_variable_value(self._param.split_ref) if not line else line if not var: var = "" assert isinstance(var, str), "The input variable is not a string: {}".format(type(var)) self.set_input_value(self._param.split_ref, var) + res = [] for i,s in enumerate(re.split(r"(%s)"%("|".join([re.escape(d) for d in self._param.delimiters])), var, flags=re.DOTALL)): if i % 2 == 1: @@ -82,6 +89,9 @@ class StringTransform(Message, ABC): self.set_output("result", res) def _merge(self, kwargs:dict[str, str] = {}): + if self.check_if_canceled("StringTransform merge processing"): + return + script = self._param.script script, kwargs = self.get_kwargs(script, kwargs, self._param.delimiters[0]) diff --git a/agent/component/switch.py b/agent/component/switch.py index 41c25c32f..85e6cd03b 100644 --- a/agent/component/switch.py +++ b/agent/component/switch.py @@ -63,9 +63,18 @@ class Switch(ComponentBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3))) def _invoke(self, **kwargs): + if self.check_if_canceled("Switch processing"): + return + for cond in self._param.conditions: + if self.check_if_canceled("Switch processing"): + return + res = [] for item in cond["items"]: + if self.check_if_canceled("Switch processing"): + return + if not item["cpn_id"]: continue cpn_v = self._canvas.get_variable_value(item["cpn_id"]) @@ -128,4 +137,4 @@ class Switch(ComponentBase, ABC): raise ValueError('Not supported operator' + operator) def thoughts(self) -> str: - return "I’m weighing a few options and will pick the next step shortly." \ No newline at end of file + return "I’m weighing a few options and will pick the next step shortly." diff --git a/agent/tools/arxiv.py b/agent/tools/arxiv.py index 74a810c74..10d502c56 100644 --- a/agent/tools/arxiv.py +++ b/agent/tools/arxiv.py @@ -63,12 +63,18 @@ class ArXiv(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("ArXiv processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("ArXiv processing"): + return + try: sort_choices = {"relevance": arxiv.SortCriterion.Relevance, "lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate, @@ -79,12 +85,20 @@ class ArXiv(ToolBase, ABC): max_results=self._param.top_n, sort_by=sort_choices[self._param.sort_by] ) - self._retrieve_chunks(list(arxiv_client.results(search)), + results = list(arxiv_client.results(search)) + + if self.check_if_canceled("ArXiv processing"): + return + + self._retrieve_chunks(results, get_title=lambda r: r.title, get_url=lambda r: r.pdf_url, get_content=lambda r: r.summary) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("ArXiv processing"): + return + last_e = e logging.exception(f"ArXiv error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/base.py b/agent/tools/base.py index 93bde20aa..a3d569694 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -125,6 +125,9 @@ class ToolBase(ComponentBase): return self._param.get_meta() def invoke(self, **kwargs): + if self.check_if_canceled("Tool processing"): + return + self.set_output("_created_time", time.perf_counter()) try: res = self._invoke(**kwargs) @@ -170,4 +173,4 @@ class ToolBase(ComponentBase): self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True))) def thoughts(self) -> str: - return self._canvas.get_component_name(self._id) + " is running..." \ No newline at end of file + return self._canvas.get_component_name(self._id) + " is running..." diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index 7145d8b89..adba4168e 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -131,10 +131,14 @@ class CodeExec(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("CodeExec processing"): + return + lang = kwargs.get("lang", self._param.lang) script = kwargs.get("script", self._param.script) arguments = {} for k, v in self._param.arguments.items(): + if kwargs.get(k): arguments[k] = kwargs[k] continue @@ -149,15 +153,28 @@ class CodeExec(ToolBase, ABC): def _execute_code(self, language: str, code: str, arguments: dict): import requests + if self.check_if_canceled("CodeExec execution"): + return + try: code_b64 = self._encode_code(code) code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump() except Exception as e: + if self.check_if_canceled("CodeExec execution"): + return + self.set_output("_ERROR", "construct code request error: " + str(e)) try: + if self.check_if_canceled("CodeExec execution"): + return "Task has been canceled" + resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:") + + if self.check_if_canceled("CodeExec execution"): + return "Task has been canceled" + if resp.status_code != 200: resp.raise_for_status() body = resp.json() @@ -173,16 +190,25 @@ class CodeExec(ToolBase, ABC): logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}") if isinstance(rt, tuple): for i, (k, o) in enumerate(self._param.outputs.items()): + if self.check_if_canceled("CodeExec execution"): + return + if k.find("_") == 0: continue o["value"] = rt[i] elif isinstance(rt, dict): for i, (k, o) in enumerate(self._param.outputs.items()): + if self.check_if_canceled("CodeExec execution"): + return + if k not in rt or k.find("_") == 0: continue o["value"] = rt[k] else: for i, (k, o) in enumerate(self._param.outputs.items()): + if self.check_if_canceled("CodeExec execution"): + return + if k.find("_") == 0: continue o["value"] = rt @@ -190,6 +216,9 @@ class CodeExec(ToolBase, ABC): self.set_output("_ERROR", "There is no response from sandbox") except Exception as e: + if self.check_if_canceled("CodeExec execution"): + return + self.set_output("_ERROR", "Exception executing code: " + str(e)) return self.output() diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py index 869fae4a3..e4d049e1b 100644 --- a/agent/tools/crawler.py +++ b/agent/tools/crawler.py @@ -29,7 +29,7 @@ class CrawlerParam(ToolParamBase): super().__init__() self.proxy = None self.extract_type = "markdown" - + def check(self): self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content']) @@ -47,18 +47,24 @@ class Crawler(ToolBase, ABC): result = asyncio.run(self.get_web(ans)) return Crawler.be_output(result) - + except Exception as e: return Crawler.be_output(f"An unexpected error occurred: {str(e)}") async def get_web(self, url): + if self.check_if_canceled("Crawler async operation"): + return + proxy = self._param.proxy if self._param.proxy else None async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler: result = await crawler.arun( url=url, bypass_cache=True ) - + + if self.check_if_canceled("Crawler async operation"): + return + if self._param.extract_type == 'html': return result.cleaned_html elif self._param.extract_type == 'markdown': diff --git a/agent/tools/deepl.py b/agent/tools/deepl.py index 41d12341d..dc331aafe 100644 --- a/agent/tools/deepl.py +++ b/agent/tools/deepl.py @@ -46,11 +46,16 @@ class DeepL(ComponentBase, ABC): component_name = "DeepL" def _run(self, history, **kwargs): + if self.check_if_canceled("DeepL processing"): + return ans = self.get_input() ans = " - ".join(ans["content"]) if "content" in ans else "" if not ans: return DeepL.be_output("") + if self.check_if_canceled("DeepL processing"): + return + try: translator = deepl.Translator(self._param.auth_key) result = translator.translate_text(ans, source_lang=self._param.source_lang, @@ -58,4 +63,6 @@ class DeepL(ComponentBase, ABC): return DeepL.be_output(result.text) except Exception as e: + if self.check_if_canceled("DeepL processing"): + return DeepL.be_output("**Error**:" + str(e)) diff --git a/agent/tools/duckduckgo.py b/agent/tools/duckduckgo.py index fcf5ee077..fd2ec1801 100644 --- a/agent/tools/duckduckgo.py +++ b/agent/tools/duckduckgo.py @@ -75,17 +75,30 @@ class DuckDuckGo(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("DuckDuckGo processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("DuckDuckGo processing"): + return + try: if kwargs.get("topic", "general") == "general": with DDGS() as ddgs: + if self.check_if_canceled("DuckDuckGo processing"): + return + # {'title': '', 'href': '', 'body': ''} duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n) + + if self.check_if_canceled("DuckDuckGo processing"): + return + self._retrieve_chunks(duck_res, get_title=lambda r: r["title"], get_url=lambda r: r.get("href", r.get("url")), @@ -94,8 +107,15 @@ class DuckDuckGo(ToolBase, ABC): return self.output("formalized_content") else: with DDGS() as ddgs: + if self.check_if_canceled("DuckDuckGo processing"): + return + # {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''} duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n) + + if self.check_if_canceled("DuckDuckGo processing"): + return + self._retrieve_chunks(duck_res, get_title=lambda r: r["title"], get_url=lambda r: r.get("href", r.get("url")), @@ -103,6 +123,9 @@ class DuckDuckGo(ToolBase, ABC): self.set_output("json", duck_res) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("DuckDuckGo processing"): + return + last_e = e logging.exception(f"DuckDuckGo error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/email.py b/agent/tools/email.py index 42d3e2878..e19fd69c6 100644 --- a/agent/tools/email.py +++ b/agent/tools/email.py @@ -101,19 +101,27 @@ class Email(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) def _invoke(self, **kwargs): + if self.check_if_canceled("Email processing"): + return + if not kwargs.get("to_email"): self.set_output("success", False) return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("Email processing"): + return + try: # Parse JSON string passed from upstream email_data = kwargs # Validate required fields if "to_email" not in email_data: - return Email.be_output("Missing required field: to_email") + self.set_output("_ERROR", "Missing required field: to_email") + self.set_output("success", False) + return False # Create email object msg = MIMEMultipart('alternative') @@ -133,6 +141,9 @@ class Email(ToolBase, ABC): # Connect to SMTP server and send logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}") + if self.check_if_canceled("Email processing"): + return + context = smtplib.ssl.create_default_context() with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server: server.ehlo() @@ -149,6 +160,10 @@ class Email(ToolBase, ABC): # Send email logging.info(f"Sending email to recipients: {recipients}") + + if self.check_if_canceled("Email processing"): + return + try: server.send_message(msg, self._param.email, recipients) success = True diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index b5917e730..012b00d84 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -81,6 +81,8 @@ class ExeSQL(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) def _invoke(self, **kwargs): + if self.check_if_canceled("ExeSQL processing"): + return def convert_decimals(obj): from decimal import Decimal @@ -96,6 +98,9 @@ class ExeSQL(ToolBase, ABC): if not sql: raise Exception("SQL for `ExeSQL` MUST not be empty.") + if self.check_if_canceled("ExeSQL processing"): + return + vars = self.get_input_elements_from_text(sql) args = {} for k, o in vars.items(): @@ -108,6 +113,9 @@ class ExeSQL(ToolBase, ABC): self.set_input_value(k, args[k]) sql = self.string_format(sql, args) + if self.check_if_canceled("ExeSQL processing"): + return + sqls = sql.split(";") if self._param.db_type in ["mysql", "mariadb"]: db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, @@ -181,6 +189,10 @@ class ExeSQL(ToolBase, ABC): sql_res = [] formalized_content = [] for single_sql in sqls: + if self.check_if_canceled("ExeSQL processing"): + ibm_db.close(conn) + return + single_sql = single_sql.replace("```", "").strip() if not single_sql: continue @@ -190,6 +202,9 @@ class ExeSQL(ToolBase, ABC): rows = [] row = ibm_db.fetch_assoc(stmt) while row and len(rows) < self._param.max_records: + if self.check_if_canceled("ExeSQL processing"): + ibm_db.close(conn) + return rows.append(row) row = ibm_db.fetch_assoc(stmt) @@ -220,6 +235,11 @@ class ExeSQL(ToolBase, ABC): sql_res = [] formalized_content = [] for single_sql in sqls: + if self.check_if_canceled("ExeSQL processing"): + cursor.close() + db.close() + return + single_sql = single_sql.replace('```','') if not single_sql: continue @@ -244,6 +264,9 @@ class ExeSQL(ToolBase, ABC): sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) + cursor.close() + db.close() + self.set_output("json", sql_res) self.set_output("formalized_content", "\n\n".join(formalized_content)) return self.output("formalized_content") diff --git a/agent/tools/github.py b/agent/tools/github.py index 7b53f0b0b..f48ab0a2d 100644 --- a/agent/tools/github.py +++ b/agent/tools/github.py @@ -59,17 +59,27 @@ class GitHub(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("GitHub processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("GitHub processing"): + return + try: url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str( self._param.top_n) headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'} response = requests.get(url=url, headers=headers).json() + + if self.check_if_canceled("GitHub processing"): + return + self._retrieve_chunks(response['items'], get_title=lambda r: r["name"], get_url=lambda r: r["html_url"], @@ -77,6 +87,9 @@ class GitHub(ToolBase, ABC): self.set_output("json", response['items']) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("GitHub processing"): + return + last_e = e logging.exception(f"GitHub error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/google.py b/agent/tools/google.py index 3184aaaeb..312b5a1fe 100644 --- a/agent/tools/google.py +++ b/agent/tools/google.py @@ -118,6 +118,9 @@ class Google(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("Google processing"): + return + if not kwargs.get("q"): self.set_output("formalized_content", "") return "" @@ -132,8 +135,15 @@ class Google(ToolBase, ABC): } last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("Google processing"): + return + try: search = GoogleSearch(params).get_dict() + + if self.check_if_canceled("Google processing"): + return + self._retrieve_chunks(search["organic_results"], get_title=lambda r: r["title"], get_url=lambda r: r["link"], @@ -142,6 +152,9 @@ class Google(ToolBase, ABC): self.set_output("json", search["organic_results"]) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("Google processing"): + return + last_e = e logging.exception(f"Google error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/googlescholar.py b/agent/tools/googlescholar.py index da7f6ef69..b5c4eb395 100644 --- a/agent/tools/googlescholar.py +++ b/agent/tools/googlescholar.py @@ -65,15 +65,25 @@ class GoogleScholar(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("GoogleScholar processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("GoogleScholar processing"): + return + try: scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low, year_high=self._param.year_high, sort_by=self._param.sort_by) + + if self.check_if_canceled("GoogleScholar processing"): + return + self._retrieve_chunks(scholar_client, get_title=lambda r: r['bib']['title'], get_url=lambda r: r["pub_url"], @@ -82,6 +92,9 @@ class GoogleScholar(ToolBase, ABC): self.set_output("json", list(scholar_client)) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("GoogleScholar processing"): + return + last_e = e logging.exception(f"GoogleScholar error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/jin10.py b/agent/tools/jin10.py index 583a18286..b477dba81 100644 --- a/agent/tools/jin10.py +++ b/agent/tools/jin10.py @@ -50,6 +50,9 @@ class Jin10(ComponentBase, ABC): component_name = "Jin10" def _run(self, history, **kwargs): + if self.check_if_canceled("Jin10 processing"): + return + ans = self.get_input() ans = " - ".join(ans["content"]) if "content" in ans else "" if not ans: @@ -58,6 +61,9 @@ class Jin10(ComponentBase, ABC): jin10_res = [] headers = {'secret-key': self._param.secret_key} try: + if self.check_if_canceled("Jin10 processing"): + return + if self._param.type == "flash": params = { 'category': self._param.flash_type, @@ -69,6 +75,8 @@ class Jin10(ComponentBase, ABC): headers=headers, data=json.dumps(params)) response = response.json() for i in response['data']: + if self.check_if_canceled("Jin10 processing"): + return jin10_res.append({"content": i['data']['content']}) if self._param.type == "calendar": params = { @@ -79,6 +87,8 @@ class Jin10(ComponentBase, ABC): headers=headers, data=json.dumps(params)) response = response.json() + if self.check_if_canceled("Jin10 processing"): + return jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) if self._param.type == "symbols": params = { @@ -90,8 +100,12 @@ class Jin10(ComponentBase, ABC): url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, headers=headers, data=json.dumps(params)) response = response.json() + if self.check_if_canceled("Jin10 processing"): + return if self._param.symbols_datatype == "symbols": for i in response['data']: + if self.check_if_canceled("Jin10 processing"): + return i['Commodity Code'] = i['c'] i['Stock Exchange'] = i['e'] i['Commodity Name'] = i['n'] @@ -99,6 +113,8 @@ class Jin10(ComponentBase, ABC): del i['c'], i['e'], i['n'], i['t'] if self._param.symbols_datatype == "quotes": for i in response['data']: + if self.check_if_canceled("Jin10 processing"): + return i['Selling Price'] = i['a'] i['Buying Price'] = i['b'] i['Commodity Code'] = i['c'] @@ -120,8 +136,12 @@ class Jin10(ComponentBase, ABC): url='https://open-data-api.jin10.com/data-api/news', headers=headers, data=json.dumps(params)) response = response.json() + if self.check_if_canceled("Jin10 processing"): + return jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()}) except Exception as e: + if self.check_if_canceled("Jin10 processing"): + return return Jin10.be_output("**ERROR**: " + str(e)) if not jin10_res: diff --git a/agent/tools/pubmed.py b/agent/tools/pubmed.py index afa171768..05c222810 100644 --- a/agent/tools/pubmed.py +++ b/agent/tools/pubmed.py @@ -71,23 +71,40 @@ class PubMed(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("PubMed processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("PubMed processing"): + return + try: Entrez.email = self._param.email pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList'] + + if self.check_if_canceled("PubMed processing"): + return + pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids), retmode="xml").read().decode("utf-8"))) + + if self.check_if_canceled("PubMed processing"): + return + self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"), get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text, get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text, get_content=lambda child: self._format_pubmed_content(child),) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("PubMed processing"): + return + last_e = e logging.exception(f"PubMed error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/qweather.py b/agent/tools/qweather.py index 2c38a8b7e..a597c2c5b 100644 --- a/agent/tools/qweather.py +++ b/agent/tools/qweather.py @@ -58,12 +58,18 @@ class QWeather(ComponentBase, ABC): component_name = "QWeather" def _run(self, history, **kwargs): + if self.check_if_canceled("Qweather processing"): + return + ans = self.get_input() ans = "".join(ans["content"]) if "content" in ans else "" if not ans: return QWeather.be_output("") try: + if self.check_if_canceled("Qweather processing"): + return + response = requests.get( url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json() if response["code"] == "200": @@ -71,16 +77,23 @@ class QWeather(ComponentBase, ABC): else: return QWeather.be_output("**Error**" + self._param.error_code[response["code"]]) + if self.check_if_canceled("Qweather processing"): + return + base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/" if self._param.type == "weather": url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang response = requests.get(url=url).json() + if self.check_if_canceled("Qweather processing"): + return if response["code"] == "200": if self._param.time_period == "now": return QWeather.be_output(str(response["now"])) else: qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]] + if self.check_if_canceled("Qweather processing"): + return if not qweather_res: return QWeather.be_output("") @@ -92,6 +105,8 @@ class QWeather(ComponentBase, ABC): elif self._param.type == "indices": url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang response = requests.get(url=url).json() + if self.check_if_canceled("Qweather processing"): + return if response["code"] == "200": indices_res = response["daily"][0]["date"] + "\n" + "\n".join( [i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]]) @@ -103,9 +118,13 @@ class QWeather(ComponentBase, ABC): elif self._param.type == "airquality": url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang response = requests.get(url=url).json() + if self.check_if_canceled("Qweather processing"): + return if response["code"] == "200": return QWeather.be_output(str(response["now"])) else: return QWeather.be_output("**Error**" + self._param.error_code[response["code"]]) except Exception as e: + if self.check_if_canceled("Qweather processing"): + return return QWeather.be_output("**Error**" + str(e)) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 77b6145ed..ab388a08e 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -82,8 +82,12 @@ class Retrieval(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("Retrieval processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", self._param.empty_response) + return kb_ids: list[str] = [] for id in self._param.kb_ids: @@ -122,7 +126,7 @@ class Retrieval(ToolBase, ABC): vars = self.get_input_elements_from_text(kwargs["query"]) vars = {k:o["value"] for k,o in vars.items()} query = self.string_format(kwargs["query"], vars) - + doc_ids=[] if self._param.meta_data_filter!={}: metas = DocumentService.get_meta_by_kbs(kb_ids) @@ -184,9 +188,14 @@ class Retrieval(ToolBase, ABC): rerank_mdl=rerank_mdl, rank_feature=label_question(query, kbs), ) + if self.check_if_canceled("Retrieval processing"): + return + if self._param.toc_enhance: chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT) cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n) + if self.check_if_canceled("Retrieval processing"): + return if cks: kbinfos["chunks"] = cks if self._param.use_kg: @@ -195,6 +204,8 @@ class Retrieval(ToolBase, ABC): kb_ids, embd_mdl, LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)) + if self.check_if_canceled("Retrieval processing"): + return if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) else: @@ -202,6 +213,8 @@ class Retrieval(ToolBase, ABC): if self._param.use_kg and kbs: ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT)) + if self.check_if_canceled("Retrieval processing"): + return if ck["content_with_weight"]: ck["content"] = ck["content_with_weight"] del ck["content_with_weight"] diff --git a/agent/tools/searxng.py b/agent/tools/searxng.py index 44ad18bae..fdc7bea52 100644 --- a/agent/tools/searxng.py +++ b/agent/tools/searxng.py @@ -79,6 +79,9 @@ class SearXNG(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("SearXNG processing"): + return + # Gracefully handle try-run without inputs query = kwargs.get("query") if not query or not isinstance(query, str) or not query.strip(): @@ -93,6 +96,9 @@ class SearXNG(ToolBase, ABC): last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("SearXNG processing"): + return + try: search_params = { 'q': query, @@ -110,6 +116,9 @@ class SearXNG(ToolBase, ABC): ) response.raise_for_status() + if self.check_if_canceled("SearXNG processing"): + return + data = response.json() if not data or not isinstance(data, dict): @@ -121,6 +130,9 @@ class SearXNG(ToolBase, ABC): results = results[:self._param.top_n] + if self.check_if_canceled("SearXNG processing"): + return + self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), @@ -130,10 +142,16 @@ class SearXNG(ToolBase, ABC): return self.output("formalized_content") except requests.RequestException as e: + if self.check_if_canceled("SearXNG processing"): + return + last_e = f"Network error: {e}" logging.exception(f"SearXNG network error: {e}") time.sleep(self._param.delay_after_error) except Exception as e: + if self.check_if_canceled("SearXNG processing"): + return + last_e = str(e) logging.exception(f"SearXNG error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/tavily.py b/agent/tools/tavily.py index 6912c3695..1f1fa0137 100644 --- a/agent/tools/tavily.py +++ b/agent/tools/tavily.py @@ -103,6 +103,9 @@ class TavilySearch(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("TavilySearch processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" @@ -113,10 +116,16 @@ class TavilySearch(ToolBase, ABC): if fld not in kwargs: kwargs[fld] = getattr(self._param, fld) for _ in range(self._param.max_retries+1): + if self.check_if_canceled("TavilySearch processing"): + return + try: kwargs["include_images"] = False kwargs["include_raw_content"] = False res = self.tavily_client.search(**kwargs) + if self.check_if_canceled("TavilySearch processing"): + return + self._retrieve_chunks(res["results"], get_title=lambda r: r["title"], get_url=lambda r: r["url"], @@ -125,6 +134,9 @@ class TavilySearch(ToolBase, ABC): self.set_output("json", res["results"]) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("TavilySearch processing"): + return + last_e = e logging.exception(f"Tavily error: {e}") time.sleep(self._param.delay_after_error) @@ -201,6 +213,9 @@ class TavilyExtract(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) def _invoke(self, **kwargs): + if self.check_if_canceled("TavilyExtract processing"): + return + self.tavily_client = TavilyClient(api_key=self._param.api_key) last_e = None for fld in ["urls", "extract_depth", "format"]: @@ -209,12 +224,21 @@ class TavilyExtract(ToolBase, ABC): if kwargs.get("urls") and isinstance(kwargs["urls"], str): kwargs["urls"] = kwargs["urls"].split(",") for _ in range(self._param.max_retries+1): + if self.check_if_canceled("TavilyExtract processing"): + return + try: kwargs["include_images"] = False res = self.tavily_client.extract(**kwargs) + if self.check_if_canceled("TavilyExtract processing"): + return + self.set_output("json", res["results"]) return self.output("json") except Exception as e: + if self.check_if_canceled("TavilyExtract processing"): + return + last_e = e logging.exception(f"Tavily error: {e}") if last_e: diff --git a/agent/tools/tushare.py b/agent/tools/tushare.py index bb9d34fe9..6a0d0c2a3 100644 --- a/agent/tools/tushare.py +++ b/agent/tools/tushare.py @@ -43,12 +43,18 @@ class TuShare(ComponentBase, ABC): component_name = "TuShare" def _run(self, history, **kwargs): + if self.check_if_canceled("TuShare processing"): + return + ans = self.get_input() ans = ",".join(ans["content"]) if "content" in ans else "" if not ans: return TuShare.be_output("") try: + if self.check_if_canceled("TuShare processing"): + return + tus_res = [] params = { "api_name": "news", @@ -58,12 +64,18 @@ class TuShare(ComponentBase, ABC): } response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8')) response = response.json() + if self.check_if_canceled("TuShare processing"): + return if response['code'] != 0: return TuShare.be_output(response['msg']) df = pd.DataFrame(response['data']['items']) df.columns = response['data']['fields'] + if self.check_if_canceled("TuShare processing"): + return tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()}) except Exception as e: + if self.check_if_canceled("TuShare processing"): + return return TuShare.be_output("**ERROR**: " + str(e)) if not tus_res: diff --git a/agent/tools/wencai.py b/agent/tools/wencai.py index 7ddf27ac3..998e27a1d 100644 --- a/agent/tools/wencai.py +++ b/agent/tools/wencai.py @@ -70,19 +70,31 @@ class WenCai(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) def _invoke(self, **kwargs): + if self.check_if_canceled("WenCai processing"): + return + if not kwargs.get("query"): self.set_output("report", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("WenCai processing"): + return + try: wencai_res = [] res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n) + if self.check_if_canceled("WenCai processing"): + return + if isinstance(res, pd.DataFrame): wencai_res.append(res.to_markdown()) elif isinstance(res, dict): for item in res.items(): + if self.check_if_canceled("WenCai processing"): + return + if isinstance(item[1], list): wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown()) elif isinstance(item[1], str): @@ -100,6 +112,9 @@ class WenCai(ToolBase, ABC): self.set_output("report", "\n\n".join(wencai_res)) return self.output("report") except Exception as e: + if self.check_if_canceled("WenCai processing"): + return + last_e = e logging.exception(f"WenCai error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/wikipedia.py b/agent/tools/wikipedia.py index 8dcddc9b9..8e0b9c3fe 100644 --- a/agent/tools/wikipedia.py +++ b/agent/tools/wikipedia.py @@ -66,17 +66,26 @@ class Wikipedia(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) def _invoke(self, **kwargs): + if self.check_if_canceled("Wikipedia processing"): + return + if not kwargs.get("query"): self.set_output("formalized_content", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("Wikipedia processing"): + return + try: wikipedia.set_lang(self._param.language) wiki_engine = wikipedia pages = [] for p in wiki_engine.search(kwargs["query"], results=self._param.top_n): + if self.check_if_canceled("Wikipedia processing"): + return + try: pages.append(wikipedia.page(p)) except Exception: @@ -87,6 +96,9 @@ class Wikipedia(ToolBase, ABC): get_content=lambda r: r.summary) return self.output("formalized_content") except Exception as e: + if self.check_if_canceled("Wikipedia processing"): + return + last_e = e logging.exception(f"Wikipedia error: {e}") time.sleep(self._param.delay_after_error) diff --git a/agent/tools/yahoofinance.py b/agent/tools/yahoofinance.py index 3cca93f3d..324dfb643 100644 --- a/agent/tools/yahoofinance.py +++ b/agent/tools/yahoofinance.py @@ -74,15 +74,24 @@ class YahooFinance(ToolBase, ABC): @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60))) def _invoke(self, **kwargs): + if self.check_if_canceled("YahooFinance processing"): + return + if not kwargs.get("stock_code"): self.set_output("report", "") return "" last_e = "" for _ in range(self._param.max_retries+1): + if self.check_if_canceled("YahooFinance processing"): + return + yohoo_res = [] try: msft = yf.Ticker(kwargs["stock_code"]) + if self.check_if_canceled("YahooFinance processing"): + return + if self._param.info: yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n") if self._param.history: @@ -100,6 +109,9 @@ class YahooFinance(ToolBase, ABC): self.set_output("report", "\n\n".join(yohoo_res)) return self.output("report") except Exception as e: + if self.check_if_canceled("YahooFinance processing"): + return + last_e = e logging.exception(f"YahooFinance error: {e}") time.sleep(self._param.delay_after_error) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index bd72e15b1..0ac2951ae 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -156,7 +156,7 @@ def run(): return get_json_result(data={"message_id": task_id}) try: - canvas = Canvas(cvs.dsl, current_user.id, req["id"]) + canvas = Canvas(cvs.dsl, current_user.id) except Exception as e: return server_error_response(e) @@ -168,8 +168,10 @@ def run(): cvs.dsl = json.loads(str(canvas)) UserCanvasService.update_by_id(req["id"], cvs.to_dict()) + except Exception as e: logging.exception(e) + canvas.cancel_task() yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n" resp = Response(sse(), mimetype="text/event-stream") @@ -177,6 +179,7 @@ def run(): resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + resp.call_on_close(lambda: canvas.cancel_task()) return resp @@ -430,7 +433,7 @@ def test_db_connect(): catalog, schema = _parse_catalog_schema(req["database"]) if not catalog: return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.") - + http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http" auth = None @@ -603,4 +606,4 @@ def download(): id = request.args.get("id") created_by = request.args.get("created_by") blob = FileService.get_blob(created_by, id) - return flask.make_response(blob) \ No newline at end of file + return flask.make_response(blob) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 44f29162c..1de9c1646 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -24,7 +24,6 @@ import time import json_repair -from api.db.services.canvas_service import UserCanvasService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.pipeline_operation_log_service import PipelineOperationLogService from common.connection_utils import timeout @@ -33,7 +32,6 @@ from common.log_utils import init_root_logger from common.config_utils import show_configs from graphrag.general.index import run_graphrag_for_kb from graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache -from rag.flow.pipeline import Pipeline from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text import logging import os @@ -478,6 +476,9 @@ async def embedding(docs, mdl, parser_config=None, callback=None): async def run_dataflow(task: dict): + from api.db.services.canvas_service import UserCanvasService + from rag.flow.pipeline import Pipeline + task_start_ts = timer() dataflow_id = task["dataflow_id"] doc_id = task["doc_id"] @@ -944,6 +945,7 @@ async def do_handle_task(task): async def handle_task(): + global DONE_TASKS, FAILED_TASKS redis_msg, task = await collect() if not task: