diff --git a/api/apps/chunk_app.py b/api/apps/chunk_app.py index 7d129539c..14ea25771 100644 --- a/api/apps/chunk_app.py +++ b/api/apps/chunk_app.py @@ -23,15 +23,18 @@ from flask_login import current_user, login_required from api import settings from api.db import LLMType, ParserType +from api.db.services.dialog_service import meta_filter from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle +from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request from rag.app.qa import beAdoc, rmPrefix from rag.app.tag import label_question from rag.nlp import rag_tokenizer, search from rag.prompts import cross_languages, keyword_extraction +from rag.prompts.prompts import gen_meta_filter from rag.settings import PAGERANK_FLD from rag.utils import rmSpace @@ -288,13 +291,26 @@ def retrieval_test(): if isinstance(kb_ids, str): kb_ids = [kb_ids] doc_ids = req.get("doc_ids", []) - similarity_threshold = float(req.get("similarity_threshold", 0.0)) - vector_similarity_weight = float(req.get("vector_similarity_weight", 0.3)) use_kg = req.get("use_kg", False) top = int(req.get("top_k", 1024)) langs = req.get("cross_languages", []) tenant_ids = [] + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(current_user.id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters = gen_meta_filter(chat_mdl, metas, question) + doc_ids.extend(meta_filter(metas, filters)) + if not doc_ids: + doc_ids = None + elif meta_data_filter.get("method") == "manual": + doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) + if not doc_ids: + doc_ids = None + try: tenants = UserTenantService.query(user_id=current_user.id) for kb_id in kb_ids: @@ -327,7 +343,9 @@ def retrieval_test(): labels = label_question(question, [kb]) ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, + float(req.get("similarity_threshold", 0.0)), + float(req.get("vector_similarity_weight", 0.3)), + top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) diff --git a/api/apps/conversation_app.py b/api/apps/conversation_app.py index 4f18f26e2..790172f77 100644 --- a/api/apps/conversation_app.py +++ b/api/apps/conversation_app.py @@ -372,7 +372,9 @@ def mindmap(): search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} - kb_ids = search_config.get("kb_ids", req["kb_ids"]) + kb_ids = search_config.get("kb_ids", []) + kb_ids.extend(req["kb_ids"]) + kb_ids = list(set(kb_ids)) mind_map = gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) if "error" in mind_map: diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 2a9d30607..fc369f2c1 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -771,10 +771,11 @@ def ask(question, kb_ids, tenant_id, chat_llm_name=None, search_config={}): def gen_mindmap(question, kb_ids, tenant_id, search_config={}): meta_data_filter = search_config.get("meta_data_filter", {}) doc_ids = search_config.get("doc_ids", []) - kb_ids = search_config.get("doc_ids", kb_ids) rerank_id = search_config.get("rerank_id", "") rerank_mdl = None kbs = KnowledgebaseService.get_by_ids(kb_ids) + if not kbs: + return {"error": "No KB selected"} embedding_list = list(set([kb.embd_id for kb in kbs])) tenant_ids = list(set([kb.tenant_id for kb in kbs]))