Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -69,6 +69,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
Since a book is long and not all the parts are useful, if it's a PDF,
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
@ -89,7 +92,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)

View File

@ -40,7 +40,7 @@ def chunk(
eng = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config",
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
)
doc = {
"docnm_kwd": filename,

View File

@ -145,6 +145,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"""
Supported file formats are docx, pdf, txt.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
@ -163,7 +166,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
for txt, poss in pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)[0]:

View File

@ -45,9 +45,6 @@ class Pdf(PdfParser):
callback
)
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
# for bb in self.boxes:
# for b in bb:
# print(b)
logging.debug("OCR: {}".format(timer() - start))
start = timer()
@ -177,6 +174,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
"""
Only pdf is supported.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
pdf_parser = None
doc = {
"docnm_kwd": filename
@ -187,7 +187,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
eng = lang.lower() == "english" # pdf_parser.is_english
if re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, tbls = pdf_parser(filename if not binary else binary,
from_page=from_page, to_page=to_page, callback=callback)
@ -222,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
sid += 1
sec_ids.append(sid)
# print(lvl, self.boxes[i]["text"], most_level, sid)
sections = [(txt, sec_ids[i], poss)
for i, (txt, _, poss) in enumerate(sections)]

View File

@ -370,7 +370,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
is_english = lang.lower() == "english" # is_english(cks)
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
doc = {
"docnm_kwd": filename,
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))

View File

@ -72,7 +72,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
Supported file formats are docx, pdf, excel, txt.
One file forms a chunk which maintains original text order.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
eng = lang.lower() == "english" # is_english(cks)
if re.search(r"\.docx$", filename, re.IGNORECASE):
@ -85,7 +87,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
pdf_parser = Pdf()
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
sections, _ = pdf_parser(
filename if not binary else binary, to_page=to_page, callback=callback)

View File

@ -143,8 +143,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
Only pdf is supported.
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
"""
parser_config = kwargs.get(
"parser_config", {
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
if re.search(r"\.pdf$", filename, re.IGNORECASE):
if kwargs.get("parser_config", {}).get("layout_recognize", "DeepDOC") == "Plain Text":
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
pdf_parser = PlainParser()
paper = {
"title": filename,

View File

@ -18,6 +18,7 @@ import json
import logging
import os
import random
import re
import time
from abc import ABC
from copy import deepcopy
@ -31,25 +32,33 @@ from dashscope import Generation
from ollama import Client
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from strenum import StrEnum
from zhipuai import ZhipuAI
from rag.nlp import is_chinese, is_english
from rag.utils import num_tokens_from_string
# Error message constants
ERROR_PREFIX = "**ERROR**"
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class LLMErrorCode(StrEnum):
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
ERROR_AUTHENTICATION = "AUTH_ERROR"
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
ERROR_SERVER = "SERVER_ERROR"
ERROR_TIMEOUT = "TIMEOUT"
ERROR_CONNECTION = "CONNECTION_ERROR"
ERROR_MODEL = "MODEL_ERROR"
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
ERROR_QUOTA = "QUOTA_EXCEEDED"
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
ERROR_GENERIC = "GENERIC_ERROR"
class ReActMode(StrEnum):
FUNCTION_CALL = "function_call"
REACT = "react"
ERROR_PREFIX = "**ERROR**"
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
@ -73,51 +82,78 @@ class Base(ABC):
def _get_delay(self):
"""Calculate retry delay time"""
return self.base_delay + random.uniform(60, 150)
return self.base_delay * random.uniform(10, 150)
def _classify_error(self, error):
"""Classify error based on error message content"""
error_str = str(error).lower()
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
return ERROR_RATE_LIMIT
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
return ERROR_AUTHENTICATION
elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str:
return ERROR_INVALID_REQUEST
elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str:
return ERROR_SERVER
elif "timeout" in error_str or "timed out" in error_str:
return ERROR_TIMEOUT
elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str:
return ERROR_CONNECTION
elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str:
return ERROR_QUOTA
elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str or "inappropriate" in error_str:
return ERROR_CONTENT_FILTER
elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str:
return ERROR_MODEL
else:
return ERROR_GENERIC
keywords_mapping = [
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
(["max rounds"], LLMErrorCode.ERROR_MODEL),
]
for words, code in keywords_mapping:
if re.search("({})".format("|".join(words)), error_str):
return code
return LLMErrorCode.ERROR_GENERIC
def _clean_conf(self, gen_conf):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
return gen_conf
def _chat(self, history, gen_conf):
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
def _chat(self, history, gen_conf, **kwargs):
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
if self.model_name.lower().find("qwen3") >=0:
kwargs["extra_body"] = {"enable_thinking": False}
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
return "", 0
ans = response.choices[0].message.content.strip()
if response.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
ans = self._length_stop(ans)
return ans, self.total_token_count(response)
def _chat_streamly(self, history, gen_conf, **kwargs):
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
tol = num_tokens_from_string(resp.choices[0].delta.content)
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
def _length_stop(self, ans):
if is_chinese([ans]):
return ans + LENGTH_NOTIFICATION_CN
@ -127,22 +163,24 @@ class Base(ABC):
logging.exception("OpenAI chat_with_tools")
# Classify the error
error_code = self._classify_error(e)
if attempt == self.max_retries:
error_code = LLMErrorCode.ERROR_MAX_RETRIES
# Check if it's a rate limit error or server error and not the last attempt
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
if should_retry:
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
else:
# For non-rate limit errors or the last attempt, return an error message
if attempt == self.max_retries:
error_code = ERROR_MAX_RETRIES
should_retry = (error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER)
if not should_retry:
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
delay = self._get_delay()
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
time.sleep(delay)
def _verbose_tool_use(self, name, args, res):
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
return "<tool_call>" + json.dumps({
"name": name,
"args": args,
"result": res
}, ensure_ascii=False, indent=2) + "</tool_call>"
def _append_history(self, hist, tool_call, tool_res):
hist.append(
@ -172,17 +210,14 @@ class Base(ABC):
if not (toolcall_session and tools):
return
self.is_tools = True
self.toolcall_session = toolcall_session
self.tools = tools
for tool in tools:
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
self.tools.append(tool)
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
def chat_with_tools(self, system: str, history: list, gen_conf: dict={}):
gen_conf = self._clean_conf(gen_conf)
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
tk_count = 0
hist = deepcopy(history)
@ -190,8 +225,9 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds * 2):
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
for _ in range(self.max_rounds+1):
logging.info(f"{self.tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
tk_count += self.total_token_count(response)
if any([not response.choices, not response.choices[0].message]):
raise Exception(f"500 response structure error. Response: {response}")
@ -207,10 +243,11 @@ class Base(ABC):
return ans, tk_count
for tool_call in response.choices[0].message.tool_calls:
logging.info(f"Response {tool_call=}")
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_sessions[name].tool_call(name, args)
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
ans += self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -218,13 +255,20 @@ class Base(ABC):
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
ans += self._verbose_tool_use(name, {}, str(e))
logging.warning( f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response, token_count = self._chat(history, gen_conf)
ans += response
tk_count += token_count
return ans, tk_count
except Exception as e:
e = self._exceptions(e, attempt)
if e:
return e, tk_count
assert False, "Shouldn't be here."
def chat(self, system, history, gen_conf):
def chat(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
@ -232,7 +276,7 @@ class Base(ABC):
# Implement exponential backoff retry strategy
for attempt in range(self.max_retries + 1):
try:
return self._chat(history, gen_conf)
return self._chat(history, gen_conf, **kwargs)
except Exception as e:
e = self._exceptions(e, attempt)
if e:
@ -253,7 +297,7 @@ class Base(ABC):
return final_tool_calls
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict={}):
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system:
@ -265,9 +309,10 @@ class Base(ABC):
for attempt in range(self.max_retries + 1):
history = hist
try:
for _ in range(self.max_rounds * 2):
for _ in range(self.max_rounds+1):
reasoning_start = False
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
logging.info(f"{tools=}")
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
final_tool_calls = {}
answer = ""
for resp in response:
@ -319,7 +364,8 @@ class Base(ABC):
name = tool_call.function.name
try:
args = json_repair.loads(tool_call.function.arguments)
tool_response = self.toolcall_session[name].tool_call(name, args)
yield self._verbose_tool_use(name, args, "Begin to call...")
tool_response = self.toolcall_session.tool_call(name, args)
history = self._append_history(history, tool_call, tool_response)
yield self._verbose_tool_use(name, args, tool_response)
except Exception as e:
@ -327,51 +373,45 @@ class Base(ABC):
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
yield self._verbose_tool_use(name, {}, str(e))
logging.warning( f"Exceed max rounds: {self.max_rounds}")
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
raise Exception("500 response structure error.")
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
continue
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens += tol
answer += resp.choices[0].delta.content
yield resp.choices[0].delta.content
yield total_tokens
return
except Exception as e:
e = self._exceptions(e, attempt)
if e:
yield e
yield total_tokens
return
yield total_tokens
assert False, "Shouldn't be here."
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf: dict={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
ans = ""
total_tokens = 0
reasoning_start = False
try:
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices:
continue
if not resp.choices[0].delta.content:
resp.choices[0].delta.content = ""
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
ans = ""
if not reasoning_start:
reasoning_start = True
ans = "<think>"
ans += resp.choices[0].delta.reasoning_content + "</think>"
else:
reasoning_start = False
ans = resp.choices[0].delta.content
tol = self.total_token_count(resp)
if not tol:
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
else:
total_tokens += tol
if resp.choices[0].finish_reason == "length":
if is_chinese(ans):
ans += LENGTH_NOTIFICATION_CN
else:
ans += LENGTH_NOTIFICATION_EN
yield ans
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
yield delta_ans
total_tokens += tol
except openai.APIError as e:
yield ans + "\n**ERROR**: " + str(e)
@ -514,7 +554,7 @@ class BaiChuanChat(Base):
"top_p": gen_conf.get("top_p", 0.85),
}
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@ -529,7 +569,7 @@ class BaiChuanChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -614,7 +654,7 @@ class ZhipuChat(Base):
return super().chat_with_tools(system, history, gen_conf)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -626,6 +666,7 @@ class ZhipuChat(Base):
ans = ""
tk_count = 0
try:
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
for resp in response:
if not resp.choices[0].delta.content:
@ -675,7 +716,7 @@ class OllamaChat(Base):
options[k] = gen_conf[k]
return options
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
# Calculate context size
ctx_size = self._calculate_dynamic_ctx(history)
@ -685,7 +726,7 @@ class OllamaChat(Base):
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
return ans, token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -766,7 +807,7 @@ class LocalLLM(Base):
yield answer + "\n**ERROR**: " + str(e)
yield num_tokens_from_string(answer)
def chat(self, system, history, gen_conf):
def chat(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@ -775,7 +816,7 @@ class LocalLLM(Base):
total_tokens = next(chat_gen)
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@ -894,7 +935,7 @@ class MistralChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
if response.choices[0].finish_reason == "length":
@ -904,7 +945,7 @@ class MistralChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
@ -913,7 +954,7 @@ class MistralChat(Base):
ans = ""
total_tokens = 0
try:
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
for resp in response:
if not resp.choices or not resp.choices[0].delta.content:
continue
@ -957,7 +998,7 @@ class BedrockChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
hist = []
for item in history:
@ -978,7 +1019,7 @@ class BedrockChat(Base):
ans = response["output"]["message"]["content"][0]["text"]
return ans, num_tokens_from_string(ans)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from botocore.exceptions import ClientError
for k in list(gen_conf.keys()):
@ -1036,7 +1077,7 @@ class GeminiChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
from google.generativeai.types import content_types
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
@ -1059,7 +1100,7 @@ class GeminiChat(Base):
ans = response.text
return ans, response.usage_metadata.total_token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from google.generativeai.types import content_types
gen_conf = self._clean_conf(gen_conf)
@ -1101,7 +1142,7 @@ class GroqChat(Base):
del gen_conf[k]
return gen_conf
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
for k in list(gen_conf.keys()):
@ -1229,7 +1270,7 @@ class CoHereChat(Base):
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if system:
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@ -1348,7 +1389,7 @@ class ReplicateChat(Base):
self.model_name = model_name
self.client = Client(api_token=key)
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
response = self.client.run(
@ -1358,7 +1399,7 @@ class ReplicateChat(Base):
ans = "".join(response)
return ans, num_tokens_from_string(ans)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
@ -1402,7 +1443,7 @@ class HunyuanChat(Base):
_gen_conf["TopP"] = gen_conf["top_p"]
return _gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
from tencentcloud.hunyuan.v20230901 import models
hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
@ -1413,7 +1454,7 @@ class HunyuanChat(Base):
ans = response.Choices[0].Message.Content
return ans, response.Usage.TotalTokens
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
TencentCloudSDKException,
)
@ -1504,7 +1545,7 @@ class BaiduYiyanChat(Base):
ans = response["result"]
return ans, self.total_token_count(response)
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@ -1588,7 +1629,7 @@ class GoogleChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf):
def _chat(self, history, gen_conf={}, **kwargs):
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
if "claude" in self.model_name:
response = self.client.messages.create(
@ -1626,7 +1667,7 @@ class GoogleChat(Base):
ans = response.text
return ans, response.usage_metadata.total_token_count
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "claude" in self.model_name:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]

File diff suppressed because it is too large Load Diff

View File

@ -89,7 +89,7 @@ class DefaultRerank(Base):
torch.cuda.empty_cache()
except Exception as e:
print(f"Error emptying cache: {e}")
log_exception(e)
def _process_batch(self, pairs, max_batch_size=None):
"""template method for subclass call"""

View File

@ -518,7 +518,8 @@ def hierarchical_merge(bull, sections, depth):
return res
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
if not sections:
return []
if isinstance(sections[0], type("")):
@ -534,8 +535,10 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。"):
if tnum < 8:
pos = ""
# Ensure that the length of the merged chunk does not exceed chunk_token_num
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
if cks:
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
if t.find(pos) < 0:
t += pos
cks.append(t)
@ -548,7 +551,10 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。"):
dels = get_delimiters(delimiter)
for sec, pos in sections:
splited_sec = re.split(r"(%s)" % dels, sec)
if num_tokens_from_string(sec) < chunk_token_num:
add_chunk(sec, pos)
continue
splited_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
for sub_sec in splited_sec:
if re.match(f"^{dels}$", sub_sec):
continue

View File

@ -384,7 +384,7 @@ class Dealer:
zero_vector = [0.0] * dim
sim_np = np.array(sim)
if doc_ids:
similarity_threshold = 0
similarity_threshold = 0
filtered_count = (sim_np >= similarity_threshold).sum()
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
for i in idx:
@ -403,7 +403,7 @@ class Dealer:
ranks["doc_aggs"][dnm]["count"] += 1
continue
break
position_int = chunk.get("position_int", [])
d = {
"chunk_id": id,

6
rag/prompts/__init__.py Normal file
View File

@ -0,0 +1,6 @@
from . import prompts
__all__ = [name for name in dir(prompts)
if not name.startswith('_')]
globals().update({name: getattr(prompts, name) for name in __all__})

View File

@ -0,0 +1,8 @@
Your responsibility is to execute assigned tasks to a high standard. Please:
1. Carefully analyze the task requirements.
2. Develop a reasonable execution plan.
3. Execute step-by-step and document the reasoning process.
4. Provide clear and accurate results.
If difficulties are encountered, clearly state the problem and explore alternative approaches.

View File

@ -0,0 +1,20 @@
Please analyze the following task:
Task: {{ task }}
Context: {{ context }}
**Analysis Requirements:**
1. Is it just a small talk? (If yes, no further plan or analysis is needed)
2. What is the core objective of the task?
3. What is the complexity level of the task?
4. What types of specialized skills are required?
5. Does the task need to be decomposed into subtasks? (If yes, propose the subtask structure)
6. How to know the task or the subtasks are impossible to lead to the success after a few rounds of interaction?
7. What are the expected success criteria?
**Available Sub-Agents and Their Specializations:**
{{ tools_desc }}
Provide a detailed analysis of the task based on the above requirements.

View File

@ -0,0 +1,13 @@
You are an agent for adding correct citations to the given text by user.
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
However, the sources are not cited in the [ID:<ID>].
Your task is to enhance user trust by generating correct, appropriate citations for this report.
{{ example }}
<context>
{{ sources }}
</context>

View File

@ -1,46 +1,108 @@
## Citation Requirements
Based on the provided document or chat history, add citations to the input text using the format specified later.
- Use a uniform citation format such as [ID:i] [ID:j], where "i" and "j" are document IDs enclosed in square brackets. Separate multiple IDs with spaces (e.g., [ID:0] [ID:1]).
- Citation markers must be placed at the end of a sentence, separated by a space from the final punctuation (e.g., period, question mark).
- A maximum of 4 citations are allowed per sentence.
- DO NOT insert citations if the content is not from retrieved chunks.
- DO NOT use standalone Document IDs (e.g., #ID#).
- Citations MUST always follow the [ID:i] format.
- STRICTLY prohibit the use of strikethrough symbols (e.g., ~~) or any other non-standard formatting syntax.
- Any violation of the above rules — including incorrect formatting, prohibited styles, or unsupported citations — will result in no citation being added for that sentence.
# Citation Requirements:
---
## Technical Rules:
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
- Place citations at the end of sentences, before punctuation
- Maximum 4 citations per sentence
- DO NOT cite content not from <context></context>
- DO NOT modify whitespace or original text
- STRICTLY prohibit non-standard formatting (~~, etc.)
## Example START
## What MUST Be Cited:
1. **Quantitative data**: Numbers, percentages, statistics, measurements
2. **Temporal claims**: Dates, timeframes, sequences of events
3. **Causal relationships**: Claims about cause and effect
4. **Comparative statements**: Rankings, comparisons, superlatives
5. **Technical definitions**: Specialized terms, concepts, methodologies
6. **Direct attributions**: What someone said, did, or believes
7. **Predictions/forecasts**: Future projections, trend analyses
8. **Controversial claims**: Disputed facts, minority opinions
<SYSTEM>: Here is the knowledge base:
## What Should NOT Be Cited:
- Common knowledge (e.g., "The sun rises in the east")
- Transitional phrases
- General introductions
- Your own analysis or synthesis (unless directly from source)
Document: Elon Musk Breaks Silence on Crypto, Warns Against Dogecoin ...
URL: https://blockworks.co/news/elon-musk-crypto-dogecoin
ID: 0
The Tesla co-founder advised against going all-in on dogecoin, but Elon Musk said its still his favorite crypto...
# Comprehensive Examples:
Document: Elon Musk's Dogecoin tweet sparks social media frenzy
ID: 1
Musk said he is 'willing to serve' D.O.G.E. shorthand for Dogecoin.
## Example 1: Data and Statistics
<context>
ID: 45
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
Document: Causal effect of Elon Musk tweets on Dogecoin price
ID: 2
If you think of Dogecoin — the cryptocurrency based on a meme — you cant help but also think of Elon Musk...
ID: 46
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
</context>
Document: Elon Musk's Tweet Ignites Dogecoin's Future In Public Services
ID: 3
The market is heating up after Elon Musk's announcement about Dogecoin. Is this a new era for crypto?...
USER: How is the smartphone market performing?
The above is the knowledge base.
ASSISTANT:
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
<USER>: What's Elon's view on dogecoin?
## Example 2: Causal Relationships and Predictions
<context>
ID: 78
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
<ASSISTANT>:
Musk has consistently expressed his fondness for Dogecoin, often citing its humor and the inclusion of dogs in its branding. He has referred to it as his favorite cryptocurrency [ID:0] [ID:1].
ID: 79
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
</context>
Recently, Musk has hinted at potential future roles for Dogecoin. His tweets have sparked speculation about Dogecoin's potential integration into public services [ID:3].
USER: What's happening in the real estate market?
Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against over-investing in it, reflecting both his personal amusement and caution regarding its speculative nature.
ASSISTANT:
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
## Example END
## Example 3: Technical Concepts and Comparisons
<context>
ID: 101
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
ID: 102
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
</context>
USER: Explain quantum computing advantages.
ASSISTANT:
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
## Example 4: Mixed Citation Needs
<context>
ID: 201
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
ID: 202
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
ID: 203
└── Content: The company's AI division currently represents only 5% of total revenue.
</context>
USER: What's the company's new strategy?
ASSISTANT:
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
## Example 5: When NOT to Cite
<context>
ID: 301
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
</context>
USER: Tell me about Paris.
ASSISTANT:
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
(Note: Only the specific date needs citation, not common knowledge about Paris)
--- Examples END ---
REMEMBER:
- Cite FACTS, not opinions or transitions
- Each citation supports the ENTIRE sentence
- When in doubt, ask: "Would a fact-checker need to verify this?"
- Place citations at sentence end, before punctuation

63
rag/prompts/next_step.md Normal file
View File

@ -0,0 +1,63 @@
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
Your job is:
1. Based on the task analysis, chose some right tools to execute.
2. Track progress and adapt plans(tool calls) when necessary.
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
# ========== TASK ANALYSIS =============
{{ task_analisys }}
# ========== TOOLS (JSON-Schema) ==========
You may invoke only the tools listed below.
Return a JSON array of objects in which item is with exactly two top-level keys:
• "name": the tool to call
• "arguments": an object whose keys/values satisfy the schema
{{ desc }}
# ========== RESPONSE FORMAT ==========
**When you need a tool**
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
[{
"name": "<tool_name1>",
"arguments": { /* tool arguments matching its schema */ }
},{
"name": "<tool_name2>",
"arguments": { /* tool arguments matching its schema */ }
}...]<|stop|>
**When you are certain the task is solved OR no further information can be obtained**
Return ONLY:
[{
"name": "complete_task",
"arguments": { "answer": "<final answer text>" }
}]<|stop|>
<verification_steps>
Before providing a final answer:
1. Double-check all gathered information
2. Verify calculations and logic
3. Ensure answer matches exactly what was asked
4. Confirm answer format meets requirements
5. Run additional verification if confidence is not 100%
</verification_steps>
<error_handling>
If you encounter issues:
1. Try alternative approaches before giving up
2. Use different tools or combinations of tools
3. Break complex problems into simpler sub-tasks
4. Verify intermediate results frequently
5. Never return "I cannot answer" without exhausting all options
</error_handling>
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
# ========== REASONING & REFLECTION ==========
You may think privately (not shown to the user) before producing each JSON object.
Internal guideline:
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
2. **Act**: Emit the JSON object to call the tool.
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.

View File

@ -1,8 +1,7 @@
import os
BASE_DIR = os.path.dirname(__file__)
PROMPT_DIR = os.path.join(BASE_DIR, "prompts")
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}

View File

@ -17,19 +17,25 @@ import datetime
import json
import logging
import re
from collections import defaultdict
from copy import deepcopy
from typing import Tuple
import jinja2
import json_repair
from rag.prompt_template import load_prompt
from api.utils import hash_str2int
from rag.prompts.prompt_template import load_prompt
from rag.settings import TAG_FLD
from rag.utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"
COMPLETE_TASK="complete_task"
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
return [
{
@ -87,14 +93,16 @@ def message_fit_in(msg, max_length=4000):
return max_length, msg
def kb_prompt(kbinfos, max_tokens):
def kb_prompt(kbinfos, max_tokens, hash_id=False):
from api.db.services.document_service import DocumentService
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
kwlg_len = len(knowledges)
used_token_count = 0
chunks_num = 0
for i, c in enumerate(knowledges):
if not c:
continue
used_token_count += num_tokens_from_string(c)
chunks_num += 1
if max_tokens * 0.97 < used_token_count:
@ -102,29 +110,30 @@ def kb_prompt(kbinfos, max_tokens):
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
break
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
docs = {d.id: d.meta_fields for d in docs}
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL | re.IGNORECASE)
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
def draw_node(k, line):
if not line:
return ""
return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
knowledges = []
for nm, cks_meta in doc2chunks.items():
txt = f"\nDocument: {nm} \n"
for k, v in cks_meta["meta"].items():
txt += f"{k}: {v}\n"
txt += "Relevant fragments as following:\n"
for i, chunk in enumerate(cks_meta["chunks"], 1):
txt += f"{chunk}\n"
knowledges.append(txt)
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v)
cnt += "\n└── Content:\n"
cnt += get_value(ck, "content", "content_with_weight")
knowledges.append(cnt)
return knowledges
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
@ -134,6 +143,13 @@ QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
NEXT_STEP = load_prompt("next_step")
REFLECT = load_prompt("reflect")
SUMMARY4MEMORY = load_prompt("summary4memory")
RANK_MEMORY = load_prompt("rank_memory")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
@ -142,6 +158,11 @@ def citation_prompt() -> str:
return template.render()
def citation_plus(sources: str) -> str:
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
return template.render(example=citation_prompt(), sources=sources)
def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
@ -172,15 +193,16 @@ def question_proposal(chat_mdl, content, topn=3):
return kwd
def full_question(tenant_id, llm_id, messages, language=None):
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
from api.db import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.llm_service import TenantLLMService
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
if not chat_mdl:
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
else:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
conv = []
for m in messages:
if m["role"] not in ["user", "assistant"]:
@ -200,7 +222,7 @@ def full_question(tenant_id, llm_id, messages, language=None):
language=language,
)
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
@ -278,13 +300,116 @@ def vision_llm_figure_describe_prompt() -> str:
return template.render()
if __name__ == "__main__":
print(CITATION_PROMPT_TEMPLATE)
print(CONTENT_TAGGING_PROMPT_TEMPLATE)
print(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE)
print(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE)
print(FULL_QUESTION_PROMPT_TEMPLATE)
print(KEYWORD_PROMPT_TEMPLATE)
print(QUESTION_PROMPT_TEMPLATE)
print(VISION_LLM_DESCRIBE_PROMPT)
print(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
def tool_schema(tools_description: list[dict], complete_task=False):
if not tools_description:
return ""
desc = {}
if complete_task:
desc[COMPLETE_TASK] = {
"type": "function",
"function": {
"name": COMPLETE_TASK,
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
}
for tool in tools_description:
desc[tool["function"]["name"]] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
def form_history(history, limit=-6):
context = ""
for h in history[limit:]:
if h["role"] == "system":
continue
role = "USER"
if h["role"].upper()!= role:
role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
return context
def analyze_task(chat_mdl, task_name, tools_description: list[dict]):
tools_desc = tool_schema(tools_description)
context = ""
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
if kwd.find("**ERROR**") >= 0:
return ""
return kwd
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
if not tools_description:
return ""
desc = tool_schema(tools_description)
template = PROMPT_JINJA_ENV.from_string(NEXT_STEP)
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = chat_mdl.chat(template.render(task_analisys=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], stop=["<|stop|>"])
tk_cnt = num_tokens_from_string(json_str)
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
return json_str, tk_cnt
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(REFLECT)
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
hist = deepcopy(history)
if hist[-1]["role"] == "user":
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
{}
**Reflection**
{}
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)

View File

@ -0,0 +1,30 @@
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
**Rules**:
1. Analyze each result's contribution to both:
- The overall goal (primary priority)
- The current sub-goal (secondary priority)
2. Sort from MOST relevant (highest impact) to LEAST relevant
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
🔹 Overall Goal: {{ goal }}
🔹 Sub-goal: {{ sub_goal }}
**Examples**:
🔹 Tool Response:
- index: 0
> Tokyo temperature is 78°F.
- index: 1
> Error: Authentication failed (expired API key).
- index: 2
> Available: 12 widgets in stock (max 5 per customer).
→ rank: [1,2,0]<|stop|>
**Your Turn**:
🔹 Tool Response:
{% for f in results %}
- index: f.i
> f.content
{% endfor %}

34
rag/prompts/reflect.md Normal file
View File

@ -0,0 +1,34 @@
**Context**:
- To achieve the goal: {{ goal }}.
- You have executed following tool calls:
{% for call in tool_calls %}
Tool call: `{{ call.name }}`
Results: {{ call.result }}
{% endfor %}
**Reflection Instructions:**
Analyze the current state of the overall task ({{ goal }}), then provide structured responses to the following:
## 1. Goal Achievement Status
- Does the current outcome align with the original purpose of this task phase?
- If not, what critical gaps exist?
## 2. Step Completion Check
- Which planned steps were completed? (List verified items)
- Which steps are pending/incomplete? (Specify exactly whats missing)
## 3. Information Adequacy
- Is the collected data sufficient to proceed?
- What key information is still needed? (e.g., metrics, user input, external data)
## 4. Critical Observations
- Unexpected outcomes: [Flag anomalies/errors]
- Risks/blockers: [Identify immediate obstacles]
- Accuracy concerns: [Highlight unreliable results]
## 5. Next-Step Recommendations
- Proposed immediate action: [Concrete next step]
- Alternative strategies if blocked: [Workaround solution]
- Tools/inputs required for next phase: [Specify resources]

View File

@ -0,0 +1,35 @@
**Role**: AI Assistant
**Task**: Summarize tool call responses
**Rules**:
1. Context: You've executed a tool (API/function) and received a response.
2. Condense the response into 1-2 short sentences.
3. Never omit:
- Success/error status
- Core results (e.g., data points, decisions)
- Critical constraints (e.g., limits, conditions)
4. Exclude technical details like timestamps/request IDs unless crucial.
5. Use language as the same as main content of the tool response.
**Response Template**:
"[Status] + [Key Outcome] + [Critical Constraints]"
**Examples**:
🔹 Tool Response:
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
→ Summary: "Success: Tokyo temperature is 78°F."
🔹 Tool Response:
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
→ Summary: "Error: Authentication failed (expired API key)."
🔹 Tool Response:
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
**Your Turn**:
- Tool call: {{ name }}
- Tool inputs as following:
{{ params }}
- Tool Response:
{{ result }}

View File

@ -0,0 +1,19 @@
**Task Instruction:**
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
**Guidelines:**
1. **Analyze the Results:**
- Carefully review the content of each results of tool call.
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
2. **Extract Relevant Information:**
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
- Ensure that the extracted information is accurate and relevant.
- **Inputs for current call:**
{{ inputs }}
- **Results:**
{{ results }}

View File

@ -239,7 +239,17 @@ def shutdown_all_mcp_sessions():
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
if isinstance(mcp_tool, dict):
return {
"type": "function",
"function": {
"name": mcp_tool["name"],
"description": mcp_tool["description"],
"parameters": mcp_tool["inputSchema"],
},
}
return {
"type": "function",
"function": {