mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 20:42:30 +08:00
Feat: add canvas_category field for UserCanvas and CanvasTemplate (#9885)
### What problem does this PR solve? Add `canvas_category` field for UserCanvas and CanvasTemplate. ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -24,7 +24,7 @@ from flask import request, Response
|
|||||||
from flask_login import login_required, current_user
|
from flask_login import login_required, current_user
|
||||||
|
|
||||||
from agent.component import LLM
|
from agent.component import LLM
|
||||||
from api.db import FileType
|
from api.db import CanvasCategory, FileType
|
||||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||||
from api.db.services.document_service import DocumentService
|
from api.db.services.document_service import DocumentService
|
||||||
from api.db.services.file_service import FileService
|
from api.db.services.file_service import FileService
|
||||||
@ -45,14 +45,14 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def templates():
|
def templates():
|
||||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.query(canvas_category=CanvasCategory.Agent)])
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def canvas_list():
|
def canvas_list():
|
||||||
return get_json_result(data=sorted([c.to_dict() for c in \
|
return get_json_result(data=sorted([c.to_dict() for c in \
|
||||||
UserCanvasService.query(user_id=current_user.id)], key=lambda x: x["update_time"]*-1)
|
UserCanvasService.query(user_id=current_user.id, canvas_category=CanvasCategory.Agent)], key=lambda x: x["update_time"]*-1)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +79,7 @@ def save():
|
|||||||
req["dsl"] = json.loads(req["dsl"])
|
req["dsl"] = json.loads(req["dsl"])
|
||||||
if "id" not in req:
|
if "id" not in req:
|
||||||
req["user_id"] = current_user.id
|
req["user_id"] = current_user.id
|
||||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip()):
|
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=CanvasCategory.Agent):
|
||||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||||
req["id"] = get_uuid()
|
req["id"] = get_uuid()
|
||||||
if not UserCanvasService.save(**req):
|
if not UserCanvasService.save(**req):
|
||||||
@ -91,7 +91,7 @@ def save():
|
|||||||
code=RetCode.OPERATING_ERROR)
|
code=RetCode.OPERATING_ERROR)
|
||||||
UserCanvasService.update_by_id(req["id"], req)
|
UserCanvasService.update_by_id(req["id"], req)
|
||||||
# save version
|
# save version
|
||||||
UserCanvasVersionService.insert( user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
UserCanvasVersionService.insert(user_canvas_id=req["id"], dsl=req["dsl"], title="{0}_{1}".format(req["title"], time.strftime("%Y_%m_%d_%H_%M_%S")))
|
||||||
UserCanvasVersionService.delete_all_versions(req["id"])
|
UserCanvasVersionService.delete_all_versions(req["id"])
|
||||||
return get_json_result(data=req)
|
return get_json_result(data=req)
|
||||||
|
|
||||||
@ -395,7 +395,7 @@ def list_canvas():
|
|||||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||||
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
[m["tenant_id"] for m in tenants], current_user.id, page_number,
|
||||||
items_per_page, orderby, desc, keywords)
|
items_per_page, orderby, desc, keywords, canvas_category=CanvasCategory.Agent)
|
||||||
return get_json_result(data={"canvas": canvas, "total": total})
|
return get_json_result(data={"canvas": canvas, "total": total})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|||||||
@ -74,8 +74,10 @@ class TaskStatus(StrEnum):
|
|||||||
DONE = "3"
|
DONE = "3"
|
||||||
FAIL = "4"
|
FAIL = "4"
|
||||||
|
|
||||||
|
|
||||||
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL}
|
||||||
|
|
||||||
|
|
||||||
class ParserType(StrEnum):
|
class ParserType(StrEnum):
|
||||||
PRESENTATION = "presentation"
|
PRESENTATION = "presentation"
|
||||||
LAWS = "laws"
|
LAWS = "laws"
|
||||||
@ -105,10 +107,19 @@ class CanvasType(StrEnum):
|
|||||||
DocBot = "docbot"
|
DocBot = "docbot"
|
||||||
|
|
||||||
|
|
||||||
|
class CanvasCategory(StrEnum):
|
||||||
|
Agent = "agent_canvas"
|
||||||
|
DataFlow = "dataflow_canvas"
|
||||||
|
|
||||||
|
VALID_CAVAS_CATEGORIES = {CanvasCategory.Agent, CanvasCategory.DataFlow}
|
||||||
|
|
||||||
|
|
||||||
class MCPServerType(StrEnum):
|
class MCPServerType(StrEnum):
|
||||||
SSE = "sse"
|
SSE = "sse"
|
||||||
STREAMABLE_HTTP = "streamable-http"
|
STREAMABLE_HTTP = "streamable-http"
|
||||||
|
|
||||||
|
|
||||||
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
VALID_MCP_SERVER_TYPES = {MCPServerType.SSE, MCPServerType.STREAMABLE_HTTP}
|
||||||
|
|
||||||
|
|
||||||
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
KNOWLEDGEBASE_FOLDER_NAME=".knowledgebase"
|
||||||
|
|||||||
@ -245,22 +245,21 @@ class JsonSerializedField(SerializedField):
|
|||||||
|
|
||||||
class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.max_retries = kwargs.pop('max_retries', 5)
|
self.max_retries = kwargs.pop("max_retries", 5)
|
||||||
self.retry_delay = kwargs.pop('retry_delay', 1)
|
self.retry_delay = kwargs.pop("retry_delay", 1)
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def execute_sql(self, sql, params=None, commit=True):
|
def execute_sql(self, sql, params=None, commit=True):
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().execute_sql(sql, params, commit)
|
return super().execute_sql(sql, params, commit)
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||||
logging.warning(
|
logging.warning(f"Lost connection (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||||
f"Lost connection (attempt {attempt+1}/{self.max_retries}): {e}"
|
|
||||||
)
|
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2 ** attempt))
|
time.sleep(self.retry_delay * (2**attempt))
|
||||||
else:
|
else:
|
||||||
logging.error(f"DB execution failure: {e}")
|
logging.error(f"DB execution failure: {e}")
|
||||||
raise
|
raise
|
||||||
@ -272,16 +271,15 @@ class RetryingPooledMySQLDatabase(PooledMySQLDatabase):
|
|||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
from peewee import OperationalError
|
from peewee import OperationalError
|
||||||
|
|
||||||
for attempt in range(self.max_retries + 1):
|
for attempt in range(self.max_retries + 1):
|
||||||
try:
|
try:
|
||||||
return super().begin()
|
return super().begin()
|
||||||
except OperationalError as e:
|
except OperationalError as e:
|
||||||
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
if e.args[0] in (2013, 2006) and attempt < self.max_retries:
|
||||||
logging.warning(
|
logging.warning(f"Lost connection during transaction (attempt {attempt + 1}/{self.max_retries})")
|
||||||
f"Lost connection during transaction (attempt {attempt+1}/{self.max_retries})"
|
|
||||||
)
|
|
||||||
self._handle_connection_loss()
|
self._handle_connection_loss()
|
||||||
time.sleep(self.retry_delay * (2 ** attempt))
|
time.sleep(self.retry_delay * (2**attempt))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -815,6 +813,7 @@ class UserCanvas(DataBaseModel):
|
|||||||
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
|
permission = CharField(max_length=16, null=False, help_text="me|team", default="me", index=True)
|
||||||
description = TextField(null=True, help_text="Canvas description")
|
description = TextField(null=True, help_text="Canvas description")
|
||||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||||
|
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
|
||||||
dsl = JSONField(null=True, default={})
|
dsl = JSONField(null=True, default={})
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
@ -827,6 +826,7 @@ class CanvasTemplate(DataBaseModel):
|
|||||||
title = JSONField(null=True, default=dict, help_text="Canvas title")
|
title = JSONField(null=True, default=dict, help_text="Canvas title")
|
||||||
description = JSONField(null=True, default=dict, help_text="Canvas description")
|
description = JSONField(null=True, default=dict, help_text="Canvas description")
|
||||||
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True)
|
||||||
|
canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True)
|
||||||
dsl = JSONField(null=True, default={})
|
dsl = JSONField(null=True, default={})
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
@ -1029,4 +1029,12 @@ def migrate_db():
|
|||||||
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
|
migrate(migrator.alter_column_type("canvas_template", "description", JSONField(null=True, default=dict, help_text="Canvas description")))
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("user_canvas", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
migrate(migrator.add_column("canvas_template", "canvas_category", CharField(max_length=32, null=False, default="agent_canvas", help_text="agent_canvas|dataflow_canvas", index=True)))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -18,7 +18,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
from agent.canvas import Canvas
|
from agent.canvas import Canvas
|
||||||
from api.db import TenantPermission
|
from api.db import CanvasCategory, TenantPermission
|
||||||
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
|
||||||
from api.db.services.api_service import API4ConversationService
|
from api.db.services.api_service import API4ConversationService
|
||||||
from api.db.services.common_service import CommonService
|
from api.db.services.common_service import CommonService
|
||||||
@ -31,6 +31,12 @@ from peewee import fn
|
|||||||
class CanvasTemplateService(CommonService):
|
class CanvasTemplateService(CommonService):
|
||||||
model = CanvasTemplate
|
model = CanvasTemplate
|
||||||
|
|
||||||
|
class DataFlowTemplateService(CommonService):
|
||||||
|
"""
|
||||||
|
Alias of CanvasTemplateService
|
||||||
|
"""
|
||||||
|
model = CanvasTemplate
|
||||||
|
|
||||||
|
|
||||||
class UserCanvasService(CommonService):
|
class UserCanvasService(CommonService):
|
||||||
model = UserCanvas
|
model = UserCanvas
|
||||||
@ -38,13 +44,14 @@ class UserCanvasService(CommonService):
|
|||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, tenant_id,
|
def get_list(cls, tenant_id,
|
||||||
page_number, items_per_page, orderby, desc, id, title):
|
page_number, items_per_page, orderby, desc, id, title, canvas_category=CanvasCategory.Agent):
|
||||||
agents = cls.model.select()
|
agents = cls.model.select()
|
||||||
if id:
|
if id:
|
||||||
agents = agents.where(cls.model.id == id)
|
agents = agents.where(cls.model.id == id)
|
||||||
if title:
|
if title:
|
||||||
agents = agents.where(cls.model.title == title)
|
agents = agents.where(cls.model.title == title)
|
||||||
agents = agents.where(cls.model.user_id == tenant_id)
|
agents = agents.where(cls.model.user_id == tenant_id)
|
||||||
|
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||||
if desc:
|
if desc:
|
||||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
@ -71,6 +78,7 @@ class UserCanvasService(CommonService):
|
|||||||
cls.model.create_time,
|
cls.model.create_time,
|
||||||
cls.model.create_date,
|
cls.model.create_date,
|
||||||
cls.model.update_date,
|
cls.model.update_date,
|
||||||
|
cls.model.canvas_category,
|
||||||
User.nickname,
|
User.nickname,
|
||||||
User.avatar.alias('tenant_avatar'),
|
User.avatar.alias('tenant_avatar'),
|
||||||
]
|
]
|
||||||
@ -87,7 +95,7 @@ class UserCanvasService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||||
page_number, items_per_page,
|
page_number, items_per_page,
|
||||||
orderby, desc, keywords,
|
orderby, desc, keywords, canvas_category=CanvasCategory.Agent,
|
||||||
):
|
):
|
||||||
fields = [
|
fields = [
|
||||||
cls.model.id,
|
cls.model.id,
|
||||||
@ -98,7 +106,8 @@ class UserCanvasService(CommonService):
|
|||||||
cls.model.permission,
|
cls.model.permission,
|
||||||
User.nickname,
|
User.nickname,
|
||||||
User.avatar.alias('tenant_avatar'),
|
User.avatar.alias('tenant_avatar'),
|
||||||
cls.model.update_time
|
cls.model.update_time,
|
||||||
|
cls.model.canvas_category,
|
||||||
]
|
]
|
||||||
if keywords:
|
if keywords:
|
||||||
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where(
|
||||||
@ -113,6 +122,7 @@ class UserCanvasService(CommonService):
|
|||||||
TenantPermission.TEAM.value)) | (
|
TenantPermission.TEAM.value)) | (
|
||||||
cls.model.user_id == user_id))
|
cls.model.user_id == user_id))
|
||||||
)
|
)
|
||||||
|
agents = agents.where(cls.model.canvas_category == canvas_category)
|
||||||
if desc:
|
if desc:
|
||||||
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
agents = agents.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user