diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 24bdf3516..39f80ff78 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -16,8 +16,10 @@ import json import re import time + import tiktoken from flask import Response, jsonify, request + from agent.canvas import Canvas from api import settings from api.db import LLMType, StatusEnum @@ -27,7 +29,8 @@ from api.db.services.canvas_service import UserCanvasService, completionOpenAI from api.db.services.canvas_service import completion as agent_completion from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import completion as rag_completion -from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap +from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter +from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService @@ -37,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_ from rag.app.tag import label_question from rag.prompts import chunks_format from rag.prompts.prompt_template import load_prompt -from rag.prompts.prompts import cross_languages, keyword_extraction +from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @@ -81,10 +84,10 @@ def create_agent_session(tenant_id, agent_id): if not isinstance(cvs.dsl, str): cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) - session_id=get_uuid() + session_id = get_uuid() canvas = Canvas(cvs.dsl, tenant_id, agent_id) canvas.reset() - + cvs.dsl = json.loads(str(canvas)) conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} API4ConversationService.save(**conv) @@ -442,26 +445,46 @@ def agents_completion_openai_compatibility(tenant_id, agent_id): def agent_completions(tenant_id, agent_id): req = request.json + ans = {} + if req.get("stream", True): + + def generate(): + for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): + if isinstance(answer, str): + try: + ans = json.loads(answer[5:]) # remove "data:" + except Exception: + continue + + if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None): + continue + + yield answer + + yield "data:[DONE]\n\n" if req.get("stream", True): - resp = Response(agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req), mimetype="text/event-stream") + resp = Response(generate(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp - result = {} + + full_content = "" for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): try: - ans = json.loads(answer[5:]) # remove "data:" - if not result: - result = ans.copy() - else: - result["data"]["answer"] += ans["data"]["answer"] - result["data"]["reference"] = ans["data"].get("reference", []) + ans = json.loads(answer[5:]) + + if ans["event"] == "message": + full_content += ans["data"]["content"] + + if ans.get("data", {}).get("reference", None): + ans["data"]["content"] = full_content + return get_result(data=ans) except Exception as e: - return get_error_data_result(str(e)) - return result + return get_result(data=f"**ERROR**: {str(e)}") + return get_result(data=ans) @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @@ -556,10 +579,7 @@ def list_agent_session(tenant_id, agent_id): if message_num != 0 and messages[message_num]["role"] != "user": chunk_list = [] # Add boundary and type checks to prevent KeyError - if (chunk_num < len(conv["reference"]) and - conv["reference"][chunk_num] is not None and - isinstance(conv["reference"][chunk_num], dict) and - "chunks" in conv["reference"][chunk_num]): + if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]: chunks = conv["reference"][chunk_num]["chunks"] for chunk in chunks: # Ensure chunk is a dictionary before calling get method @@ -860,15 +880,7 @@ def begin_inputs(agent_id): return get_error_data_result(f"Can't find agent by ID: {agent_id}") canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id) - return get_result( - data={ - "title": cvs.title, - "avatar": cvs.avatar, - "inputs": canvas.get_component_input_form("begin"), - "prologue": canvas.get_prologue(), - "mode": canvas.get_mode() - } - ) + return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()}) @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 @@ -908,7 +920,7 @@ def ask_about_embedded(): return resp -@manager.route("/searchbots/retrieval_test", methods=['POST']) # noqa: F821 +@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821 @validate_request("kb_id", "question") def retrieval_test_embedded(): token = request.headers.get("Authorization").split() @@ -938,18 +950,30 @@ def retrieval_test_embedded(): if not tenant_id: return get_error_data_result(message="permission denined.") + if req.get("search_id", ""): + search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {}) + meta_data_filter = search_config.get("meta_data_filter", {}) + metas = DocumentService.get_meta_by_kbs(kb_ids) + if meta_data_filter.get("method") == "auto": + chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", "")) + filters = gen_meta_filter(chat_mdl, metas, question) + doc_ids.extend(meta_filter(metas, filters)) + if not doc_ids: + doc_ids = None + elif meta_data_filter.get("method") == "manual": + doc_ids.extend(meta_filter(metas, meta_data_filter["manual"])) + if not doc_ids: + doc_ids = None + try: tenants = UserTenantService.query(user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: - if KnowledgebaseService.query( - tenant_id=tenant.tenant_id, id=kb_id): + if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): tenant_ids.append(tenant.tenant_id) break else: - return get_json_result( - data=False, message='Only owner of knowledgebase authorized for this operation.', - code=settings.RetCode.OPERATING_ERROR) + return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR) e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) if not e: @@ -969,17 +993,11 @@ def retrieval_test_embedded(): question += keyword_extraction(chat_mdl, question) labels = label_question(question, [kb]) - ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, - similarity_threshold, vector_similarity_weight, top, - doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), - rank_feature=labels - ) + ranks = settings.retrievaler.retrieval( + question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels + ) if use_kg: - ck = settings.kg_retrievaler.retrieval(question, - tenant_ids, - kb_ids, - embd_mdl, - LLMBundle(kb.tenant_id, LLMType.CHAT)) + ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT)) if ck["content_with_weight"]: ranks["chunks"].insert(0, ck) @@ -990,8 +1008,7 @@ def retrieval_test_embedded(): return get_json_result(data=ranks) except Exception as e: if str(e).find("not_found") > 0: - return get_json_result(data=False, message='No chunk found! Check the chunk status please!', - code=settings.RetCode.DATA_ERROR) + return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR) return server_error_response(e) diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index e5202d8fb..578287e7e 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -135,24 +135,6 @@ class UserCanvasService(CommonService): return True -def structure_answer(conv, ans, message_id, session_id): - if not conv: - return ans - content = "" - if ans["event"] == "message": - if ans["data"].get("start_to_think") is True: - content = "" - elif ans["data"].get("end_to_think") is True: - content = "" - else: - content = ans["data"]["content"] - - reference = ans["data"].get("reference") - result = {"id": message_id, "session_id": session_id, "answer": content} - if reference: - result["reference"] = [reference] - return result - def completion(tenant_id, agent_id, session_id=None, **kwargs): query = kwargs.get("query", "") or kwargs.get("question", "") files = kwargs.get("files", []) @@ -196,14 +178,13 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs): }) txt = "" for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): - ans = structure_answer(conv, ans, message_id, session_id) - txt += ans["answer"] - if ans.get("answer") or ans.get("reference"): - yield "data:" + json.dumps({"code": 0, "data": ans}, - ensure_ascii=False) + "\n\n" + ans["session_id"] = session_id + if ans["event"] == "message": + txt += ans["data"]["content"] + yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id}) - conv.reference.append(canvas.get_reference()) + conv.reference = canvas.get_reference() conv.errors = canvas.error conv.dsl = str(canvas) conv = conv.to_dict() @@ -232,9 +213,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True except Exception as e: logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}") continue - if not ans["data"]["answer"]: + if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None): continue - content_piece = ans["data"]["answer"] + content_piece = ans["data"]["content"] completion_tokens += len(tiktokenenc.encode(content_piece)) yield "data: " + json.dumps( @@ -279,9 +260,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True ): if isinstance(ans, str): ans = json.loads(ans[5:]) - if not ans["data"]["answer"]: + if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None): continue - all_content += ans["data"]["answer"] + all_content += ans["data"]["content"] completion_tokens = len(tiktokenenc.encode(all_content))