mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: Redesign and refactor agent module (#9113)
### What problem does this PR solve? #9082 #6365 <u> **WARNING: it's not compatible with the older version of `Agent` module, which means that `Agent` from older versions can not work anymore.**</u> ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -69,6 +69,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
Since a book is long and not all the parts are useful, if it's a PDF,
|
||||
please setup the page ranges for every book in order eliminate negative effects and save elapsed computing time.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
@ -89,7 +92,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
|
||||
@ -40,7 +40,7 @@ def chunk(
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config",
|
||||
{"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
|
||||
{"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"},
|
||||
)
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
|
||||
@ -145,6 +145,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"""
|
||||
Supported file formats are docx, pdf, txt.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
@ -163,7 +166,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
for txt, poss in pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)[0]:
|
||||
|
||||
@ -45,9 +45,6 @@ class Pdf(PdfParser):
|
||||
callback
|
||||
)
|
||||
callback(msg="OCR finished ({:.2f}s)".format(timer() - start))
|
||||
# for bb in self.boxes:
|
||||
# for b in bb:
|
||||
# print(b)
|
||||
logging.debug("OCR: {}".format(timer() - start))
|
||||
|
||||
start = timer()
|
||||
@ -177,6 +174,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
"""
|
||||
Only pdf is supported.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
pdf_parser = None
|
||||
doc = {
|
||||
"docnm_kwd": filename
|
||||
@ -187,7 +187,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
eng = lang.lower() == "english" # pdf_parser.is_english
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, tbls = pdf_parser(filename if not binary else binary,
|
||||
from_page=from_page, to_page=to_page, callback=callback)
|
||||
@ -222,7 +222,6 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
if lvl <= most_level and i > 0 and lvl != levels[i - 1]:
|
||||
sid += 1
|
||||
sec_ids.append(sid)
|
||||
# print(lvl, self.boxes[i]["text"], most_level, sid)
|
||||
|
||||
sections = [(txt, sec_ids[i], poss)
|
||||
for i, (txt, _, poss) in enumerate(sections)]
|
||||
|
||||
@ -370,7 +370,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
is_english = lang.lower() == "english" # is_english(cks)
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 128, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
doc = {
|
||||
"docnm_kwd": filename,
|
||||
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))
|
||||
|
||||
@ -72,7 +72,9 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
Supported file formats are docx, pdf, excel, txt.
|
||||
One file forms a chunk which maintains original text order.
|
||||
"""
|
||||
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
eng = lang.lower() == "english" # is_english(cks)
|
||||
|
||||
if re.search(r"\.docx$", filename, re.IGNORECASE):
|
||||
@ -85,7 +87,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
elif re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
pdf_parser = Pdf()
|
||||
if kwargs.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
sections, _ = pdf_parser(
|
||||
filename if not binary else binary, to_page=to_page, callback=callback)
|
||||
|
||||
@ -143,8 +143,11 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
Only pdf is supported.
|
||||
The abstract of the paper will be sliced as an entire chunk, and will not be sliced partly.
|
||||
"""
|
||||
parser_config = kwargs.get(
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC"})
|
||||
if re.search(r"\.pdf$", filename, re.IGNORECASE):
|
||||
if kwargs.get("parser_config", {}).get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
if parser_config.get("layout_recognize", "DeepDOC") == "Plain Text":
|
||||
pdf_parser = PlainParser()
|
||||
paper = {
|
||||
"title": filename,
|
||||
|
||||
@ -18,6 +18,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
@ -31,25 +32,33 @@ from dashscope import Generation
|
||||
from ollama import Client
|
||||
from openai import OpenAI
|
||||
from openai.lib.azure import AzureOpenAI
|
||||
from strenum import StrEnum
|
||||
from zhipuai import ZhipuAI
|
||||
|
||||
from rag.nlp import is_chinese, is_english
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
||||
# Error message constants
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
class LLMErrorCode(StrEnum):
|
||||
ERROR_RATE_LIMIT = "RATE_LIMIT_EXCEEDED"
|
||||
ERROR_AUTHENTICATION = "AUTH_ERROR"
|
||||
ERROR_INVALID_REQUEST = "INVALID_REQUEST"
|
||||
ERROR_SERVER = "SERVER_ERROR"
|
||||
ERROR_TIMEOUT = "TIMEOUT"
|
||||
ERROR_CONNECTION = "CONNECTION_ERROR"
|
||||
ERROR_MODEL = "MODEL_ERROR"
|
||||
ERROR_MAX_ROUNDS = "ERROR_MAX_ROUNDS"
|
||||
ERROR_CONTENT_FILTER = "CONTENT_FILTERED"
|
||||
ERROR_QUOTA = "QUOTA_EXCEEDED"
|
||||
ERROR_MAX_RETRIES = "MAX_RETRIES_EXCEEDED"
|
||||
ERROR_GENERIC = "GENERIC_ERROR"
|
||||
|
||||
|
||||
class ReActMode(StrEnum):
|
||||
FUNCTION_CALL = "function_call"
|
||||
REACT = "react"
|
||||
|
||||
ERROR_PREFIX = "**ERROR**"
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
@ -73,51 +82,78 @@ class Base(ABC):
|
||||
|
||||
def _get_delay(self):
|
||||
"""Calculate retry delay time"""
|
||||
return self.base_delay + random.uniform(60, 150)
|
||||
return self.base_delay * random.uniform(10, 150)
|
||||
|
||||
def _classify_error(self, error):
|
||||
"""Classify error based on error message content"""
|
||||
error_str = str(error).lower()
|
||||
|
||||
if "rate limit" in error_str or "429" in error_str or "tpm limit" in error_str or "too many requests" in error_str or "requests per minute" in error_str:
|
||||
return ERROR_RATE_LIMIT
|
||||
elif "auth" in error_str or "key" in error_str or "apikey" in error_str or "401" in error_str or "forbidden" in error_str or "permission" in error_str:
|
||||
return ERROR_AUTHENTICATION
|
||||
elif "invalid" in error_str or "bad request" in error_str or "400" in error_str or "format" in error_str or "malformed" in error_str or "parameter" in error_str:
|
||||
return ERROR_INVALID_REQUEST
|
||||
elif "server" in error_str or "502" in error_str or "503" in error_str or "504" in error_str or "500" in error_str or "unavailable" in error_str:
|
||||
return ERROR_SERVER
|
||||
elif "timeout" in error_str or "timed out" in error_str:
|
||||
return ERROR_TIMEOUT
|
||||
elif "connect" in error_str or "network" in error_str or "unreachable" in error_str or "dns" in error_str:
|
||||
return ERROR_CONNECTION
|
||||
elif "quota" in error_str or "capacity" in error_str or "credit" in error_str or "billing" in error_str or "limit" in error_str and "rate" not in error_str:
|
||||
return ERROR_QUOTA
|
||||
elif "filter" in error_str or "content" in error_str or "policy" in error_str or "blocked" in error_str or "safety" in error_str or "inappropriate" in error_str:
|
||||
return ERROR_CONTENT_FILTER
|
||||
elif "model" in error_str or "not found" in error_str or "does not exist" in error_str or "not available" in error_str:
|
||||
return ERROR_MODEL
|
||||
else:
|
||||
return ERROR_GENERIC
|
||||
keywords_mapping = [
|
||||
(["quota", "capacity", "credit", "billing", "balance", "欠费"], LLMErrorCode.ERROR_QUOTA),
|
||||
(["rate limit", "429", "tpm limit", "too many requests", "requests per minute"], LLMErrorCode.ERROR_RATE_LIMIT),
|
||||
(["auth", "key", "apikey", "401", "forbidden", "permission"], LLMErrorCode.ERROR_AUTHENTICATION),
|
||||
(["invalid", "bad request", "400", "format", "malformed", "parameter"], LLMErrorCode.ERROR_INVALID_REQUEST),
|
||||
(["server", "503", "502", "504", "500", "unavailable"], LLMErrorCode.ERROR_SERVER),
|
||||
(["timeout", "timed out"], LLMErrorCode.ERROR_TIMEOUT),
|
||||
(["connect", "network", "unreachable", "dns"], LLMErrorCode.ERROR_CONNECTION),
|
||||
(["filter", "content", "policy", "blocked", "safety", "inappropriate"], LLMErrorCode.ERROR_CONTENT_FILTER),
|
||||
(["model", "not found", "does not exist", "not available"], LLMErrorCode.ERROR_MODEL),
|
||||
(["max rounds"], LLMErrorCode.ERROR_MODEL),
|
||||
]
|
||||
for words, code in keywords_mapping:
|
||||
if re.search("({})".format("|".join(words)), error_str):
|
||||
return code
|
||||
|
||||
return LLMErrorCode.ERROR_GENERIC
|
||||
|
||||
def _clean_conf(self, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf)
|
||||
def _chat(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY]" + json.dumps(history, ensure_ascii=False, indent=2))
|
||||
if self.model_name.lower().find("qwen3") >=0:
|
||||
kwargs["extra_body"] = {"enable_thinking": False}
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
|
||||
if any([not response.choices, not response.choices[0].message, not response.choices[0].message.content]):
|
||||
return "", 0
|
||||
ans = response.choices[0].message.content.strip()
|
||||
if response.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
ans = self._length_stop(ans)
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def _chat_streamly(self, history, gen_conf, **kwargs):
|
||||
logging.info("[HISTORY STREAMLY]" + json.dumps(history, ensure_ascii=False, indent=4))
|
||||
reasoning_start = False
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf, stop=kwargs.get("stop"))
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if kwargs.get("with_reasoning", True) and hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
tol = num_tokens_from_string(resp.choices[0].delta.content)
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
def _length_stop(self, ans):
|
||||
if is_chinese([ans]):
|
||||
return ans + LENGTH_NOTIFICATION_CN
|
||||
@ -127,22 +163,24 @@ class Base(ABC):
|
||||
logging.exception("OpenAI chat_with_tools")
|
||||
# Classify the error
|
||||
error_code = self._classify_error(e)
|
||||
if attempt == self.max_retries:
|
||||
error_code = LLMErrorCode.ERROR_MAX_RETRIES
|
||||
|
||||
# Check if it's a rate limit error or server error and not the last attempt
|
||||
should_retry = (error_code == ERROR_RATE_LIMIT or error_code == ERROR_SERVER) and attempt < self.max_retries
|
||||
|
||||
if should_retry:
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
else:
|
||||
# For non-rate limit errors or the last attempt, return an error message
|
||||
if attempt == self.max_retries:
|
||||
error_code = ERROR_MAX_RETRIES
|
||||
should_retry = (error_code == LLMErrorCode.ERROR_RATE_LIMIT or error_code == LLMErrorCode.ERROR_SERVER)
|
||||
if not should_retry:
|
||||
return f"{ERROR_PREFIX}: {error_code} - {str(e)}"
|
||||
|
||||
delay = self._get_delay()
|
||||
logging.warning(f"Error: {error_code}. Retrying in {delay:.2f} seconds... (Attempt {attempt + 1}/{self.max_retries})")
|
||||
time.sleep(delay)
|
||||
|
||||
def _verbose_tool_use(self, name, args, res):
|
||||
return "<tool_call>" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
return "<tool_call>" + json.dumps({
|
||||
"name": name,
|
||||
"args": args,
|
||||
"result": res
|
||||
}, ensure_ascii=False, indent=2) + "</tool_call>"
|
||||
|
||||
def _append_history(self, hist, tool_call, tool_res):
|
||||
hist.append(
|
||||
@ -172,17 +210,14 @@ class Base(ABC):
|
||||
if not (toolcall_session and tools):
|
||||
return
|
||||
self.is_tools = True
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
for tool in tools:
|
||||
self.toolcall_sessions[tool["function"]["name"]] = toolcall_session
|
||||
self.tools.append(tool)
|
||||
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
def chat_with_tools(self, system: str, history: list, gen_conf: dict={}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
hist = deepcopy(history)
|
||||
@ -190,8 +225,9 @@ class Base(ABC):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds * 2):
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, **gen_conf)
|
||||
for _ in range(self.max_rounds+1):
|
||||
logging.info(f"{self.tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, tools=self.tools, tool_choice="auto", **gen_conf)
|
||||
tk_count += self.total_token_count(response)
|
||||
if any([not response.choices, not response.choices[0].message]):
|
||||
raise Exception(f"500 response structure error. Response: {response}")
|
||||
@ -207,10 +243,11 @@ class Base(ABC):
|
||||
return ans, tk_count
|
||||
|
||||
for tool_call in response.choices[0].message.tool_calls:
|
||||
logging.info(f"Response {tool_call=}")
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_sessions[name].tool_call(name, args)
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
ans += self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -218,13 +255,20 @@ class Base(ABC):
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
ans += self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning( f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response, token_count = self._chat(history, gen_conf)
|
||||
ans += response
|
||||
tk_count += token_count
|
||||
return ans, tk_count
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
return e, tk_count
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@ -232,7 +276,7 @@ class Base(ABC):
|
||||
# Implement exponential backoff retry strategy
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return self._chat(history, gen_conf)
|
||||
return self._chat(history, gen_conf, **kwargs)
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
@ -253,7 +297,7 @@ class Base(ABC):
|
||||
|
||||
return final_tool_calls
|
||||
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict):
|
||||
def chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict={}):
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system:
|
||||
@ -265,9 +309,10 @@ class Base(ABC):
|
||||
for attempt in range(self.max_retries + 1):
|
||||
history = hist
|
||||
try:
|
||||
for _ in range(self.max_rounds * 2):
|
||||
for _ in range(self.max_rounds+1):
|
||||
reasoning_start = False
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, **gen_conf)
|
||||
logging.info(f"{tools=}")
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, tools=tools, tool_choice="auto", **gen_conf)
|
||||
final_tool_calls = {}
|
||||
answer = ""
|
||||
for resp in response:
|
||||
@ -319,7 +364,8 @@ class Base(ABC):
|
||||
name = tool_call.function.name
|
||||
try:
|
||||
args = json_repair.loads(tool_call.function.arguments)
|
||||
tool_response = self.toolcall_session[name].tool_call(name, args)
|
||||
yield self._verbose_tool_use(name, args, "Begin to call...")
|
||||
tool_response = self.toolcall_session.tool_call(name, args)
|
||||
history = self._append_history(history, tool_call, tool_response)
|
||||
yield self._verbose_tool_use(name, args, tool_response)
|
||||
except Exception as e:
|
||||
@ -327,51 +373,45 @@ class Base(ABC):
|
||||
history.append({"role": "tool", "tool_call_id": tool_call.id, "content": f"Tool call error: \n{tool_call}\nException:\n" + str(e)})
|
||||
yield self._verbose_tool_use(name, {}, str(e))
|
||||
|
||||
logging.warning( f"Exceed max rounds: {self.max_rounds}")
|
||||
history.append({"role": "user", "content": f"Exceed max rounds: {self.max_rounds}"})
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if any([not resp.choices, not resp.choices[0].delta, not hasattr(resp.choices[0].delta, "content")]):
|
||||
raise Exception("500 response structure error.")
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
continue
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens += tol
|
||||
answer += resp.choices[0].delta.content
|
||||
yield resp.choices[0].delta.content
|
||||
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
e = self._exceptions(e, attempt)
|
||||
if e:
|
||||
yield e
|
||||
yield total_tokens
|
||||
return
|
||||
|
||||
yield total_tokens
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf: dict={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
reasoning_start = False
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices:
|
||||
continue
|
||||
if not resp.choices[0].delta.content:
|
||||
resp.choices[0].delta.content = ""
|
||||
if hasattr(resp.choices[0].delta, "reasoning_content") and resp.choices[0].delta.reasoning_content:
|
||||
ans = ""
|
||||
if not reasoning_start:
|
||||
reasoning_start = True
|
||||
ans = "<think>"
|
||||
ans += resp.choices[0].delta.reasoning_content + "</think>"
|
||||
else:
|
||||
reasoning_start = False
|
||||
ans = resp.choices[0].delta.content
|
||||
|
||||
tol = self.total_token_count(resp)
|
||||
if not tol:
|
||||
total_tokens += num_tokens_from_string(resp.choices[0].delta.content)
|
||||
else:
|
||||
total_tokens += tol
|
||||
|
||||
if resp.choices[0].finish_reason == "length":
|
||||
if is_chinese(ans):
|
||||
ans += LENGTH_NOTIFICATION_CN
|
||||
else:
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans
|
||||
|
||||
for delta_ans, tol in self._chat_streamly(history, gen_conf, **kwargs):
|
||||
yield delta_ans
|
||||
total_tokens += tol
|
||||
except openai.APIError as e:
|
||||
yield ans + "\n**ERROR**: " + str(e)
|
||||
|
||||
@ -514,7 +554,7 @@ class BaiChuanChat(Base):
|
||||
"top_p": gen_conf.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@ -529,7 +569,7 @@ class BaiChuanChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -614,7 +654,7 @@ class ZhipuChat(Base):
|
||||
|
||||
return super().chat_with_tools(system, history, gen_conf)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -626,6 +666,7 @@ class ZhipuChat(Base):
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
logging.info(json.dumps(history, ensure_ascii=False, indent=2))
|
||||
response = self.client.chat.completions.create(model=self.model_name, messages=history, stream=True, **gen_conf)
|
||||
for resp in response:
|
||||
if not resp.choices[0].delta.content:
|
||||
@ -675,7 +716,7 @@ class OllamaChat(Base):
|
||||
options[k] = gen_conf[k]
|
||||
return options
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
# Calculate context size
|
||||
ctx_size = self._calculate_dynamic_ctx(history)
|
||||
|
||||
@ -685,7 +726,7 @@ class OllamaChat(Base):
|
||||
token_count = response.get("eval_count", 0) + response.get("prompt_eval_count", 0)
|
||||
return ans, token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -766,7 +807,7 @@ class LocalLLM(Base):
|
||||
yield answer + "\n**ERROR**: " + str(e)
|
||||
yield num_tokens_from_string(answer)
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@ -775,7 +816,7 @@ class LocalLLM(Base):
|
||||
total_tokens = next(chat_gen)
|
||||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@ -894,7 +935,7 @@ class MistralChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -904,7 +945,7 @@ class MistralChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -913,7 +954,7 @@ class MistralChat(Base):
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf)
|
||||
response = self.client.chat_stream(model=self.model_name, messages=history, **gen_conf, **kwargs)
|
||||
for resp in response:
|
||||
if not resp.choices or not resp.choices[0].delta.content:
|
||||
continue
|
||||
@ -957,7 +998,7 @@ class BedrockChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
hist = []
|
||||
for item in history:
|
||||
@ -978,7 +1019,7 @@ class BedrockChat(Base):
|
||||
ans = response["output"]["message"]["content"][0]["text"]
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -1036,7 +1077,7 @@ class GeminiChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
@ -1059,7 +1100,7 @@ class GeminiChat(Base):
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@ -1101,7 +1142,7 @@ class GroqChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
@ -1229,7 +1270,7 @@ class CoHereChat(Base):
|
||||
response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
|
||||
)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@ -1348,7 +1389,7 @@ class ReplicateChat(Base):
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
|
||||
response = self.client.run(
|
||||
@ -1358,7 +1399,7 @@ class ReplicateChat(Base):
|
||||
ans = "".join(response)
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
||||
@ -1402,7 +1443,7 @@ class HunyuanChat(Base):
|
||||
_gen_conf["TopP"] = gen_conf["top_p"]
|
||||
return _gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.hunyuan.v20230901 import models
|
||||
|
||||
hist = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
@ -1413,7 +1454,7 @@ class HunyuanChat(Base):
|
||||
ans = response.Choices[0].Message.Content
|
||||
return ans, response.Usage.TotalTokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
|
||||
TencentCloudSDKException,
|
||||
)
|
||||
@ -1504,7 +1545,7 @@ class BaiduYiyanChat(Base):
|
||||
ans = response["result"]
|
||||
return ans, self.total_token_count(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@ -1588,7 +1629,7 @@ class GoogleChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf):
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
if "claude" in self.model_name:
|
||||
response = self.client.messages.create(
|
||||
@ -1626,7 +1667,7 @@ class GoogleChat(Base):
|
||||
ans = response.text
|
||||
return ans, response.usage_metadata.total_token_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
if "claude" in self.model_name:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
|
||||
1153
rag/llm/cv_model.py
1153
rag/llm/cv_model.py
File diff suppressed because it is too large
Load Diff
@ -89,7 +89,7 @@ class DefaultRerank(Base):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
except Exception as e:
|
||||
print(f"Error emptying cache: {e}")
|
||||
log_exception(e)
|
||||
|
||||
def _process_batch(self, pairs, max_batch_size=None):
|
||||
"""template method for subclass call"""
|
||||
|
||||
@ -518,7 +518,8 @@ def hierarchical_merge(bull, sections, depth):
|
||||
return res
|
||||
|
||||
|
||||
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?", overlapped_percent=0):
|
||||
from deepdoc.parser.pdf_parser import RAGFlowPdfParser
|
||||
if not sections:
|
||||
return []
|
||||
if isinstance(sections[0], type("")):
|
||||
@ -534,8 +535,10 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
if tnum < 8:
|
||||
pos = ""
|
||||
# Ensure that the length of the merged chunk does not exceed chunk_token_num
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num:
|
||||
|
||||
if cks[-1] == "" or tk_nums[-1] > chunk_token_num * (100 - overlapped_percent)/100.:
|
||||
if cks:
|
||||
overlapped = RAGFlowPdfParser.remove_tag(cks[-1])
|
||||
t = overlapped[int(len(overlapped)*(100-overlapped_percent)/100.):] + t
|
||||
if t.find(pos) < 0:
|
||||
t += pos
|
||||
cks.append(t)
|
||||
@ -548,7 +551,10 @@ def naive_merge(sections, chunk_token_num=128, delimiter="\n。;!?"):
|
||||
|
||||
dels = get_delimiters(delimiter)
|
||||
for sec, pos in sections:
|
||||
splited_sec = re.split(r"(%s)" % dels, sec)
|
||||
if num_tokens_from_string(sec) < chunk_token_num:
|
||||
add_chunk(sec, pos)
|
||||
continue
|
||||
splited_sec = re.split(r"(%s)" % dels, sec, flags=re.DOTALL)
|
||||
for sub_sec in splited_sec:
|
||||
if re.match(f"^{dels}$", sub_sec):
|
||||
continue
|
||||
|
||||
@ -384,7 +384,7 @@ class Dealer:
|
||||
zero_vector = [0.0] * dim
|
||||
sim_np = np.array(sim)
|
||||
if doc_ids:
|
||||
similarity_threshold = 0
|
||||
similarity_threshold = 0
|
||||
filtered_count = (sim_np >= similarity_threshold).sum()
|
||||
ranks["total"] = int(filtered_count) # Convert from np.int64 to Python int otherwise JSON serializable error
|
||||
for i in idx:
|
||||
@ -403,7 +403,7 @@ class Dealer:
|
||||
ranks["doc_aggs"][dnm]["count"] += 1
|
||||
continue
|
||||
break
|
||||
|
||||
|
||||
position_int = chunk.get("position_int", [])
|
||||
d = {
|
||||
"chunk_id": id,
|
||||
|
||||
6
rag/prompts/__init__.py
Normal file
6
rag/prompts/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
from . import prompts
|
||||
|
||||
__all__ = [name for name in dir(prompts)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(prompts, name) for name in __all__})
|
||||
8
rag/prompts/analyze_task_system.md
Normal file
8
rag/prompts/analyze_task_system.md
Normal file
@ -0,0 +1,8 @@
|
||||
Your responsibility is to execute assigned tasks to a high standard. Please:
|
||||
1. Carefully analyze the task requirements.
|
||||
2. Develop a reasonable execution plan.
|
||||
3. Execute step-by-step and document the reasoning process.
|
||||
4. Provide clear and accurate results.
|
||||
|
||||
If difficulties are encountered, clearly state the problem and explore alternative approaches.
|
||||
|
||||
20
rag/prompts/analyze_task_user.md
Normal file
20
rag/prompts/analyze_task_user.md
Normal file
@ -0,0 +1,20 @@
|
||||
Please analyze the following task:
|
||||
|
||||
Task: {{ task }}
|
||||
|
||||
Context: {{ context }}
|
||||
|
||||
**Analysis Requirements:**
|
||||
1. Is it just a small talk? (If yes, no further plan or analysis is needed)
|
||||
2. What is the core objective of the task?
|
||||
3. What is the complexity level of the task?
|
||||
4. What types of specialized skills are required?
|
||||
5. Does the task need to be decomposed into subtasks? (If yes, propose the subtask structure)
|
||||
6. How to know the task or the subtasks are impossible to lead to the success after a few rounds of interaction?
|
||||
7. What are the expected success criteria?
|
||||
|
||||
**Available Sub-Agents and Their Specializations:**
|
||||
|
||||
{{ tools_desc }}
|
||||
|
||||
Provide a detailed analysis of the task based on the above requirements.
|
||||
13
rag/prompts/citation_plus.md
Normal file
13
rag/prompts/citation_plus.md
Normal file
@ -0,0 +1,13 @@
|
||||
You are an agent for adding correct citations to the given text by user.
|
||||
You are given a piece of text within [ID:<ID>] tags, which was generated based on the provided sources.
|
||||
However, the sources are not cited in the [ID:<ID>].
|
||||
Your task is to enhance user trust by generating correct, appropriate citations for this report.
|
||||
|
||||
{{ example }}
|
||||
|
||||
<context>
|
||||
|
||||
{{ sources }}
|
||||
|
||||
</context>
|
||||
|
||||
@ -1,46 +1,108 @@
|
||||
## Citation Requirements
|
||||
Based on the provided document or chat history, add citations to the input text using the format specified later.
|
||||
|
||||
- Use a uniform citation format such as [ID:i] [ID:j], where "i" and "j" are document IDs enclosed in square brackets. Separate multiple IDs with spaces (e.g., [ID:0] [ID:1]).
|
||||
- Citation markers must be placed at the end of a sentence, separated by a space from the final punctuation (e.g., period, question mark).
|
||||
- A maximum of 4 citations are allowed per sentence.
|
||||
- DO NOT insert citations if the content is not from retrieved chunks.
|
||||
- DO NOT use standalone Document IDs (e.g., #ID#).
|
||||
- Citations MUST always follow the [ID:i] format.
|
||||
- STRICTLY prohibit the use of strikethrough symbols (e.g., ~~) or any other non-standard formatting syntax.
|
||||
- Any violation of the above rules — including incorrect formatting, prohibited styles, or unsupported citations — will result in no citation being added for that sentence.
|
||||
# Citation Requirements:
|
||||
|
||||
---
|
||||
## Technical Rules:
|
||||
- Use format: [ID:i] or [ID:i] [ID:j] for multiple sources
|
||||
- Place citations at the end of sentences, before punctuation
|
||||
- Maximum 4 citations per sentence
|
||||
- DO NOT cite content not from <context></context>
|
||||
- DO NOT modify whitespace or original text
|
||||
- STRICTLY prohibit non-standard formatting (~~, etc.)
|
||||
|
||||
## Example START
|
||||
## What MUST Be Cited:
|
||||
1. **Quantitative data**: Numbers, percentages, statistics, measurements
|
||||
2. **Temporal claims**: Dates, timeframes, sequences of events
|
||||
3. **Causal relationships**: Claims about cause and effect
|
||||
4. **Comparative statements**: Rankings, comparisons, superlatives
|
||||
5. **Technical definitions**: Specialized terms, concepts, methodologies
|
||||
6. **Direct attributions**: What someone said, did, or believes
|
||||
7. **Predictions/forecasts**: Future projections, trend analyses
|
||||
8. **Controversial claims**: Disputed facts, minority opinions
|
||||
|
||||
<SYSTEM>: Here is the knowledge base:
|
||||
## What Should NOT Be Cited:
|
||||
- Common knowledge (e.g., "The sun rises in the east")
|
||||
- Transitional phrases
|
||||
- General introductions
|
||||
- Your own analysis or synthesis (unless directly from source)
|
||||
|
||||
Document: Elon Musk Breaks Silence on Crypto, Warns Against Dogecoin ...
|
||||
URL: https://blockworks.co/news/elon-musk-crypto-dogecoin
|
||||
ID: 0
|
||||
The Tesla co-founder advised against going all-in on dogecoin, but Elon Musk said it’s still his favorite crypto...
|
||||
# Comprehensive Examples:
|
||||
|
||||
Document: Elon Musk's Dogecoin tweet sparks social media frenzy
|
||||
ID: 1
|
||||
Musk said he is 'willing to serve' D.O.G.E. – shorthand for Dogecoin.
|
||||
## Example 1: Data and Statistics
|
||||
<context>
|
||||
ID: 45
|
||||
└── Content: The global smartphone market grew by 7.8% in Q3 2024, with Samsung holding 20.6% market share and Apple at 15.8%.
|
||||
|
||||
Document: Causal effect of Elon Musk tweets on Dogecoin price
|
||||
ID: 2
|
||||
If you think of Dogecoin — the cryptocurrency based on a meme — you can’t help but also think of Elon Musk...
|
||||
ID: 46
|
||||
└── Content: 5G adoption reached 1.5 billion users worldwide by October 2024.
|
||||
</context>
|
||||
|
||||
Document: Elon Musk's Tweet Ignites Dogecoin's Future In Public Services
|
||||
ID: 3
|
||||
The market is heating up after Elon Musk's announcement about Dogecoin. Is this a new era for crypto?...
|
||||
USER: How is the smartphone market performing?
|
||||
|
||||
The above is the knowledge base.
|
||||
ASSISTANT:
|
||||
The smartphone industry is showing strong recovery. The global smartphone market grew by 7.8% in Q3 2024 [ID:45]. Samsung leads with 20.6% market share, followed by Apple at 15.8% [ID:45]. This growth is partly driven by 5G technology, which reached 1.5 billion users worldwide by October 2024 [ID:46].
|
||||
|
||||
<USER>: What's Elon's view on dogecoin?
|
||||
## Example 2: Causal Relationships and Predictions
|
||||
<context>
|
||||
ID: 78
|
||||
└── Content: Rising interest rates have led to a 15% decline in housing starts. Economists predict this trend will continue through 2025.
|
||||
|
||||
<ASSISTANT>:
|
||||
Musk has consistently expressed his fondness for Dogecoin, often citing its humor and the inclusion of dogs in its branding. He has referred to it as his favorite cryptocurrency [ID:0] [ID:1].
|
||||
ID: 79
|
||||
└── Content: The Federal Reserve's aggressive monetary policy is the primary driver behind the real estate slowdown.
|
||||
</context>
|
||||
|
||||
Recently, Musk has hinted at potential future roles for Dogecoin. His tweets have sparked speculation about Dogecoin's potential integration into public services [ID:3].
|
||||
USER: What's happening in the real estate market?
|
||||
|
||||
Overall, while Musk enjoys Dogecoin and often promotes it, he also warns against over-investing in it, reflecting both his personal amusement and caution regarding its speculative nature.
|
||||
ASSISTANT:
|
||||
The real estate market is experiencing significant challenges. Rising interest rates have led to a 15% decline in housing starts [ID:78]. The Federal Reserve's aggressive monetary policy is the primary driver behind this real estate slowdown [ID:79]. Economists predict this trend will continue through 2025 [ID:78]. These factors suggest a prolonged period of market adjustment ahead.
|
||||
|
||||
## Example END
|
||||
## Example 3: Technical Concepts and Comparisons
|
||||
<context>
|
||||
ID: 101
|
||||
└── Content: Quantum computing uses quantum bits (qubits) that can exist in superposition, unlike classical bits which are either 0 or 1.
|
||||
|
||||
ID: 102
|
||||
└── Content: Current quantum computers can solve certain problems 158 million times faster than classical supercomputers.
|
||||
</context>
|
||||
|
||||
USER: Explain quantum computing advantages.
|
||||
|
||||
ASSISTANT:
|
||||
Quantum computing represents a fundamental shift in computational capability. Unlike classical computers, quantum computing uses quantum bits (qubits) that can exist in superposition, rather than being limited to just 0 or 1 [ID:101]. This unique property enables extraordinary performance gains. Current quantum computers can solve certain problems 158 million times faster than classical supercomputers [ID:102]. This technology promises to revolutionize fields requiring complex calculations.
|
||||
|
||||
## Example 4: Mixed Citation Needs
|
||||
<context>
|
||||
ID: 201
|
||||
└── Content: CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026."
|
||||
|
||||
ID: 202
|
||||
└── Content: Industry analysts remain skeptical, with Morgan Stanley rating the stock as "underperform."
|
||||
|
||||
ID: 203
|
||||
└── Content: The company's AI division currently represents only 5% of total revenue.
|
||||
</context>
|
||||
|
||||
USER: What's the company's new strategy?
|
||||
|
||||
ASSISTANT:
|
||||
The company is undergoing a major strategic transformation. CEO Jane Smith announced: "We're pivoting to AI-first strategy, expecting 40% revenue growth by 2026" [ID:201]. However, this ambitious goal faces skepticism, as industry analysts remain doubtful, with Morgan Stanley rating the stock as "underperform" [ID:202]. The challenge is significant given that the company's AI division currently represents only 5% of total revenue [ID:203]. Success will depend on execution of this bold vision.
|
||||
|
||||
## Example 5: When NOT to Cite
|
||||
<context>
|
||||
ID: 301
|
||||
└── Content: Paris is the capital of France. The Eiffel Tower was completed in 1889.
|
||||
</context>
|
||||
|
||||
USER: Tell me about Paris.
|
||||
|
||||
ASSISTANT:
|
||||
Paris is the capital of France. It's known for its rich history, culture, and architecture. The Eiffel Tower was completed in 1889 [ID:301]. The city attracts millions of tourists annually. Paris remains one of the world's most visited destinations.
|
||||
(Note: Only the specific date needs citation, not common knowledge about Paris)
|
||||
|
||||
--- Examples END ---
|
||||
|
||||
REMEMBER:
|
||||
- Cite FACTS, not opinions or transitions
|
||||
- Each citation supports the ENTIRE sentence
|
||||
- When in doubt, ask: "Would a fact-checker need to verify this?"
|
||||
- Place citations at sentence end, before punctuation
|
||||
63
rag/prompts/next_step.md
Normal file
63
rag/prompts/next_step.md
Normal file
@ -0,0 +1,63 @@
|
||||
You are an expert Planning Agent tasked with solving problems efficiently through structured plans.
|
||||
Your job is:
|
||||
1. Based on the task analysis, chose some right tools to execute.
|
||||
2. Track progress and adapt plans(tool calls) when necessary.
|
||||
3. Use `complete_task` if no further step you need to take from tools. (All necessary steps done or little hope to be done)
|
||||
|
||||
# ========== TASK ANALYSIS =============
|
||||
{{ task_analisys }}
|
||||
|
||||
|
||||
# ========== TOOLS (JSON-Schema) ==========
|
||||
You may invoke only the tools listed below.
|
||||
Return a JSON array of objects in which item is with exactly two top-level keys:
|
||||
• "name": the tool to call
|
||||
• "arguments": an object whose keys/values satisfy the schema
|
||||
|
||||
{{ desc }}
|
||||
|
||||
# ========== RESPONSE FORMAT ==========
|
||||
✦ **When you need a tool**
|
||||
Return ONLY the Json (no additional keys, no commentary, end with `<|stop|>`), such as following:
|
||||
[{
|
||||
"name": "<tool_name1>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
},{
|
||||
"name": "<tool_name2>",
|
||||
"arguments": { /* tool arguments matching its schema */ }
|
||||
}...]<|stop|>
|
||||
|
||||
✦ **When you are certain the task is solved OR no further information can be obtained**
|
||||
Return ONLY:
|
||||
[{
|
||||
"name": "complete_task",
|
||||
"arguments": { "answer": "<final answer text>" }
|
||||
}]<|stop|>
|
||||
|
||||
<verification_steps>
|
||||
Before providing a final answer:
|
||||
1. Double-check all gathered information
|
||||
2. Verify calculations and logic
|
||||
3. Ensure answer matches exactly what was asked
|
||||
4. Confirm answer format meets requirements
|
||||
5. Run additional verification if confidence is not 100%
|
||||
</verification_steps>
|
||||
|
||||
<error_handling>
|
||||
If you encounter issues:
|
||||
1. Try alternative approaches before giving up
|
||||
2. Use different tools or combinations of tools
|
||||
3. Break complex problems into simpler sub-tasks
|
||||
4. Verify intermediate results frequently
|
||||
5. Never return "I cannot answer" without exhausting all options
|
||||
</error_handling>
|
||||
|
||||
⚠️ Any output that is not valid JSON or that contains extra fields will be rejected.
|
||||
|
||||
# ========== REASONING & REFLECTION ==========
|
||||
You may think privately (not shown to the user) before producing each JSON object.
|
||||
Internal guideline:
|
||||
1. **Reason**: Analyse the user question; decide which tools (if any) are needed.
|
||||
2. **Act**: Emit the JSON object to call the tool.
|
||||
|
||||
Today is {{ today }}. Remember that success in answering questions accurately is paramount - take all necessary steps to ensure your answer is correct.
|
||||
@ -1,8 +1,7 @@
|
||||
import os
|
||||
|
||||
BASE_DIR = os.path.dirname(__file__)
|
||||
|
||||
PROMPT_DIR = os.path.join(BASE_DIR, "prompts")
|
||||
PROMPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
_loaded_prompts = {}
|
||||
|
||||
@ -17,19 +17,25 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Tuple
|
||||
import jinja2
|
||||
import json_repair
|
||||
|
||||
from rag.prompt_template import load_prompt
|
||||
from api.utils import hash_str2int
|
||||
from rag.prompts.prompt_template import load_prompt
|
||||
from rag.settings import TAG_FLD
|
||||
from rag.utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
return [
|
||||
{
|
||||
@ -87,14 +93,16 @@ def message_fit_in(msg, max_length=4000):
|
||||
return max_length, msg
|
||||
|
||||
|
||||
def kb_prompt(kbinfos, max_tokens):
|
||||
def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
knowledges = [ck["content_with_weight"] for ck in kbinfos["chunks"]]
|
||||
knowledges = [get_value(ck, "content", "content_with_weight") for ck in kbinfos["chunks"]]
|
||||
kwlg_len = len(knowledges)
|
||||
used_token_count = 0
|
||||
chunks_num = 0
|
||||
for i, c in enumerate(knowledges):
|
||||
if not c:
|
||||
continue
|
||||
used_token_count += num_tokens_from_string(c)
|
||||
chunks_num += 1
|
||||
if max_tokens * 0.97 < used_token_count:
|
||||
@ -102,29 +110,30 @@ def kb_prompt(kbinfos, max_tokens):
|
||||
logging.warning(f"Not all the retrieval into prompt: {len(knowledges)}/{kwlg_len}")
|
||||
break
|
||||
|
||||
docs = DocumentService.get_by_ids([ck["doc_id"] for ck in kbinfos["chunks"][:chunks_num]])
|
||||
docs = DocumentService.get_by_ids([get_value(ck, "doc_id", "document_id") for ck in kbinfos["chunks"][:chunks_num]])
|
||||
docs = {d.id: d.meta_fields for d in docs}
|
||||
|
||||
doc2chunks = defaultdict(lambda: {"chunks": [], "meta": []})
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = f"---\nID: {i}\n" + (f"URL: {ck['url']}\n" if "url" in ck else "")
|
||||
cnt += re.sub(r"( style=\"[^\"]+\"|</?(html|body|head|title)>|<!DOCTYPE html>)", " ", ck["content_with_weight"], flags=re.DOTALL | re.IGNORECASE)
|
||||
doc2chunks[ck["docnm_kwd"]]["chunks"].append(cnt)
|
||||
doc2chunks[ck["docnm_kwd"]]["meta"] = docs.get(ck["doc_id"], {})
|
||||
def draw_node(k, line):
|
||||
if not line:
|
||||
return ""
|
||||
return f"\n├── {k}: " + re.sub(r"\n+", " ", line, flags=re.DOTALL)
|
||||
|
||||
knowledges = []
|
||||
for nm, cks_meta in doc2chunks.items():
|
||||
txt = f"\nDocument: {nm} \n"
|
||||
for k, v in cks_meta["meta"].items():
|
||||
txt += f"{k}: {v}\n"
|
||||
txt += "Relevant fragments as following:\n"
|
||||
for i, chunk in enumerate(cks_meta["chunks"], 1):
|
||||
txt += f"{chunk}\n"
|
||||
knowledges.append(txt)
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 100))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
cnt += draw_node(k, v)
|
||||
cnt += "\n└── Content:\n"
|
||||
cnt += get_value(ck, "content", "content_with_weight")
|
||||
knowledges.append(cnt)
|
||||
|
||||
return knowledges
|
||||
|
||||
|
||||
CITATION_PROMPT_TEMPLATE = load_prompt("citation_prompt")
|
||||
CITATION_PLUS_TEMPLATE = load_prompt("citation_plus")
|
||||
CONTENT_TAGGING_PROMPT_TEMPLATE = load_prompt("content_tagging_prompt")
|
||||
CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE = load_prompt("cross_languages_sys_prompt")
|
||||
CROSS_LANGUAGES_USER_PROMPT_TEMPLATE = load_prompt("cross_languages_user_prompt")
|
||||
@ -134,6 +143,13 @@ QUESTION_PROMPT_TEMPLATE = load_prompt("question_prompt")
|
||||
VISION_LLM_DESCRIBE_PROMPT = load_prompt("vision_llm_describe_prompt")
|
||||
VISION_LLM_FIGURE_DESCRIBE_PROMPT = load_prompt("vision_llm_figure_describe_prompt")
|
||||
|
||||
ANALYZE_TASK_SYSTEM = load_prompt("analyze_task_system")
|
||||
ANALYZE_TASK_USER = load_prompt("analyze_task_user")
|
||||
NEXT_STEP = load_prompt("next_step")
|
||||
REFLECT = load_prompt("reflect")
|
||||
SUMMARY4MEMORY = load_prompt("summary4memory")
|
||||
RANK_MEMORY = load_prompt("rank_memory")
|
||||
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
@ -142,6 +158,11 @@ def citation_prompt() -> str:
|
||||
return template.render()
|
||||
|
||||
|
||||
def citation_plus(sources: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(CITATION_PLUS_TEMPLATE)
|
||||
return template.render(example=citation_prompt(), sources=sources)
|
||||
|
||||
|
||||
def keyword_extraction(chat_mdl, content, topn=3):
|
||||
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
||||
rendered_prompt = template.render(content=content, topn=topn)
|
||||
@ -172,15 +193,16 @@ def question_proposal(chat_mdl, content, topn=3):
|
||||
return kwd
|
||||
|
||||
|
||||
def full_question(tenant_id, llm_id, messages, language=None):
|
||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.llm_service import TenantLLMService
|
||||
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
if not chat_mdl:
|
||||
if TenantLLMService.llm_id2llm_type(llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.IMAGE2TEXT, llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
conv = []
|
||||
for m in messages:
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
@ -200,7 +222,7 @@ def full_question(tenant_id, llm_id, messages, language=None):
|
||||
language=language,
|
||||
)
|
||||
|
||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.2})
|
||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||
|
||||
@ -278,13 +300,116 @@ def vision_llm_figure_describe_prompt() -> str:
|
||||
return template.render()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(CITATION_PROMPT_TEMPLATE)
|
||||
print(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
||||
print(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE)
|
||||
print(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE)
|
||||
print(FULL_QUESTION_PROMPT_TEMPLATE)
|
||||
print(KEYWORD_PROMPT_TEMPLATE)
|
||||
print(QUESTION_PROMPT_TEMPLATE)
|
||||
print(VISION_LLM_DESCRIBE_PROMPT)
|
||||
print(VISION_LLM_FIGURE_DESCRIBE_PROMPT)
|
||||
def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = {}
|
||||
if complete_task:
|
||||
desc[COMPLETE_TASK] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": COMPLETE_TASK,
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
}
|
||||
for tool in tools_description:
|
||||
desc[tool["function"]["name"]] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
context = ""
|
||||
for h in history[limit:]:
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
def analyze_task(chat_mdl, task_name, tools_description: list[dict]):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_USER)
|
||||
|
||||
kwd = chat_mdl.chat(ANALYZE_TASK_SYSTEM,[{"role": "user", "content": template.render(task=task_name, context=context, tools_desc=tools_desc)}], {})
|
||||
if isinstance(kwd, tuple):
|
||||
kwd = kwd[0]
|
||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||
if kwd.find("**ERROR**") >= 0:
|
||||
return ""
|
||||
return kwd
|
||||
|
||||
|
||||
def next_step(chat_mdl, history:list, tools_description: list[dict], task_desc):
|
||||
if not tools_description:
|
||||
return ""
|
||||
desc = tool_schema(tools_description)
|
||||
template = PROMPT_JINJA_ENV.from_string(NEXT_STEP)
|
||||
user_prompt = "\nWhat's the next tool to call? If ready OR IMPOSSIBLE TO BE READY, then call `complete_task`."
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
json_str = chat_mdl.chat(template.render(task_analisys=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||
hist[1:], stop=["<|stop|>"])
|
||||
tk_cnt = num_tokens_from_string(json_str)
|
||||
json_str = re.sub(r"^.*</think>", "", json_str, flags=re.DOTALL)
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
def reflect(chat_mdl, history: list[dict], tool_call_res: list[Tuple]):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(REFLECT)
|
||||
user_prompt = template.render(goal=goal, tool_calls=tool_calls)
|
||||
hist = deepcopy(history)
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += user_prompt
|
||||
else:
|
||||
hist.append({"role": "user", "content": user_prompt})
|
||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return """
|
||||
**Observation**
|
||||
{}
|
||||
|
||||
**Reflection**
|
||||
{}
|
||||
""".format(json.dumps(tool_calls, ensure_ascii=False, indent=2), ans)
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
def rank_memories(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str]):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
30
rag/prompts/rank_memory.md
Normal file
30
rag/prompts/rank_memory.md
Normal file
@ -0,0 +1,30 @@
|
||||
**Task**: Sort the tool call results based on relevance to the overall goal and current sub-goal. Return ONLY a sorted list of indices (0-indexed).
|
||||
|
||||
**Rules**:
|
||||
1. Analyze each result's contribution to both:
|
||||
- The overall goal (primary priority)
|
||||
- The current sub-goal (secondary priority)
|
||||
2. Sort from MOST relevant (highest impact) to LEAST relevant
|
||||
3. Output format: Strictly a Python-style list of integers. Example: [2, 0, 1]
|
||||
|
||||
🔹 Overall Goal: {{ goal }}
|
||||
🔹 Sub-goal: {{ sub_goal }}
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
- index: 0
|
||||
> Tokyo temperature is 78°F.
|
||||
- index: 1
|
||||
> Error: Authentication failed (expired API key).
|
||||
- index: 2
|
||||
> Available: 12 widgets in stock (max 5 per customer).
|
||||
|
||||
→ rank: [1,2,0]<|stop|>
|
||||
|
||||
|
||||
**Your Turn**:
|
||||
🔹 Tool Response:
|
||||
{% for f in results %}
|
||||
- index: f.i
|
||||
> f.content
|
||||
{% endfor %}
|
||||
34
rag/prompts/reflect.md
Normal file
34
rag/prompts/reflect.md
Normal file
@ -0,0 +1,34 @@
|
||||
**Context**:
|
||||
- To achieve the goal: {{ goal }}.
|
||||
- You have executed following tool calls:
|
||||
{% for call in tool_calls %}
|
||||
Tool call: `{{ call.name }}`
|
||||
Results: {{ call.result }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
**Reflection Instructions:**
|
||||
|
||||
Analyze the current state of the overall task ({{ goal }}), then provide structured responses to the following:
|
||||
|
||||
## 1. Goal Achievement Status
|
||||
- Does the current outcome align with the original purpose of this task phase?
|
||||
- If not, what critical gaps exist?
|
||||
|
||||
## 2. Step Completion Check
|
||||
- Which planned steps were completed? (List verified items)
|
||||
- Which steps are pending/incomplete? (Specify exactly what’s missing)
|
||||
|
||||
## 3. Information Adequacy
|
||||
- Is the collected data sufficient to proceed?
|
||||
- What key information is still needed? (e.g., metrics, user input, external data)
|
||||
|
||||
## 4. Critical Observations
|
||||
- Unexpected outcomes: [Flag anomalies/errors]
|
||||
- Risks/blockers: [Identify immediate obstacles]
|
||||
- Accuracy concerns: [Highlight unreliable results]
|
||||
|
||||
## 5. Next-Step Recommendations
|
||||
- Proposed immediate action: [Concrete next step]
|
||||
- Alternative strategies if blocked: [Workaround solution]
|
||||
- Tools/inputs required for next phase: [Specify resources]
|
||||
35
rag/prompts/summary4memory.md
Normal file
35
rag/prompts/summary4memory.md
Normal file
@ -0,0 +1,35 @@
|
||||
**Role**: AI Assistant
|
||||
**Task**: Summarize tool call responses
|
||||
**Rules**:
|
||||
1. Context: You've executed a tool (API/function) and received a response.
|
||||
2. Condense the response into 1-2 short sentences.
|
||||
3. Never omit:
|
||||
- Success/error status
|
||||
- Core results (e.g., data points, decisions)
|
||||
- Critical constraints (e.g., limits, conditions)
|
||||
4. Exclude technical details like timestamps/request IDs unless crucial.
|
||||
5. Use language as the same as main content of the tool response.
|
||||
|
||||
**Response Template**:
|
||||
"[Status] + [Key Outcome] + [Critical Constraints]"
|
||||
|
||||
**Examples**:
|
||||
🔹 Tool Response:
|
||||
{"status": "success", "temperature": 78.2, "unit": "F", "location": "Tokyo", "timestamp": 16923456}
|
||||
→ Summary: "Success: Tokyo temperature is 78°F."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"error": "invalid_api_key", "message": "Authentication failed: expired key"}
|
||||
→ Summary: "Error: Authentication failed (expired API key)."
|
||||
|
||||
🔹 Tool Response:
|
||||
{"available": true, "inventory": 12, "product": "widget", "limit": "max 5 per customer"}
|
||||
→ Summary: "Available: 12 widgets in stock (max 5 per customer)."
|
||||
|
||||
**Your Turn**:
|
||||
- Tool call: {{ name }}
|
||||
- Tool inputs as following:
|
||||
{{ params }}
|
||||
|
||||
- Tool Response:
|
||||
{{ result }}
|
||||
19
rag/prompts/tool_call_summary.md
Normal file
19
rag/prompts/tool_call_summary.md
Normal file
@ -0,0 +1,19 @@
|
||||
**Task Instruction:**
|
||||
|
||||
You are tasked with reading and analyzing tool call result based on the following inputs: **Inputs for current call**, and **Results**. Your objective is to extract relevant and helpful information for **Inputs for current call** from the **Results** and seamlessly integrate this information into the previous steps to continue reasoning for the original question.
|
||||
|
||||
**Guidelines:**
|
||||
|
||||
1. **Analyze the Results:**
|
||||
- Carefully review the content of each results of tool call.
|
||||
- Identify factual information that is relevant to the **Inputs for current call** and can aid in the reasoning process for the original question.
|
||||
|
||||
2. **Extract Relevant Information:**
|
||||
- Select the information from the Searched Web Pages that directly contributes to advancing the previous reasoning steps.
|
||||
- Ensure that the extracted information is accurate and relevant.
|
||||
|
||||
- **Inputs for current call:**
|
||||
{{ inputs }}
|
||||
|
||||
- **Results:**
|
||||
{{ results }}
|
||||
@ -239,7 +239,17 @@ def shutdown_all_mcp_sessions():
|
||||
logging.info("All MCPToolCallSession instances have been closed.")
|
||||
|
||||
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
|
||||
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool|dict) -> dict[str, Any]:
|
||||
if isinstance(mcp_tool, dict):
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": mcp_tool["name"],
|
||||
"description": mcp_tool["description"],
|
||||
"parameters": mcp_tool["inputSchema"],
|
||||
},
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
||||
Reference in New Issue
Block a user