Feat: add canvas_category field for UserCanvas and CanvasTemplate (#9885)

### What problem does this PR solve?

Add `canvas_category` field for UserCanvas and CanvasTemplate.

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Yongteng Lei
2025-09-03 14:55:24 +08:00
committed by GitHub
parent 5d015e48c1
commit c832e0b858
4 changed files with 49 additions and 20 deletions

View File

@ -24,7 +24,7 @@ from flask import request, Response
from flask_login import login_required, current_user from flask_login import login_required, current_user
from agent.component import LLM from agent.component import LLM
from api.db import FileType from api.db import CanvasCategory, FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService from api.db.services.file_service import FileService
@ -45,14 +45,14 @@ from rag.utils.redis_conn import REDIS_CONN
@manager.route('/templates', methods=['GET']) # noqa: F821 @manager.route('/templates', methods=['GET']) # noqa: F821
@login_required @login_required
def templates(): def templates():
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()]) return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
@manager.route('/list', methods=['GET']) # noqa: F821 @manager.route('/list', methods=['GET']) # noqa: F821
@login_required @login_required
def canvas_list(): def canvas_list():
return get_json_result(data=sorted([c.to_dict() for c in \ return get_json_result(data=sorted([c.to_dict() for c in \
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1) UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.Agent)], key=lambda x: x["update_time"]*-1)
) )
@ -79,7 +79,7 @@ def save():
req["dsl"] = json.loads(req["dsl"]) req["dsl"] = json.loads(req["dsl"])
if "id" not in req: if "id" not in req:
req["user_id"] = current_user.id req["user_id"] = current_user.id
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip()): if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.Agent):
return get_data_error_result(message=f"{req['title'].strip()} already exists.") return get_data_error_result(message=f"{req['title'].strip()} already exists.")
req["id"] = get_uuid() req["id"] = get_uuid()
if not UserCanvasService.save(**req): if not UserCanvasService.save(**req):
@ -91,7 +91,7 @@ def save():
code=RetCode.OPERATING_ERROR) code=RetCode.OPERATING_ERROR)
UserCanvasService.update_by_id(req["id"], req) UserCanvasService.update_by_id(req["id"], req)
# save version # save version
UserCanvasVersionService.insert( user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S"))) UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
UserCanvasVersionService.delete_all_versions(req["id"]) UserCanvasVersionService.delete_all_versions(req["id"])
return get_json_result(data=req) return get_json_result(data=req)
@ -395,7 +395,7 @@ def list_canvas():
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
canvas, total = UserCanvasService.get_by_tenant_ids( canvas, total = UserCanvasService.get_by_tenant_ids(
[m["tenant_id"] for m in tenants], current_user.id, page_number, [m["tenant_id"] for m in tenants], current_user.id, page_number,
items_per_page, orderby, desc, keywords) items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.Agent)
return get_json_result(data={"canvas": canvas, "total": total}) return get_json_result(data={"canvas": canvas, "total": total})
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)

View File

@ -74,8 +74,10 @@ class TaskStatus(StrEnum):
DONE = "3" DONE = "3"
FAIL = "4" FAIL = "4"
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL} VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
class ParserType(StrEnum): class ParserType(StrEnum):
PRESENTATION = "presentation" PRESENTATION = "presentation"
LAWS = "laws" LAWS = "laws"
@ -105,10 +107,19 @@ class CanvasType(StrEnum):
DocBot = "docbot" DocBot = "docbot"
class CanvasCategory(StrEnum):
Agent = "agent_canvas"
DataFlow = "dataflow_canvas"
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
class MCPServerType(StrEnum): class MCPServerType(StrEnum):
SSE = "sse" SSE = "sse"
STREAMABLE_HTTP = "streamable-http" STREAMABLE_HTTP = "streamable-http"
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP} VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase" KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"

View File

@ -245,22 +245,21 @@ class JsonSerializedField(SerializedField):
class RetryingPooledMySQLDatabase(PooledMySQLDatabase): class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.max_retries = kwargs.pop('max_retries', 5) self.max_retries = kwargs.pop("max_retries", 5)
self.retry_delay = kwargs.pop('retry_delay', 1) self.retry_delay = kwargs.pop("retry_delay", 1)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def execute_sql(self, sql, params=None, commit=True): def execute_sql(self, sql, params=None, commit=True):
from peewee import OperationalError from peewee import OperationalError
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:
return super().execute_sql(sql, params, commit) return super().execute_sql(sql, params, commit)
except OperationalError as e: except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries: if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning( logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
f"Lost connection (attempt {attempt+1}/{self.max_retries}): {e}"
)
self._handle_connection_loss() self._handle_connection_loss()
time.sleep(self.retry_delay * (2 ** attempt)) time.sleep(self.retry_delay * (2**attempt))
else: else:
logging.error(f"DB execution failure: {e}") logging.error(f"DB execution failure: {e}")
raise raise
@ -272,16 +271,15 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
def begin(self): def begin(self):
from peewee import OperationalError from peewee import OperationalError
for attempt in range(self.max_retries + 1): for attempt in range(self.max_retries + 1):
try: try:
return super().begin() return super().begin()
except OperationalError as e: except OperationalError as e:
if e.args[0] in (2013, 2006) and attempt < self.max_retries: if e.args[0] in (2013, 2006) and attempt < self.max_retries:
logging.warning( logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
)
self._handle_connection_loss() self._handle_connection_loss()
time.sleep(self.retry_delay * (2 ** attempt)) time.sleep(self.retry_delay * (2**attempt))
else: else:
raise raise
@ -815,6 +813,7 @@ class UserCanvas(DataBaseModel):
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True) permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
description = TextField(null=True, help_text="Canvas description") description = TextField(null=True, help_text="Canvas description")
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True) canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
dsl = JSONField(null=True, default={}) dsl = JSONField(null=True, default={})
class Meta: class Meta:
@ -827,6 +826,7 @@ class CanvasTemplate(DataBaseModel):
title = JSONField(null=True, default=dict, help_text="Canvas title") title = JSONField(null=True, default=dict, help_text="Canvas title")
description = JSONField(null=True, default=dict, help_text="Canvas description") description = JSONField(null=True, default=dict, help_text="Canvas description")
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True) canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
dsl = JSONField(null=True, default={}) dsl = JSONField(null=True, default={})
class Meta: class Meta:
@ -1029,4 +1029,12 @@ def migrate_db():
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description"))) migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
except Exception: except Exception:
pass pass
try:
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
try:
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
except Exception:
pass
logging.disable(logging.NOTSET) logging.disable(logging.NOTSET)

View File

@ -18,7 +18,7 @@ import logging
import time import time
from uuid import uuid4 from uuid import uuid4
from agent.canvas import Canvas from agent.canvas import Canvas
from api.db import TenantPermission from api.db import CanvasCategory, TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService from api.db.services.common_service import CommonService
@ -31,6 +31,12 @@ from peewee import fn
class CanvasTemplateService(CommonService): class CanvasTemplateService(CommonService):
model = CanvasTemplate model = CanvasTemplate
class DataFlowTemplateService(CommonService):
"""
Alias of CanvasTemplateService
"""
model = CanvasTemplate
class UserCanvasService(CommonService): class UserCanvasService(CommonService):
model = UserCanvas model = UserCanvas
@ -38,13 +44,14 @@ class UserCanvasService(CommonService):
@classmethod @classmethod
@DB.connection_context() @DB.connection_context()
def get_list(cls, tenant_id, def get_list(cls, tenant_id,
page_number, items_per_page, orderby, desc, id, title): page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
agents = cls.model.select() agents = cls.model.select()
if id: if id:
agents = agents.where(cls.model.id == id) agents = agents.where(cls.model.id == id)
if title: if title:
agents = agents.where(cls.model.title == title) agents = agents.where(cls.model.title == title)
agents = agents.where(cls.model.user_id == tenant_id) agents = agents.where(cls.model.user_id == tenant_id)
agents = agents.where(cls.model.canvas_category == canvas_category)
if desc: if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc()) agents = agents.order_by(cls.model.getter_by(orderby).desc())
else: else:
@ -71,6 +78,7 @@ class UserCanvasService(CommonService):
cls.model.create_time, cls.model.create_time,
cls.model.create_date, cls.model.create_date,
cls.model.update_date, cls.model.update_date,
cls.model.canvas_category,
User.nickname, User.nickname,
User.avatar.alias('tenant_avatar'), User.avatar.alias('tenant_avatar'),
] ]
@ -87,7 +95,7 @@ class UserCanvasService(CommonService):
@DB.connection_context() @DB.connection_context()
def get_by_tenant_ids(cls, joined_tenant_ids, user_id, def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
page_number, items_per_page, page_number, items_per_page,
orderby, desc, keywords, orderby, desc, keywords, canvas_category=CanvasCategory.Agent,
): ):
fields = [ fields = [
cls.model.id, cls.model.id,
@ -98,7 +106,8 @@ class UserCanvasService(CommonService):
cls.model.permission, cls.model.permission,
User.nickname, User.nickname,
User.avatar.alias('tenant_avatar'), User.avatar.alias('tenant_avatar'),
cls.model.update_time cls.model.update_time,
cls.model.canvas_category,
] ]
if keywords: if keywords:
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
@ -113,6 +122,7 @@ class UserCanvasService(CommonService):
TenantPermission.TEAM.value)) | ( TenantPermission.TEAM.value)) | (
cls.model.user_id == user_id)) cls.model.user_id == user_id))
) )
agents = agents.where(cls.model.canvas_category == canvas_category)
if desc: if desc:
agents = agents.order_by(cls.model.getter_by(orderby).desc()) agents = agents.order_by(cls.model.getter_by(orderby).desc())
else: else: