mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-02 02:25:31 +08:00
Fix IDE warnings (#12281)
### What problem does this PR solve? As title ### Type of change - [x] Refactoring --------- Signed-off-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@ -3,4 +3,4 @@ from . import generator
|
||||
__all__ = [name for name in dir(generator)
|
||||
if not name.startswith('_')]
|
||||
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
globals().update({name: getattr(generator, name) for name in __all__})
|
||||
|
||||
@ -28,17 +28,16 @@ from rag.prompts.template import load_prompt
|
||||
from common.constants import TAG_FLD
|
||||
from common.token_utils import encoder, num_tokens_from_string
|
||||
|
||||
|
||||
STOP_TOKEN="<|STOP|>"
|
||||
COMPLETE_TASK="complete_task"
|
||||
STOP_TOKEN = "<|STOP|>"
|
||||
COMPLETE_TASK = "complete_task"
|
||||
INPUT_UTILIZATION = 0.5
|
||||
|
||||
|
||||
def get_value(d, k1, k2):
|
||||
return d.get(k1, d.get(k2))
|
||||
|
||||
|
||||
def chunks_format(reference):
|
||||
|
||||
return [
|
||||
{
|
||||
"id": get_value(chunk, "chunk_id", "id"),
|
||||
@ -126,7 +125,7 @@ def kb_prompt(kbinfos, max_tokens, hash_id=False):
|
||||
for i, ck in enumerate(kbinfos["chunks"][:chunks_num]):
|
||||
cnt = "\nID: {}".format(i if not hash_id else hash_str2int(get_value(ck, "id", "chunk_id"), 500))
|
||||
cnt += draw_node("Title", get_value(ck, "docnm_kwd", "document_name"))
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
cnt += draw_node("URL", ck['url']) if "url" in ck else ""
|
||||
for k, v in docs.get(get_value(ck, "doc_id", "document_id"), {}).items():
|
||||
cnt += draw_node(k, v)
|
||||
cnt += "\n└── Content:\n"
|
||||
@ -173,7 +172,7 @@ ASK_SUMMARY = load_prompt("ask_summary")
|
||||
PROMPT_JINJA_ENV = jinja2.Environment(autoescape=False, trim_blocks=True, lstrip_blocks=True)
|
||||
|
||||
|
||||
def citation_prompt(user_defined_prompts: dict={}) -> str:
|
||||
def citation_prompt(user_defined_prompts: dict = {}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("citation_guidelines", CITATION_PROMPT_TEMPLATE))
|
||||
return template.render()
|
||||
|
||||
@ -258,9 +257,11 @@ async def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_id)
|
||||
|
||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query,
|
||||
languages=languages)
|
||||
|
||||
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}],
|
||||
{"temperature": 0.2})
|
||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
return query
|
||||
@ -332,7 +333,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
"description": "When you have the final answer and are ready to complete the task, call this function with your answer",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"answer":{"type":"string", "description": "The final answer to the user's question"}},
|
||||
"properties": {
|
||||
"answer": {"type": "string", "description": "The final answer to the user's question"}},
|
||||
"required": ["answer"]
|
||||
}
|
||||
}
|
||||
@ -341,7 +343,8 @@ def tool_schema(tools_description: list[dict], complete_task=False):
|
||||
name = tool["function"]["name"]
|
||||
desc[name] = tool
|
||||
|
||||
return "\n\n".join([f"## {i+1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in enumerate(desc.items())])
|
||||
return "\n\n".join([f"## {i + 1}. {fnm}\n{json.dumps(des, ensure_ascii=False, indent=4)}" for i, (fnm, des) in
|
||||
enumerate(desc.items())])
|
||||
|
||||
|
||||
def form_history(history, limit=-6):
|
||||
@ -350,14 +353,14 @@ def form_history(history, limit=-6):
|
||||
if h["role"] == "system":
|
||||
continue
|
||||
role = "USER"
|
||||
if h["role"].upper()!= role:
|
||||
if h["role"].upper() != role:
|
||||
role = "AGENT"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content'])>2048 else '')}"
|
||||
context += f"\n{role}: {h['content'][:2048] + ('...' if len(h['content']) > 2048 else '')}"
|
||||
return context
|
||||
|
||||
|
||||
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict], user_defined_prompts: dict={}):
|
||||
async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: list[dict],
|
||||
user_defined_prompts: dict = {}):
|
||||
tools_desc = tool_schema(tools_description)
|
||||
context = ""
|
||||
|
||||
@ -375,7 +378,8 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
|
||||
return kwd
|
||||
|
||||
|
||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||
async def next_step_async(chat_mdl, history: list, tools_description: list[dict], task_desc,
|
||||
user_defined_prompts: dict = {}):
|
||||
if not tools_description:
|
||||
return "", 0
|
||||
desc = tool_schema(tools_description)
|
||||
@ -396,7 +400,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
|
||||
return json_str, tk_cnt
|
||||
|
||||
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict={}):
|
||||
async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple], user_defined_prompts: dict = {}):
|
||||
tool_calls = [{"name": p[0], "result": p[1]} for p in tool_call_res]
|
||||
goal = history[1]["content"]
|
||||
template = PROMPT_JINJA_ENV.from_string(user_defined_prompts.get("reflection", REFLECT))
|
||||
@ -419,7 +423,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
|
||||
|
||||
|
||||
def form_message(system_prompt, user_prompt):
|
||||
return [{"role": "system", "content": system_prompt},{"role": "user", "content": user_prompt}]
|
||||
return [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}]
|
||||
|
||||
|
||||
def structured_output_prompt(schema=None) -> str:
|
||||
@ -427,27 +431,29 @@ def structured_output_prompt(schema=None) -> str:
|
||||
return template.render(schema=schema)
|
||||
|
||||
|
||||
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict = {}) -> str:
|
||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||
system_prompt = template.render(name=name,
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||
result=result)
|
||||
user_prompt = "→ Summary: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summaries: list[str], user_defined_prompts: dict={}):
|
||||
async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summaries: list[str],
|
||||
user_defined_prompts: dict = {}):
|
||||
template = PROMPT_JINJA_ENV.from_string(RANK_MEMORY)
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal,
|
||||
results=[{"i": i, "content": s} for i, s in enumerate(tool_call_summaries)])
|
||||
user_prompt = " → rank: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
|
||||
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
async def gen_meta_filter(chat_mdl, meta_data: dict, query: str) -> dict:
|
||||
meta_data_structure = {}
|
||||
for key, values in meta_data.items():
|
||||
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
|
||||
@ -471,13 +477,13 @@ async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||
return {"conditions": []}
|
||||
|
||||
|
||||
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||
async def gen_json(system_prompt: str, user_prompt: str, chat_mdl, gen_conf=None):
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||
if cached:
|
||||
return json_repair.loads(cached)
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], gen_conf=gen_conf)
|
||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||
try:
|
||||
res = json_repair.loads(ans)
|
||||
@ -488,10 +494,13 @@ async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None
|
||||
|
||||
|
||||
TOC_DETECTION = load_prompt("toc_detection")
|
||||
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
|
||||
|
||||
async def detect_table_of_contents(page_1024: list[str], chat_mdl):
|
||||
toc_secs = []
|
||||
for i, sec in enumerate(page_1024[:22]):
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.",
|
||||
chat_mdl)
|
||||
if toc_secs and not ans["exists"]:
|
||||
break
|
||||
toc_secs.append(sec)
|
||||
@ -500,14 +509,17 @@ async def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||
|
||||
TOC_EXTRACTION = load_prompt("toc_extraction")
|
||||
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
||||
|
||||
|
||||
async def extract_table_of_contents(toc_pages, chat_mdl):
|
||||
if not toc_pages:
|
||||
return []
|
||||
|
||||
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
||||
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)),
|
||||
"Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
async def toc_index_extractor(toc: list[dict], content: str, chat_mdl):
|
||||
tob_extractor_prompt = """
|
||||
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
||||
|
||||
@ -529,18 +541,21 @@ async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||
If the title of the section are not in the provided pages, do not add the physical_index to it.
|
||||
Directly return the final JSON structure. Do not output anything else."""
|
||||
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False,
|
||||
indent=2) + '\nDocument pages:\n' + content
|
||||
return await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
|
||||
|
||||
TOC_INDEX = load_prompt("toc_index")
|
||||
|
||||
|
||||
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
||||
if not toc_arr or not sections:
|
||||
return []
|
||||
|
||||
toc_map = {}
|
||||
for i, it in enumerate(toc_arr):
|
||||
k1 = (it["structure"]+it["title"]).replace(" ", "")
|
||||
k1 = (it["structure"] + it["title"]).replace(" ", "")
|
||||
k2 = it["title"].strip()
|
||||
if k1 not in toc_map:
|
||||
toc_map[k1] = []
|
||||
@ -558,6 +573,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
toc_arr[j]["indices"].append(i)
|
||||
|
||||
all_pathes = []
|
||||
|
||||
def dfs(start, path):
|
||||
nonlocal all_pathes
|
||||
if start >= len(toc_arr):
|
||||
@ -565,7 +581,7 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
all_pathes.append(path)
|
||||
return
|
||||
if not toc_arr[start]["indices"]:
|
||||
dfs(start+1, path)
|
||||
dfs(start + 1, path)
|
||||
return
|
||||
added = False
|
||||
for j in toc_arr[start]["indices"]:
|
||||
@ -574,12 +590,12 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
_path = deepcopy(path)
|
||||
_path.append((j, start))
|
||||
added = True
|
||||
dfs(start+1, _path)
|
||||
dfs(start + 1, _path)
|
||||
if not added and path:
|
||||
all_pathes.append(path)
|
||||
|
||||
dfs(0, [])
|
||||
path = max(all_pathes, key=lambda x:len(x))
|
||||
path = max(all_pathes, key=lambda x: len(x))
|
||||
for it in toc_arr:
|
||||
it["indices"] = []
|
||||
for j, i in path:
|
||||
@ -588,24 +604,24 @@ async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat
|
||||
|
||||
i = 0
|
||||
while i < len(toc_arr):
|
||||
it = toc_arr[i]
|
||||
it = toc_arr[i]
|
||||
if it["indices"]:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if i>0 and toc_arr[i-1]["indices"]:
|
||||
st_i = toc_arr[i-1]["indices"][-1]
|
||||
if i > 0 and toc_arr[i - 1]["indices"]:
|
||||
st_i = toc_arr[i - 1]["indices"][-1]
|
||||
else:
|
||||
st_i = 0
|
||||
e = i + 1
|
||||
while e <len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
while e < len(toc_arr) and not toc_arr[e]["indices"]:
|
||||
e += 1
|
||||
if e >= len(toc_arr):
|
||||
e = len(sections)
|
||||
else:
|
||||
e = toc_arr[e]["indices"][0]
|
||||
|
||||
for j in range(st_i, min(e+1, len(sections))):
|
||||
for j in range(st_i, min(e + 1, len(sections))):
|
||||
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
||||
structure=it["structure"],
|
||||
title=it["title"],
|
||||
@ -656,11 +672,15 @@ async def toc_transformer(toc_pages, chat_mdl):
|
||||
|
||||
toc_content = "\n".join(toc_pages)
|
||||
prompt = init_prompt + '\n Given table of contents\n:' + toc_content
|
||||
|
||||
def clean_toc(arr):
|
||||
for a in arr:
|
||||
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
||||
|
||||
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content,
|
||||
json.dumps(last_complete, ensure_ascii=False, indent=2),
|
||||
chat_mdl)
|
||||
clean_toc(last_complete)
|
||||
if if_complete == "yes":
|
||||
return last_complete
|
||||
@ -682,13 +702,17 @@ async def toc_transformer(toc_pages, chat_mdl):
|
||||
break
|
||||
clean_toc(new_complete)
|
||||
last_complete.extend(new_complete)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||
if_complete = await check_if_toc_transformation_is_complete(toc_content,
|
||||
json.dumps(last_complete, ensure_ascii=False,
|
||||
indent=2), chat_mdl)
|
||||
|
||||
return last_complete
|
||||
|
||||
|
||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||
|
||||
|
||||
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf={"temperature": 0.2}):
|
||||
if not toc_secs:
|
||||
return []
|
||||
return await gen_json(
|
||||
@ -701,12 +725,15 @@ async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2})
|
||||
|
||||
TOC_FROM_TEXT_SYSTEM = load_prompt("toc_from_text_system")
|
||||
TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
||||
|
||||
|
||||
# Generate TOC from text chunks with text llms
|
||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||
try:
|
||||
ans = await gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(
|
||||
text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
@ -743,7 +770,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
TOC_FROM_TEXT_USER + TOC_FROM_TEXT_SYSTEM
|
||||
)
|
||||
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
input_budget = 1024 if input_budget > 1024 else input_budget
|
||||
chunk_sections = split_chunks(chunks, input_budget)
|
||||
titles = []
|
||||
|
||||
@ -798,7 +825,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
if sorted_list:
|
||||
max_lvl = sorted_list[-1]
|
||||
merged = []
|
||||
for _ , (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
for _, (toc_item, src_item) in enumerate(zip(toc_with_levels, filtered)):
|
||||
if prune and toc_item.get("level", "0") >= max_lvl:
|
||||
continue
|
||||
merged.append({
|
||||
@ -812,12 +839,15 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
||||
|
||||
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
||||
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
||||
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||
|
||||
|
||||
async def relevant_chunks_with_toc(query: str, toc: list[dict], chat_mdl, topn: int = 6):
|
||||
import numpy as np
|
||||
try:
|
||||
ans = await gen_json(
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n" % "\n".join(
|
||||
[json.dumps({"level": d["level"], "title": d["title"]}, ensure_ascii=False) for d in toc])),
|
||||
chat_mdl,
|
||||
gen_conf={"temperature": 0.0, "top_p": 0.9}
|
||||
)
|
||||
@ -828,17 +858,19 @@ async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: i
|
||||
for id in ti.get("ids", []):
|
||||
if id not in id2score:
|
||||
id2score[id] = []
|
||||
id2score[id].append(sc["score"]/5.)
|
||||
id2score[id].append(sc["score"] / 5.)
|
||||
for id in id2score.keys():
|
||||
id2score[id] = np.mean(id2score[id])
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc>=0.3][:topn]
|
||||
return [(id, sc) for id, sc in list(id2score.items()) if sc >= 0.3][:topn]
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return []
|
||||
|
||||
|
||||
META_DATA = load_prompt("meta_data")
|
||||
async def gen_metadata(chat_mdl, schema:dict, content:str):
|
||||
|
||||
|
||||
async def gen_metadata(chat_mdl, schema: dict, content: str):
|
||||
template = PROMPT_JINJA_ENV.from_string(META_DATA)
|
||||
for k, desc in schema["properties"].items():
|
||||
if "enum" in desc and not desc.get("enum"):
|
||||
@ -849,4 +881,4 @@ async def gen_metadata(chat_mdl, schema:dict, content:str):
|
||||
user_prompt = "Output: "
|
||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import os
|
||||
|
||||
|
||||
PROMPT_DIR = os.path.dirname(__file__)
|
||||
|
||||
_loaded_prompts = {}
|
||||
|
||||
Reference in New Issue
Block a user