From 1262533b749535ad58260ff6f7eb0140e1a73978 Mon Sep 17 00:00:00 2001 From: Kevin Hu Date: Thu, 5 Feb 2026 19:19:09 +0800 Subject: [PATCH] Feat: support verify to set llm key and boost bigrams. (#12980) #12863 ### Type of change - [x] New Feature (non-breaking change which adds functionality) --- api/apps/canvas_app.py | 93 +++++++++++++++++++++++++++++++++- api/apps/llm_app.py | 71 +++++++++++++++++++++----- api/db/db_models.py | 4 ++ api/db/services/api_service.py | 17 ++++++- rag/nlp/query.py | 9 ++-- 5 files changed, 175 insertions(+), 19 deletions(-) diff --git a/api/apps/canvas_app.py b/api/apps/canvas_app.py index 228498e50..0d0b71b45 100644 --- a/api/apps/canvas_app.py +++ b/api/apps/canvas_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy import inspect import json import logging @@ -46,6 +47,7 @@ from rag.nlp import search from rag.utils.redis_conn import REDIS_CONN from common import settings from api.apps import login_required, current_user +from api.db.services.canvas_service import completion as agent_completion @manager.route('/templates', methods=['GET']) # noqa: F821 @@ -184,6 +186,50 @@ async def run(): return resp +@manager.route("//completion", methods=["POST"]) # noqa: F821 +@login_required +async def exp_agent_completion(canvas_id): + tenant_id = current_user.id + req = await get_request_json() + return_trace = bool(req.get("return_trace", False)) + async def generate(): + trace_items = [] + async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req): + if isinstance(answer, str): + try: + ans = json.loads(answer[5:]) # remove "data:" + except Exception: + continue + + event = ans.get("event") + if event == "node_finished": + if return_trace: + data = ans.get("data", {}) + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + ans.setdefault("data", {})["trace"] = trace_items + answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + yield answer + + if event not in ["message", "message_end"]: + continue + + yield answer + + yield "data:[DONE]\n\n" + + resp = Response(generate(), mimetype="text/event-stream") + resp.headers.add_header("Cache-control", "no-cache") + resp.headers.add_header("Connection", "keep-alive") + resp.headers.add_header("X-Accel-Buffering", "no") + resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") + return resp + + @manager.route('/rerun', methods=['POST']) # noqa: F821 @validate_request("id", "dsl", "component_id") @login_required @@ -532,20 +578,65 @@ def sessions(canvas_id): from_date = request.args.get("from_date") to_date = request.args.get("to_date") orderby = request.args.get("orderby", "update_time") + exp_user_id = request.args.get("exp_user_id") if request.args.get("desc") == "False" or request.args.get("desc") == "false": desc = False else: desc = True + + if exp_user_id: + sess = API4ConversationService.get_names(canvas_id, exp_user_id) + return get_json_result(data={"total": len(sess), "sessions": sess}) + # dsl defaults to True in all cases except for False and false include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false" total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc, - None, user_id, include_dsl, keywords, from_date, to_date) + None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id) try: return get_json_result(data={"total": total, "sessions": sess}) except Exception as e: return server_error_response(e) +@manager.route('//sessions', methods=['PUT']) # noqa: F821 +@login_required +async def set_session(canvas_id): + req = await get_request_json() + tenant_id = current_user.id + e, cvs = UserCanvasService.get_by_id(canvas_id) + assert e, "Agent not found." + if not isinstance(cvs.dsl, str): + cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False) + session_id=get_uuid() + canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id) + canvas.reset() + conv = { + "id": session_id, + "name": req.get("name", ""), + "dialog_id": cvs.id, + "user_id": tenant_id, + "exp_user_id": tenant_id, + "message": [], + "source": "agent", + "dsl": cvs.dsl, + "reference": [] + } + API4ConversationService.save(**conv) + return get_json_result(data=conv) + + +@manager.route('//sessions/', methods=['GET']) # noqa: F821 +@login_required +def get_session(canvas_id, session_id): + tenant_id = current_user.id + if not UserCanvasService.accessible(canvas_id, tenant_id): + return get_json_result( + data=False, message='Only owner of canvas authorized for this operation.', + code=RetCode.OPERATING_ERROR) + conv = API4ConversationService.get_by_id(session_id) + return get_json_result(data=conv.to_dict()) + + @manager.route('/prompts', methods=['GET']) # noqa: F821 @login_required def prompts(): diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 628261340..9d2fed802 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import asyncio import logging import json import os @@ -64,13 +65,17 @@ async def set_api_key(): chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] extra = {"provider": factory} + timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) msg = "" for llm in LLMService.query(fid=factory): if not embd_passed and llm.model_type == LLMType.EMBEDDING.value: assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: - arr, tc = mdl.encode(["Test if the api key is available"]) + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.encode, ["Test if the api key is available"]), + timeout=timeout_seconds, + ) if len(arr[0]) == 0: raise Exception("Fail") embd_passed = True @@ -80,17 +85,27 @@ async def set_api_key(): assert factory in ChatModel, f"Chat model from {factory} is not supported yet." mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra) try: - m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50}) + m, tc = await asyncio.wait_for( + mdl.async_chat( + None, + [{"role": "user", "content": "Hello! How are you doing!"}], + {"temperature": 0.9, "max_tokens": 50}, + ), + timeout=timeout_seconds, + ) if m.find("**ERROR**") >= 0: raise Exception(m) chat_passed = True except Exception as e: msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e) - elif not rerank_passed and llm.model_type == LLMType.RERANK: + elif not rerank_passed and llm.model_type == LLMType.RERANK.value: assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet." mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url")) try: - arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"]) + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]), + timeout=timeout_seconds, + ) if len(arr) == 0 or tc == 0: raise Exception("Fail") rerank_passed = True @@ -101,6 +116,9 @@ async def set_api_key(): msg = "" break + if req.get("verify", False): + return get_json_result(data={"message": msg, "success": len(msg.strip())==0}) + if msg: return get_data_error_result(message=msg) @@ -133,6 +151,7 @@ async def add_llm(): factory = req["llm_factory"] api_key = req.get("api_key", "x") llm_name = req.get("llm_name") + timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) if factory not in [f.name for f in get_allowed_llm_factories()]: return get_data_error_result(message=f"LLM factory {factory} is not allowed") @@ -215,7 +234,10 @@ async def add_llm(): assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet." mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) try: - arr, tc = mdl.encode(["Test if the api key is available"]) + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.encode, ["Test if the api key is available"]), + timeout=timeout_seconds, + ) if len(arr[0]) == 0: raise Exception("Fail") except Exception as e: @@ -229,7 +251,14 @@ async def add_llm(): **extra, ) try: - m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9}) + m, tc = await asyncio.wait_for( + mdl.async_chat( + None, + [{"role": "user", "content": "Hello! How are you doing!"}], + {"temperature": 0.9}, + ), + timeout=timeout_seconds, + ) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: @@ -239,7 +268,10 @@ async def add_llm(): assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet." try: mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) - arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]) + arr, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]), + timeout=timeout_seconds, + ) if len(arr) == 0: raise Exception("Not known.") except KeyError: @@ -252,7 +284,10 @@ async def add_llm(): mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) try: image_data = test_image - m, tc = mdl.describe(image_data) + m, tc = await asyncio.wait_for( + asyncio.to_thread(mdl.describe, image_data), + timeout=timeout_seconds, + ) if not tc and m.find("**ERROR**:") >= 0: raise Exception(m) except Exception as e: @@ -261,20 +296,29 @@ async def add_llm(): assert factory in TTSModel, f"TTS model from {factory} is not supported yet." mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) try: - for resp in mdl.tts("Hello~ RAGFlower!"): - pass + def drain_tts(): + for _ in mdl.tts("Hello~ RAGFlower!"): + pass + + await asyncio.wait_for( + asyncio.to_thread(drain_tts), + timeout=timeout_seconds, + ) except RuntimeError as e: msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) case LLMType.OCR.value: assert factory in OcrModel, f"OCR model from {factory} is not supported yet." try: mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) - ok, reason = mdl.check_available() + ok, reason = await asyncio.wait_for( + asyncio.to_thread(mdl.check_available), + timeout=timeout_seconds, + ) if not ok: raise RuntimeError(reason or "Model not available") except Exception as e: msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) - case LLMType.SPEECH2TEXT: + case LLMType.SPEECH2TEXT.value: assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet." try: mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) @@ -284,6 +328,9 @@ async def add_llm(): case _: raise RuntimeError(f"Unknown model type: {model_type}") + if req.get("verify", False): + return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0}) + if msg: return get_data_error_result(message=msg) diff --git a/api/db/db_models.py b/api/db/db_models.py index 84b3b110a..ca72be210 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -991,8 +991,10 @@ class APIToken(DataBaseModel): class API4Conversation(DataBaseModel): id = CharField(max_length=32, primary_key=True) + name = CharField(max_length=255, null=True, help_text="conversation name", index=False) dialog_id = CharField(max_length=32, null=False, index=True) user_id = CharField(max_length=255, null=False, help_text="user_id", index=True) + exp_user_id = CharField(max_length=255, null=True, help_text="exp_user_id", index=True) message = JSONField(null=True) reference = JSONField(null=True, default=[]) tokens = IntegerField(default=0) @@ -1376,6 +1378,8 @@ def migrate_db(): alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True)) alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False)) alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False)) + alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False)) + alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True)) # Migrate system_settings.value from CharField to TextField for longer sandbox configs alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)")) logging.disable(logging.NOTSET) diff --git a/api/db/services/api_service.py b/api/db/services/api_service.py index aee35422b..be41dc1b6 100644 --- a/api/db/services/api_service.py +++ b/api/db/services/api_service.py @@ -48,8 +48,8 @@ class API4ConversationService(CommonService): @DB.connection_context() def get_list(cls, dialog_id, tenant_id, page_number, items_per_page, - orderby, desc, id, user_id=None, include_dsl=True, keywords="", - from_date=None, to_date=None + orderby, desc, id=None, user_id=None, include_dsl=True, keywords="", + from_date=None, to_date=None, exp_user_id=None ): if include_dsl: sessions = cls.model.select().where(cls.model.dialog_id == dialog_id) @@ -66,6 +66,8 @@ class API4ConversationService(CommonService): sessions = sessions.where(cls.model.create_date >= from_date) if to_date: sessions = sessions.where(cls.model.create_date <= to_date) + if exp_user_id: + sessions = sessions.where(cls.model.exp_user_id == exp_user_id) if desc: sessions = sessions.order_by(cls.model.getter_by(orderby).desc()) else: @@ -74,6 +76,17 @@ class API4ConversationService(CommonService): sessions = sessions.paginate(page_number, items_per_page) return count, list(sessions.dicts()) + + @classmethod + @DB.connection_context() + def get_names(cls, dialog_id, exp_user_id): + fields = [cls.model.id, cls.model.name,] + sessions = cls.model.select(*fields).where( + cls.model.dialog_id == dialog_id, + cls.model.exp_user_id == exp_user_id + ).order_by(cls.model.getter_by("create_date").desc()) + + return list(sessions.dicts()) @classmethod @DB.connection_context() diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 096cfa4ce..39b6b439d 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -55,13 +55,11 @@ class FulltextQueryer(QueryBase): keywords = [t for t in tks if t] tks_w = self.tw.weights(tks, preprocess=False) tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] - tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()] syns = [] for tk, w in tks_w[:256]: - syn = self.syn.lookup(tk) - syn = rag_tokenizer.tokenize(" ".join(syn)).split() + syn = [rag_tokenizer.tokenize(s) for s in self.syn.lookup(tk)] keywords.extend(syn) syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()] syns.append(" ".join(syn)) @@ -190,7 +188,10 @@ class FulltextQueryer(QueryBase): d = defaultdict(int) wts = self.tw.weights(tks, preprocess=False) for i, (t, c) in enumerate(wts): - d[t] += c + d[t] += c * 0.4 + if i+1 < len(wts): + _t, _c = wts[i+1] + d[t+_t] += max(c, _c) * 0.6 return d atks = to_dict(atks)