mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-22 06:06:40 +08:00
Refa: cleanup synchronous functions in agent_with_tools (#11736)
### What problem does this PR solve? Cleanup synchronous functions in agent_with_tools. ### Type of change - [x] Refactoring
This commit is contained in:
@ -343,7 +343,8 @@ def form_history(history, limit=-6):
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
@ -352,7 +353,7 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
||||
else:
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||
kwd = chat_mdl.chat(context, [{"role": "user", "content": "Please analyze it."}])
|
||||
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
@ -361,13 +362,17 @@ def analyze_task(chat_mdl, prompt, task_name, tools_description: list[dict], use
|
||||
return kwd
|
||||
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(analyze_task, chat_mdl, prompt, task_name, tools_description, user_defined_prompts)
|
||||
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
|
||||
chat_async = getattr(chat_mdl, "async_chat", None)
|
||||
if chat_async and asyncio.iscoroutinefunction(chat_async):
|
||||
return await chat_async(system, history, **kwargs)
|
||||
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
if not tools_description:
|
||||
return ""
|
||||
return "", 0
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("plan_generation", NEXT_STEP))
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
@ -376,18 +381,18 @@ def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc,
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
json_str = await _chat_async(
|
||||
chat_mdl,
|
||||
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:],
|
||||
stop=["<|stop|>"],
|
||||
)
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(next_step, chat_mdl, history, tools_description, task_desc, user_defined_prompts)
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
@ -398,7 +403,7 @@ def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defi
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
@ -429,23 +434,15 @@ def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defin
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(reflect, chat_mdl, history, tool_call_res, user_defined_prompts)
|
||||
|
||||
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
return await asyncio.to_thread(rank_memories, chat_mdl, goal, sub_goal, tool_call_summaries, user_defined_prompts)
|
||||
|
||||
|
||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
meta_data_structure = {}
|
||||
for key, values in meta_data.items():
|
||||
@ -514,7 +511,7 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
|
||||
The structure variable is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
@ -641,8 +638,8 @@ def toc_transformer(toc_pages, chat_mdl):
|
||||
|
||||
The `structure` is the numeric system which represents the index of the hierarchy section in the table of contents. For example, the first section has structure index 1, the first subsection has structure index 1.1, the second subsection has structure index 1.2, etc.
|
||||
The `title` is a short phrase or a several-words term.
|
||||
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The response should be in the following JSON format:
|
||||
[
|
||||
{
|
||||
"structure": <structure index, "x.x.x" or None> (string),
|
||||
@ -667,7 +664,7 @@ def toc_transformer(toc_pages, chat_mdl):
|
||||
while not (if_complete == "yes"):
|
||||
prompt = f"""
|
||||
Your task is to continue the table of contents json structure, directly output the remaining part of the json structure.
|
||||
The response should be in the following JSON format:
|
||||
The response should be in the following JSON format:
|
||||
|
||||
The raw table of contents json structure is:
|
||||
{toc_content}
|
||||
@ -756,7 +753,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
|
||||
for chunk in chunks_res:
|
||||
titles.extend(chunk.get("toc", []))
|
||||
|
||||
|
||||
# Filter out entries with title == -1
|
||||
prune = len(titles) > 512
|
||||
max_len = 12 if prune else 22
|
||||
|
||||
Reference in New Issue
Block a user