mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-19 03:56:42 +08:00
Fix: tokenizer issue. (#11902)
#11786 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -271,7 +271,7 @@ class Agent(LLM, ToolBase):
|
|||||||
last_calling = ""
|
last_calling = ""
|
||||||
if len(hist) > 3:
|
if len(hist) > 3:
|
||||||
st = timer()
|
st = timer()
|
||||||
user_request = await asyncio.to_thread(full_question, messages=history, chat_mdl=self.chat_mdl)
|
user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||||
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||||
else:
|
else:
|
||||||
user_request = history[-1]["content"]
|
user_request = history[-1]["content"]
|
||||||
@ -309,7 +309,7 @@ class Agent(LLM, ToolBase):
|
|||||||
if len(hist) > 12:
|
if len(hist) > 12:
|
||||||
_hist = [hist[0], hist[1], *hist[-10:]]
|
_hist = [hist[0], hist[1], *hist[-10:]]
|
||||||
entire_txt = ""
|
entire_txt = ""
|
||||||
async for delta_ans in self._generate_streamly_async(_hist):
|
async for delta_ans in self._generate_streamly(_hist):
|
||||||
if not need2cite or cited:
|
if not need2cite or cited:
|
||||||
yield delta_ans, 0
|
yield delta_ans, 0
|
||||||
entire_txt += delta_ans
|
entire_txt += delta_ans
|
||||||
@ -397,7 +397,7 @@ Respond immediately with your final comprehensive answer.
|
|||||||
retrievals = self._canvas.get_reference()
|
retrievals = self._canvas.get_reference()
|
||||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||||
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
formated_refer = kb_prompt(retrievals, self.chat_mdl.max_length, True)
|
||||||
async for delta_ans in self._generate_streamly_async([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
async for delta_ans in self._generate_streamly([{"role": "system", "content": citation_plus("\n\n".join(formated_refer))},
|
||||||
{"role": "user", "content": text}
|
{"role": "user", "content": text}
|
||||||
]):
|
]):
|
||||||
yield delta_ans
|
yield delta_ans
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@ -97,7 +98,7 @@ class Categorize(LLM, ABC):
|
|||||||
component_name = "Categorize"
|
component_name = "Categorize"
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Categorize processing"):
|
if self.check_if_canceled("Categorize processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -121,7 +122,7 @@ class Categorize(LLM, ABC):
|
|||||||
if self.check_if_canceled("Categorize processing"):
|
if self.check_if_canceled("Categorize processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
ans = chat_mdl.chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
ans = await chat_mdl.async_chat(self._param.sys_prompt, [{"role": "user", "content": user_prompt}], self._param.gen_conf())
|
||||||
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
logging.info(f"input: {user_prompt}, answer: {str(ans)}")
|
||||||
if ERROR_PREFIX in ans:
|
if ERROR_PREFIX in ans:
|
||||||
raise Exception(ans)
|
raise Exception(ans)
|
||||||
@ -144,5 +145,9 @@ class Categorize(LLM, ABC):
|
|||||||
self.set_output("category_name", max_category)
|
self.set_output("category_name", max_category)
|
||||||
self.set_output("_next", cpn_ids)
|
self.set_output("_next", cpn_ids)
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10*60)))
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
return "Which should it falls into {}? ...".format(",".join([f"`{c}`" for c, _ in self._param.category_description.items()]))
|
||||||
|
|||||||
@ -18,9 +18,8 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Generator, AsyncGenerator
|
from typing import Any, AsyncGenerator
|
||||||
import json_repair
|
import json_repair
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
@ -168,53 +167,12 @@ class LLM(ComponentBase):
|
|||||||
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
sys_prompt = re.sub(rf"<{tag}>(.*?)</{tag}>", "", sys_prompt, flags=re.DOTALL|re.IGNORECASE)
|
||||||
return pts, sys_prompt
|
return pts, sys_prompt
|
||||||
|
|
||||||
def _generate(self, msg:list[dict], **kwargs) -> str:
|
|
||||||
if not self.imgs:
|
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
|
||||||
return self.chat_mdl.chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
|
||||||
|
|
||||||
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
async def _generate_async(self, msg: list[dict], **kwargs) -> str:
|
||||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
|
||||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
|
||||||
if self.imgs and hasattr(self.chat_mdl, "async_chat"):
|
|
||||||
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
|
||||||
return await asyncio.to_thread(self._generate, msg, **kwargs)
|
|
||||||
|
|
||||||
def _generate_streamly(self, msg:list[dict], **kwargs) -> Generator[str, None, None]:
|
|
||||||
ans = ""
|
|
||||||
last_idx = 0
|
|
||||||
endswith_think = False
|
|
||||||
def delta(txt):
|
|
||||||
nonlocal ans, last_idx, endswith_think
|
|
||||||
delta_ans = txt[last_idx:]
|
|
||||||
ans = txt
|
|
||||||
|
|
||||||
if delta_ans.find("<think>") == 0:
|
|
||||||
last_idx += len("<think>")
|
|
||||||
return "<think>"
|
|
||||||
elif delta_ans.find("<think>") > 0:
|
|
||||||
delta_ans = txt[last_idx:last_idx+delta_ans.find("<think>")]
|
|
||||||
last_idx += delta_ans.find("<think>")
|
|
||||||
return delta_ans
|
|
||||||
elif delta_ans.endswith("</think>"):
|
|
||||||
endswith_think = True
|
|
||||||
elif endswith_think:
|
|
||||||
endswith_think = False
|
|
||||||
return "</think>"
|
|
||||||
|
|
||||||
last_idx = len(ans)
|
|
||||||
if ans.endswith("</think>"):
|
|
||||||
last_idx -= len("</think>")
|
|
||||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
|
||||||
|
|
||||||
if not self.imgs:
|
if not self.imgs:
|
||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs):
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)
|
||||||
yield delta(txt)
|
return await self.chat_mdl.async_chat(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)
|
||||||
else:
|
|
||||||
for txt in self.chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs):
|
|
||||||
yield delta(txt)
|
|
||||||
|
|
||||||
async def _generate_streamly_async(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
async def _generate_streamly(self, msg: list[dict], **kwargs) -> AsyncGenerator[str, None]:
|
||||||
async def delta_wrapper(txt_iter):
|
async def delta_wrapper(txt_iter):
|
||||||
ans = ""
|
ans = ""
|
||||||
last_idx = 0
|
last_idx = 0
|
||||||
@ -246,36 +204,13 @@ class LLM(ComponentBase):
|
|||||||
async for t in txt_iter:
|
async for t in txt_iter:
|
||||||
yield delta(t)
|
yield delta(t)
|
||||||
|
|
||||||
if not self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
if not self.imgs:
|
||||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **kwargs)):
|
||||||
yield t
|
yield t
|
||||||
return
|
return
|
||||||
if self.imgs and hasattr(self.chat_mdl, "async_chat_streamly"):
|
|
||||||
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
|
||||||
yield t
|
|
||||||
return
|
|
||||||
|
|
||||||
# fallback
|
async for t in delta_wrapper(self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), images=self.imgs, **kwargs)):
|
||||||
loop = asyncio.get_running_loop()
|
yield t
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
|
||||||
|
|
||||||
def worker():
|
|
||||||
try:
|
|
||||||
for item in self._generate_streamly(msg, **kwargs):
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, item)
|
|
||||||
except Exception as e:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
|
||||||
finally:
|
|
||||||
loop.call_soon_threadsafe(queue.put_nowait, StopAsyncIteration)
|
|
||||||
|
|
||||||
threading.Thread(target=worker, daemon=True).start()
|
|
||||||
while True:
|
|
||||||
item = await queue.get()
|
|
||||||
if item is StopAsyncIteration:
|
|
||||||
break
|
|
||||||
if isinstance(item, Exception):
|
|
||||||
raise item
|
|
||||||
yield item
|
|
||||||
|
|
||||||
async def _stream_output_async(self, prompt, msg):
|
async def _stream_output_async(self, prompt, msg):
|
||||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||||
@ -407,8 +342,8 @@ class LLM(ComponentBase):
|
|||||||
def _invoke(self, **kwargs):
|
def _invoke(self, **kwargs):
|
||||||
return asyncio.run(self._invoke_async(**kwargs))
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||||
summ = tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||||
logging.info(f"[MEMORY]: {summ}")
|
logging.info(f"[MEMORY]: {summ}")
|
||||||
self._canvas.add_memory(user, assist, summ)
|
self._canvas.add_memory(user, assist, summ)
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -81,7 +82,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
component_name = "Retrieval"
|
component_name = "Retrieval"
|
||||||
|
|
||||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
def _invoke(self, **kwargs):
|
async def _invoke_async(self, **kwargs):
|
||||||
if self.check_if_canceled("Retrieval processing"):
|
if self.check_if_canceled("Retrieval processing"):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -132,7 +133,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if self._param.meta_data_filter.get("method") == "auto":
|
if self._param.meta_data_filter.get("method") == "auto":
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, query)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, query)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
@ -142,7 +143,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, query)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, query)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
@ -180,7 +181,7 @@ class Retrieval(ToolBase, ABC):
|
|||||||
doc_ids = ["-999"]
|
doc_ids = ["-999"]
|
||||||
|
|
||||||
if self._param.cross_languages:
|
if self._param.cross_languages:
|
||||||
query = cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
query = await cross_languages(kbs[0].tenant_id, None, query, self._param.cross_languages)
|
||||||
|
|
||||||
if kbs:
|
if kbs:
|
||||||
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
query = re.sub(r"^user[::\s]*", "", query, flags=re.IGNORECASE)
|
||||||
@ -253,6 +254,10 @@ class Retrieval(ToolBase, ABC):
|
|||||||
|
|
||||||
return form_cnt
|
return form_cnt
|
||||||
|
|
||||||
|
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
return asyncio.run(self._invoke_async(**kwargs))
|
||||||
|
|
||||||
def thoughts(self) -> str:
|
def thoughts(self) -> str:
|
||||||
return """
|
return """
|
||||||
Keywords: {}
|
Keywords: {}
|
||||||
|
|||||||
@ -51,7 +51,7 @@ class DeepResearcher:
|
|||||||
"""Remove Result Tags"""
|
"""Remove Result Tags"""
|
||||||
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
return DeepResearcher._remove_tags(text, BEGIN_SEARCH_RESULT, END_SEARCH_RESULT)
|
||||||
|
|
||||||
def _generate_reasoning(self, msg_history):
|
async def _generate_reasoning(self, msg_history):
|
||||||
"""Generate reasoning steps"""
|
"""Generate reasoning steps"""
|
||||||
query_think = ""
|
query_think = ""
|
||||||
if msg_history[-1]["role"] != "user":
|
if msg_history[-1]["role"] != "user":
|
||||||
@ -59,13 +59,14 @@ class DeepResearcher:
|
|||||||
else:
|
else:
|
||||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||||
|
|
||||||
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
async for ans in self.chat_mdl.async_chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
if not ans:
|
if not ans:
|
||||||
continue
|
continue
|
||||||
query_think = ans
|
query_think = ans
|
||||||
yield query_think
|
yield query_think
|
||||||
return query_think
|
query_think = ""
|
||||||
|
yield query_think
|
||||||
|
|
||||||
def _extract_search_queries(self, query_think, question, step_index):
|
def _extract_search_queries(self, query_think, question, step_index):
|
||||||
"""Extract search queries from thinking"""
|
"""Extract search queries from thinking"""
|
||||||
@ -143,10 +144,10 @@ class DeepResearcher:
|
|||||||
if d["doc_id"] not in dids:
|
if d["doc_id"] not in dids:
|
||||||
chunk_info["doc_aggs"].append(d)
|
chunk_info["doc_aggs"].append(d)
|
||||||
|
|
||||||
def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
async def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||||
"""Extract and summarize relevant information"""
|
"""Extract and summarize relevant information"""
|
||||||
summary_think = ""
|
summary_think = ""
|
||||||
for ans in self.chat_mdl.chat_streamly(
|
async for ans in self.chat_mdl.async_chat_streamly(
|
||||||
RELEVANT_EXTRACTION_PROMPT.format(
|
RELEVANT_EXTRACTION_PROMPT.format(
|
||||||
prev_reasoning=truncated_prev_reasoning,
|
prev_reasoning=truncated_prev_reasoning,
|
||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
@ -160,10 +161,11 @@ class DeepResearcher:
|
|||||||
continue
|
continue
|
||||||
summary_think = ans
|
summary_think = ans
|
||||||
yield summary_think
|
yield summary_think
|
||||||
|
summary_think = ""
|
||||||
|
|
||||||
return summary_think
|
yield summary_think
|
||||||
|
|
||||||
def thinking(self, chunk_info: dict, question: str):
|
async def thinking(self, chunk_info: dict, question: str):
|
||||||
executed_search_queries = []
|
executed_search_queries = []
|
||||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||||
all_reasoning_steps = []
|
all_reasoning_steps = []
|
||||||
@ -180,7 +182,7 @@ class DeepResearcher:
|
|||||||
|
|
||||||
# Step 1: Generate reasoning
|
# Step 1: Generate reasoning
|
||||||
query_think = ""
|
query_think = ""
|
||||||
for ans in self._generate_reasoning(msg_history):
|
async for ans in self._generate_reasoning(msg_history):
|
||||||
query_think = ans
|
query_think = ans
|
||||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||||
|
|
||||||
@ -223,7 +225,7 @@ class DeepResearcher:
|
|||||||
# Step 6: Extract relevant information
|
# Step 6: Extract relevant information
|
||||||
think += "\n\n"
|
think += "\n\n"
|
||||||
summary_think = ""
|
summary_think = ""
|
||||||
for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
async for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||||
summary_think = ans
|
summary_think = ans
|
||||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||||
|
|
||||||
|
|||||||
@ -313,7 +313,7 @@ async def retrieval_test():
|
|||||||
langs = req.get("cross_languages", [])
|
langs = req.get("cross_languages", [])
|
||||||
user_id = current_user.id
|
user_id = current_user.id
|
||||||
|
|
||||||
def _retrieval_sync():
|
async def _retrieval():
|
||||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
|
|
||||||
@ -323,7 +323,7 @@ async def retrieval_test():
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -333,7 +333,7 @@ async def retrieval_test():
|
|||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -347,7 +347,7 @@ async def retrieval_test():
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -357,7 +357,7 @@ async def retrieval_test():
|
|||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -384,7 +384,7 @@ async def retrieval_test():
|
|||||||
|
|
||||||
_question = question
|
_question = question
|
||||||
if langs:
|
if langs:
|
||||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -394,7 +394,7 @@ async def retrieval_test():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
_question += keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||||
@ -421,7 +421,7 @@ async def retrieval_test():
|
|||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await asyncio.to_thread(_retrieval_sync)
|
return await _retrieval()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||||
|
|||||||
@ -1549,11 +1549,11 @@ async def retrieval_test(tenant_id):
|
|||||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
question = await cross_languages(kb.tenant_id, None, question, langs)
|
||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
question += keyword_extraction(chat_mdl, question)
|
question += await keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
question,
|
question,
|
||||||
|
|||||||
@ -33,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
|||||||
from common import settings
|
from common import settings
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def upload(tenant_id):
|
async def upload(tenant_id):
|
||||||
|
|||||||
@ -13,7 +13,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
@ -44,6 +43,7 @@ from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extr
|
|||||||
from common.constants import RetCode, LLMType, StatusEnum
|
from common.constants import RetCode, LLMType, StatusEnum
|
||||||
from common import settings
|
from common import settings
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@token_required
|
@token_required
|
||||||
async def create(tenant_id, chat_id):
|
async def create(tenant_id, chat_id):
|
||||||
@ -969,7 +969,7 @@ async def retrieval_test_embedded():
|
|||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
def _retrieval_sync():
|
async def _retrieval():
|
||||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||||
tenant_ids = []
|
tenant_ids = []
|
||||||
_question = question
|
_question = question
|
||||||
@ -980,7 +980,7 @@ async def retrieval_test_embedded():
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, _question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -990,7 +990,7 @@ async def retrieval_test_embedded():
|
|||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, _question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -1004,7 +1004,7 @@ async def retrieval_test_embedded():
|
|||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -1014,7 +1014,7 @@ async def retrieval_test_embedded():
|
|||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not local_doc_ids:
|
if not local_doc_ids:
|
||||||
local_doc_ids = None
|
local_doc_ids = None
|
||||||
@ -1038,7 +1038,7 @@ async def retrieval_test_embedded():
|
|||||||
return get_error_data_result(message="Knowledgebase not found!")
|
return get_error_data_result(message="Knowledgebase not found!")
|
||||||
|
|
||||||
if langs:
|
if langs:
|
||||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||||
|
|
||||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||||
|
|
||||||
@ -1048,7 +1048,7 @@ async def retrieval_test_embedded():
|
|||||||
|
|
||||||
if req.get("keyword", False):
|
if req.get("keyword", False):
|
||||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||||
_question += keyword_extraction(chat_mdl, _question)
|
_question += await keyword_extraction(chat_mdl, _question)
|
||||||
|
|
||||||
labels = label_question(_question, [kb])
|
labels = label_question(_question, [kb])
|
||||||
ranks = settings.retriever.retrieval(
|
ranks = settings.retriever.retrieval(
|
||||||
@ -1068,7 +1068,7 @@ async def retrieval_test_embedded():
|
|||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await asyncio.to_thread(_retrieval_sync)
|
return await _retrieval()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -76,8 +77,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
|||||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||||
|
|
||||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||||
msg = chat_mdl.chat(system="", history=[
|
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
|
||||||
if msg.find("ERROR: ") == 0:
|
if msg.find("ERROR: ") == 0:
|
||||||
logging.error(
|
logging.error(
|
||||||
"'{}' doesn't work. {}".format(
|
"'{}' doesn't work. {}".format(
|
||||||
|
|||||||
@ -397,7 +397,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
# try to use sql if field mapping is good to go
|
# try to use sql if field mapping is good to go
|
||||||
if field_map:
|
if field_map:
|
||||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||||
if ans:
|
if ans:
|
||||||
yield ans
|
yield ans
|
||||||
return
|
return
|
||||||
@ -411,17 +411,17 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||||
|
|
||||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||||
else:
|
else:
|
||||||
questions = questions[-1:]
|
questions = questions[-1:]
|
||||||
|
|
||||||
if prompt_config.get("cross_languages"):
|
if prompt_config.get("cross_languages"):
|
||||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||||
|
|
||||||
if dialog.meta_data_filter:
|
if dialog.meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||||
if dialog.meta_data_filter.get("method") == "auto":
|
if dialog.meta_data_filter.get("method") == "auto":
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1])
|
filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not attachments:
|
if not attachments:
|
||||||
attachments = None
|
attachments = None
|
||||||
@ -430,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
if selected_keys:
|
if selected_keys:
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
||||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not attachments:
|
if not attachments:
|
||||||
attachments = None
|
attachments = None
|
||||||
@ -441,7 +441,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
attachments = ["-999"]
|
attachments = ["-999"]
|
||||||
|
|
||||||
if prompt_config.get("keyword", False):
|
if prompt_config.get("keyword", False):
|
||||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||||
|
|
||||||
refine_question_ts = timer()
|
refine_question_ts = timer()
|
||||||
|
|
||||||
@ -469,7 +469,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||||
if isinstance(think, str):
|
if isinstance(think, str):
|
||||||
thought = think
|
thought = think
|
||||||
knowledges = [t for t in think.split("\n") if t]
|
knowledges = [t for t in think.split("\n") if t]
|
||||||
@ -646,7 +646,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||||
sys_prompt = """
|
sys_prompt = """
|
||||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||||
Ensure that:
|
Ensure that:
|
||||||
@ -664,9 +664,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||||
tried_times = 0
|
tried_times = 0
|
||||||
|
|
||||||
def get_table():
|
async def get_table():
|
||||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||||
@ -705,7 +705,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tbl, sql = get_table()
|
tbl, sql = await get_table()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
user_prompt = """
|
user_prompt = """
|
||||||
Table name: {};
|
Table name: {};
|
||||||
@ -723,7 +723,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||||
try:
|
try:
|
||||||
tbl, sql = get_table()
|
tbl, sql = await get_table()
|
||||||
except Exception:
|
except Exception:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -839,7 +839,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
@ -848,7 +848,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
|||||||
if selected_keys:
|
if selected_keys:
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
@ -923,7 +923,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
if meta_data_filter:
|
if meta_data_filter:
|
||||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
if meta_data_filter.get("method") == "auto":
|
if meta_data_filter.get("method") == "auto":
|
||||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
@ -932,7 +932,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
|||||||
if selected_keys:
|
if selected_keys:
|
||||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||||
if filtered_metas:
|
if filtered_metas:
|
||||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
doc_ids = None
|
doc_ids = None
|
||||||
|
|||||||
@ -318,9 +318,6 @@ class LLMBundle(LLM4Tenant):
|
|||||||
return value
|
return value
|
||||||
raise value
|
raise value
|
||||||
|
|
||||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
|
||||||
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
|
||||||
|
|
||||||
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||||
result_queue: queue.Queue = queue.Queue()
|
result_queue: queue.Queue = queue.Queue()
|
||||||
|
|
||||||
@ -350,23 +347,6 @@ class LLMBundle(LLM4Tenant):
|
|||||||
raise item
|
raise item
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
|
||||||
ans = ""
|
|
||||||
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
|
||||||
if isinstance(txt, int):
|
|
||||||
break
|
|
||||||
|
|
||||||
if txt.endswith("</think>"):
|
|
||||||
ans = txt[: -len("</think>")]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not self.verbose_tool_use:
|
|
||||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# cancatination has beend done in async_chat_streamly
|
|
||||||
ans = txt
|
|
||||||
yield ans
|
|
||||||
|
|
||||||
def _bridge_sync_stream(self, gen):
|
def _bridge_sync_stream(self, gen):
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
queue: asyncio.Queue = asyncio.Queue()
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
|||||||
@ -98,7 +98,7 @@ class Extractor(ProcessBase, LLM):
|
|||||||
args[chunks_key] = ck["text"]
|
args[chunks_key] = ck["text"]
|
||||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||||
ck[self._param.field_name] = self._generate(msg)
|
ck[self._param.field_name] = await self._generate_async(msg)
|
||||||
prog += 1./len(chunks)
|
prog += 1./len(chunks)
|
||||||
if i % (len(chunks)//100+1) == 1:
|
if i % (len(chunks)//100+1) == 1:
|
||||||
self.callback(prog, f"{i+1} / {len(chunks)}")
|
self.callback(prog, f"{i+1} / {len(chunks)}")
|
||||||
@ -106,6 +106,6 @@ class Extractor(ProcessBase, LLM):
|
|||||||
else:
|
else:
|
||||||
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
msg, sys_prompt = self._sys_prompt_and_msg([], args)
|
||||||
msg.insert(0, {"role": "system", "content": sys_prompt})
|
msg.insert(0, {"role": "system", "content": sys_prompt})
|
||||||
self.set_output("chunks", [{self._param.field_name: self._generate(msg)}])
|
self.set_output("chunks", [{self._param.field_name: await self._generate_async(msg)}])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -33,6 +33,22 @@ class RagTokenizer(infinity.rag_tokenizer.RagTokenizer):
|
|||||||
return super().fine_grained_tokenize(tks)
|
return super().fine_grained_tokenize(tks)
|
||||||
|
|
||||||
|
|
||||||
|
def is_chinese(s):
|
||||||
|
return infinity.rag_tokenizer.is_chinese(s)
|
||||||
|
|
||||||
|
|
||||||
|
def is_number(s):
|
||||||
|
return infinity.rag_tokenizer.is_number(s)
|
||||||
|
|
||||||
|
|
||||||
|
def is_alphabet(s):
|
||||||
|
return infinity.rag_tokenizer.is_alphabet(s)
|
||||||
|
|
||||||
|
|
||||||
|
def naive_qie(txt):
|
||||||
|
return infinity.rag_tokenizer.naive_qie(txt)
|
||||||
|
|
||||||
|
|
||||||
tokenizer = RagTokenizer()
|
tokenizer = RagTokenizer()
|
||||||
tokenize = tokenizer.tokenize
|
tokenize = tokenizer.tokenize
|
||||||
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
fine_grained_tokenize = tokenizer.fine_grained_tokenize
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
@ -607,7 +608,7 @@ class Dealer:
|
|||||||
if not toc:
|
if not toc:
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
ids = relevant_chunks_with_toc(query, toc, chat_mdl, topn*2)
|
ids = asyncio.run(relevant_chunks_with_toc(query, toc, chat_mdl, topn*2))
|
||||||
if not ids:
|
if not ids:
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|||||||
@ -170,13 +170,13 @@ def citation_plus(sources: str) -> str:
|
|||||||
return template.render(example=citation_prompt(), sources=sources)
|
return template.render(example=citation_prompt(), sources=sources)
|
||||||
|
|
||||||
|
|
||||||
def keyword_extraction(chat_mdl, content, topn=3):
|
async def keyword_extraction(chat_mdl, content, topn=3):
|
||||||
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
template = PROMPT_JINJA_ENV.from_string(KEYWORD_PROMPT_TEMPLATE)
|
||||||
rendered_prompt = template.render(content=content, topn=topn)
|
rendered_prompt = template.render(content=content, topn=topn)
|
||||||
|
|
||||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
@ -185,13 +185,13 @@ def keyword_extraction(chat_mdl, content, topn=3):
|
|||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
def question_proposal(chat_mdl, content, topn=3):
|
async def question_proposal(chat_mdl, content, topn=3):
|
||||||
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
template = PROMPT_JINJA_ENV.from_string(QUESTION_PROMPT_TEMPLATE)
|
||||||
rendered_prompt = template.render(content=content, topn=topn)
|
rendered_prompt = template.render(content=content, topn=topn)
|
||||||
|
|
||||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.2})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
@ -200,7 +200,7 @@ def question_proposal(chat_mdl, content, topn=3):
|
|||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
async def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_mdl=None):
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
@ -229,12 +229,12 @@ def full_question(tenant_id=None, llm_id=None, messages=[], language=None, chat_
|
|||||||
language=language,
|
language=language,
|
||||||
)
|
)
|
||||||
|
|
||||||
ans = chat_mdl.chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
ans = await chat_mdl.async_chat(rendered_prompt, [{"role": "user", "content": "Output: "}])
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
return ans if ans.find("**ERROR**") < 0 else messages[-1]["content"]
|
||||||
|
|
||||||
|
|
||||||
def cross_languages(tenant_id, llm_id, query, languages=[]):
|
async def cross_languages(tenant_id, llm_id, query, languages=[]):
|
||||||
from common.constants import LLMType
|
from common.constants import LLMType
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.tenant_llm_service import TenantLLMService
|
from api.db.services.tenant_llm_service import TenantLLMService
|
||||||
@ -247,14 +247,14 @@ def cross_languages(tenant_id, llm_id, query, languages=[]):
|
|||||||
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
rendered_sys_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_SYS_PROMPT_TEMPLATE).render()
|
||||||
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
rendered_user_prompt = PROMPT_JINJA_ENV.from_string(CROSS_LANGUAGES_USER_PROMPT_TEMPLATE).render(query=query, languages=languages)
|
||||||
|
|
||||||
ans = chat_mdl.chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
ans = await chat_mdl.async_chat(rendered_sys_prompt, [{"role": "user", "content": rendered_user_prompt}], {"temperature": 0.2})
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
if ans.find("**ERROR**") >= 0:
|
if ans.find("**ERROR**") >= 0:
|
||||||
return query
|
return query
|
||||||
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
return "\n".join([a for a in re.sub(r"(^Output:|\n+)", "", ans, flags=re.DOTALL).split("===") if a.strip()])
|
||||||
|
|
||||||
|
|
||||||
def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
async def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
||||||
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
template = PROMPT_JINJA_ENV.from_string(CONTENT_TAGGING_PROMPT_TEMPLATE)
|
||||||
|
|
||||||
for ex in examples:
|
for ex in examples:
|
||||||
@ -269,7 +269,7 @@ def content_tagging(chat_mdl, content, all_tags, examples, topn=3):
|
|||||||
|
|
||||||
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
msg = [{"role": "system", "content": rendered_prompt}, {"role": "user", "content": "Output: "}]
|
||||||
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
_, msg = message_fit_in(msg, chat_mdl.max_length)
|
||||||
kwd = chat_mdl.chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
kwd = await chat_mdl.async_chat(rendered_prompt, msg[1:], {"temperature": 0.5})
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
@ -352,7 +352,7 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
|
|||||||
else:
|
else:
|
||||||
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
template = PROMPT_JINJA_ENV.from_string(ANALYZE_TASK_SYSTEM + "\n\n" + ANALYZE_TASK_USER)
|
||||||
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
context = template.render(task=task_name, context=context, agent_prompt=prompt, tools_desc=tools_desc)
|
||||||
kwd = await _chat_async(chat_mdl, context, [{"role": "user", "content": "Please analyze it."}])
|
kwd = await chat_mdl.async_chat(context, [{"role": "user", "content": "Please analyze it."}])
|
||||||
if isinstance(kwd, tuple):
|
if isinstance(kwd, tuple):
|
||||||
kwd = kwd[0]
|
kwd = kwd[0]
|
||||||
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
kwd = re.sub(r"^.*</think>", "", kwd, flags=re.DOTALL)
|
||||||
@ -361,14 +361,6 @@ async def analyze_task_async(chat_mdl, prompt, task_name, tools_description: lis
|
|||||||
return kwd
|
return kwd
|
||||||
|
|
||||||
|
|
||||||
async def _chat_async(chat_mdl, system: str, history: list, **kwargs):
|
|
||||||
chat_async = getattr(chat_mdl, "async_chat", None)
|
|
||||||
if chat_async and asyncio.iscoroutinefunction(chat_async):
|
|
||||||
return await chat_async(system, history, **kwargs)
|
|
||||||
return await asyncio.to_thread(chat_mdl.chat, system, history, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
async def next_step_async(chat_mdl, history:list, tools_description: list[dict], task_desc, user_defined_prompts: dict={}):
|
||||||
if not tools_description:
|
if not tools_description:
|
||||||
return "", 0
|
return "", 0
|
||||||
@ -380,8 +372,7 @@ async def next_step_async(chat_mdl, history:list, tools_description: list[dict],
|
|||||||
hist[-1]["content"] += user_prompt
|
hist[-1]["content"] += user_prompt
|
||||||
else:
|
else:
|
||||||
hist.append({"role": "user", "content": user_prompt})
|
hist.append({"role": "user", "content": user_prompt})
|
||||||
json_str = await _chat_async(
|
json_str = await chat_mdl.async_chat(
|
||||||
chat_mdl,
|
|
||||||
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
template.render(task_analysis=task_desc, desc=desc, today=datetime.datetime.now().strftime("%Y-%m-%d")),
|
||||||
hist[1:],
|
hist[1:],
|
||||||
stop=["<|stop|>"],
|
stop=["<|stop|>"],
|
||||||
@ -402,7 +393,7 @@ async def reflect_async(chat_mdl, history: list[dict], tool_call_res: list[Tuple
|
|||||||
else:
|
else:
|
||||||
hist.append({"role": "user", "content": user_prompt})
|
hist.append({"role": "user", "content": user_prompt})
|
||||||
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
_, msg = message_fit_in(hist, chat_mdl.max_length)
|
||||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:])
|
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||||
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
return """
|
return """
|
||||||
**Observation**
|
**Observation**
|
||||||
@ -422,14 +413,14 @@ def structured_output_prompt(schema=None) -> str:
|
|||||||
return template.render(schema=schema)
|
return template.render(schema=schema)
|
||||||
|
|
||||||
|
|
||||||
def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
async def tool_call_summary(chat_mdl, name: str, params: dict, result: str, user_defined_prompts: dict={}) -> str:
|
||||||
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
template = PROMPT_JINJA_ENV.from_string(SUMMARY4MEMORY)
|
||||||
system_prompt = template.render(name=name,
|
system_prompt = template.render(name=name,
|
||||||
params=json.dumps(params, ensure_ascii=False, indent=2),
|
params=json.dumps(params, ensure_ascii=False, indent=2),
|
||||||
result=result)
|
result=result)
|
||||||
user_prompt = "→ Summary: "
|
user_prompt = "→ Summary: "
|
||||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:])
|
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:])
|
||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
@ -438,11 +429,11 @@ async def rank_memories_async(chat_mdl, goal:str, sub_goal:str, tool_call_summar
|
|||||||
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
system_prompt = template.render(goal=goal, sub_goal=sub_goal, results=[{"i": i, "content": s} for i,s in enumerate(tool_call_summaries)])
|
||||||
user_prompt = " → rank: "
|
user_prompt = " → rank: "
|
||||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||||
ans = await _chat_async(chat_mdl, msg[0]["content"], msg[1:], stop="<|stop|>")
|
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:], stop="<|stop|>")
|
||||||
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
return re.sub(r"^.*</think>", "", ans, flags=re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
async def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
||||||
meta_data_structure = {}
|
meta_data_structure = {}
|
||||||
for key, values in meta_data.items():
|
for key, values in meta_data.items():
|
||||||
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
|
meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values
|
||||||
@ -453,7 +444,7 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
|||||||
user_question=query
|
user_question=query
|
||||||
)
|
)
|
||||||
user_prompt = "Generate filters:"
|
user_prompt = "Generate filters:"
|
||||||
ans = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}])
|
ans = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}])
|
||||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||||
try:
|
try:
|
||||||
ans = json_repair.loads(ans)
|
ans = json_repair.loads(ans)
|
||||||
@ -466,13 +457,13 @@ def gen_meta_filter(chat_mdl, meta_data:dict, query: str) -> dict:
|
|||||||
return {"conditions": []}
|
return {"conditions": []}
|
||||||
|
|
||||||
|
|
||||||
def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
async def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
||||||
from graphrag.utils import get_llm_cache, set_llm_cache
|
from graphrag.utils import get_llm_cache, set_llm_cache
|
||||||
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
cached = get_llm_cache(chat_mdl.llm_name, system_prompt, user_prompt, gen_conf)
|
||||||
if cached:
|
if cached:
|
||||||
return json_repair.loads(cached)
|
return json_repair.loads(cached)
|
||||||
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
_, msg = message_fit_in(form_message(system_prompt, user_prompt), chat_mdl.max_length)
|
||||||
ans = chat_mdl.chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
ans = await chat_mdl.async_chat(msg[0]["content"], msg[1:],gen_conf=gen_conf)
|
||||||
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
ans = re.sub(r"(^.*</think>|```json\n|```\n*$)", "", ans, flags=re.DOTALL)
|
||||||
try:
|
try:
|
||||||
res = json_repair.loads(ans)
|
res = json_repair.loads(ans)
|
||||||
@ -483,10 +474,10 @@ def gen_json(system_prompt:str, user_prompt:str, chat_mdl, gen_conf = None):
|
|||||||
|
|
||||||
|
|
||||||
TOC_DETECTION = load_prompt("toc_detection")
|
TOC_DETECTION = load_prompt("toc_detection")
|
||||||
def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
async def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
||||||
toc_secs = []
|
toc_secs = []
|
||||||
for i, sec in enumerate(page_1024[:22]):
|
for i, sec in enumerate(page_1024[:22]):
|
||||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_DETECTION).render(page_txt=sec), "Only JSON please.", chat_mdl)
|
||||||
if toc_secs and not ans["exists"]:
|
if toc_secs and not ans["exists"]:
|
||||||
break
|
break
|
||||||
toc_secs.append(sec)
|
toc_secs.append(sec)
|
||||||
@ -495,14 +486,14 @@ def detect_table_of_contents(page_1024:list[str], chat_mdl):
|
|||||||
|
|
||||||
TOC_EXTRACTION = load_prompt("toc_extraction")
|
TOC_EXTRACTION = load_prompt("toc_extraction")
|
||||||
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
TOC_EXTRACTION_CONTINUE = load_prompt("toc_extraction_continue")
|
||||||
def extract_table_of_contents(toc_pages, chat_mdl):
|
async def extract_table_of_contents(toc_pages, chat_mdl):
|
||||||
if not toc_pages:
|
if not toc_pages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
return await gen_json(PROMPT_JINJA_ENV.from_string(TOC_EXTRACTION).render(toc_page="\n".join(toc_pages)), "Only JSON please.", chat_mdl)
|
||||||
|
|
||||||
|
|
||||||
def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
async def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
||||||
tob_extractor_prompt = """
|
tob_extractor_prompt = """
|
||||||
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
You are given a table of contents in a json format and several pages of a document, your job is to add the physical_index to the table of contents in the json format.
|
||||||
|
|
||||||
@ -525,11 +516,11 @@ def toc_index_extractor(toc:list[dict], content:str, chat_mdl):
|
|||||||
Directly return the final JSON structure. Do not output anything else."""
|
Directly return the final JSON structure. Do not output anything else."""
|
||||||
|
|
||||||
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
prompt = tob_extractor_prompt + '\nTable of contents:\n' + json.dumps(toc, ensure_ascii=False, indent=2) + '\nDocument pages:\n' + content
|
||||||
return gen_json(prompt, "Only JSON please.", chat_mdl)
|
return await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||||
|
|
||||||
|
|
||||||
TOC_INDEX = load_prompt("toc_index")
|
TOC_INDEX = load_prompt("toc_index")
|
||||||
def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
async def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
||||||
if not toc_arr or not sections:
|
if not toc_arr or not sections:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -601,7 +592,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
|||||||
e = toc_arr[e]["indices"][0]
|
e = toc_arr[e]["indices"][0]
|
||||||
|
|
||||||
for j in range(st_i, min(e+1, len(sections))):
|
for j in range(st_i, min(e+1, len(sections))):
|
||||||
ans = gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
ans = await gen_json(PROMPT_JINJA_ENV.from_string(TOC_INDEX).render(
|
||||||
structure=it["structure"],
|
structure=it["structure"],
|
||||||
title=it["title"],
|
title=it["title"],
|
||||||
text=sections[j]), "Only JSON please.", chat_mdl)
|
text=sections[j]), "Only JSON please.", chat_mdl)
|
||||||
@ -614,7 +605,7 @@ def table_of_contents_index(toc_arr: list[dict], sections: list[str], chat_mdl):
|
|||||||
return toc_arr
|
return toc_arr
|
||||||
|
|
||||||
|
|
||||||
def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
|
async def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
|
||||||
prompt = """
|
prompt = """
|
||||||
You are given a raw table of contents and a table of contents.
|
You are given a raw table of contents and a table of contents.
|
||||||
Your job is to check if the table of contents is complete.
|
Your job is to check if the table of contents is complete.
|
||||||
@ -627,11 +618,11 @@ def check_if_toc_transformation_is_complete(content, toc, chat_mdl):
|
|||||||
Directly return the final JSON structure. Do not output anything else."""
|
Directly return the final JSON structure. Do not output anything else."""
|
||||||
|
|
||||||
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
|
prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc
|
||||||
response = gen_json(prompt, "Only JSON please.", chat_mdl)
|
response = await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||||
return response['completed']
|
return response['completed']
|
||||||
|
|
||||||
|
|
||||||
def toc_transformer(toc_pages, chat_mdl):
|
async def toc_transformer(toc_pages, chat_mdl):
|
||||||
init_prompt = """
|
init_prompt = """
|
||||||
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
|
You are given a table of contents, You job is to transform the whole table of content into a JSON format included table_of_contents.
|
||||||
|
|
||||||
@ -654,8 +645,8 @@ def toc_transformer(toc_pages, chat_mdl):
|
|||||||
def clean_toc(arr):
|
def clean_toc(arr):
|
||||||
for a in arr:
|
for a in arr:
|
||||||
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
a["title"] = re.sub(r"[.·….]{2,}", "", a["title"])
|
||||||
last_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
last_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||||
clean_toc(last_complete)
|
clean_toc(last_complete)
|
||||||
if if_complete == "yes":
|
if if_complete == "yes":
|
||||||
return last_complete
|
return last_complete
|
||||||
@ -672,21 +663,21 @@ def toc_transformer(toc_pages, chat_mdl):
|
|||||||
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
|
{json.dumps(last_complete[-24:], ensure_ascii=False, indent=2)}
|
||||||
|
|
||||||
Please continue the json structure, directly output the remaining part of the json structure."""
|
Please continue the json structure, directly output the remaining part of the json structure."""
|
||||||
new_complete = gen_json(prompt, "Only JSON please.", chat_mdl)
|
new_complete = await gen_json(prompt, "Only JSON please.", chat_mdl)
|
||||||
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
|
if not new_complete or str(last_complete).find(str(new_complete)) >= 0:
|
||||||
break
|
break
|
||||||
clean_toc(new_complete)
|
clean_toc(new_complete)
|
||||||
last_complete.extend(new_complete)
|
last_complete.extend(new_complete)
|
||||||
if_complete = check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
if_complete = await check_if_toc_transformation_is_complete(toc_content, json.dumps(last_complete, ensure_ascii=False, indent=2), chat_mdl)
|
||||||
|
|
||||||
return last_complete
|
return last_complete
|
||||||
|
|
||||||
|
|
||||||
TOC_LEVELS = load_prompt("assign_toc_levels")
|
TOC_LEVELS = load_prompt("assign_toc_levels")
|
||||||
def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
async def assign_toc_levels(toc_secs, chat_mdl, gen_conf = {"temperature": 0.2}):
|
||||||
if not toc_secs:
|
if not toc_secs:
|
||||||
return []
|
return []
|
||||||
return gen_json(
|
return await gen_json(
|
||||||
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
|
PROMPT_JINJA_ENV.from_string(TOC_LEVELS).render(),
|
||||||
str(toc_secs),
|
str(toc_secs),
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
@ -699,7 +690,7 @@ TOC_FROM_TEXT_USER = load_prompt("toc_from_text_user")
|
|||||||
# Generate TOC from text chunks with text llms
|
# Generate TOC from text chunks with text llms
|
||||||
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
async def gen_toc_from_text(txt_info: dict, chat_mdl, callback=None):
|
||||||
try:
|
try:
|
||||||
ans = gen_json(
|
ans = await gen_json(
|
||||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_SYSTEM).render(),
|
||||||
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
PROMPT_JINJA_ENV.from_string(TOC_FROM_TEXT_USER).render(text="\n".join([json.dumps(d, ensure_ascii=False) for d in txt_info["chunks"]])),
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
@ -782,7 +773,7 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
|||||||
raw_structure = [x.get("title", "") for x in filtered]
|
raw_structure = [x.get("title", "") for x in filtered]
|
||||||
|
|
||||||
# Assign hierarchy levels using LLM
|
# Assign hierarchy levels using LLM
|
||||||
toc_with_levels = assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
|
toc_with_levels = await assign_toc_levels(raw_structure, chat_mdl, {"temperature": 0.0, "top_p": 0.9})
|
||||||
if not toc_with_levels:
|
if not toc_with_levels:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@ -807,10 +798,10 @@ async def run_toc_from_text(chunks, chat_mdl, callback=None):
|
|||||||
|
|
||||||
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
TOC_RELEVANCE_SYSTEM = load_prompt("toc_relevance_system")
|
||||||
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
TOC_RELEVANCE_USER = load_prompt("toc_relevance_user")
|
||||||
def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
async def relevant_chunks_with_toc(query: str, toc:list[dict], chat_mdl, topn: int=6):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
try:
|
try:
|
||||||
ans = gen_json(
|
ans = await gen_json(
|
||||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_SYSTEM).render(),
|
||||||
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
PROMPT_JINJA_ENV.from_string(TOC_RELEVANCE_USER).render(query=query, toc_json="[\n%s\n]\n"%"\n".join([json.dumps({"level": d["level"], "title":d["title"]}, ensure_ascii=False) for d in toc])),
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
|
|||||||
@ -323,12 +323,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "keywords", {"topn": topn})
|
||||||
if not cached:
|
if not cached:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
cached = await asyncio.to_thread(
|
cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn)
|
||||||
keyword_extraction,
|
|
||||||
chat_mdl,
|
|
||||||
d["content_with_weight"],
|
|
||||||
topn,
|
|
||||||
)
|
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn})
|
||||||
if cached:
|
if cached:
|
||||||
d["important_kwd"] = cached.split(",")
|
d["important_kwd"] = cached.split(",")
|
||||||
@ -356,12 +351,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
cached = get_llm_cache(chat_mdl.llm_name, d["content_with_weight"], "question", {"topn": topn})
|
||||||
if not cached:
|
if not cached:
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
cached = await asyncio.to_thread(
|
cached = await question_proposal(chat_mdl, d["content_with_weight"], topn)
|
||||||
question_proposal,
|
|
||||||
chat_mdl,
|
|
||||||
d["content_with_weight"],
|
|
||||||
topn,
|
|
||||||
)
|
|
||||||
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "question", {"topn": topn})
|
||||||
if cached:
|
if cached:
|
||||||
d["question_kwd"] = cached.split("\n")
|
d["question_kwd"] = cached.split("\n")
|
||||||
@ -414,8 +404,7 @@ async def build_chunks(task, progress_callback):
|
|||||||
if not picked_examples:
|
if not picked_examples:
|
||||||
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
picked_examples.append({"content": "This is an example", TAG_FLD: {'example': 1}})
|
||||||
async with chat_limiter:
|
async with chat_limiter:
|
||||||
cached = await asyncio.to_thread(
|
cached = await content_tagging(
|
||||||
content_tagging,
|
|
||||||
chat_mdl,
|
chat_mdl,
|
||||||
d["content_with_weight"],
|
d["content_with_weight"],
|
||||||
all_tags,
|
all_tags,
|
||||||
|
|||||||
Reference in New Issue
Block a user