diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 29bd599c6..988351cf6 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -271,7 +271,7 @@ class Agent(LLM, ToolBase): last_calling = "" if len(hist) > 3: st = timer() - user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl) + user_request = await full_question(messages=history, chat_mdl=self.chat_mdl) self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st) else: user_request = history[-1]["content"] @@ -309,7 +309,7 @@ class Agent(LLM, ToolBase): if len(hist) > 12: _hist = [hist[0], hist[1], *hist[-10:]] entire_txt = "" - async for delta_ans in self._generate_streamly_async(_hist): + async for delta_ans in self._generate_streamly(_hist): if not need2cite or cited: yield delta_ans, 0 entire_txt += delta_ans @@ -397,7 +397,7 @@ Respond immediately with your final comprehensive answer. retrievals = self._canvas.get_reference() retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())} formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True) - async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, + async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))}, {"role": "user", "content": text} ]): yield delta_ans diff --git a/agent/component/categorize.py b/agent/component/categorize.py index 1333889bb..27cffb91c 100644 --- a/agent/component/categorize.py +++ b/agent/component/categorize.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import os import re @@ -97,7 +98,7 @@ class Categorize(LLM, ABC): component_name = "Categorize" @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Categorize processing"): return @@ -121,7 +122,7 @@ class Categorize(LLM, ABC): if self.check_if_canceled("Categorize processing"): return - ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) + ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf()) logging.info(f"input: {user_prompt}, answer: {str(ans)}") if ERROR_PREFIX in ans: raise Exception(ans) @@ -144,5 +145,9 @@ class Categorize(LLM, ABC): self.set_output("category_name", max_category) self.set_output("_next", cpn_ids) + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()])) diff --git a/agent/component/llm.py b/agent/component/llm.py index a437025e9..39e043aeb 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -18,9 +18,8 @@ import json import logging import os import re -import threading from copy import deepcopy -from typing import Any, Generator, AsyncGenerator +from typing import Any, AsyncGenerator import json_repair from functools import partial from common.constants import LLMType @@ -168,53 +167,12 @@ class LLM(ComponentBase): sys_prompt = re.sub(rf"<{tag}>(.*?)", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE) return pts, sys_prompt - def _generate(self, msg:list[dict], **kwargs) -> str: - if not self.imgs: - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) - return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - async def _generate_async(self, msg: list[dict], **kwargs) -> str: - if not self.imgs and hasattr(self.chat_mdl, "async_chat"): - return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) - if self.imgs and hasattr(self.chat_mdl, "async_chat"): - return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - return await asyncio.to_thread(self._generate, msg, **kwargs) - - def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]: - ans = "" - last_idx = 0 - endswith_think = False - def delta(txt): - nonlocal ans, last_idx, endswith_think - delta_ans = txt[last_idx:] - ans = txt - - if delta_ans.find("") == 0: - last_idx += len("") - return "" - elif delta_ans.find("") > 0: - delta_ans = txt[last_idx:last_idx+delta_ans.find("")] - last_idx += delta_ans.find("") - return delta_ans - elif delta_ans.endswith(""): - endswith_think = True - elif endswith_think: - endswith_think = False - return "" - - last_idx = len(ans) - if ans.endswith(""): - last_idx -= len("") - return re.sub(r"(|)", "", delta_ans) - if not self.imgs: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs): - yield delta(txt) - else: - for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs): - yield delta(txt) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs) + return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs) - async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: + async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]: async def delta_wrapper(txt_iter): ans = "" last_idx = 0 @@ -246,36 +204,13 @@ class LLM(ComponentBase): async for t in txt_iter: yield delta(t) - if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): + if not self.imgs: async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)): yield t return - if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"): - async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): - yield t - return - # fallback - loop = asyncio.get_running_loop() - queue: asyncio.Queue = asyncio.Queue() - - def worker(): - try: - for item in self._generate_streamly(msg, **kwargs): - loop.call_soon_threadsafe(queue.put_nowait, item) - except Exception as e: - loop.call_soon_threadsafe(queue.put_nowait, e) - finally: - loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration) - - threading.Thread(target=worker, daemon=True).start() - while True: - item = await queue.get() - if item is StopAsyncIteration: - break - if isinstance(item, Exception): - raise item - yield item + async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)): + yield t async def _stream_output_async(self, prompt, msg): _, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97)) @@ -407,8 +342,8 @@ class LLM(ComponentBase): def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) - def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): - summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) + async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): + summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) logging.info(f"[MEMORY]: {summ}") self._canvas.add_memory(user, assist, summ) diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index a0c990a81..1da832e7b 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio from functools import partial import json import os @@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC): component_name = "Retrieval" @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) - def _invoke(self, **kwargs): + async def _invoke_async(self, **kwargs): if self.check_if_canceled("Retrieval processing"): return @@ -132,7 +133,7 @@ class Retrieval(ToolBase, ABC): metas = DocumentService.get_meta_by_kbs(kb_ids) if self._param.meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, metas, query) + filters: dict = await gen_meta_filter(chat_mdl, metas, query) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None @@ -142,7 +143,7 @@ class Retrieval(ToolBase, ABC): filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, query) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None @@ -180,7 +181,7 @@ class Retrieval(ToolBase, ABC): doc_ids = ["-999"] if self._param.cross_languages: - query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) + query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages) if kbs: query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE) @@ -253,6 +254,10 @@ class Retrieval(ToolBase, ABC): return form_cnt + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12))) + def _invoke(self, **kwargs): + return asyncio.run(self._invoke_async(**kwargs)) + def thoughts(self) -> str: return """ Keywords: {} diff --git a/agentic_reasoning/deep_research.py b/agentic_reasoning/deep_research.py index d7121245f..20f7017f4 100644 --- a/agentic_reasoning/deep_research.py +++ b/agentic_reasoning/deep_research.py @@ -51,7 +51,7 @@ class DeepResearcher: """Remove Result Tags""" return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT) - def _generate_reasoning(self, msg_history): + async def _generate_reasoning(self, msg_history): """Generate reasoning steps""" query_think = "" if msg_history[-1]["role"] != "user": @@ -59,13 +59,14 @@ class DeepResearcher: else: msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n" - for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): + async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}): ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if not ans: continue query_think = ans yield query_think - return query_think + query_think = "" + yield query_think def _extract_search_queries(self, query_think, question, step_index): """Extract search queries from thinking""" @@ -143,10 +144,10 @@ class DeepResearcher: if d["doc_id"] not in dids: chunk_info["doc_aggs"].append(d) - def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): + async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos): """Extract and summarize relevant information""" summary_think = "" - for ans in self.chat_mdl.chat_streamly( + async for ans in self.chat_mdl.async_chat_streamly( RELEVANT_EXTRACTION_PROMPT.format( prev_reasoning=truncated_prev_reasoning, search_query=search_query, @@ -160,10 +161,11 @@ class DeepResearcher: continue summary_think = ans yield summary_think + summary_think = "" - return summary_think + yield summary_think - def thinking(self, chunk_info: dict, question: str): + async def thinking(self, chunk_info: dict, question: str): executed_search_queries = [] msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}] all_reasoning_steps = [] @@ -180,7 +182,7 @@ class DeepResearcher: # Step 1: Generate reasoning query_think = "" - for ans in self._generate_reasoning(msg_history): + async for ans in self._generate_reasoning(msg_history): query_think = ans yield {"answer": think + self._remove_query_tags(query_think) + "", "reference": {}, "audio_binary": None} @@ -223,7 +225,7 @@ class DeepResearcher: # Step 6: Extract relevant information think += "\n\n" summary_think = "" - for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): + async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos): summary_think = ans yield {"answer": think + self._remove_result_tags(summary_think) + "", "reference": {}, "audio_binary": None} diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 37cd0c7a1..af6bb6617 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -313,7 +313,7 @@ async def retrieval_test(): langs = req.get("cross_languages", []) user_id = current_user.id - def _retrieval_sync(): + async def _retrieval(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] @@ -323,7 +323,7 @@ async def retrieval_test(): metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, question) + filters: dict = await gen_meta_filter(chat_mdl, metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -333,7 +333,7 @@ async def retrieval_test(): filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -347,7 +347,7 @@ async def retrieval_test(): metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(user_id, LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, metas, question) + filters: dict = await gen_meta_filter(chat_mdl, metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -357,7 +357,7 @@ async def retrieval_test(): filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: chat_mdl = LLMBundle(user_id, LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -384,7 +384,7 @@ async def retrieval_test(): _question = question if langs: - _question = cross_languages(kb.tenant_id, None, _question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -394,7 +394,7 @@ async def retrieval_test(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - _question += keyword_extraction(chat_mdl, _question) + _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size, @@ -421,7 +421,7 @@ async def retrieval_test(): return get_json_result(data=ranks) try: - return await asyncio.to_thread(_retrieval_sync) + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message='No chunk found! Check the chunk status please!', diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index b65a20133..19fabdcff 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -1549,11 +1549,11 @@ async def retrieval_test(tenant_id): rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"]) if langs: - question = cross_languages(kb.tenant_id, None, question, langs) + question = await cross_languages(kb.tenant_id, None, question, langs) if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - question += keyword_extraction(chat_mdl, question) + question += await keyword_extraction(chat_mdl, question) ranks = settings.retriever.retrieval( question, diff --git a/api/apps/sdk/files.py b/api/apps/sdk/files.py index 2e9fd6df3..8bac19ccd 100644 --- a/api/apps/sdk/files.py +++ b/api/apps/sdk/files.py @@ -33,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP from common import settings from common.constants import RetCode + @manager.route('/file/upload', methods=['POST']) # noqa: F821 @token_required async def upload(tenant_id): diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index cb4d78f3b..d4db3cb56 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import asyncio import json import re import time @@ -44,6 +43,7 @@ from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extr from common.constants import RetCode, LLMType, StatusEnum from common import settings + @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @token_required async def create(tenant_id, chat_id): @@ -969,7 +969,7 @@ async def retrieval_test_embedded(): if not tenant_id: return get_error_data_result(message="permission denined.") - def _retrieval_sync(): + async def _retrieval(): local_doc_ids = list(doc_ids) if doc_ids else [] tenant_ids = [] _question = question @@ -980,7 +980,7 @@ async def retrieval_test_embedded(): metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, metas, _question) + filters: dict = await gen_meta_filter(chat_mdl, metas, _question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -990,7 +990,7 @@ async def retrieval_test_embedded(): filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, _question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -1004,7 +1004,7 @@ async def retrieval_test_embedded(): metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, metas, question) + filters: dict = await gen_meta_filter(chat_mdl, metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -1014,7 +1014,7 @@ async def retrieval_test_embedded(): filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not local_doc_ids: local_doc_ids = None @@ -1038,7 +1038,7 @@ async def retrieval_test_embedded(): return get_error_data_result(message="Knowledgebase not found!") if langs: - _question = cross_languages(kb.tenant_id, None, _question, langs) + _question = await cross_languages(kb.tenant_id, None, _question, langs) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) @@ -1048,7 +1048,7 @@ async def retrieval_test_embedded(): if req.get("keyword", False): chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT) - _question += keyword_extraction(chat_mdl, _question) + _question += await keyword_extraction(chat_mdl, _question) labels = label_question(_question, [kb]) ranks = settings.retriever.retrieval( @@ -1068,7 +1068,7 @@ async def retrieval_test_embedded(): return get_json_result(data=ranks) try: - return await asyncio.to_thread(_retrieval_sync) + return await _retrieval() except Exception as e: if str(e).find("not_found") > 0: return get_json_result(data=False, message="No chunk found! Check the chunk status please!", diff --git a/api/db/init_data.py b/api/db/init_data.py index 7454965eb..1ebc306d3 100644 --- a/api/db/init_data.py +++ b/api/db/init_data.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import json import os @@ -76,8 +77,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_ f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.") chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"]) - msg = chat_mdl.chat(system="", history=[ - {"role": "user", "content": "Hello!"}], gen_conf={}) + msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={})) if msg.find("ERROR: ") == 0: logging.error( "'{}' doesn't work. {}".format( diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 88f61f190..0fded53f6 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -397,7 +397,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) - ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) + ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans return @@ -411,17 +411,17 @@ async def async_chat(dialog, messages, stream=True, **kwargs): prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): - questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] + questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): - questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] + questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) if dialog.meta_data_filter.get("method") == "auto": - filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1]) + filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1]) attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not attachments: attachments = None @@ -430,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): if selected_keys: filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, questions[-1]) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1]) attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not attachments: attachments = None @@ -441,7 +441,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): attachments = ["-999"] if prompt_config.get("keyword", False): - questions[-1] += keyword_extraction(chat_mdl, questions[-1]) + questions[-1] += await keyword_extraction(chat_mdl, questions[-1]) refine_question_ts = timer() @@ -469,7 +469,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): ), ) - for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): + async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)): if isinstance(think, str): thought = think knowledges = [t for t in think.split("\n") if t] @@ -646,7 +646,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): return -def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): +async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = """ You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question. Ensure that: @@ -664,9 +664,9 @@ Please write the SQL, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) tried_times = 0 - def get_table(): + async def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times - sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) + sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = re.sub(r"^.*", "", sql, flags=re.DOTALL) logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) @@ -705,7 +705,7 @@ Please write the SQL, only SQL, without any other explanations or text. return settings.retriever.sql_retrieval(sql, format="json"), sql try: - tbl, sql = get_table() + tbl, sql = await get_table() except Exception as e: user_prompt = """ Table name: {}; @@ -723,7 +723,7 @@ Please write the SQL, only SQL, without any other explanations or text. Please correct the error and write SQL again, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e) try: - tbl, sql = get_table() + tbl, sql = await get_table() except Exception: return @@ -839,7 +839,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": - filters: dict = gen_meta_filter(chat_mdl, metas, question) + filters: dict = await gen_meta_filter(chat_mdl, metas, question) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None @@ -848,7 +848,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf if selected_keys: filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None @@ -923,7 +923,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": - filters: dict = gen_meta_filter(chat_mdl, metas, question) + filters: dict = await gen_meta_filter(chat_mdl, metas, question) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None @@ -932,7 +932,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): if selected_keys: filtered_metas = {key: metas[key] for key in selected_keys if key in metas} if filtered_metas: - filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question) + filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question) doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and"))) if not doc_ids: doc_ids = None diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index e4bf64aac..e5505af88 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -318,9 +318,6 @@ class LLMBundle(LLM4Tenant): return value raise value - def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str: - return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs)) - def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs): result_queue: queue.Queue = queue.Queue() @@ -350,23 +347,6 @@ class LLMBundle(LLM4Tenant): raise item yield item - def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs): - ans = "" - for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs): - if isinstance(txt, int): - break - - if txt.endswith(""): - ans = txt[: -len("")] - continue - - if not self.verbose_tool_use: - txt = re.sub(r".*?", "", txt, flags=re.DOTALL) - - # cancatination has beend done in async_chat_streamly - ans = txt - yield ans - def _bridge_sync_stream(self, gen): loop = asyncio.get_running_loop() queue: asyncio.Queue = asyncio.Queue() diff --git a/rag/flow/extractor/extractor.py b/rag/flow/extractor/extractor.py index 1b97fd1ee..8061086a7 100644 --- a/rag/flow/extractor/extractor.py +++ b/rag/flow/extractor/extractor.py @@ -98,7 +98,7 @@ class Extractor(ProcessBase, LLM): args[chunks_key] = ck["text"] msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) - ck[self._param.field_name] = self._generate(msg) + ck[self._param.field_name] = await self._generate_async(msg) prog += 1./len(chunks) if i % (len(chunks)//100+1) == 1: self.callback(prog, f"{i+1} / {len(chunks)}") @@ -106,6 +106,6 @@ class Extractor(ProcessBase, LLM): else: msg, sys_prompt = self._sys_prompt_and_msg([], args) msg.insert(0, {"role": "system", "content": sys_prompt}) - self.set_output("chunks", [{self._param.field_name: self._generate(msg)}]) + self.set_output("chunks", [{self._param.field_name: await self._generate_async(msg)}]) diff --git a/rag/nlp/rag_tokenizer.py b/rag/nlp/rag_tokenizer.py index c50e84ebc..494e1915b 100644 --- a/rag/nlp/rag_tokenizer.py +++ b/rag/nlp/rag_tokenizer.py @@ -33,6 +33,22 @@ class RagTokenizer(infinity.rag_tokenizer.RagTokenizer): return super().fine_grained_tokenize(tks) +def is_chinese(s): + return infinity.rag_tokenizer.is_chinese(s) + + +def is_number(s): + return infinity.rag_tokenizer.is_number(s) + + +def is_alphabet(s): + return infinity.rag_tokenizer.is_alphabet(s) + + +def naive_qie(txt): + return infinity.rag_tokenizer.naive_qie(txt) + + tokenizer = RagTokenizer() tokenize = tokenizer.tokenize fine_grained_tokenize = tokenizer.fine_grained_tokenize diff --git a/rag/nlp/search.py b/rag/nlp/search.py index f5dd2d4de..d2129e77f 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import json import logging import re @@ -607,7 +608,7 @@ class Dealer: if not toc: return chunks - ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2) + ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)) if not ids: return chunks diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index 523935277..621a460ad 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -170,13 +170,13 @@ def citation_plus(sources: str) -> str: return template.render(example=citation_prompt(), sources=sources) -def keyword_extraction(chat_mdl, content, topn=3): +async def keyword_extraction(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -185,13 +185,13 @@ def keyword_extraction(chat_mdl, content, topn=3): return kwd -def question_proposal(chat_mdl, content, topn=3): +async def question_proposal(chat_mdl, content, topn=3): template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE) rendered_prompt = template.render(content=content, topn=topn) msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -200,7 +200,7 @@ def question_proposal(chat_mdl, content, topn=3): return kwd -def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): +async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -229,12 +229,12 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_ language=language, ) - ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) + ans = await chat_mdl.async_chat(rendered_prompt, [{"role": "user", "content": "Output: "}]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"] -def cross_languages(tenant_id, llm_id, query, languages=[]): +async def cross_languages(tenant_id, llm_id, query, languages=[]): from common.constants import LLMType from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService @@ -247,14 +247,14 @@ def cross_languages(tenant_id, llm_id, query, languages=[]): rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render() rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages) - ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) + ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2}) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) if ans.find("**ERROR**") >= 0: return query return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()]) -def content_tagging(chat_mdl, content, all_tags, examples, topn=3): +async def content_tagging(chat_mdl, content, all_tags, examples, topn=3): template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE) for ex in examples: @@ -269,7 +269,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3): msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}] _, msg = message_fit_in(msg, chat_mdl.max_length) - kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5}) + kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.5}) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -352,7 +352,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis else: template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER) context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc) - kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}]) + kwd = await chat_mdl.async_chat(context, [{"role": "user", "content": "Please analyze it."}]) if isinstance(kwd, tuple): kwd = kwd[0] kwd = re.sub(r"^.*", "", kwd, flags=re.DOTALL) @@ -361,14 +361,6 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis return kwd -async def _chat_async(chat_mdl, system: str, history: list, **kwargs): - chat_async = getattr(chat_mdl, "async_chat", None) - if chat_async and asyncio.iscoroutinefunction(chat_async): - return await chat_async(system, history, **kwargs) - return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs) - - - async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}): if not tools_description: return "", 0 @@ -380,8 +372,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict], hist[-1]["content"] += user_prompt else: hist.append({"role": "user", "content": user_prompt}) - json_str = await _chat_async( - chat_mdl, + json_str = await chat_mdl.async_chat( template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")), hist[1:], stop=["<|stop|>"], @@ -402,7 +393,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple else: hist.append({"role": "user", "content": user_prompt}) _, msg = message_fit_in(hist, chat_mdl.max_length) - ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:]) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) return """ **Observation** @@ -422,14 +413,14 @@ def structured_output_prompt(schema=None) -> str: return template.render(schema=schema) -def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: +async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str: template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY) system_prompt = template.render(name=name, params=json.dumps(params, ensure_ascii=False, indent=2), result=result) user_prompt = "→ Summary: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:]) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:]) return re.sub(r"^.*", "", ans, flags=re.DOTALL) @@ -438,11 +429,11 @@ async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summar system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)]) user_prompt = " → rank: " _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>") + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>") return re.sub(r"^.*", "", ans, flags=re.DOTALL) -def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: +async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: meta_data_structure = {} for key, values in meta_data.items(): meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values @@ -453,7 +444,7 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: user_question=query ) user_prompt = "Generate filters:" - ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}]) + ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}]) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: ans = json_repair.loads(ans) @@ -466,13 +457,13 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict: return {"conditions": []} -def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): +async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): from graphrag.utils import get_llm_cache, set_llm_cache cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf) if cached: return json_repair.loads(cached) _, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length) - ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) + ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf) ans = re.sub(r"(^.*|```json\n|```\n*$)", "", ans, flags=re.DOTALL) try: res = json_repair.loads(ans) @@ -483,10 +474,10 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None): TOC_DETECTION = load_prompt("toc_detection") -def detect_table_of_contents(page_1024:list[str], chat_mdl): +async def detect_table_of_contents(page_1024:list[str], chat_mdl): toc_secs = [] for i, sec in enumerate(page_1024[:22]): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) + ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl) if toc_secs and not ans["exists"]: break toc_secs.append(sec) @@ -495,14 +486,14 @@ def detect_table_of_contents(page_1024:list[str], chat_mdl): TOC_EXTRACTION = load_prompt("toc_extraction") TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue") -def extract_table_of_contents(toc_pages, chat_mdl): +async def extract_table_of_contents(toc_pages, chat_mdl): if not toc_pages: return [] - return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) + return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl) -def toc_index_extractor(toc:list[dict], content:str, chat_mdl): +async def toc_index_extractor(toc:list[dict], content:str, chat_mdl): tob_extractor_prompt = """ You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format. @@ -525,11 +516,11 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl): Directly return the final JSON structure. Do not output anything else.""" prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content - return gen_json(prompt, "Only JSON please.", chat_mdl) + return await gen_json(prompt, "Only JSON please.", chat_mdl) TOC_INDEX = load_prompt("toc_index") -def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): +async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): if not toc_arr or not sections: return [] @@ -601,7 +592,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): e = toc_arr[e]["indices"][0] for j in range(st_i, min(e+1, len(sections))): - ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( + ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render( structure=it["structure"], title=it["title"], text=sections[j]), "Only JSON please.", chat_mdl) @@ -614,7 +605,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl): return toc_arr -def check_if_toc_transformation_is_complete(content, toc, chat_mdl): +async def check_if_toc_transformation_is_complete(content, toc, chat_mdl): prompt = """ You are given a raw table of contents and a table of contents. Your job is to check if the table of contents is complete. @@ -627,11 +618,11 @@ def check_if_toc_transformation_is_complete(content, toc, chat_mdl): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = gen_json(prompt, "Only JSON please.", chat_mdl) + response = await gen_json(prompt, "Only JSON please.", chat_mdl) return response['completed'] -def toc_transformer(toc_pages, chat_mdl): +async def toc_transformer(toc_pages, chat_mdl): init_prompt = """ You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents. @@ -654,8 +645,8 @@ def toc_transformer(toc_pages, chat_mdl): def clean_toc(arr): for a in arr: a["title"] = re.sub(r"[.·….]{2,}", "", a["title"]) - last_complete = gen_json(prompt, "Only JSON please.", chat_mdl) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) + last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) + if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) clean_toc(last_complete) if if_complete == "yes": return last_complete @@ -672,21 +663,21 @@ def toc_transformer(toc_pages, chat_mdl): {json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)} Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete = gen_json(prompt, "Only JSON please.", chat_mdl) + new_complete = await gen_json(prompt, "Only JSON please.", chat_mdl) if not new_complete or str(last_complete).find(str(new_complete)) >= 0: break clean_toc(new_complete) last_complete.extend(new_complete) - if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) + if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl) return last_complete TOC_LEVELS = load_prompt("assign_toc_levels") -def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): +async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}): if not toc_secs: return [] - return gen_json( + return await gen_json( PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(), str(toc_secs), chat_mdl, @@ -699,7 +690,7 @@ TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user") # Generate TOC from text chunks with text llms async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None): try: - ans = gen_json( + ans = await gen_json( PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])), chat_mdl, @@ -782,7 +773,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): raw_structure = [x.get("title", "") for x in filtered] # Assign hierarchy levels using LLM - toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) + toc_with_levels = await assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9}) if not toc_with_levels: return [] @@ -807,10 +798,10 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None): TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system") TOC_RELEVANCE_USER = load_prompt("toc_relevance_user") -def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): +async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6): import numpy as np try: - ans = gen_json( + ans = await gen_json( PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(), PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])), chat_mdl, diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 0094c081c..1a0c51600 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -323,12 +323,7 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn}) if not cached: async with chat_limiter: - cached = await asyncio.to_thread( - keyword_extraction, - chat_mdl, - d["content_with_weight"], - topn, - ) + cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) if cached: d["important_kwd"] = cached.split(",") @@ -356,12 +351,7 @@ async def build_chunks(task, progress_callback): cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn}) if not cached: async with chat_limiter: - cached = await asyncio.to_thread( - question_proposal, - chat_mdl, - d["content_with_weight"], - topn, - ) + cached = await question_proposal(chat_mdl, d["content_with_weight"], topn) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn}) if cached: d["question_kwd"] = cached.split("\n") @@ -414,8 +404,7 @@ async def build_chunks(task, progress_callback): if not picked_examples: picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}}) async with chat_limiter: - cached = await asyncio.to_thread( - content_tagging, + cached = await content_tagging( chat_mdl, d["content_with_weight"], all_tags,