Fix: tokenizer issue. (#11902)

#11786
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-12-11 17:38:17 +08:00
committed by GitHub
parent 22a51a3868
commit ea4a5cd665
17 changed files with 141 additions and 216 deletions

View File

@ -170,13 +170,13 @@ def citation_plus(sources: str) -> str:
return template.render(example=citation_prompt(), sources=sources)
def keyword_extraction(chat_mdl, content, topn=3):
async def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -185,13 +185,13 @@ def keyword_extraction(chat_mdl, content, topn=3):
return kwd
def question_proposal(chat_mdl, content, topn=3):
async def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -200,7 +200,7 @@ def question_proposal(chat_mdl, content, topn=3):
return kwd
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
@ -229,12 +229,12 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
language=language,
)
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = await chat_mdl.async_chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def cross_languages(tenant_id, llm_id, query, languages=[]):
async def cross_languages(tenant_id, llm_id, query, languages=[]):
from common.constants import LLMType
from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService
@ -247,14 +247,14 @@ def cross_languages(tenant_id, llm_id, query, languages=[]):
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0:
return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
async def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
for ex in examples:
@ -269,7 +269,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -352,7 +352,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
kwd = await chat_mdl.async_chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple):
kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -361,14 +361,6 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description:
return "", 0
@ -380,8 +372,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
hist[-1]["content"] += user_prompt
else:
hist.append({"role": "user", "content": user_prompt})
json_str = await _chat_async(
chat_mdl,
json_str = await chat_mdl.async_chat(
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:],
stop=["<|stop|>"],
@ -402,7 +393,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
else:
hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """
**Observation**
@ -422,14 +413,14 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema)
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2),
result=result)
user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@ -438,11 +429,11 @@ async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summar
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {}
for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@ -453,7 +444,7 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
user_question=query
)
user_prompt = "Generate filters:"
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}])
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
ans = json_repair.loads(ans)
@ -466,13 +457,13 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []}
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached:
return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try:
res = json_repair.loads(ans)
@ -483,10 +474,10 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
TOC_DETECTION = load_prompt("toc_detection")
def detect_table_of_contents(page_1024:list[str], chat_mdl):
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
toc_secs = []
for i, sec in enumerate(page_1024[:22]):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
if toc_secs and not ans["exists"]:
break
toc_secs.append(sec)
@ -495,14 +486,14 @@ def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
def extract_table_of_contents(toc_pages, chat_mdl):
async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages:
return []
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@ -525,11 +516,11 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
return gen_json(prompt, "Only JSON please.", chat_mdl)
return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index")
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections:
return []
@ -601,7 +592,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"],
title=it["title"],
text=sections[j]), "Only JSON please.", chat_mdl)
@ -614,7 +605,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
return toc_arr
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
async def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
prompt = """
You are given a raw table of contents and a table of contents.
Your job is to check if the table of contents is complete.
@ -627,11 +618,11 @@ def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
Directly return the final JSON structure. Do not output anything else."""
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = gen_json(prompt, "Only JSON please.", chat_mdl)
response = await gen_json(prompt, "Only JSON please.", chat_mdl)
return response['completed']
def toc_transformer(toc_pages, chat_mdl):
async def toc_transformer(toc_pages, chat_mdl):
init_prompt = """
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
@ -654,8 +645,8 @@ def toc_transformer(toc_pages, chat_mdl):
def clean_toc(arr):
for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
clean_toc(last_complete)
if if_complete == "yes":
return last_complete
@ -672,21 +663,21 @@ def toc_transformer(toc_pages, chat_mdl):
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
Please continue the json structure, directly output the remaining part of the json structure."""
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
new_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
break
clean_toc(new_complete)
last_complete.extend(new_complete)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels")
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
if not toc_secs:
return []
return gen_json(
return await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
str(toc_secs),
chat_mdl,
@ -699,7 +690,7 @@ TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try:
ans = gen_json(
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl,
@ -782,7 +773,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
raw_structure = [x.get("title", "") for x in filtered]
# Assign hierarchy levels using LLM
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
toc_with_levels = await assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
if not toc_with_levels:
return []
@ -807,10 +798,10 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
import numpy as np
try:
ans = gen_json(
ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl,