mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-01-04 03:25:30 +08:00
Compare commits
93 Commits
672958a192
...
main_backu
| Author | SHA1 | Date | |
|---|---|---|---|
| 427e0540ca | |||
| 8f16fac898 | |||
| ea00478a21 | |||
| 906f19e863 | |||
| 667dc5467e | |||
| 977962fdfe | |||
| 1e9374a373 | |||
| 9b52ba8061 | |||
| 44671ea413 | |||
| c81421d340 | |||
| ee93a80e91 | |||
| 5c981978c1 | |||
| 7fef285af5 | |||
| b1efb905e5 | |||
| 6400bf87ba | |||
| f239bc02d3 | |||
| 5776fa73a7 | |||
| fc6af1998b | |||
| 0588fe79b9 | |||
| f545265f93 | |||
| c987d33649 | |||
| d72debf0db | |||
| c33134ea2c | |||
| 17b8bb62b6 | |||
| bab6a4a219 | |||
| 6c93157b14 | |||
| 033029eaa1 | |||
| a958ddb27a | |||
| f63f007326 | |||
| b47f1afa35 | |||
| 2369be7244 | |||
| 00bb6fbd28 | |||
| 063b06494a | |||
| b824185a3a | |||
| 8e6ddd7c1b | |||
| d1bc7ad2ee | |||
| 321474fb97 | |||
| ea89e4e0c6 | |||
| 9e31631d8f | |||
| 712d537d66 | |||
| bd4eb19393 | |||
| 02efab7c11 | |||
| 8ce129bc51 | |||
| d5a44e913d | |||
| 1444de981c | |||
| bd76b8ff1a | |||
| a95f22fa88 | |||
| 38ac6a7c27 | |||
| e5f3d5ae26 | |||
| 4cbc91f2fa | |||
| 6d3d3a40ab | |||
| 51b12841d6 | |||
| 993bf7c2c8 | |||
| b42b5fcf65 | |||
| 5d391fb1f9 | |||
| 2ddfcc7cf6 | |||
| 5ba51b21c9 | |||
| 3ea84ad9c8 | |||
| 0a5dce50fb | |||
| 6c9afd1ffb | |||
| bfef96d56e | |||
| 74adf3d59c | |||
| ba7e087aef | |||
| f911aa2997 | |||
| 42f9ac997f | |||
| c7cf7aad4e | |||
| 2118bc2556 | |||
| b49eb6826b | |||
| 8dd2394e93 | |||
| 5aea82d9c4 | |||
| 47005ebe10 | |||
| 3ee47e4af7 | |||
| 55c0468ac9 | |||
| eeb36a5ce7 | |||
| aceca266ff | |||
| d82e502a71 | |||
| 0494b92371 | |||
| 8683a5b1b7 | |||
| 4cbe470089 | |||
| 6cd1824a77 | |||
| 2844700dc4 | |||
| f8fd1ea7e1 | |||
| 57edc215d7 | |||
| 7a4044b05f | |||
| e84d5412bc | |||
| 151480dc85 | |||
| 2331b3a270 | |||
| 5cd1a678c8 | |||
| cc9546b761 | |||
| a63dcfed6f | |||
| 4dd8cdc38b | |||
| 1a4822d6be | |||
| ce161f09cc |
29
.github/workflows/tests.yml
vendored
29
.github/workflows/tests.yml
vendored
@ -197,37 +197,38 @@ jobs:
|
||||
echo -e "COMPOSE_PROFILES=\${COMPOSE_PROFILES},tei-cpu" >> docker/.env
|
||||
echo -e "TEI_MODEL=BAAI/bge-small-en-v1.5" >> docker/.env
|
||||
echo -e "RAGFLOW_IMAGE=${RAGFLOW_IMAGE}" >> docker/.env
|
||||
sed -i '1i DOC_ENGINE=infinity' docker/.env
|
||||
echo "HOST_ADDRESS=http://host.docker.internal:${SVR_HTTP_PORT}" >> ${GITHUB_ENV}
|
||||
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
uv sync --python 3.12 --only-group test --no-default-groups --frozen && uv pip install sdk/python --group test
|
||||
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
- name: Run sdk tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Elasticsearch
|
||||
- name: Run frontend api tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
|
||||
|
||||
- name: Run http api tests against Elasticsearch
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee infinity_api_test.log
|
||||
|
||||
- name: Run http api tests against Infinity
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log
|
||||
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
@ -237,35 +238,35 @@ jobs:
|
||||
|
||||
- name: Start ragflow:nightly
|
||||
run: |
|
||||
sed -i '1i DOC_ENGINE=infinity' docker/.env
|
||||
sed -i '1i DOC_ENGINE=elasticsearch' docker/.env
|
||||
sudo docker compose -f docker/docker-compose.yml -p ${GITHUB_RUN_ID} up -d
|
||||
|
||||
- name: Run sdk tests against Infinity
|
||||
- name: Run sdk tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log
|
||||
|
||||
- name: Run frontend api tests against Infinity
|
||||
- name: Run frontend api tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short sdk/python/test/test_frontend_api/get_email.py sdk/python/test/test_frontend_api/test_dataset.py 2>&1 | tee es_api_test.log
|
||||
|
||||
- name: Run http api tests against Infinity
|
||||
- name: Run http api tests against Elasticsearch
|
||||
run: |
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS} > /dev/null; do
|
||||
echo "Waiting for service to be available..."
|
||||
sleep 5
|
||||
done
|
||||
source .venv/bin/activate && DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api
|
||||
source .venv/bin/activate && DOC_ENGINE=elasticsearch pytest -x -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log
|
||||
|
||||
- name: Stop ragflow:nightly
|
||||
if: always() # always run this step even if previous steps failed
|
||||
|
||||
@ -192,6 +192,7 @@ COPY pyproject.toml uv.lock ./
|
||||
COPY mcp mcp
|
||||
COPY plugin plugin
|
||||
COPY common common
|
||||
COPY memory memory
|
||||
|
||||
COPY docker/service_conf.yaml.template ./conf/service_conf.yaml.template
|
||||
COPY docker/entrypoint.sh ./
|
||||
|
||||
@ -29,6 +29,11 @@ from common.versions import get_ragflow_version
|
||||
admin_bp = Blueprint('admin', __name__, url_prefix='/api/v1/admin')
|
||||
|
||||
|
||||
@admin_bp.route('/ping', methods=['GET'])
|
||||
def ping():
|
||||
return success_response('PONG')
|
||||
|
||||
|
||||
@admin_bp.route('/login', methods=['POST'])
|
||||
def login():
|
||||
if not request.json:
|
||||
|
||||
@ -160,7 +160,7 @@ class Graph:
|
||||
return self._tenant_id
|
||||
|
||||
def get_value_with_variable(self,value: str) -> Any:
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
pat = re.compile(r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*")
|
||||
out_parts = []
|
||||
last = 0
|
||||
|
||||
@ -368,8 +368,13 @@ class Canvas(Graph):
|
||||
|
||||
if kwargs.get("webhook_payload"):
|
||||
for k, cpn in self.components.items():
|
||||
if self.components[k]["obj"].component_name.lower() == "webhook":
|
||||
for kk, vv in kwargs["webhook_payload"].items():
|
||||
if self.components[k]["obj"].component_name.lower() == "begin" and self.components[k]["obj"]._param.mode == "Webhook":
|
||||
payload = kwargs.get("webhook_payload", {})
|
||||
if "input" in payload:
|
||||
self.components[k]["obj"].set_input_value("request", payload["input"])
|
||||
for kk, vv in payload.items():
|
||||
if kk == "input":
|
||||
continue
|
||||
self.components[k]["obj"].set_output(kk, vv)
|
||||
|
||||
for k in kwargs.keys():
|
||||
@ -535,6 +540,8 @@ class Canvas(Graph):
|
||||
cite = re.search(r"\[ID:[ 0-9]+\]", cpn_obj.output("content"))
|
||||
|
||||
message_end = {}
|
||||
if cpn_obj.get_param("status"):
|
||||
message_end["status"] = cpn_obj.get_param("status")
|
||||
if isinstance(cpn_obj.output("attachment"), dict):
|
||||
message_end["attachment"] = cpn_obj.output("attachment")
|
||||
if cite:
|
||||
|
||||
@ -29,8 +29,8 @@ from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.mcp_server_service import MCPServerService
|
||||
from common.connection_utils import timeout
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, analyze_task_async, \
|
||||
citation_prompt, reflect_async, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from rag.prompts.generator import next_step_async, COMPLETE_TASK, \
|
||||
citation_prompt, kb_prompt, citation_plus, full_question, message_fit_in, structured_output_prompt
|
||||
from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool
|
||||
from agent.component.llm import LLMParam, LLM
|
||||
|
||||
@ -84,9 +84,11 @@ class Agent(LLM, ToolBase):
|
||||
def __init__(self, canvas, id, param: LLMParam):
|
||||
LLM.__init__(self, canvas, id, param)
|
||||
self.tools = {}
|
||||
for cpn in self._param.tools:
|
||||
for idx, cpn in enumerate(self._param.tools):
|
||||
cpn = self._load_tool_obj(cpn)
|
||||
self.tools[cpn.get_meta()["function"]["name"]] = cpn
|
||||
original_name = cpn.get_meta()["function"]["name"]
|
||||
indexed_name = f"{original_name}_{idx}"
|
||||
self.tools[indexed_name] = cpn
|
||||
|
||||
self.chat_mdl = LLMBundle(self._canvas.get_tenant_id(), TenantLLMService.llm_id2llm_type(self._param.llm_id), self._param.llm_id,
|
||||
max_retries=self._param.max_retries,
|
||||
@ -94,7 +96,12 @@ class Agent(LLM, ToolBase):
|
||||
max_rounds=self._param.max_rounds,
|
||||
verbose_tool_use=True
|
||||
)
|
||||
self.tool_meta = [v.get_meta() for _,v in self.tools.items()]
|
||||
self.tool_meta = []
|
||||
for indexed_name, tool_obj in self.tools.items():
|
||||
original_meta = tool_obj.get_meta()
|
||||
indexed_meta = deepcopy(original_meta)
|
||||
indexed_meta["function"]["name"] = indexed_name
|
||||
self.tool_meta.append(indexed_meta)
|
||||
|
||||
for mcp in self._param.mcp:
|
||||
_, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"])
|
||||
@ -108,7 +115,8 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
def _load_tool_obj(self, cpn: dict) -> object:
|
||||
from agent.component import component_class
|
||||
param = component_class(cpn["component_name"] + "Param")()
|
||||
tool_name = cpn["component_name"]
|
||||
param = component_class(tool_name + "Param")()
|
||||
param.update(cpn["params"])
|
||||
try:
|
||||
param.check()
|
||||
@ -202,7 +210,7 @@ class Agent(LLM, ToolBase):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
use_tools = []
|
||||
ans = ""
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
async for delta_ans, _tk in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt,schema_prompt=schema_prompt):
|
||||
if self.check_if_canceled("Agent processing"):
|
||||
return
|
||||
ans += delta_ans
|
||||
@ -246,7 +254,7 @@ class Agent(LLM, ToolBase):
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(self.chat_mdl.max_length * 0.97))
|
||||
answer_without_toolcall = ""
|
||||
use_tools = []
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async(prompt, msg, use_tools, user_defined_prompt):
|
||||
async for delta_ans, _ in self._react_with_tools_streamly_async_simple(prompt, msg, use_tools, user_defined_prompt):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
return
|
||||
|
||||
@ -264,7 +272,7 @@ class Agent(LLM, ToolBase):
|
||||
if use_tools:
|
||||
self.set_output("use_tools", use_tools)
|
||||
|
||||
async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
async def _react_with_tools_streamly_async_simple(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
token_count = 0
|
||||
tool_metas = self.tool_meta
|
||||
hist = deepcopy(history)
|
||||
@ -276,6 +284,24 @@ class Agent(LLM, ToolBase):
|
||||
else:
|
||||
user_request = history[-1]["content"]
|
||||
|
||||
def build_task_desc(prompt: str, user_request: str, user_defined_prompt: dict | None = None) -> str:
|
||||
"""Build a minimal task_desc by concatenating prompt, query, and tool schemas."""
|
||||
user_defined_prompt = user_defined_prompt or {}
|
||||
|
||||
task_desc = (
|
||||
"### Agent Prompt\n"
|
||||
f"{prompt}\n\n"
|
||||
"### User Request\n"
|
||||
f"{user_request}\n\n"
|
||||
)
|
||||
|
||||
if user_defined_prompt:
|
||||
udp_json = json.dumps(user_defined_prompt, ensure_ascii=False, indent=2)
|
||||
task_desc += "\n### User Defined Prompts\n" + udp_json + "\n"
|
||||
|
||||
return task_desc
|
||||
|
||||
|
||||
async def use_tool_async(name, args):
|
||||
nonlocal hist, use_tools, last_calling
|
||||
logging.info(f"{last_calling=} == {name=}")
|
||||
@ -286,9 +312,6 @@ class Agent(LLM, ToolBase):
|
||||
"arguments": args,
|
||||
"results": tool_response
|
||||
})
|
||||
# self.callback("add_memory", {}, "...")
|
||||
#self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||
|
||||
return name, tool_response
|
||||
|
||||
async def complete():
|
||||
@ -326,6 +349,21 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||
|
||||
def build_observation(tool_call_res: list[tuple]) -> str:
|
||||
"""
|
||||
Build a Observation from tool call results.
|
||||
No LLM involved.
|
||||
"""
|
||||
if not tool_call_res:
|
||||
return ""
|
||||
|
||||
lines = ["Observation:"]
|
||||
for name, result in tool_call_res:
|
||||
lines.append(f"[{name} result]")
|
||||
lines.append(str(result))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def append_user_content(hist, content):
|
||||
if hist[-1]["role"] == "user":
|
||||
hist[-1]["content"] += content
|
||||
@ -333,7 +371,7 @@ class Agent(LLM, ToolBase):
|
||||
hist.append({"role": "user", "content": content})
|
||||
|
||||
st = timer()
|
||||
task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
task_desc = build_task_desc(prompt, user_request, user_defined_prompt)
|
||||
self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
for _ in range(self._param.max_rounds + 1):
|
||||
if self.check_if_canceled("Agent streaming"):
|
||||
@ -364,7 +402,7 @@ class Agent(LLM, ToolBase):
|
||||
|
||||
results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
st = timer()
|
||||
reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
reflection = build_observation(results)
|
||||
append_user_content(hist, reflection)
|
||||
self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
@ -393,6 +431,135 @@ Respond immediately with your final comprehensive answer.
|
||||
async for txt, tkcnt in complete():
|
||||
yield txt, tkcnt
|
||||
|
||||
# async def _react_with_tools_streamly_async(self, prompt, history: list[dict], use_tools, user_defined_prompt={}, schema_prompt: str = ""):
|
||||
# token_count = 0
|
||||
# tool_metas = self.tool_meta
|
||||
# hist = deepcopy(history)
|
||||
# last_calling = ""
|
||||
# if len(hist) > 3:
|
||||
# st = timer()
|
||||
# user_request = await full_question(messages=history, chat_mdl=self.chat_mdl)
|
||||
# self.callback("Multi-turn conversation optimization", {}, user_request, elapsed_time=timer()-st)
|
||||
# else:
|
||||
# user_request = history[-1]["content"]
|
||||
|
||||
# async def use_tool_async(name, args):
|
||||
# nonlocal hist, use_tools, last_calling
|
||||
# logging.info(f"{last_calling=} == {name=}")
|
||||
# last_calling = name
|
||||
# tool_response = await self.toolcall_session.tool_call_async(name, args)
|
||||
# use_tools.append({
|
||||
# "name": name,
|
||||
# "arguments": args,
|
||||
# "results": tool_response
|
||||
# })
|
||||
# # self.callback("add_memory", {}, "...")
|
||||
# #self.add_memory(hist[-2]["content"], hist[-1]["content"], name, args, str(tool_response), user_defined_prompt)
|
||||
|
||||
# return name, tool_response
|
||||
|
||||
# async def complete():
|
||||
# nonlocal hist
|
||||
# need2cite = self._param.cite and self._canvas.get_reference()["chunks"] and self._id.find("-->") < 0
|
||||
# if schema_prompt:
|
||||
# need2cite = False
|
||||
# cited = False
|
||||
# if hist and hist[0]["role"] == "system":
|
||||
# if schema_prompt:
|
||||
# hist[0]["content"] += "\n" + schema_prompt
|
||||
# if need2cite and len(hist) < 7:
|
||||
# hist[0]["content"] += citation_prompt()
|
||||
# cited = True
|
||||
# yield "", token_count
|
||||
|
||||
# _hist = hist
|
||||
# if len(hist) > 12:
|
||||
# _hist = [hist[0], hist[1], *hist[-10:]]
|
||||
# entire_txt = ""
|
||||
# async for delta_ans in self._generate_streamly(_hist):
|
||||
# if not need2cite or cited:
|
||||
# yield delta_ans, 0
|
||||
# entire_txt += delta_ans
|
||||
# if not need2cite or cited:
|
||||
# return
|
||||
|
||||
# st = timer()
|
||||
# txt = ""
|
||||
# async for delta_ans in self._gen_citations_async(entire_txt):
|
||||
# if self.check_if_canceled("Agent streaming"):
|
||||
# return
|
||||
# yield delta_ans, 0
|
||||
# txt += delta_ans
|
||||
|
||||
# self.callback("gen_citations", {}, txt, elapsed_time=timer()-st)
|
||||
|
||||
# def append_user_content(hist, content):
|
||||
# if hist[-1]["role"] == "user":
|
||||
# hist[-1]["content"] += content
|
||||
# else:
|
||||
# hist.append({"role": "user", "content": content})
|
||||
|
||||
# st = timer()
|
||||
# task_desc = await analyze_task_async(self.chat_mdl, prompt, user_request, tool_metas, user_defined_prompt)
|
||||
# self.callback("analyze_task", {}, task_desc, elapsed_time=timer()-st)
|
||||
# for _ in range(self._param.max_rounds + 1):
|
||||
# if self.check_if_canceled("Agent streaming"):
|
||||
# return
|
||||
# response, tk = await next_step_async(self.chat_mdl, hist, tool_metas, task_desc, user_defined_prompt)
|
||||
# # self.callback("next_step", {}, str(response)[:256]+"...")
|
||||
# token_count += tk or 0
|
||||
# hist.append({"role": "assistant", "content": response})
|
||||
# try:
|
||||
# functions = json_repair.loads(re.sub(r"```.*", "", response))
|
||||
# if not isinstance(functions, list):
|
||||
# raise TypeError(f"List should be returned, but `{functions}`")
|
||||
# for f in functions:
|
||||
# if not isinstance(f, dict):
|
||||
# raise TypeError(f"An object type should be returned, but `{f}`")
|
||||
|
||||
# tool_tasks = []
|
||||
# for func in functions:
|
||||
# name = func["name"]
|
||||
# args = func["arguments"]
|
||||
# if name == COMPLETE_TASK:
|
||||
# append_user_content(hist, f"Respond with a formal answer. FORGET(DO NOT mention) about `{COMPLETE_TASK}`. The language for the response MUST be as the same as the first user request.\n")
|
||||
# async for txt, tkcnt in complete():
|
||||
# yield txt, tkcnt
|
||||
# return
|
||||
|
||||
# tool_tasks.append(asyncio.create_task(use_tool_async(name, args)))
|
||||
|
||||
# results = await asyncio.gather(*tool_tasks) if tool_tasks else []
|
||||
# st = timer()
|
||||
# reflection = await reflect_async(self.chat_mdl, hist, results, user_defined_prompt)
|
||||
# append_user_content(hist, reflection)
|
||||
# self.callback("reflection", {}, str(reflection), elapsed_time=timer()-st)
|
||||
|
||||
# except Exception as e:
|
||||
# logging.exception(msg=f"Wrong JSON argument format in LLM ReAct response: {e}")
|
||||
# e = f"\nTool call error, please correct the input parameter of response format and call it again.\n *** Exception ***\n{e}"
|
||||
# append_user_content(hist, str(e))
|
||||
|
||||
# logging.warning( f"Exceed max rounds: {self._param.max_rounds}")
|
||||
# final_instruction = f"""
|
||||
# {user_request}
|
||||
# IMPORTANT: You have reached the conversation limit. Based on ALL the information and research you have gathered so far, please provide a DIRECT and COMPREHENSIVE final answer to the original request.
|
||||
# Instructions:
|
||||
# 1. SYNTHESIZE all information collected during this conversation
|
||||
# 2. Provide a COMPLETE response using existing data - do not suggest additional research
|
||||
# 3. Structure your response as a FINAL DELIVERABLE, not a plan
|
||||
# 4. If information is incomplete, state what you found and provide the best analysis possible with available data
|
||||
# 5. DO NOT mention conversation limits or suggest further steps
|
||||
# 6. Focus on delivering VALUE with the information already gathered
|
||||
# Respond immediately with your final comprehensive answer.
|
||||
# """
|
||||
# if self.check_if_canceled("Agent final instruction"):
|
||||
# return
|
||||
# append_user_content(hist, final_instruction)
|
||||
|
||||
# async for txt, tkcnt in complete():
|
||||
# yield txt, tkcnt
|
||||
|
||||
async def _gen_citations_async(self, text):
|
||||
retrievals = self._canvas.get_reference()
|
||||
retrievals = {"chunks": list(retrievals["chunks"].values()), "doc_aggs": list(retrievals["doc_aggs"].values())}
|
||||
|
||||
@ -361,7 +361,7 @@ class ComponentParamBase(ABC):
|
||||
class ComponentBase(ABC):
|
||||
component_name: str
|
||||
thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z_:0-9]+@[A-Za-z0-9_.]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*"
|
||||
|
||||
def __str__(self):
|
||||
"""
|
||||
|
||||
@ -28,7 +28,7 @@ class BeginParam(UserFillUpParam):
|
||||
self.prologue = "Hi! I'm your smart assistant. What can I do for you?"
|
||||
|
||||
def check(self):
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task"])
|
||||
self.check_valid_value(self.mode, "The 'mode' should be either `conversational` or `task`", ["conversational", "task","Webhook"])
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
||||
@ -56,7 +56,6 @@ class LLMParam(ComponentParamBase):
|
||||
self.check_nonnegative_number(int(self.max_tokens), "[Agent] Max tokens")
|
||||
self.check_decimal_float(float(self.top_p), "[Agent] Top P")
|
||||
self.check_empty(self.llm_id, "[Agent] LLM")
|
||||
self.check_empty(self.sys_prompt, "[Agent] System prompt")
|
||||
self.check_empty(self.prompts, "[Agent] User prompt")
|
||||
|
||||
def gen_conf(self):
|
||||
|
||||
@ -113,6 +113,10 @@ class LoopItem(ComponentBase, ABC):
|
||||
return len(var) == 0
|
||||
elif operator == "not empty":
|
||||
return len(var) > 0
|
||||
elif var is None:
|
||||
if operator == "empty":
|
||||
return True
|
||||
return False
|
||||
|
||||
raise Exception(f"Invalid operator: {operator}")
|
||||
|
||||
|
||||
@ -1,38 +0,0 @@
|
||||
#
|
||||
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from agent.component.base import ComponentParamBase, ComponentBase
|
||||
|
||||
|
||||
class WebhookParam(ComponentParamBase):
|
||||
|
||||
"""
|
||||
Define the Begin component parameters.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return getattr(self, "inputs")
|
||||
|
||||
|
||||
class Webhook(ComponentBase):
|
||||
component_name = "Webhook"
|
||||
|
||||
def _invoke(self, **kwargs):
|
||||
pass
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return ""
|
||||
@ -25,10 +25,12 @@ from api.db.services.document_service import DocumentService
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.joint_services.memory_message_service import query_message
|
||||
from common import settings
|
||||
from common.connection_utils import timeout
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.generator import cross_languages, kb_prompt
|
||||
from rag.prompts.generator import cross_languages, kb_prompt, memory_prompt
|
||||
|
||||
|
||||
class RetrievalParam(ToolParamBase):
|
||||
@ -57,6 +59,7 @@ class RetrievalParam(ToolParamBase):
|
||||
self.top_n = 8
|
||||
self.top_k = 1024
|
||||
self.kb_ids = []
|
||||
self.memory_ids = []
|
||||
self.kb_vars = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
@ -81,15 +84,7 @@ class RetrievalParam(ToolParamBase):
|
||||
class Retrieval(ToolBase, ABC):
|
||||
component_name = "Retrieval"
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
async def _retrieve_kb(self, query_text: str):
|
||||
kb_ids: list[str] = []
|
||||
for id in self._param.kb_ids:
|
||||
if id.find("@") < 0:
|
||||
@ -124,12 +119,12 @@ class Retrieval(ToolBase, ABC):
|
||||
if self._param.rerank_id:
|
||||
rerank_mdl = LLMBundle(kbs[0].tenant_id, LLMType.RERANK, self._param.rerank_id)
|
||||
|
||||
vars = self.get_input_elements_from_text(kwargs["query"])
|
||||
vars = {k:o["value"] for k,o in vars.items()}
|
||||
query = self.string_format(kwargs["query"], vars)
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
|
||||
doc_ids=[]
|
||||
if self._param.meta_data_filter!={}:
|
||||
doc_ids = []
|
||||
if self._param.meta_data_filter != {}:
|
||||
metas = DocumentService.get_meta_by_kbs(kb_ids)
|
||||
|
||||
def _resolve_manual_filter(flt: dict) -> dict:
|
||||
@ -198,18 +193,20 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
if self._param.toc_enhance:
|
||||
chat_mdl = LLMBundle(self._canvas._tenant_id, LLMType.CHAT)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs], chat_mdl, self._param.top_n)
|
||||
cks = settings.retriever.retrieval_by_toc(query, kbinfos["chunks"], [kb.tenant_id for kb in kbs],
|
||||
chat_mdl, self._param.top_n)
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if cks:
|
||||
kbinfos["chunks"] = cks
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"], [kb.tenant_id for kb in kbs])
|
||||
kbinfos["chunks"] = settings.retriever.retrieval_by_children(kbinfos["chunks"],
|
||||
[kb.tenant_id for kb in kbs])
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retriever.retrieval(query,
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
[kb.tenant_id for kb in kbs],
|
||||
kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -218,7 +215,8 @@ class Retrieval(ToolBase, ABC):
|
||||
kbinfos = {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self._param.use_kg and kbs:
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl, LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
ck = settings.kg_retriever.retrieval(query, [kb.tenant_id for kb in kbs], filtered_kb_ids, embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
if ck["content_with_weight"]:
|
||||
@ -248,6 +246,54 @@ class Retrieval(ToolBase, ABC):
|
||||
|
||||
return form_cnt
|
||||
|
||||
async def _retrieve_memory(self, query_text: str):
|
||||
memory_ids: list[str] = [memory_id for memory_id in self._param.memory_ids]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
raise Exception("No memory is selected.")
|
||||
|
||||
embd_names = list({memory.embd_id for memory in memory_list})
|
||||
assert len(embd_names) == 1, "Memory use different embedding models."
|
||||
|
||||
vars = self.get_input_elements_from_text(query_text)
|
||||
vars = {k: o["value"] for k, o in vars.items()}
|
||||
query = self.string_format(query_text, vars)
|
||||
# query message
|
||||
message_list = query_message({"memory_id": memory_ids}, {
|
||||
"query": query,
|
||||
"similarity_threshold": self._param.similarity_threshold,
|
||||
"keywords_similarity_weight": self._param.keywords_similarity_weight,
|
||||
"top_n": self._param.top_n
|
||||
})
|
||||
print(f"found {len(message_list)} messages.")
|
||||
|
||||
if not message_list:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
formated_content = "\n".join(memory_prompt(message_list, 200000))
|
||||
|
||||
# set formalized_content output
|
||||
self.set_output("formalized_content", formated_content)
|
||||
print(f"formated_content {formated_content}")
|
||||
return formated_content
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
async def _invoke_async(self, **kwargs):
|
||||
if self.check_if_canceled("Retrieval processing"):
|
||||
return
|
||||
print(f"debug retrieval, query is {kwargs.get('query')}.", flush=True)
|
||||
if not kwargs.get("query"):
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
if self._param.kb_ids:
|
||||
return await self._retrieve_kb(kwargs["query"])
|
||||
elif self._param.memory_ids:
|
||||
return await self._retrieve_memory(kwargs["query"])
|
||||
else:
|
||||
self.set_output("formalized_content", self._param.empty_response)
|
||||
return
|
||||
|
||||
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 12)))
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
@ -108,7 +108,7 @@ def _load_user():
|
||||
authorization = request.headers.get("Authorization")
|
||||
g.user = None
|
||||
if not authorization:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
access_token = str(jwt.loads(authorization))
|
||||
|
||||
@ -192,7 +192,7 @@ async def rerun():
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.indexExist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
|
||||
@ -76,6 +76,7 @@ async def list_chunk():
|
||||
"image_id": sres.field[id].get("img_id", ""),
|
||||
"available_int": int(sres.field[id].get("available_int", 1)),
|
||||
"positions": sres.field[id].get("position_int", []),
|
||||
"doc_type_kwd": sres.field[id].get("doc_type_kwd")
|
||||
}
|
||||
assert isinstance(d["positions"], list)
|
||||
assert len(d["positions"]) == 0 or (isinstance(d["positions"][0], list) and len(d["positions"][0]) == 5)
|
||||
@ -178,8 +179,9 @@ async def set():
|
||||
# update image
|
||||
image_base64 = req.get("image_base64", None)
|
||||
if image_base64:
|
||||
bkt, name = req.get("img_id", "-").split("-")
|
||||
image_binary = base64.b64decode(image_base64)
|
||||
settings.STORAGE_IMPL.put(req["doc_id"], req["chunk_id"], image_binary)
|
||||
settings.STORAGE_IMPL.put(bkt, name, image_binary)
|
||||
return get_json_result(data=True)
|
||||
|
||||
return await asyncio.to_thread(_set_sync)
|
||||
|
||||
@ -250,14 +250,50 @@ async def list_docs():
|
||||
metadata_condition = req.get("metadata_condition", {}) or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_data_error_result(message="metadata_condition must be an object.")
|
||||
metadata = req.get("metadata", {}) or {}
|
||||
if metadata and not isinstance(metadata, dict):
|
||||
return get_data_error_result(message="metadata must be an object.")
|
||||
|
||||
doc_ids_filter = None
|
||||
if metadata_condition:
|
||||
metas = None
|
||||
if metadata_condition or metadata:
|
||||
metas = DocumentService.get_flatted_meta_by_kbs([kb_id])
|
||||
doc_ids_filter = meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and"))
|
||||
|
||||
if metadata_condition:
|
||||
doc_ids_filter = set(meta_filter(metas, convert_conditions(metadata_condition), metadata_condition.get("logic", "and")))
|
||||
if metadata_condition.get("conditions") and not doc_ids_filter:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
|
||||
if metadata:
|
||||
metadata_doc_ids = None
|
||||
for key, values in metadata.items():
|
||||
if not values:
|
||||
continue
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
values = [str(v) for v in values if v is not None and str(v).strip()]
|
||||
if not values:
|
||||
continue
|
||||
key_doc_ids = set()
|
||||
for value in values:
|
||||
key_doc_ids.update(metas.get(key, {}).get(value, []))
|
||||
if metadata_doc_ids is None:
|
||||
metadata_doc_ids = key_doc_ids
|
||||
else:
|
||||
metadata_doc_ids &= key_doc_ids
|
||||
if not metadata_doc_ids:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
if metadata_doc_ids is not None:
|
||||
if doc_ids_filter is None:
|
||||
doc_ids_filter = metadata_doc_ids
|
||||
else:
|
||||
doc_ids_filter &= metadata_doc_ids
|
||||
if not doc_ids_filter:
|
||||
return get_json_result(data={"total": 0, "docs": []})
|
||||
|
||||
if doc_ids_filter is not None:
|
||||
doc_ids_filter = list(doc_ids_filter)
|
||||
|
||||
try:
|
||||
docs, tol = DocumentService.get_by_kb_id(kb_id, page_number, items_per_page, orderby, desc, keywords, run_status, types, suffix, doc_ids_filter)
|
||||
|
||||
@ -411,6 +447,26 @@ async def metadata_update():
|
||||
return get_json_result(data={"updated": updated, "matched_docs": len(target_doc_ids)})
|
||||
|
||||
|
||||
@manager.route("/update_metadata_setting", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_id", "metadata")
|
||||
async def update_metadata_setting():
|
||||
req = await get_request_json()
|
||||
if not DocumentService.accessible(req["doc_id"], current_user.id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
|
||||
e, doc = DocumentService.get_by_id(req["doc_id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
DocumentService.update_parser_config(doc.id, {"metadata": req["metadata"]})
|
||||
e, doc = DocumentService.get_by_id(doc.id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Document not found!")
|
||||
|
||||
return get_json_result(data=doc.to_dict())
|
||||
|
||||
|
||||
@manager.route("/thumbnails", methods=["GET"]) # noqa: F821
|
||||
# @login_required
|
||||
def thumbnails():
|
||||
@ -528,7 +584,7 @@ async def run():
|
||||
DocumentService.update_by_id(id, info)
|
||||
if req.get("delete", False):
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": id}, search.index_name(tenant_id), doc.kb_id)
|
||||
|
||||
if str(req["run"]) == TaskStatus.RUNNING.value:
|
||||
@ -579,7 +635,7 @@ async def rename():
|
||||
"title_tks": title_tks,
|
||||
"title_sm_tks": rag_tokenizer.fine_grained_tokenize(title_tks),
|
||||
}
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.update(
|
||||
{"doc_id": req["doc_id"]},
|
||||
es_body,
|
||||
@ -660,7 +716,7 @@ async def change_parser():
|
||||
tenant_id = DocumentService.get_tenant_id(req["doc_id"])
|
||||
if not tenant_id:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
if settings.docStoreConn.indexExist(search.index_name(tenant_id), doc.kb_id):
|
||||
if settings.docStoreConn.index_exist(search.index_name(tenant_id), doc.kb_id):
|
||||
settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id)
|
||||
return None
|
||||
|
||||
|
||||
@ -39,9 +39,9 @@ from api.utils.api_utils import get_json_result
|
||||
from rag.nlp import search
|
||||
from api.constants import DATASET_NAME_LIMIT
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.constants import RetCode, PipelineTaskType, StatusEnum, VALID_TASK_STATUS, FileSource, LLMType, PAGERANK_FLD
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from api.apps import login_required, current_user
|
||||
|
||||
|
||||
@ -97,6 +97,19 @@ async def update():
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
|
||||
# Rename folder in FileService
|
||||
if e and req["name"].lower() != kb.name.lower():
|
||||
FileService.filter_update(
|
||||
[
|
||||
File.tenant_id == kb.tenant_id,
|
||||
File.source_type == FileSource.KNOWLEDGEBASE,
|
||||
File.type == "folder",
|
||||
File.name == kb.name,
|
||||
],
|
||||
{"name": req["name"]},
|
||||
)
|
||||
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Can't find this dataset!")
|
||||
@ -150,6 +163,21 @@ async def update():
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/update_metadata_setting', methods=['post']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("kb_id", "metadata")
|
||||
async def update_metadata_setting():
|
||||
req = await get_request_json()
|
||||
e, kb = KnowledgebaseService.get_by_id(req["kb_id"])
|
||||
if not e:
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase rename)!")
|
||||
kb = kb.to_dict()
|
||||
kb["parser_config"]["metadata"] = req["metadata"]
|
||||
KnowledgebaseService.update_by_id(kb["id"], kb)
|
||||
return get_json_result(data=kb)
|
||||
|
||||
|
||||
@manager.route('/detail', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def detail():
|
||||
@ -245,13 +273,19 @@ async def rm():
|
||||
FileService.filter_delete([File.source_type == FileSource.KNOWLEDGEBASE, File.id == f2d[0].file_id])
|
||||
File2DocumentService.delete_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
[
|
||||
File.tenant_id == kbs[0].tenant_id,
|
||||
File.source_type == FileSource.KNOWLEDGEBASE,
|
||||
File.type == "folder",
|
||||
File.name == kbs[0].name,
|
||||
]
|
||||
)
|
||||
if not KnowledgebaseService.delete_by_id(req["kb_id"]):
|
||||
return get_data_error_result(
|
||||
message="Database error (Knowledgebase removal)!")
|
||||
for kb in kbs:
|
||||
settings.docStoreConn.delete({"kb_id": kb.id}, search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.deleteIdx(search.index_name(kb.tenant_id), kb.id)
|
||||
settings.docStoreConn.delete_idx(search.index_name(kb.tenant_id), kb.id)
|
||||
if hasattr(settings.STORAGE_IMPL, 'remove_bucket'):
|
||||
settings.STORAGE_IMPL.remove_bucket(kb.id)
|
||||
return get_json_result(data=True)
|
||||
@ -352,7 +386,7 @@ def knowledge_graph(kb_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), kb_id):
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), kb_id):
|
||||
return get_json_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [kb_id])
|
||||
if not len(sres.ids):
|
||||
@ -824,11 +858,11 @@ async def check_embedding():
|
||||
index_nm = search.index_name(tenant_id)
|
||||
|
||||
res0 = docStoreConn.search(
|
||||
selectFields=[], highlightFields=[],
|
||||
select_fields=[], highlight_fields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
match_expressions=[], order_by=OrderByExpr(),
|
||||
offset=0, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||
)
|
||||
total = docStoreConn.get_total(res0)
|
||||
if total <= 0:
|
||||
@ -840,14 +874,14 @@ async def check_embedding():
|
||||
|
||||
for off in offsets:
|
||||
res1 = docStoreConn.search(
|
||||
selectFields=list(base_fields),
|
||||
highlightFields=[],
|
||||
select_fields=list(base_fields),
|
||||
highlight_fields=[],
|
||||
condition={"kb_id": kb_id, "available_int": 1},
|
||||
matchExprs=[], orderBy=OrderByExpr(),
|
||||
match_expressions=[], order_by=OrderByExpr(),
|
||||
offset=off, limit=1,
|
||||
indexNames=index_nm, knowledgebaseIds=[kb_id]
|
||||
index_names=index_nm, knowledgebase_ids=[kb_id]
|
||||
)
|
||||
ids = docStoreConn.get_chunk_ids(res1)
|
||||
ids = docStoreConn.get_doc_ids(res1)
|
||||
if not ids:
|
||||
continue
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result
|
||||
from common.constants import StatusEnum, LLMType
|
||||
from api.db.db_models import TenantLLM
|
||||
from rag.utils.base64_image import test_image
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel
|
||||
from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel
|
||||
|
||||
|
||||
@manager.route("/factories", methods=["GET"]) # noqa: F821
|
||||
@ -157,7 +157,7 @@ async def add_llm():
|
||||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
api_key = apikey_json(["auth_mode", "bedrock_ak", "bedrock_sk", "bedrock_region", "aws_role_arn"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name += "___LocalAI"
|
||||
@ -208,70 +208,83 @@ async def add_llm():
|
||||
msg = ""
|
||||
mdl_nm = llm["llm_name"].split("___")[0]
|
||||
extra = {"provider": factory}
|
||||
if llm["model_type"] == LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
key=llm["api_key"],
|
||||
model_name=mdl_nm,
|
||||
base_url=llm["api_base"],
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}], {"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.RERANK:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
|
||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||
mdl = CvModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.TTS:
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm["api_base"])
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
elif llm["model_type"] == LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=llm["api_key"], model_name=mdl_nm, base_url=llm.get("api_base", ""))
|
||||
ok, reason = mdl.check_available()
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
else:
|
||||
# TODO: check other type of models
|
||||
pass
|
||||
model_type = llm["model_type"]
|
||||
model_api_key = llm["api_key"]
|
||||
model_base_url = llm.get("api_base", "")
|
||||
match model_type:
|
||||
case LLMType.EMBEDDING.value:
|
||||
assert factory in EmbeddingModel, f"Embedding model from {factory} is not supported yet."
|
||||
mdl = EmbeddingModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
arr, tc = mdl.encode(["Test if the api key is available"])
|
||||
if len(arr[0]) == 0:
|
||||
raise Exception("Fail")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access embedding model({mdl_nm})." + str(e)
|
||||
case LLMType.CHAT.value:
|
||||
assert factory in ChatModel, f"Chat model from {factory} is not supported yet."
|
||||
mdl = ChatModel[factory](
|
||||
key=model_api_key,
|
||||
model_name=mdl_nm,
|
||||
base_url=model_base_url,
|
||||
**extra,
|
||||
)
|
||||
try:
|
||||
m, tc = await mdl.async_chat(None, [{"role": "user", "content": "Hello! How are you doing!"}],
|
||||
{"temperature": 0.9})
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
|
||||
case LLMType.RERANK.value:
|
||||
assert factory in RerankModel, f"RE-rank model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = RerankModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
arr, tc = mdl.similarity("Hello~ RAGFlower!", ["Hi, there!", "Ohh, my friend!"])
|
||||
if len(arr) == 0:
|
||||
raise Exception("Not known.")
|
||||
except KeyError:
|
||||
msg += f"{factory} dose not support this model({factory}/{mdl_nm})"
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
|
||||
case LLMType.IMAGE2TEXT.value:
|
||||
assert factory in CvModel, f"Image to text model from {factory} is not supported yet."
|
||||
mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
image_data = test_image
|
||||
m, tc = mdl.describe(image_data)
|
||||
if not tc and m.find("**ERROR**:") >= 0:
|
||||
raise Exception(m)
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case LLMType.TTS.value:
|
||||
assert factory in TTSModel, f"TTS model from {factory} is not supported yet."
|
||||
mdl = TTSModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
try:
|
||||
for resp in mdl.tts("Hello~ RAGFlower!"):
|
||||
pass
|
||||
except RuntimeError as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case LLMType.OCR.value:
|
||||
assert factory in OcrModel, f"OCR model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = OcrModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
ok, reason = mdl.check_available()
|
||||
if not ok:
|
||||
raise RuntimeError(reason or "Model not available")
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case LLMType.SPEECH2TEXT:
|
||||
assert factory in Seq2txtModel, f"Speech model from {factory} is not supported yet."
|
||||
try:
|
||||
mdl = Seq2txtModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url)
|
||||
# TODO: check the availability
|
||||
except Exception as e:
|
||||
msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e)
|
||||
case _:
|
||||
raise RuntimeError(f"Unknown model type: {model_type}")
|
||||
|
||||
if msg:
|
||||
return get_data_error_result(message=msg)
|
||||
|
||||
@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
|
||||
from quart import Response, request
|
||||
from api.apps import current_user, login_required
|
||||
|
||||
@ -106,7 +108,7 @@ async def create() -> Response:
|
||||
return get_data_error_result(message="Tenant not found.")
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -158,7 +160,7 @@ async def update() -> Response:
|
||||
req["id"] = mcp_id
|
||||
|
||||
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
return get_data_error_result(err_message)
|
||||
|
||||
@ -242,7 +244,7 @@ async def import_multiple() -> Response:
|
||||
headers = {"authorization_token": config["authorization_token"]} if "authorization_token" in config else {}
|
||||
variables = {k: v for k, v in config.items() if k not in {"type", "url", "headers"}}
|
||||
mcp_server = MCPServer(id=new_name, name=new_name, url=config["url"], server_type=config["type"], variables=variables, headers=headers)
|
||||
server_tools, err_message = get_mcp_tools([mcp_server], timeout)
|
||||
server_tools, err_message = await asyncio.to_thread(get_mcp_tools, [mcp_server], timeout)
|
||||
if err_message:
|
||||
results.append({"server": base_name, "success": False, "message": err_message})
|
||||
continue
|
||||
@ -322,9 +324,8 @@ async def list_tools() -> Response:
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"MCP list tools error: {e}")
|
||||
|
||||
results[server_key] = []
|
||||
@ -340,7 +341,7 @@ async def list_tools() -> Response:
|
||||
return server_error_response(e)
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
|
||||
|
||||
@manager.route("/test_tool", methods=["POST"]) # noqa: F821
|
||||
@ -367,10 +368,10 @@ async def test_tool() -> Response:
|
||||
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
tool_call_sessions.append(tool_call_session)
|
||||
result = tool_call_session.tool_call(tool_name, arguments, timeout)
|
||||
result = await asyncio.to_thread(tool_call_session.tool_call, tool_name, arguments, timeout)
|
||||
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions(tool_call_sessions)
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, tool_call_sessions)
|
||||
return get_json_result(data=result)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
@ -424,13 +425,12 @@ async def test_mcp() -> Response:
|
||||
tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables)
|
||||
|
||||
try:
|
||||
tools = tool_call_session.get_tools(timeout)
|
||||
tools = await asyncio.to_thread(tool_call_session.get_tools, timeout)
|
||||
except Exception as e:
|
||||
tools = []
|
||||
return get_data_error_result(message=f"Test MCP error: {e}")
|
||||
finally:
|
||||
# PERF: blocking call to close sessions — consider moving to background thread or task queue
|
||||
close_multiple_mcp_toolcall_sessions([tool_call_session])
|
||||
await asyncio.to_thread(close_multiple_mcp_toolcall_sessions, [tool_call_session])
|
||||
|
||||
for tool in tools:
|
||||
tool_dict = tool.model_dump()
|
||||
|
||||
@ -20,10 +20,12 @@ from api.apps import login_required, current_user
|
||||
from api.db import TenantPermission
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result, \
|
||||
not_allowed_parameters
|
||||
from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human
|
||||
from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT
|
||||
from memory.services.messages import MessageService
|
||||
from common.constants import MemoryType, RetCode, ForgettingPolicy
|
||||
|
||||
|
||||
@ -57,7 +59,6 @@ async def create_memory():
|
||||
|
||||
if res:
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
else:
|
||||
return get_json_result(message=memory, code=RetCode.SERVER_ERROR)
|
||||
|
||||
@ -124,7 +125,7 @@ async def update_memory(memory_id):
|
||||
return get_json_result(message=True, data=memory_dict)
|
||||
|
||||
try:
|
||||
MemoryService.update_memory(memory_id, to_update)
|
||||
MemoryService.update_memory(current_memory.tenant_id, memory_id, to_update)
|
||||
updated_memory = MemoryService.get_by_memory_id(memory_id)
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(updated_memory))
|
||||
|
||||
@ -133,7 +134,7 @@ async def update_memory(memory_id):
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@manager.route("/<memory_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_memory(memory_id):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
@ -141,13 +142,14 @@ async def delete_memory(memory_id):
|
||||
return get_json_result(message=True, code=RetCode.NOT_FOUND)
|
||||
try:
|
||||
MemoryService.delete_memory(memory_id)
|
||||
MessageService.delete_message({"memory_id": memory_id}, memory.tenant_id, memory_id)
|
||||
return get_json_result(message=True)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
return get_json_result(message=str(e), code=RetCode.SERVER_ERROR)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def list_memory():
|
||||
args = request.args
|
||||
@ -183,3 +185,26 @@ async def get_memory_config(memory_id):
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
return get_json_result(message=True, data=format_ret_data_from_memory(memory))
|
||||
|
||||
|
||||
@manager.route("/<memory_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_memory_detail(memory_id):
|
||||
args = request.args
|
||||
agent_ids = args.getlist("agent_id")
|
||||
keywords = args.get("keywords", "")
|
||||
keywords = keywords.strip()
|
||||
page = int(args.get("page", 1))
|
||||
page_size = int(args.get("page_size", 50))
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
messages = MessageService.list_message(
|
||||
memory.tenant_id, memory_id, agent_ids, keywords, page, page_size)
|
||||
agent_name_mapping = {}
|
||||
if messages["message_list"]:
|
||||
agent_list = UserCanvasService.get_basic_info_by_canvas_ids([message["agent_id"] for message in messages["message_list"]])
|
||||
agent_name_mapping = {agent["id"]: agent["title"] for agent in agent_list}
|
||||
for message in messages["message_list"]:
|
||||
message["agent_name"] = agent_name_mapping.get(message["agent_id"], "Unknown")
|
||||
return get_json_result(data={"messages": messages, "storage_type": memory.storage_type}, message=True)
|
||||
|
||||
169
api/apps/messages_app.py
Normal file
169
api/apps/messages_app.py
Normal file
@ -0,0 +1,169 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from quart import request
|
||||
from api.apps import login_required
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from common.time_utils import current_timestamp, timestamp_to_date
|
||||
|
||||
from memory.services.messages import MessageService
|
||||
from api.db.joint_services import memory_message_service
|
||||
from api.db.joint_services.memory_message_service import query_message
|
||||
from api.utils.api_utils import validate_request, get_request_json, get_error_argument_result, get_json_result
|
||||
from common.constants import RetCode
|
||||
|
||||
|
||||
@manager.route("", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("memory_id", "agent_id", "session_id", "user_input", "agent_response")
|
||||
async def add_message():
|
||||
|
||||
req = await get_request_json()
|
||||
memory_ids = req["memory_id"]
|
||||
agent_id = req["agent_id"]
|
||||
session_id = req["session_id"]
|
||||
user_id = req["user_id"] if req.get("user_id") else ""
|
||||
user_input = req["user_input"]
|
||||
agent_response = req["agent_response"]
|
||||
|
||||
res = []
|
||||
for memory_id in memory_ids:
|
||||
success, msg = await memory_message_service.save_to_memory(
|
||||
memory_id,
|
||||
{
|
||||
"user_id": user_id,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id,
|
||||
"user_input": user_input,
|
||||
"agent_response": agent_response
|
||||
}
|
||||
)
|
||||
res.append({
|
||||
"memory_id": memory_id,
|
||||
"success": success,
|
||||
"message": msg
|
||||
})
|
||||
|
||||
if all([r["success"] for r in res]):
|
||||
return get_json_result(message="Successfully added to memories.")
|
||||
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message="Some messages failed to add.", data=res)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def forget_message(memory_id: str, message_id: int):
|
||||
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
forget_time = timestamp_to_date(current_timestamp())
|
||||
update_succeed = MessageService.update_message(
|
||||
{"memory_id": memory_id, "message_id": int(message_id)},
|
||||
{"forget_at": forget_time},
|
||||
memory.tenant_id, memory_id)
|
||||
if update_succeed:
|
||||
return get_json_result(message=update_succeed)
|
||||
else:
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to forget message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("status")
|
||||
async def update_message(memory_id: str, message_id: int):
|
||||
req = await get_request_json()
|
||||
status = req["status"]
|
||||
if not isinstance(status, bool):
|
||||
return get_error_argument_result("Status must be a boolean.")
|
||||
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
update_succeed = MessageService.update_message({"memory_id": memory_id, "message_id": int(message_id)}, {"status": status}, memory.tenant_id, memory_id)
|
||||
if update_succeed:
|
||||
return get_json_result(message=update_succeed)
|
||||
else:
|
||||
return get_json_result(code=RetCode.SERVER_ERROR, message=f"Failed to set status for message '{message_id}' in memory '{memory_id}'.")
|
||||
|
||||
|
||||
@manager.route("/search", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def search_message():
|
||||
args = request.args
|
||||
print(args, flush=True)
|
||||
empty_fields = [f for f in ["memory_id", "query"] if not args.get(f)]
|
||||
if empty_fields:
|
||||
return get_error_argument_result(f"{', '.join(empty_fields)} can't be empty.")
|
||||
|
||||
memory_ids = args.getlist("memory_id")
|
||||
query = args.get("query")
|
||||
similarity_threshold = float(args.get("similarity_threshold", 0.2))
|
||||
keywords_similarity_weight = float(args.get("keywords_similarity_weight", 0.7))
|
||||
top_n = int(args.get("top_n", 5))
|
||||
agent_id = args.get("agent_id", "")
|
||||
session_id = args.get("session_id", "")
|
||||
|
||||
filter_dict = {
|
||||
"memory_id": memory_ids,
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id
|
||||
}
|
||||
params = {
|
||||
"query": query,
|
||||
"similarity_threshold": similarity_threshold,
|
||||
"keywords_similarity_weight": keywords_similarity_weight,
|
||||
"top_n": top_n
|
||||
}
|
||||
res = query_message(filter_dict, params)
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_messages():
|
||||
args = request.args
|
||||
memory_ids = args.getlist("memory_id")
|
||||
agent_id = args.get("agent_id", "")
|
||||
session_id = args.get("session_id", "")
|
||||
limit = int(args.get("limit", 10))
|
||||
if not memory_ids:
|
||||
return get_error_argument_result("memory_ids is required.")
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
res = MessageService.get_recent_messages(
|
||||
uids,
|
||||
memory_ids,
|
||||
agent_id,
|
||||
session_id,
|
||||
limit
|
||||
)
|
||||
return get_json_result(message=True, data=res)
|
||||
|
||||
|
||||
@manager.route("/<memory_id>:<message_id>/content", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_message_content(memory_id:str, message_id: int):
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Memory '{memory_id}' not found.")
|
||||
|
||||
res = MessageService.get_by_message_id(memory_id, message_id, memory.tenant_id)
|
||||
if res:
|
||||
return get_json_result(message=True, data=res)
|
||||
else:
|
||||
return get_json_result(code=RetCode.NOT_FOUND, message=f"Message '{message_id}' in memory '{memory_id}' not found.")
|
||||
@ -14,20 +14,29 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from quart import request, Response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@ -132,48 +141,785 @@ def delete_agent(tenant_id: str, agent_id: str):
|
||||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
is_test = request.path.startswith("/api/v1/webhook_test")
|
||||
start_ts = time.time()
|
||||
|
||||
@manager.route('/webhook/<agent_id>', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def webhook(tenant_id: str, agent_id: str):
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
# 1. Fetch canvas by agent_id
|
||||
exists, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
if not exists:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Canvas not found."),RetCode.BAD_REQUEST
|
||||
|
||||
# 2. Check canvas category
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
return get_data_error_result(message="Dataflow can not be triggered by webhook.")
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Dataflow can not be triggered by webhook."),RetCode.BAD_REQUEST
|
||||
|
||||
# 3. Load DSL from canvas
|
||||
dsl = getattr(cvs, "dsl", None)
|
||||
if not isinstance(dsl, dict):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Invalid DSL format."),RetCode.BAD_REQUEST
|
||||
|
||||
# 4. Check webhook configuration in DSL
|
||||
components = dsl.get("components", {})
|
||||
for k, _ in components.items():
|
||||
cpn_obj = components[k]["obj"]
|
||||
if cpn_obj["component_name"].lower() == "begin" and cpn_obj["params"]["mode"] == "Webhook":
|
||||
webhook_cfg = cpn_obj["params"]
|
||||
|
||||
if not webhook_cfg:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message="Webhook not configured for this agent."),RetCode.BAD_REQUEST
|
||||
|
||||
# 5. Validate request method against webhook_cfg.methods
|
||||
allowed_methods = webhook_cfg.get("methods", [])
|
||||
request_method = request.method.upper()
|
||||
if allowed_methods and request_method not in allowed_methods:
|
||||
return get_data_error_result(
|
||||
code=RetCode.BAD_REQUEST,message=f"HTTP method '{request_method}' not allowed for this webhook."
|
||||
),RetCode.BAD_REQUEST
|
||||
|
||||
# 6. Validate webhook security
|
||||
async def validate_webhook_security(security_cfg: dict):
|
||||
"""Validate webhook security rules based on security configuration."""
|
||||
|
||||
if not security_cfg:
|
||||
return # No security config → allowed by default
|
||||
|
||||
# 1. Validate max body size
|
||||
await _validate_max_body_size(security_cfg)
|
||||
|
||||
# 2. Validate IP whitelist
|
||||
_validate_ip_whitelist(security_cfg)
|
||||
|
||||
# # 3. Validate rate limiting
|
||||
_validate_rate_limit(security_cfg)
|
||||
|
||||
# 4. Validate authentication
|
||||
auth_type = security_cfg.get("auth_type", "none")
|
||||
|
||||
if auth_type == "none":
|
||||
return
|
||||
|
||||
if auth_type == "token":
|
||||
_validate_token_auth(security_cfg)
|
||||
|
||||
elif auth_type == "basic":
|
||||
_validate_basic_auth(security_cfg)
|
||||
|
||||
elif auth_type == "jwt":
|
||||
_validate_jwt_auth(security_cfg)
|
||||
|
||||
else:
|
||||
raise Exception(f"Unsupported auth_type: {auth_type}")
|
||||
|
||||
async def _validate_max_body_size(security_cfg):
|
||||
"""Check request size does not exceed max_body_size."""
|
||||
max_size = security_cfg.get("max_body_size")
|
||||
if not max_size:
|
||||
return
|
||||
|
||||
# Convert "10MB" → bytes
|
||||
units = {"kb": 1024, "mb": 1024**2}
|
||||
size_str = max_size.lower()
|
||||
|
||||
for suffix, factor in units.items():
|
||||
if size_str.endswith(suffix):
|
||||
limit = int(size_str.replace(suffix, "")) * factor
|
||||
break
|
||||
else:
|
||||
raise Exception("Invalid max_body_size format")
|
||||
MAX_LIMIT = 10 * 1024 * 1024 # 10MB
|
||||
if limit > MAX_LIMIT:
|
||||
raise Exception("max_body_size exceeds maximum allowed size (10MB)")
|
||||
|
||||
content_length = request.content_length or 0
|
||||
if content_length > limit:
|
||||
raise Exception(f"Request body too large: {content_length} > {limit}")
|
||||
|
||||
def _validate_ip_whitelist(security_cfg):
|
||||
"""Allow only IPs listed in ip_whitelist."""
|
||||
whitelist = security_cfg.get("ip_whitelist", [])
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
client_ip = request.remote_addr
|
||||
|
||||
|
||||
for rule in whitelist:
|
||||
if "/" in rule:
|
||||
# CIDR notation
|
||||
if ipaddress.ip_address(client_ip) in ipaddress.ip_network(rule, strict=False):
|
||||
return
|
||||
else:
|
||||
# Single IP
|
||||
if client_ip == rule:
|
||||
return
|
||||
|
||||
raise Exception(f"IP {client_ip} is not allowed by whitelist")
|
||||
|
||||
def _validate_rate_limit(security_cfg):
|
||||
"""Simple in-memory rate limiting."""
|
||||
rl = security_cfg.get("rate_limit")
|
||||
if not rl:
|
||||
return
|
||||
|
||||
limit = int(rl.get("limit", 60))
|
||||
if limit <= 0:
|
||||
raise Exception("rate_limit.limit must be > 0")
|
||||
per = rl.get("per", "minute")
|
||||
|
||||
window = {
|
||||
"second": 1,
|
||||
"minute": 60,
|
||||
"hour": 3600,
|
||||
"day": 86400,
|
||||
}.get(per)
|
||||
|
||||
if not window:
|
||||
raise Exception(f"Invalid rate_limit.per: {per}")
|
||||
|
||||
capacity = limit
|
||||
rate = limit / window
|
||||
cost = 1
|
||||
|
||||
key = f"rl:tb:{agent_id}"
|
||||
now = time.time()
|
||||
|
||||
try:
|
||||
res = REDIS_CONN.lua_token_bucket(
|
||||
keys=[key],
|
||||
args=[capacity, rate, now, cost],
|
||||
client=REDIS_CONN.REDIS,
|
||||
)
|
||||
|
||||
allowed = int(res[0])
|
||||
if allowed != 1:
|
||||
raise Exception("Too many requests (rate limit exceeded)")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Rate limit error: {e}")
|
||||
|
||||
def _validate_token_auth(security_cfg):
|
||||
"""Validate header-based token authentication."""
|
||||
token_cfg = security_cfg.get("token",{})
|
||||
header = token_cfg.get("token_header")
|
||||
token_value = token_cfg.get("token_value")
|
||||
|
||||
provided = request.headers.get(header)
|
||||
if provided != token_value:
|
||||
raise Exception("Invalid token authentication")
|
||||
|
||||
def _validate_basic_auth(security_cfg):
|
||||
"""Validate HTTP Basic Auth credentials."""
|
||||
auth_cfg = security_cfg.get("basic_auth", {})
|
||||
username = auth_cfg.get("username")
|
||||
password = auth_cfg.get("password")
|
||||
|
||||
auth = request.authorization
|
||||
if not auth or auth.username != username or auth.password != password:
|
||||
raise Exception("Invalid Basic Auth credentials")
|
||||
|
||||
def _validate_jwt_auth(security_cfg):
|
||||
"""Validate JWT token in Authorization header."""
|
||||
jwt_cfg = security_cfg.get("jwt", {})
|
||||
secret = jwt_cfg.get("secret")
|
||||
if not secret:
|
||||
raise Exception("JWT secret not configured")
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
raise Exception("Missing Bearer token")
|
||||
|
||||
token = auth_header[len("Bearer "):].strip()
|
||||
if not token:
|
||||
raise Exception("Empty Bearer token")
|
||||
|
||||
alg = (jwt_cfg.get("algorithm") or "HS256").upper()
|
||||
|
||||
decode_kwargs = {
|
||||
"key": secret,
|
||||
"algorithms": [alg],
|
||||
}
|
||||
options = {}
|
||||
if jwt_cfg.get("audience"):
|
||||
decode_kwargs["audience"] = jwt_cfg["audience"]
|
||||
options["verify_aud"] = True
|
||||
else:
|
||||
options["verify_aud"] = False
|
||||
|
||||
if jwt_cfg.get("issuer"):
|
||||
decode_kwargs["issuer"] = jwt_cfg["issuer"]
|
||||
options["verify_iss"] = True
|
||||
else:
|
||||
options["verify_iss"] = False
|
||||
try:
|
||||
decoded = jwt.decode(
|
||||
token,
|
||||
options=options,
|
||||
**decode_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Invalid JWT: {str(e)}")
|
||||
|
||||
raw_required_claims = jwt_cfg.get("required_claims", [])
|
||||
if isinstance(raw_required_claims, str):
|
||||
required_claims = [raw_required_claims]
|
||||
elif isinstance(raw_required_claims, (list, tuple, set)):
|
||||
required_claims = list(raw_required_claims)
|
||||
else:
|
||||
required_claims = []
|
||||
|
||||
required_claims = [
|
||||
c for c in required_claims
|
||||
if isinstance(c, str) and c.strip()
|
||||
]
|
||||
|
||||
RESERVED_CLAIMS = {"exp", "sub", "aud", "iss", "nbf", "iat"}
|
||||
for claim in required_claims:
|
||||
if claim in RESERVED_CLAIMS:
|
||||
raise Exception(f"Reserved JWT claim cannot be required: {claim}")
|
||||
|
||||
for claim in required_claims:
|
||||
if claim not in decoded:
|
||||
raise Exception(f"Missing JWT claim: {claim}")
|
||||
|
||||
return decoded
|
||||
|
||||
try:
|
||||
canvas = Canvas(cvs.dsl, tenant_id, agent_id)
|
||||
security_config=webhook_cfg.get("security", {})
|
||||
await validate_webhook_security(security_config)
|
||||
except Exception as e:
|
||||
return get_json_result(
|
||||
data=False, message=str(e),
|
||||
code=RetCode.EXCEPTION_ERROR)
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
if not isinstance(cvs.dsl, str):
|
||||
dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
try:
|
||||
canvas = Canvas(dsl, cvs.user_id, agent_id)
|
||||
except Exception as e:
|
||||
resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e))
|
||||
resp.status_code = RetCode.BAD_REQUEST
|
||||
return resp
|
||||
|
||||
# 7. Parse request body
|
||||
async def parse_webhook_request(content_type):
|
||||
"""Parse request based on content-type and return structured data."""
|
||||
|
||||
# 1. Query
|
||||
query_data = {k: v for k, v in request.args.items()}
|
||||
|
||||
# 2. Headers
|
||||
header_data = {k: v for k, v in request.headers.items()}
|
||||
|
||||
# 3. Body
|
||||
ctype = request.headers.get("Content-Type", "").split(";")[0].strip()
|
||||
if ctype and ctype != content_type:
|
||||
raise ValueError(
|
||||
f"Invalid Content-Type: expect '{content_type}', got '{ctype}'"
|
||||
)
|
||||
|
||||
body_data: dict = {}
|
||||
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
try:
|
||||
async for ans in canvas.run(query=req.get("query", ""), files=req.get("files", []), user_id=req.get("user_id", tenant_id), webhook_payload=req):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
if ctype == "application/json":
|
||||
body_data = await request.get_json() or {}
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], cvs.to_dict())
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
elif ctype == "multipart/form-data":
|
||||
nonlocal canvas
|
||||
form = await request.form
|
||||
files = await request.files
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
body_data = {}
|
||||
|
||||
for key, value in form.items():
|
||||
body_data[key] = value
|
||||
|
||||
if len(files) > 10:
|
||||
raise Exception("Too many uploaded files")
|
||||
for key, file in files.items():
|
||||
desc = FileService.upload_info(
|
||||
cvs.user_id, # user
|
||||
file, # FileStorage
|
||||
None # url (None for webhook)
|
||||
)
|
||||
file_parsed= await canvas.get_files_async([desc])
|
||||
body_data[key] = file_parsed
|
||||
|
||||
elif ctype == "application/x-www-form-urlencoded":
|
||||
form = await request.form
|
||||
body_data = dict(form)
|
||||
|
||||
else:
|
||||
# text/plain / octet-stream / empty / unknown
|
||||
raw = await request.get_data()
|
||||
if raw:
|
||||
try:
|
||||
body_data = json.loads(raw.decode("utf-8"))
|
||||
except Exception:
|
||||
body_data = {}
|
||||
else:
|
||||
body_data = {}
|
||||
|
||||
except Exception:
|
||||
body_data = {}
|
||||
|
||||
return {
|
||||
"query": query_data,
|
||||
"headers": header_data,
|
||||
"body": body_data,
|
||||
"content_type": ctype,
|
||||
}
|
||||
|
||||
def extract_by_schema(data, schema, name="section"):
|
||||
"""
|
||||
Extract only fields defined in schema.
|
||||
Required fields must exist.
|
||||
Optional fields default to type-based default values.
|
||||
Type validation included.
|
||||
"""
|
||||
props = schema.get("properties", {})
|
||||
required = schema.get("required", [])
|
||||
|
||||
extracted = {}
|
||||
|
||||
for field, field_schema in props.items():
|
||||
field_type = field_schema.get("type")
|
||||
|
||||
# 1. Required field missing
|
||||
if field in required and field not in data:
|
||||
raise Exception(f"{name} missing required field: {field}")
|
||||
|
||||
# 2. Optional → default value
|
||||
if field not in data:
|
||||
extracted[field] = default_for_type(field_type)
|
||||
continue
|
||||
|
||||
raw_value = data[field]
|
||||
|
||||
# 3. Auto convert value
|
||||
try:
|
||||
value = auto_cast_value(raw_value, field_type)
|
||||
except Exception as e:
|
||||
raise Exception(f"{name}.{field} auto-cast failed: {str(e)}")
|
||||
|
||||
# 4. Type validation
|
||||
if not validate_type(value, field_type):
|
||||
raise Exception(
|
||||
f"{name}.{field} type mismatch: expected {field_type}, got {type(value).__name__}"
|
||||
)
|
||||
|
||||
extracted[field] = value
|
||||
|
||||
return extracted
|
||||
|
||||
|
||||
def default_for_type(t):
|
||||
"""Return default value for the given schema type."""
|
||||
if t == "file":
|
||||
return []
|
||||
if t == "object":
|
||||
return {}
|
||||
if t == "boolean":
|
||||
return False
|
||||
if t == "number":
|
||||
return 0
|
||||
if t == "string":
|
||||
return ""
|
||||
if t and t.startswith("array"):
|
||||
return []
|
||||
if t == "null":
|
||||
return None
|
||||
return None
|
||||
|
||||
def auto_cast_value(value, expected_type):
|
||||
"""Convert string values into schema type when possible."""
|
||||
|
||||
# Non-string values already good
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
v = value.strip()
|
||||
|
||||
# Boolean
|
||||
if expected_type == "boolean":
|
||||
if v.lower() in ["true", "1"]:
|
||||
return True
|
||||
if v.lower() in ["false", "0"]:
|
||||
return False
|
||||
raise Exception(f"Cannot convert '{value}' to boolean")
|
||||
|
||||
# Number
|
||||
if expected_type == "number":
|
||||
# integer
|
||||
if v.isdigit() or (v.startswith("-") and v[1:].isdigit()):
|
||||
return int(v)
|
||||
|
||||
# float
|
||||
try:
|
||||
return float(v)
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to number")
|
||||
|
||||
# Object
|
||||
if expected_type == "object":
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an object")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to object")
|
||||
|
||||
# Array <T>
|
||||
if expected_type.startswith("array"):
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
if isinstance(parsed, list):
|
||||
return parsed
|
||||
else:
|
||||
raise Exception("JSON is not an array")
|
||||
except Exception:
|
||||
raise Exception(f"Cannot convert '{value}' to array")
|
||||
|
||||
# String (accept original)
|
||||
if expected_type == "string":
|
||||
return value
|
||||
|
||||
# File
|
||||
if expected_type == "file":
|
||||
return value
|
||||
# Default: do nothing
|
||||
return value
|
||||
|
||||
|
||||
def validate_type(value, t):
|
||||
"""Validate value type against schema type t."""
|
||||
if t == "file":
|
||||
return isinstance(value, list)
|
||||
|
||||
if t == "string":
|
||||
return isinstance(value, str)
|
||||
|
||||
if t == "number":
|
||||
return isinstance(value, (int, float))
|
||||
|
||||
if t == "boolean":
|
||||
return isinstance(value, bool)
|
||||
|
||||
if t == "object":
|
||||
return isinstance(value, dict)
|
||||
|
||||
# array<string> / array<number> / array<object>
|
||||
if t.startswith("array"):
|
||||
if not isinstance(value, list):
|
||||
return False
|
||||
|
||||
if "<" in t and ">" in t:
|
||||
inner = t[t.find("<") + 1 : t.find(">")]
|
||||
|
||||
# Check each element type
|
||||
for item in value:
|
||||
if not validate_type(item, inner):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return True
|
||||
parsed = await parse_webhook_request(webhook_cfg.get("content_types"))
|
||||
SCHEMA = webhook_cfg.get("schema", {"query": {}, "headers": {}, "body": {}})
|
||||
|
||||
# Extract strictly by schema
|
||||
try:
|
||||
query_clean = extract_by_schema(parsed["query"], SCHEMA.get("query", {}), name="query")
|
||||
header_clean = extract_by_schema(parsed["headers"], SCHEMA.get("headers", {}), name="headers")
|
||||
body_clean = extract_by_schema(parsed["body"], SCHEMA.get("body", {}), name="body")
|
||||
except Exception as e:
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)),RetCode.BAD_REQUEST
|
||||
|
||||
clean_request = {
|
||||
"query": query_clean,
|
||||
"headers": header_clean,
|
||||
"body": body_clean,
|
||||
"input": parsed
|
||||
}
|
||||
|
||||
execution_mode = webhook_cfg.get("execution_mode", "Immediately")
|
||||
response_cfg = webhook_cfg.get("response", {})
|
||||
|
||||
def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600):
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
|
||||
raw = REDIS_CONN.get(key)
|
||||
obj = json.loads(raw) if raw else {"webhooks": {}}
|
||||
|
||||
ws = obj["webhooks"].setdefault(
|
||||
str(start_ts),
|
||||
{"start_ts": start_ts, "events": []}
|
||||
)
|
||||
|
||||
ws["events"].append({
|
||||
"ts": time.time(),
|
||||
**event
|
||||
})
|
||||
|
||||
REDIS_CONN.set_obj(key, obj, ttl)
|
||||
|
||||
if execution_mode == "Immediately":
|
||||
status = response_cfg.get("status", 200)
|
||||
try:
|
||||
status = int(status)
|
||||
except (TypeError, ValueError):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}")),RetCode.BAD_REQUEST
|
||||
|
||||
if not (200 <= status <= 399):
|
||||
return get_data_error_result(code=RetCode.BAD_REQUEST,message=str(f"Invalid response status code: {status}, must be between 200 and 399")),RetCode.BAD_REQUEST
|
||||
|
||||
body_tpl = response_cfg.get("body_template", "")
|
||||
|
||||
def parse_body(body: str):
|
||||
if not body:
|
||||
return None, "application/json"
|
||||
|
||||
try:
|
||||
parsed = json.loads(body)
|
||||
return parsed, "application/json"
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return body, "text/plain"
|
||||
|
||||
|
||||
body, content_type = parse_body(body_tpl)
|
||||
resp = Response(
|
||||
json.dumps(body, ensure_ascii=False) if content_type == "application/json" else body,
|
||||
status=status,
|
||||
content_type=content_type,
|
||||
)
|
||||
|
||||
async def background_run():
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request
|
||||
):
|
||||
if is_test:
|
||||
append_webhook_trace(agent_id, start_ts, ans)
|
||||
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(cvs.user_id, cvs.to_dict())
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("Webhook background run failed")
|
||||
if is_test:
|
||||
try:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Failed to append webhook trace")
|
||||
|
||||
asyncio.create_task(background_run())
|
||||
return resp
|
||||
else:
|
||||
async def sse():
|
||||
nonlocal canvas
|
||||
contents: list[str] = []
|
||||
status = 200
|
||||
try:
|
||||
async for ans in canvas.run(
|
||||
query="",
|
||||
user_id=cvs.user_id,
|
||||
webhook_payload=clean_request,
|
||||
):
|
||||
if ans["event"] == "message":
|
||||
content = ans["data"]["content"]
|
||||
if ans["data"].get("start_to_think", False):
|
||||
content = "<think>"
|
||||
elif ans["data"].get("end_to_think", False):
|
||||
content = "</think>"
|
||||
if content:
|
||||
contents.append(content)
|
||||
if ans["event"] == "message_end":
|
||||
status = int(ans["data"].get("status", status))
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
ans
|
||||
)
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
final_content = "".join(contents)
|
||||
return {
|
||||
"message": final_content,
|
||||
"success": True,
|
||||
"code": status,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
if is_test:
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "error",
|
||||
"message": str(e),
|
||||
"error_type": type(e).__name__,
|
||||
}
|
||||
)
|
||||
append_webhook_trace(
|
||||
agent_id,
|
||||
start_ts,
|
||||
{
|
||||
"event": "finished",
|
||||
"elapsed_time": time.time() - start_ts,
|
||||
"success": False,
|
||||
}
|
||||
)
|
||||
return {"code": 400, "message": str(e),"success":False}
|
||||
|
||||
result = await sse()
|
||||
return Response(
|
||||
json.dumps(result),
|
||||
status=result["code"],
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
|
||||
@manager.route("/webhook_trace/<agent_id>", methods=["GET"]) # noqa: F821
|
||||
async def webhook_trace(agent_id: str):
|
||||
def encode_webhook_id(start_ts: str) -> str:
|
||||
WEBHOOK_ID_SECRET = "webhook_id_secret"
|
||||
sig = hmac.new(
|
||||
WEBHOOK_ID_SECRET.encode("utf-8"),
|
||||
start_ts.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return base64.urlsafe_b64encode(sig).decode("utf-8").rstrip("=")
|
||||
|
||||
def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None:
|
||||
for ts in webhooks.keys():
|
||||
if encode_webhook_id(ts) == enc_id:
|
||||
return ts
|
||||
return None
|
||||
since_ts = request.args.get("since_ts", type=float)
|
||||
webhook_id = request.args.get("webhook_id")
|
||||
|
||||
key = f"webhook-trace-{agent_id}-logs"
|
||||
raw = REDIS_CONN.get(key)
|
||||
|
||||
if since_ts is None:
|
||||
now = time.time()
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": now,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
if not raw:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
obj = json.loads(raw)
|
||||
webhooks = obj.get("webhooks", {})
|
||||
|
||||
if webhook_id is None:
|
||||
candidates = [
|
||||
float(k) for k in webhooks.keys() if float(k) > since_ts
|
||||
]
|
||||
|
||||
if not candidates:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": None,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
start_ts = min(candidates)
|
||||
real_id = str(start_ts)
|
||||
webhook_id = encode_webhook_id(real_id)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": start_ts,
|
||||
"finished": False,
|
||||
}
|
||||
)
|
||||
|
||||
real_id = decode_webhook_id(webhook_id, webhooks)
|
||||
|
||||
if not real_id:
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": [],
|
||||
"next_since_ts": since_ts,
|
||||
"finished": True,
|
||||
}
|
||||
)
|
||||
|
||||
ws = webhooks.get(str(real_id))
|
||||
events = ws.get("events", [])
|
||||
new_events = [e for e in events if e.get("ts", 0) > since_ts]
|
||||
|
||||
next_ts = since_ts
|
||||
for e in new_events:
|
||||
next_ts = max(next_ts, e["ts"])
|
||||
|
||||
finished = any(e.get("event") == "finished" for e in new_events)
|
||||
|
||||
return get_json_result(
|
||||
data={
|
||||
"webhook_id": webhook_id,
|
||||
"events": new_events,
|
||||
"next_since_ts": next_ts,
|
||||
"finished": finished,
|
||||
}
|
||||
)
|
||||
|
||||
@ -287,7 +287,7 @@ def list_chat(tenant_id):
|
||||
chats = DialogService.get_list(tenant_id, page_number, items_per_page, orderby, desc, id, name)
|
||||
if not chats:
|
||||
return get_result(data=[])
|
||||
list_assts = []
|
||||
list_assistants = []
|
||||
key_mapping = {
|
||||
"parameters": "variables",
|
||||
"prologue": "opener",
|
||||
@ -321,5 +321,5 @@ def list_chat(tenant_id):
|
||||
del res["kb_ids"]
|
||||
res["datasets"] = kb_list
|
||||
res["avatar"] = res.pop("icon")
|
||||
list_assts.append(res)
|
||||
return get_result(data=list_assts)
|
||||
list_assistants.append(res)
|
||||
return get_result(data=list_assistants)
|
||||
|
||||
@ -495,7 +495,7 @@ def knowledge_graph(tenant_id, dataset_id):
|
||||
}
|
||||
|
||||
obj = {"graph": {}, "mind_map": {}}
|
||||
if not settings.docStoreConn.indexExist(search.index_name(kb.tenant_id), dataset_id):
|
||||
if not settings.docStoreConn.index_exist(search.index_name(kb.tenant_id), dataset_id):
|
||||
return get_result(data=obj)
|
||||
sres = settings.retriever.search(req, search.index_name(kb.tenant_id), [dataset_id])
|
||||
if not len(sres.ids):
|
||||
|
||||
@ -1080,7 +1080,7 @@ def list_chunks(tenant_id, dataset_id, document_id):
|
||||
res["chunks"].append(final_chunk)
|
||||
_ = Chunk(**final_chunk)
|
||||
|
||||
elif settings.docStoreConn.indexExist(search.index_name(tenant_id), dataset_id):
|
||||
elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id):
|
||||
sres = settings.retriever.search(query, search.index_name(tenant_id), [dataset_id], emb_mdl=None, highlight=True)
|
||||
res["total"] = sres.total
|
||||
for id in sres.ids:
|
||||
|
||||
@ -205,7 +205,8 @@ async def create(tenant_id):
|
||||
if not FileService.is_parent_folder_exist(pf_id):
|
||||
return get_json_result(data=False, message="Parent Folder Doesn't Exist!", code=RetCode.BAD_REQUEST)
|
||||
if FileService.query(name=req["name"], parent_id=pf_id):
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.", code=409)
|
||||
return get_json_result(data=False, message="Duplicated folder name in the same folder.",
|
||||
code=RetCode.CONFLICT)
|
||||
|
||||
if input_file_type == FileType.FOLDER.value:
|
||||
file_type = FileType.FOLDER.value
|
||||
@ -565,11 +566,13 @@ async def rename(tenant_id):
|
||||
|
||||
if file.type != FileType.FOLDER.value and pathlib.Path(req["name"].lower()).suffix != pathlib.Path(
|
||||
file.name.lower()).suffix:
|
||||
return get_json_result(data=False, message="The extension of file can't be changed", code=RetCode.BAD_REQUEST)
|
||||
return get_json_result(data=False, message="The extension of file can't be changed",
|
||||
code=RetCode.BAD_REQUEST)
|
||||
|
||||
for existing_file in FileService.query(name=req["name"], pf_id=file.parent_id):
|
||||
if existing_file.name == req["name"]:
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.", code=409)
|
||||
return get_json_result(data=False, message="Duplicated file name in the same folder.",
|
||||
code=RetCode.CONFLICT)
|
||||
|
||||
if not FileService.update_by_id(req["file_id"], {"name": req["name"]}):
|
||||
return get_json_result(message="Database error (File rename)!", code=RetCode.SERVER_ERROR)
|
||||
@ -631,9 +634,10 @@ async def get(tenant_id, file_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/file/download/<attachment_id>", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
async def download_attachment(tenant_id,attachment_id):
|
||||
async def download_attachment(tenant_id, attachment_id):
|
||||
try:
|
||||
ext = request.args.get("ext", "markdown")
|
||||
data = await asyncio.to_thread(settings.STORAGE_IMPL.get, tenant_id, attachment_id)
|
||||
@ -645,6 +649,7 @@ async def download_attachment(tenant_id,attachment_id):
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/file/mv', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
async def move(tenant_id):
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
|
||||
@ -32,7 +33,7 @@ from api.db.services.dialog_service import DialogService, async_ask, async_chat,
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from common.metadata_utils import apply_meta_data_filter
|
||||
from common.metadata_utils import apply_meta_data_filter, convert_conditions, meta_filter
|
||||
from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from common.misc_utils import get_uuid
|
||||
@ -128,11 +129,33 @@ async def chat_completion(tenant_id, chat_id):
|
||||
req = {"question": ""}
|
||||
if not req.get("session_id"):
|
||||
req["question"] = ""
|
||||
if not DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value):
|
||||
dia = DialogService.query(tenant_id=tenant_id, id=chat_id, status=StatusEnum.VALID.value)
|
||||
if not dia:
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
if req.get("session_id"):
|
||||
if not ConversationService.query(id=req["session_id"], dialog_id=chat_id):
|
||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||
|
||||
metadata_condition = req.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
if metadata_condition and req.get("question"):
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
|
||||
if filtered_doc_ids:
|
||||
req["doc_ids"] = ",".join(filtered_doc_ids)
|
||||
else:
|
||||
req.pop("doc_ids", None)
|
||||
|
||||
if req.get("stream", True):
|
||||
resp = Response(rag_completion(tenant_id, chat_id, **req), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
@ -195,7 +218,19 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
{"role": "user", "content": "Can you tell me how to install neovim"},
|
||||
],
|
||||
stream=stream,
|
||||
extra_body={"reference": reference}
|
||||
extra_body={
|
||||
"reference": reference,
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if stream:
|
||||
@ -211,7 +246,11 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
"""
|
||||
req = await get_request_json()
|
||||
|
||||
need_reference = bool(req.get("reference", False))
|
||||
extra_body = req.get("extra_body") or {}
|
||||
if extra_body and not isinstance(extra_body, dict):
|
||||
return get_error_data_result("extra_body must be an object.")
|
||||
|
||||
need_reference = bool(extra_body.get("reference", False))
|
||||
|
||||
messages = req.get("messages", [])
|
||||
# To prevent empty [] input
|
||||
@ -229,6 +268,22 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
|
||||
metadata_condition = extra_body.get("metadata_condition") or {}
|
||||
if metadata_condition and not isinstance(metadata_condition, dict):
|
||||
return get_error_data_result(message="metadata_condition must be an object.")
|
||||
|
||||
doc_ids_str = None
|
||||
if metadata_condition:
|
||||
metas = DocumentService.get_meta_by_kbs(dia.kb_ids or [])
|
||||
filtered_doc_ids = meta_filter(
|
||||
metas,
|
||||
convert_conditions(metadata_condition),
|
||||
metadata_condition.get("logic", "and"),
|
||||
)
|
||||
if metadata_condition.get("conditions") and not filtered_doc_ids:
|
||||
filtered_doc_ids = ["-999"]
|
||||
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||
|
||||
# Filter system and non-sense assistant messages
|
||||
msg = []
|
||||
for m in messages:
|
||||
@ -276,14 +331,17 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
}
|
||||
|
||||
try:
|
||||
async for ans in async_chat(dia, msg, True, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
last_ans = ans
|
||||
answer = ans["answer"]
|
||||
|
||||
reasoning_match = re.search(r"<think>(.*?)</think>", answer, flags=re.DOTALL)
|
||||
if reasoning_match:
|
||||
reasoning_part = reasoning_match.group(1)
|
||||
content_part = answer[reasoning_match.end():]
|
||||
content_part = answer[reasoning_match.end() :]
|
||||
else:
|
||||
reasoning_part = ""
|
||||
content_part = answer
|
||||
@ -328,8 +386,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
response["choices"][0]["delta"]["content"] = None
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
response["choices"][0]["finish_reason"] = "stop"
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used,
|
||||
"total_tokens": len(prompt) + token_used}
|
||||
response["usage"] = {"prompt_tokens": len(prompt), "completion_tokens": token_used, "total_tokens": len(prompt) + token_used}
|
||||
if need_reference:
|
||||
response["choices"][0]["delta"]["reference"] = chunks_format(last_ans.get("reference", []))
|
||||
response["choices"][0]["delta"]["final_content"] = last_ans.get("answer", "")
|
||||
@ -344,7 +401,10 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return resp
|
||||
else:
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, False, toolcall_session=toolcall_session, tools=tools, quote=need_reference):
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||
# focus answer content only
|
||||
answer = ans
|
||||
break
|
||||
@ -388,7 +448,7 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
@token_required
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
tiktokenenc = tiktoken.get_encoding("cl100k_base")
|
||||
tiktoken_encode = tiktoken.get_encoding("cl100k_base")
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
return get_error_data_result("You must provide at least one message.")
|
||||
@ -396,7 +456,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||
prompt_tokens = sum(len(tiktokenenc.encode(m["content"])) for m in filtered_messages)
|
||||
prompt_tokens = sum(len(tiktoken_encode.encode(m["content"])) for m in filtered_messages)
|
||||
if not filtered_messages:
|
||||
return jsonify(
|
||||
get_data_openai(
|
||||
@ -404,7 +464,7 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
content="No valid messages found (user or assistant).",
|
||||
finish_reason="stop",
|
||||
model=req.get("model", ""),
|
||||
completion_tokens=len(tiktokenenc.encode("No valid messages found (user or assistant).")),
|
||||
completion_tokens=len(tiktoken_encode.encode("No valid messages found (user or assistant).")),
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
)
|
||||
@ -441,15 +501,19 @@ async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
):
|
||||
return jsonify(response)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
async def generate():
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
@ -457,7 +521,21 @@ async def agent_completions(tenant_id, agent_id):
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if ans.get("event") not in ["message", "message_end"]:
|
||||
event = ans.get("event")
|
||||
if event == "node_finished":
|
||||
if return_trace:
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
ans.setdefault("data", {})["trace"] = trace_items
|
||||
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
yield answer
|
||||
|
||||
if event not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
@ -474,6 +552,7 @@ async def agent_completions(tenant_id, agent_id):
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
@ -484,11 +563,22 @@ async def agent_completions(tenant_id, agent_id):
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
if return_trace and ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
|
||||
final_ans = ans
|
||||
except Exception as e:
|
||||
return get_result(data=f"**ERROR**: {str(e)}")
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if return_trace and final_ans:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
@ -832,6 +922,7 @@ async def chatbot_completions(dialog_id):
|
||||
async for answer in iframe_completion(dialog_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
return None
|
||||
|
||||
@manager.route("/chatbots/<dialog_id>/info", methods=["GET"]) # noqa: F821
|
||||
async def chatbots_inputs(dialog_id):
|
||||
@ -879,6 +970,7 @@ async def agent_bot_completions(agent_id):
|
||||
async for answer in agent_completion(objs[0].tenant_id, agent_id, **req):
|
||||
return get_result(data=answer)
|
||||
|
||||
return None
|
||||
|
||||
@manager.route("/agentbots/<agent_id>/inputs", methods=["GET"]) # noqa: F821
|
||||
async def begin_inputs(agent_id):
|
||||
|
||||
@ -660,7 +660,7 @@ def user_register(user_id, user):
|
||||
tenant_llm = get_init_tenant_llm(user_id)
|
||||
|
||||
if not UserService.save(**user):
|
||||
return
|
||||
return None
|
||||
TenantService.insert(**tenant)
|
||||
UserTenantService.insert(**usr_tenant)
|
||||
TenantLLMService.insert_many(tenant_llm)
|
||||
|
||||
@ -30,6 +30,7 @@ from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.tenant_llm_service import LLMFactoriesService, TenantLLMService
|
||||
from api.db.services.llm_service import LLMService, LLMBundle, get_init_tenant_llm
|
||||
from api.db.services.user_service import TenantService, UserTenantService
|
||||
from api.db.joint_services.memory_message_service import init_message_id_sequence, init_memory_size_cache
|
||||
from common.constants import LLMType
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common import settings
|
||||
@ -169,6 +170,8 @@ def init_web_data():
|
||||
# init_superuser()
|
||||
|
||||
add_graph_templates()
|
||||
init_message_id_sequence()
|
||||
init_memory_size_cache()
|
||||
logging.info("init web data success:{}".format(time.time() - start_time))
|
||||
|
||||
|
||||
|
||||
233
api/db/joint_services/memory_message_service.py
Normal file
233
api/db/joint_services/memory_message_service.py
Normal file
@ -0,0 +1,233 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from common.time_utils import current_timestamp, timestamp_to_date, format_iso_8601_to_ymd_hms
|
||||
from common.constants import MemoryType, LLMType
|
||||
from common.doc_store.doc_store_base import FusionExpr
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.utils.memory_utils import get_memory_type_human
|
||||
from memory.services.messages import MessageService
|
||||
from memory.services.query import MsgTextQuery, get_vector
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
from memory.utils.msg_util import get_json_result_from_llm_response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
async def save_to_memory(memory_id: str, message_dict: dict):
|
||||
"""
|
||||
:param memory_id:
|
||||
:param message_dict: {
|
||||
"user_id": str,
|
||||
"agent_id": str,
|
||||
"session_id": str,
|
||||
"user_input": str,
|
||||
"agent_response": str
|
||||
}
|
||||
"""
|
||||
memory = MemoryService.get_by_memory_id(memory_id)
|
||||
if not memory:
|
||||
return False, f"Memory '{memory_id}' not found."
|
||||
|
||||
tenant_id = memory.tenant_id
|
||||
extracted_content = await extract_by_llm(
|
||||
tenant_id,
|
||||
memory.llm_id,
|
||||
{"temperature": memory.temperature},
|
||||
get_memory_type_human(memory.memory_type),
|
||||
message_dict.get("user_input", ""),
|
||||
message_dict.get("agent_response", "")
|
||||
) if memory.memory_type != MemoryType.RAW.value else [] # if only RAW, no need to extract
|
||||
raw_message_id = REDIS_CONN.generate_auto_increment_id(namespace="memory")
|
||||
message_list = [{
|
||||
"message_id": raw_message_id,
|
||||
"message_type": MemoryType.RAW.name.lower(),
|
||||
"source_id": 0,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": f"User Input: {message_dict.get('user_input')}\nAgent Response: {message_dict.get('agent_response')}",
|
||||
"valid_at": timestamp_to_date(current_timestamp()),
|
||||
"invalid_at": None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
}, *[{
|
||||
"message_id": REDIS_CONN.generate_auto_increment_id(namespace="memory"),
|
||||
"message_type": content["message_type"],
|
||||
"source_id": raw_message_id,
|
||||
"memory_id": memory_id,
|
||||
"user_id": "",
|
||||
"agent_id": message_dict["agent_id"],
|
||||
"session_id": message_dict["session_id"],
|
||||
"content": content["content"],
|
||||
"valid_at": content["valid_at"],
|
||||
"invalid_at": content["invalid_at"] if content["invalid_at"] else None,
|
||||
"forget_at": None,
|
||||
"status": True
|
||||
} for content in extracted_content]]
|
||||
embedding_model = LLMBundle(tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
vector_list, _ = embedding_model.encode([msg["content"] for msg in message_list])
|
||||
for idx, msg in enumerate(message_list):
|
||||
msg["content_embed"] = vector_list[idx]
|
||||
vector_dimension = len(vector_list[0])
|
||||
if not MessageService.has_index(tenant_id, memory_id):
|
||||
created = MessageService.create_index(tenant_id, memory_id, vector_size=vector_dimension)
|
||||
if not created:
|
||||
return False, "Failed to create message index."
|
||||
|
||||
new_msg_size = sum([MessageService.calculate_message_size(m) for m in message_list])
|
||||
current_memory_size = get_memory_size_cache(memory_id, tenant_id)
|
||||
if new_msg_size + current_memory_size > memory.memory_size:
|
||||
size_to_delete = current_memory_size + new_msg_size - memory.memory_size
|
||||
if memory.forgetting_policy == "fifo":
|
||||
message_ids_to_delete, delete_size = MessageService.pick_messages_to_delete_by_fifo(memory_id, tenant_id, size_to_delete)
|
||||
MessageService.delete_message({"message_id": message_ids_to_delete}, tenant_id, memory_id)
|
||||
decrease_memory_size_cache(memory_id, tenant_id, delete_size)
|
||||
else:
|
||||
return False, "Failed to insert message into memory. Memory size reached limit and cannot decide which to delete."
|
||||
fail_cases = MessageService.insert_message(message_list, tenant_id, memory_id)
|
||||
if fail_cases:
|
||||
return False, "Failed to insert message into memory. Details: " + "; ".join(fail_cases)
|
||||
|
||||
increase_memory_size_cache(memory_id, tenant_id, new_msg_size)
|
||||
return True, "Message saved successfully."
|
||||
|
||||
|
||||
async def extract_by_llm(tenant_id: str, llm_id: str, extract_conf: dict, memory_type: List[str], user_input: str,
|
||||
agent_response: str, system_prompt: str = "", user_prompt: str="") -> List[dict]:
|
||||
llm_type = TenantLLMService.llm_id2llm_type(llm_id)
|
||||
if not llm_type:
|
||||
raise RuntimeError(f"Unknown type of LLM '{llm_id}'")
|
||||
if not system_prompt:
|
||||
system_prompt = PromptAssembler.assemble_system_prompt({"memory_type": memory_type})
|
||||
conversation_content = f"User Input: {user_input}\nAgent Response: {agent_response}"
|
||||
conversation_time = timestamp_to_date(current_timestamp())
|
||||
user_prompts = []
|
||||
if user_prompt:
|
||||
user_prompts.append({"role": "user", "content": user_prompt})
|
||||
user_prompts.append({"role": "user", "content": f"Conversation: {conversation_content}\nConversation Time: {conversation_time}\nCurrent Time: {conversation_time}"})
|
||||
else:
|
||||
user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)})
|
||||
llm = LLMBundle(tenant_id, llm_type, llm_id)
|
||||
res = await llm.async_chat(system_prompt, user_prompts, extract_conf)
|
||||
res_json = get_json_result_from_llm_response(res)
|
||||
return [{
|
||||
"content": extracted_content["content"],
|
||||
"valid_at": format_iso_8601_to_ymd_hms(extracted_content["valid_at"]),
|
||||
"invalid_at": format_iso_8601_to_ymd_hms(extracted_content["invalid_at"]) if extracted_content.get("invalid_at") else "",
|
||||
"message_type": message_type
|
||||
} for message_type, extracted_content_list in res_json.items() for extracted_content in extracted_content_list]
|
||||
|
||||
|
||||
def query_message(filter_dict: dict, params: dict):
|
||||
"""
|
||||
:param filter_dict: {
|
||||
"memory_id": List[str],
|
||||
"agent_id": optional
|
||||
"session_id": optional
|
||||
}
|
||||
:param params: {
|
||||
"query": question str,
|
||||
"similarity_threshold": float,
|
||||
"keywords_similarity_weight": float,
|
||||
"top_n": int
|
||||
}
|
||||
"""
|
||||
memory_ids = filter_dict["memory_id"]
|
||||
memory_list = MemoryService.get_by_ids(memory_ids)
|
||||
if not memory_list:
|
||||
return []
|
||||
|
||||
condition_dict = {k: v for k, v in filter_dict.items() if v}
|
||||
uids = [memory.tenant_id for memory in memory_list]
|
||||
|
||||
question = params["query"]
|
||||
question = question.strip()
|
||||
memory = memory_list[0]
|
||||
embd_model = LLMBundle(memory.tenant_id, llm_type=LLMType.EMBEDDING, llm_name=memory.embd_id)
|
||||
match_dense = get_vector(question, embd_model, similarity=params["similarity_threshold"])
|
||||
match_text, _ = MsgTextQuery().question(question, min_match=0.3)
|
||||
keywords_similarity_weight = params.get("keywords_similarity_weight", 0.7)
|
||||
fusion_expr = FusionExpr("weighted_sum", params["top_n"], {"weights": ",".join([str(keywords_similarity_weight), str(1 - keywords_similarity_weight)])})
|
||||
|
||||
return MessageService.search_message(memory_ids, condition_dict, uids, [match_text, match_dense, fusion_expr], params["top_n"])
|
||||
|
||||
|
||||
def init_message_id_sequence():
|
||||
message_id_redis_key = "id_generator:memory"
|
||||
if REDIS_CONN.exist(message_id_redis_key):
|
||||
current_max_id = REDIS_CONN.get(message_id_redis_key)
|
||||
logging.info(f"No need to init message_id sequence, current max id is {current_max_id}.")
|
||||
else:
|
||||
max_id = 1
|
||||
exist_memory_list = MemoryService.get_all_memory()
|
||||
if not exist_memory_list:
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
else:
|
||||
max_id = MessageService.get_max_message_id(
|
||||
uid_list=[m.tenant_id for m in exist_memory_list],
|
||||
memory_ids=[m.id for m in exist_memory_list]
|
||||
)
|
||||
REDIS_CONN.set(message_id_redis_key, max_id)
|
||||
logging.info(f"Init message_id sequence done, current max id is {max_id}.")
|
||||
|
||||
|
||||
def get_memory_size_cache(memory_id: str, uid: str):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
if REDIS_CONN.exists(redis_key):
|
||||
return REDIS_CONN.get(redis_key)
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
[memory_id],
|
||||
[uid]
|
||||
)
|
||||
memory_size = memory_size_map.get(memory_id, 0)
|
||||
set_memory_size_cache(memory_id, memory_size)
|
||||
return memory_size
|
||||
|
||||
|
||||
def set_memory_size_cache(memory_id: str, size: int):
|
||||
redis_key = f"memory_{memory_id}"
|
||||
return REDIS_CONN.set(redis_key, size)
|
||||
|
||||
|
||||
def increase_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, current_value + size)
|
||||
|
||||
|
||||
def decrease_memory_size_cache(memory_id: str, uid: str, size: int):
|
||||
current_value = get_memory_size_cache(memory_id, uid)
|
||||
return set_memory_size_cache(memory_id, max(current_value - size, 0))
|
||||
|
||||
|
||||
def init_memory_size_cache():
|
||||
memory_list = MemoryService.get_all_memory()
|
||||
if not memory_list:
|
||||
logging.info("No memory found, no need to init memory size.")
|
||||
else:
|
||||
memory_size_map = MessageService.calculate_memory_size(
|
||||
memory_ids=[m.id for m in memory_list],
|
||||
uid_list=[m.tenant_id for m in memory_list],
|
||||
)
|
||||
for memory in memory_list:
|
||||
memory_size = memory_size_map.get(memory.id, 0)
|
||||
set_memory_size_cache(memory.id, memory_size)
|
||||
logging.info("Memory size cache init done.")
|
||||
@ -34,6 +34,8 @@ from api.db.services.task_service import TaskService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from api.db.services.user_service import TenantService, UserService, UserTenantService
|
||||
from api.db.services.memory_service import MemoryService
|
||||
from memory.services.messages import MessageService
|
||||
from rag.nlp import search
|
||||
from common.constants import ActiveEnum
|
||||
from common import settings
|
||||
@ -200,7 +202,16 @@ def delete_user_data(user_id: str) -> dict:
|
||||
done_msg += f"- Deleted {llm_delete_res} tenant-LLM records.\n"
|
||||
langfuse_delete_res = TenantLangfuseService.delete_ty_tenant_id(tenant_id)
|
||||
done_msg += f"- Deleted {langfuse_delete_res} langfuse records.\n"
|
||||
# step1.3 delete own tenant
|
||||
# step1.3 delete memory and messages
|
||||
user_memory = MemoryService.get_by_tenant_id(tenant_id)
|
||||
if user_memory:
|
||||
for memory in user_memory:
|
||||
if MessageService.has_index(tenant_id, memory.id):
|
||||
MessageService.delete_index(tenant_id, memory.id)
|
||||
done_msg += " Deleted memory index."
|
||||
memory_delete_res = MemoryService.delete_by_ids([m.id for m in user_memory])
|
||||
done_msg += f"Deleted {memory_delete_res} memory datasets."
|
||||
# step1.4 delete own tenant
|
||||
tenant_delete_res = TenantService.delete_by_id(tenant_id)
|
||||
done_msg += f"- Deleted {tenant_delete_res} tenant.\n"
|
||||
# step2 delete user-tenant relation
|
||||
|
||||
@ -123,6 +123,19 @@ class UserCanvasService(CommonService):
|
||||
logging.exception(e)
|
||||
return False, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_basic_info_by_canvas_ids(cls, canvas_id):
|
||||
fields = [
|
||||
cls.model.id,
|
||||
cls.model.avatar,
|
||||
cls.model.user_id,
|
||||
cls.model.title,
|
||||
cls.model.permission,
|
||||
cls.model.canvas_category
|
||||
]
|
||||
return cls.model.select(*fields).where(cls.model.id.in_(canvas_id)).dicts()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_ids(cls, joined_tenant_ids, user_id,
|
||||
|
||||
@ -169,10 +169,12 @@ class CommonService:
|
||||
"""
|
||||
if "id" not in kwargs:
|
||||
kwargs["id"] = get_uuid()
|
||||
kwargs["create_time"] = current_timestamp()
|
||||
kwargs["create_date"] = datetime_format(datetime.now())
|
||||
kwargs["update_time"] = current_timestamp()
|
||||
kwargs["update_date"] = datetime_format(datetime.now())
|
||||
timestamp = current_timestamp()
|
||||
cur_datetime = datetime_format(datetime.now())
|
||||
kwargs["create_time"] = timestamp
|
||||
kwargs["create_date"] = cur_datetime
|
||||
kwargs["update_time"] = timestamp
|
||||
kwargs["update_date"] = cur_datetime
|
||||
sample_obj = cls.model(**kwargs).save(force_insert=True)
|
||||
return sample_obj
|
||||
|
||||
@ -207,10 +209,14 @@ class CommonService:
|
||||
data_list (list): List of dictionaries containing record data to update.
|
||||
Each dictionary must include an 'id' field.
|
||||
"""
|
||||
|
||||
timestamp = current_timestamp()
|
||||
cur_datetime = datetime_format(datetime.now())
|
||||
for data in data_list:
|
||||
data["update_time"] = timestamp
|
||||
data["update_date"] = cur_datetime
|
||||
with DB.atomic():
|
||||
for data in data_list:
|
||||
data["update_time"] = current_timestamp()
|
||||
data["update_date"] = datetime_format(datetime.now())
|
||||
cls.model.update(data).where(cls.model.id == data["id"]).execute()
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -406,7 +406,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs):
|
||||
dialog.vector_similarity_weight,
|
||||
doc_ids=attachments,
|
||||
top=dialog.top_k,
|
||||
aggs=False,
|
||||
aggs=True,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs),
|
||||
)
|
||||
@ -769,7 +769,7 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf
|
||||
vector_similarity_weight=search_config.get("vector_similarity_weight", 0.3),
|
||||
top=search_config.get("top_k", 1024),
|
||||
doc_ids=doc_ids,
|
||||
aggs=False,
|
||||
aggs=True,
|
||||
rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(question, kbs)
|
||||
)
|
||||
|
||||
@ -33,12 +33,13 @@ from api.db.db_models import DB, Document, Knowledgebase, Task, Tenant, UserTena
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from common.metadata_utils import dedupe_list
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import current_timestamp, get_format_time
|
||||
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from common import settings
|
||||
|
||||
|
||||
@ -180,6 +181,16 @@ class DocumentService(CommonService):
|
||||
"1": 2,
|
||||
"2": 2
|
||||
}
|
||||
"metadata": {
|
||||
"key1": {
|
||||
"key1_value1": 1,
|
||||
"key1_value2": 2,
|
||||
},
|
||||
"key2": {
|
||||
"key2_value1": 2,
|
||||
"key2_value2": 1,
|
||||
},
|
||||
}
|
||||
}, total
|
||||
where "1" => RUNNING, "2" => CANCEL
|
||||
"""
|
||||
@ -200,19 +211,40 @@ class DocumentService(CommonService):
|
||||
if suffix:
|
||||
query = query.where(cls.model.suffix.in_(suffix))
|
||||
|
||||
rows = query.select(cls.model.run, cls.model.suffix)
|
||||
rows = query.select(cls.model.run, cls.model.suffix, cls.model.meta_fields)
|
||||
total = rows.count()
|
||||
|
||||
suffix_counter = {}
|
||||
run_status_counter = {}
|
||||
metadata_counter = {}
|
||||
|
||||
for row in rows:
|
||||
suffix_counter[row.suffix] = suffix_counter.get(row.suffix, 0) + 1
|
||||
run_status_counter[str(row.run)] = run_status_counter.get(str(row.run), 0) + 1
|
||||
meta_fields = row.meta_fields or {}
|
||||
if isinstance(meta_fields, str):
|
||||
try:
|
||||
meta_fields = json.loads(meta_fields)
|
||||
except Exception:
|
||||
meta_fields = {}
|
||||
if not isinstance(meta_fields, dict):
|
||||
continue
|
||||
for key, value in meta_fields.items():
|
||||
values = value if isinstance(value, list) else [value]
|
||||
for vv in values:
|
||||
if vv is None:
|
||||
continue
|
||||
if isinstance(vv, str) and not vv.strip():
|
||||
continue
|
||||
sv = str(vv)
|
||||
if key not in metadata_counter:
|
||||
metadata_counter[key] = {}
|
||||
metadata_counter[key][sv] = metadata_counter[key].get(sv, 0) + 1
|
||||
|
||||
return {
|
||||
"suffix": suffix_counter,
|
||||
"run_status": run_status_counter
|
||||
"run_status": run_status_counter,
|
||||
"metadata": metadata_counter,
|
||||
}, total
|
||||
|
||||
@classmethod
|
||||
@ -314,7 +346,7 @@ class DocumentService(CommonService):
|
||||
chunks = settings.docStoreConn.search(["img_id"], [], {"doc_id": doc.id}, [], OrderByExpr(),
|
||||
page * page_size, page_size, search.index_name(tenant_id),
|
||||
[doc.kb_id])
|
||||
chunk_ids = settings.docStoreConn.get_chunk_ids(chunks)
|
||||
chunk_ids = settings.docStoreConn.get_doc_ids(chunks)
|
||||
if not chunk_ids:
|
||||
break
|
||||
all_chunk_ids.extend(chunk_ids)
|
||||
@ -665,10 +697,14 @@ class DocumentService(CommonService):
|
||||
for k,v in r.meta_fields.items():
|
||||
if k not in meta:
|
||||
meta[k] = {}
|
||||
v = str(v)
|
||||
if v not in meta[k]:
|
||||
meta[k][v] = []
|
||||
meta[k][v].append(doc_id)
|
||||
if not isinstance(v, list):
|
||||
v = [v]
|
||||
for vv in v:
|
||||
if vv not in meta[k]:
|
||||
if isinstance(vv, list) or isinstance(vv, dict):
|
||||
continue
|
||||
meta[k][vv] = []
|
||||
meta[k][vv].append(doc_id)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
@ -766,7 +802,10 @@ class DocumentService(CommonService):
|
||||
match_provided = "match" in upd
|
||||
if isinstance(meta[key], list):
|
||||
if not match_provided:
|
||||
meta[key] = new_value
|
||||
if isinstance(new_value, list):
|
||||
meta[key] = dedupe_list(new_value)
|
||||
else:
|
||||
meta[key] = new_value
|
||||
changed = True
|
||||
else:
|
||||
match_value = upd.get("match")
|
||||
@ -779,7 +818,7 @@ class DocumentService(CommonService):
|
||||
else:
|
||||
new_list.append(item)
|
||||
if replaced:
|
||||
meta[key] = new_list
|
||||
meta[key] = dedupe_list(new_list)
|
||||
changed = True
|
||||
else:
|
||||
if not match_provided:
|
||||
@ -1199,8 +1238,8 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
||||
d["q_%d_vec" % len(v)] = v
|
||||
for b in range(0, len(cks), es_bulk_size):
|
||||
if try_create_idx:
|
||||
if not settings.docStoreConn.indexExist(idxnm, kb_id):
|
||||
settings.docStoreConn.createIdx(idxnm, kb_id, len(vectors[0]))
|
||||
if not settings.docStoreConn.index_exist(idxnm, kb_id):
|
||||
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]))
|
||||
try_create_idx = False
|
||||
settings.docStoreConn.insert(cks[b:b + es_bulk_size], idxnm, kb_id)
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ class FileService(CommonService):
|
||||
# Returns:
|
||||
# List of dictionaries containing dataset IDs and names
|
||||
kbs = (
|
||||
cls.model.select(*[Knowledgebase.id, Knowledgebase.name])
|
||||
cls.model.select(*[Knowledgebase.id, Knowledgebase.name, File2Document.document_id])
|
||||
.join(File2Document, on=(File2Document.file_id == file_id))
|
||||
.join(Document, on=(File2Document.document_id == Document.id))
|
||||
.join(Knowledgebase, on=(Knowledgebase.id == Document.kb_id))
|
||||
@ -110,7 +110,7 @@ class FileService(CommonService):
|
||||
return []
|
||||
kbs_info_list = []
|
||||
for kb in list(kbs.dicts()):
|
||||
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"]})
|
||||
kbs_info_list.append({"kb_id": kb["id"], "kb_name": kb["name"], "document_id": kb["document_id"]})
|
||||
return kbs_info_list
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -425,6 +425,7 @@ class KnowledgebaseService(CommonService):
|
||||
|
||||
# Update parser_config (always override with validated default/merged config)
|
||||
payload["parser_config"] = get_parser_config(parser_id, kwargs.get("parser_config"))
|
||||
payload["parser_config"]["llm_id"] = _t.llm_id
|
||||
|
||||
return True, payload
|
||||
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
#
|
||||
from typing import List
|
||||
|
||||
from api.apps import current_user
|
||||
from api.db.db_models import DB, Memory, User
|
||||
from api.db.services import duplicate_name
|
||||
from api.db.services.common_service import CommonService
|
||||
@ -23,6 +22,7 @@ from api.utils.memory_utils import calculate_memory_type
|
||||
from api.constants import MEMORY_NAME_LIMIT
|
||||
from common.misc_utils import get_uuid
|
||||
from common.time_utils import get_format_time, current_timestamp
|
||||
from memory.utils.prompt_util import PromptAssembler
|
||||
|
||||
|
||||
class MemoryService(CommonService):
|
||||
@ -34,6 +34,17 @@ class MemoryService(CommonService):
|
||||
def get_by_memory_id(cls, memory_id: str):
|
||||
return cls.model.select().where(cls.model.id == memory_id).first()
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_by_tenant_id(cls, tenant_id: str):
|
||||
return cls.model.select().where(cls.model.tenant_id == tenant_id)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_all_memory(cls):
|
||||
memory_list = cls.model.select()
|
||||
return list(memory_list)
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def get_with_owner_name_by_id(cls, memory_id: str):
|
||||
@ -53,7 +64,9 @@ class MemoryService(CommonService):
|
||||
cls.model.forgetting_policy,
|
||||
cls.model.temperature,
|
||||
cls.model.system_prompt,
|
||||
cls.model.user_prompt
|
||||
cls.model.user_prompt,
|
||||
cls.model.create_date,
|
||||
cls.model.create_time
|
||||
]
|
||||
memory = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where(
|
||||
cls.model.id == memory_id
|
||||
@ -72,7 +85,9 @@ class MemoryService(CommonService):
|
||||
cls.model.memory_type,
|
||||
cls.model.storage_type,
|
||||
cls.model.permissions,
|
||||
cls.model.description
|
||||
cls.model.description,
|
||||
cls.model.create_time,
|
||||
cls.model.create_date
|
||||
]
|
||||
memories = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id))
|
||||
if filter_dict.get("tenant_id"):
|
||||
@ -102,6 +117,8 @@ class MemoryService(CommonService):
|
||||
if len(memory_name) > MEMORY_NAME_LIMIT:
|
||||
return False, f"Memory name {memory_name} exceeds limit of {MEMORY_NAME_LIMIT}."
|
||||
|
||||
timestamp = current_timestamp()
|
||||
format_time = get_format_time()
|
||||
# build create dict
|
||||
memory_info = {
|
||||
"id": get_uuid(),
|
||||
@ -110,10 +127,11 @@ class MemoryService(CommonService):
|
||||
"tenant_id": tenant_id,
|
||||
"embd_id": embd_id,
|
||||
"llm_id": llm_id,
|
||||
"create_time": current_timestamp(),
|
||||
"create_date": get_format_time(),
|
||||
"update_time": current_timestamp(),
|
||||
"update_date": get_format_time(),
|
||||
"system_prompt": PromptAssembler.assemble_system_prompt({"memory_type": memory_type}),
|
||||
"create_time": timestamp,
|
||||
"create_date": format_time,
|
||||
"update_time": timestamp,
|
||||
"update_date": format_time,
|
||||
}
|
||||
obj = cls.model(**memory_info).save(force_insert=True)
|
||||
|
||||
@ -126,7 +144,7 @@ class MemoryService(CommonService):
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_memory(cls, memory_id: str, update_dict: dict):
|
||||
def update_memory(cls, tenant_id: str, memory_id: str, update_dict: dict):
|
||||
if not update_dict:
|
||||
return 0
|
||||
if "temperature" in update_dict and isinstance(update_dict["temperature"], str):
|
||||
@ -135,7 +153,7 @@ class MemoryService(CommonService):
|
||||
update_dict["name"] = duplicate_name(
|
||||
cls.query,
|
||||
name=update_dict["name"],
|
||||
tenant_id=current_user.id
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
update_dict.update({
|
||||
"update_time": current_timestamp(),
|
||||
|
||||
@ -97,7 +97,7 @@ class TenantLLMService(CommonService):
|
||||
if llm_type == LLMType.EMBEDDING.value:
|
||||
mdlnm = tenant.embd_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.SPEECH2TEXT.value:
|
||||
mdlnm = tenant.asr_id
|
||||
mdlnm = tenant.asr_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.IMAGE2TEXT.value:
|
||||
mdlnm = tenant.img2txt_id if not llm_name else llm_name
|
||||
elif llm_type == LLMType.CHAT.value:
|
||||
|
||||
@ -163,6 +163,7 @@ def validate_request(*args, **kwargs):
|
||||
if error_arguments:
|
||||
error_string += "required argument values: {}".format(",".join(["{}={}".format(a[0], a[1]) for a in error_arguments]))
|
||||
return error_string
|
||||
return None
|
||||
|
||||
def wrapper(func):
|
||||
@wraps(func)
|
||||
@ -409,7 +410,7 @@ def get_parser_config(chunk_method, parser_config):
|
||||
if default_config is None:
|
||||
return deep_merge(base_defaults, parser_config)
|
||||
|
||||
# Ensure raptor and graphrag fields have default values if not provided
|
||||
# Ensure raptor and graph_rag fields have default values if not provided
|
||||
merged_config = deep_merge(base_defaults, default_config)
|
||||
merged_config = deep_merge(merged_config, parser_config)
|
||||
|
||||
|
||||
@ -54,6 +54,7 @@ class RetCode(IntEnum, CustomEnum):
|
||||
SERVER_ERROR = 500
|
||||
FORBIDDEN = 403
|
||||
NOT_FOUND = 404
|
||||
CONFLICT = 409
|
||||
|
||||
|
||||
class StatusEnum(Enum):
|
||||
@ -124,7 +125,11 @@ class FileSource(StrEnum):
|
||||
MOODLE = "moodle"
|
||||
DROPBOX = "dropbox"
|
||||
BOX = "box"
|
||||
R2 = "r2"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
|
||||
|
||||
class PipelineTaskType(StrEnum):
|
||||
PARSE = "Parse"
|
||||
DOWNLOAD = "Download"
|
||||
|
||||
@ -56,7 +56,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
# Validate credentials
|
||||
if self.bucket_type == BlobType.R2:
|
||||
if not all(
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"]
|
||||
):
|
||||
@ -64,15 +64,23 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
elif self.bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
|
||||
if authentication_method == "access_key":
|
||||
if not all(
|
||||
credentials.get(key)
|
||||
for key in ["aws_access_key_id", "aws_secret_access_key"]
|
||||
):
|
||||
raise ConnectorMissingCredentialError("Amazon S3")
|
||||
|
||||
elif authentication_method == "iam_role":
|
||||
if not credentials.get("aws_role_arn"):
|
||||
raise ConnectorMissingCredentialError("Amazon S3 IAM role ARN is required")
|
||||
|
||||
elif authentication_method == "assume_role":
|
||||
pass
|
||||
|
||||
else:
|
||||
raise ConnectorMissingCredentialError("Unsupported S3 authentication method")
|
||||
|
||||
elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE:
|
||||
if not all(
|
||||
@ -120,55 +128,72 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
|
||||
batch: list[Document] = []
|
||||
# Collect all objects first to count filename occurrences
|
||||
all_objects = []
|
||||
for page in pages:
|
||||
if "Contents" not in page:
|
||||
continue
|
||||
|
||||
for obj in page["Contents"]:
|
||||
if obj["Key"].endswith("/"):
|
||||
continue
|
||||
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
if start < last_modified <= end:
|
||||
all_objects.append(obj)
|
||||
|
||||
# Count filename occurrences to determine which need full paths
|
||||
filename_counts: dict[str, int] = {}
|
||||
for obj in all_objects:
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
filename_counts[file_name] = filename_counts.get(file_name, 0) + 1
|
||||
|
||||
if not (start < last_modified <= end):
|
||||
batch: list[Document] = []
|
||||
for obj in all_objects:
|
||||
last_modified = obj["LastModified"].replace(tzinfo=timezone.utc)
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
|
||||
file_name = os.path.basename(obj["Key"])
|
||||
key = obj["Key"]
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(file_name, 0) > 1:
|
||||
relative_path = key
|
||||
if self.prefix and key.startswith(self.prefix):
|
||||
relative_path = key[len(self.prefix):]
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else file_name
|
||||
else:
|
||||
semantic_id = file_name
|
||||
|
||||
size_bytes = extract_size_bytes(obj)
|
||||
if (
|
||||
self.size_threshold is not None
|
||||
and isinstance(size_bytes, int)
|
||||
and size_bytes > self.size_threshold
|
||||
):
|
||||
logging.warning(
|
||||
f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping."
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
continue
|
||||
try:
|
||||
blob = download_object(self.s3_client, self.bucket_name, key, self.size_threshold)
|
||||
if blob is None:
|
||||
continue
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{key}",
|
||||
blob=blob,
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=file_name,
|
||||
extension=get_file_ext(file_name),
|
||||
doc_updated_at=last_modified,
|
||||
size_bytes=size_bytes if size_bytes else 0
|
||||
)
|
||||
)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
except Exception:
|
||||
logging.exception(f"Error decoding object {key}")
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
@ -276,4 +301,4 @@ if __name__ == "__main__":
|
||||
except ConnectorMissingCredentialError as e:
|
||||
print(f"Error: {e}")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
print(f"An unexpected error occurred: {e}")
|
||||
|
||||
@ -83,6 +83,7 @@ _PAGE_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"metadata.labels",
|
||||
"history.lastUpdated",
|
||||
"ancestors",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -186,7 +186,7 @@ class OnyxConfluence:
|
||||
# between the db and redis everywhere the credentials might be updated
|
||||
new_credential_str = json.dumps(new_credentials)
|
||||
self.redis_client.set(
|
||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
||||
self.credential_key, new_credential_str, exp=self.CREDENTIAL_TTL
|
||||
)
|
||||
self._credentials_provider.set_credentials(new_credentials)
|
||||
|
||||
@ -1311,6 +1311,9 @@ class ConfluenceConnector(
|
||||
self._low_timeout_confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
self.allow_images = False
|
||||
# Track document names to detect duplicates
|
||||
self._document_name_counts: dict[str, int] = {}
|
||||
self._document_name_paths: dict[str, list[str]] = {}
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
@ -1513,6 +1516,40 @@ class ConfluenceConnector(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build hierarchical path for semantic identifier
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
|
||||
# Build path from ancestors
|
||||
path_parts = []
|
||||
if space_name:
|
||||
path_parts.append(space_name)
|
||||
|
||||
# Add ancestor pages to path if available
|
||||
if "ancestors" in page and page["ancestors"]:
|
||||
for ancestor in page["ancestors"]:
|
||||
ancestor_title = ancestor.get("title", "")
|
||||
if ancestor_title:
|
||||
path_parts.append(ancestor_title)
|
||||
|
||||
# Add current page title
|
||||
path_parts.append(page_title)
|
||||
|
||||
# Track page names for duplicate detection
|
||||
full_path = " / ".join(path_parts) if len(path_parts) > 1 else page_title
|
||||
|
||||
# Count occurrences of this page title
|
||||
if page_title not in self._document_name_counts:
|
||||
self._document_name_counts[page_title] = 0
|
||||
self._document_name_paths[page_title] = []
|
||||
self._document_name_counts[page_title] += 1
|
||||
self._document_name_paths[page_title].append(full_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[page_title] == 1:
|
||||
semantic_identifier = page_title
|
||||
else:
|
||||
semantic_identifier = full_path
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
self.confluence_client, page, self._fetched_titles
|
||||
@ -1559,11 +1596,11 @@ class ConfluenceConnector(
|
||||
return Document(
|
||||
id=page_url,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
semantic_identifier=semantic_identifier,
|
||||
extension=".html", # Confluence pages are HTML
|
||||
blob=page_content.encode("utf-8"), # Encode page content as bytes
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
doc_updated_at=datetime_from_string(page["version"]["when"]),
|
||||
size_bytes=len(page_content.encode("utf-8")), # Calculate size in bytes
|
||||
primary_owners=primary_owners if primary_owners else None,
|
||||
metadata=metadata if metadata else None,
|
||||
)
|
||||
@ -1601,7 +1638,6 @@ class ConfluenceConnector(
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# TODO(rkuo): this check is partially redundant with validate_attachment_filetype
|
||||
# and checks in convert_attachment_to_content/process_attachment
|
||||
# but doing the check here avoids an unnecessary download. Due for refactoring.
|
||||
@ -1669,6 +1705,34 @@ class ConfluenceConnector(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
# Build semantic identifier with space and page context
|
||||
attachment_title = attachment.get("title", object_url)
|
||||
space_name = page.get("space", {}).get("name", "")
|
||||
page_title = page.get("title", "")
|
||||
|
||||
# Create hierarchical name: Space / Page / Attachment
|
||||
attachment_path_parts = []
|
||||
if space_name:
|
||||
attachment_path_parts.append(space_name)
|
||||
if page_title:
|
||||
attachment_path_parts.append(page_title)
|
||||
attachment_path_parts.append(attachment_title)
|
||||
|
||||
full_attachment_path = " / ".join(attachment_path_parts) if len(attachment_path_parts) > 1 else attachment_title
|
||||
|
||||
# Track attachment names for duplicate detection
|
||||
if attachment_title not in self._document_name_counts:
|
||||
self._document_name_counts[attachment_title] = 0
|
||||
self._document_name_paths[attachment_title] = []
|
||||
self._document_name_counts[attachment_title] += 1
|
||||
self._document_name_paths[attachment_title].append(full_attachment_path)
|
||||
|
||||
# Use simple name if no duplicates, otherwise use full path
|
||||
if self._document_name_counts[attachment_title] == 1:
|
||||
attachment_semantic_identifier = attachment_title
|
||||
else:
|
||||
attachment_semantic_identifier = full_attachment_path
|
||||
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
if "version" in attachment and "by" in attachment["version"]:
|
||||
author = attachment["version"]["by"]
|
||||
@ -1680,11 +1744,12 @@ class ConfluenceConnector(
|
||||
|
||||
extension = Path(attachment.get("title", "")).suffix or ".unknown"
|
||||
|
||||
|
||||
attachment_doc = Document(
|
||||
id=attachment_id,
|
||||
# sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=attachment.get("title", object_url),
|
||||
semantic_identifier=attachment_semantic_identifier,
|
||||
extension=extension,
|
||||
blob=file_blob,
|
||||
size_bytes=len(file_blob),
|
||||
@ -1741,7 +1806,7 @@ class ConfluenceConnector(
|
||||
start_ts, end, self.batch_size
|
||||
)
|
||||
logging.debug(f"page_query_url: {page_query_url}")
|
||||
|
||||
|
||||
# store the next page start for confluence server, cursor for confluence cloud
|
||||
def store_next_page_url(next_page_url: str) -> None:
|
||||
checkpoint.next_page_url = next_page_url
|
||||
|
||||
@ -87,15 +87,69 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
# Collect all files first to count filename occurrences
|
||||
all_files = []
|
||||
self._collect_files_recursive(path, start, end, all_files)
|
||||
|
||||
# Count filename occurrences
|
||||
filename_counts: dict[str, int] = {}
|
||||
for entry, _ in all_files:
|
||||
filename_counts[entry.name] = filename_counts.get(entry.name, 0) + 1
|
||||
|
||||
# Process files in batches
|
||||
batch: list[Document] = []
|
||||
for entry, downloaded_file in all_files:
|
||||
modified_time = entry.client_modified
|
||||
if modified_time.tzinfo is None:
|
||||
modified_time = modified_time.replace(tzinfo=timezone.utc)
|
||||
else:
|
||||
modified_time = modified_time.astimezone(timezone.utc)
|
||||
|
||||
# Use full path only if filename appears multiple times
|
||||
if filename_counts.get(entry.name, 0) > 1:
|
||||
# Remove leading slash and replace slashes with ' / '
|
||||
relative_path = entry.path_display.lstrip('/')
|
||||
semantic_id = relative_path.replace('/', ' / ') if relative_path else entry.name
|
||||
else:
|
||||
semantic_id = entry.name
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=semantic_id,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
|
||||
def _collect_files_recursive(
|
||||
self,
|
||||
path: str,
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
all_files: list,
|
||||
) -> None:
|
||||
"""Recursively collect all files matching time criteria."""
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox")
|
||||
|
||||
result = self.dropbox_client.files_list_folder(
|
||||
path,
|
||||
limit=self.batch_size,
|
||||
recursive=False,
|
||||
include_non_downloadable_files=False,
|
||||
)
|
||||
|
||||
while True:
|
||||
batch: list[Document] = []
|
||||
for entry in result.entries:
|
||||
if isinstance(entry, FileMetadata):
|
||||
modified_time = entry.client_modified
|
||||
@ -112,27 +166,13 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
downloaded_file = self._download_file(entry.path_display)
|
||||
all_files.append((entry, downloaded_file))
|
||||
except Exception:
|
||||
logger.exception(f"[Dropbox]: Error downloading file {entry.path_display}")
|
||||
continue
|
||||
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"dropbox:{entry.id}",
|
||||
blob=downloaded_file,
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
extension=get_file_ext(entry.name),
|
||||
doc_updated_at=modified_time,
|
||||
size_bytes=entry.size if getattr(entry, "size", None) is not None else len(downloaded_file),
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(entry, FolderMetadata):
|
||||
yield from self._yield_files_recursive(entry.path_lower, start, end)
|
||||
|
||||
if batch:
|
||||
yield batch
|
||||
self._collect_files_recursive(entry.path_lower, start, end, all_files)
|
||||
|
||||
if not result.has_more:
|
||||
break
|
||||
|
||||
@ -94,6 +94,7 @@ class Document(BaseModel):
|
||||
blob: bytes
|
||||
doc_updated_at: datetime
|
||||
size_bytes: int
|
||||
primary_owners: list
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@ -180,6 +181,7 @@ class NotionPage(BaseModel):
|
||||
archived: bool
|
||||
properties: dict[str, Any]
|
||||
url: str
|
||||
parent: Optional[dict[str, Any]] = None # Parent reference for path reconstruction
|
||||
database_name: Optional[str] = None # Only applicable to database type pages
|
||||
|
||||
|
||||
|
||||
@ -66,6 +66,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
self.indexed_pages: set[str] = set()
|
||||
self.root_page_id = root_page_id
|
||||
self.recursive_index_enabled = recursive_index_enabled or bool(root_page_id)
|
||||
self.page_path_cache: dict[str, str] = {}
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _fetch_child_blocks(self, block_id: str, cursor: Optional[str] = None) -> dict[str, Any] | None:
|
||||
@ -242,6 +243,20 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
logging.warning(f"[Notion]: Failed to download Notion file from {url}: {exc}")
|
||||
return None
|
||||
|
||||
def _append_block_id_to_name(self, name: str, block_id: Optional[str]) -> str:
|
||||
"""Append the Notion block ID to the filename while keeping the extension."""
|
||||
if not block_id:
|
||||
return name
|
||||
|
||||
path = Path(name)
|
||||
stem = path.stem or name
|
||||
suffix = path.suffix
|
||||
|
||||
if not stem:
|
||||
return name
|
||||
|
||||
return f"{stem}_{block_id}{suffix}" if suffix else f"{stem}_{block_id}"
|
||||
|
||||
def _extract_file_metadata(self, result_obj: dict[str, Any], block_id: str) -> tuple[str | None, str, str | None]:
|
||||
file_source_type = result_obj.get("type")
|
||||
file_source = result_obj.get(file_source_type, {}) if file_source_type else {}
|
||||
@ -254,6 +269,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
elif not name:
|
||||
name = f"notion_file_{block_id}"
|
||||
|
||||
name = self._append_block_id_to_name(name, block_id)
|
||||
|
||||
caption = self._extract_rich_text(result_obj.get("caption", [])) if "caption" in result_obj else None
|
||||
|
||||
return url, name, caption
|
||||
@ -265,6 +282,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name: str,
|
||||
caption: Optional[str],
|
||||
page_last_edited_time: Optional[str],
|
||||
page_path: Optional[str],
|
||||
) -> Document | None:
|
||||
file_bytes = self._download_file(url)
|
||||
if file_bytes is None:
|
||||
@ -277,7 +295,8 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
extension = ".bin"
|
||||
|
||||
updated_at = datetime_from_string(page_last_edited_time) if page_last_edited_time else datetime.now(timezone.utc)
|
||||
semantic_identifier = caption or name or f"Notion file {block_id}"
|
||||
base_identifier = name or caption or (f"Notion file {block_id}" if block_id else "Notion file")
|
||||
semantic_identifier = f"{page_path} / {base_identifier}" if page_path else base_identifier
|
||||
|
||||
return Document(
|
||||
id=block_id,
|
||||
@ -289,7 +308,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
doc_updated_at=updated_at,
|
||||
)
|
||||
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
def _read_blocks(self, base_block_id: str, page_last_edited_time: Optional[str] = None, page_path: Optional[str] = None) -> tuple[list[NotionBlock], list[str], list[Document]]:
|
||||
result_blocks: list[NotionBlock] = []
|
||||
child_pages: list[str] = []
|
||||
attachments: list[Document] = []
|
||||
@ -370,11 +389,14 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
name=file_name,
|
||||
caption=caption,
|
||||
page_last_edited_time=page_last_edited_time,
|
||||
page_path=page_path,
|
||||
)
|
||||
if attachment_doc:
|
||||
attachments.append(attachment_doc)
|
||||
|
||||
attachment_label = caption or file_name
|
||||
attachment_label = file_name
|
||||
if caption:
|
||||
attachment_label = f"{file_name} ({caption})"
|
||||
if attachment_label:
|
||||
cur_result_text_arr.append(f"{result_type.capitalize()}: {attachment_label}")
|
||||
|
||||
@ -383,7 +405,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
child_pages.append(result_block_id)
|
||||
else:
|
||||
logging.debug(f"[Notion]: Entering sub-block: {result_block_id}")
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time)
|
||||
subblocks, subblock_child_pages, subblock_attachments = self._read_blocks(result_block_id, page_last_edited_time, page_path)
|
||||
logging.debug(f"[Notion]: Finished sub-block: {result_block_id}")
|
||||
result_blocks.extend(subblocks)
|
||||
child_pages.extend(subblock_child_pages)
|
||||
@ -423,6 +445,35 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def _build_page_path(self, page: NotionPage, visited: Optional[set[str]] = None) -> Optional[str]:
|
||||
"""Construct a hierarchical path for a page based on its parent chain."""
|
||||
if page.id in self.page_path_cache:
|
||||
return self.page_path_cache[page.id]
|
||||
|
||||
visited = visited or set()
|
||||
if page.id in visited:
|
||||
logging.warning(f"[Notion]: Detected cycle while building path for page {page.id}")
|
||||
return self._read_page_title(page)
|
||||
visited.add(page.id)
|
||||
|
||||
current_title = self._read_page_title(page) or f"Untitled Page {page.id}"
|
||||
|
||||
parent_info = getattr(page, "parent", None) or {}
|
||||
parent_type = parent_info.get("type")
|
||||
parent_id = parent_info.get(parent_type) if parent_type else None
|
||||
|
||||
parent_path = None
|
||||
if parent_type in {"page_id", "database_id"} and isinstance(parent_id, str):
|
||||
try:
|
||||
parent_page = self._fetch_page(parent_id)
|
||||
parent_path = self._build_page_path(parent_page, visited)
|
||||
except Exception as exc:
|
||||
logging.warning(f"[Notion]: Failed to resolve parent {parent_id} for page {page.id}: {exc}")
|
||||
|
||||
full_path = f"{parent_path} / {current_title}" if parent_path else current_title
|
||||
self.page_path_cache[page.id] = full_path
|
||||
return full_path
|
||||
|
||||
def _read_pages(self, pages: list[NotionPage], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None) -> Generator[Document, None, None]:
|
||||
"""Reads pages for rich text content and generates Documents."""
|
||||
all_child_page_ids: list[str] = []
|
||||
@ -441,13 +492,18 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
continue
|
||||
|
||||
logging.info(f"[Notion]: Reading page with ID {page.id}, with url {page.url}")
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time)
|
||||
page_path = self._build_page_path(page)
|
||||
page_blocks, child_page_ids, attachment_docs = self._read_blocks(page.id, page.last_edited_time, page_path)
|
||||
all_child_page_ids.extend(child_page_ids)
|
||||
self.indexed_pages.add(page.id)
|
||||
|
||||
raw_page_title = self._read_page_title(page)
|
||||
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
|
||||
|
||||
# Append the page id to help disambiguate duplicate names
|
||||
base_identifier = page_path or page_title
|
||||
semantic_identifier = f"{base_identifier}_{page.id}" if base_identifier else page.id
|
||||
|
||||
if not page_blocks:
|
||||
if not raw_page_title:
|
||||
logging.warning(f"[Notion]: No blocks OR title found for page with ID {page.id}. Skipping.")
|
||||
@ -469,7 +525,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
joined_text = "\n".join(sec.text for sec in sections)
|
||||
blob = joined_text.encode("utf-8")
|
||||
yield Document(
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=page_title, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
id=page.id, blob=blob, source=DocumentSource.NOTION, semantic_identifier=semantic_identifier, extension=".txt", size_bytes=len(blob), doc_updated_at=datetime_from_string(page.last_edited_time)
|
||||
)
|
||||
|
||||
for attachment_doc in attachment_docs:
|
||||
@ -597,4 +653,4 @@ if __name__ == "__main__":
|
||||
document_batches = connector.load_from_state()
|
||||
for doc_batch in document_batches:
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
print(doc)
|
||||
@ -167,7 +167,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
|
||||
|
||||
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
|
||||
"""构建文档ID"""
|
||||
return f"{channel_id}__{thread_ts}"
|
||||
|
||||
|
||||
@ -179,7 +178,6 @@ def thread_to_doc(
|
||||
user_cache: dict[str, BasicExpertInfo | None],
|
||||
channel_access: Any | None,
|
||||
) -> Document:
|
||||
"""将线程转换为文档"""
|
||||
channel_id = channel["id"]
|
||||
|
||||
initial_sender_expert_info = expert_info_from_slack_id(
|
||||
@ -237,7 +235,6 @@ def filter_channels(
|
||||
channels_to_connect: list[str] | None,
|
||||
regex_enabled: bool,
|
||||
) -> list[ChannelType]:
|
||||
"""过滤频道"""
|
||||
if not channels_to_connect:
|
||||
return all_channels
|
||||
|
||||
@ -381,7 +378,6 @@ def _process_message(
|
||||
[MessageType], SlackMessageFilterReason | None
|
||||
] = default_msg_filter,
|
||||
) -> ProcessedSlackMessage:
|
||||
"""处理消息"""
|
||||
thread_ts = message.get("thread_ts")
|
||||
thread_or_message_ts = thread_ts or message["ts"]
|
||||
try:
|
||||
@ -536,7 +532,6 @@ class SlackConnector(
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: Any = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""获取所有简化文档(带权限同步)"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
|
||||
@ -254,18 +254,21 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
elif bucket_type == BlobType.S3:
|
||||
authentication_method = credentials.get("authentication_method", "access_key")
|
||||
|
||||
region_name = credentials.get("region") or None
|
||||
|
||||
if authentication_method == "access_key":
|
||||
session = boto3.Session(
|
||||
aws_access_key_id=credentials["aws_access_key_id"],
|
||||
aws_secret_access_key=credentials["aws_secret_access_key"],
|
||||
region_name=region_name,
|
||||
)
|
||||
return session.client("s3")
|
||||
return session.client("s3", region_name=region_name)
|
||||
|
||||
elif authentication_method == "iam_role":
|
||||
role_arn = credentials["aws_role_arn"]
|
||||
|
||||
def _refresh_credentials() -> dict[str, str]:
|
||||
sts_client = boto3.client("sts")
|
||||
sts_client = boto3.client("sts", region_name=credentials.get("region") or None)
|
||||
assumed_role_object = sts_client.assume_role(
|
||||
RoleArn=role_arn,
|
||||
RoleSessionName=f"onyx_blob_storage_{int(datetime.now().timestamp())}",
|
||||
@ -285,11 +288,11 @@ def create_s3_client(bucket_type: BlobType, credentials: dict[str, Any], europea
|
||||
)
|
||||
botocore_session = get_session()
|
||||
botocore_session._credentials = refreshable
|
||||
session = boto3.Session(botocore_session=botocore_session)
|
||||
return session.client("s3")
|
||||
session = boto3.Session(botocore_session=botocore_session, region_name=region_name)
|
||||
return session.client("s3", region_name=region_name)
|
||||
|
||||
elif authentication_method == "assume_role":
|
||||
return boto3.client("s3")
|
||||
return boto3.client("s3", region_name=region_name)
|
||||
|
||||
else:
|
||||
raise ValueError("Invalid authentication method for S3.")
|
||||
|
||||
0
common/doc_store/__init__.py
Normal file
0
common/doc_store/__init__.py
Normal file
@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
@ -22,7 +21,6 @@ DEFAULT_MATCH_VECTOR_TOPN = 10
|
||||
DEFAULT_MATCH_SPARSE_TOPN = 10
|
||||
VEC = list | np.ndarray
|
||||
|
||||
|
||||
@dataclass
|
||||
class SparseVector:
|
||||
indices: list[int]
|
||||
@ -55,14 +53,13 @@ class SparseVector:
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
|
||||
|
||||
class MatchTextExpr(ABC):
|
||||
class MatchTextExpr:
|
||||
def __init__(
|
||||
self,
|
||||
fields: list[str],
|
||||
matching_text: str,
|
||||
topn: int,
|
||||
extra_options: dict = dict(),
|
||||
extra_options: dict | None = None,
|
||||
):
|
||||
self.fields = fields
|
||||
self.matching_text = matching_text
|
||||
@ -70,7 +67,7 @@ class MatchTextExpr(ABC):
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchDenseExpr(ABC):
|
||||
class MatchDenseExpr:
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
@ -78,7 +75,7 @@ class MatchDenseExpr(ABC):
|
||||
embedding_data_type: str,
|
||||
distance_type: str,
|
||||
topn: int = DEFAULT_MATCH_VECTOR_TOPN,
|
||||
extra_options: dict = dict(),
|
||||
extra_options: dict | None = None,
|
||||
):
|
||||
self.vector_column_name = vector_column_name
|
||||
self.embedding_data = embedding_data
|
||||
@ -88,7 +85,7 @@ class MatchDenseExpr(ABC):
|
||||
self.extra_options = extra_options
|
||||
|
||||
|
||||
class MatchSparseExpr(ABC):
|
||||
class MatchSparseExpr:
|
||||
def __init__(
|
||||
self,
|
||||
vector_column_name: str,
|
||||
@ -104,7 +101,7 @@ class MatchSparseExpr(ABC):
|
||||
self.opt_params = opt_params
|
||||
|
||||
|
||||
class MatchTensorExpr(ABC):
|
||||
class MatchTensorExpr:
|
||||
def __init__(
|
||||
self,
|
||||
column_name: str,
|
||||
@ -120,7 +117,7 @@ class MatchTensorExpr(ABC):
|
||||
self.extra_option = extra_option
|
||||
|
||||
|
||||
class FusionExpr(ABC):
|
||||
class FusionExpr:
|
||||
def __init__(self, method: str, topn: int, fusion_params: dict | None = None):
|
||||
self.method = method
|
||||
self.topn = topn
|
||||
@ -129,7 +126,8 @@ class FusionExpr(ABC):
|
||||
|
||||
MatchExpr = MatchTextExpr | MatchDenseExpr | MatchSparseExpr | MatchTensorExpr | FusionExpr
|
||||
|
||||
class OrderByExpr(ABC):
|
||||
|
||||
class OrderByExpr:
|
||||
def __init__(self):
|
||||
self.fields = list()
|
||||
def asc(self, field: str):
|
||||
@ -141,13 +139,14 @@ class OrderByExpr(ABC):
|
||||
def fields(self):
|
||||
return self.fields
|
||||
|
||||
|
||||
class DocStoreConnection(ABC):
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def dbType(self) -> str:
|
||||
def db_type(self) -> str:
|
||||
"""
|
||||
Return the type of the database.
|
||||
"""
|
||||
@ -165,21 +164,21 @@ class DocStoreConnection(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
"""
|
||||
Create an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def deleteIdx(self, indexName: str, knowledgebaseId: str):
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
"""
|
||||
Delete an index with given name
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def indexExist(self, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def index_exist(self, index_name: str, dataset_id: str) -> bool:
|
||||
"""
|
||||
Check if an index with given name exists
|
||||
"""
|
||||
@ -191,16 +190,16 @@ class DocStoreConnection(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, selectFields: list[str],
|
||||
highlightFields: list[str],
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
matchExprs: list[MatchExpr],
|
||||
orderBy: OrderByExpr,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
indexNames: str|list[str],
|
||||
knowledgebaseIds: list[str],
|
||||
aggFields: list[str] = [],
|
||||
index_names: str|list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
"""
|
||||
@ -209,28 +208,28 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
|
||||
def get(self, data_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
||||
"""
|
||||
Get single chunk with given id
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, rows: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
|
||||
def insert(self, rows: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
||||
"""
|
||||
Update or insert a bulk of rows
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
"""
|
||||
Update rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
"""
|
||||
Delete rows with given conjunctive equivalent filtering condition
|
||||
"""
|
||||
@ -245,7 +244,7 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_chunk_ids(self, res):
|
||||
def get_doc_ids(self, res):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
@ -253,18 +252,18 @@ class DocStoreConnection(ABC):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_highlight(self, res, keywords: list[str], fieldnm: str):
|
||||
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregation(self, res, fieldnm: str):
|
||||
def get_aggregation(self, res, field_name: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
@abstractmethod
|
||||
def sql(sql: str, fetch_size: int, format: str):
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
"""
|
||||
Run the sql generated by text-to-sql
|
||||
"""
|
||||
326
common/doc_store/es_conn_base.py
Normal file
326
common/doc_store/es_conn_base.py
Normal file
@ -0,0 +1,326 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
|
||||
from elasticsearch import Elasticsearch, NotFoundError
|
||||
from elasticsearch_dsl import Index
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.file_utils import get_project_base_directory
|
||||
from common.misc_utils import convert_bytes
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, OrderByExpr, MatchExpr
|
||||
from rag.nlp import is_english, rag_tokenizer
|
||||
from common import settings
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
|
||||
class ESConnectionBase(DocStoreConnection):
|
||||
def __init__(self, mapping_file_name: str="mapping.json", logger_name: str='ragflow.es_conn'):
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
|
||||
self.info = {}
|
||||
self.logger.info(f"Use Elasticsearch {settings.ES['hosts']} as the doc engine.")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
if self._connect():
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.warning(f"{str(e)}. Waiting Elasticsearch {settings.ES['hosts']} to be healthy.")
|
||||
time.sleep(5)
|
||||
|
||||
if not self.es.ping():
|
||||
msg = f"Elasticsearch {settings.ES['hosts']} is unhealthy in 120s."
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
v = self.info.get("version", {"number": "8.11.3"})
|
||||
v = v["number"].split(".")[0]
|
||||
if int(v) < 8:
|
||||
msg = f"Elasticsearch version must be greater than or equal to 8, current version: {v}"
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
msg = f"Elasticsearch mapping file not found at {fp_mapping}"
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.mapping = json.load(open(fp_mapping, "r"))
|
||||
self.logger.info(f"Elasticsearch {settings.ES['hosts']} is healthy.")
|
||||
|
||||
def _connect(self):
|
||||
self.es = Elasticsearch(
|
||||
settings.ES["hosts"].split(","),
|
||||
basic_auth=(settings.ES["username"], settings.ES[
|
||||
"password"]) if "username" in settings.ES and "password" in settings.ES else None,
|
||||
verify_certs= settings.ES.get("verify_certs", False),
|
||||
timeout=600 )
|
||||
if self.es:
|
||||
self.info = self.es.info()
|
||||
return True
|
||||
return False
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def db_type(self) -> str:
|
||||
return "elasticsearch"
|
||||
|
||||
def health(self) -> dict:
|
||||
health_dict = dict(self.es.cluster.health())
|
||||
health_dict["type"] = "elasticsearch"
|
||||
return health_dict
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
curl -XGET "http://{es_host}/_cluster/stats" -H "kbn-xsrf: reporting" to view raw stats.
|
||||
"""
|
||||
raw_stats = self.es.cluster.stats()
|
||||
self.logger.debug(f"ESConnection.get_cluster_stats: {raw_stats}")
|
||||
try:
|
||||
res = {
|
||||
'cluster_name': raw_stats['cluster_name'],
|
||||
'status': raw_stats['status']
|
||||
}
|
||||
indices_status = raw_stats['indices']
|
||||
res.update({
|
||||
'indices': indices_status['count'],
|
||||
'indices_shards': indices_status['shards']['total']
|
||||
})
|
||||
doc_info = indices_status['docs']
|
||||
res.update({
|
||||
'docs': doc_info['count'],
|
||||
'docs_deleted': doc_info['deleted']
|
||||
})
|
||||
store_info = indices_status['store']
|
||||
res.update({
|
||||
'store_size': convert_bytes(store_info['size_in_bytes']),
|
||||
'total_dataset_size': convert_bytes(store_info['total_data_set_size_in_bytes'])
|
||||
})
|
||||
mappings_info = indices_status['mappings']
|
||||
res.update({
|
||||
'mappings_fields': mappings_info['total_field_count'],
|
||||
'mappings_deduplicated_fields': mappings_info['total_deduplicated_field_count'],
|
||||
'mappings_deduplicated_size': convert_bytes(mappings_info['total_deduplicated_mapping_size_in_bytes'])
|
||||
})
|
||||
node_info = raw_stats['nodes']
|
||||
res.update({
|
||||
'nodes': node_info['count']['total'],
|
||||
'nodes_version': node_info['versions'],
|
||||
'os_mem': convert_bytes(node_info['os']['mem']['total_in_bytes']),
|
||||
'os_mem_used': convert_bytes(node_info['os']['mem']['used_in_bytes']),
|
||||
'os_mem_used_percent': node_info['os']['mem']['used_percent'],
|
||||
'jvm_versions': node_info['jvm']['versions'][0]['vm_version'],
|
||||
'jvm_heap_used': convert_bytes(node_info['jvm']['mem']['heap_used_in_bytes']),
|
||||
'jvm_heap_max': convert_bytes(node_info['jvm']['mem']['heap_max_in_bytes'])
|
||||
})
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get_cluster_stats: {e}")
|
||||
return None
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
if self.index_exist(index_name, dataset_id):
|
||||
return True
|
||||
try:
|
||||
from elasticsearch.client import IndicesClient
|
||||
return IndicesClient(self.es).create(index=index_name,
|
||||
settings=self.mapping["settings"],
|
||||
mappings=self.mapping["mappings"])
|
||||
except Exception:
|
||||
self.logger.exception("ESConnection.createIndex error %s" % index_name)
|
||||
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
if len(dataset_id) > 0:
|
||||
# The index need to be alive after any kb deletion since all kb under this tenant are in one index.
|
||||
return
|
||||
try:
|
||||
self.es.indices.delete(index=index_name, allow_no_indices=True)
|
||||
except NotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
self.logger.exception("ESConnection.deleteIdx error %s" % index_name)
|
||||
|
||||
def index_exist(self, index_name: str, dataset_id: str = None) -> bool:
|
||||
s = Index(index_name, self.es)
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
return s.exists()
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
break
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def get(self, doc_id: str, index_name: str, dataset_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=index_name,
|
||||
id=doc_id, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
doc = res["_source"]
|
||||
doc["id"] = doc_id
|
||||
return doc
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
|
||||
raise e
|
||||
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None
|
||||
):
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, documents: list[dict], index_name: str, dataset_id: str = None) -> list[str]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res):
|
||||
if isinstance(res["hits"]["total"], type({})):
|
||||
return res["hits"]["total"]["value"]
|
||||
return res["hits"]["total"]
|
||||
|
||||
def get_doc_ids(self, res):
|
||||
return [d["_id"] for d in res["hits"]["hits"]]
|
||||
|
||||
def _get_source(self, res):
|
||||
rr = []
|
||||
for d in res["hits"]["hits"]:
|
||||
d["_source"]["id"] = d["_id"]
|
||||
d["_source"]["_score"] = d["_score"]
|
||||
rr.append(d["_source"])
|
||||
return rr
|
||||
|
||||
@abstractmethod
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def get_highlight(self, res, keywords: list[str], field_name: str):
|
||||
ans = {}
|
||||
for d in res["hits"]["hits"]:
|
||||
highlights = d.get("highlight")
|
||||
if not highlights:
|
||||
continue
|
||||
txt = "...".join([a for a in list(highlights.items())[0][1]])
|
||||
if not is_english(txt.split()):
|
||||
ans[d["_id"]] = txt
|
||||
continue
|
||||
|
||||
txt = d["_source"][field_name]
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txt_list = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
for w in keywords:
|
||||
t = re.sub(r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w), r"\1<em>\2</em>\3", t,
|
||||
flags=re.IGNORECASE | re.MULTILINE)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txt_list.append(t)
|
||||
ans[d["_id"]] = "...".join(txt_list) if txt_list else "...".join([a for a in list(highlights.items())[0][1]])
|
||||
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res, field_name: str):
|
||||
agg_field = "aggs_" + field_name
|
||||
if "aggregations" not in res or agg_field not in res["aggregations"]:
|
||||
return list()
|
||||
buckets = res["aggregations"][agg_field]["buckets"]
|
||||
return [(b["key"], b["doc_count"]) for b in buckets]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
self.logger.debug(f"ESConnection.sql get sql: {sql}")
|
||||
sql = re.sub(r"[ `]+", " ", sql)
|
||||
sql = sql.replace("%", "")
|
||||
replaces = []
|
||||
for r in re.finditer(r" ([a-z_]+_l?tks)( like | ?= ?)'([^']+)'", sql):
|
||||
fld, v = r.group(1), r.group(3)
|
||||
match = " MATCH({}, '{}', 'operator=OR;minimum_should_match=30%') ".format(
|
||||
fld, rag_tokenizer.fine_grained_tokenize(rag_tokenizer.tokenize(v)))
|
||||
replaces.append(
|
||||
("{}{}'{}'".format(
|
||||
r.group(1),
|
||||
r.group(2),
|
||||
r.group(3)),
|
||||
match))
|
||||
|
||||
for p, r in replaces:
|
||||
sql = sql.replace(p, r, 1)
|
||||
self.logger.debug(f"ESConnection.sql to es: {sql}")
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.sql.query(body={"query": sql, "fetch_size": fetch_size}, format=format,
|
||||
request_timeout="2s")
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.sql got exception. SQL:\n{sql}")
|
||||
raise Exception(f"SQL error: {e}\n\nSQL: {sql}")
|
||||
self.logger.error(f"ESConnection.sql timeout for {ATTEMPT_TIME} times!")
|
||||
return None
|
||||
451
common/doc_store/infinity_conn_base.py
Normal file
451
common/doc_store/infinity_conn_base.py
Normal file
@ -0,0 +1,451 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
|
||||
import infinity
|
||||
from infinity.common import ConflictType
|
||||
from infinity.index import IndexInfo, IndexType
|
||||
from infinity.connection_pool import ConnectionPool
|
||||
from infinity.errors import ErrorCode
|
||||
import pandas as pd
|
||||
from common.file_utils import get_project_base_directory
|
||||
from rag.nlp import is_english
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import DocStoreConnection, MatchExpr, OrderByExpr
|
||||
|
||||
|
||||
class InfinityConnectionBase(DocStoreConnection):
|
||||
def __init__(self, mapping_file_name: str="infinity_mapping.json", logger_name: str="ragflow.infinity_conn"):
|
||||
self.dbName = settings.INFINITY.get("db_name", "default_db")
|
||||
self.mapping_file_name = mapping_file_name
|
||||
self.logger = logging.getLogger(logger_name)
|
||||
infinity_uri = settings.INFINITY["uri"]
|
||||
if ":" in infinity_uri:
|
||||
host, port = infinity_uri.split(":")
|
||||
infinity_uri = infinity.common.NetworkAddress(host, int(port))
|
||||
self.connPool = None
|
||||
self.logger.info(f"Use Infinity {infinity_uri} as the doc engine.")
|
||||
for _ in range(24):
|
||||
try:
|
||||
conn_pool = ConnectionPool(infinity_uri, max_size=4)
|
||||
inf_conn = conn_pool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
if res.error_code == ErrorCode.OK and res.server_status in ["started", "alive"]:
|
||||
self._migrate_db(inf_conn)
|
||||
self.connPool = conn_pool
|
||||
conn_pool.release_conn(inf_conn)
|
||||
break
|
||||
conn_pool.release_conn(inf_conn)
|
||||
self.logger.warning(f"Infinity status: {res.server_status}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"{str(e)}. Waiting Infinity {infinity_uri} to be healthy.")
|
||||
time.sleep(5)
|
||||
if self.connPool is None:
|
||||
msg = f"Infinity {infinity_uri} is unhealthy in 120s."
|
||||
self.logger.error(msg)
|
||||
raise Exception(msg)
|
||||
self.logger.info(f"Infinity {infinity_uri} is healthy.")
|
||||
|
||||
def _migrate_db(self, inf_conn):
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
table_names = inf_db.list_tables().table_names
|
||||
for table_name in table_names:
|
||||
inf_table = inf_db.get_table(table_name)
|
||||
index_names = inf_table.list_indexes().index_names
|
||||
if "q_vec_idx" not in index_names:
|
||||
# Skip tables not created by me
|
||||
continue
|
||||
column_names = inf_table.show_columns()["name"]
|
||||
column_names = set(column_names)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_name in column_names:
|
||||
continue
|
||||
res = inf_table.add_columns({field_name: field_info})
|
||||
assert res.error_code == infinity.ErrorCode.OK
|
||||
self.logger.info(f"INFINITY added following column to table {table_name}: {field_name} {field_info}")
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
|
||||
"""
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def field_keyword(field_name: str):
|
||||
# judge keyword or not, such as "*_kwd" tag-like columns.
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
# rm _kwd, _tks, _sm_tks, _with_weight suffix in field name.
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
# convert matching field to
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@staticmethod
|
||||
def list2str(lst: str | list, sep: str = " ") -> str:
|
||||
if isinstance(lst, str):
|
||||
return lst
|
||||
return sep.join(lst)
|
||||
|
||||
def equivalent_condition_to_str(self, condition: dict, table_instance=None) -> str | None:
|
||||
assert "_id" not in condition
|
||||
columns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
columns[n] = (ty, de)
|
||||
|
||||
def exists(cln):
|
||||
nonlocal columns
|
||||
assert cln in columns, f"'{cln}' should be in '{columns}'."
|
||||
ty, de = columns[cln]
|
||||
if ty.lower().find("cha"):
|
||||
if not de:
|
||||
de = ""
|
||||
return f" {cln}!='{de}' "
|
||||
return f"{cln}!={de}"
|
||||
|
||||
cond = list()
|
||||
for k, v in condition.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{item}')")
|
||||
if inCond:
|
||||
strInCond = " or ".join(inCond)
|
||||
strInCond = f"({strInCond})"
|
||||
cond.append(strInCond)
|
||||
else:
|
||||
cond.append(f"filter_fulltext('{self.convert_matching_field(k)}', '{v}')")
|
||||
elif isinstance(v, list):
|
||||
inCond = list()
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
item = item.replace("'", "''")
|
||||
inCond.append(f"'{item}'")
|
||||
else:
|
||||
inCond.append(str(item))
|
||||
if inCond:
|
||||
strInCond = ", ".join(inCond)
|
||||
strInCond = f"{k} IN ({strInCond})"
|
||||
cond.append(strInCond)
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
cond.append("NOT (%s)" % exists(vv))
|
||||
elif isinstance(v, str):
|
||||
cond.append(f"{k}='{v}'")
|
||||
elif k == "exists":
|
||||
cond.append(exists(v))
|
||||
else:
|
||||
cond.append(f"{k}={str(v)}")
|
||||
return " AND ".join(cond) if cond else "1=1"
|
||||
|
||||
@staticmethod
|
||||
def concat_dataframes(df_list: list[pd.DataFrame], select_fields: list[str]) -> pd.DataFrame:
|
||||
df_list2 = [df for df in df_list if not df.empty]
|
||||
if df_list2:
|
||||
return pd.concat(df_list2, axis=0).reset_index(drop=True)
|
||||
|
||||
schema = []
|
||||
for field_name in select_fields:
|
||||
if field_name == "score()": # Workaround: fix schema is changed to score()
|
||||
schema.append("SCORE")
|
||||
elif field_name == "similarity()": # Workaround: fix schema is changed to similarity()
|
||||
schema.append("SIMILARITY")
|
||||
else:
|
||||
schema.append(field_name)
|
||||
return pd.DataFrame(columns=schema)
|
||||
|
||||
"""
|
||||
Database operations
|
||||
"""
|
||||
|
||||
def db_type(self) -> str:
|
||||
return "infinity"
|
||||
|
||||
def health(self) -> dict:
|
||||
"""
|
||||
Return the health status of the database.
|
||||
"""
|
||||
inf_conn = self.connPool.get_conn()
|
||||
res = inf_conn.show_current_node()
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res2 = {
|
||||
"type": "infinity",
|
||||
"status": "green" if res.error_code == 0 and res.server_status in ["started", "alive"] else "red",
|
||||
"error": res.error_msg,
|
||||
}
|
||||
return res2
|
||||
|
||||
"""
|
||||
Table operations
|
||||
"""
|
||||
|
||||
def create_idx(self, index_name: str, dataset_id: str, vector_size: int):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
inf_db = inf_conn.create_database(self.dbName, ConflictType.Ignore)
|
||||
|
||||
fp_mapping = os.path.join(get_project_base_directory(), "conf", self.mapping_file_name)
|
||||
if not os.path.exists(fp_mapping):
|
||||
raise Exception(f"Mapping file not found at {fp_mapping}")
|
||||
schema = json.load(open(fp_mapping))
|
||||
vector_name = f"q_{vector_size}_vec"
|
||||
schema[vector_name] = {"type": f"vector,{vector_size},float"}
|
||||
inf_table = inf_db.create_table(
|
||||
table_name,
|
||||
schema,
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
inf_table.create_index(
|
||||
"q_vec_idx",
|
||||
IndexInfo(
|
||||
vector_name,
|
||||
IndexType.Hnsw,
|
||||
{
|
||||
"M": "16",
|
||||
"ef_construction": "50",
|
||||
"metric": "cosine",
|
||||
"encode": "lvq",
|
||||
},
|
||||
),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
for field_name, field_info in schema.items():
|
||||
if field_info["type"] != "varchar" or "analyzer" not in field_info:
|
||||
continue
|
||||
analyzers = field_info["analyzer"]
|
||||
if isinstance(analyzers, str):
|
||||
analyzers = [analyzers]
|
||||
for analyzer in analyzers:
|
||||
inf_table.create_index(
|
||||
f"ft_{re.sub(r'[^a-zA-Z0-9]', '_', field_name)}_{re.sub(r'[^a-zA-Z0-9]', '_', analyzer)}",
|
||||
IndexInfo(field_name, IndexType.FullText, {"ANALYZER": analyzer}),
|
||||
ConflictType.Ignore,
|
||||
)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.info(f"INFINITY created table {table_name}, vector size {vector_size}")
|
||||
return True
|
||||
|
||||
def delete_idx(self, index_name: str, dataset_id: str):
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
db_instance.drop_table(table_name, ConflictType.Ignore)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.info(f"INFINITY dropped table {table_name}")
|
||||
|
||||
def index_exist(self, index_name: str, dataset_id: str) -> bool:
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
try:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
_ = db_instance.get_table(table_name)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.warning(f"INFINITY indexExist {str(e)}")
|
||||
return False
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
dataset_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def get(self, doc_id: str, index_name: str, knowledgebase_ids: list[str]) -> dict | None:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, documents: list[dict], index_name: str, dataset_ids: str = None) -> list[str]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
@abstractmethod
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, dataset_id: str) -> bool:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def delete(self, condition: dict, index_name: str, dataset_id: str) -> int:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{dataset_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
self.logger.warning(f"Skipped deleting from table {table_name} since the table doesn't exist.")
|
||||
return 0
|
||||
filter = self.equivalent_condition_to_str(condition, table_instance)
|
||||
self.logger.debug(f"INFINITY delete table {table_name}, filter {filter}.")
|
||||
res = table_instance.delete(filter)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res.deleted_rows
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_total(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> int:
|
||||
if isinstance(res, tuple):
|
||||
return res[1]
|
||||
return len(res)
|
||||
|
||||
def get_doc_ids(self, res: tuple[pd.DataFrame, int] | pd.DataFrame) -> list[str]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
return list(res["id"])
|
||||
|
||||
@abstractmethod
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
def get_highlight(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, keywords: list[str], field_name: str):
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
ans = {}
|
||||
num_rows = len(res)
|
||||
column_id = res["id"]
|
||||
if field_name not in res:
|
||||
return {}
|
||||
for i in range(num_rows):
|
||||
id = column_id[i]
|
||||
txt = res[field_name][i]
|
||||
if re.search(r"<em>[^<>]+</em>", txt, flags=re.IGNORECASE | re.MULTILINE):
|
||||
ans[id] = txt
|
||||
continue
|
||||
txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE)
|
||||
txt_list = []
|
||||
for t in re.split(r"[.?!;\n]", txt):
|
||||
if is_english([t]):
|
||||
for w in keywords:
|
||||
t = re.sub(
|
||||
r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" % re.escape(w),
|
||||
r"\1<em>\2</em>\3",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
else:
|
||||
for w in sorted(keywords, key=len, reverse=True):
|
||||
t = re.sub(
|
||||
re.escape(w),
|
||||
f"<em>{w}</em>",
|
||||
t,
|
||||
flags=re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
if not re.search(r"<em>[^<>]+</em>", t, flags=re.IGNORECASE | re.MULTILINE):
|
||||
continue
|
||||
txt_list.append(t)
|
||||
if txt_list:
|
||||
ans[id] = "...".join(txt_list)
|
||||
else:
|
||||
ans[id] = txt
|
||||
return ans
|
||||
|
||||
def get_aggregation(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, field_name: str):
|
||||
"""
|
||||
Manual aggregation for tag fields since Infinity doesn't provide native aggregation
|
||||
"""
|
||||
from collections import Counter
|
||||
|
||||
# Extract DataFrame from result
|
||||
if isinstance(res, tuple):
|
||||
df, _ = res
|
||||
else:
|
||||
df = res
|
||||
|
||||
if df.empty or field_name not in df.columns:
|
||||
return []
|
||||
|
||||
# Aggregate tag counts
|
||||
tag_counter = Counter()
|
||||
|
||||
for value in df[field_name]:
|
||||
if pd.isna(value) or not value:
|
||||
continue
|
||||
|
||||
# Handle different tag formats
|
||||
if isinstance(value, str):
|
||||
# Split by ### for tag_kwd field or comma for other formats
|
||||
if field_name == "tag_kwd" and "###" in value:
|
||||
tags = [tag.strip() for tag in value.split("###") if tag.strip()]
|
||||
else:
|
||||
# Try comma separation as fallback
|
||||
tags = [tag.strip() for tag in value.split(",") if tag.strip()]
|
||||
|
||||
for tag in tags:
|
||||
if tag: # Only count non-empty tags
|
||||
tag_counter[tag] += 1
|
||||
elif isinstance(value, list):
|
||||
# Handle list format
|
||||
for tag in value:
|
||||
if tag and isinstance(tag, str):
|
||||
tag_counter[tag.strip()] += 1
|
||||
|
||||
# Return as list of [tag, count] pairs, sorted by count descending
|
||||
return [[tag, count] for tag, count in tag_counter.most_common()]
|
||||
|
||||
"""
|
||||
SQL
|
||||
"""
|
||||
|
||||
def sql(self, sql: str, fetch_size: int, format: str):
|
||||
raise NotImplementedError("Not implemented")
|
||||
@ -16,7 +16,7 @@ import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from common import settings
|
||||
import httpx
|
||||
@ -58,21 +58,34 @@ def _get_delay(backoff_factor: float, attempt: int) -> float:
|
||||
_SENSITIVE_QUERY_KEYS = {"client_secret", "secret", "code", "access_token", "refresh_token", "password", "token", "app_secret"}
|
||||
|
||||
def _redact_sensitive_url_params(url: str) -> str:
|
||||
"""
|
||||
Return a version of the URL that is safe to log.
|
||||
|
||||
We intentionally drop query parameters and userinfo to avoid leaking
|
||||
credentials or tokens via logs. Only scheme, host, port and path
|
||||
are preserved.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
if not parsed.query:
|
||||
return url
|
||||
clean_query = []
|
||||
for k, v in parse_qsl(parsed.query, keep_blank_values=True):
|
||||
if k.lower() in _SENSITIVE_QUERY_KEYS:
|
||||
clean_query.append((k, "***REDACTED***"))
|
||||
else:
|
||||
clean_query.append((k, v))
|
||||
new_query = urlencode(clean_query, doseq=True)
|
||||
redacted_url = urlunparse(parsed._replace(query=new_query))
|
||||
return redacted_url
|
||||
# Remove any potential userinfo (username:password@)
|
||||
netloc = parsed.hostname or ""
|
||||
if parsed.port:
|
||||
netloc = f"{netloc}:{parsed.port}"
|
||||
# Reconstruct URL without query, params, fragment, or userinfo.
|
||||
safe_url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
parsed.path,
|
||||
"", # params
|
||||
"", # query
|
||||
"", # fragment
|
||||
)
|
||||
)
|
||||
return safe_url
|
||||
except Exception:
|
||||
return url
|
||||
# If parsing fails, fall back to omitting the URL entirely.
|
||||
return "<redacted-url>"
|
||||
|
||||
def _is_sensitive_url(url: str) -> bool:
|
||||
"""Return True if URL is one of the configured OAuth endpoints."""
|
||||
@ -144,23 +157,28 @@ async def async_request(
|
||||
method=method, url=url, headers=headers, **kwargs
|
||||
)
|
||||
duration = time.monotonic() - start
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
logger.debug(
|
||||
f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s"
|
||||
)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.debug(f"async_request {method} {log_url} -> {response.status_code} in {duration:.3f}s")
|
||||
return response
|
||||
except httpx.RequestError as exc:
|
||||
last_exc = exc
|
||||
if attempt >= retries:
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.warning(f"async_request exhausted retries for {method}")
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
if not _is_sensitive_url(url):
|
||||
log_url = _redact_sensitive_url_params(url)
|
||||
logger.warning(
|
||||
f"async_request exhausted retries for {method} {log_url}"
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method}; retrying in {delay:.2f}s"
|
||||
)
|
||||
raise
|
||||
delay = _get_delay(backoff_factor, attempt)
|
||||
log_url = "<SENSITIVE ENDPOINT>" if _is_sensitive_url(url) else _redact_sensitive_url_params(url)
|
||||
# Avoid including the (potentially sensitive) URL in retry logs.
|
||||
logger.warning(
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method} {log_url}; retrying in {delay:.2f}s"
|
||||
f"async_request attempt {attempt + 1}/{retries + 1} failed for {method}; retrying in {delay:.2f}s"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
raise last_exc # pragma: no cover
|
||||
|
||||
@ -75,9 +75,12 @@ def init_root_logger(logfile_basename: str, log_format: str = "%(asctime)-15s %(
|
||||
def log_exception(e, *args):
|
||||
logging.exception(e)
|
||||
for a in args:
|
||||
if hasattr(a, "text"):
|
||||
logging.error(a.text)
|
||||
raise Exception(a.text)
|
||||
else:
|
||||
logging.error(str(a))
|
||||
try:
|
||||
text = getattr(a, "text")
|
||||
except Exception:
|
||||
text = None
|
||||
if text is not None:
|
||||
logging.error(text)
|
||||
raise Exception(text)
|
||||
logging.error(str(a))
|
||||
raise e
|
||||
|
||||
@ -44,21 +44,27 @@ def meta_filter(metas: dict, filters: list[dict], logic: str = "and"):
|
||||
def filter_out(v2docs, operator, value):
|
||||
ids = []
|
||||
for input, docids in v2docs.items():
|
||||
|
||||
if operator in ["=", "≠", ">", "<", "≥", "≤"]:
|
||||
try:
|
||||
if isinstance(input, list):
|
||||
input = input[0]
|
||||
input = float(input)
|
||||
value = float(value)
|
||||
except Exception:
|
||||
input = str(input)
|
||||
value = str(value)
|
||||
pass
|
||||
if isinstance(input, str):
|
||||
input = input.lower()
|
||||
if isinstance(value, str):
|
||||
value = value.lower()
|
||||
|
||||
for conds in [
|
||||
(operator == "contains", str(value).lower() in str(input).lower()),
|
||||
(operator == "not contains", str(value).lower() not in str(input).lower()),
|
||||
(operator == "in", str(input).lower() in str(value).lower()),
|
||||
(operator == "not in", str(input).lower() not in str(value).lower()),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower())),
|
||||
(operator == "contains", input in value if not isinstance(input, list) else all([i in value for i in input])),
|
||||
(operator == "not contains", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
|
||||
(operator == "in", input in value if not isinstance(input, list) else all([i in value for i in input])),
|
||||
(operator == "not in", input not in value if not isinstance(input, list) else all([i not in value for i in input])),
|
||||
(operator == "start with", str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower())),
|
||||
(operator == "end with", str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower())),
|
||||
(operator == "empty", not input),
|
||||
(operator == "not empty", input),
|
||||
(operator == "=", input == value),
|
||||
@ -145,6 +151,18 @@ async def apply_meta_data_filter(
|
||||
return doc_ids
|
||||
|
||||
|
||||
def dedupe_list(values: list) -> list:
|
||||
seen = set()
|
||||
deduped = []
|
||||
for item in values:
|
||||
key = str(item)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
deduped.append(item)
|
||||
return deduped
|
||||
|
||||
|
||||
def update_metadata_to(metadata, meta):
|
||||
if not meta:
|
||||
return metadata
|
||||
@ -156,11 +174,13 @@ def update_metadata_to(metadata, meta):
|
||||
return metadata
|
||||
if not isinstance(meta, dict):
|
||||
return metadata
|
||||
|
||||
for k, v in meta.items():
|
||||
if isinstance(v, list):
|
||||
v = [vv for vv in v if isinstance(vv, str)]
|
||||
if not v:
|
||||
continue
|
||||
v = dedupe_list(v)
|
||||
if not isinstance(v, list) and not isinstance(v, str):
|
||||
continue
|
||||
if k not in metadata:
|
||||
@ -171,6 +191,7 @@ def update_metadata_to(metadata, meta):
|
||||
metadata[k].extend(v)
|
||||
else:
|
||||
metadata[k].append(v)
|
||||
metadata[k] = dedupe_list(metadata[k])
|
||||
else:
|
||||
metadata[k] = v
|
||||
|
||||
@ -202,4 +223,4 @@ def metadata_schema(metadata: list|None) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
json_schema["additionalProperties"] = False
|
||||
return json_schema
|
||||
return json_schema
|
||||
|
||||
72
common/query_base.py
Normal file
72
common/query_base.py
Normal file
@ -0,0 +1,72 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class QueryBase(ABC):
|
||||
|
||||
@staticmethod
|
||||
def is_chinese(line):
|
||||
arr = re.split(r"[ \t]+", line)
|
||||
if len(arr) <= 3:
|
||||
return True
|
||||
e = 0
|
||||
for t in arr:
|
||||
if not re.match(r"[a-zA-Z]+$", t):
|
||||
e += 1
|
||||
return e * 1.0 / len(arr) >= 0.7
|
||||
|
||||
@staticmethod
|
||||
def sub_special_char(line):
|
||||
return re.sub(r"([:\{\}/\[\]\-\*\"\(\)\|\+~\^])", r"\\\1", line).strip()
|
||||
|
||||
@staticmethod
|
||||
def rmWWW(txt):
|
||||
patts = [
|
||||
(
|
||||
r"是*(怎么办|什么样的|哪家|一下|那家|请问|啥样|咋样了|什么时候|何时|何地|何人|是否|是不是|多少|哪里|怎么|哪儿|怎么样|如何|哪些|是啥|啥是|啊|吗|呢|吧|咋|什么|有没有|呀|谁|哪位|哪个)是*",
|
||||
"",
|
||||
),
|
||||
(r"(^| )(what|who|how|which|where|why)('re|'s)? ", " "),
|
||||
(
|
||||
r"(^| )('s|'re|is|are|were|was|do|does|did|don't|doesn't|didn't|has|have|be|there|you|me|your|my|mine|just|please|may|i|should|would|wouldn't|will|won't|done|go|for|with|so|the|a|an|by|i'm|it's|he's|she's|they|they're|you're|as|by|on|in|at|up|out|down|of|to|or|and|if) ",
|
||||
" ")
|
||||
]
|
||||
otxt = txt
|
||||
for r, p in patts:
|
||||
txt = re.sub(r, p, txt, flags=re.IGNORECASE)
|
||||
if not txt:
|
||||
txt = otxt
|
||||
return txt
|
||||
|
||||
@staticmethod
|
||||
def add_space_between_eng_zh(txt):
|
||||
# (ENG/ENG+NUM) + ZH
|
||||
txt = re.sub(r'([A-Za-z]+[0-9]+)([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ENG + ZH
|
||||
txt = re.sub(r'([A-Za-z])([\u4e00-\u9fa5]+)', r'\1 \2', txt)
|
||||
# ZH + (ENG/ENG+NUM)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z]+[0-9]+)', r'\1 \2', txt)
|
||||
txt = re.sub(r'([\u4e00-\u9fa5]+)([A-Za-z])', r'\1 \2', txt)
|
||||
return txt
|
||||
|
||||
@abstractmethod
|
||||
def question(self, text, tbl, min_match):
|
||||
"""
|
||||
Returns a query object based on the input text, table, and minimum match criteria.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented")
|
||||
@ -39,6 +39,9 @@ from rag.utils.oss_conn import RAGFlowOSS
|
||||
|
||||
from rag.nlp import search
|
||||
|
||||
import memory.utils.es_conn as memory_es_conn
|
||||
import memory.utils.infinity_conn as memory_infinity_conn
|
||||
|
||||
LLM = None
|
||||
LLM_FACTORY = None
|
||||
LLM_BASE_URL = None
|
||||
@ -76,9 +79,11 @@ FEISHU_OAUTH = None
|
||||
OAUTH_CONFIG = None
|
||||
DOC_ENGINE = os.getenv('DOC_ENGINE', 'elasticsearch')
|
||||
DOC_ENGINE_INFINITY = (DOC_ENGINE.lower() == "infinity")
|
||||
MSG_ENGINE = DOC_ENGINE
|
||||
|
||||
|
||||
docStoreConn = None
|
||||
msgStoreConn = None
|
||||
|
||||
retriever = None
|
||||
kg_retriever = None
|
||||
@ -256,6 +261,15 @@ def init_settings():
|
||||
else:
|
||||
raise Exception(f"Not supported doc engine: {DOC_ENGINE}")
|
||||
|
||||
global MSG_ENGINE, msgStoreConn
|
||||
MSG_ENGINE = DOC_ENGINE # use the same engine for message store
|
||||
if MSG_ENGINE == "elasticsearch":
|
||||
ES = get_base_config("es", {})
|
||||
msgStoreConn = memory_es_conn.ESConnection()
|
||||
elif MSG_ENGINE == "infinity":
|
||||
INFINITY = get_base_config("infinity", {"uri": "infinity:23817"})
|
||||
msgStoreConn = memory_infinity_conn.InfinityConnection()
|
||||
|
||||
global AZURE, S3, MINIO, OSS, GCS
|
||||
if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']:
|
||||
AZURE = get_base_config("azure", {})
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
|
||||
def current_timestamp():
|
||||
@ -123,4 +124,31 @@ def delta_seconds(date_string: str):
|
||||
3600.0 # If current time is 2024-01-01 13:00:00
|
||||
"""
|
||||
dt = datetime.datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
return (datetime.datetime.now() - dt).total_seconds()
|
||||
|
||||
|
||||
def format_iso_8601_to_ymd_hms(time_str: str) -> str:
|
||||
"""
|
||||
Convert ISO 8601 formatted string to "YYYY-MM-DD HH:MM:SS" format.
|
||||
|
||||
Args:
|
||||
time_str: ISO 8601 date string (e.g. "2024-01-01T12:00:00Z")
|
||||
|
||||
Returns:
|
||||
str: Date string in "YYYY-MM-DD HH:MM:SS" format
|
||||
|
||||
Example:
|
||||
>>> format_iso_8601_to_ymd_hms("2024-01-01T12:00:00Z")
|
||||
'2024-01-01 12:00:00'
|
||||
"""
|
||||
from dateutil import parser
|
||||
|
||||
try:
|
||||
if parser.isoparse(time_str):
|
||||
dt = datetime.datetime.fromisoformat(time_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
return time_str
|
||||
except Exception as e:
|
||||
logging.error(str(e))
|
||||
return time_str
|
||||
|
||||
@ -44,17 +44,23 @@ def total_token_count_from_response(resp):
|
||||
if resp is None:
|
||||
return 0
|
||||
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
try:
|
||||
try:
|
||||
if hasattr(resp, "usage") and hasattr(resp.usage, "total_tokens"):
|
||||
return resp.usage.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
try:
|
||||
try:
|
||||
if hasattr(resp, "usage_metadata") and hasattr(resp.usage_metadata, "total_tokens"):
|
||||
return resp.usage_metadata.total_tokens
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
if hasattr(resp, "meta") and hasattr(resp.meta, "billed_units") and hasattr(resp.meta.billed_units, "input_tokens"):
|
||||
return resp.meta.billed_units.input_tokens
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(resp, dict) and 'usage' in resp and 'total_tokens' in resp['usage']:
|
||||
try:
|
||||
@ -79,4 +85,3 @@ def total_token_count_from_response(resp):
|
||||
def truncate(string: str, max_len: int) -> str:
|
||||
"""Returns truncated text if the length of text exceed max_len."""
|
||||
return encoder.decode(encoder.encode(string)[:max_len])
|
||||
|
||||
|
||||
@ -31,6 +31,7 @@
|
||||
"entity_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"source_id": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"n_hop_with_weight": {"type": "varchar", "default": ""},
|
||||
"mom_with_weight": {"type": "varchar", "default": ""},
|
||||
"removed_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"doc_type_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"toc_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
|
||||
@ -762,6 +762,13 @@
|
||||
"status": "1",
|
||||
"rank": "940",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "glm-4.7",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat",
|
||||
"is_tools": true
|
||||
},
|
||||
{
|
||||
"llm_name": "glm-4.5",
|
||||
"tags": "LLM,CHAT,128K",
|
||||
@ -1251,6 +1258,12 @@
|
||||
"status": "1",
|
||||
"rank": "810",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "MiniMax-M2.1",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
"max_tokens": 200000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "MiniMax-M2",
|
||||
"tags": "LLM,CHAT,200k",
|
||||
|
||||
19
conf/message_infinity_mapping.json
Normal file
19
conf/message_infinity_mapping.json
Normal file
@ -0,0 +1,19 @@
|
||||
{
|
||||
"id": {"type": "varchar", "default": ""},
|
||||
"message_id": {"type": "integer", "default": 0},
|
||||
"message_type_kwd": {"type": "varchar", "default": ""},
|
||||
"source_id": {"type": "integer", "default": 0},
|
||||
"memory_id": {"type": "varchar", "default": ""},
|
||||
"user_id": {"type": "varchar", "default": ""},
|
||||
"agent_id": {"type": "varchar", "default": ""},
|
||||
"session_id": {"type": "varchar", "default": ""},
|
||||
"valid_at": {"type": "varchar", "default": ""},
|
||||
"valid_at_flt": {"type": "float", "default": 0.0},
|
||||
"invalid_at": {"type": "varchar", "default": ""},
|
||||
"invalid_at_flt": {"type": "float", "default": 0.0},
|
||||
"forget_at": {"type": "varchar", "default": ""},
|
||||
"forget_at_flt": {"type": "float", "default": 0.0},
|
||||
"status_int": {"type": "integer", "default": 1},
|
||||
"zone_id": {"type": "integer", "default": 0},
|
||||
"content": {"type": "varchar", "default": "", "analyzer": ["rag-coarse", "rag-fine"], "comment": "content_ltks"}
|
||||
}
|
||||
@ -18,6 +18,7 @@ from io import BytesIO
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
from PIL import Image
|
||||
|
||||
from rag.nlp import find_codec
|
||||
|
||||
@ -109,6 +110,52 @@ class RAGFlowExcelParser:
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
return wb
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_worksheet(ws, sheetname=None):
|
||||
"""
|
||||
Extract images from a worksheet and enrich them with vision-based descriptions.
|
||||
|
||||
Returns: List[dict]
|
||||
"""
|
||||
images = getattr(ws, "_images", [])
|
||||
if not images:
|
||||
return []
|
||||
|
||||
raw_items = []
|
||||
|
||||
for img in images:
|
||||
try:
|
||||
img_bytes = img._data()
|
||||
pil_img = Image.open(BytesIO(img_bytes)).convert("RGB")
|
||||
|
||||
anchor = img.anchor
|
||||
if hasattr(anchor, "_from") and hasattr(anchor, "_to"):
|
||||
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
|
||||
r2, c2 = anchor._to.row + 1, anchor._to.col + 1
|
||||
if r1 == r2 and c1 == c2:
|
||||
span = "single_cell"
|
||||
else:
|
||||
span = "multi_cell"
|
||||
else:
|
||||
r1, c1 = anchor._from.row + 1, anchor._from.col + 1
|
||||
r2, c2 = r1, c1
|
||||
span = "single_cell"
|
||||
|
||||
item = {
|
||||
"sheet": sheetname or ws.title,
|
||||
"image": pil_img,
|
||||
"image_description": "",
|
||||
"row_from": r1,
|
||||
"col_from": c1,
|
||||
"row_to": r2,
|
||||
"col_to": c2,
|
||||
"span_type": span,
|
||||
}
|
||||
raw_items.append(item)
|
||||
except Exception:
|
||||
continue
|
||||
return raw_items
|
||||
|
||||
def html(self, fnm, chunk_rows=256):
|
||||
from html import escape
|
||||
|
||||
|
||||
@ -38,8 +38,8 @@ def vision_figure_parser_figure_data_wrapper(figures_data_without_positions):
|
||||
|
||||
|
||||
def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
|
||||
if not tbls:
|
||||
return []
|
||||
if not sections:
|
||||
return tbls
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.7, "Visual model detected. Attempting to enhance figure extraction...")
|
||||
@ -55,6 +55,31 @@ def vision_figure_parser_docx_wrapper(sections, tbls, callback=None,**kwargs):
|
||||
callback(0.8, f"Visual model error: {e}. Skipping figure parsing enhancement.")
|
||||
return tbls
|
||||
|
||||
def vision_figure_parser_figure_xlsx_wrapper(images,callback=None, **kwargs):
|
||||
tbls = []
|
||||
if not images:
|
||||
return []
|
||||
try:
|
||||
vision_model = LLMBundle(kwargs["tenant_id"], LLMType.IMAGE2TEXT)
|
||||
callback(0.2, "Visual model detected. Attempting to enhance Excel image extraction...")
|
||||
except Exception:
|
||||
vision_model = None
|
||||
if vision_model:
|
||||
figures_data = [((
|
||||
img["image"], # Image.Image
|
||||
[img["image_description"]] # description list (must be list)
|
||||
),
|
||||
[
|
||||
(0, 0, 0, 0, 0) # dummy position
|
||||
]) for img in images]
|
||||
try:
|
||||
parser = VisionFigureParser(vision_model=vision_model, figures_data=figures_data, **kwargs)
|
||||
callback(0.22, "Parsing images...")
|
||||
boosted_figures = parser(callback=callback)
|
||||
tbls.extend(boosted_figures)
|
||||
except Exception as e:
|
||||
callback(0.25, f"Excel visual model error: {e}. Skipping vision enhancement.")
|
||||
return tbls
|
||||
|
||||
def vision_figure_parser_pdf_wrapper(tbls, callback=None, **kwargs):
|
||||
if not tbls:
|
||||
|
||||
@ -511,7 +511,7 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
for output in outputs:
|
||||
match output["type"]:
|
||||
case MinerUContentType.TEXT:
|
||||
section = output["text"]
|
||||
section = output.get("text", "")
|
||||
case MinerUContentType.TABLE:
|
||||
section = output.get("table_body", "") + "\n".join(output.get("table_caption", [])) + "\n".join(
|
||||
output.get("table_footnote", []))
|
||||
@ -521,13 +521,13 @@ class MinerUParser(RAGFlowPdfParser):
|
||||
section = "".join(output.get("image_caption", [])) + "\n" + "".join(
|
||||
output.get("image_footnote", []))
|
||||
case MinerUContentType.EQUATION:
|
||||
section = output["text"]
|
||||
section = output.get("text", "")
|
||||
case MinerUContentType.CODE:
|
||||
section = output["code_body"] + "\n".join(output.get("code_caption", []))
|
||||
section = output.get("code_body", "") + "\n".join(output.get("code_caption", []))
|
||||
case MinerUContentType.LIST:
|
||||
section = "\n".join(output.get("list_items", []))
|
||||
case MinerUContentType.DISCARDED:
|
||||
pass
|
||||
continue # Skip discarded blocks entirely
|
||||
|
||||
if section and parse_method == "manual":
|
||||
sections.append((section, output["type"], self._line_tag(output)))
|
||||
|
||||
@ -1447,6 +1447,7 @@ class VisionParser(RAGFlowPdfParser):
|
||||
def __init__(self, vision_model, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.vision_model = vision_model
|
||||
self.outlines = []
|
||||
|
||||
def __images__(self, fnm, zoomin=3, page_from=0, page_to=299, callback=None):
|
||||
try:
|
||||
|
||||
@ -88,12 +88,9 @@ class RAGFlowPptParser:
|
||||
texts = []
|
||||
for shape in sorted(
|
||||
slide.shapes, key=lambda x: ((x.top if x.top is not None else 0) // 10, x.left if x.left is not None else 0)):
|
||||
try:
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
txt = self.__extract(shape)
|
||||
if txt:
|
||||
texts.append(txt)
|
||||
txts.append("\n".join(texts))
|
||||
|
||||
return txts
|
||||
|
||||
@ -72,7 +72,7 @@ services:
|
||||
infinity:
|
||||
profiles:
|
||||
- infinity
|
||||
image: infiniflow/infinity:v0.6.11
|
||||
image: infiniflow/infinity:v0.6.13
|
||||
volumes:
|
||||
- infinity_data:/var/infinity
|
||||
- ./infinity_conf.toml:/infinity_conf.toml
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
[general]
|
||||
version = "0.6.11"
|
||||
version = "0.6.13"
|
||||
time_zone = "utc-8"
|
||||
|
||||
[network]
|
||||
|
||||
0
docker/launch_backend_service.sh
Normal file → Executable file
0
docker/launch_backend_service.sh
Normal file → Executable file
37
docs/faq.mdx
37
docs/faq.mdx
@ -493,18 +493,35 @@ See [here](./guides/agent/best_practices/accelerate_agent_question_answering.md)
|
||||
|
||||
### How to use MinerU to parse PDF documents?
|
||||
|
||||
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow works only as a remote client to MinerU (>= 2.6.3) and does not install or execute MinerU locally. To use this feature:
|
||||
From v0.22.0 onwards, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse PDFs and reading the returned files. To use this feature:
|
||||
|
||||
1. Prepare a reachable MinerU API service (for example, the FastAPI server provided by MinerU).
|
||||
2. Configure RAGFlow with remote MinerU settings (environment variables or UI model provider):
|
||||
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
|
||||
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`, `vlm-lmdeploy-engine`).
|
||||
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
|
||||
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown (which supports PDF parsing), and select **MinerU** in **PDF parser**.
|
||||
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
|
||||
1. Prepare a reachable MinerU API service (FastAPI server).
|
||||
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
|
||||
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
|
||||
- `MINERU_BACKEND`: The MinerU backend:
|
||||
- `"pipeline"` (default)
|
||||
- `"vlm-http-client"`
|
||||
- `"vlm-transformers"`
|
||||
- `"vlm-vllm-engine"`
|
||||
- `"vlm-mlx-engine"`
|
||||
- `"vlm-vllm-async-engine"`
|
||||
- `"vlm-lmdeploy-engine"`.
|
||||
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
|
||||
- `1`: Delete.
|
||||
- `0`: Retain.
|
||||
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
|
||||
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
|
||||
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
|
||||
|
||||
:::note
|
||||
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
|
||||
:::
|
||||
|
||||
:::caution WARNING
|
||||
Third-party visual models are marked **Experimental**, because we have not fully tested these models for the aforementioned data extraction tasks.
|
||||
:::
|
||||
---
|
||||
|
||||
### How to configure MinerU-specific settings?
|
||||
|
||||
@ -24,7 +24,7 @@ We use gVisor to isolate code execution from the host system. Please follow [the
|
||||
RAGFlow Sandbox is a secure, pluggable code execution backend. It serves as the code executor for the **Code** component. Please follow the [instructions here](https://github.com/infiniflow/ragflow/tree/main/sandbox) to install RAGFlow Sandbox.
|
||||
|
||||
:::note Docker client version
|
||||
The executor manager image now bundles Docker CLI `29.1.0` (API 1.44+). Older images shipped Docker 24.x and will fail against newer Docker daemons with `client version 1.43 is too old`. Pull the latest `infiniflow/sandbox-executor-manager:latest` or rebuild `./sandbox/executor_manager` if you encounter this error.
|
||||
The executor manager image now bundles Docker CLI `29.1.0` (API 1.44+). Older images shipped Docker 24.x and will fail against newer Docker daemons with `client version 1.43 is too old`. Pull the latest `infiniflow/sandbox-executor-manager:latest` or rebuild it in `./sandbox/executor_manager` if you encounter this error.
|
||||
:::
|
||||
|
||||
:::tip NOTE
|
||||
@ -134,7 +134,7 @@ Your executor manager image includes Docker CLI 24.x (API 1.43), but the host Do
|
||||
|
||||
**Solution**
|
||||
|
||||
Pull the latest executor manager image or rebuild it locally to upgrade the built-in Docker client:
|
||||
Pull the latest executor manager image or rebuild it in `./sandbox/executor_manager` to upgrade the built-in Docker client:
|
||||
|
||||
```bash
|
||||
docker pull infiniflow/sandbox-executor-manager:latest
|
||||
|
||||
90
docs/guides/agent/agent_component_reference/http.md
Normal file
90
docs/guides/agent/agent_component_reference/http.md
Normal file
@ -0,0 +1,90 @@
|
||||
---
|
||||
sidebar_position: 30
|
||||
slug: /http_request_component
|
||||
---
|
||||
|
||||
# HTTP request component
|
||||
|
||||
A component that calls remote services.
|
||||
|
||||
---
|
||||
|
||||
An **HTTP request** component lets you access remote APIs or services by providing a URL and an HTTP method, and then receive the response. You can customize headers, parameters, proxies, and timeout settings, and use common methods like GET and POST. It’s useful for exchanging data with external systems in a workflow.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- An accessible remote API or service.
|
||||
- Add a Token or credentials to the request header, if the target service requires authentication.
|
||||
|
||||
## Configurations
|
||||
|
||||
### Url
|
||||
|
||||
*Required*. The complete request address, for example: http://api.example.com/data.
|
||||
|
||||
### Method
|
||||
|
||||
The HTTP request method to select. Available options:
|
||||
|
||||
- GET
|
||||
- POST
|
||||
- PUT
|
||||
|
||||
### Timeout
|
||||
|
||||
The maximum waiting time for the request, in seconds. Defaults to `60`.
|
||||
|
||||
### Headers
|
||||
|
||||
Custom HTTP headers can be set here, for example:
|
||||
|
||||
```http
|
||||
{
|
||||
"Accept": "application/json",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive"
|
||||
}
|
||||
```
|
||||
|
||||
### Proxy
|
||||
|
||||
Optional. The proxy server address to use for this request.
|
||||
|
||||
### Clean HTML
|
||||
|
||||
`Boolean`: Whether to remove HTML tags from the returned results and keep plain text only.
|
||||
|
||||
### Parameter
|
||||
|
||||
*Optional*. Parameters to send with the HTTP request. Supports key-value pairs:
|
||||
|
||||
- To assign a value using a dynamic system variable, set it as Variable.
|
||||
- To override these dynamic values under certain conditions and use a fixed static value instead, Value is the appropriate choice.
|
||||
|
||||
|
||||
:::tip NOTE
|
||||
- For GET requests, these parameters are appended to the end of the URL.
|
||||
- For POST/PUT requests, they are sent as the request body.
|
||||
:::
|
||||
|
||||
#### Example setting
|
||||
|
||||

|
||||
|
||||
#### Example response
|
||||
|
||||
```html
|
||||
{ "args": { "App": "RAGFlow", "Query": "How to do?", "Userid": "241ed25a8e1011f0b979424ebc5b108b" }, "headers": { "Accept": "/", "Accept-Encoding": "gzip, deflate, br, zstd", "Cache-Control": "no-cache", "Host": "httpbin.org", "User-Agent": "python-requests/2.32.2", "X-Amzn-Trace-Id": "Root=1-68c9210c-5aab9088580c130a2f065523" }, "origin": "185.36.193.38", "url": "https://httpbin.org/get?Userid=241ed25a8e1011f0b979424ebc5b108b&App=RAGFlow&Query=How+to+do%3F" }
|
||||
```
|
||||
|
||||
### Output
|
||||
|
||||
The global variable name for the output of the HTTP request component, which can be referenced by other components in the workflow.
|
||||
|
||||
- `Result`: `string` The response returned by the remote service.
|
||||
|
||||
## Example
|
||||
|
||||
This is a usage example: a workflow sends a GET request from the **Begin** component to `https://httpbin.org/get` via the **HTTP Request_0** component, passes parameters to the server, and finally outputs the result through the **Message_0** component.
|
||||
|
||||

|
||||
@ -40,21 +40,31 @@ The output of a PDF parser is `json`. In the PDF parser, you select the parsing
|
||||
- A third-party visual model from a specific model provider.
|
||||
|
||||
:::danger IMPORTANT
|
||||
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports MinerU (>= 2.6.3) as an optional PDF parser with multiple backends. RAGFlow acts only as a **remote client** for MinerU, calling the MinerU API to parse documents, reading the returned output files, and ingesting the parsed content. To use this feature:
|
||||
Starting from v0.22.0, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse documents and reading the returned files. To use this feature:
|
||||
:::
|
||||
|
||||
1. Prepare a reachable MinerU API service (FastAPI server).
|
||||
2. Configure RAGFlow with the remote MinerU settings (env or UI model provider):
|
||||
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
|
||||
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`, `vlm-lmdeploy-engine`).
|
||||
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
|
||||
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown, which supports PDF parsing, and select **MinerU** in **PDF parser**.
|
||||
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
|
||||
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
|
||||
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
|
||||
- `MINERU_BACKEND`: The MinerU backend:
|
||||
- `"pipeline"` (default)
|
||||
- `"vlm-http-client"`
|
||||
- `"vlm-transformers"`
|
||||
- `"vlm-vllm-engine"`
|
||||
- `"vlm-mlx-engine"`
|
||||
- `"vlm-vllm-async-engine"`
|
||||
- `"vlm-lmdeploy-engine"`.
|
||||
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
|
||||
- `1`: Delete.
|
||||
- `0`: Retain.
|
||||
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
|
||||
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
|
||||
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
|
||||
|
||||
:::note
|
||||
All MinerU environment variables are optional. If set, RAGFlow will auto-provision a MinerU OCR model for the tenant on first use with these values. To avoid auto-provisioning, configure MinerU solely through the UI and leave the env vars unset.
|
||||
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
|
||||
:::
|
||||
|
||||
:::caution WARNING
|
||||
|
||||
@ -29,7 +29,7 @@ The architecture consists of isolated Docker base images for each supported lang
|
||||
- (Optional) GNU Make for simplified command-line management.
|
||||
|
||||
:::tip NOTE
|
||||
The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. To solve this issue, pull the latest `infiniflow/sandbox-executor-manager:latest` from Docker Hub (or rebuild `./sandbox/executor_manager`).
|
||||
The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. To solve this issue, pull the latest `infiniflow/sandbox-executor-manager:latest` from Docker Hub or rebuild it in `./sandbox/executor_manager`.
|
||||
:::
|
||||
|
||||
## Build Docker base images
|
||||
|
||||
@ -45,7 +45,7 @@ Google Cloud external project.
|
||||
http://localhost:9380/v1/connector/google-drive/oauth/web/callback
|
||||
```
|
||||
|
||||
### If using Docker deployment:
|
||||
- If using Docker deployment:
|
||||
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
@ -53,15 +53,16 @@ http://localhost:80
|
||||
```
|
||||
|
||||

|
||||
### If running from source:
|
||||
|
||||
- If running from source:
|
||||
**Authorized JavaScript origin:**
|
||||
```
|
||||
http://localhost:9222
|
||||
```
|
||||
|
||||

|
||||
5. After saving, click **Download JSON**. This file will later be
|
||||
uploaded into RAGFlow.
|
||||
|
||||
5. After saving, click **Download JSON**. This file will later be uploaded into RAGFlow.
|
||||
|
||||

|
||||
|
||||
|
||||
@ -40,21 +40,31 @@ RAGFlow isn't one-size-fits-all. It is built for flexibility and supports deeper
|
||||
- A third-party visual model from a specific model provider.
|
||||
|
||||
:::danger IMPORTANT
|
||||
MinerU PDF document parsing is available starting from v0.22.0. RAGFlow supports MinerU (>= 2.6.3) as an optional PDF parser with multiple backends. RAGFlow acts only as a **remote client** for MinerU, calling the MinerU API to parse documents, reading the returned output files, and ingesting the parsed content. To use this feature:
|
||||
|
||||
1. Prepare a reachable MinerU API service (FastAPI server).
|
||||
2. Configure RAGFlow with the remote MinerU settings (env or UI model provider):
|
||||
- `MINERU_APISERVER`: MinerU API endpoint, for example `http://mineru-host:8886`.
|
||||
- `MINERU_BACKEND`: MinerU backend, defaults to `pipeline` (supports `vlm-http-client`, `vlm-transformers`, `vlm-vllm-engine`, `vlm-mlx-engine`, `vlm-vllm-async-engine`).
|
||||
- `MINERU_SERVER_URL`: (optional) For `vlm-http-client`, the downstream vLLM HTTP server, for example `http://vllm-host:30000`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) Local directory to store MinerU API outputs (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temp dir is used (`1` deletes temp outputs; set `0` to keep).
|
||||
3. In the web UI, navigate to the **Configuration** page of your dataset. Click **Built-in** in the **Ingestion pipeline** section, select a chunking method from the **Built-in** dropdown, which supports PDF parsing, and select **MinerU** in **PDF parser**.
|
||||
4. If you use a custom ingestion pipeline instead, provide the same MinerU settings and select **MinerU** in the **Parsing method** section of the **Parser** component.
|
||||
Starting from v0.22.0, RAGFlow includes MinerU (≥ 2.6.3) as an optional PDF parser of multiple backends. Please note that RAGFlow acts only as a *remote client* for MinerU, calling the MinerU API to parse documents and reading the returned files. To use this feature:
|
||||
:::
|
||||
|
||||
1. Prepare a reachable MinerU API service (FastAPI server).
|
||||
2. In the **.env** file or from the **Model providers** page in the UI, configure RAGFlow as a remote client to MinerU:
|
||||
- `MINERU_APISERVER`: The MinerU API endpoint (e.g., `http://mineru-host:8886`).
|
||||
- `MINERU_BACKEND`: The MinerU backend:
|
||||
- `"pipeline"` (default)
|
||||
- `"vlm-http-client"`
|
||||
- `"vlm-transformers"`
|
||||
- `"vlm-vllm-engine"`
|
||||
- `"vlm-mlx-engine"`
|
||||
- `"vlm-vllm-async-engine"`
|
||||
- `"vlm-lmdeploy-engine"`.
|
||||
- `MINERU_SERVER_URL`: (optional) The downstream vLLM HTTP server (e.g., `http://vllm-host:30000`). Applicable when `MINERU_BACKEND` is set to `"vlm-http-client"`.
|
||||
- `MINERU_OUTPUT_DIR`: (optional) The local directory for holding the outputs of the MinerU API service (zip/JSON) before ingestion.
|
||||
- `MINERU_DELETE_OUTPUT`: Whether to delete temporary output when a temporary directory is used:
|
||||
- `1`: Delete.
|
||||
- `0`: Retain.
|
||||
3. In the web UI, navigate to your dataset's **Configuration** page and find the **Ingestion pipeline** section:
|
||||
- If you decide to use a chunking method from the **Built-in** dropdown, ensure it supports PDF parsing, then select **MinerU** from the **PDF parser** dropdown.
|
||||
- If you use a custom ingestion pipeline instead, select **MinerU** in the **PDF parser** section of the **Parser** component.
|
||||
|
||||
:::note
|
||||
All MinerU environment variables are optional. When they are set, RAGFlow will auto-create a MinerU OCR model for a tenant on first use using these values. If you do not want this auto-provisioning, configure MinerU only through the UI and leave the env vars unset.
|
||||
All MinerU environment variables are optional. When set, these values are used to auto-provision a MinerU OCR model for the tenant on first use. To avoid auto-provisioning, skip the environment variable settings and only configure MinerU from the **Model providers** page in the UI.
|
||||
:::
|
||||
|
||||
:::caution WARNING
|
||||
|
||||
@ -48,6 +48,7 @@ This API follows the same request and response format as OpenAI's API. It allows
|
||||
- `"model"`: `string`
|
||||
- `"messages"`: `object list`
|
||||
- `"stream"`: `boolean`
|
||||
- `"extra_body"`: `object` (optional)
|
||||
|
||||
##### Request example
|
||||
|
||||
@ -59,7 +60,20 @@ curl --request POST \
|
||||
--data '{
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"stream": true
|
||||
"stream": true,
|
||||
"extra_body": {
|
||||
"reference": true,
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@ -74,6 +88,11 @@ curl --request POST \
|
||||
- `stream` (*Body parameter*) `boolean`
|
||||
Whether to receive the response as a stream. Set this to `false` explicitly if you prefer to receive the entire response in one go instead of as a stream.
|
||||
|
||||
- `extra_body` (*Body parameter*) `object`
|
||||
Extra request parameters:
|
||||
- `reference`: `boolean` - include reference in the final chunk (stream) or in the final message (non-stream).
|
||||
- `metadata_condition`: `object` - metadata filter conditions applied to retrieval results.
|
||||
|
||||
#### Response
|
||||
|
||||
Stream:
|
||||
@ -2236,7 +2255,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
|
||||
- `"document_ids"`: `list[string]` *optional*
|
||||
The associated document ID.
|
||||
- `"metadata_condition"`: `object`, *optional*
|
||||
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
|
||||
- `"logic"`: Defines the logic relation between conditions if multiple conditions are provided. Options:
|
||||
- `"and"` (default)
|
||||
- `"or"`
|
||||
- `"conditions"`: `list[object]` *optional*
|
||||
@ -2266,7 +2285,7 @@ Batch update or delete document-level metadata within a specified dataset. If bo
|
||||
- `"deletes`: (*Body parameter*), `list[ojbect]`, *optional*
|
||||
Deletes metadata of the retrieved documents. Each object: `{ "key": string, "value": string }`.
|
||||
- `"key"`: `string` The name of the key to delete.
|
||||
- `"value"`: `string` *Optional* The value of the key to delete.
|
||||
- `"value"`: `string` *Optional* The value of the key to delete.
|
||||
- When provided, only keys with a matching value are deleted.
|
||||
- When omitted, all specified keys are deleted.
|
||||
|
||||
@ -2533,7 +2552,7 @@ curl --request POST \
|
||||
:::caution WARNING
|
||||
`model_type` is an *internal* parameter, serving solely as a temporary workaround for the current model-configuration design limitations.
|
||||
|
||||
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
|
||||
Its main purpose is to let *multimodal* models (stored in the database as `"image2text"`) pass backend validation/dispatching. Be mindful that:
|
||||
|
||||
- Do *not* treat it as a stable public API.
|
||||
- It is subject to change or removal in future releases.
|
||||
@ -3185,6 +3204,7 @@ Asks a specified chat assistant a question to start an AI-powered conversation.
|
||||
- `"stream"`: `boolean`
|
||||
- `"session_id"`: `string` (optional)
|
||||
- `"user_id`: `string` (optional)
|
||||
- `"metadata_condition"`: `object` (optional)
|
||||
|
||||
##### Request example
|
||||
|
||||
@ -3207,7 +3227,17 @@ curl --request POST \
|
||||
{
|
||||
"question": "Who are you",
|
||||
"stream": true,
|
||||
"session_id":"9fa7691cb85c11ef9c5f0242ac120005"
|
||||
"session_id":"9fa7691cb85c11ef9c5f0242ac120005",
|
||||
"metadata_condition": {
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"name": "author",
|
||||
"comparison_operator": "is",
|
||||
"value": "bob"
|
||||
}
|
||||
]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
@ -3225,6 +3255,13 @@ curl --request POST \
|
||||
The ID of session. If it is not provided, a new session will be generated.
|
||||
- `"user_id"`: (*Body parameter*), `string`
|
||||
The optional user-defined ID. Valid *only* when no `session_id` is provided.
|
||||
- `"metadata_condition"`: (*Body parameter*), `object`
|
||||
Optional metadata filter conditions applied to retrieval results.
|
||||
- `logic`: `string`, one of `and` / `or`
|
||||
- `conditions`: `list[object]` where each condition contains:
|
||||
- `name`: `string` metadata key
|
||||
- `comparison_operator`: `string` (e.g. `is`, `not is`, `contains`, `not contains`, `start with`, `end with`, `empty`, `not empty`, `>`, `<`, `≥`, `≤`)
|
||||
- `value`: `string|number|boolean` (optional for `empty`/`not empty`)
|
||||
|
||||
#### Response
|
||||
|
||||
@ -3601,6 +3638,8 @@ Asks a specified agent a question to start an AI-powered conversation.
|
||||
[DONE]
|
||||
```
|
||||
|
||||
- You can optionally return step-by-step trace logs (see `return_trace` below).
|
||||
|
||||
:::
|
||||
|
||||
#### Request
|
||||
@ -3616,6 +3655,17 @@ Asks a specified agent a question to start an AI-powered conversation.
|
||||
- `"session_id"`: `string` (optional)
|
||||
- `"inputs"`: `object` (optional)
|
||||
- `"user_id"`: `string` (optional)
|
||||
- `"return_trace"`: `boolean` (optional, default `false`) — include execution trace logs.
|
||||
|
||||
#### Streaming events to handle
|
||||
|
||||
When `stream=true`, the server sends Server-Sent Events (SSE). Clients should handle these `event` types:
|
||||
|
||||
- `message`: streaming content from Message components.
|
||||
- `message_end`: end of a Message component; may include `reference`/`attachment`.
|
||||
- `node_finished`: a component finishes; `data.inputs/outputs/error/elapsed_time` describe the node result. If `return_trace=true`, the trace is attached inside the same `node_finished` event (`data.trace`).
|
||||
|
||||
The stream terminates with `[DONE]`.
|
||||
|
||||
:::info IMPORTANT
|
||||
You can include custom parameters in the request body, but first ensure they are defined in the [Begin](../guides/agent/agent_component_reference/begin.mdx) component.
|
||||
@ -3800,6 +3850,92 @@ data: {
|
||||
"session_id": "cd097ca083dc11f0858253708ecb6573"
|
||||
}
|
||||
|
||||
data: {
|
||||
"event": "node_finished",
|
||||
"message_id": "cecdcb0e83dc11f0858253708ecb6573",
|
||||
"created_at": 1756364483,
|
||||
"task_id": "d1f79142831f11f09cc51795b9eb07c0",
|
||||
"data": {
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxx",
|
||||
"_created_time": 15294.0382,
|
||||
"_elapsed_time": 0.00017
|
||||
},
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 11.2091,
|
||||
"created_at": 15294.0382,
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {},
|
||||
"outputs": {
|
||||
"_created_time": 15257.7949,
|
||||
"_elapsed_time": 0.00070
|
||||
},
|
||||
"component_id": "begin",
|
||||
"component_name": "begin",
|
||||
"component_type": "Begin",
|
||||
"error": null,
|
||||
"elapsed_time": 0.00085,
|
||||
"created_at": 15257.7949
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxx",
|
||||
"_created_time": 15257.7982,
|
||||
"_elapsed_time": 36.2382
|
||||
},
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"component_name": "Agent_0",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 36.2385,
|
||||
"created_at": 15257.7982
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"trace": [
|
||||
{
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"content": "xxxxxxxxxxxxxxxxx",
|
||||
"_created_time": 15294.0382,
|
||||
"_elapsed_time": 0.00017
|
||||
},
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"error": null,
|
||||
"elapsed_time": 11.2091,
|
||||
"created_at": 15294.0382
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"session_id": "cd097ca083dc11f0858253708ecb6573"
|
||||
}
|
||||
|
||||
data:[DONE]
|
||||
```
|
||||
|
||||
@ -3874,7 +4010,100 @@ Non-stream:
|
||||
"doc_name": "INSTALL3.md"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "begin",
|
||||
"component_name": "begin",
|
||||
"component_type": "Begin",
|
||||
"created_at": 15926.567517862,
|
||||
"elapsed_time": 0.0008189299987861887,
|
||||
"error": null,
|
||||
"inputs": {},
|
||||
"outputs": {
|
||||
"_created_time": 15926.567517862,
|
||||
"_elapsed_time": 0.0006958619997021742
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Agent:WeakDragonsRead",
|
||||
"component_name": "Agent_0",
|
||||
"component_type": "Agent",
|
||||
"created_at": 15926.569121755,
|
||||
"elapsed_time": 53.49016142000073,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15926.569121755,
|
||||
"_elapsed_time": 53.489981256001556,
|
||||
"content": "xxxxxxxxxxxxxx",
|
||||
"use_tools": [
|
||||
{
|
||||
"arguments": {
|
||||
"query": "xxxx"
|
||||
},
|
||||
"name": "search_my_dateset",
|
||||
"results": "xxxxxxxxxxx"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Agent:EveryHairsChew",
|
||||
"component_name": "Agent_1",
|
||||
"component_type": "Agent",
|
||||
"created_at": 15980.060569101,
|
||||
"elapsed_time": 23.61718057500002,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"sys.query": "how to install neovim?"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15980.060569101,
|
||||
"_elapsed_time": 0.0003451630000199657,
|
||||
"content": "xxxxxxxxxxxx"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"component_id": "Message:SlickDingosHappen",
|
||||
"trace": [
|
||||
{
|
||||
"component_id": "Message:SlickDingosHappen",
|
||||
"component_name": "Message_0",
|
||||
"component_type": "Message",
|
||||
"created_at": 15980.061302513,
|
||||
"elapsed_time": 23.61655923699982,
|
||||
"error": null,
|
||||
"inputs": {
|
||||
"Agent:EveryHairsChew@content": "xxxxxxxxx",
|
||||
"Agent:WeakDragonsRead@content": "xxxxxxxxxxx"
|
||||
},
|
||||
"outputs": {
|
||||
"_created_time": 15980.061302513,
|
||||
"_elapsed_time": 0.0006695749998471001,
|
||||
"content": "xxxxxxxxxxx"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"event": "workflow_finished",
|
||||
"message_id": "c4692a2683d911f0858253708ecb6573",
|
||||
|
||||
@ -77,6 +77,19 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
||||
If your model is not listed here but has APIs compatible with those of OpenAI, click **OpenAI-API-Compatible** on the **Model providers** page to configure your model.
|
||||
:::
|
||||
|
||||
## Example: AI Badgr (OpenAI-compatible)
|
||||
|
||||
You can use **AI Badgr** with RAGFlow via the existing OpenAI-API-Compatible provider.
|
||||
|
||||
To configure AI Badgr:
|
||||
|
||||
- **Provider**: `OpenAI-API-Compatible`
|
||||
- **Base URL**: `https://aibadgr.com/api/v1`
|
||||
- **API Key**: your AI Badgr API key (from the AI Badgr dashboard)
|
||||
- **Model**: any AI Badgr chat or embedding model ID, as exposed by AI Badgr's OpenAI-compatible APIs
|
||||
|
||||
AI Badgr implements OpenAI-compatible endpoints for `/v1/chat/completions`, `/v1/embeddings`, and `/v1/models`, so no additional code changes in RAGFlow are required.
|
||||
|
||||
:::note
|
||||
The list of supported models is extracted from [this source](https://github.com/infiniflow/ragflow/blob/main/rag/llm/__init__.py) and may not be the most current. For the latest supported model list, please refer to the Python file.
|
||||
:::
|
||||
|
||||
@ -23,8 +23,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
|
||||
return [
|
||||
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
|
||||
"http://mirrors.tuna.tsinghua.edu.cn/ubuntu-ports/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
|
||||
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
|
||||
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
|
||||
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
|
||||
"https://repo.huaweicloud.com/repository/maven/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
|
||||
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
|
||||
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
|
||||
["https://registry.npmmirror.com/-/binary/chrome-for-testing/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
|
||||
@ -34,8 +34,8 @@ def get_urls(use_china_mirrors=False) -> list[Union[str, list[str]]]:
|
||||
return [
|
||||
"http://archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb",
|
||||
"http://ports.ubuntu.com/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_arm64.deb",
|
||||
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar",
|
||||
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.0.0/tika-server-standard-3.0.0.jar.md5",
|
||||
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar",
|
||||
"https://repo1.maven.org/maven2/org/apache/tika/tika-server-standard/3.2.3/tika-server-standard-3.2.3.jar.md5",
|
||||
"https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken",
|
||||
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chrome-linux64.zip", "chrome-linux64-121-0-6167-85"],
|
||||
["https://storage.googleapis.com/chrome-for-testing-public/121.0.6167.85/linux64/chromedriver-linux64.zip", "chromedriver-linux64-121-0-6167-85"],
|
||||
@ -49,10 +49,10 @@ repos = [
|
||||
]
|
||||
|
||||
|
||||
def download_model(repo_id):
|
||||
local_dir = os.path.abspath(os.path.join("huggingface.co", repo_id))
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
snapshot_download(repo_id=repo_id, local_dir=local_dir)
|
||||
def download_model(repository_id):
|
||||
local_directory = os.path.abspath(os.path.join("huggingface.co", repository_id))
|
||||
os.makedirs(local_directory, exist_ok=True)
|
||||
snapshot_download(repo_id=repository_id, local_dir=local_directory)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -132,8 +132,8 @@ class EntityResolution(Extractor):
|
||||
f"{remain_candidates_to_resolve} remain."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Error resolving candidate batch: {e}")
|
||||
except Exception as exception:
|
||||
logging.error(f"Error resolving candidate batch: {exception}")
|
||||
|
||||
|
||||
tasks = []
|
||||
@ -251,7 +251,7 @@ class EntityResolution(Extractor):
|
||||
ans_list = []
|
||||
records = [r.strip() for r in results.split(record_delimiter)]
|
||||
for record in records:
|
||||
pattern_int = f"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
||||
pattern_int = fr"{re.escape(entity_index_delimiter)}(\d+){re.escape(entity_index_delimiter)}"
|
||||
match_int = re.search(pattern_int, record)
|
||||
res_int = int(str(match_int.group(1) if match_int else '0'))
|
||||
if res_int > records_length:
|
||||
|
||||
@ -71,18 +71,17 @@ class Extractor:
|
||||
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.92))
|
||||
response = ""
|
||||
for attempt in range(3):
|
||||
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
|
||||
try:
|
||||
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
||||
break
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if attempt == 2:
|
||||
|
||||
@ -198,7 +198,7 @@ async def run_graphrag_for_kb(
|
||||
|
||||
for d in raw_chunks:
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 1024:
|
||||
if num_tokens_from_string(current_chunk + content) < 4096:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
|
||||
@ -78,10 +78,6 @@ class GraphExtractor(Extractor):
|
||||
hint_prompt = self._entity_extract_prompt.format(**self._context_base, input_text=content)
|
||||
|
||||
gen_conf = {}
|
||||
final_result = ""
|
||||
glean_result = ""
|
||||
if_loop_result = ""
|
||||
history = []
|
||||
logging.info(f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
if self.callback:
|
||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
|
||||
@ -24,11 +24,11 @@ from common.misc_utils import get_uuid
|
||||
from graphrag.query_analyze_prompt import PROMPTS
|
||||
from graphrag.utils import get_entity_type2samples, get_llm_cache, set_llm_cache, get_relation
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
|
||||
from rag.nlp.search import Dealer, index_name
|
||||
from common.float_utils import get_float
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
|
||||
class KGSearch(Dealer):
|
||||
|
||||
@ -26,9 +26,9 @@ from networkx.readwrite import json_graph
|
||||
from common.misc_utils import get_uuid
|
||||
from common.connection_utils import timeout
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.doc_store_conn import OrderByExpr
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
||||
|
||||
@ -96,7 +96,7 @@ ragflow:
|
||||
infinity:
|
||||
image:
|
||||
repository: infiniflow/infinity
|
||||
tag: v0.6.11
|
||||
tag: v0.6.13
|
||||
pullPolicy: IfNotPresent
|
||||
pullSecrets: []
|
||||
storage:
|
||||
|
||||
0
memory/__init__.py
Normal file
0
memory/__init__.py
Normal file
0
memory/services/__init__.py
Normal file
0
memory/services/__init__.py
Normal file
240
memory/services/messages.py
Normal file
240
memory/services/messages.py
Normal file
@ -0,0 +1,240 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr, MatchExpr
|
||||
|
||||
|
||||
def index_name(uid: str): return f"memory_{uid}"
|
||||
|
||||
|
||||
class MessageService:
|
||||
|
||||
@classmethod
|
||||
def has_index(cls, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.index_exist(index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def create_index(cls, uid: str, memory_id: str, vector_size: int):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.create_idx(index, memory_id, vector_size)
|
||||
|
||||
@classmethod
|
||||
def delete_index(cls, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.delete_idx(index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def insert_message(cls, messages: List[dict], uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
[m.update({
|
||||
"id": f'{memory_id}_{m["message_id"]}',
|
||||
"status": 1 if m["status"] else 0
|
||||
}) for m in messages]
|
||||
return settings.msgStoreConn.insert(messages, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def update_message(cls, condition: dict, update_dict: dict, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
if "status" in update_dict:
|
||||
update_dict["status"] = 1 if update_dict["status"] else 0
|
||||
return settings.msgStoreConn.update(condition, update_dict, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def delete_message(cls, condition: dict, uid: str, memory_id: str):
|
||||
index = index_name(uid)
|
||||
return settings.msgStoreConn.delete(condition, index, memory_id)
|
||||
|
||||
@classmethod
|
||||
def list_message(cls, uid: str, memory_id: str, agent_ids: List[str]=None, keywords: str=None, page: int=1, page_size: int=50):
|
||||
index = index_name(uid)
|
||||
filter_dict = {}
|
||||
if agent_ids:
|
||||
filter_dict["agent_id"] = agent_ids
|
||||
if keywords:
|
||||
filter_dict["session_id"] = keywords
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=filter_dict,
|
||||
match_expressions=[], order_by=order_by,
|
||||
offset=(page-1)*page_size, limit=page_size,
|
||||
index_names=index, memory_ids=[memory_id], agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
total_count = settings.msgStoreConn.get_total(res)
|
||||
doc_mapping = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status"
|
||||
])
|
||||
return {
|
||||
"message_list": list(doc_mapping.values()),
|
||||
"total_count": total_count
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_recent_messages(cls, uid_list: List[str], memory_ids: List[str], agent_id: str, session_id: str, limit: int):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
condition_dict = {
|
||||
"agent_id": agent_id,
|
||||
"session_id": session_id
|
||||
}
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=condition_dict,
|
||||
match_expressions=[], order_by=order_by,
|
||||
offset=0, limit=limit,
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
|
||||
)
|
||||
doc_mapping = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id","user_id", "agent_id", "session_id",
|
||||
"valid_at", "invalid_at", "forget_at", "status", "content"
|
||||
])
|
||||
return list(doc_mapping.values())
|
||||
|
||||
@classmethod
|
||||
def search_message(cls, memory_ids: List[str], condition_dict: dict, uid_list: List[str], match_expressions:list[MatchExpr], top_n: int):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
# filter only valid messages by default
|
||||
if "status" not in condition_dict:
|
||||
condition_dict["status"] = 1
|
||||
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=[
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id",
|
||||
"valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
],
|
||||
highlight_fields=[],
|
||||
condition=condition_dict,
|
||||
match_expressions=match_expressions,
|
||||
order_by=order_by,
|
||||
offset=0, limit=top_n,
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[]
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, [
|
||||
"message_id", "message_type", "source_id", "memory_id", "user_id", "agent_id", "session_id", "valid_at",
|
||||
"invalid_at", "forget_at", "status", "content"
|
||||
])
|
||||
return list(docs.values())
|
||||
|
||||
@staticmethod
|
||||
def calculate_message_size(message: dict):
|
||||
return sys.getsizeof(message["content"]) + sys.getsizeof(message["content_embed"][0]) * len(message["content_embed"])
|
||||
|
||||
@classmethod
|
||||
def calculate_memory_size(cls, memory_ids: List[str], uid_list: List[str]):
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("valid_at")
|
||||
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["memory_id", "content", "content_embed"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=2000*len(memory_ids),
|
||||
index_names=index_names, memory_ids=memory_ids, agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, ["memory_id", "content", "content_embed"])
|
||||
size_dict = {}
|
||||
for doc in docs.values():
|
||||
if size_dict.get(doc["memory_id"]):
|
||||
size_dict[doc["memory_id"]] += cls.calculate_message_size(doc)
|
||||
else:
|
||||
size_dict[doc["memory_id"]] = cls.calculate_message_size(doc)
|
||||
return size_dict
|
||||
|
||||
@classmethod
|
||||
def pick_messages_to_delete_by_fifo(cls, memory_id: str, uid: str, size_to_delete: int):
|
||||
select_fields = ["message_id", "content", "content_embed"]
|
||||
_index_name = index_name(uid)
|
||||
res = settings.msgStoreConn.get_forgotten_messages(select_fields, _index_name, memory_id)
|
||||
message_list = settings.msgStoreConn.get_fields(res, select_fields)
|
||||
current_size = 0
|
||||
ids_to_remove = []
|
||||
for message in message_list:
|
||||
if current_size < size_to_delete:
|
||||
current_size += cls.calculate_message_size(message)
|
||||
ids_to_remove.append(message["message_id"])
|
||||
else:
|
||||
return ids_to_remove, current_size
|
||||
if current_size >= size_to_delete:
|
||||
return ids_to_remove, current_size
|
||||
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("valid_at")
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["memory_id", "content", "content_embed"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=2000,
|
||||
index_names=[_index_name], memory_ids=[memory_id], agg_fields=[]
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, select_fields)
|
||||
for doc in docs.values():
|
||||
if current_size < size_to_delete:
|
||||
current_size += cls.calculate_message_size(doc)
|
||||
ids_to_remove.append(doc["memory_id"])
|
||||
else:
|
||||
return ids_to_remove, current_size
|
||||
return ids_to_remove, current_size
|
||||
|
||||
@classmethod
|
||||
def get_by_message_id(cls, memory_id: str, message_id: int, uid: str):
|
||||
index = index_name(uid)
|
||||
doc_id = f'{memory_id}_{message_id}'
|
||||
return settings.msgStoreConn.get(doc_id, index, [memory_id])
|
||||
|
||||
@classmethod
|
||||
def get_max_message_id(cls, uid_list: List[str], memory_ids: List[str]):
|
||||
order_by = OrderByExpr()
|
||||
order_by.desc("message_id")
|
||||
index_names = [index_name(uid) for uid in uid_list]
|
||||
res = settings.msgStoreConn.search(
|
||||
select_fields=["message_id"],
|
||||
highlight_fields=[],
|
||||
condition={},
|
||||
match_expressions=[],
|
||||
order_by=order_by,
|
||||
offset=0, limit=1,
|
||||
index_names=index_names, memory_ids=memory_ids,
|
||||
agg_fields=[], hide_forgotten=False
|
||||
)
|
||||
docs = settings.msgStoreConn.get_fields(res, ["message_id"])
|
||||
if not docs:
|
||||
return 1
|
||||
else:
|
||||
latest_msg = list(docs.values())[0]
|
||||
return int(latest_msg["message_id"])
|
||||
185
memory/services/query.py
Normal file
185
memory/services/query.py
Normal file
@ -0,0 +1,185 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import re
|
||||
import logging
|
||||
import json
|
||||
import numpy as np
|
||||
from common.query_base import QueryBase
|
||||
from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr
|
||||
from common.float_utils import get_float
|
||||
from rag.nlp import rag_tokenizer, term_weight, synonym
|
||||
|
||||
|
||||
def get_vector(txt, emb_mdl, topk=10, similarity=0.1):
|
||||
if isinstance(similarity, str) and len(similarity) > 0:
|
||||
try:
|
||||
similarity = float(similarity)
|
||||
except Exception as e:
|
||||
logging.warning(f"Convert similarity '{similarity}' to float failed: {e}. Using default 0.1")
|
||||
similarity = 0.1
|
||||
qv, _ = emb_mdl.encode_queries(txt)
|
||||
shape = np.array(qv).shape
|
||||
if len(shape) > 1:
|
||||
raise Exception(
|
||||
f"Dealer.get_vector returned array's shape {shape} doesn't match expectation(exact one dimension).")
|
||||
embedding_data = [get_float(v) for v in qv]
|
||||
vector_column_name = f"q_{len(embedding_data)}_vec"
|
||||
return MatchDenseExpr(vector_column_name, embedding_data, 'float', 'cosine', topk, {"similarity": similarity})
|
||||
|
||||
|
||||
class MsgTextQuery(QueryBase):
|
||||
|
||||
def __init__(self):
|
||||
self.tw = term_weight.Dealer()
|
||||
self.syn = synonym.Dealer()
|
||||
self.query_fields = [
|
||||
"content"
|
||||
]
|
||||
|
||||
def question(self, txt, tbl="messages", min_match: float=0.6):
|
||||
original_query = txt
|
||||
txt = MsgTextQuery.add_space_between_eng_zh(txt)
|
||||
txt = re.sub(
|
||||
r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+",
|
||||
" ",
|
||||
rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())),
|
||||
).strip()
|
||||
otxt = txt
|
||||
txt = MsgTextQuery.rmWWW(txt)
|
||||
|
||||
if not self.is_chinese(txt):
|
||||
txt = self.rmWWW(txt)
|
||||
tks = rag_tokenizer.tokenize(txt).split()
|
||||
keywords = [t for t in tks if t]
|
||||
tks_w = self.tw.weights(tks, preprocess=False)
|
||||
tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w]
|
||||
tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk]
|
||||
tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()]
|
||||
syns = []
|
||||
for tk, w in tks_w[:256]:
|
||||
syn = self.syn.lookup(tk)
|
||||
syn = rag_tokenizer.tokenize(" ".join(syn)).split()
|
||||
keywords.extend(syn)
|
||||
syn = ["\"{}\"^{:.4f}".format(s, w / 4.) for s in syn if s.strip()]
|
||||
syns.append(" ".join(syn))
|
||||
|
||||
q = ["({}^{:.4f}".format(tk, w) + " {})".format(syn) for (tk, w), syn in zip(tks_w, syns) if
|
||||
tk and not re.match(r"[.^+\(\)-]", tk)]
|
||||
for i in range(1, len(tks_w)):
|
||||
left, right = tks_w[i - 1][0].strip(), tks_w[i][0].strip()
|
||||
if not left or not right:
|
||||
continue
|
||||
q.append(
|
||||
'"%s %s"^%.4f'
|
||||
% (
|
||||
tks_w[i - 1][0],
|
||||
tks_w[i][0],
|
||||
max(tks_w[i - 1][1], tks_w[i][1]) * 2,
|
||||
)
|
||||
)
|
||||
if not q:
|
||||
q.append(txt)
|
||||
query = " ".join(q)
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"original_query": original_query}
|
||||
), keywords
|
||||
|
||||
def need_fine_grained_tokenize(tk):
|
||||
if len(tk) < 3:
|
||||
return False
|
||||
if re.match(r"[0-9a-z\.\+#_\*-]+$", tk):
|
||||
return False
|
||||
return True
|
||||
|
||||
txt = self.rmWWW(txt)
|
||||
qs, keywords = [], []
|
||||
for tt in self.tw.split(txt)[:256]: # .split():
|
||||
if not tt:
|
||||
continue
|
||||
keywords.append(tt)
|
||||
twts = self.tw.weights([tt])
|
||||
syns = self.syn.lookup(tt)
|
||||
if syns and len(keywords) < 32:
|
||||
keywords.extend(syns)
|
||||
logging.debug(json.dumps(twts, ensure_ascii=False))
|
||||
tms = []
|
||||
for tk, w in sorted(twts, key=lambda x: x[1] * -1):
|
||||
sm = (
|
||||
rag_tokenizer.fine_grained_tokenize(tk).split()
|
||||
if need_fine_grained_tokenize(tk)
|
||||
else []
|
||||
)
|
||||
sm = [
|
||||
re.sub(
|
||||
r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+",
|
||||
"",
|
||||
m,
|
||||
)
|
||||
for m in sm
|
||||
]
|
||||
sm = [self.sub_special_char(m) for m in sm if len(m) > 1]
|
||||
sm = [m for m in sm if len(m) > 1]
|
||||
|
||||
if len(keywords) < 32:
|
||||
keywords.append(re.sub(r"[ \\\"']+", "", tk))
|
||||
keywords.extend(sm)
|
||||
|
||||
tk_syns = self.syn.lookup(tk)
|
||||
tk_syns = [self.sub_special_char(s) for s in tk_syns]
|
||||
if len(keywords) < 32:
|
||||
keywords.extend([s for s in tk_syns if s])
|
||||
tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s]
|
||||
tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns]
|
||||
|
||||
if len(keywords) >= 32:
|
||||
break
|
||||
|
||||
tk = self.sub_special_char(tk)
|
||||
if tk.find(" ") > 0:
|
||||
tk = '"%s"' % tk
|
||||
if tk_syns:
|
||||
tk = f"({tk} OR (%s)^0.2)" % " ".join(tk_syns)
|
||||
if sm:
|
||||
tk = f'{tk} OR "%s" OR ("%s"~2)^0.5' % (" ".join(sm), " ".join(sm))
|
||||
if tk.strip():
|
||||
tms.append((tk, w))
|
||||
|
||||
tms = " ".join([f"({t})^{w}" for t, w in tms])
|
||||
|
||||
if len(twts) > 1:
|
||||
tms += ' ("%s"~2)^1.5' % rag_tokenizer.tokenize(tt)
|
||||
|
||||
syns = " OR ".join(
|
||||
[
|
||||
'"%s"'
|
||||
% rag_tokenizer.tokenize(self.sub_special_char(s))
|
||||
for s in syns
|
||||
]
|
||||
)
|
||||
if syns and tms:
|
||||
tms = f"({tms})^5 OR ({syns})^0.7"
|
||||
|
||||
qs.append(tms)
|
||||
|
||||
if qs:
|
||||
query = " OR ".join([f"({t})" for t in qs if t])
|
||||
if not query:
|
||||
query = otxt
|
||||
return MatchTextExpr(
|
||||
self.query_fields, query, 100, {"minimum_should_match": min_match, "original_query": original_query}
|
||||
), keywords
|
||||
return None, keywords
|
||||
0
memory/utils/__init__.py
Normal file
0
memory/utils/__init__.py
Normal file
494
memory/utils/es_conn.py
Normal file
494
memory/utils/es_conn.py
Normal file
@ -0,0 +1,494 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
|
||||
import copy
|
||||
from elasticsearch import NotFoundError
|
||||
from elasticsearch_dsl import UpdateByQuery, Q, Search
|
||||
from elastic_transport import ConnectionTimeout
|
||||
from common.decorator import singleton
|
||||
from common.doc_store.doc_store_base import MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, FusionExpr
|
||||
from common.doc_store.es_conn_base import ESConnectionBase
|
||||
from common.float_utils import get_float
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
|
||||
ATTEMPT_TIME = 2
|
||||
|
||||
|
||||
@singleton
|
||||
class ESConnection(ESConnectionBase):
|
||||
|
||||
@staticmethod
|
||||
def convert_field_name(field_name: str) -> str:
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "content":
|
||||
return "content_ltks"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
@staticmethod
|
||||
def map_message_to_es_fields(message: dict) -> dict:
|
||||
"""
|
||||
Map message dictionary fields to Elasticsearch document/Infinity fields.
|
||||
|
||||
:param message: A dictionary containing message details.
|
||||
:return: A dictionary formatted for Elasticsearch/Infinity indexing.
|
||||
"""
|
||||
storage_doc = {
|
||||
"id": message.get("id"),
|
||||
"message_id": message["message_id"],
|
||||
"message_type_kwd": message["message_type"],
|
||||
"source_id": message["source_id"],
|
||||
"memory_id": message["memory_id"],
|
||||
"user_id": message["user_id"],
|
||||
"agent_id": message["agent_id"],
|
||||
"session_id": message["session_id"],
|
||||
"valid_at": message["valid_at"],
|
||||
"invalid_at": message["invalid_at"],
|
||||
"forget_at": message["forget_at"],
|
||||
"status_int": 1 if message["status"] else 0,
|
||||
"zone_id": message.get("zone_id", 0),
|
||||
"content_ltks": message["content"],
|
||||
f"q_{len(message['content_embed'])}_vec": message["content_embed"],
|
||||
}
|
||||
return storage_doc
|
||||
|
||||
@staticmethod
|
||||
def get_message_from_es_doc(doc: dict) -> dict:
|
||||
"""
|
||||
Convert an Elasticsearch/Infinity document back to a message dictionary.
|
||||
|
||||
:param doc: A dictionary representing the Elasticsearch/Infinity document.
|
||||
:return: A dictionary formatted as a message.
|
||||
"""
|
||||
embd_field_name = next((key for key in doc.keys() if re.match(r"q_\d+_vec", key)), None)
|
||||
message = {
|
||||
"message_id": doc["message_id"],
|
||||
"message_type": doc["message_type_kwd"],
|
||||
"source_id": doc["source_id"] if doc["source_id"] else None,
|
||||
"memory_id": doc["memory_id"],
|
||||
"user_id": doc.get("user_id", ""),
|
||||
"agent_id": doc["agent_id"],
|
||||
"session_id": doc["session_id"],
|
||||
"zone_id": doc.get("zone_id", 0),
|
||||
"valid_at": doc["valid_at"],
|
||||
"invalid_at": doc.get("invalid_at", "-"),
|
||||
"forget_at": doc.get("forget_at", "-"),
|
||||
"status": bool(int(doc["status_int"])),
|
||||
"content": doc.get("content_ltks", ""),
|
||||
"content_embed": doc.get(embd_field_name, []) if embd_field_name else [],
|
||||
}
|
||||
if doc.get("id"):
|
||||
message["id"] = doc["id"]
|
||||
return message
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self, select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
memory_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
hide_forgotten: bool = True
|
||||
):
|
||||
"""
|
||||
Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl.html
|
||||
"""
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
assert "_id" not in condition
|
||||
bool_query = Q("bool", must=[], must_not=[])
|
||||
if hide_forgotten:
|
||||
# filter not forget
|
||||
bool_query.must_not.append(Q("exists", field="forget_at"))
|
||||
|
||||
condition["memory_id"] = memory_ids
|
||||
for k, v in condition.items():
|
||||
if k == "session_id" and v:
|
||||
bool_query.filter.append(Q("query_string", **{"query": f"*{v}*", "fields": ["session_id"], "analyze_wildcard": True}))
|
||||
continue
|
||||
if not v:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
s = Search()
|
||||
vector_similarity_weight = 0.5
|
||||
for m in match_expressions:
|
||||
if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
|
||||
assert len(match_expressions) == 3 and isinstance(match_expressions[0], MatchTextExpr) and isinstance(match_expressions[1],
|
||||
MatchDenseExpr) and isinstance(
|
||||
match_expressions[2], FusionExpr)
|
||||
weights = m.fusion_params["weights"]
|
||||
vector_similarity_weight = get_float(weights.split(",")[1])
|
||||
for m in match_expressions:
|
||||
if isinstance(m, MatchTextExpr):
|
||||
minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
bool_query.must.append(Q("query_string", fields=[self.convert_field_name(f) for f in m.fields],
|
||||
type="best_fields", query=m.matching_text,
|
||||
minimum_should_match=minimum_should_match,
|
||||
boost=1))
|
||||
bool_query.boost = 1.0 - vector_similarity_weight
|
||||
|
||||
elif isinstance(m, MatchDenseExpr):
|
||||
assert (bool_query is not None)
|
||||
similarity = 0.0
|
||||
if "similarity" in m.extra_options:
|
||||
similarity = m.extra_options["similarity"]
|
||||
s = s.knn(self.convert_field_name(m.vector_column_name),
|
||||
m.topn,
|
||||
m.topn * 2,
|
||||
query_vector=list(m.embedding_data),
|
||||
filter=bool_query.to_dict(),
|
||||
similarity=similarity,
|
||||
)
|
||||
|
||||
if bool_query and rank_feature:
|
||||
for fld, sc in rank_feature.items():
|
||||
if fld != PAGERANK_FLD:
|
||||
fld = f"{TAG_FLD}.{fld}"
|
||||
bool_query.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))
|
||||
|
||||
if bool_query:
|
||||
s = s.query(bool_query)
|
||||
for field in highlight_fields:
|
||||
s = s.highlight(field)
|
||||
|
||||
if order_by:
|
||||
orders = list()
|
||||
for field, order in order_by.fields:
|
||||
order = "asc" if order == 0 else "desc"
|
||||
if field.endswith("_int") or field.endswith("_flt"):
|
||||
order_info = {"order": order, "unmapped_type": "float"}
|
||||
else:
|
||||
order_info = {"order": order, "unmapped_type": "text"}
|
||||
orders.append({field: order_info})
|
||||
s = s.sort(*orders)
|
||||
|
||||
if agg_fields:
|
||||
for fld in agg_fields:
|
||||
s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)
|
||||
|
||||
if limit > 0:
|
||||
s = s[offset:offset + limit]
|
||||
q = s.to_dict()
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} query: " + json.dumps(q))
|
||||
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
#print(json.dumps(q, ensure_ascii=False))
|
||||
res = self.es.search(index=index_names,
|
||||
body=q,
|
||||
timeout="600s",
|
||||
# search_type="dfs_query_then_fetch",
|
||||
track_total_hits=True,
|
||||
_source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_names)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_names)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
|
||||
bool_query = Q("bool", must_not=[])
|
||||
bool_query.must_not.append(Q("term", forget_at=None))
|
||||
bool_query.filter.append(Q("term", memory_id=memory_id))
|
||||
# from old to new
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("forget_at")
|
||||
# build search
|
||||
s = Search()
|
||||
s = s.query(bool_query)
|
||||
s = s.sort(order_by)
|
||||
s = s[:limit]
|
||||
q = s.to_dict()
|
||||
# search
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.search(index=index_name, body=q, timeout="600s", track_total_hits=True, _source=True)
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
self.logger.debug(f"ESConnection.search {str(index_name)} res: " + str(res))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.search {str(index_name)} query: " + str(q) + str(e))
|
||||
raise e
|
||||
|
||||
self.logger.error(f"ESConnection.search timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.search timeout.")
|
||||
|
||||
def get(self, doc_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
for i in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.get(index=index_name,
|
||||
id=doc_id, source=True, )
|
||||
if str(res.get("timed_out", "")).lower() == "true":
|
||||
raise Exception("Es Timeout.")
|
||||
message = res["_source"]
|
||||
message["id"] = doc_id
|
||||
return self.get_message_from_es_doc(message)
|
||||
except NotFoundError:
|
||||
return None
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ESConnection.get({doc_id}) got exception")
|
||||
raise e
|
||||
self.logger.error(f"ESConnection.get timeout for {ATTEMPT_TIME} times!")
|
||||
raise Exception("ESConnection.get timeout.")
|
||||
|
||||
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||
# Refers to https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html
|
||||
operations = []
|
||||
for d in documents:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
d_copy_raw = copy.deepcopy(d)
|
||||
d_copy = self.map_message_to_es_fields(d_copy_raw)
|
||||
d_copy["memory_id"] = memory_id
|
||||
meta_id = d_copy.pop("id", "")
|
||||
operations.append(
|
||||
{"index": {"_index": index_name, "_id": meta_id}})
|
||||
operations.append(d_copy)
|
||||
res = []
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = []
|
||||
r = self.es.bulk(index=index_name, operations=operations,
|
||||
refresh=False, timeout="60s")
|
||||
if re.search(r"False", str(r["errors"]), re.IGNORECASE):
|
||||
return res
|
||||
|
||||
for item in r["items"]:
|
||||
for action in ["create", "delete", "index", "update"]:
|
||||
if action in item and "error" in item[action]:
|
||||
res.append(str(item[action]["_id"]) + ":" + str(item[action]["error"]))
|
||||
return res
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
res.append(str(e))
|
||||
self.logger.warning("ESConnection.insert got exception: " + str(e))
|
||||
|
||||
return res
|
||||
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
doc = copy.deepcopy(new_value)
|
||||
update_dict = {self.convert_field_name(k): v for k, v in doc.items()}
|
||||
update_dict.pop("id", None)
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
if "id" in condition_dict and isinstance(condition_dict["id"], str):
|
||||
# update specific single document
|
||||
message_id = condition_dict["id"]
|
||||
for i in range(ATTEMPT_TIME):
|
||||
for k in update_dict.keys():
|
||||
if "feas" != k.split("_")[-1]:
|
||||
continue
|
||||
try:
|
||||
self.es.update(index=index_name, id=message_id, script=f"ctx._source.remove(\"{k}\");")
|
||||
except Exception:
|
||||
self.logger.exception(f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception")
|
||||
try:
|
||||
self.es.update(index=index_name, id=message_id, doc=update_dict)
|
||||
return True
|
||||
except Exception as e:
|
||||
self.logger.exception(
|
||||
f"ESConnection.update(index={index_name}, id={message_id}, doc={json.dumps(condition, ensure_ascii=False)}) got exception: " + str(e))
|
||||
break
|
||||
return False
|
||||
|
||||
# update unspecific maybe-multiple documents
|
||||
bool_query = Q("bool")
|
||||
for k, v in condition_dict.items():
|
||||
if not isinstance(k, str) or not v:
|
||||
continue
|
||||
if k == "exists":
|
||||
bool_query.filter.append(Q("exists", field=v))
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
bool_query.filter.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
bool_query.filter.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception(
|
||||
f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")
|
||||
scripts = []
|
||||
params = {}
|
||||
for k, v in update_dict.items():
|
||||
if k == "remove":
|
||||
if isinstance(v, str):
|
||||
scripts.append(f"ctx._source.remove('{v}');")
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"int i=ctx._source.{kk}.indexOf(params.p_{kk});ctx._source.{kk}.remove(i);")
|
||||
params[f"p_{kk}"] = vv
|
||||
continue
|
||||
if k == "add":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
scripts.append(f"ctx._source.{kk}.add(params.pp_{kk});")
|
||||
params[f"pp_{kk}"] = vv.strip()
|
||||
continue
|
||||
if (not isinstance(k, str) or not v) and k != "status_int":
|
||||
continue
|
||||
if isinstance(v, str):
|
||||
v = re.sub(r"(['\n\r]|\\.)", " ", v)
|
||||
params[f"pp_{k}"] = v
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
elif isinstance(v, int) or isinstance(v, float):
|
||||
scripts.append(f"ctx._source.{k}={v};")
|
||||
elif isinstance(v, list):
|
||||
scripts.append(f"ctx._source.{k}=params.pp_{k};")
|
||||
params[f"pp_{k}"] = json.dumps(v, ensure_ascii=False)
|
||||
else:
|
||||
raise Exception(
|
||||
f"newValue `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str.")
|
||||
ubq = UpdateByQuery(
|
||||
index=index_name).using(
|
||||
self.es).query(bool_query)
|
||||
ubq = ubq.script(source="".join(scripts), params=params)
|
||||
ubq = ubq.params(refresh=True)
|
||||
ubq = ubq.params(slices=5)
|
||||
ubq = ubq.params(conflicts="proceed")
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
_ = ubq.execute()
|
||||
return True
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.error("ESConnection.update got exception: " + str(e) + "\n".join(scripts))
|
||||
break
|
||||
return False
|
||||
|
||||
def delete(self, condition: dict, index_name: str, memory_id: str) -> int:
|
||||
assert "_id" not in condition
|
||||
condition_dict = {self.convert_field_name(k): v for k, v in condition.items()}
|
||||
condition_dict["memory_id"] = memory_id
|
||||
if "id" in condition_dict:
|
||||
message_ids = condition_dict["id"]
|
||||
if not isinstance(message_ids, list):
|
||||
message_ids = [message_ids]
|
||||
if not message_ids: # when message_ids is empty, delete all
|
||||
qry = Q("match_all")
|
||||
else:
|
||||
qry = Q("ids", values=message_ids)
|
||||
else:
|
||||
qry = Q("bool")
|
||||
for k, v in condition_dict.items():
|
||||
if k == "exists":
|
||||
qry.filter.append(Q("exists", field=v))
|
||||
|
||||
elif k == "must_not":
|
||||
if isinstance(v, dict):
|
||||
for kk, vv in v.items():
|
||||
if kk == "exists":
|
||||
qry.must_not.append(Q("exists", field=vv))
|
||||
|
||||
elif isinstance(v, list):
|
||||
qry.must.append(Q("terms", **{k: v}))
|
||||
elif isinstance(v, str) or isinstance(v, int):
|
||||
qry.must.append(Q("term", **{k: v}))
|
||||
else:
|
||||
raise Exception("Condition value must be int, str or list.")
|
||||
self.logger.debug("ESConnection.delete query: " + json.dumps(qry.to_dict()))
|
||||
for _ in range(ATTEMPT_TIME):
|
||||
try:
|
||||
res = self.es.delete_by_query(
|
||||
index=index_name,
|
||||
body=Search().query(qry).to_dict(),
|
||||
refresh=True)
|
||||
return res["deleted"]
|
||||
except ConnectionTimeout:
|
||||
self.logger.exception("ES request timeout")
|
||||
time.sleep(3)
|
||||
self._connect()
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.warning("ESConnection.delete got exception: " + str(e))
|
||||
if re.search(r"(not_found)", str(e), re.IGNORECASE):
|
||||
return 0
|
||||
return 0
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
|
||||
res_fields = {}
|
||||
if not fields:
|
||||
return {}
|
||||
for doc in self._get_source(res):
|
||||
message = self.get_message_from_es_doc(doc)
|
||||
m = {}
|
||||
for n, v in message.items():
|
||||
if n not in fields:
|
||||
continue
|
||||
if isinstance(v, list):
|
||||
m[n] = v
|
||||
continue
|
||||
if n in ["message_id", "source_id", "valid_at", "invalid_at", "forget_at", "status"] and isinstance(v, (int, float, bool)):
|
||||
m[n] = v
|
||||
continue
|
||||
if not isinstance(v, str):
|
||||
m[n] = str(v)
|
||||
else:
|
||||
m[n] = v
|
||||
|
||||
if m:
|
||||
res_fields[doc["id"]] = m
|
||||
return res_fields
|
||||
467
memory/utils/infinity_conn.py
Normal file
467
memory/utils/infinity_conn.py
Normal file
@ -0,0 +1,467 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
from infinity.common import InfinityException, SortType
|
||||
from infinity.errors import ErrorCode
|
||||
|
||||
from common.decorator import singleton
|
||||
import pandas as pd
|
||||
from common.constants import PAGERANK_FLD, TAG_FLD
|
||||
from common.doc_store.doc_store_base import MatchExpr, MatchTextExpr, MatchDenseExpr, FusionExpr, OrderByExpr
|
||||
from common.doc_store.infinity_conn_base import InfinityConnectionBase
|
||||
from common.time_utils import date_string_to_timestamp
|
||||
|
||||
|
||||
@singleton
|
||||
class InfinityConnection(InfinityConnectionBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mapping_file_name = "message_infinity_mapping.json"
|
||||
|
||||
"""
|
||||
Dataframe and fields convert
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def field_keyword(field_name: str):
|
||||
# no keywords right now
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def convert_message_field_to_infinity(field_name: str):
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
@staticmethod
|
||||
def convert_infinity_field_to_message(field_name: str):
|
||||
if field_name.startswith("message_type"):
|
||||
return "message_type"
|
||||
if field_name.startswith("status"):
|
||||
return "status"
|
||||
if re.match(r"q_\d+_vec", field_name):
|
||||
return "content_embed"
|
||||
return field_name
|
||||
|
||||
def convert_select_fields(self, output_fields: list[str]) -> list[str]:
|
||||
return list({self.convert_message_field_to_infinity(f) for f in output_fields})
|
||||
|
||||
@staticmethod
|
||||
def convert_matching_field(field_weight_str: str) -> str:
|
||||
tokens = field_weight_str.split("^")
|
||||
field = tokens[0]
|
||||
if field == "content":
|
||||
field = "content@ft_contentm_rag_fine"
|
||||
tokens[0] = field
|
||||
return "^".join(tokens)
|
||||
|
||||
@staticmethod
|
||||
def convert_condition_and_order_field(field_name: str):
|
||||
match field_name:
|
||||
case "message_type":
|
||||
return "message_type_kwd"
|
||||
case "status":
|
||||
return "status_int"
|
||||
case "valid_at":
|
||||
return "valid_at_flt"
|
||||
case "invalid_at":
|
||||
return "invalid_at_flt"
|
||||
case "forget_at":
|
||||
return "forget_at_flt"
|
||||
case _:
|
||||
return field_name
|
||||
|
||||
"""
|
||||
CRUD operations
|
||||
"""
|
||||
|
||||
def search(
|
||||
self,
|
||||
select_fields: list[str],
|
||||
highlight_fields: list[str],
|
||||
condition: dict,
|
||||
match_expressions: list[MatchExpr],
|
||||
order_by: OrderByExpr,
|
||||
offset: int,
|
||||
limit: int,
|
||||
index_names: str | list[str],
|
||||
memory_ids: list[str],
|
||||
agg_fields: list[str] | None = None,
|
||||
rank_feature: dict | None = None,
|
||||
hide_forgotten: bool = True,
|
||||
) -> tuple[pd.DataFrame, int]:
|
||||
"""
|
||||
BUG: Infinity returns empty for a highlight field if the query string doesn't use that field.
|
||||
"""
|
||||
if isinstance(index_names, str):
|
||||
index_names = index_names.split(",")
|
||||
assert isinstance(index_names, list) and len(index_names) > 0
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
table_list = list()
|
||||
if hide_forgotten:
|
||||
condition.update({"must_not": {"exists": "forget_at_flt"}})
|
||||
output = select_fields.copy()
|
||||
output = self.convert_select_fields(output)
|
||||
if agg_fields is None:
|
||||
agg_fields = []
|
||||
for essential_field in ["id"] + agg_fields:
|
||||
if essential_field not in output:
|
||||
output.append(essential_field)
|
||||
score_func = ""
|
||||
score_column = ""
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
score_func = "score()"
|
||||
score_column = "SCORE"
|
||||
break
|
||||
if not score_func:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchDenseExpr):
|
||||
score_func = "similarity()"
|
||||
score_column = "SIMILARITY"
|
||||
break
|
||||
if match_expressions:
|
||||
if score_func not in output:
|
||||
output.append(score_func)
|
||||
if PAGERANK_FLD not in output:
|
||||
output.append(PAGERANK_FLD)
|
||||
output = [f for f in output if f != "_score"]
|
||||
if limit <= 0:
|
||||
# ElasticSearch default limit is 10000
|
||||
limit = 10000
|
||||
|
||||
# Prepare expressions common to all tables
|
||||
filter_cond = None
|
||||
filter_fulltext = ""
|
||||
if condition:
|
||||
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
|
||||
table_found = False
|
||||
for indexName in index_names:
|
||||
for mem_id in memory_ids:
|
||||
table_name = f"{indexName}_{mem_id}"
|
||||
try:
|
||||
filter_cond = self.equivalent_condition_to_str(condition_dict, db_instance.get_table(table_name))
|
||||
table_found = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
if table_found:
|
||||
break
|
||||
if not table_found:
|
||||
self.logger.error(f"No valid tables found for indexNames {index_names} and memoryIds {memory_ids}")
|
||||
return pd.DataFrame(), 0
|
||||
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
if filter_cond and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_cond})
|
||||
matchExpr.fields = [self.convert_matching_field(field) for field in matchExpr.fields]
|
||||
fields = ",".join(matchExpr.fields)
|
||||
filter_fulltext = f"filter_fulltext('{fields}', '{matchExpr.matching_text}')"
|
||||
if filter_cond:
|
||||
filter_fulltext = f"({filter_cond}) AND {filter_fulltext}"
|
||||
minimum_should_match = matchExpr.extra_options.get("minimum_should_match", 0.0)
|
||||
if isinstance(minimum_should_match, float):
|
||||
str_minimum_should_match = str(int(minimum_should_match * 100)) + "%"
|
||||
matchExpr.extra_options["minimum_should_match"] = str_minimum_should_match
|
||||
|
||||
# Add rank_feature support
|
||||
if rank_feature and "rank_features" not in matchExpr.extra_options:
|
||||
# Convert rank_feature dict to Infinity's rank_features string format
|
||||
# Format: "field^feature_name^weight,field^feature_name^weight"
|
||||
rank_features_list = []
|
||||
for feature_name, weight in rank_feature.items():
|
||||
# Use TAG_FLD as the field containing rank features
|
||||
rank_features_list.append(f"{TAG_FLD}^{feature_name}^{weight}")
|
||||
if rank_features_list:
|
||||
matchExpr.extra_options["rank_features"] = ",".join(rank_features_list)
|
||||
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
self.logger.debug(f"INFINITY search MatchTextExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
if filter_fulltext and "filter" not in matchExpr.extra_options:
|
||||
matchExpr.extra_options.update({"filter": filter_fulltext})
|
||||
for k, v in matchExpr.extra_options.items():
|
||||
if not isinstance(v, str):
|
||||
matchExpr.extra_options[k] = str(v)
|
||||
similarity = matchExpr.extra_options.get("similarity")
|
||||
if similarity:
|
||||
matchExpr.extra_options["threshold"] = similarity
|
||||
del matchExpr.extra_options["similarity"]
|
||||
self.logger.debug(f"INFINITY search MatchDenseExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
self.logger.debug(f"INFINITY search FusionExpr: {json.dumps(matchExpr.__dict__)}")
|
||||
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
|
||||
total_hits_count = 0
|
||||
# Scatter search tables and gather the results
|
||||
for indexName in index_names:
|
||||
for memory_id in memory_ids:
|
||||
table_name = f"{indexName}_{memory_id}"
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
continue
|
||||
table_list.append(table_name)
|
||||
builder = table_instance.output(output)
|
||||
if len(match_expressions) > 0:
|
||||
for matchExpr in match_expressions:
|
||||
if isinstance(matchExpr, MatchTextExpr):
|
||||
fields = ",".join(matchExpr.fields)
|
||||
builder = builder.match_text(
|
||||
fields,
|
||||
matchExpr.matching_text,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, MatchDenseExpr):
|
||||
builder = builder.match_dense(
|
||||
matchExpr.vector_column_name,
|
||||
matchExpr.embedding_data,
|
||||
matchExpr.embedding_data_type,
|
||||
matchExpr.distance_type,
|
||||
matchExpr.topn,
|
||||
matchExpr.extra_options.copy(),
|
||||
)
|
||||
elif isinstance(matchExpr, FusionExpr):
|
||||
builder = builder.fusion(matchExpr.method, matchExpr.topn, matchExpr.fusion_params)
|
||||
else:
|
||||
if filter_cond and len(filter_cond) > 0:
|
||||
builder.filter(filter_cond)
|
||||
if order_by.fields:
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(offset).limit(limit)
|
||||
mem_res, extra_result = builder.option({"total_hits_count": True}).to_df()
|
||||
if extra_result:
|
||||
total_hits_count += int(extra_result["total_hits_count"])
|
||||
self.logger.debug(f"INFINITY search table: {str(table_name)}, result: {str(mem_res)}")
|
||||
df_list.append(mem_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = self.concat_dataframes(df_list, output)
|
||||
if match_expressions:
|
||||
res["_score"] = res[score_column] + res[PAGERANK_FLD]
|
||||
res = res.sort_values(by="_score", ascending=False).reset_index(drop=True)
|
||||
res = res.head(limit)
|
||||
self.logger.debug(f"INFINITY search final result: {str(res)}")
|
||||
return res, total_hits_count
|
||||
|
||||
def get_forgotten_messages(self, select_fields: list[str], index_name: str, memory_id: str, limit: int=2000):
|
||||
condition = {"memory_id": memory_id, "exists": "forget_at_flt"}
|
||||
order_by = OrderByExpr()
|
||||
order_by.asc("forget_at_flt")
|
||||
# query
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
output_fields = [self.convert_message_field_to_infinity(f) for f in select_fields]
|
||||
builder = table_instance.output(output_fields)
|
||||
filter_cond = self.equivalent_condition_to_str(condition, db_instance.get_table(table_name))
|
||||
builder.filter(filter_cond)
|
||||
order_by_expr_list = list()
|
||||
if order_by.fields:
|
||||
for order_field in order_by.fields:
|
||||
order_field_name = self.convert_condition_and_order_field(order_field[0])
|
||||
if order_field[1] == 0:
|
||||
order_by_expr_list.append((order_field_name, SortType.Asc))
|
||||
else:
|
||||
order_by_expr_list.append((order_field_name, SortType.Desc))
|
||||
builder.sort(order_by_expr_list)
|
||||
builder.offset(0).limit(limit)
|
||||
mem_res, _ = builder.option({"total_hits_count": True}).to_df()
|
||||
res = self.concat_dataframes(mem_res, output_fields)
|
||||
res.head(limit)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return res
|
||||
|
||||
def get(self, message_id: str, index_name: str, memory_ids: list[str]) -> dict | None:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
df_list = list()
|
||||
assert isinstance(memory_ids, list)
|
||||
table_list = list()
|
||||
for memoryId in memory_ids:
|
||||
table_name = f"{index_name}_{memoryId}"
|
||||
table_list.append(table_name)
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except Exception:
|
||||
self.logger.warning(f"Table not found: {table_name}, this memory isn't created in Infinity. Maybe it is created in other document engine.")
|
||||
continue
|
||||
mem_res, _ = table_instance.output(["*"]).filter(f"id = '{message_id}'").to_df()
|
||||
self.logger.debug(f"INFINITY get table: {str(table_list)}, result: {str(mem_res)}")
|
||||
df_list.append(mem_res)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
res = self.concat_dataframes(df_list, ["id"])
|
||||
fields = set(res.columns.tolist())
|
||||
res_fields = self.get_fields(res, list(fields))
|
||||
return res_fields.get(message_id, None)
|
||||
|
||||
def insert(self, documents: list[dict], index_name: str, memory_id: str = None) -> list[str]:
|
||||
if not documents:
|
||||
return []
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
vector_size = int(len(documents[0]["content_embed"]))
|
||||
try:
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
except InfinityException as e:
|
||||
# src/common/status.cppm, kTableNotExist = 3022
|
||||
if e.error_code != ErrorCode.TABLE_NOT_EXIST:
|
||||
raise
|
||||
if vector_size == 0:
|
||||
raise ValueError("Cannot infer vector size from documents")
|
||||
self.create_idx(index_name, memory_id, vector_size)
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
# embedding fields can't have a default value....
|
||||
embedding_columns = []
|
||||
table_columns = table_instance.show_columns().rows()
|
||||
for n, ty, _, _ in table_columns:
|
||||
r = re.search(r"Embedding\([a-z]+,([0-9]+)\)", ty)
|
||||
if not r:
|
||||
continue
|
||||
embedding_columns.append((n, int(r.group(1))))
|
||||
|
||||
docs = copy.deepcopy(documents)
|
||||
for d in docs:
|
||||
assert "_id" not in d
|
||||
assert "id" in d
|
||||
for k, v in list(d.items()):
|
||||
field_name = self.convert_message_field_to_infinity(k)
|
||||
if field_name in ["valid_at", "invalid_at", "forget_at"]:
|
||||
d[f"{field_name}_flt"] = date_string_to_timestamp(v) if v else 0
|
||||
if v is None:
|
||||
d[field_name] = ""
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
d[k] = "###".join(v)
|
||||
else:
|
||||
d[k] = v
|
||||
elif k == "memory_id":
|
||||
if isinstance(d[k], list):
|
||||
d[k] = d[k][0] # since d[k] is a list, but we need a str
|
||||
elif field_name == "content_embed":
|
||||
d[f"q_{vector_size}_vec"] = d["content_embed"]
|
||||
d.pop("content_embed")
|
||||
else:
|
||||
d[field_name] = v
|
||||
if k != field_name:
|
||||
d.pop(k)
|
||||
|
||||
for n, vs in embedding_columns:
|
||||
if n in d:
|
||||
continue
|
||||
d[n] = [0] * vs
|
||||
ids = ["'{}'".format(d["id"]) for d in docs]
|
||||
str_ids = ", ".join(ids)
|
||||
str_filter = f"id IN ({str_ids})"
|
||||
table_instance.delete(str_filter)
|
||||
table_instance.insert(docs)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
self.logger.debug(f"INFINITY inserted into {table_name} {str_ids}.")
|
||||
return []
|
||||
|
||||
def update(self, condition: dict, new_value: dict, index_name: str, memory_id: str) -> bool:
|
||||
inf_conn = self.connPool.get_conn()
|
||||
db_instance = inf_conn.get_database(self.dbName)
|
||||
table_name = f"{index_name}_{memory_id}"
|
||||
table_instance = db_instance.get_table(table_name)
|
||||
|
||||
columns = {}
|
||||
if table_instance:
|
||||
for n, ty, de, _ in table_instance.show_columns().rows():
|
||||
columns[n] = (ty, de)
|
||||
condition_dict = {self.convert_condition_and_order_field(k): v for k, v in condition.items()}
|
||||
filter = self.equivalent_condition_to_str(condition_dict, table_instance)
|
||||
update_dict = {self.convert_message_field_to_infinity(k): v for k, v in new_value.items()}
|
||||
date_floats = {}
|
||||
for k, v in update_dict.items():
|
||||
if k in ["valid_at", "invalid_at", "forget_at"]:
|
||||
date_floats[f"{k}_flt"] = date_string_to_timestamp(v) if v else 0
|
||||
elif self.field_keyword(k):
|
||||
if isinstance(v, list):
|
||||
update_dict[k] = "###".join(v)
|
||||
else:
|
||||
update_dict[k] = v
|
||||
elif k == "memory_id":
|
||||
if isinstance(update_dict[k], list):
|
||||
update_dict[k] = update_dict[k][0] # since d[k] is a list, but we need a str
|
||||
else:
|
||||
update_dict[k] = v
|
||||
if date_floats:
|
||||
update_dict.update(date_floats)
|
||||
|
||||
self.logger.debug(f"INFINITY update table {table_name}, filter {filter}, newValue {new_value}.")
|
||||
table_instance.update(filter, update_dict)
|
||||
self.connPool.release_conn(inf_conn)
|
||||
return True
|
||||
|
||||
"""
|
||||
Helper functions for search result
|
||||
"""
|
||||
|
||||
def get_fields(self, res: tuple[pd.DataFrame, int] | pd.DataFrame, fields: list[str]) -> dict[str, dict]:
|
||||
if isinstance(res, tuple):
|
||||
res = res[0]
|
||||
if not fields:
|
||||
return {}
|
||||
fields_all = fields.copy()
|
||||
fields_all.append("id")
|
||||
fields_all = {self.convert_message_field_to_infinity(f) for f in fields_all}
|
||||
|
||||
column_map = {col.lower(): col for col in res.columns}
|
||||
matched_columns = {column_map[col.lower()]: col for col in fields_all if col.lower() in column_map}
|
||||
none_columns = [col for col in fields_all if col.lower() not in column_map]
|
||||
|
||||
res2 = res[matched_columns.keys()]
|
||||
res2 = res2.rename(columns=matched_columns)
|
||||
res2.drop_duplicates(subset=["id"], inplace=True)
|
||||
|
||||
for column in list(res2.columns):
|
||||
k = column.lower()
|
||||
if self.field_keyword(k):
|
||||
res2[column] = res2[column].apply(lambda v: [kwd for kwd in v.split("###") if kwd])
|
||||
else:
|
||||
pass
|
||||
for column in ["content"]:
|
||||
if column in res2:
|
||||
del res2[column]
|
||||
for column in none_columns:
|
||||
res2[column] = None
|
||||
|
||||
res_dict = res2.set_index("id").to_dict(orient="index")
|
||||
return {_id: {self.convert_infinity_field_to_message(k): v for k, v in doc.items()} for _id, doc in res_dict.items()}
|
||||
37
memory/utils/msg_util.py
Normal file
37
memory/utils/msg_util.py
Normal file
@ -0,0 +1,37 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
|
||||
|
||||
def get_json_result_from_llm_response(response_str: str) -> dict:
|
||||
"""
|
||||
Parse the LLM response string to extract JSON content.
|
||||
The function looks for the first and last curly braces to identify the JSON part.
|
||||
If parsing fails, it returns an empty dictionary.
|
||||
|
||||
:param response_str: The response string from the LLM.
|
||||
:return: A dictionary parsed from the JSON content in the response.
|
||||
"""
|
||||
try:
|
||||
clean_str = response_str.strip()
|
||||
if clean_str.startswith('```json'):
|
||||
clean_str = clean_str[7:] # Remove the starting ```json
|
||||
if clean_str.endswith('```'):
|
||||
clean_str = clean_str[:-3] # Remove the ending ```
|
||||
|
||||
return json.loads(clean_str.strip())
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
return {}
|
||||
201
memory/utils/prompt_util.py
Normal file
201
memory/utils/prompt_util.py
Normal file
@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Optional, List
|
||||
|
||||
from common.constants import MemoryType
|
||||
from common.time_utils import current_timestamp
|
||||
|
||||
class PromptAssembler:
|
||||
|
||||
SYSTEM_BASE_TEMPLATE = """**Memory Extraction Specialist**
|
||||
You are an expert at analyzing conversations to extract structured memory.
|
||||
|
||||
{type_specific_instructions}
|
||||
|
||||
|
||||
**OUTPUT REQUIREMENTS:**
|
||||
1. Output MUST be valid JSON
|
||||
2. Follow the specified output format exactly
|
||||
3. Each extracted item MUST have: content, valid_at, invalid_at
|
||||
4. Timestamps in {timestamp_format} format
|
||||
5. Only extract memory types specified above
|
||||
6. Maximum {max_items} items per type
|
||||
"""
|
||||
|
||||
TYPE_INSTRUCTIONS = {
|
||||
MemoryType.SEMANTIC.name.lower(): """
|
||||
**EXTRACT SEMANTIC KNOWLEDGE:**
|
||||
- Universal facts, definitions, concepts, relationships
|
||||
- Time-invariant, generally true information
|
||||
- Examples: "The capital of France is Paris", "Water boils at 100°C"
|
||||
|
||||
**Timestamp Rules for Semantic Knowledge:**
|
||||
- valid_at: When the fact became true (e.g., law enactment, discovery)
|
||||
- invalid_at: When it becomes false (e.g., repeal, disproven) or empty if still true
|
||||
- Default: valid_at = conversation time, invalid_at = "" for timeless facts
|
||||
""",
|
||||
|
||||
MemoryType.EPISODIC.name.lower(): """
|
||||
**EXTRACT EPISODIC KNOWLEDGE:**
|
||||
- Specific experiences, events, personal stories
|
||||
- Time-bound, person-specific, contextual
|
||||
- Examples: "Yesterday I fixed the bug", "User reported issue last week"
|
||||
|
||||
**Timestamp Rules for Episodic Knowledge:**
|
||||
- valid_at: Event start/occurrence time
|
||||
- invalid_at: Event end time or empty if instantaneous
|
||||
- Extract explicit times: "at 3 PM", "last Monday", "from X to Y"
|
||||
""",
|
||||
|
||||
MemoryType.PROCEDURAL.name.lower(): """
|
||||
**EXTRACT PROCEDURAL KNOWLEDGE:**
|
||||
- Processes, methods, step-by-step instructions
|
||||
- Goal-oriented, actionable, often includes conditions
|
||||
- Examples: "To reset password, click...", "Debugging steps: 1)..."
|
||||
|
||||
**Timestamp Rules for Procedural Knowledge:**
|
||||
- valid_at: When procedure becomes valid/effective
|
||||
- invalid_at: When it expires/becomes obsolete or empty if current
|
||||
- For version-specific: use release dates
|
||||
- For best practices: invalid_at = ""
|
||||
"""
|
||||
}
|
||||
|
||||
OUTPUT_TEMPLATES = {
|
||||
MemoryType.SEMANTIC.name.lower(): """
|
||||
"semantic": [
|
||||
{
|
||||
"content": "Clear factual statement",
|
||||
"valid_at": "timestamp or empty",
|
||||
"invalid_at": "timestamp or empty"
|
||||
}
|
||||
]
|
||||
""",
|
||||
|
||||
MemoryType.EPISODIC.name.lower(): """
|
||||
"episodic": [
|
||||
{
|
||||
"content": "Narrative event description",
|
||||
"valid_at": "event start timestamp",
|
||||
"invalid_at": "event end timestamp or empty"
|
||||
}
|
||||
]
|
||||
""",
|
||||
|
||||
MemoryType.PROCEDURAL.name.lower(): """
|
||||
"procedural": [
|
||||
{
|
||||
"content": "Actionable instructions",
|
||||
"valid_at": "procedure effective timestamp",
|
||||
"invalid_at": "procedure expiration timestamp or empty"
|
||||
}
|
||||
]
|
||||
"""
|
||||
}
|
||||
|
||||
BASE_USER_PROMPT = """
|
||||
**CONVERSATION:**
|
||||
{conversation}
|
||||
|
||||
**CONVERSATION TIME:** {conversation_time}
|
||||
**CURRENT TIME:** {current_time}
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def assemble_system_prompt(cls, config: dict) -> str:
|
||||
types_to_extract = cls._get_types_to_extract(config["memory_type"])
|
||||
|
||||
type_instructions = cls._generate_type_instructions(types_to_extract)
|
||||
|
||||
output_format = cls._generate_output_format(types_to_extract)
|
||||
|
||||
full_prompt = cls.SYSTEM_BASE_TEMPLATE.format(
|
||||
type_specific_instructions=type_instructions,
|
||||
timestamp_format=config.get("timestamp_format", "ISO 8601"),
|
||||
max_items=config.get("max_items_per_type", 5)
|
||||
)
|
||||
|
||||
full_prompt += f"\n**REQUIRED OUTPUT FORMAT (JSON):**\n```json\n{{\n{output_format}\n}}\n```\n"
|
||||
|
||||
examples = cls._generate_examples(types_to_extract)
|
||||
if examples:
|
||||
full_prompt += f"\n**EXAMPLES:**\n{examples}\n"
|
||||
|
||||
return full_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_types_to_extract(requested_types: List[str]) -> List[str]:
|
||||
types = set()
|
||||
for rt in requested_types:
|
||||
if rt in [e.name.lower() for e in MemoryType] and rt != MemoryType.RAW.name.lower():
|
||||
types.add(rt)
|
||||
return list(types)
|
||||
|
||||
@classmethod
|
||||
def _generate_type_instructions(cls, types_to_extract: List[str]) -> str:
|
||||
target_types = set(types_to_extract)
|
||||
instructions = [cls.TYPE_INSTRUCTIONS[mt] for mt in target_types]
|
||||
return "\n".join(instructions)
|
||||
|
||||
@classmethod
|
||||
def _generate_output_format(cls, types_to_extract: List[str]) -> str:
|
||||
target_types = set(types_to_extract)
|
||||
output_parts = [cls.OUTPUT_TEMPLATES[mt] for mt in target_types]
|
||||
return ",\n".join(output_parts)
|
||||
|
||||
@staticmethod
|
||||
def _generate_examples(types_to_extract: list[str]) -> str:
|
||||
examples = []
|
||||
|
||||
if MemoryType.SEMANTIC.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Semantic Example:**
|
||||
Input: "Python lists are mutable and support various operations."
|
||||
Output: {"semantic": [{"content": "Python lists are mutable data structures", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
|
||||
""")
|
||||
|
||||
if MemoryType.EPISODIC.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Episodic Example:**
|
||||
Input: "I deployed the new feature yesterday afternoon."
|
||||
Output: {"episodic": [{"content": "User deployed new feature", "valid_at": "2024-01-14T14:00:00", "invalid_at": "2024-01-14T18:00:00"}]}
|
||||
""")
|
||||
|
||||
if MemoryType.PROCEDURAL.name.lower() in types_to_extract:
|
||||
examples.append("""
|
||||
**Procedural Example:**
|
||||
Input: "To debug API errors: 1) Check logs 2) Verify endpoints 3) Test connectivity."
|
||||
Output: {"procedural": [{"content": "API error debugging: 1. Check logs 2. Verify endpoints 3. Test connectivity", "valid_at": "2024-01-15T10:00:00", "invalid_at": ""}]}
|
||||
""")
|
||||
|
||||
return "\n".join(examples)
|
||||
|
||||
@classmethod
|
||||
def assemble_user_prompt(
|
||||
cls,
|
||||
conversation: str,
|
||||
conversation_time: Optional[str] = None,
|
||||
current_time: Optional[str] = None
|
||||
) -> str:
|
||||
return cls.BASE_USER_PROMPT.format(
|
||||
conversation=conversation,
|
||||
conversation_time=conversation_time or "Not specified",
|
||||
current_time=current_time or current_timestamp(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_raw_user_prompt(cls):
|
||||
return cls.BASE_USER_PROMPT
|
||||
@ -46,7 +46,7 @@ dependencies = [
|
||||
"groq==0.9.0",
|
||||
"grpcio-status==1.67.1",
|
||||
"html-text==0.6.2",
|
||||
"infinity-sdk==0.6.11",
|
||||
"infinity-sdk==0.6.13",
|
||||
"infinity-emb>=0.0.66,<0.0.67",
|
||||
"jira==3.10.5",
|
||||
"json-repair==0.35.0",
|
||||
|
||||
@ -91,7 +91,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
filename, binary=binary, from_page=from_page, to_page=to_page)
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
random_choices([t for t, _ in sections], k=200)))
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=sections,tbls=tbls,callback=callback,**kwargs)
|
||||
# tbls = [((None, lns), None) for lns in tbls]
|
||||
sections=[(item[0],item[1] if item[1] is not None else "") for item in sections if not isinstance(item[1], Image.Image)]
|
||||
callback(0.8, "Finish parsing.")
|
||||
@ -147,9 +147,16 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
|
||||
elif re.search(r"\.doc$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
with BytesIO(binary) as binary:
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = parser.from_buffer(binary)
|
||||
try:
|
||||
from tika import parser as tika_parser
|
||||
except Exception as e:
|
||||
callback(0.8, f"tika not available: {e}. Unsupported .doc parsing.")
|
||||
logging.warning(f"tika not available: {e}. Unsupported .doc parsing for {filename}.")
|
||||
return []
|
||||
|
||||
binary = BytesIO(binary)
|
||||
doc_parsed = tika_parser.from_buffer(binary)
|
||||
if doc_parsed.get('content', None) is not None:
|
||||
sections = doc_parsed['content'].split('\n')
|
||||
sections = [(line, "") for line in sections if line]
|
||||
remove_contents_table(sections, eng=is_english(
|
||||
|
||||
@ -312,7 +312,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
tk_cnt = num_tokens_from_string(txt)
|
||||
if sec_id > -1:
|
||||
last_sid = sec_id
|
||||
tbls=vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_pdf_wrapper(tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
res.extend(tokenize_chunks(chunks, doc, eng, pdf_parser))
|
||||
table_ctx = max(0, int(parser_config.get("table_context_size", 0) or 0))
|
||||
@ -325,7 +325,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
docx_parser = Docx()
|
||||
ti_list, tbls = docx_parser(filename, binary,
|
||||
from_page=0, to_page=10000, callback=callback)
|
||||
tbls=vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
|
||||
tbls = vision_figure_parser_docx_wrapper(sections=ti_list,tbls=tbls,callback=callback,**kwargs)
|
||||
res = tokenize_table(tbls, doc, eng)
|
||||
for text, image in ti_list:
|
||||
d = copy.deepcopy(doc)
|
||||
|
||||
@ -651,7 +651,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000, lang="Chinese", ca
|
||||
"parser_config", {
|
||||
"chunk_token_num": 512, "delimiter": "\n!?。;!?", "layout_recognize": "DeepDOC", "analyze_hyperlink": True})
|
||||
|
||||
child_deli = parser_config.get("children_delimiter", "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
|
||||
child_deli = (parser_config.get("children_delimiter") or "").encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
|
||||
cust_child_deli = re.findall(r"`([^`]+)`", child_deli)
|
||||
child_deli = "|".join(re.sub(r"`([^`]+)`", "", child_deli))
|
||||
if cust_child_deli:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user