mirror of
https://github.com/infiniflow/ragflow.git
synced 2025-12-08 12:32:30 +08:00
Feat: add webhook component. (#11033)
### What problem does this PR solve? #10427 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -277,6 +277,14 @@ class Canvas(Graph):
|
|||||||
for k, cpn in self.components.items():
|
for k, cpn in self.components.items():
|
||||||
self.components[k]["obj"].reset(True)
|
self.components[k]["obj"].reset(True)
|
||||||
|
|
||||||
|
if kwargs.get("webhook_payload"):
|
||||||
|
for k, cpn in self.components.items():
|
||||||
|
if self.components[k]["obj"].component_name.lower() == "webhook":
|
||||||
|
for kk, vv in kwargs["webhook_payload"].items():
|
||||||
|
self.components[k]["obj"].set_output(kk, vv)
|
||||||
|
|
||||||
|
self.components[k]["obj"].reset(True)
|
||||||
|
|
||||||
for k in kwargs.keys():
|
for k in kwargs.keys():
|
||||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||||
if k == "files":
|
if k == "files":
|
||||||
|
|||||||
@ -216,7 +216,7 @@ class LLM(ComponentBase):
|
|||||||
error: str = ""
|
error: str = ""
|
||||||
output_structure=None
|
output_structure=None
|
||||||
try:
|
try:
|
||||||
output_structure = None#self._param.outputs['structured']
|
output_structure = self._param.outputs['structured']
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if output_structure:
|
if output_structure:
|
||||||
|
|||||||
38
agent/component/webhook.py
Normal file
38
agent/component/webhook.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
#
|
||||||
|
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
from agent.component.base import ComponentParamBase, ComponentBase
|
||||||
|
|
||||||
|
|
||||||
|
class WebhookParam(ComponentParamBase):
|
||||||
|
|
||||||
|
"""
|
||||||
|
Define the Begin component parameters.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_input_form(self) -> dict[str, dict]:
|
||||||
|
return getattr(self, "inputs")
|
||||||
|
|
||||||
|
|
||||||
|
class Webhook(ComponentBase):
|
||||||
|
component_name = "Webhook"
|
||||||
|
|
||||||
|
def _invoke(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def thoughts(self) -> str:
|
||||||
|
return ""
|
||||||
@ -46,8 +46,8 @@ def set_connector():
|
|||||||
"status": TaskStatus.SCHEDULE
|
"status": TaskStatus.SCHEDULE
|
||||||
}
|
}
|
||||||
conn["status"] = TaskStatus.SCHEDULE
|
conn["status"] = TaskStatus.SCHEDULE
|
||||||
|
ConnectorService.save(**conn)
|
||||||
|
|
||||||
ConnectorService.save(**conn)
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
e, conn = ConnectorService.get_by_id(req["id"])
|
e, conn = ConnectorService.get_by_id(req["id"])
|
||||||
|
|
||||||
|
|||||||
@ -104,6 +104,10 @@ def update():
|
|||||||
message="Duplicated knowledgebase name.")
|
message="Duplicated knowledgebase name.")
|
||||||
|
|
||||||
del req["kb_id"]
|
del req["kb_id"]
|
||||||
|
connectors = []
|
||||||
|
if "connectors" in req:
|
||||||
|
connectors = req["connectors"]
|
||||||
|
del req["connectors"]
|
||||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||||
return get_data_error_result()
|
return get_data_error_result()
|
||||||
|
|
||||||
@ -120,6 +124,10 @@ def update():
|
|||||||
if not e:
|
if not e:
|
||||||
return get_data_error_result(
|
return get_data_error_result(
|
||||||
message="Database error (Knowledgebase rename)!")
|
message="Database error (Knowledgebase rename)!")
|
||||||
|
if connectors:
|
||||||
|
errors = Connector2KbService.link_connectors(kb.id, [conn["id"] for conn in connectors], current_user.id)
|
||||||
|
if errors:
|
||||||
|
logging.error("Link KB errors: ", errors)
|
||||||
kb = kb.to_dict()
|
kb = kb.to_dict()
|
||||||
kb.update(req)
|
kb.update(req)
|
||||||
|
|
||||||
@ -892,13 +900,3 @@ def check_embedding():
|
|||||||
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
return get_json_result(code=RetCode.NOT_EFFECTIVE, message="failed", data={"summary": summary, "results": results})
|
||||||
|
|
||||||
|
|
||||||
@manager.route("/<kb_id>/link", methods=["POST"]) # noqa: F821
|
|
||||||
@validate_request("connector_ids")
|
|
||||||
@login_required
|
|
||||||
def link_connector(kb_id):
|
|
||||||
req = request.json
|
|
||||||
errors = Connector2KbService.link_connectors(kb_id, req["connector_ids"], current_user.id)
|
|
||||||
if errors:
|
|
||||||
return get_json_result(data=False, message=errors, code=RetCode.SERVER_ERROR)
|
|
||||||
return get_json_result(data=True)
|
|
||||||
|
|
||||||
|
|||||||
@ -15,15 +15,19 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from agent.canvas import Canvas
|
||||||
|
from api.db import CanvasCategory
|
||||||
from api.db.services.canvas_service import UserCanvasService
|
from api.db.services.canvas_service import UserCanvasService
|
||||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||||
from common.constants import RetCode
|
from common.constants import RetCode
|
||||||
from common.misc_utils import get_uuid
|
from common.misc_utils import get_uuid
|
||||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, token_required
|
||||||
from api.utils.api_utils import get_result
|
from api.utils.api_utils import get_result
|
||||||
from flask import request
|
from flask import request, Response
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||||
@ -127,3 +131,49 @@ def delete_agent(tenant_id: str, agent_id: str):
|
|||||||
|
|
||||||
UserCanvasService.delete_by_id(agent_id)
|
UserCanvasService.delete_by_id(agent_id)
|
||||||
return get_json_result(data=True)
|
return get_json_result(data=True)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||||
|
@token_required
|
||||||
|
def webhook(tenant_id: str, agent_id: str):
|
||||||
|
req = request.json
|
||||||
|
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
|
||||||
|
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||||
|
if not e:
|
||||||
|
return get_data_error_result(message="canvas not found.")
|
||||||
|
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
|
||||||
|
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||||
|
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||||
|
except Exception as e:
|
||||||
|
return get_json_result(
|
||||||
|
data=False, message=str(e),
|
||||||
|
code=RetCode.EXCEPTION_ERROR)
|
||||||
|
|
||||||
|
def sse():
|
||||||
|
nonlocal canvas
|
||||||
|
try:
|
||||||
|
for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||||
|
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
cvs.dsl = json.loads(str(canvas))
|
||||||
|
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception(e)
|
||||||
|
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||||
|
|
||||||
|
resp = Response(sse(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
|||||||
@ -39,10 +39,14 @@ class ConnectorService(CommonService):
|
|||||||
if not task:
|
if not task:
|
||||||
if status == TaskStatus.SCHEDULE:
|
if status == TaskStatus.SCHEDULE:
|
||||||
SyncLogsService.schedule(connector_id, c2k.kb_id)
|
SyncLogsService.schedule(connector_id, c2k.kb_id)
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||||
|
return
|
||||||
|
|
||||||
if task.status == TaskStatus.DONE:
|
if task.status == TaskStatus.DONE:
|
||||||
if status == TaskStatus.SCHEDULE:
|
if status == TaskStatus.SCHEDULE:
|
||||||
SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed)
|
SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed)
|
||||||
|
ConnectorService.update_by_id(connector_id, {"status": status})
|
||||||
|
return
|
||||||
|
|
||||||
task = task.to_dict()
|
task = task.to_dict()
|
||||||
task["status"] = status
|
task["status"] = status
|
||||||
@ -72,16 +76,19 @@ class SyncLogsService(CommonService):
|
|||||||
cls.model.id,
|
cls.model.id,
|
||||||
cls.model.connector_id,
|
cls.model.connector_id,
|
||||||
cls.model.kb_id,
|
cls.model.kb_id,
|
||||||
|
cls.model.update_date,
|
||||||
cls.model.poll_range_start,
|
cls.model.poll_range_start,
|
||||||
cls.model.poll_range_end,
|
cls.model.poll_range_end,
|
||||||
cls.model.new_docs_indexed,
|
cls.model.new_docs_indexed,
|
||||||
cls.model.error_msg,
|
cls.model.total_docs_indexed,
|
||||||
|
cls.model.full_exception_trace,
|
||||||
cls.model.error_count,
|
cls.model.error_count,
|
||||||
Connector.name,
|
Connector.name,
|
||||||
Connector.source,
|
Connector.source,
|
||||||
Connector.tenant_id,
|
Connector.tenant_id,
|
||||||
Connector.timeout_secs,
|
Connector.timeout_secs,
|
||||||
Knowledgebase.name.alias("kb_name"),
|
Knowledgebase.name.alias("kb_name"),
|
||||||
|
Knowledgebase.avatar.alias("kb_avatar"),
|
||||||
cls.model.from_beginning.alias("reindex"),
|
cls.model.from_beginning.alias("reindex"),
|
||||||
cls.model.status
|
cls.model.status
|
||||||
]
|
]
|
||||||
@ -128,7 +135,7 @@ class SyncLogsService(CommonService):
|
|||||||
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.")
|
||||||
return None
|
return None
|
||||||
reindex = "1" if reindex else "0"
|
reindex = "1" if reindex else "0"
|
||||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE})
|
||||||
return cls.save(**{
|
return cls.save(**{
|
||||||
"id": get_uuid(),
|
"id": get_uuid(),
|
||||||
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
"kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id,
|
||||||
@ -145,7 +152,7 @@ class SyncLogsService(CommonService):
|
|||||||
full_exception_trace=cls.model.full_exception_trace + str(e)
|
full_exception_trace=cls.model.full_exception_trace + str(e)
|
||||||
) \
|
) \
|
||||||
.where(cls.model.id == task.id).execute()
|
.where(cls.model.id == task.id).execute()
|
||||||
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDUL})
|
ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE})
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
def increase_docs(cls, id, min_update, max_update, doc_num, err_msg="", error_count=0):
|
||||||
|
|||||||
@ -115,7 +115,7 @@ export default {
|
|||||||
similarityThreshold: '相似度閾值',
|
similarityThreshold: '相似度閾值',
|
||||||
similarityThresholdTip:
|
similarityThresholdTip:
|
||||||
'我們使用混合相似度得分來評估兩行文本之間的距離。它是加權關鍵詞相似度和向量餘弦相似度。如果查詢和塊之間的相似度小於此閾值,則該塊將被過濾掉。預設值設定為 0.2,也就是說,文本塊的混合相似度得分至少要 20 才會被檢索。',
|
'我們使用混合相似度得分來評估兩行文本之間的距離。它是加權關鍵詞相似度和向量餘弦相似度。如果查詢和塊之間的相似度小於此閾值,則該塊將被過濾掉。預設值設定為 0.2,也就是說,文本塊的混合相似度得分至少要 20 才會被檢索。',
|
||||||
vectorSimilarityWeight: '關鍵字相似度權重',
|
vectorSimilarityWeight: '矢量相似度權重',
|
||||||
vectorSimilarityWeightTip:
|
vectorSimilarityWeightTip:
|
||||||
'我們使用混合相似性評分來評估兩行文本之間的距離。它是加權關鍵字相似性和矢量餘弦相似性或rerank得分(0〜1)。兩個權重的總和為1.0。',
|
'我們使用混合相似性評分來評估兩行文本之間的距離。它是加權關鍵字相似性和矢量餘弦相似性或rerank得分(0〜1)。兩個權重的總和為1.0。',
|
||||||
testText: '測試文本',
|
testText: '測試文本',
|
||||||
|
|||||||
Reference in New Issue
Block a user