Feat: add mechanism to check cancellation in Agent (#10766)

### What problem does this PR solve?

Add mechanism to check cancellation in Agent.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-11-11 17:36:48 +08:00
committed by GitHub
parent d81e4095de
commit 9213568692
36 changed files with 495 additions and 20 deletions

View File

@ -63,12 +63,18 @@ class ArXiv(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("ArXiv processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("ArXiv processing"):
return
try:
sort_choices = {"relevance": arxiv.SortCriterion.Relevance,
"lastUpdatedDate": arxiv.SortCriterion.LastUpdatedDate,
@ -79,12 +85,20 @@ class ArXiv(ToolBase, ABC):
max_results=self._param.top_n,
sort_by=sort_choices[self._param.sort_by]
)
self._retrieve_chunks(list(arxiv_client.results(search)),
results = list(arxiv_client.results(search))
if self.check_if_canceled("ArXiv processing"):
return
self._retrieve_chunks(results,
get_title=lambda r: r.title,
get_url=lambda r: r.pdf_url,
get_content=lambda r: r.summary)
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("ArXiv processing"):
return
last_e = e
logging.exception(f"ArXiv error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -125,6 +125,9 @@ class ToolBase(ComponentBase):
return self._param.get_meta()
def invoke(self, **kwargs):
if self.check_if_canceled("Tool processing"):
return
self.set_output("_created_time", time.perf_counter())
try:
res = self._invoke(**kwargs)
@ -170,4 +173,4 @@ class ToolBase(ComponentBase):
self.set_output("formalized_content", "\n".join(kb_prompt({"chunks": chunks, "doc_aggs": aggs}, 200000, True)))
def thoughts(self) -> str:
return self._canvas.get_component_name(self._id) + " is running..."
return self._canvas.get_component_name(self._id) + " is running..."

View File

@ -131,10 +131,14 @@ class CodeExec(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("CodeExec processing"):
return
lang = kwargs.get("lang", self._param.lang)
script = kwargs.get("script", self._param.script)
arguments = {}
for k, v in self._param.arguments.items():
if kwargs.get(k):
arguments[k] = kwargs[k]
continue
@ -149,15 +153,28 @@ class CodeExec(ToolBase, ABC):
def _execute_code(self, language: str, code: str, arguments: dict):
import requests
if self.check_if_canceled("CodeExec execution"):
return
try:
code_b64 = self._encode_code(code)
code_req = CodeExecutionRequest(code_b64=code_b64, language=language, arguments=arguments).model_dump()
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return
self.set_output("_ERROR", "construct code request error: " + str(e))
try:
if self.check_if_canceled("CodeExec execution"):
return "Task has been canceled"
resp = requests.post(url=f"http://{settings.SANDBOX_HOST}:9385/run", json=code_req, timeout=int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run, code_req: {code_req}, resp.status_code {resp.status_code}:")
if self.check_if_canceled("CodeExec execution"):
return "Task has been canceled"
if resp.status_code != 200:
resp.raise_for_status()
body = resp.json()
@ -173,16 +190,25 @@ class CodeExec(ToolBase, ABC):
logging.info(f"http://{settings.SANDBOX_HOST}:9385/run -> {rt}")
if isinstance(rt, tuple):
for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k.find("_") == 0:
continue
o["value"] = rt[i]
elif isinstance(rt, dict):
for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k not in rt or k.find("_") == 0:
continue
o["value"] = rt[k]
else:
for i, (k, o) in enumerate(self._param.outputs.items()):
if self.check_if_canceled("CodeExec execution"):
return
if k.find("_") == 0:
continue
o["value"] = rt
@ -190,6 +216,9 @@ class CodeExec(ToolBase, ABC):
self.set_output("_ERROR", "There is no response from sandbox")
except Exception as e:
if self.check_if_canceled("CodeExec execution"):
return
self.set_output("_ERROR", "Exception executing code: " + str(e))
return self.output()

View File

@ -29,7 +29,7 @@ class CrawlerParam(ToolParamBase):
super().__init__()
self.proxy = None
self.extract_type = "markdown"
def check(self):
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
@ -47,18 +47,24 @@ class Crawler(ToolBase, ABC):
result = asyncio.run(self.get_web(ans))
return Crawler.be_output(result)
except Exception as e:
return Crawler.be_output(f"An unexpected error occurred: {str(e)}")
async def get_web(self, url):
if self.check_if_canceled("Crawler async operation"):
return
proxy = self._param.proxy if self._param.proxy else None
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
result = await crawler.arun(
url=url,
bypass_cache=True
)
if self.check_if_canceled("Crawler async operation"):
return
if self._param.extract_type == 'html':
return result.cleaned_html
elif self._param.extract_type == 'markdown':

View File

@ -46,11 +46,16 @@ class DeepL(ComponentBase, ABC):
component_name = "DeepL"
def _run(self, history, **kwargs):
if self.check_if_canceled("DeepL processing"):
return
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
return DeepL.be_output("")
if self.check_if_canceled("DeepL processing"):
return
try:
translator = deepl.Translator(self._param.auth_key)
result = translator.translate_text(ans, source_lang=self._param.source_lang,
@ -58,4 +63,6 @@ class DeepL(ComponentBase, ABC):
return DeepL.be_output(result.text)
except Exception as e:
if self.check_if_canceled("DeepL processing"):
return
DeepL.be_output("**Error**:" + str(e))

View File

@ -75,17 +75,30 @@ class DuckDuckGo(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("DuckDuckGo processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("DuckDuckGo processing"):
return
try:
if kwargs.get("topic", "general") == "general":
with DDGS() as ddgs:
if self.check_if_canceled("DuckDuckGo processing"):
return
# {'title': '', 'href': '', 'body': ''}
duck_res = ddgs.text(kwargs["query"], max_results=self._param.top_n)
if self.check_if_canceled("DuckDuckGo processing"):
return
self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")),
@ -94,8 +107,15 @@ class DuckDuckGo(ToolBase, ABC):
return self.output("formalized_content")
else:
with DDGS() as ddgs:
if self.check_if_canceled("DuckDuckGo processing"):
return
# {'date': '', 'title': '', 'body': '', 'url': '', 'image': '', 'source': ''}
duck_res = ddgs.news(kwargs["query"], max_results=self._param.top_n)
if self.check_if_canceled("DuckDuckGo processing"):
return
self._retrieve_chunks(duck_res,
get_title=lambda r: r["title"],
get_url=lambda r: r.get("href", r.get("url")),
@ -103,6 +123,9 @@ class DuckDuckGo(ToolBase, ABC):
self.set_output("json", duck_res)
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("DuckDuckGo processing"):
return
last_e = e
logging.exception(f"DuckDuckGo error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -101,19 +101,27 @@ class Email(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Email processing"):
return
if not kwargs.get("to_email"):
self.set_output("success", False)
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Email processing"):
return
try:
# Parse JSON string passed from upstream
email_data = kwargs
# Validate required fields
if "to_email" not in email_data:
return Email.be_output("Missing required field: to_email")
self.set_output("_ERROR", "Missing required field: to_email")
self.set_output("success", False)
return False
# Create email object
msg = MIMEMultipart('alternative')
@ -133,6 +141,9 @@ class Email(ToolBase, ABC):
# Connect to SMTP server and send
logging.info(f"Connecting to SMTP server {self._param.smtp_server}:{self._param.smtp_port}")
if self.check_if_canceled("Email processing"):
return
context = smtplib.ssl.create_default_context()
with smtplib.SMTP(self._param.smtp_server, self._param.smtp_port) as server:
server.ehlo()
@ -149,6 +160,10 @@ class Email(ToolBase, ABC):
# Send email
logging.info(f"Sending email to recipients: {recipients}")
if self.check_if_canceled("Email processing"):
return
try:
server.send_message(msg, self._param.email, recipients)
success = True

View File

@ -81,6 +81,8 @@ class ExeSQL(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("ExeSQL processing"):
return
def convert_decimals(obj):
from decimal import Decimal
@ -96,6 +98,9 @@ class ExeSQL(ToolBase, ABC):
if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.")
if self.check_if_canceled("ExeSQL processing"):
return
vars = self.get_input_elements_from_text(sql)
args = {}
for k, o in vars.items():
@ -108,6 +113,9 @@ class ExeSQL(ToolBase, ABC):
self.set_input_value(k, args[k])
sql = self.string_format(sql, args)
if self.check_if_canceled("ExeSQL processing"):
return
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
@ -181,6 +189,10 @@ class ExeSQL(ToolBase, ABC):
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
ibm_db.close(conn)
return
single_sql = single_sql.replace("```", "").strip()
if not single_sql:
continue
@ -190,6 +202,9 @@ class ExeSQL(ToolBase, ABC):
rows = []
row = ibm_db.fetch_assoc(stmt)
while row and len(rows) < self._param.max_records:
if self.check_if_canceled("ExeSQL processing"):
ibm_db.close(conn)
return
rows.append(row)
row = ibm_db.fetch_assoc(stmt)
@ -220,6 +235,11 @@ class ExeSQL(ToolBase, ABC):
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
cursor.close()
db.close()
return
single_sql = single_sql.replace('```','')
if not single_sql:
continue
@ -244,6 +264,9 @@ class ExeSQL(ToolBase, ABC):
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
cursor.close()
db.close()
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")

View File

@ -59,17 +59,27 @@ class GitHub(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("GitHub processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("GitHub processing"):
return
try:
url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str(
self._param.top_n)
headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'}
response = requests.get(url=url, headers=headers).json()
if self.check_if_canceled("GitHub processing"):
return
self._retrieve_chunks(response['items'],
get_title=lambda r: r["name"],
get_url=lambda r: r["html_url"],
@ -77,6 +87,9 @@ class GitHub(ToolBase, ABC):
self.set_output("json", response['items'])
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("GitHub processing"):
return
last_e = e
logging.exception(f"GitHub error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -118,6 +118,9 @@ class Google(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Google processing"):
return
if not kwargs.get("q"):
self.set_output("formalized_content", "")
return ""
@ -132,8 +135,15 @@ class Google(ToolBase, ABC):
}
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Google processing"):
return
try:
search = GoogleSearch(params).get_dict()
if self.check_if_canceled("Google processing"):
return
self._retrieve_chunks(search["organic_results"],
get_title=lambda r: r["title"],
get_url=lambda r: r["link"],
@ -142,6 +152,9 @@ class Google(ToolBase, ABC):
self.set_output("json", search["organic_results"])
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("Google processing"):
return
last_e = e
logging.exception(f"Google error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -65,15 +65,25 @@ class GoogleScholar(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("GoogleScholar processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("GoogleScholar processing"):
return
try:
scholar_client = scholarly.search_pubs(kwargs["query"], patents=self._param.patents, year_low=self._param.year_low,
year_high=self._param.year_high, sort_by=self._param.sort_by)
if self.check_if_canceled("GoogleScholar processing"):
return
self._retrieve_chunks(scholar_client,
get_title=lambda r: r['bib']['title'],
get_url=lambda r: r["pub_url"],
@ -82,6 +92,9 @@ class GoogleScholar(ToolBase, ABC):
self.set_output("json", list(scholar_client))
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("GoogleScholar processing"):
return
last_e = e
logging.exception(f"GoogleScholar error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -50,6 +50,9 @@ class Jin10(ComponentBase, ABC):
component_name = "Jin10"
def _run(self, history, **kwargs):
if self.check_if_canceled("Jin10 processing"):
return
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not ans:
@ -58,6 +61,9 @@ class Jin10(ComponentBase, ABC):
jin10_res = []
headers = {'secret-key': self._param.secret_key}
try:
if self.check_if_canceled("Jin10 processing"):
return
if self._param.type == "flash":
params = {
'category': self._param.flash_type,
@ -69,6 +75,8 @@ class Jin10(ComponentBase, ABC):
headers=headers, data=json.dumps(params))
response = response.json()
for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": i['data']['content']})
if self._param.type == "calendar":
params = {
@ -79,6 +87,8 @@ class Jin10(ComponentBase, ABC):
headers=headers, data=json.dumps(params))
response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
if self._param.type == "symbols":
params = {
@ -90,8 +100,12 @@ class Jin10(ComponentBase, ABC):
url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type,
headers=headers, data=json.dumps(params))
response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
if self._param.symbols_datatype == "symbols":
for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
i['Commodity Code'] = i['c']
i['Stock Exchange'] = i['e']
i['Commodity Name'] = i['n']
@ -99,6 +113,8 @@ class Jin10(ComponentBase, ABC):
del i['c'], i['e'], i['n'], i['t']
if self._param.symbols_datatype == "quotes":
for i in response['data']:
if self.check_if_canceled("Jin10 processing"):
return
i['Selling Price'] = i['a']
i['Buying Price'] = i['b']
i['Commodity Code'] = i['c']
@ -120,8 +136,12 @@ class Jin10(ComponentBase, ABC):
url='https://open-data-api.jin10.com/data-api/news',
headers=headers, data=json.dumps(params))
response = response.json()
if self.check_if_canceled("Jin10 processing"):
return
jin10_res.append({"content": pd.DataFrame(response['data']).to_markdown()})
except Exception as e:
if self.check_if_canceled("Jin10 processing"):
return
return Jin10.be_output("**ERROR**: " + str(e))
if not jin10_res:

View File

@ -71,23 +71,40 @@ class PubMed(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("PubMed processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("PubMed processing"):
return
try:
Entrez.email = self._param.email
pubmedids = Entrez.read(Entrez.esearch(db='pubmed', retmax=self._param.top_n, term=kwargs["query"]))['IdList']
if self.check_if_canceled("PubMed processing"):
return
pubmedcnt = ET.fromstring(re.sub(r'<(/?)b>|<(/?)i>', '', Entrez.efetch(db='pubmed', id=",".join(pubmedids),
retmode="xml").read().decode("utf-8")))
if self.check_if_canceled("PubMed processing"):
return
self._retrieve_chunks(pubmedcnt.findall("PubmedArticle"),
get_title=lambda child: child.find("MedlineCitation").find("Article").find("ArticleTitle").text,
get_url=lambda child: "https://pubmed.ncbi.nlm.nih.gov/" + child.find("MedlineCitation").find("PMID").text,
get_content=lambda child: self._format_pubmed_content(child),)
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("PubMed processing"):
return
last_e = e
logging.exception(f"PubMed error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -58,12 +58,18 @@ class QWeather(ComponentBase, ABC):
component_name = "QWeather"
def _run(self, history, **kwargs):
if self.check_if_canceled("Qweather processing"):
return
ans = self.get_input()
ans = "".join(ans["content"]) if "content" in ans else ""
if not ans:
return QWeather.be_output("")
try:
if self.check_if_canceled("Qweather processing"):
return
response = requests.get(
url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json()
if response["code"] == "200":
@ -71,16 +77,23 @@ class QWeather(ComponentBase, ABC):
else:
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
if self.check_if_canceled("Qweather processing"):
return
base_url = "https://api.qweather.com/v7/" if self._param.user_type == 'paid' else "https://devapi.qweather.com/v7/"
if self._param.type == "weather":
url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200":
if self._param.time_period == "now":
return QWeather.be_output(str(response["now"]))
else:
qweather_res = [{"content": str(i) + "\n"} for i in response["daily"]]
if self.check_if_canceled("Qweather processing"):
return
if not qweather_res:
return QWeather.be_output("")
@ -92,6 +105,8 @@ class QWeather(ComponentBase, ABC):
elif self._param.type == "indices":
url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200":
indices_res = response["daily"][0]["date"] + "\n" + "\n".join(
[i["name"] + ": " + i["category"] + ", " + i["text"] for i in response["daily"]])
@ -103,9 +118,13 @@ class QWeather(ComponentBase, ABC):
elif self._param.type == "airquality":
url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang
response = requests.get(url=url).json()
if self.check_if_canceled("Qweather processing"):
return
if response["code"] == "200":
return QWeather.be_output(str(response["now"]))
else:
return QWeather.be_output("**Error**" + self._param.error_code[response["code"]])
except Exception as e:
if self.check_if_canceled("Qweather processing"):
return
return QWeather.be_output("**Error**" + str(e))

View File

@ -82,8 +82,12 @@ class Retrieval(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Retrieval processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", self._param.empty_response)
return
kb_ids: list[str] = []
for id in self._param.kb_ids:
@ -122,7 +126,7 @@ class Retrieval(ToolBase, ABC):
vars = self.get_input_elements_from_text(kwargs["query"])
vars = {k:o["value"] for k,o in vars.items()}
query = self.string_format(kwargs["query"], vars)
doc_ids=[]
if self._param.meta_data_filter!={}:
metas = DocumentService.get_meta_by_kbs(kb_ids)
@ -184,9 +188,14 @@ class Retrieval(ToolBase, ABC):
rerank_mdl=rerank_mdl,
rank_feature=label_question(query, kbs),
)
if self.check_if_canceled("Retrieval processing"):
return
if self._param.toc_enhance:
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
if self.check_if_canceled("Retrieval processing"):
return
if cks:
kbinfos["chunks"] = cks
if self._param.use_kg:
@ -195,6 +204,8 @@ class Retrieval(ToolBase, ABC):
kb_ids,
embd_mdl,
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
kbinfos["chunks"].insert(0, ck)
else:
@ -202,6 +213,8 @@ class Retrieval(ToolBase, ABC):
if self._param.use_kg and kbs:
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
if self.check_if_canceled("Retrieval processing"):
return
if ck["content_with_weight"]:
ck["content"] = ck["content_with_weight"]
del ck["content_with_weight"]

View File

@ -79,6 +79,9 @@ class SearXNG(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("SearXNG processing"):
return
# Gracefully handle try-run without inputs
query = kwargs.get("query")
if not query or not isinstance(query, str) or not query.strip():
@ -93,6 +96,9 @@ class SearXNG(ToolBase, ABC):
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("SearXNG processing"):
return
try:
search_params = {
'q': query,
@ -110,6 +116,9 @@ class SearXNG(ToolBase, ABC):
)
response.raise_for_status()
if self.check_if_canceled("SearXNG processing"):
return
data = response.json()
if not data or not isinstance(data, dict):
@ -121,6 +130,9 @@ class SearXNG(ToolBase, ABC):
results = results[:self._param.top_n]
if self.check_if_canceled("SearXNG processing"):
return
self._retrieve_chunks(results,
get_title=lambda r: r.get("title", ""),
get_url=lambda r: r.get("url", ""),
@ -130,10 +142,16 @@ class SearXNG(ToolBase, ABC):
return self.output("formalized_content")
except requests.RequestException as e:
if self.check_if_canceled("SearXNG processing"):
return
last_e = f"Network error: {e}"
logging.exception(f"SearXNG network error: {e}")
time.sleep(self._param.delay_after_error)
except Exception as e:
if self.check_if_canceled("SearXNG processing"):
return
last_e = str(e)
logging.exception(f"SearXNG error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -103,6 +103,9 @@ class TavilySearch(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("TavilySearch processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
@ -113,10 +116,16 @@ class TavilySearch(ToolBase, ABC):
if fld not in kwargs:
kwargs[fld] = getattr(self._param, fld)
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("TavilySearch processing"):
return
try:
kwargs["include_images"] = False
kwargs["include_raw_content"] = False
res = self.tavily_client.search(**kwargs)
if self.check_if_canceled("TavilySearch processing"):
return
self._retrieve_chunks(res["results"],
get_title=lambda r: r["title"],
get_url=lambda r: r["url"],
@ -125,6 +134,9 @@ class TavilySearch(ToolBase, ABC):
self.set_output("json", res["results"])
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("TavilySearch processing"):
return
last_e = e
logging.exception(f"Tavily error: {e}")
time.sleep(self._param.delay_after_error)
@ -201,6 +213,9 @@ class TavilyExtract(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("TavilyExtract processing"):
return
self.tavily_client = TavilyClient(api_key=self._param.api_key)
last_e = None
for fld in ["urls", "extract_depth", "format"]:
@ -209,12 +224,21 @@ class TavilyExtract(ToolBase, ABC):
if kwargs.get("urls") and isinstance(kwargs["urls"], str):
kwargs["urls"] = kwargs["urls"].split(",")
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("TavilyExtract processing"):
return
try:
kwargs["include_images"] = False
res = self.tavily_client.extract(**kwargs)
if self.check_if_canceled("TavilyExtract processing"):
return
self.set_output("json", res["results"])
return self.output("json")
except Exception as e:
if self.check_if_canceled("TavilyExtract processing"):
return
last_e = e
logging.exception(f"Tavily error: {e}")
if last_e:

View File

@ -43,12 +43,18 @@ class TuShare(ComponentBase, ABC):
component_name = "TuShare"
def _run(self, history, **kwargs):
if self.check_if_canceled("TuShare processing"):
return
ans = self.get_input()
ans = ",".join(ans["content"]) if "content" in ans else ""
if not ans:
return TuShare.be_output("")
try:
if self.check_if_canceled("TuShare processing"):
return
tus_res = []
params = {
"api_name": "news",
@ -58,12 +64,18 @@ class TuShare(ComponentBase, ABC):
}
response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'))
response = response.json()
if self.check_if_canceled("TuShare processing"):
return
if response['code'] != 0:
return TuShare.be_output(response['msg'])
df = pd.DataFrame(response['data']['items'])
df.columns = response['data']['fields']
if self.check_if_canceled("TuShare processing"):
return
tus_res.append({"content": (df[df['content'].str.contains(self._param.keyword, case=False)]).to_markdown()})
except Exception as e:
if self.check_if_canceled("TuShare processing"):
return
return TuShare.be_output("**ERROR**: " + str(e))
if not tus_res:

View File

@ -70,19 +70,31 @@ class WenCai(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
if self.check_if_canceled("WenCai processing"):
return
if not kwargs.get("query"):
self.set_output("report", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("WenCai processing"):
return
try:
wencai_res = []
res = pywencai.get(query=kwargs["query"], query_type=self._param.query_type, perpage=self._param.top_n)
if self.check_if_canceled("WenCai processing"):
return
if isinstance(res, pd.DataFrame):
wencai_res.append(res.to_markdown())
elif isinstance(res, dict):
for item in res.items():
if self.check_if_canceled("WenCai processing"):
return
if isinstance(item[1], list):
wencai_res.append(item[0] + "\n" + pd.DataFrame(item[1]).to_markdown())
elif isinstance(item[1], str):
@ -100,6 +112,9 @@ class WenCai(ToolBase, ABC):
self.set_output("report", "\n\n".join(wencai_res))
return self.output("report")
except Exception as e:
if self.check_if_canceled("WenCai processing"):
return
last_e = e
logging.exception(f"WenCai error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -66,17 +66,26 @@ class Wikipedia(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Wikipedia processing"):
return
if not kwargs.get("query"):
self.set_output("formalized_content", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("Wikipedia processing"):
return
try:
wikipedia.set_lang(self._param.language)
wiki_engine = wikipedia
pages = []
for p in wiki_engine.search(kwargs["query"], results=self._param.top_n):
if self.check_if_canceled("Wikipedia processing"):
return
try:
pages.append(wikipedia.page(p))
except Exception:
@ -87,6 +96,9 @@ class Wikipedia(ToolBase, ABC):
get_content=lambda r: r.summary)
return self.output("formalized_content")
except Exception as e:
if self.check_if_canceled("Wikipedia processing"):
return
last_e = e
logging.exception(f"Wikipedia error: {e}")
time.sleep(self._param.delay_after_error)

View File

@ -74,15 +74,24 @@ class YahooFinance(ToolBase, ABC):
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("YahooFinance processing"):
return
if not kwargs.get("stock_code"):
self.set_output("report", "")
return ""
last_e = ""
for _ in range(self._param.max_retries+1):
if self.check_if_canceled("YahooFinance processing"):
return
yohoo_res = []
try:
msft = yf.Ticker(kwargs["stock_code"])
if self.check_if_canceled("YahooFinance processing"):
return
if self._param.info:
yohoo_res.append("# Information:\n" + pd.Series(msft.info).to_markdown() + "\n")
if self._param.history:
@ -100,6 +109,9 @@ class YahooFinance(ToolBase, ABC):
self.set_output("report", "\n\n".join(yohoo_res))
return self.output("report")
except Exception as e:
if self.check_if_canceled("YahooFinance processing"):
return
last_e = e
logging.exception(f"YahooFinance error: {e}")
time.sleep(self._param.delay_after_error)