Files
ragflow/api/db/services/canvas_service.py
天海蒼灆 5e2c33e5b0 Fix: grow reference list (#9674)
### What problem does this PR solve?

Fix Multiple conversations cause the reference list to grow indefinitely
due to Python's mutable default argument behavior.
Explicitly initialize reference as empty list when creating new sessions

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
2025-08-25 14:08:15 +08:00

308 lines
11 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import logging
import time
from uuid import uuid4
from agent.canvas import Canvas
from api.db import TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.utils import get_uuid
from api.utils.api_utils import get_data_openai
import tiktoken
from peewee import fn
class CanvasTemplateService(CommonService):
model = CanvasTemplate
class UserCanvasService(CommonService):
model = UserCanvas
@classmethod
@DB.connection_context()
def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, title):
agents = cls.model.select()
if id:
agents = agents.where(cls.model.id == id)
if title:
agents = agents.where(cls.model.title == title)
agents = agents.where(cls.model.user_id == tenant_id)
if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc())
else:
agents = agents.order_by(cls.model.getter_by(orderby).asc())
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts())
@classmethod
@DB.connection_context()
def get_by_tenant_id(cls, pid):
try:
fields = [
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.dsl,
cls.model.description,
cls.model.permission,
cls.model.update_time,
cls.model.user_id,
cls.model.create_time,
cls.model.create_date,
cls.model.update_date,
User.nickname,
User.avatar.alias('tenant_avatar'),
]
agents = cls.model.select(*fields) \
.join(User, on=(cls.model.user_id == User.id)) \
.where(cls.model.id == pid)
# obj = cls.model.query(id=pid)[0]
return True, agents.dicts()[0]
except Exception as e:
logging.exception(e)
return False, None
@classmethod
@DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page,
orderby, desc, keywords,
):
fields = [
cls.model.id,
cls.model.avatar,
cls.model.title,
cls.model.dsl,
cls.model.description,
cls.model.permission,
User.nickname,
User.avatar.alias('tenant_avatar'),
cls.model.update_time
]
if keywords:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id)),
(fn.LOWER(cls.model.title).contains(keywords.lower()))
)
else:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
((cls.model.user_id.in_(joined_tenant_ids) & (cls.model.permission ==
TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id))
)
if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc())
else:
agents = agents.order_by(cls.model.getter_by(orderby).asc())
count = agents.count()
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts()), count
@classmethod
@DB.connection_context()
def accessible(cls, canvas_id, tenant_id):
from api.db.services.user_service import UserTenantService
e, c = UserCanvasService.get_by_tenant_id(canvas_id)
if not e:
return False
tids = [t.tenant_id for t in UserTenantService.query(user_id=tenant_id)]
if c["user_id"] != canvas_id and c["user_id"] not in tids:
return False
return True
def structure_answer(conv, ans, message_id, session_id):
if not conv:
return ans
content = ""
if ans["event"] == "message":
if ans["data"].get("start_to_think") is True:
content = "<think>"
elif ans["data"].get("end_to_think") is True:
content = "</think>"
else:
content = ans["data"]["content"]
reference = ans["data"].get("reference")
result = {"id": message_id, "session_id": session_id, "answer": content}
if reference:
result["reference"] = [reference]
return result
def completion(tenant_id, agent_id, session_id=None, **kwargs):
query = kwargs.get("query", "") or kwargs.get("question", "")
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
if not conv.message:
conv.message = []
if not isinstance(conv.dsl, str):
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
canvas = Canvas(conv.dsl, tenant_id, agent_id)
else:
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
assert cvs.user_id == tenant_id, "You do not own the agent."
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
canvas.reset()
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": user_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl,
"reference": []
}
API4ConversationService.save(**conv)
conv = API4Conversation(**conv)
message_id = str(uuid4())
conv.message.append({
"role": "user",
"content": query,
"id": message_id
})
txt = ""
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
ans = structure_answer(conv, ans, message_id, session_id)
txt += ans["answer"]
if ans.get("answer") or ans.get("reference"):
yield "data:" + json.dumps({"code": 0, "data": ans},
ensure_ascii=False) + "\n\n"
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
conv.reference.append(canvas.get_reference())
conv.errors = canvas.error
conv.dsl = str(canvas)
conv = conv.to_dict()
API4ConversationService.append_message(conv["id"], conv)
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
tiktokenenc = tiktoken.get_encoding("cl100k_base")
prompt_tokens = len(tiktokenenc.encode(str(question)))
user_id = kwargs.get("user_id", "")
if stream:
completion_tokens = 0
try:
for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
if isinstance(ans, str):
try:
ans = json.loads(ans[5:]) # remove "data:"
except Exception as e:
logging.exception(f"Agent OpenAI-Compatible completionOpenAI parse answer failed: {e}")
continue
if not ans["data"]["answer"]:
continue
content_piece = ans["data"]["answer"]
completion_tokens += len(tiktokenenc.encode(content_piece))
yield "data: " + json.dumps(
get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=content_piece,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
stream=True
),
ensure_ascii=False
) + "\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
yield "data: " + json.dumps(
get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
stream=True
),
ensure_ascii=False
) + "\n\n"
yield "data: [DONE]\n\n"
else:
try:
all_content = ""
for ans in completion(
tenant_id=tenant_id,
agent_id=agent_id,
session_id=session_id,
query=question,
user_id=user_id,
**kwargs
):
if isinstance(ans, str):
ans = json.loads(ans[5:])
if not ans["data"]["answer"]:
continue
all_content += ans["data"]["answer"]
completion_tokens = len(tiktokenenc.encode(all_content))
yield get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
content=all_content,
finish_reason="stop",
param=None
)
except Exception as e:
yield get_data_openai(
id=session_id or str(uuid4()),
model=agent_id,
prompt_tokens=prompt_tokens,
completion_tokens=len(tiktokenenc.encode(f"**ERROR**: {str(e)}")),
content=f"**ERROR**: {str(e)}",
finish_reason="stop",
param=None
)