Feat: Redesign and refactor agent module (#9113)

### What problem does this PR solve?

#9082 #6365

<u> **WARNING: it's not compatible with the older version of `Agent`
module, which means that `Agent` from older versions can not work
anymore.**</u>

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2025-07-30 19:41:09 +08:00
committed by GitHub
parent 07e37560fc
commit d9fe279dde
124 changed files with 7744 additions and 18226 deletions

View File

@ -14,20 +14,35 @@
# limitations under the License.
#
import json
import traceback
import logging
import re
import sys
from functools import partial
import trio
from flask import request, Response
from flask_login import login_required, current_user
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService
from agent.component import LLM
from api.db import FileType
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
from api.db.services.document_service import DocumentService
from api.db.services.file_service import FileService
from api.db.services.user_service import TenantService
from api.db.services.user_canvas_version import UserCanvasVersionService
from api.settings import RetCode
from api.utils import get_uuid
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result
from api.utils.api_utils import get_json_result, server_error_response, validate_request, get_data_error_result, \
get_error_data_result
from agent.canvas import Canvas
from peewee import MySQLDatabase, PostgresqlDatabase
from api.db.db_models import APIToken
import time
from api.utils.file_utils import filename_type, read_potential_broken_pdf
from rag.utils.redis_conn import REDIS_CONN
@manager.route('/templates', methods=['GET']) # noqa: F821
@login_required
def templates():
@ -112,8 +127,10 @@ def getsse(canvas_id):
@login_required
def run():
req = request.json
stream = req.get("stream", True)
running_hint_text = req.get("running_hint_text", "")
query = req.get("query", "")
files = req.get("files", [])
inputs = req.get("inputs", {})
user_id = req.get("user_id", current_user.id)
e, cvs = UserCanvasService.get_by_id(req["id"])
if not e:
return get_data_error_result(message="canvas not found.")
@ -125,68 +142,29 @@ def run():
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
final_ans = {"reference": [], "content": ""}
message_id = req.get("message_id", get_uuid())
try:
canvas = Canvas(cvs.dsl, current_user.id)
if "message" in req:
canvas.messages.append({"role": "user", "content": req["message"], "id": message_id})
canvas.add_user_input(req["message"])
canvas = Canvas(cvs.dsl, current_user.id, req["id"])
except Exception as e:
return server_error_response(e)
if stream:
def sse():
nonlocal answer, cvs
try:
for ans in canvas.run(running_hint_text = running_hint_text, stream=True):
if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "",
"data": {"answer": ans["content"],
"running_status": True}},
ensure_ascii=False) + "\n\n"
continue
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
def sse():
nonlocal canvas, user_id
try:
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
canvas.history.append(("assistant", final_ans["content"]))
if not canvas.path[-1]:
canvas.path.pop(-1)
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e:
cvs.dsl = json.loads(str(canvas))
if not canvas.path[-1]:
canvas.path.pop(-1)
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
traceback.print_exc()
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
except Exception as e:
logging.exception(e)
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
for answer in canvas.run(running_hint_text = running_hint_text, stream=False):
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
cvs.dsl = json.loads(str(canvas))
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
return get_json_result(data={"answer": final_ans["content"], "reference": final_ans.get("reference", [])})
resp = Response(sse(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route('/reset', methods=['POST']) # noqa: F821
@ -212,9 +190,84 @@ def reset():
return server_error_response(e)
@manager.route('/input_elements', methods=['GET']) # noqa: F821
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
def upload(canvas_id):
e, cvs = UserCanvasService.get_by_tenant_id(canvas_id)
if not e:
return get_data_error_result(message="canvas not found.")
user_id = cvs["user_id"]
def structured(filename, filetype, blob, content_type):
nonlocal user_id
if filetype == FileType.PDF.value:
blob = read_potential_broken_pdf(blob)
location = get_uuid()
FileService.put_blob(user_id, location, blob)
return {
"id": location,
"name": filename,
"size": sys.getsizeof(blob),
"extension": filename.split(".")[-1].lower(),
"mime_type": content_type,
"created_by": user_id,
"created_at": time.time(),
"preview_url": None
}
if request.args.get("url"):
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
CrawlerRunConfig,
DefaultMarkdownGenerator,
PruningContentFilter,
CrawlResult
)
try:
url = request.args.get("url")
filename = re.sub(r"\?.*", "", url.split("/")[-1])
async def adownload():
browser_config = BrowserConfig(
headless=True,
verbose=False,
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
markdown_generator=DefaultMarkdownGenerator(
content_filter=PruningContentFilter()
),
pdf=True,
screenshot=False
)
result: CrawlResult = await crawler.arun(
url=url,
config=crawler_config
)
return result
page = trio.run(adownload())
if page.pdf:
if filename.split(".")[-1].lower() != "pdf":
filename += ".pdf"
return get_json_result(data=structured(filename, "pdf", page.pdf, page.response_headers["content-type"]))
return get_json_result(data=structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id))
except Exception as e:
return server_error_response(e)
file = request.files['file']
try:
DocumentService.check_doc_health(user_id, file.filename)
return get_json_result(data=structured(file.filename, filename_type(file.filename), file.read(), file.content_type))
except Exception as e:
return server_error_response(e)
@manager.route('/input_form', methods=['GET']) # noqa: F821
@login_required
def input_elements():
def input_form():
cvs_id = request.args.get("id")
cpn_id = request.args.get("component_id")
try:
@ -227,7 +280,7 @@ def input_elements():
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
return get_json_result(data=canvas.get_component_input_elements(cpn_id))
return get_json_result(data=canvas.get_component_input_form(cpn_id))
except Exception as e:
return server_error_response(e)
@ -237,8 +290,6 @@ def input_elements():
@login_required
def debug():
req = request.json
for p in req["params"]:
assert p.get("key")
try:
e, user_canvas = UserCanvasService.get_by_id(req["id"])
if not e:
@ -249,11 +300,22 @@ def debug():
code=RetCode.OPERATING_ERROR)
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id)
componant = canvas.get_component(req["component_id"])["obj"]
componant.reset()
componant._param.debug_inputs = req["params"]
df = canvas.get_component(req["component_id"])["obj"].debug()
return get_json_result(data=df.to_dict(orient="records"))
canvas.reset()
canvas.message_id = get_uuid()
component = canvas.get_component(req["component_id"])["obj"]
component.reset()
if isinstance(component, LLM):
component.set_debug_inputs(req["params"])
component.invoke(**{k: o["value"] for k,o in req["params"].items()})
outputs = component.output()
for k in outputs.keys():
if isinstance(outputs[k], partial):
txt = ""
for c in outputs[k]():
txt += c
outputs[k] = txt
return get_json_result(data=outputs)
except Exception as e:
return server_error_response(e)
@ -292,6 +354,8 @@ def test_db_connect():
return get_json_result(data="Database Connection Successful!")
except Exception as e:
return server_error_response(e)
#api get list version dsl of canvas
@manager.route('/getlistversion/<canvas_id>', methods=['GET']) # noqa: F821
@login_required
@ -301,6 +365,8 @@ def getlistversion(canvas_id):
return get_json_result(data=list)
except Exception as e:
return get_data_error_result(message=f"Error getting history files: {e}")
#api get version dsl of canvas
@manager.route('/getversion/<version_id>', methods=['GET']) # noqa: F821
@login_required
@ -312,6 +378,8 @@ def getversion( version_id):
return get_json_result(data=version.to_dict())
except Exception as e:
return get_json_result(data=f"Error getting history file: {e}")
@manager.route('/listteam', methods=['GET']) # noqa: F821
@login_required
def list_kbs():
@ -328,6 +396,8 @@ def list_kbs():
return get_json_result(data={"kbs": kbs, "total": total})
except Exception as e:
return server_error_response(e)
@manager.route('/setting', methods=['POST']) # noqa: F821
@validate_request("id", "title", "permission")
@login_required
@ -351,3 +421,46 @@ def setting():
code=RetCode.OPERATING_ERROR)
num= UserCanvasService.update_by_id(req["id"], flow)
return get_json_result(data=num)
@manager.route('/trace', methods=['GET']) # noqa: F821
def trace():
cvs_id = request.args.get("canvas_id")
msg_id = request.args.get("message_id")
try:
bin = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
if not bin:
return get_json_result(data={})
return get_json_result(data=json.loads(bin.encode("utf-8")))
except Exception as e:
logging.exception(e)
@manager.route('/<canvas_id>/sessions', methods=['GET']) # noqa: F821
@login_required
def sessions(canvas_id):
tenant_id = current_user.id
if not UserCanvasService.query(user_id=tenant_id, id=canvas_id):
return get_error_data_result(message=f"You don't own the agent {canvas_id}.")
user_id = request.args.get("user_id")
page_number = int(request.args.get("page", 1))
items_per_page = int(request.args.get("page_size", 30))
keywords = request.args.get("keywords")
from_date = request.args.get("from_date")
to_date = request.args.get("to_date")
orderby = request.args.get("orderby", "update_time")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
None, user_id, include_dsl, keywords, from_date, to_date)
try:
return get_json_result(data={"total": total, "sessions": sess})
except Exception as e:
return server_error_response(e)

View File

@ -33,6 +33,7 @@ from api.db.services.user_service import UserTenantService
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, validate_request
from graphrag.general.mind_map_extractor import MindMapExtractor
from rag.app.tag import label_question
from rag.prompts.prompts import chunks_format
@manager.route("/set", methods=["POST"]) # noqa: F821
@ -90,25 +91,10 @@ def get():
else:
return get_json_result(data=False, message="Only owner of conversation authorized for this operation.", code=settings.RetCode.OPERATING_ERROR)
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = [
{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
"document_id": get_value(ck, "doc_id", "document_id"),
"document_name": get_value(ck, "docnm_kwd", "document_name"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
"image_id": get_value(ck, "image_id", "img_id"),
"positions": get_value(ck, "positions", "position_int"),
"doc_type": get_value(ck, "doc_type", "doc_type_kwd"),
}
for ck in ref.get("chunks", [])
]
ref["chunks"] = chunks_format(ref)
conv = conv.to_dict()
conv["avatar"] = avatar
@ -201,26 +187,10 @@ def completion():
if not conv.reference:
conv.reference = []
else:
def get_value(d, k1, k2):
return d.get(k1, d.get(k2))
for ref in conv.reference:
if isinstance(ref, list):
continue
ref["chunks"] = [
{
"id": get_value(ck, "chunk_id", "id"),
"content": get_value(ck, "content", "content_with_weight"),
"document_id": get_value(ck, "doc_id", "document_id"),
"document_name": get_value(ck, "docnm_kwd", "document_name"),
"dataset_id": get_value(ck, "kb_id", "dataset_id"),
"image_id": get_value(ck, "image_id", "img_id"),
"positions": get_value(ck, "positions", "position_int"),
"doc_type": get_value(ck, "doc_type_kwd", "doc_type_kwd"),
}
for ck in ref.get("chunks", [])
]
ref["chunks"] = chunks_format(ref)
if not conv.reference:
conv.reference = []

View File

@ -51,7 +51,7 @@ def set_dialog():
vector_similarity_weight = req.get("vector_similarity_weight", 0.3)
llm_setting = req.get("llm_setting", {})
prompt_config = req["prompt_config"]
if not req.get("kb_ids", []) and not prompt_config.get("tavily_api_key") and "{knowledge}" in prompt_config['system']:
return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no knowledge base/Tavily used here.")

View File

@ -556,7 +556,7 @@ def list_agent_session(tenant_id, agent_id):
desc = True
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id, include_dsl)
if not convs:
return get_result(data=[])
for conv in convs:
@ -817,9 +817,6 @@ def agent_bot_completions(agent_id):
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
if "quote" not in req:
req["quote"] = False
if req.get("stream", True):
resp = Response(agent_completion(objs[0].tenant_id, agent_id, **req), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
@ -830,3 +827,23 @@ def agent_bot_completions(agent_id):
for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
return get_result(data=answer)
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
def begin_inputs(agent_id):
token = request.headers.get("Authorization").split()
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!"')
token = token[1]
objs = APIToken.query(beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
e, cvs = UserCanvasService.get_by_id(agent_id)
if not e:
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
canvas = Canvas(json.dumps(cvs.dsl), objs[0].tenant_id)
return get_result(data=canvas.get_component_input_form("begin"))

View File

@ -463,6 +463,7 @@ class DataBaseModel(BaseModel):
@DB.connection_context()
@DB.lock("init_database_tables", 60)
def init_database_tables(alter_fields=[]):
members = inspect.getmembers(sys.modules[__name__], inspect.isclass)
table_objs = []
@ -474,7 +475,7 @@ def init_database_tables(alter_fields=[]):
if not obj.table_exists():
logging.debug(f"start create table {obj.__name__}")
try:
obj.create_table()
obj.create_table(safe=True)
logging.debug(f"create table success: {obj.__name__}")
except Exception as e:
logging.exception(e)
@ -798,6 +799,7 @@ class API4Conversation(DataBaseModel):
duration = FloatField(default=0, index=True)
round = IntegerField(default=0, index=True)
thumb_up = IntegerField(default=0, index=True)
errors = TextField(null=True, help_text="errors")
class Meta:
db_table = "api_4_conversation"
@ -1009,4 +1011,8 @@ def migrate_db():
migrate(migrator.add_column("document", "suffix", CharField(max_length=32, null=False, default="", help_text="The real file extension suffix", index=True)))
except Exception:
pass
logging.disable(logging.NOTSET)
try:
migrate(migrator.add_column("api_4_conversation", "errors", TextField(null=True, help_text="errors")))
except Exception:
pass
logging.disable(logging.NOTSET)

View File

@ -154,6 +154,11 @@ def init_llm_factory():
def add_graph_templates():
dir = os.path.join(get_project_base_directory(), "agent", "templates")
CanvasTemplateService.filter_delete([1 == 1])
if not os.path.exists(dir):
logging.warning("Missing agent templates!")
return
for fnm in os.listdir(dir):
try:
cnvs = json.load(open(os.path.join(dir, fnm), "r",encoding="utf-8"))
@ -162,7 +167,7 @@ def add_graph_templates():
except Exception:
CanvasTemplateService.update_by_id(cnvs["id"], cnvs)
except Exception:
logging.exception("Add graph templates error: ")
logging.exception("Add agent templates error: ")
def init_web_data():

View File

@ -43,7 +43,9 @@ class API4ConversationService(CommonService):
@DB.connection_context()
def get_list(cls, dialog_id, tenant_id,
page_number, items_per_page,
orderby, desc, id, user_id=None, include_dsl=True):
orderby, desc, id, user_id=None, include_dsl=True, keywords="",
from_date=None, to_date=None
):
if include_dsl:
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
else:
@ -53,13 +55,20 @@ class API4ConversationService(CommonService):
sessions = sessions.where(cls.model.id == id)
if user_id:
sessions = sessions.where(cls.model.user_id == user_id)
if keywords:
sessions = sessions.where(peewee.fn.LOWER(cls.model.message).contains(keywords.lower()))
if from_date:
sessions = sessions.where(cls.model.create_date >= from_date)
if to_date:
sessions = sessions.where(cls.model.create_date <= to_date)
if desc:
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
else:
sessions = sessions.order_by(cls.model.getter_by(orderby).asc())
count = sessions.count()
sessions = sessions.paginate(page_number, items_per_page)
return list(sessions.dicts())
return count, list(sessions.dicts())
@classmethod
@DB.connection_context()

View File

@ -14,6 +14,7 @@
# limitations under the License.
#
import json
import logging
import time
import traceback
from uuid import uuid4
@ -22,11 +23,12 @@ from api.db import TenantPermission
from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversation
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.conversation_service import structure_answer
from api.utils import get_uuid
from api.utils.api_utils import get_data_openai
import tiktoken
from peewee import fn
class CanvasTemplateService(CommonService):
model = CanvasTemplate
@ -79,7 +81,7 @@ class UserCanvasService(CommonService):
# obj = cls.model.query(id=pid)[0]
return True, agents.dicts()[0]
except Exception as e:
print(e)
logging.exception(e)
return False, None
@classmethod
@ -119,120 +121,58 @@ class UserCanvasService(CommonService):
count = agents.count()
agents = agents.paginate(page_number, items_per_page)
return list(agents.dicts()), count
def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
assert cvs.user_id == tenant_id, "You do not own the agent."
if not isinstance(cvs.dsl,str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
canvas = Canvas(cvs.dsl, tenant_id)
canvas.reset()
message_id = str(uuid4())
if not session_id:
query = canvas.get_preset_param()
if query:
for ele in query:
if not ele["optional"]:
if not kwargs.get(ele["key"]):
assert False, f"`{ele['key']}` is required"
ele["value"] = kwargs[ele["key"]]
if ele["optional"]:
if kwargs.get(ele["key"]):
ele["value"] = kwargs[ele['key']]
else:
if "value" in ele:
ele.pop("value")
cvs.dsl = json.loads(str(canvas))
def completion(tenant_id, agent_id, session_id=None, **kwargs):
query = kwargs.get("query", "")
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
if not conv.message:
conv.message = []
canvas = Canvas(json.dumps(conv.dsl), tenant_id, session_id)
else:
e, cvs = UserCanvasService.get_by_id(agent_id)
assert e, "Agent not found."
assert cvs.user_id == tenant_id, "You do not own the agent."
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, session_id)
conv = {
"id": session_id,
"dialog_id": cvs.id,
"user_id": kwargs.get("user_id", "") if isinstance(kwargs, dict) else "",
"message": [{"role": "assistant", "content": canvas.get_prologue(), "created_at": time.time()}],
"user_id": user_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl
}
API4ConversationService.save(**conv)
conv = API4Conversation(**conv)
else:
e, conv = API4ConversationService.get_by_id(session_id)
assert e, "Session not found!"
canvas = Canvas(json.dumps(conv.dsl), tenant_id)
canvas.messages.append({"role": "user", "content": question, "id": message_id})
canvas.add_user_input(question)
if not conv.message:
conv.message = []
conv.message.append({
"role": "user",
"content": question,
"id": message_id
})
if not conv.reference:
conv.reference = []
conv.reference.append({"chunks": [], "doc_aggs": []})
kwargs_changed = False
if kwargs:
query = canvas.get_preset_param()
if query:
for ele in query:
if ele["key"] in kwargs:
if ele["value"] != kwargs[ele["key"]]:
ele["value"] = kwargs[ele["key"]]
kwargs_changed = True
if kwargs_changed:
conv.dsl = json.loads(str(canvas))
API4ConversationService.update_by_id(session_id, {"dsl": conv.dsl})
message_id = str(uuid4())
conv.message.append({
"role": "user",
"content": query,
"id": message_id
})
txt = ""
for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
ans["session_id"] = session_id
if ans["event"] == "message":
txt += ans["data"]["content"]
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
final_ans = {"reference": [], "content": ""}
if stream:
try:
for ans in canvas.run(stream=stream):
if ans.get("running_status"):
yield "data:" + json.dumps({"code": 0, "message": "",
"data": {"answer": ans["content"],
"running_status": True}},
ensure_ascii=False) + "\n\n"
continue
for k in ans.keys():
final_ans[k] = ans[k]
ans = {"answer": ans["content"], "reference": ans.get("reference", []), "param": canvas.get_preset_param()}
ans = structure_answer(conv, ans, message_id, session_id)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
ensure_ascii=False) + "\n\n"
conv.message.append({"role": "assistant", "content": txt, "created_at": time.time(), "id": message_id})
conv.reference = canvas.get_reference()
conv.errors = canvas.error
API4ConversationService.append_message(conv.id, conv.to_dict())
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "created_at": time.time(), "id": message_id})
canvas.history.append(("assistant", final_ans["content"]))
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
except Exception as e:
traceback.print_exc()
conv.dsl = json.loads(str(canvas))
API4ConversationService.append_message(conv.id, conv.to_dict())
yield "data:" + json.dumps({"code": 500, "message": str(e),
"data": {"answer": "**ERROR**: " + str(e), "reference": []}},
ensure_ascii=False) + "\n\n"
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
else:
for answer in canvas.run(stream=False):
if answer.get("running_status"):
continue
final_ans["content"] = "\n".join(answer["content"]) if "content" in answer else ""
canvas.messages.append({"role": "assistant", "content": final_ans["content"], "id": message_id})
if final_ans.get("reference"):
canvas.reference.append(final_ans["reference"])
conv.dsl = json.loads(str(canvas))
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", []) , "param": canvas.get_preset_param()}
result = structure_answer(conv, result, message_id, session_id)
API4ConversationService.append_message(conv.id, conv.to_dict())
yield result
break
def completionOpenAI(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
"""Main function for OpenAI-compatible completions, structured similarly to the completion function."""
tiktokenenc = tiktoken.get_encoding("cl100k_base")

View File

@ -27,7 +27,7 @@ import xxhash
from peewee import fn
from api import settings
from api.constants import IMG_BASE64_PREFIX
from api.constants import IMG_BASE64_PREFIX, FILE_NAME_LEN_LIMIT
from api.db import FileType, LLMType, ParserType, StatusEnum, TaskStatus, UserTenantRole
from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTenant, File2Document, File
from api.db.db_utils import bulk_insert_into_db
@ -100,6 +100,17 @@ class DocumentService(CommonService):
docs = docs.paginate(page_number, items_per_page)
return list(docs.dicts()), count
@classmethod
@DB.connection_context()
def check_doc_health(cls, tenant_id: str, filename):
import os
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError("Exceed the maximum length of file name!")
return True
@classmethod
@DB.connection_context()
def get_by_kb_id(cls, kb_id, page_number, items_per_page,
@ -258,13 +269,13 @@ class DocumentService(CommonService):
)
if len(graph_source) > 0 and doc.id in list(graph_source.values())[0]["source_id"]:
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "source_id": doc.id},
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
{"remove": {"source_id": doc.id}},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.update({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["graph"]},
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
{"removed_kwd": "Y"},
search.index_name(tenant_id), doc.kb_id)
settings.docStoreConn.delete({"kb_id": doc.kb_id, "knowledge_graph_kwd": ["entity", "relation", "graph", "subgraph", "community_report"], "must_not": {"exists": "source_id"}},
search.index_name(tenant_id), doc.kb_id)
search.index_name(tenant_id), doc.kb_id)
except Exception:
pass
return cls.delete_by_id(doc.id)
@ -323,9 +334,9 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num +
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num +
chunk_num).where(
chunk_num).where(
Knowledgebase.id == kb_id).execute()
return num
@ -341,9 +352,9 @@ class DocumentService(CommonService):
"Document not found which is supposed to be there")
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
token_num,
token_num,
chunk_num=Knowledgebase.chunk_num -
chunk_num
chunk_num
).where(
Knowledgebase.id == kb_id).execute()
return num
@ -356,9 +367,9 @@ class DocumentService(CommonService):
num = Knowledgebase.update(
token_num=Knowledgebase.token_num -
doc.token_num,
doc.token_num,
chunk_num=Knowledgebase.chunk_num -
doc.chunk_num,
doc.chunk_num,
doc_num=Knowledgebase.doc_num - 1
).where(
Knowledgebase.id == doc.kb_id).execute()
@ -388,7 +399,7 @@ class DocumentService(CommonService):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
@ -410,7 +421,7 @@ class DocumentService(CommonService):
docs = cls.model.select(
Knowledgebase.tenant_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.name == name, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
@ -423,7 +434,7 @@ class DocumentService(CommonService):
docs = cls.model.select(
cls.model.id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
Knowledgebase.id == cls.model.kb_id)
).join(UserTenant, on=(UserTenant.tenant_id == Knowledgebase.tenant_id)
).where(cls.model.id == doc_id, UserTenant.user_id == user_id).paginate(0, 1)
docs = docs.dicts()
@ -435,12 +446,12 @@ class DocumentService(CommonService):
@DB.connection_context()
def accessible4deletion(cls, doc_id, user_id):
docs = cls.model.select(cls.model.id
).join(
).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)
Knowledgebase.id == cls.model.kb_id)
).join(
UserTenant, on=(
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
(UserTenant.tenant_id == Knowledgebase.created_by) & (UserTenant.user_id == user_id))
).where(
cls.model.id == doc_id,
UserTenant.status == StatusEnum.VALID.value,
@ -457,7 +468,7 @@ class DocumentService(CommonService):
docs = cls.model.select(
Knowledgebase.embd_id).join(
Knowledgebase, on=(
Knowledgebase.id == cls.model.kb_id)).where(
Knowledgebase.id == cls.model.kb_id)).where(
cls.model.id == doc_id, Knowledgebase.status == StatusEnum.VALID.value)
docs = docs.dicts()
if not docs:
@ -499,7 +510,7 @@ class DocumentService(CommonService):
if not doc_id:
return
return doc_id[0]["id"]
@classmethod
@DB.connection_context()
def get_doc_ids_by_doc_names(cls, doc_names):
@ -612,7 +623,7 @@ class DocumentService(CommonService):
info = {
"process_duration": datetime.timestamp(
datetime.now()) -
d["process_begin_at"].timestamp(),
d["process_begin_at"].timestamp(),
"run": status}
if prg != 0:
info["progress"] = prg

View File

@ -14,7 +14,6 @@
# limitations under the License.
#
import logging
import os
import re
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
@ -22,7 +21,6 @@ from pathlib import Path
from flask_login import current_user
from peewee import fn
from api.constants import FILE_NAME_LEN_LIMIT
from api.db import KNOWLEDGEBASE_FOLDER_NAME, FileSource, FileType, ParserType
from api.db.db_models import DB, Document, File, File2Document, Knowledgebase
from api.db.services import duplicate_name
@ -31,6 +29,7 @@ from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from api.utils import get_uuid
from api.utils.file_utils import filename_type, read_potential_broken_pdf, thumbnail_img
from rag.llm.cv_model import GptV4
from rag.utils.storage_factory import STORAGE_IMPL
@ -411,12 +410,7 @@ class FileService(CommonService):
err, files = [], []
for file in file_objs:
try:
MAX_FILE_NUM_PER_USER = int(os.environ.get("MAX_FILE_NUM_PER_USER", 0))
if MAX_FILE_NUM_PER_USER > 0 and DocumentService.get_doc_count(kb.tenant_id) >= MAX_FILE_NUM_PER_USER:
raise RuntimeError("Exceed the maximum file number of a free user!")
if len(file.filename.encode("utf-8")) > FILE_NAME_LEN_LIMIT:
raise RuntimeError(f"File name must be {FILE_NAME_LEN_LIMIT} bytes or less.")
DocumentService.check_doc_health(kb.tenant_id, file.filename)
filename = duplicate_name(DocumentService.query, name=file.filename, kb_id=kb.id)
filetype = filename_type(filename)
if filetype == FileType.OTHER.value:
@ -463,6 +457,19 @@ class FileService(CommonService):
@staticmethod
def parse_docs(file_objs, user_id):
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
threads.append(exe.submit(FileService.parse, file.filename, file.read(), False))
res = []
for th in threads:
res.append(th.result())
return "\n\n".join(res)
@staticmethod
def parse(filename, blob, img_base64=True, tenant_id=None):
from rag.app import audio, email, naive, picture, presentation
def dummy(prog=None, msg=""):
@ -470,19 +477,12 @@ class FileService(CommonService):
FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 16096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text"}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
for file in file_objs:
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": user_id}
filetype = filename_type(file.filename)
blob = file.read()
threads.append(exe.submit(FACTORY.get(FileService.get_parser(filetype, file.filename, ""), naive).chunk, file.filename, blob, **kwargs))
res = []
for th in threads:
res.append("\n".join([ck["content_with_weight"] for ck in th.result()]))
return "\n\n".join(res)
kwargs = {"lang": "English", "callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": current_user.id if current_user else tenant_id}
file_type = filename_type(filename)
if img_base64 and file_type == FileType.VISUAL.value:
return GptV4.image2base64(blob)
cks = FACTORY.get(FileService.get_parser(filename_type(filename), filename, ""), naive).chunk(filename, blob, **kwargs)
return "\n".join([ck["content_with_weight"] for ck in cks])
@staticmethod
def get_parser(doc_type, filename, default):
@ -495,3 +495,14 @@ class FileService(CommonService):
if re.search(r"\.(eml)$", filename):
return ParserType.EMAIL.value
return default
@staticmethod
def get_blob(user_id, location):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.get(bname, location)
@staticmethod
def put_blob(user_id, location, blob):
bname = f"{user_id}-downloads"
return STORAGE_IMPL.put(bname, location, blob)

View File

@ -14,6 +14,8 @@
# limitations under the License.
#
import logging
import re
from functools import partial
from langfuse import Langfuse
@ -137,7 +139,7 @@ class TenantLLMService(CommonService):
@classmethod
@DB.connection_context()
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese"):
def model_instance(cls, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
if llm_type == LLMType.EMBEDDING.value:
if model_config["llm_factory"] not in EmbeddingModel:
@ -152,12 +154,12 @@ class TenantLLMService(CommonService):
if llm_type == LLMType.IMAGE2TEXT.value:
if model_config["llm_factory"] not in CvModel:
return
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"])
return CvModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], lang, base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.CHAT.value:
if model_config["llm_factory"] not in ChatModel:
return
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"])
return ChatModel[model_config["llm_factory"]](model_config["api_key"], model_config["llm_name"], base_url=model_config["api_base"], **kwargs)
if llm_type == LLMType.SPEECH2TEXT:
if model_config["llm_factory"] not in Seq2txtModel:
@ -221,20 +223,21 @@ class TenantLLMService(CommonService):
for llm_factory in llm_factories:
for llm in llm_factory["llm"]:
if llm_id == llm["llm_name"]:
return llm["model_type"].strip(",")[-1]
return llm["model_type"].split(",")[-1]
class LLMBundle:
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese", **kwargs):
self.tenant_id = tenant_id
self.llm_type = llm_type
self.llm_name = llm_name
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang)
self.mdl = TenantLLMService.model_instance(tenant_id, llm_type, llm_name, lang=lang, **kwargs)
assert self.mdl, "Can't find model for {}/{}/{}".format(tenant_id, llm_type, llm_name)
model_config = TenantLLMService.get_model_config(tenant_id, llm_type, llm_name)
self.max_length = model_config.get("max_tokens", 8192)
self.is_tools = model_config.get("is_tools", False)
self.verbose_tool_use = kwargs.get("verbose_tool_use")
langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=tenant_id)
if langfuse_keys:
@ -331,7 +334,7 @@ class LLMBundle:
return txt
def tts(self, text):
def tts(self, text: str) -> None:
if self.langfuse:
span = self.trace.span(name="tts", input={"text": text})
@ -359,17 +362,20 @@ class LLMBundle:
return txt[last_think_end + len("</think>") :]
def chat(self, system, history, gen_conf):
def chat(self, system: str, history: list, gen_conf: dict={}, **kwargs) -> str:
if self.langfuse:
generation = self.trace.generation(name="chat", model=self.llm_name, input={"system": system, "history": history})
chat = self.mdl.chat
chat_partial = partial(self.mdl.chat, system, history, gen_conf)
if self.is_tools and self.mdl.is_tools:
chat = self.mdl.chat_with_tools
chat_partial = partial(self.mdl.chat_with_tools, system, history, gen_conf)
txt, used_tokens = chat(system, history, gen_conf)
txt, used_tokens = chat_partial(**kwargs)
txt = self._remove_reasoning_content(txt)
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
if isinstance(txt, int) and not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, used_tokens, self.llm_name):
logging.error("LLMBundle.chat can't update token usage for {}/CHAT llm_name: {}, used_tokens: {}".format(self.tenant_id, self.llm_name, used_tokens))
@ -378,17 +384,17 @@ class LLMBundle:
return txt
def chat_streamly(self, system, history, gen_conf):
def chat_streamly(self, system: str, history: list, gen_conf: dict={}, **kwargs):
if self.langfuse:
generation = self.trace.generation(name="chat_streamly", model=self.llm_name, input={"system": system, "history": history})
ans = ""
chat_streamly = self.mdl.chat_streamly
chat_partial = partial(self.mdl.chat_streamly, system, history, gen_conf)
total_tokens = 0
if self.is_tools and self.mdl.is_tools:
chat_streamly = self.mdl.chat_streamly_with_tools
chat_partial = partial(self.mdl.chat_streamly_with_tools, system, history, gen_conf)
for txt in chat_streamly(system, history, gen_conf):
for txt in chat_partial(**kwargs):
if isinstance(txt, int):
total_tokens = txt
if self.langfuse:
@ -398,8 +404,12 @@ class LLMBundle:
if txt.endswith("</think>"):
ans = ans.rstrip("</think>")
if not self.verbose_tool_use:
txt = re.sub(r"<tool_call>.*?</tool_call>", "", txt, flags=re.DOTALL)
ans += txt
yield ans
if total_tokens > 0:
if not TenantLLMService.increase_usage(self.tenant_id, self.llm_type, txt, self.llm_name):
logging.error("LLMBundle.chat_streamly can't update token usage for {}/CHAT llm_name: {}, content: {}".format(self.tenant_id, self.llm_name, txt))

View File

@ -15,6 +15,7 @@
#
import base64
import datetime
import hashlib
import io
import json
import os
@ -405,3 +406,7 @@ def download_img(url):
def delta_seconds(date_string: str):
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
return (datetime.datetime.now() - dt).total_seconds()
def hash_str2int(line:str, mod: int=10 ** 8) -> int:
return int(hashlib.sha1(line.encode("utf-8")).hexdigest(), 16) % mod