Revert: broken agent completion by #9631 (#9760)

### What problem does this PR solve?

Revert broken agent completion by #9631.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Yongteng Lei
2025-08-27 17:16:55 +08:00
committed by GitHub
parent 986b9cbb1a
commit 6cb3e08381
2 changed files with 71 additions and 73 deletions

View File

@ -16,8 +16,10 @@
import json import json
import re import re
import time import time
import tiktoken import tiktoken
from flask import Response, jsonify, request from flask import Response, jsonify, request
from agent.canvas import Canvas from agent.canvas import Canvas
from api import settings from api import settings
from api.db import LLMType, StatusEnum from api.db import LLMType, StatusEnum
@ -27,7 +29,8 @@ from api.db.services.canvas_service import UserCanvasService, completionOpenAI
from api.db.services.canvas_service import completion as agent_completion from api.db.services.canvas_service import completion as agent_completion
from api.db.services.conversation_service import ConversationService, iframe_completion from api.db.services.conversation_service import ConversationService, iframe_completion
from api.db.services.conversation_service import completion as rag_completion from api.db.services.conversation_service import completion as rag_completion
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
from api.db.services.document_service import DocumentService
from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.llm_service import LLMBundle from api.db.services.llm_service import LLMBundle
from api.db.services.search_service import SearchService from api.db.services.search_service import SearchService
@ -37,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_
from rag.app.tag import label_question from rag.app.tag import label_question
from rag.prompts import chunks_format from rag.prompts import chunks_format
from rag.prompts.prompt_template import load_prompt from rag.prompts.prompt_template import load_prompt
from rag.prompts.prompts import cross_languages, keyword_extraction from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@ -81,10 +84,10 @@ def create_agent_session(tenant_id, agent_id):
if not isinstance(cvs.dsl, str): if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid() session_id = get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id) canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset() canvas.reset()
cvs.dsl = json.loads(str(canvas)) cvs.dsl = json.loads(str(canvas))
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl} conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [{"role": "assistant", "content": canvas.get_prologue()}], "source": "agent", "dsl": cvs.dsl}
API4ConversationService.save(**conv) API4ConversationService.save(**conv)
@ -442,26 +445,46 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
def agent_completions(tenant_id, agent_id): def agent_completions(tenant_id, agent_id):
req = request.json req = request.json
ans = {}
if req.get("stream", True):
def generate():
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
if isinstance(answer, str):
try:
ans = json.loads(answer[5:]) # remove "data:"
except Exception:
continue
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
continue
yield answer
yield "data:[DONE]\n\n"
if req.get("stream", True): if req.get("stream", True):
resp = Response(agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req), mimetype="text/event-stream") resp = Response(generate(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp return resp
result = {}
full_content = ""
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req): for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
try: try:
ans = json.loads(answer[5:]) # remove "data:" ans = json.loads(answer[5:])
if not result:
result = ans.copy() if ans["event"] == "message":
else: full_content += ans["data"]["content"]
result["data"]["answer"] += ans["data"]["answer"]
result["data"]["reference"] = ans["data"].get("reference", []) if ans.get("data", {}).get("reference", None):
ans["data"]["content"] = full_content
return get_result(data=ans)
except Exception as e: except Exception as e:
return get_error_data_result(str(e)) return get_result(data=f"**ERROR**: {str(e)}")
return result return get_result(data=ans)
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821 @manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@ -556,10 +579,7 @@ def list_agent_session(tenant_id, agent_id):
if message_num != 0 and messages[message_num]["role"] != "user": if message_num != 0 and messages[message_num]["role"] != "user":
chunk_list = [] chunk_list = []
# Add boundary and type checks to prevent KeyError # Add boundary and type checks to prevent KeyError
if (chunk_num < len(conv["reference"]) and if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
conv["reference"][chunk_num] is not None and
isinstance(conv["reference"][chunk_num], dict) and
"chunks" in conv["reference"][chunk_num]):
chunks = conv["reference"][chunk_num]["chunks"] chunks = conv["reference"][chunk_num]["chunks"]
for chunk in chunks: for chunk in chunks:
# Ensure chunk is a dictionary before calling get method # Ensure chunk is a dictionary before calling get method
@ -860,15 +880,7 @@ def begin_inputs(agent_id):
return get_error_data_result(f"Can't find agent by ID: {agent_id}") return get_error_data_result(f"Can't find agent by ID: {agent_id}")
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id) canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
return get_result( return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
data={
"title": cvs.title,
"avatar": cvs.avatar,
"inputs": canvas.get_component_input_form("begin"),
"prologue": canvas.get_prologue(),
"mode": canvas.get_mode()
}
)
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821 @manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
@ -908,7 +920,7 @@ def ask_about_embedded():
return resp return resp
@manager.route("/searchbots/retrieval_test", methods=['POST']) # noqa: F821 @manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
@validate_request("kb_id", "question") @validate_request("kb_id", "question")
def retrieval_test_embedded(): def retrieval_test_embedded():
token = request.headers.get("Authorization").split() token = request.headers.get("Authorization").split()
@ -938,18 +950,30 @@ def retrieval_test_embedded():
if not tenant_id: if not tenant_id:
return get_error_data_result(message="permission denined.") return get_error_data_result(message="permission denined.")
if req.get("search_id", ""):
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
metas = DocumentService.get_meta_by_kbs(kb_ids)
if meta_data_filter.get("method") == "auto":
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
filters = gen_meta_filter(chat_mdl, metas, question)
doc_ids.extend(meta_filter(metas, filters))
if not doc_ids:
doc_ids = None
elif meta_data_filter.get("method") == "manual":
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
if not doc_ids:
doc_ids = None
try: try:
tenants = UserTenantService.query(user_id=tenant_id) tenants = UserTenantService.query(user_id=tenant_id)
for kb_id in kb_ids: for kb_id in kb_ids:
for tenant in tenants: for tenant in tenants:
if KnowledgebaseService.query( if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
tenant_id=tenant.tenant_id, id=kb_id):
tenant_ids.append(tenant.tenant_id) tenant_ids.append(tenant.tenant_id)
break break
else: else:
return get_json_result( return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
data=False, message='Only owner of knowledgebase authorized for this operation.',
code=settings.RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
if not e: if not e:
@ -969,17 +993,11 @@ def retrieval_test_embedded():
question += keyword_extraction(chat_mdl, question) question += keyword_extraction(chat_mdl, question)
labels = label_question(question, [kb]) labels = label_question(question, [kb])
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size, ranks = settings.retrievaler.retrieval(
similarity_threshold, vector_similarity_weight, top, question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), )
rank_feature=labels
)
if use_kg: if use_kg:
ck = settings.kg_retrievaler.retrieval(question, ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
tenant_ids,
kb_ids,
embd_mdl,
LLMBundle(kb.tenant_id, LLMType.CHAT))
if ck["content_with_weight"]: if ck["content_with_weight"]:
ranks["chunks"].insert(0, ck) ranks["chunks"].insert(0, ck)
@ -990,8 +1008,7 @@ def retrieval_test_embedded():
return get_json_result(data=ranks) return get_json_result(data=ranks)
except Exception as e: except Exception as e:
if str(e).find("not_found") > 0: if str(e).find("not_found") > 0:
return get_json_result(data=False, message='No chunk found! Check the chunk status please!', return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
code=settings.RetCode.DATA_ERROR)
return server_error_response(e) return server_error_response(e)

View File

@ -135,24 +135,6 @@ class UserCanvasService(CommonService):
return True return True
def structure_answer(conv, ans, message_id, session_id):
if not conv:
return ans
content = ""
if ans["event"] == "message":
if ans["data"].get("start_to_think") is True:
content = "<think>"
elif ans["data"].get("end_to_think") is True:
content = "</think>"
else:
content = ans["data"]["content"]
reference = ans["data"].get("reference")
result = {"id": message_id, "session_id": session_id, "answer": content}
if reference:
result["reference"] = [reference]
return result
def completion(tenant_id, agent_id, session_id=None, **kwargs): def completion(tenant_id, agent_id, session_id=None, **kwargs):
query = kwargs.get("query", "") or kwargs.get("question", "") query = kwargs.get("query", "") or kwargs.get("question", "")
files = kwargs.get("files", []) files = kwargs.get("files", [])
@ -196,14 +178,13 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
}) })
txt = "" txt = ""
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
ans = structure_answer(conv, ans, message_id, session_id) ans["session_id"] = session_id
txt += ans["answer"] if ans["event"] == "message":
if ans.get("answer") or ans.get("reference"): txt += ans["data"]["content"]
yield "data:" + json.dumps({"code": 0, "data": ans}, yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
ensure_ascii=False) + "\n\n"
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id}) conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
conv.reference.append(canvas.get_reference()) conv.reference = canvas.get_reference()
conv.errors = canvas.error conv.errors = canvas.error
conv.dsl = str(canvas) conv.dsl = str(canvas)
conv = conv.to_dict() conv = conv.to_dict()
@ -232,9 +213,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
except Exception as e: except Exception as e:
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}") logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
continue continue
if not ans["data"]["answer"]: if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
continue continue
content_piece = ans["data"]["answer"] content_piece = ans["data"]["content"]
completion_tokens += len(tiktokenenc.encode(content_piece)) completion_tokens += len(tiktokenenc.encode(content_piece))
yield "data: " + json.dumps( yield "data: " + json.dumps(
@ -279,9 +260,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
): ):
if isinstance(ans, str): if isinstance(ans, str):
ans = json.loads(ans[5:]) ans = json.loads(ans[5:])
if not ans["data"]["answer"]: if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
continue continue
all_content += ans["data"]["answer"] all_content += ans["data"]["content"]
completion_tokens = len(tiktokenenc.encode(all_content)) completion_tokens = len(tiktokenenc.encode(all_content))