mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Refa: unify reference format of agent completion and OpenAI-compatible completion API (#9792)
### What problem does this PR solve? Unify reference format of agent completion and OpenAI-compatible completion API. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Documentation Update - [x] Refactoring
This commit is contained in:
@ -414,7 +414,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.get("id", req.get("metadata", {}).get("id", "")),
|
||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
||||
stream=True,
|
||||
**req,
|
||||
),
|
||||
@ -432,7 +432,7 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.get("id", req.get("metadata", {}).get("id", "")),
|
||||
session_id=req.get("session_id", req.get("id", "") or req.get("metadata", {}).get("id", "")),
|
||||
stream=False,
|
||||
**req,
|
||||
)
|
||||
@ -445,7 +445,6 @@ def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
def agent_completions(tenant_id, agent_id):
|
||||
req = request.json
|
||||
|
||||
ans = {}
|
||||
if req.get("stream", True):
|
||||
|
||||
def generate():
|
||||
@ -456,14 +455,13 @@ def agent_completions(tenant_id, agent_id):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(generate(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
@ -472,6 +470,8 @@ def agent_completions(tenant_id, agent_id):
|
||||
return resp
|
||||
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
@ -480,11 +480,14 @@ def agent_completions(tenant_id, agent_id):
|
||||
full_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
ans["data"]["content"] = full_content
|
||||
return get_result(data=ans)
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
final_ans = ans
|
||||
except Exception as e:
|
||||
return get_result(data=f"**ERROR**: {str(e)}")
|
||||
return get_result(data=ans)
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
|
||||
@ -213,26 +213,33 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
except Exception as e:
|
||||
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
|
||||
continue
|
||||
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
content_piece = ans["data"]["content"]
|
||||
|
||||
content_piece = ""
|
||||
if ans["event"] == "message":
|
||||
content_piece = ans["data"]["content"]
|
||||
|
||||
completion_tokens += len(tiktokenenc.encode(content_piece))
|
||||
|
||||
yield "data: " + json.dumps(
|
||||
get_data_openai(
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
content=content_piece,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
stream=True
|
||||
),
|
||||
ensure_ascii=False
|
||||
) + "\n\n"
|
||||
)
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
openai_data["choices"][0]["delta"]["reference"] = ans["data"]["reference"]
|
||||
|
||||
yield "data: " + json.dumps(openai_data, ensure_ascii=False) + "\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data: " + json.dumps(
|
||||
get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
@ -250,6 +257,7 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
else:
|
||||
try:
|
||||
all_content = ""
|
||||
reference = {}
|
||||
for ans in completion(
|
||||
tenant_id=tenant_id,
|
||||
agent_id=agent_id,
|
||||
@ -260,13 +268,18 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
):
|
||||
if isinstance(ans, str):
|
||||
ans = json.loads(ans[5:])
|
||||
if ans.get("event") != "message" or not ans.get("data", {}).get("reference", None):
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
continue
|
||||
all_content += ans["data"]["content"]
|
||||
|
||||
if ans["event"] == "message":
|
||||
all_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
completion_tokens = len(tiktokenenc.encode(all_content))
|
||||
|
||||
yield get_data_openai(
|
||||
openai_data = get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
@ -276,7 +289,12 @@ def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True
|
||||
param=None
|
||||
)
|
||||
|
||||
if reference:
|
||||
openai_data["choices"][0]["message"]["reference"] = reference
|
||||
|
||||
yield openai_data
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield get_data_openai(
|
||||
id=session_id or str(uuid4()),
|
||||
model=agent_id,
|
||||
|
||||
Reference in New Issue
Block a user