diff --git a/agent/canvas.py b/agent/canvas.py index 003c993c2..c971409e3 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -277,6 +277,14 @@ class Canvas(Graph): for k, cpn in self.components.items(): self.components[k]["obj"].reset(True) + if kwargs.get("webhook_payload"): + for k, cpn in self.components.items(): + if self.components[k]["obj"].component_name.lower() == "webhook": + for kk, vv in kwargs["webhook_payload"].items(): + self.components[k]["obj"].set_output(kk, vv) + + self.components[k]["obj"].reset(True) + for k in kwargs.keys(): if k in ["query", "user_id", "files"] and kwargs[k]: if k == "files": diff --git a/agent/component/llm.py b/agent/component/llm.py index d2ed1514d..c8383835b 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -216,7 +216,7 @@ class LLM(ComponentBase): error: str = "" output_structure=None try: - output_structure = None#self._param.outputs['structured'] + output_structure = self._param.outputs['structured'] except Exception: pass if output_structure: diff --git a/agent/component/webhook.py b/agent/component/webhook.py new file mode 100644 index 000000000..c707d4556 --- /dev/null +++ b/agent/component/webhook.py @@ -0,0 +1,38 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from agent.component.base import ComponentParamBase, ComponentBase + + +class WebhookParam(ComponentParamBase): + + """ + Define the Begin component parameters. + """ + def __init__(self): + super().__init__() + + def get_input_form(self) -> dict[str, dict]: + return getattr(self, "inputs") + + +class Webhook(ComponentBase): + component_name = "Webhook" + + def _invoke(self, **kwargs): + pass + + def thoughts(self) -> str: + return "" diff --git a/api/apps/connector_app.py b/api/apps/connector_app.py index f463c7f39..d6f756fef 100644 --- a/api/apps/connector_app.py +++ b/api/apps/connector_app.py @@ -46,8 +46,8 @@ def set_connector(): "status": TaskStatus.SCHEDULE } conn["status"] = TaskStatus.SCHEDULE + ConnectorService.save(**conn) - ConnectorService.save(**conn) time.sleep(1) e, conn = ConnectorService.get_by_id(req["id"]) diff --git a/api/apps/kb_app.py b/api/apps/kb_app.py index 28c2fa31e..ac0afaa48 100644 --- a/api/apps/kb_app.py +++ b/api/apps/kb_app.py @@ -104,6 +104,10 @@ def update(): message="Duplicated knowledgebase name.") del req["kb_id"] + connectors = [] + if "connectors" in req: + connectors = req["connectors"] + del req["connectors"] if not KnowledgebaseService.update_by_id(kb.id, req): return get_data_error_result() @@ -120,6 +124,10 @@ def update(): if not e: return get_data_error_result( message="Database error (Knowledgebase rename)!") + if connectors: + errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id) + if errors: + logging.error("Link KB errors: ", errors) kb = kb.to_dict() kb.update(req) @@ -892,13 +900,3 @@ def check_embedding(): return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results}) -@manager.route("//link", methods=["POST"]) # noqa: F821 -@validate_request("connector_ids") -@login_required -def link_connector(kb_id): - req = request.json - errors = Connector2KbService.link_connectors(kb_id, req["connector_ids"], current_user.id) - if errors: - return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR) - return get_json_result(data=True) - diff --git a/api/apps/sdk/agent.py b/api/apps/sdk/agent.py index d0c19de7b..208b7a1be 100644 --- a/api/apps/sdk/agent.py +++ b/api/apps/sdk/agent.py @@ -15,15 +15,19 @@ # import json +import logging import time from typing import Any, cast + +from agent.canvas import Canvas +from api.db import CanvasCategory from api.db.services.canvas_service import UserCanvasService from api.db.services.user_canvas_version import UserCanvasVersionService from common.constants import RetCode from common.misc_utils import get_uuid from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required from api.utils.api_utils import get_result -from flask import request +from flask import request, Response @manager.route('/agents', methods=['GET']) # noqa: F821 @@ -127,3 +131,49 @@ def delete_agent(tenant_id: str, agent_id: str): UserCanvasService.delete_by_id(agent_id) return get_json_result(data=True) + + +@manager.route('/webhook/', methods=['POST']) # noqa: F821 +@token_required +def webhook(tenant_id: str, agent_id: str): + req = request.json + if not UserCanvasService.accessible(req["id"], tenant_id): + return get_json_result( + data=False, message='Only owner of canvas authorized for this operation.', + code=RetCode.OPERATING_ERROR) + + e, cvs = UserCanvasService.get_by_id(req["id"]) + if not e: + return get_data_error_result(message="canvas not found.") + + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + + if cvs.canvas_category == CanvasCategory.DataFlow: + return get_data_error_result(message="Dataflow can not be triggered by webhook.") + + try: + canvas = Canvas(cvs.dsl, tenant_id, agent_id) + except Exception as e: + return get_json_result( + data=False, message=str(e), + code=RetCode.EXCEPTION_ERROR) + + def sse(): + nonlocal canvas + try: + for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req): + yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + + cvs.dsl = json.loads(str(canvas)) + UserCanvasService.update_by_id(req["id"], cvs.to_dict()) + except Exception as e: + logging.exception(e) + yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n" + + resp = Response(sse(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 3ba00596a..bc54dc34b 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -39,10 +39,14 @@ class ConnectorService(CommonService): if not task: if status == TaskStatus.SCHEDULE: SyncLogsService.schedule(connector_id, c2k.kb_id) + ConnectorService.update_by_id(connector_id, {"status": status}) + return if task.status == TaskStatus.DONE: if status == TaskStatus.SCHEDULE: SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed) + ConnectorService.update_by_id(connector_id, {"status": status}) + return task = task.to_dict() task["status"] = status @@ -72,16 +76,19 @@ class SyncLogsService(CommonService): cls.model.id, cls.model.connector_id, cls.model.kb_id, + cls.model.update_date, cls.model.poll_range_start, cls.model.poll_range_end, cls.model.new_docs_indexed, - cls.model.error_msg, + cls.model.total_docs_indexed, + cls.model.full_exception_trace, cls.model.error_count, Connector.name, Connector.source, Connector.tenant_id, Connector.timeout_secs, Knowledgebase.name.alias("kb_name"), + Knowledgebase.avatar.alias("kb_avatar"), cls.model.from_beginning.alias("reindex"), cls.model.status ] @@ -128,7 +135,7 @@ class SyncLogsService(CommonService): logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") return None reindex = "1" if reindex else "0" - ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL}) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) return cls.save(**{ "id": get_uuid(), "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, @@ -145,7 +152,7 @@ class SyncLogsService(CommonService): full_exception_trace=cls.model.full_exception_trace + str(e) ) \ .where(cls.model.id == task.id).execute() - ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL}) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) @classmethod def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0): diff --git a/web/src/locales/zh-traditional.ts b/web/src/locales/zh-traditional.ts index 076cd4cb4..c946ea9f8 100644 --- a/web/src/locales/zh-traditional.ts +++ b/web/src/locales/zh-traditional.ts @@ -115,7 +115,7 @@ export default { similarityThreshold: '相似度閾值', similarityThresholdTip: '我們使用混合相似度得分來評估兩行文本之間的距離。它是加權關鍵詞相似度和向量餘弦相似度。如果查詢和塊之間的相似度小於此閾值,則該塊將被過濾掉。預設值設定為 0.2,也就是說,文本塊的混合相似度得分至少要 20 才會被檢索。', - vectorSimilarityWeight: '關鍵字相似度權重', + vectorSimilarityWeight: '矢量相似度權重', vectorSimilarityWeightTip: '我們使用混合相似性評分來評估兩行文本之間的距離。它是加權關鍵字相似性和矢量餘弦相似性或rerank得分(0〜1)。兩個權重的總和為1.0。', testText: '測試文本',