mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-02-06 10:35:06 +08:00
Feat: support verify to set llm key and boost bigrams. (#12980)
#12863 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@ -46,6 +47,7 @@ from rag.nlp import search
|
|||||||
from rag.utils.redis_conn import REDIS_CONN
|
from rag.utils.redis_conn import REDIS_CONN
|
||||||
from common import settings
|
from common import settings
|
||||||
from api.apps import login_required, current_user
|
from api.apps import login_required, current_user
|
||||||
|
from api.db.services.canvas_service import completion as agent_completion
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||||
@ -184,6 +186,50 @@ async def run():
|
|||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route("/<canvas_id>/completion", methods=["POST"]) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def exp_agent_completion(canvas_id):
|
||||||
|
tenant_id = current_user.id
|
||||||
|
req = await get_request_json()
|
||||||
|
return_trace = bool(req.get("return_trace", False))
|
||||||
|
async def generate():
|
||||||
|
trace_items = []
|
||||||
|
async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
|
||||||
|
if isinstance(answer, str):
|
||||||
|
try:
|
||||||
|
ans = json.loads(answer[5:]) # remove "data:"
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
event = ans.get("event")
|
||||||
|
if event == "node_finished":
|
||||||
|
if return_trace:
|
||||||
|
data = ans.get("data", {})
|
||||||
|
trace_items.append(
|
||||||
|
{
|
||||||
|
"component_id": data.get("component_id"),
|
||||||
|
"trace": [copy.deepcopy(data)],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ans.setdefault("data", {})["trace"] = trace_items
|
||||||
|
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
if event not in ["message", "message_end"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield answer
|
||||||
|
|
||||||
|
yield "data:[DONE]\n\n"
|
||||||
|
|
||||||
|
resp = Response(generate(), mimetype="text/event-stream")
|
||||||
|
resp.headers.add_header("Cache-control", "no-cache")
|
||||||
|
resp.headers.add_header("Connection", "keep-alive")
|
||||||
|
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||||
|
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||||
@validate_request("id", "dsl", "component_id")
|
@validate_request("id", "dsl", "component_id")
|
||||||
@login_required
|
@login_required
|
||||||
@ -532,20 +578,65 @@ def sessions(canvas_id):
|
|||||||
from_date = request.args.get("from_date")
|
from_date = request.args.get("from_date")
|
||||||
to_date = request.args.get("to_date")
|
to_date = request.args.get("to_date")
|
||||||
orderby = request.args.get("orderby", "update_time")
|
orderby = request.args.get("orderby", "update_time")
|
||||||
|
exp_user_id = request.args.get("exp_user_id")
|
||||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||||
desc = False
|
desc = False
|
||||||
else:
|
else:
|
||||||
desc = True
|
desc = True
|
||||||
|
|
||||||
|
if exp_user_id:
|
||||||
|
sess = API4ConversationService.get_names(canvas_id, exp_user_id)
|
||||||
|
return get_json_result(data={"total": len(sess), "sessions": sess})
|
||||||
|
|
||||||
# dsl defaults to True in all cases except for False and false
|
# dsl defaults to True in all cases except for False and false
|
||||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||||
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
|
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
|
||||||
None, user_id, include_dsl, keywords, from_date, to_date)
|
None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
|
||||||
try:
|
try:
|
||||||
return get_json_result(data={"total": total, "sessions": sess})
|
return get_json_result(data={"total": total, "sessions": sess})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return server_error_response(e)
|
return server_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/<canvas_id>/sessions', methods=['PUT']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
async def set_session(canvas_id):
|
||||||
|
req = await get_request_json()
|
||||||
|
tenant_id = current_user.id
|
||||||
|
e, cvs = UserCanvasService.get_by_id(canvas_id)
|
||||||
|
assert e, "Agent not found."
|
||||||
|
if not isinstance(cvs.dsl, str):
|
||||||
|
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||||
|
session_id=get_uuid()
|
||||||
|
canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
|
||||||
|
canvas.reset()
|
||||||
|
conv = {
|
||||||
|
"id": session_id,
|
||||||
|
"name": req.get("name", ""),
|
||||||
|
"dialog_id": cvs.id,
|
||||||
|
"user_id": tenant_id,
|
||||||
|
"exp_user_id": tenant_id,
|
||||||
|
"message": [],
|
||||||
|
"source": "agent",
|
||||||
|
"dsl": cvs.dsl,
|
||||||
|
"reference": []
|
||||||
|
}
|
||||||
|
API4ConversationService.save(**conv)
|
||||||
|
return get_json_result(data=conv)
|
||||||
|
|
||||||
|
|
||||||
|
@manager.route('/<canvas_id>/sessions/<session_id>', methods=['GET']) # noqa: F821
|
||||||
|
@login_required
|
||||||
|
def get_session(canvas_id, session_id):
|
||||||
|
tenant_id = current_user.id
|
||||||
|
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||||
|
return get_json_result(
|
||||||
|
data=False, message='Only owner of canvas authorized for this operation.',
|
||||||
|
code=RetCode.OPERATING_ERROR)
|
||||||
|
conv = API4ConversationService.get_by_id(session_id)
|
||||||
|
return get_json_result(data=conv.to_dict())
|
||||||
|
|
||||||
|
|
||||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||||
@login_required
|
@login_required
|
||||||
def prompts():
|
def prompts():
|
||||||
|
|||||||
@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@ -64,13 +65,17 @@ async def set_api_key():
|
|||||||
chat_passed, embd_passed, rerank_passed = False, False, False
|
chat_passed, embd_passed, rerank_passed = False, False, False
|
||||||
factory = req["llm_factory"]
|
factory = req["llm_factory"]
|
||||||
extra = {"provider": factory}
|
extra = {"provider": factory}
|
||||||
|
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||||
msg = ""
|
msg = ""
|
||||||
for llm in LLMService.query(fid=factory):
|
for llm in LLMService.query(fid=factory):
|
||||||
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
if not embd_passed and llm.model_type == LLMType.EMBEDDING.value:
|
||||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||||
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
mdl = EmbeddingModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||||
try:
|
try:
|
||||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
arr, tc = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if len(arr[0]) == 0:
|
if len(arr[0]) == 0:
|
||||||
raise Exception("Fail")
|
raise Exception("Fail")
|
||||||
embd_passed = True
|
embd_passed = True
|
||||||
@ -80,17 +85,27 @@ async def set_api_key():
|
|||||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||||
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
mdl = ChatModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"), **extra)
|
||||||
try:
|
try:
|
||||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9, "max_tokens": 50})
|
m, tc = await asyncio.wait_for(
|
||||||
|
mdl.async_chat(
|
||||||
|
None,
|
||||||
|
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||||
|
{"temperature": 0.9, "max_tokens": 50},
|
||||||
|
),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if m.find("**ERROR**") >= 0:
|
if m.find("**ERROR**") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
chat_passed = True
|
chat_passed = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
msg += f"\nFail to access model({llm.fid}/{llm.llm_name}) using this api key." + str(e)
|
||||||
elif not rerank_passed and llm.model_type == LLMType.RERANK:
|
elif not rerank_passed and llm.model_type == LLMType.RERANK.value:
|
||||||
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
assert factory in RerankModel, f"Re-rank model from {factory} is not supported yet."
|
||||||
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
mdl = RerankModel[factory](req["api_key"], llm.llm_name, base_url=req.get("base_url"))
|
||||||
try:
|
try:
|
||||||
arr, tc = mdl.similarity("What's the weather?", ["Is it sunny today?"])
|
arr, tc = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.similarity, "What's the weather?", ["Is it sunny today?"]),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if len(arr) == 0 or tc == 0:
|
if len(arr) == 0 or tc == 0:
|
||||||
raise Exception("Fail")
|
raise Exception("Fail")
|
||||||
rerank_passed = True
|
rerank_passed = True
|
||||||
@ -101,6 +116,9 @@ async def set_api_key():
|
|||||||
msg = ""
|
msg = ""
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if req.get("verify", False):
|
||||||
|
return get_json_result(data={"message": msg, "success": len(msg.strip())==0})
|
||||||
|
|
||||||
if msg:
|
if msg:
|
||||||
return get_data_error_result(message=msg)
|
return get_data_error_result(message=msg)
|
||||||
|
|
||||||
@ -133,6 +151,7 @@ async def add_llm():
|
|||||||
factory = req["llm_factory"]
|
factory = req["llm_factory"]
|
||||||
api_key = req.get("api_key", "x")
|
api_key = req.get("api_key", "x")
|
||||||
llm_name = req.get("llm_name")
|
llm_name = req.get("llm_name")
|
||||||
|
timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10))
|
||||||
|
|
||||||
if factory not in [f.name for f in get_allowed_llm_factories()]:
|
if factory not in [f.name for f in get_allowed_llm_factories()]:
|
||||||
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
|
return get_data_error_result(message=f"LLM factory {factory} is not allowed")
|
||||||
@ -215,7 +234,10 @@ async def add_llm():
|
|||||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||||
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
try:
|
try:
|
||||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
arr, tc = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.encode, ["Test if the api key is available"]),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if len(arr[0]) == 0:
|
if len(arr[0]) == 0:
|
||||||
raise Exception("Fail")
|
raise Exception("Fail")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -229,7 +251,14 @@ async def add_llm():
|
|||||||
**extra,
|
**extra,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
m, tc = await asyncio.wait_for(
|
||||||
|
mdl.async_chat(
|
||||||
|
None,
|
||||||
|
[{"role": "user", "content": "Hello! How are you doing!"}],
|
||||||
|
{"temperature": 0.9},
|
||||||
|
),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -239,7 +268,10 @@ async def add_llm():
|
|||||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||||
try:
|
try:
|
||||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
arr, tc = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.similarity, "Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"]),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if len(arr) == 0:
|
if len(arr) == 0:
|
||||||
raise Exception("Not known.")
|
raise Exception("Not known.")
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -252,7 +284,10 @@ async def add_llm():
|
|||||||
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
try:
|
try:
|
||||||
image_data = test_image
|
image_data = test_image
|
||||||
m, tc = mdl.describe(image_data)
|
m, tc = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.describe, image_data),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if not tc and m.find("**ERROR**:") >= 0:
|
if not tc and m.find("**ERROR**:") >= 0:
|
||||||
raise Exception(m)
|
raise Exception(m)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -261,20 +296,29 @@ async def add_llm():
|
|||||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||||
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
try:
|
try:
|
||||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
def drain_tts():
|
||||||
pass
|
for _ in mdl.tts("Hello~ RAGFlower!"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(drain_tts),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
case LLMType.OCR.value:
|
case LLMType.OCR.value:
|
||||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||||
try:
|
try:
|
||||||
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
ok, reason = mdl.check_available()
|
ok, reason = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(mdl.check_available),
|
||||||
|
timeout=timeout_seconds,
|
||||||
|
)
|
||||||
if not ok:
|
if not ok:
|
||||||
raise RuntimeError(reason or "Model not available")
|
raise RuntimeError(reason or "Model not available")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||||
case LLMType.SPEECH2TEXT:
|
case LLMType.SPEECH2TEXT.value:
|
||||||
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||||
try:
|
try:
|
||||||
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||||
@ -284,6 +328,9 @@ async def add_llm():
|
|||||||
case _:
|
case _:
|
||||||
raise RuntimeError(f"Unknown model type: {model_type}")
|
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||||
|
|
||||||
|
if req.get("verify", False):
|
||||||
|
return get_json_result(data={"message": msg, "success": len(msg.strip()) == 0})
|
||||||
|
|
||||||
if msg:
|
if msg:
|
||||||
return get_data_error_result(message=msg)
|
return get_data_error_result(message=msg)
|
||||||
|
|
||||||
|
|||||||
@ -991,8 +991,10 @@ class APIToken(DataBaseModel):
|
|||||||
|
|
||||||
class API4Conversation(DataBaseModel):
|
class API4Conversation(DataBaseModel):
|
||||||
id = CharField(max_length=32, primary_key=True)
|
id = CharField(max_length=32, primary_key=True)
|
||||||
|
name = CharField(max_length=255, null=True, help_text="conversation name", index=False)
|
||||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||||
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
|
user_id = CharField(max_length=255, null=False, help_text="user_id", index=True)
|
||||||
|
exp_user_id = CharField(max_length=255, null=True, help_text="exp_user_id", index=True)
|
||||||
message = JSONField(null=True)
|
message = JSONField(null=True)
|
||||||
reference = JSONField(null=True, default=[])
|
reference = JSONField(null=True, default=[])
|
||||||
tokens = IntegerField(default=0)
|
tokens = IntegerField(default=0)
|
||||||
@ -1376,6 +1378,8 @@ def migrate_db():
|
|||||||
alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
|
alter_db_add_column(migrator, "tenant_llm", "status", CharField(max_length=1, null=False, help_text="is it validate(0: wasted, 1: validate)", default="1", index=True))
|
||||||
alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
|
alter_db_add_column(migrator, "connector2kb", "auto_parse", CharField(max_length=1, null=False, default="1", index=False))
|
||||||
alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
|
alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False))
|
||||||
|
alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False))
|
||||||
|
alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True))
|
||||||
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
|
# Migrate system_settings.value from CharField to TextField for longer sandbox configs
|
||||||
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
|
alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)"))
|
||||||
logging.disable(logging.NOTSET)
|
logging.disable(logging.NOTSET)
|
||||||
|
|||||||
@ -48,8 +48,8 @@ class API4ConversationService(CommonService):
|
|||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def get_list(cls, dialog_id, tenant_id,
|
def get_list(cls, dialog_id, tenant_id,
|
||||||
page_number, items_per_page,
|
page_number, items_per_page,
|
||||||
orderby, desc, id, user_id=None, include_dsl=True, keywords="",
|
orderby, desc, id=None, user_id=None, include_dsl=True, keywords="",
|
||||||
from_date=None, to_date=None
|
from_date=None, to_date=None, exp_user_id=None
|
||||||
):
|
):
|
||||||
if include_dsl:
|
if include_dsl:
|
||||||
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||||
@ -66,6 +66,8 @@ class API4ConversationService(CommonService):
|
|||||||
sessions = sessions.where(cls.model.create_date >= from_date)
|
sessions = sessions.where(cls.model.create_date >= from_date)
|
||||||
if to_date:
|
if to_date:
|
||||||
sessions = sessions.where(cls.model.create_date <= to_date)
|
sessions = sessions.where(cls.model.create_date <= to_date)
|
||||||
|
if exp_user_id:
|
||||||
|
sessions = sessions.where(cls.model.exp_user_id == exp_user_id)
|
||||||
if desc:
|
if desc:
|
||||||
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
sessions = sessions.order_by(cls.model.getter_by(orderby).desc())
|
||||||
else:
|
else:
|
||||||
@ -75,6 +77,17 @@ class API4ConversationService(CommonService):
|
|||||||
|
|
||||||
return count, list(sessions.dicts())
|
return count, list(sessions.dicts())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@DB.connection_context()
|
||||||
|
def get_names(cls, dialog_id, exp_user_id):
|
||||||
|
fields = [cls.model.id, cls.model.name,]
|
||||||
|
sessions = cls.model.select(*fields).where(
|
||||||
|
cls.model.dialog_id == dialog_id,
|
||||||
|
cls.model.exp_user_id == exp_user_id
|
||||||
|
).order_by(cls.model.getter_by("create_date").desc())
|
||||||
|
|
||||||
|
return list(sessions.dicts())
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@DB.connection_context()
|
@DB.connection_context()
|
||||||
def append_message(cls, id, conversation):
|
def append_message(cls, id, conversation):
|
||||||
|
|||||||
@ -55,13 +55,11 @@ class FulltextQueryer(QueryBase):
|
|||||||
keywords = [t for t in tks if t]
|
keywords = [t for t in tks if t]
|
||||||
tks_w = self.tw.weights(tks, preprocess=False)
|
tks_w = self.tw.weights(tks, preprocess=False)
|
||||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
|
||||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||||
syns = []
|
syns = []
|
||||||
for tk, w in tks_w[:256]:
|
for tk, w in tks_w[:256]:
|
||||||
syn = self.syn.lookup(tk)
|
syn = [rag_tokenizer.tokenize(s) for s in self.syn.lookup(tk)]
|
||||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
|
||||||
keywords.extend(syn)
|
keywords.extend(syn)
|
||||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||||
syns.append(" ".join(syn))
|
syns.append(" ".join(syn))
|
||||||
@ -190,7 +188,10 @@ class FulltextQueryer(QueryBase):
|
|||||||
d = defaultdict(int)
|
d = defaultdict(int)
|
||||||
wts = self.tw.weights(tks, preprocess=False)
|
wts = self.tw.weights(tks, preprocess=False)
|
||||||
for i, (t, c) in enumerate(wts):
|
for i, (t, c) in enumerate(wts):
|
||||||
d[t] += c
|
d[t] += c * 0.4
|
||||||
|
if i+1 < len(wts):
|
||||||
|
_t, _c = wts[i+1]
|
||||||
|
d[t+_t] += max(c, _c) * 0.6
|
||||||
return d
|
return d
|
||||||
|
|
||||||
atks = to_dict(atks)
|
atks = to_dict(atks)
|
||||||
|
|||||||
Reference in New Issue
Block a user