Feat: support verify to set llm key and boost bigrams. (#12980)

#12863

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Kevin Hu
2026-02-05 19:19:09 +08:00
committed by GitHub
parent bbd8ba64a1
commit 1262533b74
5 changed files with 175 additions and 19 deletions

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import inspect
import json
import logging
@ -46,6 +47,7 @@ from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common import settings
from api.apps import login_required, current_user
from api.db.services.canvas_service import completion as agent_completion
@manager.route('/templates', methods=['GET']) # noqa: F821
@ -184,6 +186,50 @@ async def run():
return resp
@manager.route("/<canvas_id>/completion", methods=["POST"]) # noqa: F821
@login_required
async def exp_agent_completion(canvas_id):
tenant_id = current_user.id
req = await get_request_json()
return_trace = bool(req.get("return_trace", False))
async def generate():
trace_items = []
async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
if isinstance(answer, str):
try:
ans = json.loads(answer[5:]) # remove "data:"
except Exception:
continue
event = ans.get("event")
if event == "node_finished":
if return_trace:
data = ans.get("data", {})
trace_items.append(
{
"component_id": data.get("component_id"),
"trace": [copy.deepcopy(data)],
}
)
ans.setdefault("data", {})["trace"] = trace_items
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
yield answer
if event not in ["message", "message_end"]:
continue
yield answer
yield "data:[DONE]\n\n"
resp = Response(generate(), mimetype="text/event-stream")
resp.headers.add_header("Cache-control", "no-cache")
resp.headers.add_header("Connection", "keep-alive")
resp.headers.add_header("X-Accel-Buffering", "no")
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
return resp
@manager.route('/rerun', methods=['POST']) # noqa: F821
@validate_request("id", "dsl", "component_id")
@login_required
@ -532,20 +578,65 @@ def sessions(canvas_id):
from_date = request.args.get("from_date")
to_date = request.args.get("to_date")
orderby = request.args.get("orderby", "update_time")
exp_user_id = request.args.get("exp_user_id")
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
desc = False
else:
desc = True
if exp_user_id:
sess = API4ConversationService.get_names(canvas_id, exp_user_id)
return get_json_result(data={"total": len(sess), "sessions": sess})
# dsl defaults to True in all cases except for False and false
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
None, user_id, include_dsl, keywords, from_date, to_date)
None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
try:
return get_json_result(data={"total": total, "sessions": sess})
except Exception as e:
return server_error_response(e)
@manager.route('/<canvas_id>/sessions', methods=['PUT']) # noqa: F821
@login_required
async def set_session(canvas_id):
req = await get_request_json()
tenant_id = current_user.id
e, cvs = UserCanvasService.get_by_id(canvas_id)
assert e, "Agent not found."
if not isinstance(cvs.dsl, str):
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
session_id=get_uuid()
canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
canvas.reset()
conv = {
"id": session_id,
"name": req.get("name", ""),
"dialog_id": cvs.id,
"user_id": tenant_id,
"exp_user_id": tenant_id,
"message": [],
"source": "agent",
"dsl": cvs.dsl,
"reference": []
}
API4ConversationService.save(**conv)
return get_json_result(data=conv)
@manager.route('/<canvas_id>/sessions/<session_id>', methods=['GET']) # noqa: F821
@login_required
def get_session(canvas_id, session_id):
tenant_id = current_user.id
if not UserCanvasService.accessible(canvas_id, tenant_id):
return get_json_result(
data=False, message='Only owner of canvas authorized for this operation.',
code=RetCode.OPERATING_ERROR)
conv = API4ConversationService.get_by_id(session_id)
return get_json_result(data=conv.to_dict())
@manager.route('/prompts', methods=['GET']) # noqa: F821
@login_required
def prompts():

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import logging
import json
import os
@ -64,13 +65,17 @@ async def set_api_key():
chat_passed, embd_passed, rerank_passed = False, False, False
factory = req["llm_factory"]
extra = {"provider": factory}
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
msg = ""
for llm in LLMService.query(fid=factory):
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
arr, tc = mdl.encode(["Test if the api key is available"])
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
timeout=timeout_seconds,
)
if len(arr[0]) == 0:
raise Exception("Fail")
embd_passed = True
@ -80,17 +85,27 @@ async def set_api_key():
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
m, tc = await asyncio.wait_for(
mdl.async_chat(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9, "max_tokens": 50},
),
timeout=timeout_seconds,
)
if m.find("**ERROR**") >= 0:
raise Exception(m)
chat_passed = True
except Exception as e:
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
elif not rerank_passed and llm.model_type == LLMType.RERANK:
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
try:
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
timeout=timeout_seconds,
)
if len(arr) == 0 or tc == 0:
raise Exception("Fail")
rerank_passed = True
@ -101,6 +116,9 @@ async def set_api_key():
msg = ""
break
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
if msg:
return get_data_error_result(message=msg)
@ -133,6 +151,7 @@ async def add_llm():
factory = req["llm_factory"]
api_key = req.get("api_key", "x")
llm_name = req.get("llm_name")
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
if factory not in [f.name for f in get_allowed_llm_factories()]:
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
@ -215,7 +234,10 @@ async def add_llm():
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
arr, tc = mdl.encode(["Test if the api key is available"])
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
timeout=timeout_seconds,
)
if len(arr[0]) == 0:
raise Exception("Fail")
except Exception as e:
@ -229,7 +251,14 @@ async def add_llm():
**extra,
)
try:
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
m, tc = await asyncio.wait_for(
mdl.async_chat(
None,
[{"role": "user", "content": "Hello! How are you doing!"}],
{"temperature": 0.9},
),
timeout=timeout_seconds,
)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@ -239,7 +268,10 @@ async def add_llm():
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
try:
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
arr, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
timeout=timeout_seconds,
)
if len(arr) == 0:
raise Exception("Not known.")
except KeyError:
@ -252,7 +284,10 @@ async def add_llm():
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
image_data = test_image
m, tc = mdl.describe(image_data)
m, tc = await asyncio.wait_for(
asyncio.to_thread(mdl.describe, image_data),
timeout=timeout_seconds,
)
if not tc and m.find("**ERROR**:") >= 0:
raise Exception(m)
except Exception as e:
@ -261,20 +296,29 @@ async def add_llm():
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
try:
for resp in mdl.tts("Hello~ RAGFlower!"):
pass
def drain_tts():
for _ in mdl.tts("Hello~ RAGFlower!"):
pass
await asyncio.wait_for(
asyncio.to_thread(drain_tts),
timeout=timeout_seconds,
)
except RuntimeError as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.OCR.value:
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
try:
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
ok, reason = mdl.check_available()
ok, reason = await asyncio.wait_for(
asyncio.to_thread(mdl.check_available),
timeout=timeout_seconds,
)
if not ok:
raise RuntimeError(reason or "Model not available")
except Exception as e:
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
case LLMType.SPEECH2TEXT:
case LLMType.SPEECH2TEXT.value:
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
try:
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
@ -284,6 +328,9 @@ async def add_llm():
case _:
raise RuntimeError(f"Unknown model type: {model_type}")
if req.get("verify", False):
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
if msg:
return get_data_error_result(message=msg)