Fix: tokenizer issue. (#11902)

#11786
### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Kevin Hu
2025-12-11 17:38:17 +08:00
committed by GitHub
parent 22a51a3868
commit ea4a5cd665
17 changed files with 141 additions and 216 deletions

View File

@ -271,7 +271,7 @@ class Agent(LLM, ToolBase):
last_calling = "" last_calling = ""
if len(hist) > 3: if len(hist) > 3:
st = timer() st = timer()
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl) user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
else: else:
user_request = history[-1]["content"] user_request = history[-1]["content"]
@ -309,7 +309,7 @@ class Agent(LLM, ToolBase):
if len(hist) > 12: if len(hist) > 12:
_hist = [hist[0], hist[1], *hist[-10:]] _hist = [hist[0], hist[1], *hist[-10:]]
entire_txt = "" entire_txt = ""
async for delta_ans in self._generate_streamly_async(_hist): async for delta_ans in self._generate_streamly(_hist):
if not need2cite or cited: if not need2cite or cited:
yield delta_ans, 0 yield delta_ans, 0
entire_txt += delta_ans entire_txt += delta_ans
@ -397,7 +397,7 @@ Respond immediately with your final comprehensive answer.
retrievals = self._canvas.get_reference() retrievals = self._canvas.get_reference()
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
{"role": "user", "content": text} {"role": "user", "content": text}
]): ]):
yield delta_ans yield delta_ans

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import os import os
import re import re
@ -97,7 +98,7 @@ class Categorize(LLM, ABC):
component_name = "Categorize" component_name = "Categorize"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Categorize processing"): if self.check_if_canceled("Categorize processing"):
return return
@ -121,7 +122,7 @@ class Categorize(LLM, ABC):
if self.check_if_canceled("Categorize processing"): if self.check_if_canceled("Categorize processing"):
return return
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
logging.info(f"input: {user_prompt}, answer: {str(ans)}") logging.info(f"input: {user_prompt}, answer: {str(ans)}")
if ERROR_PREFIX in ans: if ERROR_PREFIX in ans:
raise Exception(ans) raise Exception(ans)
@ -144,5 +145,9 @@ class Categorize(LLM, ABC):
self.set_output("category_name", max_category) self.set_output("category_name", max_category)
self.set_output("_next", cpn_ids) self.set_output("_next", cpn_ids)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def thoughts(self) -> str: def thoughts(self) -> str:
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()])) return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))

View File

@ -18,9 +18,8 @@ import json
import logging import logging
import os import os
import re import re
import threading
from copy import deepcopy from copy import deepcopy
from typing import Any, Generator, AsyncGenerator from typing import Any, AsyncGenerator
import json_repair import json_repair
from functools import partial from functools import partial
from common.constants import LLMType from common.constants import LLMType
@ -168,53 +167,12 @@ class LLM(ComponentBase):
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE) sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
return pts, sys_prompt return pts, sys_prompt
def _generate(self, msg:list[dict], **kwargs) -> str:
if not self.imgs:
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
async def _generate_async(self, msg: list[dict], **kwargs) -> str: async def _generate_async(self, msg: list[dict], **kwargs) -> str:
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
return await asyncio.to_thread(self._generate, msg, **kwargs)
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
ans = ""
last_idx = 0
endswith_think = False
def delta(txt):
nonlocal ans, last_idx, endswith_think
delta_ans = txt[last_idx:]
ans = txt
if delta_ans.find("<think>") == 0:
last_idx += len("<think>")
return "<think>"
elif delta_ans.find("<think>") > 0:
delta_ans = txt[last_idx:last_idx+delta_ans.find("<think>")]
last_idx += delta_ans.find("<think>")
return delta_ans
elif delta_ans.endswith("</think>"):
endswith_think = True
elif endswith_think:
endswith_think = False
return "</think>"
last_idx = len(ans)
if ans.endswith("</think>"):
last_idx -= len("</think>")
return re.sub(r"(<think>|</think>)", "", delta_ans)
if not self.imgs: if not self.imgs:
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs): return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
yield delta(txt) return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
else:
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
yield delta(txt)
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
async def delta_wrapper(txt_iter): async def delta_wrapper(txt_iter):
ans = "" ans = ""
last_idx = 0 last_idx = 0
@ -246,36 +204,13 @@ class LLM(ComponentBase):
async for t in txt_iter: async for t in txt_iter:
yield delta(t) yield delta(t)
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): if not self.imgs:
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
yield t yield t
return return
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
yield t
return
# fallback async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
loop = asyncio.get_running_loop() yield t
queue: asyncio.Queue = asyncio.Queue()
def worker():
try:
for item in self._generate_streamly(msg, **kwargs):
loop.call_soon_threadsafe(queue.put_nowait, item)
except Exception as e:
loop.call_soon_threadsafe(queue.put_nowait, e)
finally:
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
threading.Thread(target=worker, daemon=True).start()
while True:
item = await queue.get()
if item is StopAsyncIteration:
break
if isinstance(item, Exception):
raise item
yield item
async def _stream_output_async(self, prompt, msg): async def _stream_output_async(self, prompt, msg):
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
@ -407,8 +342,8 @@ class LLM(ComponentBase):
def _invoke(self, **kwargs): def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs)) return asyncio.run(self._invoke_async(**kwargs))
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
logging.info(f"[MEMORY]: {summ}") logging.info(f"[MEMORY]: {summ}")
self._canvas.add_memory(user, assist, summ) self._canvas.add_memory(user, assist, summ)

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
from functools import partial from functools import partial
import json import json
import os import os
@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC):
component_name = "Retrieval" component_name = "Retrieval"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs): async def _invoke_async(self, **kwargs):
if self.check_if_canceled("Retrieval processing"): if self.check_if_canceled("Retrieval processing"):
return return
@ -132,7 +133,7 @@ class Retrieval(ToolBase, ABC):
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if self._param.meta_data_filter.get("method") == "auto": if self._param.meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, metas, query) filters: dict = await gen_meta_filter(chat_mdl, metas, query)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
@ -142,7 +143,7 @@ class Retrieval(ToolBase, ABC):
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, query) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
@ -180,7 +181,7 @@ class Retrieval(ToolBase, ABC):
doc_ids = ["-999"] doc_ids = ["-999"]
if self._param.cross_languages: if self._param.cross_languages:
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
if kbs: if kbs:
query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE) query = re.sub(r"^user[:\s]*", "", query, flags=re.IGNORECASE)
@ -253,6 +254,10 @@ class Retrieval(ToolBase, ABC):
return form_cnt return form_cnt
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def thoughts(self) -> str: def thoughts(self) -> str:
return """ return """
Keywords: {} Keywords: {}

View File

@ -51,7 +51,7 @@ class DeepResearcher:
"""Remove Result Tags""" """Remove Result Tags"""
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
def _generate_reasoning(self, msg_history): async def _generate_reasoning(self, msg_history):
"""Generate reasoning steps""" """Generate reasoning steps"""
query_think = "" query_think = ""
if msg_history[-1]["role"] != "user": if msg_history[-1]["role"] != "user":
@ -59,13 +59,14 @@ class DeepResearcher:
else: else:
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if not ans: if not ans:
continue continue
query_think = ans query_think = ans
yield query_think yield query_think
return query_think query_think = ""
yield query_think
def _extract_search_queries(self, query_think, question, step_index): def _extract_search_queries(self, query_think, question, step_index):
"""Extract search queries from thinking""" """Extract search queries from thinking"""
@ -143,10 +144,10 @@ class DeepResearcher:
if d["doc_id"] not in dids: if d["doc_id"] not in dids:
chunk_info["doc_aggs"].append(d) chunk_info["doc_aggs"].append(d)
def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
"""Extract and summarize relevant information""" """Extract and summarize relevant information"""
summary_think = "" summary_think = ""
for ans in self.chat_mdl.chat_streamly( async for ans in self.chat_mdl.async_chat_streamly(
RELEVANT_EXTRACTION_PROMPT.format( RELEVANT_EXTRACTION_PROMPT.format(
prev_reasoning=truncated_prev_reasoning, prev_reasoning=truncated_prev_reasoning,
search_query=search_query, search_query=search_query,
@ -160,10 +161,11 @@ class DeepResearcher:
continue continue
summary_think = ans summary_think = ans
yield summary_think yield summary_think
summary_think = ""
return summary_think yield summary_think
def thinking(self, chunk_info: dict, question: str): async def thinking(self, chunk_info: dict, question: str):
executed_search_queries = [] executed_search_queries = []
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
all_reasoning_steps = [] all_reasoning_steps = []
@ -180,7 +182,7 @@ class DeepResearcher:
# Step 1: Generate reasoning # Step 1: Generate reasoning
query_think = "" query_think = ""
for ans in self._generate_reasoning(msg_history): async for ans in self._generate_reasoning(msg_history):
query_think = ans query_think = ans
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None} yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
@ -223,7 +225,7 @@ class DeepResearcher:
# Step 6: Extract relevant information # Step 6: Extract relevant information
think += "\n\n" think += "\n\n"
summary_think = "" summary_think = ""
for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
summary_think = ans summary_think = ans
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None} yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}

View File

@ -313,7 +313,7 @@ async def retrieval_test():
langs = req.get("cross_languages", []) langs = req.get("cross_languages", [])
user_id = current_user.id user_id = current_user.id
def _retrieval_sync(): async def _retrieval():
local_doc_ids = list(doc_ids) if doc_ids else [] local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = [] tenant_ids = []
@ -323,7 +323,7 @@ async def retrieval_test():
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = await gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -333,7 +333,7 @@ async def retrieval_test():
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -347,7 +347,7 @@ async def retrieval_test():
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(user_id, LLMType.CHAT) chat_mdl = LLMBundle(user_id, LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = await gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -357,7 +357,7 @@ async def retrieval_test():
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
chat_mdl = LLMBundle(user_id, LLMType.CHAT) chat_mdl = LLMBundle(user_id, LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -384,7 +384,7 @@ async def retrieval_test():
_question = question _question = question
if langs: if langs:
_question = cross_languages(kb.tenant_id, None, _question, langs) _question = await cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -394,7 +394,7 @@ async def retrieval_test():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
@ -421,7 +421,7 @@ async def retrieval_test():
return get_json_result(data=ranks) return get_json_result(data=ranks)
try: try:
return await asyncio.to_thread(_retrieval_sync) return await _retrieval()
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message='No chunk found! Check the chunk status please!',

View File

@ -1549,11 +1549,11 @@ async def retrieval_test(tenant_id):
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
if langs: if langs:
question = cross_languages(kb.tenant_id, None, question, langs) question = await cross_languages(kb.tenant_id, None, question, langs)
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
question += keyword_extraction(chat_mdl, question) question += await keyword_extraction(chat_mdl, question)
ranks = settings.retriever.retrieval( ranks = settings.retriever.retrieval(
question, question,

View File

@ -33,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
from common import settings from common import settings
from common.constants import RetCode from common.constants import RetCode
@manager.route('/file/upload', methods=['POST']) # noqa: F821 @manager.route('/file/upload', methods=['POST']) # noqa: F821
@token_required @token_required
async def upload(tenant_id): async def upload(tenant_id):

View File

@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import re import re
import time import time
@ -44,6 +43,7 @@ from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extr
from common.constants import RetCode, LLMType, StatusEnum from common.constants import RetCode, LLMType, StatusEnum
from common import settings from common import settings
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@token_required @token_required
async def create(tenant_id, chat_id): async def create(tenant_id, chat_id):
@ -969,7 +969,7 @@ async def retrieval_test_embedded():
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
def _retrieval_sync(): async def _retrieval():
local_doc_ids = list(doc_ids) if doc_ids else [] local_doc_ids = list(doc_ids) if doc_ids else []
tenant_ids = [] tenant_ids = []
_question = question _question = question
@ -980,7 +980,7 @@ async def retrieval_test_embedded():
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, metas, _question) filters: dict = await gen_meta_filter(chat_mdl, metas, _question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -990,7 +990,7 @@ async def retrieval_test_embedded():
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, _question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -1004,7 +1004,7 @@ async def retrieval_test_embedded():
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = await gen_meta_filter(chat_mdl, metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -1014,7 +1014,7 @@ async def retrieval_test_embedded():
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not local_doc_ids: if not local_doc_ids:
local_doc_ids = None local_doc_ids = None
@ -1038,7 +1038,7 @@ async def retrieval_test_embedded():
return get_error_data_result(message="Knowledgebase not found!") return get_error_data_result(message="Knowledgebase not found!")
if langs: if langs:
_question = cross_languages(kb.tenant_id, None, _question, langs) _question = await cross_languages(kb.tenant_id, None, _question, langs)
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
@ -1048,7 +1048,7 @@ async def retrieval_test_embedded():
if req.get("keyword", False): if req.get("keyword", False):
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
_question += keyword_extraction(chat_mdl, _question) _question += await keyword_extraction(chat_mdl, _question)
labels = label_question(_question, [kb]) labels = label_question(_question, [kb])
ranks = settings.retriever.retrieval( ranks = settings.retriever.retrieval(
@ -1068,7 +1068,7 @@ async def retrieval_test_embedded():
return get_json_result(data=ranks) return get_json_result(data=ranks)
try: try:
return await asyncio.to_thread(_retrieval_sync) return await _retrieval()
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", return get_json_result(data=False, message="No chunk found! Check the chunk status please!",

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import logging import logging
import json import json
import os import os
@ -76,8 +77,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
msg = chat_mdl.chat(system="", history=[ msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
{"role": "user", "content": "Hello!"}], gen_conf={})
if msg.find("ERROR: ") == 0: if msg.find("ERROR: ") == 0:
logging.error( logging.error(
"'{}' doesn't work. {}".format( "'{}' doesn't work. {}".format(

View File

@ -397,7 +397,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
# try to use sql if field mapping is good to go # try to use sql if field mapping is good to go
if field_map: if field_map:
logging.debug("Use SQL to retrieval:{}".format(questions[-1])) logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
if ans: if ans:
yield ans yield ans
return return
@ -411,17 +411,17 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
if len(questions) > 1 and prompt_config.get("refine_multiturn"): if len(questions) > 1 and prompt_config.get("refine_multiturn"):
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
else: else:
questions = questions[-1:] questions = questions[-1:]
if prompt_config.get("cross_languages"): if prompt_config.get("cross_languages"):
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
if dialog.meta_data_filter: if dialog.meta_data_filter:
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
if dialog.meta_data_filter.get("method") == "auto": if dialog.meta_data_filter.get("method") == "auto":
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1]) filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1])
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not attachments: if not attachments:
attachments = None attachments = None
@ -430,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
if selected_keys: if selected_keys:
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, questions[-1]) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not attachments: if not attachments:
attachments = None attachments = None
@ -441,7 +441,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
attachments = ["-999"] attachments = ["-999"]
if prompt_config.get("keyword", False): if prompt_config.get("keyword", False):
questions[-1] += keyword_extraction(chat_mdl, questions[-1]) questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
refine_question_ts = timer() refine_question_ts = timer()
@ -469,7 +469,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
), ),
) )
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
if isinstance(think, str): if isinstance(think, str):
thought = think thought = think
knowledges = [t for t in think.split("\n") if t] knowledges = [t for t in think.split("\n") if t]
@ -646,7 +646,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
return return
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
sys_prompt = """ sys_prompt = """
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
Ensure that: Ensure that:
@ -664,9 +664,9 @@ Please write the SQL, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
tried_times = 0 tried_times = 0
def get_table(): async def get_table():
nonlocal sys_prompt, user_prompt, question, tried_times nonlocal sys_prompt, user_prompt, question, tried_times
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL) sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r"[\r\n]+", " ", sql.lower())
@ -705,7 +705,7 @@ Please write the SQL, only SQL, without any other explanations or text.
return settings.retriever.sql_retrieval(sql, format="json"), sql return settings.retriever.sql_retrieval(sql, format="json"), sql
try: try:
tbl, sql = get_table() tbl, sql = await get_table()
except Exception as e: except Exception as e:
user_prompt = """ user_prompt = """
Table name: {}; Table name: {};
@ -723,7 +723,7 @@ Please write the SQL, only SQL, without any other explanations or text.
Please correct the error and write SQL again, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text.
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
try: try:
tbl, sql = get_table() tbl, sql = await get_table()
except Exception: except Exception:
return return
@ -839,7 +839,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
if meta_data_filter: if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = await gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
@ -848,7 +848,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
if selected_keys: if selected_keys:
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
@ -923,7 +923,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if meta_data_filter: if meta_data_filter:
metas = DocumentService.get_meta_by_kbs(kb_ids) metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto": if meta_data_filter.get("method") == "auto":
filters: dict = gen_meta_filter(chat_mdl, metas, question) filters: dict = await gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None
@ -932,7 +932,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
if selected_keys: if selected_keys:
filtered_metas = {key: metas[key] for key in selected_keys if key in metas} filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
if filtered_metas: if filtered_metas:
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
if not doc_ids: if not doc_ids:
doc_ids = None doc_ids = None

View File

@ -318,9 +318,6 @@ class LLMBundle(LLM4Tenant):
return value return value
raise value raise value
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
result_queue: queue.Queue = queue.Queue() result_queue: queue.Queue = queue.Queue()
@ -350,23 +347,6 @@ class LLMBundle(LLM4Tenant):
raise item raise item
yield item yield item
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
ans = ""
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
if isinstance(txt, int):
break
if txt.endswith("</think>"):
ans = txt[: -len("</think>")]
continue
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
# cancatination has beend done in async_chat_streamly
ans = txt
yield ans
def _bridge_sync_stream(self, gen): def _bridge_sync_stream(self, gen):
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
queue: asyncio.Queue = asyncio.Queue() queue: asyncio.Queue = asyncio.Queue()

View File

@ -98,7 +98,7 @@ class Extractor(ProcessBase, LLM):
args[chunks_key] = ck["text"] args[chunks_key] = ck["text"]
msg, sys_prompt = self._sys_prompt_and_msg([], args) msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt}) msg.insert(0, {"role": "system", "content": sys_prompt})
ck[self._param.field_name] = self._generate(msg) ck[self._param.field_name] = await self._generate_async(msg)
prog += 1./len(chunks) prog += 1./len(chunks)
if i % (len(chunks)//100+1) == 1: if i % (len(chunks)//100+1) == 1:
self.callback(prog, f"{i+1} / {len(chunks)}") self.callback(prog, f"{i+1} / {len(chunks)}")
@ -106,6 +106,6 @@ class Extractor(ProcessBase, LLM):
else: else:
msg, sys_prompt = self._sys_prompt_and_msg([], args) msg, sys_prompt = self._sys_prompt_and_msg([], args)
msg.insert(0, {"role": "system", "content": sys_prompt}) msg.insert(0, {"role": "system", "content": sys_prompt})
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}]) self.set_output("chunks", [{self._param.field_name: await self._generate_async(msg)}])

View File

@ -33,6 +33,22 @@ class RagTokenizer(infinity.rag_tokenizer.RagTokenizer):
return super().fine_grained_tokenize(tks) return super().fine_grained_tokenize(tks)
def is_chinese(s):
return infinity.rag_tokenizer.is_chinese(s)
def is_number(s):
return infinity.rag_tokenizer.is_number(s)
def is_alphabet(s):
return infinity.rag_tokenizer.is_alphabet(s)
def naive_qie(txt):
return infinity.rag_tokenizer.naive_qie(txt)
tokenizer = RagTokenizer() tokenizer = RagTokenizer()
tokenize = tokenizer.tokenize tokenize = tokenizer.tokenize
fine_grained_tokenize = tokenizer.fine_grained_tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
import asyncio
import json import json
import logging import logging
import re import re
@ -607,7 +608,7 @@ class Dealer:
if not toc: if not toc:
return chunks return chunks
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2))
if not ids: if not ids:
return chunks return chunks

View File

@ -170,13 +170,13 @@ def citation_plus(sources: str) -> str:
return template.render(example=citation_prompt(), sources=sources) return template.render(example=citation_prompt(), sources=sources)
def keyword_extraction(chat_mdl, content, topn=3): async def keyword_extraction(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE) template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn) rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -185,13 +185,13 @@ def keyword_extraction(chat_mdl, content, topn=3):
return kwd return kwd
def question_proposal(chat_mdl, content, topn=3): async def question_proposal(chat_mdl, content, topn=3):
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
rendered_prompt = template.render(content=content, topn=topn) rendered_prompt = template.render(content=content, topn=topn)
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -200,7 +200,7 @@ def question_proposal(chat_mdl, content, topn=3):
return kwd return kwd
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
@ -229,12 +229,12 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
language=language, language=language,
) )
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) ans = await chat_mdl.async_chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
def cross_languages(tenant_id, llm_id, query, languages=[]): async def cross_languages(tenant_id, llm_id, query, languages=[]):
from common.constants import LLMType from common.constants import LLMType
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.tenant_llm_service import TenantLLMService
@ -247,14 +247,14 @@ def cross_languages(tenant_id, llm_id, query, languages=[]):
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
if ans.find("**ERROR**") >= 0: if ans.find("**ERROR**") >= 0:
return query return query
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
def content_tagging(chat_mdl, content, all_tags, examples, topn=3): async def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE) template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
for ex in examples: for ex in examples:
@ -269,7 +269,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
_, msg = message_fit_in(msg, chat_mdl.max_length) _, msg = message_fit_in(msg, chat_mdl.max_length)
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.5})
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -352,7 +352,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
else: else:
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}]) kwd = await chat_mdl.async_chat(context, [{"role": "user", "content": "Please analyze it."}])
if isinstance(kwd, tuple): if isinstance(kwd, tuple):
kwd = kwd[0] kwd = kwd[0]
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL) kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
@ -361,14 +361,6 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
return kwd return kwd
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
chat_async = getattr(chat_mdl, "async_chat", None)
if chat_async and asyncio.iscoroutinefunction(chat_async):
return await chat_async(system, history, **kwargs)
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
if not tools_description: if not tools_description:
return "", 0 return "", 0
@ -380,8 +372,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
hist[-1]["content"] += user_prompt hist[-1]["content"] += user_prompt
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
json_str = await _chat_async( json_str = await chat_mdl.async_chat(
chat_mdl,
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
hist[1:], hist[1:],
stop=["<|stop|>"], stop=["<|stop|>"],
@ -402,7 +393,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
else: else:
hist.append({"role": "user", "content": user_prompt}) hist.append({"role": "user", "content": user_prompt})
_, msg = message_fit_in(hist, chat_mdl.max_length) _, msg = message_fit_in(hist, chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:]) ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
return """ return """
**Observation** **Observation**
@ -422,14 +413,14 @@ def structured_output_prompt(schema=None) -> str:
return template.render(schema=schema) return template.render(schema=schema)
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
system_prompt = template.render(name=name, system_prompt = template.render(name=name,
params=json.dumps(params, ensure_ascii=False, indent=2), params=json.dumps(params, ensure_ascii=False, indent=2),
result=result) result=result)
user_prompt = "→ Summary: " user_prompt = "→ Summary: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:]) ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
@ -438,11 +429,11 @@ async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summar
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
user_prompt = " → rank: " user_prompt = " → rank: "
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>") ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL) return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
meta_data_structure = {} meta_data_structure = {}
for key, values in meta_data.items(): for key, values in meta_data.items():
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
@ -453,7 +444,7 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
user_question=query user_question=query
) )
user_prompt = "Generate filters:" user_prompt = "Generate filters:"
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}])
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL) ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try: try:
ans = json_repair.loads(ans) ans = json_repair.loads(ans)
@ -466,13 +457,13 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
return {"conditions": []} return {"conditions": []}
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
from graphrag.utils import get_llm_cache, set_llm_cache from graphrag.utils import get_llm_cache, set_llm_cache
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
if cached: if cached:
return json_repair.loads(cached) return json_repair.loads(cached)
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL) ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
try: try:
res = json_repair.loads(ans) res = json_repair.loads(ans)
@ -483,10 +474,10 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
TOC_DETECTION = load_prompt("toc_detection") TOC_DETECTION = load_prompt("toc_detection")
def detect_table_of_contents(page_1024:list[str], chat_mdl): async def detect_table_of_contents(page_1024:list[str], chat_mdl):
toc_secs = [] toc_secs = []
for i, sec in enumerate(page_1024[:22]): for i, sec in enumerate(page_1024[:22]):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
if toc_secs and not ans["exists"]: if toc_secs and not ans["exists"]:
break break
toc_secs.append(sec) toc_secs.append(sec)
@ -495,14 +486,14 @@ def detect_table_of_contents(page_1024:list[str], chat_mdl):
TOC_EXTRACTION = load_prompt("toc_extraction") TOC_EXTRACTION = load_prompt("toc_extraction")
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
def extract_table_of_contents(toc_pages, chat_mdl): async def extract_table_of_contents(toc_pages, chat_mdl):
if not toc_pages: if not toc_pages:
return [] return []
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
def toc_index_extractor(toc:list[dict], content:str, chat_mdl): async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
tob_extractor_prompt = """ tob_extractor_prompt = """
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
@ -525,11 +516,11 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
Directly return the final JSON structure. Do not output anything else.""" Directly return the final JSON structure. Do not output anything else."""
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
return gen_json(prompt, "Only JSON please.", chat_mdl) return await gen_json(prompt, "Only JSON please.", chat_mdl)
TOC_INDEX = load_prompt("toc_index") TOC_INDEX = load_prompt("toc_index")
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
if not toc_arr or not sections: if not toc_arr or not sections:
return [] return []
@ -601,7 +592,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
e = toc_arr[e]["indices"][0] e = toc_arr[e]["indices"][0]
for j in range(st_i, min(e+1, len(sections))): for j in range(st_i, min(e+1, len(sections))):
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
structure=it["structure"], structure=it["structure"],
title=it["title"], title=it["title"],
text=sections[j]), "Only JSON please.", chat_mdl) text=sections[j]), "Only JSON please.", chat_mdl)
@ -614,7 +605,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
return toc_arr return toc_arr
def check_if_toc_transformation_is_complete(content, toc, chat_mdl): async def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
prompt = """ prompt = """
You are given a raw table of contents and a table of contents. You are given a raw table of contents and a table of contents.
Your job is to check if the table of contents is complete. Your job is to check if the table of contents is complete.
@ -627,11 +618,11 @@ def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
Directly return the final JSON structure. Do not output anything else.""" Directly return the final JSON structure. Do not output anything else."""
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
response = gen_json(prompt, "Only JSON please.", chat_mdl) response = await gen_json(prompt, "Only JSON please.", chat_mdl)
return response['completed'] return response['completed']
def toc_transformer(toc_pages, chat_mdl): async def toc_transformer(toc_pages, chat_mdl):
init_prompt = """ init_prompt = """
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
@ -654,8 +645,8 @@ def toc_transformer(toc_pages, chat_mdl):
def clean_toc(arr): def clean_toc(arr):
for a in arr: for a in arr:
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl) last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
clean_toc(last_complete) clean_toc(last_complete)
if if_complete == "yes": if if_complete == "yes":
return last_complete return last_complete
@ -672,21 +663,21 @@ def toc_transformer(toc_pages, chat_mdl):
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)} {json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
Please continue the json structure, directly output the remaining part of the json structure.""" Please continue the json structure, directly output the remaining part of the json structure."""
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl) new_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
if not new_complete or str(last_complete).find(str(new_complete)) >= 0: if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
break break
clean_toc(new_complete) clean_toc(new_complete)
last_complete.extend(new_complete) last_complete.extend(new_complete)
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
return last_complete return last_complete
TOC_LEVELS = load_prompt("assign_toc_levels") TOC_LEVELS = load_prompt("assign_toc_levels")
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
if not toc_secs: if not toc_secs:
return [] return []
return gen_json( return await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(), PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
str(toc_secs), str(toc_secs),
chat_mdl, chat_mdl,
@ -699,7 +690,7 @@ TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
# Generate TOC from text chunks with text llms # Generate TOC from text chunks with text llms
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
try: try:
ans = gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
chat_mdl, chat_mdl,
@ -782,7 +773,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
raw_structure = [x.get("title", "") for x in filtered] raw_structure = [x.get("title", "") for x in filtered]
# Assign hierarchy levels using LLM # Assign hierarchy levels using LLM
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) toc_with_levels = await assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
if not toc_with_levels: if not toc_with_levels:
return [] return []
@ -807,10 +798,10 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
import numpy as np import numpy as np
try: try:
ans = gen_json( ans = await gen_json(
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
chat_mdl, chat_mdl,

View File

@ -323,12 +323,7 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await asyncio.to_thread( cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn)
keyword_extraction,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
if cached: if cached:
d["important_kwd"] = cached.split(",") d["important_kwd"] = cached.split(",")
@ -356,12 +351,7 @@ async def build_chunks(task, progress_callback):
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
if not cached: if not cached:
async with chat_limiter: async with chat_limiter:
cached = await asyncio.to_thread( cached = await question_proposal(chat_mdl, d["content_with_weight"], topn)
question_proposal,
chat_mdl,
d["content_with_weight"],
topn,
)
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
if cached: if cached:
d["question_kwd"] = cached.split("\n") d["question_kwd"] = cached.split("\n")
@ -414,8 +404,7 @@ async def build_chunks(task, progress_callback):
if not picked_examples: if not picked_examples:
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
async with chat_limiter: async with chat_limiter:
cached = await asyncio.to_thread( cached = await content_tagging(
content_tagging,
chat_mdl, chat_mdl,
d["content_with_weight"], d["content_with_weight"],
all_tags, all_tags,