mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-30 00:32:30 +08:00
Fix: tokenizer issue. (#11902)
#11786 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -313,7 +313,7 @@ async def retrieval_test():
|
||||
langs = req.get("cross_languages", [])
|
||||
user_id = current_user.id
|
||||
|
||||
def _retrieval_sync():
|
||||
async def _retrieval():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
|
||||
@ -323,7 +323,7 @@ async def retrieval_test():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -333,7 +333,7 @@ async def retrieval_test():
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -347,7 +347,7 @@ async def retrieval_test():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -357,7 +357,7 @@ async def retrieval_test():
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(user_id, LLMType.CHAT)
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -384,7 +384,7 @@ async def retrieval_test():
|
||||
|
||||
_question = question
|
||||
if langs:
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -394,7 +394,7 @@ async def retrieval_test():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(_question, embd_mdl, tenant_ids, kb_ids, page, size,
|
||||
@ -421,7 +421,7 @@ async def retrieval_test():
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
return await _retrieval()
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
||||
|
||||
@ -1549,11 +1549,11 @@ async def retrieval_test(tenant_id):
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, LLMType.RERANK, llm_name=req["rerank_id"])
|
||||
|
||||
if langs:
|
||||
question = cross_languages(kb.tenant_id, None, question, langs)
|
||||
question = await cross_languages(kb.tenant_id, None, question, langs)
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
question += keyword_extraction(chat_mdl, question)
|
||||
question += await keyword_extraction(chat_mdl, question)
|
||||
|
||||
ranks = settings.retriever.retrieval(
|
||||
question,
|
||||
|
||||
@ -33,6 +33,7 @@ from api.utils.web_utils import CONTENT_TYPE_MAP
|
||||
from common import settings
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route('/file/upload', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def upload(tenant_id):
|
||||
|
||||
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
@ -44,6 +43,7 @@ from rag.prompts.generator import cross_languages, gen_meta_filter, keyword_extr
|
||||
from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create(tenant_id, chat_id):
|
||||
@ -969,7 +969,7 @@ async def retrieval_test_embedded():
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
|
||||
def _retrieval_sync():
|
||||
async def _retrieval():
|
||||
local_doc_ids = list(doc_ids) if doc_ids else []
|
||||
tenant_ids = []
|
||||
_question = question
|
||||
@ -980,7 +980,7 @@ async def retrieval_test_embedded():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, _question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -990,7 +990,7 @@ async def retrieval_test_embedded():
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, _question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, _question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -1004,7 +1004,7 @@ async def retrieval_test_embedded():
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -1014,7 +1014,7 @@ async def retrieval_test_embedded():
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT)
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
local_doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not local_doc_ids:
|
||||
local_doc_ids = None
|
||||
@ -1038,7 +1038,7 @@ async def retrieval_test_embedded():
|
||||
return get_error_data_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
_question = cross_languages(kb.tenant_id, None, _question, langs)
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
|
||||
embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id)
|
||||
|
||||
@ -1048,7 +1048,7 @@ async def retrieval_test_embedded():
|
||||
|
||||
if req.get("keyword", False):
|
||||
chat_mdl = LLMBundle(kb.tenant_id, LLMType.CHAT)
|
||||
_question += keyword_extraction(chat_mdl, _question)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
labels = label_question(_question, [kb])
|
||||
ranks = settings.retriever.retrieval(
|
||||
@ -1068,7 +1068,7 @@ async def retrieval_test_embedded():
|
||||
return get_json_result(data=ranks)
|
||||
|
||||
try:
|
||||
return await asyncio.to_thread(_retrieval_sync)
|
||||
return await _retrieval()
|
||||
except Exception as e:
|
||||
if str(e).find("not_found") > 0:
|
||||
return get_json_result(data=False, message="No chunk found! Check the chunk status please!",
|
||||
|
||||
@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
@ -76,8 +77,7 @@ def init_superuser(nickname=DEFAULT_SUPERUSER_NICKNAME, email=DEFAULT_SUPERUSER_
|
||||
f"Super user initialized. email: {email},A default password has been set; changing the password after login is strongly recommended.")
|
||||
|
||||
chat_mdl = LLMBundle(tenant["id"], LLMType.CHAT, tenant["llm_id"])
|
||||
msg = chat_mdl.chat(system="", history=[
|
||||
{"role": "user", "content": "Hello!"}], gen_conf={})
|
||||
msg = asyncio.run(chat_mdl.async_chat(system="", history=[{"role": "user", "content": "Hello!"}], gen_conf={}))
|
||||
if msg.find("ERROR: ") == 0:
|
||||
logging.error(
|
||||
"'{}' doesn't work. {}".format(
|
||||
|
||||
@ -397,7 +397,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
# try to use sql if field mapping is good to go
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids)
|
||||
if ans:
|
||||
yield ans
|
||||
return
|
||||
@ -411,17 +411,17 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ")
|
||||
|
||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
else:
|
||||
questions = questions[-1:]
|
||||
|
||||
if prompt_config.get("cross_languages"):
|
||||
questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||
questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])]
|
||||
|
||||
if dialog.meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(dialog.kb_ids)
|
||||
if dialog.meta_data_filter.get("method") == "auto":
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
@ -430,7 +430,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, questions[-1])
|
||||
attachments.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not attachments:
|
||||
attachments = None
|
||||
@ -441,7 +441,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
attachments = ["-999"]
|
||||
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
questions[-1] += await keyword_extraction(chat_mdl, questions[-1])
|
||||
|
||||
refine_question_ts = timer()
|
||||
|
||||
@ -469,7 +469,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
),
|
||||
)
|
||||
|
||||
for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
async for think in reasoner.thinking(kbinfos, attachments_ + " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
knowledges = [t for t in think.split("\n") if t]
|
||||
@ -646,7 +646,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None):
|
||||
sys_prompt = """
|
||||
You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question.
|
||||
Ensure that:
|
||||
@ -664,9 +664,9 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question)
|
||||
tried_times = 0
|
||||
|
||||
def get_table():
|
||||
async def get_table():
|
||||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06})
|
||||
sql = re.sub(r"^.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
@ -705,7 +705,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
return settings.retriever.sql_retrieval(sql, format="json"), sql
|
||||
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
tbl, sql = await get_table()
|
||||
except Exception as e:
|
||||
user_prompt = """
|
||||
Table name: {};
|
||||
@ -723,7 +723,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
||||
Please correct the error and write SQL again, only SQL, without any other explanations or text.
|
||||
""".format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, e)
|
||||
try:
|
||||
tbl, sql = get_table()
|
||||
tbl, sql = await get_table()
|
||||
except Exception:
|
||||
return
|
||||
|
||||
@ -839,7 +839,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
@ -848,7 +848,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
@ -923,7 +923,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if meta_data_filter:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
if meta_data_filter.get("method") == "auto":
|
||||
filters: dict = gen_meta_filter(chat_mdl, metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
@ -932,7 +932,7 @@ async def gen_mindmap(question, kb_ids, tenant_id, search_config={}):
|
||||
if selected_keys:
|
||||
filtered_metas = {key: metas[key] for key in selected_keys if key in metas}
|
||||
if filtered_metas:
|
||||
filters: dict = gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question)
|
||||
doc_ids.extend(meta_filter(metas, filters["conditions"], filters.get("logic", "and")))
|
||||
if not doc_ids:
|
||||
doc_ids = None
|
||||
|
||||
@ -318,9 +318,6 @@ class LLMBundle(LLM4Tenant):
|
||||
return value
|
||||
raise value
|
||||
|
||||
def chat(self, system: str, history: list, gen_conf: dict = {}, **kwargs) -> str:
|
||||
return self._run_coroutine_sync(self.async_chat(system, history, gen_conf, **kwargs))
|
||||
|
||||
def _sync_from_async_stream(self, async_gen_fn, *args, **kwargs):
|
||||
result_queue: queue.Queue = queue.Queue()
|
||||
|
||||
@ -350,23 +347,6 @@ class LLMBundle(LLM4Tenant):
|
||||
raise item
|
||||
yield item
|
||||
|
||||
def chat_streamly(self, system: str, history: list, gen_conf: dict = {}, **kwargs):
|
||||
ans = ""
|
||||
for txt in self._sync_from_async_stream(self.async_chat_streamly, system, history, gen_conf, **kwargs):
|
||||
if isinstance(txt, int):
|
||||
break
|
||||
|
||||
if txt.endswith("</think>"):
|
||||
ans = txt[: -len("</think>")]
|
||||
continue
|
||||
|
||||
if not self.verbose_tool_use:
|
||||
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
|
||||
|
||||
# cancatination has beend done in async_chat_streamly
|
||||
ans = txt
|
||||
yield ans
|
||||
|
||||
def _bridge_sync_stream(self, gen):
|
||||
loop = asyncio.get_running_loop()
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
Reference in New Issue
Block a user