diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index ca95d5ab7..bb92b0566 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -24,6 +24,7 @@ from api.db.services.llm_service import LLMBundle from api import settings from api.utils.api_utils import validate_request, build_error_result, apikey_required from rag.app.tag import label_question +from api.db.services.dialog_service import meta_filter @manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 @@ -37,7 +38,10 @@ def retrieval(tenant_id): retrieval_setting = req.get("retrieval_setting", {}) similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) top = int(retrieval_setting.get("top_k", 1024)) - + metadata_condition = req.get("metadata_condition",{}) + metas = DocumentService.get_meta_by_kbs([kb_id]) + + doc_ids = [] try: e, kb = KnowledgebaseService.get_by_id(kb_id) @@ -45,7 +49,12 @@ def retrieval(tenant_id): return build_error_result(message="Knowledgebase not found!", code=settings.RetCode.NOT_FOUND) embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING.value, llm_name=kb.embd_id) - + print(metadata_condition) + print("after",convert_conditions(metadata_condition)) + doc_ids.extend(meta_filter(metas, convert_conditions(metadata_condition))) + print("doc_ids",doc_ids) + if not doc_ids and metadata_condition is not None: + doc_ids = ['-999'] ranks = settings.retrievaler.retrieval( question, embd_mdl, @@ -56,6 +65,7 @@ def retrieval(tenant_id): similarity_threshold=similarity_threshold, vector_similarity_weight=0.3, top=top, + doc_ids=doc_ids, rank_feature=label_question(question, [kb]) ) @@ -64,6 +74,7 @@ def retrieval(tenant_id): [tenant_id], [kb_id], embd_mdl, + doc_ids, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -90,3 +101,20 @@ def retrieval(tenant_id): ) logging.exception(e) return build_error_result(message=str(e), code=settings.RetCode.SERVER_ERROR) + +def convert_conditions(metadata_condition): + if metadata_condition is None: + metadata_condition = {} + op_mapping = { + "is": "=", + "not is": "≠" + } + return [ + { + "op": op_mapping.get(cond["comparison_operator"], cond["comparison_operator"]), + "key": cond["name"], + "value": cond["value"] + } + for cond in metadata_condition.get("conditions", []) +] +