Fix IDE warnings (#12281)

### What problem does this PR solve?

As title

### Type of change

- [x] Refactoring

---------

Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
Jin Hai
2025-12-29 12:01:18 +08:00
committed by GitHub
parent 647fb115a0
commit 01f0ced1e6
43 changed files with 817 additions and 637 deletions

View File

@ -3,4 +3,4 @@ from . import generator
__all__ = [name for name in dir(generator)
if not name.startswith('_')]
globals().update({name: getattr(generator, name) for name in __all__})
globals().update({name: getattr(generator, name) for name in __all__})

View File

@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
from common.constants import TAG_FLD
from common.token_utils import encoder, num_tokens_from_string
STOP_TOKEN="<|STOP|>"
COMPLETE_TASK="complete_task"
STOP_TOKEN = "<|STOP|>"
COMPLETE_TASK = "complete_task"
INPUT_UTILIZATION = 0.5
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
def chunks_format(reference):
return [
{
"id": get_value(chunk, "chunk_id", "id"),
@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
cnt += draw_node(k, v)
cnt += "\n└── Content:\n"
@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
def citation_prompt(user_defined_prompts: dict={}) -> str:
def citation_prompt(user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
return template.render()
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
languages=languages)
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
{"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
"parameters": {
"type": "object",
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
"properties": {
"answer": {"type": "string", "description": "The final answer to the user's question"}},
"required": ["answer"]
}
}
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
name = tool["function"]["name"]
desc[name] = tool
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
enumerate(desc.items())])
def form_history(history, limit=-6):
@ -350,14 +353,14 @@ def form_history(history, limit=-6):
if h["role"] == "system":
continue
role = "USER"
if h["role"].upper()!= role:
if h["role"].upper() != role:
role = "AGENT"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
return context
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
user_defined_prompts: dict = {}):
tools_desc = tool_schema(tools_description)
context = ""
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
user_defined_prompts: dict = {}):
if not tools_description:
return "", 0
desc = tool_schema(tools_description)
@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
return json_str, tk_cnt
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
goal = history[1]["content"]
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
def form_message(system_prompt, user_prompt):
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
def structured_output_prompt(schema=None) -> str:
@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema)
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
user_defined_prompts: dict = {}):
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
system_prompt = template.render(goal=goal, sub_goal=sub_goal,
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
meta_data_structure = {}
for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []}
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
TOC_DETECTION = load_prompt("toc_detection")
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
async def detect_table_of_contents(page_1024: list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
"Only JSON please.", chat_mdl)
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
If the title of the section are not in the provided pages, do not add the physical_index to it.
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
indent=2) + '\nDocument pages:\n' + content
return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
toc_map = {}
for i, it in enumerate(toc_arr):
k1 = (it["structure"]+it["title"]).replace(" ", "")
k1 = (it["structure"] + it["title"]).replace(" ", "")
k2 = it["title"].strip()
if k1 not in toc_map:
toc_map[k1] = []
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
toc_arr[j]["indices"].append(i)
all_pathes = []
def dfs(start, path):
nonlocal all_pathes
if start >= len(toc_arr):
@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
all_pathes.append(path)
return
if not toc_arr[start]["indices"]:
dfs(start+1, path)
dfs(start + 1, path)
return
added = False
for j in toc_arr[start]["indices"]:
@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
_path = deepcopy(path)
_path.append((j, start))
added = True
dfs(start+1, _path)
dfs(start + 1, _path)
if not added and path:
all_pathes.append(path)
dfs(0, [])
path = max(all_pathes, key=lambda x:len(x))
path = max(all_pathes, key=lambda x: len(x))
for it in toc_arr:
it["indices"] = []
for j, i in path:
@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
i = 0
while i < len(toc_arr):
it = toc_arr[i]
it = toc_arr[i]
if it["indices"]:
i += 1
continue
if i>0 and toc_arr[i-1]["indices"]:
st_i = toc_arr[i-1]["indices"][-1]
if i > 0 and toc_arr[i - 1]["indices"]:
st_i = toc_arr[i - 1]["indices"][-1]
else:
st_i = 0
e = i + 1
while e <len(toc_arr) and not toc_arr[e]["indices"]:
while e < len(toc_arr) and not toc_arr[e]["indices"]:
e += 1
if e >= len(toc_arr):
e = len(sections)
else:
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
for j in range(st_i, min(e + 1, len(sections))):
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
toc_content = "\n".join(toc_pages)
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False, indent=2),
chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content,
json.dumps(last_complete, ensure_ascii=False,
indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
if not toc_secs:
return []
return await gen_json(
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
)
input_budget = 1024 if input_budget > 1024 else input_budget
input_budget = 1024 if input_budget > 1024 else input_budget
chunk_sections = split_chunks(chunks, input_budget)
titles = []
@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
if sorted_list:
max_lvl = sorted_list[-1]
merged = []
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
if prune and toc_item.get("level", "0") >= max_lvl:
continue
merged.append({
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
import numpy as np
try:
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,
gen_conf={"temperature": 0.0, "top_p": 0.9}
)
@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
for id in ti.get("ids", []):
if id not in id2score:
id2score[id] = []
id2score[id].append(sc["score"]/5.)
id2score[id].append(sc["score"] / 5.)
for id in id2score.keys():
id2score[id] = np.mean(id2score[id])
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
except Exception as e:
logging.exception(e)
return []
META_DATA = load_prompt("meta_data")
async def gen_metadata(chat_mdl, schema:dict, content:str):
async def gen_metadata(chat_mdl, schema: dict, content: str):
template = PROMPT_JINJA_ENV.from_string(META_DATA)
for k, desc in schema["properties"].items():
if "enum" in desc and not desc.get("enum"):
@ -849,4 +881,4 @@ async def gen_metadata(chat_mdl, schema:dict, content:str):
user_prompt = "Output: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)

View File

@ -1,6 +1,5 @@
import os
PROMPT_DIR = os.path.dirname(__file__)
_loaded_prompts = {}