# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import binascii import logging import re import time from copy import deepcopy from datetime import datetime from functools import partial from timeit import default_timer as timer import trio from langfuse import Langfuse from peewee import fn from agentic_reasoning import DeepResearcher from api import settings from api.db import LLMType, ParserType, StatusEnum from api.db.db_models import DB, Dialog from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.llm_service import LLMBundle from api.db.services.tenant_llm_service import TenantLLMService from api.utils import current_timestamp, datetime_format from graphrag.general.mind_map_extractor import MindMapExtractor from rag.app.resume import forbidden_select_fields4resume from rag.app.tag import label_question from rag.nlp.search import index_name from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, \ gen_meta_filter, PROMPT_JINJA_ENV, ASK_SUMMARY from rag.utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily from common.string_utils import remove_redundant_spaces class DialogService(CommonService): model = Dialog @classmethod def save(cls, **kwargs): """Save a new record to database. This method creates a new record in the database with the provided field values, forcing an insert operation rather than an update. Args: **kwargs: Record field values as keyword arguments. Returns: Model instance: The created record object. """ sample_obj = cls.model(**kwargs).save(force_insert=True) return sample_obj @classmethod def update_many_by_id(cls, data_list): """Update multiple records by their IDs. This method updates multiple records in the database, identified by their IDs. It automatically updates the update_time and update_date fields for each record. Args: data_list (list): List of dictionaries containing record data to update. Each dictionary must include an 'id' field. """ with DB.atomic(): for data in data_list: data["update_time"] = current_timestamp() data["update_date"] = datetime_format(datetime.now()) cls.model.update(data).where(cls.model.id == data["id"]).execute() @classmethod @DB.connection_context() def get_list(cls, tenant_id, page_number, items_per_page, orderby, desc, id, name): chats = cls.model.select() if id: chats = chats.where(cls.model.id == id) if name: chats = chats.where(cls.model.name == name) chats = chats.where((cls.model.tenant_id == tenant_id) & (cls.model.status == StatusEnum.VALID.value)) if desc: chats = chats.order_by(cls.model.getter_by(orderby).desc()) else: chats = chats.order_by(cls.model.getter_by(orderby).asc()) chats = chats.paginate(page_number, items_per_page) return list(chats.dicts()) @classmethod @DB.connection_context() def get_by_tenant_ids(cls, joined_tenant_ids, user_id, page_number, items_per_page, orderby, desc, keywords, parser_id=None): from api.db.db_models import User fields = [ cls.model.id, cls.model.tenant_id, cls.model.name, cls.model.description, cls.model.language, cls.model.llm_id, cls.model.llm_setting, cls.model.prompt_type, cls.model.prompt_config, cls.model.similarity_threshold, cls.model.vector_similarity_weight, cls.model.top_n, cls.model.top_k, cls.model.do_refer, cls.model.rerank_id, cls.model.kb_ids, cls.model.icon, cls.model.status, User.nickname, User.avatar.alias("tenant_avatar"), cls.model.update_time, cls.model.create_time, ] if keywords: dialogs = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), (fn.LOWER(cls.model.name).contains(keywords.lower())), ) ) else: dialogs = ( cls.model.select(*fields) .join(User, on=(cls.model.tenant_id == User.id)) .where( (cls.model.tenant_id.in_(joined_tenant_ids) | (cls.model.tenant_id == user_id)) & (cls.model.status == StatusEnum.VALID.value), ) ) if parser_id: dialogs = dialogs.where(cls.model.parser_id == parser_id) if desc: dialogs = dialogs.order_by(cls.model.getter_by(orderby).desc()) else: dialogs = dialogs.order_by(cls.model.getter_by(orderby).asc()) count = dialogs.count() if page_number and items_per_page: dialogs = dialogs.paginate(page_number, items_per_page) return list(dialogs.dicts()), count @classmethod @DB.connection_context() def get_all_dialogs_by_tenant_id(cls, tenant_id): fields = [cls.model.id] dialogs = cls.model.select(*fields).where(cls.model.tenant_id == tenant_id) dialogs.order_by(cls.model.create_time.asc()) offset, limit = 0, 100 res = [] while True: d_batch = dialogs.offset(offset).limit(limit) _temp = list(d_batch.dicts()) if not _temp: break res.extend(_temp) offset += limit return res def chat_solo(dialog, messages, stream=True): if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) prompt_config = dialog.prompt_config tts_mdl = None if prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"] if stream: last_ans = "" delta_ans = "" for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting): answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} delta_ans = "" if delta_ans: yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()} else: answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, answer), "prompt": "", "created_at": time.time()} def get_models(dialog): embd_mdl, chat_mdl, rerank_mdl, tts_mdl = None, None, None, None kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) if len(embedding_list) > 1: raise Exception("**ERROR**: Knowledge bases use different embedding models.") if embedding_list: embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_list[0]) if not embd_mdl: raise LookupError("Embedding model(%s) not found" % embedding_list[0]) if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) if dialog.rerank_id: rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id) if dialog.prompt_config.get("tts"): tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS) return kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl BAD_CITATION_PATTERNS = [ re.compile(r"\(\s*ID\s*[: ]*\s*(\d+)\s*\)"), # (ID: 12) re.compile(r"\[\s*ID\s*[: ]*\s*(\d+)\s*\]"), # [ID: 12] re.compile(r"【\s*ID\s*[: ]*\s*(\d+)\s*】"), # 【ID: 12】 re.compile(r"ref\s*(\d+)", flags=re.IGNORECASE), # ref12、REF 12 ] def repair_bad_citation_formats(answer: str, kbinfos: dict, idx: set): max_index = len(kbinfos["chunks"]) def safe_add(i): if 0 <= i < max_index: idx.add(i) return True return False def find_and_replace(pattern, group_index=1, repl=lambda i: f"ID:{i}", flags=0): nonlocal answer def replacement(match): try: i = int(match.group(group_index)) if safe_add(i): return f"[{repl(i)}]" except Exception: pass return match.group(0) answer = re.sub(pattern, replacement, answer, flags=flags) for pattern in BAD_CITATION_PATTERNS: find_and_replace(pattern) return answer, idx def convert_conditions(metadata_condition): if metadata_condition is None: metadata_condition = {} op_mapping = { "is": "=", "not is": "≠" } return [ { "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), "key": cond["name"], "value": cond["value"] } for cond in metadata_condition.get("conditions", []) ] def meta_filter(metas: dict, filters: list[dict]): doc_ids = set([]) def filter_out(v2docs, operator, value): ids = [] for input, docids in v2docs.items(): try: input = float(input) value = float(value) except Exception: input = str(input) value = str(value) for conds in [ (operator == "contains", str(value).lower() in str(input).lower()), (operator == "not contains", str(value).lower() not in str(input).lower()), (operator == "start with", str(input).lower().startswith(str(value).lower())), (operator == "end with", str(input).lower().endswith(str(value).lower())), (operator == "empty", not input), (operator == "not empty", input), (operator == "=", input == value), (operator == "≠", input != value), (operator == ">", input > value), (operator == "<", input < value), (operator == "≥", input >= value), (operator == "≤", input <= value), ]: try: if all(conds): ids.extend(docids) break except Exception: pass return ids for k, v2docs in metas.items(): for f in filters: if k != f["key"]: continue ids = filter_out(v2docs, f["op"], f["value"]) if not doc_ids: doc_ids = set(ids) else: doc_ids = doc_ids & set(ids) if not doc_ids: return [] return list(doc_ids) def chat(dialog, messages, stream=True, **kwargs): assert messages[-1]["role"] == "user", "The last content of this conversation is not from user." if not dialog.kb_ids and not dialog.prompt_config.get("tavily_api_key"): for ans in chat_solo(dialog, messages, stream): yield ans return chat_start_ts = timer() if TenantLLMService.llm_id2llm_type(dialog.llm_id) == "image2text": llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id) else: llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id) max_tokens = llm_model_config.get("max_tokens", 8192) check_llm_ts = timer() langfuse_tracer = None trace_context = {} langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id) if langfuse_keys: langfuse = Langfuse(public_key=langfuse_keys.public_key, secret_key=langfuse_keys.secret_key, host=langfuse_keys.host) if langfuse.auth_check(): langfuse_tracer = langfuse trace_id = langfuse_tracer.create_trace_id() trace_context = {"trace_id": trace_id} check_langfuse_tracer_ts = timer() kbs, embd_mdl, rerank_mdl, chat_mdl, tts_mdl = get_models(dialog) toolcall_session, tools = kwargs.get("toolcall_session"), kwargs.get("tools") if toolcall_session and tools: chat_mdl.bind_tools(toolcall_session, tools) bind_models_ts = timer() retriever = settings.retriever questions = [m["content"] for m in messages if m["role"] == "user"][-3:] attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else [] if "doc_ids" in messages[-1]: attachments = messages[-1]["doc_ids"] prompt_config = dialog.prompt_config field_map = KnowledgebaseService.get_field_map(dialog.kb_ids) # try to use sql if field mapping is good to go if field_map: logging.debug("Use SQL to retrieval:{}".format(questions[-1])) ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True), dialog.kb_ids) if ans: yield ans return for p in prompt_config["parameters"]: if p["key"] == "knowledge": continue if p["key"] not in kwargs and not p["optional"]: raise KeyError("Miss parameter: " + p["key"]) if p["key"] not in kwargs: prompt_config["system"] = prompt_config["system"].replace("{%s}" % p["key"], " ") if len(questions) > 1 and prompt_config.get("refine_multiturn"): questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)] else: questions = questions[-1:] if prompt_config.get("cross_languages"): questions = [cross_languages(dialog.tenant_id, dialog.llm_id, questions[0], prompt_config["cross_languages"])] if dialog.meta_data_filter: metas = DocumentService.get_meta_by_kbs(dialog.kb_ids) if dialog.meta_data_filter.get("method") == "auto": filters = gen_meta_filter(chat_mdl, metas, questions[-1]) attachments.extend(meta_filter(metas, filters)) if not attachments: attachments = None elif dialog.meta_data_filter.get("method") == "manual": attachments.extend(meta_filter(metas, dialog.meta_data_filter["manual"])) if not attachments: attachments = None if prompt_config.get("keyword", False): questions[-1] += keyword_extraction(chat_mdl, questions[-1]) refine_question_ts = timer() thought = "" kbinfos = {"total": 0, "chunks": [], "doc_aggs": []} knowledges = [] if attachments is not None and "knowledge" in [p["key"] for p in prompt_config["parameters"]]: tenant_ids = list(set([kb.tenant_id for kb in kbs])) knowledges = [] if prompt_config.get("reasoning", False): reasoner = DeepResearcher( chat_mdl, prompt_config, partial( retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3, doc_ids=attachments, ), ) for think in reasoner.thinking(kbinfos, " ".join(questions)): if isinstance(think, str): thought = think knowledges = [t for t in think.split("\n") if t] elif stream: yield think else: if embd_mdl: kbinfos = retriever.retrieval( " ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n, dialog.similarity_threshold, dialog.vector_similarity_weight, doc_ids=attachments, top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl, rank_feature=label_question(" ".join(questions), kbs), ) if prompt_config.get("toc_enhance"): cks = retriever.retrieval_by_toc(" ".join(questions), kbinfos["chunks"], tenant_ids, chat_mdl, dialog.top_n) if cks: kbinfos["chunks"] = cks if prompt_config.get("tavily_api_key"): tav = Tavily(prompt_config["tavily_api_key"]) tav_res = tav.retrieve_chunks(" ".join(questions)) kbinfos["chunks"].extend(tav_res["chunks"]) kbinfos["doc_aggs"].extend(tav_res["doc_aggs"]) if prompt_config.get("use_kg"): ck = settings.kg_retriever.retrieval(" ".join(questions), tenant_ids, dialog.kb_ids, embd_mdl, LLMBundle(dialog.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: kbinfos["chunks"].insert(0, ck) knowledges = kb_prompt(kbinfos, max_tokens) logging.debug("{}->{}".format(" ".join(questions), "\n->".join(knowledges))) retrieval_ts = timer() if not knowledges and prompt_config.get("empty_response"): empty_res = prompt_config["empty_response"] yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)} return {"answer": prompt_config["empty_response"], "reference": kbinfos} kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges) gen_conf = dialog.llm_setting msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}] prompt4citation = "" if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): prompt4citation = citation_prompt() msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])} for m in messages if m["role"] != "system"]) used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95)) assert len(msg) >= 2, f"message_fit_in has bug: {msg}" prompt = msg[0]["content"] if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) def decorate_answer(answer): nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer refs = [] ans = answer.split("") think = "" if len(ans) == 2: think = ans[0] + "" answer = ans[1] if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)): idx = set([]) if embd_mdl and not re.search(r"\[ID:([0-9]+)\]", answer): answer, idx = retriever.insert_citations( answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=1 - dialog.vector_similarity_weight, vtweight=dialog.vector_similarity_weight, ) else: for match in re.finditer(r"\[ID:([0-9]+)\]", answer): i = int(match.group(1)) if i < len(kbinfos["chunks"]): idx.add(i) answer, idx = repair_bad_citation_formats(answer, kbinfos, idx) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs refs = deepcopy(kbinfos) for c in refs["chunks"]: if c.get("vector"): del c["vector"] if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'" finish_chat_ts = timer() total_time_cost = (finish_chat_ts - chat_start_ts) * 1000 check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000 check_langfuse_tracer_cost = (check_langfuse_tracer_ts - check_llm_ts) * 1000 bind_embedding_time_cost = (bind_models_ts - check_langfuse_tracer_ts) * 1000 refine_question_time_cost = (refine_question_ts - bind_models_ts) * 1000 retrieval_time_cost = (retrieval_ts - refine_question_ts) * 1000 generate_result_time_cost = (finish_chat_ts - retrieval_ts) * 1000 tk_num = num_tokens_from_string(think + answer) prompt += "\n\n### Query:\n%s" % " ".join(questions) prompt = ( f"{prompt}\n\n" "## Time elapsed:\n" f" - Total: {total_time_cost:.1f}ms\n" f" - Check LLM: {check_llm_time_cost:.1f}ms\n" f" - Check Langfuse tracer: {check_langfuse_tracer_cost:.1f}ms\n" f" - Bind models: {bind_embedding_time_cost:.1f}ms\n" f" - Query refinement(LLM): {refine_question_time_cost:.1f}ms\n" f" - Retrieval: {retrieval_time_cost:.1f}ms\n" f" - Generate answer: {generate_result_time_cost:.1f}ms\n\n" "## Token usage:\n" f" - Generated tokens(approximately): {tk_num}\n" f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s" ) # Add a condition check to call the end method only if langfuse_tracer exists if langfuse_tracer and "langfuse_generation" in locals(): langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL) langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()} langfuse_generation.update(output=langfuse_output) langfuse_generation.end() return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if langfuse_tracer: langfuse_generation = langfuse_tracer.start_generation( trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} ) if stream: last_ans = "" answer = "" for ans in chat_mdl.chat_streamly(prompt + prompt4citation, msg[1:], gen_conf): if thought: ans = re.sub(r"^.*", "", ans, flags=re.DOTALL) answer = ans delta_ans = ans[len(last_ans):] if num_tokens_from_string(delta_ans) < 16: continue last_ans = answer yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} delta_ans = answer[len(last_ans):] if delta_ans: yield {"answer": thought + answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)} yield decorate_answer(thought + answer) else: answer = chat_mdl.chat(prompt + prompt4citation, msg[1:], gen_conf) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) res = decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res def use_sql(question, field_map, tenant_id, chat_mdl, quota=True, kb_ids=None): sys_prompt = "You are a Database Administrator. You need to check the fields of the following tables based on the user's list of questions and write the SQL corresponding to the last question." user_prompt = """ Table name: {}; Table of database fields are as follows: {} Question are as follows: {} Please write the SQL, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question) tried_times = 0 def get_table(): nonlocal sys_prompt, user_prompt, question, tried_times sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {"temperature": 0.06}) sql = re.sub(r"^.*", "", sql, flags=re.DOTALL) logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}") sql = re.sub(r"[\r\n]+", " ", sql.lower()) sql = re.sub(r".*select ", "select ", sql.lower()) sql = re.sub(r" +", " ", sql) sql = re.sub(r"([;;]|```).*", "", sql) if sql[: len("select ")] != "select ": return None, None if not re.search(r"((sum|avg|max|min)\(|group by )", sql.lower()): if sql[: len("select *")] != "select *": sql = "select doc_id,docnm_kwd," + sql[6:] else: flds = [] for k in field_map.keys(): if k in forbidden_select_fields4resume: continue if len(flds) > 11: break flds.append(k) sql = "select doc_id,docnm_kwd," + ",".join(flds) + sql[8:] if kb_ids: kb_filter = "(" + " OR ".join([f"kb_id = '{kb_id}'" for kb_id in kb_ids]) + ")" if "where" not in sql.lower(): sql += f" WHERE {kb_filter}" else: sql += f" AND {kb_filter}" logging.debug(f"{question} get SQL(refined): {sql}") tried_times += 1 return settings.retriever.sql_retrieval(sql, format="json"), sql tbl, sql = get_table() if tbl is None: return None if tbl.get("error") and tried_times <= 2: user_prompt = """ Table name: {}; Table of database fields are as follows: {} Question are as follows: {} Please write the SQL, only SQL, without any other explanations or text. The SQL error you provided last time is as follows: {} Error issued by database as follows: {} Please correct the error and write SQL again, only SQL, without any other explanations or text. """.format(index_name(tenant_id), "\n".join([f"{k}: {v}" for k, v in field_map.items()]), question, sql, tbl["error"]) tbl, sql = get_table() logging.debug("TRY it again: {}".format(sql)) logging.debug("GET table: {}".format(tbl)) if tbl.get("error") or len(tbl["rows"]) == 0: return None docid_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "doc_id"]) doc_name_idx = set([ii for ii, c in enumerate(tbl["columns"]) if c["name"] == "docnm_kwd"]) column_idx = [ii for ii in range(len(tbl["columns"])) if ii not in (docid_idx | doc_name_idx)] # compose Markdown table columns = ( "|" + "|".join( [re.sub(r"(/.*|([^()]+))", "", field_map.get(tbl["columns"][i]["name"], tbl["columns"][i]["name"])) for i in column_idx]) + ( "|Source|" if docid_idx and docid_idx else "|") ) line = "|" + "|".join(["------" for _ in range(len(column_idx))]) + ("|------|" if docid_idx and docid_idx else "") rows = ["|" + "|".join([remove_redundant_spaces(str(r[i])) for i in column_idx]).replace("None", " ") + "|" for r in tbl["rows"]] rows = [r for r in rows if re.sub(r"[ |]+", "", r)] if quota: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) else: rows = "\n".join([r + f" ##{ii}$$ |" for ii, r in enumerate(rows)]) rows = re.sub(r"T[0-9]{2}:[0-9]{2}:[0-9]{2}(\.[0-9]+Z)?\|", "|", rows) if not docid_idx or not doc_name_idx: logging.warning("SQL missing field: " + sql) return {"answer": "\n".join([columns, line, rows]), "reference": {"chunks": [], "doc_aggs": []}, "prompt": sys_prompt} docid_idx = list(docid_idx)[0] doc_name_idx = list(doc_name_idx)[0] doc_aggs = {} for r in tbl["rows"]: if r[docid_idx] not in doc_aggs: doc_aggs[r[docid_idx]] = {"doc_name": r[doc_name_idx], "count": 0} doc_aggs[r[docid_idx]]["count"] += 1 return { "answer": "\n".join([columns, line, rows]), "reference": { "chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[doc_name_idx]} for r in tbl["rows"]], "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in doc_aggs.items()], }, "prompt": sys_prompt, } def tts(tts_mdl, text): if not tts_mdl or not text: return bin = b"" for chunk in tts_mdl.tts(text): bin += chunk return binascii.hexlify(bin).decode("utf-8") def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): doc_ids = search_config.get("doc_ids", []) rerank_mdl = None kb_ids = search_config.get("kb_ids", kb_ids) chat_llm_name = search_config.get("chat_id", chat_llm_name) rerank_id = search_config.get("rerank_id", "") meta_data_filter = search_config.get("meta_data_filter") kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) is_knowledge_graph = all([kb.parser_id == ParserType.KG for kb in kbs]) retriever = settings.retriever if not is_knowledge_graph else settings.kg_retriever embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) if rerank_id: rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": filters = gen_meta_filter(chat_mdl, metas, question) doc_ids.extend(meta_filter(metas, filters)) if not doc_ids: doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) if not doc_ids: doc_ids = None kbinfos = retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=kb_ids, page=1, page_size=12, similarity_threshold=search_config.get("similarity_threshold", 0.1), vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), top=search_config.get("top_k", 1024), doc_ids=doc_ids, aggs=False, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs) ) knowledges = kb_prompt(kbinfos, max_tokens) sys_prompt = PROMPT_JINJA_ENV.from_string(ASK_SUMMARY).render(knowledge="\n".join(knowledges)) msg = [{"role": "user", "content": question}] def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] if not recall_docs: recall_docs = kbinfos["doc_aggs"] kbinfos["doc_aggs"] = recall_docs refs = deepcopy(kbinfos) for c in refs["chunks"]: if c.get("vector"): del c["vector"] if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0: answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'" refs["chunks"] = chunks_format(refs) return {"answer": answer, "reference": refs} answer = "" for ans in chat_mdl.chat_streamly(sys_prompt, msg, {"temperature": 0.1}): answer = ans yield {"answer": answer, "reference": {}} yield decorate_answer(answer) def gen_mindmap(question, kb_ids, tenant_id, search_config={}): meta_data_filter = search_config.get("meta_data_filter", {}) doc_ids = search_config.get("doc_ids", []) rerank_id = search_config.get("rerank_id", "") rerank_mdl = None kbs = KnowledgebaseService.get_by_ids(kb_ids) if not kbs: return {"error": "No KB selected"} embedding_list = list(set([kb.embd_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs])) embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) if rerank_id: rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) if meta_data_filter: metas = DocumentService.get_meta_by_kbs(kb_ids) if meta_data_filter.get("method") == "auto": filters = gen_meta_filter(chat_mdl, metas, question) doc_ids.extend(meta_filter(metas, filters)) if not doc_ids: doc_ids = None elif meta_data_filter.get("method") == "manual": doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) if not doc_ids: doc_ids = None ranks = settings.retriever.retrieval( question=question, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=kb_ids, page=1, page_size=12, similarity_threshold=search_config.get("similarity_threshold", 0.2), vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3), top=search_config.get("top_k", 1024), doc_ids=doc_ids, aggs=False, rerank_mdl=rerank_mdl, rank_feature=label_question(question, kbs), ) mindmap = MindMapExtractor(chat_mdl) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) return mind_map.output