mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
415
rag/prompts/prompts.py
Normal file
415
rag/prompts/prompts.py
Normal file
@ -0,0 +1,415 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
from api.utils import hash_str2int
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
"content": get_value(chunk, "content", "content_with_weight"),
|
||||
"document_id": get_value(chunk, "doc_id", "document_id"),
|
||||
"document_name": get_value(chunk, "docnm_kwd", "document_name"),
|
||||
"dataset_id": get_value(chunk, "kb_id", "dataset_id"),
|
||||
"image_id": get_value(chunk, "image_id", "img_id"),
|
||||
"positions": get_value(chunk, "positions", "position_int"),
|
||||
"url": chunk.get("url"),
|
||||
"similarity": chunk.get("similarity"),
|
||||
"vector_similarity": chunk.get("vector_similarity"),
|
||||
"term_similarity": chunk.get("term_similarity"),
|
||||
"doc_type": chunk.get("doc_type_kwd"),
|
||||
}
|
||||
for chunk in reference.get("chunks", [])
|
||||
]
|
||||
|
||||
|
||||
def message_fit_in(msg, max_length=4000):
|
||||
def count():
|
||||
nonlocal msg
|
||||
tks_cnts = []
|
||||
for m in msg:
|
||||
tks_cnts.append({"role": m["role"], "count": num_tokens_from_string(m["content"])})
|
||||
total = 0
|
||||
for m in tks_cnts:
|
||||
total += m["count"]
|
||||
return total
|
||||
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
msg_ = [m for m in msg if m["role"] == "system"]
|
||||
if len(msg) > 1:
|
||||
msg_.append(msg[-1])
|
||||
msg = msg_
|
||||
c = count()
|
||||
if c < max_length:
|
||||
return c, msg
|
||||
|
||||
ll = num_tokens_from_string(msg_[0]["content"])
|
||||
ll2 = num_tokens_from_string(msg_[-1]["content"])
|
||||
if ll / (ll + ll2) > 0.8:
|
||||
m = msg_[0]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[0]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
m = msg_[-1]["content"]
|
||||
m = encoder.decode(encoder.encode(m)[: max_length - ll2])
|
||||
msg[-1]["content"] = m
|
||||
return max_length, msg
|
||||
|
||||
|
||||
def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
|
||||
kwlg_len = len(knowledges)
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
for i, c in enumerate(knowledges):
|
||||
if not c:
|
||||
continue
|
||||
used_token_count += num_tokens_from_string(c)
|
||||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
knowledges = knowledges[:i]
|
||||
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
|
||||
docs = {d.id: d.meta_fields for d in docs}
|
||||
|
||||
def draw_node(k, line):
|
||||
if not line:
|
||||
return ""
|
||||
return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
|
||||
|
||||
knowledges = []
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
cnt += draw_node(k, v)
|
||||
cnt += "\n└── Content:\n"
|
||||
cnt += get_value(ck, "content", "content_with_weight")
|
||||
knowledges.append(cnt)
|
||||
|
||||
return knowledges
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
|
||||
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
|
||||
FULL_QUESTION_PROMPT_TEMPLATE = load_prompt("full_question_prompt")
|
||||
KEYWORD_PROMPT_TEMPLATE = load_prompt("keyword_prompt")
|
||||
QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
||||
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
||||
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
||||
|
||||
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
||||
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
||||
NEXT_STEP = load_prompt("next_step")
|
||||
REFLECT = load_prompt("reflect")
|
||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||
RANK_MEMORY = load_prompt("rank_memory")
|
||||
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt() -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PROMPT_TEMPLATE)
|
||||
return template.render()
|
||||
|
||||
|
||||
def citation_plus(sources: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
|
||||
return template.render(example=citation_prompt(), sources=sources)
|
||||
|
||||
|
||||
def keyword_extraction(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def question_proposal(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
if not chat_mdl:
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conversation = "\n".join(conv)
|
||||
today = datetime.date.today().isoformat()
|
||||
yesterday = (datetime.date.today() - datetime.timedelta(days=1)).isoformat()
|
||||
tomorrow = (datetime.date.today() + datetime.timedelta(days=1)).isoformat()
|
||||
|
||||
template = PROMPT_JINJA_ENV.from_string(FULL_QUESTION_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(
|
||||
today=today,
|
||||
yesterday=yesterday,
|
||||
tomorrow=tomorrow,
|
||||
conversation=conversation,
|
||||
language=language,
|
||||
)
|
||||
|
||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||
|
||||
|
||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
if llm_id and TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
|
||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||
|
||||
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
return query
|
||||
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
||||
|
||||
|
||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
||||
|
||||
for ex in examples:
|
||||
ex["tags_json"] = json.dumps(ex[TAG_FLD], indent=2, ensure_ascii=False)
|
||||
|
||||
rendered_prompt = template.render(
|
||||
topn=topn,
|
||||
all_tags=all_tags,
|
||||
examples=examples,
|
||||
content=content,
|
||||
)
|
||||
|
||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
raise Exception(kwd)
|
||||
|
||||
try:
|
||||
obj = json_repair.loads(kwd)
|
||||
except json_repair.JSONDecodeError:
|
||||
try:
|
||||
result = kwd.replace(rendered_prompt[:-1], "").replace("user", "").replace("model", "").strip()
|
||||
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
||||
obj = json_repair.loads(result)
|
||||
except Exception as e:
|
||||
logging.exception(f"JSON parsing error: {result} -> {e}")
|
||||
raise e
|
||||
res = {}
|
||||
for k, v in obj.items():
|
||||
try:
|
||||
if int(v) > 0:
|
||||
res[str(k)] = int(v)
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
def vision_llm_describe_prompt(page=None) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_DESCRIBE_PROMPT)
|
||||
|
||||
return template.render(page=page)
|
||||
|
||||
|
||||
def vision_llm_figure_describe_prompt() -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
|
||||
return template.render()
|
||||
|
||||
|
||||
def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = {}
|
||||
if complete_task:
|
||||
desc[COMPLETE_TASK] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPLETE_TASK,
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
for tool in tools_description:
|
||||
desc[tool["function"]["name"]] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
context = ""
|
||||
for h in history[limit:]:
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, task_name, tools_description: list[dict]):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
|
||||
|
||||
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(NEXT_STEP)
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analisys=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(REFLECT)
|
||||
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
{}
|
||||
|
||||
**Reflection**
|
||||
{}
|
||||
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
Reference in New Issue
Block a user