mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
### What problem does this PR solve? Revert broken agent completion by #9631. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@ -16,8 +16,10 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api import settings
|
from api import settings
|
||||||
from api.db import LLMType, StatusEnum
|
from api.db import LLMType, StatusEnum
|
||||||
@ -27,7 +29,8 @@ from api.db.services.canvas_service import UserCanvasService, completionOpenAI
|
|||||||
from api.db.services.canvas_service import completion as agent_completion
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
from api.db.services.conversation_service import ConversationService, iframe_completion
|
from api.db.services.conversation_service import ConversationService, iframe_completion
|
||||||
from api.db.services.conversation_service import completion as rag_completion
|
from api.db.services.conversation_service import completion as rag_completion
|
||||||
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap
|
from api.db.services.dialog_service import DialogService, ask, chat, gen_mindmap, meta_filter
|
||||||
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||||
from api.db.services.llm_service import LLMBundle
|
from api.db.services.llm_service import LLMBundle
|
||||||
from api.db.services.search_service import SearchService
|
from api.db.services.search_service import SearchService
|
||||||
@ -37,7 +40,7 @@ from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_
|
|||||||
from rag.app.tag import label_question
|
from rag.app.tag import label_question
|
||||||
from rag.prompts import chunks_format
|
from rag.prompts import chunks_format
|
||||||
from rag.prompts.prompt_template import load_prompt
|
from rag.prompts.prompt_template import load_prompt
|
||||||
from rag.prompts.prompts import cross_languages, keyword_extraction
|
from rag.prompts.prompts import cross_languages, gen_meta_filter, keyword_extraction
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||||
@ -81,7 +84,7 @@ def create_agent_session(tenant_id, agent_id):
|
|||||||
if not isinstance(cvs.dsl, str):
|
if not isinstance(cvs.dsl, str):
|
||||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
session_id=get_uuid()
|
session_id = get_uuid()
|
||||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||||
canvas.reset()
|
canvas.reset()
|
||||||
|
|
||||||
@ -442,26 +445,46 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
|||||||
def agent_completions(tenant_id, agent_id):
|
def agent_completions(tenant_id, agent_id):
|
||||||
req = request.json
|
req = request.json
|
||||||
|
|
||||||
|
ans = {}
|
||||||
|
if req.get("stream", True):
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
|
if isinstance(answer, str):
|
||||||
|
try:
|
||||||
|
ans = json.loads(answer[5:]) # remove "data:"
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
yield "data:[DONE]\n\n"
|
||||||
|
|
||||||
if req.get("stream", True):
|
if req.get("stream", True):
|
||||||
resp = Response(agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req), mimetype="text/event-stream")
|
resp = Response(generate(), mimetype="text/event-stream")
|
||||||
resp.headers.add_header("Cache-control", "no-cache")
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
resp.headers.add_header("Connection", "keep-alive")
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
return resp
|
return resp
|
||||||
result = {}
|
|
||||||
|
full_content = ""
|
||||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||||
try:
|
try:
|
||||||
ans = json.loads(answer[5:]) # remove "data:"
|
ans = json.loads(answer[5:])
|
||||||
if not result:
|
|
||||||
result = ans.copy()
|
if ans["event"] == "message":
|
||||||
else:
|
full_content += ans["data"]["content"]
|
||||||
result["data"]["answer"] += ans["data"]["answer"]
|
|
||||||
result["data"]["reference"] = ans["data"].get("reference", [])
|
if ans.get("data", {}).get("reference", None):
|
||||||
|
ans["data"]["content"] = full_content
|
||||||
|
return get_result(data=ans)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return get_error_data_result(str(e))
|
return get_result(data=f"**ERROR**: {str(e)}")
|
||||||
return result
|
return get_result(data=ans)
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||||
@ -556,10 +579,7 @@ def list_agent_session(tenant_id, agent_id):
|
|||||||
if message_num != 0 and messages[message_num]["role"] != "user":
|
if message_num != 0 and messages[message_num]["role"] != "user":
|
||||||
chunk_list = []
|
chunk_list = []
|
||||||
# Add boundary and type checks to prevent KeyError
|
# Add boundary and type checks to prevent KeyError
|
||||||
if (chunk_num < len(conv["reference"]) and
|
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
||||||
conv["reference"][chunk_num] is not None and
|
|
||||||
isinstance(conv["reference"][chunk_num], dict) and
|
|
||||||
"chunks" in conv["reference"][chunk_num]):
|
|
||||||
chunks = conv["reference"][chunk_num]["chunks"]
|
chunks = conv["reference"][chunk_num]["chunks"]
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
# Ensure chunk is a dictionary before calling get method
|
# Ensure chunk is a dictionary before calling get method
|
||||||
@ -860,15 +880,7 @@ def begin_inputs(agent_id):
|
|||||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||||
|
|
||||||
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
|
||||||
return get_result(
|
return get_result(data={"title": cvs.title, "avatar": cvs.avatar, "inputs": canvas.get_component_input_form("begin"), "prologue": canvas.get_prologue(), "mode": canvas.get_mode()})
|
||||||
data={
|
|
||||||
"title": cvs.title,
|
|
||||||
"avatar": cvs.avatar,
|
|
||||||
"inputs": canvas.get_component_input_form("begin"),
|
|
||||||
"prologue": canvas.get_prologue(),
|
|
||||||
"mode": canvas.get_mode()
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
@manager.route("/searchbots/ask", methods=["POST"]) # noqa: F821
|
||||||
@ -908,7 +920,7 @@ def ask_about_embedded():
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/searchbots/retrieval_test", methods=['POST']) # noqa: F821
|
@manager.route("/searchbots/retrieval_test", methods=["POST"]) # noqa: F821
|
||||||
@validate_request("kb_id", "question")
|
@validate_request("kb_id", "question")
|
||||||
def retrieval_test_embedded():
|
def retrieval_test_embedded():
|
||||||
token = request.headers.get("Authorization").split()
|
token = request.headers.get("Authorization").split()
|
||||||
@ -938,18 +950,30 @@ def retrieval_test_embedded():
|
|||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
return get_error_data_result(message="permission denined.")
|
return get_error_data_result(message="permission denined.")
|
||||||
|
|
||||||
|
if req.get("search_id", ""):
|
||||||
|
search_config = SearchService.get_detail(req.get("search_id", "")).get("search_config", {})
|
||||||
|
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||||
|
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||||
|
if meta_data_filter.get("method") == "auto":
|
||||||
|
chat_mdl = LLMBundle(tenant_id, LLMType.CHAT, llm_name=search_config.get("chat_id", ""))
|
||||||
|
filters = gen_meta_filter(chat_mdl, metas, question)
|
||||||
|
doc_ids.extend(meta_filter(metas, filters))
|
||||||
|
if not doc_ids:
|
||||||
|
doc_ids = None
|
||||||
|
elif meta_data_filter.get("method") == "manual":
|
||||||
|
doc_ids.extend(meta_filter(metas, meta_data_filter["manual"]))
|
||||||
|
if not doc_ids:
|
||||||
|
doc_ids = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tenants = UserTenantService.query(user_id=tenant_id)
|
tenants = UserTenantService.query(user_id=tenant_id)
|
||||||
for kb_id in kb_ids:
|
for kb_id in kb_ids:
|
||||||
for tenant in tenants:
|
for tenant in tenants:
|
||||||
if KnowledgebaseService.query(
|
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||||
tenant_id=tenant.tenant_id, id=kb_id):
|
|
||||||
tenant_ids.append(tenant.tenant_id)
|
tenant_ids.append(tenant.tenant_id)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
return get_json_result(
|
return get_json_result(data=False, message="Only owner of knowledgebase authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
|
||||||
data=False, message='Only owner of knowledgebase authorized for this operation.',
|
|
||||||
code=settings.RetCode.OPERATING_ERROR)
|
|
||||||
|
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||||
if not e:
|
if not e:
|
||||||
@ -969,17 +993,11 @@ def retrieval_test_embedded():
|
|||||||
question += keyword_extraction(chat_mdl, question)
|
question += keyword_extraction(chat_mdl, question)
|
||||||
|
|
||||||
labels = label_question(question, [kb])
|
labels = label_question(question, [kb])
|
||||||
ranks = settings.retrievaler.retrieval(question, embd_mdl, tenant_ids, kb_ids, page, size,
|
ranks = settings.retrievaler.retrieval(
|
||||||
similarity_threshold, vector_similarity_weight, top,
|
question, embd_mdl, tenant_ids, kb_ids, page, size, similarity_threshold, vector_similarity_weight, top, doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||||
doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"),
|
|
||||||
rank_feature=labels
|
|
||||||
)
|
)
|
||||||
if use_kg:
|
if use_kg:
|
||||||
ck = settings.kg_retrievaler.retrieval(question,
|
ck = settings.kg_retrievaler.retrieval(question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, LLMType.CHAT))
|
||||||
tenant_ids,
|
|
||||||
kb_ids,
|
|
||||||
embd_mdl,
|
|
||||||
LLMBundle(kb.tenant_id, LLMType.CHAT))
|
|
||||||
if ck["content_with_weight"]:
|
if ck["content_with_weight"]:
|
||||||
ranks["chunks"].insert(0, ck)
|
ranks["chunks"].insert(0, ck)
|
||||||
|
|
||||||
@ -990,8 +1008,7 @@ def retrieval_test_embedded():
|
|||||||
return get_json_result(data=ranks)
|
return get_json_result(data=ranks)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if str(e).find("not_found") > 0:
|
if str(e).find("not_found") > 0:
|
||||||
return get_json_result(data=False, message='No chunk found! Check the chunk status please!',
|
return get_json_result(data=False, message="No chunk found! Check the chunk status please!", code=settings.RetCode.DATA_ERROR)
|
||||||
code=settings.RetCode.DATA_ERROR)
|
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -135,24 +135,6 @@ class UserCanvasService(CommonService):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def structure_answer(conv, ans, message_id, session_id):
|
|
||||||
if not conv:
|
|
||||||
return ans
|
|
||||||
content = ""
|
|
||||||
if ans["event"] == "message":
|
|
||||||
if ans["data"].get("start_to_think") is True:
|
|
||||||
content = "<think>"
|
|
||||||
elif ans["data"].get("end_to_think") is True:
|
|
||||||
content = "</think>"
|
|
||||||
else:
|
|
||||||
content = ans["data"]["content"]
|
|
||||||
|
|
||||||
reference = ans["data"].get("reference")
|
|
||||||
result = {"id": message_id, "session_id": session_id, "answer": content}
|
|
||||||
if reference:
|
|
||||||
result["reference"] = [reference]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||||
query = kwargs.get("query", "") or kwargs.get("question", "")
|
query = kwargs.get("query", "") or kwargs.get("question", "")
|
||||||
files = kwargs.get("files", [])
|
files = kwargs.get("files", [])
|
||||||
@ -196,14 +178,13 @@ def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
|||||||
})
|
})
|
||||||
txt = ""
|
txt = ""
|
||||||
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||||
ans = structure_answer(conv, ans, message_id, session_id)
|
ans["session_id"] = session_id
|
||||||
txt += ans["answer"]
|
if ans["event"] == "message":
|
||||||
if ans.get("answer") or ans.get("reference"):
|
txt += ans["data"]["content"]
|
||||||
yield "data:" + json.dumps({"code": 0, "data": ans},
|
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||||
ensure_ascii=False) + "\n\n"
|
|
||||||
|
|
||||||
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
|
||||||
conv.reference.append(canvas.get_reference())
|
conv.reference = canvas.get_reference()
|
||||||
conv.errors = canvas.error
|
conv.errors = canvas.error
|
||||||
conv.dsl = str(canvas)
|
conv.dsl = str(canvas)
|
||||||
conv = conv.to_dict()
|
conv = conv.to_dict()
|
||||||
@ -232,9 +213,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
||||||
continue
|
continue
|
||||||
if not ans["data"]["answer"]:
|
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||||
continue
|
continue
|
||||||
content_piece = ans["data"]["answer"]
|
content_piece = ans["data"]["content"]
|
||||||
completion_tokens += len(tiktokenenc.encode(content_piece))
|
completion_tokens += len(tiktokenenc.encode(content_piece))
|
||||||
|
|
||||||
yield "data: " + json.dumps(
|
yield "data: " + json.dumps(
|
||||||
@ -279,9 +260,9 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
|||||||
):
|
):
|
||||||
if isinstance(ans, str):
|
if isinstance(ans, str):
|
||||||
ans = json.loads(ans[5:])
|
ans = json.loads(ans[5:])
|
||||||
if not ans["data"]["answer"]:
|
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||||
continue
|
continue
|
||||||
all_content += ans["data"]["answer"]
|
all_content += ans["data"]["content"]
|
||||||
|
|
||||||
completion_tokens = len(tiktokenenc.encode(all_content))
|
completion_tokens = len(tiktokenenc.encode(all_content))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user