diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 534a003dc..6c59ab128 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -29,6 +29,7 @@ from api.db.services.conversation_service import ConversationService, structure_ from api.db.services.dialog_service import DialogService, ask, chat from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from api.db.services.search_service import SearchService from api.db.services.tenant_llm_service import TenantLLMService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request @@ -344,10 +345,18 @@ def ask_about(): req = request.json uid = current_user.id + search_id = req.get("search_id", "") + search_app = None + search_config = {} + if search_id: + search_app = SearchService.get_detail(search_id) + if search_app: + search_config = search_app.get("search_config", {}) + def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid): + for ans in ask(req["question"], req["kb_ids"], uid, search_config=search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -366,15 +375,68 @@ def ask_about(): @validate_request("question", "kb_ids") def mindmap(): req = request.json + + search_id = req.get("search_id", "") + search_app = None + search_config = {} + if search_id: + search_app = SearchService.get_detail(search_id) + if search_app: + search_config = search_app.get("search_config", {}) + kb_ids = req["kb_ids"] + if search_config.get("kb_ids", []): + kb_ids = search_config.get("kb_ids", []) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_data_error_result(message="Knowledgebase not found!") - embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) - chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) + chat_id = "" + similarity_threshold = 0.3, + vector_similarity_weight = 0.3, + top = 1024, + doc_ids = [] + rerank_id = "" + rerank_mdl = None + + if search_config: + if search_config.get("chat_id", ""): + chat_id = search_config.get("chat_id", "") + if search_config.get("similarity_threshold", 0.2): + similarity_threshold = search_config.get("similarity_threshold", 0.2) + if search_config.get("vector_similarity_weight", 0.3): + vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3) + if search_config.get("top_k", 1024): + top = search_config.get("top_k", 1024) + if search_config.get("doc_ids", []): + doc_ids = search_config.get("doc_ids", []) + if search_config.get("rerank_id", ""): + rerank_id = search_config.get("rerank_id", "") + + tenant_id = kb.tenant_id + if search_app and search_app.get("tenant_id", ""): + tenant_id = search_app.get("tenant_id", "") + + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id) + if rerank_id: + rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) question = req["question"] - ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb])) + ranks = settings.retrievaler.retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_id, + kb_ids=kb_ids, + page=1, + page_size=12, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=doc_ids, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(question, [kb]), + ) mindmap = MindMapExtractor(chat_mdl) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = mind_map.output @@ -388,8 +450,19 @@ def mindmap(): @validate_request("question") def related_questions(): req = request.json + + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + question = req["question"] - chat_mdl = LLMBundle(current_user.id, LLMType.CHAT) + + chat_id = search_config.get("chat_id", "") + chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, chat_id) + + gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) prompt = load_prompt("related_question") ans = chat_mdl.chat( prompt, @@ -402,6 +475,6 @@ Related search terms: """, } ], - {"temperature": 0.9}, + gen_conf, ) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 8df1812b8..16346595f 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -902,10 +902,16 @@ def ask_about_embedded(): req = request.json uid = objs[0].tenant_id + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + def stream(): nonlocal req, uid try: - for ans in ask(req["question"], req["kb_ids"], uid): + for ans in ask(req["question"], req["kb_ids"], uid, search_config): yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" except Exception as e: yield "data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -1021,8 +1027,19 @@ def related_questions_embedded(): tenant_id = objs[0].tenant_id if not tenant_id: return get_error_data_result(message="permission denined.") + + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + question = req["question"] - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + + chat_id = search_config.get("chat_id", "") + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_id) + + gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) prompt = load_prompt("related_question") ans = chat_mdl.chat( prompt, @@ -1035,7 +1052,7 @@ Related search terms: """, } ], - {"temperature": 0.9}, + gen_conf, ) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) @@ -1083,15 +1100,62 @@ def mindmap(): tenant_id = objs[0].tenant_id req = request.json + + search_id = req.get("search_id", "") + search_config = {} + if search_id: + if search_app := SearchService.get_detail(search_id): + search_config = search_app.get("search_config", {}) + kb_ids = req["kb_ids"] + if search_config.get("kb_ids", []): + kb_ids = search_config.get("kb_ids", []) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: return get_error_data_result(message="Knowledgebase not found!") - embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) - chat_mdl = LLMBundle(tenant_id, LLMType.CHAT) + chat_id = "" + similarity_threshold = 0.3, + vector_similarity_weight = 0.3, + top = 1024, + doc_ids = [] + rerank_id = "" + rerank_mdl = None + + if search_config: + if search_config.get("chat_id", ""): + chat_id = search_config.get("chat_id", "") + if search_config.get("similarity_threshold", 0.2): + similarity_threshold = search_config.get("similarity_threshold", 0.2) + if search_config.get("vector_similarity_weight", 0.3): + vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3) + if search_config.get("top_k", 1024): + top = search_config.get("top_k", 1024) + if search_config.get("doc_ids", []): + doc_ids = search_config.get("doc_ids", []) + if search_config.get("rerank_id", ""): + rerank_id = search_config.get("rerank_id", "") + + embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id) + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=chat_id) + if rerank_id: + rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) question = req["question"] - ranks = settings.retrievaler.retrieval(question, embd_mdl, kb.tenant_id, kb_ids, 1, 12, 0.3, 0.3, aggs=False, rank_feature=label_question(question, [kb])) + ranks = settings.retrievaler.retrieval( + question=question, + embd_mdl=embd_mdl, + tenant_ids=tenant_id, + kb_ids=kb_ids, + page=1, + page_size=12, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=doc_ids, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(question, [kb]), + ) mindmap = MindMapExtractor(chat_mdl) mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]]) mind_map = mind_map.output diff --git a/api/db/db_models.py b/api/db/db_models.py index cdd946eb7..6a1291baa 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -872,7 +872,7 @@ class Search(DataBaseModel): default={ "kb_ids": [], "doc_ids": [], - "similarity_threshold": 0.0, + "similarity_threshold": 0.2, "vector_similarity_weight": 0.3, "use_kg": False, # rerank settings diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 0f637bc70..97e508960 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -687,7 +687,30 @@ def tts(tts_mdl, text): return binascii.hexlify(bin).decode("utf-8") -def ask(question, kb_ids, tenant_id, chat_llm_name=None): +def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): + similarity_threshold = 0.1, + vector_similarity_weight = 0.3, + top = 1024, + doc_ids = [] + rerank_id = "" + rerank_mdl = None + + if search_config: + if search_config.get("kb_ids", []): + kb_ids = search_config.get("kb_ids", []) + if search_config.get("chat_id", ""): + chat_llm_name = search_config.get("chat_id", "") + if search_config.get("similarity_threshold", 0.1): + similarity_threshold = search_config.get("similarity_threshold", 0.1) + if search_config.get("vector_similarity_weight", 0.3): + vector_similarity_weight = search_config.get("vector_similarity_weight", 0.3) + if search_config.get("top_k", 1024): + top = search_config.get("top_k", 1024) + if search_config.get("doc_ids", []): + doc_ids = search_config.get("doc_ids", []) + if search_config.get("rerank_id", ""): + rerank_id = search_config.get("rerank_id", "") + kbs = KnowledgebaseService.get_by_ids(kb_ids) embedding_list = list(set([kb.embd_id for kb in kbs])) @@ -696,9 +719,26 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None): embd_mdl = LLMBundle(tenant_id, LLMType.EMBEDDING, embedding_list[0]) chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, chat_llm_name) + if rerank_id: + rerank_mdl = LLMBundle(tenant_id, LLMType.RERANK, rerank_id) max_tokens = chat_mdl.max_length tenant_ids = list(set([kb.tenant_id for kb in kbs])) - kbinfos = retriever.retrieval(question, embd_mdl, tenant_ids, kb_ids, 1, 12, 0.1, 0.3, aggs=False, rank_feature=label_question(question, kbs)) + kbinfos = retriever.retrieval( + question = question, + embd_mdl=embd_mdl, + tenant_ids=tenant_ids, + kb_ids=kb_ids, + page=1, + page_size=12, + similarity_threshold=similarity_threshold, + vector_similarity_weight=vector_similarity_weight, + top=top, + doc_ids=doc_ids, + aggs=False, + rerank_mdl=rerank_mdl, + rank_feature=label_question(question, kbs) + ) + knowledges = kb_prompt(kbinfos, max_tokens) prompt = """ Role: You're a smart assistant. Your name is Miss R.