# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import binascii import logging import re import time from copy import deepcopy from datetime import datetime from functools import partial from timeit import default_timer as timer from langfuse import Langfuse from peewee import fn from api.db.services.file_service import FileService from common.constants import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle from common.metadata_utils import apply_meta_data_filter from api.db.services.tenant_llm_service import TenantLLMService from common.time_utils import current_timestamp, datetime_format from graphrag.general.mind_map_extractor import MindMapExtractor from rag.advanced_rag import DeepResearcher from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces from common import settings class DialogService(CommonService): model = Dialog @classmethod def save(cls, **kwargs): """Save a new record to database. This method creates a new record in the database with the provided field values, forcing an insert operation rather than an update. Args: **kwargs: Record field values as keyword arguments. Returns: Model instance: The created record object. """ sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj @classmethod def update_many_by_id(cls, data_list): """Update multiple records by their IDs. This method updates multiple records in the database, identified by their IDs. It automatically updates the update_time and update_date fields for each record. Args: data_list (list): List of dictionaries containing record data to update. Each dictionary must include an 'id' field. """ with DB.atomic(): for data in data_list: data["update_time"] = current_timestamp() data["update_date"] = datetime_format(datetime.now()) cls.model.update(data).where(cls.model.id == data["id"]).execute() @classmethod @DB.connection_context() def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name): chats = cls.model.select() if id: chats = chats.where(cls.model.id == id) if name: chats = chats.where(cls.model.name == name) chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value)) if desc: chats = chats.order_by(cls.model.getter_by(orderby).desc()) else: chats = chats.order_by(cls.model.getter_by(orderby).asc()) chats = chats.paginate(page_number, items_per_page) return list(chats.dicts()) @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None): from api.db.db_models import User fields = [ cls.model.id, cls.model.tenant_id, cls.model.name, cls.model.description, cls.model.language, cls.model.llm_id, cls.model.llm_setting, cls.model.prompt_type, cls.model.prompt_config, cls.model.similarity_threshold, cls.model.vector_similarity_weight, cls.model.top_n, cls.model.top_k, cls.model.do_refer, cls.model.rerank_id, cls.model.kb_ids, cls.model.icon, cls.model.status, User.nickname, User.avatar.alias("tenant_avatar"), cls.model.update_time, cls.model.create_time, ] if keywords: dialogs = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), (fn.LOWER(cls.model.name).contains(keywords.lower())), ) ) else: dialogs = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), ) ) if parser_id: dialogs = dialogs.where(cls.model.parser_id == parser_id) if desc: dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc()) else: dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc()) count = dialogs.count() if page_number and items_per_page: dialogs = dialogs.paginate(page_number, items_per_page) return list(dialogs.dicts()), count @classmethod @DB.connection_context() def get_all_dialogs_by_tenant_id(cls, tenant_id): fields = [cls.model.id] dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id) dialogs.order_by(cls.model.create_time.asc()) offset, limit = 0, 100 res = [] while True: d_batch = dialogs.offset(offset).limit(limit) _temp = list(d_batch.dicts()) if not _temp: break res.extend(_temp) offset += limit return res async def async_chat_solo(dialog, messages, stream=True): attachments = "" if "files" in messages[-1]: attachments = "\n\n".join(FileService.get_files(messages[-1]["files"])) if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) prompt_config = dialog.prompt_config tts_mdl = None if prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if attachments and msg: msg[-1]["content"] += attachments if stream: stream_iter = chat_mdl.async_chat_streamly_delta(prompt_config.get("system", ""), msg, dialog.llm_setting) async for kind, value, state in _stream_with_think_delta(stream_iter): if kind == "marker": flags = {"start_to_think": True} if value == "" else {"end_to_think": True} yield {"answer": "", "reference": {}, "audio_binary": None, "prompt": "", "created_at": time.time(), "final": False, **flags} continue yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "prompt": "", "created_at": time.time(), "final": False} else: answer = await chat_mdl.async_chat(prompt_config.get("system", ""), msg, dialog.llm_setting) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} def get_models(dialog): embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) if len(embedding_list) > 1: raise Exception("**ERROR**: Knowledge bases use different embedding models.") if embedding_list: embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0]) if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) if dialog.rerank_id: rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) if dialog.prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl BAD_CITATION_PATTERNS = [ re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】 re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 ] def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): max_index = len(kbinfos["chunks"]) def safe_add(i): if 0 <= i < max_index: idx.add(i) return True return False def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): nonlocal answer def replacement(match): try: i = int(match.group(group_index)) if safe_add(i): return f"[{repl(i)}]" except Exception: pass return match.group(0) answer = re.sub(pattern, replacement, answer, flags=flags) for pattern in BAD_CITATION_PATTERNS: find_and_replace(pattern) return answer, idx async def async_chat(dialog, messages, stream=True, **kwargs): logging.debug("Begin async_chat") assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): async for ans in async_chat_solo(dialog, messages, stream): yield ans return chat_start_ts = timer() if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) max_tokens = llm_model_config.get("max_tokens", 8192) check_llm_ts = timer() langfuse_tracer = None trace_context = {} langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id) if langfuse_keys: langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) try: if langfuse.auth_check(): langfuse_tracer = langfuse trace_id = langfuse_tracer.create_trace_id() trace_context = {"trace_id": trace_id} except Exception: # Skip langfuse tracing if connection fails pass check_langfuse_tracer_ts = timer() kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") if toolcall_session and tools: chat_mdl.bind_tools(toolcall_session, tools) bind_models_ts = timer() retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] attachments_= "" if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] if "files" in messages[-1]: attachments_ = "\n\n".join(FileService.get_files(messages[-1]["files"])) prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) logging.debug(f"field_map retrieved: {field_map}") # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) ans = await use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) # For aggregate queries (COUNT, SUM, etc.), chunks may be empty but answer is still valid if ans and (ans.get("reference", {}).get("chunks") or ans.get("answer")): yield ans return else: logging.debug("SQL failed or returned no results, falling back to vector search") param_keys = [p["key"] for p in prompt_config.get("parameters", [])] logging.debug(f"attachments={attachments}, param_keys={param_keys}, embd_mdl={embd_mdl}") for p in prompt_config["parameters"]: if p["key"] == "knowledge": continue if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): questions = [await full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): questions = [await cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) attachments = await apply_meta_data_filter( dialog.meta_data_filter, metas, questions[-1], chat_mdl, attachments, ) if prompt_config.get("keyword", False): questions[-1] += await keyword_extraction(chat_mdl, questions[-1]) refine_question_ts = timer() thought = "" kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} knowledges = [] if attachments is not None and "knowledge" in param_keys: logging.debug("Proceeding with retrieval") tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] if prompt_config.get("reasoning", False) or kwargs.get("reasoning"): reasoner = DeepResearcher( chat_mdl, prompt_config, partial( retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments, ), ) queue = asyncio.Queue() async def callback(msg:str): nonlocal queue await queue.put(msg + "
") await callback("") task = asyncio.create_task(reasoner.research(kbinfos, questions[-1], questions[-1], callback=callback)) while True: msg = await queue.get() if msg.find("") == 0: yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "start_to_think": True} elif msg.find("") == 0: yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, "end_to_think": True} break else: yield {"answer": msg, "reference": {}, "audio_binary": None, "final": False} await task else: if embd_mdl: kbinfos = await retriever.retrieval( " ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=attachments, top=dialog.top_k, aggs=True, rerank_mdl=rerank_mdl, rank_feature=label_question(" ".join(questions), kbs), ) if prompt_config.get("toc_enhance"): cks = await retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks kbinfos["chunks"] = retriever.retrieval_by_children(kbinfos["chunks"], tenant_ids) if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): ck = await settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) knowledges = kb_prompt(kbinfos, max_tokens) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res), "final": True} return kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)+attachments_}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" prompt = msg[0]["content"] if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) def decorate_answer(answer): nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer refs = [] ans = answer.split("
") think = "" if len(ans) == 2: think = ans[0] + "" answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): idx = set([]) if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer): answer, idx = retriever.insert_citations( answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight, ) else: for match in re.finditer(r"\[ID:([0-9]+)\]", answer): i = int(match.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) answer, idx = repair_bad_citation_formats(answer, kbinfos, idx) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs refs = deepcopy(kbinfos) for c in refs["chunks"]: if c.get("vector"): del c["vector"] if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" finish_chat_ts = timer() total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000 bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000 refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000 retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000 generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 tk_num = num_tokens_from_string(think + answer) prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt = ( f"{prompt}\n\n" "## Time elapsed:\n" f" - Total: {total_time_cost:.1f}ms\n" f" - Check LLM: {check_llm_time_cost:.1f}ms\n" f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n" f" - Bind models: {bind_embedding_time_cost:.1f}ms\n" f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n" f" - Retrieval: {retrieval_time_cost:.1f}ms\n" f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n" "## Token usage:\n" f" - Generated tokens(approximately): {tk_num}\n" f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s" ) # Add a condition check to call the end method only if langfuse_tracer exists if langfuse_tracer and "langfuse_generation" in locals(): langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL) langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()} langfuse_generation.update(output=langfuse_output) langfuse_generation.end() return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if langfuse_tracer: langfuse_generation = langfuse_tracer.start_generation( trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} ) if stream: stream_iter = chat_mdl.async_chat_streamly_delta(prompt + prompt4citation, msg[1:], gen_conf) last_state = None async for kind, value, state in _stream_with_think_delta(stream_iter): last_state = state if kind == "marker": flags = {"start_to_think": True} if value == "" else {"end_to_think": True} yield {"answer": "", "reference": {}, "audio_binary": None, "final": False, **flags} continue yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False} full_answer = last_state.full_text if last_state else "" if full_answer: final = decorate_answer(thought + full_answer) final["final"] = True final["audio_binary"] = None final["answer"] = "" yield final else: answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res return async def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): logging.debug(f"use_sql: Question: {question}") # Determine which document engine we're using doc_engine = "infinity" if settings.DOC_ENGINE_INFINITY else "es" # Construct the full table name # For Elasticsearch: ragflow_{tenant_id} (kb_id is in WHERE clause) # For Infinity: ragflow_{tenant_id}_{kb_id} (each KB has its own table) base_table = index_name(tenant_id) if doc_engine == "infinity" and kb_ids and len(kb_ids) == 1: # Infinity: append kb_id to table name table_name = f"{base_table}_{kb_ids[0]}" logging.debug(f"use_sql: Using Infinity table name: {table_name}") else: # Elasticsearch/OpenSearch: use base index name table_name = base_table logging.debug(f"use_sql: Using ES/OS table name: {table_name}") # Generate engine-specific SQL prompts if doc_engine == "infinity": # Build Infinity prompts with JSON extraction context json_field_names = list(field_map.keys()) sys_prompt = """You are a Database Administrator. Write SQL for a table with JSON 'chunk_data' column. JSON Extraction: json_extract_string(chunk_data, '$.FieldName') Numeric Cast: CAST(json_extract_string(chunk_data, '$.FieldName') AS INTEGER/FLOAT) NULL Check: json_extract_isnull(chunk_data, '$.FieldName') == false RULES: 1. Use EXACT field names (case-sensitive) from the list below 2. For SELECT: include doc_id, docnm, and json_extract_string() for requested fields 3. For COUNT: use COUNT(*) or COUNT(DISTINCT json_extract_string(...)) 4. Add AS alias for extracted field names 5. DO NOT select 'content' field 6. Only add NULL check (json_extract_isnull() == false) in WHERE clause when: - Question asks to "show me" or "display" specific columns - Question mentions "not null" or "excluding null" - Add NULL check for count specific column - DO NOT add NULL check for COUNT(*) queries (COUNT(*) counts all rows including nulls) 7. Output ONLY the SQL, no explanations""" user_prompt = """Table: {} Fields (EXACT case): {} {} Question: {} Write SQL using json_extract_string() with exact field names. Include doc_id, docnm for data queries. Only SQL.""".format( table_name, ", ".join(json_field_names), "\n".join([f" - {field}" for field in json_field_names]), question ) else: # Build ES/OS prompts with direct field access sys_prompt = """You are a Database Administrator. Write SQL queries. RULES: 1. Use EXACT field names from the schema below (e.g., product_tks, not product) 2. Quote field names starting with digit: "123_field" 3. Add IS NOT NULL in WHERE clause when: - Question asks to "show me" or "display" specific columns 4. Include doc_id/docnm in non-aggregate statement 5. Output ONLY the SQL, no explanations""" user_prompt = """Table: {} Available fields: {} Question: {} Write SQL using exact field names above. Include doc_id, docnm_kwd for data queries. Only SQL.""".format( table_name, "\n".join([f" - {k} ({v})" for k, v in field_map.items()]), question ) tried_times = 0 async def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times sql = await chat_mdl.async_chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) logging.debug(f"use_sql: Raw SQL from LLM: {repr(sql[:500])}") # Remove think blocks if present (format: ...) sql = re.sub(r"\n.*?\n\s*", "", sql, flags=re.DOTALL) sql = re.sub(r"思考\n.*?\n", "", sql, flags=re.DOTALL) # Remove markdown code blocks (```sql ... ```) sql = re.sub(r"```(?:sql)?\s*", "", sql, flags=re.IGNORECASE) sql = re.sub(r"```\s*$", "", sql, flags=re.IGNORECASE) # Remove trailing semicolon that ES SQL parser doesn't like sql = sql.rstrip().rstrip(';').strip() # Add kb_id filter for ES/OS only (Infinity already has it in table name) if doc_engine != "infinity" and kb_ids: # Build kb_filter: single KB or multiple KBs with OR if len(kb_ids) == 1: kb_filter = f"kb_id = '{kb_ids[0]}'" else: kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" if "where " not in sql.lower(): o = sql.lower().split("order by") if len(o) > 1: sql = o[0] + f" WHERE {kb_filter} order by " + o[1] else: sql += f" WHERE {kb_filter}" elif "kb_id =" not in sql.lower() and "kb_id=" not in sql.lower(): sql = re.sub(r"\bwhere\b ", f"where {kb_filter} and ", sql, flags=re.IGNORECASE) logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 logging.debug(f"use_sql: Executing SQL retrieval (attempt {tried_times})") tbl = settings.retriever.sql_retrieval(sql, format="json") if tbl is None: logging.debug("use_sql: SQL retrieval returned None") return None, sql logging.debug(f"use_sql: SQL retrieval completed, got {len(tbl.get('rows', []))} rows") return tbl, sql try: tbl, sql = await get_table() logging.debug(f"use_sql: Initial SQL execution SUCCESS. SQL: {sql}") logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows, columns: {[c['name'] for c in tbl.get('columns', [])]}") except Exception as e: logging.warning(f"use_sql: Initial SQL execution FAILED with error: {e}") # Build retry prompt with error information if doc_engine == "infinity": # Build Infinity error retry prompt json_field_names = list(field_map.keys()) user_prompt = """ Table name: {}; JSON fields available in 'chunk_data' column (use these exact names in json_extract_string): {} Question: {} Please write the SQL using json_extract_string(chunk_data, '$.field_name') with the field names from the list above. Only SQL, no explanations. The SQL error you provided last time is as follows: {} Please correct the error and write SQL again using json_extract_string(chunk_data, '$.field_name') syntax with the correct field names. Only SQL, no explanations. """.format(table_name, "\n".join([f" - {field}" for field in json_field_names]), question, e) else: # Build ES/OS error retry prompt user_prompt = """ Table name: {}; Table of database fields are as follows (use the field names directly in SQL): {} Question are as follows: {} Please write the SQL using the exact field names above, only SQL, without any other explanations or text. The SQL error you provided last time is as follows: {} Please correct the error and write SQL again using the exact field names above, only SQL, without any other explanations or text. """.format(table_name, "\n".join([f"{k} ({v})" for k, v in field_map.items()]), question, e) try: tbl, sql = await get_table() logging.debug(f"use_sql: Retry SQL execution SUCCESS. SQL: {sql}") logging.debug(f"use_sql: Retrieved {len(tbl.get('rows', []))} rows on retry") except Exception: logging.error("use_sql: Retry SQL execution also FAILED, returning None") return if len(tbl["rows"]) == 0: logging.warning(f"use_sql: No rows returned from SQL query, returning None. SQL: {sql}") return None logging.debug(f"use_sql: Proceeding with {len(tbl['rows'])} rows to build answer") docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]]) logging.debug(f"use_sql: All columns: {[(i, c['name']) for i, c in enumerate(tbl['columns'])]}") logging.debug(f"use_sql: docid_idx={docid_idx}, doc_name_idx={doc_name_idx}") column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] logging.debug(f"use_sql: column_idx={column_idx}") logging.debug(f"use_sql: field_map={field_map}") # Helper function to map column names to display names def map_column_name(col_name): if col_name.lower() == "count(star)": return "COUNT(*)" # First, try to extract AS alias from any expression (aggregate functions, json_extract_string, etc.) # Pattern: anything AS alias_name as_match = re.search(r'\s+AS\s+([^\s,)]+)', col_name, re.IGNORECASE) if as_match: alias = as_match.group(1).strip('"\'') # Use the alias for display name lookup if alias in field_map: display = field_map[alias] return re.sub(r"(/.*|([^()]+))", "", display) # If alias not in field_map, try to match case-insensitively for field_key, display_value in field_map.items(): if field_key.lower() == alias.lower(): return re.sub(r"(/.*|([^()]+))", "", display_value) # Return alias as-is if no mapping found return alias # Try direct mapping first (for simple column names) if col_name in field_map: display = field_map[col_name] # Clean up any suffix patterns return re.sub(r"(/.*|([^()]+))", "", display) # Try case-insensitive match for simple column names col_lower = col_name.lower() for field_key, display_value in field_map.items(): if field_key.lower() == col_lower: return re.sub(r"(/.*|([^()]+))", "", display_value) # For aggregate expressions or complex expressions without AS alias, # try to replace field names with display names result = col_name for field_name, display_name in field_map.items(): # Replace field_name with display_name in the expression result = result.replace(field_name, display_name) # Clean up any suffix patterns result = re.sub(r"(/.*|([^()]+))", "", result) return result # compose Markdown table columns = ( "|" + "|".join( [map_column_name(tbl["columns"][i]["name"]) for i in column_idx]) + ( "|Source|" if docid_idx and doc_name_idx else "|") ) line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") # Build rows ensuring column names match values - create a dict for each row # keyed by column name to handle any SQL column order rows = [] for row_idx, r in enumerate(tbl["rows"]): row_dict = {tbl["columns"][i]["name"]: r[i] for i in range(len(tbl["columns"])) if i < len(r)} if row_idx == 0: logging.debug(f"use_sql: First row data: {row_dict}") row_values = [] for col_idx in column_idx: col_name = tbl["columns"][col_idx]["name"] value = row_dict.get(col_name, " ") row_values.append(remove_redundant_spaces(str(value)).replace("None", " ")) # Add Source column with citation marker if Source column exists if docid_idx and doc_name_idx: row_values.append(f" ##{row_idx}$$") row_str = "|" + "|".join(row_values) + "|" if re.sub(r"[ |]+", "", row_str): rows.append(row_str) if quota: rows = "\n".join(rows) else: rows = "\n".join(rows) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not doc_name_idx: logging.warning(f"use_sql: SQL missing required doc_id or docnm_kwd field. docid_idx={docid_idx}, doc_name_idx={doc_name_idx}. SQL: {sql}") # For aggregate queries (COUNT, SUM, AVG, MAX, MIN, DISTINCT), fetch doc_id, docnm_kwd separately # to provide source chunks, but keep the original table format answer if re.search(r"(count|sum|avg|max|min|distinct)\s*\(", sql.lower()): # Keep original table format as answer answer = "\n".join([columns, line, rows]) # Now fetch doc_id, docnm_kwd to provide source chunks # Extract WHERE clause from the original SQL where_match = re.search(r"\bwhere\b(.+?)(?:\bgroup by\b|\border by\b|\blimit\b|$)", sql, re.IGNORECASE) if where_match: where_clause = where_match.group(1).strip() # Build a query to get doc_id and docnm_kwd with the same WHERE clause chunks_sql = f"select doc_id, docnm_kwd from {table_name} where {where_clause}" # Add LIMIT to avoid fetching too many chunks if "limit" not in chunks_sql.lower(): chunks_sql += " limit 20" logging.debug(f"use_sql: Fetching chunks with SQL: {chunks_sql}") try: chunks_tbl = settings.retriever.sql_retrieval(chunks_sql, format="json") if chunks_tbl.get("rows") and len(chunks_tbl["rows"]) > 0: # Build chunks reference - use case-insensitive matching chunks_did_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() == "doc_id"), None) chunks_dn_idx = next((i for i, c in enumerate(chunks_tbl["columns"]) if c["name"].lower() in ["docnm_kwd", "docnm"]), None) if chunks_did_idx is not None and chunks_dn_idx is not None: chunks = [{"doc_id": r[chunks_did_idx], "docnm_kwd": r[chunks_dn_idx]} for r in chunks_tbl["rows"]] # Build doc_aggs doc_aggs = {} for r in chunks_tbl["rows"]: doc_id = r[chunks_did_idx] doc_name = r[chunks_dn_idx] if doc_id not in doc_aggs: doc_aggs[doc_id] = {"doc_name": doc_name, "count": 0} doc_aggs[doc_id]["count"] += 1 doc_aggs_list = [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()] logging.debug(f"use_sql: Returning aggregate answer with {len(chunks)} chunks from {len(doc_aggs)} documents") return {"answer": answer, "reference": {"chunks": chunks, "doc_aggs": doc_aggs_list}, "prompt": sys_prompt} except Exception as e: logging.warning(f"use_sql: Failed to fetch chunks: {e}") # Fallback: return answer without chunks return {"answer": answer, "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} # Fallback to table format for other cases return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} docid_idx = list(docid_idx)[0] doc_name_idx = list(doc_name_idx)[0] doc_aggs = {} for r in tbl["rows"]: if r[docid_idx] not in doc_aggs: doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} doc_aggs[r[docid_idx]]["count"] += 1 result = { "answer": "\n".join([columns, line, rows]), "reference": { "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], }, "prompt": sys_prompt, } logging.debug(f"use_sql: Returning answer with {len(result['reference']['chunks'])} chunks from {len(doc_aggs)} documents") return result def clean_tts_text(text: str) -> str: if not text: return "" text = text.encode("utf-8", "ignore").decode("utf-8", "ignore") text = re.sub(r"[\x00-\x08\x0B-\x0C\x0E-\x1F\x7F]", "", text) emoji_pattern = re.compile( "[\U0001F600-\U0001F64F" "\U0001F300-\U0001F5FF" "\U0001F680-\U0001F6FF" "\U0001F1E0-\U0001F1FF" "\U00002700-\U000027BF" "\U0001F900-\U0001F9FF" "\U0001FA70-\U0001FAFF" "\U0001FAD0-\U0001FAFF]+", flags=re.UNICODE ) text = emoji_pattern.sub("", text) text = re.sub(r"\s+", " ", text).strip() MAX_LEN = 500 if len(text) > MAX_LEN: text = text[:MAX_LEN] return text def tts(tts_mdl, text): if not tts_mdl or not text: return None text = clean_tts_text(text) if not text: return None bin = b"" try: for chunk in tts_mdl.tts(text): bin += chunk except Exception as e: logging.error(f"TTS failed: {e}, text={text!r}") return None return binascii.hexlify(bin).decode("utf-8") class _ThinkStreamState: def __init__(self) -> None: self.full_text = "" self.last_idx = 0 self.endswith_think = False self.last_full = "" self.last_model_full = "" self.in_think = False self.buffer = "" def _next_think_delta(state: _ThinkStreamState) -> str: full_text = state.full_text if full_text == state.last_full: return "" state.last_full = full_text delta_ans = full_text[state.last_idx:] if delta_ans.find("") == 0: state.last_idx += len("") return "" if delta_ans.find("") > 0: delta_text = full_text[state.last_idx:state.last_idx + delta_ans.find("")] state.last_idx += delta_ans.find("") return delta_text if delta_ans.endswith(""): state.endswith_think = True elif state.endswith_think: state.endswith_think = False return "" state.last_idx = len(full_text) if full_text.endswith(""): state.last_idx -= len("") return re.sub(r"(|)", "", delta_ans) async def _stream_with_think_delta(stream_iter, min_tokens: int = 16): state = _ThinkStreamState() async for chunk in stream_iter: if not chunk: continue if chunk.startswith(state.last_model_full): new_part = chunk[len(state.last_model_full):] state.last_model_full = chunk else: new_part = chunk state.last_model_full += chunk if not new_part: continue state.full_text += new_part delta = _next_think_delta(state) if not delta: continue if delta in ("", ""): if delta == "" and state.in_think: continue if delta == "" and not state.in_think: continue if state.buffer: yield ("text", state.buffer, state) state.buffer = "" state.in_think = delta == "" yield ("marker", delta, state) continue state.buffer += delta if num_tokens_from_string(state.buffer) < min_tokens: continue yield ("text", state.buffer, state) state.buffer = "" if state.buffer: yield ("text", state.buffer, state) state.buffer = "" if state.endswith_think: yield ("marker", "", state) async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) chat_llm_name = search_config.get("chat_id", chat_llm_name) rerank_id = search_config.get("rerank_id", "") meta_data_filter = search_config.get("meta_data_filter") kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) if rerank_id: rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) kbinfos = await retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=kb_ids, page=1, page_size=12, similarity_threshold=search_config.get("similarity_threshold", 0.1), vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), top=search_config.get("top_k", 1024), doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs) ) knowledges = kb_prompt(kbinfos, max_tokens) sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) msg = [{"role": "user", "content": question}] def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs refs = deepcopy(kbinfos) for c in refs["chunks"]: if c.get("vector"): del c["vector"] if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" refs["chunks"] = chunks_format(refs) return {"answer": answer, "reference": refs} stream_iter = chat_mdl.async_chat_streamly_delta(sys_prompt, msg, {"temperature": 0.1}) last_state = None async for kind, value, state in _stream_with_think_delta(stream_iter): last_state = state if kind == "marker": flags = {"start_to_think": True} if value == "" else {"end_to_think": True} yield {"answer": "", "reference": {}, "final": False, **flags} continue yield {"answer": value, "reference": {}, "final": False} full_answer = last_state.full_text if last_state else "" final = decorate_answer(full_answer) final["final"] = True final["answer"] = "" yield final async def gen_mindmap(question, kb_ids, tenant_id, search_config={}): meta_data_filter = search_config.get("meta_data_filter", {}) doc_ids = search_config.get("doc_ids", []) rerank_id = search_config.get("rerank_id", "") rerank_mdl = None kbs = KnowledgebaseService.get_by_ids(kb_ids) if not kbs: return {"error": "No KB selected"} embedding_list = list(set([kb.embd_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs])) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) if rerank_id: rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) doc_ids = await apply_meta_data_filter(meta_data_filter, metas, question, chat_mdl, doc_ids) ranks = await settings.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=kb_ids, page=1, page_size=12, similarity_threshold=search_config.get("similarity_threshold", 0.2), vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), top=search_config.get("top_k", 1024), doc_ids=doc_ids, aggs=False, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs), ) mindmap = MindMapExtractor(chat_mdl) mind_map = await mindmap([c["content_with_weight"] for c in ranks["chunks"]]) return mind_map.output